Applying compute/kernel function to vertex buffer before vertex shader

I would like to use a compute shader to modify my vertices before they are passed to the vertex shader. I can’t find any examples or explanations of this, except that it seems to be mentioned here: Metal emulate geometry shaders using compute shaders. This doesn’t help me as it doesn’t explain the CPU part of it.

I have seen many examples where a texture buffer is read and written to in a compute shader, but I need to read and modify the vertex buffer, which contains custom vertex structs with normals, and is created by a MDLMesh. I would be forever grateful for some sample code!


What I actually want to achieve is really to be able to modify the vertex normals on the GPU. The other option would be if I could access the entire triangle from the vertex shader, like in the linked answer. For some reason I can only access a single vertex, using the stage_in attribute. Using the entire buffer does not work for me in this particular case, this is probably related to using a mesh provided by Model I/O and MDLMesh. When I create the vertices manually I am able to access the vertex buffer array. Having said that, with that solution I would have to calculate the new vertex normal vector three time for each triangle which seems wasteful, and in any case I want to be able to apply compute shaders to the vertex buffer!

Thanks to Ken Thomases' comments, I managed to find a solution. He made me realise it is quite straightforward:

I'm using a vertex struct that looks like this:

// Metal side
struct Vertex {
    float4 position;
    float4 normal;
    float4 color;

// Swift side
struct Vertex {
    var position: float4
    var normal: float4
    var color: float4

During setup where I usually create a vertex buffer, index buffer and render pipeline state, I now also make a compute pipeline state:

// Vertex buffer
let dataSize = vertexData.count*MemoryLayout<Vertex>.stride
vertexBuffer = device.makeBuffer(bytes: vertexData, length: dataSize, options: [])!

// Index buffer
indexCount = indices.count
let indexSize = indexCount*MemoryLayout<UInt16>.stride
indexBuffer = device.makeBuffer(bytes: indices, length: indexSize, options: [])!

// Compute pipeline state
let adjustmentFunction = library.makeFunction(name: "adjustment_func")!
cps = try! device.makeComputePipelineState(function: adjustmentFunction)

// Render pipeline state
let rpld = MTLRenderPipelineDescriptor()
rpld.vertexFunction = library.makeFunction(name: "vertex_func")
rpld.fragmentFunction = library.makeFunction(name: "fragment_func")
rpld.colorAttachments[0].pixelFormat = .bgra8Unorm
rps = try! device.makeRenderPipelineState(descriptor: rpld)

commandQueue = device.makeCommandQueue()!

Then my render function looks like this:

let black = MTLClearColor(red: 0, green: 0, blue: 0, alpha: 1)
rpd.colorAttachments[0].texture = drawable.texture
rpd.colorAttachments[0].clearColor = black
rpd.colorAttachments[0].loadAction = .clear

let commandBuffer = commandQueue.makeCommandBuffer()!

let computeCommandEncoder = commandBuffer.makeComputeCommandEncoder()!
computeCommandEncoder.setBuffer(vertexBuffer, offset: 0, index: 0)
computeCommandEncoder.dispatchThreadgroups(MTLSize(width: meshSize*meshSize, height: 1, depth: 1), threadsPerThreadgroup: MTLSize(width: 4, height: 1, depth: 1))

let renderCommandEncoder = commandBuffer.makeRenderCommandEncoder(descriptor: rpd)!

updateUniforms(aspect: Float(size.width/size.height))
renderCommandEncoder.setVertexBuffer(vertexBuffer, offset: 0, index: 0)
renderCommandEncoder.setVertexBuffer(uniformBuffer, offset: 0, index: 1)
renderCommandEncoder.setFragmentBuffer(uniformBuffer, offset: 0, index: 1)
renderCommandEncoder.drawIndexedPrimitives(type: .triangle, indexCount: indexCount, indexType: .uint16, indexBuffer: indexBuffer, indexBufferOffset: 0)


Finally my compute shader looks like this:

kernel void adjustment_func(const device Vertex *vertices [[buffer(0)]], uint2 gid [[thread_position_in_grid]]) {
    vertices[gid.x].position = function(pos.xyz);

and this is the signature of my vertex function:

vertex VertexOut vertex_func(const device Vertex *vertices [[buffer(0)]], uint i [[vertex_id]], constant Uniforms &uniforms [[buffer(1)]]) 
