How to plot in multiple subplots
PythonPandasMatplotlibSeabornSubplotPython Problem Overview
I am a little confused about how this code works:
fig, axes = plt.subplots(nrows=2, ncols=2)
plt.show()
How does the fig, axes work in this case? What does it do?
Also why wouldn't this work to do the same thing:
fig = plt.figure()
axes = fig.subplots(nrows=2, ncols=2)
Python Solutions
Solution 1 - Python
There are several ways to do it. The subplots
method creates the figure along with the subplots that are then stored in the ax
array. For example:
import matplotlib.pyplot as plt
x = range(10)
y = range(10)
fig, ax = plt.subplots(nrows=2, ncols=2)
for row in ax:
for col in row:
col.plot(x, y)
plt.show()
However, something like this will also work, it's not so "clean" though since you are creating a figure with subplots and then add on top of them:
fig = plt.figure()
plt.subplot(2, 2, 1)
plt.plot(x, y)
plt.subplot(2, 2, 2)
plt.plot(x, y)
plt.subplot(2, 2, 3)
plt.plot(x, y)
plt.subplot(2, 2, 4)
plt.plot(x, y)
plt.show()
Solution 2 - Python
import matplotlib.pyplot as plt
fig, ax = plt.subplots(2, 2)
ax[0, 0].plot(range(10), 'r') #row=0, col=0
ax[1, 0].plot(range(10), 'b') #row=1, col=0
ax[0, 1].plot(range(10), 'g') #row=0, col=1
ax[1, 1].plot(range(10), 'k') #row=1, col=1
plt.show()
Solution 3 - Python
-
You can also unpack the axes in the subplots call
-
And set whether you want to share the x and y axes between the subplots
Like this:
import matplotlib.pyplot as plt
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows=2, ncols=2, sharex=True, sharey=True)
ax1.plot(range(10), 'r')
ax2.plot(range(10), 'b')
ax3.plot(range(10), 'g')
ax4.plot(range(10), 'k')
plt.show()
Solution 4 - Python
You might be interested in the fact that as of matplotlib version 2.1 the second code from the question works fine as well.
From the change log:
> Figure class now has subplots method The Figure class now has a subplots() method which behaves the same as pyplot.subplots() but on an existing figure.
Example:
import matplotlib.pyplot as plt
fig = plt.figure()
axes = fig.subplots(nrows=2, ncols=2)
plt.show()
Solution 5 - Python
Read the documentation: matplotlib.pyplot.subplots
pyplot.subplots()
returns a tuple fig, ax
which is unpacked in two variables using the notation
fig, axes = plt.subplots(nrows=2, ncols=2)
The code:
fig = plt.figure()
axes = fig.subplots(nrows=2, ncols=2)
does not work because subplots()
is a function in pyplot
not a member of the object Figure
.
Solution 6 - Python
Iterating through all subplots sequentially:
fig, axes = plt.subplots(nrows, ncols)
for ax in axes.flatten():
ax.plot(x,y)
Accessing a specific index:
for row in range(nrows):
for col in range(ncols):
axes[row,col].plot(x[row], y[col])
Solution 7 - Python
Subplots with pandas
- This answer is for subplots with
pandas
, which, usesmatplotlib
as the default plotting backend. - Here is four options to create subplots starting with a
pandas.DataFrame
- Implementation 1. and 2. are for the data in a wide format, creating subplots for each column.
- Implementation 3. and 4. are for data in a long format, creating subplots for each unique value in a column.
- Tested in
python 3.8.11
,pandas 1.3.2
,matplotlib 3.4.3
,seaborn 0.11.2
Imports and Data
import seaborn as sns # data only
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
# wide dataframe
df = sns.load_dataset('planets').iloc[:, 2:5]
orbital_period mass distance
0 269.300 7.10 77.40
1 874.774 2.21 56.95
2 763.000 2.60 19.84
3 326.030 19.40 110.62
4 516.220 10.50 119.47
# long dataframe
dfm = sns.load_dataset('planets').iloc[:, 2:5].melt()
variable value
0 orbital_period 269.300
1 orbital_period 874.774
2 orbital_period 763.000
3 orbital_period 326.030
4 orbital_period 516.220
subplots=True
and layout
, for each column
1. - Use the parameters
subplots=True
andlayout=(rows, cols)
inpandas.DataFrame.plot
- This example uses
kind='density'
, but there are different options forkind
, and this applies to them all. Without specifyingkind
, a line plot is the default. ax
is array ofAxesSubplot
returned bypandas.DataFrame.plot
- See How to get a
Figure
object, if needed.
axes = df.plot(kind='density', subplots=True, layout=(2, 2), sharex=False, figsize=(10, 6))
# extract the figure object; only used for tight_layout in this example
fig = axes[0][0].get_figure()
# set the individual titles
for ax, title in zip(axes.ravel(), df.columns):
ax.set_title(title)
fig.tight_layout()
plt.show()
plt.subplots
, for each column
2. - Create an array of
Axes
withmatplotlib.pyplot.subplots
and then passaxes[i, j]
oraxes[n]
to theax
parameter.- This option uses
pandas.DataFrame.plot
, but can use otheraxes
level plot calls as a substitute (e.g.sns.kdeplot
,plt.plot
, etc.) - It's easiest to collapse the subplot array of
Axes
into one dimension with.ravel
or.flatten
. See.ravel
vs.flatten
. - Any variables applying to each
axes
, that need to be iterate through, are combined with.zip
(e.g.cols
,axes
,colors
,palette
, etc.). Each object must be the same length.
- This option uses
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10, 6)) # define the figure and subplots
axes = axes.ravel() # array to 1D
cols = df.columns # create a list of dataframe columns to use
colors = ['tab:blue', 'tab:orange', 'tab:green'] # list of colors for each subplot, otherwise all subplots will be one color
for col, color, ax in zip(cols, colors, axes):
df[col].plot(kind='density', ax=ax, color=color, label=col, title=col)
ax.legend()
fig.delaxes(axes[3]) # delete the empty subplot
fig.tight_layout()
plt.show()
Result for 1. and 2.
plt.subplots
, for each group in .groupby
3. - This is similar to 2., except it zips
color
andaxes
to a.groupby
object.
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10, 6)) # define the figure and subplots
axes = axes.ravel() # array to 1D
dfg = dfm.groupby('variable') # get data for each unique value in the first column
colors = ['tab:blue', 'tab:orange', 'tab:green'] # list of colors for each subplot, otherwise all subplots will be one color
for (group, data), color, ax in zip(dfg, colors, axes):
data.plot(kind='density', ax=ax, color=color, title=group, legend=False)
fig.delaxes(axes[3]) # delete the empty subplot
fig.tight_layout()
plt.show()
seaborn
figure-level plot
4. - Use a
seaborn
figure-level plot, and use thecol
orrow
parameter.seaborn
is a high-level API formatplotlib
. See seaborn: API reference
p = sns.displot(data=dfm, kind='kde', col='variable', col_wrap=2, x='value', hue='variable',
facet_kws={'sharey': False, 'sharex': False}, height=3.5, aspect=1.75)
sns.move_legend(p, "upper left", bbox_to_anchor=(.55, .45))
Solution 8 - Python
The other answers are great, this answer is a combination which might be useful.
import numpy as np
import matplotlib.pyplot as plt
# Optional: define x for all the sub-plots
x = np.linspace(0,2*np.pi,100)
# (1) Prepare the figure infrastructure
fig, ax_array = plt.subplots(nrows=2, ncols=2)
# flatten the array of axes, which makes them easier to iterate through and assign
ax_array = ax_array.flatten()
# (2) Plot loop
for i, ax in enumerate(ax_array):
ax.plot(x , np.sin(x + np.pi/2*i))
#ax.set_title(f'plot {i}')
# Optional: main title
plt.suptitle('Plots')
Summary
- Prepare the figure infrastructure
- Get ax_array, an array of the subplots
- Flatten the array in order to use it in one 'for loop'
- Plot loop
- Loop over the flattened ax_array to update the subplots
- optional: use enumeration to track subplot number
- Once flattened, each
ax_array
can be individually indexed from0
throughnrows x ncols -1
(e.g.ax_array[0]
,ax_array[1]
,ax_array[2]
,ax_array[3]
).
Solution 9 - Python
axes
array to 1D
Convert the - Generating subplots with
plt.subplots(nrows, ncols)
, where both nrows and ncols is greater than 1, returns a nested array of<AxesSubplot:>
objects.- It is not necessary to flatten
axes
in cases where eithernrows=1
orncols=1
, becauseaxes
will already be 1 dimensional, which is a result of the default parametersqueeze=True
- It is not necessary to flatten
- The easiest way to access the objects, is to convert the array to 1 dimension with
.ravel()
,.flatten()
, or.flat
..ravel
vs..flatten
flatten
always returns a copy.ravel
returns a view of the original array whenever possible.
- Once the array of
axes
is converted to 1-d, there are a number of ways to plot.
import matplotlib.pyplot as plt
import numpy as np # sample data only
# example of data
rads = np.arange(0, 2*np.pi, 0.01)
y_data = np.array([np.sin(t*rads) for t in range(1, 5)])
x_data = [rads, rads, rads, rads]
# Generate figure and its subplots
fig, axes = plt.subplots(nrows=2, ncols=2)
# axes before
array([[<AxesSubplot:>, <AxesSubplot:>],
[<AxesSubplot:>, <AxesSubplot:>]], dtype=object)
# convert the array to 1 dimension
axes = axes.ravel()
# axes after
array([<AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>],
dtype=object)
- Iterate through the flattened array
- If there are more subplots than data, this will result in
IndexError: list index out of range
- Try option 3. instead, or select a subset of the axes (e.g.
axes[:-2]
)
- Try option 3. instead, or select a subset of the axes (e.g.
- If there are more subplots than data, this will result in
for i, ax in enumerate(axes):
ax.plot(x_data[i], y_data[i])
- Access each axes by index
axes[0].plot(x_data[0], y_data[0])
axes[1].plot(x_data[1], y_data[1])
axes[2].plot(x_data[2], y_data[2])
axes[3].plot(x_data[3], y_data[3])
- Index the data and axes
for i in range(len(x_data)):
axes[i].plot(x_data[i], y_data[i])
zip
the axes and data together and then iterate through the list of tuples
for ax, x, y in zip(axes, x_data, y_data):
ax.plot(x, y)
Ouput
Solution 10 - Python
here is a simple solution
fig, ax = plt.subplots(nrows=2, ncols=3, sharex=True, sharey=False)
for sp in fig.axes:
sp.plot(range(10))
Solution 11 - Python
Go with the following if you really want to use a loop:
def plot(data):
fig = plt.figure(figsize=(100, 100))
for idx, k in enumerate(data.keys(), 1):
x, y = data[k].keys(), data[k].values
plt.subplot(63, 10, idx)
plt.bar(x, y)
plt.show()