From 1217240918073a7dbd2291d335f9494b15be5a7c Mon Sep 17 00:00:00 2001 From: Kourosh Hakhamaneshi Date: Mon, 30 Dec 2024 16:48:52 -0800 Subject: [PATCH] drafting ray integration Signed-off-by: Kourosh Hakhamaneshi Former-commit-id: 163ddb680b6f84a4424a887a3b8a5d668044e87c --- examples/train_lora/llama3_lora_sft.yaml | 8 +++ src/llamafactory/cli.py | 5 +- src/llamafactory/hparams/parser.py | 56 +++++++++++++------ src/llamafactory/integrations/__init__.py | 0 src/llamafactory/integrations/ray/__init__.py | 0 .../integrations/ray/ray_train.py | 28 ++++++++++ .../integrations/ray/ray_train_args.py | 22 ++++++++ .../integrations/ray/ray_utils.py | 9 +++ src/llamafactory/train/tuner.py | 36 +++++++++++- 9 files changed, 143 insertions(+), 21 deletions(-) create mode 100644 src/llamafactory/integrations/__init__.py create mode 100644 src/llamafactory/integrations/ray/__init__.py create mode 100644 src/llamafactory/integrations/ray/ray_train.py create mode 100644 src/llamafactory/integrations/ray/ray_train_args.py create mode 100644 src/llamafactory/integrations/ray/ray_utils.py diff --git a/examples/train_lora/llama3_lora_sft.yaml b/examples/train_lora/llama3_lora_sft.yaml index 243f2445..558ac3e9 100644 --- a/examples/train_lora/llama3_lora_sft.yaml +++ b/examples/train_lora/llama3_lora_sft.yaml @@ -9,6 +9,7 @@ finetuning_type: lora lora_target: all ### dataset +dataset_dir: /home/ray/default/lf/data/ dataset: identity,alpaca_en_demo template: llama3 cutoff_len: 2048 @@ -38,3 +39,10 @@ val_size: 0.1 per_device_eval_batch_size: 1 eval_strategy: steps eval_steps: 500 + + +### ray setup +resources_per_worker: + GPU: 1 +num_workers: 4 +# placement_strategy: ... diff --git a/src/llamafactory/cli.py b/src/llamafactory/cli.py index ef5bd1dc..26d7a3df 100644 --- a/src/llamafactory/cli.py +++ b/src/llamafactory/cli.py @@ -27,7 +27,7 @@ from .extras.env import VERSION, print_env from .extras.misc import get_device_count from .train.tuner import export_model, run_exp from .webui.interface import run_web_demo, run_web_ui - +from .integrations.ray.ray_utils import should_use_ray USAGE = ( "-" * 70 @@ -87,7 +87,8 @@ def main(): export_model() elif command == Command.TRAIN: force_torchrun = os.getenv("FORCE_TORCHRUN", "0").lower() in ["true", "1"] - if force_torchrun or get_device_count() > 1: + use_ray = should_use_ray() + if force_torchrun or (get_device_count() > 1 and not use_ray): master_addr = os.getenv("MASTER_ADDR", "127.0.0.1") master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999))) logger.info_rank0(f"Initializing distributed tasks at: {master_addr}:{master_port}") diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 4a254367..06853ae8 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -19,6 +19,10 @@ import os import sys from typing import Any, Dict, Optional, Tuple +import json +import yaml +from pathlib import Path + import torch import transformers from transformers import HfArgumentParser, Seq2SeqTrainingArguments @@ -37,39 +41,51 @@ from .finetuning_args import FinetuningArguments from .generating_args import GeneratingArguments from .model_args import ModelArguments +from ..integrations.ray.ray_train_args import RayTrainArguments logger = logging.get_logger(__name__) - check_dependencies() -_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments] -_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments] +_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments, RayTrainArguments] +_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments, RayTrainArguments] _INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] _INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] _EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] _EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] -def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]: +def _read_args(args: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: if args is not None: - return parser.parse_dict(args) + return args if len(sys.argv) == 2 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")): - return parser.parse_yaml_file(os.path.abspath(sys.argv[1])) + # read yaml file + return yaml.safe_load(Path(sys.argv[1]).absolute().read_text()) + elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # read json file + return json.loads(Path(sys.argv[1]).absolute().read_text()) + else: + return {} - if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): - return parser.parse_json_file(os.path.abspath(sys.argv[1])) - (*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(return_remaining_strings=True) +def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None, allow_extra_keys: bool = False) -> Tuple[Any]: - if unknown_args: - print(parser.format_help()) - print(f"Got unknown args, potentially deprecated arguments: {unknown_args}") - raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}") + args_dict = _read_args(args) + + if args_dict: + return parser.parse_dict(args_dict, allow_extra_keys=allow_extra_keys) + else: + (*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(args=args_dict, return_remaining_strings=True) - return (*parsed_args,) + if unknown_args: + print(parser.format_help()) + print(f"Got unknown args, potentially deprecated arguments: {unknown_args}") + raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}") + + return (*parsed_args,) + def _set_transformers_logging() -> None: @@ -161,9 +177,17 @@ def _parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS: return _parse_args(parser, args) -def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: - model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args) +def _parse_ray_args(args: Optional[Dict[str, Any]] = None) -> RayTrainArguments: + parser = HfArgumentParser(RayTrainArguments) + ray_args = _parse_args(parser, args, allow_extra_keys=True)[0] + if ray_args.use_ray: + require_version("ray", "To fix: pip install ray") + return ray_args + +def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: + model_args, data_args, training_args, finetuning_args, generating_args, _ = _parse_train_args(args) + # Setup logging if training_args.should_log: _set_transformers_logging() diff --git a/src/llamafactory/integrations/__init__.py b/src/llamafactory/integrations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/llamafactory/integrations/ray/__init__.py b/src/llamafactory/integrations/ray/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/llamafactory/integrations/ray/ray_train.py b/src/llamafactory/integrations/ray/ray_train.py new file mode 100644 index 00000000..4620bb5a --- /dev/null +++ b/src/llamafactory/integrations/ray/ray_train.py @@ -0,0 +1,28 @@ + +from typing import Any, Callable, Dict + +from ray.train.torch import TorchTrainer +from ray.train import ScalingConfig + +from .ray_train_args import RayTrainArguments + + +def get_ray_trainer( + training_function: Callable, + train_loop_config: Dict[str, Any], + ray_args: RayTrainArguments, +) -> TorchTrainer: + + if not ray_args.use_ray: + raise ValueError("Ray is not enabled. Please set USE_RAY=1 in your environment.") + + trainer = TorchTrainer( + training_function, + train_loop_config=train_loop_config, + scaling_config=ScalingConfig( + num_workers=ray_args.num_workers, + resources_per_worker=ray_args.resources_per_worker, + use_gpu=True, + ), + ) + return trainer \ No newline at end of file diff --git a/src/llamafactory/integrations/ray/ray_train_args.py b/src/llamafactory/integrations/ray/ray_train_args.py new file mode 100644 index 00000000..9ee9dc8e --- /dev/null +++ b/src/llamafactory/integrations/ray/ray_train_args.py @@ -0,0 +1,22 @@ +from dataclasses import dataclass, field +from typing import Any, Dict, Literal, Optional + +from .ray_utils import should_use_ray + +@dataclass +class RayTrainArguments: + r""" + Arguments pertaining to the Ray training. + """ + resources_per_worker: Optional[Dict[str, Any]] = field(default_factory=lambda: {"GPU": 1}, metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."}) + num_workers: Optional[int] = field(default=1, metadata={"help": "The number of workers for Ray training. Default is 1 worker."}) + placement_strategy: Optional[Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"]] = field(default="PACK", metadata={"help": "The placement strategy for Ray training. Default is PACK."}) + + @property + def use_ray(self) -> bool: + """ + Always returns the value from the environment variable check. + This prevents manual setting of use_ray. + """ + return should_use_ray() + diff --git a/src/llamafactory/integrations/ray/ray_utils.py b/src/llamafactory/integrations/ray/ray_utils.py new file mode 100644 index 00000000..67ce2ed9 --- /dev/null +++ b/src/llamafactory/integrations/ray/ray_utils.py @@ -0,0 +1,9 @@ + +import os + + +def should_use_ray(): + return os.getenv("USE_RAY", "0").lower() in ["true", "1"] + + + diff --git a/src/llamafactory/train/tuner.py b/src/llamafactory/train/tuner.py index 6c79320e..507a1f14 100644 --- a/src/llamafactory/train/tuner.py +++ b/src/llamafactory/train/tuner.py @@ -23,6 +23,7 @@ from ..data import get_template_and_fix_tokenizer from ..extras import logging from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from ..hparams import get_infer_args, get_train_args +from ..hparams.parser import _parse_ray_args, _read_args from ..model import load_model, load_tokenizer from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback from .dpo import run_dpo @@ -36,12 +37,14 @@ from .trainer_utils import get_swanlab_callback if TYPE_CHECKING: from transformers import TrainerCallback - + logger = logging.get_logger(__name__) - -def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None: +def training_function(config: Dict[str, Any]) -> None: + args = config.get("args", None) + callbacks = config.get("callbacks", []) + callbacks.append(LogCallback()) model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args) @@ -68,6 +71,33 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallb else: raise ValueError(f"Unknown task: {finetuning_args.stage}.") +def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None: + + args_dict = _read_args(args) + ray_args = _parse_ray_args(args_dict) + + if ray_args.use_ray: + # Import lazily to avoid ray not installed error + from ..integrations.ray.ray_train import get_ray_trainer + + # Initialize ray trainer + trainer = get_ray_trainer( + training_function=training_function, + train_loop_config={ + "args": args_dict, + "callbacks": callbacks, + }, + ray_args=ray_args, + ) + trainer.fit() + else: + training_function( + config={ + "args": args_dict, + "callbacks": callbacks, + } + ) + def export_model(args: Optional[Dict[str, Any]] = None) -> None: model_args, data_args, finetuning_args, _ = get_infer_args(args)