metalgpumetal-performance-shaders

Minimum matrix sizes to benefit from matrix multiplication on GPU


I am particularly interested in matrix multiplication using Metal Performance Shaders, but answers about other frameworks are also fine.

Matrix multiplication is theoretically highly parallelisable operation. I need to multiply a lot of matrices by themselves like A’ A (where apostrophe stands for transposition). The size of the matrices A is about 4000 x 300. I was wondering if it’s worth porting the multiplication code to the GPU given the size of these matrices. As I understand, multiplying on GPU will also involve copying the data from main memory to GPU memory (I’m using eGPU, so the memory is not shared). Then there must be a trade off between additional effort for copying the data back and forth, and speed up in the calculations. So my question is: at what sizes of the matrices (approx) I could start to see the benefits of doing it on GPU?

P.S. There is also this article which basically says to not bother because GPU doesn’t help, something about its memory cache being slow (in general on all GPUs): https://graphics.stanford.edu/papers/gpumatrixmult/gpumatrixmult.pdf


Solution

  • I've made a test, and it's significantly faster (x 8-9) on GPU for my case, even including all the memory copying from CPU to GPU and back. I am comparing float32 matrix multiplication performance, since Metal doesn't support float64.

    let count = 100
    
    let N = 7005
    let K = 700
    
    let DIV = 8
    let K2 = (K / DIV) * DIV + (K % DIV > 0 ? 1 : 0) * DIV
    let N2 = (N / DIV) * DIV + (N % DIV > 0 ? 1 : 0) * DIV
    
    print(N2)
    print(K2)
    
    printTimeElapsedWhenRunningCode(title: "vDSP(f)") {
        
        let ATf = [Float].init(repeating: Float(1), count: N*K)
        let Af = [Float].init(repeating: Float(1), count: N*K)
        var C = Array(repeating: Float(0), count: K*K)
    
        for _ in 0..<count {
    
            vDSP_mmul(ATf, 1,
                      Af, 1,
                      &C, 1,
                      vDSP_Length(K),
                      vDSP_Length(K),
                      vDSP_Length(N))
        }
    }
    
    guard let bufferA = device.makeBuffer(length: K2 * N2 * MemoryLayout<Float>.stride,
                                          options: [.storageModeManaged]) else {
        fatalError("Could not make buffer A")
    }
    
    guard let bufferC = device.makeBuffer(length: K2 * K2 * MemoryLayout<Float>.stride,
                                          options: [.storageModeManaged]) else {
        fatalError("Could not make buffer C")
    }
    
    let descA = MPSMatrixDescriptor(dimensions: N2,
                                    columns: K2,
                                    rowBytes: K2 * MemoryLayout<Float>.stride,
                                    dataType: .float32)
    
    let descC = MPSMatrixDescriptor(dimensions: K2,
                                    columns: K2,
                                    rowBytes: K2 * MemoryLayout<Float>.stride,
                                    dataType: .float32)
    
    let matrixA = MPSMatrix(buffer: bufferA, descriptor: descA)
    let matrixC = MPSMatrix(buffer: bufferC, descriptor: descC)
    
    let matrixMultiplication = MPSMatrixMultiplication(device: device,
                                                       transposeLeft: true,
                                                       transposeRight: false,
                                                       resultRows: K2,
                                                       resultColumns: K2,
                                                       interiorColumns: N2,
                                                       alpha: 1,
                                                       beta: 0)
    
    guard let commandQueue = device.makeCommandQueue() else {
        fatalError("Could not make command queue")
    }
    
    printTimeElapsedWhenRunningCode(title: "Metal") {
        
        let Af = [Float].init(repeating: Float(1), count: N*K)
        let zeros = [Float].init(repeating: Float(0), count: K2)
    
        for i in 0..<count {
    
            var dest = bufferA.contents()
            Af.withUnsafeBufferPointer { pA in
                var from = pA.baseAddress!
                for _ in 0..<N {
                    dest.copyMemory(from: from, byteCount: K)
                    dest += K
                    if K2 > K {
                        dest.copyMemory(from: zeros, byteCount: K2 - K)
                        dest += K2 - K
                    }
                    from += K
                }
            }
            
            for _ in 0..<(N2-N) {
                dest.copyMemory(from: zeros, byteCount: K2)
            }
            
            bufferA.didModifyRange(0..<N2*K2)
            
            let commandBuffer = commandQueue.makeCommandBuffer()!
    
            matrixMultiplication.encode(commandBuffer: commandBuffer,
                                        leftMatrix: matrixA,
                                        rightMatrix: matrixA,
                                        resultMatrix: matrixC)
    
            let blitEncoder = commandBuffer.makeBlitCommandEncoder()!
            blitEncoder.synchronize(resource: bufferC)
            blitEncoder.endEncoding()
            
            commandBuffer.commit()
    
            if i == count - 1 {
                commandBuffer.waitUntilCompleted()
            }
        }
    }
    

    Output:

    AMD Radeon RX 5700 XT
    7008
    704
    Time elapsed for vDSP(f): 5.156805992126465 s.
    Time elapsed for Metal: 0.6834449768066406 s.
    DONE.