mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-03-04 02:35:59 +08:00
[feature] support using ray.remote to start distributed training. (#10109)
This commit is contained in:
@@ -14,7 +14,6 @@
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal
|
||||
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from transformers.training_args import _convert_str_dict
|
||||
@@ -40,56 +39,29 @@ else:
|
||||
class RayArguments:
|
||||
r"""Arguments pertaining to the Ray training."""
|
||||
|
||||
ray_run_name: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The training results will be saved at `<ray_storage_path>/ray_run_name`."},
|
||||
)
|
||||
ray_storage_path: str = field(
|
||||
default="./saves",
|
||||
metadata={"help": "The storage path to save training results to"},
|
||||
)
|
||||
ray_storage_filesystem: Literal["s3", "gs", "gcs"] | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The storage filesystem to use. If None specified, local filesystem will be used."},
|
||||
)
|
||||
ray_num_workers: int = field(
|
||||
default=1,
|
||||
metadata={"help": "The number of workers for Ray training. Default is 1 worker."},
|
||||
)
|
||||
resources_per_worker: dict | str = field(
|
||||
default_factory=lambda: {"GPU": 1},
|
||||
metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."},
|
||||
)
|
||||
placement_strategy: Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"] = field(
|
||||
default="PACK",
|
||||
metadata={"help": "The placement strategy for Ray training. Default is PACK."},
|
||||
)
|
||||
ray_init_kwargs: dict | str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The arguments to pass to ray.init for Ray training. Default is None."},
|
||||
)
|
||||
master_addr: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The master address for init_process_group"},
|
||||
)
|
||||
master_port: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The master port for init_process_group"},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
self.use_ray = use_ray()
|
||||
if isinstance(self.resources_per_worker, str) and self.resources_per_worker.startswith("{"):
|
||||
self.resources_per_worker = _convert_str_dict(json.loads(self.resources_per_worker))
|
||||
|
||||
if isinstance(self.ray_init_kwargs, str) and self.ray_init_kwargs.startswith("{"):
|
||||
self.ray_init_kwargs = _convert_str_dict(json.loads(self.ray_init_kwargs))
|
||||
|
||||
if self.ray_storage_filesystem is not None:
|
||||
if self.ray_storage_filesystem not in ["s3", "gs", "gcs"]:
|
||||
raise ValueError(
|
||||
f"ray_storage_filesystem must be one of ['s3', 'gs', 'gcs'], got {self.ray_storage_filesystem}."
|
||||
)
|
||||
|
||||
import pyarrow.fs as fs
|
||||
|
||||
if self.ray_storage_filesystem == "s3":
|
||||
self.ray_storage_filesystem = fs.S3FileSystem()
|
||||
elif self.ray_storage_filesystem == "gs" or self.ray_storage_filesystem == "gcs":
|
||||
self.ray_storage_filesystem = fs.GcsFileSystem()
|
||||
|
||||
|
||||
@dataclass
|
||||
class Fp8Arguments:
|
||||
|
||||
Reference in New Issue
Block a user