mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 06:12:50 +08:00
[ray] add storage filesystem to ray config (#7854)
This commit is contained in:
parent
036a76e9cb
commit
b4407e4b0b
@ -34,6 +34,10 @@ class RayArguments:
|
|||||||
default="./saves",
|
default="./saves",
|
||||||
metadata={"help": "The storage path to save training results to"},
|
metadata={"help": "The storage path to save training results to"},
|
||||||
)
|
)
|
||||||
|
ray_storage_filesystem: Optional[Literal["s3", "gs", "gcs"]] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The storage filesystem to use. If None specified, local filesystem will be used."},
|
||||||
|
)
|
||||||
ray_num_workers: int = field(
|
ray_num_workers: int = field(
|
||||||
default=1,
|
default=1,
|
||||||
metadata={"help": "The number of workers for Ray training. Default is 1 worker."},
|
metadata={"help": "The number of workers for Ray training. Default is 1 worker."},
|
||||||
@ -55,6 +59,17 @@ class RayArguments:
|
|||||||
self.use_ray = use_ray()
|
self.use_ray = use_ray()
|
||||||
if isinstance(self.resources_per_worker, str) and self.resources_per_worker.startswith("{"):
|
if isinstance(self.resources_per_worker, str) and self.resources_per_worker.startswith("{"):
|
||||||
self.resources_per_worker = _convert_str_dict(json.loads(self.resources_per_worker))
|
self.resources_per_worker = _convert_str_dict(json.loads(self.resources_per_worker))
|
||||||
|
if self.ray_storage_filesystem is not None:
|
||||||
|
if self.ray_storage_filesystem not in ["s3", "gs", "gcs"]:
|
||||||
|
raise ValueError(
|
||||||
|
f"ray_storage_filesystem must be one of ['s3', 'gs', 'gcs'], got {self.ray_storage_filesystem}"
|
||||||
|
)
|
||||||
|
import pyarrow.fs as fs
|
||||||
|
|
||||||
|
if self.ray_storage_filesystem == "s3":
|
||||||
|
self.ray_storage_filesystem = fs.S3FileSystem()
|
||||||
|
elif self.ray_storage_filesystem == "gs" or self.ray_storage_filesystem == "gcs":
|
||||||
|
self.ray_storage_filesystem = fs.GcsFileSystem()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -680,6 +680,12 @@ def get_ray_trainer(
|
|||||||
if ray_args.ray_init_kwargs is not None:
|
if ray_args.ray_init_kwargs is not None:
|
||||||
ray.init(**ray_args.ray_init_kwargs)
|
ray.init(**ray_args.ray_init_kwargs)
|
||||||
|
|
||||||
|
if ray_args.ray_storage_filesystem is not None:
|
||||||
|
# this means we are using s3/gcs
|
||||||
|
storage_path = ray_args.ray_storage_path
|
||||||
|
else:
|
||||||
|
storage_path = Path(ray_args.ray_storage_path).absolute().as_posix()
|
||||||
|
|
||||||
trainer = TorchTrainer(
|
trainer = TorchTrainer(
|
||||||
training_function,
|
training_function,
|
||||||
train_loop_config=train_loop_config,
|
train_loop_config=train_loop_config,
|
||||||
@ -691,7 +697,8 @@ def get_ray_trainer(
|
|||||||
),
|
),
|
||||||
run_config=RunConfig(
|
run_config=RunConfig(
|
||||||
name=ray_args.ray_run_name,
|
name=ray_args.ray_run_name,
|
||||||
storage_path=Path(ray_args.ray_storage_path).absolute().as_posix(),
|
storage_filesystem=ray_args.ray_storage_filesystem,
|
||||||
|
storage_path=storage_path,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return trainer
|
return trainer
|
||||||
|
Loading…
x
Reference in New Issue
Block a user