Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Applying compute/kernel function to vertex buffer before vertex shader

I would like to use a compute shader to modify my vertices before they are passed to the vertex shader. I can’t find any examples or explanations of this, except that it seems to be mentioned here: Metal emulate geometry shaders using compute shaders. This doesn’t help me as it doesn’t explain the CPU part of it.

I have seen many examples where a texture buffer is read and written to in a compute shader, but I need to read and modify the vertex buffer, which contains custom vertex structs with normals, and is created by a MDLMesh. I would be forever grateful for some sample code!

BACKGROUND

What I actually want to achieve is really to be able to modify the vertex normals on the GPU. The other option would be if I could access the entire triangle from the vertex shader, like in the linked answer. For some reason I can only access a single vertex, using the stage_in attribute. Using the entire buffer does not work for me in this particular case, this is probably related to using a mesh provided by Model I/O and MDLMesh. When I create the vertices manually I am able to access the vertex buffer array. Having said that, with that solution I would have to calculate the new vertex normal vector three time for each triangle which seems wasteful, and in any case I want to be able to apply compute shaders to the vertex buffer!

like image 211
Nils Nielsen Avatar asked Dec 29 '18 13:12

Nils Nielsen


People also ask

How do I use a vertex buffer in Vulkan?

There are many possible ways to use a given buffer (as texture storage, as uniforms, as writeable data, etc), so Vulkan needs to know what exactly are you going to use that buffer for. We will use it strictly as a vertex buffer, so we will just put the VK_BUFFER_USAGE_VERTEX_BUFFER_BIT flag in the VkBufferCreateInfo::usage parameter.

What vertex format should I use for my vertex data?

Each of them is a vec3. This vertex format is not optimal as the data can be packed much better, but we will use this for simplicity. Optimized vertex formats will be a topic for later. Our Mesh class will hold a std::vector of Vertex for our vertex data, and an AllocatedBuffer which is where we will store the GPU copy of that data.

How does the GPU compute shader work?

It's a void function, initially without parameters. When a GPU is instructed to execute a compute shader function it partitions its work into groups and then schedules them to run independently and in parallel. Each group in turn consists of a number of threads that perform the same calculations but with different input.

How to make a point surface shader work with a compute shader?

Duplicate the Point Surface shader and rename it to Point Surface GPU. Adjust its shader menu label to match. Also, as we now rely on a structured buffer filled by a compute shader increase the shader's target level to 4.5. This isn't strictly needed but indicates that we need compute shader support.


1 Answers

Thanks to Ken Thomases' comments, I managed to find a solution. He made me realise it is quite straightforward:

I'm using a vertex struct that looks like this:

// Metal side
struct Vertex {
    float4 position;
    float4 normal;
    float4 color;
};

// Swift side
struct Vertex {
    var position: float4
    var normal: float4
    var color: float4
}

During setup where I usually create a vertex buffer, index buffer and render pipeline state, I now also make a compute pipeline state:

// Vertex buffer
let dataSize = vertexData.count*MemoryLayout<Vertex>.stride
vertexBuffer = device.makeBuffer(bytes: vertexData, length: dataSize, options: [])!

// Index buffer
indexCount = indices.count
let indexSize = indexCount*MemoryLayout<UInt16>.stride
indexBuffer = device.makeBuffer(bytes: indices, length: indexSize, options: [])!

// Compute pipeline state
let adjustmentFunction = library.makeFunction(name: "adjustment_func")!
cps = try! device.makeComputePipelineState(function: adjustmentFunction)

// Render pipeline state
let rpld = MTLRenderPipelineDescriptor()
rpld.vertexFunction = library.makeFunction(name: "vertex_func")
rpld.fragmentFunction = library.makeFunction(name: "fragment_func")
rpld.colorAttachments[0].pixelFormat = .bgra8Unorm
rps = try! device.makeRenderPipelineState(descriptor: rpld)

commandQueue = device.makeCommandQueue()!

Then my render function looks like this:

let black = MTLClearColor(red: 0, green: 0, blue: 0, alpha: 1)
rpd.colorAttachments[0].texture = drawable.texture
rpd.colorAttachments[0].clearColor = black
rpd.colorAttachments[0].loadAction = .clear

let commandBuffer = commandQueue.makeCommandBuffer()!

let computeCommandEncoder = commandBuffer.makeComputeCommandEncoder()!
computeCommandEncoder.setComputePipelineState(cps)
computeCommandEncoder.setBuffer(vertexBuffer, offset: 0, index: 0)
computeCommandEncoder.dispatchThreadgroups(MTLSize(width: meshSize*meshSize, height: 1, depth: 1), threadsPerThreadgroup: MTLSize(width: 4, height: 1, depth: 1))
computeCommandEncoder.endEncoding()

let renderCommandEncoder = commandBuffer.makeRenderCommandEncoder(descriptor: rpd)!
renderCommandEncoder.setRenderPipelineState(rps)
renderCommandEncoder.setFrontFacing(.counterClockwise)
renderCommandEncoder.setCullMode(.back)

updateUniforms(aspect: Float(size.width/size.height))
renderCommandEncoder.setVertexBuffer(vertexBuffer, offset: 0, index: 0)
renderCommandEncoder.setVertexBuffer(uniformBuffer, offset: 0, index: 1)
renderCommandEncoder.setFragmentBuffer(uniformBuffer, offset: 0, index: 1)
renderCommandEncoder.drawIndexedPrimitives(type: .triangle, indexCount: indexCount, indexType: .uint16, indexBuffer: indexBuffer, indexBufferOffset: 0)
renderCommandEncoder.endEncoding()

commandBuffer.present(drawable)
commandBuffer.commit()

Finally my compute shader looks like this:

kernel void adjustment_func(const device Vertex *vertices [[buffer(0)]], uint2 gid [[thread_position_in_grid]]) {
    vertices[gid.x].position = function(pos.xyz);
}

and this is the signature of my vertex function:

vertex VertexOut vertex_func(const device Vertex *vertices [[buffer(0)]], uint i [[vertex_id]], constant Uniforms &uniforms [[buffer(1)]]) 
like image 86
Nils Nielsen Avatar answered Oct 22 '22 10:10

Nils Nielsen