Plot Anatomy
Every Matplotlib figure has a hierarchy: Figure contains Axes, Axes contain plot elements. Understanding this structure is key to creating and customizing plots.
Figure and Axes
A Figure is the entire window or canvas. An Axes is a single plot within the figure. One figure can contain multiple axes (subplots).
Why it matters: Knowing the difference lets you control layout and create complex multi-panel figures.
Two Interfaces
pyplot (Quick)
import matplotlib.pyplot as plt
# Quick and simple
plt.plot([1, 2, 3], [1, 4, 9])
plt.title('Quick Plot')
plt.xlabel('X')
plt.ylabel('Y')
plt.show()
Object-Oriented (Flexible)
import matplotlib.pyplot as plt
# More control
fig, ax = plt.subplots()
ax.plot([1, 2, 3], [1, 4, 9])
ax.set_title('OO Plot')
ax.set_xlabel('X')
ax.set_ylabel('Y')
plt.show()
The pyplot interface is quick for simple plots and works like MATLAB. The object-oriented interface gives more control and is recommended for complex figures. When you call plt.subplots(), it returns both the figure and axes objects, giving you direct access to all plot elements. This approach makes it easier to modify specific parts of your visualization and is essential when creating multiple subplots or customizing individual elements.
Installing Matplotlib
Before creating visualizations, you need to install matplotlib. It comes with many data science distributions, but you can also install it separately using pip or conda.
# Install with pip
# pip install matplotlib
# Install with conda
# conda install matplotlib
# Verify installation
import matplotlib
print(matplotlib.__version__) # Output: e.g., 3.8.2
# Import the pyplot module (standard convention)
import matplotlib.pyplot as plt
import numpy as np # NumPy is commonly used with matplotlib
The standard convention is to import matplotlib.pyplot as plt. This gives you access to all the common plotting functions with a short, convenient alias. NumPy is typically imported alongside matplotlib because most visualizations involve numerical data that benefits from NumPy's array operations. The combination of these two libraries forms the foundation of Python's data visualization ecosystem.
The Plotting Workflow
Creating a visualization follows a consistent workflow: create figure, plot data, customize appearance, and display or save.
# Complete workflow example
import matplotlib.pyplot as plt
import numpy as np
# 1. Create figure and axes
fig, ax = plt.subplots(figsize=(10, 6))
# 2. Generate and plot data
x = np.linspace(0, 10, 100)
y = np.sin(x)
ax.plot(x, y, color='steelblue', linewidth=2, label='sin(x)')
# 3. Customize appearance
ax.set_title('Sine Wave Visualization', fontsize=14, fontweight='bold')
ax.set_xlabel('X-axis', fontsize=12)
ax.set_ylabel('Y-axis', fontsize=12)
ax.legend(loc='upper right')
# 4. Refine layout
ax.grid(True, alpha=0.3)
plt.tight_layout()
# 5. Save and/or show
plt.savefig('sine_wave.png', dpi=150, bbox_inches='tight')
plt.show()
This example demonstrates the complete plotting workflow from start to finish. First, we create a figure with specified dimensions using figsize. Then we generate data with NumPy and plot it with customized styling. The set_title, set_xlabel, and set_ylabel methods add descriptive text. The grid and tight_layout calls improve readability and spacing. Finally, savefig exports the figure before show displays it. Always save before showing, as show clears the figure from memory.
Understanding Coordinates
Matplotlib uses a Cartesian coordinate system where the origin (0,0) is typically at the bottom-left, with x increasing to the right and y increasing upward.
Data Coordinates
# Data coordinates are your actual data values
x_data = [1, 2, 3, 4, 5]
y_data = [10, 20, 15, 25, 30]
plt.plot(x_data, y_data)
# Points appear at their data values
# (1, 10), (2, 20), (3, 15), etc.
Data coordinates match your actual data values.
Axes Coordinates
# Axes coordinates: 0-1 relative to axes
# (0, 0) = bottom-left corner
# (1, 1) = top-right corner
# (0.5, 0.5) = center of axes
ax.text(0.5, 0.5, 'Center',
transform=ax.transAxes)
Axes coordinates are normalized 0-1 values.
Practice: Plot Anatomy
Task: Import matplotlib.pyplot and numpy, then print the matplotlib version.
Show Solution
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
print(f"Matplotlib version: {matplotlib.__version__}")
print(f"NumPy version: {np.__version__}")
Task: Create a figure with axes using plt.subplots() and display it.
Show Solution
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
ax.set_title('Empty Figure')
plt.show()
Task: Create a simple line plot using the object-oriented interface (fig, ax = plt.subplots()).
Show Solution
import matplotlib.pyplot as plt
import numpy as np
# Create figure and axes
fig, ax = plt.subplots(figsize=(8, 5))
# Plot data
x = np.linspace(0, 5, 50)
y = x ** 2
ax.plot(x, y)
# Customize
ax.set_title('Quadratic Function')
ax.set_xlabel('x')
ax.set_ylabel('x squared')
ax.grid(True)
plt.show()
Task: Create a plot following the complete workflow: create, plot, customize, save, and show.
Show Solution
import matplotlib.pyplot as plt
import numpy as np
# 1. Create
fig, ax = plt.subplots(figsize=(10, 6))
# 2. Plot
x = np.linspace(0, 2 * np.pi, 100)
ax.plot(x, np.sin(x), label='sin(x)', color='blue')
ax.plot(x, np.cos(x), label='cos(x)', color='red')
# 3. Customize
ax.set_title('Trigonometric Functions', fontsize=14)
ax.set_xlabel('Radians')
ax.set_ylabel('Value')
ax.legend()
ax.grid(True, alpha=0.3)
# 4. Save
plt.tight_layout()
plt.savefig('trig_functions.png', dpi=150, bbox_inches='tight')
# 5. Show
plt.show()
Basic Plots
Start with line plots, then add markers, colors, and labels to make your visualizations informative and attractive. Line plots are the foundation of data visualization, perfect for showing trends and continuous data over time.
Line Plot Components
A line plot connects data points with straight lines. The essential components are: X values (horizontal position), Y values (vertical position), line style, color, and optional markers at each point.
When to use: Time series data, continuous measurements, trends over intervals, comparing multiple series.
Simple Line Plot
The most basic plot connects a series of points. Provide x and y coordinates as lists or arrays.
import matplotlib.pyplot as plt
import numpy as np
# Generate data
x = np.linspace(0, 10, 100) # 100 points from 0 to 10
y = np.sin(x) # Sine of each x value
# Create the plot
plt.plot(x, y)
# Add labels and title
plt.title('Sine Wave')
plt.xlabel('X')
plt.ylabel('sin(x)')
plt.grid(True)
plt.show()
This example creates a smooth sine wave by generating 100 evenly-spaced x values and computing their sine. The np.linspace function creates the x array, and np.sin applies the sine function element-wise. The plt.plot function connects these points with a line. Adding grid(True) overlays gridlines that make it easier to read values from the chart. The more points you generate, the smoother the curve appears.
Multiple Lines on One Plot
Compare multiple data series by plotting them on the same axes. Use labels and a legend to identify each line.
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 10, 100)
# Plot multiple lines
plt.plot(x, np.sin(x), label='sin(x)')
plt.plot(x, np.cos(x), label='cos(x)')
plt.plot(x, np.sin(x) + np.cos(x), label='sin(x) + cos(x)')
# Customize
plt.title('Trigonometric Functions')
plt.xlabel('X')
plt.ylabel('Y')
plt.legend() # Display legend with labels
plt.grid(True, alpha=0.3)
plt.show()
When plotting multiple lines, the label parameter assigns a name to each line that appears in the legend. The plt.legend() call displays the legend box, which by default appears in the best location to avoid overlapping data. Each subsequent plot call adds a new line in a different color from matplotlib's default color cycle. The alpha parameter on the grid controls transparency, making the gridlines subtle so they do not distract from the data.
Line Styles and Colors
Customize appearance using format strings (quick) or keyword arguments (detailed control).
Format String Reference
Colors
b- Blueg- Greenr- Redc- Cyanm- Magentay- Yellowk- Blackw- White
Markers
o- Circles- Square^- Triangle upv- Triangle down*- Star+- Plusx- X.- Point
Line Styles
-- Solid--- Dashed-.- Dash-dot:- Dotted- (none) - No line
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 5, 20)
# Format strings: 'color marker linestyle'
plt.plot(x, x, 'r--') # Red dashed line
plt.plot(x, x**1.5, 'bo') # Blue circles only (no line)
plt.plot(x, x**2, 'g.-') # Green dot markers with dash-dot line
plt.plot(x, x**2.5, 'ms-') # Magenta squares with solid line
plt.title('Format String Examples')
plt.legend(['Linear', 'x^1.5', 'Quadratic', 'x^2.5'])
plt.show()
Format strings combine color, marker, and line style into a short code. The order does not matter: 'r--', '--r', and 'r--' all work. When you omit the line style and only provide a marker (like 'bo'), matplotlib draws markers without connecting them. This is useful for scatter-like plots where you want to emphasize individual points rather than trends. Format strings are convenient for quick plots, but keyword arguments offer more options.
Keyword Arguments for Fine Control
For precise control over appearance, use keyword arguments instead of format strings.
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 10, 50)
y = np.sin(x)
plt.plot(x, y,
color='#E74C3C', # Hex color
linestyle='--', # Dashed line
linewidth=2.5, # Line thickness
marker='o', # Circle markers
markersize=6, # Marker size
markerfacecolor='white', # Marker fill color
markeredgecolor='#E74C3C', # Marker border color
markeredgewidth=1.5, # Marker border thickness
alpha=0.8, # Transparency (0-1)
label='Custom Style')
plt.title('Styled Line Plot')
plt.xlabel('X Values')
plt.ylabel('Y Values')
plt.legend()
plt.grid(True, linestyle=':', alpha=0.5)
plt.show()
Keyword arguments provide granular control over every visual aspect of your plot. You can specify colors using hex codes for exact brand colors, control line and marker sizes independently, and set different colors for marker faces and edges. The alpha parameter adds transparency, which is useful when lines overlap. This level of customization is essential for publication-quality figures and presentations where visual consistency matters.
Filling Areas
Use fill_between to shade the area between a line and a baseline, useful for showing ranges or emphasizing magnitude.
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 10, 100)
y = np.sin(x)
plt.plot(x, y, color='blue', linewidth=2)
plt.fill_between(x, y, alpha=0.3, color='blue') # Fill to y=0
plt.fill_between(x, y, 1, where=(y > 0.5), alpha=0.3, color='green') # Conditional fill
plt.title('Area Fill Example')
plt.xlabel('X')
plt.ylabel('Y')
plt.axhline(y=0, color='black', linestyle='-', linewidth=0.5) # Baseline
plt.show()
The fill_between function shades the area between your data and a baseline (default y=0). The first argument is x values, the second is your y data, and an optional third argument specifies a different baseline. The where parameter allows conditional filling, shading only where a condition is true. This is powerful for highlighting specific regions like values above a threshold. The axhline function draws a horizontal reference line at y=0.
Practice: Basic Plots
Task: Plot the squares of numbers 1-10 (x vs x squared).
Show Solution
import matplotlib.pyplot as plt
import numpy as np
x = np.arange(1, 11)
y = x ** 2
plt.plot(x, y)
plt.title('Squares')
plt.xlabel('x')
plt.ylabel('x squared')
plt.show()
Task: Create any plot and add a title, x-label, and y-label.
Show Solution
import matplotlib.pyplot as plt
plt.plot([1, 2, 3, 4], [1, 4, 2, 3])
plt.title('My First Plot')
plt.xlabel('Time (s)')
plt.ylabel('Value')
plt.show()
Task: Plot both x squared and x cubed on the same graph with a legend.
Show Solution
import matplotlib.pyplot as plt
import numpy as np
x = np.arange(1, 6)
plt.plot(x, x**2, label='x squared')
plt.plot(x, x**3, label='x cubed')
plt.xlabel('x')
plt.ylabel('y')
plt.legend()
plt.show()
Task: Create a plot with red dashed line and circle markers using format strings.
Show Solution
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 5, 10)
y = x ** 2
plt.plot(x, y, 'ro--') # Red circles, dashed line
plt.title('Format String Demo')
plt.xlabel('X')
plt.ylabel('Y')
plt.grid(True)
plt.show()
Task: Create a line plot with custom color, linewidth, marker size, and transparency.
Show Solution
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 10, 20)
y = np.sin(x)
plt.plot(x, y,
color='#3498DB',
linewidth=2,
marker='s',
markersize=8,
markerfacecolor='white',
markeredgecolor='#3498DB',
alpha=0.9,
label='Styled Sine')
plt.title('Keyword Arguments Demo')
plt.xlabel('X')
plt.ylabel('sin(x)')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
Task: Plot a curve and fill the area underneath with a semi-transparent color.
Show Solution
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 10, 100)
y = np.exp(-x/3) * np.sin(x)
plt.plot(x, y, color='purple', linewidth=2, label='Damped sine')
plt.fill_between(x, y, alpha=0.3, color='purple')
plt.axhline(y=0, color='black', linewidth=0.5)
plt.title('Filled Area Plot')
plt.xlabel('X')
plt.ylabel('Y')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
Chart Types
Different data calls for different chart types. Choose the right visualization to communicate your insights effectively. Line plots excel at showing trends over time, bar charts compare discrete categories, scatter plots reveal relationships between variables, histograms display distributions, pie charts show proportions, and box plots summarize statistical properties. Selecting the appropriate chart type is crucial—the wrong choice can mislead or confuse your audience, while the right one makes patterns and insights immediately clear.
Bar Chart
categories = ['A', 'B', 'C', 'D']
values = [25, 40, 30, 55]
plt.bar(categories, values, color='steelblue')
plt.title('Sales by Category')
plt.xlabel('Category')
plt.ylabel('Sales')
plt.show()
Bar charts are ideal for comparing discrete categories. Use plt.barh() for horizontal bars.
Scatter Plot
import numpy as np
x = np.random.randn(100)
y = x + np.random.randn(100) * 0.5
plt.scatter(x, y, alpha=0.6, c='coral')
plt.title('Correlation Example')
plt.xlabel('X')
plt.ylabel('Y')
plt.show()
Use scatter plots to show relationships. The alpha parameter controls transparency (useful for overlapping points).
Histogram
data = np.random.randn(1000)
plt.hist(data, bins=30, edgecolor='black', alpha=0.7)
plt.title('Distribution of Values')
plt.xlabel('Value')
plt.ylabel('Frequency')
plt.show()
Histograms show how data is distributed across value ranges. The bins parameter controls granularity: more bins show more detail but can look noisy, fewer bins smooth the distribution but may hide patterns. The edgecolor parameter adds borders to distinguish adjacent bars. Always consider your data size when choosing bins: with 1000 points, 30 bins gives about 33 points per bin on average, providing a reliable estimate of the distribution shape.
Pie Chart
Pie charts show parts of a whole. Use them sparingly, as bar charts are often clearer for comparisons.
import matplotlib.pyplot as plt
# Data for pie chart
sizes = [35, 25, 20, 15, 5]
labels = ['Python', 'JavaScript', 'Java', 'C++', 'Other']
colors = ['#3498DB', '#F39C12', '#E74C3C', '#9B59B6', '#95A5A6']
explode = (0.05, 0, 0, 0, 0) # Offset the first slice
plt.figure(figsize=(8, 8))
plt.pie(sizes, labels=labels, colors=colors, explode=explode,
autopct='%1.1f%%', startangle=90, shadow=True)
plt.title('Programming Language Popularity')
plt.axis('equal') # Equal aspect ratio for circular pie
plt.show()
Pie charts represent proportions of a whole, with each slice sized according to its value. The explode parameter pulls slices away from center for emphasis. autopct formats percentage labels on each slice. startangle rotates the chart so the first slice begins at that angle (90 degrees is the top). The axis('equal') call ensures the pie is circular rather than elliptical. While visually appealing, pie charts become hard to read with many categories or similar-sized slices.
Box Plot
Box plots (box-and-whisker plots) show statistical distributions including median, quartiles, and outliers.
import matplotlib.pyplot as plt
import numpy as np
# Generate sample data for multiple groups
np.random.seed(42)
data = [np.random.normal(0, std, 100) for std in range(1, 5)]
fig, ax = plt.subplots(figsize=(10, 6))
bp = ax.boxplot(data, labels=['A', 'B', 'C', 'D'], patch_artist=True)
# Color the boxes
colors = ['#3498DB', '#2ECC71', '#F39C12', '#E74C3C']
for patch, color in zip(bp['boxes'], colors):
patch.set_facecolor(color)
patch.set_alpha(0.7)
ax.set_title('Distribution Comparison')
ax.set_xlabel('Group')
ax.set_ylabel('Value')
ax.grid(True, axis='y', alpha=0.3)
plt.show()
Box plots visualize five key statistics: minimum, first quartile (25th percentile), median (50th percentile), third quartile (75th percentile), and maximum. Points beyond the whiskers are outliers. The patch_artist=True parameter allows filling boxes with color. Box plots excel at comparing distributions across groups because you can quickly see differences in spread, center, and skewness. They are essential for exploratory data analysis and identifying unusual patterns in your data.
Heatmap
Heatmaps display matrix data using color intensity, perfect for correlation matrices or 2D data.
import matplotlib.pyplot as plt
import numpy as np
# Create a correlation-like matrix
np.random.seed(42)
data = np.random.rand(5, 5)
# Make it symmetric for a correlation matrix effect
data = (data + data.T) / 2
np.fill_diagonal(data, 1)
fig, ax = plt.subplots(figsize=(8, 6))
im = ax.imshow(data, cmap='coolwarm', vmin=0, vmax=1)
# Add colorbar
cbar = plt.colorbar(im)
cbar.set_label('Correlation')
# Add labels
labels = ['A', 'B', 'C', 'D', 'E']
ax.set_xticks(range(len(labels)))
ax.set_yticks(range(len(labels)))
ax.set_xticklabels(labels)
ax.set_yticklabels(labels)
# Add value annotations
for i in range(len(labels)):
for j in range(len(labels)):
ax.text(j, i, f'{data[i, j]:.2f}', ha='center', va='center',
color='white' if data[i, j] > 0.5 else 'black')
ax.set_title('Correlation Heatmap')
plt.tight_layout()
plt.show()
Heatmaps use color to represent values in a matrix. The imshow function displays the matrix, and cmap selects the color scheme (coolwarm goes from blue through white to red). The vmin and vmax parameters set the color scale range. Adding a colorbar provides a legend for interpreting colors. Text annotations show exact values in each cell, with color chosen based on the background to ensure readability. Heatmaps are indispensable for visualizing correlation matrices, confusion matrices, and any 2D tabular data.
Chart Selection Guide
| Data Type | Best Chart | When to Use |
|---|---|---|
| Trends over time | Line Plot | Stock prices, temperature, continuous measurements |
| Category comparison | Bar Chart | Sales by region, survey responses, counts |
| Two variable relationship | Scatter Plot | Height vs weight, correlation analysis |
| Value distribution | Histogram | Age distribution, test scores, frequency |
| Parts of a whole | Pie Chart | Market share, budget allocation (limited categories) |
| Statistical summary | Box Plot | Compare distributions, identify outliers |
| Matrix/2D data | Heatmap | Correlation matrix, confusion matrix, schedules |
Practice: Chart Types
Task: Create a bar chart showing sales for 4 products.
Show Solution
import matplotlib.pyplot as plt
products = ['Widget', 'Gadget', 'Gizmo', 'Doodad']
sales = [120, 85, 200, 150]
plt.bar(products, sales)
plt.title('Product Sales')
plt.ylabel('Units Sold')
plt.show()
Task: Create a scatter plot with random x and y data.
Show Solution
import matplotlib.pyplot as plt
import numpy as np
x = np.random.rand(30)
y = np.random.rand(30)
plt.scatter(x, y, color='coral', s=50)
plt.title('Random Scatter')
plt.xlabel('X')
plt.ylabel('Y')
plt.show()
Task: Generate 500 random numbers and plot their histogram with 20 bins.
Show Solution
import matplotlib.pyplot as plt
import numpy as np
data = np.random.randn(500)
plt.hist(data, bins=20, edgecolor='black')
plt.title('Random Distribution')
plt.xlabel('Value')
plt.ylabel('Count')
plt.show()
Task: Create a horizontal bar chart showing programming language popularity.
Show Solution
import matplotlib.pyplot as plt
languages = ['Python', 'JavaScript', 'Java', 'C++', 'Go']
popularity = [85, 78, 65, 55, 45]
plt.barh(languages, popularity, color='steelblue')
plt.title('Programming Language Popularity')
plt.xlabel('Popularity Score')
plt.tight_layout()
plt.show()
Task: Create a scatter plot where point color depends on y value.
Show Solution
import matplotlib.pyplot as plt
import numpy as np
x = np.random.randn(50)
y = np.random.randn(50)
plt.scatter(x, y, c=y, cmap='viridis')
plt.colorbar(label='Y value')
plt.title('Colored Scatter')
plt.show()
Task: Create a box plot comparing 3 different distributions.
Show Solution
import matplotlib.pyplot as plt
import numpy as np
np.random.seed(42)
data = [
np.random.normal(0, 1, 100), # Mean 0, std 1
np.random.normal(2, 1.5, 100), # Mean 2, std 1.5
np.random.normal(-1, 0.5, 100) # Mean -1, std 0.5
]
plt.boxplot(data, labels=['Group A', 'Group B', 'Group C'])
plt.title('Distribution Comparison')
plt.ylabel('Value')
plt.grid(True, axis='y', alpha=0.3)
plt.show()
Customization
Control colors, fonts, axes limits, and more to create publication-quality visualizations.
What is Plot Customization?
Plot customization involves modifying the visual appearance of charts to improve clarity, match branding, or meet publication standards. Key customization areas include:
- Colors: Line colors, fill colors, colormaps for data-driven coloring
- Typography: Font family, size, weight for titles and labels
- Axes: Limits, ticks, tick labels, logarithmic scales
- Layout: Figure size, DPI, margins, tight layout
- Annotations: Text, arrows, shapes to highlight features
- Styles: Pre-built themes that change overall appearance
Colors and Styles
Matplotlib supports multiple ways to specify colors and provides built-in styles for quick theming.
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 10, 100)
# Color specification methods
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
# Named colors
axes[0, 0].plot(x, np.sin(x), color='crimson', linewidth=2)
axes[0, 0].set_title('Named Color: crimson')
# Hex colors
axes[0, 1].plot(x, np.sin(x), color='#3498DB', linewidth=2)
axes[0, 1].set_title('Hex Color: #3498DB')
# RGB tuple (0-1 scale)
axes[1, 0].plot(x, np.sin(x), color=(0.2, 0.6, 0.3), linewidth=2)
axes[1, 0].set_title('RGB Tuple: (0.2, 0.6, 0.3)')
# RGBA with transparency
axes[1, 1].plot(x, np.sin(x), color=(0.8, 0.2, 0.5, 0.7), linewidth=2)
axes[1, 1].set_title('RGBA with Alpha: 0.7')
plt.tight_layout()
plt.show()
Matplotlib accepts colors in many formats: named colors like 'crimson' or 'steelblue', hexadecimal codes for precise web colors, RGB tuples with values from 0 to 1, and RGBA tuples that include an alpha (transparency) channel. Using consistent colors across your visualizations creates a professional appearance and helps viewers associate colors with specific data categories.
Common Named Colors
- ● red
- ● blue
- ● green
- ● orange
- ● purple
- ● cyan
- ● magenta
- ● yellow
- ● coral
- ● crimson
- ● steelblue
- ● teal
- ● gold
- ● navy
- ● olive
- ● salmon
Using Styles
Styles are pre-built themes that change multiple visual settings at once.
import matplotlib.pyplot as plt
import numpy as np
# See all available styles
print(plt.style.available)
# ['Solarize_Light2', 'bmh', 'classic', 'dark_background',
# 'fast', 'fivethirtyeight', 'ggplot', 'grayscale', 'seaborn-v0_8', ...]
# Apply a style
plt.style.use('seaborn-v0_8-darkgrid')
x = np.linspace(0, 10, 100)
plt.figure(figsize=(10, 6))
plt.plot(x, np.sin(x), label='sin(x)')
plt.plot(x, np.cos(x), label='cos(x)')
plt.title('Seaborn Dark Grid Style')
plt.legend()
plt.show()
# Reset to default
plt.style.use('default')
Matplotlib comes with dozens of built-in styles that instantly transform your plots. Popular choices include 'ggplot' (mimics R's ggplot2), 'fivethirtyeight' (inspired by the news site's data journalism), 'dark_background' (white text on black), and 'seaborn-v0_8-darkgrid' (clean statistical style). Use plt.style.available to see all options. You can also create custom style sheets or use plt.style.context() for temporary style changes.
Colormaps
Colormaps map data values to colors, essential for heatmaps, scatter plots with color-coded data, and contour plots.
import matplotlib.pyplot as plt
import numpy as np
# Create data for colormap demonstration
x = np.random.randn(100)
y = np.random.randn(100)
colors = np.sqrt(x**2 + y**2) # Distance from origin
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
# Different colormaps
cmaps = ['viridis', 'plasma', 'coolwarm', 'RdYlGn']
for ax, cmap in zip(axes.flat, cmaps):
scatter = ax.scatter(x, y, c=colors, cmap=cmap, s=50, alpha=0.7)
ax.set_title(f'Colormap: {cmap}')
ax.set_xlabel('X')
ax.set_ylabel('Y')
plt.colorbar(scatter, ax=ax, label='Distance')
plt.tight_layout()
plt.show()
Colormaps transform numerical values into colors. Sequential colormaps (viridis, plasma) work best for ordered data ranging from low to high. Diverging colormaps (coolwarm, RdYlGn) are ideal when data has a meaningful center point, with different colors for above and below. Categorical colormaps (Set1, tab10) provide distinct colors for discrete categories. The 'viridis' colormap is perceptually uniform and colorblind-friendly, making it the default choice for most applications.
Axis Limits and Ticks
Control exactly what ranges and tick marks appear on your axes for clearer data presentation.
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 10, 100)
y = np.sin(x)
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
# Default axis limits
axes[0, 0].plot(x, y)
axes[0, 0].set_title('Default Axis Limits')
# Custom axis limits
axes[0, 1].plot(x, y)
axes[0, 1].set_xlim(0, 5) # X-axis range
axes[0, 1].set_ylim(-1.5, 1.5) # Y-axis range
axes[0, 1].set_title('Custom Limits: xlim(0,5), ylim(-1.5,1.5)')
# Custom tick locations
axes[1, 0].plot(x, y)
axes[1, 0].set_xticks([0, np.pi, 2*np.pi, 3*np.pi])
axes[1, 0].set_xticklabels(['0', 'π', '2π', '3π'])
axes[1, 0].set_title('Custom Tick Labels')
# Tick formatting
axes[1, 1].plot(x, y)
axes[1, 1].set_yticks([-1, -0.5, 0, 0.5, 1])
axes[1, 1].set_yticklabels(['Very Low', 'Low', 'Mid', 'High', 'Very High'])
axes[1, 1].set_title('Descriptive Tick Labels')
plt.tight_layout()
plt.show()
The set_xlim and set_ylim methods control the visible range of your axes, useful for zooming into interesting regions or maintaining consistent scales across multiple plots. The set_xticks and set_yticks methods specify exact tick positions, while set_xticklabels and set_yticklabels let you replace numeric ticks with custom text. This is particularly useful for displaying mathematical notation, categorical labels, or human-readable descriptions instead of raw numbers.
Figure Size and DPI
Control the dimensions and resolution of your figures for different output needs.
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 10, 100)
y = np.sin(x)
# Method 1: Set size when creating figure
plt.figure(figsize=(12, 4)) # Width=12 inches, Height=4 inches
plt.plot(x, y)
plt.title('Wide Figure (12x4 inches)')
plt.show()
# Method 2: With subplots and DPI
fig, ax = plt.subplots(figsize=(8, 6), dpi=150) # Higher DPI = sharper
ax.plot(x, y, linewidth=2)
ax.set_title('High DPI Figure (150 dpi)')
plt.show()
# Checking current figure size
fig = plt.gcf()
print(f"Figure size: {fig.get_size_inches()}") # Returns (width, height)
The figsize parameter takes a tuple of (width, height) in inches. Standard sizes include (6, 4) for general use, (10, 6) for presentations, and (12, 8) for detailed visualizations. DPI (dots per inch) controls resolution: 100 dpi is standard for screens, 150-300 dpi for print, and 72 dpi for web. Higher DPI creates sharper images but larger file sizes. Always consider your final output medium when choosing these settings.
Annotations and Text
Add text, arrows, and shapes to highlight important features in your plots.
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 2*np.pi, 100)
y = np.sin(x)
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(x, y, 'b-', linewidth=2)
# Add text at a specific location
ax.text(np.pi/2, 1.1, 'Maximum', fontsize=12, ha='center',
color='green', fontweight='bold')
# Add annotation with arrow
ax.annotate('Minimum', xy=(3*np.pi/2, -1), xytext=(3*np.pi/2, -1.5),
fontsize=12, ha='center',
arrowprops=dict(arrowstyle='->', color='red', lw=2))
# Add horizontal reference line
ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
# Add vertical span to highlight a region
ax.axvspan(np.pi, 2*np.pi, alpha=0.2, color='yellow', label='Second half')
# Add text box
props = dict(boxstyle='round', facecolor='wheat', alpha=0.8)
ax.text(0.05, 0.95, 'Sine Wave: y = sin(x)', transform=ax.transAxes,
fontsize=10, verticalalignment='top', bbox=props)
ax.set_title('Annotation Examples')
ax.set_xlabel('x (radians)')
ax.set_ylabel('sin(x)')
ax.legend()
plt.tight_layout()
plt.show()
Annotations transform raw visualizations into informative graphics. The text function places text at data coordinates, while annotate adds text with an arrow pointing to a specific location. The axhline and axvline functions draw horizontal and vertical reference lines, while axhspan and axvspan shade regions. The transform=ax.transAxes parameter places text in figure-relative coordinates (0-1), useful for legends or notes that should stay fixed regardless of data range.
Grid and Spines
Customize grid lines and axis borders (spines) for cleaner visualizations.
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 10, 100)
y = np.sin(x)
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
# Grid customization
axes[0, 0].plot(x, y)
axes[0, 0].grid(True, linestyle='--', alpha=0.7)
axes[0, 0].set_title('Dashed Grid')
# Major and minor grid
axes[0, 1].plot(x, y)
axes[0, 1].grid(True, which='major', linestyle='-', alpha=0.7)
axes[0, 1].grid(True, which='minor', linestyle=':', alpha=0.4)
axes[0, 1].minorticks_on()
axes[0, 1].set_title('Major + Minor Grid')
# Remove spines (axis borders)
axes[1, 0].plot(x, y)
axes[1, 0].spines['top'].set_visible(False)
axes[1, 0].spines['right'].set_visible(False)
axes[1, 0].set_title('No Top/Right Spines')
# Move spines to center
axes[1, 1].plot(x, y)
axes[1, 1].spines['left'].set_position('center')
axes[1, 1].spines['bottom'].set_position('center')
axes[1, 1].spines['top'].set_visible(False)
axes[1, 1].spines['right'].set_visible(False)
axes[1, 1].set_title('Centered Spines (Math Style)')
plt.tight_layout()
plt.show()
Grid lines help viewers read exact values from your plot. Use linestyle ('--', ':', '-') and alpha to control appearance. Minor ticks provide finer divisions without cluttering major tick labels. Spines are the axis borders; removing top and right spines creates a cleaner, modern look common in data journalism. Moving spines to center creates traditional math-style axes that cross at the origin, useful for educational materials showing positive and negative values.
Practice: Customization
Task: Create a line plot with hex color #E74C3C.
Show Solution
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 10, 50)
y = np.sin(x)
plt.plot(x, y, color='#E74C3C', linewidth=2)
plt.title('Custom Hex Color')
plt.show()
Task: Apply the 'ggplot' style to a simple plot.
Show Solution
import matplotlib.pyplot as plt
import numpy as np
plt.style.use('ggplot')
x = np.linspace(0, 10, 50)
plt.plot(x, np.sin(x), label='sin')
plt.plot(x, np.cos(x), label='cos')
plt.legend()
plt.title('ggplot Style')
plt.show()
plt.style.use('default') # Reset
Task: Create a sine plot and zoom into x=[0, π], y=[-0.5, 1.5].
Show Solution
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 2*np.pi, 100)
plt.plot(x, np.sin(x))
plt.xlim(0, np.pi)
plt.ylim(-0.5, 1.5)
plt.title('Zoomed View')
plt.xlabel('x')
plt.ylabel('sin(x)')
plt.show()
Task: Create a plot with dashed grid lines at 50% opacity.
Show Solution
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 10, 50)
plt.plot(x, x**2)
plt.grid(True, linestyle='--', alpha=0.5)
plt.title('Dashed Grid')
plt.xlabel('x')
plt.ylabel('y')
plt.show()
Task: Plot sin(x) and annotate the maximum point with an arrow.
Show Solution
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 2*np.pi, 100)
y = np.sin(x)
plt.figure(figsize=(10, 6))
plt.plot(x, y, 'b-', linewidth=2)
plt.annotate('Maximum',
xy=(np.pi/2, 1), # Point to annotate
xytext=(np.pi/2 + 1, 0.7), # Text position
fontsize=12,
arrowprops=dict(arrowstyle='->', color='red'))
plt.title('Annotated Sine Wave')
plt.xlabel('x')
plt.ylabel('sin(x)')
plt.grid(True, alpha=0.3)
plt.show()
Task: Create a plot with only bottom and left axis borders visible.
Show Solution
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 10, 50)
y = np.exp(-x/5) * np.sin(x)
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(x, y, 'steelblue', linewidth=2)
# Remove top and right spines
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_title('Clean Modern Style')
ax.set_xlabel('Time')
ax.set_ylabel('Amplitude')
plt.show()
Subplots and Saving
Create multiple plots in a single figure and save your visualizations to files.
What are Subplots?
Subplots allow you to display multiple plots within a single figure, enabling side-by-side comparisons or multi-panel visualizations. Key concepts include:
- Figure: The overall canvas that contains all subplots
- Axes: Individual plot areas within the figure
- Grid Layout: Organizing subplots in rows and columns
- GridSpec: Advanced layout control for irregular grids
- Shared Axes: Common x or y axes across subplots for easier comparison
Basic Subplots
The subplots function creates a figure with a grid of axes that you can fill with different plots.
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 10, 100)
# Create a 2x2 grid of subplots
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
# Top-left: Line plot
axes[0, 0].plot(x, np.sin(x), 'b-', linewidth=2)
axes[0, 0].set_title('Sine Wave')
axes[0, 0].set_xlabel('x')
axes[0, 0].set_ylabel('sin(x)')
axes[0, 0].grid(True, alpha=0.3)
# Top-right: Cosine plot
axes[0, 1].plot(x, np.cos(x), 'r-', linewidth=2)
axes[0, 1].set_title('Cosine Wave')
axes[0, 1].set_xlabel('x')
axes[0, 1].set_ylabel('cos(x)')
axes[0, 1].grid(True, alpha=0.3)
# Bottom-left: Bar chart
categories = ['A', 'B', 'C', 'D', 'E']
values = [23, 45, 56, 78, 32]
axes[1, 0].bar(categories, values, color='steelblue')
axes[1, 0].set_title('Category Values')
axes[1, 0].set_ylabel('Count')
# Bottom-right: Scatter plot
scatter_x = np.random.randn(50)
scatter_y = np.random.randn(50)
axes[1, 1].scatter(scatter_x, scatter_y, c='coral', alpha=0.6, s=50)
axes[1, 1].set_title('Random Scatter')
axes[1, 1].set_xlabel('X')
axes[1, 1].set_ylabel('Y')
# Prevent overlap and display
plt.tight_layout()
plt.show()
The subplots function returns a Figure object and an array of Axes objects. For a 2x2 grid, axes is a 2D array accessed with [row, col] indexing (zero-based). Each axes object has its own plotting methods like plot, bar, scatter, and its own set_title, set_xlabel, set_ylabel methods. The tight_layout function automatically adjusts spacing to prevent titles and labels from overlapping, which is essential for multi-panel figures.
Single Row or Column
When creating a single row or column of subplots, the axes array is 1D instead of 2D.
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 10, 100)
# Single row (1x3 grid)
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
axes[0].plot(x, np.sin(x), 'b-')
axes[0].set_title('Sine')
axes[1].plot(x, np.cos(x), 'r-')
axes[1].set_title('Cosine')
axes[2].plot(x, np.tan(x), 'g-')
axes[2].set_ylim(-5, 5) # Limit y for tangent
axes[2].set_title('Tangent')
plt.tight_layout()
plt.show()
# Single column (3x1 grid)
fig, axes = plt.subplots(3, 1, figsize=(8, 10))
for i, (func, name) in enumerate([(np.sin, 'Sine'), (np.cos, 'Cosine'), (np.exp, 'Exponential')]):
if name == 'Exponential':
axes[i].plot(x, func(x/10), 'purple') # Scale x for exp
else:
axes[i].plot(x, func(x), 'teal')
axes[i].set_title(name)
axes[i].set_xlabel('x')
axes[i].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
When subplots creates a single row (1, n) or single column (n, 1), the axes variable becomes a 1D array accessed with a single index like axes[0], axes[1]. This is more convenient than 2D indexing for simple layouts. The example also shows iterating through axes with enumerate, which is useful when applying similar formatting to multiple subplots. Each subplot maintains independent x and y limits, allowing you to zoom differently on each panel.
Shared Axes
Use shared axes to maintain the same scale across subplots, making comparisons easier.
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 10, 100)
# Shared X axis (useful for time series at different scales)
fig, axes = plt.subplots(2, 1, figsize=(10, 8), sharex=True)
axes[0].plot(x, np.sin(x), 'b-', linewidth=2)
axes[0].set_ylabel('sin(x)')
axes[0].set_title('Shared X-Axis Example')
axes[0].grid(True, alpha=0.3)
axes[1].plot(x, np.sin(x) + np.random.randn(100) * 0.2, 'r-', alpha=0.7)
axes[1].set_ylabel('sin(x) + noise')
axes[1].set_xlabel('x')
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
# Shared Y axis (useful for comparing distributions)
fig, axes = plt.subplots(1, 3, figsize=(15, 5), sharey=True)
data1 = np.random.normal(0, 1, 500)
data2 = np.random.normal(2, 1.5, 500)
data3 = np.random.normal(-1, 0.8, 500)
for ax, data, title in zip(axes, [data1, data2, data3], ['Group A', 'Group B', 'Group C']):
ax.hist(data, bins=20, edgecolor='black', alpha=0.7)
ax.set_title(title)
ax.set_xlabel('Value')
axes[0].set_ylabel('Frequency') # Only first needs y-label
plt.tight_layout()
plt.show()
The sharex=True parameter links the x-axes of all subplots so they zoom and pan together, essential for time series or sequential data. The sharey=True parameter does the same for y-axes, crucial when comparing distributions or values that should be on the same scale. With shared axes, x-axis labels only appear on the bottom row and y-axis labels only on the leftmost column, reducing visual clutter while maintaining clarity.
GridSpec for Complex Layouts
GridSpec provides fine-grained control for creating irregular subplot grids.
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
x = np.linspace(0, 10, 100)
# Create figure with GridSpec
fig = plt.figure(figsize=(12, 8))
gs = gridspec.GridSpec(3, 3, figure=fig)
# Large plot spanning 2 rows and 2 columns
ax1 = fig.add_subplot(gs[0:2, 0:2])
ax1.plot(x, np.sin(x), 'b-', linewidth=2)
ax1.plot(x, np.cos(x), 'r-', linewidth=2)
ax1.set_title('Main Plot (2x2)', fontsize=14)
ax1.legend(['sin', 'cos'])
ax1.grid(True, alpha=0.3)
# Right column: two stacked plots
ax2 = fig.add_subplot(gs[0, 2])
ax2.bar(['A', 'B', 'C'], [3, 7, 5], color='steelblue')
ax2.set_title('Bar Chart')
ax3 = fig.add_subplot(gs[1, 2])
ax3.scatter(np.random.randn(30), np.random.randn(30), c='coral')
ax3.set_title('Scatter')
# Bottom row: three equal plots
ax4 = fig.add_subplot(gs[2, 0])
ax4.hist(np.random.randn(200), bins=15, color='purple', alpha=0.7)
ax4.set_title('Histogram 1')
ax5 = fig.add_subplot(gs[2, 1])
ax5.hist(np.random.randn(200), bins=15, color='green', alpha=0.7)
ax5.set_title('Histogram 2')
ax6 = fig.add_subplot(gs[2, 2])
ax6.hist(np.random.randn(200), bins=15, color='orange', alpha=0.7)
ax6.set_title('Histogram 3')
plt.tight_layout()
plt.show()
GridSpec creates a grid that you can slice like a NumPy array to create subplots of varying sizes. The syntax gs[row_slice, col_slice] specifies which grid cells the subplot occupies. For example, gs[0:2, 0:2] creates a subplot spanning rows 0-1 and columns 0-1 (a 2x2 area). This is invaluable for dashboards where a main visualization needs more space, surrounded by smaller supporting plots. GridSpec also supports gaps between subplots via hspace and wspace parameters.
Saving Figures
Save your visualizations to various file formats for reports, papers, or web publishing.
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 10, 100)
y = np.sin(x)
# Create a figure
plt.figure(figsize=(10, 6))
plt.plot(x, y, 'b-', linewidth=2)
plt.title('Sine Wave')
plt.xlabel('x')
plt.ylabel('sin(x)')
plt.grid(True, alpha=0.3)
# Save as PNG (raster format)
plt.savefig('sine_wave.png', dpi=150)
# Save as PDF (vector format - scales without pixelation)
plt.savefig('sine_wave.pdf')
# Save as SVG (vector format - editable in Illustrator)
plt.savefig('sine_wave.svg')
# Save with tight bounding box (removes extra whitespace)
plt.savefig('sine_wave_tight.png', dpi=200, bbox_inches='tight')
# Save with transparent background
plt.savefig('sine_wave_transparent.png', dpi=150, transparent=True)
# Save with custom background color
plt.savefig('sine_wave_dark.png', dpi=150, facecolor='#2C3E50')
plt.show()
The savefig function exports your figure to a file. PNG is ideal for web and presentations with good quality at 150-200 dpi. PDF and SVG are vector formats that scale infinitely without losing quality, perfect for print publications and editable graphics. The bbox_inches='tight' parameter crops extra whitespace, while transparent=True removes the background for overlay use. Always call savefig before show() because show() may clear the figure in some backends.
File Format Reference
| Format | Type | Best For | Extension |
|---|---|---|---|
| PNG | Raster | Web, presentations, screenshots | .png |
| Vector | Print, papers, reports | ||
| SVG | Vector | Web (scalable), editing | .svg |
| JPEG | Raster | Photos (lossy compression) | .jpg, .jpeg |
| EPS | Vector | LaTeX, professional printing | .eps |
Practice: Subplots & Saving
Task: Create two plots side by side showing sine and cosine.
Show Solution
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 10, 100)
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
axes[0].plot(x, np.sin(x))
axes[0].set_title('Sine')
axes[1].plot(x, np.cos(x))
axes[1].set_title('Cosine')
plt.tight_layout()
plt.show()
Task: Create two plots stacked vertically.
Show Solution
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 10, 100)
fig, axes = plt.subplots(2, 1, figsize=(8, 8))
axes[0].plot(x, x**2)
axes[0].set_title('Quadratic')
axes[1].plot(x, np.sqrt(x))
axes[1].set_title('Square Root')
plt.tight_layout()
plt.show()
Task: Create a 2x2 grid with line, bar, scatter, and histogram.
Show Solution
import matplotlib.pyplot as plt
import numpy as np
fig, axes = plt.subplots(2, 2, figsize=(10, 8))
# Line plot
x = np.linspace(0, 10, 50)
axes[0, 0].plot(x, np.sin(x))
axes[0, 0].set_title('Line Plot')
# Bar chart
axes[0, 1].bar(['A', 'B', 'C', 'D'], [4, 7, 2, 8])
axes[0, 1].set_title('Bar Chart')
# Scatter plot
axes[1, 0].scatter(np.random.rand(30), np.random.rand(30))
axes[1, 0].set_title('Scatter Plot')
# Histogram
axes[1, 1].hist(np.random.randn(200), bins=15)
axes[1, 1].set_title('Histogram')
plt.tight_layout()
plt.show()
Task: Create two stacked plots that share the x-axis.
Show Solution
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 10, 100)
fig, axes = plt.subplots(2, 1, figsize=(10, 6), sharex=True)
axes[0].plot(x, np.sin(x), 'b-')
axes[0].set_ylabel('sin(x)')
axes[0].set_title('Shared X-Axis')
axes[1].plot(x, np.cos(x), 'r-')
axes[1].set_ylabel('cos(x)')
axes[1].set_xlabel('x')
plt.tight_layout()
plt.show()
Task: Create a plot and save it as PNG at 200 DPI with tight bounding box.
Show Solution
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 10, 100)
plt.figure(figsize=(10, 6))
plt.plot(x, np.sin(x), 'b-', linewidth=2)
plt.title('Saved Figure')
plt.xlabel('x')
plt.ylabel('sin(x)')
plt.grid(True, alpha=0.3)
plt.savefig('my_plot.png', dpi=200, bbox_inches='tight')
plt.show()
Task: Create a layout with one large plot on the left and two smaller plots stacked on the right.
Show Solution
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
x = np.linspace(0, 10, 100)
fig = plt.figure(figsize=(12, 6))
gs = gridspec.GridSpec(2, 2, figure=fig)
# Large left plot
ax1 = fig.add_subplot(gs[:, 0])
ax1.plot(x, np.sin(x), 'b-', linewidth=2)
ax1.set_title('Main Plot')
# Top-right small plot
ax2 = fig.add_subplot(gs[0, 1])
ax2.bar(['A', 'B', 'C'], [3, 7, 5])
ax2.set_title('Bar')
# Bottom-right small plot
ax3 = fig.add_subplot(gs[1, 1])
ax3.scatter(np.random.rand(20), np.random.rand(20))
ax3.set_title('Scatter')
plt.tight_layout()
plt.show()
Advanced Techniques
Master advanced visualization techniques including twin axes, 3D plots, animations, and real-world data integration.
Advanced Visualization Concepts
Advanced techniques help you create sophisticated visualizations for complex data scenarios:
- Twin Axes: Plot two different scales on the same chart (e.g., temperature and precipitation)
- Logarithmic Scales: Handle data spanning multiple orders of magnitude
- 3D Plots: Visualize three-dimensional data relationships
- Pandas Integration: Direct plotting from DataFrames
- Real-time Updates: Create animated or updating visualizations
Twin Axes (Dual Y-Axis)
Display two different measurements with different scales on the same plot, sharing the x-axis.
import matplotlib.pyplot as plt
import numpy as np
# Sample data: temperature and rainfall over 12 months
months = np.arange(1, 13)
temperature = [5, 7, 12, 18, 23, 28, 31, 30, 25, 18, 11, 6] # Celsius
rainfall = [80, 65, 70, 55, 45, 30, 25, 35, 50, 75, 85, 90] # mm
fig, ax1 = plt.subplots(figsize=(12, 6))
# First axis: Temperature (line plot)
color1 = '#E74C3C'
ax1.set_xlabel('Month')
ax1.set_ylabel('Temperature (°C)', color=color1)
ax1.plot(months, temperature, color=color1, linewidth=2, marker='o', label='Temperature')
ax1.tick_params(axis='y', labelcolor=color1)
ax1.set_ylim(0, 35)
# Second axis: Rainfall (bar chart)
ax2 = ax1.twinx() # Create twin axis sharing x
color2 = '#3498DB'
ax2.set_ylabel('Rainfall (mm)', color=color2)
ax2.bar(months, rainfall, color=color2, alpha=0.5, label='Rainfall')
ax2.tick_params(axis='y', labelcolor=color2)
ax2.set_ylim(0, 100)
# Title and legend
plt.title('Monthly Temperature and Rainfall')
ax1.set_xticks(months)
ax1.set_xticklabels(['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'])
# Combined legend
lines1, labels1 = ax1.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax1.legend(lines1 + lines2, labels1 + labels2, loc='upper left')
plt.tight_layout()
plt.show()
The twinx() method creates a second y-axis that shares the same x-axis. This is essential when comparing variables with different units or scales, like temperature and rainfall. Color-code both the data and the y-axis labels to make it clear which axis corresponds to which data. The example also demonstrates combining legends from both axes into a single legend box. Use twin axes sparingly, as they can be confusing if overdone.
Logarithmic Scales
Use log scales when data spans multiple orders of magnitude or follows exponential patterns.
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(1, 100, 100)
y_linear = x ** 2
y_exp = np.exp(x / 20)
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
# Linear scale (both axes)
axes[0, 0].plot(x, y_exp, 'b-', linewidth=2)
axes[0, 0].set_title('Linear Scale')
axes[0, 0].set_xlabel('x')
axes[0, 0].set_ylabel('y')
axes[0, 0].grid(True, alpha=0.3)
# Logarithmic Y-axis
axes[0, 1].semilogy(x, y_exp, 'r-', linewidth=2) # Or use set_yscale('log')
axes[0, 1].set_title('Logarithmic Y-Axis (semilogy)')
axes[0, 1].set_xlabel('x')
axes[0, 1].set_ylabel('y (log scale)')
axes[0, 1].grid(True, alpha=0.3)
# Logarithmic X-axis
axes[1, 0].semilogx(x, y_linear, 'g-', linewidth=2)
axes[1, 0].set_title('Logarithmic X-Axis (semilogx)')
axes[1, 0].set_xlabel('x (log scale)')
axes[1, 0].set_ylabel('y')
axes[1, 0].grid(True, alpha=0.3)
# Log-log scale (both axes)
axes[1, 1].loglog(x, y_linear, 'm-', linewidth=2)
axes[1, 1].set_title('Log-Log Scale (loglog)')
axes[1, 1].set_xlabel('x (log scale)')
axes[1, 1].set_ylabel('y (log scale)')
axes[1, 1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
Logarithmic scales compress large ranges and expand small ranges, making patterns visible that would be invisible on linear scales. Use semilogy when y values span many orders of magnitude (like population growth), semilogx when x values do (like frequency response), and loglog for power-law relationships. On a log scale, exponential curves become straight lines, and power laws become straight lines on log-log plots. This is powerful for identifying the mathematical nature of your data.
3D Surface Plots
Visualize three-dimensional data using surface plots, wireframes, and scatter plots.
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
# Create mesh grid
x = np.linspace(-5, 5, 50)
y = np.linspace(-5, 5, 50)
X, Y = np.meshgrid(x, y)
Z = np.sin(np.sqrt(X**2 + Y**2)) # Ripple pattern
fig = plt.figure(figsize=(15, 5))
# 3D Surface plot
ax1 = fig.add_subplot(131, projection='3d')
surf = ax1.plot_surface(X, Y, Z, cmap='viridis', edgecolor='none', alpha=0.8)
ax1.set_title('3D Surface')
ax1.set_xlabel('X')
ax1.set_ylabel('Y')
ax1.set_zlabel('Z')
fig.colorbar(surf, ax=ax1, shrink=0.5, aspect=10)
# 3D Wireframe
ax2 = fig.add_subplot(132, projection='3d')
ax2.plot_wireframe(X, Y, Z, color='steelblue', linewidth=0.5)
ax2.set_title('3D Wireframe')
ax2.set_xlabel('X')
ax2.set_ylabel('Y')
ax2.set_zlabel('Z')
# 3D Scatter
ax3 = fig.add_subplot(133, projection='3d')
n = 200
x_scatter = np.random.randn(n)
y_scatter = np.random.randn(n)
z_scatter = np.random.randn(n)
colors = np.sqrt(x_scatter**2 + y_scatter**2 + z_scatter**2)
scatter = ax3.scatter(x_scatter, y_scatter, z_scatter, c=colors, cmap='plasma', s=30)
ax3.set_title('3D Scatter')
ax3.set_xlabel('X')
ax3.set_ylabel('Y')
ax3.set_zlabel('Z')
plt.tight_layout()
plt.show()
3D plotting requires importing Axes3D from mpl_toolkits.mplot3d and specifying projection='3d' when creating axes. The meshgrid function creates coordinate matrices from x and y vectors, essential for surface plots. Surface plots (plot_surface) fill the surface with color, wireframes show only the grid lines, and scatter3D displays individual points in 3D space. 3D plots can be rotated interactively in Jupyter notebooks or matplotlib windows, helping viewers understand spatial relationships.
Plotting with Pandas
Pandas DataFrames have built-in plotting that uses Matplotlib under the hood.
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
# Create sample DataFrame
np.random.seed(42)
dates = pd.date_range('2024-01-01', periods=100)
df = pd.DataFrame({
'Date': dates,
'Sales': np.random.randint(100, 500, 100) + np.sin(np.arange(100) / 10) * 100,
'Costs': np.random.randint(50, 250, 100),
'Profit': np.random.randint(20, 200, 100)
})
df.set_index('Date', inplace=True)
# Line plot directly from DataFrame
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
# Line plot: all columns
df.plot(ax=axes[0, 0], linewidth=1.5)
axes[0, 0].set_title('All Columns - Line Plot')
axes[0, 0].set_ylabel('Value')
axes[0, 0].legend(loc='upper left')
# Bar plot: single column
df['Sales'].resample('W').mean().plot(kind='bar', ax=axes[0, 1], color='steelblue')
axes[0, 1].set_title('Weekly Average Sales')
axes[0, 1].set_xlabel('Week')
axes[0, 1].tick_params(axis='x', rotation=45)
# Area plot: stacked
df[['Costs', 'Profit']].plot.area(ax=axes[1, 0], alpha=0.5)
axes[1, 0].set_title('Stacked Area Plot')
axes[1, 0].set_ylabel('Value')
# Scatter plot: two columns
df.plot.scatter(x='Costs', y='Sales', ax=axes[1, 1], alpha=0.6, c='Profit',
cmap='viridis', colorbar=True)
axes[1, 1].set_title('Costs vs Sales (colored by Profit)')
plt.tight_layout()
plt.show()
Pandas plotting methods like df.plot(), df.plot.bar(), df.plot.scatter() provide convenient shortcuts that automatically handle labels, legends, and date formatting. The ax parameter lets you place Pandas plots on specific matplotlib axes for multi-panel figures. Pandas also supports resampling (resample) for time series aggregation before plotting. This integration makes it easy to go from data analysis directly to visualization without manual data extraction.
Error Bars
Display uncertainty or variability in your data using error bars.
import matplotlib.pyplot as plt
import numpy as np
# Sample data with errors
categories = ['Group A', 'Group B', 'Group C', 'Group D', 'Group E']
means = [45, 62, 38, 55, 70]
std_devs = [5, 8, 4, 7, 6] # Standard deviations as error
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# Vertical bar chart with error bars
axes[0].bar(categories, means, yerr=std_devs, capsize=5,
color='steelblue', edgecolor='black', alpha=0.8)
axes[0].set_title('Bar Chart with Error Bars')
axes[0].set_ylabel('Value')
axes[0].set_xlabel('Category')
axes[0].grid(True, axis='y', alpha=0.3)
# Line plot with error band
x = np.linspace(0, 10, 20)
y = np.sin(x) * 2 + 5
error = 0.5 + 0.3 * np.random.randn(20)
axes[1].plot(x, y, 'b-', linewidth=2, marker='o', label='Mean')
axes[1].fill_between(x, y - error, y + error, alpha=0.3, color='blue', label='±1 std')
axes[1].errorbar(x, y, yerr=error, fmt='none', ecolor='blue', capsize=3, alpha=0.5)
axes[1].set_title('Line Plot with Error Band')
axes[1].set_xlabel('X')
axes[1].set_ylabel('Y')
axes[1].legend()
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
Error bars communicate uncertainty in your measurements, essential for scientific and statistical visualizations. The yerr parameter adds vertical error bars, while xerr handles horizontal errors. The capsize parameter controls the width of the error bar caps. For continuous data, fill_between creates error bands that are less cluttered than individual error bars. Always include error bars when presenting experimental or statistical data to help viewers assess the reliability of your findings.
Contour Plots
Visualize 3D data in 2D using contour lines or filled contours.
import matplotlib.pyplot as plt
import numpy as np
# Create data
x = np.linspace(-3, 3, 100)
y = np.linspace(-3, 3, 100)
X, Y = np.meshgrid(x, y)
Z = np.exp(-(X**2 + Y**2)) + np.exp(-((X-1.5)**2 + (Y-1.5)**2))
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
# Contour lines
cs1 = axes[0].contour(X, Y, Z, levels=10, colors='black')
axes[0].clabel(cs1, inline=True, fontsize=8) # Add labels to contour lines
axes[0].set_title('Contour Lines')
axes[0].set_xlabel('X')
axes[0].set_ylabel('Y')
axes[0].set_aspect('equal')
# Filled contours
cs2 = axes[1].contourf(X, Y, Z, levels=20, cmap='viridis')
plt.colorbar(cs2, ax=axes[1], label='Z value')
axes[1].set_title('Filled Contours')
axes[1].set_xlabel('X')
axes[1].set_ylabel('Y')
axes[1].set_aspect('equal')
# Contour lines on filled background
cs3a = axes[2].contourf(X, Y, Z, levels=20, cmap='coolwarm', alpha=0.8)
cs3b = axes[2].contour(X, Y, Z, levels=10, colors='black', linewidths=0.5)
plt.colorbar(cs3a, ax=axes[2], label='Z value')
axes[2].set_title('Combined Contours')
axes[2].set_xlabel('X')
axes[2].set_ylabel('Y')
axes[2].set_aspect('equal')
plt.tight_layout()
plt.show()
Contour plots represent 3D surfaces on a 2D plane, like topographic maps. The contour function draws lines of constant value, while contourf fills regions between contour levels with colors. The clabel function adds numeric labels directly on contour lines. The levels parameter controls how many contour lines appear; more levels show finer detail. Contour plots are invaluable for optimization landscapes, probability distributions, and any application where you need to show height or intensity patterns.
Practice: Advanced Techniques
Task: Plot temperature (line) and humidity (bars) on the same figure with different y-axes.
Show Solution
import matplotlib.pyplot as plt
import numpy as np
days = np.arange(1, 8)
temp = [22, 25, 28, 26, 24, 23, 21]
humidity = [60, 55, 50, 58, 62, 65, 70]
fig, ax1 = plt.subplots(figsize=(10, 6))
ax1.plot(days, temp, 'r-o', linewidth=2, label='Temperature')
ax1.set_xlabel('Day')
ax1.set_ylabel('Temperature (°C)', color='red')
ax1.tick_params(axis='y', labelcolor='red')
ax2 = ax1.twinx()
ax2.bar(days, humidity, color='blue', alpha=0.5, label='Humidity')
ax2.set_ylabel('Humidity (%)', color='blue')
ax2.tick_params(axis='y', labelcolor='blue')
plt.title('Weekly Weather')
plt.tight_layout()
plt.show()
Task: Plot exponential growth using semilogy to show a straight line.
Show Solution
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 10, 100)
y = np.exp(x)
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
axes[0].plot(x, y)
axes[0].set_title('Linear Scale')
axes[0].set_xlabel('x')
axes[0].set_ylabel('exp(x)')
axes[1].semilogy(x, y)
axes[1].set_title('Logarithmic Y Scale')
axes[1].set_xlabel('x')
axes[1].set_ylabel('exp(x)')
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
Task: Create a 3D surface plot of z = sin(x) * cos(y).
Show Solution
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
x = np.linspace(-np.pi, np.pi, 50)
y = np.linspace(-np.pi, np.pi, 50)
X, Y = np.meshgrid(x, y)
Z = np.sin(X) * np.cos(Y)
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')
surf = ax.plot_surface(X, Y, Z, cmap='viridis')
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.set_title('3D Surface: sin(x) * cos(y)')
fig.colorbar(surf, shrink=0.5)
plt.show()
Task: Create a bar chart with error bars representing standard deviation.
Show Solution
import matplotlib.pyplot as plt
import numpy as np
categories = ['A', 'B', 'C', 'D']
means = [25, 40, 35, 50]
std_devs = [3, 5, 4, 6]
plt.figure(figsize=(8, 6))
plt.bar(categories, means, yerr=std_devs, capsize=5,
color='steelblue', edgecolor='black')
plt.title('Bar Chart with Error Bars')
plt.xlabel('Category')
plt.ylabel('Value')
plt.grid(True, axis='y', alpha=0.3)
plt.show()
Interactive Demo
Experiment with different plot parameters to see how they affect visualizations in real-time.
Visualization Playground
Select a chart type to see example code and learn when to use it:
Line Plot
Best for: Trends over time, continuous data, comparing multiple series.
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 10, 100)
y = np.sin(x)
plt.figure(figsize=(10, 6))
plt.plot(x, y, 'b-', linewidth=2, label='sin(x)')
plt.title('Line Plot Example')
plt.xlabel('X')
plt.ylabel('Y')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
Visualization Best Practices
DO:
- Always include titles and axis labels
- Choose colors with good contrast
- Use legends when showing multiple series
- Match chart type to data type
- Keep visualizations simple and focused
- Use consistent styling across related plots
DON'T:
- Use pie charts with too many slices
- Start y-axis at non-zero without good reason
- Use 3D when 2D is sufficient
- Overcrowd plots with too much data
- Use rainbow colormaps for sequential data
- Forget to add units to labels
Common Mistakes & Fixes
| Mistake | Problem | Fix |
|---|---|---|
plt.show() then plt.savefig() |
Empty file saved | Call savefig() before show() |
Forgetting plt.tight_layout() |
Overlapping labels | Add plt.tight_layout() before show |
Using plt.plot() for categories |
Connected points | Use plt.bar() for categories |
| Too many colors in legend | Visual clutter | Group data or use subplots |
| Default figure size | Plots too small | Use figsize=(10, 6) or larger |
Real-World Examples
See how data visualization is applied in real scenarios across different industries and use cases.
Stock Price Analysis
Visualize stock prices with moving averages to identify trends and trading signals.
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
# Simulate stock data
np.random.seed(42)
dates = pd.date_range('2024-01-01', periods=100)
prices = 100 + np.cumsum(np.random.randn(100) * 2)
# Calculate moving averages
ma_20 = pd.Series(prices).rolling(20).mean()
ma_50 = pd.Series(prices).rolling(50).mean()
fig, ax = plt.subplots(figsize=(12, 6))
ax.plot(dates, prices, 'b-', alpha=0.6, label='Price')
ax.plot(dates, ma_20, 'r-', linewidth=2, label='20-day MA')
ax.plot(dates, ma_50, 'g-', linewidth=2, label='50-day MA')
# Highlight buy/sell signals
ax.fill_between(dates, prices, ma_20,
where=(prices > ma_20),
alpha=0.2, color='green', label='Above MA')
ax.fill_between(dates, prices, ma_20,
where=(prices < ma_20),
alpha=0.2, color='red', label='Below MA')
ax.set_title('Stock Price with Moving Averages')
ax.set_xlabel('Date')
ax.set_ylabel('Price ($)')
ax.legend(loc='upper left')
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
Health Data Dashboard
Create a dashboard showing vital signs over time with normal range indicators.
import matplotlib.pyplot as plt
import numpy as np
# Simulate daily health data
days = np.arange(1, 31)
heart_rate = 70 + np.random.randn(30) * 8
blood_pressure_sys = 120 + np.random.randn(30) * 10
blood_pressure_dia = 80 + np.random.randn(30) * 5
fig, axes = plt.subplots(2, 1, figsize=(12, 8), sharex=True)
# Heart rate plot
axes[0].plot(days, heart_rate, 'r-o', markersize=4)
axes[0].axhspan(60, 100, alpha=0.2, color='green', label='Normal')
axes[0].set_ylabel('Heart Rate (bpm)')
axes[0].set_title('30-Day Health Monitoring')
axes[0].legend(loc='upper right')
axes[0].grid(True, alpha=0.3)
# Blood pressure plot
axes[1].plot(days, blood_pressure_sys, 'b-o', markersize=4, label='Systolic')
axes[1].plot(days, blood_pressure_dia, 'c-s', markersize=4, label='Diastolic')
axes[1].axhspan(90, 120, alpha=0.2, color='green', label='Normal Sys')
axes[1].set_xlabel('Day')
axes[1].set_ylabel('Blood Pressure (mmHg)')
axes[1].legend(loc='upper right')
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
Sales Analysis
Analyze sales data by region and product category with grouped bar charts.
import matplotlib.pyplot as plt
import numpy as np
# Sales data
regions = ['North', 'South', 'East', 'West']
products = ['Electronics', 'Clothing', 'Food']
x = np.arange(len(regions))
width = 0.25
sales = {
'Electronics': [120, 90, 150, 110],
'Clothing': [80, 120, 100, 95],
'Food': [150, 130, 140, 160]
}
fig, ax = plt.subplots(figsize=(10, 6))
for i, (product, values) in enumerate(sales.items()):
offset = width * i
bars = ax.bar(x + offset, values, width, label=product)
ax.bar_label(bars, padding=3, fontsize=8)
ax.set_xlabel('Region')
ax.set_ylabel('Sales (thousands $)')
ax.set_title('Q4 Sales by Region and Product')
ax.set_xticks(x + width)
ax.set_xticklabels(regions)
ax.legend(loc='upper left')
ax.grid(True, axis='y', alpha=0.3)
plt.tight_layout()
plt.show()
Weather Analysis
Visualize temperature and precipitation patterns with dual-axis plots.
import matplotlib.pyplot as plt
import numpy as np
months = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
temp = [5, 7, 12, 18, 23, 28, 31, 30, 25, 18, 11, 6]
rain = [80, 65, 70, 55, 45, 30, 25, 35, 50, 75, 85, 90]
fig, ax1 = plt.subplots(figsize=(12, 6))
# Temperature line
color1 = '#E74C3C'
ax1.plot(months, temp, color=color1, marker='o', linewidth=2)
ax1.fill_between(months, temp, alpha=0.2, color=color1)
ax1.set_xlabel('Month')
ax1.set_ylabel('Temperature (°C)', color=color1)
ax1.tick_params(axis='y', labelcolor=color1)
ax1.set_ylim(0, 35)
# Rainfall bars
ax2 = ax1.twinx()
color2 = '#3498DB'
ax2.bar(months, rain, color=color2, alpha=0.5)
ax2.set_ylabel('Rainfall (mm)', color=color2)
ax2.tick_params(axis='y', labelcolor=color2)
ax2.set_ylim(0, 100)
plt.title('Annual Temperature and Rainfall')
plt.tight_layout()
plt.show()
Machine Learning Model Evaluation
Visualize model performance with confusion matrix and ROC curve.
import matplotlib.pyplot as plt
import numpy as np
# Simulate confusion matrix and ROC data
confusion_matrix = np.array([[85, 15], [10, 90]])
fpr = np.array([0, 0.1, 0.2, 0.4, 0.6, 0.8, 1.0])
tpr = np.array([0, 0.5, 0.7, 0.85, 0.92, 0.97, 1.0])
auc = 0.89
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# Confusion Matrix
im = axes[0].imshow(confusion_matrix, cmap='Blues')
axes[0].set_xticks([0, 1])
axes[0].set_yticks([0, 1])
axes[0].set_xticklabels(['Predicted Negative', 'Predicted Positive'])
axes[0].set_yticklabels(['Actual Negative', 'Actual Positive'])
axes[0].set_title('Confusion Matrix')
# Add text annotations
for i in range(2):
for j in range(2):
axes[0].text(j, i, str(confusion_matrix[i, j]),
ha='center', va='center', fontsize=20,
color='white' if confusion_matrix[i, j] > 50 else 'black')
plt.colorbar(im, ax=axes[0])
# ROC Curve
axes[1].plot(fpr, tpr, 'b-', linewidth=2, label=f'ROC (AUC = {auc:.2f})')
axes[1].plot([0, 1], [0, 1], 'r--', label='Random Classifier')
axes[1].fill_between(fpr, tpr, alpha=0.3)
axes[1].set_xlabel('False Positive Rate')
axes[1].set_ylabel('True Positive Rate')
axes[1].set_title('ROC Curve')
axes[1].legend(loc='lower right')
axes[1].grid(True, alpha=0.3)
axes[1].set_xlim([0, 1])
axes[1].set_ylim([0, 1])
plt.tight_layout()
plt.show()
Quick Reference
A comprehensive cheat sheet of commonly used Matplotlib functions, parameters, and patterns.
Plot Functions
| Function | Description | Example |
|---|---|---|
plt.plot() |
Line plot | plt.plot(x, y, 'b-o') |
plt.bar() |
Vertical bar chart | plt.bar(cats, vals) |
plt.barh() |
Horizontal bar chart | plt.barh(cats, vals) |
plt.scatter() |
Scatter plot | plt.scatter(x, y, c=colors) |
plt.hist() |
Histogram | plt.hist(data, bins=30) |
plt.pie() |
Pie chart | plt.pie(sizes, labels=names) |
plt.boxplot() |
Box plot | plt.boxplot(data) |
plt.contour() |
Contour lines | plt.contour(X, Y, Z) |
plt.contourf() |
Filled contours | plt.contourf(X, Y, Z) |
plt.imshow() |
Display image/matrix | plt.imshow(data, cmap='viridis') |
Customization Functions
| Function | Description | Example |
|---|---|---|
plt.title() |
Set plot title | plt.title('My Plot', fontsize=14) |
plt.xlabel() |
Set x-axis label | plt.xlabel('Time (s)') |
plt.ylabel() |
Set y-axis label | plt.ylabel('Value') |
plt.xlim() |
Set x-axis range | plt.xlim(0, 10) |
plt.ylim() |
Set y-axis range | plt.ylim(-1, 1) |
plt.xticks() |
Set x tick positions | plt.xticks([0, 5, 10]) |
plt.yticks() |
Set y tick positions | plt.yticks([0, 0.5, 1]) |
plt.legend() |
Add legend | plt.legend(loc='upper left') |
plt.grid() |
Add grid lines | plt.grid(True, alpha=0.3) |
plt.text() |
Add text annotation | plt.text(x, y, 'label') |
plt.annotate() |
Add annotation with arrow | plt.annotate('note', xy=(x,y)) |
plt.colorbar() |
Add color scale bar | plt.colorbar(label='Value') |
Figure & Layout Functions
| Function | Description | Example |
|---|---|---|
plt.figure() |
Create new figure | plt.figure(figsize=(10, 6)) |
plt.subplots() |
Create figure with subplots | fig, ax = plt.subplots(2, 2) |
plt.subplot() |
Add single subplot | plt.subplot(2, 2, 1) |
plt.tight_layout() |
Adjust layout spacing | plt.tight_layout() |
plt.savefig() |
Save figure to file | plt.savefig('plot.png', dpi=150) |
plt.show() |
Display the figure | plt.show() |
plt.clf() |
Clear current figure | plt.clf() |
plt.close() |
Close figure window | plt.close('all') |
ax.twinx() |
Create twin y-axis | ax2 = ax.twinx() |
ax.twiny() |
Create twin x-axis | ax2 = ax.twiny() |
Line Styles
| Code | Style |
|---|---|
'-' | Solid line ― |
'--' | Dashed line - - |
'-.' | Dash-dot line -· |
':' | Dotted line ··· |
'' | No line (markers only) |
Markers
| Code | Marker |
|---|---|
'o' | Circle ● |
's' | Square ■ |
'^' | Triangle ▲ |
'*' | Star ★ |
'+' | Plus + |
'x' | Cross × |
'D' | Diamond ◆ |
Color Codes
| Code | Color |
|---|---|
'b' | ■ Blue |
'g' | ■ Green |
'r' | ■ Red |
'c' | ■ Cyan |
'm' | ■ Magenta |
'y' | ■ Yellow |
'k' | ■ Black |
'w' | □ White |
Popular Colormaps
| Name | Use Case |
|---|---|
'viridis' | Default, perceptually uniform |
'plasma' | High contrast sequential |
'coolwarm' | Diverging (blue to red) |
'RdYlGn' | Diverging (red to green) |
'Blues' | Sequential blues |
'hot' | Black to white via red |
'gray' | Grayscale |
'tab10' | Categorical (10 colors) |
Ready-to-Use Templates
Basic Plot Template
import matplotlib.pyplot as plt
import numpy as np
# Data
x = np.linspace(0, 10, 100)
y = np.sin(x)
# Create figure
plt.figure(figsize=(10, 6))
# Plot
plt.plot(x, y, 'b-', linewidth=2, label='Data')
# Customize
plt.title('Title', fontsize=14)
plt.xlabel('X Label')
plt.ylabel('Y Label')
plt.legend()
plt.grid(True, alpha=0.3)
# Save and show
plt.tight_layout()
plt.savefig('plot.png', dpi=150, bbox_inches='tight')
plt.show()
Multi-Panel Template
import matplotlib.pyplot as plt
import numpy as np
# Data
x = np.linspace(0, 10, 100)
# Create subplots
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
# Plot each panel
axes[0, 0].plot(x, np.sin(x))
axes[0, 0].set_title('Panel 1')
axes[0, 1].plot(x, np.cos(x))
axes[0, 1].set_title('Panel 2')
axes[1, 0].bar(['A', 'B', 'C'], [1, 2, 3])
axes[1, 0].set_title('Panel 3')
axes[1, 1].scatter(np.random.rand(20), np.random.rand(20))
axes[1, 1].set_title('Panel 4')
# Adjust and show
plt.tight_layout()
plt.savefig('multi_panel.png', dpi=150)
plt.show()
Publication-Quality Template
import matplotlib.pyplot as plt
import numpy as np
# Set style for publications
plt.rcParams.update({
'font.size': 12,
'axes.labelsize': 14,
'axes.titlesize': 16,
'legend.fontsize': 11,
'xtick.labelsize': 11,
'ytick.labelsize': 11,
'figure.figsize': (8, 6),
'figure.dpi': 100,
'savefig.dpi': 300,
'savefig.bbox': 'tight'
})
# Data
x = np.linspace(0, 2*np.pi, 100)
# Create figure with object-oriented API
fig, ax = plt.subplots()
# Plot with styling
ax.plot(x, np.sin(x), 'b-', linewidth=2, label=r'$\sin(x)$')
ax.plot(x, np.cos(x), 'r--', linewidth=2, label=r'$\cos(x)$')
# Customize
ax.set_xlabel(r'$x$ (radians)')
ax.set_ylabel(r'$f(x)$')
ax.set_title('Trigonometric Functions')
ax.legend(loc='upper right', frameon=True)
ax.grid(True, alpha=0.3, linestyle='--')
# Remove top and right spines
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
# Save for publication
plt.savefig('publication_plot.pdf')
plt.savefig('publication_plot.png')
plt.show()
Frequently Asked Questions
Use plt.rcParams.update({'font.size': 14}) to change the default font size for all text elements, or use parameters like fontsize=14 on individual elements.
You likely called plt.show() before plt.savefig(). The show() function clears the figure in some backends. Always call savefig() first.
Try using a style with plt.style.use('seaborn-v0_8') or plt.style.use('ggplot'). Also increase figure size with figsize=(10, 6) and remove top/right spines for a cleaner look.
plt.plot() is the pyplot interface (simpler, implicit). ax.plot() is the object-oriented interface (more control, explicit). Use OO style for complex plots and subplots.
Key Takeaways
Figure and Axes
Figure is the canvas, Axes is where data is plotted. One figure can have multiple axes.
Choose Right Chart
Line for trends, bar for categories, scatter for relationships, histogram for distributions.
Always Label
Add titles, axis labels, and legends to make plots self-explanatory.
Subplots for Comparison
Use subplots() to create multi-panel figures for side-by-side comparisons.
Save Before Show
Call savefig() before show() to save your plots to files.
Use Styles
plt.style.use() applies pre-made themes for consistent, attractive plots.
Knowledge Check
Quick Quiz
Test what you've learned about data visualization with Matplotlib