Matplotlib: draw between subplots

Last week, I was preparing a data analysis report using Jupyter, Pandas and Matplotlib (to only quote a few bricks of this wonderful framework). One of the figures had two subplots, the second being an enlargement of a region of the first. To make it obvious, and at the same time show the old MATLAB Fanclub how so 90 they were, I decided to put an arrow from the first to the second subplot.

Matplotlib figure showing two subplot with an arrow spanning over both

Alas, from this simple idea started a three hours long research on the Internet, many trials and errors, even a bunch of crashes, and I was close to surrender against the technology of the last century… No way!

Artists and Tranformations

If you only need to show off, skip this section. I’m just going to explain roughly what’s under the hood. Figures, Axes, Axis, what are they? I cannot think of a better explanation than the Matplotlib’s Artist Tutorial, but I’ll try to summarize it. In Matplotlib, you draw on a Canvas using a Renderer. The objects that know how to use a renderer to draw on the canvas are called Artists.

There are two types of Artists: primitives and containers. The primitives represent the standard graphical objects we want to paint onto our canvas: Line2D, Rectangle, Text, AxesImage, etc., and the containers are places to put them (Axis, Axes and Figure).

From this, we deduce that we need to create a primitive, namely a Patch, and more exactly in my case a FancyArrowPatch, and add it to the figure. Easy.

There is another aspect to take into account. As we are adding a patch to the figure, it is normal to expect that it will be drawn using the figure’s coordinates. This is not very easy when it comes to properly place the arrow. To solve this, we need to have a look at how transformations work. There again, there is a very nice Transformations Tutorial, and it will be hard to explain better. When dealing with coordinates in Matplotlib, you need to take into account which coordinates system they apply to. There are 4 coordinates systems:

CoordinateTransformation ObjectDescription
dataax.transDataThe userland data coordinate system, controlled by the xlim and ylim
axesax.transAxesThe coordinate system of the Axes; (0,0) is bottom left of the axes, and (1,1) is top right of the axes.
figurefig.transFigureThe coordinate system of the Figure; (0,0) is bottom left of the figure, and (1,1) is top right of the figure.
displayNoneThis is the pixel coordinate system of the display; (0,0) is the bottom left of the display, and (width, height) is the top right of the display in pixels. Alternatively, the identity transform (matplotlib.transforms.IdentityTransform()) may be used instead of None.

The reference system is the display, and transformation objects describe how to map coordinates back to this reference system. We know where we would position the arrow according to the plotted data, i.e. its coordinates in the data systems, but we have no idea where this could maps onto the whole figure. If we apply a transformation using the data system’s transformation object, we obtain the coordinates in the display system. One step. The next is simply to apply an inverted transformation using the figure system’s transformation object to convert again from the display system to the figure system.

One last thing to take care of, is to properly specify the coordinate system we use when creating the patch through the transform parameter.

Show me the Code

Okay, get your data ready, setup the plot! In the below example, I use a Pandas DataFrame, but this would work the same with NumPy arrays.

import matplotlib.pyplot as plt
import matplotlib.patches

fig, axes = plt.subplots(2)
# Show the full measurement for comparison on axis 0
testdata.plot(ax=axes[0], x="mdfTime", y="psiDt")
# Place a little rectangle to mark the area of the enlargement
axes[0].add_patch(matplotlib.patches.Rectangle((200., -4.), 50., 6., transform=axes[0].transData, alpha=0.3, color="g"))
# Enlarged measurement on axis 1
testdata[int(200./0.02):int(250./0.02)].plot(ax=axes[1], x="mdfTime", y="psiDt")

And finally, the really relevant piece of code:

# Create the arrow
# 1. Get transformation operators for axis and figure
ax0tr = axes[0].transData # Axis 0 -> Display
ax1tr = axes[1].transData # Axis 1 -> Display
figtr = fig.transFigure.inverted() # Display -> Figure
# 2. Transform arrow start point from axis 0 to figure coordinates
ptB = figtr.transform(ax0tr.transform((225., -10.)))
# 3. Transform arrow end point from axis 1 to figure coordinates
ptE = figtr.transform(ax1tr.transform((225., 1.)))
# 4. Create the patch
arrow = matplotlib.patches.FancyArrowPatch(
    ptB, ptE, transform=fig.transFigure,  # Place arrow in figure coord system
    fc = "g", connectionstyle="arc3,rad=0.2", arrowstyle='simple', alpha = 0.3,
    mutation_scale = 40.
)
# 5. Add patch to list of objects to draw onto the figure
fig.patches.append(arrow)
Matplotlib figure showing two subplot with an arrow spanning over both

If you are dealing with lines, the process is the same, but you need to append your crafted object to the lines property of the figure instead of patches. Have a look at this StackOverflow question.