GEMM - Part 1: Basics and a CPU Implementation Example

GEMM stands for general matrix multiply, it is the "level 3" routine of the BLAS (Basic Linear Algebra Subprograms), built for common linear algebra operations. GEMM is also widely used in areas like computer vision, machine learning.

The Formula

The formula of GEMM is:

$C=\alpha AB+\beta C$

where $A$, $B$, and $C$ are matrices, $\alpha$ and $\beta$ are constant values.

GEMM C Example

Here is a very neat GEMM implementation written in C, in the well-known neural network framework darknet.

darknet/gemm.c at master · pjreddie/darknet
Convolutional Neural Networks. Contribute to pjreddie/darknet development by creating an account on GitHub.
void gemm_cpu(int TA, int TB, int M, int N, int K, float ALPHA, 
        float *A, int lda, 
        float *B, int ldb,
        float BETA,
        float *C, int ldc)
{
    int i, j;
    for(i = 0; i < M; ++i){
        for(j = 0; j < N; ++j){
            C[i*ldc + j] *= BETA;
        }
    }
    if(!TA && !TB)
        gemm_nn(M, N, K, ALPHA,A,lda, B, ldb,C,ldc);
    else if(TA && !TB)
        gemm_tn(M, N, K, ALPHA,A,lda, B, ldb,C,ldc);
    else if(!TA && TB)
        gemm_nt(M, N, K, ALPHA,A,lda, B, ldb,C,ldc);
    else
        gemm_tt(M, N, K, ALPHA,A,lda, B, ldb,C,ldc);
}

By default, matrix $A$, $B$ is not transposed (TA=0 && TB=0), which means:

*A is a 1-d array which stores a $(M, K)$ matrix

*B is a 1-d array which stores a $(K, N)$ matrix

*C is a 1-d array which stores a $(M, N)$ matrix, and it will be used to store the final result.

$C=\beta C$ will be computed first for better efficiency. Then we will do $C=\alpha AB+C$ part.

Matrix $A$, $B$, and $C$ are all stored in a row-majored order, which means elements of the same row are stored consecutive in memory. (This doesn't mean all elements in the matrix is stored consecutive in memory).

Leading Dimension

Elements of the matrix used in the gemm function is not necessarily stored consecutive in memory? A little counter-intuitive right? To explain this, I need to introduce leading dimension (argument lda, ldb, and ldc).

Actually, elements of a matrix is stored consecutive in memory, but when multiplying matrix, sometimes we want to use part of an existing matrix as the input/output, not all of it.

Suppose we have a $(6, 8)$ matrix $Q$ in our memory (row-majored order), and we want to do matrix multiply on part of it, which is a $(3, 4)$ matrix $q$.

/*
Q Q Q Q Q Q Q Q
Q Q q q q q Q Q
Q Q q q q q Q Q
Q Q q q q q Q Q
Q Q Q Q Q Q Q Q
Q Q Q Q Q Q Q Q
*/
i want to use q in gemm, but i don't need to copy it explicitly

Apparently, elements in matrix $q$ were not stored consecutively in memory. Instead of copying the data first then do the gemm, we can do gemm directly if we use the right parameters *A, M, K, and most importantly, lda. In the previous example:

TA=0 means matrix $Q$ and of course matrix $q$ are row-majored, or not transposed.

lda=8 means the leading dimension (number of columns in this case) of the matrix stored in the memory is $8$, which is the dimension of the matrix $Q$.

K=4 means the dimension (number of columns in this case) of the matrix used for gemm is $4$, which is the dimension of the matrix $q$.

M=3 means the number of rows is $3$ for matrix $q$.

*A=*(Q+10) means the first element of matrix $q$, is the 11th element of the matrix $Q$, starting address together with offset were given here.

These are all we need for one input/output of the gemm function. And if you're familiar with numpy, here's an example in Python:

import numpy as np

# 1-d array Q, to get the idea how it is stored in memory
Q = np.arange(6 * 8)
print(Q)
# 2-d array QQ, how we understand the matrix, with 2-d shape information
QQ = Q.reshape(6, 8)
print(QQ)
# to help you understand the C explanation above
lda = QQ.shape[1] # 8
K = 4
M = 3
offset = 10
# these are all we need to get the q, or to use it directly in gemm function
q = QQ[offset//lda: offset//lda+M, offset%lda: offset%lda+K]
# q = QQ[1:4, 2:6]
print(q)

And here's the explanation of leading dimension provided by IBM for their ESSL (Engineering and Scientific Subroutine Library).

Matrix Multiplication

After the easy part $C=\beta C$ is done, $\alpha AB$ will be computed. The order of storage of matrix $A$, $B$ should be considered and taken care of.

from wikipedia (Row- and column-major order), link below

Matrices can be stored in row-major order or column-major order. Row-major order is used for C-style arrays. That means, by default, elements of the same row are considered to be stored consecutively. But in some cases (an example below will show a situation that can benefit from it) we need to store matrices in a column-major order. Storing a matrix in column-major order in a row-major order convention is eqivalent to store the transpose matrix of the origianl in the memory.

Now you may know why we need int TA and int TB parameters in our gemm function. In our simple example, TA=0 means matrix $A$ is stored in a row-major order. And TA!=0 means matrix $A$ is stored in a column-major order, or you can say the transpose matrix of $A$, which is $A^T$, is stored in the memory.

// if (TA == 0 && TB == 0)
void gemm_nn(int M, int N, int K, float ALPHA, 
        float *A, int lda, 
        float *B, int ldb,
        float *C, int ldc)
{
    int i,j,k;
    #pragma omp parallel for
    for(i = 0; i < M; ++i){
        for(k = 0; k < K; ++k){
            register float A_PART = ALPHA*A[i*lda+k];
            for(j = 0; j < N; ++j){
                C[i*ldc+j] += A_PART*B[k*ldb+j];
            }
        }
    }
}

When TA==0 && TB==0, the snippet above will be used to compute $C=C+\alpha AB$

A More Efficient Way

// if (TA == 0 && TB != 0)
void gemm_nt(int M, int N, int K, float ALPHA, 
        float *A, int lda, 
        float *B, int ldb,
        float *C, int ldc)
{
    int i,j,k;
    #pragma omp parallel for
    for(i = 0; i < M; ++i){
        for(j = 0; j < N; ++j){
            register float sum = 0;
            for(k = 0; k < K; ++k){
                sum += ALPHA*A[i*lda+k]*B[j*ldb + k];
            }
            C[i*ldc+j] += sum;
        }
    }
}