Intro to Matplotlib
Matplotlib is a comprehensive library for creating static, animated, and interactive visualizations in Python. It is used along with
NumPy to provide an environment that is an effective open source alternative to MatLab. Thus, this may feel very familiar to you if
you’ve ever used MatLab. Please note that Matplotlib is not the only tool for creating data visualization in Python. Other common
libraries include seaborn, plotly, and even pandas has some built in functions for quick plotting. This lab will mainly focus on
Matplotlib, but I highly recommend checking out seaborn as well, since it can make some type of plots much easier to create. The
good news is that seaborn is built on top of Matplotlib, so it will be very easy to pick up after going through this lab. However,
you’re free to use any of the plotting tools mentioned above throughout the course. Just as there were many ways to solve a
problem in pandas, there are many ways to visualize our data.
Let’s begin by importing matplotlib and learning how to create some beautiful graphs. If for some reason you don’t have matplotlib
installed, you will first have to go to your terminal (or Anaconda Prompt if on Windows) and enter the following:
conda install matplotlib
Make sure you’ve already installed Anaconda
The cell below is optional. It just makes our figure sizes 15 x 5 and adds a darkgrid background by default. Feel free to experiment
with these values if you’d like, but you’ll have to restart the kernel and run all to see any changes.
Plotting Numpy Arrays
Let’s begin by creating some numpy arrays and plotting them. We can create a simple x, y graph using plt.plot(x, y). We can then
display this by calling plt.show().
It’s would also be nice to add titles and axis labels to our graphs. We can accomplish this using plt.title() and plt.x/ylabel().
Note: Make sure to add any titles/labels and such before calling plt.show(). Otherwise they will not show up.
We’ll sometimes like to plot more than one function together, which can be done by calling plt.plot() a second time in the same
cell. However, make sure to add an additional label parameter when doing this to give each plot/function a unique label name.
That way we can then call plt.legend(), which will automatically add a legend for us and let us know which plot is which.
What if we instead wanted to plot sin(x) and cos(x) as separate plots, but side by side? This is where subplots(nrows, ncols)
comes in handy. subplots(nrows, ncols) will return to us a figure object and a set of axes. To better understand subplots, think of
a figure object as a large canvas where we can place a set of axes on. These set of axes are where we will plot our graphs on.
Previously, we only had one single axis in our figure, and we were plotting 1 or many functions on that single axis. However, figures
can have many axes, which we’ll demonstrate using subplots(nrows, ncols). For example, in the code below we tell
subplots(nrows, ncols) that we would like a figure object with 1 row and 2 columns (essentially an array with 2 elements, where
each element is an axis that we can plot on). Since axes can be thought of as an array holding the figure axes, we can index it to
specify what we’d like to plot in each one.
Try to see if you can now make a figure object with 2 rows and 2 columns, and plot something on all 4 axes. This will work exactly
the same as above, except that we now index our axes as a 2d array.
Custom Axes Range
Something else we’d sometimes like to do is change the x and y axis range. We can configure the ranges of the axes using the
set_ylim and set_xlim, or axis(‘tight’) for automatically getting “tightly fitted” axes ranges:
Final Note: The plt.plot() function supports additional arguments that can be used to specify a wide variety of line color
and style alternatives. Here are just a few examples, but please refer to the official documentation if you’d like to learn
Plotting Pandas Columns
Great job! Let’s now see how we can plot pandas columns (which essentially behave like arrays), which is much more common
and what you’ll be doing when exploring your datasets. To do this, let’s revisit that pokemon dataset. Plese refer back to the
pandas lab if you don’t recall the pokemon dataset or need help accessing it.
# Name Type 1 Type 2 Total HP Attack Defense Sp. Atk Sp. Def Speed Generation Legendary
0 1 Bulbasaur Grass Poison 318 45 49 49 65 65 45 1 False
1 2 Ivysaur Grass Poison 405 60 62 63 80 80 60 1 False
2 3 Venusaur Grass Poison 525 80 82 83 100 100 80 1 False
3 3 VenusaurMega Venusaur Grass Poison 625 80 100 123 122 120 80 1 False
4 4 Charmander Fire NaN 309 39 52 43 60 50 65 1 False
matplotlib has a wide variety of plot types. We’ve only seen simple line plots above, but let’s see how we can create a scatter plot
of pokemon HP and Attack. We can acomplish this using plt.scatter(x, y), where x and y can be arrays (as we saw earlier), or
Optional plot: If you’ve installed seaborn and imported it, we can easily expand on the scatterplot above by plotting each point by
a specific category (Generation in this case). To do this, we simply specify the category we’d like to color by in an additional hue
parameter. This can also be achieved in matplotlib, but you have to create a dictionary and map each unique value to a color,
which can be a bit tedious.
We can visualize the distribution of any pandas column using plt.hist(x), where x is any pandas column.
Let’s see a few plots that require us to do a little more work to create. I say a little more work because it may require us to gather
multiple information in some cases.
Recall what .value_counts() from the pandas section gave us. Calling value_counts() on pokemon_df[‘Type 1’] will list out each
unique Type 1 pokemon type and how many pokemon of that type are in our dataset. However, this information looks rather
boring. Let’s see how we can turn this info into a bar plot.
Name: Type 1, dtype: int64
Calling .index off of .value_counts() will return just the index names. Calling .values off of .value_counts() will return just the
corresponding values. What we’ll do is save this information into variables and then use these variables as our x and y in
plt.bar(x, y) to create our bar plot. Notice how x is the index labels, and y is the corresponding value counts in the plot below.
We can also take the top 5 types above (Water, Normal, Grass, Bug, Psychic) and represent this information in a pie chart using
plt.pie(sizes, colors). Sizes is the size for each wegde in the pie chart. For sizes we’ll pass only the first 5 type_1_values. We’ll
also pass a colors array to give each wedge a unique color. Shadow=True is optional (just adds a slight shadow to our pie chart for
visual purposes), and startangle=90 is the angle by which the start of the pie is rotated, counterclockwise from the x‑axis.
Note that plt.pie(sizes, colors) returns two things. patches, which is a sequence of matplotlib.patches.Wedge instances, and
texts, a list of the label Text instances. It’s important that we save these into variables so that we can pass this information to
plt.legend(patches, names). We’ll pass patches so that the legend knows which patch/wedge corresponds to which color, and
the first 5 type_1_names as the corresponding name to each patch/wedge.
Let’s collect the first 4 pokemon names in our DataFrame along with their corresponding HP, Attack, and defense.
We can compare these 4 pokemon using a stackplot. plt.stackplot(x, y1, y2, y3) works as follows in our case: x is the first 4
pokemon in our dataset, y1 is their corresponding HPs, y2 their corresponding Attacks, and y3 their corresponding defense
values. Notice how we’re stacking each attribute on top of each of the 4 pokemon.
We can also use our knowledge of subplots to represent this information in 3 separate subplots if we’d like. Recall that x is the
first 4 pokemon names in our dataset (Bulbasaur, Ivysaur, Venusaur, VenusaurMega Venusaur). So for each of these 3 subplots
we’re plotting the names along the x axis, and their corresponding HP or Attack or Defense value as the y value.
We can also make line plots of pandas columns, which we’ve seen in the earlier lab sections when we were passing arrays. These
type of plots will come in handy if you decide to work with stock data for your project.
One last and very important plot that we’ll talk about is the heatmap. If we call .corr() on a DataFrame we’ll get back a correlation
Note: Make sure you don’t have categorical variables in your DataFrame or you’ll get an error.
HP Attack Defense Sp. Atk Sp. Def Speed
HP 1.000000 0.422386 0.239622 0.362380 0.378718 0.175952
Attack 0.422386 1.000000 0.438687 0.396362 0.263990 0.381240
Defense 0.239622 0.438687 1.000000 0.223549 0.510747 0.015227
Sp. Atk 0.362380 0.396362 0.223549 1.000000 0.506121 0.473018
Sp. Def 0.378718 0.263990 0.510747 0.506121 1.000000 0.259133
Speed 0.175952 0.381240 0.015227 0.473018 0.259133 1.000000
We can then pass this information into sns.heatmap() to visualize it.
Note: You can also use matplotlib for this, but it’s much more intuitive in Seaborn and just one line of code, so I’d really
recommend at least using Seaborn for heatmaps. You’ll usually see people mixing a combination of matplotlib and Seaborn since
it’s a really powerful combo. Since Seaborn is built on matplotlib, a lot of plots can be made with just one or 2 lines of code in
Seaborn and then you can just use matplotlib to tweek certain things to your liking that are hidden in Seaborn (hidden for
simplicity and easy use).
Congrats! You now know enough to begin exploring and visualizing data on your own. As always, feel free to continue exploring
this dataset or repeat some of the steps above on a new dataset as practice. Also refer back to the documentation if we didn’t get
to cover a graph you need for your project this semester. Most visualization tools and libraries support just about every statistical
plots you can think of.
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import rcParams
rcParams[‘figure.figsize’] = 15, 5
In : x = np.linspace(0, 20)
y = np.sin(x)
plt.title(‘My Graph Title’)
plt.xlabel(‘My xaxis label’)
plt.ylabel(‘My yaxis label’)
plt.plot(x, np.sin(x), label=’y=sin(x)’)
plt.plot(x, np.cos(x), label=’y=cos(x)’)
plt.title(‘Sine and Cosine Functions ’)
# Assigns a color to each plot along with it’s corresponding label name to display in a legend.
# one figure object with 1×2=2 axes
fig, axes = plt.subplots(nrows=1, ncols=2)
# plotting sin﴾x﴿ in axes. left image, or 1st element/axis in the axes array.
# setting the title of axes to be y=sin﴾x﴿
# plotting cos﴾x﴿ in axes. right image, or 2nd element/axis in the axes array.
# setting the title of axes to be y=cos﴾x﴿
In : x = np.linspace(1, 21)
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(14, 8))
In : x = np.linspace(0, 5, 11)
fig, axes = plt.subplots(nrows=1, ncols=3)
# we can also plot multiple graphs using plt.plot﴾x1, y1, x2, y2﴿.
axes.plot(x, x**2, x, x**4)
axes.set_title(‘default axis range’)
axes.plot(x, x**2, x, x**4)
# tight happens to be the same here, but this will come in handy when you want to automatically align axes i
# fitted manner.
axes.set_title(‘tight axis range’)
# notice the zoom in that happened.
axes.plot(x, x**2, x, x**4)
axes.set_title(“custom axes range”);
In : x = np.linspace(0, 5, 11)
plt.plot(x, x+1, color=”red”, linewidth=0.25)
plt.plot(x, x+2, color=”red”, linewidth=0.50)
plt.plot(x, x+3, color=”red”, linewidth=1.00)
plt.plot(x, x+4, color=”red”, linewidth=2.00)
# possible linestype options ‘‘, ‘–’, ‘.’, ‘:’, ‘steps’
plt.plot(x, x+5, color=”green”, lw=3, linestyle=’’)
plt.plot(x, x+6, color=”green”, lw=3, ls=’.’)
plt.plot(x, x+7, color=”green”, lw=3, ls=’:’)
# custom dash ﴾not something we’ll do often﴿
line, = plt.plot(x, x+8, color=”black”, lw=1.50)
line.set_dashes([5, 10, 15, 10]) # format: line length, space length, …
# possible marker symbols: marker = ’+’, ’o’, ’*’, ’s’, ’,’, ’.’, ’1′, ’2′, ’3′, ’4′, …
plt.plot(x, x+ 9, color=”blue”, lw=3, ls=’’, marker=’+’)
plt.plot(x, x+10, color=”blue”, lw=3, ls=’’, marker=’o’)
plt.plot(x, x+11, color=”blue”, lw=3, ls=’’, marker=’s’)
plt.plot(x, x+12, color=”blue”, lw=3, ls=’’, marker=’1′)
# marker size and color
plt.plot(x, x+13, color=”purple”, lw=1, ls=’’, marker=’o’, markersize=2)
plt.plot(x, x+14, color=”purple”, lw=1, ls=’’, marker=’o’, markersize=4)
plt.plot(x, x+15, color=”purple”, lw=1, ls=’’, marker=’o’, markersize=8, markerfacecolor=”red”)
plt.plot(x, x+16, color=”purple”, lw=1, ls=’’, marker=’s’, markersize=8,
markerfacecolor=”yellow”, markeredgewidth=3, markeredgecolor=”green”);
pokemon_df = pd.read_csv(filepath_or_buffer=’Pokemon.csv’)
Out: In :
plt.title(‘HP vs. Attack’)
In : sns.scatterplot(x=pokemon_df[‘HP’], y=pokemon_df[‘Attack’], hue=’Generation’, data=pokemon_df)
Out: In :
# collect index names ﴾left side of .value_counts﴾﴿﴿
type_1_names = pokemon_df[‘Type 1’].value_counts().index
# collect corresponding value counts ﴾right side of .value_counts﴾﴿﴿
type_1_values = pokemon_df[‘Type 1’].value_counts().values
# plt.bar﴾x, y﴿
plt.title(‘Type 1 Pokemon value counts’)
# colors we’ll use for each patch/wedge
colors = [‘yellowgreen’, ‘gold’, ‘lightskyblue’, ‘lightcoral’,’brown’]
patches, texts = plt.pie(type_1_values[:5], colors=colors, shadow=True, startangle=90)
# will center our pie chart, but not necessary
In : x = pokemon_df[‘Name’][:4]
y1 = pokemon_df[‘HP’][:4]
y2 = pokemon_df[‘Attack’][:4]
y3 = pokemon_df[‘Defense’][:4]
labels = [“HP ”, “Attack”, “Defense”]
plt.stackplot(x, y1, y2, y3)
plt.legend(loc=’upper left’, labels=labels)
# rotates the xlabels to appear vertical rather than horizontal by default.
# Create 3 subplots sharing y axis
fig, axes = plt.subplots(nrows=3, ncols=1, sharey=True)
# ’ko, r., and : are just line styles and optional.’
# notice how you can also set multiple things together in one single command rather than
# creating separate .set_title, .set_ylabel, .set_xlabel, etc. commands like we were doing earlier.
axes.plot(x, y1, ‘ko’)
axes.set(title=’3 subplots’, ylabel=’HP’)
axes.plot(x, y2, ‘r.’)
axes.plot(x, y3, ‘:’)
pokemon_df[[‘HP’, ‘Attack’, ‘Defense’,’Sp. Atk’, ‘Sp. Def’, ‘Speed’]].corr()
Out: In : sns.heatmap(pokemon_df[[‘HP’, ‘Attack’, ‘Defense’,’Sp. Atk’, ‘Sp. Def’, ‘Speed’]].corr())