Experiment with local transcription

This commit is contained in:
Ali 2022-05-31 02:00:09 +04:00
parent 4c19ffb361
commit 1b0b7660db
3 changed files with 72 additions and 15 deletions

View File

@ -7,6 +7,7 @@ private var sharedRecognizers: [String: NSObject] = [:]
private struct TranscriptionResult { private struct TranscriptionResult {
var text: String var text: String
var confidence: Float var confidence: Float
var isFinal: Bool
} }
private func transcribeAudio(path: String, locale: String) -> Signal<TranscriptionResult?, NoError> { private func transcribeAudio(path: String, locale: String) -> Signal<TranscriptionResult?, NoError> {
@ -53,7 +54,7 @@ private func transcribeAudio(path: String, locale: String) -> Signal<Transcripti
let request = SFSpeechURLRecognitionRequest(url: URL(fileURLWithPath: tempFilePath)) let request = SFSpeechURLRecognitionRequest(url: URL(fileURLWithPath: tempFilePath))
request.requiresOnDeviceRecognition = speechRecognizer.supportsOnDeviceRecognition request.requiresOnDeviceRecognition = speechRecognizer.supportsOnDeviceRecognition
request.shouldReportPartialResults = false request.shouldReportPartialResults = true
let task = speechRecognizer.recognitionTask(with: request, resultHandler: { result, error in let task = speechRecognizer.recognitionTask(with: request, resultHandler: { result, error in
if let result = result { if let result = result {
@ -62,8 +63,11 @@ private func transcribeAudio(path: String, locale: String) -> Signal<Transcripti
confidence += segment.confidence confidence += segment.confidence
} }
confidence /= Float(result.bestTranscription.segments.count) confidence /= Float(result.bestTranscription.segments.count)
subscriber.putNext(TranscriptionResult(text: result.bestTranscription.formattedString, confidence: confidence)) subscriber.putNext(TranscriptionResult(text: result.bestTranscription.formattedString, confidence: confidence, isFinal: result.isFinal))
subscriber.putCompletion()
if result.isFinal {
subscriber.putCompletion()
}
} else { } else {
print("transcribeAudio: locale: \(locale), error: \(String(describing: error))") print("transcribeAudio: locale: \(locale), error: \(String(describing: error))")
@ -91,7 +95,12 @@ private func transcribeAudio(path: String, locale: String) -> Signal<Transcripti
|> runOn(.mainQueue()) |> runOn(.mainQueue())
} }
public func transcribeAudio(path: String, appLocale: String) -> Signal<String?, NoError> { public struct LocallyTranscribedAudio {
public var text: String
public var isFinal: Bool
}
public func transcribeAudio(path: String, appLocale: String) -> Signal<LocallyTranscribedAudio?, NoError> {
var signals: [Signal<TranscriptionResult?, NoError>] = [] var signals: [Signal<TranscriptionResult?, NoError>] = []
var locales: [String] = [] var locales: [String] = []
if !locales.contains(Locale.current.identifier) { if !locales.contains(Locale.current.identifier) {
@ -113,10 +122,12 @@ public func transcribeAudio(path: String, appLocale: String) -> Signal<String?,
} }
return resultSignal return resultSignal
|> map { results -> String? in |> map { results -> LocallyTranscribedAudio? in
let sortedResults = results.compactMap({ $0 }).sorted(by: { lhs, rhs in let sortedResults = results.compactMap({ $0 }).sorted(by: { lhs, rhs in
return lhs.confidence > rhs.confidence return lhs.confidence > rhs.confidence
}) })
return sortedResults.first?.text return sortedResults.first.flatMap { result -> LocallyTranscribedAudio in
return LocallyTranscribedAudio(text: result.text, isFinal: result.isFinal)
}
} }
} }

View File

@ -332,6 +332,20 @@ public extension TelegramEngine {
return _internal_transcribeAudio(postbox: self.account.postbox, network: self.account.network, audioTranscriptionManager: self.account.stateManager.audioTranscriptionManager, messageId: messageId) return _internal_transcribeAudio(postbox: self.account.postbox, network: self.account.network, audioTranscriptionManager: self.account.stateManager.audioTranscriptionManager, messageId: messageId)
} }
public func storeLocallyTranscribedAudio(messageId: MessageId, text: String, isFinal: Bool) -> Signal<Never, NoError> {
return self.account.postbox.transaction { transaction -> Void in
transaction.updateMessage(messageId, update: { currentMessage in
let storeForwardInfo = currentMessage.forwardInfo.flatMap(StoreMessageForwardInfo.init)
var attributes = currentMessage.attributes.filter { !($0 is AudioTranscriptionMessageAttribute) }
attributes.append(AudioTranscriptionMessageAttribute(id: 0, text: text, isPending: !isFinal, didRate: false))
return .update(StoreMessage(id: currentMessage.id, globallyUniqueId: currentMessage.globallyUniqueId, groupingKey: currentMessage.groupingKey, threadId: currentMessage.threadId, timestamp: currentMessage.timestamp, flags: StoreMessageFlags(currentMessage.flags), tags: currentMessage.tags, globalTags: currentMessage.globalTags, localTags: currentMessage.localTags, forwardInfo: storeForwardInfo, authorId: currentMessage.author?.id, text: currentMessage.text, attributes: attributes, media: currentMessage.media))
})
}
|> ignoreValues
}
public func rateAudioTranscription(messageId: MessageId, id: Int64, isGood: Bool) -> Signal<Never, NoError> { public func rateAudioTranscription(messageId: MessageId, id: Int64, isGood: Bool) -> Signal<Never, NoError> {
return _internal_rateAudioTranscription(postbox: self.account.postbox, network: self.account.network, messageId: messageId, id: id, isGood: isGood) return _internal_rateAudioTranscription(postbox: self.account.postbox, network: self.account.network, messageId: messageId, id: id, isGood: isGood)
} }

View File

@ -32,7 +32,7 @@ private struct FetchControls {
} }
private enum TranscribedText { private enum TranscribedText {
case success(String) case success(text: String, isPending: Bool)
case error case error
} }
@ -40,7 +40,7 @@ private func transcribedText(message: Message) -> TranscribedText? {
for attribute in message.attributes { for attribute in message.attributes {
if let attribute = attribute as? AudioTranscriptionMessageAttribute { if let attribute = attribute as? AudioTranscriptionMessageAttribute {
if !attribute.text.isEmpty { if !attribute.text.isEmpty {
return .success(attribute.text) return .success(text: attribute.text, isPending: attribute.isPending)
} else { } else {
return .error return .error
} }
@ -343,7 +343,21 @@ final class ChatMessageInteractiveFileNode: ASDisplayNode {
return return
} }
if transcribedText(message: message) == nil { var shouldBeginTranscription = false
var shouldExpandNow = false
if let result = transcribedText(message: message) {
shouldExpandNow = true
if case let .success(_, isPending) = result {
shouldBeginTranscription = isPending
} else {
shouldBeginTranscription = true
}
} else {
shouldBeginTranscription = true
}
if shouldBeginTranscription {
if self.transcribeDisposable == nil { if self.transcribeDisposable == nil {
self.audioTranscriptionState = .inProgress self.audioTranscriptionState = .inProgress
self.requestUpdateLayout(true) self.requestUpdateLayout(true)
@ -351,7 +365,7 @@ final class ChatMessageInteractiveFileNode: ASDisplayNode {
if context.sharedContext.immediateExperimentalUISettings.localTranscription { if context.sharedContext.immediateExperimentalUISettings.localTranscription {
let appLocale = presentationData.strings.baseLanguageCode let appLocale = presentationData.strings.baseLanguageCode
let signal: Signal<String?, NoError> = context.engine.data.get(TelegramEngine.EngineData.Item.Messages.Message(id: message.id)) let signal: Signal<LocallyTranscribedAudio?, NoError> = context.engine.data.get(TelegramEngine.EngineData.Item.Messages.Message(id: message.id))
|> mapToSignal { message -> Signal<String?, NoError> in |> mapToSignal { message -> Signal<String?, NoError> in
guard let message = message else { guard let message = message else {
return .single(nil) return .single(nil)
@ -376,14 +390,26 @@ final class ChatMessageInteractiveFileNode: ASDisplayNode {
return TempBox.shared.tempFile(fileName: "audio.m4a").path return TempBox.shared.tempFile(fileName: "audio.m4a").path
}) })
} }
|> mapToSignal { result -> Signal<String?, NoError> in |> mapToSignal { result -> Signal<LocallyTranscribedAudio?, NoError> in
guard let result = result else { guard let result = result else {
return .single(nil) return .single(nil)
} }
return transcribeAudio(path: result, appLocale: appLocale) return transcribeAudio(path: result, appLocale: appLocale)
} }
let _ = signal.start(next: { [weak self] result in self.transcribeDisposable = (signal
|> deliverOnMainQueue).start(next: { [weak self] result in
guard let strongSelf = self, let arguments = strongSelf.arguments else {
return
}
if let result = result {
let _ = arguments.context.engine.messages.storeLocallyTranscribedAudio(messageId: arguments.message.id, text: result.text, isFinal: result.isFinal).start()
} else {
strongSelf.audioTranscriptionState = .collapsed
strongSelf.requestUpdateLayout(true)
}
}, completed: { [weak self] in
guard let strongSelf = self else { guard let strongSelf = self else {
return return
} }
@ -399,7 +425,9 @@ final class ChatMessageInteractiveFileNode: ASDisplayNode {
}) })
} }
} }
} else { }
if shouldExpandNow {
switch self.audioTranscriptionState { switch self.audioTranscriptionState {
case .expanded: case .expanded:
self.audioTranscriptionState = .collapsed self.audioTranscriptionState = .collapsed
@ -615,8 +643,12 @@ final class ChatMessageInteractiveFileNode: ASDisplayNode {
if let transcribedText = transcribedText, case .expanded = effectiveAudioTranscriptionState { if let transcribedText = transcribedText, case .expanded = effectiveAudioTranscriptionState {
switch transcribedText { switch transcribedText {
case let .success(text): case let .success(text, isPending):
textString = NSAttributedString(string: text, font: textFont, textColor: messageTheme.primaryTextColor) var resultText = text
if isPending {
resultText += " [...]"
}
textString = NSAttributedString(string: resultText, font: textFont, textColor: messageTheme.primaryTextColor)
case .error: case .error:
let errorTextFont = Font.regular(floor(arguments.presentationData.fontSize.baseDisplaySize * 15.0 / 17.0)) let errorTextFont = Font.regular(floor(arguments.presentationData.fontSize.baseDisplaySize * 15.0 / 17.0))
//TODO:localize //TODO:localize