mirror of
https://github.com/Swiftgram/Telegram-iOS.git
synced 2025-12-24 07:05:35 +00:00
Sticker cut out for iOS < 17
This commit is contained in:
@@ -0,0 +1,28 @@
|
||||
load("@build_bazel_rules_swift//swift:swift.bzl", "swift_library")
|
||||
|
||||
swift_library(
|
||||
name = "ImageObjectSeparation",
|
||||
module_name = "ImageObjectSeparation",
|
||||
srcs = glob([
|
||||
"Sources/**/*.swift",
|
||||
]),
|
||||
copts = [
|
||||
"-warnings-as-errors",
|
||||
],
|
||||
deps = [
|
||||
"//submodules/AsyncDisplayKit",
|
||||
"//submodules/Display",
|
||||
"//submodules/Postbox",
|
||||
"//submodules/TelegramCore",
|
||||
"//submodules/SSignalKit/SwiftSignalKit",
|
||||
"//submodules/AccountContext",
|
||||
"//submodules/AppBundle",
|
||||
"//submodules/ImageTransparency",
|
||||
"//submodules/TelegramUI/Components/AnimationCache/ImageDCT",
|
||||
"//submodules/FileMediaResourceStatus",
|
||||
"//third-party/ZipArchive:ZipArchive",
|
||||
],
|
||||
visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,441 @@
|
||||
import Foundation
|
||||
import UIKit
|
||||
import Display
|
||||
import Vision
|
||||
import CoreImage
|
||||
import CoreImage.CIFilterBuiltins
|
||||
import VideoToolbox
|
||||
import SwiftSignalKit
|
||||
import Postbox
|
||||
import TelegramCore
|
||||
import AccountContext
|
||||
import FileMediaResourceStatus
|
||||
import ZipArchive
|
||||
import ImageTransparency
|
||||
|
||||
private let queue = Queue()
|
||||
|
||||
public enum CutoutAvailability {
|
||||
case available
|
||||
case progress(Float)
|
||||
case unavailable
|
||||
}
|
||||
|
||||
private var forceCoreMLVariant: Bool {
|
||||
#if targetEnvironment(simulator)
|
||||
return true
|
||||
#else
|
||||
return false
|
||||
#endif
|
||||
}
|
||||
|
||||
private func modelPath() -> String {
|
||||
return NSTemporaryDirectory() + "u2netp.mlmodelc"
|
||||
}
|
||||
|
||||
public func cutoutAvailability(context: AccountContext) -> Signal<CutoutAvailability, NoError> {
|
||||
if #available(iOS 17.0, *), !forceCoreMLVariant {
|
||||
return .single(.available)
|
||||
} else if #available(iOS 14.0, *) {
|
||||
let compiledModelPath = modelPath()
|
||||
|
||||
#if DEBUG
|
||||
// try? FileManager.default.removeItem(atPath: compiledModelPath)
|
||||
#endif
|
||||
|
||||
if FileManager.default.fileExists(atPath: compiledModelPath) {
|
||||
return .single(.available)
|
||||
}
|
||||
return context.engine.peers.resolvePeerByName(name: "stickersbackgroundseparation")
|
||||
|> mapToSignal { result -> Signal<CutoutAvailability, NoError> in
|
||||
guard case let .result(maybePeer) = result else {
|
||||
return .complete()
|
||||
}
|
||||
guard let peer = maybePeer else {
|
||||
return .single(.unavailable)
|
||||
}
|
||||
|
||||
return context.account.viewTracker.aroundMessageHistoryViewForLocation(.peer(peerId: peer.id, threadId: nil), index: .lowerBound, anchorIndex: .lowerBound, count: 5, fixedCombinedReadStates: nil)
|
||||
|> mapToSignal { view -> Signal<(TelegramMediaFile, EngineMessage)?, NoError> in
|
||||
if !view.0.isLoading {
|
||||
if let message = view.0.entries.last?.message, let file = message.media.first(where: { $0 is TelegramMediaFile }) as? TelegramMediaFile {
|
||||
return .single((file, EngineMessage(message)))
|
||||
} else {
|
||||
return .single(nil)
|
||||
}
|
||||
} else {
|
||||
return .complete()
|
||||
}
|
||||
}
|
||||
|> take(1)
|
||||
|> mapToSignal { maybeFileAndMessage -> Signal<CutoutAvailability, NoError> in
|
||||
if let (file, message) = maybeFileAndMessage {
|
||||
let fetchedData = fetchedMediaResource(mediaBox: context.account.postbox.mediaBox, userLocation: .other, userContentType: .file, reference: FileMediaReference.message(message: MessageReference(message._asMessage()), media: file).resourceReference(file.resource))
|
||||
|
||||
enum FetchStatus {
|
||||
case completed(String)
|
||||
case progress(Float)
|
||||
case failed
|
||||
}
|
||||
|
||||
let fetchStatus = Signal<FetchStatus, NoError> { subscriber in
|
||||
let fetchedDisposable = fetchedData.start()
|
||||
let thumbnailDisposable = context.account.postbox.mediaBox.resourceData(file.resource, attemptSynchronously: false).start(next: { next in
|
||||
if next.complete {
|
||||
SSZipArchive.unzipFile(atPath: next.path, toDestination: NSTemporaryDirectory())
|
||||
subscriber.putNext(.completed(compiledModelPath))
|
||||
subscriber.putCompletion()
|
||||
}
|
||||
}, error: subscriber.putError, completed: subscriber.putCompletion)
|
||||
let progressDisposable = messageFileMediaResourceStatus(context: context, file: file, message: message, isRecentActions: false).start(next: { status in
|
||||
switch status.fetchStatus {
|
||||
case let .Remote(progress), let .Fetching(_, progress), let .Paused(progress):
|
||||
subscriber.putNext(.progress(progress))
|
||||
default:
|
||||
break
|
||||
}
|
||||
})
|
||||
return ActionDisposable {
|
||||
fetchedDisposable.dispose()
|
||||
thumbnailDisposable.dispose()
|
||||
progressDisposable.dispose()
|
||||
}
|
||||
}
|
||||
return fetchStatus
|
||||
|> mapToSignal { status -> Signal<CutoutAvailability, NoError> in
|
||||
switch status {
|
||||
case let .completed(path):
|
||||
let _ = path
|
||||
return .single(.available)
|
||||
case let .progress(progress):
|
||||
return .single(.progress(progress))
|
||||
case .failed:
|
||||
return .single(.unavailable)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return .single(.unavailable)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return .single(.unavailable)
|
||||
}
|
||||
}
|
||||
|
||||
public func cutoutStickerImage(from image: UIImage, context: AccountContext? = nil, onlyCheck: Bool = false) -> Signal<UIImage?, NoError> {
|
||||
guard let cgImage = image.cgImage else {
|
||||
return .single(nil)
|
||||
}
|
||||
if #available(iOS 17.0, *), !forceCoreMLVariant {
|
||||
return Signal { subscriber in
|
||||
let ciContext = CIContext(options: nil)
|
||||
let inputImage = CIImage(cgImage: cgImage)
|
||||
let handler = VNImageRequestHandler(cgImage: cgImage, options: [:])
|
||||
let request = VNGenerateForegroundInstanceMaskRequest { [weak handler] request, error in
|
||||
guard let handler, let result = request.results?.first as? VNInstanceMaskObservation else {
|
||||
subscriber.putNext(nil)
|
||||
subscriber.putCompletion()
|
||||
return
|
||||
}
|
||||
if onlyCheck {
|
||||
subscriber.putNext(UIImage())
|
||||
subscriber.putCompletion()
|
||||
} else {
|
||||
let instances = instances(atPoint: nil, inObservation: result)
|
||||
if let mask = try? result.generateScaledMaskForImage(forInstances: instances, from: handler) {
|
||||
let filter = CIFilter.blendWithMask()
|
||||
filter.inputImage = inputImage
|
||||
filter.backgroundImage = CIImage(color: .clear)
|
||||
filter.maskImage = CIImage(cvPixelBuffer: mask)
|
||||
if let output = filter.outputImage, let cgImage = ciContext.createCGImage(output, from: inputImage.extent) {
|
||||
let image = UIImage(cgImage: cgImage)
|
||||
subscriber.putNext(image)
|
||||
subscriber.putCompletion()
|
||||
return
|
||||
}
|
||||
}
|
||||
subscriber.putNext(nil)
|
||||
subscriber.putCompletion()
|
||||
}
|
||||
}
|
||||
try? handler.perform([request])
|
||||
return ActionDisposable {
|
||||
request.cancel()
|
||||
}
|
||||
}
|
||||
|> runOn(queue)
|
||||
} else if #available(iOS 14.0, *), onlyCheck {
|
||||
return Signal { subscriber in
|
||||
U2netp.load(contentsOf: URL(fileURLWithPath: modelPath()), completionHandler: { result in
|
||||
switch result {
|
||||
case let .success(model):
|
||||
let modelImageSize = CGSize(width: 320, height: 320)
|
||||
if let squareImage = scaleImageToPixelSize(image: image, size: modelImageSize),
|
||||
let pixelBuffer = buffer(from: squareImage),
|
||||
let result = try? model.prediction(in_0: pixelBuffer),
|
||||
let resultImage = UIImage(pixelBuffer: result.out_p1),
|
||||
imageHasSubject(resultImage) {
|
||||
subscriber.putNext(UIImage())
|
||||
} else {
|
||||
subscriber.putNext(nil)
|
||||
}
|
||||
subscriber.putCompletion()
|
||||
case .failure:
|
||||
subscriber.putNext(nil)
|
||||
subscriber.putCompletion()
|
||||
}
|
||||
})
|
||||
return EmptyDisposable
|
||||
}
|
||||
|> runOn(queue)
|
||||
} else {
|
||||
return .single(nil)
|
||||
}
|
||||
}
|
||||
|
||||
public struct CutoutResult {
|
||||
public enum Image {
|
||||
case image(UIImage, CIImage)
|
||||
case pixelBuffer(CVPixelBuffer)
|
||||
}
|
||||
|
||||
public let index: Int
|
||||
public let extractedImage: Image?
|
||||
public let edgesMaskImage: Image?
|
||||
public let maskImage: Image?
|
||||
public let backgroundImage: Image?
|
||||
}
|
||||
|
||||
public enum CutoutTarget {
|
||||
case point(CGPoint?)
|
||||
case index(Int)
|
||||
case all
|
||||
}
|
||||
|
||||
|
||||
func refineEdges(_ maskImage: CIImage) -> CIImage? {
|
||||
let maskImage = maskImage.clampedToExtent()
|
||||
|
||||
let blurFilter = CIFilter(name: "CIGaussianBlur")!
|
||||
blurFilter.setValue(maskImage, forKey: kCIInputImageKey)
|
||||
blurFilter.setValue(11.4, forKey: kCIInputRadiusKey)
|
||||
|
||||
let controlsFilter = CIFilter(name: "CIColorControls")!
|
||||
controlsFilter.setValue(blurFilter.outputImage, forKey: kCIInputImageKey)
|
||||
controlsFilter.setValue(6.61, forKey: kCIInputContrastKey)
|
||||
|
||||
let sharpenFilter = CIFilter(name: "CISharpenLuminance")!
|
||||
sharpenFilter.setValue(controlsFilter.outputImage, forKey: kCIInputImageKey)
|
||||
sharpenFilter.setValue(250.0, forKey: kCIInputSharpnessKey)
|
||||
|
||||
return sharpenFilter.outputImage?.cropped(to: maskImage.extent)
|
||||
}
|
||||
|
||||
public func cutoutImage(
|
||||
from image: UIImage,
|
||||
editedImage: UIImage? = nil,
|
||||
crop: (offset: CGPoint, rotation: CGFloat, scale: CGFloat)?,
|
||||
target: CutoutTarget,
|
||||
includeExtracted: Bool = true,
|
||||
completion: @escaping ([CutoutResult]) -> Void
|
||||
) {
|
||||
guard #available(iOS 14.0, *), let cgImage = image.cgImage else {
|
||||
completion([])
|
||||
return
|
||||
}
|
||||
|
||||
let ciContext = CIContext(options: nil)
|
||||
let inputImage = CIImage(cgImage: cgImage)
|
||||
var results: [CutoutResult] = []
|
||||
|
||||
func process(instance: Int, mask originalMaskImage: CIImage) {
|
||||
let extractedImage: CutoutResult.Image?
|
||||
if includeExtracted {
|
||||
let filter = CIFilter.blendWithMask()
|
||||
filter.backgroundImage = CIImage(color: .clear)
|
||||
|
||||
let dimensions: CGSize
|
||||
var maskImage = originalMaskImage
|
||||
if let editedImage = editedImage?.cgImage.flatMap({ CIImage(cgImage: $0) }) {
|
||||
filter.inputImage = editedImage
|
||||
dimensions = editedImage.extent.size
|
||||
|
||||
if let (cropOffset, cropRotation, cropScale) = crop {
|
||||
let initialScale: CGFloat
|
||||
if maskImage.extent.height > maskImage.extent.width {
|
||||
initialScale = dimensions.width / maskImage.extent.width
|
||||
} else {
|
||||
initialScale = dimensions.width / maskImage.extent.height
|
||||
}
|
||||
|
||||
let dimensions = editedImage.extent.size
|
||||
maskImage = maskImage.transformed(by: CGAffineTransform(translationX: -maskImage.extent.width / 2.0, y: -maskImage.extent.height / 2.0))
|
||||
|
||||
var transform = CGAffineTransform.identity
|
||||
transform = transform.translatedBy(x: dimensions.width / 2.0 + cropOffset.x, y: dimensions.height / 2.0 + cropOffset.y * -1.0)
|
||||
transform = transform.rotated(by: -cropRotation)
|
||||
transform = transform.scaledBy(x: cropScale * initialScale, y: cropScale * initialScale)
|
||||
maskImage = maskImage.transformed(by: transform)
|
||||
}
|
||||
} else {
|
||||
filter.inputImage = inputImage
|
||||
dimensions = inputImage.extent.size
|
||||
}
|
||||
filter.maskImage = maskImage
|
||||
|
||||
if let output = filter.outputImage, let cgImage = ciContext.createCGImage(output, from: CGRect(origin: .zero, size: dimensions)) {
|
||||
extractedImage = .image(UIImage(cgImage: cgImage), output)
|
||||
} else {
|
||||
extractedImage = nil
|
||||
}
|
||||
} else {
|
||||
extractedImage = nil
|
||||
}
|
||||
|
||||
let whiteImage = CIImage(color: .white)
|
||||
let blackImage = CIImage(color: .black)
|
||||
|
||||
let maskFilter = CIFilter.blendWithMask()
|
||||
maskFilter.inputImage = whiteImage
|
||||
maskFilter.backgroundImage = blackImage
|
||||
maskFilter.maskImage = originalMaskImage
|
||||
|
||||
let refinedMaskFilter = CIFilter.blendWithMask()
|
||||
refinedMaskFilter.inputImage = whiteImage
|
||||
refinedMaskFilter.backgroundImage = blackImage
|
||||
refinedMaskFilter.maskImage = refineEdges(originalMaskImage)
|
||||
|
||||
let edgesMaskImage: CutoutResult.Image?
|
||||
let maskImage: CutoutResult.Image?
|
||||
if let maskOutput = maskFilter.outputImage?.cropped(to: inputImage.extent), let maskCgImage = ciContext.createCGImage(maskOutput, from: inputImage.extent), let refinedMaskOutput = refinedMaskFilter.outputImage?.cropped(to: inputImage.extent), let refinedMaskCgImage = ciContext.createCGImage(refinedMaskOutput, from: inputImage.extent) {
|
||||
edgesMaskImage = .image(UIImage(cgImage: maskCgImage), maskOutput)
|
||||
maskImage = .image(UIImage(cgImage: refinedMaskCgImage), refinedMaskOutput)
|
||||
} else {
|
||||
edgesMaskImage = nil
|
||||
maskImage = nil
|
||||
}
|
||||
|
||||
if extractedImage != nil || maskImage != nil {
|
||||
results.append(CutoutResult(index: instance, extractedImage: extractedImage, edgesMaskImage: edgesMaskImage, maskImage: maskImage, backgroundImage: nil))
|
||||
}
|
||||
}
|
||||
|
||||
if #available(iOS 17.0, *), !forceCoreMLVariant {
|
||||
queue.async {
|
||||
let handler = VNImageRequestHandler(cgImage: cgImage, options: [:])
|
||||
let request = VNGenerateForegroundInstanceMaskRequest { [weak handler] request, error in
|
||||
guard let handler, let result = request.results?.first as? VNInstanceMaskObservation else {
|
||||
completion([])
|
||||
return
|
||||
}
|
||||
|
||||
let targetInstances: IndexSet
|
||||
switch target {
|
||||
case let .point(point):
|
||||
targetInstances = instances(atPoint: point, inObservation: result)
|
||||
case let .index(index):
|
||||
targetInstances = IndexSet([index])
|
||||
case .all:
|
||||
targetInstances = result.allInstances
|
||||
}
|
||||
|
||||
for instance in targetInstances {
|
||||
if let mask = try? result.generateScaledMaskForImage(forInstances: IndexSet(integer: instance), from: handler) {
|
||||
process(instance: instance, mask: CIImage(cvPixelBuffer: mask))
|
||||
}
|
||||
}
|
||||
completion(results)
|
||||
}
|
||||
|
||||
try? handler.perform([request])
|
||||
}
|
||||
} else {
|
||||
U2netp.load(contentsOf: URL(fileURLWithPath: modelPath()), completionHandler: { result in
|
||||
switch result {
|
||||
case let .success(model):
|
||||
let modelImageSize = CGSize(width: 320, height: 320)
|
||||
if let squareImage = scaleImageToPixelSize(image: image, size: modelImageSize), let pixelBuffer = buffer(from: squareImage), let result = try? model.prediction(in_0: pixelBuffer), let maskImage = UIImage(pixelBuffer: result.out_p1), let scaledMaskImage = scaleImageToPixelSize(image: maskImage, size: image.size), let ciImage = CIImage(image: scaledMaskImage) {
|
||||
process(instance: 0, mask: ciImage)
|
||||
}
|
||||
case .failure:
|
||||
break
|
||||
}
|
||||
completion(results)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@available(iOS 17.0, *)
|
||||
private func instances(atPoint maybePoint: CGPoint?, inObservation observation: VNInstanceMaskObservation) -> IndexSet {
|
||||
guard let point = maybePoint else {
|
||||
return observation.allInstances
|
||||
}
|
||||
|
||||
let instanceMap = observation.instanceMask
|
||||
let coords = VNImagePointForNormalizedPoint(point, CVPixelBufferGetWidth(instanceMap) - 1, CVPixelBufferGetHeight(instanceMap) - 1)
|
||||
|
||||
CVPixelBufferLockBaseAddress(instanceMap, .readOnly)
|
||||
guard let pixels = CVPixelBufferGetBaseAddress(instanceMap) else {
|
||||
fatalError()
|
||||
}
|
||||
let bytesPerRow = CVPixelBufferGetBytesPerRow(instanceMap)
|
||||
let instanceLabel = pixels.load(fromByteOffset: Int(coords.y) * bytesPerRow + Int(coords.x), as: UInt8.self)
|
||||
CVPixelBufferUnlockBaseAddress(instanceMap, .readOnly)
|
||||
|
||||
return instanceLabel == 0 ? observation.allInstances : [Int(instanceLabel)]
|
||||
}
|
||||
|
||||
private extension UIImage {
|
||||
convenience init?(pixelBuffer: CVPixelBuffer) {
|
||||
var cgImage: CGImage?
|
||||
VTCreateCGImageFromCVPixelBuffer(pixelBuffer, options: nil, imageOut: &cgImage)
|
||||
|
||||
guard let cgImage = cgImage else {
|
||||
return nil
|
||||
}
|
||||
|
||||
self.init(cgImage: cgImage)
|
||||
}
|
||||
}
|
||||
|
||||
private func scaleImageToPixelSize(image: UIImage, size: CGSize) -> UIImage? {
|
||||
UIGraphicsBeginImageContextWithOptions(size, true, 1.0)
|
||||
image.draw(in: CGRect(origin: CGPoint(), size: size), blendMode: .copy, alpha: 1.0)
|
||||
let result = UIGraphicsGetImageFromCurrentImageContext()
|
||||
UIGraphicsEndImageContext()
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
private func buffer(from image: UIImage) -> CVPixelBuffer? {
|
||||
let attrs = [kCVPixelBufferCGImageCompatibilityKey: kCFBooleanTrue, kCVPixelBufferCGBitmapContextCompatibilityKey: kCFBooleanTrue] as CFDictionary
|
||||
var pixelBuffer : CVPixelBuffer?
|
||||
let status = CVPixelBufferCreate(kCFAllocatorDefault, Int(image.size.width), Int(image.size.height), kCVPixelFormatType_32ARGB, attrs, &pixelBuffer)
|
||||
guard (status == kCVReturnSuccess) else {
|
||||
return nil
|
||||
}
|
||||
|
||||
guard let pixelBufferUnwrapped = pixelBuffer else {
|
||||
return nil
|
||||
}
|
||||
|
||||
CVPixelBufferLockBaseAddress(pixelBufferUnwrapped, CVPixelBufferLockFlags(rawValue: 0))
|
||||
let pixelData = CVPixelBufferGetBaseAddress(pixelBufferUnwrapped)
|
||||
|
||||
let rgbColorSpace = CGColorSpaceCreateDeviceRGB()
|
||||
|
||||
guard let context = CGContext(data: pixelData, width: Int(image.size.width), height: Int(image.size.height), bitsPerComponent: 8, bytesPerRow: CVPixelBufferGetBytesPerRow(pixelBufferUnwrapped), space: rgbColorSpace, bitmapInfo: CGImageAlphaInfo.noneSkipFirst.rawValue) else {
|
||||
return nil
|
||||
}
|
||||
|
||||
context.translateBy(x: 0, y: image.size.height)
|
||||
context.scaleBy(x: 1.0, y: -1.0)
|
||||
|
||||
UIGraphicsPushContext(context)
|
||||
image.draw(in: CGRect(x: 0, y: 0, width: image.size.width, height: image.size.height))
|
||||
UIGraphicsPopContext()
|
||||
CVPixelBufferUnlockBaseAddress(pixelBufferUnwrapped, CVPixelBufferLockFlags(rawValue: 0))
|
||||
|
||||
return pixelBufferUnwrapped
|
||||
}
|
||||
@@ -0,0 +1,252 @@
|
||||
import CoreML
|
||||
|
||||
/// Model Prediction Input Type
|
||||
@available(macOS 13.0, iOS 14.0, tvOS 14.0, watchOS 7.0, *)
|
||||
class U2netpInput : MLFeatureProvider {
|
||||
|
||||
/// in_0 as color (kCVPixelFormatType_32BGRA) image buffer, 320 pixels wide by 320 pixels high
|
||||
var in_0: CVPixelBuffer
|
||||
|
||||
var featureNames: Set<String> {
|
||||
get {
|
||||
return ["in_0"]
|
||||
}
|
||||
}
|
||||
|
||||
func featureValue(for featureName: String) -> MLFeatureValue? {
|
||||
if (featureName == "in_0") {
|
||||
return MLFeatureValue(pixelBuffer: in_0)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
init(in_0: CVPixelBuffer) {
|
||||
self.in_0 = in_0
|
||||
}
|
||||
|
||||
convenience init(in_0With in_0: CGImage) throws {
|
||||
self.init(in_0: try MLFeatureValue(cgImage: in_0, pixelsWide: 320, pixelsHigh: 320, pixelFormatType: kCVPixelFormatType_32ARGB, options: nil).imageBufferValue!)
|
||||
}
|
||||
|
||||
convenience init(in_0At in_0: URL) throws {
|
||||
self.init(in_0: try MLFeatureValue(imageAt: in_0, pixelsWide: 320, pixelsHigh: 320, pixelFormatType: kCVPixelFormatType_32ARGB, options: nil).imageBufferValue!)
|
||||
}
|
||||
|
||||
func setIn_0(with in_0: CGImage) throws {
|
||||
self.in_0 = try MLFeatureValue(cgImage: in_0, pixelsWide: 320, pixelsHigh: 320, pixelFormatType: kCVPixelFormatType_32ARGB, options: nil).imageBufferValue!
|
||||
}
|
||||
|
||||
func setIn_0(with in_0: URL) throws {
|
||||
self.in_0 = try MLFeatureValue(imageAt: in_0, pixelsWide: 320, pixelsHigh: 320, pixelFormatType: kCVPixelFormatType_32ARGB, options: nil).imageBufferValue!
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
/// Model Prediction Output Type
|
||||
@available(macOS 11.0, iOS 14.0, tvOS 14.0, watchOS 7.0, *)
|
||||
class U2netpOutput : MLFeatureProvider {
|
||||
|
||||
/// Source provided by CoreML
|
||||
private let provider : MLFeatureProvider
|
||||
|
||||
/// out_p0 as grayscale (kCVPixelFormatType_OneComponent8) image buffer, 320 pixels wide by 320 pixels high
|
||||
lazy var out_p0: CVPixelBuffer = {
|
||||
[unowned self] in return self.provider.featureValue(for: "out_p0")!.imageBufferValue
|
||||
}()!
|
||||
|
||||
/// out_p1 as grayscale (kCVPixelFormatType_OneComponent8) image buffer, 320 pixels wide by 320 pixels high
|
||||
lazy var out_p1: CVPixelBuffer = {
|
||||
[unowned self] in return self.provider.featureValue(for: "out_p1")!.imageBufferValue
|
||||
}()!
|
||||
|
||||
/// out_p2 as grayscale (kCVPixelFormatType_OneComponent8) image buffer, 320 pixels wide by 320 pixels high
|
||||
lazy var out_p2: CVPixelBuffer = {
|
||||
[unowned self] in return self.provider.featureValue(for: "out_p2")!.imageBufferValue
|
||||
}()!
|
||||
|
||||
/// out_p3 as grayscale (kCVPixelFormatType_OneComponent8) image buffer, 320 pixels wide by 320 pixels high
|
||||
lazy var out_p3: CVPixelBuffer = {
|
||||
[unowned self] in return self.provider.featureValue(for: "out_p3")!.imageBufferValue
|
||||
}()!
|
||||
|
||||
/// out_p4 as grayscale (kCVPixelFormatType_OneComponent8) image buffer, 320 pixels wide by 320 pixels high
|
||||
lazy var out_p4: CVPixelBuffer = {
|
||||
[unowned self] in return self.provider.featureValue(for: "out_p4")!.imageBufferValue
|
||||
}()!
|
||||
|
||||
/// out_p5 as grayscale (kCVPixelFormatType_OneComponent8) image buffer, 320 pixels wide by 320 pixels high
|
||||
lazy var out_p5: CVPixelBuffer = {
|
||||
[unowned self] in return self.provider.featureValue(for: "out_p5")!.imageBufferValue
|
||||
}()!
|
||||
|
||||
/// out_p6 as grayscale (kCVPixelFormatType_OneComponent8) image buffer, 320 pixels wide by 320 pixels high
|
||||
lazy var out_p6: CVPixelBuffer = {
|
||||
[unowned self] in return self.provider.featureValue(for: "out_p6")!.imageBufferValue
|
||||
}()!
|
||||
|
||||
var featureNames: Set<String> {
|
||||
return self.provider.featureNames
|
||||
}
|
||||
|
||||
func featureValue(for featureName: String) -> MLFeatureValue? {
|
||||
return self.provider.featureValue(for: featureName)
|
||||
}
|
||||
|
||||
init(out_p0: CVPixelBuffer, out_p1: CVPixelBuffer, out_p2: CVPixelBuffer, out_p3: CVPixelBuffer, out_p4: CVPixelBuffer, out_p5: CVPixelBuffer, out_p6: CVPixelBuffer) {
|
||||
self.provider = try! MLDictionaryFeatureProvider(dictionary: ["out_p0" : MLFeatureValue(pixelBuffer: out_p0), "out_p1" : MLFeatureValue(pixelBuffer: out_p1), "out_p2" : MLFeatureValue(pixelBuffer: out_p2), "out_p3" : MLFeatureValue(pixelBuffer: out_p3), "out_p4" : MLFeatureValue(pixelBuffer: out_p4), "out_p5" : MLFeatureValue(pixelBuffer: out_p5), "out_p6" : MLFeatureValue(pixelBuffer: out_p6)])
|
||||
}
|
||||
|
||||
init(features: MLFeatureProvider) {
|
||||
self.provider = features
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Class for model loading and prediction
|
||||
@available(macOS 11.0, iOS 14.0, tvOS 14.0, watchOS 7.0, *)
|
||||
class U2netp {
|
||||
let model: MLModel
|
||||
|
||||
/**
|
||||
Construct U2netp instance with an existing MLModel object.
|
||||
|
||||
Usually the application does not use this initializer unless it makes a subclass of U2netp.
|
||||
Such application may want to use `MLModel(contentsOfURL:configuration:)` and `U2netp.urlOfModelInThisBundle` to create a MLModel object to pass-in.
|
||||
|
||||
- parameters:
|
||||
- model: MLModel object
|
||||
*/
|
||||
init(model: MLModel) {
|
||||
self.model = model
|
||||
}
|
||||
|
||||
/**
|
||||
Construct U2netp instance with explicit path to mlmodelc file
|
||||
- parameters:
|
||||
- modelURL: the file url of the model
|
||||
|
||||
- throws: an NSError object that describes the problem
|
||||
*/
|
||||
convenience init(contentsOf modelURL: URL) throws {
|
||||
try self.init(model: MLModel(contentsOf: modelURL))
|
||||
}
|
||||
|
||||
/**
|
||||
Construct a model with URL of the .mlmodelc directory and configuration
|
||||
|
||||
- parameters:
|
||||
- modelURL: the file url of the model
|
||||
- configuration: the desired model configuration
|
||||
|
||||
- throws: an NSError object that describes the problem
|
||||
*/
|
||||
convenience init(contentsOf modelURL: URL, configuration: MLModelConfiguration) throws {
|
||||
try self.init(model: MLModel(contentsOf: modelURL, configuration: configuration))
|
||||
}
|
||||
|
||||
/**
|
||||
Construct U2netp instance asynchronously with URL of the .mlmodelc directory with optional configuration.
|
||||
|
||||
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
|
||||
|
||||
- parameters:
|
||||
- modelURL: the URL to the model
|
||||
- configuration: the desired model configuration
|
||||
- handler: the completion handler to be called when the model loading completes successfully or unsuccessfully
|
||||
*/
|
||||
class func load(contentsOf modelURL: URL, configuration: MLModelConfiguration = MLModelConfiguration(), completionHandler handler: @escaping (Swift.Result<U2netp, Error>) -> Void) {
|
||||
MLModel.load(contentsOf: modelURL, configuration: configuration) { result in
|
||||
switch result {
|
||||
case .failure(let error):
|
||||
handler(.failure(error))
|
||||
case .success(let model):
|
||||
handler(.success(U2netp(model: model)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
Construct U2netp instance asynchronously with URL of the .mlmodelc directory with optional configuration.
|
||||
|
||||
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
|
||||
|
||||
- parameters:
|
||||
- modelURL: the URL to the model
|
||||
- configuration: the desired model configuration
|
||||
*/
|
||||
@available(macOS 12.0, iOS 15.0, tvOS 15.0, watchOS 8.0, *)
|
||||
class func load(contentsOf modelURL: URL, configuration: MLModelConfiguration = MLModelConfiguration()) async throws -> U2netp {
|
||||
let model = try await MLModel.load(contentsOf: modelURL, configuration: configuration)
|
||||
return U2netp(model: model)
|
||||
}
|
||||
|
||||
/**
|
||||
Make a prediction using the structured interface
|
||||
|
||||
- parameters:
|
||||
- input: the input to the prediction as U2netpInput
|
||||
|
||||
- throws: an NSError object that describes the problem
|
||||
|
||||
- returns: the result of the prediction as U2netpOutput
|
||||
*/
|
||||
func prediction(input: U2netpInput) throws -> U2netpOutput {
|
||||
return try self.prediction(input: input, options: MLPredictionOptions())
|
||||
}
|
||||
|
||||
/**
|
||||
Make a prediction using the structured interface
|
||||
|
||||
- parameters:
|
||||
- input: the input to the prediction as U2netpInput
|
||||
- options: prediction options
|
||||
|
||||
- throws: an NSError object that describes the problem
|
||||
|
||||
- returns: the result of the prediction as U2netpOutput
|
||||
*/
|
||||
func prediction(input: U2netpInput, options: MLPredictionOptions) throws -> U2netpOutput {
|
||||
let outFeatures = try model.prediction(from: input, options:options)
|
||||
return U2netpOutput(features: outFeatures)
|
||||
}
|
||||
|
||||
/**
|
||||
Make a prediction using the convenience interface
|
||||
|
||||
- parameters:
|
||||
- in_0 as color (kCVPixelFormatType_32BGRA) image buffer, 320 pixels wide by 320 pixels high
|
||||
|
||||
- throws: an NSError object that describes the problem
|
||||
|
||||
- returns: the result of the prediction as U2netpOutput
|
||||
*/
|
||||
func prediction(in_0: CVPixelBuffer) throws -> U2netpOutput {
|
||||
let input_ = U2netpInput(in_0: in_0)
|
||||
return try self.prediction(input: input_)
|
||||
}
|
||||
|
||||
/**
|
||||
Make a batch prediction using the structured interface
|
||||
|
||||
- parameters:
|
||||
- inputs: the inputs to the prediction as [U2netpInput]
|
||||
- options: prediction options
|
||||
|
||||
- throws: an NSError object that describes the problem
|
||||
|
||||
- returns: the result of the prediction as [U2netpOutput]
|
||||
*/
|
||||
func predictions(inputs: [U2netpInput], options: MLPredictionOptions = MLPredictionOptions()) throws -> [U2netpOutput] {
|
||||
let batchIn = MLArrayBatchProvider(array: inputs)
|
||||
let batchOut = try model.predictions(from: batchIn, options: options)
|
||||
var results : [U2netpOutput] = []
|
||||
results.reserveCapacity(inputs.count)
|
||||
for i in 0..<batchOut.count {
|
||||
let outProvider = batchOut.features(at: i)
|
||||
let result = U2netpOutput(features: outProvider)
|
||||
results.append(result)
|
||||
}
|
||||
return results
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user