Sticker cut out for iOS < 17

This commit is contained in:
Ilya Laktyushin 2024-04-26 14:23:15 +04:00
parent ff3e512869
commit b1f80d475c
13 changed files with 804 additions and 270 deletions

View File

@ -108,6 +108,7 @@ swift_library(
"//submodules/TelegramUI/Components/DustEffect",
"//submodules/TelegramUI/Components/DynamicCornerRadiusView",
"//submodules/TelegramUI/Components/StickerPickerScreen",
"//submodules/TelegramUI/Components/MediaEditor/ImageObjectSeparation",
],
visibility = [
"//visibility:public",

View File

@ -25,6 +25,7 @@ import TelegramUIPreferences
import FastBlur
import MediaEditor
import StickerPickerScreen
import ImageObjectSeparation
public struct DrawingResultData {
public let data: Data?

View File

@ -56,6 +56,21 @@ private func generateHistogram(cgImage: CGImage) -> ([[vImagePixelCount]], Int)?
return ([histogramBinZero, histogramBinOne, histogramBinTwo, histogramBinThree], alphaBinIndex)
}
public func imageHasSubject(_ image: UIImage) -> Bool {
guard let cgImage = image.cgImage, cgImage.bitsPerComponent == 8, cgImage.bitsPerPixel == 32 else {
return false
}
if let (histogramBins, _) = generateHistogram(cgImage: cgImage) {
var totalCount: vImagePixelCount = 0
for i in 0 ..< 255 {
totalCount += histogramBins[1][i]
}
let opaqueCount: vImagePixelCount = histogramBins[1][255]
return Double(opaqueCount) / Double(totalCount) > 0.05
}
return false
}
public func imageHasTransparency(_ image: UIImage) -> Bool {
guard let cgImage = image.cgImage, cgImage.bitsPerComponent == 8, cgImage.bitsPerPixel == 32 else {
return false

View File

@ -72,6 +72,8 @@ swift_library(
"//submodules/ImageTransparency",
"//submodules/FFMpegBinding",
"//submodules/TelegramUI/Components/AnimationCache/ImageDCT",
"//submodules/FileMediaResourceStatus",
"//submodules/TelegramUI/Components/MediaEditor/ImageObjectSeparation",
],
visibility = [
"//visibility:public",

View File

@ -0,0 +1,28 @@
load("@build_bazel_rules_swift//swift:swift.bzl", "swift_library")
swift_library(
name = "ImageObjectSeparation",
module_name = "ImageObjectSeparation",
srcs = glob([
"Sources/**/*.swift",
]),
copts = [
"-warnings-as-errors",
],
deps = [
"//submodules/AsyncDisplayKit",
"//submodules/Display",
"//submodules/Postbox",
"//submodules/TelegramCore",
"//submodules/SSignalKit/SwiftSignalKit",
"//submodules/AccountContext",
"//submodules/AppBundle",
"//submodules/ImageTransparency",
"//submodules/TelegramUI/Components/AnimationCache/ImageDCT",
"//submodules/FileMediaResourceStatus",
"//third-party/ZipArchive:ZipArchive",
],
visibility = [
"//visibility:public",
],
)

View File

@ -0,0 +1,441 @@
import Foundation
import UIKit
import Display
import Vision
import CoreImage
import CoreImage.CIFilterBuiltins
import VideoToolbox
import SwiftSignalKit
import Postbox
import TelegramCore
import AccountContext
import FileMediaResourceStatus
import ZipArchive
import ImageTransparency
private let queue = Queue()
public enum CutoutAvailability {
case available
case progress(Float)
case unavailable
}
private var forceCoreMLVariant: Bool {
#if targetEnvironment(simulator)
return true
#else
return false
#endif
}
private func modelPath() -> String {
return NSTemporaryDirectory() + "u2netp.mlmodelc"
}
public func cutoutAvailability(context: AccountContext) -> Signal<CutoutAvailability, NoError> {
if #available(iOS 17.0, *), !forceCoreMLVariant {
return .single(.available)
} else if #available(iOS 14.0, *) {
let compiledModelPath = modelPath()
#if DEBUG
// try? FileManager.default.removeItem(atPath: compiledModelPath)
#endif
if FileManager.default.fileExists(atPath: compiledModelPath) {
return .single(.available)
}
return context.engine.peers.resolvePeerByName(name: "stickersbackgroundseparation")
|> mapToSignal { result -> Signal<CutoutAvailability, NoError> in
guard case let .result(maybePeer) = result else {
return .complete()
}
guard let peer = maybePeer else {
return .single(.unavailable)
}
return context.account.viewTracker.aroundMessageHistoryViewForLocation(.peer(peerId: peer.id, threadId: nil), index: .lowerBound, anchorIndex: .lowerBound, count: 5, fixedCombinedReadStates: nil)
|> mapToSignal { view -> Signal<(TelegramMediaFile, EngineMessage)?, NoError> in
if !view.0.isLoading {
if let message = view.0.entries.last?.message, let file = message.media.first(where: { $0 is TelegramMediaFile }) as? TelegramMediaFile {
return .single((file, EngineMessage(message)))
} else {
return .single(nil)
}
} else {
return .complete()
}
}
|> take(1)
|> mapToSignal { maybeFileAndMessage -> Signal<CutoutAvailability, NoError> in
if let (file, message) = maybeFileAndMessage {
let fetchedData = fetchedMediaResource(mediaBox: context.account.postbox.mediaBox, userLocation: .other, userContentType: .file, reference: FileMediaReference.message(message: MessageReference(message._asMessage()), media: file).resourceReference(file.resource))
enum FetchStatus {
case completed(String)
case progress(Float)
case failed
}
let fetchStatus = Signal<FetchStatus, NoError> { subscriber in
let fetchedDisposable = fetchedData.start()
let thumbnailDisposable = context.account.postbox.mediaBox.resourceData(file.resource, attemptSynchronously: false).start(next: { next in
if next.complete {
SSZipArchive.unzipFile(atPath: next.path, toDestination: NSTemporaryDirectory())
subscriber.putNext(.completed(compiledModelPath))
subscriber.putCompletion()
}
}, error: subscriber.putError, completed: subscriber.putCompletion)
let progressDisposable = messageFileMediaResourceStatus(context: context, file: file, message: message, isRecentActions: false).start(next: { status in
switch status.fetchStatus {
case let .Remote(progress), let .Fetching(_, progress), let .Paused(progress):
subscriber.putNext(.progress(progress))
default:
break
}
})
return ActionDisposable {
fetchedDisposable.dispose()
thumbnailDisposable.dispose()
progressDisposable.dispose()
}
}
return fetchStatus
|> mapToSignal { status -> Signal<CutoutAvailability, NoError> in
switch status {
case let .completed(path):
let _ = path
return .single(.available)
case let .progress(progress):
return .single(.progress(progress))
case .failed:
return .single(.unavailable)
}
}
} else {
return .single(.unavailable)
}
}
}
} else {
return .single(.unavailable)
}
}
public func cutoutStickerImage(from image: UIImage, context: AccountContext? = nil, onlyCheck: Bool = false) -> Signal<UIImage?, NoError> {
guard let cgImage = image.cgImage else {
return .single(nil)
}
if #available(iOS 17.0, *), !forceCoreMLVariant {
return Signal { subscriber in
let ciContext = CIContext(options: nil)
let inputImage = CIImage(cgImage: cgImage)
let handler = VNImageRequestHandler(cgImage: cgImage, options: [:])
let request = VNGenerateForegroundInstanceMaskRequest { [weak handler] request, error in
guard let handler, let result = request.results?.first as? VNInstanceMaskObservation else {
subscriber.putNext(nil)
subscriber.putCompletion()
return
}
if onlyCheck {
subscriber.putNext(UIImage())
subscriber.putCompletion()
} else {
let instances = instances(atPoint: nil, inObservation: result)
if let mask = try? result.generateScaledMaskForImage(forInstances: instances, from: handler) {
let filter = CIFilter.blendWithMask()
filter.inputImage = inputImage
filter.backgroundImage = CIImage(color: .clear)
filter.maskImage = CIImage(cvPixelBuffer: mask)
if let output = filter.outputImage, let cgImage = ciContext.createCGImage(output, from: inputImage.extent) {
let image = UIImage(cgImage: cgImage)
subscriber.putNext(image)
subscriber.putCompletion()
return
}
}
subscriber.putNext(nil)
subscriber.putCompletion()
}
}
try? handler.perform([request])
return ActionDisposable {
request.cancel()
}
}
|> runOn(queue)
} else if #available(iOS 14.0, *), onlyCheck {
return Signal { subscriber in
U2netp.load(contentsOf: URL(fileURLWithPath: modelPath()), completionHandler: { result in
switch result {
case let .success(model):
let modelImageSize = CGSize(width: 320, height: 320)
if let squareImage = scaleImageToPixelSize(image: image, size: modelImageSize),
let pixelBuffer = buffer(from: squareImage),
let result = try? model.prediction(in_0: pixelBuffer),
let resultImage = UIImage(pixelBuffer: result.out_p1),
imageHasSubject(resultImage) {
subscriber.putNext(UIImage())
} else {
subscriber.putNext(nil)
}
subscriber.putCompletion()
case .failure:
subscriber.putNext(nil)
subscriber.putCompletion()
}
})
return EmptyDisposable
}
|> runOn(queue)
} else {
return .single(nil)
}
}
public struct CutoutResult {
public enum Image {
case image(UIImage, CIImage)
case pixelBuffer(CVPixelBuffer)
}
public let index: Int
public let extractedImage: Image?
public let edgesMaskImage: Image?
public let maskImage: Image?
public let backgroundImage: Image?
}
public enum CutoutTarget {
case point(CGPoint?)
case index(Int)
case all
}
func refineEdges(_ maskImage: CIImage) -> CIImage? {
let maskImage = maskImage.clampedToExtent()
let blurFilter = CIFilter(name: "CIGaussianBlur")!
blurFilter.setValue(maskImage, forKey: kCIInputImageKey)
blurFilter.setValue(11.4, forKey: kCIInputRadiusKey)
let controlsFilter = CIFilter(name: "CIColorControls")!
controlsFilter.setValue(blurFilter.outputImage, forKey: kCIInputImageKey)
controlsFilter.setValue(6.61, forKey: kCIInputContrastKey)
let sharpenFilter = CIFilter(name: "CISharpenLuminance")!
sharpenFilter.setValue(controlsFilter.outputImage, forKey: kCIInputImageKey)
sharpenFilter.setValue(250.0, forKey: kCIInputSharpnessKey)
return sharpenFilter.outputImage?.cropped(to: maskImage.extent)
}
public func cutoutImage(
from image: UIImage,
editedImage: UIImage? = nil,
crop: (offset: CGPoint, rotation: CGFloat, scale: CGFloat)?,
target: CutoutTarget,
includeExtracted: Bool = true,
completion: @escaping ([CutoutResult]) -> Void
) {
guard #available(iOS 14.0, *), let cgImage = image.cgImage else {
completion([])
return
}
let ciContext = CIContext(options: nil)
let inputImage = CIImage(cgImage: cgImage)
var results: [CutoutResult] = []
func process(instance: Int, mask originalMaskImage: CIImage) {
let extractedImage: CutoutResult.Image?
if includeExtracted {
let filter = CIFilter.blendWithMask()
filter.backgroundImage = CIImage(color: .clear)
let dimensions: CGSize
var maskImage = originalMaskImage
if let editedImage = editedImage?.cgImage.flatMap({ CIImage(cgImage: $0) }) {
filter.inputImage = editedImage
dimensions = editedImage.extent.size
if let (cropOffset, cropRotation, cropScale) = crop {
let initialScale: CGFloat
if maskImage.extent.height > maskImage.extent.width {
initialScale = dimensions.width / maskImage.extent.width
} else {
initialScale = dimensions.width / maskImage.extent.height
}
let dimensions = editedImage.extent.size
maskImage = maskImage.transformed(by: CGAffineTransform(translationX: -maskImage.extent.width / 2.0, y: -maskImage.extent.height / 2.0))
var transform = CGAffineTransform.identity
transform = transform.translatedBy(x: dimensions.width / 2.0 + cropOffset.x, y: dimensions.height / 2.0 + cropOffset.y * -1.0)
transform = transform.rotated(by: -cropRotation)
transform = transform.scaledBy(x: cropScale * initialScale, y: cropScale * initialScale)
maskImage = maskImage.transformed(by: transform)
}
} else {
filter.inputImage = inputImage
dimensions = inputImage.extent.size
}
filter.maskImage = maskImage
if let output = filter.outputImage, let cgImage = ciContext.createCGImage(output, from: CGRect(origin: .zero, size: dimensions)) {
extractedImage = .image(UIImage(cgImage: cgImage), output)
} else {
extractedImage = nil
}
} else {
extractedImage = nil
}
let whiteImage = CIImage(color: .white)
let blackImage = CIImage(color: .black)
let maskFilter = CIFilter.blendWithMask()
maskFilter.inputImage = whiteImage
maskFilter.backgroundImage = blackImage
maskFilter.maskImage = originalMaskImage
let refinedMaskFilter = CIFilter.blendWithMask()
refinedMaskFilter.inputImage = whiteImage
refinedMaskFilter.backgroundImage = blackImage
refinedMaskFilter.maskImage = refineEdges(originalMaskImage)
let edgesMaskImage: CutoutResult.Image?
let maskImage: CutoutResult.Image?
if let maskOutput = maskFilter.outputImage?.cropped(to: inputImage.extent), let maskCgImage = ciContext.createCGImage(maskOutput, from: inputImage.extent), let refinedMaskOutput = refinedMaskFilter.outputImage?.cropped(to: inputImage.extent), let refinedMaskCgImage = ciContext.createCGImage(refinedMaskOutput, from: inputImage.extent) {
edgesMaskImage = .image(UIImage(cgImage: maskCgImage), maskOutput)
maskImage = .image(UIImage(cgImage: refinedMaskCgImage), refinedMaskOutput)
} else {
edgesMaskImage = nil
maskImage = nil
}
if extractedImage != nil || maskImage != nil {
results.append(CutoutResult(index: instance, extractedImage: extractedImage, edgesMaskImage: edgesMaskImage, maskImage: maskImage, backgroundImage: nil))
}
}
if #available(iOS 17.0, *), !forceCoreMLVariant {
queue.async {
let handler = VNImageRequestHandler(cgImage: cgImage, options: [:])
let request = VNGenerateForegroundInstanceMaskRequest { [weak handler] request, error in
guard let handler, let result = request.results?.first as? VNInstanceMaskObservation else {
completion([])
return
}
let targetInstances: IndexSet
switch target {
case let .point(point):
targetInstances = instances(atPoint: point, inObservation: result)
case let .index(index):
targetInstances = IndexSet([index])
case .all:
targetInstances = result.allInstances
}
for instance in targetInstances {
if let mask = try? result.generateScaledMaskForImage(forInstances: IndexSet(integer: instance), from: handler) {
process(instance: instance, mask: CIImage(cvPixelBuffer: mask))
}
}
completion(results)
}
try? handler.perform([request])
}
} else {
U2netp.load(contentsOf: URL(fileURLWithPath: modelPath()), completionHandler: { result in
switch result {
case let .success(model):
let modelImageSize = CGSize(width: 320, height: 320)
if let squareImage = scaleImageToPixelSize(image: image, size: modelImageSize), let pixelBuffer = buffer(from: squareImage), let result = try? model.prediction(in_0: pixelBuffer), let maskImage = UIImage(pixelBuffer: result.out_p1), let scaledMaskImage = scaleImageToPixelSize(image: maskImage, size: image.size), let ciImage = CIImage(image: scaledMaskImage) {
process(instance: 0, mask: ciImage)
}
case .failure:
break
}
completion(results)
})
}
}
@available(iOS 17.0, *)
private func instances(atPoint maybePoint: CGPoint?, inObservation observation: VNInstanceMaskObservation) -> IndexSet {
guard let point = maybePoint else {
return observation.allInstances
}
let instanceMap = observation.instanceMask
let coords = VNImagePointForNormalizedPoint(point, CVPixelBufferGetWidth(instanceMap) - 1, CVPixelBufferGetHeight(instanceMap) - 1)
CVPixelBufferLockBaseAddress(instanceMap, .readOnly)
guard let pixels = CVPixelBufferGetBaseAddress(instanceMap) else {
fatalError()
}
let bytesPerRow = CVPixelBufferGetBytesPerRow(instanceMap)
let instanceLabel = pixels.load(fromByteOffset: Int(coords.y) * bytesPerRow + Int(coords.x), as: UInt8.self)
CVPixelBufferUnlockBaseAddress(instanceMap, .readOnly)
return instanceLabel == 0 ? observation.allInstances : [Int(instanceLabel)]
}
private extension UIImage {
convenience init?(pixelBuffer: CVPixelBuffer) {
var cgImage: CGImage?
VTCreateCGImageFromCVPixelBuffer(pixelBuffer, options: nil, imageOut: &cgImage)
guard let cgImage = cgImage else {
return nil
}
self.init(cgImage: cgImage)
}
}
private func scaleImageToPixelSize(image: UIImage, size: CGSize) -> UIImage? {
UIGraphicsBeginImageContextWithOptions(size, true, 1.0)
image.draw(in: CGRect(origin: CGPoint(), size: size), blendMode: .copy, alpha: 1.0)
let result = UIGraphicsGetImageFromCurrentImageContext()
UIGraphicsEndImageContext()
return result
}
private func buffer(from image: UIImage) -> CVPixelBuffer? {
let attrs = [kCVPixelBufferCGImageCompatibilityKey: kCFBooleanTrue, kCVPixelBufferCGBitmapContextCompatibilityKey: kCFBooleanTrue] as CFDictionary
var pixelBuffer : CVPixelBuffer?
let status = CVPixelBufferCreate(kCFAllocatorDefault, Int(image.size.width), Int(image.size.height), kCVPixelFormatType_32ARGB, attrs, &pixelBuffer)
guard (status == kCVReturnSuccess) else {
return nil
}
guard let pixelBufferUnwrapped = pixelBuffer else {
return nil
}
CVPixelBufferLockBaseAddress(pixelBufferUnwrapped, CVPixelBufferLockFlags(rawValue: 0))
let pixelData = CVPixelBufferGetBaseAddress(pixelBufferUnwrapped)
let rgbColorSpace = CGColorSpaceCreateDeviceRGB()
guard let context = CGContext(data: pixelData, width: Int(image.size.width), height: Int(image.size.height), bitsPerComponent: 8, bytesPerRow: CVPixelBufferGetBytesPerRow(pixelBufferUnwrapped), space: rgbColorSpace, bitmapInfo: CGImageAlphaInfo.noneSkipFirst.rawValue) else {
return nil
}
context.translateBy(x: 0, y: image.size.height)
context.scaleBy(x: 1.0, y: -1.0)
UIGraphicsPushContext(context)
image.draw(in: CGRect(x: 0, y: 0, width: image.size.width, height: image.size.height))
UIGraphicsPopContext()
CVPixelBufferUnlockBaseAddress(pixelBufferUnwrapped, CVPixelBufferLockFlags(rawValue: 0))
return pixelBufferUnwrapped
}

View File

@ -0,0 +1,252 @@
import CoreML
/// Model Prediction Input Type
@available(macOS 13.0, iOS 14.0, tvOS 14.0, watchOS 7.0, *)
class U2netpInput : MLFeatureProvider {
/// in_0 as color (kCVPixelFormatType_32BGRA) image buffer, 320 pixels wide by 320 pixels high
var in_0: CVPixelBuffer
var featureNames: Set<String> {
get {
return ["in_0"]
}
}
func featureValue(for featureName: String) -> MLFeatureValue? {
if (featureName == "in_0") {
return MLFeatureValue(pixelBuffer: in_0)
}
return nil
}
init(in_0: CVPixelBuffer) {
self.in_0 = in_0
}
convenience init(in_0With in_0: CGImage) throws {
self.init(in_0: try MLFeatureValue(cgImage: in_0, pixelsWide: 320, pixelsHigh: 320, pixelFormatType: kCVPixelFormatType_32ARGB, options: nil).imageBufferValue!)
}
convenience init(in_0At in_0: URL) throws {
self.init(in_0: try MLFeatureValue(imageAt: in_0, pixelsWide: 320, pixelsHigh: 320, pixelFormatType: kCVPixelFormatType_32ARGB, options: nil).imageBufferValue!)
}
func setIn_0(with in_0: CGImage) throws {
self.in_0 = try MLFeatureValue(cgImage: in_0, pixelsWide: 320, pixelsHigh: 320, pixelFormatType: kCVPixelFormatType_32ARGB, options: nil).imageBufferValue!
}
func setIn_0(with in_0: URL) throws {
self.in_0 = try MLFeatureValue(imageAt: in_0, pixelsWide: 320, pixelsHigh: 320, pixelFormatType: kCVPixelFormatType_32ARGB, options: nil).imageBufferValue!
}
}
/// Model Prediction Output Type
@available(macOS 11.0, iOS 14.0, tvOS 14.0, watchOS 7.0, *)
class U2netpOutput : MLFeatureProvider {
/// Source provided by CoreML
private let provider : MLFeatureProvider
/// out_p0 as grayscale (kCVPixelFormatType_OneComponent8) image buffer, 320 pixels wide by 320 pixels high
lazy var out_p0: CVPixelBuffer = {
[unowned self] in return self.provider.featureValue(for: "out_p0")!.imageBufferValue
}()!
/// out_p1 as grayscale (kCVPixelFormatType_OneComponent8) image buffer, 320 pixels wide by 320 pixels high
lazy var out_p1: CVPixelBuffer = {
[unowned self] in return self.provider.featureValue(for: "out_p1")!.imageBufferValue
}()!
/// out_p2 as grayscale (kCVPixelFormatType_OneComponent8) image buffer, 320 pixels wide by 320 pixels high
lazy var out_p2: CVPixelBuffer = {
[unowned self] in return self.provider.featureValue(for: "out_p2")!.imageBufferValue
}()!
/// out_p3 as grayscale (kCVPixelFormatType_OneComponent8) image buffer, 320 pixels wide by 320 pixels high
lazy var out_p3: CVPixelBuffer = {
[unowned self] in return self.provider.featureValue(for: "out_p3")!.imageBufferValue
}()!
/// out_p4 as grayscale (kCVPixelFormatType_OneComponent8) image buffer, 320 pixels wide by 320 pixels high
lazy var out_p4: CVPixelBuffer = {
[unowned self] in return self.provider.featureValue(for: "out_p4")!.imageBufferValue
}()!
/// out_p5 as grayscale (kCVPixelFormatType_OneComponent8) image buffer, 320 pixels wide by 320 pixels high
lazy var out_p5: CVPixelBuffer = {
[unowned self] in return self.provider.featureValue(for: "out_p5")!.imageBufferValue
}()!
/// out_p6 as grayscale (kCVPixelFormatType_OneComponent8) image buffer, 320 pixels wide by 320 pixels high
lazy var out_p6: CVPixelBuffer = {
[unowned self] in return self.provider.featureValue(for: "out_p6")!.imageBufferValue
}()!
var featureNames: Set<String> {
return self.provider.featureNames
}
func featureValue(for featureName: String) -> MLFeatureValue? {
return self.provider.featureValue(for: featureName)
}
init(out_p0: CVPixelBuffer, out_p1: CVPixelBuffer, out_p2: CVPixelBuffer, out_p3: CVPixelBuffer, out_p4: CVPixelBuffer, out_p5: CVPixelBuffer, out_p6: CVPixelBuffer) {
self.provider = try! MLDictionaryFeatureProvider(dictionary: ["out_p0" : MLFeatureValue(pixelBuffer: out_p0), "out_p1" : MLFeatureValue(pixelBuffer: out_p1), "out_p2" : MLFeatureValue(pixelBuffer: out_p2), "out_p3" : MLFeatureValue(pixelBuffer: out_p3), "out_p4" : MLFeatureValue(pixelBuffer: out_p4), "out_p5" : MLFeatureValue(pixelBuffer: out_p5), "out_p6" : MLFeatureValue(pixelBuffer: out_p6)])
}
init(features: MLFeatureProvider) {
self.provider = features
}
}
/// Class for model loading and prediction
@available(macOS 11.0, iOS 14.0, tvOS 14.0, watchOS 7.0, *)
class U2netp {
let model: MLModel
/**
Construct U2netp instance with an existing MLModel object.
Usually the application does not use this initializer unless it makes a subclass of U2netp.
Such application may want to use `MLModel(contentsOfURL:configuration:)` and `U2netp.urlOfModelInThisBundle` to create a MLModel object to pass-in.
- parameters:
- model: MLModel object
*/
init(model: MLModel) {
self.model = model
}
/**
Construct U2netp instance with explicit path to mlmodelc file
- parameters:
- modelURL: the file url of the model
- throws: an NSError object that describes the problem
*/
convenience init(contentsOf modelURL: URL) throws {
try self.init(model: MLModel(contentsOf: modelURL))
}
/**
Construct a model with URL of the .mlmodelc directory and configuration
- parameters:
- modelURL: the file url of the model
- configuration: the desired model configuration
- throws: an NSError object that describes the problem
*/
convenience init(contentsOf modelURL: URL, configuration: MLModelConfiguration) throws {
try self.init(model: MLModel(contentsOf: modelURL, configuration: configuration))
}
/**
Construct U2netp instance asynchronously with URL of the .mlmodelc directory with optional configuration.
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
- parameters:
- modelURL: the URL to the model
- configuration: the desired model configuration
- handler: the completion handler to be called when the model loading completes successfully or unsuccessfully
*/
class func load(contentsOf modelURL: URL, configuration: MLModelConfiguration = MLModelConfiguration(), completionHandler handler: @escaping (Swift.Result<U2netp, Error>) -> Void) {
MLModel.load(contentsOf: modelURL, configuration: configuration) { result in
switch result {
case .failure(let error):
handler(.failure(error))
case .success(let model):
handler(.success(U2netp(model: model)))
}
}
}
/**
Construct U2netp instance asynchronously with URL of the .mlmodelc directory with optional configuration.
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
- parameters:
- modelURL: the URL to the model
- configuration: the desired model configuration
*/
@available(macOS 12.0, iOS 15.0, tvOS 15.0, watchOS 8.0, *)
class func load(contentsOf modelURL: URL, configuration: MLModelConfiguration = MLModelConfiguration()) async throws -> U2netp {
let model = try await MLModel.load(contentsOf: modelURL, configuration: configuration)
return U2netp(model: model)
}
/**
Make a prediction using the structured interface
- parameters:
- input: the input to the prediction as U2netpInput
- throws: an NSError object that describes the problem
- returns: the result of the prediction as U2netpOutput
*/
func prediction(input: U2netpInput) throws -> U2netpOutput {
return try self.prediction(input: input, options: MLPredictionOptions())
}
/**
Make a prediction using the structured interface
- parameters:
- input: the input to the prediction as U2netpInput
- options: prediction options
- throws: an NSError object that describes the problem
- returns: the result of the prediction as U2netpOutput
*/
func prediction(input: U2netpInput, options: MLPredictionOptions) throws -> U2netpOutput {
let outFeatures = try model.prediction(from: input, options:options)
return U2netpOutput(features: outFeatures)
}
/**
Make a prediction using the convenience interface
- parameters:
- in_0 as color (kCVPixelFormatType_32BGRA) image buffer, 320 pixels wide by 320 pixels high
- throws: an NSError object that describes the problem
- returns: the result of the prediction as U2netpOutput
*/
func prediction(in_0: CVPixelBuffer) throws -> U2netpOutput {
let input_ = U2netpInput(in_0: in_0)
return try self.prediction(input: input_)
}
/**
Make a batch prediction using the structured interface
- parameters:
- inputs: the inputs to the prediction as [U2netpInput]
- options: prediction options
- throws: an NSError object that describes the problem
- returns: the result of the prediction as [U2netpOutput]
*/
func predictions(inputs: [U2netpInput], options: MLPredictionOptions = MLPredictionOptions()) throws -> [U2netpOutput] {
let batchIn = MLArrayBatchProvider(array: inputs)
let batchOut = try model.predictions(from: batchIn, options: options)
var results : [U2netpOutput] = []
results.reserveCapacity(inputs.count)
for i in 0..<batchOut.count {
let outProvider = batchOut.features(at: i)
let result = U2netpOutput(features: outProvider)
results.append(result)
}
return results
}
}

View File

@ -1,245 +0,0 @@
import Foundation
import UIKit
import Display
import Vision
import CoreImage
import CoreImage.CIFilterBuiltins
import SwiftSignalKit
import VideoToolbox
private let queue = Queue()
public func cutoutStickerImage(from image: UIImage, onlyCheck: Bool = false) -> Signal<UIImage?, NoError> {
if #available(iOS 17.0, *) {
guard let cgImage = image.cgImage else {
return .single(nil)
}
return Signal { subscriber in
let ciContext = CIContext(options: nil)
let inputImage = CIImage(cgImage: cgImage)
let handler = VNImageRequestHandler(cgImage: cgImage, options: [:])
let request = VNGenerateForegroundInstanceMaskRequest { [weak handler] request, error in
guard let handler, let result = request.results?.first as? VNInstanceMaskObservation else {
subscriber.putNext(nil)
subscriber.putCompletion()
return
}
if onlyCheck {
subscriber.putNext(UIImage())
subscriber.putCompletion()
} else {
let instances = instances(atPoint: nil, inObservation: result)
if let mask = try? result.generateScaledMaskForImage(forInstances: instances, from: handler) {
let filter = CIFilter.blendWithMask()
filter.inputImage = inputImage
filter.backgroundImage = CIImage(color: .clear)
filter.maskImage = CIImage(cvPixelBuffer: mask)
if let output = filter.outputImage, let cgImage = ciContext.createCGImage(output, from: inputImage.extent) {
let image = UIImage(cgImage: cgImage)
subscriber.putNext(image)
subscriber.putCompletion()
return
}
}
subscriber.putNext(nil)
subscriber.putCompletion()
}
}
try? handler.perform([request])
return ActionDisposable {
request.cancel()
}
}
|> runOn(queue)
} else {
return .single(nil)
}
}
public struct CutoutResult {
public enum Image {
case image(UIImage, CIImage)
case pixelBuffer(CVPixelBuffer)
}
public let index: Int
public let extractedImage: Image?
public let edgesMaskImage: Image?
public let maskImage: Image?
public let backgroundImage: Image?
}
public enum CutoutTarget {
case point(CGPoint?)
case index(Int)
case all
}
func refineEdges(_ maskImage: CIImage) -> CIImage? {
let maskImage = maskImage.clampedToExtent()
let blurFilter = CIFilter(name: "CIGaussianBlur")!
blurFilter.setValue(maskImage, forKey: kCIInputImageKey)
blurFilter.setValue(11.4, forKey: kCIInputRadiusKey)
let controlsFilter = CIFilter(name: "CIColorControls")!
controlsFilter.setValue(blurFilter.outputImage, forKey: kCIInputImageKey)
controlsFilter.setValue(6.61, forKey: kCIInputContrastKey)
let sharpenFilter = CIFilter(name: "CISharpenLuminance")!
sharpenFilter.setValue(controlsFilter.outputImage, forKey: kCIInputImageKey)
sharpenFilter.setValue(250.0, forKey: kCIInputSharpnessKey)
return sharpenFilter.outputImage?.cropped(to: maskImage.extent)
}
public func cutoutImage(
from image: UIImage,
editedImage: UIImage? = nil,
values: MediaEditorValues?,
target: CutoutTarget,
includeExtracted: Bool = true,
completion: @escaping ([CutoutResult]) -> Void
) {
if #available(iOS 17.0, *), let cgImage = image.cgImage {
let ciContext = CIContext(options: nil)
let inputImage = CIImage(cgImage: cgImage)
queue.async {
let handler = VNImageRequestHandler(cgImage: cgImage, options: [:])
let request = VNGenerateForegroundInstanceMaskRequest { [weak handler] request, error in
guard let handler, let result = request.results?.first as? VNInstanceMaskObservation else {
completion([])
return
}
let targetInstances: IndexSet
switch target {
case let .point(point):
targetInstances = instances(atPoint: point, inObservation: result)
case let .index(index):
targetInstances = IndexSet([index])
case .all:
targetInstances = result.allInstances
}
var results: [CutoutResult] = []
for instance in targetInstances {
if let mask = try? result.generateScaledMaskForImage(forInstances: IndexSet(integer: instance), from: handler) {
let extractedImage: CutoutResult.Image?
if includeExtracted {
let filter = CIFilter.blendWithMask()
filter.backgroundImage = CIImage(color: .clear)
let dimensions: CGSize
var maskImage = CIImage(cvPixelBuffer: mask)
if let editedImage = editedImage?.cgImage.flatMap({ CIImage(cgImage: $0) }) {
filter.inputImage = editedImage
dimensions = editedImage.extent.size
if let values {
let initialScale: CGFloat
if maskImage.extent.height > maskImage.extent.width {
initialScale = dimensions.width / maskImage.extent.width
} else {
initialScale = dimensions.width / maskImage.extent.height
}
let dimensions = editedImage.extent.size
maskImage = maskImage.transformed(by: CGAffineTransform(translationX: -maskImage.extent.width / 2.0, y: -maskImage.extent.height / 2.0))
var transform = CGAffineTransform.identity
let position = values.cropOffset
let rotation = values.cropRotation
let scale = values.cropScale
transform = transform.translatedBy(x: dimensions.width / 2.0 + position.x, y: dimensions.height / 2.0 + position.y * -1.0)
transform = transform.rotated(by: -rotation)
transform = transform.scaledBy(x: scale * initialScale, y: scale * initialScale)
maskImage = maskImage.transformed(by: transform)
}
} else {
filter.inputImage = inputImage
dimensions = inputImage.extent.size
}
filter.maskImage = maskImage
if let output = filter.outputImage, let cgImage = ciContext.createCGImage(output, from: CGRect(origin: .zero, size: dimensions)) {
extractedImage = .image(UIImage(cgImage: cgImage), output)
} else {
extractedImage = nil
}
} else {
extractedImage = nil
}
let whiteImage = CIImage(color: .white)
let blackImage = CIImage(color: .black)
let maskFilter = CIFilter.blendWithMask()
maskFilter.inputImage = whiteImage
maskFilter.backgroundImage = blackImage
maskFilter.maskImage = CIImage(cvPixelBuffer: mask)
let refinedMaskFilter = CIFilter.blendWithMask()
refinedMaskFilter.inputImage = whiteImage
refinedMaskFilter.backgroundImage = blackImage
refinedMaskFilter.maskImage = refineEdges(CIImage(cvPixelBuffer: mask))
let edgesMaskImage: CutoutResult.Image?
let maskImage: CutoutResult.Image?
if let maskOutput = maskFilter.outputImage?.cropped(to: inputImage.extent), let maskCgImage = ciContext.createCGImage(maskOutput, from: inputImage.extent), let refinedMaskOutput = refinedMaskFilter.outputImage?.cropped(to: inputImage.extent), let refinedMaskCgImage = ciContext.createCGImage(refinedMaskOutput, from: inputImage.extent) {
edgesMaskImage = .image(UIImage(cgImage: maskCgImage), maskOutput)
maskImage = .image(UIImage(cgImage: refinedMaskCgImage), refinedMaskOutput)
} else {
edgesMaskImage = nil
maskImage = nil
}
if extractedImage != nil || maskImage != nil {
results.append(CutoutResult(index: instance, extractedImage: extractedImage, edgesMaskImage: edgesMaskImage, maskImage: maskImage, backgroundImage: nil))
}
}
}
completion(results)
}
try? handler.perform([request])
}
} else {
completion([])
}
}
@available(iOS 17.0, *)
private func instances(atPoint maybePoint: CGPoint?, inObservation observation: VNInstanceMaskObservation) -> IndexSet {
guard let point = maybePoint else {
return observation.allInstances
}
let instanceMap = observation.instanceMask
let coords = VNImagePointForNormalizedPoint(point, CVPixelBufferGetWidth(instanceMap) - 1, CVPixelBufferGetHeight(instanceMap) - 1)
CVPixelBufferLockBaseAddress(instanceMap, .readOnly)
guard let pixels = CVPixelBufferGetBaseAddress(instanceMap) else {
fatalError()
}
let bytesPerRow = CVPixelBufferGetBytesPerRow(instanceMap)
let instanceLabel = pixels.load(fromByteOffset: Int(coords.y) * bytesPerRow + Int(coords.x), as: UInt8.self)
CVPixelBufferUnlockBaseAddress(instanceMap, .readOnly)
return instanceLabel == 0 ? observation.allInstances : [Int(instanceLabel)]
}
private extension UIImage {
convenience init?(pixelBuffer: CVPixelBuffer) {
var cgImage: CGImage?
VTCreateCGImageFromCVPixelBuffer(pixelBuffer, options: nil, imageOut: &cgImage)
guard let cgImage = cgImage else {
return nil
}
self.init(cgImage: cgImage)
}
}

View File

@ -12,6 +12,7 @@ import TelegramPresentationData
import FastBlur
import AccountContext
import ImageTransparency
import ImageObjectSeparation
public struct MediaEditorPlayerState: Equatable {
public struct Track: Equatable {
@ -190,8 +191,27 @@ public final class MediaEditor {
}
}
public private(set) var canCutout: Bool = false
public var canCutoutUpdated: (Bool, Bool) -> Void = { _, _ in }
public enum CutoutStatus: Equatable {
public enum Availability: Equatable {
case available
case preparing(progress: Float)
case unavailable
}
case unknown
case known(canCutout: Bool, availability: Availability, hasTransparency: Bool)
}
private let cutoutDisposable = MetaDisposable()
private var cutoutStatusValue: CutoutStatus = .unknown {
didSet {
self.cutoutStatusPromise.set(self.cutoutStatusValue)
}
}
private let cutoutStatusPromise = ValuePromise<CutoutStatus>(.unknown)
public var cutoutStatus: Signal<CutoutStatus, NoError> {
return self.cutoutStatusPromise.get()
}
public var maskUpdated: (UIImage, Bool) -> Void = { _, _ in }
public var classificationUpdated: ([(String, Float)]) -> Void = { _ in }
@ -482,6 +502,7 @@ public final class MediaEditor {
}
deinit {
self.cutoutDisposable.dispose()
self.textureSourceDisposable?.dispose()
self.invalidateTimeObservers()
}
@ -726,19 +747,29 @@ public final class MediaEditor {
if case .sticker = self.mode {
if !imageHasTransparency(image) {
let _ = (cutoutStickerImage(from: image, onlyCheck: true)
|> deliverOnMainQueue).start(next: { [weak self] result in
self.cutoutDisposable.set((cutoutAvailability(context: self.context)
|> mapToSignal { availability -> Signal<MediaEditor.CutoutStatus, NoError> in
switch availability {
case .available:
return cutoutStickerImage(from: image, context: context, onlyCheck: true)
|> map { result in
return .known(canCutout: result != nil, availability: .available, hasTransparency: false)
}
case let .progress(progress):
return .single(.known(canCutout: false, availability: .preparing(progress: progress), hasTransparency: false))
case .unavailable:
return .single(.known(canCutout: false, availability: .unavailable, hasTransparency: false))
}
}
|> deliverOnMainQueue).start(next: { [weak self] status in
guard let self else {
return
}
let canCutout = result != nil
self.canCutout = canCutout
self.canCutoutUpdated(canCutout, false)
})
self.cutoutStatusValue = status
}))
self.maskUpdated(image, false)
} else {
self.canCutout = false
self.canCutoutUpdated(false, true)
self.cutoutStatusValue = .known(canCutout: false, availability: .unavailable, hasTransparency: true)
if let maskImage = generateTintedImage(image: image, color: .white, backgroundColor: .black) {
self.maskUpdated(maskImage, true)

View File

@ -448,6 +448,10 @@ public final class MediaEditorValues: Codable, Equatable {
return self.qualityPreset == .sticker
}
public var cropValues: (offset: CGPoint, rotation: CGFloat, scale: CGFloat) {
return (self.cropOffset, self.cropRotation, self.cropScale)
}
public init(
peerId: EnginePeer.Id,
originalDimensions: PixelDimensions,

View File

@ -58,6 +58,7 @@ swift_library(
"//submodules/TelegramUI/Components/Stickers/StickerPackEditTitleController",
"//submodules/TelegramUI/Components/StickerPickerScreen",
"//submodules/UIKitRuntimeUtils",
"//submodules/TelegramUI/Components/MediaEditor/ImageObjectSeparation",
],
visibility = [
"//visibility:public",

View File

@ -17,6 +17,7 @@ import LottieAnimationComponent
import MessageInputPanelComponent
import DustEffect
import PlainButtonComponent
import ImageObjectSeparation
private final class MediaCutoutScreenComponent: Component {
typealias EnvironmentType = ViewControllerComponentContainer.Environment
@ -118,7 +119,7 @@ private final class MediaCutoutScreenComponent: Component {
}
component.mediaEditor.processImage { [weak self] originalImage, _ in
cutoutImage(from: originalImage, values: nil, target: .point(point), includeExtracted: false, completion: { [weak self] results in
cutoutImage(from: originalImage, crop: nil, target: .point(point), includeExtracted: false, completion: { [weak self] results in
Queue.mainQueue().async {
if let self, let _ = self.component, let result = results.first, let maskImage = result.maskImage, let controller = self.environment?.controller() as? MediaCutoutScreen {
if case let .image(mask, _) = maskImage {
@ -427,7 +428,7 @@ private final class MediaCutoutScreenComponent: Component {
if isFirstTime {
let values = component.mediaEditor.values
component.mediaEditor.processImage { originalImage, editedImage in
cutoutImage(from: originalImage, editedImage: editedImage, values: values, target: .all, completion: { results in
cutoutImage(from: originalImage, editedImage: editedImage, crop: values.cropValues, target: .all, completion: { results in
Queue.mainQueue().async {
if !results.isEmpty {
for result in results {

View File

@ -47,6 +47,7 @@ import StickerPeekUI
import StickerPackEditTitleController
import StickerPickerScreen
import UIKitRuntimeUtils
import ImageObjectSeparation
private let playbackButtonTag = GenericComponentViewTag()
private let muteButtonTag = GenericComponentViewTag()
@ -2008,7 +2009,7 @@ final class MediaEditorScreenComponent: Component {
if let subject = controller.node.subject, case .empty = subject {
} else if let canCutout = controller.node.canCutout {
} else if case let .known(canCutout, _, hasTransparency) = controller.node.stickerCutoutStatus {
if controller.node.isCutout || controller.node.stickerMaskDrawingView?.internalState.canUndo == true {
hasUndoButton = true
}
@ -2020,7 +2021,7 @@ final class MediaEditorScreenComponent: Component {
hasRestoreButton = true
}
}
if hasUndoButton || controller.node.hasTransparency {
if hasUndoButton || hasTransparency {
hasOutlineButton = true
}
}
@ -2537,8 +2538,8 @@ public final class MediaEditorScreen: ViewController, UIDropInteractionDelegate
private var isDismissed = false
private var isDismissBySwipeSuppressed = false
fileprivate var canCutout: Bool?
fileprivate var hasTransparency = false
fileprivate var stickerCutoutStatus: MediaEditor.CutoutStatus = .unknown
private var stickerCutoutStatusDisposable: Disposable?
fileprivate var isCutout = false
private (set) var hasAnyChanges = false
@ -2807,6 +2808,7 @@ public final class MediaEditorScreen: ViewController, UIDropInteractionDelegate
self.appInForegroundDisposable?.dispose()
self.playbackPositionDisposable?.dispose()
self.availableReactionsDisposable?.dispose()
self.stickerCutoutStatusDisposable?.dispose()
}
private func setup(with subject: MediaEditorScreen.Subject) {
@ -2939,15 +2941,15 @@ public final class MediaEditorScreen: ViewController, UIDropInteractionDelegate
}
controller.requestLayout(transition: .animated(duration: 0.25, curve: .easeInOut))
}
}
mediaEditor.canCutoutUpdated = { [weak self] canCutout, hasTransparency in
}
self.stickerCutoutStatusDisposable = (mediaEditor.cutoutStatus
|> deliverOnMainQueue).start(next: { [weak self] cutoutStatus in
guard let self else {
return
}
self.canCutout = canCutout
self.hasTransparency = hasTransparency
self.stickerCutoutStatus = cutoutStatus
self.requestLayout(forceUpdate: true, transition: .easeInOut(duration: 0.25))
}
})
mediaEditor.maskUpdated = { [weak self] mask, apply in
guard let self else {
return
@ -4922,8 +4924,8 @@ public final class MediaEditorScreen: ViewController, UIDropInteractionDelegate
}
}
},
cutoutUndo: { [weak self, weak controller] in
if let self, let controller, let mediaEditor = self.mediaEditor, let stickerMaskDrawingView = self.stickerMaskDrawingView {
cutoutUndo: { [weak self] in
if let self, let mediaEditor = self.mediaEditor, let stickerMaskDrawingView = self.stickerMaskDrawingView {
if self.entitiesView.hasSelection {
self.entitiesView.selectEntity(nil)
}
@ -4934,12 +4936,12 @@ public final class MediaEditorScreen: ViewController, UIDropInteractionDelegate
mediaEditor.setSegmentationMask(drawingImage)
}
if self.isDisplayingTool == .cutoutRestore && !stickerMaskDrawingView.internalState.canUndo && !controller.node.isCutout {
if self.isDisplayingTool == .cutoutRestore && !stickerMaskDrawingView.internalState.canUndo && !self.isCutout {
self.cutoutScreen?.mode = .erase
self.isDisplayingTool = .cutoutErase
self.requestLayout(forceUpdate: true, transition: .easeInOut(duration: 0.25))
}
} else if controller.node.isCutout {
} else if self.isCutout {
let action = { [weak self, weak mediaEditor] in
guard let self, let mediaEditor else {
return