Swiftgram/submodules/TelegramUI/Components/MediaEditor/Sources/HistogramCalculationPass.swift
2023-05-29 16:27:26 +04:00

101 lines
4.3 KiB
Swift

import Foundation
import Metal
import simd
import MetalPerformanceShaders
final class HistogramCalculationPass: DefaultRenderPass {
fileprivate var cachedTexture: MTLTexture?
fileprivate var histogramBuffer: MTLBuffer?
fileprivate var calculation: MPSImageHistogram?
var isEnabled = false
var updated: ((Data) -> Void)?
override var fragmentShaderFunctionName: String {
return "histogramPrepareFragmentShader"
}
override var pixelFormat: MTLPixelFormat {
return .r8Unorm
}
override func setup(device: MTLDevice, library: MTLLibrary) {
var histogramInfo = MPSImageHistogramInfo(
numberOfHistogramEntries: 256,
histogramForAlpha: false,
minPixelValue: vector_float4(0,0,0,0),
maxPixelValue: vector_float4(1,1,1,1)
)
let calculation = MPSImageHistogram(device: device, histogramInfo: &histogramInfo)
calculation.zeroHistogram = true
let histogramBufferLength = calculation.histogramSize(forSourceFormat: .bgra8Unorm)
let lumaHistogramBufferLength = calculation.histogramSize(forSourceFormat: .r8Unorm)
if let histogramBuffer = device.makeBuffer(length: histogramBufferLength + lumaHistogramBufferLength, options: [.storageModeShared]) {
self.calculation = calculation
self.histogramBuffer = histogramBuffer
}
super.setup(device: device, library: library)
}
override func process(input: MTLTexture, device: MTLDevice, commandBuffer: MTLCommandBuffer) -> MTLTexture? {
if self.isEnabled {
self.setupVerticesBuffer(device: device)
let width = input.width
let height = input.height
if self.cachedTexture == nil || self.cachedTexture?.width != width || self.cachedTexture?.height != height {
let textureDescriptor = MTLTextureDescriptor()
textureDescriptor.textureType = .type2D
textureDescriptor.width = width
textureDescriptor.height = height
textureDescriptor.pixelFormat = .r8Unorm
textureDescriptor.storageMode = .shared
textureDescriptor.usage = [.shaderRead, .shaderWrite, .renderTarget]
guard let texture = device.makeTexture(descriptor: textureDescriptor) else {
return input
}
self.cachedTexture = texture
texture.label = "lumaTexture"
}
let renderPassDescriptor = MTLRenderPassDescriptor()
renderPassDescriptor.colorAttachments[0].texture = self.cachedTexture!
renderPassDescriptor.colorAttachments[0].loadAction = .dontCare
renderPassDescriptor.colorAttachments[0].storeAction = .store
renderPassDescriptor.colorAttachments[0].clearColor = MTLClearColor(red: 0.0, green: 0.0, blue: 0.0, alpha: 0.0)
guard let renderCommandEncoder = commandBuffer.makeRenderCommandEncoder(descriptor: renderPassDescriptor) else {
return input
}
renderCommandEncoder.setViewport(MTLViewport(
originX: 0, originY: 0,
width: Double(width), height: Double(height),
znear: -1.0, zfar: 1.0)
)
renderCommandEncoder.setFragmentTexture(input, index: 0)
self.encodeDefaultCommands(using: renderCommandEncoder)
renderCommandEncoder.endEncoding()
if let histogramBuffer = self.histogramBuffer, let calculation = self.calculation {
calculation.encode(to: commandBuffer, sourceTexture: input, histogram: histogramBuffer, histogramOffset: 0)
let lumaHistogramBufferLength = calculation.histogramSize(forSourceFormat: .r8Unorm)
calculation.encode(to: commandBuffer, sourceTexture: self.cachedTexture!, histogram: histogramBuffer, histogramOffset: histogramBuffer.length - lumaHistogramBufferLength)
let histogramData = Data(bytes: histogramBuffer.contents(), count: histogramBuffer.length)
self.updated?(histogramData)
}
}
return input
}
}