MATIH Platform is in active MVP development. Documentation reflects current implementation status.
13. ML Service & MLOps
Distributed Training (Ray)

Distributed Training (Ray)

The DistributedTrainer class provides multi-framework distributed training using Ray Train with support for data-parallel, model-parallel, and pipeline-parallel strategies.


Training Configuration

from src.training.distributed_trainer import (
    DistributedTrainer, TrainingConfig, DataConfig, DistributedStrategy
)
 
training_config = TrainingConfig(
    model_class="sklearn.ensemble.RandomForestClassifier",
    model_params={"n_estimators": 200, "max_depth": 10},
    epochs=10,
    learning_rate=0.001,
    optimizer="adam",
    loss_function="cross_entropy",
    strategy=DistributedStrategy.DATA_PARALLEL,
    num_workers=4,
    use_gpu=False,
    resources_per_worker={"CPU": 2},
    checkpoint_frequency=1,
    keep_checkpoints=3,
    early_stopping=True,
    patience=5,
    min_delta=0.001,
    mlflow_tracking_uri="http://mlflow:5000",
    experiment_name="fraud-detection-v3",
)
 
data_config = DataConfig(
    train_path="s3://bucket/train.parquet",
    validation_path="s3://bucket/val.parquet",
    batch_size=32,
    shuffle=True,
    format="parquet",
    target_column="label",
)

Running Training

trainer = DistributedTrainer(ray_address="ray://localhost:10001")
 
result = await trainer.train(
    training_config=training_config,
    data_config=data_config,
    tenant_id="acme-corp",
    user_id="alice@acme.com",
)
 
print(result.status)               # "succeeded"
print(result.final_metrics)        # {"val_loss": 0.287, "val_accuracy": 0.912}
print(result.best_checkpoint_path) # "/tmp/matih/training/acme-corp/.../checkpoint"
print(result.duration_seconds)     # 342.5

Distributed Strategies

StrategyDescriptionBest For
DATA_PARALLELReplicate model across workers, shard dataMost training tasks
MODEL_PARALLELSplit model across workersVery large models
PIPELINE_PARALLELPipeline stages across workersDeep sequential models
HYBRIDCombine data and model parallelismLarge model + large data

Framework-Specific Trainers

The trainer automatically selects the appropriate Ray trainer based on the model class name:

if "torch" in model_class:     -> TorchTrainer
elif "tensorflow" in model_class: -> TensorflowTrainer
elif "xgboost" in model_class: -> XGBoostTrainer
elif "lightgbm" in model_class: -> LightGBMTrainer
else:                           -> Generic Trainer

Data Loading

Data is loaded using Ray Data and automatically sharded across workers:

# Supported formats: parquet, csv, json
train_ds = ray.data.read_parquet(data_config.train_path)
# Shard equally across workers
datasets["train"] = train_ds.split(world_size, equal=True)[world_rank]

Early Stopping

Training automatically stops when validation loss stops improving:

# Config
early_stopping=True
patience=5        # Stop after 5 epochs without improvement
min_delta=0.001   # Minimum improvement threshold

Source Files

FilePath
DistributedTrainerdata-plane/ml-service/src/training/distributed_trainer.py
DeepSpeed/FSDP Trainerdata-plane/ml-service/src/training/deepspeed_fsdp_trainer.py
Distributed Workflowdata-plane/ml-service/src/training/distributed_workflow_service.py