DEV Community

Cover image for Plotting and Data Visualization with Matplotlib
teri
teri

Posted on

Plotting and Data Visualization with Matplotlib

Working with raw data in the form of a CSV (comma-separated value) does not visually tell a story. However, if done right with a visualization library like Matplotlib, your users tend to appreciate you because they can connect the dots easily with visuals.

This article is an introduction to using Matplotlib for plotting and data visualizations.

GitHub Repo

Check the complete source code in this repo.

What is Matplotlib

Matplotlib is a Python plotting library that allows you to turn data into pretty visualizations, also known as plots or figures.

The following reasons are why Matplotlib is necessary for data scientists:

  • It is built on NumPy arrays (and Python)
  • Integrates directly with Pandas
  • Can create basic or advanced plots

Importing Matplotlib

To start with Matplotlib, import it into your Jupyter Notebook like this:

%matplotlib inline
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
Enter fullscreen mode Exit fullscreen mode
  • %matplotlib inline: this magic command with the percentage sign in front of matplotlib helps make sure all matplotlib plots and graphs appear within the notebook

The simplest way to create a plot is with:

plt.plot();
Enter fullscreen mode Exit fullscreen mode

plot figure

Let's add some data to the plot:

x = [1, 2, 3, 4]
y = [11, 22, 33, 44]
plt.plot(x, y);
Enter fullscreen mode Exit fullscreen mode

This code shows the plot values on the plot figure's x and y axes.

plotting a graph on the x and y axes

The recommended way of plotting a graph is using this method, which should give the same result as the previous method:

fig, ax = plt.subplots()
ax.plot(x, y);
Enter fullscreen mode Exit fullscreen mode

Note: Changing the x and y should return an entirely different graph.

Anatomy of Matplotlib Plot

The representation of a typical workflow of a Matplotlib figure includes:

  • A plot axes title
  • Legend
  • y-axis label
  • x-axis label

Let's see an example of the workflow.

# 0. import matplotlib and get it ready for plotting in Jupyter
%matplotlib inline 
import matplotlib.pyplot as plt

# 1. Prepare data 
x = [1, 2, 3, 4]
y = [11, 22, 33, 44]

# 2. Setup plot
fig, ax = plt.subplots(figsize=(10, 10))

# 3. Plot data  
ax.plot(x, y)

# 4. Customize plot
ax.set(title = "Simple plot",
       xlabel = "x-axis",
       ylabel = "y-axis")

# 5. Save & show (you have to save the whole figure)
fig.savefig("images/sample-plot.png")
Enter fullscreen mode Exit fullscreen mode

The code above shows that you can set a title with the ax.set() method and save the plot as a .png file in the images folder.

matplotlib workflow

Creating Figures with NumPy arrays

In this section, you will create different plots like scatter and bar, but there are others like histograms, lines, and subplots.

Copy-paste this code in your notebook:

import numpy as np
x = np.linspace(0, 10, 100)
x[:10]
Enter fullscreen mode Exit fullscreen mode

linspace: returns evenly spaced numbers over a specified interval. Also, the index of x displays only the first ten results.

Plot the data and create a line plot:

fig, ax = plt.subplots()
ax.plot(x, x**2);
Enter fullscreen mode Exit fullscreen mode

You should see something like this:

line plot

For a scatter plot, use the same data from above:

fig, ax = plt.subplots()
ax.scatter(x, np.exp(x));
Enter fullscreen mode Exit fullscreen mode

Note: Instead of using .plot() on ax axes, switch to using .scatter().

Scatter plot

Working with dictionaries and making a plot:

nut_butter_prices = {"Almond butter": 10, "Peanut butter": 9, "Cashew butter": 5}

fig, ax = plt.subplots()
ax.bar(nut_butter_prices.keys(), nut_butter_prices.values())
ax.set(title = "Teri's Nut Butter Store",
       ylabel = "Price ($)"
      );
Enter fullscreen mode Exit fullscreen mode

bar plot

Horizontal Bar
Another way of creating a plot is plotting a horizontal bar with .barh.

fig, ax = plt.subplots()
ax.barh(list(nut_butter_prices.keys()), list(nut_butter_prices.values()));
Enter fullscreen mode Exit fullscreen mode

horizontal bar

Subplots and Histograms
You can turn a single figure into subplots of four equal parts with this code:

# Subplots option 1
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(
    nrows = 2,
    ncols = 2,
    figsize = (10, 5)
)

# Plot to each different axes
ax1.plot(x, x / 2);
ax2.scatter(np.random.random(10), np.random.random(10));
ax3.bar(nut_butter_prices.keys(), nut_butter_prices.values());
ax4.hist(np.random.randn(1000));
Enter fullscreen mode Exit fullscreen mode

subplots

# subplots option 2
fig, ax = plt.subplots(nrows = 2, 
                       ncols = 2, 
                       figsize = (10, 5))

# Plot to each different index
ax[0, 0].plot(x, x/2);
ax[0, 1].scatter(np.random.random(10), np.random.random(10));
ax[1, 0].bar(nut_butter_prices.keys(), nut_butter_prices.values());
ax[1, 1].hist(np.random.randn(1000));
Enter fullscreen mode Exit fullscreen mode

subplots

Plotting from Pandas DataFrame

This section will show you how to use the Pandas DataFrame to visualize data using a .csv file.

Download the car sales data

Before using an imported to read and use it, first import the pandas library:

import pandas as pd
Enter fullscreen mode Exit fullscreen mode

Make a DataFrame with this command:

car_sales = pd.read_csv("car_sales.csv")
car_sales
Enter fullscreen mode Exit fullscreen mode

Reading the car sales data is saved in the root directory of the main Python notebook. But if you save it in a folder, you must reference it in the .read_csv() method.

Car sales data

To remove the $ sign and turn it into an integer data type, run this command, which is in regex:

car_sales['Price']=car_sales['Price'].str.replace('$','',regex=False).str.replace(',','',regex=False).astype(float).astype(int)
car_sales
Enter fullscreen mode Exit fullscreen mode

adjusted car sales price data

Add a Sale Date Column:

car_sales["Sale Date"] = pd.date_range("1/1/2023", periods=len(car_sales))
car_sales
Enter fullscreen mode Exit fullscreen mode

sale date column

Add Total Sales Column:

car_sales["Total Sales"] = car_sales["Price"].cumsum()
car_sales
Enter fullscreen mode Exit fullscreen mode

Plot the Total Sales:

car_sales.plot(x = "Sale Date", y = "Total Sales")
Enter fullscreen mode Exit fullscreen mode

total sales

Repeat the same process to plot with any column axis just like this:

car_sales.plot(x="Odometer (KM)", y = "Price", kind = "scatter");
Enter fullscreen mode Exit fullscreen mode

Scatter plot

In Summary

Matplotlib creates beautiful visualization depending on what you want to achieve, as it is rich with various options to spice up your data and make it visually appealing.

Resources

Top comments (0)