Strassen’s Matrix Multiplication Algorithm

Saahil Mahato
The Startup
Published in
5 min readAug 17, 2020

Strassen’s Algorithm is an algorithm for matrix multiplication. It is faster than the naive matrix multiplication algorithm. In order to know how, let’s compare both of these algorithms along with the their implementation in C++.

Suppose we are multiplying 2 matrices A and B and both of them have dimensions n x n. The resulting matrix C after multiplication in the naive algorithm is obtained by the formula:

for i = 1, …, n and j = 1, …, n

The C++ implementation of this formula is:

int** multiply(int** A, int** B, int n) {
int** C = initializeMatrix(n);
setToZero(C, n);
for(int i=0; i<n; i++)
for(int j=0; j<n; j++)
for(int k=0; k<n; k++)
C[i][j] += A[i][k] * B[k][j];
return C;
}

In this algorithm, the statement “C[i][j] += A[i][k] * B[k][j]” executes n³ times as evident from the three nested for loops and is the most costly operation in the algorithm. So, the time complexity of the naive algorithm is O(n³).

Now let’s take a look at Strassen algorithm. Strassen algorithm is a recursive method for matrix multiplication where we divide the matrix into 4 sub-matrices of dimensions n/2 x n/2 in each recursive step.

For example, consider two 4 x 4 matrices A and B that we need to multiply. A 4 x 4 can be divided into four 2 x 2 matrices.

Here, Aᵢⱼ and Bᵢⱼ are 2 x 2 matrices.

Now, we can calculate the product of A and B (matrix C) with the following formulas:

Screenshot taken from: http://www.cs.utsa.edu/~wagner/CS3343/strassen/strassen_detail.html

This version of the formulas is the one I find easiest to remember. You can also find a guide on how to remember them easily on https://www.geeksforgeeks.org/easy-way-remember-strassens-matrix-equation/

It isn’t clear how Strassen came up with these formulas but we can verify that it works. Also, there are conditions for Strassen’s algorithm to work.

  1. Both input matrices should be of dimensions n x n.
  2. n should be a power of 2.

If the above conditions are not satisfied, we must pad the matrices with 0 to satisfy the above conditions.

Now, let’s see the implementation of the algorithm in C++.

We need basic matrix operations, add and subtract to implement the algorithm as we can see that it is necessary in the formula. The code for the functions is:

int** add(int** M1, int** M2, int n) {
int** temp = initializeMatrix(n);
for(int i=0; i<n; i++)
for(int j=0; j<n; j++)
temp[i][j] = M1[i][j] + M2[i][j];
return temp;
}

int** subtract(int** M1, int** M2, int n) {
int** temp = initializeMatrix(n);
for(int i=0; i<n; i++)
for(int j=0; j<n; j++)
temp[i][j] = M1[i][j] - M2[i][j];
return temp;
}

Now, lets implement the Strassen’s matrix multiplication function. We need to call the StrassenMultiply function recursively by dividing the matrices into four matrices of dimension n/2 x n/2 and calculate the product using the formulas mentioned before in order to multiply the entire matrices. Also, we need a base case to stop the recursion. The base case is when the dimension of the matrix is of dimension 1 x 1 and the product of two elements is returned.

The code:

The base case:

if (n == 1) {
int** C = initializeMatrix(1);
C[0][0] = A[0][0] * B[0][0];
return C;
}

Declaring C and calculating dimension of sub-matrices:

int** C = initializeMatrix(n);
int k = n/2;

Initializing sub-matrices and defining sub-matrices:

int** A11 = initializeMatrix(k);
int** A12 = initializeMatrix(k);
int** A21 = initializeMatrix(k);
int** A22 = initializeMatrix(k);
int** B11 = initializeMatrix(k);
int** B12 = initializeMatrix(k);
int** B21 = initializeMatrix(k);
int** B22 = initializeMatrix(k);
for(int i=0; i<k; i++)
for(int j=0; j<k; j++) {
A11[i][j] = A[i][j];
A12[i][j] = A[i][k+j];
A21[i][j] = A[k+i][j];
A22[i][j] = A[k+i][k+j];
B11[i][j] = B[i][j];
B12[i][j] = B[i][k+j];
B21[i][j] = B[k+i][j];
B22[i][j] = B[k+i][k+j];
}

Calculating the product P1 to P7 and the resulting matrix C using formulas:

int** P1 = strassenMultiply(A11, subtract(B12, B22, k), k);
int** P2 = strassenMultiply(add(A11, A12, k), B22, k);
int** P3 = strassenMultiply(add(A21, A22, k), B11, k);
int** P4 = strassenMultiply(A22, subtract(B21, B11, k), k);
int** P5 = strassenMultiply(add(A11, A22, k), add(B11, B22, k), k);
int** P6 = strassenMultiply(subtract(A12, A22, k), add(B21, B22, k), k);
int** P7 = strassenMultiply(subtract(A11, A21, k), add(B11, B12, k), k);

int** C11 = subtract(add(add(P5, P4, k), P6, k), P2, k);
int** C12 = add(P1, P2, k);
int** C21 = add(P3, P4, k);
int** C22 = subtract(subtract(add(P5, P1, k), P3, k), P7, k);

Copying values to C and returning C

for(int i=0; i<k; i++)
for(int j=0; j<k; j++) {
C[i][j] = C11[i][j];
C[i][j+k] = C12[i][j];
C[k+i][j] = C21[i][j];
C[k+i][k+j] = C22[i][j];
}
return C;

Now let’s calculate the time complexity of the algorithm using master method. The algorithm makes seven recursive calls while calculating P1 to P7, so a=7. The input matrix size is divided by 2 in each recursive call, so b=2. The work done outside recursive call and merging the solutions is adding, subtracting and copying values to C which is O(n²), so d=2.

So, the master’s equation is T(n) = 7T(n/2) + O(n²)

This satisfies the condition, a > b^d, so the time complexity of the strassen’s matrix multiplication algorithm is O(n^log2(7)) = O(n^2.81). So, Strassen’s matrix multiplication algorithm is asymptotically faster than the naive algorithm.

This might not seem as a significant improvement but for large input sizes the difference is significant. We can see the difference in the graph below.

As you can see when the input becomes larger strassen’s algorithm shows significant performance increase compared to the naive algorithm. Strassen’s algorithm can also be parallelized like the naive algorithm to further improve performance.

Strassen’s matrix multiplication algorithm also has a few disadvantages:

  1. Recursion stack consumes more memory.
  2. The recursive calls add latency.

Due to these reasons, naive algorithm is a better option for smaller inputs which can be determined from the graph too.

Thank you for reading, if you find any error or have a way to optimize the code please mention them in the comments and I will correct or implement them. If you want the full working program, you can find the code for the Straight Forward (Naive) matrix multiplication and Strassen’s matrix multiplication algorithm implemented in C++, Java and Python in my github repository. The link is provided below:

--

--

Saahil Mahato
The Startup

Hi, I am a Computer Science undergrad. I enjoy programming and anything to do with technology in general. I write about things that I am curious about.