Model Registry (MLflow)
The ML Service uses MLflow as its model registry, providing experiment tracking, model versioning, artifact management, and model stage lifecycle management. The integration is implemented in registry/model_registry.py and provides a tenant-scoped abstraction over MLflow's native API.
MLflow Integration Architecture
ML Service
|
+-- ModelRegistry (abstraction layer)
| |
| +-- MLflow Tracking API
| | +-- Experiments
| | +-- Runs
| | +-- Metrics/Params
| |
| +-- MLflow Model Registry API
| | +-- Registered Models
| | +-- Model Versions
| | +-- Stage Transitions
| |
| +-- MLflow Artifact Store
| +-- MinIO/S3 Backend
| +-- Model Binaries
| +-- Training Artifacts
|
v
MLflow Server (HTTP)
|
+-- PostgreSQL (metadata)
+-- MinIO/S3 (artifacts)Model Metadata
@dataclass
class ModelMetadata:
"""Model metadata stored in the registry."""
model_id: str
tenant_id: str
name: str
description: str
model_type: str # classification, regression, clustering
framework: str # sklearn, pytorch, tensorflow, onnx
current_version: str
versions: list[str]
stage: str # development, staging, production, archived
metrics: dict[str, float]
parameters: dict[str, Any]
tags: dict[str, str]
created_at: datetime
updated_at: datetime
created_by: str
artifact_uri: strModel Registry Operations
Register Model
class ModelRegistry:
"""MLflow-backed model registry with tenant isolation."""
async def register_model(
self,
tenant_id: str,
name: str,
model: Any,
framework: str,
metrics: dict[str, float],
parameters: dict[str, Any],
tags: dict[str, str] | None = None,
) -> ModelMetadata:
"""Register a new model or create a new version."""
# Tenant-scoped experiment name
experiment_name = f"tenants/{tenant_id}/models/{name}"
experiment = mlflow.get_or_create_experiment(experiment_name)
with mlflow.start_run(experiment_id=experiment.experiment_id):
# Log parameters
mlflow.log_params(parameters)
# Log metrics
mlflow.log_metrics(metrics)
# Log model based on framework
if framework == "sklearn":
mlflow.sklearn.log_model(model, "model")
elif framework == "pytorch":
mlflow.pytorch.log_model(model, "model")
elif framework == "onnx":
mlflow.onnx.log_model(model, "model")
# Register in model registry
model_uri = f"runs:/{mlflow.active_run().info.run_id}/model"
result = mlflow.register_model(
model_uri,
f"{tenant_id}/{name}",
)
# Tag with tenant information
mlflow.set_tags({
"tenant_id": tenant_id,
"framework": framework,
**(tags or {}),
})
return self._build_metadata(result, tenant_id)Get Model
async def get_model(
self,
tenant_id: str,
model_id: str,
) -> ModelMetadata | None:
"""Get model metadata by ID."""
try:
model = mlflow.MlflowClient().get_registered_model(
f"{tenant_id}/{model_id}"
)
return self._build_metadata(model, tenant_id)
except mlflow.exceptions.MlflowException:
return NoneList Models
async def list_models(
self,
tenant_id: str,
stage: str | None = None,
framework: str | None = None,
) -> list[ModelMetadata]:
"""List all models for a tenant."""
filter_string = f"name LIKE '{tenant_id}/%'"
if stage:
filter_string += f" AND tags.stage = '{stage}'"
models = mlflow.MlflowClient().search_registered_models(
filter_string=filter_string,
)
return [self._build_metadata(m, tenant_id) for m in models]Model Stage Lifecycle
DEVELOPMENT --> STAGING --> PRODUCTION --> ARCHIVED
| | |
v v v
(deleted) (deleted) DEVELOPMENT
(rollback)Stage Transitions
async def transition_stage(
self,
tenant_id: str,
model_id: str,
version: str,
target_stage: str,
) -> ModelMetadata:
"""Transition a model version to a new stage."""
client = mlflow.MlflowClient()
# Validate transition
self._validate_transition(current_stage, target_stage)
# If promoting to production, archive current production version
if target_stage == "production":
current_prod = self._get_production_version(
tenant_id, model_id
)
if current_prod:
client.transition_model_version_stage(
name=f"{tenant_id}/{model_id}",
version=current_prod,
stage="archived",
)
# Execute transition
client.transition_model_version_stage(
name=f"{tenant_id}/{model_id}",
version=version,
stage=target_stage,
)
return await self.get_model(tenant_id, model_id)| Transition | Requirements |
|---|---|
| Development -> Staging | Validation pipeline passes |
| Staging -> Production | A/B test or shadow deployment passes |
| Production -> Archived | New production version promoted |
| Any -> Development | Rollback request |
Experiment Tracking
Logging Metrics
async def log_training_metrics(
self,
tenant_id: str,
run_id: str,
step: int,
metrics: dict[str, float],
) -> None:
"""Log training metrics for a run."""
client = mlflow.MlflowClient()
for key, value in metrics.items():
client.log_metric(run_id, key, value, step=step)Comparing Experiments
async def compare_runs(
self,
tenant_id: str,
run_ids: list[str],
metrics: list[str],
) -> list[dict[str, Any]]:
"""Compare multiple experiment runs."""
comparisons = []
for run_id in run_ids:
run = mlflow.get_run(run_id)
comparison = {
"run_id": run_id,
"parameters": run.data.params,
"metrics": {
m: run.data.metrics.get(m)
for m in metrics
},
}
comparisons.append(comparison)
return comparisonsArtifact Management
Model artifacts are stored in MinIO/S3 and managed through MLflow:
| Artifact Type | Storage Path | Format |
|---|---|---|
| Model binary | s3://mlflow/{tenant_id}/models/{name}/{version}/ | Framework-specific |
| ONNX model | s3://mlflow/{tenant_id}/models/{name}/{version}/model.onnx | ONNX |
| Training logs | s3://mlflow/{tenant_id}/runs/{run_id}/logs/ | Text |
| Metrics history | s3://mlflow/{tenant_id}/runs/{run_id}/metrics/ | JSON |
| Requirements | s3://mlflow/{tenant_id}/models/{name}/{version}/requirements.txt | Text |
| Model card | s3://mlflow/{tenant_id}/models/{name}/{version}/model_card.md | Markdown |
API Endpoints
POST /api/v1/models # Register model
GET /api/v1/models # List models
GET /api/v1/models/{model_id} # Get model
DELETE /api/v1/models/{model_id} # Delete model
POST /api/v1/models/{model_id}/versions # Create version
PUT /api/v1/models/{model_id}/stage # Transition stage
GET /api/v1/experiments # List experiments
GET /api/v1/experiments/{exp_id}/runs # List runs
GET /api/v1/experiments/compare # Compare runs