support poll translation

This commit is contained in:
Mikhail Filimonov 2024-05-01 20:36:12 +04:00
parent 1f4297e0db
commit 5ac7dc9356
4 changed files with 157 additions and 36 deletions

View File

@ -210,6 +210,7 @@ private var declaredEncodables: Void = {
declareEncodable(MediaSpoilerMessageAttribute.self, f: { MediaSpoilerMessageAttribute(decoder: $0) })
declareEncodable(AuthSessionInfoAttribute.self, f: { AuthSessionInfoAttribute(decoder: $0) })
declareEncodable(TranslationMessageAttribute.self, f: { TranslationMessageAttribute(decoder: $0) })
declareEncodable(TranslationMessageAttribute.Additional.self, f: { TranslationMessageAttribute.Additional(decoder: $0) })
declareEncodable(SynchronizeAutosaveItemOperation.self, f: { SynchronizeAutosaveItemOperation(decoder: $0) })
declareEncodable(TelegramMediaStory.self, f: { TelegramMediaStory(decoder: $0) })
declareEncodable(SynchronizeViewStoriesOperation.self, f: { SynchronizeViewStoriesOperation(decoder: $0) })

View File

@ -1,10 +1,35 @@
import Postbox
public class TranslationMessageAttribute: MessageAttribute, Equatable {
public struct Additional : PostboxCoding, Equatable {
public let text: String
public let entities: [MessageTextEntity]
public init(text: String, entities: [MessageTextEntity]) {
self.text = text
self.entities = entities
}
public init(decoder: PostboxDecoder) {
self.text = decoder.decodeStringForKey("text", orElse: "")
self.entities = decoder.decodeObjectArrayWithDecoderForKey("entities")
}
public func encode(_ encoder: PostboxEncoder) {
encoder.encodeString(self.text, forKey: "text")
encoder.encodeObjectArray(self.entities, forKey: "entities")
}
}
public let text: String
public let entities: [MessageTextEntity]
public let toLang: String
public let additional:[Additional]
public var associatedPeerIds: [PeerId] {
return []
}
@ -12,16 +37,19 @@ public class TranslationMessageAttribute: MessageAttribute, Equatable {
public init(
text: String,
entities: [MessageTextEntity],
additional:[Additional] = [],
toLang: String
) {
self.text = text
self.entities = entities
self.toLang = toLang
self.additional = additional
}
required public init(decoder: PostboxDecoder) {
self.text = decoder.decodeStringForKey("text", orElse: "")
self.entities = decoder.decodeObjectArrayWithDecoderForKey("entities")
self.additional = decoder.decodeObjectArrayWithDecoderForKey("additional")
self.toLang = decoder.decodeStringForKey("toLang", orElse: "")
}
@ -29,6 +57,7 @@ public class TranslationMessageAttribute: MessageAttribute, Equatable {
encoder.encodeString(self.text, forKey: "text")
encoder.encodeObjectArray(self.entities, forKey: "entities")
encoder.encodeString(self.toLang, forKey: "toLang")
encoder.encodeObjectArray(self.additional, forKey: "additional")
}
public static func ==(lhs: TranslationMessageAttribute, rhs: TranslationMessageAttribute) -> Bool {
@ -41,6 +70,9 @@ public class TranslationMessageAttribute: MessageAttribute, Equatable {
if lhs.toLang != rhs.toLang {
return false
}
if lhs.additional != rhs.additional {
return false
}
return true
}
}

View File

@ -504,8 +504,12 @@ public extension TelegramEngine {
return EngineMessageReactionListContext(account: self.account, message: message, readStats: readStats, reaction: reaction)
}
public func translate(text: String, toLang: String) -> Signal<String?, TranslationError> {
return _internal_translate(network: self.account.network, text: text, toLang: toLang)
public func translate(text: String, toLang: String, entities: [MessageTextEntity] = []) -> Signal<(String, [MessageTextEntity])?, TranslationError> {
return _internal_translate(network: self.account.network, text: text, toLang: toLang, entities: entities)
}
public func translate(texts: [(String, [MessageTextEntity])], toLang: String) -> Signal<[(String, [MessageTextEntity])], TranslationError> {
return _internal_translate_texts(network: self.account.network, texts: texts, toLang: toLang)
}
public func translateMessages(messageIds: [EngineMessage.Id], toLang: String) -> Signal<Void, TranslationError> {

View File

@ -13,11 +13,11 @@ public enum TranslationError {
case limitExceeded
}
func _internal_translate(network: Network, text: String, toLang: String) -> Signal<String?, TranslationError> {
func _internal_translate(network: Network, text: String, toLang: String, entities: [MessageTextEntity] = []) -> Signal<(String, [MessageTextEntity])?, TranslationError> {
var flags: Int32 = 0
flags |= (1 << 1)
return network.request(Api.functions.messages.translateText(flags: flags, peer: nil, id: nil, text: [.textWithEntities(text: text, entities: [])], toLang: toLang))
return network.request(Api.functions.messages.translateText(flags: flags, peer: nil, id: nil, text: [.textWithEntities(text: text, entities: apiEntitiesFromMessageTextEntities(entities, associatedPeers: SimpleDictionary()))], toLang: toLang))
|> mapError { error -> TranslationError in
if error.errorDescription.hasPrefix("FLOOD_WAIT") {
return .limitExceeded
@ -33,11 +33,11 @@ func _internal_translate(network: Network, text: String, toLang: String) -> Sign
return .generic
}
}
|> mapToSignal { result -> Signal<String?, TranslationError> in
|> mapToSignal { result -> Signal<(String, [MessageTextEntity])?, TranslationError> in
switch result {
case let .translateResult(results):
if case let .textWithEntities(text, _) = results.first {
return .single(text)
if case let .textWithEntities(text, entities) = results.first {
return .single((text, messageTextEntitiesFromApiEntities(entities)))
} else {
return .single(nil)
}
@ -45,59 +45,143 @@ func _internal_translate(network: Network, text: String, toLang: String) -> Sign
}
}
func _internal_translate_texts(network: Network, texts: [(String, [MessageTextEntity])], toLang: String) -> Signal<[(String, [MessageTextEntity])], TranslationError> {
var flags: Int32 = 0
flags |= (1 << 1)
var apiTexts: [Api.TextWithEntities] = []
for text in texts {
apiTexts.append(.textWithEntities(text: text.0, entities: apiEntitiesFromMessageTextEntities(text.1, associatedPeers: SimpleDictionary())))
}
return network.request(Api.functions.messages.translateText(flags: flags, peer: nil, id: nil, text: apiTexts, toLang: toLang))
|> mapError { error -> TranslationError in
if error.errorDescription.hasPrefix("FLOOD_WAIT") {
return .limitExceeded
} else if error.errorDescription == "MSG_ID_INVALID" {
return .invalidMessageId
} else if error.errorDescription == "INPUT_TEXT_EMPTY" {
return .textIsEmpty
} else if error.errorDescription == "INPUT_TEXT_TOO_LONG" {
return .textTooLong
} else if error.errorDescription == "TO_LANG_INVALID" {
return .invalidLanguage
} else {
return .generic
}
}
|> mapToSignal { result -> Signal<[(String, [MessageTextEntity])], TranslationError> in
var texts: [(String, [MessageTextEntity])] = []
switch result {
case let .translateResult(results):
for result in results {
if case let .textWithEntities(text, entities) = result {
texts.append((text, messageTextEntitiesFromApiEntities(entities)))
}
}
}
return .single(texts)
}
}
func _internal_translateMessages(account: Account, messageIds: [EngineMessage.Id], toLang: String) -> Signal<Void, TranslationError> {
guard let peerId = messageIds.first?.peerId else {
return .never()
}
return account.postbox.transaction { transaction -> Api.InputPeer? in
return transaction.getPeer(peerId).flatMap(apiInputPeer)
return account.postbox.transaction { transaction -> (Api.InputPeer?, [Message]) in
return (transaction.getPeer(peerId).flatMap(apiInputPeer), messageIds.compactMap({ transaction.getMessage($0) }))
}
|> castError(TranslationError.self)
|> mapToSignal { inputPeer -> Signal<Void, TranslationError> in
|> mapToSignal { (inputPeer, messages) -> Signal<Void, TranslationError> in
guard let inputPeer = inputPeer else {
return .never()
}
let polls = messages.compactMap { msg in
if let poll = msg.media.first as? TelegramMediaPoll {
return (poll, msg.id)
} else {
return nil
}
}
let pollSignals = polls.map { (poll, id) in
var texts: [(String, [MessageTextEntity])] = []
texts.append((poll.text, poll.textEntities))
for option in poll.options {
texts.append((option.text, option.entities))
}
return _internal_translate_texts(network: account.network, texts: texts, toLang: toLang)
}
var flags: Int32 = 0
flags |= (1 << 0)
let id: [Int32] = messageIds.map { $0.id }
return account.network.request(Api.functions.messages.translateText(flags: flags, peer: inputPeer, id: id, text: nil, toLang: toLang))
|> mapError { error -> TranslationError in
if error.errorDescription.hasPrefix("FLOOD_WAIT") {
return .limitExceeded
} else if error.errorDescription == "MSG_ID_INVALID" {
return .invalidMessageId
} else if error.errorDescription == "INPUT_TEXT_EMPTY" {
return .textIsEmpty
} else if error.errorDescription == "INPUT_TEXT_TOO_LONG" {
return .textTooLong
} else if error.errorDescription == "TO_LANG_INVALID" {
return .invalidLanguage
} else {
return .generic
let msgs: Signal<Api.messages.TranslatedText?, TranslationError>
if id.isEmpty {
msgs = .single(nil)
} else {
msgs = account.network.request(Api.functions.messages.translateText(flags: flags, peer: inputPeer, id: id, text: nil, toLang: toLang))
|> map(Optional.init)
|> mapError { error -> TranslationError in
if error.errorDescription.hasPrefix("FLOOD_WAIT") {
return .limitExceeded
} else if error.errorDescription == "MSG_ID_INVALID" {
return .invalidMessageId
} else if error.errorDescription == "INPUT_TEXT_EMPTY" {
return .textIsEmpty
} else if error.errorDescription == "INPUT_TEXT_TOO_LONG" {
return .textTooLong
} else if error.errorDescription == "TO_LANG_INVALID" {
return .invalidLanguage
} else {
return .generic
}
}
}
|> mapToSignal { result -> Signal<Void, TranslationError> in
guard case let .translateResult(results) = result else {
return .complete()
}
return combineLatest(msgs, combineLatest(pollSignals))
|> mapToSignal { (result, pollResults) -> Signal<Void, TranslationError> in
return account.postbox.transaction { transaction in
var index = 0
for result in results {
let messageId = messageIds[index]
if case let .textWithEntities(text, entities) = result {
let updatedAttribute: TranslationMessageAttribute = TranslationMessageAttribute(text: text, entities: messageTextEntitiesFromApiEntities(entities), toLang: toLang)
transaction.updateMessage(messageId, update: { currentMessage in
if case let .translateResult(results) = result {
var index = 0
for result in results {
let messageId = messageIds[index]
if case let .textWithEntities(text, entities) = result {
let updatedAttribute: TranslationMessageAttribute = TranslationMessageAttribute(text: text, entities: messageTextEntitiesFromApiEntities(entities), toLang: toLang)
transaction.updateMessage(messageId, update: { currentMessage in
let storeForwardInfo = currentMessage.forwardInfo.flatMap(StoreMessageForwardInfo.init)
var attributes = currentMessage.attributes.filter { !($0 is TranslationMessageAttribute) }
attributes.append(updatedAttribute)
return .update(StoreMessage(id: currentMessage.id, globallyUniqueId: currentMessage.globallyUniqueId, groupingKey: currentMessage.groupingKey, threadId: currentMessage.threadId, timestamp: currentMessage.timestamp, flags: StoreMessageFlags(currentMessage.flags), tags: currentMessage.tags, globalTags: currentMessage.globalTags, localTags: currentMessage.localTags, forwardInfo: storeForwardInfo, authorId: currentMessage.author?.id, text: currentMessage.text, attributes: attributes, media: currentMessage.media))
})
}
index += 1
}
}
if !pollResults.isEmpty {
for (i, poll) in polls.enumerated() {
let result = pollResults[i]
transaction.updateMessage(poll.1, update: { currentMessage in
let storeForwardInfo = currentMessage.forwardInfo.flatMap(StoreMessageForwardInfo.init)
var attributes = currentMessage.attributes.filter { !($0 is TranslationMessageAttribute) }
var attrOptions: [TranslationMessageAttribute.Additional] = []
for (i, option) in poll.0.options.enumerated() {
let translated = result[i + 1]
attrOptions.append(.init(text: translated.0, entities: translated.1))
}
let updatedAttribute: TranslationMessageAttribute = TranslationMessageAttribute(text: result[0].0, entities: result[0].1, additional: attrOptions, toLang: toLang)
attributes.append(updatedAttribute)
return .update(StoreMessage(id: currentMessage.id, globallyUniqueId: currentMessage.globallyUniqueId, groupingKey: currentMessage.groupingKey, threadId: currentMessage.threadId, timestamp: currentMessage.timestamp, flags: StoreMessageFlags(currentMessage.flags), tags: currentMessage.tags, globalTags: currentMessage.globalTags, localTags: currentMessage.localTags, forwardInfo: storeForwardInfo, authorId: currentMessage.author?.id, text: currentMessage.text, attributes: attributes, media: currentMessage.media))
})
}
index += 1
}
}
|> castError(TranslationError.self)