Module 5.1

Matplotlib Fundamentals

Master the foundation of Python data visualization. Learn to create publication-quality charts, understand the figure and axes architecture, and build customized visualizations for data analysis.

55 min
Beginner
Hands-on
What You'll Learn
  • Understand figure and axes architecture
  • Create line, scatter, bar, and histogram plots
  • Customize colors, labels, and legends
  • Build complex subplot layouts
  • Save publication-quality figures
Contents
01

Introduction to Matplotlib

Matplotlib is the foundational visualization library in Python's data science ecosystem. Created by John Hunter in 2003, it provides a MATLAB-like interface for creating static, animated, and interactive visualizations. Whether you need a quick exploratory plot or a publication-ready figure, Matplotlib gives you complete control over every aspect of your visualization.

Why Matplotlib?

While newer libraries like Seaborn and Plotly offer higher-level interfaces, Matplotlib remains essential for several reasons. First, it provides the most granular control over plot elements, allowing you to customize everything from tick marks to annotation positions. Second, most other Python visualization libraries are built on top of Matplotlib, so understanding it helps you work with the entire ecosystem. Third, it excels at creating publication-quality figures with precise formatting requirements.

Key Concept

Matplotlib

A comprehensive library for creating static, animated, and interactive visualizations in Python. It produces publication-quality figures in a variety of formats and interactive environments.

Getting Started

To use Matplotlib, you first need to install and import it. The most common approach is to import the pyplot module, which provides a convenient interface similar to MATLAB. By convention, we import it as plt, which you will see in virtually all Python data science code.

# Install matplotlib (run in terminal)
# pip install matplotlib

# Standard import convention
import matplotlib.pyplot as plt
import numpy as np

# Check version
print(plt.matplotlib.__version__)  # 3.8.2 (or your installed version)

Your First Plot

Creating a basic plot in Matplotlib is remarkably simple. You pass your data to a plotting function, add some labels, and display the result. The library handles all the complex rendering details behind the scenes, allowing you to focus on your data rather than the mechanics of drawing.

# Create sample data
x = [1, 2, 3, 4, 5]
y = [2, 4, 6, 8, 10]

# Create a simple line plot
plt.plot(x, y)
plt.xlabel('X Values')
plt.ylabel('Y Values')
plt.title('My First Matplotlib Plot')
plt.show()

Two Interfaces: pyplot vs Object-Oriented

Matplotlib offers two ways to create plots. The pyplot interface (also called the state-based interface) is simpler and works well for quick plots. The object-oriented interface gives you more control and is better for complex figures with multiple subplots. Professional data scientists typically use the object-oriented approach for production code.

# pyplot interface (state-based) - simpler but less control
plt.figure(figsize=(8, 4))
plt.plot([1, 2, 3], [1, 4, 9])
plt.title('pyplot Style')
plt.show()

# Object-oriented interface - more control, recommended for complex plots
fig, ax = plt.subplots(figsize=(8, 4))
ax.plot([1, 2, 3], [1, 4, 9])
ax.set_title('Object-Oriented Style')
plt.show()
pyplot Interface
  • Quick exploratory plots
  • Single figure, simple layouts
  • Uses plt.function() syntax
  • State-based, tracks current figure
Object-Oriented
  • Complex multi-panel figures
  • Fine-grained customization
  • Uses ax.method() syntax
  • Explicit figure and axes objects
02

Figure & Axes Architecture

Understanding Matplotlib's architecture is crucial for creating effective visualizations. At its core, every Matplotlib plot consists of a Figure container that holds one or more Axes objects. The Figure is like a canvas, while each Axes is an individual plot within that canvas. Mastering this hierarchy unlocks the full power of the library.

The Matplotlib Hierarchy

Matplotlib organizes visualizations in a clear hierarchy. The Figure is the top-level container that represents the entire image. Inside the Figure, you have one or more Axes objects, which are the actual plots where data gets visualized. Each Axes contains additional elements like the x-axis, y-axis, title, and legend. Understanding this structure helps you manipulate any part of your visualization.

Key Concepts

Figure & Axes

Figure: The top-level container for all plot elements. Think of it as the window or page that holds your visualization.

Axes: The area where data is plotted, including the x-axis, y-axis, and all visual elements like lines and markers.

# Creating Figure and Axes explicitly
fig = plt.figure(figsize=(10, 6))  # Create a figure 10 inches wide, 6 inches tall
ax = fig.add_subplot(111)          # Add a single axes (1 row, 1 col, position 1)

# Plot on the axes
ax.plot([1, 2, 3, 4], [1, 4, 2, 3])
ax.set_xlabel('X Axis Label')
ax.set_ylabel('Y Axis Label')
ax.set_title('Understanding Figure and Axes')
plt.show()

The subplots() Shortcut

While you can create figures and axes separately, the plt.subplots() function provides a convenient way to create both at once. This is the most common pattern you will see in professional code. It returns a tuple containing the figure and axes, which you can unpack directly into variables.

# The most common pattern: create figure and axes together
fig, ax = plt.subplots(figsize=(10, 6))

# Now use ax to plot and customize
x = np.linspace(0, 10, 100)
ax.plot(x, np.sin(x), label='sin(x)')
ax.plot(x, np.cos(x), label='cos(x)')
ax.legend()
ax.set_title('Sine and Cosine Waves')
plt.show()

Figure Properties

The Figure object controls overall properties like size, resolution, and background color. Setting the right figure size is important for both display and export. The figsize parameter takes a tuple of width and height in inches. The dpi parameter controls resolution, with higher values producing sharper but larger images.

# Figure with custom properties
fig, ax = plt.subplots(
    figsize=(12, 8),     # Width, height in inches
    dpi=100,             # Dots per inch (resolution)
    facecolor='white'    # Background color
)

ax.plot([1, 2, 3], [1, 2, 3])
ax.set_title('Custom Figure Properties')

# Access figure properties
print(f"Figure size: {fig.get_size_inches()}")  # [12.  8.]
print(f"DPI: {fig.get_dpi()}")                  # 100.0

plt.tight_layout()  # Adjust spacing to prevent overlap
plt.show()

Axes Properties and Methods

The Axes object provides methods to customize every aspect of your plot. Methods starting with set_ configure properties like labels and limits. The Axes also manages the actual data visualization through plotting methods like plot(), scatter(), and bar(). Learning the key Axes methods is essential for effective visualization.

fig, ax = plt.subplots(figsize=(10, 6))

# Sample data
x = np.arange(1, 11)
y = x ** 2

# Plot the data
ax.plot(x, y, color='blue', linewidth=2, marker='o')

# Customize axes properties
ax.set_xlabel('X Values', fontsize=12)
ax.set_ylabel('Y Squared', fontsize=12)
ax.set_title('Customizing Axes Properties', fontsize=14, fontweight='bold')
ax.set_xlim(0, 12)                    # Set x-axis limits
ax.set_ylim(0, 120)                   # Set y-axis limits
ax.set_xticks(range(0, 13, 2))        # Custom tick positions
ax.grid(True, linestyle='--', alpha=0.7)  # Add grid

plt.show()
Method Description Example
ax.set_xlabel() Set x-axis label ax.set_xlabel('Time (s)')
ax.set_ylabel() Set y-axis label ax.set_ylabel('Value')
ax.set_title() Set plot title ax.set_title('My Plot')
ax.set_xlim() Set x-axis range ax.set_xlim(0, 100)
ax.set_ylim() Set y-axis range ax.set_ylim(-1, 1)
ax.legend() Display legend ax.legend(loc='best')
ax.grid() Add grid lines ax.grid(True)

Practice Questions

Task: Create a figure that is 8 inches wide and 5 inches tall, then plot the squares of numbers 1-5.

# Given: Numbers 1-5
numbers = [1, 2, 3, 4, 5]

# Your code here: Create figure with figsize, calculate squares, plot, add title

Expected Output: A line plot showing points (1,1), (2,4), (3,9), (4,16), (5,25) with title "Squares"

numbers = [1, 2, 3, 4, 5]

fig, ax = plt.subplots(figsize=(8, 5))
squares = [n ** 2 for n in numbers]
ax.plot(numbers, squares, marker='o')
ax.set_title('Squares')
ax.set_xlabel('Number')
ax.set_ylabel('Square')
plt.show()

Task: Create a plot with custom x and y limits, a grid, and formatted labels.

# Given: Temperature data
hours = [6, 9, 12, 15, 18, 21]
temps = [15, 20, 28, 32, 25, 18]

# Your code: Plot temps, set xlim 0-24, ylim 10-40, add grid, label axes

Expected Output: Temperature plot with proper axis limits, grid, and descriptive labels

hours = [6, 9, 12, 15, 18, 21]
temps = [15, 20, 28, 32, 25, 18]

fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(hours, temps, marker='s', color='red', linewidth=2)
ax.set_xlim(0, 24)
ax.set_ylim(10, 40)
ax.set_xlabel('Hour of Day', fontsize=12)
ax.set_ylabel('Temperature (°C)', fontsize=12)
ax.set_title('Daily Temperature Profile', fontsize=14)
ax.grid(True, linestyle='--', alpha=0.7)
ax.set_xticks([0, 6, 12, 18, 24])
plt.show()

Task: Plot three mathematical functions on the same axes with a legend.

# Given: x values from 0 to 2π
x = np.linspace(0, 2 * np.pi, 100)

# Your code: Plot sin(x), cos(x), and tan(x) clipped to [-2, 2]
# Add labels, legend (upper right), title, and grid

Expected Output: Three curves with different colors, a legend, and appropriate styling

x = np.linspace(0, 2 * np.pi, 100)

fig, ax = plt.subplots(figsize=(12, 6))
ax.plot(x, np.sin(x), label='sin(x)', color='blue')
ax.plot(x, np.cos(x), label='cos(x)', color='red')
ax.plot(x, np.clip(np.tan(x), -2, 2), label='tan(x)', color='green', linestyle='--')

ax.set_xlabel('x (radians)')
ax.set_ylabel('f(x)')
ax.set_title('Trigonometric Functions')
ax.set_ylim(-2.5, 2.5)
ax.legend(loc='upper right')
ax.grid(True, alpha=0.3)
plt.show()
03

Basic Plot Types

Matplotlib provides a rich variety of plot types for different data visualization needs. From line plots showing trends over time to scatter plots revealing relationships between variables, each plot type serves a specific purpose. Mastering these fundamental chart types gives you the building blocks for more complex visualizations.

Interactive: Choose the Right Chart

Decision Helper

What kind of data do you want to visualize? Click to see the best chart type for your needs.

Line Plots

Line plots are ideal for showing continuous data and trends over time or ordered categories. They connect data points with lines, making it easy to see patterns, trends, and fluctuations. Use line plots when your x-axis represents a continuous variable like time, or when you want to emphasize the connection between sequential data points.

# Basic line plot
months = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun']
sales = [12000, 15000, 13500, 17000, 19000, 21000]

fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(months, sales, marker='o', linewidth=2, markersize=8, color='#3498db')
ax.set_xlabel('Month')
ax.set_ylabel('Sales ($)')
ax.set_title('Monthly Sales Performance')
ax.grid(True, axis='y', alpha=0.3)
plt.show()

# Multiple lines for comparison
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(months, sales, marker='o', label='2024', linewidth=2)
ax.plot(months, [11000, 13000, 14000, 15500, 17000, 18500], marker='s', label='2023', linewidth=2)
ax.legend()
ax.set_title('Year-over-Year Sales Comparison')
plt.show()

Scatter Plots

Scatter plots display individual data points as markers, making them perfect for exploring relationships between two numerical variables. They help identify correlations, clusters, and outliers in your data. You can also encode additional variables using color and size to create information-rich visualizations.

# Basic scatter plot
np.random.seed(42)
study_hours = np.random.uniform(1, 10, 50)
exam_scores = 50 + 5 * study_hours + np.random.normal(0, 5, 50)

fig, ax = plt.subplots(figsize=(10, 6))
ax.scatter(study_hours, exam_scores, alpha=0.7, s=100, c='#e74c3c', edgecolors='white')
ax.set_xlabel('Study Hours')
ax.set_ylabel('Exam Score')
ax.set_title('Study Hours vs Exam Performance')
plt.show()

# Scatter with color mapping (third variable)
ages = np.random.randint(18, 25, 50)
scatter = ax.scatter(study_hours, exam_scores, c=ages, cmap='viridis', s=100, alpha=0.7)
plt.colorbar(scatter, label='Student Age')
plt.show()
Reference

Key Scatter Parameters

s controls marker size, c sets color (can be array for colormap), alpha controls transparency (0-1), and edgecolors sets marker border color.

Bar Charts

Bar charts compare discrete categories using rectangular bars. They are excellent for showing quantities across different groups or categories. Matplotlib supports both vertical bars (bar()) and horizontal bars (barh()), as well as grouped and stacked variations for comparing multiple series.

# Vertical bar chart
categories = ['Electronics', 'Clothing', 'Food', 'Books', 'Sports']
revenue = [45000, 32000, 28000, 15000, 22000]

fig, ax = plt.subplots(figsize=(10, 6))
bars = ax.bar(categories, revenue, color=['#3498db', '#e74c3c', '#2ecc71', '#f39c12', '#9b59b6'])
ax.set_xlabel('Category')
ax.set_ylabel('Revenue ($)')
ax.set_title('Revenue by Product Category')

# Add value labels on bars
for bar, val in zip(bars, revenue):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 500, 
            f'${val:,}', ha='center', va='bottom', fontsize=10)
plt.show()

# Horizontal bar chart (good for long category names)
fig, ax = plt.subplots(figsize=(10, 6))
ax.barh(categories, revenue, color='#3498db')
ax.set_xlabel('Revenue ($)')
ax.set_title('Revenue by Product Category')
plt.show()

Grouped and Stacked Bar Charts

When comparing multiple data series across categories, grouped bars place bars side-by-side while stacked bars pile them on top of each other. Grouped bars are better for comparing individual values, while stacked bars show how parts contribute to a whole.

# Grouped bar chart
categories = ['Q1', 'Q2', 'Q3', 'Q4']
product_a = [25, 30, 35, 40]
product_b = [20, 28, 32, 38]

x = np.arange(len(categories))
width = 0.35

fig, ax = plt.subplots(figsize=(10, 6))
bars1 = ax.bar(x - width/2, product_a, width, label='Product A', color='#3498db')
bars2 = ax.bar(x + width/2, product_b, width, label='Product B', color='#e74c3c')

ax.set_xlabel('Quarter')
ax.set_ylabel('Sales (thousands)')
ax.set_title('Quarterly Sales Comparison')
ax.set_xticks(x)
ax.set_xticklabels(categories)
ax.legend()
plt.show()

# Stacked bar chart
fig, ax = plt.subplots(figsize=(10, 6))
ax.bar(categories, product_a, label='Product A', color='#3498db')
ax.bar(categories, product_b, bottom=product_a, label='Product B', color='#e74c3c')
ax.set_ylabel('Total Sales (thousands)')
ax.set_title('Quarterly Sales - Stacked')
ax.legend()
plt.show()

Histograms

Histograms visualize the distribution of numerical data by grouping values into bins. They show how frequently values occur within different ranges, helping you understand the shape of your data. Use histograms to identify patterns like normal distributions, skewness, or multimodal distributions.

# Basic histogram
np.random.seed(42)
exam_scores = np.random.normal(75, 10, 200)  # Mean=75, Std=10, 200 students

fig, ax = plt.subplots(figsize=(10, 6))
counts, bins, patches = ax.hist(exam_scores, bins=20, color='#3498db', 
                                 edgecolor='white', alpha=0.7)
ax.set_xlabel('Exam Score')
ax.set_ylabel('Number of Students')
ax.set_title('Distribution of Exam Scores')
ax.axvline(exam_scores.mean(), color='red', linestyle='--', label=f'Mean: {exam_scores.mean():.1f}')
ax.legend()
plt.show()

# Histogram with density curve
fig, ax = plt.subplots(figsize=(10, 6))
ax.hist(exam_scores, bins=20, density=True, alpha=0.7, color='#3498db', edgecolor='white')

# Overlay a smooth density curve
from scipy import stats
x_range = np.linspace(exam_scores.min(), exam_scores.max(), 100)
ax.plot(x_range, stats.norm.pdf(x_range, exam_scores.mean(), exam_scores.std()), 
        'r-', linewidth=2, label='Normal Distribution')
ax.set_xlabel('Exam Score')
ax.set_ylabel('Density')
ax.legend()
plt.show()

Pie Charts

Pie charts show parts of a whole as slices of a circle. While often overused, they work well for displaying simple proportions with a small number of categories (typically 5 or fewer). Use the explode parameter to emphasize specific slices and autopct to display percentages.

# Pie chart with percentages
categories = ['Electronics', 'Clothing', 'Food', 'Books', 'Other']
market_share = [35, 25, 20, 10, 10]
colors = ['#3498db', '#e74c3c', '#2ecc71', '#f39c12', '#95a5a6']
explode = [0.05, 0, 0, 0, 0]  # Slightly separate first slice

fig, ax = plt.subplots(figsize=(10, 8))
wedges, texts, autotexts = ax.pie(market_share, labels=categories, colors=colors,
                                   autopct='%1.1f%%', explode=explode,
                                   shadow=True, startangle=90)

# Style the percentage labels
for autotext in autotexts:
    autotext.set_fontsize(11)
    autotext.set_fontweight('bold')

ax.set_title('Market Share by Category', fontsize=14, fontweight='bold')
plt.show()

Practice Questions

Task: Create a bar chart showing website traffic by day of week.

# Given: Daily traffic data
days = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
visitors = [1200, 1500, 1400, 1600, 1800, 2200, 1900]

# Your code: Create bar chart with different color for weekends

Expected Output: Bar chart with weekdays in blue, weekends in green

days = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
visitors = [1200, 1500, 1400, 1600, 1800, 2200, 1900]
colors = ['#3498db'] * 5 + ['#2ecc71'] * 2  # Blue weekdays, green weekends

fig, ax = plt.subplots(figsize=(10, 6))
ax.bar(days, visitors, color=colors)
ax.set_xlabel('Day of Week')
ax.set_ylabel('Visitors')
ax.set_title('Website Traffic by Day')
plt.show()

Task: Create a scatter plot of house prices vs. square footage with a trend line.

# Given: Housing data
np.random.seed(42)
sqft = np.random.uniform(1000, 3000, 30)
price = 50000 + 150 * sqft + np.random.normal(0, 30000, 30)

# Your code: Scatter plot + linear trend line using np.polyfit

Expected Output: Scatter plot with red trend line, proper labels

np.random.seed(42)
sqft = np.random.uniform(1000, 3000, 30)
price = 50000 + 150 * sqft + np.random.normal(0, 30000, 30)

fig, ax = plt.subplots(figsize=(10, 6))
ax.scatter(sqft, price, alpha=0.7, s=80, c='#3498db')

# Add trend line
z = np.polyfit(sqft, price, 1)
p = np.poly1d(z)
x_line = np.linspace(sqft.min(), sqft.max(), 100)
ax.plot(x_line, p(x_line), 'r-', linewidth=2, label='Trend')

ax.set_xlabel('Square Footage')
ax.set_ylabel('Price ($)')
ax.set_title('House Prices vs. Square Footage')
ax.legend()
plt.show()

Task: Create a histogram of employee salaries with mean and median lines.

# Given: Salary data (right-skewed distribution)
np.random.seed(42)
salaries = np.random.exponential(scale=50000, size=500) + 30000

# Your code: Histogram with 25 bins, vertical lines for mean (red) and median (green)

Expected Output: Histogram with legend showing mean and median values

np.random.seed(42)
salaries = np.random.exponential(scale=50000, size=500) + 30000

fig, ax = plt.subplots(figsize=(10, 6))
ax.hist(salaries, bins=25, color='#3498db', edgecolor='white', alpha=0.7)

mean_sal = salaries.mean()
median_sal = np.median(salaries)
ax.axvline(mean_sal, color='red', linestyle='--', linewidth=2, label=f'Mean: ${mean_sal:,.0f}')
ax.axvline(median_sal, color='green', linestyle='-', linewidth=2, label=f'Median: ${median_sal:,.0f}')

ax.set_xlabel('Salary ($)')
ax.set_ylabel('Frequency')
ax.set_title('Employee Salary Distribution')
ax.legend()
plt.show()

Task: Create a grouped bar chart comparing revenue across regions and quarters.

# Given: Revenue data by region
quarters = ['Q1', 'Q2', 'Q3', 'Q4']
north = [120, 145, 160, 180]
south = [100, 130, 140, 155]
west = [90, 110, 125, 145]

# Your code: Grouped bar chart with 3 bars per quarter, legend, and value labels

Expected Output: Three colored bars per quarter with values displayed on top

quarters = ['Q1', 'Q2', 'Q3', 'Q4']
north = [120, 145, 160, 180]
south = [100, 130, 140, 155]
west = [90, 110, 125, 145]

x = np.arange(len(quarters))
width = 0.25

fig, ax = plt.subplots(figsize=(12, 6))
bars1 = ax.bar(x - width, north, width, label='North', color='#3498db')
bars2 = ax.bar(x, south, width, label='South', color='#e74c3c')
bars3 = ax.bar(x + width, west, width, label='West', color='#2ecc71')

# Add value labels
for bars in [bars1, bars2, bars3]:
    for bar in bars:
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 2,
                f'{bar.get_height():.0f}', ha='center', va='bottom', fontsize=9)

ax.set_xlabel('Quarter')
ax.set_ylabel('Revenue (thousands)')
ax.set_title('Quarterly Revenue by Region')
ax.set_xticks(x)
ax.set_xticklabels(quarters)
ax.legend()
plt.show()
04

Customizing Your Visualizations

The true power of Matplotlib lies in its extensive customization options. You can control every visual aspect of your plots, from colors and fonts to line styles and marker shapes. Learning these customization techniques helps you create professional, publication-ready figures that effectively communicate your data story.

Colors in Matplotlib

Matplotlib accepts colors in multiple formats, giving you flexibility in how you specify them. You can use named colors like 'red' or 'steelblue', hex codes like '#3498db', RGB tuples like (0.2, 0.4, 0.6), or shorthand codes like 'b' for blue. Using consistent, accessible colors improves your visualization quality significantly.

# Different ways to specify colors
fig, axes = plt.subplots(2, 3, figsize=(15, 8))

x = [1, 2, 3, 4, 5]
y = [2, 4, 6, 8, 10]

# Named color
axes[0, 0].plot(x, y, color='steelblue', linewidth=3)
axes[0, 0].set_title("Named: 'steelblue'")

# Hex code
axes[0, 1].plot(x, y, color='#e74c3c', linewidth=3)
axes[0, 1].set_title("Hex: '#e74c3c'")

# RGB tuple (0-1 range)
axes[0, 2].plot(x, y, color=(0.2, 0.6, 0.3), linewidth=3)
axes[0, 2].set_title("RGB: (0.2, 0.6, 0.3)")

# Shorthand codes
axes[1, 0].plot(x, y, color='g', linewidth=3)  # green
axes[1, 0].set_title("Shorthand: 'g'")

# With alpha transparency
axes[1, 1].plot(x, y, color='purple', alpha=0.5, linewidth=10)
axes[1, 1].set_title("With alpha=0.5")

# Using colormaps
colors = plt.cm.viridis(np.linspace(0, 1, 5))
for i, c in enumerate(colors):
    axes[1, 2].plot([1, 2], [i, i+1], color=c, linewidth=3)
axes[1, 2].set_title("Colormap: viridis")

plt.tight_layout()
plt.show()

Line Styles and Markers

Differentiating multiple data series requires varying line styles and markers. Matplotlib provides solid, dashed, dotted, and dash-dot line styles. Markers include circles, squares, triangles, and many more. Combining these options creates visually distinct lines that remain readable even in black-and-white printing.

# Line styles
fig, ax = plt.subplots(figsize=(12, 6))
x = np.linspace(0, 10, 50)

ax.plot(x, np.sin(x), linestyle='-', label='Solid (-)', linewidth=2)
ax.plot(x, np.sin(x + 0.5), linestyle='--', label='Dashed (--)', linewidth=2)
ax.plot(x, np.sin(x + 1), linestyle='-.', label='Dash-dot (-.)', linewidth=2)
ax.plot(x, np.sin(x + 1.5), linestyle=':', label='Dotted (:)', linewidth=2)

ax.legend()
ax.set_title('Line Style Options')
plt.show()

# Marker styles
fig, ax = plt.subplots(figsize=(12, 6))
x = np.arange(1, 6)

markers = ['o', 's', '^', 'D', 'v', 'p', '*', 'h']
marker_names = ['Circle', 'Square', 'Triangle Up', 'Diamond', 'Triangle Down', 'Pentagon', 'Star', 'Hexagon']

for i, (marker, name) in enumerate(zip(markers, marker_names)):
    ax.plot(x, [i+1]*5, marker=marker, markersize=12, linestyle='', label=name)

ax.set_yticks(range(1, len(markers)+1))
ax.set_yticklabels(marker_names)
ax.set_title('Marker Style Options')
ax.legend(loc='upper right')
plt.show()
Pro Tip

Format String Shortcut

Combine color, marker, and line style in one string: plt.plot(x, y, 'ro--') creates red circles with dashed lines. Format: [color][marker][linestyle].

Labels, Titles, and Text

Clear labels and titles are essential for understandable visualizations. Matplotlib provides extensive control over text properties including font size, weight, family, and color. You can add annotations to highlight specific data points and use mathematical notation with LaTeX-style formatting.

fig, ax = plt.subplots(figsize=(12, 7))

x = np.linspace(0, 10, 100)
y = np.sin(x)
ax.plot(x, y, 'b-', linewidth=2)

# Customized labels and title
ax.set_xlabel('Time (seconds)', fontsize=14, fontweight='bold', color='#333')
ax.set_ylabel('Amplitude', fontsize=14, fontweight='bold', color='#333')
ax.set_title('Sine Wave with Custom Styling', fontsize=18, fontweight='bold', 
             color='#2c3e50', pad=20)

# Add text annotation
ax.text(5, 0.5, 'Peak Region', fontsize=12, style='italic', 
        bbox=dict(boxstyle='round', facecolor='yellow', alpha=0.5))

# Add annotation with arrow
ax.annotate('Maximum', xy=(np.pi/2, 1), xytext=(3, 0.7),
            fontsize=12, arrowprops=dict(arrowstyle='->', color='red'),
            color='red')

# Mathematical notation using LaTeX
ax.text(8, -0.5, r'$y = \sin(x)$', fontsize=14, 
        bbox=dict(facecolor='white', edgecolor='gray'))

plt.tight_layout()
plt.show()

Legends

Legends identify different data series in your plot. Matplotlib offers extensive legend customization including position, number of columns, frame styling, and font properties. A well-placed legend enhances readability without obscuring important data points.

fig, ax = plt.subplots(figsize=(12, 6))

x = np.linspace(0, 10, 100)
ax.plot(x, np.sin(x), label='sin(x)', linewidth=2)
ax.plot(x, np.cos(x), label='cos(x)', linewidth=2)
ax.plot(x, np.sin(x) * np.cos(x), label='sin(x)·cos(x)', linewidth=2)

# Customized legend
ax.legend(
    loc='upper right',      # Position: 'best', 'upper left', 'lower right', etc.
    fontsize=11,            # Font size
    frameon=True,           # Show frame
    facecolor='white',      # Background color
    edgecolor='gray',       # Frame color
    framealpha=0.9,         # Frame transparency
    ncol=3,                 # Number of columns
    title='Functions',      # Legend title
    title_fontsize=12
)

ax.set_title('Legend Customization Example')
plt.show()

# Legend positions
positions = ['best', 'upper left', 'upper right', 'lower left', 
             'lower right', 'center left', 'center right', 'upper center', 
             'lower center', 'center']
print(f"Available legend locations: {positions}")

Styles and Themes

Matplotlib includes built-in style sheets that change the overall appearance of your plots. These provide consistent, professional looks without manually setting every property. You can also combine styles or create custom ones for your organization's branding.

# View available styles
print(plt.style.available)
# ['Solarize_Light2', 'bmh', 'classic', 'dark_background', 'fivethirtyeight',
#  'ggplot', 'grayscale', 'seaborn-v0_8', 'seaborn-v0_8-bright', ...]

# Apply a style
plt.style.use('seaborn-v0_8-whitegrid')

fig, ax = plt.subplots(figsize=(10, 6))
x = np.linspace(0, 10, 100)
ax.plot(x, np.sin(x), linewidth=2)
ax.plot(x, np.cos(x), linewidth=2)
ax.set_title('Seaborn Whitegrid Style')
plt.show()

# Use style as context manager (temporary)
with plt.style.context('dark_background'):
    fig, ax = plt.subplots(figsize=(10, 6))
    ax.plot(x, np.sin(x), linewidth=2)
    ax.set_title('Dark Background Style')
    plt.show()

# Reset to default
plt.style.use('default')
Style Description Best For
seaborn-v0_8 Clean, modern aesthetic General data science
ggplot R's ggplot2 inspired Statistical graphics
fivethirtyeight Bold, journalistic Presentations
dark_background Dark theme Slides, dark UIs
bmh Bayesian Methods for Hackers Academic papers

Practice Questions

Task: Create a styled line plot with custom colors, markers, and line style.

# Given: Monthly revenue data
months = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun']
revenue = [45, 52, 48, 61, 55, 67]

# Your code: Plot with hex color #2ecc71, square markers, dashed line, linewidth 2

Expected Output: Green dashed line with square markers

months = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun']
revenue = [45, 52, 48, 61, 55, 67]

fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(months, revenue, color='#2ecc71', marker='s', linestyle='--', 
        linewidth=2, markersize=10)
ax.set_xlabel('Month')
ax.set_ylabel('Revenue (thousands)')
ax.set_title('Monthly Revenue')
plt.show()

Task: Create a plot with annotations highlighting the maximum value.

# Given: Stock price data
days = np.arange(1, 31)
prices = 100 + np.cumsum(np.random.randn(30))

# Your code: Line plot with annotation arrow pointing to the max price

Expected Output: Line plot with "Peak: $X.XX" annotation at the maximum

np.random.seed(42)
days = np.arange(1, 31)
prices = 100 + np.cumsum(np.random.randn(30))

fig, ax = plt.subplots(figsize=(12, 6))
ax.plot(days, prices, 'b-', linewidth=2)

max_idx = np.argmax(prices)
max_price = prices[max_idx]
max_day = days[max_idx]

ax.annotate(f'Peak: ${max_price:.2f}', 
            xy=(max_day, max_price),
            xytext=(max_day + 3, max_price - 2),
            fontsize=12, fontweight='bold',
            arrowprops=dict(arrowstyle='->', color='red', lw=2),
            color='red')
ax.scatter([max_day], [max_price], color='red', s=100, zorder=5)

ax.set_xlabel('Day')
ax.set_ylabel('Price ($)')
ax.set_title('Stock Price with Peak Annotation')
plt.show()

Task: Create a publication-quality comparison chart with full styling.

# Given: Performance metrics for 3 algorithms
categories = ['Accuracy', 'Speed', 'Memory', 'Scalability']
algo_a = [0.92, 0.78, 0.85, 0.70]
algo_b = [0.88, 0.95, 0.72, 0.88]
algo_c = [0.95, 0.65, 0.90, 0.75]

# Your code: Grouped bar chart with custom colors, legend, annotations
# Use fivethirtyeight style, add title with subtitle effect

Expected Output: Professional grouped bar chart with value labels

categories = ['Accuracy', 'Speed', 'Memory', 'Scalability']
algo_a = [0.92, 0.78, 0.85, 0.70]
algo_b = [0.88, 0.95, 0.72, 0.88]
algo_c = [0.95, 0.65, 0.90, 0.75]

with plt.style.context('fivethirtyeight'):
    fig, ax = plt.subplots(figsize=(12, 7))
    
    x = np.arange(len(categories))
    width = 0.25
    
    bars1 = ax.bar(x - width, algo_a, width, label='Algorithm A', color='#3498db')
    bars2 = ax.bar(x, algo_b, width, label='Algorithm B', color='#e74c3c')
    bars3 = ax.bar(x + width, algo_c, width, label='Algorithm C', color='#2ecc71')
    
    # Value labels
    for bars in [bars1, bars2, bars3]:
        for bar in bars:
            ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,
                    f'{bar.get_height():.2f}', ha='center', fontsize=9)
    
    ax.set_ylabel('Score', fontsize=12)
    ax.set_title('Algorithm Performance Comparison\n', fontsize=16, fontweight='bold')
    ax.text(0.5, 1.02, 'Higher is better for all metrics', transform=ax.transAxes,
            ha='center', fontsize=10, style='italic', color='gray')
    
    ax.set_xticks(x)
    ax.set_xticklabels(categories)
    ax.set_ylim(0, 1.15)
    ax.legend(loc='upper right', framealpha=0.9)
    
    plt.tight_layout()
    plt.show()
05

Subplots & Figure Layouts

Complex data stories often require multiple visualizations displayed together. Matplotlib's subplot system lets you create grid layouts with multiple axes in a single figure. From simple side-by-side comparisons to sophisticated dashboard-style layouts, mastering subplots is essential for professional data visualization.

Basic Subplots with plt.subplots()

The plt.subplots() function creates a figure with a grid of axes. Specify the number of rows and columns, and it returns a figure object and an array of axes. For a single row or column, the axes array is 1-dimensional. For grids, it is 2-dimensional and you access individual axes using row and column indices.

# Create a 2x2 grid of subplots
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

x = np.linspace(0, 10, 100)

# Access each subplot by row, column index
axes[0, 0].plot(x, np.sin(x), 'b-')
axes[0, 0].set_title('Sine Wave')

axes[0, 1].plot(x, np.cos(x), 'r-')
axes[0, 1].set_title('Cosine Wave')

axes[1, 0].plot(x, np.exp(-x/5), 'g-')
axes[1, 0].set_title('Exponential Decay')

axes[1, 1].plot(x, np.log(x + 1), 'm-')
axes[1, 1].set_title('Logarithmic Growth')

plt.tight_layout()  # Prevent overlapping
plt.show()

# For single row or column, axes is 1D
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
axes[0].plot(x, np.sin(x))
axes[1].plot(x, np.cos(x))
axes[2].plot(x, np.tan(x))
plt.tight_layout()
plt.show()

Sharing Axes

When comparing related data, sharing x or y axes helps viewers make accurate comparisons. The sharex and sharey parameters link the axis limits across subplots. This ensures that all plots use the same scale, making differences in the data immediately apparent.

# Shared x-axis for time series comparison
fig, axes = plt.subplots(3, 1, figsize=(12, 8), sharex=True)

time = np.arange(100)
np.random.seed(42)

# Different metrics over the same time period
axes[0].plot(time, np.cumsum(np.random.randn(100)), 'b-', linewidth=2)
axes[0].set_ylabel('Revenue')
axes[0].set_title('Business Metrics Over Time')

axes[1].plot(time, 50 + np.cumsum(np.random.randn(100) * 0.5), 'g-', linewidth=2)
axes[1].set_ylabel('Customers')

axes[2].plot(time, 80 + np.cumsum(np.random.randn(100) * 0.3), 'r-', linewidth=2)
axes[2].set_ylabel('Satisfaction')
axes[2].set_xlabel('Days')

plt.tight_layout()
plt.show()

# Shared y-axis for comparing distributions
fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharey=True)

data1 = np.random.normal(50, 10, 1000)
data2 = np.random.normal(60, 15, 1000)

axes[0].hist(data1, bins=30, color='#3498db', edgecolor='white')
axes[0].set_title('Group A')
axes[0].set_xlabel('Value')
axes[0].set_ylabel('Frequency')

axes[1].hist(data2, bins=30, color='#e74c3c', edgecolor='white')
axes[1].set_title('Group B')
axes[1].set_xlabel('Value')

plt.tight_layout()
plt.show()
Parameter

sharex / sharey

When True, all subplots share the same axis. Can also be 'row', 'col', or 'all' for more control. Shared axes automatically hide redundant tick labels.

GridSpec for Complex Layouts

For layouts where subplots have different sizes, use GridSpec. This powerful tool lets you create grids where individual plots span multiple rows or columns. It is perfect for creating dashboard-style visualizations with a mix of large and small charts.

from matplotlib.gridspec import GridSpec

fig = plt.figure(figsize=(14, 10))
gs = GridSpec(3, 3, figure=fig)

# Large plot spanning top row
ax1 = fig.add_subplot(gs[0, :])  # Row 0, all columns
x = np.linspace(0, 10, 100)
ax1.plot(x, np.sin(x) * np.exp(-x/10), linewidth=2)
ax1.set_title('Main Time Series (Spanning Full Width)', fontsize=14)

# Two medium plots in middle row
ax2 = fig.add_subplot(gs[1, :2])  # Row 1, columns 0-1
ax2.bar(['A', 'B', 'C', 'D'], [25, 40, 30, 35], color='#3498db')
ax2.set_title('Category Breakdown')

ax3 = fig.add_subplot(gs[1, 2])  # Row 1, column 2
ax3.pie([35, 25, 20, 20], labels=['Q1', 'Q2', 'Q3', 'Q4'], autopct='%1.0f%%')
ax3.set_title('Quarterly Split')

# Three small plots in bottom row
ax4 = fig.add_subplot(gs[2, 0])
ax4.scatter(np.random.rand(20), np.random.rand(20), c='#e74c3c', s=50)
ax4.set_title('Scatter')

ax5 = fig.add_subplot(gs[2, 1])
ax5.hist(np.random.randn(100), bins=15, color='#2ecc71', edgecolor='white')
ax5.set_title('Distribution')

ax6 = fig.add_subplot(gs[2, 2])
ax6.plot([1, 2, 3, 4], [1, 4, 2, 3], 'o-', color='#9b59b6', linewidth=2)
ax6.set_title('Trend')

plt.tight_layout()
plt.show()

Saving Figures

After creating your visualization, you will often need to save it for reports, presentations, or publications. The savefig() function exports figures in various formats including PNG, PDF, SVG, and EPS. Control resolution with dpi and use bbox_inches='tight' to remove extra whitespace.

# Create a figure to save
fig, ax = plt.subplots(figsize=(10, 6))
x = np.linspace(0, 10, 100)
ax.plot(x, np.sin(x), linewidth=2)
ax.set_title('Publication-Ready Figure')
ax.set_xlabel('X Axis')
ax.set_ylabel('Y Axis')

# Save in different formats
fig.savefig('my_plot.png', dpi=300, bbox_inches='tight', facecolor='white')
fig.savefig('my_plot.pdf', bbox_inches='tight')  # Vector format for papers
fig.savefig('my_plot.svg', bbox_inches='tight')  # Vector format for web

# With transparent background
fig.savefig('my_plot_transparent.png', dpi=300, bbox_inches='tight', transparent=True)

print("Figures saved successfully!")
plt.show()

# Common savefig parameters
# dpi: Resolution (300 for print, 150 for web)
# bbox_inches: 'tight' removes whitespace
# facecolor: Background color
# transparent: True for transparent background
# format: 'png', 'pdf', 'svg', 'eps', 'jpg'
Format Type Best For Recommended DPI
PNG Raster Web, presentations 150-300
PDF Vector Academic papers, print N/A (scalable)
SVG Vector Web, interactive N/A (scalable)
EPS Vector LaTeX documents N/A (scalable)
JPG Raster Photos (not charts) 150-300

Practice Questions

Task: Create a 1x2 subplot comparing two distributions.

# Given: Two datasets
np.random.seed(42)
before = np.random.normal(100, 15, 200)
after = np.random.normal(110, 12, 200)

# Your code: Side-by-side histograms with shared y-axis, proper titles

Expected Output: Two histograms showing "Before" and "After" with same y-scale

np.random.seed(42)
before = np.random.normal(100, 15, 200)
after = np.random.normal(110, 12, 200)

fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharey=True)

axes[0].hist(before, bins=20, color='#e74c3c', edgecolor='white', alpha=0.7)
axes[0].set_title('Before Treatment')
axes[0].set_xlabel('Value')
axes[0].set_ylabel('Frequency')
axes[0].axvline(before.mean(), color='black', linestyle='--', label=f'Mean: {before.mean():.1f}')
axes[0].legend()

axes[1].hist(after, bins=20, color='#2ecc71', edgecolor='white', alpha=0.7)
axes[1].set_title('After Treatment')
axes[1].set_xlabel('Value')
axes[1].axvline(after.mean(), color='black', linestyle='--', label=f'Mean: {after.mean():.1f}')
axes[1].legend()

plt.suptitle('Treatment Effect Comparison', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

Task: Create a 2x2 subplot showing different views of sales data.

# Given: Monthly sales data
months = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun']
online = [120, 150, 140, 180, 200, 220]
store = [80, 90, 95, 100, 85, 110]

# Create 4 plots: line chart, bar chart, stacked bar, pie chart

Expected Output: Four different visualizations of the same data

months = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun']
online = [120, 150, 140, 180, 200, 220]
store = [80, 90, 95, 100, 85, 110]

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Line chart
axes[0, 0].plot(months, online, 'o-', label='Online', linewidth=2)
axes[0, 0].plot(months, store, 's-', label='Store', linewidth=2)
axes[0, 0].set_title('Sales Trend')
axes[0, 0].legend()
axes[0, 0].set_ylabel('Sales (K)')

# Grouped bar chart
x = np.arange(len(months))
width = 0.35
axes[0, 1].bar(x - width/2, online, width, label='Online')
axes[0, 1].bar(x + width/2, store, width, label='Store')
axes[0, 1].set_title('Monthly Comparison')
axes[0, 1].set_xticks(x)
axes[0, 1].set_xticklabels(months)
axes[0, 1].legend()

# Stacked bar
axes[1, 0].bar(months, online, label='Online')
axes[1, 0].bar(months, store, bottom=online, label='Store')
axes[1, 0].set_title('Total Sales (Stacked)')
axes[1, 0].legend()

# Pie chart for totals
axes[1, 1].pie([sum(online), sum(store)], labels=['Online', 'Store'],
               autopct='%1.1f%%', colors=['#3498db', '#e74c3c'])
axes[1, 1].set_title('Channel Distribution')

plt.suptitle('Sales Dashboard', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

Task: Create a dashboard layout with one large plot on top and three smaller plots below.

# Given: Stock data
np.random.seed(42)
days = np.arange(1, 101)
stock_price = 100 + np.cumsum(np.random.randn(100) * 2)
volume = np.random.uniform(1, 5, 100)
returns = np.diff(stock_price) / stock_price[:-1] * 100

# Create: Large price chart on top, volume bar, returns histogram, returns scatter below

Expected Output: Professional dashboard with main chart and supporting visuals

from matplotlib.gridspec import GridSpec

np.random.seed(42)
days = np.arange(1, 101)
stock_price = 100 + np.cumsum(np.random.randn(100) * 2)
volume = np.random.uniform(1, 5, 100)
returns = np.diff(stock_price) / stock_price[:-1] * 100

fig = plt.figure(figsize=(16, 10))
gs = GridSpec(2, 3, figure=fig, height_ratios=[2, 1])

# Main price chart (top, spans all columns)
ax1 = fig.add_subplot(gs[0, :])
ax1.plot(days, stock_price, 'b-', linewidth=2)
ax1.fill_between(days, stock_price.min(), stock_price, alpha=0.2)
ax1.set_title('Stock Price Over Time', fontsize=14, fontweight='bold')
ax1.set_xlabel('Day')
ax1.set_ylabel('Price ($)')
ax1.grid(True, alpha=0.3)

# Volume bar chart (bottom left)
ax2 = fig.add_subplot(gs[1, 0])
colors = ['#2ecc71' if returns[i-1] >= 0 else '#e74c3c' for i in range(1, len(volume))]
colors.insert(0, '#3498db')
ax2.bar(days, volume, color=colors, width=1)
ax2.set_title('Trading Volume')
ax2.set_xlabel('Day')
ax2.set_ylabel('Volume (M)')

# Returns histogram (bottom center)
ax3 = fig.add_subplot(gs[1, 1])
ax3.hist(returns, bins=20, color='#9b59b6', edgecolor='white')
ax3.axvline(0, color='black', linestyle='--')
ax3.set_title('Returns Distribution')
ax3.set_xlabel('Daily Return (%)')
ax3.set_ylabel('Frequency')

# Returns scatter (bottom right)
ax4 = fig.add_subplot(gs[1, 2])
colors = ['#2ecc71' if r >= 0 else '#e74c3c' for r in returns]
ax4.scatter(days[1:], returns, c=colors, alpha=0.6, s=30)
ax4.axhline(0, color='black', linestyle='-', linewidth=0.5)
ax4.set_title('Daily Returns')
ax4.set_xlabel('Day')
ax4.set_ylabel('Return (%)')

plt.suptitle('Stock Analysis Dashboard', fontsize=16, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

Key Takeaways

Figure & Axes Architecture

Every Matplotlib visualization has a Figure (canvas) containing one or more Axes (individual plots). Use fig, ax = plt.subplots() to create both at once. The object-oriented approach gives you more control than pyplot.

Choose the Right Plot Type

Use line plots for trends over time, scatter plots for relationships between variables, bar charts for categorical comparisons, and histograms for distributions. Match your chart type to your data and message.

Customize for Clarity

Use colors, markers, and line styles to differentiate data series. Add clear labels, titles, and legends. The format string shortcut ('ro--') combines color, marker, and linestyle in one argument.

Master Subplots

Create multi-panel figures with plt.subplots(rows, cols). Use sharex and sharey for consistent scales. For complex layouts, use GridSpec to span rows and columns.

Use Built-in Styles

Apply professional styling with plt.style.use('seaborn-v0_8'). Use styles as context managers for temporary changes. Explore available styles with plt.style.available.

Save Publication-Quality

Export figures with fig.savefig(). Use PNG for web (dpi=150), PDF/SVG for print. Always add bbox_inches='tight' to remove extra whitespace and facecolor='white' for clean backgrounds.

Knowledge Check

Quick Quiz

Test your understanding of Matplotlib fundamentals

1 What is the recommended way to create a figure and axes in Matplotlib?
2 Which plot type is best for showing the distribution of a continuous variable?
3 What does the format string 'ro--' represent in plt.plot(x, y, 'ro--')?
4 How do you create a 2x3 grid of subplots?
5 What parameter removes extra whitespace when saving a figure?
6 Which method is used to set the x-axis label in the object-oriented interface?
Answer all questions to check your score