The Art of Problem Solving: How to Create a One Line List Comprehension for Matrix Multiplication in Python

Photo by Dawid Małecki on Unsplash

This article assumes knowledge of Python (list comprehensions) and linear algebra (matrix multiplication). But, while this post is about how to write a one-line list comp for matrix multiplication, it’s also about the problem solving process that you can use to solve these kinds of problems.

(Quick note: This post has no relation to the phenomenal book “The Art of Problem Solving” about solving math problems—although I do love that book, and highly recommend it.)

This is the wonderful book that this post has no relationship to. https://artofproblemsolving.com/

The Problem Statement

First, let’s set up the problem.

The question put to us is this: “Using just built-in Python functions, can you write a one-line list comprehension to perform matrix multiplication on two matrices stored as lists of lists?”

Before reading the rest of the post, you might be interested in trying this yourself! Depending on your experience with Python and linear algebra, this might either be a fun challenge for the next 30 minutes or a grueling puzzle that takes you days. (Grueling puzzles can also be fun, if you’re into that kind of thing. I most certainly am.)

As an important side note—I will always, always, be sketching things out by hand when I’m trying to solve hard problems. I’m not going to share photos of my sketch work here, but know that there was a lot of whiteboard doodling going on as I worked through this problem. I highly recommend putting your thoughts into drawings when you’re problem solving. This also makes it much easier to communicate your thoughts to others!

Alright, so here’s an example of matrix multiplication:

https://en.wikipedia.org/wiki/Matrix_multiplication

Some helpful shortcuts to remember when dealing with matrix multiplication. Let’s say we’re multiplying A*B, where A is a (4x2) matrix and B is a (2x3) matrix (like in the example above).

  • The number of columns of A need to match the number of rows in B. (We’re good in this case—there are 2 columns in A, and 2 rows in B.) If this isn’t true, you can’t multiply the matrices together.
  • The resulting matrix will have as many rows as A and as many columns as B. So in this case, the result will be a (4x3) matrix.

(By the way—if you haven’t watched the 3Blue1Brown series on “The Essence of Linear Algebra”, it’s probably the best series on linear algebra I’ve ever seen. Each video, you’ll find yourself going “I had no idea that’s what all of this was about!” It gives you an amazing intuition for what’s going on in linear algebra.)

Moving on—here’s an example of list comprehensions in Python:

And we want to create something like this…

So how do we go about solving this?

Problem Solving Part 1: Sketch Out a Solution

First, let’s think through how we’d solve this if there were no constraints.

Kind of cheating—using np.matmul()

If we weren’t restricted to just using built-in Python functions, and we didn’t need to use a list comprehension, the problem would be trivial—we could just use the NumPy matrix multiplication function.

import numpy as npdef numpy_matrix_multiplication(A, B):
return np.matmul(A, B) # same as "return A @ B"

But that’s obviously defeating the purpose of this puzzle.

Using a for loop and NumPy functions

So let’s say we can’t use np.matmul(), but anything else is fine. How would we approach it then?

Well, matrix multiplication can be thought of as taking a row of A, a column of B, doing the dot product between them, and then storing that result in the new matrix. Specifically, if we take the 2nd row of matrix A and the 3rd column of matrix B, and then we take the dot product of those two, then we’ll store the result in position (2, 3) in the new matrix (the value at row 2 column 3).

Let’s try to put this idea into code using for loops. We’ll still use NumPy for the matrix dot product for now, just so we don’t have to worry about it at first.

This works! But we’re leaning pretty heavily on NumPy functions and objects currently (like NumPy arrays and the .dot() method), so in just a minute we’re going to see if we can write a for loop without some of this NumPy functionality.

(Also, notice how we’re using the built-in enumerate function to get the indices for where the value needs to go in our new matrix.)

Understanding the problem is sometimes easy, sometimes really difficult. It can often help to start with simpler versions of the problem to start with, or to break the problem down into smaller and simpler pieces. Here are some examples of that process:

  • Before trying to implement matrix multiplication, make sure you really understand how it’s computed. Try it for two 2x2 matrices, working out the math by hand. Google around for good explanations of what matrix multiplication is doing.
  • Write the simplest list comprehension you can, and then increase the complexity slowly: you could add a for loop inside the list comp; then you could add a conditional; then maybe you could try two for loops; then two loops and a conditional; etc. Build things up in complexity, trying to develop an intuitive sense of what’s going on.
  • Before using list comprehensions, use for loops.
  • Before typing anything into a computer, use pencil and paper.
  • Sketch things and try things even if you don’t feel like you understand what’s going on. Experimentation is a wonderful way to learn!

Problem Solving Part 2: Improve the Solution Iteratively

Now that we basically know what we’re doing, we’re going to slowly improve our solution by getting it closer to the final product. In our case, this mostly means converting everything to use built-in Python functions and objects (rather than NumPy functions and objects). We’ll also do some code cleanup at the end.

This is where the bulk of the work actually happens

Writing our own dot product (replacing np.dot)

First, let’s get rid of the NumPy dot product function. To compute the dot product of two vectors of equal length, you essentially multiply the numbers that are at the same indices of each vector and then add them all together. So let’s say you have two vectors: (a1 a2) and (b1 b2). Then the dot product would be:

dot_product = a1*b1 + a2*b2

Let’s put this into our code.

Creating the new matrix without relying on np arrays

Next, let’s think about how we can create the result matrix without using NumPy arrays to store the values as certain indices.

The way our for loops are nested, we’re going to dot product a single row of A by all of our columns in B before moving onto the next row in A. This means that these dot product values will all exist in the first row of our resulting matrix.

(Remember that if we take the dot product of the 1st row in matrix A with the nth column in matrix B, we’ll store that value in position (1, n) in our new matrix.)

This means that each time we take a row in A and iterate through dot products of the columns in B, we can create a new list with all of those results. That list will become a new row in our resulting matrix! After going through each row in A, we’ll create a list of lists—which will be exactly the matrix we’re looking for.

If you’re not sure how this works, step through the for loop yourself and see what’s happening.

Here’s the code:

A brief interlude: what happened to the list comp?

Notice that we’ve effectively forgotten about the list comprehension part of the puzzle for now.

This is totally fine! In fact, it’s necessary—before solving the problem with certain parameters, we need to figure out how to solve the problem at all.

We’re currently figuring out how to do matrix multiplication using just built-in Python functions, and this isn’t a trivial task. After figuring this out using any approach we need (in this case, a for loop), we can then move on to crafting a solution that satisfies all of the requirements—namely, using a list comprehension.

Break down a problem into simpler parts and give yourself any resources you need before trying to solve the full thing perfectly.

Transposing B without using NumPy array functionality

Currently, we’re converted each matrix into a NumPy array in the beginning of the function so that we can transpose matrix B using “B.T”. We’re transposing it like this so that we can iterate through the columns of B using the normal Python “for item in my_list” syntax.

How can we iterate through the columns of B using built-in Python functionality?

This is a good time to mention that during this whole time, it’s incredibly helpful to be testing as you go along. As I’m solving this problem, I have an iPython terminal open and I’m trying out things constantly. I’ve created a matrix A to play with and a matrix B to play with, and I’m testing out my for loops and my functions as I go along.

Test small things constantly—you’ll be learning with each small bit of code you write, and you won’t go too far in a bad direction.

So for answering this question—“how do we iterate through columns in matrix B without converting B to a NumPy array?”—I’m going to be trying all kinds of things in the terminal and seeing what works and what doesn’t. I’m also going to be Googling around for functions I may not know about.

In this case, a Google search for “python transpose list of lists” yields this StackOverflow result:

Perfect! This is a cool trick—using the * operator to “unpack” B into individual lists, then using zip() to put those lists together into tuples. Using the map function with “list” as the first argument returns those tuples as lists, rather than tuples. If we want to return a list of lists (rather than a map object), we can wrap this whole thing with “list()”.

If you’re not sure what that looks like, create a matrix B and try it yourself!

transposed_B = list(map(list, zip(*B)))

In our case, we don’t necessarily care about returning a list of lists—we’re fine with a list of tuples—so we’ll drop the “map(list, …)” part of this answer. We’re also fine with iterating through the zip object (rather than explicitly converting to a list of tuples), so we’ll just use zip(*B).

transposed_B = zip(*B)

Done!

Now that we don’t need to convert B to a NumPy array, let’s rewrite that functionality and see what our code looks like:

Code cleanup—remove enumerate and dot_product variable

Our for loop code now computes the matrix multiplication of A and B without using any NumPy functions! We’re getting really close to the point of trying to convert this all into a one-line list comprehension.

First, let’s clean up the code a bit. We’ll do these three things:

  1. Remove enumerate. We’re no longer using our row and column indices for anything, so we can just iterate through the rows and columns themselves.
  2. Instead of storing the dot product as a new variable before appending it to new_row, we’ll just directly append it. This will make the conversion to a list comprehension a little easier.
  3. Add an “if name = main” block to the bottom of the file as a place where we can test our new function out. This way, if we wanted we could simply run the file (for example: $ python matrix_multiplication.py) and see if our function still works.

Here’s the code:

Problem Solving Part 3: Fulfill All Requirements

We’re almost there! Sort of.

We’ve written out matrix multiplication in Python using only built-in functions, but we’re currently using for loops. Now, we need to convert everything that we’ve written into a one-line list comprehension.

Nested for loops in list comprehensions

First, let’s think about this. We currently have two nested for loops. How would we write two nested for loops using a list comprehension?

Googling “python list comprehension nested loop” brings us to this handy-dandy StackOverflow answer:

So it looks like this is the order for nested loops in list comps:

  1. First, write the result you want to return: j+k
  2. Then, write the first loop: for j in s1
  3. Then, write the second loop: for k in s2
  4. Put it all together: [j+k for j in s1 for k in s2]

What does it look like if we convert our for loops into this structure?

  1. Write the result: …hmm, is our result the dot product sum? Or is it the list append? What exactly is our result here?

Ok, so we need to do a little more thinking.

Turning our second for loop into a list comprehension

List comprehensions are basically just for loops in a different format. So what happens if we try squeezing our second for loop into a list comprehension before trying to do both loops?

We basically just take our second for loop and turn it into a list comp using the usual syntax—no nested looping required. Now we’re in a good position to return to our original question of what result we want to return in our final list comp.

Turning the outer for loop into a list comprehension

So what are we returning for our final list comprehension?

Well, it looks like we’re returning “new_row”. Does this make sense? If we think about it, is does—for each row in A, we want to return a new row in the new matrix. That new row just happens to be created through the process of the rather-complicated-looking list comprehension that we just created using the dot products with the columns of B.

So let’s try squeezing our outer loop into a single list comprehension now…

Since we’ve already done the work of squeezing our inner loop down into a list comp, this part actually seems pretty easy! (Other than the fact that the inner list comp is already pretty long and complicated looking.) We don’t even need the nested for loop list comprehension syntax, because our inner loop is “hidden” inside the inner list comprehension.

If you need to think about what this looks like without that complicated list comprehension staring at you, just replace it with L or some other variable, squeeze the outer for loop into the final list comp, and then replace L with the complicated inner loop list comprehension again.

This is another good tip for problem solving:

If complicated math is throwing you off, replace the complicated math with a short description of what the math is doing or a simple variable that represents the complicated math—and then keep working your way through the problem. Get an intuitive understanding for what’s going on before you bring in long, scary, complicated calculations.

Finalizing our solution

I want to say that we’re “bringing it all together” here, but we’ve already done that! The only thing left is to clean up our code and make sure our function docstring looks good.

Of course, at this point we could also do more extensive testing to make sure our solution works. It’s often considered best practice to write tests before starting development (for example, using unittest) so that you can think of your edge cases and desired functionality before getting too deep in the coding, and so that you can test yourself as you’re going along.

Here’s what the final code looks like:

Yay! We’re done.

A couple qualifications

There are a couple things we should point out, since we’ve told ourselves we’re done with the problem.

First, our function as it currently exists doesn’t check to see if the matrices can actually be multiplied together. If our A matrix is a 4x2 and our B matrix is a 3x3, we aren’t going to be able to multiply them together—our code doesn’t currently account for that.

If we wanted to do this checking, we could add a couple very simple lines at the beginning of the function. Whether or not this breaks the “one line list comp” requirement is subject to debate:

# We check to see if the number of columns
# in A matches the number of rows in B.
if len(A[0]) != len(B):
# Raise some kind of exception here.
raise Exception

Second—just to reinforce testing once more—we really should’ve been testing this from the beginning. Whenever you start a problem, it can help to ask: “What are examples of input to this problem, and expected output? What kinds of edge cases are there that might break my solution if I’m not careful?” Edge cases often exist for things like:

  • Input data we aren’t expecting (like dictionaries instead of lists, floats instead of ints, or strings instead of numbers).
  • Empty / null / zero input data (like [], {}, None, 0, and so on).
  • Negative input data (if we only test things with positive numbers, negative numbers could cause issues).

Conclusions

This was a complicated little puzzle, so like I mentioned in the beginning, your current level of experience is going to determine how much of this you understand.

No matter your level, the techniques here can help you solve all kinds of problems:

  • Define the Problem. Start with a clear understanding of your problem and what solutions look like.
  • Sketch the Solution. Using any tools and approaches you can think of, sketch out what a solution looks like. It won’t be the best solution—just find a way to get to something that works. Break the problem into smaller sub-problems if needed, or solve a much simpler version of the full problem. Don’t worry about using real code or math.
  • Iterate Toward Better Solutions. Look at one piece at a time and make it better. Don’t try to go from a bad solution to a perfect solution in one jump—and realize that there actually is no such thing as a “perfect” solution. Nothing is perfect: things are just better or worse when evaluated on a certain performance metric. (And things that are better on one metric might often be worse on another, like common tradeoffs of time vs. memory in algorithm design.)
  • Fulfill All Requirements. Make sure your final solution fulfills all the requirements set out in the beginning!
  • Test, Test, Test. Test small bits of the solution while you’re working. Test functions that you haven’t used before. Test lines of your code separately, then test it all together. Write a battery of tests to throw at your function anytime you want that you know it’s working correctly (or so you know when you’ve finally built something that works). Always be testing!

This last point is so important that I want to frame it another way.

Problem solving is experimentation. You experiment with various problem definitions; you experiment with general approaches to the solution; and you run small experiments constantly to validate each piece of what you’re doing. Experimenting means getting feedback, and the faster you can experiment the faster you can build something that meets your requirements. Experiment, experiment, experiment.

I hope you can find something useful in this post—something that takes your problem solving skills to the next level. Claps and shares are greatly appreciated!

And finally, if you have any good resources on problem solving, I would love to hear about them.

Happy problem solving!

Steven Rouk

--

--

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store