Model Registries and Drift Detection
SummaryA model in production without a registry is...
A model in production without a registry is...
A model in production without a registry is a model you cannot reproduce, audit, or roll back. This section builds a complete MLflow workflow: logging parameters, metrics, and artifacts during training, registering model versions with stage transitions, and loading production models by name. It then confronts the three ways models degrade silently. Data drift — when feature distributions shift between training and serving — is detected with the Kolmogorov-Smirnov test and the Population Stability Index. Concept drift — when the relationship between features and targets changes — is detected by monitoring prediction-outcome correlation over time. Feature degradation — when an upstream pipeline breaks and a feature becomes constant, null, or nonsensical — is caught with variance, null rate, and cardinality monitors. Each detector produces actionable alerts with thresholds calibrated to avoid both alert fatigue and missed regressions.
Model Registries and Drift Detection
11.1 — Model Registries
Ask yourself four questions about the model currently serving production traffic. What hyperparameters was it trained with? What version of the training data did it use? How did it perform on the evaluation set? When was it last retrained?
If you cannot answer all four without digging through Slack messages, Jupyter notebooks, or someone’s memory, you have a registry problem. And a registry problem becomes an incident response problem the first time production predictions go wrong and you need to roll back to the last known-good model.
MLflow: Track Everything, Regret Nothing
MLflow is the de facto open-source standard for experiment tracking and model management. It solves three problems: recording what you tried (experiment tracking), storing what you produced (model registry), and loading what you need (model serving).
Here is a complete workflow — training a model, logging everything, registering the result, and loading it for inference:
import mlflow
import mlflow.sklearn
import numpy as np
from sklearn.datasets import make_classification
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.metrics import f1_score, precision_score, recall_score
from sklearn.model_selection import train_test_split
def train_and_register(
n_estimators: int = 200,
max_depth: int = 5,
learning_rate: float = 0.1,
experiment_name: str = "churn_prediction",
model_name: str = "churn_model",
) -> str:
"""
Train a model, log everything to MLflow, and register it.
Returns the run ID for downstream reference.
"""
mlflow.set_experiment(experiment_name)
# Generate realistic data (replace with your actual data loading)
X, y = make_classification(
n_samples=10_000, n_features=20, n_informative=12,
n_redundant=3, weights=[0.7, 0.3], random_state=42,
)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, stratify=y, random_state=42,
)
with mlflow.start_run() as run:
# Log training parameters — every knob you turned
mlflow.log_params({
"n_estimators": n_estimators,
"max_depth": max_depth,
"learning_rate": learning_rate,
"train_samples": len(X_train),
"test_samples": len(X_test),
"positive_rate": float(y_train.mean()),
})
# Train
model = GradientBoostingClassifier(
n_estimators=n_estimators,
max_depth=max_depth,
learning_rate=learning_rate,
random_state=42,
)
model.fit(X_train, y_train)
# Evaluate
y_pred = model.predict(X_test)
metrics = {
"f1": f1_score(y_test, y_pred),
"precision": precision_score(y_test, y_pred),
"recall": recall_score(y_test, y_pred),
}
mlflow.log_metrics(metrics)
# Log the model artifact — this stores the serialized model
mlflow.sklearn.log_model(
model, "model",
registered_model_name=model_name,
)
# Log feature importance as an artifact
importance = dict(zip(
[f"feature_{i}" for i in range(X_train.shape[1])],
model.feature_importances_.tolist(),
))
mlflow.log_dict(importance, "feature_importance.json")
print(f"Run ID: {run.info.run_id}")
print(f"Metrics: {metrics}")
return run.info.run_id
def load_production_model(model_name: str = "churn_model") -> object:
"""
Load the production model by name and stage.
In MLflow, models transition through stages:
None -> Staging -> Production -> Archived
"""
model_uri = f"models:/{model_name}/Production"
model = mlflow.sklearn.load_model(model_uri)
return model
# Usage:
# run_id = train_and_register(n_estimators=300, max_depth=4)
# model = load_production_model()
# predictions = model.predict(new_data)
Three things to notice. First, log_params captures every decision you made — not just hyperparameters, but data characteristics like the positive class rate and sample counts. When you are debugging a model six months from now, this metadata is the difference between “I think we trained on about 10K samples” and knowing exactly what happened. Second, log_model with registered_model_name does two things in one call: it stores the serialized model as an artifact and registers a new version in the model registry. Third, load_model with the stage URI (models:/churn_model/Production) means your inference code never references a file path. It references a logical name and a stage. You promote a model to production in the MLflow UI, and the inference service picks it up.
Beyond MLflow: Alternatives and Model Cards
Weights & Biases offers richer visualization and better team collaboration features, but it is a SaaS product with usage-based pricing. Neptune provides similar capabilities with a focus on experiment comparison. DVC is the lightweight alternative — it versions data and models using Git-like semantics but lacks the experiment tracking UI and model registry. For teams under ten people with straightforward workflows, DVC plus a naming convention can be sufficient. For anything larger, the registry abstraction that MLflow provides saves more time than it costs.
Regardless of your registry choice, every registered model should have a model card: a document that states the model’s intended use, known limitations, training data characteristics, fairness considerations, and performance across subgroups. A model card is not bureaucracy — it is the documentation that prevents someone from using your churn model to make lending decisions, or deploying a model trained on US data to serve European users without revalidation.
11.2 — Detecting Drift
Your model learned a function that maps features to a target. That function was correct for the training data distribution. The moment production data deviates from that distribution, the function’s guarantees evaporate. This deviation comes in three forms, each with different causes and different detectors.
Data Drift: Features Shift
Data drift occurs when the distribution of input features changes between training and serving. The model itself has not changed. The relationship between features and target has not changed. But the inputs the model receives in production look different from the inputs it was trained on.
Examples: a credit scoring model trained on pre-2020 income distributions receives post-pandemic applications where income volatility has doubled. A recommender trained on desktop browsing behavior receives mobile traffic with shorter sessions and different click patterns. A fraud detector trained on card-present transactions receives a surge of card-not-present transactions during a holiday sale.
The Kolmogorov-Smirnov test is the workhorse for detecting univariate data drift. It compares two distributions and returns a statistic measuring their maximum divergence:
from dataclasses import dataclass
import numpy as np
from scipy import stats
@dataclass
class DriftResult:
feature: str
statistic: float
p_value: float
is_drifted: bool
def detect_data_drift(
reference: dict[str, np.ndarray],
current: dict[str, np.ndarray],
p_threshold: float = 0.01,
) -> list[DriftResult]:
"""
Compare feature distributions between reference (training) and current
(production) data using the two-sample KS test.
Args:
reference: Feature name -> array of values from training data.
current: Feature name -> array of values from production window.
p_threshold: P-value below which drift is flagged. Use 0.01, not 0.05.
With large production samples, 0.05 triggers on statistically
significant but practically irrelevant shifts.
Returns:
List of DriftResult, one per feature, sorted by severity.
"""
results: list[DriftResult] = []
for feature_name in reference:
if feature_name not in current:
# Feature missing entirely — that is degradation, not drift
results.append(DriftResult(
feature=feature_name, statistic=1.0,
p_value=0.0, is_drifted=True,
))
continue
ref_values = reference[feature_name]
cur_values = current[feature_name]
ks_stat, p_value = stats.ks_2samp(ref_values, cur_values)
results.append(DriftResult(
feature=feature_name,
statistic=ks_stat,
p_value=p_value,
is_drifted=p_value < p_threshold,
))
# Sort by KS statistic descending — worst drift first
results.sort(key=lambda r: r.statistic, reverse=True)
return results
The p-threshold of 0.01 deserves explanation. With production samples of 10,000+ observations, a KS test at the conventional 0.05 threshold will flag tiny distribution shifts that have no practical impact on model performance. Use 0.01 or lower, and always pair statistical drift detection with a check on the actual prediction impact.
Population Stability Index (PSI)
The PSI is more interpretable than a p-value for stakeholder reporting. It quantifies how much a distribution has shifted by comparing bin proportions between reference and current data:
def compute_psi(
reference: np.ndarray,
current: np.ndarray,
n_bins: int = 10,
eps: float = 1e-4,
) -> float:
"""
Population Stability Index.
PSI < 0.1: no significant shift
PSI 0.1–0.2: moderate shift, investigate
PSI > 0.2: significant shift, action required
Uses quantile-based binning from the reference distribution
to handle skewed features correctly.
"""
# Create bins from reference distribution (not uniform bins)
quantiles = np.linspace(0, 100, n_bins + 1)
bin_edges = np.percentile(reference, quantiles)
bin_edges[0] = -np.inf
bin_edges[-1] = np.inf
ref_counts = np.histogram(reference, bins=bin_edges)[0]
cur_counts = np.histogram(current, bins=bin_edges)[0]
# Normalize to proportions, add epsilon to avoid division by zero
ref_pct = ref_counts / len(reference) + eps
cur_pct = cur_counts / len(current) + eps
psi = np.sum((cur_pct - ref_pct) * np.log(cur_pct / ref_pct))
return float(psi)
Use quantile-based binning from the reference distribution, not uniform bins. Uniform bins produce misleading PSI values for skewed features because most bins are empty.
Feature Degradation: The Silent Killer
Drift detection assumes features are present and valid but distributed differently. Feature degradation is worse — a feature has stopped being informative entirely. An upstream ETL pipeline changed its output format. A third-party API started returning nulls. A database migration silently set a column to its default value.
These failures produce no errors. The model receives a valid input, makes a prediction, and returns a response. The prediction is garbage because a critical feature is now constant, null, or nonsensical. Here is a detector that catches the three most common degradation patterns:
@dataclass
class DegradationAlert:
feature: str
alert_type: str # "null_rate", "low_variance", "cardinality_collapse"
current_value: float
baseline_value: float
message: str
def detect_feature_degradation(
reference: dict[str, np.ndarray],
current: dict[str, np.ndarray],
null_rate_threshold: float = 0.05,
variance_ratio_threshold: float = 0.1,
cardinality_ratio_threshold: float = 0.3,
) -> list[DegradationAlert]:
"""
Detect features that have degraded — not shifted in distribution,
but broken structurally.
Checks three failure modes:
1. Null rate spike: feature that was <1% null is now >5% null
2. Variance collapse: feature variance dropped below 10% of training variance
3. Cardinality collapse: number of unique values dropped below 30% of training
"""
alerts: list[DegradationAlert] = []
for feature_name, ref_values in reference.items():
cur_values = current.get(feature_name)
if cur_values is None:
alerts.append(DegradationAlert(
feature=feature_name, alert_type="missing_feature",
current_value=0.0, baseline_value=1.0,
message=f"Feature '{feature_name}' is entirely absent from "
f"production data.",
))
continue
# 1. Null rate spike
ref_null_rate = np.isnan(ref_values.astype(float)).mean()
cur_null_rate = np.isnan(cur_values.astype(float)).mean()
if cur_null_rate > null_rate_threshold and cur_null_rate > ref_null_rate * 3:
alerts.append(DegradationAlert(
feature=feature_name, alert_type="null_rate",
current_value=cur_null_rate, baseline_value=ref_null_rate,
message=f"Null rate for '{feature_name}' jumped from "
f"{ref_null_rate:.1%} to {cur_null_rate:.1%}.",
))
# 2. Variance collapse
ref_var = np.nanvar(ref_values.astype(float))
cur_var = np.nanvar(cur_values.astype(float))
if ref_var > 0:
variance_ratio = cur_var / ref_var
if variance_ratio < variance_ratio_threshold:
alerts.append(DegradationAlert(
feature=feature_name, alert_type="low_variance",
current_value=cur_var, baseline_value=ref_var,
message=f"Variance of '{feature_name}' collapsed to "
f"{variance_ratio:.1%} of training variance. "
f"Feature may be constant.",
))
# 3. Cardinality collapse
ref_unique = len(np.unique(ref_values[~np.isnan(ref_values.astype(float))]))
cur_unique = len(np.unique(cur_values[~np.isnan(cur_values.astype(float))]))
if ref_unique > 10: # Only check features with meaningful cardinality
cardinality_ratio = cur_unique / ref_unique
if cardinality_ratio < cardinality_ratio_threshold:
alerts.append(DegradationAlert(
feature=feature_name, alert_type="cardinality_collapse",
current_value=cur_unique, baseline_value=ref_unique,
message=f"Unique values for '{feature_name}' dropped from "
f"{ref_unique} to {cur_unique}.",
))
return alerts
Alert Thresholds: The Goldilocks Problem
Two failure modes of alerting are equally dangerous. Too sensitive: you alert on every minor distribution shift, your team develops alert fatigue, and when a real drift event occurs, nobody investigates because the alert channel is full of noise. Too lenient: you only alert on catastrophic drift, and by the time the alert fires, the model has been serving degraded predictions for weeks.
The calibration approach that works in practice:
- Start lenient. Set thresholds that would have caught the most severe drift event in your historical data. PSI > 0.25, KS p-value < 0.001, null rate > 10%.
- Track false negatives. Every time a model issue is discovered by someone other than the monitoring system, ask: which threshold would have caught this, and how much sooner?
- Tighten gradually. Lower thresholds until the monitoring system catches issues before humans do, then stop.
- Window size matters. Compare weekly production windows against the training baseline, not individual batches. Daily windows are noisy. Hourly windows are chaos.
The goal is not zero drift. The goal is catching drift that matters — drift that degrades predictions enough to affect the business metric you care about — before the business notices.