mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	[ray] allow for specifying ray.init kwargs (i.e. runtime_env) (#7647)
* ray init kwargs * Update trainer_utils.py * fix ray args --------- Co-authored-by: hoshi-hiyouga <hiyouga@buaa.edu.cn>
This commit is contained in:
		
							parent
							
								
									1c436c9f25
								
							
						
					
					
						commit
						bb8d79bae2
					
				@ -31,10 +31,16 @@ report_to: none  # choices: [none, wandb, tensorboard, swanlab, mlflow]
 | 
			
		||||
### ray
 | 
			
		||||
ray_run_name: llama3_8b_sft_lora
 | 
			
		||||
ray_storage_path: ./saves
 | 
			
		||||
ray_num_workers: 4  # number of GPUs to use
 | 
			
		||||
ray_num_workers: 4  # Number of GPUs to use.
 | 
			
		||||
placement_strategy: PACK
 | 
			
		||||
resources_per_worker:
 | 
			
		||||
  GPU: 1
 | 
			
		||||
placement_strategy: PACK
 | 
			
		||||
# ray_init_kwargs:
 | 
			
		||||
#   runtime_env:
 | 
			
		||||
#     env_vars:
 | 
			
		||||
#       <YOUR-ENV-VAR-HERE>: "<YOUR-ENV-VAR-HERE>"
 | 
			
		||||
#     pip:
 | 
			
		||||
#       - emoji
 | 
			
		||||
 | 
			
		||||
### train
 | 
			
		||||
per_device_train_batch_size: 1
 | 
			
		||||
 | 
			
		||||
@ -46,6 +46,10 @@ class RayArguments:
 | 
			
		||||
        default="PACK",
 | 
			
		||||
        metadata={"help": "The placement strategy for Ray training. Default is PACK."},
 | 
			
		||||
    )
 | 
			
		||||
    ray_init_kwargs: Optional[dict] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "The arguments to pass to ray.init for Ray training. Default is None."},
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        self.use_ray = use_ray()
 | 
			
		||||
 | 
			
		||||
@ -48,6 +48,7 @@ if is_apollo_available():
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if is_ray_available():
 | 
			
		||||
    import ray
 | 
			
		||||
    from ray.train import RunConfig, ScalingConfig
 | 
			
		||||
    from ray.train.torch import TorchTrainer
 | 
			
		||||
 | 
			
		||||
@ -644,6 +645,9 @@ def get_ray_trainer(
 | 
			
		||||
    if not ray_args.use_ray:
 | 
			
		||||
        raise ValueError("Ray was not enabled. Please set `USE_RAY=1` to enable ray.")
 | 
			
		||||
 | 
			
		||||
    if ray_args.ray_init_kwargs is not None:
 | 
			
		||||
        ray.init(**ray_args.ray_init_kwargs)
 | 
			
		||||
 | 
			
		||||
    trainer = TorchTrainer(
 | 
			
		||||
        training_function,
 | 
			
		||||
        train_loop_config=train_loop_config,
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user