Distributed Training (Ray)
The DistributedTrainer class provides multi-framework distributed training using Ray Train with support for data-parallel, model-parallel, and pipeline-parallel strategies.
Training Configuration
from src.training.distributed_trainer import (
DistributedTrainer, TrainingConfig, DataConfig, DistributedStrategy
)
training_config = TrainingConfig(
model_class="sklearn.ensemble.RandomForestClassifier",
model_params={"n_estimators": 200, "max_depth": 10},
epochs=10,
learning_rate=0.001,
optimizer="adam",
loss_function="cross_entropy",
strategy=DistributedStrategy.DATA_PARALLEL,
num_workers=4,
use_gpu=False,
resources_per_worker={"CPU": 2},
checkpoint_frequency=1,
keep_checkpoints=3,
early_stopping=True,
patience=5,
min_delta=0.001,
mlflow_tracking_uri="http://mlflow:5000",
experiment_name="fraud-detection-v3",
)
data_config = DataConfig(
train_path="s3://bucket/train.parquet",
validation_path="s3://bucket/val.parquet",
batch_size=32,
shuffle=True,
format="parquet",
target_column="label",
)Running Training
trainer = DistributedTrainer(ray_address="ray://localhost:10001")
result = await trainer.train(
training_config=training_config,
data_config=data_config,
tenant_id="acme-corp",
user_id="alice@acme.com",
)
print(result.status) # "succeeded"
print(result.final_metrics) # {"val_loss": 0.287, "val_accuracy": 0.912}
print(result.best_checkpoint_path) # "/tmp/matih/training/acme-corp/.../checkpoint"
print(result.duration_seconds) # 342.5Distributed Strategies
| Strategy | Description | Best For |
|---|---|---|
DATA_PARALLEL | Replicate model across workers, shard data | Most training tasks |
MODEL_PARALLEL | Split model across workers | Very large models |
PIPELINE_PARALLEL | Pipeline stages across workers | Deep sequential models |
HYBRID | Combine data and model parallelism | Large model + large data |
Framework-Specific Trainers
The trainer automatically selects the appropriate Ray trainer based on the model class name:
if "torch" in model_class: -> TorchTrainer
elif "tensorflow" in model_class: -> TensorflowTrainer
elif "xgboost" in model_class: -> XGBoostTrainer
elif "lightgbm" in model_class: -> LightGBMTrainer
else: -> Generic TrainerData Loading
Data is loaded using Ray Data and automatically sharded across workers:
# Supported formats: parquet, csv, json
train_ds = ray.data.read_parquet(data_config.train_path)
# Shard equally across workers
datasets["train"] = train_ds.split(world_size, equal=True)[world_rank]Early Stopping
Training automatically stops when validation loss stops improving:
# Config
early_stopping=True
patience=5 # Stop after 5 epochs without improvement
min_delta=0.001 # Minimum improvement thresholdSource Files
| File | Path |
|---|---|
| DistributedTrainer | data-plane/ml-service/src/training/distributed_trainer.py |
| DeepSpeed/FSDP Trainer | data-plane/ml-service/src/training/deepspeed_fsdp_trainer.py |
| Distributed Workflow | data-plane/ml-service/src/training/distributed_workflow_service.py |