An Introduction to Subplots in Matplotlib

Lili Beit
Analytics Vidhya
Published in
7 min readJan 22, 2021

In this story, I’ll discuss:

  1. How to create subplots in matplotlib
  2. Adding labels and padding
  3. How to automate subplot creation using a for loop

As a python beginner, I found myself unexpectedly flummoxed by trying to create subplots. What’s the difference between the plt.subplot() and plt.subplots() functions? What are the cryptic three-digit numbers such as plt.subplot(131)? And what does “fig, ax” mean? I will try to demystify subplot syntax below.

Photo by Amber Engle on Unsplash

What are subplots and why are they useful?

Using subplots simply means putting more than one plot in the same figure. For example, here is a figure containing six subplots. This data is from the King County House Sales dataset, available on Kaggle.

When presenting your data, it can be helpful to show multiple graphs at a glance, for comparison and a better understanding of the overall dataset. It’s also easier to add one image to a slide deck than to paste and align multiple images.

Basic Subplot Creation

Here is some basic code to create subplots:

# import pandas, matplotlib and seaborn
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
# choose style for plots
sns.set_style(“darkgrid”)
# import data
df = pd.read_csv('kc_house_data.csv')
# define figure and axes for subplots
fig, ax = plt.subplots(figsize=(15,4), nrows=1, ncols=3);

Here, we are defining the figure (fig) and axes (ax) that comprise the output of the subplots() function. A “figure” means the entire image containing all the subplots. “Axes” refers to the individual subplots. Since we specified that the number of rows is 1, and the number of columns is 3, our output will be:

Let’s add some data.

# add plots to each axisfig, ax = plt.subplots(figsize=(15,4), nrows=1, ncols=3);ax[0].hist(df['bedrooms'])
ax[0].set_title('bedrooms')
ax[1].hist(df['bathrooms'])
ax[1].set_title('bathrooms')
ax[2].hist(df['sqft_living'])
ax[2].set_title('sqft_living');

Here, we used the first, second, and third axes to plot histograms. We called the hist and set_title methods on each item in the array ‘ax’. Defining each

axis in the array separately would yield the same result:

# add plots to each axis - alternative codefig, [ax1, ax2, ax3] = plt.subplots(figsize=(15,4), nrows=1, ncols=3);ax1.hist(df[‘bedrooms’])
ax1.set_title(‘bedrooms’)
ax2.hist(df[‘bathrooms’])
ax2.set_title(‘bathrooms’)
ax3.hist(df[‘sqft_living’])
ax3.set_title(‘sqft_living’)

Ok, great. Now guess what — there is yet another way to create the exact same subplots. You may see subplots defined individually using plt.subplot():

# create axes individually using plt.subplot()plt.figure(figsize = (18,6)); # need this step to set figure sizeax1 = plt.subplot(1, 3, 1)
df['bedrooms'].hist()
plt.title('bedrooms')
ax2 = plt.subplot(1, 3, 2)
df['bathrooms'].hist();
plt.title('bathrooms')
ax3 = plt.subplot(1, 3, 3)
df['sqft_living'].hist()
plt.title('sqft_living')
plt.savefig('images/hist_no_figure')

The numbers in the subplot function refer to the parameters nrows, ncols, and index. So

ax1 = plt.subplot(1, 3, 1)

means we are adding this plot to a figure with 1 row and 3 columns, and that this plot should be the first axis in the figure. The commas can also be omitted, as long as the numbers are in the right order: total number of rows, total number of columns, and index of the plot. For example:

ax1 = plt.subplot(131)

Making Subplots Readable — Labels and Padding

Before we move on to automating subplot creation, let’s add labels to our subplots and adjust the white space between them. Above, we added titles to each subplot, but what if we want a title for the whole figure? To do this, we can call the suptitle() method on the figure.

fig.suptitle(‘Histograms’, fontsize=15)

Additionally, x and y axis labels can be added to each subplot, as shown below.

ax[0].set_xlabel(‘number of bedrooms’)
ax[0].set_ylabel(‘number of homes’)

The plt.subplots_adjust() function is useful for adjusting the white space between plots, and between the figure title and plots (use parameter “top” for this). Putting it all together:

fig, ax = plt.subplots(figsize=(18,6), nrows=1, ncols=3, sharey=True);
fig.suptitle('Features of King County Homes Sold in 2014 and 2015', fontsize=21)
ax[0].hist(df['bedrooms'])
ax[0].set_title('Bedrooms Distribution', fontsize=18)
ax[0].set_xlabel('number of bedrooms', fontsize=15)
ax[0].set_ylabel('number of homes', fontsize=15)
ax[1].hist(df['bathrooms'])
ax[1].set_title('Bathrooms Distribution', fontsize=18)
ax[1].set_xlabel('number of bathrooms', fontsize=15)
ax[1].set_ylabel('number of homes', fontsize=15)
ax[2].hist(df['sqft_living'])
ax[2].set_title('Square Feet Distribution', fontsize=18)
ax[2].set_xlabel('living space square footage', fontsize=15)
ax[2].set_ylabel('number of homes', fontsize=15)
plt.subplots_adjust(top=0.85, wspace=0.15)

Note that the plt.subplots() function here has one additional parameter specified:

sharey=True

This sets the y-axis of all three graphs to be the same, for ease of comparison between them.

Automating Subplot Creation

Now for the fun part! What if you have many subplots to include? For example, let’s pick 6 columns from the King County home sales data and do a histogram for each, as we saw at the beginning of this story.

It would be time-consuming to assign a plot, title, and labels for six (or more) axes, so we can use a for loop. First, note that for figures containing multiple rows of subplots, axes must be identified by their row and column positions. If we were going to write out code for the six plots the long way, we would start like this:

fig, ax = plt.subplots(figsize=(15,10), nrows=2, ncols=3, sharey=True);
fig.suptitle('Histograms', fontsize=15)
ax[0,0].hist(df['bedrooms'])
ax[0,0].set_title('Bedrooms Distribution')
ax[0,0].set_xlabel('number of bedrooms')
ax[0,0].set_ylabel('number of homes')
ax[0,1].hist(df['bathrooms'])
ax[0,1].set_title('Bathrooms Distribution')
ax[0,1].set_xlabel('number of bathrooms')
ax[0,1].set_ylabel('number of homes')
# I don't want to write these out anymore... let's automate them

Because this figure has more than one row of subplots, the axes must be identified by [row, column] instead of just [column]. For example, [0,0] indicates the axis is in the first row and first column, and [0,1] indicates the axis is in the first row and second column. (Gotta love python for using 0 as the start of all numeric sequences instead of 1.) If we wanted to call the axis in the second row, first column, we would use [1,0]. And for the third row, second column, we would use [2,1]. You get the idea!

Now, how to do we automate the creation of many rows of subplots? To create each [row, column] identifier using a for loop, we can use this code:

# choose features for the subplots
features = ['bedrooms',
'bathrooms',
'sqft_living',
'floors',
'condition',
'yr_built' ]
# choose how many columns you want
num_cols = 3
# set a number of rows
if len(features)%num_cols == 0:
num_rows = len(features)//num_cols
else:
num_rows = (len(features)//num_cols) + 1
# define the figure and axes
fig, ax = plt.subplots(figsize=(17,12),
nrows=num_rows,
ncols=num_cols)
# add a figure title
fig.suptitle('Features of King County Homes Sold in 2014 and 2015',
fontsize=21)
#use a for loop to create each subplot:
for feat in features:
row = features.index(feat)//num_cols
col = features.index(feat)%num_cols

ax[row, col].hist(df[feat], bins=20)
ax[row, col].set_title(feat.title()+' Distribution',
fontsize=18)
ax[row, col].set_xlabel(feat.title(),
fontsize=15)
ax[row, col].set_ylabel('Number of Homes',
fontsize=15)

plt.subplots_adjust(top=0.9, wspace=0.3, hspace=0.3)

Let’s look at the for loop first. Here, we assign each feature in our ‘features’ list to a particular subplot [row, column] based on the feature’s index position in the list.

To do this, we use the // and % operators.

The Floor Division operator // returns the quotient of two numbers, and then rounds down to the nearest integer. The Modulo operator % divides the first number by the second and returns the remainder.

Let’s think back to fourth grade for a moment. Remember when 11 divided by 3 was not called 3.67 but was called “3 remainder 2?” In the same way,

11 // 3 = 3

and

11 % 3 = 2

The Floor Division operator gives us the quotient without the remainder, and the Modulo operator gives us the remainder.

This is useful for our for loop. We can automate subplot creation by dividing the feature’s index in the ‘features’ list by the number of columns in our figure. The rounded-down quotient will be the feature’s row position, and the remainder will be the feature’s column position. So, for the sixth item in our ‘feat’ list,

ax[row, col]

is the same as saying:

ax[5 // 3, 5 % 3]

or

ax[1, 2]

which puts this subplot in the second row, third column.

The ‘if’ statement in the code above also uses the // and % operators to assign a number of rows to the figure, based on how many columns you have chosen:

# set a number of rows
if len(features)%num_cols == 0:
num_rows = len(features)//num_cols
else:
num_rows = (len(features)//num_cols) + 1

This code is saying: if the features list length divided by the number of columns has a remainder of 0, then the number of rows should be just that — the length of the features list divided by the number of columns. If the remainder is not 0, then add one extra row (to hold the remaining subplots).

And there we have it — multiple axes created using a for loop, and the // and % operators to set row and column positions for each axis. All code for this article can be found here. Thanks for reading, and I’d love to hear your comments and suggestions.

--

--