Robust AI
Focuses on building reliable ML systems with proper error handling, graceful degradation, and robustness to distribution shift and adversarial conditions.
- Explain how SRE principles apply to ML systems including SLOs for model quality and error budgets
- Design graceful degradation strategies with fallback systems for ML service failures
- Analyze types of distribution shift and implement appropriate detection and response mechanisms
- Evaluate behavioral and metamorphic testing approaches for validating ML model correctness
- Implement model calibration techniques to produce reliable confidence scores for downstream decision-making
01 Building Reliable ML Systems Viz
Reliability in ML systems means consistently producing correct outputs under expected conditions and degrading gracefully under unexpected conditions. Unlike traditional software where correctness is binary, ML systems operate on a spectrum of quality, making reliability a more nuanced engineering challenge.
ML Reliability
The ability of an ML system to consistently produce correct outputs under expected conditions and degrade gracefully under unexpected conditions. Unlike traditional software where failures are binary (crash or no crash), ML reliability operates on a continuous spectrum of prediction quality.
Figure: ML Failure Modes & Mitigation Matrix
SRE Principles for ML
Reliability engineering for ML draws on principles from site reliability engineering (SRE) and adapts them for the unique characteristics of ML systems. Key practices include defining service level objectives for model quality, implementing automated testing, and establishing incident response procedures.
Traditional SLOs measure uptime and latency. ML SLOs must also measure prediction quality. Example: "The recommendation model shall maintain click-through rate within 10% of baseline for 99.5% of weekly measurement periods." Defining these ML-specific SLOs requires collaboration between ML engineers, product managers, and SRE teams.
| Metric | Traditional Software | ML System Equivalent |
|---|---|---|
| Uptime | 99.9% availability | 99.9% availability + model accuracy > threshold |
| Latency | P99 < 200ms | P99 < 200ms including model inference |
| Correctness | Returns expected output | Prediction accuracy > SLO target |
| Error Budget | Allowed downtime minutes | Allowed accuracy degradation periods |
| MTTR | Time to restore service | Time to detect + diagnose + rollback model |
Table 16.1: Translating SRE concepts to ML systems.
End-to-End Pipeline Reliability
The reliability of an ML system depends on every component in the pipeline: data collection, preprocessing, model inference, postprocessing, and integration with downstream systems. A failure at any point can corrupt outputs.
ML pipeline reliability follows the chain rule: the overall reliability is bounded by the least reliable component. A 99.99% available model is worthless if the feature pipeline fails 1% of the time. Map every component in your pipeline, measure its failure rate independently, and prioritize improving the least reliable component.
MTBF and MTTR
Mean time between failures (MTBF) and mean time to recovery (MTTR) are key reliability metrics. For ML systems, "failure" might mean model accuracy dropping below a threshold rather than a complete system crash.
Availability = \frac{MTBF}{MTBF + MTTR}For ML systems, preventing all failures is impractical — data distributions shift, upstream APIs change, and the real world is unpredictable. Focus investment on reducing MTTR: automated drift detection (fast detection), pre-built runbooks (fast diagnosis), and one-click rollback (fast recovery). A team that can recover in 5 minutes is more reliable than one that tries to prevent all failures but takes 4 hours to recover when one occurs.
ML Reliability
The ability of an ML system to consistently produce correct outputs under expected conditions and degrade gracefully under unexpected conditions.
MTTR
Mean Time To Recovery, the average time required to detect, diagnose, and resolve a system failure, a key reliability metric.
02 Error Handling and Graceful Degradation
ML systems must handle errors at multiple levels: infrastructure failures (GPU errors, network outages), data errors (missing features, invalid inputs), and model errors (low confidence predictions, out-of-distribution inputs). Each type requires appropriate handling strategies.
| Error Level | Examples | Detection | Handling Strategy |
|---|---|---|---|
| Infrastructure | GPU OOM, network timeout, disk full | Health checks, resource monitoring | Retry, failover to backup, circuit breaker |
| Data | Missing features, invalid types, NaN values | Schema validation, range checks | Default values, feature imputation, reject request |
| Model | Low confidence, OOD input, inconsistent output | Confidence thresholds, OOD detection | Fallback model, human review, cached response |
Table 16.2: Error types in ML systems and their handling strategies.
Graceful Degradation
Graceful Degradation
A system design strategy where components fail incrementally rather than catastrophically, with each fallback level providing reduced but still useful functionality. In ML systems, this typically means falling back to simpler models or heuristics when the primary model is unavailable or unreliable.
Graceful degradation means the system continues to provide value even when components fail. Each fallback level provides less value but maintains system availability.
A robust system doesn't need to be right all the time; it needs to know when it's wrongDeeper InsightThe most dangerous ML systems are those that fail silently with high confidence. A truly robust system distinguishes between inputs it can handle reliably and inputs where its predictions are uncertain. By detecting its own limitations — through calibrated confidence, OOD detection, or ensemble disagreement — the system can route uncertain cases to fallbacks or human review, maintaining overall reliability even when individual predictions would be unreliable.Think of it like...Think of a seasoned doctor who says "I'm not sure — let me refer you to a specialist" versus a overconfident one who guesses. The first doctor's patients get better outcomes overall, even though the doctor personally handles fewer cases. Click to collapse
Level 1: Personalized ML model (best quality). Level 2: Collaborative filtering fallback (if personalized model is down). Level 3: Popularity-based ranking (if all models are unavailable). Level 4: Curated editorial list (if all computation fails). Each level is pre-computed and cached, ensuring the system always has something to show the user.
Confidence-Based Routing
Confidence-based routing sends predictions through different paths based on model confidence. High-confidence predictions are served directly. Low-confidence predictions might be routed to a more expensive model, queued for human review, or handled by a fallback system.
class="tok-comment"># Confidence-based routing pseudocode
def serve_prediction(input_data):
prediction, confidence = primary_model.predict(input_data)
if confidence >= class="tok-number">0.95:
return prediction class="tok-comment"># High confidence: serve directly
elif confidence >= class="tok-number">0.7:
class="tok-comment"># Medium confidence: use ensemble for verification
ensemble_pred = ensemble_model.predict(input_data)
return ensemble_pred
else:
class="tok-comment"># Low confidence: fall back or escalate
if is_time_sensitive:
return rule_based_fallback(input_data)
else:
queue_for_human_review(input_data)
return cached_default_response()Confidence-based routing only works if the model's confidence scores are well-calibrated — meaning a 90% confidence prediction should actually be correct 90% of the time. Apply temperature scaling or Platt scaling to calibrate confidence scores before using them for routing decisions.
Circuit Breaker Pattern
Circuit Breaker
A design pattern borrowed from electrical engineering and distributed systems that monitors for failures and automatically routes traffic to a fallback system when the primary system becomes unhealthy. The circuit breaker has three states: closed (normal), open (fallback), and half-open (testing recovery).
Circuit breaker patterns prevent cascading failures in ML serving systems. If a model serving endpoint begins returning errors or high latencies, the circuit breaker trips and routes traffic to a fallback system.
Without circuit breakers, a slow model endpoint causes upstream services to queue requests, consuming memory and threads. This can cascade through the entire system. A circuit breaker that trips after 5 consecutive failures or latency exceeding 3x the P99 baseline prevents this cascade by immediately routing to the fallback.
Graceful Degradation
A design strategy where system components fail incrementally, falling back to simpler alternatives rather than complete failure.
Circuit Breaker
A pattern that detects failures and prevents cascading damage by routing traffic to fallback systems when the primary system is unhealthy.
03 Robustness to Distribution Shift
Distribution shift occurs when the data encountered in production differs from the training data distribution. This is inevitable in real-world systems because the world changes over time, and no training dataset can perfectly represent all future conditions.
Distribution Shift
A change between the training data distribution and the data encountered in production. Distribution shift is inevitable in real-world systems and is the primary cause of model performance degradation over time. It encompasses covariate shift, label shift, and concept drift.
What is the difference between robustness and accuracy?
Covariate Shift
Covariate shift refers to changes in the input distribution while the relationship between inputs and outputs remains stable. Domain adaptation techniques address covariate shift by learning representations that are invariant across domains.
P_{train}(X) \neq P_{test}(X) \quad \text{but} \quad P_{train}(Y|X) = P_{test}(Y|X)| Shift Type | What Changes | What Stays Constant | Example |
|---|---|---|---|
| Covariate Shift | Input distribution P(X) | Conditional P(Y|X) | Camera images at night vs. daytime |
| Label Shift | Class proportions P(Y) | Conditional P(X|Y) | Disease prevalence changes seasonally |
| Concept Drift | Relationship P(Y|X) | Potentially everything | Customer preferences evolve over time |
| Domain Shift | Both P(X) and P(Y|X) | Task structure | Model trained on hospital A, deployed at hospital B |
Table 16.3: Types of distribution shift in ML systems.
Simple domain adaptation includes input normalization and batch normalization, which center each batch to reduce covariate shift. More sophisticated methods include adversarial domain adaptation (learning features that a domain discriminator cannot distinguish) and distribution matching (minimizing Maximum Mean Discrepancy between source and target feature distributions).
Label Shift
Label shift occurs when the class proportions change between training and deployment. A disease detection model trained on balanced data may encounter highly imbalanced real-world prevalence.
A skin cancer classifier trained on a balanced dataset (50% benign, 50% malignant) encounters real-world prevalence of 98% benign, 2% malignant. Without adjustment, the model's precision on malignant predictions will be much lower than expected from validation. Calibration and class-weight adjustment account for this prevalence difference, but the most robust approach is to evaluate and train with realistic class proportions.
Out-of-Distribution Detection
Out-of-Distribution (OOD) Detection
Methods for identifying inputs that differ significantly from the training distribution, enabling the system to flag uncertain predictions for human review or route them to fallback systems rather than making unreliable predictions.
OOD detection identifies inputs that are fundamentally different from anything seen during training. Rather than making potentially unreliable predictions, the system can flag them for human review or route them to a fallback.
- Softmax confidence thresholds — Simple but unreliable; neural networks are often overconfident on OOD inputs.
- Energy-based methods — Use the energy score (log-sum-exp of logits) as an OOD indicator; more reliable than softmax.
- Mahalanobis distance — Compute distance from class centroids in feature space; effective for detecting semantic OOD.
- Ensemble disagreement — High variance across ensemble members suggests the input is unlike training data.
Standard neural networks are notoriously overconfident on out-of-distribution inputs. A classifier trained on cats and dogs may assign 99% confidence to an image of a car. Do not rely on raw softmax confidence for OOD detection. Use dedicated OOD scoring methods like energy scores or Mahalanobis distance, which are specifically designed to distinguish in-distribution from OOD inputs.
Distribution Shift
The difference between the training data distribution and the data encountered in production, a primary cause of model performance degradation.
Out-of-Distribution Detection
Methods for identifying inputs that differ significantly from the training distribution, enabling appropriate handling rather than unreliable predictions.
04 Testing ML Systems
Testing ML systems requires approaches beyond traditional software testing because ML behavior is learned from data rather than explicitly programmed. A comprehensive testing strategy includes unit tests for data processing code, integration tests for the pipeline, and model-specific tests that validate learned behavior.
| Test Level | What It Tests | Example | When to Run |
|---|---|---|---|
| Unit Tests | Data processing functions, feature engineering | Verify feature scaler produces expected output | Every commit |
| Integration Tests | Pipeline components working together | Verify data flows from ingestion through prediction | Every PR merge |
| Model Tests | Learned model behavior | Behavioral tests, invariance tests, minimum functionality | Every model training run |
| System Tests | End-to-end with real infrastructure | Latency under load, failover behavior | Pre-deployment |
Table 16.4: ML testing pyramid from unit tests to system tests.
Behavioral Testing (CheckList)
Behavioral Testing
A testing methodology inspired by the CheckList framework that validates specific model capabilities through three test types: Minimum Functionality Tests (basic capability), Invariance Tests (robustness to irrelevant changes), and Directional Expectation Tests (correct response to meaningful changes).
Behavioral testing tests specific model capabilities through targeted test suites. Minimum functionality tests verify basic capabilities. Invariance tests check that irrelevant input changes do not affect predictions. Directional expectation tests verify expected directional output changes.
Minimum Functionality: "This movie is excellent" should be positive. "This movie is terrible" should be negative. Invariance: Changing "movie" to "film" should not change the prediction. Changing "I watched in the theater" to "I watched at home" should not change sentiment. Directional: Adding "but the ending was disappointing" to a positive review should reduce the positive score.
class="tok-comment"># Behavioral test suite example
def test_minimum_functionality():
class="tok-string">class="tok-string">""class="tok-string">"Basic capabilities the model MUST have."class="tok-string">""
assert model.predict(&class="tok-comment">#class="tok-number">39;This is great!class="tok-string">class="tok-number">39;) == class="tok-number">39;positiveclass="tok-string">class="tok-number">39;
assert model.predict(&class="tok-comment">#class="tok-number">39;This is terrible.class="tok-number">39;) == class="tok-string">class="tok-number">39;negativeclass="tok-number">39;
def test_invariance():
class="tok-string">class="tok-string">""class="tok-string">"Irrelevant changes should NOT affect predictions."class="tok-string">""
base = model.predict(&class="tok-comment">#class="tok-number">39;The food was excellent.class="tok-string">class="tok-number">39;)
assert model.predict(&class="tok-comment">#class="tok-number">39;The food was excellent!class="tok-number">39;) == base # Punctuation
assert model.predict(&class="tok-comment">#class="tok-number">39;The FOOD was excellent.class="tok-string">class="tok-number">39;) == base # Capitalization
def test_directional():
class="tok-string">"""Meaningful changes should produce expected direction."""
base_score = model.score(&class="tok-comment">#class="tok-number">39;Good product.class="tok-number">39;)
modified_score = model.score(&class="tok-comment">#class="tok-number">39;Good product but broke after a week.class="tok-number">39;)
assert modified_score < base_score class="tok-comment"># Adding negative context reduces scoreMetamorphic Testing
Metamorphic Testing
A testing strategy that verifies relationships between inputs and outputs rather than requiring knowledge of the correct output for each test case. Useful when ground truth is unavailable or expensive, which is common in ML applications.
Metamorphic testing addresses the oracle problem by testing relationships between inputs. If a model should be invariant to rotation, rotating an input and checking that the output is unchanged provides a test without needing to know the correct output.
Start by asking: "What transformations should NOT change the output?" and "What transformations should change the output in a predictable direction?" For image classifiers: rotation, cropping, and brightness should not change the class. For regression models: scaling inputs by a constant should scale outputs proportionally if the relationship is linear.
Stress Testing and Chaos Engineering
Stress testing and chaos engineering for ML systems deliberately inject failures, noise, and adversarial conditions to verify system resilience.
- Corrupted inputs — Feed null values, extreme outliers, wrong data types to test input validation.
- Missing features — Drop random features to verify the model handles missing data gracefully.
- Model serving failures — Kill model serving pods to test failover to backup models.
- Data pipeline outages — Disconnect the feature store to verify cached/default features are used.
- Latency injection — Add artificial delay to model inference to test timeout handling.
- Adversarial inputs — Test with known adversarial examples to verify detection and routing.
Chaos engineering is powerful but must be conducted in a controlled staging environment that mirrors production. Never inject failures in production without comprehensive safeguards, automatic rollback, and team readiness. Start with the simplest failure mode (e.g., single pod restart) and gradually increase severity.
Behavioral Testing
A testing approach that validates specific model capabilities through targeted test suites, including invariance, directional, and minimum functionality tests.
Metamorphic Testing
A testing strategy that verifies relationships between inputs and outputs rather than requiring knowledge of correct outputs for each test case.
05 Robustness Engineering Practices
Ensemble Methods
Ensemble Methods
Techniques that combine predictions from multiple independently trained models to improve accuracy and robustness. Because individual models tend to make different errors, averaging their predictions reduces variance and provides more stable outputs, especially on out-of-distribution inputs.
Ensemble methods improve robustness by combining predictions from multiple models. Individual models may be sensitive to specific types of distribution shift, but their errors tend to be uncorrelated.
\hat{y}_{ensemble} = \frac{1}{M} \sum_{m=1}^{M} f_m(x)An ensemble of 5 diverse models (different architectures, training data splits, or random seeds) provides more robustness than 20 copies of the same architecture. Maximize diversity: use different model families, different feature subsets, or different training epochs. If all ensemble members agree on an OOD input, none of them detected the shift.
Data Augmentation for Robustness
Data augmentation during training improves robustness by exposing the model to a wider range of input variations. Augmentations should simulate the types of distribution shift expected in production.
| Augmentation | Simulates | Best For |
|---|---|---|
| Gaussian noise | Sensor noise, signal degradation | Audio, time series, sensor data |
| Color jittering | Lighting and white balance changes | Camera-based vision systems |
| Random crop and resize | Varying object scale and position | Object detection, classification |
| Mixup / CutMix | Class boundary smoothing | General robustness and calibration |
| Back-translation | Paraphrasing, style variation | NLP text classification |
| SpecAugment | Acoustic variability | Speech recognition |
Table 16.5: Data augmentation strategies matched to production distribution shifts.
The most effective augmentations simulate the actual distribution shifts your model will encounter. If deploying a vision model in outdoor environments, augment with weather effects (rain, fog, glare) and time-of-day variations. If deploying an NLP model across domains, augment with vocabulary and style variations. Generic augmentation helps, but targeted augmentation based on deployment analysis helps more.
Model Calibration
Model Calibration
The process of adjusting a model's confidence scores so they accurately reflect the true probability of correct prediction. A perfectly calibrated model that reports 80% confidence should be correct approximately 80% of the time across all such predictions.
Calibration ensures that model confidence scores accurately reflect the true probability of correct prediction. Temperature scaling and Platt scaling are simple post-hoc calibration methods.
p_{calibrated} = \sigma\left(\frac{z}{T}\right)A medical diagnostic model provides predictions with calibrated confidence. When it reports >95% confidence, doctors trust the prediction and move quickly. When it reports 60-80% confidence, doctors examine the case more carefully. When it reports <50% confidence, the case is flagged for specialist review. This tiered workflow is only possible because the confidence scores are calibrated — uncalibrated scores would mislead the clinical workflow.
Continuous Monitoring and Retraining
Continuous monitoring and automated retraining form a feedback loop that maintains robustness over time. This closed-loop approach is more sustainable than manually triggered retraining.
Automated retraining pipelines must include validation gates that prevent deploying a model that is worse than the current production model. Always compare the retrained model against the production model on both recent data and a stable holdout set. Without these gates, retraining on noisy or corrupted recent data can degrade production quality.
Model Calibration
The process of adjusting model confidence scores so they accurately reflect the true probability of correct prediction.
Ensemble Methods
Techniques that combine predictions from multiple models to improve accuracy and robustness through error diversification.
Key Takeaways
- 1ML reliability requires graceful degradation with fallback systems rather than brittle all-or-nothing designs.
- 2Distribution shift is inevitable in production; systems must detect and adapt to changes in data distributions over time.
- 3Behavioral and metamorphic testing complement traditional testing by validating learned model behaviors.
- 4Model calibration enables reliable confidence-based routing and better human-AI decision collaboration.
- 5Continuous monitoring and automated retraining form a feedback loop that maintains robustness over time.
CH.16
Chapter Complete
Chapter Progress
Interact with the visualization
Robust AI Quiz
Test your understanding of distribution shift, out-of-distribution detection, and building reliable ML systems.
Ready to test your knowledge?