From 39c1e29ed747026432f885a9c2e6d4c9c7c7e6d8 Mon Sep 17 00:00:00 2001 From: Eric Tang <46737979+erictang000@users.noreply.github.com> Date: Wed, 9 Apr 2025 20:31:05 -0700 Subject: [PATCH] [ray] allow for specifying ray.init kwargs (i.e. runtime_env) (#7647) * ray init kwargs * Update trainer_utils.py * fix ray args --------- Co-authored-by: hoshi-hiyouga --- examples/train_lora/llama3_lora_sft_ray.yaml | 10 ++++++++-- src/llamafactory/hparams/training_args.py | 4 ++++ src/llamafactory/train/trainer_utils.py | 4 ++++ 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/examples/train_lora/llama3_lora_sft_ray.yaml b/examples/train_lora/llama3_lora_sft_ray.yaml index e7e4b390..8c03bf9e 100644 --- a/examples/train_lora/llama3_lora_sft_ray.yaml +++ b/examples/train_lora/llama3_lora_sft_ray.yaml @@ -31,10 +31,16 @@ report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow] ### ray ray_run_name: llama3_8b_sft_lora ray_storage_path: ./saves -ray_num_workers: 4 # number of GPUs to use +ray_num_workers: 4 # Number of GPUs to use. +placement_strategy: PACK resources_per_worker: GPU: 1 -placement_strategy: PACK +# ray_init_kwargs: +# runtime_env: +# env_vars: +# : "" +# pip: +# - emoji ### train per_device_train_batch_size: 1 diff --git a/src/llamafactory/hparams/training_args.py b/src/llamafactory/hparams/training_args.py index dfc41caf..ee9c4e93 100644 --- a/src/llamafactory/hparams/training_args.py +++ b/src/llamafactory/hparams/training_args.py @@ -46,6 +46,10 @@ class RayArguments: default="PACK", metadata={"help": "The placement strategy for Ray training. Default is PACK."}, ) + ray_init_kwargs: Optional[dict] = field( + default=None, + metadata={"help": "The arguments to pass to ray.init for Ray training. Default is None."}, + ) def __post_init__(self): self.use_ray = use_ray() diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index e9e58d7e..89459a82 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -48,6 +48,7 @@ if is_apollo_available(): if is_ray_available(): + import ray from ray.train import RunConfig, ScalingConfig from ray.train.torch import TorchTrainer @@ -644,6 +645,9 @@ def get_ray_trainer( if not ray_args.use_ray: raise ValueError("Ray was not enabled. Please set `USE_RAY=1` to enable ray.") + if ray_args.ray_init_kwargs is not None: + ray.init(**ray_args.ray_init_kwargs) + trainer = TorchTrainer( training_function, train_loop_config=train_loop_config,