Model Serialization
Training a model takes time — you don't want to redo it every time someone wants a prediction! Think of serialization like freezing a meal you cooked: you do all the hard work once, then save it so you can "reheat" it later instantly. Without saving your model, you'd have to retrain from scratch every time you restart your program. That's like cooking dinner from scratch every night instead of using leftovers!
Serialization
In simple terms: Serialization is like taking a photo of your trained model's brain. You save everything it learned (the patterns, the weights, the settings) to a file. Later, you can load this file and the model "wakes up" exactly as you left it — ready to make predictions without any retraining. The file is typically small (kilobytes to megabytes) and portable.
Using Joblib (Recommended)
# Step 1: Train a model to save
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
import joblib
# Load data and train
X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Create pipeline with preprocessing
pipeline = Pipeline([
('scaler', StandardScaler()),
('classifier', RandomForestClassifier(n_estimators=100, random_state=42))
])
pipeline.fit(X_train, y_train)
print(f"Model accuracy: {pipeline.score(X_test, y_test):.4f}")
Why use a pipeline? Imagine you trained your model with scaled data (values between 0-1), but in production you receive raw data (values like 150, 3000). The predictions would be completely wrong!
- Pipeline = preprocessing + model bundled together
- When you save the pipeline, it remembers HOW to scale (the min/max from training data)
- New data goes through the same transformations automatically
- Pro tip: Always save pipelines, never just the model!
# Step 2: Save the entire pipeline with joblib
model_path = 'models/iris_classifier.joblib'
# Save the model
joblib.dump(pipeline, model_path)
print(f"Model saved to {model_path}")
# Check file size
import os
file_size = os.path.getsize(model_path) / 1024 # KB
print(f"File size: {file_size:.2f} KB")
What's happening:
- joblib.dump(pipeline, path): Saves everything to a file — like pressing "Save Game" in a video game
- Why joblib over pickle? Joblib is specially optimized for ML models with large arrays. It's faster and can compress files
- File extension: Use
.joblibto make it clear what's inside - Model size: Simple models are tiny (~KB). Deep learning models can be gigabytes!
# Step 3: Load and use the saved model
# In production, this is all you need
loaded_pipeline = joblib.load(model_path)
# Make predictions with loaded model
sample = [[5.1, 3.5, 1.4, 0.2]] # New iris sample
prediction = loaded_pipeline.predict(sample)
probabilities = loaded_pipeline.predict_proba(sample)
print(f"Prediction: {prediction[0]} (class: {['setosa', 'versicolor', 'virginica'][prediction[0]]})")
print(f"Probabilities: {probabilities[0].round(4)}")
Loading in production — this is all you need:
- joblib.load(path): Reads the file and reconstructs your pipeline exactly
- No training data needed: The loaded model already knows everything — just give it new samples!
- predict(): Returns the class (0, 1, 2, etc.)
- predict_proba(): Returns confidence for each class (e.g., 95% setosa, 3% versicolor, 2% virginica)
Speed: Loading takes milliseconds. Prediction takes milliseconds. That's why we preload at startup!
Saving Model Metadata
Always save metadata about your model for reproducibility and debugging:
# Step 4: Save model with metadata
import json
from datetime import datetime
metadata = {
'model_name': 'iris_classifier',
'version': '1.0.0',
'created_at': datetime.now().isoformat(),
'sklearn_version': sklearn.__version__,
'training_accuracy': float(pipeline.score(X_train, y_train)),
'test_accuracy': float(pipeline.score(X_test, y_test)),
'features': ['sepal_length', 'sepal_width', 'petal_length', 'petal_width'],
'classes': ['setosa', 'versicolor', 'virginica'],
'hyperparameters': pipeline.get_params()
}
# Save metadata alongside model
with open('models/iris_classifier_metadata.json', 'w') as f:
json.dump(metadata, f, indent=2, default=str)
print("Metadata saved successfully")
Why save metadata? Future you will thank you!
- version: "Is this the model from last week or last month?"
- sklearn_version: Critical! Models may not load with different sklearn versions
- test_accuracy: "How good was this model supposed to be?"
- features: "Wait, what order were the features in?"
- created_at: "When did we train this?"
Real story: Many production bugs happen because the feature order changed between training and prediction. Metadata prevents these nightmares!
Practice Questions
Task: Save a model with gzip compression and compare file sizes.
Show Solution
import joblib
import os
# Save without compression
joblib.dump(pipeline, 'model_uncompressed.joblib')
# Save with compression
joblib.dump(pipeline, 'model_compressed.joblib', compress=3) # 1-9 compression level
# Compare sizes
size_uncompressed = os.path.getsize('model_uncompressed.joblib') / 1024
size_compressed = os.path.getsize('model_compressed.joblib') / 1024
print(f"Uncompressed: {size_uncompressed:.2f} KB")
print(f"Compressed: {size_compressed:.2f} KB")
print(f"Reduction: {(1 - size_compressed/size_uncompressed)*100:.1f}%")
# Load compressed model
loaded = joblib.load('model_compressed.joblib')
print(f"Model works: {loaded.score(X_test, y_test):.4f}")
Task: Train a simple model, save it with joblib, load it back, and verify it still works.
Show Solution
import joblib
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
# Train model
X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X_train, y_train)
print(f"Original model accuracy: {model.score(X_test, y_test):.4f}")
# Save
joblib.dump(model, 'my_model.joblib')
print("Model saved!")
# Load
loaded_model = joblib.load('my_model.joblib')
print(f"Loaded model accuracy: {loaded_model.score(X_test, y_test):.4f}")
# Verify they're the same!
assert model.score(X_test, y_test) == loaded_model.score(X_test, y_test)
print("✓ Model works identically after loading!")
Task: Create a pipeline with StandardScaler + model, save it, and verify preprocessing works correctly after loading.
Show Solution
import joblib
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
# Create pipeline
pipeline = Pipeline([
('scaler', StandardScaler()),
('classifier', LogisticRegression(random_state=42))
])
# Train
pipeline.fit(X_train, y_train)
# Save
joblib.dump(pipeline, 'full_pipeline.joblib')
# Load and test with RAW (unscaled) data
loaded_pipeline = joblib.load('full_pipeline.joblib')
# This should work - pipeline handles scaling internally!
raw_sample = [[5.1, 3.5, 1.4, 0.2]] # Raw values, not scaled
prediction = loaded_pipeline.predict(raw_sample)
print(f"Prediction from raw input: {prediction}")
# Verify scaler has correct statistics
scaler = loaded_pipeline.named_steps['scaler']
print(f"Scaler mean (from training): {scaler.mean_.round(2)}")
print("✓ Pipeline remembers how to scale!")
Task: Save a model along with metadata (version, date, metrics, features) in a JSON file.
Show Solution
import joblib
import json
from datetime import datetime
import sklearn
# Train and save model
pipeline.fit(X_train, y_train)
joblib.dump(pipeline, 'models/model_v1.joblib')
# Create metadata
metadata = {
'model_name': 'iris_classifier',
'version': '1.0.0',
'created_at': datetime.now().isoformat(),
'sklearn_version': sklearn.__version__,
'python_version': '3.10',
'metrics': {
'train_accuracy': float(pipeline.score(X_train, y_train)),
'test_accuracy': float(pipeline.score(X_test, y_test))
},
'features': ['sepal_length', 'sepal_width', 'petal_length', 'petal_width'],
'classes': ['setosa', 'versicolor', 'virginica'],
'notes': 'First production model'
}
# Save metadata
with open('models/model_v1_metadata.json', 'w') as f:
json.dump(metadata, f, indent=2)
print("Model and metadata saved!")
print(json.dumps(metadata, indent=2))
Task: Write a safe model loader that checks sklearn version compatibility before loading.
Show Solution
import joblib
import json
import sklearn
from packaging import version
def safe_load_model(model_path, metadata_path):
"""Load model with version compatibility check."""
# Load metadata first
with open(metadata_path, 'r') as f:
metadata = json.load(f)
saved_version = metadata.get('sklearn_version', 'unknown')
current_version = sklearn.__version__
print(f"Model was saved with sklearn {saved_version}")
print(f"Current sklearn version: {current_version}")
# Check major.minor version match
saved_major_minor = '.'.join(saved_version.split('.')[:2])
current_major_minor = '.'.join(current_version.split('.')[:2])
if saved_major_minor != current_major_minor:
print(f"WARNING: Version mismatch!")
print(f" Model: {saved_major_minor}.x, Current: {current_major_minor}.x")
print(" This may cause errors. Consider retraining.")
# Ask for confirmation
proceed = input("Load anyway? (y/n): ")
if proceed.lower() != 'y':
raise ValueError("Loading cancelled due to version mismatch")
# Load the model
try:
model = joblib.load(model_path)
print("✓ Model loaded successfully!")
return model, metadata
except Exception as e:
print(f"✗ Failed to load: {e}")
raise
# Usage
model, meta = safe_load_model('models/model_v1.joblib', 'models/model_v1_metadata.json')
Task: Create a simple model registry that tracks multiple model versions and can load any version on demand.
Show Solution
import joblib
import json
from pathlib import Path
from datetime import datetime
class ModelRegistry:
"""Simple local model registry."""
def __init__(self, registry_dir='model_registry'):
self.registry_dir = Path(registry_dir)
self.registry_dir.mkdir(exist_ok=True)
self.registry_file = self.registry_dir / 'registry.json'
self.registry = self._load_registry()
def _load_registry(self):
if self.registry_file.exists():
with open(self.registry_file, 'r') as f:
return json.load(f)
return {'models': {}}
def _save_registry(self):
with open(self.registry_file, 'w') as f:
json.dump(self.registry, f, indent=2)
def register(self, model, name, version, metrics=None):
"""Save and register a model version."""
model_path = self.registry_dir / f"{name}_{version}.joblib"
joblib.dump(model, model_path)
if name not in self.registry['models']:
self.registry['models'][name] = {}
self.registry['models'][name][version] = {
'path': str(model_path),
'created_at': datetime.now().isoformat(),
'metrics': metrics or {}
}
self._save_registry()
print(f"Registered {name} v{version}")
def load(self, name, version='latest'):
"""Load a specific model version."""
if name not in self.registry['models']:
raise ValueError(f"Model '{name}' not found")
versions = self.registry['models'][name]
if version == 'latest':
version = max(versions.keys())
if version not in versions:
raise ValueError(f"Version '{version}' not found")
path = versions[version]['path']
return joblib.load(path)
def list_versions(self, name):
"""List all versions of a model."""
return list(self.registry['models'].get(name, {}).keys())
# Usage
registry = ModelRegistry()
registry.register(pipeline, 'iris_classifier', 'v1.0', {'accuracy': 0.95})
registry.register(pipeline, 'iris_classifier', 'v1.1', {'accuracy': 0.97})
print(registry.list_versions('iris_classifier'))
model = registry.load('iris_classifier', 'latest')
Flask REST API
Your model is saved — but how do people actually USE it? They can't run Python! You need an API (Application Programming Interface) — think of it as a waiter at a restaurant. Customers (other apps) don't go into the kitchen (your Python code). They tell the waiter what they want, the waiter brings it to the kitchen, and returns with the result. Flask is a simple, beginner-friendly framework to create this "waiter" for your model.
Basic Flask Prediction API
# app.py - Basic Flask ML API
from flask import Flask, request, jsonify
import joblib
import numpy as np
app = Flask(__name__)
# Load model at startup (not on every request!)
model = joblib.load('models/iris_classifier.joblib')
@app.route('/health', methods=['GET'])
def health():
"""Health check endpoint."""
return jsonify({'status': 'healthy', 'model': 'iris_classifier'})
@app.route('/predict', methods=['POST'])
def predict():
"""Make predictions from JSON input."""
try:
# Get JSON data
data = request.get_json()
# Extract features (expecting list of lists)
features = np.array(data['features'])
# Make prediction
predictions = model.predict(features)
probabilities = model.predict_proba(features)
# Format response
results = []
class_names = ['setosa', 'versicolor', 'virginica']
for pred, probs in zip(predictions, probabilities):
results.append({
'prediction': int(pred),
'class_name': class_names[pred],
'probabilities': {name: float(p) for name, p in zip(class_names, probs)}
})
return jsonify({'predictions': results})
except Exception as e:
return jsonify({'error': str(e)}), 400
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000, debug=True)
Anatomy of a Flask ML API:
- Load model ONCE at startup: Loading takes time. Do it once, not per request (imagine reheating your entire kitchen for each order!)
- @app.route('/health'): A simple endpoint to check "is the server alive?" — like pinging "are you there?"
- @app.route('/predict', methods=['POST']): POST = sending data TO the server (the customer's order)
- request.get_json(): Parse the customer's order (JSON data they sent)
- jsonify(): Convert Python dict to JSON response (the meal we send back)
- try/except: Always handle errors gracefully — don't crash the whole kitchen!
To run: python app.py then visit http://localhost:5000/health
Testing the API
# test_api.py - Test the Flask API
import requests
import json
# API endpoint
url = 'http://localhost:5000/predict'
# Test data
payload = {
'features': [
[5.1, 3.5, 1.4, 0.2], # Should be setosa
[6.2, 2.9, 4.3, 1.3], # Should be versicolor
[7.2, 3.0, 5.8, 1.6] # Should be virginica
]
}
# Make request
response = requests.post(
url,
json=payload,
headers={'Content-Type': 'application/json'}
)
# Print results
print(f"Status: {response.status_code}")
print(json.dumps(response.json(), indent=2))
How to test your API:
- requests library: Python's way to "be the customer" and call APIs
- requests.post(url, json=...): Send a POST request with JSON data
- Batch predictions: Send multiple samples at once — more efficient than one-by-one!
- response.status_code: 200 = success, 400 = bad request (your fault), 500 = server error (API's fault)
Pro tip: Test with known samples (like training data) first to verify predictions are correct!
# Run with curl
curl -X POST http://localhost:5000/predict \
-H "Content-Type: application/json" \
-d '{"features": [[5.1, 3.5, 1.4, 0.2]]}'
Adding Input Validation
# Enhanced prediction endpoint with validation
from flask import Flask, request, jsonify
import joblib
import numpy as np
EXPECTED_FEATURES = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']
@app.route('/predict', methods=['POST'])
def predict_validated():
"""Make predictions with input validation."""
try:
data = request.get_json()
if 'features' not in data:
return jsonify({'error': 'Missing "features" in request body'}), 400
features = np.array(data['features'])
# Validate shape
if features.ndim == 1:
features = features.reshape(1, -1)
if features.shape[1] != 4:
return jsonify({
'error': f'Expected 4 features, got {features.shape[1]}',
'expected_features': EXPECTED_FEATURES
}), 400
# Validate values
if np.any(np.isnan(features)) or np.any(np.isinf(features)):
return jsonify({'error': 'Features contain NaN or Inf values'}), 400
# Make prediction
predictions = model.predict(features)
probabilities = model.predict_proba(features)
# ... rest of response formatting
except ValueError as e:
return jsonify({'error': f'Invalid input format: {str(e)}'}), 400
except Exception as e:
return jsonify({'error': f'Server error: {str(e)}'}), 500
Why validation matters — users will send you ANYTHING:
- Missing 'features' key: Someone sends
{"data": [...]}instead of{"features": [...]} - Wrong number of features: Model expects 4 features, someone sends 3 or 5
- NaN/Inf values: Broken data that makes predictions meaningless
- Helpful error messages: Don't just say "error" — tell them WHAT went wrong and HOW to fix it
- HTTP status codes: 400 = client's fault (bad data), 500 = server's fault (bug in your code)
Golden rule: Never trust user input. Validate everything!
Practice Questions
Task: Add request logging that captures timestamp, request body, prediction, and latency.
Show Solution
import logging
from datetime import datetime
import time
# Configure logging
logging.basicConfig(
filename='predictions.log',
level=logging.INFO,
format='%(asctime)s - %(message)s'
)
@app.route('/predict', methods=['POST'])
def predict_with_logging():
start_time = time.time()
data = request.get_json()
try:
features = np.array(data['features'])
predictions = model.predict(features)
# Calculate latency
latency_ms = (time.time() - start_time) * 1000
# Log request
logging.info(
f"REQUEST - samples: {len(features)}, "
f"predictions: {predictions.tolist()}, "
f"latency_ms: {latency_ms:.2f}"
)
return jsonify({
'predictions': predictions.tolist(),
'latency_ms': latency_ms
})
except Exception as e:
logging.error(f"ERROR - {str(e)}")
return jsonify({'error': str(e)}), 400
Task: Create a minimal Flask API with just a health check endpoint. Run and test it.
Show Solution
# hello_api.py
from flask import Flask, jsonify
app = Flask(__name__)
@app.route('/')
def home():
return "Hello! The API is running."
@app.route('/health')
def health():
return jsonify({
'status': 'healthy',
'message': 'API is up and running!'
})
if __name__ == '__main__':
app.run(debug=True, port=5000)
# Run: python hello_api.py
# Test:
# curl http://localhost:5000/
# curl http://localhost:5000/health
Task: Write a Python script that tests your ML API by sending a POST request with sample data.
Show Solution
# test_my_api.py
import requests
import json
# API URL
BASE_URL = 'http://localhost:5000'
# Test 1: Health check
print("Testing health endpoint...")
response = requests.get(f'{BASE_URL}/health')
print(f"Status: {response.status_code}")
print(f"Response: {response.json()}")
# Test 2: Single prediction
print("\nTesting single prediction...")
payload = {
'features': [[5.1, 3.5, 1.4, 0.2]] # One iris sample
}
response = requests.post(
f'{BASE_URL}/predict',
json=payload
)
print(f"Status: {response.status_code}")
print(f"Prediction: {json.dumps(response.json(), indent=2)}")
# Test 3: Batch prediction
print("\nTesting batch prediction...")
payload = {
'features': [
[5.1, 3.5, 1.4, 0.2],
[6.2, 2.9, 4.3, 1.3],
[7.2, 3.0, 5.8, 1.6]
]
}
response = requests.post(f'{BASE_URL}/predict', json=payload)
print(f"Predictions: {response.json()}")
Task: Add comprehensive input validation to your Flask predict endpoint with helpful error messages.
Show Solution
from flask import Flask, request, jsonify
import numpy as np
import joblib
app = Flask(__name__)
model = joblib.load('model.joblib')
EXPECTED_FEATURES = 4
FEATURE_NAMES = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']
@app.route('/predict', methods=['POST'])
def predict():
# Check content type
if not request.is_json:
return jsonify({'error': 'Content-Type must be application/json'}), 400
data = request.get_json()
# Check for 'features' key
if 'features' not in data:
return jsonify({
'error': 'Missing "features" key',
'expected_format': {'features': [[1.0, 2.0, 3.0, 4.0]]}
}), 400
try:
features = np.array(data['features'])
except Exception:
return jsonify({'error': 'Features must be a list of lists of numbers'}), 400
# Check dimensions
if features.ndim == 1:
features = features.reshape(1, -1)
if features.shape[1] != EXPECTED_FEATURES:
return jsonify({
'error': f'Expected {EXPECTED_FEATURES} features, got {features.shape[1]}',
'feature_names': FEATURE_NAMES
}), 400
# Check for NaN/Inf
if np.any(np.isnan(features)) or np.any(np.isinf(features)):
return jsonify({'error': 'Features cannot contain NaN or Inf values'}), 400
# Make prediction
predictions = model.predict(features)
return jsonify({'predictions': predictions.tolist()})
Task: Add simple rate limiting to prevent API abuse (max 10 requests per minute per IP).
Show Solution
from flask import Flask, request, jsonify
from collections import defaultdict
import time
app = Flask(__name__)
# Simple in-memory rate limiter
request_counts = defaultdict(list)
RATE_LIMIT = 10 # requests per minute
def check_rate_limit(ip):
"""Check if IP has exceeded rate limit."""
now = time.time()
minute_ago = now - 60
# Clean old requests
request_counts[ip] = [t for t in request_counts[ip] if t > minute_ago]
if len(request_counts[ip]) >= RATE_LIMIT:
return False, len(request_counts[ip])
# Record this request
request_counts[ip].append(now)
return True, len(request_counts[ip])
@app.before_request
def rate_limit_check():
"""Check rate limit before every request."""
ip = request.remote_addr
allowed, count = check_rate_limit(ip)
if not allowed:
return jsonify({
'error': 'Rate limit exceeded',
'limit': f'{RATE_LIMIT} requests per minute',
'retry_after': '60 seconds'
}), 429
@app.route('/predict', methods=['POST'])
def predict():
# Your normal prediction logic
data = request.get_json()
features = np.array(data['features'])
predictions = model.predict(features)
return jsonify({'predictions': predictions.tolist()})
# Note: For production, use Flask-Limiter or Redis-based solutions
FastAPI Service
FastAPI is like Flask's younger, smarter sibling — it does everything Flask does, but with superpowers! The biggest wins: (1) It automatically validates your data (catches wrong inputs before they crash your model), (2) It generates beautiful documentation pages automatically (no extra work!), and (3) It's faster because it uses async. If you're building something serious, FastAPI is the modern choice.
FastAPI with Pydantic Models
# main.py - FastAPI ML Service
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field, validator
from typing import List, Optional
import joblib
import numpy as np
app = FastAPI(
title="Iris Classification API",
description="ML model serving with FastAPI",
version="1.0.0"
)
# Load model at startup
model = joblib.load('models/iris_classifier.joblib')
# Pydantic models for request/response validation
class IrisFeatures(BaseModel):
sepal_length: float = Field(..., ge=0, le=10, description="Sepal length in cm")
sepal_width: float = Field(..., ge=0, le=10, description="Sepal width in cm")
petal_length: float = Field(..., ge=0, le=10, description="Petal length in cm")
petal_width: float = Field(..., ge=0, le=10, description="Petal width in cm")
@validator('*')
def no_nan(cls, v):
if v is None or np.isnan(v) or np.isinf(v):
raise ValueError('Values cannot be NaN or Inf')
return v
def to_array(self):
return [self.sepal_length, self.sepal_width, self.petal_length, self.petal_width]
class PredictionRequest(BaseModel):
samples: List[IrisFeatures]
class PredictionResult(BaseModel):
prediction: int
class_name: str
confidence: float
probabilities: dict
class PredictionResponse(BaseModel):
predictions: List[PredictionResult]
Pydantic = your data bouncer — it checks everyone at the door!
- BaseModel: Define what your data SHOULD look like
- Field(..., ge=0, le=10): Value must be ≥ 0 AND ≤ 10 (constraints!)
- @validator: Custom checks (like "no NaN values allowed")
- Automatic rejection: If data doesn't match, FastAPI returns a nice error automatically!
Why this is amazing: In Flask, you write validation code manually. In FastAPI, you define the rules ONCE and validation happens automatically. Less code, fewer bugs!
# Define endpoints
@app.get("/health")
async def health_check():
"""Check if the service is running."""
return {"status": "healthy", "model_loaded": model is not None}
@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
"""Make predictions for iris flower classification."""
try:
# Convert to numpy array
features = np.array([sample.to_array() for sample in request.samples])
# Get predictions
predictions = model.predict(features)
probabilities = model.predict_proba(features)
# Format results
class_names = ['setosa', 'versicolor', 'virginica']
results = []
for pred, probs in zip(predictions, probabilities):
results.append(PredictionResult(
prediction=int(pred),
class_name=class_names[pred],
confidence=float(max(probs)),
probabilities={name: float(p) for name, p in zip(class_names, probs)}
))
return PredictionResponse(predictions=results)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
FastAPI endpoint breakdown:
- async def: Async = can handle many requests simultaneously without waiting (like a waiter serving multiple tables)
- response_model=PredictionResponse: "I promise to return data matching this format" — FastAPI validates your OUTPUT too!
- HTTPException: Proper way to return errors with status codes
- request: PredictionRequest: FastAPI automatically parses AND validates the incoming JSON
Magic: Just by defining the function signature, FastAPI knows what data to expect and how to validate it!
# Startup and shutdown events
@app.on_event("startup")
async def startup_event():
"""Run on application startup."""
print("Loading model...")
# Model is already loaded globally
print("Model ready!")
@app.on_event("shutdown")
async def shutdown_event():
"""Cleanup on shutdown."""
print("Shutting down...")
# Run with: uvicorn main:app --reload --host 0.0.0.0 --port 8000
Lifecycle events — setup and cleanup:
- startup event: Run code BEFORE the first request (load models, connect to databases)
- shutdown event: Clean up AFTER the last request (close connections, save logs)
- uvicorn: The server that runs FastAPI apps (like how a car engine runs the car)
- --reload: Auto-restart when code changes (development mode — don't use in production!)
- --host 0.0.0.0: Accept connections from any IP (needed for Docker)
To run: uvicorn main:app --reload then visit http://localhost:8000/docs for free documentation!
Automatic API Documentation
FastAPI automatically generates interactive documentation:
- Swagger UI:
http://localhost:8000/docs - ReDoc:
http://localhost:8000/redoc - OpenAPI JSON:
http://localhost:8000/openapi.json
Practice Questions
Task: Add API key authentication using FastAPI's dependency injection.
Show Solution
from fastapi import FastAPI, HTTPException, Depends, Security
from fastapi.security import APIKeyHeader
import os
app = FastAPI()
# API key configuration
API_KEY = os.getenv("API_KEY", "your-secret-key")
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
async def verify_api_key(api_key: str = Security(api_key_header)):
"""Verify the API key."""
if api_key is None:
raise HTTPException(status_code=401, detail="API key missing")
if api_key != API_KEY:
raise HTTPException(status_code=403, detail="Invalid API key")
return api_key
@app.post("/predict", dependencies=[Depends(verify_api_key)])
async def predict(request: PredictionRequest):
"""Protected endpoint - requires valid API key."""
# ... prediction logic
pass
# Test with:
# curl -X POST http://localhost:8000/predict \
# -H "X-API-Key: your-secret-key" \
# -H "Content-Type: application/json" \
# -d '{"samples": [{"sepal_length": 5.1, ...}]}'
Task: Create a minimal FastAPI app and explore the automatic documentation.
Show Solution
# main.py
from fastapi import FastAPI
app = FastAPI(
title="My First FastAPI",
description="A simple API to learn FastAPI",
version="1.0.0"
)
@app.get("/")
async def home():
return {"message": "Hello from FastAPI!"}
@app.get("/health")
async def health():
return {"status": "healthy"}
@app.get("/greet/{name}")
async def greet(name: str):
return {"message": f"Hello, {name}!"}
# Run with: uvicorn main:app --reload
# Visit: http://localhost:8000/docs ← Free Swagger UI!
# Also try: http://localhost:8000/redoc
Task: Create a Pydantic model for iris prediction input and see how validation works.
Show Solution
from fastapi import FastAPI
from pydantic import BaseModel, Field
app = FastAPI()
# Define input schema with validation
class IrisSample(BaseModel):
sepal_length: float = Field(..., ge=0, le=10, description="Sepal length in cm")
sepal_width: float = Field(..., ge=0, le=10, description="Sepal width in cm")
petal_length: float = Field(..., ge=0, le=10, description="Petal length in cm")
petal_width: float = Field(..., ge=0, le=10, description="Petal width in cm")
class Config:
json_schema_extra = {
"example": {
"sepal_length": 5.1,
"sepal_width": 3.5,
"petal_length": 1.4,
"petal_width": 0.2
}
}
@app.post("/predict")
async def predict(sample: IrisSample):
# Validation happens automatically!
# If values are out of range, FastAPI returns 422 error
return {
"received": sample.dict(),
"message": "Validation passed!"
}
# Test with invalid data to see automatic error:
# curl -X POST localhost:8000/predict -H "Content-Type: application/json" \
# -d '{"sepal_length": 100}' ← Will fail validation!
Task: Create a Pydantic response model and use it with your prediction endpoint.
Show Solution
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List, Dict
app = FastAPI()
# Response model
class PredictionResult(BaseModel):
prediction: int
class_name: str
confidence: float
probabilities: Dict[str, float]
class PredictionResponse(BaseModel):
predictions: List[PredictionResult]
latency_ms: float
# The endpoint
@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
start = time.time()
# ... make predictions ...
results = [
PredictionResult(
prediction=0,
class_name="setosa",
confidence=0.95,
probabilities={"setosa": 0.95, "versicolor": 0.03, "virginica": 0.02}
)
]
return PredictionResponse(
predictions=results,
latency_ms=(time.time() - start) * 1000
)
# FastAPI validates output matches PredictionResponse
# Documentation shows exact response format
Task: Use FastAPI lifespan events to load the model on startup and verify it's ready.
Show Solution
from fastapi import FastAPI, HTTPException
from contextlib import asynccontextmanager
import joblib
# Global model variable
ml_model = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Load model on startup, cleanup on shutdown."""
global ml_model
# STARTUP
print("Loading model...")
try:
ml_model = joblib.load('models/iris_classifier.joblib')
print("✓ Model loaded successfully!")
except Exception as e:
print(f"✗ Failed to load model: {e}")
raise
yield # App runs here
# SHUTDOWN
print("Shutting down...")
ml_model = None
app = FastAPI(lifespan=lifespan)
@app.get("/health")
async def health():
return {
"status": "healthy",
"model_loaded": ml_model is not None
}
@app.post("/predict")
async def predict(request: PredictionRequest):
if ml_model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
# Use ml_model for predictions
predictions = ml_model.predict(...)
return {"predictions": predictions.tolist()}
Docker Containerization
Ever heard "it works on my machine!" but fails everywhere else? Docker solves this forever. Think of Docker as a shipping container for your code. Just like shipping containers work with any ship or truck, Docker containers run the same way on any computer. You package your app, Python, all libraries, and the model into ONE box. Someone else can run that box without installing ANYTHING — they just need Docker. No more "install Python 3.10, then pip install 50 packages..."!
Project Structure
ml-api/
├── app/
│ ├── __init__.py
│ ├── main.py # FastAPI application
│ └── models/
│ └── iris_classifier.joblib
├── requirements.txt
├── Dockerfile
├── docker-compose.yml
└── .dockerignore
Dockerfile
# Dockerfile for ML API
FROM python:3.10-slim
# Set working directory
WORKDIR /app
# Install system dependencies
RUN apt-get update && apt-get install -y \
gcc \
&& rm -rf /var/lib/apt/lists/*
# Copy requirements first (for layer caching)
COPY requirements.txt .
# Install Python dependencies
RUN pip install --no-cache-dir -r requirements.txt
# Copy application code
COPY app/ ./app/
# Set environment variables
ENV PYTHONPATH=/app
ENV MODEL_PATH=/app/app/models/iris_classifier.joblib
# Expose port
EXPOSE 8000
# Run the application
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
Dockerfile explained — it's like a recipe for your container:
- FROM python:3.10-slim: Start with a lightweight Python image (the base ingredients)
- WORKDIR /app: Create a folder inside the container and work from there
- COPY requirements.txt FIRST: This is a trick! Docker caches each step. If requirements don't change, it skips pip install on rebuilds (saves minutes!)
- RUN pip install: Install Python packages inside the container
- COPY app/: Copy your code into the container
- EXPOSE 8000: Document which port the app uses (doesn't actually open it)
- CMD: The command that runs when you start the container
Why slim? Full Python image is ~1GB. Slim is ~150MB. Your container downloads faster!
requirements.txt
# requirements.txt
fastapi==0.109.0
uvicorn[standard]==0.27.0
scikit-learn==1.4.0
joblib==1.3.2
numpy==1.26.3
pydantic==2.5.3
Pin your versions or face chaos:
- Use ==, not >=:
scikit-learn==1.4.0notscikit-learn>=1.4 - Why? Your model was trained with sklearn 1.4.0. If someone installs 1.5.0, it might not load!
- How to get versions: Run
pip freeze > requirements.txtin your working environment - Common mistake: Forgetting to include numpy, pandas (they're dependencies but sometimes versions matter)
Critical: The sklearn version MUST match what you used to train. Different versions = potential "unpickling error"!
Build and Run
# Build the Docker image
docker build -t ml-api:1.0 .
# Run the container
docker run -d -p 8000:8000 --name ml-api ml-api:1.0
# Check logs
docker logs ml-api
# Test the API
curl http://localhost:8000/health
# Stop and remove
docker stop ml-api
docker rm ml-api
Docker Compose for Multi-Container Setup
# docker-compose.yml
version: '3.8'
services:
ml-api:
build: .
ports:
- "8000:8000"
environment:
- MODEL_PATH=/app/app/models/iris_classifier.joblib
- LOG_LEVEL=INFO
volumes:
- ./logs:/app/logs
restart: unless-stopped
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
interval: 30s
timeout: 10s
retries: 3
# Optional: Add monitoring
prometheus:
image: prom/prometheus
ports:
- "9090:9090"
volumes:
- ./prometheus.yml:/etc/prometheus/prometheus.yml
depends_on:
- ml-api
Docker Compose — when you need multiple containers working together:
- services: Define multiple containers (API, database, monitoring, etc.)
- ports: Map container port to your machine's port (8000:8000 means access via localhost:8000)
- environment: Pass configuration variables (like settings, but from outside)
- volumes: Persist data outside the container (logs, databases — containers are disposable!)
- healthcheck: Automatically restart if the service becomes unhealthy
- depends_on: Start services in order (database before API)
One command to rule them all: docker-compose up starts everything!
# Start all services
docker-compose up -d
# View logs
docker-compose logs -f ml-api
# Stop all services
docker-compose down
Practice Questions
Task: Create a multi-stage Dockerfile that first builds dependencies in a larger image, then copies only what's needed to a slim final image.
Show Solution
# Multi-stage Dockerfile
# Stage 1: Build
FROM python:3.10 AS builder
WORKDIR /app
# Install build dependencies
RUN apt-get update && apt-get install -y gcc
# Install Python packages
COPY requirements.txt .
RUN pip install --user --no-cache-dir -r requirements.txt
# Stage 2: Production
FROM python:3.10-slim
WORKDIR /app
# Copy installed packages from builder
COPY --from=builder /root/.local /root/.local
# Make sure scripts are in PATH
ENV PATH=/root/.local/bin:$PATH
# Copy only application code
COPY app/ ./app/
EXPOSE 8000
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
# This produces a smaller image since build tools aren't included
Task: Create a simple Dockerfile for your ML API and understand each line.
Show Solution
# Dockerfile
# Line 1: Start from a Python image (the "base")
FROM python:3.10-slim
# Line 2: Create a folder in the container
WORKDIR /app
# Line 3: Copy requirements file
COPY requirements.txt .
# Line 4: Install Python packages
RUN pip install --no-cache-dir -r requirements.txt
# Line 5: Copy your application code
COPY . .
# Line 6: Tell Docker what port your app uses
EXPOSE 8000
# Line 7: Command to run when container starts
CMD ["python", "app.py"]
# Build: docker build -t my-ml-api .
# Run: docker run -p 8000:8000 my-ml-api
# Test: curl http://localhost:8000/health
Task: Practice the essential Docker commands: build, run, stop, logs, and cleanup.
Show Solution
# BUILD: Create an image from Dockerfile
docker build -t ml-api:v1 .
# -t = tag/name the image
# . = use Dockerfile in current directory
# LIST IMAGES: See what images exist
docker images
# RUN: Start a container from the image
docker run -d -p 8000:8000 --name my-api ml-api:v1
# -d = run in background (detached)
# -p = map port (host:container)
# --name = name the container
# CHECK RUNNING CONTAINERS
docker ps
# VIEW LOGS
docker logs my-api
docker logs -f my-api # -f = follow (live)
# STOP the container
docker stop my-api
# REMOVE the container
docker rm my-api
# REMOVE the image
docker rmi ml-api:v1
# PRO TIP: One-liner to stop, remove, rebuild
docker stop my-api; docker rm my-api; docker build -t ml-api:v1 . && docker run -d -p 8000:8000 --name my-api ml-api:v1
Task: Pass configuration via environment variables instead of hardcoding.
Show Solution
# Dockerfile with environment variables
FROM python:3.10-slim
WORKDIR /app
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY . .
# Set default environment variables
ENV MODEL_PATH=/app/models/model.joblib
ENV LOG_LEVEL=INFO
ENV PORT=8000
EXPOSE ${PORT}
CMD ["sh", "-c", "uvicorn app:app --host 0.0.0.0 --port ${PORT}"]
# In your Python app, read environment variables
import os
MODEL_PATH = os.getenv('MODEL_PATH', 'models/model.joblib')
LOG_LEVEL = os.getenv('LOG_LEVEL', 'INFO')
PORT = int(os.getenv('PORT', 8000))
# Override env vars when running
docker run -d \
-p 8080:8080 \
-e MODEL_PATH=/app/models/v2.joblib \
-e LOG_LEVEL=DEBUG \
-e PORT=8080 \
ml-api:v1
Task: Create a docker-compose.yml with health checks and automatic restart policies.
Show Solution
# docker-compose.yml
version: '3.8'
services:
ml-api:
build: .
ports:
- "8000:8000"
environment:
- MODEL_PATH=/app/models/model.joblib
- LOG_LEVEL=INFO
volumes:
- ./logs:/app/logs # Persist logs outside container
restart: unless-stopped # Auto-restart unless manually stopped
# Health check - restart if unhealthy
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
interval: 30s # Check every 30 seconds
timeout: 10s # Fail if no response in 10s
retries: 3 # Mark unhealthy after 3 failures
start_period: 40s # Wait 40s before first check
# Resource limits
deploy:
resources:
limits:
memory: 512M
reservations:
memory: 256M
# Run
docker-compose up -d
# Check health status
docker-compose ps
# View logs
docker-compose logs -f ml-api
# Stop
docker-compose down
Monitoring & Logging
Your API is live — but how do you know it's working well? You can't stare at the logs 24/7! Monitoring is like having security cameras and sensors in your restaurant. You track: Is the kitchen working fast enough? (latency) Are orders failing? (errors) Is the food quality consistent? (prediction distribution) Are customer tastes changing? (data drift). Without monitoring, your model could be making terrible predictions and you wouldn't know until customers complain!
Structured Logging
# logging_config.py - Production logging setup
import logging
import json
from datetime import datetime
import sys
class JSONFormatter(logging.Formatter):
"""Format logs as JSON for easy parsing."""
def format(self, record):
log_record = {
'timestamp': datetime.utcnow().isoformat(),
'level': record.levelname,
'message': record.getMessage(),
'module': record.module,
'function': record.funcName,
'line': record.lineno
}
if hasattr(record, 'prediction_data'):
log_record['prediction_data'] = record.prediction_data
if record.exc_info:
log_record['exception'] = self.formatException(record.exc_info)
return json.dumps(log_record)
# Configure logger
def setup_logging():
logger = logging.getLogger('ml_api')
logger.setLevel(logging.INFO)
# Console handler with JSON format
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(JSONFormatter())
logger.addHandler(handler)
return logger
logger = setup_logging()
Why JSON logs instead of plain text?
- Plain text:
2024-01-15 10:30:00 - Made prediction: [1, 2, 0]— hard to search and filter - JSON:
{"timestamp": "2024-01-15T10:30:00", "predictions": [1, 2, 0], "latency_ms": 15} - Log aggregation tools (like Elasticsearch, CloudWatch) can parse JSON automatically
- Easy queries: "Show me all predictions slower than 100ms" or "Find errors from last hour"
- Structured = searchable: Filter by any field, create dashboards, set up alerts
Pro tip: Include prediction_data in logs so you can debug "why did the model predict X?"
# main.py - Using structured logging
from logging_config import logger
import time
@app.post("/predict")
async def predict(request: PredictionRequest):
start_time = time.time()
try:
# Make prediction
features = np.array([s.to_array() for s in request.samples])
predictions = model.predict(features)
latency_ms = (time.time() - start_time) * 1000
# Log prediction with structured data
logger.info(
"Prediction completed",
extra={'prediction_data': {
'num_samples': len(request.samples),
'predictions': predictions.tolist(),
'latency_ms': round(latency_ms, 2)
}}
)
return PredictionResponse(predictions=format_predictions(predictions))
except Exception as e:
logger.error(f"Prediction failed: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
Prometheus Metrics
# metrics.py - Prometheus metrics for ML API
from prometheus_client import Counter, Histogram, Gauge, generate_latest
from fastapi import Response
# Define metrics
PREDICTIONS_TOTAL = Counter(
'predictions_total',
'Total number of predictions',
['model_name', 'prediction_class']
)
PREDICTION_LATENCY = Histogram(
'prediction_latency_seconds',
'Prediction latency in seconds',
buckets=[0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0]
)
MODEL_CONFIDENCE = Histogram(
'model_confidence',
'Prediction confidence distribution',
buckets=[0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.99, 1.0]
)
FEATURE_VALUE = Gauge(
'feature_value',
'Feature value distribution',
['feature_name']
)
Prometheus metrics — your monitoring dashboard data:
- Counter: Things that only go UP (total requests, total predictions). "How many times did X happen?"
- Histogram: Distribution of values (latency). Gives you p50, p95, p99 automatically. "How long do requests usually take?"
- Gauge: Current value that goes up AND down (active connections, memory usage). "What is X right now?"
- Labels: Add dimensions like model_name or prediction_class to slice data
The payoff: Connect to Grafana for beautiful dashboards. Set alerts like "if p95 latency > 500ms, notify me!"
# Instrument the prediction endpoint
from prometheus_client import generate_latest, CONTENT_TYPE_LATEST
@app.post("/predict")
async def predict(request: PredictionRequest):
with PREDICTION_LATENCY.time(): # Automatically record latency
features = np.array([s.to_array() for s in request.samples])
predictions = model.predict(features)
probabilities = model.predict_proba(features)
# Record metrics
class_names = ['setosa', 'versicolor', 'virginica']
for pred, probs in zip(predictions, probabilities):
PREDICTIONS_TOTAL.labels(
model_name='iris_classifier',
prediction_class=class_names[pred]
).inc()
MODEL_CONFIDENCE.observe(max(probs))
return format_response(predictions, probabilities)
@app.get("/metrics")
async def metrics():
"""Expose Prometheus metrics."""
return Response(generate_latest(), media_type=CONTENT_TYPE_LATEST)
Data Drift Detection
# drift_monitor.py - Simple data drift detection
import numpy as np
from scipy import stats
from collections import deque
class DriftMonitor:
"""Monitor for data drift in production."""
def __init__(self, reference_data, window_size=1000):
self.reference_data = reference_data
self.recent_data = deque(maxlen=window_size)
self.feature_names = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']
def add_sample(self, sample):
"""Add a new prediction sample to the window."""
self.recent_data.append(sample)
def check_drift(self, threshold=0.05):
"""Check for distribution drift using KS test."""
if len(self.recent_data) < 100:
return {'status': 'insufficient_data', 'samples': len(self.recent_data)}
recent = np.array(self.recent_data)
drift_detected = {}
for i, name in enumerate(self.feature_names):
# Kolmogorov-Smirnov test
statistic, p_value = stats.ks_2samp(
self.reference_data[:, i],
recent[:, i]
)
drift_detected[name] = {
'statistic': float(statistic),
'p_value': float(p_value),
'drift': p_value < threshold
}
return {
'status': 'checked',
'samples_analyzed': len(self.recent_data),
'features': drift_detected,
'any_drift': any(f['drift'] for f in drift_detected.values())
}
# Initialize with training data
drift_monitor = DriftMonitor(X_train)
Data drift — when reality changes but your model doesn't:
- The problem: You trained on 2023 data. It's now 2025. Customer behavior changed!
- KS test: Compares "do these two datasets come from the same distribution?"
- Low p-value (<0.05): The distributions are DIFFERENT. Production data doesn't look like training data!
- Why it matters: Models only work well on data similar to what they trained on
- Example: Trained on ages 20-40, now seeing ages 50-70 → predictions may be wrong
When drift is detected: Time to retrain your model with fresh data!
- Set up alerts for latency spikes and error rate increases
- Monitor prediction distribution - sudden changes may indicate data issues
- Implement A/B testing for model updates
- Keep model versioning and rollback capability
Practice Questions
Task: Create an endpoint that can load and switch between different model versions at runtime.
Show Solution
from fastapi import FastAPI
import joblib
from pathlib import Path
import threading
app = FastAPI()
class ModelManager:
"""Manage multiple model versions."""
def __init__(self, models_dir: str):
self.models_dir = Path(models_dir)
self.current_model = None
self.current_version = None
self.lock = threading.Lock()
def list_versions(self):
"""List available model versions."""
return [f.stem for f in self.models_dir.glob("*.joblib")]
def load_version(self, version: str):
"""Load a specific model version."""
model_path = self.models_dir / f"{version}.joblib"
if not model_path.exists():
raise ValueError(f"Model version {version} not found")
with self.lock:
self.current_model = joblib.load(model_path)
self.current_version = version
return True
def predict(self, features):
"""Make prediction with current model."""
with self.lock:
if self.current_model is None:
raise ValueError("No model loaded")
return self.current_model.predict(features)
model_manager = ModelManager("models/")
model_manager.load_version("v1.0")
@app.get("/models/versions")
async def list_versions():
return {"versions": model_manager.list_versions(), "current": model_manager.current_version}
@app.post("/models/switch/{version}")
async def switch_model(version: str):
try:
model_manager.load_version(version)
return {"message": f"Switched to {version}", "current": model_manager.current_version}
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
Task: Set up basic file logging for your ML API to track predictions.
Show Solution
import logging
from datetime import datetime
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('api.log'), # Write to file
logging.StreamHandler() # Also print to console
]
)
logger = logging.getLogger(__name__)
# Usage in your API
@app.route('/predict', methods=['POST'])
def predict():
logger.info(f"Received prediction request from {request.remote_addr}")
try:
features = request.get_json()['features']
prediction = model.predict(features)
logger.info(f"Prediction successful: {prediction}")
return jsonify({'prediction': prediction.tolist()})
except Exception as e:
logger.error(f"Prediction failed: {str(e)}", exc_info=True)
return jsonify({'error': str(e)}), 500
# api.log will contain all logs with timestamps
Task: Add latency tracking to measure how long predictions take.
Show Solution
import time
from functools import wraps
# Decorator to measure execution time
def track_latency(func):
@wraps(func)
def wrapper(*args, **kwargs):
start = time.time()
result = func(*args, **kwargs)
latency_ms = (time.time() - start) * 1000
print(f"{func.__name__} took {latency_ms:.2f}ms")
return result
return wrapper
# Usage with Flask
@app.route('/predict', methods=['POST'])
def predict():
start_time = time.time()
# Your prediction code
features = request.get_json()['features']
prediction = model.predict(features)
# Calculate and log latency
latency_ms = (time.time() - start_time) * 1000
return jsonify({
'prediction': prediction.tolist(),
'latency_ms': round(latency_ms, 2)
})
# Track latencies over time
latencies = []
@app.route('/predict', methods=['POST'])
def predict():
start = time.time()
# ... prediction logic ...
latency = (time.time() - start) * 1000
latencies.append(latency)
# Alert if latency is too high
if latency > 100: # 100ms threshold
logging.warning(f"High latency detected: {latency:.2f}ms")
Task: Monitor the distribution of predictions to detect if model behavior changes.
Show Solution
from collections import Counter
from datetime import datetime
class PredictionMonitor:
"""Track prediction distribution over time."""
def __init__(self):
self.hourly_counts = {} # {hour: Counter}
self.total_counts = Counter()
def record(self, prediction):
"""Record a prediction."""
hour = datetime.now().strftime('%Y-%m-%d %H:00')
if hour not in self.hourly_counts:
self.hourly_counts[hour] = Counter()
self.hourly_counts[hour][prediction] += 1
self.total_counts[prediction] += 1
def get_distribution(self):
"""Get overall prediction distribution."""
total = sum(self.total_counts.values())
return {k: v/total for k, v in self.total_counts.items()}
def check_drift(self, expected_distribution, threshold=0.1):
"""Check if current distribution differs from expected."""
current = self.get_distribution()
alerts = []
for class_name, expected_pct in expected_distribution.items():
actual_pct = current.get(class_name, 0)
diff = abs(actual_pct - expected_pct)
if diff > threshold:
alerts.append(f"{class_name}: expected {expected_pct:.1%}, got {actual_pct:.1%}")
return alerts
# Usage
monitor = PredictionMonitor()
@app.post("/predict")
async def predict(request: PredictionRequest):
predictions = model.predict(...)
for pred in predictions:
monitor.record(pred)
return {"predictions": predictions.tolist()}
@app.get("/monitoring/distribution")
async def get_distribution():
return monitor.get_distribution()
Task: Implement basic data drift detection to alert when input data distribution changes.
Show Solution
import numpy as np
from collections import deque
from scipy import stats
class DriftDetector:
"""Detect when input data distribution shifts."""
def __init__(self, reference_data, window_size=1000, threshold=0.05):
self.reference_data = reference_data
self.window = deque(maxlen=window_size)
self.threshold = threshold
self.feature_names = ['feature_1', 'feature_2', 'feature_3', 'feature_4']
def add_sample(self, sample):
"""Add a new sample to the monitoring window."""
self.window.append(sample)
def check_drift(self):
"""Check if current data distribution differs from reference."""
if len(self.window) < 100:
return {"status": "insufficient_data", "samples": len(self.window)}
current_data = np.array(self.window)
drift_detected = {}
for i, name in enumerate(self.feature_names):
# Kolmogorov-Smirnov test
stat, p_value = stats.ks_2samp(
self.reference_data[:, i],
current_data[:, i]
)
is_drift = p_value < self.threshold
drift_detected[name] = {
"p_value": round(p_value, 4),
"drift": is_drift
}
any_drift = any(d["drift"] for d in drift_detected.values())
return {
"status": "drift_detected" if any_drift else "ok",
"features": drift_detected
}
# Initialize with training data
drift_detector = DriftDetector(X_train)
@app.post("/predict")
async def predict(request: PredictionRequest):
features = np.array([s.to_array() for s in request.samples])
# Record samples for drift detection
for sample in features:
drift_detector.add_sample(sample)
return {"predictions": model.predict(features).tolist()}
@app.get("/monitoring/drift")
async def check_drift():
return drift_detector.check_drift()
Key Takeaways
Save Pipelines, Not Just Models
Use joblib to save the entire pipeline including preprocessing. This ensures consistent transformations during prediction.
FastAPI for Production
FastAPI provides automatic validation, documentation, and async support. Use Pydantic models for request/response schemas.
Containerize Everything
Docker ensures consistent environments. Pin all dependency versions and use multi-stage builds for smaller images.
Monitor Model Health
Track prediction distribution, latency, and error rates. Implement drift detection to know when to retrain.
Structured Logging
Use JSON logging for easy parsing by log aggregation systems. Include prediction metadata for debugging.
Version Your Models
Track model versions with metadata. Enable easy rollback if a new model underperforms in production.
Knowledge Check
Why is joblib preferred over pickle for scikit-learn models?
What is the main advantage of FastAPI over Flask for ML APIs?
Why should you copy requirements.txt before application code in a Dockerfile?
What does data drift detection help you identify?
What should you do when loading a model at API startup?
What type of Prometheus metric would you use to track prediction latency?