mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
[ray] allow for specifying ray.init kwargs (i.e. runtime_env) (#7647)
* ray init kwargs * Update trainer_utils.py * fix ray args --------- Co-authored-by: hoshi-hiyouga <hiyouga@buaa.edu.cn>
This commit is contained in:
parent
ee840b4e01
commit
39c1e29ed7
@ -31,10 +31,16 @@ report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
|||||||
### ray
|
### ray
|
||||||
ray_run_name: llama3_8b_sft_lora
|
ray_run_name: llama3_8b_sft_lora
|
||||||
ray_storage_path: ./saves
|
ray_storage_path: ./saves
|
||||||
ray_num_workers: 4 # number of GPUs to use
|
ray_num_workers: 4 # Number of GPUs to use.
|
||||||
|
placement_strategy: PACK
|
||||||
resources_per_worker:
|
resources_per_worker:
|
||||||
GPU: 1
|
GPU: 1
|
||||||
placement_strategy: PACK
|
# ray_init_kwargs:
|
||||||
|
# runtime_env:
|
||||||
|
# env_vars:
|
||||||
|
# <YOUR-ENV-VAR-HERE>: "<YOUR-ENV-VAR-HERE>"
|
||||||
|
# pip:
|
||||||
|
# - emoji
|
||||||
|
|
||||||
### train
|
### train
|
||||||
per_device_train_batch_size: 1
|
per_device_train_batch_size: 1
|
||||||
|
@ -46,6 +46,10 @@ class RayArguments:
|
|||||||
default="PACK",
|
default="PACK",
|
||||||
metadata={"help": "The placement strategy for Ray training. Default is PACK."},
|
metadata={"help": "The placement strategy for Ray training. Default is PACK."},
|
||||||
)
|
)
|
||||||
|
ray_init_kwargs: Optional[dict] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The arguments to pass to ray.init for Ray training. Default is None."},
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.use_ray = use_ray()
|
self.use_ray = use_ray()
|
||||||
|
@ -48,6 +48,7 @@ if is_apollo_available():
|
|||||||
|
|
||||||
|
|
||||||
if is_ray_available():
|
if is_ray_available():
|
||||||
|
import ray
|
||||||
from ray.train import RunConfig, ScalingConfig
|
from ray.train import RunConfig, ScalingConfig
|
||||||
from ray.train.torch import TorchTrainer
|
from ray.train.torch import TorchTrainer
|
||||||
|
|
||||||
@ -644,6 +645,9 @@ def get_ray_trainer(
|
|||||||
if not ray_args.use_ray:
|
if not ray_args.use_ray:
|
||||||
raise ValueError("Ray was not enabled. Please set `USE_RAY=1` to enable ray.")
|
raise ValueError("Ray was not enabled. Please set `USE_RAY=1` to enable ray.")
|
||||||
|
|
||||||
|
if ray_args.ray_init_kwargs is not None:
|
||||||
|
ray.init(**ray_args.ray_init_kwargs)
|
||||||
|
|
||||||
trainer = TorchTrainer(
|
trainer = TorchTrainer(
|
||||||
training_function,
|
training_function,
|
||||||
train_loop_config=train_loop_config,
|
train_loop_config=train_loop_config,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user