Capstone Project 1

Iris Flower Classification

Build a complete machine learning pipeline to classify iris flower species. You will explore the famous Iris dataset, perform exploratory data analysis, train multiple classification models, evaluate their performance, and deploy a prediction function with proper documentation.

4-6 hours
Beginner
200 Points
What You Will Build
  • Exploratory data analysis notebook
  • Data visualization with matplotlib/seaborn
  • Multiple ML classification models
  • Model evaluation and comparison
  • Prediction function with saved model
Contents
01

Project Overview

This introductory project brings together the fundamental concepts from the Machine Learning Basics module. You will work with the famous Iris Dataset - one of the most well-known datasets in machine learning, containing 150 samples of iris flowers across 3 species (Setosa, Versicolor, and Virginica) with 4 features each (sepal length, sepal width, petal length, petal width). Your goal is to build a complete classification pipeline from data exploration to model deployment, demonstrating your understanding of supervised learning fundamentals.

Skills Applied: This project tests your proficiency in Python (pandas, numpy), data visualization (matplotlib, seaborn), machine learning (scikit-learn), and model evaluation metrics.
EDA

Explore data distributions, correlations, and patterns

Visualization

Create informative plots and pair plots

ML Models

Train and compare multiple classifiers

Evaluation

Assess accuracy, precision, recall, and F1

Learning Objectives

Technical Skills
  • Load and preprocess tabular data with pandas
  • Perform comprehensive exploratory data analysis
  • Create publication-quality visualizations
  • Implement train-test split and cross-validation
  • Train multiple classification algorithms
ML Workflow Skills
  • Understand the end-to-end ML pipeline
  • Compare model performance using metrics
  • Interpret confusion matrices and classification reports
  • Save and load trained models with joblib
  • Document your analysis for reproducibility
Ready to submit? Already completed the project? Submit your work now!
Submit Now
02

Project Scenario

FloraLab Research Institute

You have been hired as a Junior Machine Learning Engineer at FloraLab, a botanical research institute that specializes in plant species identification using AI. The team has collected measurements from 150 iris flowers and needs an automated classification system to identify species based on their physical characteristics.

"We need a reliable way to classify iris flowers by species using just their sepal and petal measurements. Can you build a machine learning model that achieves at least 90% accuracy? We also need visualizations to understand how the features relate to each species."

Dr. Sarah Chen, Lead Botanist

Tasks to Complete

Data Analysis
  • What is the distribution of each feature?
  • Are there any correlations between features?
  • Which features best separate the species?
  • Are there any outliers in the data?
Model Building
  • Which classification algorithm performs best?
  • What is the optimal train-test split?
  • How does cross-validation improve reliability?
  • What are the most important features?
Evaluation
  • What is the accuracy on the test set?
  • Which species is hardest to classify?
  • What does the confusion matrix reveal?
  • How confident is the model in its predictions?
Deployment
  • How to save the trained model?
  • How to load and use the model for new predictions?
  • How to create a simple prediction function?
  • How to document the model for others?
Pro Tip: Think like a data scientist! Your notebook should tell a story from data exploration to final model, with clear explanations at each step that even non-technical stakeholders can follow.
03

The Dataset

You will work with the famous Iris dataset, introduced by statistician Ronald Fisher in 1936. Download the CSV file containing all 150 samples:

Dataset Download

Download the Iris dataset CSV file and save it to your project folder. The file contains 150 samples with 4 features and 1 target variable.

Original Data Source

This project uses the Iris Dataset from Kaggle - the "Hello World" of machine learning. The dataset was originally collected by Edgar Anderson and made famous by Ronald Fisher's 1936 paper on discriminant analysis. It remains one of the most used datasets for learning classification.

Dataset Info: 150 rows × 5 columns | 3 Classes (50 each) | 4 Features | Balanced dataset | No missing values | Classic benchmark dataset
Dataset Schema
ColumnTypeDescriptionRange
sepal_lengthFloatLength of sepal in centimeters4.3 - 7.9 cm
sepal_widthFloatWidth of sepal in centimeters2.0 - 4.4 cm
petal_lengthFloatLength of petal in centimeters1.0 - 6.9 cm
petal_widthFloatWidth of petal in centimeters0.1 - 2.5 cm
speciesStringTarget variable: flower speciessetosa, versicolor, virginica
Iris Setosa

50 samples | Smallest petals | Most distinct species, linearly separable

Iris Versicolor

50 samples | Medium-sized | Some overlap with Virginica

Iris Virginica

50 samples | Largest petals | Some overlap with Versicolor

Sample Data Preview

Here is what a typical record looks like from iris.csv:

sepal_lengthsepal_widthpetal_lengthpetal_widthspecies
5.13.51.40.2setosa
7.03.24.71.4versicolor
6.33.36.02.5virginica
Fun Fact: The Iris dataset is so popular that scikit-learn includes it as a built-in dataset. You can also load it with from sklearn.datasets import load_iris, but for this project, use the CSV file to practice real-world data loading.
04

Project Requirements

Your project must include all of the following components. Structure your Jupyter notebook with clear markdown headers and code cells.

1
Data Loading and Exploration

Load the dataset and understand its structure:

  • Load iris.csv using pandas
  • Display the first 10 rows with df.head(10)
  • Check data types and shape with df.info() and df.shape
  • Verify there are no missing values
  • Display summary statistics with df.describe()
  • Check class distribution with df['species'].value_counts()
Deliverable: Markdown section explaining the dataset characteristics and any initial observations about the data.
2
Exploratory Data Analysis (EDA)

Create visualizations to understand the data:

  • Distribution plots: Histograms for each feature by species
  • Box plots: Compare feature distributions across species
  • Pair plot: Use seaborn pairplot to visualize all feature combinations
  • Correlation heatmap: Show relationships between numeric features
  • Violin plots: Display distribution and density by species

Analysis questions to answer:

  • Which features show the clearest separation between species?
  • Is there any overlap between species? Which ones?
  • Are there any outliers in the dataset?
Deliverable: At least 5 visualizations with clear titles, labels, and markdown explanations of what each plot reveals.
3
Data Preprocessing

Prepare the data for machine learning:

  • Separate features (X) and target (y)
  • Encode target labels if necessary (LabelEncoder)
  • Split data into training (80%) and testing (20%) sets
  • Set a random_state for reproducibility
  • Optionally: Scale features using StandardScaler
Deliverable: Code cells showing the preprocessing steps with print statements confirming the shapes of training and test sets.
4
Model Training

Train at least 3 different classification models:

  • Logistic Regression: Baseline linear classifier
  • K-Nearest Neighbors (KNN): Instance-based learning
  • Decision Tree: Tree-based classifier
  • Random Forest: Ensemble method (bonus)
  • Support Vector Machine: Margin-based classifier (bonus)

For each model:

  • Fit the model on training data
  • Make predictions on test data
  • Store predictions for evaluation
Deliverable: At least 3 trained models with clear documentation explaining each algorithm's approach.
5
Model Evaluation

Evaluate and compare model performance:

  • Accuracy Score: Overall prediction accuracy
  • Classification Report: Precision, recall, F1-score per class
  • Confusion Matrix: Visualize prediction errors
  • Cross-Validation: 5-fold or 10-fold CV scores
  • Comparison Table: Compare all models side by side
Deliverable: A summary table comparing accuracy across all models, and confusion matrix visualizations for the best model.
6
Model Saving and Prediction Function

Deploy the best model:

  • Select the best performing model based on evaluation
  • Save the model using joblib or pickle
  • Create a predict_species() function that takes measurements as input
  • Demonstrate the function with sample predictions
# Example prediction function
def predict_species(sepal_length, sepal_width, petal_length, petal_width):
    """Predict iris species from flower measurements."""
    features = [[sepal_length, sepal_width, petal_length, petal_width]]
    prediction = model.predict(features)
    return species_names[prediction[0]]

# Test the function
print(predict_species(5.1, 3.5, 1.4, 0.2))  # Expected: 'setosa'
Deliverable: Saved model file (.pkl or .joblib) and a working prediction function with test examples.
05

Model Specifications

Train the following classification models and compare their performance. Each model has different strengths suitable for this multiclass classification task.

Logistic Regression

Linear model that estimates class probabilities using the logistic function. Works well for linearly separable classes.

from sklearn.linear_model import LogisticRegression

lr_model = LogisticRegression(max_iter=200)
lr_model.fit(X_train, y_train)
lr_pred = lr_model.predict(X_test)
Interpretable Probabilistic Fast
K-Nearest Neighbors

Instance-based learning that classifies based on the majority class of the k nearest training samples.

from sklearn.neighbors import KNeighborsClassifier

knn_model = KNeighborsClassifier(n_neighbors=5)
knn_model.fit(X_train, y_train)
knn_pred = knn_model.predict(X_test)
Non-parametric Simple Intuitive
Decision Tree

Tree-based model that makes decisions by learning simple rules inferred from features. Highly interpretable.

from sklearn.tree import DecisionTreeClassifier

dt_model = DecisionTreeClassifier(random_state=42)
dt_model.fit(X_train, y_train)
dt_pred = dt_model.predict(X_test)
Interpretable Visual No scaling needed
Random Forest (Bonus)

Ensemble of decision trees that reduces overfitting by averaging predictions from multiple trees trained on different subsets.

from sklearn.ensemble import RandomForestClassifier

rf_model = RandomForestClassifier(n_estimators=100, random_state=42)
rf_model.fit(X_train, y_train)
rf_pred = rf_model.predict(X_test)
Ensemble Robust Feature importance
Expected Performance

For the Iris dataset, well-tuned models typically achieve:

ModelExpected AccuracyNotes
Logistic Regression95-100%Works well due to near-linear separability
K-Nearest Neighbors95-100%k=3 or k=5 typically optimal
Decision Tree90-97%May overfit without pruning
Random Forest95-100%Most robust, rarely overfits
SVM (RBF kernel)95-100%Excellent for small datasets
Target: Your best model should achieve at least 90% accuracy on the test set. Scores above 95% are excellent!
06

Required Visualizations

Create at least 5 visualizations in your notebook. Each should have clear titles, axis labels, and legends where appropriate.

1. Pair Plot

Scatter plots of all feature pairs colored by species. Use sns.pairplot(df, hue='species')

Required
2. Correlation Heatmap

Heatmap showing correlations between numeric features. Use sns.heatmap(df.corr(), annot=True)

Required
3. Box Plots by Species

Box plots for each feature grouped by species to compare distributions. 2x2 subplot layout recommended.

Required
4. Confusion Matrix

Heatmap of the confusion matrix for your best model. Use sns.heatmap(confusion_matrix(), annot=True)

Required
5. Model Comparison Bar Chart

Bar chart comparing accuracy scores of all trained models. Include error bars from cross-validation if possible.

Required
6. Violin Plots (Bonus)

Violin plots showing distribution shape for each feature by species. More informative than box plots.

Bonus
Sample Visualization Code
import matplotlib.pyplot as plt
import seaborn as sns

# Set style for all plots
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('husl')

# 1. Pair Plot
plt.figure(figsize=(12, 10))
sns.pairplot(df, hue='species', markers=['o', 's', 'D'])
plt.suptitle('Iris Dataset - Pair Plot by Species', y=1.02)
plt.tight_layout()
plt.savefig('figures/pairplot.png', dpi=150)
plt.show()

# 2. Correlation Heatmap
plt.figure(figsize=(8, 6))
sns.heatmap(df.drop('species', axis=1).corr(), 
            annot=True, cmap='coolwarm', center=0)
plt.title('Feature Correlation Heatmap')
plt.tight_layout()
plt.savefig('figures/correlation.png', dpi=150)
plt.show()
07

Submission Requirements

Create a public GitHub repository with the exact name shown below:

Required Repository Name
iris-flower-classification
github.com/<your-username>/iris-flower-classification
Required Project Structure
iris-flower-classification/
├── data/
│   └── iris.csv                    # Original dataset
├── notebooks/
│   └── iris_classification.ipynb   # Main analysis notebook
├── models/
│   └── best_model.pkl              # Saved trained model
├── figures/
│   ├── pairplot.png                # Pair plot visualization
│   ├── correlation.png             # Correlation heatmap
│   ├── boxplots.png                # Box plots by species
│   ├── confusion_matrix.png        # Confusion matrix
│   └── model_comparison.png        # Model accuracy comparison
└── README.md                       # Project documentation
README.md Required Sections
1. Project Header
  • Project title and description
  • Your full name and submission date
  • Course and project number
2. Dataset Description
  • Iris dataset overview
  • Features and target variable
  • Link to original source
3. Installation
  • Required packages (pandas, numpy, sklearn, etc.)
  • How to set up the environment
  • How to run the notebook
4. Results Summary
  • Best model and accuracy achieved
  • Key findings from EDA
  • Model comparison table
5. Visualizations
  • Include key figures inline
  • Brief caption for each
  • Use markdown image syntax
6. How to Use the Model
  • Code example for loading model
  • Sample prediction code
  • Expected input/output format
Do Include
  • All required files in correct folders
  • Well-commented notebook with markdown
  • Saved model file (.pkl or .joblib)
  • All visualization images
  • Comprehensive README
  • requirements.txt file
Do Not Include
  • Jupyter notebook checkpoints (.ipynb_checkpoints/)
  • Python cache files (__pycache__/)
  • Virtual environment folders (venv/, env/)
  • Large unnecessary files
  • Incomplete or broken code
Important: Before submitting, restart your notebook kernel and run all cells from top to bottom to ensure everything executes correctly.
Submit Your Project

Enter your GitHub username - we will verify your repository automatically

08

Grading Rubric

Your project will be graded on the following criteria. Total: 200 points.

Criteria Points Description
Data Loading and Exploration 25 Proper data loading, initial analysis, summary statistics
Exploratory Data Analysis 35 At least 5 quality visualizations with clear insights
Data Preprocessing 20 Proper train-test split, encoding, optional scaling
Model Training 35 At least 3 different models trained correctly
Model Evaluation 35 Accuracy, classification report, confusion matrix, comparison
Model Saving and Prediction 25 Saved model file and working prediction function
Documentation 25 README quality, code comments, notebook markdown
Total 200
Grading Levels
Excellent
180-200

Exceeds all requirements with exceptional quality

Good
150-179

Meets all requirements with good quality

Satisfactory
120-149

Meets minimum requirements

Needs Work
< 120

Missing key requirements

Ready to Submit?

Make sure you have completed all requirements and reviewed the grading rubric above.

Submit Your Project
09

Pre-Submission Checklist

Use this checklist to verify you have completed all requirements before submitting your project.

Data and EDA
Model Training
Evaluation
Repository
Final Check: Restart your notebook kernel and run all cells (Kernel → Restart & Run All) to ensure everything executes without errors.