Kourosh Hakhamaneshi 09a17b5415 drafting ray integration
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>

Former-commit-id: 19c12ddae9350f6e25a270fe3372f5b9094cf960
2025-01-07 08:55:44 +00:00

28 lines
754 B
Python

from typing import Any, Callable, Dict
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig
from .ray_train_args import RayTrainArguments
def get_ray_trainer(
training_function: Callable,
train_loop_config: Dict[str, Any],
ray_args: RayTrainArguments,
) -> TorchTrainer:
if not ray_args.use_ray:
raise ValueError("Ray is not enabled. Please set USE_RAY=1 in your environment.")
trainer = TorchTrainer(
training_function,
train_loop_config=train_loop_config,
scaling_config=ScalingConfig(
num_workers=ray_args.num_workers,
resources_per_worker=ray_args.resources_per_worker,
use_gpu=True,
),
)
return trainer