Distributed Training with Ray AIR
The ML Service uses Ray AIR (AI Runtime) as its distributed compute framework, providing unified APIs for data processing, training, hyperparameter tuning, and model serving. Ray AIR orchestrates workloads across Ray clusters, enabling horizontal scaling from a single machine to hundreds of GPUs.
Ray AIR Components
Ray AIR Ecosystem
|
+-- Ray Data ---------> Distributed data loading and preprocessing
|
+-- Ray Train --------> Distributed model training
| +-- TorchTrainer (PyTorch distributed)
| +-- XGBoostTrainer (XGBoost distributed)
| +-- SklearnTrainer (scikit-learn)
| +-- HorovodTrainer (Horovod for TF/PyTorch)
|
+-- Ray Tune ---------> Hyperparameter optimization
| +-- ASHA Scheduler (Early stopping)
| +-- Optuna (Bayesian optimization)
| +-- PBT (Population-based training)
|
+-- Ray Serve --------> Online model serving
+-- Deployment (Auto-scaling replicas)
+-- Pipeline (Multi-model composition)Ray AIR Orchestrator
The orchestrator.py in ray_air/ coordinates Ray AIR components:
class RayAIROrchestrator:
"""Orchestrates Ray AIR workflows."""
def __init__(
self,
ray_address: str = "ray://localhost:10001",
namespace: str = "matih-ml",
):
self._address = ray_address
self._namespace = namespace
async def submit_training(
self,
tenant_id: str,
config: TrainingConfig,
data_config: DataConfig,
) -> str:
"""Submit a training job to Ray."""
import ray
ray.init(address=self._address, namespace=self._namespace)
# Build scaling configuration
scaling = ScalingConfig(
num_workers=config.num_workers,
use_gpu=config.use_gpu,
resources_per_worker={
"CPU": 4,
"GPU": config.gpus_per_worker if config.use_gpu else 0,
},
)
# Build trainer based on framework
trainer = self._build_trainer(config, data_config, scaling)
# Submit as Ray job
job_id = ray.get_runtime_context().job_id
result = trainer.fit()
return job_idData-Parallel Training
Data-parallel training distributes the dataset across workers, each running a copy of the model:
class DataParallelTrainer:
"""Data-parallel distributed training."""
def train(
self,
config: TrainingConfig,
data_config: DataConfig,
) -> TrainingResult:
"""Execute data-parallel training."""
from ray.train.torch import TorchTrainer
def train_loop(config):
import torch
from ray.train import get_dataset_shard
# Each worker gets a shard of the data
dataset_shard = get_dataset_shard("train")
model = build_model(config["model_config"])
model = ray.train.torch.prepare_model(model)
optimizer = torch.optim.Adam(
model.parameters(),
lr=config["learning_rate"],
)
for epoch in range(config["epochs"]):
for batch in dataset_shard.iter_torch_batches(
batch_size=config["batch_size"]
):
loss = train_step(model, batch, optimizer)
# Report metrics
ray.train.report(
{"loss": loss, "epoch": epoch}
)
# Checkpoint
with tempfile.TemporaryDirectory() as tmpdir:
torch.save(model.state_dict(), f"{tmpdir}/model.pt")
ray.train.report(
{"loss": loss},
checkpoint=ray.train.Checkpoint.from_directory(tmpdir),
)
trainer = TorchTrainer(
train_loop_per_worker=train_loop,
train_loop_config={
"model_config": config.model_config,
"learning_rate": config.learning_rate,
"epochs": config.epochs,
"batch_size": config.batch_size,
},
scaling_config=ScalingConfig(
num_workers=config.num_workers,
use_gpu=True,
),
datasets={"train": ray_dataset},
)
return trainer.fit()Model-Parallel Training
For models that do not fit in a single GPU:
class ModelParallelConfig:
"""Configuration for model-parallel training."""
tensor_parallel_size: int = 2 # Split tensors across GPUs
pipeline_parallel_size: int = 1 # Pipeline stages
sequence_parallel: bool = False # Sequence parallelism
activation_checkpointing: bool = TrueDeepSpeed/FSDP Integration
class DeepSpeedFSDPTrainer:
"""Training with DeepSpeed or PyTorch FSDP."""
def train_with_deepspeed(
self,
model: Any,
config: dict,
) -> TrainingResult:
"""Train using DeepSpeed ZeRO optimization."""
import deepspeed
ds_config = {
"zero_optimization": {
"stage": 3, # ZeRO Stage 3
"offload_optimizer": {
"device": "cpu",
"pin_memory": True,
},
"offload_param": {
"device": "cpu",
"pin_memory": True,
},
},
"fp16": {"enabled": True},
"gradient_accumulation_steps": config.get(
"gradient_accumulation_steps", 4
),
}
model_engine, optimizer, _, _ = deepspeed.initialize(
model=model,
config=ds_config,
)
...| Parallelism | Memory Savings | Communication | Best For |
|---|---|---|---|
| Data Parallel | None | Gradients (AllReduce) | Small models, large data |
| ZeRO Stage 1 | Optimizer states | Gradients | Medium models |
| ZeRO Stage 2 | + Gradients | Parameters (Gather) | Large models |
| ZeRO Stage 3 | + Parameters | Parameters (Gather) | Very large models |
| Pipeline | Layer-wise | Activations | Deep models |
| Tensor | Tensor-wise | Tensor pieces | Wide layers |
Ray Data Integration
The ray_data_service.py manages distributed data loading:
class RayDataService:
"""Distributed data loading with Ray Data."""
async def load_dataset(
self,
data_config: DataConfig,
) -> ray.data.Dataset:
"""Load dataset as Ray Dataset."""
if data_config.format == "parquet":
ds = ray.data.read_parquet(data_config.train_path)
elif data_config.format == "csv":
ds = ray.data.read_csv(data_config.train_path)
elif data_config.format == "json":
ds = ray.data.read_json(data_config.train_path)
# Apply preprocessing
if data_config.normalize:
ds = ds.map_batches(self._normalize)
if data_config.shuffle:
ds = ds.random_shuffle()
return dsRay Serve Integration
The ray_serve_service.py manages model deployments:
class RayServeService:
"""Manages Ray Serve model deployments."""
async def deploy(
self,
model_name: str,
model_path: str,
num_replicas: int = 1,
resources: dict | None = None,
) -> str:
"""Deploy model to Ray Serve."""
...
async def scale(
self,
deployment_name: str,
num_replicas: int,
) -> None:
"""Scale a deployment."""
...
async def get_status(
self,
deployment_name: str,
) -> dict[str, Any]:
"""Get deployment status."""
...
async def undeploy(self, deployment_name: str) -> None:
"""Remove a deployment."""
...Cluster Management
| Configuration | Default | Description |
|---|---|---|
ray_address | ray://localhost:10001 | Ray cluster head node |
ray_namespace | matih-ml | Ray namespace for isolation |
ray_dashboard_port | 8265 | Ray Dashboard URL |
max_workers | 16 | Maximum Ray workers |
worker_cpu | 4 | CPUs per worker |
worker_gpu | 1 | GPUs per worker |
worker_memory_gb | 16 | Memory per worker |
Resource Quotas
Each tenant is assigned resource quotas to prevent resource monopolization:
| Quota | Free Tier | Standard | Enterprise |
|---|---|---|---|
| Max concurrent jobs | 1 | 5 | 20 |
| Max GPUs per job | 1 | 4 | 16 |
| Max training hours/day | 4 | 24 | Unlimited |
| Max workers | 2 | 8 | 32 |
API Endpoints
POST /api/v1/ray/training # Submit training job
GET /api/v1/ray/training/{job_id} # Job status
POST /api/v1/ray/training/{job_id}/cancel # Cancel job
GET /api/v1/ray/cluster/status # Cluster status
GET /api/v1/ray/cluster/resources # Available resources
POST /api/v1/ray/serve/deploy # Deploy model
GET /api/v1/ray/serve/deployments # List deployments
DELETE /api/v1/ray/serve/{name} # Undeploy model