Matplotlib Tutorial

Matplotlib - Subplots() Function



The Matplotlib.pyplot.subplots() function creates a figure and a grid of subplots with a single call, while providing reasonable control over how the individual plots are created. The syntax for using this function is given below:

Syntax

matplotlib.pyplot.subplots(nrows=1, ncols=1, **fig_kw)

Parameters

nrows Optional. Specify number of rows in the subplot grid. Default is 1.
ncols Optional. Specify number of columns in the subplot grid. Default is 1.
**fig_kw Optional. Additional keyword arguments are passed to the pyplot.figure call.

Return Value

Returns fig:Figure and ax:axes or array of axes. Typical idioms for handling the return value are:

# using the variable ax for single Axes
fig, ax = plt.subplots()

# using the variable ax for multiple Axes
#axes can accessed using ax[0][0], ax[0][1], 
#so on OR ax[0, 0], ax[0, 1] 
fig, ax = plt.subplots(2, 2)

# using tuple unpacking for multiple Axes
#axes can accessed using ax1, ax2,.. etc
fig, (ax1, ax2) = plt.subplots(1, 2)
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2)

Example: subplots example

In the example below, a subplot of 2x2 grid is created each displaying a different plot.

import matplotlib.pyplot as plt
import numpy as np

#creating a array of values between
#0.1 to 10 with a difference of 0.1
x = np.arange(0.1, 10, 0.1)
y1 = np.sin(x)
y2 = np.tan(x)
y3 = np.exp(x)
y4 = np.log(x)

fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2)

#plotting curves
ax1.plot(x, y1) 
ax2.plot(x, y2) 
ax3.plot(x, y3) 
ax4.plot(x, y4) 

#formatting axes
ax1.set_title("Sine")
ax2.set_title("Tan")
ax3.set_title("Exp")
ax4.set_title("Log")

#displaying the figure
plt.tight_layout()
plt.show()

The output of the above code will be:

SubPlots Function

Super Title

The suptitle() function can be used to add title to the entire figure.

Example: adding title to a figure

Consider the example below, where a super title is added to the above plot. Please also note the two different ways of handling the subplot() function return.

import matplotlib.pyplot as plt
import numpy as np

#creating a array of values between
#0.1 to 10 with a difference of 0.1
x = np.arange(0.1, 10, 0.1)
y1 = np.sin(x)
y2 = np.tan(x)
y3 = np.exp(x)
y4 = np.log(x)

fig, ax = plt.subplots(2, 2)

#plotting curves 
#method 1
ax[0, 0].plot(x, y1) 
ax[0, 1].plot(x, y2) 
#method 2
ax[1][0].plot(x, y3) 
ax[1][1].plot(x, y4) 

#formatting axes
ax[0, 0].set_title("Sine")
ax[0, 1].set_title("Tan")
ax[1][0].set_title("Exp")
ax[1][1].set_title("Log")

#setting super title
plt.suptitle("Mathematical Functions")

#displaying the figure
plt.tight_layout()
plt.show()

The output of the above code will be:

SubPlots Function

5