Experiment with local transcription

This commit is contained in:
Ali 2022-05-31 02:00:09 +04:00
parent 4c19ffb361
commit 1b0b7660db
3 changed files with 72 additions and 15 deletions

View File

@ -7,6 +7,7 @@ private var sharedRecognizers: [String: NSObject] = [:]
private struct TranscriptionResult {
var text: String
var confidence: Float
var isFinal: Bool
}
private func transcribeAudio(path: String, locale: String) -> Signal<TranscriptionResult?, NoError> {
@ -53,7 +54,7 @@ private func transcribeAudio(path: String, locale: String) -> Signal<Transcripti
let request = SFSpeechURLRecognitionRequest(url: URL(fileURLWithPath: tempFilePath))
request.requiresOnDeviceRecognition = speechRecognizer.supportsOnDeviceRecognition
request.shouldReportPartialResults = false
request.shouldReportPartialResults = true
let task = speechRecognizer.recognitionTask(with: request, resultHandler: { result, error in
if let result = result {
@ -62,8 +63,11 @@ private func transcribeAudio(path: String, locale: String) -> Signal<Transcripti
confidence += segment.confidence
}
confidence /= Float(result.bestTranscription.segments.count)
subscriber.putNext(TranscriptionResult(text: result.bestTranscription.formattedString, confidence: confidence))
subscriber.putCompletion()
subscriber.putNext(TranscriptionResult(text: result.bestTranscription.formattedString, confidence: confidence, isFinal: result.isFinal))
if result.isFinal {
subscriber.putCompletion()
}
} else {
print("transcribeAudio: locale: \(locale), error: \(String(describing: error))")
@ -91,7 +95,12 @@ private func transcribeAudio(path: String, locale: String) -> Signal<Transcripti
|> runOn(.mainQueue())
}
public func transcribeAudio(path: String, appLocale: String) -> Signal<String?, NoError> {
public struct LocallyTranscribedAudio {
public var text: String
public var isFinal: Bool
}
public func transcribeAudio(path: String, appLocale: String) -> Signal<LocallyTranscribedAudio?, NoError> {
var signals: [Signal<TranscriptionResult?, NoError>] = []
var locales: [String] = []
if !locales.contains(Locale.current.identifier) {
@ -113,10 +122,12 @@ public func transcribeAudio(path: String, appLocale: String) -> Signal<String?,
}
return resultSignal
|> map { results -> String? in
|> map { results -> LocallyTranscribedAudio? in
let sortedResults = results.compactMap({ $0 }).sorted(by: { lhs, rhs in
return lhs.confidence > rhs.confidence
})
return sortedResults.first?.text
return sortedResults.first.flatMap { result -> LocallyTranscribedAudio in
return LocallyTranscribedAudio(text: result.text, isFinal: result.isFinal)
}
}
}

View File

@ -332,6 +332,20 @@ public extension TelegramEngine {
return _internal_transcribeAudio(postbox: self.account.postbox, network: self.account.network, audioTranscriptionManager: self.account.stateManager.audioTranscriptionManager, messageId: messageId)
}
public func storeLocallyTranscribedAudio(messageId: MessageId, text: String, isFinal: Bool) -> Signal<Never, NoError> {
return self.account.postbox.transaction { transaction -> Void in
transaction.updateMessage(messageId, update: { currentMessage in
let storeForwardInfo = currentMessage.forwardInfo.flatMap(StoreMessageForwardInfo.init)
var attributes = currentMessage.attributes.filter { !($0 is AudioTranscriptionMessageAttribute) }
attributes.append(AudioTranscriptionMessageAttribute(id: 0, text: text, isPending: !isFinal, didRate: false))
return .update(StoreMessage(id: currentMessage.id, globallyUniqueId: currentMessage.globallyUniqueId, groupingKey: currentMessage.groupingKey, threadId: currentMessage.threadId, timestamp: currentMessage.timestamp, flags: StoreMessageFlags(currentMessage.flags), tags: currentMessage.tags, globalTags: currentMessage.globalTags, localTags: currentMessage.localTags, forwardInfo: storeForwardInfo, authorId: currentMessage.author?.id, text: currentMessage.text, attributes: attributes, media: currentMessage.media))
})
}
|> ignoreValues
}
public func rateAudioTranscription(messageId: MessageId, id: Int64, isGood: Bool) -> Signal<Never, NoError> {
return _internal_rateAudioTranscription(postbox: self.account.postbox, network: self.account.network, messageId: messageId, id: id, isGood: isGood)
}

View File

@ -32,7 +32,7 @@ private struct FetchControls {
}
private enum TranscribedText {
case success(String)
case success(text: String, isPending: Bool)
case error
}
@ -40,7 +40,7 @@ private func transcribedText(message: Message) -> TranscribedText? {
for attribute in message.attributes {
if let attribute = attribute as? AudioTranscriptionMessageAttribute {
if !attribute.text.isEmpty {
return .success(attribute.text)
return .success(text: attribute.text, isPending: attribute.isPending)
} else {
return .error
}
@ -343,7 +343,21 @@ final class ChatMessageInteractiveFileNode: ASDisplayNode {
return
}
if transcribedText(message: message) == nil {
var shouldBeginTranscription = false
var shouldExpandNow = false
if let result = transcribedText(message: message) {
shouldExpandNow = true
if case let .success(_, isPending) = result {
shouldBeginTranscription = isPending
} else {
shouldBeginTranscription = true
}
} else {
shouldBeginTranscription = true
}
if shouldBeginTranscription {
if self.transcribeDisposable == nil {
self.audioTranscriptionState = .inProgress
self.requestUpdateLayout(true)
@ -351,7 +365,7 @@ final class ChatMessageInteractiveFileNode: ASDisplayNode {
if context.sharedContext.immediateExperimentalUISettings.localTranscription {
let appLocale = presentationData.strings.baseLanguageCode
let signal: Signal<String?, NoError> = context.engine.data.get(TelegramEngine.EngineData.Item.Messages.Message(id: message.id))
let signal: Signal<LocallyTranscribedAudio?, NoError> = context.engine.data.get(TelegramEngine.EngineData.Item.Messages.Message(id: message.id))
|> mapToSignal { message -> Signal<String?, NoError> in
guard let message = message else {
return .single(nil)
@ -376,14 +390,26 @@ final class ChatMessageInteractiveFileNode: ASDisplayNode {
return TempBox.shared.tempFile(fileName: "audio.m4a").path
})
}
|> mapToSignal { result -> Signal<String?, NoError> in
|> mapToSignal { result -> Signal<LocallyTranscribedAudio?, NoError> in
guard let result = result else {
return .single(nil)
}
return transcribeAudio(path: result, appLocale: appLocale)
}
let _ = signal.start(next: { [weak self] result in
self.transcribeDisposable = (signal
|> deliverOnMainQueue).start(next: { [weak self] result in
guard let strongSelf = self, let arguments = strongSelf.arguments else {
return
}
if let result = result {
let _ = arguments.context.engine.messages.storeLocallyTranscribedAudio(messageId: arguments.message.id, text: result.text, isFinal: result.isFinal).start()
} else {
strongSelf.audioTranscriptionState = .collapsed
strongSelf.requestUpdateLayout(true)
}
}, completed: { [weak self] in
guard let strongSelf = self else {
return
}
@ -399,7 +425,9 @@ final class ChatMessageInteractiveFileNode: ASDisplayNode {
})
}
}
} else {
}
if shouldExpandNow {
switch self.audioTranscriptionState {
case .expanded:
self.audioTranscriptionState = .collapsed
@ -615,8 +643,12 @@ final class ChatMessageInteractiveFileNode: ASDisplayNode {
if let transcribedText = transcribedText, case .expanded = effectiveAudioTranscriptionState {
switch transcribedText {
case let .success(text):
textString = NSAttributedString(string: text, font: textFont, textColor: messageTheme.primaryTextColor)
case let .success(text, isPending):
var resultText = text
if isPending {
resultText += " [...]"
}
textString = NSAttributedString(string: resultText, font: textFont, textColor: messageTheme.primaryTextColor)
case .error:
let errorTextFont = Font.regular(floor(arguments.presentationData.fontSize.baseDisplaySize * 15.0 / 17.0))
//TODO:localize