mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-11-05 10:22:15 +08:00
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com> Former-commit-id: 19c12ddae9350f6e25a270fe3372f5b9094cf960
28 lines
754 B
Python
28 lines
754 B
Python
|
|
from typing import Any, Callable, Dict
|
|
|
|
from ray.train.torch import TorchTrainer
|
|
from ray.train import ScalingConfig
|
|
|
|
from .ray_train_args import RayTrainArguments
|
|
|
|
|
|
def get_ray_trainer(
|
|
training_function: Callable,
|
|
train_loop_config: Dict[str, Any],
|
|
ray_args: RayTrainArguments,
|
|
) -> TorchTrainer:
|
|
|
|
if not ray_args.use_ray:
|
|
raise ValueError("Ray is not enabled. Please set USE_RAY=1 in your environment.")
|
|
|
|
trainer = TorchTrainer(
|
|
training_function,
|
|
train_loop_config=train_loop_config,
|
|
scaling_config=ScalingConfig(
|
|
num_workers=ray_args.num_workers,
|
|
resources_per_worker=ray_args.resources_per_worker,
|
|
use_gpu=True,
|
|
),
|
|
)
|
|
return trainer |