From c83a2364acad3756ff7eefa1691a2c36bb4330e5 Mon Sep 17 00:00:00 2001 From: Peter <> Date: Mon, 21 Oct 2019 23:27:06 +0400 Subject: [PATCH] Calculate local unread count before applying notification --- BUCK | 7 +- .../NotificationService-Bridging-Header.h | 9 +- NotificationService/NotificationService.h | 9 +- NotificationService/NotificationService.m | 74 +- NotificationService/NotificationService.swift | 31 + NotificationService/Sync.h | 1 - NotificationService/Sync.m | 4 +- NotificationService/Sync.swift | 113 +- submodules/Database/Buffers/BUCK | 13 + .../Database/Buffers/Sources/Buffers.swift | 173 ++ .../Database/MessageHistoryMetadataTable/BUCK | 17 + .../Sources/MessageHistoryMetadataTable.swift | 342 +++ .../MessageHistoryReadStateTable/BUCK | 16 + .../MessageHistoryReadStateTable.swift | 576 +++++ submodules/Database/MurmurHash/BUCK | 20 + .../MurmurHash/Sources/MurMurHash32.h | 12 + .../MurmurHash/Sources/MurMurHash32.m | 120 + .../MurmurHash/Sources/MurmurHash.swift | 11 + submodules/Database/PostboxCoding/BUCK | 15 + .../PostboxCoding/Sources/PostboxCoding.swift | 1297 +++++++++++ submodules/Database/PostboxDataTypes/BUCK | 15 + .../Sources/ChatListTotalUnreadState.swift | 86 + .../PostboxDataTypes/Sources/MessageId.swift | 280 +++ .../Sources/MessageIndex.swift | 79 + .../Sources/PeerGroupId.swift | 23 + .../PostboxDataTypes/Sources/PeerId.swift | 89 + .../Sources/PeerReadState.swift | 152 ++ .../Sources/PeerSummaryCounterTags.swift | 32 + submodules/Database/Table/BUCK | 14 + submodules/Database/Table/Sources/Table.swift | 18 + submodules/Database/ValueBox/BUCK | 16 + .../Database/ValueBox/Sources/Database.swift | 72 + .../ValueBox/Sources/SqliteValueBox.swift | 2066 +++++++++++++++++ .../Database/ValueBox/Sources/ValueBox.swift | 95 + .../ValueBox/Sources/ValueBoxKey.swift | 252 ++ .../ValueBox/Sources/ValueBoxLogger.swift | 5 + .../Postbox/Postbox/ChatListIndexTable.swift | 6 +- .../Postbox/GroupMessageStatsTable.swift | 10 + .../Postbox/MessageHistoryMetadataTable.swift | 10 - submodules/Postbox/Postbox/Table.swift | 12 +- 40 files changed, 6091 insertions(+), 101 deletions(-) create mode 100644 NotificationService/NotificationService.swift create mode 100644 submodules/Database/Buffers/BUCK create mode 100644 submodules/Database/Buffers/Sources/Buffers.swift create mode 100644 submodules/Database/MessageHistoryMetadataTable/BUCK create mode 100644 submodules/Database/MessageHistoryMetadataTable/Sources/MessageHistoryMetadataTable.swift create mode 100644 submodules/Database/MessageHistoryReadStateTable/BUCK create mode 100644 submodules/Database/MessageHistoryReadStateTable/Sources/MessageHistoryReadStateTable.swift create mode 100644 submodules/Database/MurmurHash/BUCK create mode 100644 submodules/Database/MurmurHash/Sources/MurMurHash32.h create mode 100644 submodules/Database/MurmurHash/Sources/MurMurHash32.m create mode 100644 submodules/Database/MurmurHash/Sources/MurmurHash.swift create mode 100644 submodules/Database/PostboxCoding/BUCK create mode 100644 submodules/Database/PostboxCoding/Sources/PostboxCoding.swift create mode 100644 submodules/Database/PostboxDataTypes/BUCK create mode 100644 submodules/Database/PostboxDataTypes/Sources/ChatListTotalUnreadState.swift create mode 100644 submodules/Database/PostboxDataTypes/Sources/MessageId.swift create mode 100644 submodules/Database/PostboxDataTypes/Sources/MessageIndex.swift create mode 100644 submodules/Database/PostboxDataTypes/Sources/PeerGroupId.swift create mode 100644 submodules/Database/PostboxDataTypes/Sources/PeerId.swift create mode 100644 submodules/Database/PostboxDataTypes/Sources/PeerReadState.swift create mode 100644 submodules/Database/PostboxDataTypes/Sources/PeerSummaryCounterTags.swift create mode 100644 submodules/Database/Table/BUCK create mode 100644 submodules/Database/Table/Sources/Table.swift create mode 100644 submodules/Database/ValueBox/BUCK create mode 100644 submodules/Database/ValueBox/Sources/Database.swift create mode 100644 submodules/Database/ValueBox/Sources/SqliteValueBox.swift create mode 100644 submodules/Database/ValueBox/Sources/ValueBox.swift create mode 100644 submodules/Database/ValueBox/Sources/ValueBoxKey.swift create mode 100644 submodules/Database/ValueBox/Sources/ValueBoxLogger.swift diff --git a/BUCK b/BUCK index cebcf05495..58bf24ecf0 100644 --- a/BUCK +++ b/BUCK @@ -339,10 +339,13 @@ apple_binary( deps = [ "//submodules/BuildConfig:BuildConfig", "//submodules/MtProtoKit:MtProtoKit#shared", + "//submodules/SSignalKit/SwiftSignalKit:SwiftSignalKit#shared", "//submodules/EncryptionProvider:EncryptionProvider", + "//submodules/Database/ValueBox:ValueBox", + "//submodules/Database/PostboxDataTypes:PostboxDataTypes", + "//submodules/Database/MessageHistoryReadStateTable:MessageHistoryReadStateTable", + "//submodules/Database/MessageHistoryMetadataTable:MessageHistoryMetadataTable", "//submodules/sqlcipher:sqlcipher", - #"//submodules/Postbox:Postbox#shared", - #"//submodules/SyncCore:SyncCore#shared", ], frameworks = [ "$SDKROOT/System/Library/Frameworks/Foundation.framework", diff --git a/NotificationService/NotificationService-Bridging-Header.h b/NotificationService/NotificationService-Bridging-Header.h index 542d5fdd23..d6f4fefa20 100644 --- a/NotificationService/NotificationService-Bridging-Header.h +++ b/NotificationService/NotificationService-Bridging-Header.h @@ -1,13 +1,6 @@ #ifndef NotificationService_BridgingHeader_h #define NotificationService_BridgingHeader_h -#import -#import - -@protocol SyncProvider - -- (void)addIncomingMessageWithRootPath:(NSString * _Nonnull)rootPath accountId:(int64_t)accountId encryptionParameters:(DeviceSpecificEncryptionParameters * _Nonnull)encryptionParameters peerId:(int64_t)peerId messageId:(int32_t)messageId completion:(void (^)(int32_t))completion; - -@end +#import "NotificationService.h" #endif diff --git a/NotificationService/NotificationService.h b/NotificationService/NotificationService.h index 91142ad7de..2903b0bc03 100644 --- a/NotificationService/NotificationService.h +++ b/NotificationService/NotificationService.h @@ -1,9 +1,16 @@ #import #import +#import NS_ASSUME_NONNULL_BEGIN -@interface NotificationService : UNNotificationServiceExtension +@interface NotificationServiceImpl : NSObject + +- (instancetype)initWithCountIncomingMessage:(void (^)(NSString *, int64_t, DeviceSpecificEncryptionParameters *, int64_t, int32_t))countIncomingMessage; + +- (void)updateUnreadCount:(int32_t)unreadCount; +- (void)didReceiveNotificationRequest:(UNNotificationRequest *)request withContentHandler:(void (^)(UNNotificationContent * _Nonnull))contentHandler; +- (void)serviceExtensionTimeWillExpire; @end diff --git a/NotificationService/NotificationService.m b/NotificationService/NotificationService.m index 4acf92730f..59b0436ea2 100644 --- a/NotificationService/NotificationService.m +++ b/NotificationService/NotificationService.m @@ -1,5 +1,7 @@ #import "NotificationService.h" +#import + #import #import @@ -34,47 +36,43 @@ static int64_t makePeerId(int32_t namespace, int32_t value) { return (((int64_t)(namespace)) << 32) | ((int64_t)((uint64_t)((uint32_t)value))); } -@interface ParsedNotificationMessage : NSObject - -@property (nonatomic, readonly) int64_t accountId; -@property (nonatomic, readonly) int64_t peerId; -@property (nonatomic, readonly) int32_t messageId; - -@end - -@implementation ParsedNotificationMessage - -- (instancetype)initWithAccountId:(int64_t)accountId peerId:(int64_t)peerId messageId:(int64_t)messageId { - self = [super init]; - if (self != nil) { - _accountId = accountId; - _peerId = peerId; - _messageId = messageId; +static void reportMemory() { + struct task_basic_info info; + mach_msg_type_number_t size = TASK_BASIC_INFO_COUNT; + kern_return_t kerr = task_info(mach_task_self(), TASK_BASIC_INFO, (task_info_t)&info, &size); + if (kerr == KERN_SUCCESS) { + NSLog(@"Memory in use (in bytes): %lu", info.resident_size); + NSLog(@"Memory in use (in MiB): %f", ((CGFloat)info.resident_size / 1048576)); + } else { + NSLog(@"Error with task_info(): %s", mach_error_string(kerr)); } - return self; } -@end - -@interface NotificationService () { +@interface NotificationServiceImpl () { + void (^_countIncomingMessage)(NSString *, int64_t, DeviceSpecificEncryptionParameters *, int64_t, int32_t); + NSString * _Nullable _rootPath; + DeviceSpecificEncryptionParameters * _Nullable _deviceSpecificEncryptionParameters; NSString * _Nullable _baseAppBundleId; void (^_contentHandler)(UNNotificationContent *); UNMutableNotificationContent * _Nullable _bestAttemptContent; void (^_cancelFetch)(void); - ParsedNotificationMessage * _Nullable _parsedMessage; NSNumber * _Nullable _updatedUnreadCount; bool _contentReady; } @end -@implementation NotificationService +@implementation NotificationServiceImpl -- (instancetype)init { +- (instancetype)initWithCountIncomingMessage:(void (^)(NSString *, int64_t, DeviceSpecificEncryptionParameters *, int64_t, int32_t))countIncomingMessage { self = [super init]; if (self != nil) { + reportMemory(); + + _countIncomingMessage = [countIncomingMessage copy]; + NSString *appBundleIdentifier = [NSBundle mainBundle].bundleIdentifier; NSRange lastDotRange = [appBundleIdentifier rangeOfString:@"." options:NSBackwardsSearch]; if (lastDotRange.location != NSNotFound) { @@ -85,6 +83,9 @@ static int64_t makePeerId(int32_t namespace, int32_t value) { if (appGroupUrl != nil) { NSString *rootPath = [[appGroupUrl path] stringByAppendingPathComponent:@"telegram-data"]; _rootPath = rootPath; + if (rootPath != nil) { + _deviceSpecificEncryptionParameters = [BuildConfig deviceSpecificEncryptionParameters:rootPath baseAppBundleId:_baseAppBundleId]; + } } else { NSAssert(false, @"appGroupUrl == nil"); } @@ -97,6 +98,7 @@ static int64_t makePeerId(int32_t namespace, int32_t value) { - (void)completeWithBestAttemptContent { _contentReady = true; + //_updatedUnreadCount = @(-1); if (_contentReady && _updatedUnreadCount) { [self _internalComplete]; } @@ -110,6 +112,8 @@ static int64_t makePeerId(int32_t namespace, int32_t value) { } - (void)_internalComplete { + reportMemory(); + #ifdef __IPHONE_13_0 if (_baseAppBundleId != nil) { BGAppRefreshTaskRequest *request = [[BGAppRefreshTaskRequest alloc] initWithIdentifier:[_baseAppBundleId stringByAppendingString:@".refresh"]]; @@ -193,16 +197,9 @@ static int64_t makePeerId(int32_t namespace, int32_t value) { peerId = makePeerId(PeerNamespaceCloudChannel, [channelIdString intValue]); } - _parsedMessage = [[ParsedNotificationMessage alloc] initWithAccountId:account.accountId peerId:peerId messageId:messageId]; - - __weak NotificationService *weakSelf = self; - [self addUnreadMessage:_rootPath accountId:account.accountId encryptionParameters:nil peerId:peerId messageId:messageId completion:^(int32_t badge) { - __strong NotificationService *strongSelf = weakSelf; - if (strongSelf == nil) { - return; - } - [strongSelf updateUnreadCount:badge]; - }]; + if (_countIncomingMessage && _deviceSpecificEncryptionParameters) { + _countIncomingMessage(_rootPath, account.accountId, _deviceSpecificEncryptionParameters, peerId, messageId); + } NSString *silentString = decryptedPayload[@"silent"]; if ([silentString isKindOfClass:[NSString class]]) { @@ -375,10 +372,10 @@ static int64_t makePeerId(int32_t namespace, int32_t value) { } else { BuildConfig *buildConfig = [[BuildConfig alloc] initWithBaseAppBundleId:_baseAppBundleId]; - __weak NotificationService *weakSelf = self; + __weak typeof(self) weakSelf = self; _cancelFetch = fetchImage(buildConfig, accountInfos.proxy, account, inputFileLocation, fileDatacenterId, ^(NSData * _Nullable data) { dispatch_async(dispatch_get_main_queue(), ^{ - __strong NotificationService *strongSelf = weakSelf; + __strong typeof(weakSelf) strongSelf = weakSelf; if (strongSelf == nil) { return; } @@ -431,11 +428,4 @@ static int64_t makePeerId(int32_t namespace, int32_t value) { } } -- (void)addUnreadMessage:(NSString * _Nonnull)rootPath accountId:(int64_t)accountId encryptionParameters:(DeviceSpecificEncryptionParameters * _Nonnull)encryptionParameters peerId:(int64_t)peerId messageId:(int32_t)messageId completion:(void (^)(int32_t))completion { - - if (completion) { - completion(-1); - } -} - @end diff --git a/NotificationService/NotificationService.swift b/NotificationService/NotificationService.swift new file mode 100644 index 0000000000..ee86842c8f --- /dev/null +++ b/NotificationService/NotificationService.swift @@ -0,0 +1,31 @@ +import Foundation +import UserNotifications + +@available(iOSApplicationExtension 10.0, *) +@objc(NotificationService) +final class NotificationService: UNNotificationServiceExtension { + private let impl: NotificationServiceImpl + + override init() { + var completion: ((Int32) -> Void)? + self.impl = NotificationServiceImpl(countIncomingMessage: { rootPath, accountId, encryptionParameters, peerId, messageId in + SyncProviderImpl().addIncomingMessage(withRootPath: rootPath, accountId: accountId, encryptionParameters: encryptionParameters, peerId: peerId, messageId: messageId, completion: { count in + completion?(count) + }) + }) + + super.init() + + completion = { [weak self] count in + self?.impl.updateUnreadCount(count) + } + } + + override func didReceive(_ request: UNNotificationRequest, withContentHandler contentHandler: @escaping (UNNotificationContent) -> Void) { + self.impl.didReceive(request, withContentHandler: contentHandler) + } + + override func serviceExtensionTimeWillExpire() { + self.impl.serviceExtensionTimeWillExpire() + } +} diff --git a/NotificationService/Sync.h b/NotificationService/Sync.h index 646fcb770d..274ac473b4 100644 --- a/NotificationService/Sync.h +++ b/NotificationService/Sync.h @@ -1,4 +1,3 @@ #import -#import diff --git a/NotificationService/Sync.m b/NotificationService/Sync.m index f2b2d6a155..6bc986aefe 100644 --- a/NotificationService/Sync.m +++ b/NotificationService/Sync.m @@ -1 +1,3 @@ -#import "Sync.h" \ No newline at end of file +#import "Sync.h" +//#import + diff --git a/NotificationService/Sync.swift b/NotificationService/Sync.swift index 0de5fd0b6b..7ec237a2f3 100644 --- a/NotificationService/Sync.swift +++ b/NotificationService/Sync.swift @@ -1,32 +1,95 @@ -//import SwiftSignalKit -//import Postbox -//import SyncCore -//import BuildConfig +import Foundation +import SwiftSignalKit +import ValueBox +import MessageHistoryReadStateTable +import MessageHistoryMetadataTable +import PostboxDataTypes -@objc(SyncProviderImpl) -final class SyncProviderImpl: NSObject { +private func accountRecordIdPathName(_ id: Int64) -> String { + return "account-\(UInt64(bitPattern: id))" } -/*@objc(SyncProviderImpl) -final class SyncProviderImpl: NSObject, SyncProvider { - func addIncomingMessage(withRootPath rootPath: String, accountId: Int64, encryptionParameters: DeviceSpecificEncryptionParameters, peerId: Int64, messageId: Int32, completion: ((Int32) -> Void)!) { - let _ = (addIncomingMessageImpl(rootPath: rootPath, accountId: accountId, encryptionParameters: ValueBoxEncryptionParameters(forceEncryptionIfNoSet: false, key: ValueBoxEncryptionParameters.Key(data: encryptionParameters.key)!, salt: ValueBoxEncryptionParameters.Salt(data: encryptionParameters.salt)!), peerId: peerId, messageId: messageId) - |> deliverOnMainQueue).start(next: { result in - completion(Int32(clamping: result)) - }) +private final class ValueBoxLoggerImpl: ValueBoxLogger { + func log(_ what: String) { + print("ValueBox: \(what)") } } -private func addIncomingMessageImpl(rootPath: String, accountId: Int64, encryptionParameters: ValueBoxEncryptionParameters, peerId: Int64, messageId: Int32) -> Signal { - return accountTransaction(rootPath: rootPath, id: AccountRecordId(rawValue: accountId), encryptionParameters: encryptionParameters, transaction: { transaction -> Int in - transaction.countIncomingMessage(id: MessageId(peerId: PeerId(peerId), namespace: Namespaces.Message.Cloud, id: messageId)) - let totalUnreadState = transaction.getTotalUnreadState() - let totalCount = totalUnreadState.count(for: .filtered, in: .chats, with: [ - .regularChatsAndPrivateGroups, - .publicGroups, - .channels - ]) - return Int(totalCount) - }) +private extension PeerSummaryCounterTags { + static let regularChatsAndPrivateGroups = PeerSummaryCounterTags(rawValue: 1 << 0) + static let publicGroups = PeerSummaryCounterTags(rawValue: 1 << 1) + static let channels = PeerSummaryCounterTags(rawValue: 1 << 2) +} + +private struct Namespaces { + struct Message { + static let Cloud: Int32 = 0 + } + + struct Peer { + static let CloudUser: Int32 = 0 + static let CloudGroup: Int32 = 1 + static let CloudChannel: Int32 = 2 + static let SecretChat: Int32 = 3 + } +} + +final class SyncProviderImpl { + func addIncomingMessage(withRootPath rootPath: String, accountId: Int64, encryptionParameters: DeviceSpecificEncryptionParameters, peerId: Int64, messageId: Int32, completion: @escaping (Int32) -> Void) { + Queue.mainQueue().async { + let basePath = rootPath + "/" + accountRecordIdPathName(accountId) + "/postbox" + + let valueBox = SqliteValueBox(basePath: basePath + "/db", queue: Queue.mainQueue(), logger: ValueBoxLoggerImpl(), encryptionParameters: ValueBoxEncryptionParameters(forceEncryptionIfNoSet: false, key: ValueBoxEncryptionParameters.Key(data: encryptionParameters.key)!, salt: ValueBoxEncryptionParameters.Salt(data: encryptionParameters.salt)!), disableCache: true, upgradeProgress: { _ in + }) + + let metadataTable = MessageHistoryMetadataTable(valueBox: valueBox, table: MessageHistoryMetadataTable.tableSpec(10)) + let readStateTable = MessageHistoryReadStateTable(valueBox: valueBox, table: MessageHistoryReadStateTable.tableSpec(14), defaultMessageNamespaceReadStates: [:]) + + let peerId = PeerId(peerId) + + let initialCombinedState = readStateTable.getCombinedState(peerId) + let (combinedState, _) = readStateTable.addIncomingMessages(peerId, indices: Set([MessageIndex(id: MessageId(peerId: peerId, namespace: 0, id: messageId), timestamp: 1)])) + if let combinedState = combinedState { + let initialCount = initialCombinedState?.count ?? 0 + let updatedCount = combinedState.count + let deltaCount = max(0, updatedCount - initialCount) + + let tag: PeerSummaryCounterTags + if peerId.namespace == Namespaces.Peer.CloudChannel { + tag = .channels + } else { + tag = .regularChatsAndPrivateGroups + } + + var totalCount: Int32 = -1 + + var totalUnreadState = metadataTable.getChatListTotalUnreadState() + if var counters = totalUnreadState.absoluteCounters[tag] { + if initialCount == 0 && updatedCount > 0 { + counters.chatCount += 1 + } + counters.messageCount += deltaCount + totalUnreadState.absoluteCounters[tag] = counters + } + if var counters = totalUnreadState.filteredCounters[tag] { + if initialCount == 0 && updatedCount > 0 { + counters.chatCount += 1 + } + counters.messageCount += deltaCount + totalUnreadState.filteredCounters[tag] = counters + } + + totalCount = totalUnreadState.count(for: .filtered, in: .messages, with: [.channels, .publicGroups, .regularChatsAndPrivateGroups]) + metadataTable.setChatListTotalUnreadState(totalUnreadState) + metadataTable.setShouldReindexUnreadCounts(value: true) + + metadataTable.beforeCommit() + readStateTable.beforeCommit() + + completion(totalCount) + } else { + completion(-1) + } + } + } } -*/ diff --git a/submodules/Database/Buffers/BUCK b/submodules/Database/Buffers/BUCK new file mode 100644 index 0000000000..9a4ddc96ff --- /dev/null +++ b/submodules/Database/Buffers/BUCK @@ -0,0 +1,13 @@ +load("//Config:buck_rule_macros.bzl", "static_library") + +static_library( + name = "Buffers", + srcs = glob([ + "Sources/**/*.swift", + ]), + deps = [ + ], + frameworks = [ + "$SDKROOT/System/Library/Frameworks/Foundation.framework", + ], +) diff --git a/submodules/Database/Buffers/Sources/Buffers.swift b/submodules/Database/Buffers/Sources/Buffers.swift new file mode 100644 index 0000000000..d7544c07af --- /dev/null +++ b/submodules/Database/Buffers/Sources/Buffers.swift @@ -0,0 +1,173 @@ +import Foundation + +private let emptyMemory = malloc(1)! + +public class MemoryBuffer: Equatable, CustomStringConvertible { + public internal(set) var memory: UnsafeMutableRawPointer + var capacity: Int + public internal(set) var length: Int + var freeWhenDone: Bool + + public init(copyOf buffer: MemoryBuffer) { + self.memory = malloc(buffer.length) + memcpy(self.memory, buffer.memory, buffer.length) + self.capacity = buffer.length + self.length = buffer.length + self.freeWhenDone = true + } + + public init(memory: UnsafeMutableRawPointer, capacity: Int, length: Int, freeWhenDone: Bool) { + self.memory = memory + self.capacity = capacity + self.length = length + self.freeWhenDone = freeWhenDone + } + + public init(data: Data) { + if data.count == 0 { + self.memory = emptyMemory + self.capacity = 0 + self.length = 0 + self.freeWhenDone = false + } else { + self.memory = malloc(data.count)! + data.copyBytes(to: self.memory.assumingMemoryBound(to: UInt8.self), count: data.count) + self.capacity = data.count + self.length = data.count + self.freeWhenDone = false + } + } + + public init() { + self.memory = emptyMemory + self.capacity = 0 + self.length = 0 + self.freeWhenDone = false + } + + deinit { + if self.freeWhenDone { + free(self.memory) + } + } + + public var description: String { + let hexString = NSMutableString() + let bytes = self.memory.assumingMemoryBound(to: UInt8.self) + for i in 0 ..< self.length { + hexString.appendFormat("%02x", UInt(bytes[i])) + } + + return hexString as String + } + + public func makeData() -> Data { + if self.length == 0 { + return Data() + } else { + return Data(bytes: self.memory, count: self.length) + } + } + + public func withDataNoCopy(_ f: (Data) -> Void) { + f(Data(bytesNoCopy: self.memory, count: self.length, deallocator: .none)) + } + + public static func ==(lhs: MemoryBuffer, rhs: MemoryBuffer) -> Bool { + return lhs.length == rhs.length && memcmp(lhs.memory, rhs.memory, lhs.length) == 0 + } +} + +public final class WriteBuffer: MemoryBuffer { + public var offset = 0 + + public override init() { + super.init(memory: malloc(32), capacity: 32, length: 0, freeWhenDone: true) + } + + public func makeReadBufferAndReset() -> ReadBuffer { + let buffer = ReadBuffer(memory: self.memory, length: self.offset, freeWhenDone: true) + self.memory = malloc(32) + self.capacity = 32 + self.offset = 0 + return buffer + } + + public func readBufferNoCopy() -> ReadBuffer { + return ReadBuffer(memory: self.memory, length: self.offset, freeWhenDone: false) + } + + override public func makeData() -> Data { + return Data(bytes: self.memory.assumingMemoryBound(to: UInt8.self), count: self.offset) + } + + public func reset() { + self.offset = 0 + } + + public func write(_ data: UnsafeRawPointer, offset: Int, length: Int) { + if self.offset + length > self.capacity { + self.capacity = self.offset + length + 256 + if self.length == 0 { + self.memory = malloc(self.capacity)! + } else { + self.memory = realloc(self.memory, self.capacity) + } + } + memcpy(self.memory + self.offset, data + offset, length) + self.offset += length + self.length = self.offset + } + + public func write(_ data: Data) { + let length = data.count + if self.offset + length > self.capacity { + self.capacity = self.offset + length + 256 + if self.length == 0 { + self.memory = malloc(self.capacity)! + } else { + self.memory = realloc(self.memory, self.capacity) + } + } + data.copyBytes(to: self.memory.advanced(by: offset).assumingMemoryBound(to: UInt8.self), count: length) + self.offset += length + self.length = self.offset + } +} + +public final class ReadBuffer: MemoryBuffer { + public var offset = 0 + + override public init(data: Data) { + super.init(data: data) + } + + public init(memory: UnsafeMutableRawPointer, length: Int, freeWhenDone: Bool) { + super.init(memory: memory, capacity: length, length: length, freeWhenDone: freeWhenDone) + } + + public init(memoryBufferNoCopy: MemoryBuffer) { + super.init(memory: memoryBufferNoCopy.memory, capacity: memoryBufferNoCopy.length, length: memoryBufferNoCopy.length, freeWhenDone: false) + } + + public func dataNoCopy() -> Data { + return Data(bytesNoCopy: self.memory.assumingMemoryBound(to: UInt8.self), count: self.length, deallocator: .none) + } + + public func read(_ data: UnsafeMutableRawPointer, offset: Int, length: Int) { + memcpy(data + offset, self.memory.advanced(by: self.offset), length) + self.offset += length + } + + public func skip(_ length: Int) { + self.offset += length + } + + public func reset() { + self.offset = 0 + } + + public func sharedBufferNoCopy() -> ReadBuffer { + return ReadBuffer(memory: memory, length: length, freeWhenDone: false) + } +} diff --git a/submodules/Database/MessageHistoryMetadataTable/BUCK b/submodules/Database/MessageHistoryMetadataTable/BUCK new file mode 100644 index 0000000000..cc4d7c555c --- /dev/null +++ b/submodules/Database/MessageHistoryMetadataTable/BUCK @@ -0,0 +1,17 @@ +load("//Config:buck_rule_macros.bzl", "static_library") + +static_library( + name = "MessageHistoryMetadataTable", + srcs = glob([ + "Sources/**/*.swift", + ]), + deps = [ + "//submodules/Database/ValueBox:ValueBox", + "//submodules/Database/Table:Table", + "//submodules/Database/PostboxDataTypes:PostboxDataTypes", + "//submodules/Database/PostboxCoding:PostboxCoding", + ], + frameworks = [ + "$SDKROOT/System/Library/Frameworks/Foundation.framework", + ], +) diff --git a/submodules/Database/MessageHistoryMetadataTable/Sources/MessageHistoryMetadataTable.swift b/submodules/Database/MessageHistoryMetadataTable/Sources/MessageHistoryMetadataTable.swift new file mode 100644 index 0000000000..6314b2a861 --- /dev/null +++ b/submodules/Database/MessageHistoryMetadataTable/Sources/MessageHistoryMetadataTable.swift @@ -0,0 +1,342 @@ +import Foundation +import ValueBox +import Table +import PostboxCoding +import PostboxDataTypes +import Buffers + +private enum MetadataPrefix: Int8 { + case ChatListInitialized = 0 + case PeerNextMessageIdByNamespace = 2 + case NextStableMessageId = 3 + case ChatListTotalUnreadState = 4 + case NextPeerOperationLogIndex = 5 + case ChatListGroupInitialized = 6 + case GroupFeedIndexInitialized = 7 + case ShouldReindexUnreadCounts = 8 + case PeerHistoryInitialized = 9 +} + +public enum ChatListTotalUnreadStateCategory: Int32 { + case filtered = 0 + case raw = 1 +} + +public enum ChatListTotalUnreadStateStats: Int32 { + case messages = 0 + case chats = 1 +} + +private struct InitializedChatListKey: Hashable { + let groupId: PeerGroupId +} + +public final class MessageHistoryMetadataTable: Table { + public static func tableSpec(_ id: Int32) -> ValueBoxTable { + return ValueBoxTable(id: id, keyType: .binary, compactValuesOnCreation: true) + } + + private let sharedPeerHistoryInitializedKey = ValueBoxKey(length: 8 + 1) + private let sharedGroupFeedIndexInitializedKey = ValueBoxKey(length: 4 + 1) + private let sharedChatListGroupHistoryInitializedKey = ValueBoxKey(length: 4 + 1) + private let sharedPeerNextMessageIdByNamespaceKey = ValueBoxKey(length: 8 + 1 + 4) + private let sharedBuffer = WriteBuffer() + + private var initializedChatList = Set() + private var initializedHistoryPeerIds = Set() + private var initializedGroupFeedIndexIds = Set() + + private var peerNextMessageIdByNamespace: [PeerId: [MessageId.Namespace: MessageId.Id]] = [:] + private var updatedPeerNextMessageIdByNamespace: [PeerId: Set] = [:] + + private var nextMessageStableId: UInt32? + private var nextMessageStableIdUpdated = false + + private var chatListTotalUnreadState: ChatListTotalUnreadState? + private var chatListTotalUnreadStateUpdated = false + + private var nextPeerOperationLogIndex: UInt32? + private var nextPeerOperationLogIndexUpdated = false + + private var currentPinnedChatPeerIds: Set? + private var currentPinnedChatPeerIdsUpdated = false + + private func peerHistoryInitializedKey(_ id: PeerId) -> ValueBoxKey { + self.sharedPeerHistoryInitializedKey.setInt64(0, value: id.toInt64()) + self.sharedPeerHistoryInitializedKey.setInt8(8, value: MetadataPrefix.PeerHistoryInitialized.rawValue) + return self.sharedPeerHistoryInitializedKey + } + + private func groupFeedIndexInitializedKey(_ id: PeerGroupId) -> ValueBoxKey { + self.sharedGroupFeedIndexInitializedKey.setInt32(0, value: id.rawValue) + self.sharedGroupFeedIndexInitializedKey.setInt8(4, value: MetadataPrefix.GroupFeedIndexInitialized.rawValue) + return self.sharedGroupFeedIndexInitializedKey + } + + private func chatListGroupInitializedKey(_ key: InitializedChatListKey) -> ValueBoxKey { + self.sharedChatListGroupHistoryInitializedKey.setInt32(0, value: key.groupId.rawValue) + self.sharedChatListGroupHistoryInitializedKey.setInt8(4, value: MetadataPrefix.ChatListGroupInitialized.rawValue) + return self.sharedChatListGroupHistoryInitializedKey + } + + private func peerNextMessageIdByNamespaceKey(_ id: PeerId, namespace: MessageId.Namespace) -> ValueBoxKey { + self.sharedPeerNextMessageIdByNamespaceKey.setInt64(0, value: id.toInt64()) + self.sharedPeerNextMessageIdByNamespaceKey.setInt8(8, value: MetadataPrefix.PeerNextMessageIdByNamespace.rawValue) + self.sharedPeerNextMessageIdByNamespaceKey.setInt32(8 + 1, value: namespace) + + return self.sharedPeerNextMessageIdByNamespaceKey + } + + private func key(_ prefix: MetadataPrefix) -> ValueBoxKey { + let key = ValueBoxKey(length: 1) + key.setInt8(0, value: prefix.rawValue) + return key + } + + public func setInitializedChatList(groupId: PeerGroupId) { + switch groupId { + case .root: + self.valueBox.set(self.table, key: self.key(MetadataPrefix.ChatListInitialized), value: MemoryBuffer()) + case .group: + self.valueBox.set(self.table, key: self.chatListGroupInitializedKey(InitializedChatListKey(groupId: groupId)), value: MemoryBuffer()) + } + self.initializedChatList.insert(InitializedChatListKey(groupId: groupId)) + } + + public func isInitializedChatList(groupId: PeerGroupId) -> Bool { + let key = InitializedChatListKey(groupId: groupId) + if self.initializedChatList.contains(key) { + return true + } else { + switch groupId { + case .root: + if self.valueBox.exists(self.table, key: self.key(MetadataPrefix.ChatListInitialized)) { + self.initializedChatList.insert(key) + return true + } else { + return false + } + case .group: + if self.valueBox.exists(self.table, key: self.chatListGroupInitializedKey(key)) { + self.initializedChatList.insert(key) + return true + } else { + return false + } + } + } + } + + public func setShouldReindexUnreadCounts(value: Bool) { + if value { + self.valueBox.set(self.table, key: self.key(MetadataPrefix.ShouldReindexUnreadCounts), value: MemoryBuffer()) + } else { + self.valueBox.remove(self.table, key: self.key(MetadataPrefix.ShouldReindexUnreadCounts), secure: false) + } + } + + public func shouldReindexUnreadCounts() -> Bool { + if self.valueBox.exists(self.table, key: self.key(MetadataPrefix.ShouldReindexUnreadCounts)) { + return true + } else { + return false + } + } + + public func setInitialized(_ peerId: PeerId) { + self.initializedHistoryPeerIds.insert(peerId) + self.sharedBuffer.reset() + self.valueBox.set(self.table, key: self.peerHistoryInitializedKey(peerId), value: self.sharedBuffer) + } + + public func isInitialized(_ peerId: PeerId) -> Bool { + if self.initializedHistoryPeerIds.contains(peerId) { + return true + } else { + if self.valueBox.exists(self.table, key: self.peerHistoryInitializedKey(peerId)) { + self.initializedHistoryPeerIds.insert(peerId) + return true + } else { + return false + } + } + } + + public func setGroupFeedIndexInitialized(_ groupId: PeerGroupId) { + self.initializedGroupFeedIndexIds.insert(groupId) + self.sharedBuffer.reset() + self.valueBox.set(self.table, key: self.groupFeedIndexInitializedKey(groupId), value: self.sharedBuffer) + } + + public func isGroupFeedIndexInitialized(_ groupId: PeerGroupId) -> Bool { + if self.initializedGroupFeedIndexIds.contains(groupId) { + return true + } else { + if self.valueBox.exists(self.table, key: self.groupFeedIndexInitializedKey(groupId)) { + self.initializedGroupFeedIndexIds.insert(groupId) + return true + } else { + return false + } + } + } + + public func getNextMessageIdAndIncrement(_ peerId: PeerId, namespace: MessageId.Namespace) -> MessageId { + if let messageIdByNamespace = self.peerNextMessageIdByNamespace[peerId] { + if let nextId = messageIdByNamespace[namespace] { + self.peerNextMessageIdByNamespace[peerId]![namespace] = nextId + 1 + if updatedPeerNextMessageIdByNamespace[peerId] != nil { + updatedPeerNextMessageIdByNamespace[peerId]!.insert(namespace) + } else { + updatedPeerNextMessageIdByNamespace[peerId] = Set([namespace]) + } + return MessageId(peerId: peerId, namespace: namespace, id: nextId) + } else { + var nextId: Int32 = 1 + if let value = self.valueBox.get(self.table, key: self.peerNextMessageIdByNamespaceKey(peerId, namespace: namespace)) { + value.read(&nextId, offset: 0, length: 4) + } + self.peerNextMessageIdByNamespace[peerId]![namespace] = nextId + 1 + if updatedPeerNextMessageIdByNamespace[peerId] != nil { + updatedPeerNextMessageIdByNamespace[peerId]!.insert(namespace) + } else { + updatedPeerNextMessageIdByNamespace[peerId] = Set([namespace]) + } + return MessageId(peerId: peerId, namespace: namespace, id: nextId) + } + } else { + var nextId: Int32 = 1 + if let value = self.valueBox.get(self.table, key: self.peerNextMessageIdByNamespaceKey(peerId, namespace: namespace)) { + value.read(&nextId, offset: 0, length: 4) + } + + self.peerNextMessageIdByNamespace[peerId] = [namespace: nextId + 1] + if updatedPeerNextMessageIdByNamespace[peerId] != nil { + updatedPeerNextMessageIdByNamespace[peerId]!.insert(namespace) + } else { + updatedPeerNextMessageIdByNamespace[peerId] = Set([namespace]) + } + return MessageId(peerId: peerId, namespace: namespace, id: nextId) + } + } + + public func getNextStableMessageIndexId() -> UInt32 { + if let nextId = self.nextMessageStableId { + self.nextMessageStableId = nextId + 1 + self.nextMessageStableIdUpdated = true + return nextId + } else { + if let value = self.valueBox.get(self.table, key: self.key(.NextStableMessageId)) { + var nextId: UInt32 = 0 + value.read(&nextId, offset: 0, length: 4) + self.nextMessageStableId = nextId + 1 + self.nextMessageStableIdUpdated = true + return nextId + } else { + let nextId: UInt32 = 1 + self.nextMessageStableId = nextId + 1 + self.nextMessageStableIdUpdated = true + return nextId + } + } + } + + public func getNextPeerOperationLogIndex() -> UInt32 { + if let nextId = self.nextPeerOperationLogIndex { + self.nextPeerOperationLogIndex = nextId + 1 + self.nextPeerOperationLogIndexUpdated = true + return nextId + } else { + if let value = self.valueBox.get(self.table, key: self.key(.NextPeerOperationLogIndex)) { + var nextId: UInt32 = 0 + value.read(&nextId, offset: 0, length: 4) + self.nextPeerOperationLogIndex = nextId + 1 + self.nextPeerOperationLogIndexUpdated = true + return nextId + } else { + let nextId: UInt32 = 1 + self.nextPeerOperationLogIndex = nextId + 1 + self.nextPeerOperationLogIndexUpdated = true + return nextId + } + } + } + + public func getChatListTotalUnreadState() -> ChatListTotalUnreadState { + if let cached = self.chatListTotalUnreadState { + return cached + } else { + if let value = self.valueBox.get(self.table, key: self.key(.ChatListTotalUnreadState)), let state = PostboxDecoder(buffer: value).decodeObjectForKey("_", decoder: { + ChatListTotalUnreadState(decoder: $0) + }) as? ChatListTotalUnreadState { + self.chatListTotalUnreadState = state + return state + } else { + let state = ChatListTotalUnreadState(absoluteCounters: [:], filteredCounters: [:]) + self.chatListTotalUnreadState = state + return state + } + } + } + + public func setChatListTotalUnreadState(_ state: ChatListTotalUnreadState) { + let current = self.getChatListTotalUnreadState() + if current != state { + self.chatListTotalUnreadState = state + self.chatListTotalUnreadStateUpdated = true + } + } + + override public func clearMemoryCache() { + self.initializedChatList.removeAll() + self.initializedHistoryPeerIds.removeAll() + self.peerNextMessageIdByNamespace.removeAll() + self.updatedPeerNextMessageIdByNamespace.removeAll() + self.nextMessageStableId = nil + self.nextMessageStableIdUpdated = false + self.chatListTotalUnreadState = nil + self.chatListTotalUnreadStateUpdated = false + } + + override public func beforeCommit() { + let sharedBuffer = WriteBuffer() + for (peerId, namespaces) in self.updatedPeerNextMessageIdByNamespace { + for namespace in namespaces { + if let messageIdByNamespace = self.peerNextMessageIdByNamespace[peerId], let maxId = messageIdByNamespace[namespace] { + sharedBuffer.reset() + var mutableMaxId = maxId + sharedBuffer.write(&mutableMaxId, offset: 0, length: 4) + self.valueBox.set(self.table, key: self.peerNextMessageIdByNamespaceKey(peerId, namespace: namespace), value: sharedBuffer) + } else { + self.valueBox.remove(self.table, key: self.peerNextMessageIdByNamespaceKey(peerId, namespace: namespace), secure: false) + } + } + } + self.updatedPeerNextMessageIdByNamespace.removeAll() + + if self.nextMessageStableIdUpdated { + if let nextMessageStableId = self.nextMessageStableId { + var nextId: UInt32 = nextMessageStableId + self.valueBox.set(self.table, key: self.key(.NextStableMessageId), value: MemoryBuffer(memory: &nextId, capacity: 4, length: 4, freeWhenDone: false)) + self.nextMessageStableIdUpdated = false + } + } + + if self.nextPeerOperationLogIndexUpdated { + if let nextPeerOperationLogIndex = self.nextPeerOperationLogIndex { + var nextId: UInt32 = nextPeerOperationLogIndex + self.valueBox.set(self.table, key: self.key(.NextPeerOperationLogIndex), value: MemoryBuffer(memory: &nextId, capacity: 4, length: 4, freeWhenDone: false)) + self.nextPeerOperationLogIndexUpdated = false + } + } + + if self.chatListTotalUnreadStateUpdated { + if let state = self.chatListTotalUnreadState { + let buffer = PostboxEncoder() + buffer.encodeObject(state, forKey: "_") + self.valueBox.set(self.table, key: self.key(.ChatListTotalUnreadState), value: buffer.readBufferNoCopy()) + } + self.chatListTotalUnreadStateUpdated = false + } + } +} diff --git a/submodules/Database/MessageHistoryReadStateTable/BUCK b/submodules/Database/MessageHistoryReadStateTable/BUCK new file mode 100644 index 0000000000..39fda20a59 --- /dev/null +++ b/submodules/Database/MessageHistoryReadStateTable/BUCK @@ -0,0 +1,16 @@ +load("//Config:buck_rule_macros.bzl", "static_library") + +static_library( + name = "MessageHistoryReadStateTable", + srcs = glob([ + "Sources/**/*.swift", + ]), + deps = [ + "//submodules/Database/ValueBox:ValueBox", + "//submodules/Database/Table:Table", + "//submodules/Database/PostboxDataTypes:PostboxDataTypes", + ], + frameworks = [ + "$SDKROOT/System/Library/Frameworks/Foundation.framework", + ], +) diff --git a/submodules/Database/MessageHistoryReadStateTable/Sources/MessageHistoryReadStateTable.swift b/submodules/Database/MessageHistoryReadStateTable/Sources/MessageHistoryReadStateTable.swift new file mode 100644 index 0000000000..c50fcea739 --- /dev/null +++ b/submodules/Database/MessageHistoryReadStateTable/Sources/MessageHistoryReadStateTable.swift @@ -0,0 +1,576 @@ +import Foundation +import PostboxDataTypes +import Table +import ValueBox +import Buffers + +private let traceReadStates = false + +public enum ApplyInteractiveMaxReadIdResult { + case None + case Push(thenSync: Bool) +} + +private final class InternalPeerReadStates { + var namespaces: [MessageId.Namespace: PeerReadState] + + init(namespaces: [MessageId.Namespace: PeerReadState]) { + self.namespaces = namespaces + } +} + +public final class MessageHistoryReadStateTable: Table { + public static func tableSpec(_ id: Int32) -> ValueBoxTable { + return ValueBoxTable(id: id, keyType: .int64, compactValuesOnCreation: false) + } + + private let defaultMessageNamespaceReadStates: [MessageId.Namespace: PeerReadState] + + private var cachedPeerReadStates: [PeerId: InternalPeerReadStates?] = [:] + private var updatedInitialPeerReadStates: [PeerId: [MessageId.Namespace: PeerReadState]] = [:] + + private let sharedKey = ValueBoxKey(length: 8) + + private func key(_ id: PeerId) -> ValueBoxKey { + self.sharedKey.setInt64(0, value: id.toInt64()) + return self.sharedKey + } + + public init(valueBox: ValueBox, table: ValueBoxTable, defaultMessageNamespaceReadStates: [MessageId.Namespace: PeerReadState]) { + self.defaultMessageNamespaceReadStates = defaultMessageNamespaceReadStates + + super.init(valueBox: valueBox, table: table) + } + + private func get(_ id: PeerId) -> InternalPeerReadStates? { + if let states = self.cachedPeerReadStates[id] { + return states + } else { + if let value = self.valueBox.get(self.table, key: self.key(id)) { + var count: Int32 = 0 + value.read(&count, offset: 0, length: 4) + var stateByNamespace: [MessageId.Namespace: PeerReadState] = [:] + for _ in 0 ..< count { + var namespaceId: Int32 = 0 + value.read(&namespaceId, offset: 0, length: 4) + + let state: PeerReadState + var kind: Int8 = 0 + value.read(&kind, offset: 0, length: 1) + if kind == 0 { + var maxIncomingReadId: Int32 = 0 + var maxOutgoingReadId: Int32 = 0 + var maxKnownId: Int32 = 0 + var count: Int32 = 0 + + value.read(&maxIncomingReadId, offset: 0, length: 4) + value.read(&maxOutgoingReadId, offset: 0, length: 4) + value.read(&maxKnownId, offset: 0, length: 4) + value.read(&count, offset: 0, length: 4) + + var flags: Int32 = 0 + value.read(&flags, offset: 0, length: 4) + let markedUnread = (flags & (1 << 0)) != 0 + + state = .idBased(maxIncomingReadId: maxIncomingReadId, maxOutgoingReadId: maxOutgoingReadId, maxKnownId: maxKnownId, count: count, markedUnread: markedUnread) + } else { + var maxIncomingReadTimestamp: Int32 = 0 + var maxIncomingReadIdPeerId: Int64 = 0 + var maxIncomingReadIdNamespace: Int32 = 0 + var maxIncomingReadIdId: Int32 = 0 + + var maxOutgoingReadTimestamp: Int32 = 0 + var maxOutgoingReadIdPeerId: Int64 = 0 + var maxOutgoingReadIdNamespace: Int32 = 0 + var maxOutgoingReadIdId: Int32 = 0 + + var count: Int32 = 0 + + value.read(&maxIncomingReadTimestamp, offset: 0, length: 4) + value.read(&maxIncomingReadIdPeerId, offset: 0, length: 8) + value.read(&maxIncomingReadIdNamespace, offset: 0, length: 4) + value.read(&maxIncomingReadIdId, offset: 0, length: 4) + + value.read(&maxOutgoingReadTimestamp, offset: 0, length: 4) + value.read(&maxOutgoingReadIdPeerId, offset: 0, length: 8) + value.read(&maxOutgoingReadIdNamespace, offset: 0, length: 4) + value.read(&maxOutgoingReadIdId, offset: 0, length: 4) + + value.read(&count, offset: 0, length: 4) + + var flags: Int32 = 0 + value.read(&flags, offset: 0, length: 4) + let markedUnread = (flags & (1 << 0)) != 0 + + state = .indexBased(maxIncomingReadIndex: MessageIndex(id: MessageId(peerId: PeerId(maxIncomingReadIdPeerId), namespace: maxIncomingReadIdNamespace, id: maxIncomingReadIdId), timestamp: maxIncomingReadTimestamp), maxOutgoingReadIndex: MessageIndex(id: MessageId(peerId: PeerId(maxOutgoingReadIdPeerId), namespace: maxOutgoingReadIdNamespace, id: maxOutgoingReadIdId), timestamp: maxOutgoingReadTimestamp), count: count, markedUnread: markedUnread) + } + stateByNamespace[namespaceId] = state + } + let states = InternalPeerReadStates(namespaces: stateByNamespace) + self.cachedPeerReadStates[id] = states + return states + } else { + self.cachedPeerReadStates[id] = nil + return nil + } + } + } + + public func getCombinedState(_ peerId: PeerId) -> CombinedPeerReadState? { + if let states = self.get(peerId) { + return CombinedPeerReadState(states: states.namespaces.map({$0})) + } + return nil + } + + private func markReadStatesAsUpdated(_ peerId: PeerId, namespaces: [MessageId.Namespace: PeerReadState]) { + if self.updatedInitialPeerReadStates[peerId] == nil { + self.updatedInitialPeerReadStates[peerId] = namespaces + } + } + + public func resetStates(_ peerId: PeerId, namespaces: [MessageId.Namespace: PeerReadState]) -> CombinedPeerReadState? { + if traceReadStates { + print("[ReadStateTable] resetStates peerId: \(peerId), namespaces: \(namespaces)") + } + + if let states = self.get(peerId) { + var updated = false + for (namespace, state) in namespaces { + if states.namespaces[namespace] == nil || states.namespaces[namespace]! != state { + self.markReadStatesAsUpdated(peerId, namespaces: states.namespaces) + updated = true + } + states.namespaces[namespace] = state + } + if updated { + return CombinedPeerReadState(states: states.namespaces.map({$0})) + } else { + return nil + } + } else { + self.markReadStatesAsUpdated(peerId, namespaces: [:]) + let states = InternalPeerReadStates(namespaces: namespaces) + self.cachedPeerReadStates[peerId] = states + return CombinedPeerReadState(states: states.namespaces.map({$0})) + } + } + + + public func addIncomingMessages(_ peerId: PeerId, indices: Set) -> (CombinedPeerReadState?, Bool) { + var indicesByNamespace: [MessageId.Namespace: [MessageIndex]] = [:] + for index in indices { + if indicesByNamespace[index.id.namespace] != nil { + indicesByNamespace[index.id.namespace]!.append(index) + } else { + indicesByNamespace[index.id.namespace] = [index] + } + } + + if let states = self.get(peerId) { + if traceReadStates { + print("[ReadStateTable] addIncomingMessages peerId: \(peerId), indices: \(indices) (before: \(states.namespaces))") + } + + var updated = false + let invalidated = false + for (namespace, namespaceIndices) in indicesByNamespace { + let currentState = states.namespaces[namespace] ?? self.defaultMessageNamespaceReadStates[namespace] + + if let currentState = currentState { + var addedUnreadCount: Int32 = 0 + for index in namespaceIndices { + switch currentState { + case let .idBased(maxIncomingReadId, _, maxKnownId, _, _): + if index.id.id > maxKnownId && index.id.id > maxIncomingReadId { + addedUnreadCount += 1 + } + case let .indexBased(maxIncomingReadIndex, _, _, _): + if index > maxIncomingReadIndex { + addedUnreadCount += 1 + } + } + } + + if addedUnreadCount != 0 { + self.markReadStatesAsUpdated(peerId, namespaces: states.namespaces) + + states.namespaces[namespace] = currentState.withAddedCount(addedUnreadCount) + updated = true + + if traceReadStates { + print("[ReadStateTable] added \(addedUnreadCount)") + } + } + } + } + + return (updated ? CombinedPeerReadState(states: states.namespaces.map({$0})) : nil, invalidated) + } else { + if traceReadStates { + print("[ReadStateTable] addIncomingMessages peerId: \(peerId), just invalidated)") + } + return (nil, true) + } + } + + public func deleteMessages(_ peerId: PeerId, indices: [MessageIndex], incomingStatsInIndices: (PeerId, MessageId.Namespace, [MessageIndex]) -> (Int, Bool)) -> (CombinedPeerReadState?, Bool) { + var indicesByNamespace: [MessageId.Namespace: [MessageIndex]] = [:] + for index in indices { + if indicesByNamespace[index.id.namespace] != nil { + indicesByNamespace[index.id.namespace]!.append(index) + } else { + indicesByNamespace[index.id.namespace] = [index] + } + } + + if let states = self.get(peerId) { + if traceReadStates { + print("[ReadStateTable] deleteMessages peerId: \(peerId), ids: \(indices) (before: \(states.namespaces))") + } + + var updated = false + var invalidate = false + for (namespace, namespaceIndices) in indicesByNamespace { + if let currentState = states.namespaces[namespace] { + var unreadIndices: [MessageIndex] = [] + for index in namespaceIndices { + if !currentState.isIncomingMessageIndexRead(index) { + unreadIndices.append(index) + } + } + + let (knownCount, holes) = incomingStatsInIndices(peerId, namespace, unreadIndices) + if holes { + invalidate = true + } + + self.markReadStatesAsUpdated(peerId, namespaces: states.namespaces) + + var updatedState = currentState.withAddedCount(Int32(-knownCount)) + if updatedState.count < 0 { + invalidate = true + updatedState = currentState.withAddedCount(-updatedState.count) + } + + states.namespaces[namespace] = updatedState + updated = true + } else { + invalidate = true + } + } + + return (updated ? CombinedPeerReadState(states: states.namespaces.map({$0})) : nil, invalidate) + } else { + return (nil, true) + } + } + + public func applyIncomingMaxReadId(_ messageId: MessageId, incomingStatsInRange: (MessageId.Namespace, MessageId.Id, MessageId.Id) -> (count: Int, holes: Bool), topMessageId: (MessageId.Id, Bool)?) -> (CombinedPeerReadState?, Bool) { + if let states = self.get(messageId.peerId), let state = states.namespaces[messageId.namespace] { + if traceReadStates { + print("[ReadStateTable] applyMaxReadId peerId: \(messageId.peerId), maxReadId: \(messageId) (before: \(states.namespaces))") + } + + switch state { + case let .idBased(maxIncomingReadId, maxOutgoingReadId, maxKnownId, count, markedUnread): + if maxIncomingReadId < messageId.id || (topMessageId != nil && (messageId.id == topMessageId!.0 || topMessageId!.1) && state.count != 0) || markedUnread { + var (deltaCount, holes) = incomingStatsInRange(messageId.namespace, maxIncomingReadId + 1, messageId.id) + + if traceReadStates { + print("[ReadStateTable] applyMaxReadId after deltaCount: \(deltaCount), holes: \(holes)") + } + + if let topMessageId = topMessageId, (messageId.id == topMessageId.0 || topMessageId.1) { + if deltaCount != Int(state.count) { + deltaCount = Int(state.count) + holes = true + } + } + + self.markReadStatesAsUpdated(messageId.peerId, namespaces: states.namespaces) + + states.namespaces[messageId.namespace] = .idBased(maxIncomingReadId: messageId.id, maxOutgoingReadId: maxOutgoingReadId, maxKnownId: maxKnownId, count: max(0, count - Int32(deltaCount)), markedUnread: false) + return (CombinedPeerReadState(states: states.namespaces.map({$0})), holes) + } + case .indexBased: + assertionFailure() + break + } + } else { + return (nil, true) + } + + return (nil, false) + } + + public func applyIncomingMaxReadIndex(_ messageIndex: MessageIndex, topMessageIndex: MessageIndex?, incomingStatsInRange: (MessageIndex, MessageIndex) -> (count: Int, holes: Bool, readMesageIds: [MessageId])) -> (CombinedPeerReadState?, Bool, [MessageId]) { + if let states = self.get(messageIndex.id.peerId), let state = states.namespaces[messageIndex.id.namespace] { + if traceReadStates { + print("[ReadStateTable] applyIncomingMaxReadIndex peerId: \(messageIndex.id.peerId), maxReadIndex: \(messageIndex) (before: \(states.namespaces))") + } + + switch state { + case .idBased: + assertionFailure() + case let .indexBased(maxIncomingReadIndex, maxOutgoingReadIndex, count, markedUnread): + var readPastTopIndex = false + if let topMessageIndex = topMessageIndex, messageIndex >= topMessageIndex && count != 0 { + readPastTopIndex = true + } + if maxIncomingReadIndex < messageIndex || markedUnread || readPastTopIndex { + let (realDeltaCount, holes, messageIds) = incomingStatsInRange(maxIncomingReadIndex.successor(), messageIndex) + var deltaCount = realDeltaCount + if readPastTopIndex { + deltaCount = max(Int(count), deltaCount) + } + + if traceReadStates { + print("[ReadStateTable] applyIncomingMaxReadIndex after deltaCount: \(deltaCount), holes: \(holes)") + } + + self.markReadStatesAsUpdated(messageIndex.id.peerId, namespaces: states.namespaces) + + states.namespaces[messageIndex.id.namespace] = .indexBased(maxIncomingReadIndex: messageIndex, maxOutgoingReadIndex: maxOutgoingReadIndex, count: max(0, count - Int32(deltaCount)), markedUnread: false) + return (CombinedPeerReadState(states: states.namespaces.map({$0})), holes, messageIds) + } + } + } else { + return (nil, true, []) + } + + return (nil, false, []) + } + + public func applyOutgoingMaxReadId(_ messageId: MessageId) -> (CombinedPeerReadState?, Bool) { + if let states = self.get(messageId.peerId), let state = states.namespaces[messageId.namespace] { + switch state { + case let .idBased(maxIncomingReadId, maxOutgoingReadId, maxKnownId, count, markedUnread): + if maxOutgoingReadId < messageId.id { + self.markReadStatesAsUpdated(messageId.peerId, namespaces: states.namespaces) + states.namespaces[messageId.namespace] = .idBased(maxIncomingReadId: maxIncomingReadId, maxOutgoingReadId: messageId.id, maxKnownId: maxKnownId, count: count, markedUnread: markedUnread) + return (CombinedPeerReadState(states: states.namespaces.map({$0})), false) + } + case .indexBased: + assertionFailure() + break + } + } else { + return (nil, true) + } + + return (nil, false) + } + + public func applyOutgoingMaxReadIndex(_ messageIndex: MessageIndex, outgoingIndexStatsInRange: (MessageIndex, MessageIndex) -> [MessageId]) -> (CombinedPeerReadState?, Bool, [MessageId]) { + if let states = self.get(messageIndex.id.peerId), let state = states.namespaces[messageIndex.id.namespace] { + switch state { + case .idBased: + assertionFailure() + break + case let .indexBased(maxIncomingReadIndex, maxOutgoingReadIndex, count, markedUnread): + if maxOutgoingReadIndex < messageIndex { + let messageIds: [MessageId] = outgoingIndexStatsInRange(maxOutgoingReadIndex.successor(), messageIndex) + + self.markReadStatesAsUpdated(messageIndex.id.peerId, namespaces: states.namespaces) + states.namespaces[messageIndex.id.namespace] = .indexBased(maxIncomingReadIndex: maxIncomingReadIndex, maxOutgoingReadIndex: messageIndex, count: count, markedUnread: markedUnread) + return (CombinedPeerReadState(states: states.namespaces.map({$0})), false, messageIds) + } + } + } else { + return (nil, true, []) + } + + return (nil, false, []) + } + + public func applyInteractiveMaxReadIndex(messageIndex: MessageIndex, incomingStatsInRange: (MessageId.Namespace, MessageId.Id, MessageId.Id) -> (count: Int, holes: Bool), incomingIndexStatsInRange: (MessageIndex, MessageIndex) -> (count: Int, holes: Bool, readMesageIds: [MessageId]), topMessageId: (MessageId.Id, Bool)?, topMessageIndexByNamespace: (MessageId.Namespace) -> MessageIndex?) -> (combinedState: CombinedPeerReadState?, ApplyInteractiveMaxReadIdResult, readMesageIds: [MessageId]) { + if let states = self.get(messageIndex.id.peerId) { + if let state = states.namespaces[messageIndex.id.namespace] { + switch state { + case .idBased: + let (combinedState, holes) = self.applyIncomingMaxReadId(messageIndex.id, incomingStatsInRange: incomingStatsInRange, topMessageId: topMessageId) + + if let combinedState = combinedState { + return (combinedState, .Push(thenSync: holes), []) + } + + return (combinedState, holes ? .Push(thenSync: true) : .None, []) + case .indexBased: + let topMessageIndex: MessageIndex? = topMessageIndexByNamespace(messageIndex.id.namespace) + let (combinedState, holes, messageIds) = self.applyIncomingMaxReadIndex(messageIndex, topMessageIndex: topMessageIndex, incomingStatsInRange: incomingIndexStatsInRange) + + if let combinedState = combinedState { + return (combinedState, .Push(thenSync: holes), messageIds) + } + + return (combinedState, holes ? .Push(thenSync: true) : .None, messageIds) + } + } else { + for (namespace, state) in states.namespaces { + if let topIndex = topMessageIndexByNamespace(namespace), topIndex <= messageIndex { + switch state { + case .idBased: + let (combinedState, holes) = self.applyIncomingMaxReadId(topIndex.id, incomingStatsInRange: incomingStatsInRange, topMessageId: nil) + + if let combinedState = combinedState { + return (combinedState, .Push(thenSync: holes), []) + } + + return (combinedState, holes ? .Push(thenSync: true) : .None, []) + case .indexBased: + let (combinedState, holes, messageIds) = self.applyIncomingMaxReadIndex(topIndex, topMessageIndex: topMessageIndexByNamespace(namespace), incomingStatsInRange: incomingIndexStatsInRange) + + if let combinedState = combinedState { + return (combinedState, .Push(thenSync: holes), messageIds) + } + + return (combinedState, holes ? .Push(thenSync: true) : .None, messageIds) + } + } + } + return (nil, .Push(thenSync: true), []) + } + } else { + return (nil, .Push(thenSync: true), []) + } + } + + public func applyInteractiveMarkUnread(peerId: PeerId, namespace: MessageId.Namespace, value: Bool) -> CombinedPeerReadState? { + if let states = self.get(peerId), let state = states.namespaces[namespace] { + switch state { + case let .idBased(maxIncomingReadId, maxOutgoingReadId, maxKnownId, count, markedUnread): + if markedUnread != value { + self.markReadStatesAsUpdated(peerId, namespaces: states.namespaces) + + states.namespaces[namespace] = .idBased(maxIncomingReadId: maxIncomingReadId, maxOutgoingReadId: maxOutgoingReadId, maxKnownId: maxKnownId, count: count, markedUnread: value) + return CombinedPeerReadState(states: states.namespaces.map({$0})) + } else { + return nil + } + case let .indexBased(maxIncomingReadIndex, maxOutgoingReadIndex, count, markedUnread): + if markedUnread != value { + self.markReadStatesAsUpdated(peerId, namespaces: states.namespaces) + + states.namespaces[namespace] = .indexBased(maxIncomingReadIndex: maxIncomingReadIndex, maxOutgoingReadIndex: maxOutgoingReadIndex, count: count, markedUnread: value) + return CombinedPeerReadState(states: states.namespaces.map({$0})) + } else { + return nil + } + } + } else { + return nil + } + } + + public func transactionUnreadCountDeltas() -> [PeerId: Int32] { + var deltas: [PeerId: Int32] = [:] + for (id, initialNamespaces) in self.updatedInitialPeerReadStates { + var initialCount: Int32 = 0 + for (_, state) in initialNamespaces { + initialCount += state.count + } + + var updatedCount: Int32 = 0 + if let maybeStates = self.cachedPeerReadStates[id] { + if let states = maybeStates { + for (_, state) in states.namespaces { + updatedCount += state.count + } + } + } else { + assertionFailure() + } + + if initialCount != updatedCount { + deltas[id] = updatedCount - initialCount + } + } + return deltas + } + + public func transactionAlteredInitialPeerCombinedReadStates() -> [PeerId: CombinedPeerReadState] { + var result: [PeerId: CombinedPeerReadState] = [:] + for (peerId, namespacesAndStates) in self.updatedInitialPeerReadStates { + var states: [(MessageId.Namespace, PeerReadState)] = [] + for (namespace, state) in namespacesAndStates { + states.append((namespace, state)) + } + result[peerId] = CombinedPeerReadState(states: states) + } + return result + } + + override public func clearMemoryCache() { + self.cachedPeerReadStates.removeAll() + assert(self.updatedInitialPeerReadStates.isEmpty) + } + + override public func beforeCommit() { + if !self.updatedInitialPeerReadStates.isEmpty { + let sharedBuffer = WriteBuffer() + for (id, initialNamespaces) in self.updatedInitialPeerReadStates { + if let wrappedStates = self.cachedPeerReadStates[id], let states = wrappedStates { + sharedBuffer.reset() + var count: Int32 = Int32(states.namespaces.count) + sharedBuffer.write(&count, offset: 0, length: 4) + for (namespace, state) in states.namespaces { + var namespaceId: Int32 = namespace + sharedBuffer.write(&namespaceId, offset: 0, length: 4) + + switch state { + case .idBased(var maxIncomingReadId, var maxOutgoingReadId, var maxKnownId, var count, let markedUnread): + var kind: Int8 = 0 + sharedBuffer.write(&kind, offset: 0, length: 1) + + sharedBuffer.write(&maxIncomingReadId, offset: 0, length: 4) + sharedBuffer.write(&maxOutgoingReadId, offset: 0, length: 4) + sharedBuffer.write(&maxKnownId, offset: 0, length: 4) + sharedBuffer.write(&count, offset: 0, length: 4) + var flags: Int32 = 0 + if markedUnread { + flags |= (1 << 0) + } + sharedBuffer.write(&flags, offset: 0, length: 4) + case .indexBased(let maxIncomingReadIndex, let maxOutgoingReadIndex, var count, let markedUnread): + var kind: Int8 = 1 + sharedBuffer.write(&kind, offset: 0, length: 1) + + var maxIncomingReadTimestamp: Int32 = maxIncomingReadIndex.timestamp + var maxIncomingReadIdPeerId: Int64 = maxIncomingReadIndex.id.peerId.toInt64() + var maxIncomingReadIdNamespace: Int32 = maxIncomingReadIndex.id.namespace + var maxIncomingReadIdId: Int32 = maxIncomingReadIndex.id.id + + var maxOutgoingReadTimestamp: Int32 = maxOutgoingReadIndex.timestamp + var maxOutgoingReadIdPeerId: Int64 = maxOutgoingReadIndex.id.peerId.toInt64() + var maxOutgoingReadIdNamespace: Int32 = maxOutgoingReadIndex.id.namespace + var maxOutgoingReadIdId: Int32 = maxOutgoingReadIndex.id.id + + sharedBuffer.write(&maxIncomingReadTimestamp, offset: 0, length: 4) + sharedBuffer.write(&maxIncomingReadIdPeerId, offset: 0, length: 8) + sharedBuffer.write(&maxIncomingReadIdNamespace, offset: 0, length: 4) + sharedBuffer.write(&maxIncomingReadIdId, offset: 0, length: 4) + + sharedBuffer.write(&maxOutgoingReadTimestamp, offset: 0, length: 4) + sharedBuffer.write(&maxOutgoingReadIdPeerId, offset: 0, length: 8) + sharedBuffer.write(&maxOutgoingReadIdNamespace, offset: 0, length: 4) + sharedBuffer.write(&maxOutgoingReadIdId, offset: 0, length: 4) + + sharedBuffer.write(&count, offset: 0, length: 4) + + var flags: Int32 = 0 + if markedUnread { + flags |= 1 << 0 + } + sharedBuffer.write(&flags, offset: 0, length: 4) + } + } + self.valueBox.set(self.table, key: self.key(id), value: sharedBuffer) + } else { + self.valueBox.remove(self.table, key: self.key(id), secure: false) + } + } + self.updatedInitialPeerReadStates.removeAll() + } + } +} diff --git a/submodules/Database/MurmurHash/BUCK b/submodules/Database/MurmurHash/BUCK new file mode 100644 index 0000000000..645fb93634 --- /dev/null +++ b/submodules/Database/MurmurHash/BUCK @@ -0,0 +1,20 @@ +load("//Config:buck_rule_macros.bzl", "static_library") + +static_library( + name = "MurmurHash", + srcs = glob([ + "Sources/**/*.swift", + "Sources/**/*.m", + ]), + headers = glob([ + "Sources/**/*.h", + ]), + exported_headers = glob([ + "Sources/**/*.h", + ]), + deps = [ + ], + frameworks = [ + "$SDKROOT/System/Library/Frameworks/Foundation.framework", + ], +) diff --git a/submodules/Database/MurmurHash/Sources/MurMurHash32.h b/submodules/Database/MurmurHash/Sources/MurMurHash32.h new file mode 100644 index 0000000000..6b5ab89c3e --- /dev/null +++ b/submodules/Database/MurmurHash/Sources/MurMurHash32.h @@ -0,0 +1,12 @@ +#ifndef Postbox_MurMurHash32_h +#define Postbox_MurMurHash32_h + +#import +#import + +int32_t murMurHash32(void *bytes, int length); +int32_t murMurHash32Data(NSData *data); +int32_t murMurHashString32(const char *s); +NSString *postboxTransformedString(CFStringRef string, bool replaceWithTransliteratedVersion, bool appendTransliteratedVersion); + +#endif diff --git a/submodules/Database/MurmurHash/Sources/MurMurHash32.m b/submodules/Database/MurmurHash/Sources/MurMurHash32.m new file mode 100644 index 0000000000..87f695f017 --- /dev/null +++ b/submodules/Database/MurmurHash/Sources/MurMurHash32.m @@ -0,0 +1,120 @@ +#import "MurMurHash32.h" + +#include +#include + +#define FORCE_INLINE __attribute__((always_inline)) + +static inline uint32_t rotl32 ( uint32_t x, int8_t r ) +{ + return (x << r) | (x >> (32 - r)); +} + +#define ROTL32(x,y) rotl32(x,y) + +static FORCE_INLINE uint32_t getblock ( const uint32_t * p, int i ) +{ + return p[i]; +} + +static FORCE_INLINE uint32_t fmix ( uint32_t h ) +{ + h ^= h >> 16; + h *= 0x85ebca6b; + h ^= h >> 13; + h *= 0xc2b2ae35; + h ^= h >> 16; + + return h; +} + +static void murMurHash32Impl(const void *key, int len, uint32_t seed, void *out) +{ + const uint8_t * data = (const uint8_t*)key; + const int nblocks = len / 4; + + uint32_t h1 = seed; + + const uint32_t c1 = 0xcc9e2d51; + const uint32_t c2 = 0x1b873593; + + //---------- + // body + + const uint32_t * blocks = (const uint32_t *)(data + nblocks*4); + + for(int i = -nblocks; i; i++) + { + uint32_t k1 = getblock(blocks,i); + + k1 *= c1; + k1 = ROTL32(k1,15); + k1 *= c2; + + h1 ^= k1; + h1 = ROTL32(h1,13); + h1 = h1*5+0xe6546b64; + } + + //---------- + // tail + + const uint8_t * tail = (const uint8_t*)(data + nblocks*4); + + uint32_t k1 = 0; + + switch(len & 3) + { + case 3: k1 ^= tail[2] << 16; + case 2: k1 ^= tail[1] << 8; + case 1: k1 ^= tail[0]; + k1 *= c1; k1 = ROTL32(k1,15); k1 *= c2; h1 ^= k1; + }; + + //---------- + // finalization + + h1 ^= len; + + h1 = fmix(h1); + + *(uint32_t*)out = h1; +} + +int32_t murMurHash32(void *bytes, int length) +{ + int32_t result = 0; + murMurHash32Impl(bytes, length, -137723950, &result); + + return result; +} + +int32_t murMurHash32Data(NSData *data) { + return murMurHash32((void *)data.bytes, (int)data.length); +} + +int32_t murMurHashString32(const char *s) +{ + int32_t result = 0; + murMurHash32Impl(s, (int)strlen(s), -137723950, &result); + + return result; +} + +NSString *postboxTransformedString(CFStringRef string, bool replaceWithTransliteratedVersion, bool appendTransliteratedVersion) { + NSMutableString *mutableString = [[NSMutableString alloc] initWithString:(__bridge NSString * _Nonnull)(string)]; + CFStringTransform((CFMutableStringRef)mutableString, NULL, kCFStringTransformStripCombiningMarks, false); + + if (replaceWithTransliteratedVersion || appendTransliteratedVersion) { + NSMutableString *transliteratedString = [[NSMutableString alloc] initWithString:mutableString]; + CFStringTransform((CFMutableStringRef)transliteratedString, NULL, kCFStringTransformToLatin, false); + if (replaceWithTransliteratedVersion) { + return transliteratedString; + } else { + [mutableString appendString:@" "]; + [mutableString appendString:transliteratedString]; + } + } + + return mutableString; +} diff --git a/submodules/Database/MurmurHash/Sources/MurmurHash.swift b/submodules/Database/MurmurHash/Sources/MurmurHash.swift new file mode 100644 index 0000000000..a0f326a4eb --- /dev/null +++ b/submodules/Database/MurmurHash/Sources/MurmurHash.swift @@ -0,0 +1,11 @@ +import Foundation + +public enum HashFunctions { + public static func murMurHash32(_ s: String) -> Int32 { + return murMurHashString32(s) + } + + public static func murMurHash32(_ d: Data) -> Int32 { + return murMurHash32Data(d) + } +} diff --git a/submodules/Database/PostboxCoding/BUCK b/submodules/Database/PostboxCoding/BUCK new file mode 100644 index 0000000000..585d78b981 --- /dev/null +++ b/submodules/Database/PostboxCoding/BUCK @@ -0,0 +1,15 @@ +load("//Config:buck_rule_macros.bzl", "static_library") + +static_library( + name = "PostboxCoding", + srcs = glob([ + "Sources/**/*.swift", + ]), + deps = [ + "//submodules/Database/Buffers:Buffers", + "//submodules/Database/MurmurHash:MurmurHash", + ], + frameworks = [ + "$SDKROOT/System/Library/Frameworks/Foundation.framework", + ], +) diff --git a/submodules/Database/PostboxCoding/Sources/PostboxCoding.swift b/submodules/Database/PostboxCoding/Sources/PostboxCoding.swift new file mode 100644 index 0000000000..3566cfbf59 --- /dev/null +++ b/submodules/Database/PostboxCoding/Sources/PostboxCoding.swift @@ -0,0 +1,1297 @@ +import Foundation +import Buffers +import MurmurHash + +public protocol PostboxCoding { + init(decoder: PostboxDecoder) + func encode(_ encoder: PostboxEncoder) +} + +private final class EncodableTypeStore { + var dict: [Int32 : (PostboxDecoder) -> PostboxCoding] = [:] + + func decode(_ typeHash: Int32, decoder: PostboxDecoder) -> PostboxCoding? { + if let typeDecoder = self.dict[typeHash] { + return typeDecoder(decoder) + } else { + return nil + } + } +} + +private let _typeStore = EncodableTypeStore() +private let typeStore = { () -> EncodableTypeStore in + return _typeStore +}() + +public func declareEncodable(_ type: Any.Type, f: @escaping(PostboxDecoder) -> PostboxCoding) { + let string = "\(type)" + let hash = murMurHashString32(string) + if typeStore.dict[hash] != nil { + assertionFailure("Encodable type hash collision for \(type)") + } + typeStore.dict[murMurHashString32("\(type)")] = f +} + +public func declareEncodable(typeHash: Int32, _ f: @escaping(PostboxDecoder) -> PostboxCoding) { + if typeStore.dict[typeHash] != nil { + assertionFailure("Encodable type hash collision for \(typeHash)") + } + typeStore.dict[typeHash] = f +} + +public func persistentHash32(_ string: String) -> Int32 { + return murMurHashString32(string) +} + +private enum ValueType: Int8 { + case Int32 = 0 + case Int64 = 1 + case Bool = 2 + case Double = 3 + case String = 4 + case Object = 5 + case Int32Array = 6 + case Int64Array = 7 + case ObjectArray = 8 + case ObjectDictionary = 9 + case Bytes = 10 + case Nil = 11 + case StringArray = 12 + case BytesArray = 13 +} + +public final class PostboxEncoder { + private let buffer = WriteBuffer() + + public init() { + } + + public func memoryBuffer() -> MemoryBuffer { + return self.buffer + } + + public func makeReadBufferAndReset() -> ReadBuffer { + return self.buffer.makeReadBufferAndReset() + } + + public func readBufferNoCopy() -> ReadBuffer { + return self.buffer.readBufferNoCopy() + } + + public func makeData() -> Data { + return self.buffer.makeData() + } + + public func reset() { + self.buffer.reset() + } + + public func encodeKey(_ key: StaticString) { + var length: Int8 = Int8(key.utf8CodeUnitCount) + self.buffer.write(&length, offset: 0, length: 1) + self.buffer.write(key.utf8Start, offset: 0, length: Int(length)) + } + + public func encodeKey(_ key: String) { + let data = key.data(using: .utf8)! + data.withUnsafeBytes { (keyBytes: UnsafePointer) -> Void in + var length: Int8 = Int8(data.count) + self.buffer.write(&length, offset: 0, length: 1) + self.buffer.write(keyBytes, offset: 0, length: Int(length)) + } + } + + public func encodeNil(forKey key: StaticString) { + self.encodeKey(key) + var type: Int8 = ValueType.Nil.rawValue + self.buffer.write(&type, offset: 0, length: 1) + } + + public func encodeInt32(_ value: Int32, forKey key: StaticString) { + self.encodeKey(key) + var type: Int8 = ValueType.Int32.rawValue + self.buffer.write(&type, offset: 0, length: 1) + var v = value + self.buffer.write(&v, offset: 0, length: 4) + } + + public func encodeInt32(_ value: Int32, forKey key: String) { + self.encodeKey(key) + var type: Int8 = ValueType.Int32.rawValue + self.buffer.write(&type, offset: 0, length: 1) + var v = value + self.buffer.write(&v, offset: 0, length: 4) + } + + public func encodeInt64(_ value: Int64, forKey key: StaticString) { + self.encodeKey(key) + var type: Int8 = ValueType.Int64.rawValue + self.buffer.write(&type, offset: 0, length: 1) + var v = value + self.buffer.write(&v, offset: 0, length: 8) + } + + public func encodeBool(_ value: Bool, forKey key: StaticString) { + self.encodeKey(key) + var type: Int8 = ValueType.Bool.rawValue + self.buffer.write(&type, offset: 0, length: 1) + var v: Int8 = value ? 1 : 0 + self.buffer.write(&v, offset: 0, length: 1) + } + + public func encodeDouble(_ value: Double, forKey key: StaticString) { + self.encodeKey(key) + var type: Int8 = ValueType.Double.rawValue + self.buffer.write(&type, offset: 0, length: 1) + var v = value + self.buffer.write(&v, offset: 0, length: 8) + } + + public func encodeString(_ value: String, forKey key: StaticString) { + self.encodeKey(key) + var type: Int8 = ValueType.String.rawValue + self.buffer.write(&type, offset: 0, length: 1) + if let data = value.data(using: .utf8, allowLossyConversion: true) { + var length: Int32 = Int32(data.count) + self.buffer.write(&length, offset: 0, length: 4) + self.buffer.write(data) + } else { + var length: Int32 = 0 + self.buffer.write(&length, offset: 0, length: 4) + } + } + + public func encodeRootObject(_ value: PostboxCoding) { + self.encodeObject(value, forKey: "_") + } + + public func encodeObject(_ value: PostboxCoding, forKey key: StaticString) { + self.encodeKey(key) + var t: Int8 = ValueType.Object.rawValue + self.buffer.write(&t, offset: 0, length: 1) + + let string = "\(type(of: value))" + var typeHash: Int32 = murMurHashString32(string) + self.buffer.write(&typeHash, offset: 0, length: 4) + + let innerEncoder = PostboxEncoder() + value.encode(innerEncoder) + + var length: Int32 = Int32(innerEncoder.buffer.offset) + self.buffer.write(&length, offset: 0, length: 4) + self.buffer.write(innerEncoder.buffer.memory, offset: 0, length: Int(length)) + } + + public func encodeObjectWithEncoder(_ value: T, encoder: (PostboxEncoder) -> Void, forKey key: String) { + self.encodeKey(key) + var t: Int8 = ValueType.Object.rawValue + self.buffer.write(&t, offset: 0, length: 1) + + let string = "\(type(of: value))" + var typeHash: Int32 = murMurHashString32(string) + self.buffer.write(&typeHash, offset: 0, length: 4) + + let innerEncoder = PostboxEncoder() + encoder(innerEncoder) + + var length: Int32 = Int32(innerEncoder.buffer.offset) + self.buffer.write(&length, offset: 0, length: 4) + self.buffer.write(innerEncoder.buffer.memory, offset: 0, length: Int(length)) + } + + public func encodeInt32Array(_ value: [Int32], forKey key: StaticString) { + self.encodeKey(key) + var type: Int8 = ValueType.Int32Array.rawValue + self.buffer.write(&type, offset: 0, length: 1) + var length: Int32 = Int32(value.count) + self.buffer.write(&length, offset: 0, length: 4) + value.withUnsafeBufferPointer { (data: UnsafeBufferPointer) -> Void in + self.buffer.write(UnsafeRawPointer(data.baseAddress!), offset: 0, length: Int(length) * 4) + return + } + } + + public func encodeInt64Array(_ value: [Int64], forKey key: StaticString) { + self.encodeKey(key) + var type: Int8 = ValueType.Int64Array.rawValue + self.buffer.write(&type, offset: 0, length: 1) + var length: Int32 = Int32(value.count) + self.buffer.write(&length, offset: 0, length: 4) + value.withUnsafeBufferPointer { (data: UnsafeBufferPointer) -> Void in + self.buffer.write(UnsafeRawPointer(data.baseAddress!), offset: 0, length: Int(length) * 8) + return + } + } + + public func encodeObjectArray(_ value: [T], forKey key: StaticString) { + self.encodeKey(key) + var t: Int8 = ValueType.ObjectArray.rawValue + self.buffer.write(&t, offset: 0, length: 1) + var length: Int32 = Int32(value.count) + self.buffer.write(&length, offset: 0, length: 4) + let innerEncoder = PostboxEncoder() + for object in value { + var typeHash: Int32 = murMurHashString32("\(type(of: object))") + self.buffer.write(&typeHash, offset: 0, length: 4) + + innerEncoder.reset() + object.encode(innerEncoder) + + var length: Int32 = Int32(innerEncoder.buffer.offset) + self.buffer.write(&length, offset: 0, length: 4) + self.buffer.write(innerEncoder.buffer.memory, offset: 0, length: Int(length)) + } + } + + public func encodeObjectArrayWithEncoder(_ value: [T], forKey key: StaticString, encoder: (T, PostboxEncoder) -> Void) { + self.encodeKey(key) + var t: Int8 = ValueType.ObjectArray.rawValue + self.buffer.write(&t, offset: 0, length: 1) + var length: Int32 = Int32(value.count) + self.buffer.write(&length, offset: 0, length: 4) + let innerEncoder = PostboxEncoder() + for object in value { + var typeHash: Int32 = murMurHashString32("\(type(of: object))") + self.buffer.write(&typeHash, offset: 0, length: 4) + + innerEncoder.reset() + encoder(object, innerEncoder) + + var length: Int32 = Int32(innerEncoder.buffer.offset) + self.buffer.write(&length, offset: 0, length: 4) + self.buffer.write(innerEncoder.buffer.memory, offset: 0, length: Int(length)) + } + } + + public func encodeGenericObjectArray(_ value: [PostboxCoding], forKey key: StaticString) { + self.encodeKey(key) + var t: Int8 = ValueType.ObjectArray.rawValue + self.buffer.write(&t, offset: 0, length: 1) + var length: Int32 = Int32(value.count) + self.buffer.write(&length, offset: 0, length: 4) + let innerEncoder = PostboxEncoder() + for object in value { + var typeHash: Int32 = murMurHashString32("\(type(of: object))") + self.buffer.write(&typeHash, offset: 0, length: 4) + + innerEncoder.reset() + object.encode(innerEncoder) + + var length: Int32 = Int32(innerEncoder.buffer.offset) + self.buffer.write(&length, offset: 0, length: 4) + self.buffer.write(innerEncoder.buffer.memory, offset: 0, length: Int(length)) + } + } + + public func encodeStringArray(_ value: [String], forKey key: StaticString) { + self.encodeKey(key) + var type: Int8 = ValueType.StringArray.rawValue + self.buffer.write(&type, offset: 0, length: 1) + var length: Int32 = Int32(value.count) + self.buffer.write(&length, offset: 0, length: 4) + + for object in value { + let data = object.data(using: .utf8, allowLossyConversion: true) ?? (String("").data(using: .utf8)!) + var length: Int32 = Int32(data.count) + self.buffer.write(&length, offset: 0, length: 4) + self.buffer.write(data) + } + } + + public func encodeBytesArray(_ value: [MemoryBuffer], forKey key: StaticString) { + self.encodeKey(key) + var type: Int8 = ValueType.BytesArray.rawValue + self.buffer.write(&type, offset: 0, length: 1) + var length: Int32 = Int32(value.count) + self.buffer.write(&length, offset: 0, length: 4) + + for object in value { + var length: Int32 = Int32(object.length) + self.buffer.write(&length, offset: 0, length: 4) + self.buffer.write(object.memory, offset: 0, length: object.length) + } + } + + public func encodeObjectDictionary(_ value: [K : V], forKey key: StaticString) where K: PostboxCoding { + self.encodeKey(key) + var t: Int8 = ValueType.ObjectDictionary.rawValue + self.buffer.write(&t, offset: 0, length: 1) + var length: Int32 = Int32(value.count) + self.buffer.write(&length, offset: 0, length: 4) + + let innerEncoder = PostboxEncoder() + for record in value { + var keyTypeHash: Int32 = murMurHashString32("\(type(of: record.0))") + self.buffer.write(&keyTypeHash, offset: 0, length: 4) + innerEncoder.reset() + record.0.encode(innerEncoder) + var keyLength: Int32 = Int32(innerEncoder.buffer.offset) + self.buffer.write(&keyLength, offset: 0, length: 4) + self.buffer.write(innerEncoder.buffer.memory, offset: 0, length: Int(keyLength)) + + var valueTypeHash: Int32 = murMurHashString32("\(type(of: record.1))") + self.buffer.write(&valueTypeHash, offset: 0, length: 4) + innerEncoder.reset() + record.1.encode(innerEncoder) + var valueLength: Int32 = Int32(innerEncoder.buffer.offset) + self.buffer.write(&valueLength, offset: 0, length: 4) + self.buffer.write(innerEncoder.buffer.memory, offset: 0, length: Int(valueLength)) + } + } + + public func encodeObjectDictionary(_ value: [K : V], forKey key: StaticString, keyEncoder: (K, PostboxEncoder) -> Void) { + self.encodeKey(key) + var t: Int8 = ValueType.ObjectDictionary.rawValue + self.buffer.write(&t, offset: 0, length: 1) + var length: Int32 = Int32(value.count) + self.buffer.write(&length, offset: 0, length: 4) + + let innerEncoder = PostboxEncoder() + for record in value { + var keyTypeHash: Int32 = murMurHashString32("\(type(of: record.0))") + self.buffer.write(&keyTypeHash, offset: 0, length: 4) + innerEncoder.reset() + keyEncoder(record.0, innerEncoder) + var keyLength: Int32 = Int32(innerEncoder.buffer.offset) + self.buffer.write(&keyLength, offset: 0, length: 4) + self.buffer.write(innerEncoder.buffer.memory, offset: 0, length: Int(keyLength)) + + var valueTypeHash: Int32 = murMurHashString32("\(type(of: record.1))") + self.buffer.write(&valueTypeHash, offset: 0, length: 4) + innerEncoder.reset() + record.1.encode(innerEncoder) + var valueLength: Int32 = Int32(innerEncoder.buffer.offset) + self.buffer.write(&valueLength, offset: 0, length: 4) + self.buffer.write(innerEncoder.buffer.memory, offset: 0, length: Int(valueLength)) + } + } + + public func encodeBytes(_ bytes: WriteBuffer, forKey key: StaticString) { + self.encodeKey(key) + var type: Int8 = ValueType.Bytes.rawValue + self.buffer.write(&type, offset: 0, length: 1) + var bytesLength: Int32 = Int32(bytes.offset) + self.buffer.write(&bytesLength, offset: 0, length: 4) + self.buffer.write(bytes.memory, offset: 0, length: bytes.offset) + } + + public func encodeBytes(_ bytes: ReadBuffer, forKey key: StaticString) { + self.encodeKey(key) + var type: Int8 = ValueType.Bytes.rawValue + self.buffer.write(&type, offset: 0, length: 1) + var bytesLength: Int32 = Int32(bytes.offset) + self.buffer.write(&bytesLength, offset: 0, length: 4) + self.buffer.write(bytes.memory, offset: 0, length: bytes.offset) + } + + public func encodeBytes(_ bytes: MemoryBuffer, forKey key: StaticString) { + self.encodeKey(key) + var type: Int8 = ValueType.Bytes.rawValue + self.buffer.write(&type, offset: 0, length: 1) + var bytesLength: Int32 = Int32(bytes.length) + self.buffer.write(&bytesLength, offset: 0, length: 4) + self.buffer.write(bytes.memory, offset: 0, length: bytes.length) + } + + public func encodeData(_ data: Data, forKey key: StaticString) { + self.encodeKey(key) + var type: Int8 = ValueType.Bytes.rawValue + self.buffer.write(&type, offset: 0, length: 1) + var bytesLength: Int32 = Int32(data.count) + self.buffer.write(&bytesLength, offset: 0, length: 4) + data.withUnsafeBytes { (bytes: UnsafePointer) -> Void in + self.buffer.write(bytes, offset: 0, length: Int(bytesLength)) + } + } + + public let sharedWriteBuffer = WriteBuffer() +} + +public final class PostboxDecoder { + private let buffer: MemoryBuffer + private var offset: Int = 0 + + public init(buffer: MemoryBuffer) { + self.buffer = buffer + } + + private class func skipValue(_ bytes: UnsafePointer, offset: inout Int, length: Int, valueType: ValueType) { + switch valueType { + case .Int32: + offset += 4 + case .Int64: + offset += 8 + case .Bool: + offset += 1 + case .Double: + offset += 8 + case .String: + var length: Int32 = 0 + memcpy(&length, bytes + offset, 4) + offset += 4 + Int(length) + case .Object: + var length: Int32 = 0 + memcpy(&length, bytes + (offset + 4), 4) + offset += 8 + Int(length) + case .Int32Array: + var length: Int32 = 0 + memcpy(&length, bytes + offset, 4) + offset += 4 + Int(length) * 4 + case .Int64Array: + var length: Int32 = 0 + memcpy(&length, bytes + offset, 4) + offset += 4 + Int(length) * 8 + case .ObjectArray: + var length: Int32 = 0 + memcpy(&length, bytes + offset, 4) + offset += 4 + var i: Int32 = 0 + while i < length { + var objectLength: Int32 = 0 + memcpy(&objectLength, bytes + (offset + 4), 4) + offset += 8 + Int(objectLength) + i += 1 + } + case .ObjectDictionary: + var length: Int32 = 0 + memcpy(&length, bytes + offset, 4) + offset += 4 + var i: Int32 = 0 + while i < length { + var keyLength: Int32 = 0 + memcpy(&keyLength, bytes + (offset + 4), 4) + offset += 8 + Int(keyLength) + + var valueLength: Int32 = 0 + memcpy(&valueLength, bytes + (offset + 4), 4) + offset += 8 + Int(valueLength) + i += 1 + } + case .Bytes: + var length: Int32 = 0 + memcpy(&length, bytes + offset, 4) + offset += 4 + Int(length) + case .Nil: + break + case .StringArray, .BytesArray: + var length: Int32 = 0 + memcpy(&length, bytes + offset, 4) + offset += 4 + var i: Int32 = 0 + while i < length { + var stringLength: Int32 = 0 + memcpy(&stringLength, bytes + offset, 4) + offset += 4 + Int(stringLength) + i += 1 + } + } + } + + private class func positionOnKey(_ rawBytes: UnsafeRawPointer, offset: inout Int, maxOffset: Int, length: Int, key: StaticString, valueType: ValueType) -> Bool + { + let bytes = rawBytes.assumingMemoryBound(to: Int8.self) + + let startOffset = offset + + let keyLength: Int = key.utf8CodeUnitCount + while (offset < maxOffset) { + let readKeyLength = bytes[offset] + assert(readKeyLength >= 0) + offset += 1 + offset += Int(readKeyLength) + + let readValueType = bytes[offset] + offset += 1 + + if keyLength == Int(readKeyLength) && memcmp(bytes + (offset - Int(readKeyLength) - 1), key.utf8Start, keyLength) == 0 { + if readValueType == valueType.rawValue { + return true + } else if readValueType == ValueType.Nil.rawValue { + return false + } else { + skipValue(bytes, offset: &offset, length: length, valueType: ValueType(rawValue: readValueType)!) + } + } else { + skipValue(bytes, offset: &offset, length: length, valueType: ValueType(rawValue: readValueType)!) + } + } + + if (startOffset != 0) { + offset = 0 + return positionOnKey(bytes, offset: &offset, maxOffset: startOffset, length: length, key: key, valueType: valueType) + } + + return false + } + + private class func positionOnStringKey(_ rawBytes: UnsafeRawPointer, offset: inout Int, maxOffset: Int, length: Int, key: String, valueType: ValueType) -> Bool + { + let bytes = rawBytes.assumingMemoryBound(to: Int8.self) + + let startOffset = offset + + let keyData = key.data(using: .utf8)! + + return keyData.withUnsafeBytes { (keyBytes: UnsafePointer) -> Bool in + let keyLength: Int = keyData.count + while (offset < maxOffset) { + let readKeyLength = bytes[offset] + assert(readKeyLength >= 0) + offset += 1 + offset += Int(readKeyLength) + + let readValueType = bytes[offset] + offset += 1 + + if keyLength == Int(readKeyLength) && memcmp(bytes + (offset - Int(readKeyLength) - 1), keyBytes, keyLength) == 0 { + if readValueType == valueType.rawValue { + return true + } else if readValueType == ValueType.Nil.rawValue { + return false + } else { + skipValue(bytes, offset: &offset, length: length, valueType: ValueType(rawValue: readValueType)!) + } + } else { + skipValue(bytes, offset: &offset, length: length, valueType: ValueType(rawValue: readValueType)!) + } + } + + if (startOffset != 0) { + offset = 0 + return positionOnStringKey(bytes, offset: &offset, maxOffset: startOffset, length: length, key: key, valueType: valueType) + } + + return false + } + } + + private class func positionOnKey(_ bytes: UnsafePointer, offset: inout Int, maxOffset: Int, length: Int, key: Int16, valueType: ValueType) -> Bool + { + var keyValue = key + let startOffset = offset + + let keyLength: Int = 2 + while (offset < maxOffset) + { + let readKeyLength = bytes[offset] + offset += 1 + offset += Int(readKeyLength) + + let readValueType = bytes[offset] + offset += 1 + + if readValueType != valueType.rawValue || keyLength != Int(readKeyLength) || memcmp(bytes + (offset - Int(readKeyLength) - 1), &keyValue, keyLength) != 0 { + skipValue(bytes, offset: &offset, length: length, valueType: ValueType(rawValue: readValueType)!) + } else { + return true + } + } + + if (startOffset != 0) + { + offset = 0 + return positionOnKey(bytes, offset: &offset, maxOffset: startOffset, length: length, key: key, valueType: valueType) + } + + return false + } + + public func decodeInt32ForKey(_ key: StaticString, orElse: Int32) -> Int32 { + if PostboxDecoder.positionOnKey(self.buffer.memory, offset: &self.offset, maxOffset: self.buffer.length, length: self.buffer.length, key: key, valueType: .Int32) { + var value: Int32 = 0 + memcpy(&value, self.buffer.memory + self.offset, 4) + self.offset += 4 + return value + } else { + return orElse + } + } + + public func decodeInt32ForKey(_ key: String, orElse: Int32) -> Int32 { + if PostboxDecoder.positionOnStringKey(self.buffer.memory, offset: &self.offset, maxOffset: self.buffer.length, length: self.buffer.length, key: key, valueType: .Int32) { + var value: Int32 = 0 + memcpy(&value, self.buffer.memory + self.offset, 4) + self.offset += 4 + return value + } else { + return orElse + } + } + + public func decodeOptionalInt32ForKey(_ key: StaticString) -> Int32? { + if PostboxDecoder.positionOnKey(self.buffer.memory, offset: &self.offset, maxOffset: self.buffer.length, length: self.buffer.length, key: key, valueType: .Int32) { + var value: Int32 = 0 + memcpy(&value, self.buffer.memory + self.offset, 4) + self.offset += 4 + return value + } else { + return nil + } + } + + public func decodeOptionalInt32ForKey(_ key: String) -> Int32? { + if PostboxDecoder.positionOnStringKey(self.buffer.memory, offset: &self.offset, maxOffset: self.buffer.length, length: self.buffer.length, key: key, valueType: .Int32) { + var value: Int32 = 0 + memcpy(&value, self.buffer.memory + self.offset, 4) + self.offset += 4 + return value + } else { + return nil + } + } + + public func decodeInt64ForKey(_ key: StaticString, orElse: Int64) -> Int64 { + if PostboxDecoder.positionOnKey(self.buffer.memory, offset: &self.offset, maxOffset: self.buffer.length, length: self.buffer.length, key: key, valueType: .Int64) { + var value: Int64 = 0 + memcpy(&value, self.buffer.memory + self.offset, 8) + self.offset += 8 + return value + } else { + return orElse + } + } + + public func decodeOptionalInt64ForKey(_ key: StaticString) -> Int64? { + if PostboxDecoder.positionOnKey(self.buffer.memory, offset: &self.offset, maxOffset: self.buffer.length, length: self.buffer.length, key: key, valueType: .Int64) { + var value: Int64 = 0 + memcpy(&value, self.buffer.memory + self.offset, 8) + self.offset += 8 + return value + } else { + return nil + } + } + + public func decodeBoolForKey(_ key: StaticString, orElse: Bool) -> Bool { + if PostboxDecoder.positionOnKey(self.buffer.memory, offset: &self.offset, maxOffset: self.buffer.length, length: self.buffer.length, key: key, valueType: .Bool) { + var value: Int8 = 0 + memcpy(&value, self.buffer.memory + self.offset, 1) + self.offset += 1 + return value != 0 + } else { + return orElse + } + } + + public func decodeOptionalBoolForKey(_ key: StaticString) -> Bool? { + if PostboxDecoder.positionOnKey(self.buffer.memory, offset: &self.offset, maxOffset: self.buffer.length, length: self.buffer.length, key: key, valueType: .Bool) { + var value: Int8 = 0 + memcpy(&value, self.buffer.memory + self.offset, 1) + self.offset += 1 + return value != 0 + } else { + return nil + } + } + + public func decodeDoubleForKey(_ key: StaticString, orElse: Double) -> Double { + if PostboxDecoder.positionOnKey(self.buffer.memory, offset: &self.offset, maxOffset: self.buffer.length, length: self.buffer.length, key: key, valueType: .Double) { + var value: Double = 0 + memcpy(&value, self.buffer.memory + self.offset, 8) + self.offset += 8 + return value + } else { + return orElse + } + } + + public func decodeOptionalDoubleForKey(_ key: StaticString) -> Double? { + if PostboxDecoder.positionOnKey(self.buffer.memory, offset: &self.offset, maxOffset: self.buffer.length, length: self.buffer.length, key: key, valueType: .Double) { + var value: Double = 0 + memcpy(&value, self.buffer.memory + self.offset, 8) + self.offset += 8 + return value + } else { + return 0 + } + } + + public func decodeStringForKey(_ key: StaticString, orElse: String) -> String { + if PostboxDecoder.positionOnKey(self.buffer.memory, offset: &self.offset, maxOffset: self.buffer.length, length: self.buffer.length, key: key, valueType: .String) { + var length: Int32 = 0 + memcpy(&length, self.buffer.memory + self.offset, 4) + let data = Data(bytes: self.buffer.memory.assumingMemoryBound(to: UInt8.self).advanced(by: self.offset + 4), count: Int(length)) + self.offset += 4 + Int(length) + return String(data: data, encoding: .utf8) ?? orElse + } else { + return orElse + } + } + + public func decodeOptionalStringForKey(_ key: StaticString) -> String? { + if PostboxDecoder.positionOnKey(self.buffer.memory, offset: &self.offset, maxOffset: self.buffer.length, length: self.buffer.length, key: key, valueType: .String) { + var length: Int32 = 0 + memcpy(&length, self.buffer.memory + self.offset, 4) + let data = Data(bytes: self.buffer.memory.assumingMemoryBound(to: UInt8.self).advanced(by: self.offset + 4), count: Int(length)) + self.offset += 4 + Int(length) + return String(data: data, encoding: .utf8) + } else { + return nil + } + } + + public func decodeRootObject() -> PostboxCoding? { + return self.decodeObjectForKey("_") + } + + public func decodeObjectForKey(_ key: StaticString) -> PostboxCoding? { + if PostboxDecoder.positionOnKey(self.buffer.memory, offset: &self.offset, maxOffset: self.buffer.length, length: self.buffer.length, key: key, valueType: .Object) { + var typeHash: Int32 = 0 + memcpy(&typeHash, self.buffer.memory + self.offset, 4) + self.offset += 4 + + var length: Int32 = 0 + memcpy(&length, self.buffer.memory + self.offset, 4) + + let innerDecoder = PostboxDecoder(buffer: ReadBuffer(memory: self.buffer.memory + (self.offset + 4), length: Int(length), freeWhenDone: false)) + self.offset += 4 + Int(length) + + return typeStore.decode(typeHash, decoder: innerDecoder) + } else { + return nil + } + } + + public func decodeObjectForKey(_ key: StaticString, decoder: (PostboxDecoder) -> PostboxCoding) -> PostboxCoding? { + if PostboxDecoder.positionOnKey(self.buffer.memory, offset: &self.offset, maxOffset: self.buffer.length, length: self.buffer.length, key: key, valueType: .Object) { + var typeHash: Int32 = 0 + memcpy(&typeHash, self.buffer.memory + self.offset, 4) + self.offset += 4 + + var length: Int32 = 0 + memcpy(&length, self.buffer.memory + self.offset, 4) + + let innerDecoder = PostboxDecoder(buffer: ReadBuffer(memory: self.buffer.memory + (self.offset + 4), length: Int(length), freeWhenDone: false)) + self.offset += 4 + Int(length) + + return decoder(innerDecoder) + } else { + return nil + } + } + + public func decodeAnyObjectForKey(_ key: StaticString, decoder: (PostboxDecoder) -> Any?) -> Any? { + if PostboxDecoder.positionOnKey(self.buffer.memory, offset: &self.offset, maxOffset: self.buffer.length, length: self.buffer.length, key: key, valueType: .Object) { + var typeHash: Int32 = 0 + memcpy(&typeHash, self.buffer.memory + self.offset, 4) + self.offset += 4 + + var length: Int32 = 0 + memcpy(&length, self.buffer.memory + self.offset, 4) + + let innerDecoder = PostboxDecoder(buffer: ReadBuffer(memory: self.buffer.memory + (self.offset + 4), length: Int(length), freeWhenDone: false)) + self.offset += 4 + Int(length) + + return decoder(innerDecoder) + } else { + return nil + } + } + + public func decodeObjectForKeyThrowing(_ key: StaticString, decoder: (PostboxDecoder) throws -> Any) throws -> Any? { + if PostboxDecoder.positionOnKey(self.buffer.memory, offset: &self.offset, maxOffset: self.buffer.length, length: self.buffer.length, key: key, valueType: .Object) { + var typeHash: Int32 = 0 + memcpy(&typeHash, self.buffer.memory + self.offset, 4) + self.offset += 4 + + var length: Int32 = 0 + memcpy(&length, self.buffer.memory + self.offset, 4) + + let innerDecoder = PostboxDecoder(buffer: ReadBuffer(memory: self.buffer.memory + (self.offset + 4), length: Int(length), freeWhenDone: false)) + self.offset += 4 + Int(length) + + return try decoder(innerDecoder) + } else { + return nil + } + } + + public func decodeInt32ArrayForKey(_ key: StaticString) -> [Int32] { + if PostboxDecoder.positionOnKey(self.buffer.memory, offset: &self.offset, maxOffset: self.buffer.length, length: self.buffer.length, key: key, valueType: .Int32Array) { + var length: Int32 = 0 + memcpy(&length, self.buffer.memory + self.offset, 4) + var array: [Int32] = [] + array.reserveCapacity(Int(length)) + var i: Int32 = 0 + while i < length { + var element: Int32 = 0 + memcpy(&element, self.buffer.memory + (self.offset + 4 + 4 * Int(i)), 4) + array.append(element) + i += 1 + } + self.offset += 4 + Int(length) * 4 + return array + } else { + return [] + } + } + + public func decodeInt64ArrayForKey(_ key: StaticString) -> [Int64] { + if PostboxDecoder.positionOnKey(self.buffer.memory, offset: &self.offset, maxOffset: self.buffer.length, length: self.buffer.length, key: key, valueType: .Int64Array) { + var length: Int32 = 0 + memcpy(&length, self.buffer.memory + self.offset, 4) + var array: [Int64] = [] + array.reserveCapacity(Int(length)) + var i: Int32 = 0 + while i < length { + var element: Int64 = 0 + memcpy(&element, self.buffer.memory + (self.offset + 4 + 8 * Int(i)), 8) + array.append(element) + i += 1 + } + self.offset += 4 + Int(length) * 8 + return array + } else { + return [] + } + } + + public func decodeObjectArrayWithDecoderForKey(_ key: StaticString) -> [T] where T: PostboxCoding { + if PostboxDecoder.positionOnKey(self.buffer.memory, offset: &self.offset, maxOffset: self.buffer.length, length: self.buffer.length, key: key, valueType: .ObjectArray) { + var length: Int32 = 0 + memcpy(&length, self.buffer.memory + self.offset, 4) + self.offset += 4 + + var array: [T] = [] + array.reserveCapacity(Int(length)) + + var i: Int32 = 0 + while i < length { + var typeHash: Int32 = 0 + memcpy(&typeHash, self.buffer.memory + self.offset, 4) + self.offset += 4 + + var objectLength: Int32 = 0 + memcpy(&objectLength, self.buffer.memory + self.offset, 4) + + let innerDecoder = PostboxDecoder(buffer: ReadBuffer(memory: self.buffer.memory + (self.offset + 4), length: Int(objectLength), freeWhenDone: false)) + self.offset += 4 + Int(objectLength) + + array.append(T(decoder: innerDecoder)) + + i += 1 + } + + return array + } else { + return [] + } + } + + public func decodeOptionalObjectArrayWithDecoderForKey(_ key: StaticString) -> [T]? where T: PostboxCoding { + if PostboxDecoder.positionOnKey(self.buffer.memory, offset: &self.offset, maxOffset: self.buffer.length, length: self.buffer.length, key: key, valueType: .ObjectArray) { + var length: Int32 = 0 + memcpy(&length, self.buffer.memory + self.offset, 4) + self.offset += 4 + + var array: [T] = [] + array.reserveCapacity(Int(length)) + + var i: Int32 = 0 + while i < length { + var typeHash: Int32 = 0 + memcpy(&typeHash, self.buffer.memory + self.offset, 4) + self.offset += 4 + + var objectLength: Int32 = 0 + memcpy(&objectLength, self.buffer.memory + self.offset, 4) + + let innerDecoder = PostboxDecoder(buffer: ReadBuffer(memory: self.buffer.memory + (self.offset + 4), length: Int(objectLength), freeWhenDone: false)) + self.offset += 4 + Int(objectLength) + + array.append(T(decoder: innerDecoder)) + + i += 1 + } + + return array + } else { + return nil + } + } + + public func decodeObjectArrayWithCustomDecoderForKey(_ key: StaticString, decoder: (PostboxDecoder) throws -> T) throws -> [T] { + if PostboxDecoder.positionOnKey(self.buffer.memory, offset: &self.offset, maxOffset: self.buffer.length, length: self.buffer.length, key: key, valueType: .ObjectArray) { + var length: Int32 = 0 + memcpy(&length, self.buffer.memory + self.offset, 4) + self.offset += 4 + + var array: [T] = [] + array.reserveCapacity(Int(length)) + + var i: Int32 = 0 + while i < length { + var typeHash: Int32 = 0 + memcpy(&typeHash, self.buffer.memory + self.offset, 4) + self.offset += 4 + + var objectLength: Int32 = 0 + memcpy(&objectLength, self.buffer.memory + self.offset, 4) + + let innerDecoder = PostboxDecoder(buffer: ReadBuffer(memory: self.buffer.memory + (self.offset + 4), length: Int(objectLength), freeWhenDone: false)) + self.offset += 4 + Int(objectLength) + + let value = try decoder(innerDecoder) + array.append(value) + + i += 1 + } + + return array + } else { + return [] + } + } + + public func decodeStringArrayForKey(_ key: StaticString) -> [String] { + if PostboxDecoder.positionOnKey(self.buffer.memory, offset: &self.offset, maxOffset: self.buffer.length, length: self.buffer.length, key: key, valueType: .StringArray) { + var length: Int32 = 0 + memcpy(&length, self.buffer.memory + self.offset, 4) + self.offset += 4 + + var array: [String] = [] + array.reserveCapacity(Int(length)) + + var i: Int32 = 0 + while i < length { + var length: Int32 = 0 + memcpy(&length, self.buffer.memory + self.offset, 4) + let data = Data(bytes: self.buffer.memory.assumingMemoryBound(to: UInt8.self).advanced(by: self.offset + 4), count: Int(length)) + self.offset += 4 + Int(length) + if let string = String(data: data, encoding: .utf8) { + array.append(string) + } else { + assertionFailure() + array.append("") + } + + i += 1 + } + + return array + } else { + return [] + } + } + + public func decodeBytesArrayForKey(_ key: StaticString) -> [MemoryBuffer] { + if PostboxDecoder.positionOnKey(self.buffer.memory, offset: &self.offset, maxOffset: self.buffer.length, length: self.buffer.length, key: key, valueType: .BytesArray) { + var length: Int32 = 0 + memcpy(&length, self.buffer.memory + self.offset, 4) + self.offset += 4 + + var array: [MemoryBuffer] = [] + array.reserveCapacity(Int(length)) + + var i: Int32 = 0 + while i < length { + var length: Int32 = 0 + memcpy(&length, self.buffer.memory + self.offset, 4) + let bytes = malloc(Int(length))! + memcpy(bytes, self.buffer.memory.advanced(by: self.offset + 4), Int(length)) + array.append(MemoryBuffer(memory: bytes, capacity: Int(length), length: Int(length), freeWhenDone: true)) + self.offset += 4 + Int(length) + + i += 1 + } + + return array + } else { + return [] + } + } + + public func decodeObjectArrayForKey(_ key: StaticString) -> [T] where T: PostboxCoding { + if PostboxDecoder.positionOnKey(self.buffer.memory, offset: &self.offset, maxOffset: self.buffer.length, length: self.buffer.length, key: key, valueType: .ObjectArray) { + var length: Int32 = 0 + memcpy(&length, self.buffer.memory + self.offset, 4) + self.offset += 4 + + var array: [T] = [] + array.reserveCapacity(Int(length)) + + var failed = false + var i: Int32 = 0 + while i < length { + var typeHash: Int32 = 0 + memcpy(&typeHash, self.buffer.memory + self.offset, 4) + self.offset += 4 + + var objectLength: Int32 = 0 + memcpy(&objectLength, self.buffer.memory + self.offset, 4) + + let innerDecoder = PostboxDecoder(buffer: ReadBuffer(memory: self.buffer.memory + (self.offset + 4), length: Int(objectLength), freeWhenDone: false)) + self.offset += 4 + Int(objectLength) + + if !failed { + if let object = typeStore.decode(typeHash, decoder: innerDecoder) as? T { + array.append(object) + } else { + failed = true + } + } + + i += 1 + } + + if failed { + return [] + } else { + return array + } + } else { + return [] + } + } + + public func decodeObjectArrayForKey(_ key: StaticString) -> [PostboxCoding] { + if PostboxDecoder.positionOnKey(self.buffer.memory, offset: &self.offset, maxOffset: self.buffer.length, length: self.buffer.length, key: key, valueType: .ObjectArray) { + var length: Int32 = 0 + memcpy(&length, self.buffer.memory + self.offset, 4) + self.offset += 4 + + var array: [PostboxCoding] = [] + array.reserveCapacity(Int(length)) + + var failed = false + var i: Int32 = 0 + while i < length { + var typeHash: Int32 = 0 + memcpy(&typeHash, self.buffer.memory + self.offset, 4) + self.offset += 4 + + var objectLength: Int32 = 0 + memcpy(&objectLength, self.buffer.memory + self.offset, 4) + + let innerDecoder = PostboxDecoder(buffer: ReadBuffer(memory: self.buffer.memory + (self.offset + 4), length: Int(objectLength), freeWhenDone: false)) + self.offset += 4 + Int(objectLength) + + if !failed { + if let object = typeStore.decode(typeHash, decoder: innerDecoder) { + array.append(object) + } else { + failed = true + } + } + + i += 1 + } + + if failed { + return [] + } else { + return array + } + } else { + return [] + } + } + + public func decodeObjectDictionaryForKey(_ key: StaticString) -> [K : V] where K: PostboxCoding, K: Hashable { + if PostboxDecoder.positionOnKey(self.buffer.memory, offset: &self.offset, maxOffset: self.buffer.length, length: self.buffer.length, key: key, valueType: .ObjectDictionary) { + var length: Int32 = 0 + memcpy(&length, self.buffer.memory + self.offset, 4) + self.offset += 4 + + var dictionary: [K : V] = [:] + + var failed = false + var i: Int32 = 0 + while i < length { + var keyHash: Int32 = 0 + memcpy(&keyHash, self.buffer.memory + self.offset, 4) + self.offset += 4 + + var keyLength: Int32 = 0 + memcpy(&keyLength, self.buffer.memory + self.offset, 4) + + var innerDecoder = PostboxDecoder(buffer: ReadBuffer(memory: self.buffer.memory + (self.offset + 4), length: Int(keyLength), freeWhenDone: false)) + self.offset += 4 + Int(keyLength) + + let key = failed ? nil : (typeStore.decode(keyHash, decoder: innerDecoder) as? K) + + var valueHash: Int32 = 0 + memcpy(&valueHash, self.buffer.memory + self.offset, 4) + self.offset += 4 + + var valueLength: Int32 = 0 + memcpy(&valueLength, self.buffer.memory + self.offset, 4) + + innerDecoder = PostboxDecoder(buffer: ReadBuffer(memory: self.buffer.memory + (self.offset + 4), length: Int(valueLength), freeWhenDone: false)) + self.offset += 4 + Int(valueLength) + + let value = failed ? nil : (typeStore.decode(valueHash, decoder: innerDecoder) as? V) + + if let key = key, let value = value { + dictionary[key] = value + } else { + failed = true + } + + i += 1 + } + + if failed { + return [:] + } else { + return dictionary + } + } else { + return [:] + } + } + + public func decodeObjectDictionaryForKey(_ key: StaticString, keyDecoder: (PostboxDecoder) -> K) -> [K : V] where K: Hashable { + if PostboxDecoder.positionOnKey(self.buffer.memory, offset: &self.offset, maxOffset: self.buffer.length, length: self.buffer.length, key: key, valueType: .ObjectDictionary) { + var length: Int32 = 0 + memcpy(&length, self.buffer.memory + self.offset, 4) + self.offset += 4 + + var dictionary: [K : V] = [:] + + var failed = false + var i: Int32 = 0 + while i < length { + var keyHash: Int32 = 0 + memcpy(&keyHash, self.buffer.memory + self.offset, 4) + self.offset += 4 + + var keyLength: Int32 = 0 + memcpy(&keyLength, self.buffer.memory + self.offset, 4) + + var innerDecoder = PostboxDecoder(buffer: ReadBuffer(memory: self.buffer.memory + (self.offset + 4), length: Int(keyLength), freeWhenDone: false)) + self.offset += 4 + Int(keyLength) + + var key: K? + if !failed { + key = keyDecoder(innerDecoder) + } + + var valueHash: Int32 = 0 + memcpy(&valueHash, self.buffer.memory + self.offset, 4) + self.offset += 4 + + var valueLength: Int32 = 0 + memcpy(&valueLength, self.buffer.memory + self.offset, 4) + + innerDecoder = PostboxDecoder(buffer: ReadBuffer(memory: self.buffer.memory + (self.offset + 4), length: Int(valueLength), freeWhenDone: false)) + self.offset += 4 + Int(valueLength) + + let value = failed ? nil : (typeStore.decode(valueHash, decoder: innerDecoder) as? V) + + if let key = key, let value = value { + dictionary[key] = value + } else { + failed = true + } + + i += 1 + } + + if failed { + return [:] + } else { + return dictionary + } + } else { + return [:] + } + } + + public func decodeObjectDictionaryForKey(_ key: StaticString, keyDecoder: (PostboxDecoder) -> K, valueDecoder: (PostboxDecoder) -> V) -> [K : V] where K: Hashable { + if PostboxDecoder.positionOnKey(self.buffer.memory, offset: &self.offset, maxOffset: self.buffer.length, length: self.buffer.length, key: key, valueType: .ObjectDictionary) { + var length: Int32 = 0 + memcpy(&length, self.buffer.memory + self.offset, 4) + self.offset += 4 + + var dictionary: [K : V] = [:] + + var failed = false + var i: Int32 = 0 + while i < length { + var keyHash: Int32 = 0 + memcpy(&keyHash, self.buffer.memory + self.offset, 4) + self.offset += 4 + + var keyLength: Int32 = 0 + memcpy(&keyLength, self.buffer.memory + self.offset, 4) + + var innerDecoder = PostboxDecoder(buffer: ReadBuffer(memory: self.buffer.memory + (self.offset + 4), length: Int(keyLength), freeWhenDone: false)) + self.offset += 4 + Int(keyLength) + + var key: K? + if !failed { + key = keyDecoder(innerDecoder) + } + + var valueHash: Int32 = 0 + memcpy(&valueHash, self.buffer.memory + self.offset, 4) + self.offset += 4 + + var valueLength: Int32 = 0 + memcpy(&valueLength, self.buffer.memory + self.offset, 4) + + innerDecoder = PostboxDecoder(buffer: ReadBuffer(memory: self.buffer.memory + (self.offset + 4), length: Int(valueLength), freeWhenDone: false)) + self.offset += 4 + Int(valueLength) + + let value = failed ? nil : (valueDecoder(innerDecoder) as V) + + if let key = key, let value = value { + dictionary[key] = value + } else { + failed = true + } + + i += 1 + } + + if failed { + return [:] + } else { + return dictionary + } + } else { + return [:] + } + } + + public func decodeBytesForKeyNoCopy(_ key: StaticString) -> ReadBuffer? { + if PostboxDecoder.positionOnKey(self.buffer.memory, offset: &self.offset, maxOffset: self.buffer.length, length: self.buffer.length, key: key, valueType: .Bytes) { + var length: Int32 = 0 + memcpy(&length, self.buffer.memory + self.offset, 4) + self.offset += 4 + Int(length) + return ReadBuffer(memory: self.buffer.memory.advanced(by: self.offset - Int(length)), length: Int(length), freeWhenDone: false) + } else { + return nil + } + } + + public func decodeBytesForKey(_ key: StaticString) -> ReadBuffer? { + if PostboxDecoder.positionOnKey(self.buffer.memory, offset: &self.offset, maxOffset: self.buffer.length, length: self.buffer.length, key: key, valueType: .Bytes) { + var length: Int32 = 0 + memcpy(&length, self.buffer.memory + self.offset, 4) + self.offset += 4 + Int(length) + let copyBytes = malloc(Int(length))! + memcpy(copyBytes, self.buffer.memory.advanced(by: self.offset - Int(length)), Int(length)) + return ReadBuffer(memory: copyBytes, length: Int(length), freeWhenDone: true) + } else { + return nil + } + } + + public func decodeDataForKey(_ key: StaticString) -> Data? { + if PostboxDecoder.positionOnKey(self.buffer.memory, offset: &self.offset, maxOffset: self.buffer.length, length: self.buffer.length, key: key, valueType: .Bytes) { + var length: Int32 = 0 + memcpy(&length, self.buffer.memory + self.offset, 4) + self.offset += 4 + Int(length) + var result = Data(count: Int(length)) + result.withUnsafeMutableBytes { (bytes: UnsafeMutablePointer) -> Void in + memcpy(bytes, self.buffer.memory.advanced(by: self.offset - Int(length)), Int(length)) + } + return result + } else { + return nil + } + } +} diff --git a/submodules/Database/PostboxDataTypes/BUCK b/submodules/Database/PostboxDataTypes/BUCK new file mode 100644 index 0000000000..837af3a0a7 --- /dev/null +++ b/submodules/Database/PostboxDataTypes/BUCK @@ -0,0 +1,15 @@ +load("//Config:buck_rule_macros.bzl", "static_library") + +static_library( + name = "PostboxDataTypes", + srcs = glob([ + "Sources/**/*.swift", + ]), + deps = [ + "//submodules/Database/ValueBox:ValueBox", + "//submodules/Database/PostboxCoding:PostboxCoding", + ], + frameworks = [ + "$SDKROOT/System/Library/Frameworks/Foundation.framework", + ], +) diff --git a/submodules/Database/PostboxDataTypes/Sources/ChatListTotalUnreadState.swift b/submodules/Database/PostboxDataTypes/Sources/ChatListTotalUnreadState.swift new file mode 100644 index 0000000000..ab6bce8721 --- /dev/null +++ b/submodules/Database/PostboxDataTypes/Sources/ChatListTotalUnreadState.swift @@ -0,0 +1,86 @@ +import Foundation +import PostboxCoding + +public enum ChatListTotalUnreadStateCategory: Int32 { + case filtered = 0 + case raw = 1 +} + +public enum ChatListTotalUnreadStateStats: Int32 { + case messages = 0 + case chats = 1 +} + +public struct ChatListTotalUnreadCounters: PostboxCoding, Equatable { + public var messageCount: Int32 + public var chatCount: Int32 + + public init(messageCount: Int32, chatCount: Int32) { + self.messageCount = messageCount + self.chatCount = chatCount + } + + public init(decoder: PostboxDecoder) { + self.messageCount = decoder.decodeInt32ForKey("m", orElse: 0) + self.chatCount = decoder.decodeInt32ForKey("c", orElse: 0) + } + + public func encode(_ encoder: PostboxEncoder) { + encoder.encodeInt32(self.messageCount, forKey: "m") + encoder.encodeInt32(self.chatCount, forKey: "c") + } +} + +public struct ChatListTotalUnreadState: PostboxCoding, Equatable { + public var absoluteCounters: [PeerSummaryCounterTags: ChatListTotalUnreadCounters] + public var filteredCounters: [PeerSummaryCounterTags: ChatListTotalUnreadCounters] + + public init(absoluteCounters: [PeerSummaryCounterTags: ChatListTotalUnreadCounters], filteredCounters: [PeerSummaryCounterTags: ChatListTotalUnreadCounters]) { + self.absoluteCounters = absoluteCounters + self.filteredCounters = filteredCounters + } + + public init(decoder: PostboxDecoder) { + self.absoluteCounters = decoder.decodeObjectDictionaryForKey("ad", keyDecoder: { decoder in + return PeerSummaryCounterTags(rawValue: decoder.decodeInt32ForKey("k", orElse: 0)) + }, valueDecoder: { decoder in + return ChatListTotalUnreadCounters(decoder: decoder) + }) + self.filteredCounters = decoder.decodeObjectDictionaryForKey("fd", keyDecoder: { decoder in + return PeerSummaryCounterTags(rawValue: decoder.decodeInt32ForKey("k", orElse: 0)) + }, valueDecoder: { decoder in + return ChatListTotalUnreadCounters(decoder: decoder) + }) + } + + public func encode(_ encoder: PostboxEncoder) { + encoder.encodeObjectDictionary(self.absoluteCounters, forKey: "ad", keyEncoder: { key, encoder in + encoder.encodeInt32(key.rawValue, forKey: "k") + }) + encoder.encodeObjectDictionary(self.filteredCounters, forKey: "fd", keyEncoder: { key, encoder in + encoder.encodeInt32(key.rawValue, forKey: "k") + }) + } + + public func count(for category: ChatListTotalUnreadStateCategory, in statsType: ChatListTotalUnreadStateStats, with tags: PeerSummaryCounterTags) -> Int32 { + let counters: [PeerSummaryCounterTags: ChatListTotalUnreadCounters] + switch category { + case .raw: + counters = self.absoluteCounters + case .filtered: + counters = self.filteredCounters + } + var result: Int32 = 0 + for tag in tags { + if let category = counters[tag] { + switch statsType { + case .messages: + result = result &+ category.messageCount + case .chats: + result = result &+ category.chatCount + } + } + } + return result + } +} diff --git a/submodules/Database/PostboxDataTypes/Sources/MessageId.swift b/submodules/Database/PostboxDataTypes/Sources/MessageId.swift new file mode 100644 index 0000000000..ff38f64dfe --- /dev/null +++ b/submodules/Database/PostboxDataTypes/Sources/MessageId.swift @@ -0,0 +1,280 @@ +import Foundation +import Buffers + +public struct MessageId: Hashable, Comparable, CustomStringConvertible { + public typealias Namespace = Int32 + public typealias Id = Int32 + + public let peerId: PeerId + public let namespace: Namespace + public let id: Id + + public var description: String { + get { + return "\(namespace)_\(id)" + } + } + + public init(peerId: PeerId, namespace: Namespace, id: Id) { + self.peerId = peerId + self.namespace = namespace + self.id = id + } + + public init(_ buffer: ReadBuffer) { + var peerIdNamespaceValue: Int32 = 0 + memcpy(&peerIdNamespaceValue, buffer.memory + buffer.offset, 4) + var peerIdIdValue: Int32 = 0 + memcpy(&peerIdIdValue, buffer.memory + (buffer.offset + 4), 4) + self.peerId = PeerId(namespace: peerIdNamespaceValue, id: peerIdIdValue) + + var namespaceValue: Int32 = 0 + memcpy(&namespaceValue, buffer.memory + (buffer.offset + 8), 4) + self.namespace = namespaceValue + var idValue: Int32 = 0 + memcpy(&idValue, buffer.memory + (buffer.offset + 12), 4) + self.id = idValue + + buffer.offset += 16 + } + + public func encodeToBuffer(_ buffer: WriteBuffer) { + var peerIdNamespace = self.peerId.namespace + var peerIdId = self.peerId.id + var namespace = self.namespace + var id = self.id + buffer.write(&peerIdNamespace, offset: 0, length: 4); + buffer.write(&peerIdId, offset: 0, length: 4); + buffer.write(&namespace, offset: 0, length: 4); + buffer.write(&id, offset: 0, length: 4); + } + + public static func encodeArrayToBuffer(_ array: [MessageId], buffer: WriteBuffer) { + var length: Int32 = Int32(array.count) + buffer.write(&length, offset: 0, length: 4) + for id in array { + id.encodeToBuffer(buffer) + } + } + + public static func decodeArrayFromBuffer(_ buffer: ReadBuffer) -> [MessageId] { + var length: Int32 = 0 + memcpy(&length, buffer.memory, 4) + buffer.offset += 4 + var i = 0 + var array: [MessageId] = [] + while i < Int(length) { + array.append(MessageId(buffer)) + i += 1 + } + return array + } + + public static func <(lhs: MessageId, rhs: MessageId) -> Bool { + if lhs.namespace == rhs.namespace { + if lhs.id == rhs.id { + return lhs.peerId < rhs.peerId + } else { + return lhs.id < rhs.id + } + } else { + return lhs.namespace < rhs.namespace + } + } +} + +public struct ChatListIndex: Comparable, Hashable { + public let pinningIndex: UInt16? + public let messageIndex: MessageIndex + + public init(pinningIndex: UInt16?, messageIndex: MessageIndex) { + self.pinningIndex = pinningIndex + self.messageIndex = messageIndex + } + + public static func <(lhs: ChatListIndex, rhs: ChatListIndex) -> Bool { + if let lhsPinningIndex = lhs.pinningIndex, let rhsPinningIndex = rhs.pinningIndex { + if lhsPinningIndex > rhsPinningIndex { + return true + } else if lhsPinningIndex < rhsPinningIndex { + return false + } + } else if lhs.pinningIndex != nil { + return false + } else if rhs.pinningIndex != nil { + return true + } + return lhs.messageIndex < rhs.messageIndex + } + + public var hashValue: Int { + return self.messageIndex.hashValue + } + + public static var absoluteUpperBound: ChatListIndex { + return ChatListIndex(pinningIndex: 0, messageIndex: MessageIndex.absoluteUpperBound()) + } + + public static var absoluteLowerBound: ChatListIndex { + return ChatListIndex(pinningIndex: nil, messageIndex: MessageIndex.absoluteLowerBound()) + } + + public var predecessor: ChatListIndex { + return ChatListIndex(pinningIndex: self.pinningIndex, messageIndex: self.messageIndex.predecessor()) + } + + public var successor: ChatListIndex { + return ChatListIndex(pinningIndex: self.pinningIndex, messageIndex: self.messageIndex.successor()) + } +} + +public struct MessageTags: OptionSet, Sequence, Hashable { + public var rawValue: UInt32 + + public init(rawValue: UInt32) { + self.rawValue = rawValue + } + + public init() { + self.rawValue = 0 + } + + public static let All = MessageTags(rawValue: 0xffffffff) + + public var containsSingleElement: Bool { + var hasOne = false + for i in 0 ..< 31 { + let tag = (self.rawValue >> UInt32(i)) & 1 + if tag != 0 { + if hasOne { + return false + } else { + hasOne = true + } + } + } + return hasOne + } + + public func makeIterator() -> AnyIterator { + var index = 0 + return AnyIterator { () -> MessageTags? in + while index < 31 { + let currentTags = self.rawValue >> UInt32(index) + let tag = MessageTags(rawValue: 1 << UInt32(index)) + index += 1 + if currentTags == 0 { + break + } + + if (currentTags & 1) != 0 { + return tag + } + } + return nil + } + } +} + +public struct GlobalMessageTags: OptionSet, Sequence, Hashable { + public var rawValue: UInt32 + + public init(rawValue: UInt32) { + self.rawValue = rawValue + } + + public init() { + self.rawValue = 0 + } + + var isSingleTag: Bool { + let t = Int32(bitPattern: self.rawValue) + return t != 0 && t == (t & (-t)) + } + + public func makeIterator() -> AnyIterator { + var index = 0 + return AnyIterator { () -> GlobalMessageTags? in + while index < 31 { + let currentTags = self.rawValue >> UInt32(index) + let tag = GlobalMessageTags(rawValue: 1 << UInt32(index)) + index += 1 + if currentTags == 0 { + break + } + + if (currentTags & 1) != 0 { + return tag + } + } + return nil + } + } + + public var hashValue: Int { + return self.rawValue.hashValue + } +} + +public struct LocalMessageTags: OptionSet, Sequence, Hashable { + public var rawValue: UInt32 + + public init(rawValue: UInt32) { + self.rawValue = rawValue + } + + public init() { + self.rawValue = 0 + } + + var isSingleTag: Bool { + let t = Int32(bitPattern: self.rawValue) + return t != 0 && t == (t & (-t)) + } + + public func makeIterator() -> AnyIterator { + var index = 0 + return AnyIterator { () -> LocalMessageTags? in + while index < 31 { + let currentTags = self.rawValue >> UInt32(index) + let tag = LocalMessageTags(rawValue: 1 << UInt32(index)) + index += 1 + if currentTags == 0 { + break + } + + if (currentTags & 1) != 0 { + return tag + } + } + return nil + } + } + + public var hashValue: Int { + return self.rawValue.hashValue + } +} + +public struct MessageFlags: OptionSet { + public var rawValue: UInt32 + + public init(rawValue: UInt32) { + self.rawValue = rawValue + } + + public init() { + self.rawValue = 0 + } + + public static let Unsent = MessageFlags(rawValue: 1) + public static let Failed = MessageFlags(rawValue: 2) + public static let Incoming = MessageFlags(rawValue: 4) + public static let TopIndexable = MessageFlags(rawValue: 16) + public static let Sending = MessageFlags(rawValue: 32) + public static let CanBeGroupedIntoFeed = MessageFlags(rawValue: 64) + public static let WasScheduled = MessageFlags(rawValue: 128) + public static let CountedAsIncoming = MessageFlags(rawValue: 256) + + public static let IsIncomingMask = MessageFlags([.Incoming, .CountedAsIncoming]) +} diff --git a/submodules/Database/PostboxDataTypes/Sources/MessageIndex.swift b/submodules/Database/PostboxDataTypes/Sources/MessageIndex.swift new file mode 100644 index 0000000000..0a1136c74f --- /dev/null +++ b/submodules/Database/PostboxDataTypes/Sources/MessageIndex.swift @@ -0,0 +1,79 @@ +import Foundation + +public struct MessageIndex: Comparable, Hashable { + public let id: MessageId + public let timestamp: Int32 + + public init(id: MessageId, timestamp: Int32) { + self.id = id + self.timestamp = timestamp + } + + public func predecessor() -> MessageIndex { + if self.id.id != 0 { + return MessageIndex(id: MessageId(peerId: self.id.peerId, namespace: self.id.namespace, id: self.id.id - 1), timestamp: self.timestamp) + } else if self.id.namespace != 0 { + return MessageIndex(id: MessageId(peerId: self.id.peerId, namespace: self.id.namespace - 1, id: Int32.max - 1), timestamp: self.timestamp) + } else if self.timestamp != 0 { + return MessageIndex(id: MessageId(peerId: self.id.peerId, namespace: Int32(Int8.max) - 1, id: Int32.max - 1), timestamp: self.timestamp - 1) + } else { + return self + } + } + + public func successor() -> MessageIndex { + return MessageIndex(id: MessageId(peerId: self.id.peerId, namespace: self.id.namespace, id: self.id.id == Int32.max ? self.id.id : (self.id.id + 1)), timestamp: self.timestamp) + } + + public var hashValue: Int { + return self.id.hashValue + } + + public static func absoluteUpperBound() -> MessageIndex { + return MessageIndex(id: MessageId(peerId: PeerId(namespace: Int32(Int8.max), id: Int32.max), namespace: Int32(Int8.max), id: Int32.max), timestamp: Int32.max) + } + + public static func absoluteLowerBound() -> MessageIndex { + return MessageIndex(id: MessageId(peerId: PeerId(namespace: 0, id: 0), namespace: 0, id: 0), timestamp: 0) + } + + public static func lowerBound(peerId: PeerId) -> MessageIndex { + return MessageIndex(id: MessageId(peerId: peerId, namespace: 0, id: 0), timestamp: 0) + } + + public static func lowerBound(peerId: PeerId, namespace: MessageId.Namespace) -> MessageIndex { + return MessageIndex(id: MessageId(peerId: peerId, namespace: namespace, id: 0), timestamp: 0) + } + + public static func upperBound(peerId: PeerId) -> MessageIndex { + return MessageIndex(id: MessageId(peerId: peerId, namespace: Int32(Int8.max), id: Int32.max), timestamp: Int32.max) + } + + public static func upperBound(peerId: PeerId, namespace: MessageId.Namespace) -> MessageIndex { + return MessageIndex(id: MessageId(peerId: peerId, namespace: namespace, id: Int32.max), timestamp: Int32.max) + } + + public static func upperBound(peerId: PeerId, timestamp: Int32, namespace: MessageId.Namespace) -> MessageIndex { + return MessageIndex(id: MessageId(peerId: peerId, namespace: namespace, id: Int32.max), timestamp: timestamp) + } + + func withPeerId(_ peerId: PeerId) -> MessageIndex { + return MessageIndex(id: MessageId(peerId: peerId, namespace: self.id.namespace, id: self.id.id), timestamp: self.timestamp) + } + + func withNamespace(_ namespace: MessageId.Namespace) -> MessageIndex { + return MessageIndex(id: MessageId(peerId: self.id.peerId, namespace: namespace, id: self.id.id), timestamp: self.timestamp) + } + + public static func <(lhs: MessageIndex, rhs: MessageIndex) -> Bool { + if lhs.timestamp != rhs.timestamp { + return lhs.timestamp < rhs.timestamp + } + + if lhs.id.namespace != rhs.id.namespace { + return lhs.id.namespace < rhs.id.namespace + } + + return lhs.id.id < rhs.id.id + } +} diff --git a/submodules/Database/PostboxDataTypes/Sources/PeerGroupId.swift b/submodules/Database/PostboxDataTypes/Sources/PeerGroupId.swift new file mode 100644 index 0000000000..5a3c56a135 --- /dev/null +++ b/submodules/Database/PostboxDataTypes/Sources/PeerGroupId.swift @@ -0,0 +1,23 @@ +import Foundation + +public enum PeerGroupId: Hashable, Equatable, RawRepresentable { + case root + case group(Int32) + + public var rawValue: Int32 { + switch self { + case .root: + return 0 + case let .group(id): + return id + } + } + + public init(rawValue: Int32) { + if rawValue == 0 { + self = .root + } else { + self = .group(rawValue) + } + } +} diff --git a/submodules/Database/PostboxDataTypes/Sources/PeerId.swift b/submodules/Database/PostboxDataTypes/Sources/PeerId.swift new file mode 100644 index 0000000000..30b8e02cda --- /dev/null +++ b/submodules/Database/PostboxDataTypes/Sources/PeerId.swift @@ -0,0 +1,89 @@ +import Foundation +import Buffers + +public struct PeerId: Hashable, CustomStringConvertible, Comparable { + public typealias Namespace = Int32 + public typealias Id = Int32 + + public let namespace: Namespace + public let id: Id + + public init(namespace: Namespace, id: Id) { + self.namespace = namespace + self.id = id + } + + public init(_ n: Int64) { + self.namespace = Int32((n >> 32) & 0x7fffffff) + self.id = Int32(bitPattern: UInt32(n & 0xffffffff)) + } + + public func toInt64() -> Int64 { + return (Int64(self.namespace) << 32) | Int64(bitPattern: UInt64(UInt32(bitPattern: self.id))) + } + + public static func encodeArrayToBuffer(_ array: [PeerId], buffer: WriteBuffer) { + var length: Int32 = Int32(array.count) + buffer.write(&length, offset: 0, length: 4) + for id in array { + var value = id.toInt64() + buffer.write(&value, offset: 0, length: 8) + } + } + + public static func decodeArrayFromBuffer(_ buffer: ReadBuffer) -> [PeerId] { + var length: Int32 = 0 + memcpy(&length, buffer.memory, 4) + buffer.offset += 4 + var i = 0 + var array: [PeerId] = [] + array.reserveCapacity(Int(length)) + while i < Int(length) { + var value: Int64 = 0 + buffer.read(&value, offset: 0, length: 8) + array.append(PeerId(value)) + i += 1 + } + return array + } + + public var hashValue: Int { + get { + return Int(self.id) + } + } + + public var description: String { + get { + return "\(namespace):\(id)" + } + } + + public init(_ buffer: ReadBuffer) { + var namespace: Int32 = 0 + var id: Int32 = 0 + memcpy(&namespace, buffer.memory, 4) + self.namespace = namespace + memcpy(&id, buffer.memory + 4, 4) + self.id = id + } + + public func encodeToBuffer(_ buffer: WriteBuffer) { + var namespace = self.namespace + var id = self.id + buffer.write(&namespace, offset: 0, length: 4); + buffer.write(&id, offset: 0, length: 4); + } + + public static func <(lhs: PeerId, rhs: PeerId) -> Bool { + if lhs.namespace != rhs.namespace { + return lhs.namespace < rhs.namespace + } + + if lhs.id != rhs.id { + return lhs.id < rhs.id + } + + return false + } +} diff --git a/submodules/Database/PostboxDataTypes/Sources/PeerReadState.swift b/submodules/Database/PostboxDataTypes/Sources/PeerReadState.swift new file mode 100644 index 0000000000..64b0499ed9 --- /dev/null +++ b/submodules/Database/PostboxDataTypes/Sources/PeerReadState.swift @@ -0,0 +1,152 @@ + +public enum PeerReadState: Equatable, CustomStringConvertible { + case idBased(maxIncomingReadId: MessageId.Id, maxOutgoingReadId: MessageId.Id, maxKnownId: MessageId.Id, count: Int32, markedUnread: Bool) + case indexBased(maxIncomingReadIndex: MessageIndex, maxOutgoingReadIndex: MessageIndex, count: Int32, markedUnread: Bool) + + public var count: Int32 { + switch self { + case let .idBased(_, _, _, count, _): + return count + case let .indexBased(_, _, count, _): + return count + } + } + + public var maxKnownId: MessageId.Id? { + switch self { + case let .idBased(_, _, maxKnownId, _, _): + return maxKnownId + case .indexBased: + return nil + } + } + + + public var isUnread: Bool { + switch self { + case let .idBased(_, _, _, count, markedUnread): + return count > 0 || markedUnread + case let .indexBased(_, _, count, markedUnread): + return count > 0 || markedUnread + } + } + + public var markedUnread: Bool { + switch self { + case let .idBased(_, _, _, _, markedUnread): + return markedUnread + case let .indexBased(_, _, _, markedUnread): + return markedUnread + } + } + + public func withAddedCount(_ value: Int32) -> PeerReadState { + switch self { + case let .idBased(maxIncomingReadId, maxOutgoingReadId, maxKnownId, count, markedUnread): + return .idBased(maxIncomingReadId: maxIncomingReadId, maxOutgoingReadId: maxOutgoingReadId, maxKnownId: maxKnownId, count: count + value, markedUnread: markedUnread) + case let .indexBased(maxIncomingReadIndex, maxOutgoingReadIndex, count, markedUnread): + return .indexBased(maxIncomingReadIndex: maxIncomingReadIndex, maxOutgoingReadIndex: maxOutgoingReadIndex, count: count + value, markedUnread: markedUnread) + } + } + + public var description: String { + switch self { + case let .idBased(maxIncomingReadId, maxOutgoingReadId, maxKnownId, count, markedUnread): + return "(PeerReadState maxIncomingReadId: \(maxIncomingReadId), maxOutgoingReadId: \(maxOutgoingReadId) maxKnownId: \(maxKnownId), count: \(count), markedUnread: \(markedUnread)" + case let .indexBased(maxIncomingReadIndex, maxOutgoingReadIndex, count, markedUnread): + return "(PeerReadState maxIncomingReadIndex: \(maxIncomingReadIndex), maxOutgoingReadIndex: \(maxOutgoingReadIndex), count: \(count), markedUnread: \(markedUnread)" + } + } + + public func isIncomingMessageIndexRead(_ index: MessageIndex) -> Bool { + switch self { + case let .idBased(maxIncomingReadId, _, _, _, _): + return maxIncomingReadId >= index.id.id + case let .indexBased(maxIncomingReadIndex, _, _, _): + return maxIncomingReadIndex >= index + } + } + + public func isOutgoingMessageIndexRead(_ index: MessageIndex) -> Bool { + switch self { + case let .idBased(_, maxOutgoingReadId, _, _, _): + return maxOutgoingReadId >= index.id.id + case let .indexBased(_, maxOutgoingReadIndex, _, _): + return maxOutgoingReadIndex >= index + } + } +} + +public struct CombinedPeerReadState: Equatable { + public var states: [(MessageId.Namespace, PeerReadState)] + + public init(states: [(MessageId.Namespace, PeerReadState)]) { + self.states = states + } + + public var count: Int32 { + var result: Int32 = 0 + for (_, state) in self.states { + result += state.count + } + return result + } + + public var markedUnread: Bool { + for (_, state) in self.states { + if state.markedUnread { + return true + } + } + return false + } + + public var isUnread: Bool { + for (_, state) in self.states { + if state.isUnread { + return true + } + } + return false + } + + public static func ==(lhs: CombinedPeerReadState, rhs: CombinedPeerReadState) -> Bool { + if lhs.states.count != rhs.states.count { + return false + } + for (lhsNamespace, lhsState) in lhs.states { + var rhsFound = false + inner: for (rhsNamespace, rhsState) in rhs.states { + if rhsNamespace == lhsNamespace { + if lhsState != rhsState { + return false + } + rhsFound = true + break inner + } + } + if !rhsFound { + return false + } + } + return true + } + + public func isOutgoingMessageIndexRead(_ index: MessageIndex) -> Bool { + for (namespace, readState) in self.states { + if namespace == index.id.namespace { + return readState.isOutgoingMessageIndexRead(index) + } + } + return false + } + + public func isIncomingMessageIndexRead(_ index: MessageIndex) -> Bool { + for (namespace, readState) in self.states { + if namespace == index.id.namespace { + return readState.isIncomingMessageIndexRead(index) + } + } + return false + } +} diff --git a/submodules/Database/PostboxDataTypes/Sources/PeerSummaryCounterTags.swift b/submodules/Database/PostboxDataTypes/Sources/PeerSummaryCounterTags.swift new file mode 100644 index 0000000000..3ba8967f06 --- /dev/null +++ b/submodules/Database/PostboxDataTypes/Sources/PeerSummaryCounterTags.swift @@ -0,0 +1,32 @@ +import Foundation + +public struct PeerSummaryCounterTags: OptionSet, Sequence, Hashable { + public var rawValue: Int32 + + public init(rawValue: Int32) { + self.rawValue = rawValue + } + + public init() { + self.rawValue = 0 + } + + public func makeIterator() -> AnyIterator { + var index = 0 + return AnyIterator { () -> PeerSummaryCounterTags? in + while index < 31 { + let currentTags = self.rawValue >> UInt32(index) + let tag = PeerSummaryCounterTags(rawValue: 1 << UInt32(index)) + index += 1 + if currentTags == 0 { + break + } + + if (currentTags & 1) != 0 { + return tag + } + } + return nil + } + } +} diff --git a/submodules/Database/Table/BUCK b/submodules/Database/Table/BUCK new file mode 100644 index 0000000000..fd0bf366e4 --- /dev/null +++ b/submodules/Database/Table/BUCK @@ -0,0 +1,14 @@ +load("//Config:buck_rule_macros.bzl", "static_library") + +static_library( + name = "Table", + srcs = glob([ + "Sources/**/*.swift", + ]), + deps = [ + "//submodules/Database/ValueBox:ValueBox", + ], + frameworks = [ + "$SDKROOT/System/Library/Frameworks/Foundation.framework", + ], +) diff --git a/submodules/Database/Table/Sources/Table.swift b/submodules/Database/Table/Sources/Table.swift new file mode 100644 index 0000000000..5ee6cf9ceb --- /dev/null +++ b/submodules/Database/Table/Sources/Table.swift @@ -0,0 +1,18 @@ +import Foundation +import ValueBox + +open class Table { + public final let valueBox: ValueBox + public final let table: ValueBoxTable + + public init(valueBox: ValueBox, table: ValueBoxTable) { + self.valueBox = valueBox + self.table = table + } + + open func clearMemoryCache() { + } + + open func beforeCommit() { + } +} diff --git a/submodules/Database/ValueBox/BUCK b/submodules/Database/ValueBox/BUCK new file mode 100644 index 0000000000..a12e9c0e07 --- /dev/null +++ b/submodules/Database/ValueBox/BUCK @@ -0,0 +1,16 @@ +load("//Config:buck_rule_macros.bzl", "static_library") + +static_library( + name = "ValueBox", + srcs = glob([ + "Sources/**/*.swift", + ]), + deps = [ + "//submodules/SSignalKit/SwiftSignalKit:SwiftSignalKit#shared", + "//submodules/sqlcipher:sqlcipher", + "//submodules/Database/Buffers:Buffers", + ], + frameworks = [ + "$SDKROOT/System/Library/Frameworks/Foundation.framework", + ], +) diff --git a/submodules/Database/ValueBox/Sources/Database.swift b/submodules/Database/ValueBox/Sources/Database.swift new file mode 100644 index 0000000000..33d0639b8c --- /dev/null +++ b/submodules/Database/ValueBox/Sources/Database.swift @@ -0,0 +1,72 @@ +// +// SQLite.swift +// https://github.com/stephencelis/SQLite.swift +// Copyright (c) 2014-2015 Stephen Celis. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. +// + +import Foundation +#if os(macOS) +import sqlciphermac +#else +import sqlcipher +#endif + +final class Database { + internal var handle: OpaquePointer? = nil + + init?(logger: ValueBoxLogger, location: String) { + if location != ":memory:" { + let _ = open(location + "-guard", O_WRONLY | O_CREAT | O_APPEND, S_IRUSR | S_IWUSR) + } + let flags = SQLITE_OPEN_CREATE | SQLITE_OPEN_READWRITE | SQLITE_OPEN_FULLMUTEX + let res = sqlite3_open_v2(location, &self.handle, flags, nil) + if res != SQLITE_OK { + logger.log("sqlite3_open_v2: \(res)") + return nil + } + } + + deinit { + sqlite3_close(self.handle) + } // sqlite3_close_v2 in Yosemite/iOS 8? + + public func execute(_ SQL: String) -> Bool { + let res = sqlite3_exec(self.handle, SQL, nil, nil, nil) + if res == SQLITE_OK { + return true + } else { + if let error = sqlite3_errmsg(self.handle), let str = NSString(utf8String: error) { + print("SQL error \(res): \(str) on SQL") + } else { + print("SQL error \(res) on SQL") + } + return false + } + } + + public func currentError() -> String? { + if let error = sqlite3_errmsg(self.handle), let str = NSString(utf8String: error) { + return "SQL error \(str)" + } else { + return nil + } + } +} diff --git a/submodules/Database/ValueBox/Sources/SqliteValueBox.swift b/submodules/Database/ValueBox/Sources/SqliteValueBox.swift new file mode 100644 index 0000000000..4bfb081c1c --- /dev/null +++ b/submodules/Database/ValueBox/Sources/SqliteValueBox.swift @@ -0,0 +1,2066 @@ +import Foundation +import sqlcipher +import SwiftSignalKit +import Buffers + +private struct SqliteValueBoxTable { + let table: ValueBoxTable + let hasPrimaryKey: Bool +} + +let SQLITE_TRANSIENT = unsafeBitCast(-1, to: sqlite3_destructor_type.self) + +private func checkTableKey(_ table: ValueBoxTable, _ key: ValueBoxKey) { + switch table.keyType { + case .binary: + break + case .int64: + precondition(key.length == 8) + } +} + +struct SqlitePreparedStatement { + let logger: ValueBoxLogger + let statement: OpaquePointer? + + func bind(_ index: Int, data: UnsafeRawPointer, length: Int) { + sqlite3_bind_blob(statement, Int32(index), data, Int32(length), SQLITE_TRANSIENT) + } + + func bindText(_ index: Int, data: UnsafeRawPointer, length: Int) { + sqlite3_bind_text(statement, Int32(index), data.assumingMemoryBound(to: Int8.self), Int32(length), SQLITE_TRANSIENT) + } + + func bind(_ index: Int, number: Int64) { + sqlite3_bind_int64(statement, Int32(index), number) + } + + func bindNull(_ index: Int) { + sqlite3_bind_null(statement, Int32(index)) + } + + func bind(_ index: Int, number: Int32) { + sqlite3_bind_int(statement, Int32(index), number) + } + + func reset() { + sqlite3_reset(statement) + sqlite3_clear_bindings(statement) + } + + func step(handle: OpaquePointer?, _ initial: Bool = false, path: String?) -> Bool { + let res = sqlite3_step(statement) + if res != SQLITE_ROW && res != SQLITE_DONE { + if let error = sqlite3_errmsg(handle), let str = NSString(utf8String: error) { + self.logger.log("SQL error \(res): \(str) on step") + } else { + self.logger.log("SQL error \(res) on step") + } + + if res == SQLITE_CORRUPT { + if let path = path { + self.logger.log("Corrupted DB at step, dropping") + try? FileManager.default.removeItem(atPath: path) + preconditionFailure() + } + } + } + return res == SQLITE_ROW + } + + func tryStep(handle: OpaquePointer?, _ initial: Bool = false, path: String?) -> Bool { + let res = sqlite3_step(statement) + if res != SQLITE_ROW && res != SQLITE_DONE { + if res != SQLITE_MISUSE { + if let error = sqlite3_errmsg(handle), let str = NSString(utf8String: error) { + self.logger.log("SQL error \(res): \(str) on step") + } else { + self.logger.log("SQL error \(res) on step") + } + } + + if res == SQLITE_CORRUPT { + if let path = path { + self.logger.log("Corrupted DB at step, dropping") + try? FileManager.default.removeItem(atPath: path) + preconditionFailure() + } + } + } + return res == SQLITE_ROW || res == SQLITE_DONE + } + + func int32At(_ index: Int) -> Int32 { + return sqlite3_column_int(statement, Int32(index)) + } + + func int64At(_ index: Int) -> Int64 { + return sqlite3_column_int64(statement, Int32(index)) + } + + func valueAt(_ index: Int) -> ReadBuffer { + let valueLength = sqlite3_column_bytes(statement, Int32(index)) + let valueData = sqlite3_column_blob(statement, Int32(index)) + + let valueMemory = malloc(Int(valueLength))! + memcpy(valueMemory, valueData, Int(valueLength)) + return ReadBuffer(memory: valueMemory, length: Int(valueLength), freeWhenDone: true) + } + + func stringAt(_ index: Int) -> String? { + let valueLength = sqlite3_column_bytes(statement, Int32(index)) + if let valueData = sqlite3_column_blob(statement, Int32(index)) { + return String(data: Data(bytes: valueData, count: Int(valueLength)), encoding: .utf8) + } else { + return nil + } + } + + func keyAt(_ index: Int) -> ValueBoxKey { + let valueLength = sqlite3_column_bytes(statement, Int32(index)) + let valueData = sqlite3_column_blob(statement, Int32(index)) + + let key = ValueBoxKey(length: Int(valueLength)) + memcpy(key.memory, valueData, Int(valueLength)) + return key + } + + func int64KeyAt(_ index: Int) -> ValueBoxKey { + let value = sqlite3_column_int64(statement, Int32(index)) + + let key = ValueBoxKey(length: 8) + key.setInt64(0, value: value) + return key + } + + func int64KeyValueAt(_ index: Int) -> Int64 { + return sqlite3_column_int64(statement, Int32(index)) + } + + func destroy() { + sqlite3_finalize(statement) + } +} + +private let dabaseFileNames: [String] = [ + "db_sqlite", + "db_sqlite-shm", + "db_sqlite-wal" +] + +private struct TablePairKey: Hashable { + let table1: Int32 + let table2: Int32 +} + +public final class SqliteValueBox: ValueBox { + private let lock = NSRecursiveLock() + + fileprivate let basePath: String + private let logger: ValueBoxLogger + private let inMemory: Bool + private let disableCache: Bool + private let encryptionParameters: ValueBoxEncryptionParameters? + private let databasePath: String + private var database: Database! + private var tables: [Int32: SqliteValueBoxTable] = [:] + private var fullTextTables: [Int32: ValueBoxFullTextTable] = [:] + private var getStatements: [Int32 : SqlitePreparedStatement] = [:] + private var getRowIdStatements: [Int32 : SqlitePreparedStatement] = [:] + private var rangeKeyAscStatementsLimit: [Int32 : SqlitePreparedStatement] = [:] + private var rangeKeyAscStatementsNoLimit: [Int32 : SqlitePreparedStatement] = [:] + private var rangeKeyDescStatementsLimit: [Int32 : SqlitePreparedStatement] = [:] + private var rangeKeyDescStatementsNoLimit: [Int32 : SqlitePreparedStatement] = [:] + private var deleteRangeStatements: [Int32 : SqlitePreparedStatement] = [:] + private var rangeValueAscStatementsLimit: [Int32 : SqlitePreparedStatement] = [:] + private var rangeValueAscStatementsNoLimit: [Int32 : SqlitePreparedStatement] = [:] + private var rangeValueDescStatementsLimit: [Int32 : SqlitePreparedStatement] = [:] + private var rangeValueDescStatementsNoLimit: [Int32 : SqlitePreparedStatement] = [:] + private var scanStatements: [Int32 : SqlitePreparedStatement] = [:] + private var scanKeysStatements: [Int32 : SqlitePreparedStatement] = [:] + private var existsStatements: [Int32 : SqlitePreparedStatement] = [:] + private var updateStatements: [Int32 : SqlitePreparedStatement] = [:] + private var insertOrReplacePrimaryKeyStatements: [Int32 : SqlitePreparedStatement] = [:] + private var insertOrReplaceIndexKeyStatements: [Int32 : SqlitePreparedStatement] = [:] + private var deleteStatements: [Int32 : SqlitePreparedStatement] = [:] + private var moveStatements: [Int32 : SqlitePreparedStatement] = [:] + private var copyStatements: [TablePairKey : SqlitePreparedStatement] = [:] + private var fullTextInsertStatements: [Int32 : SqlitePreparedStatement] = [:] + private var fullTextDeleteStatements: [Int32 : SqlitePreparedStatement] = [:] + private var fullTextMatchGlobalStatements: [Int32 : SqlitePreparedStatement] = [:] + private var fullTextMatchCollectionStatements: [Int32 : SqlitePreparedStatement] = [:] + private var fullTextMatchCollectionTagsStatements: [Int32 : SqlitePreparedStatement] = [:] + + private var secureDeleteEnabled: Bool = false + + private let checkpoints = MetaDisposable() + + private let queue: Queue + + public init(basePath: String, queue: Queue, logger: ValueBoxLogger, encryptionParameters: ValueBoxEncryptionParameters?, disableCache: Bool = false, upgradeProgress: (Float) -> Void, inMemory: Bool = false) { + self.basePath = basePath + self.logger = logger + self.inMemory = inMemory + self.disableCache = disableCache + self.encryptionParameters = encryptionParameters + self.databasePath = basePath + "/db_sqlite" + self.queue = queue + self.database = self.openDatabase(encryptionParameters: encryptionParameters, upgradeProgress: upgradeProgress) + } + + deinit { + precondition(self.queue.isCurrent()) + self.clearStatements() + checkpoints.dispose() + } + + func internalClose() { + self.database = nil + } + + private func openDatabase(encryptionParameters: ValueBoxEncryptionParameters?, upgradeProgress: (Float) -> Void) -> Database { + precondition(self.queue.isCurrent()) + + checkpoints.set(nil) + lock.lock() + + let _ = try? FileManager.default.createDirectory(atPath: basePath, withIntermediateDirectories: true, attributes: nil) + let path = basePath + "/db_sqlite" + + #if DEBUG + let exists = FileManager.default.fileExists(atPath: path) + self.logger.log("Opening \(path), exists: \(exists)") + if exists { + do { + let data = try Data(contentsOf: URL(fileURLWithPath: path), options: .mappedIfSafe) + self.logger.log("\(path) size: \(data.count)") + } catch let e { + self.logger.log("Couldn't open database: \(e)") + } + } + let walExists = FileManager.default.fileExists(atPath: path + "-wal") + self.logger.log("Opening \(path)-wal, exists: \(walExists)") + if walExists { + do { + let data = try Data(contentsOf: URL(fileURLWithPath: path + "-wal"), options: .mappedIfSafe) + self.logger.log("\(path)-wal size: \(data.count)") + } catch let e { + self.logger.log("Couldn't open database: \(e)") + } + } + #endif + + var database: Database + if let result = Database(logger: self.logger, location: self.inMemory ? ":memory:" : path) { + database = result + } else { + self.logger.log("Couldn't open DB") + + let tempPath = basePath + "_test\(arc4random())" + enum TempError: Error { + case generic + } + do { + try FileManager.default.createDirectory(atPath: tempPath, withIntermediateDirectories: true, attributes: nil) + let testDatabase = Database(logger: self.logger, location: tempPath + "/test_db")! + var resultCode = testDatabase.execute("PRAGMA journal_mode=WAL") + if !resultCode { + throw TempError.generic + } + resultCode = testDatabase.execute("PRAGMA user_version=123") + if !resultCode { + throw TempError.generic + } + } catch { + let _ = try? FileManager.default.removeItem(atPath: tempPath) + self.logger.log("Don't have write access to database folder") + preconditionFailure("Don't have write access to database folder") + } + + let _ = try? FileManager.default.removeItem(atPath: path) + preconditionFailure("Couldn't open database") + } + + var resultCode: Bool = true + + resultCode = database.execute("PRAGMA cipher_plaintext_header_size=32") + assert(resultCode) + resultCode = database.execute("PRAGMA cipher_default_plaintext_header_size=32") + assert(resultCode) + + if self.isEncrypted(database) { + if let encryptionParameters = encryptionParameters { + precondition(encryptionParameters.salt.data.count == 16) + precondition(encryptionParameters.key.data.count == 32) + + let hexKey = hexString(encryptionParameters.key.data + encryptionParameters.salt.data) + + resultCode = database.execute("PRAGMA key=\"x'\(hexKey)'\"") + assert(resultCode) + + if self.isEncrypted(database) { + self.logger.log("Encryption key is invalid") + + for fileName in dabaseFileNames { + let _ = try? FileManager.default.removeItem(atPath: basePath + "/\(fileName)") + } + database = Database(logger: self.logger, location: path)! + + resultCode = database.execute("PRAGMA cipher_plaintext_header_size=32") + assert(resultCode) + resultCode = database.execute("PRAGMA cipher_default_plaintext_header_size=32") + assert(resultCode) + + resultCode = database.execute("PRAGMA key=\"x'\(hexKey)'\"") + assert(resultCode) + } + } else { + self.logger.log("Encryption key is required") + assert(false) + for fileName in dabaseFileNames { + let _ = try? FileManager.default.removeItem(atPath: basePath + "/\(fileName)") + } + database = Database(logger: self.logger, location: path)! + + resultCode = database.execute("PRAGMA cipher_plaintext_header_size=32") + assert(resultCode) + resultCode = database.execute("PRAGMA cipher_default_plaintext_header_size=32") + assert(resultCode) + } + } else if let encryptionParameters = encryptionParameters, encryptionParameters.forceEncryptionIfNoSet { + let hexKey = hexString(encryptionParameters.key.data + encryptionParameters.salt.data) + + if FileManager.default.fileExists(atPath: path) { + self.logger.log("Reencrypting database") + database = self.reencryptInPlace(database: database, encryptionParameters: encryptionParameters) + + if self.isEncrypted(database) { + self.logger.log("Reencryption failed") + + for fileName in dabaseFileNames { + let _ = try? FileManager.default.removeItem(atPath: basePath + "/\(fileName)") + } + database = Database(logger: self.logger, location: path)! + + resultCode = database.execute("PRAGMA cipher_plaintext_header_size=32") + assert(resultCode) + resultCode = database.execute("PRAGMA cipher_default_plaintext_header_size=32") + assert(resultCode) + + resultCode = database.execute("PRAGMA key=\"x'\(hexKey)'\"") + assert(resultCode) + } + } else { + precondition(encryptionParameters.salt.data.count == 16) + precondition(encryptionParameters.key.data.count == 32) + resultCode = database.execute("PRAGMA key=\"x'\(hexKey)'\"") + assert(resultCode) + + if self.isEncrypted(database) { + self.logger.log("Encryption setup failed") + //assert(false) + + for fileName in dabaseFileNames { + let _ = try? FileManager.default.removeItem(atPath: basePath + "/\(fileName)") + } + database = Database(logger: self.logger, location: path)! + + resultCode = database.execute("PRAGMA cipher_plaintext_header_size=32") + assert(resultCode) + resultCode = database.execute("PRAGMA cipher_default_plaintext_header_size=32") + assert(resultCode) + + resultCode = database.execute("PRAGMA key=\"x'\(hexKey)'\"") + assert(resultCode) + } + } + } + + sqlite3_busy_timeout(database.handle, 1000 * 10000) + + if self.disableCache { + database.execute("PRAGMA cache_size=32") + } + + resultCode = database.execute("PRAGMA mmap_size=0") + assert(resultCode) + resultCode = database.execute("PRAGMA synchronous=NORMAL") + assert(resultCode) + resultCode = database.execute("PRAGMA temp_store=MEMORY") + assert(resultCode) + resultCode = database.execute("PRAGMA journal_mode=WAL") + assert(resultCode) + resultCode = database.execute("PRAGMA cipher_memory_security = OFF") + assert(resultCode) + //resultCode = database.execute("PRAGMA wal_autocheckpoint=500") + //database.execute("PRAGMA journal_size_limit=1536") + + /*#if DEBUG + var statement: OpaquePointer? = nil + sqlite3_prepare_v2(database.handle, "PRAGMA integrity_check", -1, &statement, nil) + let preparedStatement = SqlitePreparedStatement(logger: self.logger, statement: statement) + while preparedStatement.step(handle: database.handle, path: self.databasePath) { + let value = preparedStatement.valueAt(0) + let text = String(data: Data(bytes: value.memory.assumingMemoryBound(to: UInt8.self), count: value.length), encoding: .utf8) + print("integrity_check: \(text ?? "")") + assert(text == "ok") + //let value = preparedStatement.stringAt(0) + //print("integrity_check: \(value)") + } + preparedStatement.destroy() + #endif*/ + + let _ = self.runPragma(database, "checkpoint_fullfsync = 1") + assert(self.runPragma(database, "checkpoint_fullfsync") == "1") + + self.beginInternal(database: database) + + let result = self.getUserVersion(database) + + if result < 3 { + resultCode = database.execute("CREATE TABLE __meta_fulltext_tables (name INTEGER)") + assert(resultCode) + } + + if result < 4 { + resultCode = database.execute("PRAGMA user_version=4") + assert(resultCode) + } + + for table in self.listTables(database) { + self.tables[table.table.id] = table + } + for table in self.listFullTextTables(database) { + self.fullTextTables[table.id] = table + } + + self.commitInternal(database: database) + + lock.unlock() + + return database + } + + public func beginStats() { + } + + public func endStats() { + } + + public func begin() { + precondition(self.queue.isCurrent()) + let resultCode = self.database.execute("BEGIN IMMEDIATE") + assert(resultCode) + } + + public func commit() { + precondition(self.queue.isCurrent()) + let resultCode = self.database.execute("COMMIT") + assert(resultCode) + } + + public func checkpoint() { + precondition(self.queue.isCurrent()) + let resultCode = self.database.execute("PRAGMA wal_checkpoint(PASSIVE)") + assert(resultCode) + } + + private func beginInternal(database: Database) { + precondition(self.queue.isCurrent()) + let resultCode = database.execute("BEGIN IMMEDIATE") + assert(resultCode) + } + + private func commitInternal(database: Database) { + precondition(self.queue.isCurrent()) + let resultCode = database.execute("COMMIT") + assert(resultCode) + } + + private func isEncrypted(_ database: Database) -> Bool { + var statement: OpaquePointer? = nil + let status = sqlite3_prepare_v2(database.handle, "SELECT * FROM sqlite_master LIMIT 1", -1, &statement, nil) + if status == SQLITE_NOTADB { + return true + } + let preparedStatement = SqlitePreparedStatement(logger: self.logger, statement: statement) + if !preparedStatement.tryStep(handle: database.handle, path: self.databasePath) { + preparedStatement.destroy() + return true + } + preparedStatement.destroy() + return status == SQLITE_NOTADB + } + + private func getUserVersion(_ database: Database) -> Int64 { + precondition(self.queue.isCurrent()) + var statement: OpaquePointer? = nil + let status = sqlite3_prepare_v2(database.handle, "PRAGMA user_version", -1, &statement, nil) + precondition(status == SQLITE_OK) + let preparedStatement = SqlitePreparedStatement(logger: self.logger, statement: statement) + let _ = preparedStatement.step(handle: database.handle, path: self.databasePath) + let value = preparedStatement.int64At(0) + preparedStatement.destroy() + return value + } + + private func runPragma(_ database: Database, _ pragma: String) -> String { + precondition(self.queue.isCurrent()) + var statement: OpaquePointer? = nil + let status = sqlite3_prepare_v2(database.handle, "PRAGMA \(pragma)", -1, &statement, nil) + precondition(status == SQLITE_OK) + let preparedStatement = SqlitePreparedStatement(logger: self.logger, statement: statement) + var result: String? + if preparedStatement.step(handle: database.handle, path: self.databasePath) { + result = preparedStatement.stringAt(0) + } + preparedStatement.destroy() + return result ?? "" + } + + private func listTables(_ database: Database) -> [SqliteValueBoxTable] { + precondition(self.queue.isCurrent()) + var statement: OpaquePointer? = nil + let status = sqlite3_prepare_v2(database.handle, "SELECT name, type, sql FROM sqlite_master", -1, &statement, nil) + precondition(status == SQLITE_OK) + let preparedStatement = SqlitePreparedStatement(logger: self.logger, statement: statement) + var tables: [SqliteValueBoxTable] = [] + + while preparedStatement.step(handle: database.handle, true, path: self.databasePath) { + guard let name = preparedStatement.stringAt(0) else { + assertionFailure() + continue + } + guard let type = preparedStatement.stringAt(1), type == "table" else { + continue + } + guard let sql = preparedStatement.stringAt(2) else { + assertionFailure() + continue + } + + if name.hasPrefix("t") { + if let intName = Int(String(name[name.index(after: name.startIndex)...])) { + let keyType: ValueBoxKeyType + var hasPrimaryKey = false + if sql.range(of: "(key INTEGER") != nil { + keyType = .int64 + hasPrimaryKey = true + } else if sql.range(of: "(key BLOB") != nil { + keyType = .binary + if sql.range(of: "(key BLOB PRIMARY KEY") != nil { + hasPrimaryKey = true + } + } else { + assertionFailure() + continue + } + let isCompact = sql.range(of: "WITHOUT ROWID") != nil + tables.append(SqliteValueBoxTable(table: ValueBoxTable(id: Int32(intName), keyType: keyType, compactValuesOnCreation: isCompact), hasPrimaryKey: hasPrimaryKey)) + } + } + } + preparedStatement.destroy() + + return tables + } + + private func listFullTextTables(_ database: Database) -> [ValueBoxFullTextTable] { + precondition(self.queue.isCurrent()) + var statement: OpaquePointer? = nil + let status = sqlite3_prepare_v2(database.handle, "SELECT name FROM __meta_fulltext_tables", -1, &statement, nil) + assert(status == SQLITE_OK) + let preparedStatement = SqlitePreparedStatement(logger: self.logger, statement: statement) + var tables: [ValueBoxFullTextTable] = [] + + while preparedStatement.step(handle: database.handle, true, path: self.databasePath) { + let value = preparedStatement.int64At(0) + tables.append(ValueBoxFullTextTable(id: Int32(value))) + } + preparedStatement.destroy() + return tables + } + + private func checkTable(_ table: ValueBoxTable) -> SqliteValueBoxTable { + precondition(self.queue.isCurrent()) + if let currentTable = self.tables[table.id] { + precondition(currentTable.table.keyType == table.keyType) + return currentTable + } else { + self.createTable(database: self.database, table: table) + let resultTable = SqliteValueBoxTable(table: table, hasPrimaryKey: true) + self.tables[table.id] = resultTable + return resultTable + } + } + + private func createTable(database: Database, table: ValueBoxTable) { + switch table.keyType { + case .binary: + var resultCode: Bool + var createStatement = "CREATE TABLE t\(table.id) (key BLOB PRIMARY KEY, value BLOB)" + if table.compactValuesOnCreation { + createStatement += " WITHOUT ROWID" + } + resultCode = database.execute(createStatement) + assert(resultCode) + case .int64: + let resultCode = database.execute("CREATE TABLE t\(table.id) (key INTEGER PRIMARY KEY, value BLOB)") + assert(resultCode) + } + } + + private func checkFullTextTable(_ table: ValueBoxFullTextTable) { + precondition(self.queue.isCurrent()) + if let _ = self.fullTextTables[table.id] { + } else { + var resultCode = self.database.execute("CREATE VIRTUAL TABLE ft\(table.id) USING fts5(collectionId, itemId, contents, tags)") + precondition(resultCode) + self.fullTextTables[table.id] = table + resultCode = self.database.execute("INSERT INTO __meta_fulltext_tables(name) VALUES (\(table.id))") + precondition(resultCode) + } + } + + private func getStatement(_ table: ValueBoxTable, key: ValueBoxKey) -> SqlitePreparedStatement { + precondition(self.queue.isCurrent()) + checkTableKey(table, key) + + let resultStatement: SqlitePreparedStatement + + if let statement = self.getStatements[table.id] { + resultStatement = statement + } else { + var statement: OpaquePointer? = nil + let status = sqlite3_prepare_v2(self.database.handle, "SELECT value FROM t\(table.id) WHERE key=?", -1, &statement, nil) + precondition(status == SQLITE_OK) + let preparedStatement = SqlitePreparedStatement(logger: self.logger, statement: statement) + self.getStatements[table.id] = preparedStatement + resultStatement = preparedStatement + } + + resultStatement.reset() + + switch table.keyType { + case .binary: + resultStatement.bind(1, data: key.memory, length: key.length) + case .int64: + resultStatement.bind(1, number: key.getInt64(0)) + } + + return resultStatement + } + + private func getRowIdStatement(_ table: ValueBoxTable, key: ValueBoxKey) -> SqlitePreparedStatement { + precondition(self.queue.isCurrent()) + checkTableKey(table, key) + + let resultStatement: SqlitePreparedStatement + + if let statement = self.getRowIdStatements[table.id] { + resultStatement = statement + } else { + var statement: OpaquePointer? = nil + let status = sqlite3_prepare_v2(self.database.handle, "SELECT rowid FROM t\(table.id) WHERE key=?", -1, &statement, nil) + precondition(status == SQLITE_OK) + let preparedStatement = SqlitePreparedStatement(logger: self.logger, statement: statement) + self.getRowIdStatements[table.id] = preparedStatement + resultStatement = preparedStatement + } + + resultStatement.reset() + + switch table.keyType { + case .binary: + resultStatement.bind(1, data: key.memory, length: key.length) + case .int64: + resultStatement.bind(1, number: key.getInt64(0)) + } + + return resultStatement + } + + private func rangeKeyAscStatementLimit(_ table: ValueBoxTable, start: ValueBoxKey, end: ValueBoxKey, limit: Int) -> SqlitePreparedStatement { + precondition(self.queue.isCurrent()) + checkTableKey(table, start) + checkTableKey(table, end) + + let resultStatement: SqlitePreparedStatement + + if let statement = self.rangeKeyAscStatementsLimit[table.id] { + resultStatement = statement + } else { + var statement: OpaquePointer? = nil + let status = sqlite3_prepare_v2(self.database.handle, "SELECT key FROM t\(table.id) WHERE key > ? AND key < ? ORDER BY key ASC LIMIT ?", -1, &statement, nil) + precondition(status == SQLITE_OK) + let preparedStatement = SqlitePreparedStatement(logger: self.logger, statement: statement) + self.rangeKeyAscStatementsLimit[table.id] = preparedStatement + resultStatement = preparedStatement + } + resultStatement.reset() + + switch table.keyType { + case .binary: + resultStatement.bind(1, data: start.memory, length: start.length) + resultStatement.bind(2, data: end.memory, length: end.length) + case .int64: + resultStatement.bind(1, number: start.getInt64(0)) + resultStatement.bind(2, number: end.getInt64(0)) + } + resultStatement.bind(3, number: Int32(limit)) + + return resultStatement + } + + private func rangeKeyAscStatementNoLimit(_ table: ValueBoxTable, start: ValueBoxKey, end: ValueBoxKey) -> + SqlitePreparedStatement { + precondition(self.queue.isCurrent()) + checkTableKey(table, start) + checkTableKey(table, end) + + let resultStatement: SqlitePreparedStatement + + if let statement = self.rangeKeyAscStatementsNoLimit[table.id] { + resultStatement = statement + } else { + var statement: OpaquePointer? = nil + let status = sqlite3_prepare_v2(self.database.handle, "SELECT key FROM t\(table.id) WHERE key > ? AND key < ? ORDER BY key ASC", -1, &statement, nil) + precondition(status == SQLITE_OK) + let preparedStatement = SqlitePreparedStatement(logger: self.logger, statement: statement) + self.rangeKeyAscStatementsNoLimit[table.id] = preparedStatement + resultStatement = preparedStatement + } + + resultStatement.reset() + + switch table.keyType { + case .binary: + resultStatement.bind(1, data: start.memory, length: start.length) + resultStatement.bind(2, data: end.memory, length: end.length) + case .int64: + resultStatement.bind(1, number: start.getInt64(0)) + resultStatement.bind(2, number: end.getInt64(0)) + } + + return resultStatement + } + + private func rangeKeyDescStatementLimit(_ table: ValueBoxTable, start: ValueBoxKey, end: ValueBoxKey, limit: Int) -> SqlitePreparedStatement { + precondition(self.queue.isCurrent()) + let resultStatement: SqlitePreparedStatement + checkTableKey(table, start) + checkTableKey(table, end) + + if let statement = self.rangeKeyDescStatementsLimit[table.id] { + resultStatement = statement + } else { + var statement: OpaquePointer? = nil + let status = sqlite3_prepare_v2(self.database.handle, "SELECT key FROM t\(table.id) WHERE key > ? AND key < ? ORDER BY key DESC LIMIT ?", -1, &statement, nil) + precondition(status == SQLITE_OK) + let preparedStatement = SqlitePreparedStatement(logger: self.logger, statement: statement) + self.rangeKeyDescStatementsLimit[table.id] = preparedStatement + resultStatement = preparedStatement + } + + resultStatement.reset() + + switch table.keyType { + case .binary: + resultStatement.bind(1, data: start.memory, length: start.length) + resultStatement.bind(2, data: end.memory, length: end.length) + case .int64: + resultStatement.bind(1, number: start.getInt64(0)) + resultStatement.bind(2, number: end.getInt64(0)) + } + resultStatement.bind(3, number: Int32(limit)) + + return resultStatement + } + + private func rangeKeyDescStatementNoLimit(_ table: ValueBoxTable, start: ValueBoxKey, end: ValueBoxKey) -> SqlitePreparedStatement { + precondition(self.queue.isCurrent()) + let resultStatement: SqlitePreparedStatement + checkTableKey(table, start) + checkTableKey(table, end) + + if let statement = self.rangeKeyDescStatementsNoLimit[table.id] { + resultStatement = statement + } else { + var statement: OpaquePointer? = nil + let status = sqlite3_prepare_v2(self.database.handle, "SELECT key FROM t\(table.id) WHERE key > ? AND key < ? ORDER BY key DESC", -1, &statement, nil) + precondition(status == SQLITE_OK) + let preparedStatement = SqlitePreparedStatement(logger: self.logger, statement: statement) + self.rangeKeyDescStatementsNoLimit[table.id] = preparedStatement + resultStatement = preparedStatement + } + + resultStatement.reset() + + switch table.keyType { + case .binary: + resultStatement.bind(1, data: start.memory, length: start.length) + resultStatement.bind(2, data: end.memory, length: end.length) + case .int64: + resultStatement.bind(1, number: start.getInt64(0)) + resultStatement.bind(2, number: end.getInt64(0)) + } + + return resultStatement + } + + private func rangeDeleteStatement(_ table: ValueBoxTable, start: ValueBoxKey, end: ValueBoxKey) -> SqlitePreparedStatement { + precondition(self.queue.isCurrent()) + let resultStatement: SqlitePreparedStatement + checkTableKey(table, start) + checkTableKey(table, end) + precondition(start <= end) + + if let statement = self.deleteRangeStatements[table.id] { + resultStatement = statement + } else { + var statement: OpaquePointer? = nil + let status = sqlite3_prepare_v2(self.database.handle, "DELETE FROM t\(table.id) WHERE key >= ? AND key <= ?", -1, &statement, nil) + precondition(status == SQLITE_OK) + let preparedStatement = SqlitePreparedStatement(logger: self.logger, statement: statement) + self.deleteRangeStatements[table.id] = preparedStatement + resultStatement = preparedStatement + } + + resultStatement.reset() + + switch table.keyType { + case .binary: + resultStatement.bind(1, data: start.memory, length: start.length) + resultStatement.bind(2, data: end.memory, length: end.length) + case .int64: + resultStatement.bind(1, number: start.getInt64(0)) + resultStatement.bind(2, number: end.getInt64(0)) + } + + return resultStatement + } + + private func rangeValueAscStatementLimit(_ table: ValueBoxTable, start: ValueBoxKey, end: ValueBoxKey, limit: Int) -> SqlitePreparedStatement { + precondition(self.queue.isCurrent()) + checkTableKey(table, start) + checkTableKey(table, end) + + let resultStatement: SqlitePreparedStatement + + if let statement = self.rangeValueAscStatementsLimit[table.id] { + resultStatement = statement + } else { + var statement: OpaquePointer? = nil + let status = sqlite3_prepare_v2(self.database.handle, "SELECT key, value FROM t\(table.id) WHERE key > ? AND key < ? ORDER BY key ASC LIMIT ?", -1, &statement, nil) + precondition(status == SQLITE_OK) + let preparedStatement = SqlitePreparedStatement(logger: self.logger, statement: statement) + self.rangeValueAscStatementsLimit[table.id] = preparedStatement + resultStatement = preparedStatement + } + resultStatement.reset() + + switch table.keyType { + case .binary: + resultStatement.bind(1, data: start.memory, length: start.length) + resultStatement.bind(2, data: end.memory, length: end.length) + case .int64: + resultStatement.bind(1, number: start.getInt64(0)) + resultStatement.bind(2, number: end.getInt64(0)) + } + resultStatement.bind(3, number: Int32(limit)) + + return resultStatement + } + + private func rangeValueAscStatementNoLimit(_ table: ValueBoxTable, start: ValueBoxKey, end: ValueBoxKey) -> SqlitePreparedStatement { + precondition(self.queue.isCurrent()) + checkTableKey(table, start) + checkTableKey(table, end) + + let resultStatement: SqlitePreparedStatement + + if let statement = self.rangeValueAscStatementsNoLimit[table.id] { + resultStatement = statement + } else { + var statement: OpaquePointer? = nil + let status = sqlite3_prepare_v2(self.database.handle, "SELECT key, value FROM t\(table.id) WHERE key > ? AND key < ? ORDER BY key ASC", -1, &statement, nil) + precondition(status == SQLITE_OK) + let preparedStatement = SqlitePreparedStatement(logger: self.logger, statement: statement) + self.rangeValueAscStatementsNoLimit[table.id] = preparedStatement + resultStatement = preparedStatement + } + + resultStatement.reset() + + switch table.keyType { + case .binary: + resultStatement.bind(1, data: start.memory, length: start.length) + resultStatement.bind(2, data: end.memory, length: end.length) + case .int64: + resultStatement.bind(1, number: start.getInt64(0)) + resultStatement.bind(2, number: end.getInt64(0)) + } + + return resultStatement + } + + private func rangeValueDescStatementLimit(_ table: ValueBoxTable, start: ValueBoxKey, end: ValueBoxKey, limit: Int) -> SqlitePreparedStatement { + precondition(self.queue.isCurrent()) + checkTableKey(table, start) + checkTableKey(table, end) + + let resultStatement: SqlitePreparedStatement + + if let statement = self.rangeValueDescStatementsLimit[table.id] { + resultStatement = statement + } else { + var statement: OpaquePointer? = nil + let status = sqlite3_prepare_v2(self.database.handle, "SELECT key, value FROM t\(table.id) WHERE key > ? AND key < ? ORDER BY key DESC LIMIT ?", -1, &statement, nil) + precondition(status == SQLITE_OK) + let preparedStatement = SqlitePreparedStatement(logger: self.logger, statement: statement) + self.rangeValueDescStatementsLimit[table.id] = preparedStatement + resultStatement = preparedStatement + } + + resultStatement.reset() + + switch table.keyType { + case .binary: + resultStatement.bind(1, data: start.memory, length: start.length) + resultStatement.bind(2, data: end.memory, length: end.length) + case .int64: + resultStatement.bind(1, number: start.getInt64(0)) + resultStatement.bind(2, number: end.getInt64(0)) + } + resultStatement.bind(3, number: Int32(limit)) + + return resultStatement + } + + private func rangeValueDescStatementNoLimit(_ table: ValueBoxTable, start: ValueBoxKey, end: ValueBoxKey) -> SqlitePreparedStatement { + precondition(self.queue.isCurrent()) + checkTableKey(table, start) + checkTableKey(table, end) + + let resultStatement: SqlitePreparedStatement + + if let statement = self.rangeValueDescStatementsNoLimit[table.id] { + resultStatement = statement + } else { + var statement: OpaquePointer? = nil + let status = sqlite3_prepare_v2(self.database.handle, "SELECT key, value FROM t\(table.id) WHERE key > ? AND key < ? ORDER BY key DESC", -1, &statement, nil) + precondition(status == SQLITE_OK) + let preparedStatement = SqlitePreparedStatement(logger: self.logger, statement: statement) + self.rangeValueDescStatementsNoLimit[table.id] = preparedStatement + resultStatement = preparedStatement + } + + resultStatement.reset() + + switch table.keyType { + case .binary: + resultStatement.bind(1, data: start.memory, length: start.length) + resultStatement.bind(2, data: end.memory, length: end.length) + case .int64: + resultStatement.bind(1, number: start.getInt64(0)) + resultStatement.bind(2, number: end.getInt64(0)) + } + + return resultStatement + } + + private func scanStatement(_ table: ValueBoxTable) -> SqlitePreparedStatement { + precondition(self.queue.isCurrent()) + + let resultStatement: SqlitePreparedStatement + + if let statement = self.scanStatements[table.id] { + resultStatement = statement + } else { + var statement: OpaquePointer? = nil + let status = sqlite3_prepare_v2(self.database.handle, "SELECT key, value FROM t\(table.id) ORDER BY key ASC", -1, &statement, nil) + precondition(status == SQLITE_OK) + let preparedStatement = SqlitePreparedStatement(logger: self.logger, statement: statement) + self.scanStatements[table.id] = preparedStatement + resultStatement = preparedStatement + } + + resultStatement.reset() + + return resultStatement + } + + private func scanKeysStatement(_ table: ValueBoxTable) -> SqlitePreparedStatement { + precondition(self.queue.isCurrent()) + + let resultStatement: SqlitePreparedStatement + + if let statement = self.scanKeysStatements[table.id] { + resultStatement = statement + } else { + var statement: OpaquePointer? = nil + let status = sqlite3_prepare_v2(self.database.handle, "SELECT key FROM t\(table.id) ORDER BY key ASC", -1, &statement, nil) + precondition(status == SQLITE_OK) + let preparedStatement = SqlitePreparedStatement(logger: self.logger, statement: statement) + self.scanKeysStatements[table.id] = preparedStatement + resultStatement = preparedStatement + } + + resultStatement.reset() + + return resultStatement + } + + private func existsStatement(_ table: ValueBoxTable, key: ValueBoxKey) -> SqlitePreparedStatement { + precondition(self.queue.isCurrent()) + checkTableKey(table, key) + + let resultStatement: SqlitePreparedStatement + + if let statement = self.existsStatements[table.id] { + resultStatement = statement + } else { + var statement: OpaquePointer? = nil + let status = sqlite3_prepare_v2(self.database.handle, "SELECT rowid FROM t\(table.id) WHERE key=?", -1, &statement, nil) + precondition(status == SQLITE_OK) + let preparedStatement = SqlitePreparedStatement(logger: self.logger, statement: statement) + self.existsStatements[table.id] = preparedStatement + resultStatement = preparedStatement + } + + resultStatement.reset() + + switch table.keyType { + case .binary: + resultStatement.bind(1, data: key.memory, length: key.length) + case .int64: + resultStatement.bind(1, number: key.getInt64(0)) + } + + return resultStatement + } + + private func updateStatement(_ table: ValueBoxTable, key: ValueBoxKey, value: MemoryBuffer) -> SqlitePreparedStatement { + precondition(self.queue.isCurrent()) + checkTableKey(table, key) + + let resultStatement: SqlitePreparedStatement + + if let statement = self.updateStatements[table.id] { + resultStatement = statement + } else { + var statement: OpaquePointer? = nil + let status = sqlite3_prepare_v2(self.database.handle, "UPDATE t\(table.id) SET value=? WHERE key=?", -1, &statement, nil) + precondition(status == SQLITE_OK) + let preparedStatement = SqlitePreparedStatement(logger: self.logger, statement: statement) + self.updateStatements[table.id] = preparedStatement + resultStatement = preparedStatement + } + + resultStatement.reset() + + resultStatement.bind(1, data: value.memory, length: value.length) + switch table.keyType { + case .binary: + resultStatement.bind(2, data: key.memory, length: key.length) + case .int64: + resultStatement.bind(2, number: key.getInt64(0)) + } + + return resultStatement + } + + private func insertOrReplaceStatement(_ table: SqliteValueBoxTable, key: ValueBoxKey, value: MemoryBuffer) -> SqlitePreparedStatement { + precondition(self.queue.isCurrent()) + checkTableKey(table.table, key) + + let resultStatement: SqlitePreparedStatement + + if table.table.keyType == .int64 || table.hasPrimaryKey { + if let statement = self.insertOrReplacePrimaryKeyStatements[table.table.id] { + resultStatement = statement + } else { + var statement: OpaquePointer? = nil + let status = sqlite3_prepare_v2(self.database.handle, "INSERT INTO t\(table.table.id) (key, value) VALUES(?, ?) ON CONFLICT(key) DO UPDATE SET value=excluded.value", -1, &statement, nil) + if status != SQLITE_OK { + let errorText = self.database.currentError() ?? "Unknown error" + preconditionFailure(errorText) + } + let preparedStatement = SqlitePreparedStatement(logger: self.logger, statement: statement) + self.insertOrReplacePrimaryKeyStatements[table.table.id] = preparedStatement + resultStatement = preparedStatement + } + } else { + if let statement = self.insertOrReplaceIndexKeyStatements[table.table.id] { + resultStatement = statement + } else { + var statement: OpaquePointer? = nil + let status = sqlite3_prepare_v2(self.database.handle, "INSERT INTO t\(table.table.id) (key, value) VALUES(?, ?)", -1, &statement, nil) + if status != SQLITE_OK { + let errorText = self.database.currentError() ?? "Unknown error" + preconditionFailure(errorText) + } + let preparedStatement = SqlitePreparedStatement(logger: self.logger, statement: statement) + self.insertOrReplacePrimaryKeyStatements[table.table.id] = preparedStatement + resultStatement = preparedStatement + } + } + + resultStatement.reset() + + switch table.table.keyType { + case .binary: + resultStatement.bind(1, data: key.memory, length: key.length) + case .int64: + resultStatement.bind(1, number: key.getInt64(0)) + } + if value.length == 0 { + resultStatement.bindNull(2) + } else { + resultStatement.bind(2, data: value.memory, length: value.length) + } + + return resultStatement + } + + private func deleteStatement(_ table: ValueBoxTable, key: ValueBoxKey) -> SqlitePreparedStatement { + precondition(self.queue.isCurrent()) + checkTableKey(table, key) + + let resultStatement: SqlitePreparedStatement + + if let statement = self.deleteStatements[table.id] { + resultStatement = statement + } else { + var statement: OpaquePointer? = nil + let status = sqlite3_prepare_v2(self.database.handle, "DELETE FROM t\(table.id) WHERE key=?", -1, &statement, nil) + precondition(status == SQLITE_OK) + let preparedStatement = SqlitePreparedStatement(logger: self.logger, statement: statement) + self.deleteStatements[table.id] = preparedStatement + resultStatement = preparedStatement + } + + resultStatement.reset() + + switch table.keyType { + case .binary: + resultStatement.bind(1, data: key.memory, length: key.length) + case .int64: + resultStatement.bind(1, number: key.getInt64(0)) + } + + return resultStatement + } + + private func moveStatement(_ table: ValueBoxTable, from previousKey: ValueBoxKey, to updatedKey: ValueBoxKey) -> SqlitePreparedStatement { + precondition(self.queue.isCurrent()) + checkTableKey(table, previousKey) + checkTableKey(table, updatedKey) + + let resultStatement: SqlitePreparedStatement + + if let statement = self.moveStatements[table.id] { + resultStatement = statement + } else { + var statement: OpaquePointer? = nil + let status = sqlite3_prepare_v2(self.database.handle, "UPDATE t\(table.id) SET key=? WHERE key=?", -1, &statement, nil) + precondition(status == SQLITE_OK) + let preparedStatement = SqlitePreparedStatement(logger: self.logger, statement: statement) + self.moveStatements[table.id] = preparedStatement + resultStatement = preparedStatement + } + + resultStatement.reset() + + switch table.keyType { + case .binary: + resultStatement.bind(1, data: previousKey.memory, length: previousKey.length) + resultStatement.bind(2, data: updatedKey.memory, length: updatedKey.length) + case .int64: + resultStatement.bind(1, number: previousKey.getInt64(0)) + resultStatement.bind(2, number: updatedKey.getInt64(0)) + } + + return resultStatement + } + + private func copyStatement(fromTable: ValueBoxTable, fromKey: ValueBoxKey, toTable: ValueBoxTable, toKey: ValueBoxKey) -> SqlitePreparedStatement { + precondition(self.queue.isCurrent()) + let _ = checkTable(fromTable) + let _ = checkTable(toTable) + checkTableKey(fromTable, fromKey) + checkTableKey(toTable, toKey) + + let resultStatement: SqlitePreparedStatement + + if let statement = self.copyStatements[TablePairKey(table1: fromTable.id, table2: toTable.id)] { + resultStatement = statement + } else { + var statement: OpaquePointer? = nil + let status = sqlite3_prepare_v2(self.database.handle, "INSERT INTO t\(toTable.id) (key, value) SELECT ?, t\(fromTable.id).value FROM t\(fromTable.id) WHERE t\(fromTable.id).key=?", -1, &statement, nil) + precondition(status == SQLITE_OK) + let preparedStatement = SqlitePreparedStatement(logger: self.logger, statement: statement) + self.copyStatements[TablePairKey(table1: fromTable.id, table2: toTable.id)] = preparedStatement + resultStatement = preparedStatement + } + + resultStatement.reset() + + switch toTable.keyType { + case .binary: + resultStatement.bind(1, data: toKey.memory, length: toKey.length) + case .int64: + resultStatement.bind(1, number: toKey.getInt64(0)) + } + + switch fromTable.keyType { + case .binary: + resultStatement.bind(2, data: fromKey.memory, length: fromKey.length) + case .int64: + resultStatement.bind(2, number: fromKey.getInt64(0)) + } + + return resultStatement + } + + private func fullTextInsertStatement(_ table: ValueBoxFullTextTable, collectionId: Data, itemId: Data, contents: Data, tags: Data) -> SqlitePreparedStatement { + precondition(self.queue.isCurrent()) + + let resultStatement: SqlitePreparedStatement + + if let statement = self.fullTextInsertStatements[table.id] { + resultStatement = statement + } else { + var statement: OpaquePointer? = nil + let status = sqlite3_prepare_v2(self.database.handle, "INSERT INTO ft\(table.id) (collectionId, itemId, contents, tags) VALUES(?, ?, ?, ?)", -1, &statement, nil) + precondition(status == SQLITE_OK) + let preparedStatement = SqlitePreparedStatement(logger: self.logger, statement: statement) + self.fullTextInsertStatements[table.id] = preparedStatement + resultStatement = preparedStatement + } + + resultStatement.reset() + + collectionId.withUnsafeBytes { (bytes: UnsafePointer) -> Void in + resultStatement.bindText(1, data: bytes, length: collectionId.count) + } + + itemId.withUnsafeBytes { (bytes: UnsafePointer) -> Void in + resultStatement.bindText(2, data: bytes, length: itemId.count) + } + + contents.withUnsafeBytes { (bytes: UnsafePointer) -> Void in + resultStatement.bindText(3, data: bytes, length: contents.count) + } + + tags.withUnsafeBytes { (bytes: UnsafePointer) -> Void in + resultStatement.bindText(4, data: bytes, length: tags.count) + } + + return resultStatement + } + + private func fullTextDeleteStatement(_ table: ValueBoxFullTextTable, itemId: Data) -> SqlitePreparedStatement { + let resultStatement: SqlitePreparedStatement + + if let statement = self.fullTextDeleteStatements[table.id] { + resultStatement = statement + } else { + var statement: OpaquePointer? = nil + let status = sqlite3_prepare_v2(self.database.handle, "DELETE FROM ft\(table.id) WHERE itemId=?", -1, &statement, nil) + precondition(status == SQLITE_OK) + let preparedStatement = SqlitePreparedStatement(logger: self.logger, statement: statement) + self.fullTextDeleteStatements[table.id] = preparedStatement + resultStatement = preparedStatement + } + + resultStatement.reset() + + itemId.withUnsafeBytes { (bytes: UnsafePointer) -> Void in + resultStatement.bindText(1, data: bytes, length: itemId.count) + } + + return resultStatement + } + + private func fullTextMatchGlobalStatement(_ table: ValueBoxFullTextTable, contents: Data) -> SqlitePreparedStatement { + let resultStatement: SqlitePreparedStatement + + if let statement = self.fullTextMatchGlobalStatements[table.id] { + resultStatement = statement + } else { + var statement: OpaquePointer? = nil + let status = sqlite3_prepare_v2(self.database.handle, "SELECT collectionId, itemId FROM ft\(table.id) WHERE ft\(table.id) MATCH 'contents:\"' || ? || '\"'", -1, &statement, nil) + if status != SQLITE_OK { + self.printError() + assertionFailure() + } + let preparedStatement = SqlitePreparedStatement(logger: self.logger, statement: statement) + self.fullTextMatchGlobalStatements[table.id] = preparedStatement + resultStatement = preparedStatement + } + + resultStatement.reset() + + contents.withUnsafeBytes { (bytes: UnsafePointer) -> Void in + resultStatement.bindText(1, data: bytes, length: contents.count) + } + + return resultStatement + } + + private func fullTextMatchCollectionStatement(_ table: ValueBoxFullTextTable, collectionId: Data, contents: Data) -> SqlitePreparedStatement { + let resultStatement: SqlitePreparedStatement + + if let statement = self.fullTextMatchCollectionStatements[table.id] { + resultStatement = statement + } else { + var statement: OpaquePointer? = nil + let status = sqlite3_prepare_v2(self.database.handle, "SELECT collectionId, itemId FROM ft\(table.id) WHERE ft\(table.id) MATCH 'contents:\"' || ? || '\" AND collectionId:\"' || ? || '\"'", -1, &statement, nil) + precondition(status == SQLITE_OK) + let preparedStatement = SqlitePreparedStatement(logger: self.logger, statement: statement) + self.fullTextMatchCollectionStatements[table.id] = preparedStatement + resultStatement = preparedStatement + } + + resultStatement.reset() + + contents.withUnsafeBytes { (bytes: UnsafePointer) -> Void in + resultStatement.bindText(1, data: bytes, length: contents.count) + } + + collectionId.withUnsafeBytes { (bytes: UnsafePointer) -> Void in + resultStatement.bindText(2, data: bytes, length: collectionId.count) + } + + return resultStatement + } + + private func fullTextMatchCollectionTagsStatement(_ table: ValueBoxFullTextTable, collectionId: Data, contents: Data, tags: Data) -> SqlitePreparedStatement { + let resultStatement: SqlitePreparedStatement + + if let statement = self.fullTextMatchCollectionTagsStatements[table.id] { + resultStatement = statement + } else { + var statement: OpaquePointer? = nil + let status = sqlite3_prepare_v2(self.database.handle, "SELECT collectionId, itemId FROM ft\(table.id) WHERE ft\(table.id) MATCH 'contents:\"' || ? || '\" AND collectionId:\"' || ? || '\" AND tags:\"' || ? || '\"'", -1, &statement, nil) + precondition(status == SQLITE_OK) + let preparedStatement = SqlitePreparedStatement(logger: self.logger, statement: statement) + self.fullTextMatchCollectionTagsStatements[table.id] = preparedStatement + resultStatement = preparedStatement + } + + resultStatement.reset() + + contents.withUnsafeBytes { (bytes: UnsafePointer) -> Void in + resultStatement.bindText(1, data: bytes, length: contents.count) + } + + collectionId.withUnsafeBytes { (bytes: UnsafePointer) -> Void in + resultStatement.bindText(2, data: bytes, length: collectionId.count) + } + + tags.withUnsafeBytes { (bytes: UnsafePointer) -> Void in + resultStatement.bindText(3, data: bytes, length: tags.count) + } + + return resultStatement + } + + public func get(_ table: ValueBoxTable, key: ValueBoxKey) -> ReadBuffer? { + precondition(self.queue.isCurrent()) + if let _ = self.tables[table.id] { + let statement = self.getStatement(table, key: key) + + var buffer: ReadBuffer? + + while statement.step(handle: self.database.handle, path: self.databasePath) { + buffer = statement.valueAt(0) + break + } + + statement.reset() + + return buffer + } + + withExtendedLifetime(key, {}) + + return nil + } + + public func read(_ table: ValueBoxTable, key: ValueBoxKey, _ process: (Int, (UnsafeMutableRawPointer, Int, Int) -> Void) -> Void) { + precondition(self.queue.isCurrent()) + if let _ = self.tables[table.id] { + let statement = self.getRowIdStatement(table, key: key) + + if statement.step(handle: self.database.handle, path: self.databasePath) { + let rowId = statement.int64At(0) + var blobHandle: OpaquePointer? + sqlite3_blob_open(database.handle, "main", "t\(table.id)", "value", rowId, 0, &blobHandle) + if let blobHandle = blobHandle { + let length = sqlite3_blob_bytes(blobHandle) + process(Int(length), { buffer, offset, length in + sqlite3_blob_read(blobHandle, buffer, Int32(length), Int32(offset)) + }) + sqlite3_blob_close(blobHandle) + } + } + statement.reset() + } + } + + public func readWrite(_ table: ValueBoxTable, key: ValueBoxKey, _ process: (Int, (UnsafeMutableRawPointer, Int, Int) -> Void, (UnsafeRawPointer, Int, Int) -> Void) -> Void) { + if let _ = self.tables[table.id] { + let statement = self.getRowIdStatement(table, key: key) + + if statement.step(handle: self.database.handle, path: self.databasePath) { + let rowId = statement.int64At(0) + var blobHandle: OpaquePointer? + sqlite3_blob_open(database.handle, "main", "t\(table.id)", "value", rowId, 1, &blobHandle) + if let blobHandle = blobHandle { + let length = sqlite3_blob_bytes(blobHandle) + process(Int(length), { buffer, offset, length in + sqlite3_blob_read(blobHandle, buffer, Int32(length), Int32(offset)) + }, { buffer, offset, length in + sqlite3_blob_write(blobHandle, buffer, Int32(length), Int32(offset)) + }) + sqlite3_blob_close(blobHandle) + } + } + statement.reset() + } + } + + public func exists(_ table: ValueBoxTable, key: ValueBoxKey) -> Bool { + precondition(self.queue.isCurrent()) + if let _ = self.get(table, key: key) { + return true + } + return false + } + + public func range(_ table: ValueBoxTable, start: ValueBoxKey, end: ValueBoxKey, values: (ValueBoxKey, ReadBuffer) -> Bool, limit: Int) { + precondition(self.queue.isCurrent()) + if start == end { + return + } + + if let _ = self.tables[table.id] { + let statement: SqlitePreparedStatement + + switch table.keyType { + case .binary: + if start < end { + if limit <= 0 { + statement = self.rangeValueAscStatementNoLimit(table, start: start, end: end) + } else { + statement = self.rangeValueAscStatementLimit(table, start: start, end: end, limit: limit) + } + } else { + if limit <= 0 { + statement = self.rangeValueDescStatementNoLimit(table, start: end, end: start) + } else { + statement = self.rangeValueDescStatementLimit(table, start: end, end: start, limit: limit) + } + } + + while statement.step(handle: self.database.handle, path: self.databasePath) { + let key = statement.keyAt(0) + let value = statement.valueAt(1) + + if !values(key, value) { + break + } + } + + statement.reset() + case .int64: + if start.reversed < end.reversed { + if limit <= 0 { + statement = self.rangeValueAscStatementNoLimit(table, start: start, end: end) + } else { + statement = self.rangeValueAscStatementLimit(table, start: start, end: end, limit: limit) + } + } else { + if limit <= 0 { + statement = self.rangeValueDescStatementNoLimit(table, start: end, end: start) + } else { + statement = self.rangeValueDescStatementLimit(table, start: end, end: start, limit: limit) + } + } + + while statement.step(handle: self.database.handle, path: self.databasePath) { + let key = statement.int64KeyAt(0) + let value = statement.valueAt(1) + + if !values(key, value) { + break + } + } + + statement.reset() + } + } + + withExtendedLifetime(start, {}) + withExtendedLifetime(end, {}) + } + + public func range(_ table: ValueBoxTable, start: ValueBoxKey, end: ValueBoxKey, keys: (ValueBoxKey) -> Bool, limit: Int) { + precondition(self.queue.isCurrent()) + if let _ = self.tables[table.id] { + let statement: SqlitePreparedStatement + + switch table.keyType { + case .binary: + if start < end { + if limit <= 0 { + statement = self.rangeKeyAscStatementNoLimit(table, start: start, end: end) + } else { + statement = self.rangeKeyAscStatementLimit(table, start: start, end: end, limit: limit) + } + } else { + if limit <= 0 { + statement = self.rangeKeyDescStatementNoLimit(table, start: end, end: start) + } else { + statement = self.rangeKeyDescStatementLimit(table, start: end, end: start, limit: limit) + } + } + + while statement.step(handle: self.database.handle, path: self.databasePath) { + let key = statement.keyAt(0) + + if !keys(key) { + break + } + } + + statement.reset() + case .int64: + if start.reversed < end.reversed { + if limit <= 0 { + statement = self.rangeKeyAscStatementNoLimit(table, start: start, end: end) + } else { + statement = self.rangeKeyAscStatementLimit(table, start: start, end: end, limit: limit) + } + } else { + if limit <= 0 { + statement = self.rangeKeyDescStatementNoLimit(table, start: end, end: start) + } else { + statement = self.rangeKeyDescStatementLimit(table, start: end, end: start, limit: limit) + } + } + + while statement.step(handle: self.database.handle, path: self.databasePath) { + let key = statement.int64KeyAt(0) + + if !keys(key) { + break + } + } + + statement.reset() + } + } + + withExtendedLifetime(start, {}) + withExtendedLifetime(end, {}) + } + + public func scan(_ table: ValueBoxTable, values: (ValueBoxKey, ReadBuffer) -> Bool) { + precondition(self.queue.isCurrent()) + + if let _ = self.tables[table.id] { + let statement: SqlitePreparedStatement = self.scanStatement(table) + + while statement.step(handle: self.database.handle, path: self.databasePath) { + let key = statement.keyAt(0) + let value = statement.valueAt(1) + + if !values(key, value) { + break + } + } + + statement.reset() + } + } + + public func scan(_ table: ValueBoxTable, keys: (ValueBoxKey) -> Bool) { + precondition(self.queue.isCurrent()) + + if let _ = self.tables[table.id] { + let statement: SqlitePreparedStatement = self.scanKeysStatement(table) + + while statement.step(handle: self.database.handle, path: self.databasePath) { + let key = statement.keyAt(0) + + if !keys(key) { + break + } + } + + statement.reset() + } + } + + public func scanInt64(_ table: ValueBoxTable, values: (Int64, ReadBuffer) -> Bool) { + precondition(self.queue.isCurrent()) + + if let _ = self.tables[table.id] { + let statement: SqlitePreparedStatement = self.scanStatement(table) + + while statement.step(handle: self.database.handle, path: self.databasePath) { + let key = statement.int64KeyValueAt(0) + let value = statement.valueAt(1) + + if !values(key, value) { + break + } + } + + statement.reset() + } + } + + public func scanInt64(_ table: ValueBoxTable, keys: (Int64) -> Bool) { + precondition(self.queue.isCurrent()) + + if let _ = self.tables[table.id] { + let statement: SqlitePreparedStatement = self.scanKeysStatement(table) + + while statement.step(handle: self.database.handle, path: self.databasePath) { + let key = statement.int64KeyValueAt(0) + + if !keys(key) { + break + } + } + + statement.reset() + } + } + + public func set(_ table: ValueBoxTable, key: ValueBoxKey, value: MemoryBuffer) { + precondition(self.queue.isCurrent()) + let sqliteTable = self.checkTable(table) + + if sqliteTable.hasPrimaryKey { + let statement = self.insertOrReplaceStatement(sqliteTable, key: key, value: value) + while statement.step(handle: self.database.handle, path: self.databasePath) { + } + statement.reset() + } else { + if self.exists(table, key: key) { + let statement = self.updateStatement(table, key: key, value: value) + while statement.step(handle: self.database.handle, path: self.databasePath) { + } + statement.reset() + } else { + let statement = self.insertOrReplaceStatement(sqliteTable, key: key, value: value) + while statement.step(handle: self.database.handle, path: self.databasePath) { + } + statement.reset() + } + } + } + + public func remove(_ table: ValueBoxTable, key: ValueBoxKey, secure: Bool) { + precondition(self.queue.isCurrent()) + if let _ = self.tables[table.id] { + if secure != self.secureDeleteEnabled { + self.secureDeleteEnabled = secure + let result = database.execute("PRAGMA secure_delete=\(secure ? 1 : 0)") + precondition(result) + } + + let statement = self.deleteStatement(table, key: key) + while statement.step(handle: self.database.handle, path: self.databasePath) { + } + statement.reset() + } + } + + public func removeRange(_ table: ValueBoxTable, start: ValueBoxKey, end: ValueBoxKey) { + precondition(self.queue.isCurrent()) + if let _ = self.tables[table.id] { + let statement = self.rangeDeleteStatement(table, start: min(start, end), end: max(start, end)) + while statement.step(handle: self.database.handle, path: self.databasePath) { + } + statement.reset() + } + } + + public func move(_ table: ValueBoxTable, from previousKey: ValueBoxKey, to updatedKey: ValueBoxKey) { + precondition(self.queue.isCurrent()) + if let _ = self.tables[table.id] { + let statement = self.moveStatement(table, from: previousKey, to: updatedKey) + while statement.step(handle: self.database.handle, path: self.databasePath) { + } + statement.reset() + } + } + + public func copy(fromTable: ValueBoxTable, fromKey: ValueBoxKey, toTable: ValueBoxTable, toKey: ValueBoxKey) { + precondition(self.queue.isCurrent()) + if let _ = self.tables[fromTable.id] { + let statement = self.copyStatement(fromTable: fromTable, fromKey: fromKey, toTable: toTable, toKey: toKey) + while statement.step(handle: self.database.handle, path: self.databasePath) { + } + statement.reset() + } + } + + public func renameTable(_ table: ValueBoxTable, to toTable: ValueBoxTable) { + let sqliteTable = self.checkTable(table) + let resultCode = database.execute("ALTER TABLE t\(table.id) RENAME TO t\(toTable.id)") + precondition(resultCode) + self.tables[toTable.id] = SqliteValueBoxTable(table: ValueBoxTable(id: toTable.id, keyType: sqliteTable.table.keyType, compactValuesOnCreation: sqliteTable.table.compactValuesOnCreation), hasPrimaryKey: sqliteTable.hasPrimaryKey) + self.tables.removeValue(forKey: table.id) + } + + public func fullTextMatch(_ table: ValueBoxFullTextTable, collectionId: String?, query: String, tags: String?, values: (String, String) -> Bool) { + if let _ = self.fullTextTables[table.id] { + guard let queryData = query.data(using: .utf8) else { + return + } + + var statement: SqlitePreparedStatement? + if let collectionId = collectionId { + if let collectionIdData = collectionId.data(using: .utf8) { + if let tags = tags { + if let tagsData = tags.data(using: .utf8) { + statement = self.fullTextMatchCollectionTagsStatement(table, collectionId: collectionIdData, contents: queryData, tags: tagsData) + } + } else { + statement = self.fullTextMatchCollectionStatement(table, collectionId: collectionIdData, contents: queryData) + } + } + } else { + statement = self.fullTextMatchGlobalStatement(table, contents: queryData) + } + + if let statement = statement { + while statement.step(handle: self.database.handle, path: self.databasePath) { + let resultCollectionId = statement.stringAt(0) + let resultItemId = statement.stringAt(1) + + if let resultCollectionId = resultCollectionId, let resultItemId = resultItemId { + if !values(resultCollectionId, resultItemId) { + break + } + } else { + assertionFailure() + } + } + + statement.reset() + } + } + } + + public func fullTextSet(_ table: ValueBoxFullTextTable, collectionId: String, itemId: String, contents: String, tags: String) { + self.checkFullTextTable(table) + + guard let collectionIdData = collectionId.data(using: .utf8), let itemIdData = itemId.data(using: .utf8), let contentsData = contents.data(using: .utf8), let tagsData = tags.data(using: .utf8) else { + return + } + + let statement = self.fullTextInsertStatement(table, collectionId: collectionIdData, itemId: itemIdData, contents: contentsData, tags: tagsData) + while statement.step(handle: self.database.handle, path: self.databasePath) { + } + statement.reset() + } + + public func fullTextRemove(_ table: ValueBoxFullTextTable, itemId: String) { + if let _ = self.fullTextTables[table.id] { + guard let itemIdData = itemId.data(using: .utf8) else { + return + } + + let statement = self.fullTextDeleteStatement(table, itemId: itemIdData) + while statement.step(handle: self.database.handle, path: self.databasePath) { + } + statement.reset() + } + } + + public func count(_ table: ValueBoxTable, start: ValueBoxKey, end: ValueBoxKey) -> Int { + let _ = self.checkTable(table) + + var statementImpl: OpaquePointer? = nil + let status = sqlite3_prepare_v2(self.database.handle, "SELECT COUNT(*) FROM t\(table.id) WHERE key > ? AND key < ?", -1, &statementImpl, nil) + precondition(status == SQLITE_OK) + let statement = SqlitePreparedStatement(logger: self.logger, statement: statementImpl) + switch table.keyType { + case .binary: + statement.bind(1, data: start.memory, length: start.length) + case .int64: + statement.bind(1, number: start.getInt64(0)) + } + switch table.keyType { + case .binary: + statement.bind(2, data: end.memory, length: end.length) + case .int64: + statement.bind(2, number: end.getInt64(0)) + } + + var result = 0 + while statement.step(handle: database.handle, true, path: self.databasePath) { + let value = statement.int32At(0) + result = Int(value) + } + statement.reset() + statement.destroy() + return result + } + + public func count(_ table: ValueBoxTable) -> Int { + let _ = self.checkTable(table) + + var statementImpl: OpaquePointer? = nil + let status = sqlite3_prepare_v2(self.database.handle, "SELECT COUNT(*) FROM t\(table.id)", -1, &statementImpl, nil) + precondition(status == SQLITE_OK) + let statement = SqlitePreparedStatement(logger: self.logger, statement: statementImpl) + + var result = 0 + while statement.step(handle: database.handle, true, path: self.databasePath) { + let value = statement.int32At(0) + result = Int(value) + } + statement.reset() + statement.destroy() + return result + } + + private func clearStatements() { + precondition(self.queue.isCurrent()) + for (_, statement) in self.getStatements { + statement.destroy() + } + self.getStatements.removeAll() + + for (_, statement) in self.getRowIdStatements { + statement.destroy() + } + self.getRowIdStatements.removeAll() + + for (_, statement) in self.rangeKeyAscStatementsLimit { + statement.destroy() + } + self.rangeKeyAscStatementsLimit.removeAll() + + for (_, statement) in self.rangeKeyAscStatementsNoLimit { + statement.destroy() + } + self.rangeKeyAscStatementsNoLimit.removeAll() + + for (_, statement) in self.rangeKeyDescStatementsLimit { + statement.destroy() + } + self.rangeKeyDescStatementsLimit.removeAll() + + for (_, statement) in self.rangeKeyDescStatementsNoLimit { + statement.destroy() + } + self.rangeKeyDescStatementsNoLimit.removeAll() + + for (_, statement) in self.deleteRangeStatements { + statement.destroy() + } + self.deleteRangeStatements.removeAll() + + for (_, statement) in self.rangeValueAscStatementsLimit { + statement.destroy() + } + self.rangeValueAscStatementsLimit.removeAll() + + for (_, statement) in self.rangeValueAscStatementsNoLimit { + statement.destroy() + } + self.rangeValueAscStatementsNoLimit.removeAll() + + for (_, statement) in self.rangeValueDescStatementsLimit { + statement.destroy() + } + self.rangeValueDescStatementsLimit.removeAll() + + for (_, statement) in self.rangeValueDescStatementsNoLimit { + statement.destroy() + } + self.rangeValueDescStatementsNoLimit.removeAll() + + for (_, statement) in self.scanStatements { + statement.destroy() + } + self.scanStatements.removeAll() + + for (_, statement) in self.scanKeysStatements { + statement.destroy() + } + self.scanKeysStatements.removeAll() + + for (_, statement) in self.existsStatements { + statement.destroy() + } + self.existsStatements.removeAll() + + for (_, statement) in self.updateStatements { + statement.destroy() + } + self.updateStatements.removeAll() + + for (_, statement) in self.insertOrReplaceIndexKeyStatements { + statement.destroy() + } + self.insertOrReplaceIndexKeyStatements.removeAll() + + for (_, statement) in self.insertOrReplacePrimaryKeyStatements { + statement.destroy() + } + self.insertOrReplacePrimaryKeyStatements.removeAll() + + for (_, statement) in self.deleteStatements { + statement.destroy() + } + self.deleteStatements.removeAll() + + for (_, statement) in self.moveStatements { + statement.destroy() + } + self.moveStatements.removeAll() + + for (_, statement) in self.copyStatements { + statement.destroy() + } + self.copyStatements.removeAll() + + for (_, statement) in self.fullTextInsertStatements { + statement.destroy() + } + self.fullTextInsertStatements.removeAll() + + for (_, statement) in self.fullTextDeleteStatements { + statement.destroy() + } + self.fullTextDeleteStatements.removeAll() + + for (_, statement) in self.fullTextMatchGlobalStatements { + statement.destroy() + } + self.fullTextMatchGlobalStatements.removeAll() + + for (_, statement) in self.fullTextMatchCollectionStatements { + statement.destroy() + } + self.fullTextMatchCollectionStatements.removeAll() + + for (_, statement) in self.fullTextMatchCollectionTagsStatements { + statement.destroy() + } + self.fullTextMatchCollectionTagsStatements.removeAll() + } + + public func removeAllFromTable(_ table: ValueBoxTable) { + let _ = self.database.execute("DELETE FROM t\(table.id)") + } + + public func removeTable(_ table: ValueBoxTable) { + let _ = self.database.execute("DROP TABLE t\(table.id)") + self.tables.removeValue(forKey: table.id) + } + + public func drop() { + precondition(self.queue.isCurrent()) + self.clearStatements() + + self.lock.lock() + self.database = nil + self.lock.unlock() + + self.logger.log("dropping DB") + + for fileName in dabaseFileNames { + let _ = try? FileManager.default.removeItem(atPath: self.basePath + "/\(fileName)") + } + + self.database = self.openDatabase(encryptionParameters: self.encryptionParameters, upgradeProgress: { _ in }) + + tables.removeAll() + } + + private func printError() { + if let error = sqlite3_errmsg(self.database.handle), let str = NSString(utf8String: error) { + print("SQL error \(str)") + } + } + + public func exportEncrypted(to exportBasePath: String, encryptionParameters: ValueBoxEncryptionParameters) { + self.exportEncrypted(database: self.database, to: exportBasePath, encryptionParameters: encryptionParameters) + } + + private func exportEncrypted(database: Database, to exportBasePath: String, encryptionParameters: ValueBoxEncryptionParameters) { + let _ = try? FileManager.default.createDirectory(atPath: exportBasePath, withIntermediateDirectories: true, attributes: nil) + let exportFilePath = "\(exportBasePath)/db_sqlite" + + let hexKey = hexString(encryptionParameters.key.data + encryptionParameters.salt.data) + + precondition(encryptionParameters.salt.data.count == 16) + precondition(encryptionParameters.key.data.count == 32) + + var resultCode = database.execute("ATTACH DATABASE '\(exportFilePath)' AS encrypted KEY \"x'\(hexKey)'\"") + assert(resultCode) + resultCode = database.execute("SELECT sqlcipher_export('encrypted')") + assert(resultCode) + let userVersion = self.getUserVersion(database) + resultCode = database.execute("PRAGMA encrypted.user_version=\(userVersion)") + resultCode = database.execute("DETACH DATABASE encrypted") + assert(resultCode) + } + + private func reencryptInPlace(database: Database, encryptionParameters: ValueBoxEncryptionParameters) -> Database { + let targetPath = self.basePath + "/db_export" + let _ = try? FileManager.default.removeItem(atPath: targetPath) + + self.exportEncrypted(database: database, to: targetPath, encryptionParameters: encryptionParameters) + + for name in dabaseFileNames { + let _ = try? FileManager.default.removeItem(atPath: self.basePath + "/\(name)") + let _ = try? FileManager.default.moveItem(atPath: targetPath + "/\(name)", toPath: self.basePath + "/\(name)") + } + let _ = try? FileManager.default.removeItem(atPath: targetPath) + + let updatedDatabase = Database(logger: self.logger, location: self.databasePath)! + + var resultCode = updatedDatabase.execute("PRAGMA cipher_plaintext_header_size=32") + assert(resultCode) + resultCode = updatedDatabase.execute("PRAGMA cipher_default_plaintext_header_size=32") + assert(resultCode) + + let hexKey = hexString(encryptionParameters.key.data + encryptionParameters.salt.data) + + resultCode = updatedDatabase.execute("PRAGMA key=\"x'\(hexKey)'\"") + assert(resultCode) + + return updatedDatabase + } + + public func vacuum() { + var resultCode = self.database.execute("VACUUM") + precondition(resultCode) + resultCode = self.database.execute("PRAGMA wal_checkpoint(TRUNCATE)") + precondition(resultCode) + } +} + +private func hexString(_ data: Data) -> String { + let hexString = NSMutableString() + data.withUnsafeBytes { (bytes: UnsafePointer) -> Void in + for i in 0 ..< data.count { + hexString.appendFormat("%02x", UInt(bytes.advanced(by: i).pointee)) + } + } + + return hexString as String +} diff --git a/submodules/Database/ValueBox/Sources/ValueBox.swift b/submodules/Database/ValueBox/Sources/ValueBox.swift new file mode 100644 index 0000000000..18cc7faeaa --- /dev/null +++ b/submodules/Database/ValueBox/Sources/ValueBox.swift @@ -0,0 +1,95 @@ +import Foundation +import Buffers + +public enum ValueBoxKeyType: Int32 { + case binary + case int64 +} + +public struct ValueBoxTable { + let id: Int32 + let keyType: ValueBoxKeyType + let compactValuesOnCreation: Bool + + public init(id: Int32, keyType: ValueBoxKeyType, compactValuesOnCreation: Bool) { + self.id = id + self.keyType = keyType + self.compactValuesOnCreation = compactValuesOnCreation + } +} + +public struct ValueBoxFullTextTable { + let id: Int32 +} + +public struct ValueBoxEncryptionParameters { + public struct Key { + public let data: Data + + public init?(data: Data) { + if data.count == 32 { + self.data = data + } else { + return nil + } + } + } + + public struct Salt { + public let data: Data + + public init?(data: Data) { + if data.count == 16 { + self.data = data + } else { + return nil + } + } + } + + public let forceEncryptionIfNoSet: Bool + public let key: Key + public let salt: Salt + + public init(forceEncryptionIfNoSet: Bool, key: Key, salt: Salt) { + self.forceEncryptionIfNoSet = forceEncryptionIfNoSet + self.key = key + self.salt = salt + } +} + +public protocol ValueBox { + func begin() + func commit() + func checkpoint() + + func beginStats() + func endStats() + + func range(_ table: ValueBoxTable, start: ValueBoxKey, end: ValueBoxKey, values: (ValueBoxKey, ReadBuffer) -> Bool, limit: Int) + func range(_ table: ValueBoxTable, start: ValueBoxKey, end: ValueBoxKey, keys: (ValueBoxKey) -> Bool, limit: Int) + func scan(_ table: ValueBoxTable, values: (ValueBoxKey, ReadBuffer) -> Bool) + func scan(_ table: ValueBoxTable, keys: (ValueBoxKey) -> Bool) + func scanInt64(_ table: ValueBoxTable, values: (Int64, ReadBuffer) -> Bool) + func scanInt64(_ table: ValueBoxTable, keys: (Int64) -> Bool) + func get(_ table: ValueBoxTable, key: ValueBoxKey) -> ReadBuffer? + func read(_ table: ValueBoxTable, key: ValueBoxKey, _ process: (Int, (UnsafeMutableRawPointer, Int, Int) -> Void) -> Void) + func readWrite(_ table: ValueBoxTable, key: ValueBoxKey, _ process: (Int, (UnsafeMutableRawPointer, Int, Int) -> Void, (UnsafeRawPointer, Int, Int) -> Void) -> Void) + func exists(_ table: ValueBoxTable, key: ValueBoxKey) -> Bool + func set(_ table: ValueBoxTable, key: ValueBoxKey, value: MemoryBuffer) + func remove(_ table: ValueBoxTable, key: ValueBoxKey, secure: Bool) + func move(_ table: ValueBoxTable, from previousKey: ValueBoxKey, to updatedKey: ValueBoxKey) + func copy(fromTable: ValueBoxTable, fromKey: ValueBoxKey, toTable: ValueBoxTable, toKey: ValueBoxKey) + func removeRange(_ table: ValueBoxTable, start: ValueBoxKey, end: ValueBoxKey) + func fullTextSet(_ table: ValueBoxFullTextTable, collectionId: String, itemId: String, contents: String, tags: String) + func fullTextMatch(_ table: ValueBoxFullTextTable, collectionId: String?, query: String, tags: String?, values: (String, String) -> Bool) + func fullTextRemove(_ table: ValueBoxFullTextTable, itemId: String) + func removeAllFromTable(_ table: ValueBoxTable) + func removeTable(_ table: ValueBoxTable) + func renameTable(_ table: ValueBoxTable, to toTable: ValueBoxTable) + func drop() + func count(_ table: ValueBoxTable, start: ValueBoxKey, end: ValueBoxKey) -> Int + func count(_ table: ValueBoxTable) -> Int + + func exportEncrypted(to exportBasePath: String, encryptionParameters: ValueBoxEncryptionParameters) +} diff --git a/submodules/Database/ValueBox/Sources/ValueBoxKey.swift b/submodules/Database/ValueBox/Sources/ValueBoxKey.swift new file mode 100644 index 0000000000..bfe36e4062 --- /dev/null +++ b/submodules/Database/ValueBox/Sources/ValueBoxKey.swift @@ -0,0 +1,252 @@ +import Foundation +import Buffers + +private final class ValueBoxKeyImpl { + let memory: UnsafeMutableRawPointer + + init(memory: UnsafeMutableRawPointer) { + self.memory = memory + } + + deinit { + free(self.memory) + } +} + +public struct ValueBoxKey: Equatable, Hashable, CustomStringConvertible, Comparable { + public let memory: UnsafeMutableRawPointer + public let length: Int + private let impl: ValueBoxKeyImpl + + public init(length: Int) { + self.memory = malloc(length)! + self.length = length + self.impl = ValueBoxKeyImpl(memory: self.memory) + } + + public init(_ value: String) { + let data = value.data(using: .utf8, allowLossyConversion: true) ?? Data() + self.memory = malloc(data.count) + self.length = data.count + self.impl = ValueBoxKeyImpl(memory: self.memory) + data.copyBytes(to: self.memory.assumingMemoryBound(to: UInt8.self), count: data.count) + } + + public init(_ buffer: MemoryBuffer) { + self.memory = malloc(buffer.length) + self.length = buffer.length + self.impl = ValueBoxKeyImpl(memory: self.memory) + memcpy(self.memory, buffer.memory, buffer.length) + } + + public func setInt32(_ offset: Int, value: Int32) { + var bigEndianValue = Int32(bigEndian: value) + memcpy(self.memory + offset, &bigEndianValue, 4) + } + + public func setUInt32(_ offset: Int, value: UInt32) { + var bigEndianValue = UInt32(bigEndian: value) + memcpy(self.memory + offset, &bigEndianValue, 4) + } + + public func setInt64(_ offset: Int, value: Int64) { + var bigEndianValue = Int64(bigEndian: value) + memcpy(self.memory + offset, &bigEndianValue, 8) + } + + public func setInt8(_ offset: Int, value: Int8) { + var varValue = value + memcpy(self.memory + offset, &varValue, 1) + } + + public func setUInt8(_ offset: Int, value: UInt8) { + var varValue = value + memcpy(self.memory + offset, &varValue, 1) + } + + public func setUInt16(_ offset: Int, value: UInt16) { + var varValue = value + memcpy(self.memory + offset, &varValue, 2) + } + + public func getInt32(_ offset: Int) -> Int32 { + var value: Int32 = 0 + memcpy(&value, self.memory + offset, 4) + return Int32(bigEndian: value) + } + + public func getUInt32(_ offset: Int) -> UInt32 { + var value: UInt32 = 0 + memcpy(&value, self.memory + offset, 4) + return UInt32(bigEndian: value) + } + + public func getInt64(_ offset: Int) -> Int64 { + var value: Int64 = 0 + memcpy(&value, self.memory + offset, 8) + return Int64(bigEndian: value) + } + + public func getInt8(_ offset: Int) -> Int8 { + var value: Int8 = 0 + memcpy(&value, self.memory + offset, 1) + return value + } + + public func getUInt8(_ offset: Int) -> UInt8 { + var value: UInt8 = 0 + memcpy(&value, self.memory + offset, 1) + return value + } + + public func getUInt16(_ offset: Int) -> UInt16 { + var value: UInt16 = 0 + memcpy(&value, self.memory + offset, 2) + return value + } + + public func prefix(_ length: Int) -> ValueBoxKey { + assert(length <= self.length, "length <= self.length") + let key = ValueBoxKey(length: length) + memcpy(key.memory, self.memory, length) + return key + } + + public func isPrefix(to other: ValueBoxKey) -> Bool { + if self.length == 0 { + return true + } else if self.length <= other.length { + return memcmp(self.memory, other.memory, self.length) == 0 + } else { + return false + } + } + + public var reversed: ValueBoxKey { + let key = ValueBoxKey(length: self.length) + let keyMemory = key.memory.assumingMemoryBound(to: UInt8.self) + let selfMemory = self.memory.assumingMemoryBound(to: UInt8.self) + var i = self.length - 1 + while i >= 0 { + keyMemory[i] = selfMemory[self.length - 1 - i] + i -= 1 + } + return key + } + + public var successor: ValueBoxKey { + let key = ValueBoxKey(length: self.length) + memcpy(key.memory, self.memory, self.length) + let memory = key.memory.assumingMemoryBound(to: UInt8.self) + var i = self.length - 1 + while i >= 0 { + var byte = memory[i] + if byte != 0xff { + byte += 1 + memory[i] = byte + break + } else { + byte = 0 + memory[i] = byte + } + i -= 1 + } + return key + } + + public var predecessor: ValueBoxKey { + let key = ValueBoxKey(length: self.length) + memcpy(key.memory, self.memory, self.length) + let memory = key.memory.assumingMemoryBound(to: UInt8.self) + var i = self.length - 1 + while i >= 0 { + var byte = memory[i] + if byte != 0x00 { + byte -= 1 + memory[i] = byte + break + } else { + if i == 0 { + assert(self.length > 1) + let previousKey = ValueBoxKey(length: self.length - 1) + memcpy(previousKey.memory, self.memory, self.length - 1) + return previousKey + } else { + byte = 0xff + memory[i] = byte + } + } + i -= 1 + } + return key + } + + public var description: String { + let string = NSMutableString() + let memory = self.memory.assumingMemoryBound(to: UInt8.self) + for i in 0 ..< self.length { + let byte: Int = Int(memory[i]) + string.appendFormat("%02x", byte) + } + return string as String + } + + public var stringValue: String { + if let string = String(data: Data(bytes: self.memory, count: self.length), encoding: .utf8) { + return string + } else { + return "" + } + } + + public func substringValue(_ range: Range) -> String? { + return String(data: Data(bytes: self.memory.advanced(by: range.lowerBound), count: range.count), encoding: .utf8) + } + + public var hashValue: Int { + var hash = 37 + let bytes = self.memory.assumingMemoryBound(to: Int8.self) + for i in 0 ..< self.length { + hash = (hash &* 54059) ^ (Int(bytes[i]) &* 76963) + } + return hash + } + + public static func ==(lhs: ValueBoxKey, rhs: ValueBoxKey) -> Bool { + return lhs.length == rhs.length && memcmp(lhs.memory, rhs.memory, lhs.length) == 0 + } + + public static func <(lhs: ValueBoxKey, rhs: ValueBoxKey) -> Bool { + return mdb_cmp_memn(lhs.memory, lhs.length, rhs.memory, rhs.length) < 0 + } + + public func toMemoryBuffer() -> MemoryBuffer { + let data = malloc(self.length)! + memcpy(data, self.memory, self.length) + return MemoryBuffer(memory: data, capacity: self.length, length: self.length, freeWhenDone: true) + } + + public static func +(lhs: ValueBoxKey, rhs: ValueBoxKey) -> ValueBoxKey { + let result = ValueBoxKey(length: lhs.length + rhs.length) + memcpy(result.memory, lhs.memory, lhs.length) + memcpy(result.memory.advanced(by: lhs.length), rhs.memory, rhs.length) + return result + } +} + +private func mdb_cmp_memn(_ a_memory: UnsafeMutableRawPointer, _ a_length: Int, _ b_memory: UnsafeMutableRawPointer, _ b_length: Int) -> Int +{ + var diff: Int = 0 + var len_diff: Int = 0 + var len: Int = 0 + + len = a_length + len_diff = a_length - b_length + if len_diff > 0 { + len = b_length + len_diff = 1 + } + + diff = Int(memcmp(a_memory, b_memory, len)) + return diff != 0 ? diff : len_diff < 0 ? -1 : len_diff +} diff --git a/submodules/Database/ValueBox/Sources/ValueBoxLogger.swift b/submodules/Database/ValueBox/Sources/ValueBoxLogger.swift new file mode 100644 index 0000000000..5e2f3e362f --- /dev/null +++ b/submodules/Database/ValueBox/Sources/ValueBoxLogger.swift @@ -0,0 +1,5 @@ +import Foundation + +public protocol ValueBoxLogger { + func log(_ what: String) +} diff --git a/submodules/Postbox/Postbox/ChatListIndexTable.swift b/submodules/Postbox/Postbox/ChatListIndexTable.swift index 2d0c86e700..9e1a41e51e 100644 --- a/submodules/Postbox/Postbox/ChatListIndexTable.swift +++ b/submodules/Postbox/Postbox/ChatListIndexTable.swift @@ -644,14 +644,10 @@ final class ChatListIndexTable: Table { if peerId.namespace == Int32.max { return } - /*guard let peer = postbox.peerTable.get(peerId) else { - return - }*/ guard let combinedState = postbox.readStateTable.getCombinedState(peerId) else { return } - /*let notificationPeerId: PeerId = peer.associatedPeerId ?? peerId - let notificationSettings = postbox.peerNotificationSettingsTable.getEffective(notificationPeerId)*/ + let inclusion = self.get(peerId: peerId) if let (inclusionGroupId, _) = inclusion.includedIndex(peerId: peerId), inclusionGroupId == groupId { for (namespace, state) in combinedState.states { diff --git a/submodules/Postbox/Postbox/GroupMessageStatsTable.swift b/submodules/Postbox/Postbox/GroupMessageStatsTable.swift index fb478000d7..feccfe59ba 100644 --- a/submodules/Postbox/Postbox/GroupMessageStatsTable.swift +++ b/submodules/Postbox/Postbox/GroupMessageStatsTable.swift @@ -77,6 +77,16 @@ public struct PeerGroupUnreadCountersCombinedSummary: PostboxCoding, Equatable { } } +public enum ChatListTotalUnreadStateCategory: Int32 { + case filtered = 0 + case raw = 1 +} + +public enum ChatListTotalUnreadStateStats: Int32 { + case messages = 0 + case chats = 1 +} + public struct ChatListTotalUnreadState: PostboxCoding, Equatable { public var absoluteCounters: [PeerSummaryCounterTags: ChatListTotalUnreadCounters] public var filteredCounters: [PeerSummaryCounterTags: ChatListTotalUnreadCounters] diff --git a/submodules/Postbox/Postbox/MessageHistoryMetadataTable.swift b/submodules/Postbox/Postbox/MessageHistoryMetadataTable.swift index c464fa7fb0..5fe138ec64 100644 --- a/submodules/Postbox/Postbox/MessageHistoryMetadataTable.swift +++ b/submodules/Postbox/Postbox/MessageHistoryMetadataTable.swift @@ -32,16 +32,6 @@ public struct ChatListTotalUnreadCounters: PostboxCoding, Equatable { } } -public enum ChatListTotalUnreadStateCategory: Int32 { - case filtered = 0 - case raw = 1 -} - -public enum ChatListTotalUnreadStateStats: Int32 { - case messages = 0 - case chats = 1 -} - private struct InitializedChatListKey: Hashable { let groupId: PeerGroupId } diff --git a/submodules/Postbox/Postbox/Table.swift b/submodules/Postbox/Postbox/Table.swift index 9f51743dab..a94415c3b5 100644 --- a/submodules/Postbox/Postbox/Table.swift +++ b/submodules/Postbox/Postbox/Table.swift @@ -1,17 +1,17 @@ import Foundation -class Table { - final let valueBox: ValueBox - final let table: ValueBoxTable +open class Table { + public final let valueBox: ValueBox + public final let table: ValueBoxTable - init(valueBox: ValueBox, table: ValueBoxTable) { + public init(valueBox: ValueBox, table: ValueBoxTable) { self.valueBox = valueBox self.table = table } - func clearMemoryCache() { + open func clearMemoryCache() { } - func beforeCommit() { + open func beforeCommit() { } }