From b4407e4b0b1b3e0eaccddb9fd43b758fae39a57b Mon Sep 17 00:00:00 2001 From: Eric Tang <46737979+erictang000@users.noreply.github.com> Date: Sun, 27 Apr 2025 07:12:40 -0700 Subject: [PATCH] [ray] add storage filesystem to ray config (#7854) --- src/llamafactory/hparams/training_args.py | 15 +++++++++++++++ src/llamafactory/train/trainer_utils.py | 9 ++++++++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/src/llamafactory/hparams/training_args.py b/src/llamafactory/hparams/training_args.py index ee9c4e93..fae9a6a3 100644 --- a/src/llamafactory/hparams/training_args.py +++ b/src/llamafactory/hparams/training_args.py @@ -34,6 +34,10 @@ class RayArguments: default="./saves", metadata={"help": "The storage path to save training results to"}, ) + ray_storage_filesystem: Optional[Literal["s3", "gs", "gcs"]] = field( + default=None, + metadata={"help": "The storage filesystem to use. If None specified, local filesystem will be used."}, + ) ray_num_workers: int = field( default=1, metadata={"help": "The number of workers for Ray training. Default is 1 worker."}, @@ -55,6 +59,17 @@ class RayArguments: self.use_ray = use_ray() if isinstance(self.resources_per_worker, str) and self.resources_per_worker.startswith("{"): self.resources_per_worker = _convert_str_dict(json.loads(self.resources_per_worker)) + if self.ray_storage_filesystem is not None: + if self.ray_storage_filesystem not in ["s3", "gs", "gcs"]: + raise ValueError( + f"ray_storage_filesystem must be one of ['s3', 'gs', 'gcs'], got {self.ray_storage_filesystem}" + ) + import pyarrow.fs as fs + + if self.ray_storage_filesystem == "s3": + self.ray_storage_filesystem = fs.S3FileSystem() + elif self.ray_storage_filesystem == "gs" or self.ray_storage_filesystem == "gcs": + self.ray_storage_filesystem = fs.GcsFileSystem() @dataclass diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index 12ee6bb3..e626c0f0 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -680,6 +680,12 @@ def get_ray_trainer( if ray_args.ray_init_kwargs is not None: ray.init(**ray_args.ray_init_kwargs) + if ray_args.ray_storage_filesystem is not None: + # this means we are using s3/gcs + storage_path = ray_args.ray_storage_path + else: + storage_path = Path(ray_args.ray_storage_path).absolute().as_posix() + trainer = TorchTrainer( training_function, train_loop_config=train_loop_config, @@ -691,7 +697,8 @@ def get_ray_trainer( ), run_config=RunConfig( name=ray_args.ray_run_name, - storage_path=Path(ray_args.ray_storage_path).absolute().as_posix(), + storage_filesystem=ray_args.ray_storage_filesystem, + storage_path=storage_path, ), ) return trainer