hlsltransformation-matrixgpu-instancing

When drawing mesh instances on the GPU indirectly (no CPU involved), is it better to calculate transform matrices on CPU or GPU?


I am instancing tens of thousands of meshes on the GPU - each mesh needs to have a unique transform. Is it faster to calculate tens of thousands of matrices on the CPU and pass them to the GPU via a compute buffer, or is it faster to calculate each unique TRS matrix on the gpu itself (e.g, using a compute shader)?

I have tried implementing both, but I have yet been able to correctly calculate TRS matrixes in HLSL. I just want to make sure, before trying further, that calculating on the GPU may be a good option since there are SO MANY instances.


Solution

  • Yes if you have tons of instances it can definitely be worth it to offset the calculation to gpu. Please note that if you also want to perform culling this will also need to be done on gpu as well (in case of instancing it's not such a complex operation).

    On the matter of speed you will have a threshold where performing on gpu will become faster than cpu (since you need to upload data and perform a compute pass), this will vary depending on architecture.

    This is the compute shader code I was using to convert a SRT pose to a matrix (in that case I also tried to reasonably optimize the code in order to avoid multiple matrix multiplications, even though they are blazing fast on gpu)

    #define PI acos(-1.0f)
    
    struct PoseSRT
    {
        float3 position;
        float3 scale;
       float3 rotation;
    };
    
    float4 quat_yawpitchroll(float yaw, float pitch, float roll)
    {
        float halfRoll = roll * 0.5f * PI * 2.0f;
        float halfPitch = pitch * 0.5f * PI * 2.0f;
        float halfYaw = yaw * 0.5f* PI * 2.0f;
    
        float sinRoll = sin(halfRoll);
        float cosRoll = cos(halfRoll);
        float sinPitch = sin(halfPitch);
        float cosPitch = cos(halfPitch);
        float sinYaw = sin(halfYaw);
        float cosYaw = cos(halfYaw);
    
        float4 result;
    
        result.x = (cosYaw * sinPitch * cosRoll) + (sinYaw * cosPitch * sinRoll);
        result.y = (sinYaw * cosPitch * cosRoll) - (cosYaw * sinPitch * sinRoll);
        result.z = (cosYaw * cosPitch * sinRoll) - (sinYaw * sinPitch * cosRoll);
        result.w = (cosYaw * cosPitch * cosRoll) + (sinYaw * sinPitch * sinRoll);
        return result;
    }
    
    float4x4 srt_to_matrix(PoseSRT pose)
    {
        float4 rotation = quat_yawpitchroll(pose.rotation.y,pose.rotation.x, pose.rotation.z);
    
        float4x4 result;
        float xx = rotation.x * rotation.x;
        float yy = rotation.y * rotation.y;
        float zz = rotation.z * rotation.z;
        float xy = rotation.x * rotation.y;
        float zw = rotation.z * rotation.w;
        float zx = rotation.z * rotation.x;
        float yw = rotation.y * rotation.w;
        float yz = rotation.y * rotation.z;
        float xw = rotation.x * rotation.w;
    
        result._14 = 0.0f;
        result._24 = 0.0f;
        result._34 = 0.0f;
        result._44 = 1.0f;
    
        result._11 = 1.0f - (2.0f * (yy + zz));
        result._12 = 2.0f * (xy + zw);
        result._13 = 2.0f * (zx - yw);
        result._11_12_13 *= pose.scale.x;
    
        result._21 = 2.0f * (xy - zw);
        result._22 = 1.0f - (2.0f * (zz + xx));
        result._23 = 2.0f * (yz + xw);
        result._21_22_23 *= pose.scale.y;
    
        result._31 = 2.0f * (zx + yw);
        result._32 = 2.0f * (yz - xw);
        result._33 = 1.0f - (2.0f * (yy + xx));
        result._31_32_33 *= pose.scale.z;
    
        result._41 = pose.position.x;
        result._42 = pose.position.y;
        result._43 = pose.position.z;
    
        return result;
    }
    
    StructuredBuffer<PoseSRT> InputBuffer : register(t0);
    
    RWStructuredBuffer<float4x4> OutputBuffer : register(u0);
    
    cbuffer cbSettings : register(b0)
    {
        uint elementCount;
    }
    
    [numthreads(128,1,1)]
    void CS(uint3 tid : SV_DispatchThreadID)
    {
        if (tid.x >= elementCount)
            return;
    
        PoseSRT pose = InputBuffer[tid.x];
    
        float4x4 mat = srt_to_matrix(pose);
    
        OutputBuffer[tid.x] = mat;
    }