Various improvements

This commit is contained in:
Isaac
2024-12-06 22:15:52 +08:00
parent 1abaeddfad
commit 4e964d4546
12 changed files with 570 additions and 1995 deletions

View File

@@ -77,18 +77,18 @@ private final class FetchImpl {
let partRange: Range<Int64>
let fetchRange: Range<Int64>
let fetchedData: Data
let decryptedData: Data
let cleanData: Data
init(
partRange: Range<Int64>,
fetchRange: Range<Int64>,
fetchedData: Data,
decryptedData: Data
cleanData: Data
) {
self.partRange = partRange
self.fetchRange = fetchRange
self.fetchedData = fetchedData
self.decryptedData = decryptedData
self.cleanData = cleanData
}
}
@@ -148,6 +148,48 @@ private final class FetchImpl {
case cdn(CdnData)
}
private final class DecryptionState {
let aesKey: Data
var aesIv: Data
let decryptedSize: Int64
var offset: Int = 0
init(aesKey: Data, aesIv: Data, decryptedSize: Int64) {
self.aesKey = aesKey
self.aesIv = aesIv
self.decryptedSize = decryptedSize
}
func tryDecrypt(data: Data, offset: Int, loggingIdentifier: String) -> Data? {
if offset == self.offset {
var decryptedData = data
if self.decryptedSize == 0 {
Logger.shared.log("FetchV2", "\(loggingIdentifier): not decrypting part \(offset) ..< \(offset + data.count) (decryptedSize == 0)")
return nil
}
if decryptedData.count % 16 != 0 {
Logger.shared.log("FetchV2", "\(loggingIdentifier): not decrypting part \(offset) ..< \(offset + data.count) (decryptedData.count % 16 != 0)")
}
let decryptedDataCount = decryptedData.count
decryptedData.withUnsafeMutableBytes { rawBytes -> Void in
let bytes = rawBytes.baseAddress!.assumingMemoryBound(to: UInt8.self)
self.aesIv.withUnsafeMutableBytes { rawIv -> Void in
let iv = rawIv.baseAddress!.assumingMemoryBound(to: UInt8.self)
MTAesDecryptBytesInplaceAndModifyIv(bytes, decryptedDataCount, self.aesKey, iv)
}
}
if self.offset + decryptedData.count > self.decryptedSize {
decryptedData.count = Int(self.decryptedSize) - self.offset
}
self.offset += decryptedData.count
Logger.shared.log("FetchV2", "\(loggingIdentifier): decrypted part \(offset) ..< \(offset + data.count) (new offset is \(self.offset))")
return decryptedData
} else {
return nil
}
}
}
private final class FetchingState {
let fetchLocation: FetchLocation
let partSize: Int64
@@ -160,6 +202,7 @@ private final class FetchImpl {
var pendingParts: [PendingPart] = []
var completedRanges = RangeSet<Int64>()
var decryptionState: DecryptionState?
var pendingReadyParts: [PendingReadyPart] = []
var completedHashRanges = RangeSet<Int64>()
var pendingHashRanges: [PendingHashRange] = []
@@ -174,7 +217,8 @@ private final class FetchImpl {
maxPartSize: Int64,
partAlignment: Int64,
partDivision: Int64,
maxPendingParts: Int
maxPendingParts: Int,
decryptionState: DecryptionState?
) {
self.fetchLocation = fetchLocation
self.partSize = partSize
@@ -183,6 +227,7 @@ private final class FetchImpl {
self.partAlignment = partAlignment
self.partDivision = partDivision
self.maxPendingParts = maxPendingParts
self.decryptionState = decryptionState
}
deinit {
@@ -373,6 +418,12 @@ private final class FetchImpl {
if self.state == nil {
Logger.shared.log("FetchV2", "\(self.loggingIdentifier): initializing to .datacenter(\(self.datacenterId))")
var decryptionState: DecryptionState?
if let encryptionKey = self.encryptionKey, let decryptedSize = self.decryptedSize {
decryptionState = DecryptionState(aesKey: encryptionKey.aesKey, aesIv: encryptionKey.aesIv, decryptedSize: decryptedSize)
self.onNext(.reset)
}
self.state = .fetching(FetchingState(
fetchLocation: .datacenter(self.datacenterId),
partSize: self.defaultPartSize,
@@ -380,7 +431,8 @@ private final class FetchImpl {
maxPartSize: 1 * 1024 * 1024,
partAlignment: 4 * 1024,
partDivision: 1 * 1024 * 1024,
maxPendingParts: 6
maxPendingParts: 6,
decryptionState: decryptionState
))
}
guard let state = self.state else {
@@ -396,55 +448,75 @@ private final class FetchImpl {
do {
var removedPendingReadyPartIndices: [Int] = []
for i in 0 ..< state.pendingReadyParts.count {
let pendingReadyPart = state.pendingReadyParts[i]
if state.completedHashRanges.isSuperset(of: RangeSet<Int64>(pendingReadyPart.fetchRange)) {
removedPendingReadyPartIndices.append(i)
var checkOffset: Int64 = 0
var checkFailed = false
while checkOffset < pendingReadyPart.fetchedData.count {
if let hashRange = state.hashRanges[pendingReadyPart.fetchRange.lowerBound + checkOffset] {
var clippedHashRange = hashRange.range
if pendingReadyPart.fetchRange.lowerBound + Int64(pendingReadyPart.fetchedData.count) < clippedHashRange.lowerBound {
Logger.shared.log("FetchV2", "\(self.loggingIdentifier): unable to check \(pendingReadyPart.fetchRange): data range \(clippedHashRange) out of bounds (0 ..< \(pendingReadyPart.fetchedData.count))")
checkFailed = true
break
}
clippedHashRange = clippedHashRange.lowerBound ..< min(clippedHashRange.upperBound, pendingReadyPart.fetchRange.lowerBound + Int64(pendingReadyPart.fetchedData.count))
let partLocalHashRange = (clippedHashRange.lowerBound - pendingReadyPart.fetchRange.lowerBound) ..< (clippedHashRange.upperBound - pendingReadyPart.fetchRange.lowerBound)
if partLocalHashRange.lowerBound < 0 || partLocalHashRange.upperBound > pendingReadyPart.fetchedData.count {
Logger.shared.log("FetchV2", "\(self.loggingIdentifier): unable to check \(pendingReadyPart.fetchRange): data range \(partLocalHashRange) out of bounds (0 ..< \(pendingReadyPart.fetchedData.count))")
checkFailed = true
break
}
let dataToHash = pendingReadyPart.decryptedData.subdata(in: Int(partLocalHashRange.lowerBound) ..< Int(partLocalHashRange.upperBound))
let localHash = MTSha256(dataToHash)
if localHash != hashRange.data {
Logger.shared.log("FetchV2", "\(self.loggingIdentifier): failed to verify \(pendingReadyPart.fetchRange): hash mismatch")
checkFailed = true
break
}
checkOffset += partLocalHashRange.upperBound - partLocalHashRange.lowerBound
} else {
Logger.shared.log("FetchV2", "\(self.loggingIdentifier): unable to find \(pendingReadyPart.fetchRange) hash range despite it being marked as ready")
checkFailed = true
break
if let decryptionState = state.decryptionState {
while true {
var removedSomePendingReadyPart = false
for i in 0 ..< state.pendingReadyParts.count {
if removedPendingReadyPartIndices.contains(i) {
continue
}
let pendingReadyPart = state.pendingReadyParts[i]
if let resultData = decryptionState.tryDecrypt(data: pendingReadyPart.cleanData, offset: Int(pendingReadyPart.fetchRange.lowerBound), loggingIdentifier: self.loggingIdentifier) {
removedPendingReadyPartIndices.append(i)
removedSomePendingReadyPart = true
self.commitPendingReadyPart(state: state, partRange: pendingReadyPart.partRange, fetchRange: pendingReadyPart.fetchRange, data: resultData)
}
}
if !checkFailed {
self.commitPendingReadyPart(state: state, partRange: pendingReadyPart.partRange, fetchRange: pendingReadyPart.fetchRange, data: pendingReadyPart.decryptedData)
} else {
Logger.shared.log("FetchV2", "\(self.loggingIdentifier): unable to find \(pendingReadyPart.fetchRange) hash check failed")
if !removedSomePendingReadyPart {
break
}
}
} else {
for i in 0 ..< state.pendingReadyParts.count {
let pendingReadyPart = state.pendingReadyParts[i]
if state.completedHashRanges.isSuperset(of: RangeSet<Int64>(pendingReadyPart.fetchRange)) {
removedPendingReadyPartIndices.append(i)
var checkOffset: Int64 = 0
var checkFailed = false
while checkOffset < pendingReadyPart.fetchedData.count {
if let hashRange = state.hashRanges[pendingReadyPart.fetchRange.lowerBound + checkOffset] {
var clippedHashRange = hashRange.range
if pendingReadyPart.fetchRange.lowerBound + Int64(pendingReadyPart.fetchedData.count) < clippedHashRange.lowerBound {
Logger.shared.log("FetchV2", "\(self.loggingIdentifier): unable to check \(pendingReadyPart.fetchRange): data range \(clippedHashRange) out of bounds (0 ..< \(pendingReadyPart.fetchedData.count))")
checkFailed = true
break
}
clippedHashRange = clippedHashRange.lowerBound ..< min(clippedHashRange.upperBound, pendingReadyPart.fetchRange.lowerBound + Int64(pendingReadyPart.fetchedData.count))
let partLocalHashRange = (clippedHashRange.lowerBound - pendingReadyPart.fetchRange.lowerBound) ..< (clippedHashRange.upperBound - pendingReadyPart.fetchRange.lowerBound)
if partLocalHashRange.lowerBound < 0 || partLocalHashRange.upperBound > pendingReadyPart.fetchedData.count {
Logger.shared.log("FetchV2", "\(self.loggingIdentifier): unable to check \(pendingReadyPart.fetchRange): data range \(partLocalHashRange) out of bounds (0 ..< \(pendingReadyPart.fetchedData.count))")
checkFailed = true
break
}
let dataToHash = pendingReadyPart.cleanData.subdata(in: Int(partLocalHashRange.lowerBound) ..< Int(partLocalHashRange.upperBound))
let localHash = MTSha256(dataToHash)
if localHash != hashRange.data {
Logger.shared.log("FetchV2", "\(self.loggingIdentifier): failed to verify \(pendingReadyPart.fetchRange): hash mismatch")
checkFailed = true
break
}
checkOffset += partLocalHashRange.upperBound - partLocalHashRange.lowerBound
} else {
Logger.shared.log("FetchV2", "\(self.loggingIdentifier): unable to find \(pendingReadyPart.fetchRange) hash range despite it being marked as ready")
checkFailed = true
break
}
}
if !checkFailed {
self.commitPendingReadyPart(state: state, partRange: pendingReadyPart.partRange, fetchRange: pendingReadyPart.fetchRange, data: pendingReadyPart.cleanData)
} else {
Logger.shared.log("FetchV2", "\(self.loggingIdentifier): unable to find \(pendingReadyPart.fetchRange) hash check failed")
}
}
}
}
for index in removedPendingReadyPartIndices.reversed() {
for index in removedPendingReadyPartIndices.sorted(by: >) {
state.pendingReadyParts.remove(at: index)
}
}
@@ -452,7 +524,9 @@ private final class FetchImpl {
var requiredHashRanges = RangeSet<Int64>()
for pendingReadyPart in state.pendingReadyParts {
//TODO:check if already have hashes
requiredHashRanges.formUnion(RangeSet<Int64>(pendingReadyPart.fetchRange))
if state.decryptionState == nil {
requiredHashRanges.formUnion(RangeSet<Int64>(pendingReadyPart.fetchRange))
}
}
requiredHashRanges.subtract(state.completedHashRanges)
for pendingHashRange in state.pendingHashRanges {
@@ -613,7 +687,8 @@ private final class FetchImpl {
maxPartSize: self.cdnPartSize * 2,
partAlignment: self.cdnPartSize,
partDivision: 1 * 1024 * 1024,
maxPendingParts: 6
maxPendingParts: 6,
decryptionState: nil
))
self.update()
}, error: { [weak self] error in
@@ -661,7 +736,8 @@ private final class FetchImpl {
maxPartSize: self.defaultPartSize,
partAlignment: 4 * 1024,
partDivision: 1 * 1024 * 1024,
maxPendingParts: 6
maxPendingParts: 6,
decryptionState: nil
))
self.update()
@@ -819,7 +895,16 @@ private final class FetchImpl {
partRange: partRange,
fetchRange: fetchRange,
fetchedData: verifyPartHashData.fetchedData,
decryptedData: data
cleanData: data
))
} else if state.decryptionState != nil {
Logger.shared.log("FetchV2", "\(self.loggingIdentifier): stashing data part \(partRange) (aligned as \(fetchRange)) for decryption")
state.pendingReadyParts.append(FetchImpl.PendingReadyPart(
partRange: partRange,
fetchRange: fetchRange,
fetchedData: data,
cleanData: data
))
} else {
self.commitPendingReadyPart(
@@ -837,7 +922,8 @@ private final class FetchImpl {
maxPartSize: self.cdnPartSize * 2,
partAlignment: self.cdnPartSize,
partDivision: 1 * 1024 * 1024,
maxPendingParts: 6
maxPendingParts: 6,
decryptionState: nil
))
case let .cdnRefresh(cdnData, refreshToken):
self.state = .reuploadingToCdn(ReuploadingToCdnState(