2023-05-26 17:32:02 +04:00

285 lines
12 KiB
Swift

import Foundation
import UIKit
import Metal
import MetalPerformanceShaders
import simd
import CoreImage
struct TextureSize {
let width: Int
let height: Int
}
private final class EnhanceLightnessPass: DefaultRenderPass {
fileprivate var cachedTexture: MTLTexture?
override var fragmentShaderFunctionName: String {
return "rgbToLightnessFragmentShader"
}
override var pixelFormat: MTLPixelFormat {
return .r8Unorm
}
func process(input: MTLTexture, size: TextureSize, scale: simd_float2, device: MTLDevice, commandBuffer: MTLCommandBuffer) -> MTLTexture? {
self.setupVerticesBuffer(device: device)
let width = size.width
let height = size.height
if self.cachedTexture == nil {
let textureDescriptor = MTLTextureDescriptor()
textureDescriptor.textureType = .type2D
textureDescriptor.width = width
textureDescriptor.height = height
textureDescriptor.pixelFormat = .r8Unorm
textureDescriptor.storageMode = .private
textureDescriptor.usage = [.shaderRead, .shaderWrite, .renderTarget]
guard let texture = device.makeTexture(descriptor: textureDescriptor) else {
return nil
}
texture.label = "lightnessTexture"
self.cachedTexture = texture
}
let renderPassDescriptor = MTLRenderPassDescriptor()
renderPassDescriptor.colorAttachments[0].texture = self.cachedTexture!
renderPassDescriptor.colorAttachments[0].loadAction = .dontCare
renderPassDescriptor.colorAttachments[0].storeAction = .store
renderPassDescriptor.colorAttachments[0].clearColor = MTLClearColor(red: 0.0, green: 0.0, blue: 0.0, alpha: 0.0)
guard let renderCommandEncoder = commandBuffer.makeRenderCommandEncoder(descriptor: renderPassDescriptor) else {
return nil
}
renderCommandEncoder.setViewport(MTLViewport(
originX: 0, originY: 0,
width: Double(size.width), height: Double(size.height),
znear: -1.0, zfar: 1.0)
)
let samplerDescriptor = MTLSamplerDescriptor()
samplerDescriptor.minFilter = .linear
samplerDescriptor.magFilter = .linear
samplerDescriptor.sAddressMode = .mirrorRepeat
samplerDescriptor.tAddressMode = .mirrorRepeat
samplerDescriptor.rAddressMode = .mirrorRepeat
guard let samplerState = device.makeSamplerState(descriptor: samplerDescriptor) else {
return nil
}
var scale = scale
renderCommandEncoder.setFragmentTexture(input, index: 0)
renderCommandEncoder.setFragmentBytes(&scale, length: MemoryLayout<simd_float2>.size, index: 0)
renderCommandEncoder.setFragmentSamplerState(samplerState, index: 0)
self.encodeDefaultCommands(using: renderCommandEncoder)
renderCommandEncoder.endEncoding()
return self.cachedTexture!
}
}
private let binCount = 256
struct MediaEditorEnhanceLUTGeneratorParameters {
var histogramBins: simd_uint1
var clipLimit: simd_uint1
var totalPixelCountPerTile: simd_uint1
var numberOfLUTs: simd_uint1
}
private final class EnhanceLUTGeneratorPass: RenderPass {
fileprivate var pipelineState: MTLComputePipelineState?
fileprivate var histogramBuffer: MTLBuffer?
fileprivate var calculation: MPSImageHistogram?
private var lutTexture: MTLTexture?
func setup(device: MTLDevice, library: MTLLibrary) {
}
func setup(gridSize: TextureSize, device: MTLDevice, library: MTLLibrary) {
var histogramInfo = MPSImageHistogramInfo(
numberOfHistogramEntries: binCount,
histogramForAlpha: false,
minPixelValue: vector_float4(0,0,0,0),
maxPixelValue: vector_float4(1,1,1,1)
)
let calculation = MPSImageHistogram(device: device, histogramInfo: &histogramInfo)
calculation.zeroHistogram = false
self.calculation = calculation
let pipelineDescriptor = MTLComputePipelineDescriptor()
pipelineDescriptor.computeFunction = library.makeFunction(name: "enhanceGenerateLUT")
do {
self.pipelineState = try device.makeComputePipelineState(descriptor: pipelineDescriptor, options: .argumentInfo, reflection: nil)
} catch {
print(error.localizedDescription)
}
}
func process(input: MTLTexture, device: MTLDevice, commandBuffer: MTLCommandBuffer) -> MTLTexture? {
return nil
}
func process(input: MTLTexture, gridSize: TextureSize, clipLimit: Float, device: MTLDevice, commandBuffer: MTLCommandBuffer) -> MTLTexture? {
let lutCount = gridSize.width * gridSize.height
let tileSize = TextureSize(width: input.width / gridSize.width, height: input.height / gridSize.height);
let clipLimitValue = max(1, clipLimit * Float(tileSize.width * tileSize.height) / Float(binCount))
if self.lutTexture == nil {
let textureDescriptor = MTLTextureDescriptor()
textureDescriptor.textureType = .type2D
textureDescriptor.width = binCount
textureDescriptor.height = lutCount
textureDescriptor.pixelFormat = .r8Unorm
textureDescriptor.storageMode = .private
textureDescriptor.usage = [.shaderRead, .shaderWrite, .renderTarget]
guard let texture = device.makeTexture(descriptor: textureDescriptor) else {
return nil
}
self.lutTexture = texture
texture.label = "lutTexture"
}
guard let calculation = self.calculation, let histogramBuffer = device.makeBuffer(length: calculation.histogramSize(forSourceFormat: .r8Unorm) * lutCount, options: [.storageModePrivate]) else {
return nil
}
let histogramSize = calculation.histogramSize(forSourceFormat: input.pixelFormat)
for i in 0 ..< lutCount {
let col = i % gridSize.width
let row = i / gridSize.width
calculation.clipRectSource = MTLRegionMake2D(col * tileSize.width, row * tileSize.height, tileSize.width, tileSize.height)
calculation.encode(to: commandBuffer, sourceTexture: input, histogram: histogramBuffer, histogramOffset: i * histogramSize)
}
guard let computeCommandEncoder = commandBuffer.makeComputeCommandEncoder() else {
return nil
}
guard let pipelineState = self.pipelineState else {
return nil
}
var parameters = MediaEditorEnhanceLUTGeneratorParameters(
histogramBins: UInt32(binCount),
clipLimit: UInt32(clipLimitValue),
totalPixelCountPerTile: UInt32(tileSize.width * tileSize.height),
numberOfLUTs: UInt32(lutCount)
)
computeCommandEncoder.setComputePipelineState(pipelineState)
computeCommandEncoder.setBuffer(histogramBuffer, offset: 0, index: 0)
computeCommandEncoder.setBytes(&parameters, length: MemoryLayout<MediaEditorEnhanceLUTGeneratorParameters>.size, index: 1)
computeCommandEncoder.setTexture(self.lutTexture, index: 0)
let w = pipelineState.threadExecutionWidth
let threadsPerThreadgroup = MTLSize(width: w, height: 1, depth: 1)
let threadgroupsPerGrid = MTLSize(width: (lutCount + w - 1) / w, height: 1, depth: 1)
computeCommandEncoder.dispatchThreadgroups(threadgroupsPerGrid, threadsPerThreadgroup: threadsPerThreadgroup)
computeCommandEncoder.endEncoding()
return self.lutTexture!
}
}
private final class EnhanceLookupPass: DefaultRenderPass {
fileprivate var cachedTexture: MTLTexture?
override var fragmentShaderFunctionName: String {
return "enhanceColorLookupFragmentShader"
}
func process(input: MTLTexture, lookupTexture: MTLTexture, value: simd_float1, gridSize: simd_float2, device: MTLDevice, commandBuffer: MTLCommandBuffer) -> MTLTexture? {
self.setupVerticesBuffer(device: device)
let width = input.width
let height = input.height
if self.cachedTexture == nil {
let textureDescriptor = MTLTextureDescriptor()
textureDescriptor.textureType = .type2D
textureDescriptor.width = width
textureDescriptor.height = height
textureDescriptor.pixelFormat = input.pixelFormat
textureDescriptor.storageMode = .private
textureDescriptor.usage = [.shaderRead, .shaderWrite, .renderTarget]
guard let texture = device.makeTexture(descriptor: textureDescriptor) else {
return input
}
self.cachedTexture = texture
}
let renderPassDescriptor = MTLRenderPassDescriptor()
renderPassDescriptor.colorAttachments[0].texture = self.cachedTexture!
renderPassDescriptor.colorAttachments[0].loadAction = .dontCare
renderPassDescriptor.colorAttachments[0].storeAction = .store
renderPassDescriptor.colorAttachments[0].clearColor = MTLClearColor(red: 0.0, green: 0.0, blue: 0.0, alpha: 0.0)
guard let renderCommandEncoder = commandBuffer.makeRenderCommandEncoder(descriptor: renderPassDescriptor) else {
return input
}
renderCommandEncoder.setViewport(MTLViewport(
originX: 0, originY: 0,
width: Double(width), height: Double(height),
znear: -1.0, zfar: 1.0)
)
var gridSize = gridSize
var value = value
renderCommandEncoder.setFragmentTexture(input, index: 0)
renderCommandEncoder.setFragmentTexture(lookupTexture, index: 1)
renderCommandEncoder.setFragmentBytes(&gridSize, length: MemoryLayout<simd_float2>.size, index: 0)
renderCommandEncoder.setFragmentBytes(&value, length: MemoryLayout<simd_float1>.size, index: 1)
self.encodeDefaultCommands(using: renderCommandEncoder)
renderCommandEncoder.endEncoding()
return self.cachedTexture!
}
}
final class EnhanceRenderPass: RenderPass {
private let lightnessPass = EnhanceLightnessPass()
private let lutGeneratorPass = EnhanceLUTGeneratorPass()
private let lookupPass = EnhanceLookupPass()
var value: simd_float1 = 0.0
let clipLimit: Float = 1.25
let tileGridSize: TextureSize = TextureSize(width: 4, height: 4)
func setup(device: MTLDevice, library: MTLLibrary) {
self.lightnessPass.setup(device: device, library: library)
self.lutGeneratorPass.setup(gridSize: self.tileGridSize, device: device, library: library)
self.lookupPass.setup(device: device, library: library)
}
func process(input: MTLTexture, device: MTLDevice, commandBuffer: MTLCommandBuffer) -> MTLTexture? {
guard self.value > 0.005 else {
return input
}
let dY = (self.tileGridSize.height - (input.height % self.tileGridSize.height)) % self.tileGridSize.height
let dX = (self.tileGridSize.width - (input.width % self.tileGridSize.width)) % self.tileGridSize.width
let lightnessSize = TextureSize(width: input.width + dX, height: input.height + dY)
let lightnessScale = simd_float2(Float(input.width + dX) / Float(input.width), Float(input.height + dY) / Float(input.height))
let lightness = self.lightnessPass.process(input: input, size: lightnessSize, scale: lightnessScale, device: device, commandBuffer: commandBuffer)
let lookupTexture = self.lutGeneratorPass.process(input: lightness!, gridSize: self.tileGridSize, clipLimit: self.clipLimit, device: device, commandBuffer: commandBuffer)
let gridSize = simd_float2(Float(self.tileGridSize.width), Float(self.tileGridSize.height))
let output = self.lookupPass.process(input: input, lookupTexture: lookupTexture!, value: self.value, gridSize: gridSize, device: device, commandBuffer: commandBuffer)
return output
}
}