From 1217240918073a7dbd2291d335f9494b15be5a7c Mon Sep 17 00:00:00 2001 From: Kourosh Hakhamaneshi Date: Mon, 30 Dec 2024 16:48:52 -0800 Subject: [PATCH 1/4] drafting ray integration Signed-off-by: Kourosh Hakhamaneshi Former-commit-id: 163ddb680b6f84a4424a887a3b8a5d668044e87c --- examples/train_lora/llama3_lora_sft.yaml | 8 +++ src/llamafactory/cli.py | 5 +- src/llamafactory/hparams/parser.py | 56 +++++++++++++------ src/llamafactory/integrations/__init__.py | 0 src/llamafactory/integrations/ray/__init__.py | 0 .../integrations/ray/ray_train.py | 28 ++++++++++ .../integrations/ray/ray_train_args.py | 22 ++++++++ .../integrations/ray/ray_utils.py | 9 +++ src/llamafactory/train/tuner.py | 36 +++++++++++- 9 files changed, 143 insertions(+), 21 deletions(-) create mode 100644 src/llamafactory/integrations/__init__.py create mode 100644 src/llamafactory/integrations/ray/__init__.py create mode 100644 src/llamafactory/integrations/ray/ray_train.py create mode 100644 src/llamafactory/integrations/ray/ray_train_args.py create mode 100644 src/llamafactory/integrations/ray/ray_utils.py diff --git a/examples/train_lora/llama3_lora_sft.yaml b/examples/train_lora/llama3_lora_sft.yaml index 243f2445..558ac3e9 100644 --- a/examples/train_lora/llama3_lora_sft.yaml +++ b/examples/train_lora/llama3_lora_sft.yaml @@ -9,6 +9,7 @@ finetuning_type: lora lora_target: all ### dataset +dataset_dir: /home/ray/default/lf/data/ dataset: identity,alpaca_en_demo template: llama3 cutoff_len: 2048 @@ -38,3 +39,10 @@ val_size: 0.1 per_device_eval_batch_size: 1 eval_strategy: steps eval_steps: 500 + + +### ray setup +resources_per_worker: + GPU: 1 +num_workers: 4 +# placement_strategy: ... diff --git a/src/llamafactory/cli.py b/src/llamafactory/cli.py index ef5bd1dc..26d7a3df 100644 --- a/src/llamafactory/cli.py +++ b/src/llamafactory/cli.py @@ -27,7 +27,7 @@ from .extras.env import VERSION, print_env from .extras.misc import get_device_count from .train.tuner import export_model, run_exp from .webui.interface import run_web_demo, run_web_ui - +from .integrations.ray.ray_utils import should_use_ray USAGE = ( "-" * 70 @@ -87,7 +87,8 @@ def main(): export_model() elif command == Command.TRAIN: force_torchrun = os.getenv("FORCE_TORCHRUN", "0").lower() in ["true", "1"] - if force_torchrun or get_device_count() > 1: + use_ray = should_use_ray() + if force_torchrun or (get_device_count() > 1 and not use_ray): master_addr = os.getenv("MASTER_ADDR", "127.0.0.1") master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999))) logger.info_rank0(f"Initializing distributed tasks at: {master_addr}:{master_port}") diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 4a254367..06853ae8 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -19,6 +19,10 @@ import os import sys from typing import Any, Dict, Optional, Tuple +import json +import yaml +from pathlib import Path + import torch import transformers from transformers import HfArgumentParser, Seq2SeqTrainingArguments @@ -37,39 +41,51 @@ from .finetuning_args import FinetuningArguments from .generating_args import GeneratingArguments from .model_args import ModelArguments +from ..integrations.ray.ray_train_args import RayTrainArguments logger = logging.get_logger(__name__) - check_dependencies() -_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments] -_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments] +_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments, RayTrainArguments] +_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments, RayTrainArguments] _INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] _INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] _EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] _EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] -def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]: +def _read_args(args: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: if args is not None: - return parser.parse_dict(args) + return args if len(sys.argv) == 2 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")): - return parser.parse_yaml_file(os.path.abspath(sys.argv[1])) + # read yaml file + return yaml.safe_load(Path(sys.argv[1]).absolute().read_text()) + elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # read json file + return json.loads(Path(sys.argv[1]).absolute().read_text()) + else: + return {} - if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): - return parser.parse_json_file(os.path.abspath(sys.argv[1])) - (*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(return_remaining_strings=True) +def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None, allow_extra_keys: bool = False) -> Tuple[Any]: - if unknown_args: - print(parser.format_help()) - print(f"Got unknown args, potentially deprecated arguments: {unknown_args}") - raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}") + args_dict = _read_args(args) + + if args_dict: + return parser.parse_dict(args_dict, allow_extra_keys=allow_extra_keys) + else: + (*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(args=args_dict, return_remaining_strings=True) - return (*parsed_args,) + if unknown_args: + print(parser.format_help()) + print(f"Got unknown args, potentially deprecated arguments: {unknown_args}") + raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}") + + return (*parsed_args,) + def _set_transformers_logging() -> None: @@ -161,9 +177,17 @@ def _parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS: return _parse_args(parser, args) -def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: - model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args) +def _parse_ray_args(args: Optional[Dict[str, Any]] = None) -> RayTrainArguments: + parser = HfArgumentParser(RayTrainArguments) + ray_args = _parse_args(parser, args, allow_extra_keys=True)[0] + if ray_args.use_ray: + require_version("ray", "To fix: pip install ray") + return ray_args + +def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: + model_args, data_args, training_args, finetuning_args, generating_args, _ = _parse_train_args(args) + # Setup logging if training_args.should_log: _set_transformers_logging() diff --git a/src/llamafactory/integrations/__init__.py b/src/llamafactory/integrations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/llamafactory/integrations/ray/__init__.py b/src/llamafactory/integrations/ray/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/llamafactory/integrations/ray/ray_train.py b/src/llamafactory/integrations/ray/ray_train.py new file mode 100644 index 00000000..4620bb5a --- /dev/null +++ b/src/llamafactory/integrations/ray/ray_train.py @@ -0,0 +1,28 @@ + +from typing import Any, Callable, Dict + +from ray.train.torch import TorchTrainer +from ray.train import ScalingConfig + +from .ray_train_args import RayTrainArguments + + +def get_ray_trainer( + training_function: Callable, + train_loop_config: Dict[str, Any], + ray_args: RayTrainArguments, +) -> TorchTrainer: + + if not ray_args.use_ray: + raise ValueError("Ray is not enabled. Please set USE_RAY=1 in your environment.") + + trainer = TorchTrainer( + training_function, + train_loop_config=train_loop_config, + scaling_config=ScalingConfig( + num_workers=ray_args.num_workers, + resources_per_worker=ray_args.resources_per_worker, + use_gpu=True, + ), + ) + return trainer \ No newline at end of file diff --git a/src/llamafactory/integrations/ray/ray_train_args.py b/src/llamafactory/integrations/ray/ray_train_args.py new file mode 100644 index 00000000..9ee9dc8e --- /dev/null +++ b/src/llamafactory/integrations/ray/ray_train_args.py @@ -0,0 +1,22 @@ +from dataclasses import dataclass, field +from typing import Any, Dict, Literal, Optional + +from .ray_utils import should_use_ray + +@dataclass +class RayTrainArguments: + r""" + Arguments pertaining to the Ray training. + """ + resources_per_worker: Optional[Dict[str, Any]] = field(default_factory=lambda: {"GPU": 1}, metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."}) + num_workers: Optional[int] = field(default=1, metadata={"help": "The number of workers for Ray training. Default is 1 worker."}) + placement_strategy: Optional[Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"]] = field(default="PACK", metadata={"help": "The placement strategy for Ray training. Default is PACK."}) + + @property + def use_ray(self) -> bool: + """ + Always returns the value from the environment variable check. + This prevents manual setting of use_ray. + """ + return should_use_ray() + diff --git a/src/llamafactory/integrations/ray/ray_utils.py b/src/llamafactory/integrations/ray/ray_utils.py new file mode 100644 index 00000000..67ce2ed9 --- /dev/null +++ b/src/llamafactory/integrations/ray/ray_utils.py @@ -0,0 +1,9 @@ + +import os + + +def should_use_ray(): + return os.getenv("USE_RAY", "0").lower() in ["true", "1"] + + + diff --git a/src/llamafactory/train/tuner.py b/src/llamafactory/train/tuner.py index 6c79320e..507a1f14 100644 --- a/src/llamafactory/train/tuner.py +++ b/src/llamafactory/train/tuner.py @@ -23,6 +23,7 @@ from ..data import get_template_and_fix_tokenizer from ..extras import logging from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from ..hparams import get_infer_args, get_train_args +from ..hparams.parser import _parse_ray_args, _read_args from ..model import load_model, load_tokenizer from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback from .dpo import run_dpo @@ -36,12 +37,14 @@ from .trainer_utils import get_swanlab_callback if TYPE_CHECKING: from transformers import TrainerCallback - + logger = logging.get_logger(__name__) - -def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None: +def training_function(config: Dict[str, Any]) -> None: + args = config.get("args", None) + callbacks = config.get("callbacks", []) + callbacks.append(LogCallback()) model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args) @@ -68,6 +71,33 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallb else: raise ValueError(f"Unknown task: {finetuning_args.stage}.") +def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None: + + args_dict = _read_args(args) + ray_args = _parse_ray_args(args_dict) + + if ray_args.use_ray: + # Import lazily to avoid ray not installed error + from ..integrations.ray.ray_train import get_ray_trainer + + # Initialize ray trainer + trainer = get_ray_trainer( + training_function=training_function, + train_loop_config={ + "args": args_dict, + "callbacks": callbacks, + }, + ray_args=ray_args, + ) + trainer.fit() + else: + training_function( + config={ + "args": args_dict, + "callbacks": callbacks, + } + ) + def export_model(args: Optional[Dict[str, Any]] = None) -> None: model_args, data_args, finetuning_args, _ = get_infer_args(args) From bba52e258e7474a64b8080b159a3ce762e1ee67c Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Mon, 6 Jan 2025 23:55:56 +0000 Subject: [PATCH 2/4] run style check Former-commit-id: 1e8e7be0a535e55888f58bbe2c38bc1c382e9012 --- examples/train_lora/llama3_lora_sft.yaml | 2 +- src/llamafactory/cli.py | 3 +- src/llamafactory/hparams/parser.py | 41 +++++++++++++------ .../integrations/ray/ray_train.py | 10 ++--- .../integrations/ray/ray_train_args.py | 16 ++++++-- .../integrations/ray/ray_utils.py | 4 -- src/llamafactory/train/tuner.py | 13 +++--- 7 files changed, 54 insertions(+), 35 deletions(-) diff --git a/examples/train_lora/llama3_lora_sft.yaml b/examples/train_lora/llama3_lora_sft.yaml index 558ac3e9..dc4f5add 100644 --- a/examples/train_lora/llama3_lora_sft.yaml +++ b/examples/train_lora/llama3_lora_sft.yaml @@ -9,7 +9,7 @@ finetuning_type: lora lora_target: all ### dataset -dataset_dir: /home/ray/default/lf/data/ +dataset_dir: /home/ray/default/LLaMA-Factory/data/ dataset: identity,alpaca_en_demo template: llama3 cutoff_len: 2048 diff --git a/src/llamafactory/cli.py b/src/llamafactory/cli.py index 26d7a3df..a2a8e29d 100644 --- a/src/llamafactory/cli.py +++ b/src/llamafactory/cli.py @@ -25,9 +25,10 @@ from .eval.evaluator import run_eval from .extras import logging from .extras.env import VERSION, print_env from .extras.misc import get_device_count +from .integrations.ray.ray_utils import should_use_ray from .train.tuner import export_model, run_exp from .webui.interface import run_web_demo, run_web_ui -from .integrations.ray.ray_utils import should_use_ray + USAGE = ( "-" * 70 diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 06853ae8..8cdfa7cb 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -15,16 +15,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import os import sys -from typing import Any, Dict, Optional, Tuple - -import json -import yaml from pathlib import Path +from typing import Any, Dict, Optional, Tuple import torch import transformers +import yaml from transformers import HfArgumentParser, Seq2SeqTrainingArguments from transformers.integrations import is_deepspeed_zero3_enabled from transformers.trainer_utils import get_last_checkpoint @@ -35,21 +34,35 @@ from transformers.utils.versions import require_version from ..extras import logging from ..extras.constants import CHECKPOINT_NAMES from ..extras.misc import check_dependencies, get_current_device +from ..integrations.ray.ray_train_args import RayTrainArguments from .data_args import DataArguments from .evaluation_args import EvaluationArguments from .finetuning_args import FinetuningArguments from .generating_args import GeneratingArguments from .model_args import ModelArguments -from ..integrations.ray.ray_train_args import RayTrainArguments logger = logging.get_logger(__name__) check_dependencies() -_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments, RayTrainArguments] -_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments, RayTrainArguments] +_TRAIN_ARGS = [ + ModelArguments, + DataArguments, + Seq2SeqTrainingArguments, + FinetuningArguments, + GeneratingArguments, + RayTrainArguments, +] +_TRAIN_CLS = Tuple[ + ModelArguments, + DataArguments, + Seq2SeqTrainingArguments, + FinetuningArguments, + GeneratingArguments, + RayTrainArguments, +] _INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] _INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] _EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] @@ -70,14 +83,17 @@ def _read_args(args: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: return {} -def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None, allow_extra_keys: bool = False) -> Tuple[Any]: - +def _parse_args( + parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None, allow_extra_keys: bool = False +) -> Tuple[Any]: args_dict = _read_args(args) - + if args_dict: return parser.parse_dict(args_dict, allow_extra_keys=allow_extra_keys) else: - (*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(args=args_dict, return_remaining_strings=True) + (*parsed_args, unknown_args) = parser.parse_args_into_dataclasses( + args=args_dict, return_remaining_strings=True + ) if unknown_args: print(parser.format_help()) @@ -85,7 +101,6 @@ def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = Non raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}") return (*parsed_args,) - def _set_transformers_logging() -> None: @@ -187,7 +202,7 @@ def _parse_ray_args(args: Optional[Dict[str, Any]] = None) -> RayTrainArguments: def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: model_args, data_args, training_args, finetuning_args, generating_args, _ = _parse_train_args(args) - + # Setup logging if training_args.should_log: _set_transformers_logging() diff --git a/src/llamafactory/integrations/ray/ray_train.py b/src/llamafactory/integrations/ray/ray_train.py index 4620bb5a..50a2927a 100644 --- a/src/llamafactory/integrations/ray/ray_train.py +++ b/src/llamafactory/integrations/ray/ray_train.py @@ -1,21 +1,19 @@ - from typing import Any, Callable, Dict -from ray.train.torch import TorchTrainer from ray.train import ScalingConfig +from ray.train.torch import TorchTrainer from .ray_train_args import RayTrainArguments - + def get_ray_trainer( training_function: Callable, train_loop_config: Dict[str, Any], ray_args: RayTrainArguments, ) -> TorchTrainer: - if not ray_args.use_ray: raise ValueError("Ray is not enabled. Please set USE_RAY=1 in your environment.") - + trainer = TorchTrainer( training_function, train_loop_config=train_loop_config, @@ -25,4 +23,4 @@ def get_ray_trainer( use_gpu=True, ), ) - return trainer \ No newline at end of file + return trainer diff --git a/src/llamafactory/integrations/ray/ray_train_args.py b/src/llamafactory/integrations/ray/ray_train_args.py index 9ee9dc8e..afacbee1 100644 --- a/src/llamafactory/integrations/ray/ray_train_args.py +++ b/src/llamafactory/integrations/ray/ray_train_args.py @@ -3,14 +3,23 @@ from typing import Any, Dict, Literal, Optional from .ray_utils import should_use_ray + @dataclass class RayTrainArguments: r""" Arguments pertaining to the Ray training. """ - resources_per_worker: Optional[Dict[str, Any]] = field(default_factory=lambda: {"GPU": 1}, metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."}) - num_workers: Optional[int] = field(default=1, metadata={"help": "The number of workers for Ray training. Default is 1 worker."}) - placement_strategy: Optional[Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"]] = field(default="PACK", metadata={"help": "The placement strategy for Ray training. Default is PACK."}) + + resources_per_worker: Optional[Dict[str, Any]] = field( + default_factory=lambda: {"GPU": 1}, + metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."}, + ) + num_workers: Optional[int] = field( + default=1, metadata={"help": "The number of workers for Ray training. Default is 1 worker."} + ) + placement_strategy: Optional[Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"]] = field( + default="PACK", metadata={"help": "The placement strategy for Ray training. Default is PACK."} + ) @property def use_ray(self) -> bool: @@ -19,4 +28,3 @@ class RayTrainArguments: This prevents manual setting of use_ray. """ return should_use_ray() - diff --git a/src/llamafactory/integrations/ray/ray_utils.py b/src/llamafactory/integrations/ray/ray_utils.py index 67ce2ed9..8b8b045e 100644 --- a/src/llamafactory/integrations/ray/ray_utils.py +++ b/src/llamafactory/integrations/ray/ray_utils.py @@ -1,9 +1,5 @@ - import os def should_use_ray(): return os.getenv("USE_RAY", "0").lower() in ["true", "1"] - - - diff --git a/src/llamafactory/train/tuner.py b/src/llamafactory/train/tuner.py index 507a1f14..8c461890 100644 --- a/src/llamafactory/train/tuner.py +++ b/src/llamafactory/train/tuner.py @@ -37,14 +37,15 @@ from .trainer_utils import get_swanlab_callback if TYPE_CHECKING: from transformers import TrainerCallback - + logger = logging.get_logger(__name__) + def training_function(config: Dict[str, Any]) -> None: args = config.get("args", None) callbacks = config.get("callbacks", []) - + callbacks.append(LogCallback()) model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args) @@ -71,15 +72,15 @@ def training_function(config: Dict[str, Any]) -> None: else: raise ValueError(f"Unknown task: {finetuning_args.stage}.") + def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None: - args_dict = _read_args(args) ray_args = _parse_ray_args(args_dict) - - if ray_args.use_ray: + + if ray_args.use_ray: # Import lazily to avoid ray not installed error from ..integrations.ray.ray_train import get_ray_trainer - + # Initialize ray trainer trainer = get_ray_trainer( training_function=training_function, From b4174021d68f46d5be2a28e5a8fcb2dcfae3810b Mon Sep 17 00:00:00 2001 From: hiyouga Date: Tue, 7 Jan 2025 08:54:41 +0000 Subject: [PATCH 3/4] refactor ray integration, support save ckpt Former-commit-id: d8cac6f54663e6cffeddf2c65e3da454e7b86a75 --- .env.local | 1 + examples/README.md | 6 ++ examples/README_zh.md | 6 ++ examples/train_lora/llama3_lora_sft.yaml | 8 -- examples/train_lora/llama3_lora_sft_ray.yaml | 48 ++++++++++++ src/llamafactory/cli.py | 6 +- src/llamafactory/extras/misc.py | 10 ++- src/llamafactory/extras/packages.py | 4 + src/llamafactory/hparams/__init__.py | 7 +- src/llamafactory/hparams/parser.py | 77 +++++++------------ src/llamafactory/hparams/training_args.py | 48 ++++++++++++ src/llamafactory/integrations/__init__.py | 0 src/llamafactory/integrations/ray/__init__.py | 0 .../integrations/ray/ray_train.py | 26 ------- .../integrations/ray/ray_train_args.py | 30 -------- .../integrations/ray/ray_utils.py | 5 -- src/llamafactory/train/trainer_utils.py | 53 ++++++++++--- src/llamafactory/train/tuner.py | 41 ++++------ 18 files changed, 215 insertions(+), 161 deletions(-) create mode 100644 examples/train_lora/llama3_lora_sft_ray.yaml create mode 100644 src/llamafactory/hparams/training_args.py delete mode 100644 src/llamafactory/integrations/__init__.py delete mode 100644 src/llamafactory/integrations/ray/__init__.py delete mode 100644 src/llamafactory/integrations/ray/ray_train.py delete mode 100644 src/llamafactory/integrations/ray/ray_train_args.py delete mode 100644 src/llamafactory/integrations/ray/ray_utils.py diff --git a/.env.local b/.env.local index 203aebaf..1d8f2a00 100644 --- a/.env.local +++ b/.env.local @@ -12,6 +12,7 @@ FORCE_CHECK_IMPORTS= LLAMAFACTORY_VERBOSITY= USE_MODELSCOPE_HUB= USE_OPENMIND_HUB= +USE_RAY= RECORD_VRAM= # torchrun FORCE_TORCHRUN= diff --git a/examples/README.md b/examples/README.md index cc2afc46..89f7d174 100644 --- a/examples/README.md +++ b/examples/README.md @@ -95,6 +95,12 @@ FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml ``` +#### Supervised Fine-Tuning with Ray on 4 GPUs + +```bash +USE_RAY=1 llamafactory-cli train examples/train_full/llama3_lora_sft_ray.yaml +``` + ### QLoRA Fine-Tuning #### Supervised Fine-Tuning with 4/8-bit Bitsandbytes/HQQ/EETQ Quantization (Recommended) diff --git a/examples/README_zh.md b/examples/README_zh.md index b41d7ab8..2c108e56 100644 --- a/examples/README_zh.md +++ b/examples/README_zh.md @@ -95,6 +95,12 @@ FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml ``` +#### 使用 Ray 在 4 张 GPU 上微调 + +```bash +USE_RAY=1 llamafactory-cli train examples/train_full/llama3_lora_sft_ray.yaml +``` + ### QLoRA 微调 #### 基于 4/8 比特 Bitsandbytes/HQQ/EETQ 量化进行指令监督微调(推荐) diff --git a/examples/train_lora/llama3_lora_sft.yaml b/examples/train_lora/llama3_lora_sft.yaml index dc4f5add..243f2445 100644 --- a/examples/train_lora/llama3_lora_sft.yaml +++ b/examples/train_lora/llama3_lora_sft.yaml @@ -9,7 +9,6 @@ finetuning_type: lora lora_target: all ### dataset -dataset_dir: /home/ray/default/LLaMA-Factory/data/ dataset: identity,alpaca_en_demo template: llama3 cutoff_len: 2048 @@ -39,10 +38,3 @@ val_size: 0.1 per_device_eval_batch_size: 1 eval_strategy: steps eval_steps: 500 - - -### ray setup -resources_per_worker: - GPU: 1 -num_workers: 4 -# placement_strategy: ... diff --git a/examples/train_lora/llama3_lora_sft_ray.yaml b/examples/train_lora/llama3_lora_sft_ray.yaml new file mode 100644 index 00000000..4aac08bc --- /dev/null +++ b/examples/train_lora/llama3_lora_sft_ray.yaml @@ -0,0 +1,48 @@ +### model +model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct # or use local absolute path +trust_remote_code: true + +### method +stage: sft +do_train: true +finetuning_type: lora +lora_target: all + +### dataset +dataset: identity,alpaca_en_demo +dataset_dir: REMOTE:llamafactory/demo_data # or use local absolute path +template: llama3 +cutoff_len: 2048 +max_samples: 1000 +overwrite_cache: true +preprocessing_num_workers: 16 + +### output +output_dir: tmp_dir +logging_steps: 10 +save_steps: 500 +plot_loss: true +overwrite_output_dir: true + +### train +per_device_train_batch_size: 1 +gradient_accumulation_steps: 8 +learning_rate: 1.0e-4 +num_train_epochs: 3.0 +lr_scheduler_type: cosine +warmup_ratio: 0.1 +bf16: true +ddp_timeout: 180000000 + +### eval +val_size: 0.1 +per_device_eval_batch_size: 1 +eval_strategy: steps +eval_steps: 500 + +### ray +ray_run_name: llama3_8b_sft_lora +ray_num_workers: 4 # number of GPUs to use +resources_per_worker: + GPU: 1 +placement_strategy: PACK diff --git a/src/llamafactory/cli.py b/src/llamafactory/cli.py index a2a8e29d..72085e2d 100644 --- a/src/llamafactory/cli.py +++ b/src/llamafactory/cli.py @@ -24,8 +24,7 @@ from .chat.chat_model import run_chat from .eval.evaluator import run_eval from .extras import logging from .extras.env import VERSION, print_env -from .extras.misc import get_device_count -from .integrations.ray.ray_utils import should_use_ray +from .extras.misc import get_device_count, use_ray from .train.tuner import export_model, run_exp from .webui.interface import run_web_demo, run_web_ui @@ -88,8 +87,7 @@ def main(): export_model() elif command == Command.TRAIN: force_torchrun = os.getenv("FORCE_TORCHRUN", "0").lower() in ["true", "1"] - use_ray = should_use_ray() - if force_torchrun or (get_device_count() > 1 and not use_ray): + if force_torchrun or (get_device_count() > 1 and not use_ray()): master_addr = os.getenv("MASTER_ADDR", "127.0.0.1") master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999))) logger.info_rank0(f"Initializing distributed tasks at: {master_addr}:{master_port}") diff --git a/src/llamafactory/extras/misc.py b/src/llamafactory/extras/misc.py index 735c5d63..11797f9f 100644 --- a/src/llamafactory/extras/misc.py +++ b/src/llamafactory/extras/misc.py @@ -229,7 +229,7 @@ def skip_check_imports() -> None: r""" Avoids flash attention import error in custom model files. """ - if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]: + if os.getenv("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]: transformers.dynamic_module_utils.check_imports = get_relative_imports @@ -275,8 +275,12 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str: def use_modelscope() -> bool: - return os.environ.get("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"] + return os.getenv("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"] def use_openmind() -> bool: - return os.environ.get("USE_OPENMIND_HUB", "0").lower() in ["true", "1"] + return os.getenv("USE_OPENMIND_HUB", "0").lower() in ["true", "1"] + + +def use_ray() -> bool: + return os.getenv("USE_RAY", "0").lower() in ["true", "1"] diff --git a/src/llamafactory/extras/packages.py b/src/llamafactory/extras/packages.py index 44b9bb8a..6b2bc3f3 100644 --- a/src/llamafactory/extras/packages.py +++ b/src/llamafactory/extras/packages.py @@ -62,6 +62,10 @@ def is_pillow_available(): return _is_package_available("PIL") +def is_ray_available(): + return _is_package_available("ray") + + def is_requests_available(): return _is_package_available("requests") diff --git a/src/llamafactory/hparams/__init__.py b/src/llamafactory/hparams/__init__.py index cfe448c1..254a845e 100644 --- a/src/llamafactory/hparams/__init__.py +++ b/src/llamafactory/hparams/__init__.py @@ -17,7 +17,8 @@ from .evaluation_args import EvaluationArguments from .finetuning_args import FinetuningArguments from .generating_args import GeneratingArguments from .model_args import ModelArguments -from .parser import get_eval_args, get_infer_args, get_train_args +from .parser import get_eval_args, get_infer_args, get_ray_args, get_train_args, read_args +from .training_args import RayArguments, TrainingArguments __all__ = [ @@ -26,7 +27,11 @@ __all__ = [ "FinetuningArguments", "GeneratingArguments", "ModelArguments", + "RayArguments", + "TrainingArguments", "get_eval_args", "get_infer_args", + "get_ray_args", "get_train_args", + "read_args", ] diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 8cdfa7cb..62edbf78 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -19,12 +19,12 @@ import json import os import sys from pathlib import Path -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union import torch import transformers import yaml -from transformers import HfArgumentParser, Seq2SeqTrainingArguments +from transformers import HfArgumentParser from transformers.integrations import is_deepspeed_zero3_enabled from transformers.trainer_utils import get_last_checkpoint from transformers.training_args import ParallelMode @@ -34,12 +34,12 @@ from transformers.utils.versions import require_version from ..extras import logging from ..extras.constants import CHECKPOINT_NAMES from ..extras.misc import check_dependencies, get_current_device -from ..integrations.ray.ray_train_args import RayTrainArguments from .data_args import DataArguments from .evaluation_args import EvaluationArguments from .finetuning_args import FinetuningArguments from .generating_args import GeneratingArguments from .model_args import ModelArguments +from .training_args import RayArguments, TrainingArguments logger = logging.get_logger(__name__) @@ -47,60 +47,41 @@ logger = logging.get_logger(__name__) check_dependencies() -_TRAIN_ARGS = [ - ModelArguments, - DataArguments, - Seq2SeqTrainingArguments, - FinetuningArguments, - GeneratingArguments, - RayTrainArguments, -] -_TRAIN_CLS = Tuple[ - ModelArguments, - DataArguments, - Seq2SeqTrainingArguments, - FinetuningArguments, - GeneratingArguments, - RayTrainArguments, -] +_TRAIN_ARGS = [ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments] +_TRAIN_CLS = Tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments] _INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] _INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] _EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] _EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] -def _read_args(args: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: +def read_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> Union[Dict[str, Any], List[str]]: if args is not None: return args if len(sys.argv) == 2 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")): - # read yaml file return yaml.safe_load(Path(sys.argv[1]).absolute().read_text()) elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"): - # read json file return json.loads(Path(sys.argv[1]).absolute().read_text()) else: - return {} + return sys.argv[1:] def _parse_args( - parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None, allow_extra_keys: bool = False + parser: "HfArgumentParser", args: Optional[Union[Dict[str, Any], List[str]]] = None, allow_extra_keys: bool = False ) -> Tuple[Any]: - args_dict = _read_args(args) + args = read_args(args) + if isinstance(args, dict): + return parser.parse_dict(args, allow_extra_keys=allow_extra_keys) - if args_dict: - return parser.parse_dict(args_dict, allow_extra_keys=allow_extra_keys) - else: - (*parsed_args, unknown_args) = parser.parse_args_into_dataclasses( - args=args_dict, return_remaining_strings=True - ) + (*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(args=args, return_remaining_strings=True) - if unknown_args: - print(parser.format_help()) - print(f"Got unknown args, potentially deprecated arguments: {unknown_args}") - raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}") + if unknown_args: + print(parser.format_help()) + print(f"Got unknown args, potentially deprecated arguments: {unknown_args}") + raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}") - return (*parsed_args,) + return (*parsed_args,) def _set_transformers_logging() -> None: @@ -141,7 +122,7 @@ def _verify_model_args( def _check_extra_dependencies( model_args: "ModelArguments", finetuning_args: "FinetuningArguments", - training_args: Optional["Seq2SeqTrainingArguments"] = None, + training_args: Optional["TrainingArguments"] = None, ) -> None: if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]: logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.") @@ -177,31 +158,29 @@ def _check_extra_dependencies( require_version("rouge_chinese", "To fix: pip install rouge-chinese") -def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: +def _parse_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS: parser = HfArgumentParser(_TRAIN_ARGS) return _parse_args(parser, args) -def _parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: +def _parse_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS: parser = HfArgumentParser(_INFER_ARGS) return _parse_args(parser, args) -def _parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS: +def _parse_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS: parser = HfArgumentParser(_EVAL_ARGS) return _parse_args(parser, args) -def _parse_ray_args(args: Optional[Dict[str, Any]] = None) -> RayTrainArguments: - parser = HfArgumentParser(RayTrainArguments) - ray_args = _parse_args(parser, args, allow_extra_keys=True)[0] - if ray_args.use_ray: - require_version("ray", "To fix: pip install ray") +def get_ray_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> RayArguments: + parser = HfArgumentParser(RayArguments) + (ray_args,) = _parse_args(parser, args, allow_extra_keys=True) return ray_args -def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: - model_args, data_args, training_args, finetuning_args, generating_args, _ = _parse_train_args(args) +def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS: + model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args) # Setup logging if training_args.should_log: @@ -410,7 +389,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: return model_args, data_args, training_args, finetuning_args, generating_args -def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: +def get_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS: model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args) _set_transformers_logging() @@ -443,7 +422,7 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: return model_args, data_args, finetuning_args, generating_args -def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS: +def get_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS: model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args) _set_transformers_logging() diff --git a/src/llamafactory/hparams/training_args.py b/src/llamafactory/hparams/training_args.py new file mode 100644 index 00000000..9df24ace --- /dev/null +++ b/src/llamafactory/hparams/training_args.py @@ -0,0 +1,48 @@ +import json +from dataclasses import dataclass, field +from typing import Literal, Optional, Union + +from transformers import Seq2SeqTrainingArguments +from transformers.training_args import _convert_str_dict + +from ..extras.misc import use_ray + + +@dataclass +class RayArguments: + r""" + Arguments pertaining to the Ray training. + """ + + ray_run_name: Optional[str] = field( + default=None, + metadata={"help": "The training results will be saved at `saves/ray_run_name`."}, + ) + ray_num_workers: int = field( + default=1, + metadata={"help": "The number of workers for Ray training. Default is 1 worker."}, + ) + resources_per_worker: Union[dict, str] = field( + default_factory=lambda: {"GPU": 1}, + metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."}, + ) + placement_strategy: Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"] = field( + default="PACK", + metadata={"help": "The placement strategy for Ray training. Default is PACK."}, + ) + + def __post_init__(self): + self.use_ray = use_ray() + if isinstance(self.resources_per_worker, str) and self.resources_per_worker.startswith("{"): + self.resources_per_worker = _convert_str_dict(json.loads(self.resources_per_worker)) + + +@dataclass +class TrainingArguments(RayArguments, Seq2SeqTrainingArguments): + r""" + Arguments pertaining to the trainer. + """ + + def __post_init__(self): + Seq2SeqTrainingArguments.__post_init__(self) + RayArguments.__post_init__(self) diff --git a/src/llamafactory/integrations/__init__.py b/src/llamafactory/integrations/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/llamafactory/integrations/ray/__init__.py b/src/llamafactory/integrations/ray/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/llamafactory/integrations/ray/ray_train.py b/src/llamafactory/integrations/ray/ray_train.py deleted file mode 100644 index 50a2927a..00000000 --- a/src/llamafactory/integrations/ray/ray_train.py +++ /dev/null @@ -1,26 +0,0 @@ -from typing import Any, Callable, Dict - -from ray.train import ScalingConfig -from ray.train.torch import TorchTrainer - -from .ray_train_args import RayTrainArguments - - -def get_ray_trainer( - training_function: Callable, - train_loop_config: Dict[str, Any], - ray_args: RayTrainArguments, -) -> TorchTrainer: - if not ray_args.use_ray: - raise ValueError("Ray is not enabled. Please set USE_RAY=1 in your environment.") - - trainer = TorchTrainer( - training_function, - train_loop_config=train_loop_config, - scaling_config=ScalingConfig( - num_workers=ray_args.num_workers, - resources_per_worker=ray_args.resources_per_worker, - use_gpu=True, - ), - ) - return trainer diff --git a/src/llamafactory/integrations/ray/ray_train_args.py b/src/llamafactory/integrations/ray/ray_train_args.py deleted file mode 100644 index afacbee1..00000000 --- a/src/llamafactory/integrations/ray/ray_train_args.py +++ /dev/null @@ -1,30 +0,0 @@ -from dataclasses import dataclass, field -from typing import Any, Dict, Literal, Optional - -from .ray_utils import should_use_ray - - -@dataclass -class RayTrainArguments: - r""" - Arguments pertaining to the Ray training. - """ - - resources_per_worker: Optional[Dict[str, Any]] = field( - default_factory=lambda: {"GPU": 1}, - metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."}, - ) - num_workers: Optional[int] = field( - default=1, metadata={"help": "The number of workers for Ray training. Default is 1 worker."} - ) - placement_strategy: Optional[Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"]] = field( - default="PACK", metadata={"help": "The placement strategy for Ray training. Default is PACK."} - ) - - @property - def use_ray(self) -> bool: - """ - Always returns the value from the environment variable check. - This prevents manual setting of use_ray. - """ - return should_use_ray() diff --git a/src/llamafactory/integrations/ray/ray_utils.py b/src/llamafactory/integrations/ray/ray_utils.py deleted file mode 100644 index 8b8b045e..00000000 --- a/src/llamafactory/integrations/ray/ray_utils.py +++ /dev/null @@ -1,5 +0,0 @@ -import os - - -def should_use_ray(): - return os.getenv("USE_RAY", "0").lower() in ["true", "1"] diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index 4cd2337e..6aca53cf 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -18,7 +18,8 @@ # limitations under the License. from collections.abc import Mapping -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import torch from transformers import Trainer @@ -31,7 +32,7 @@ from typing_extensions import override from ..extras import logging from ..extras.constants import IGNORE_INDEX -from ..extras.packages import is_galore_available +from ..extras.packages import is_galore_available, is_ray_available from ..hparams import FinetuningArguments, ModelArguments from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params @@ -40,11 +41,16 @@ if is_galore_available(): from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit # type: ignore +if is_ray_available(): + from ray.train import RunConfig, ScalingConfig + from ray.train.torch import TorchTrainer + + if TYPE_CHECKING: - from transformers import PreTrainedModel, Seq2SeqTrainingArguments, TrainerCallback + from transformers import PreTrainedModel, TrainerCallback from trl import AutoModelForCausalLMWithValueHead - from ..hparams import DataArguments + from ..hparams import DataArguments, RayArguments, TrainingArguments logger = logging.get_logger(__name__) @@ -75,7 +81,7 @@ def create_modelcard_and_push( trainer: "Trainer", model_args: "ModelArguments", data_args: "DataArguments", - training_args: "Seq2SeqTrainingArguments", + training_args: "TrainingArguments", finetuning_args: "FinetuningArguments", ) -> None: kwargs = { @@ -188,7 +194,7 @@ def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]: def _create_galore_optimizer( model: "PreTrainedModel", - training_args: "Seq2SeqTrainingArguments", + training_args: "TrainingArguments", finetuning_args: "FinetuningArguments", ) -> "torch.optim.Optimizer": if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all": @@ -272,7 +278,7 @@ def _create_galore_optimizer( def _create_loraplus_optimizer( model: "PreTrainedModel", - training_args: "Seq2SeqTrainingArguments", + training_args: "TrainingArguments", finetuning_args: "FinetuningArguments", ) -> "torch.optim.Optimizer": default_lr = training_args.learning_rate @@ -312,7 +318,7 @@ def _create_loraplus_optimizer( def _create_badam_optimizer( model: "PreTrainedModel", - training_args: "Seq2SeqTrainingArguments", + training_args: "TrainingArguments", finetuning_args: "FinetuningArguments", ) -> "torch.optim.Optimizer": decay_params, nodecay_params = [], [] @@ -373,7 +379,7 @@ def _create_badam_optimizer( def _create_adam_mini_optimizer( model: "PreTrainedModel", - training_args: "Seq2SeqTrainingArguments", + training_args: "TrainingArguments", ) -> "torch.optim.Optimizer": from adam_mini import Adam_mini # type: ignore @@ -398,7 +404,7 @@ def _create_adam_mini_optimizer( def create_custom_optimizer( model: "PreTrainedModel", - training_args: "Seq2SeqTrainingArguments", + training_args: "TrainingArguments", finetuning_args: "FinetuningArguments", ) -> Optional["torch.optim.Optimizer"]: if finetuning_args.use_galore: @@ -415,7 +421,7 @@ def create_custom_optimizer( def create_custom_scheduler( - training_args: "Seq2SeqTrainingArguments", + training_args: "TrainingArguments", num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None, ) -> None: @@ -499,3 +505,28 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall config={"Framework": "🦙LlamaFactory"}, ) return swanlab_callback + + +def get_ray_trainer( + training_function: Callable, + train_loop_config: Dict[str, Any], + ray_args: "RayArguments", +) -> "TorchTrainer": + if not ray_args.use_ray: + raise ValueError("Ray was not enabled. Please set `USE_RAY=1` to enable ray.") + + trainer = TorchTrainer( + training_function, + train_loop_config=train_loop_config, + scaling_config=ScalingConfig( + num_workers=ray_args.ray_num_workers, + resources_per_worker=ray_args.resources_per_worker, + placement_strategy=ray_args.placement_strategy, + use_gpu=True, + ), + run_config=RunConfig( + name=ray_args.ray_run_name, + storage_path=Path("./saves").absolute().as_posix(), + ), + ) + return trainer diff --git a/src/llamafactory/train/tuner.py b/src/llamafactory/train/tuner.py index 8c461890..24620c87 100644 --- a/src/llamafactory/train/tuner.py +++ b/src/llamafactory/train/tuner.py @@ -22,8 +22,8 @@ from transformers import PreTrainedModel from ..data import get_template_and_fix_tokenizer from ..extras import logging from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME -from ..hparams import get_infer_args, get_train_args -from ..hparams.parser import _parse_ray_args, _read_args +from ..extras.packages import is_ray_available +from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args from ..model import load_model, load_tokenizer from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback from .dpo import run_dpo @@ -32,7 +32,11 @@ from .ppo import run_ppo from .pt import run_pt from .rm import run_rm from .sft import run_sft -from .trainer_utils import get_swanlab_callback +from .trainer_utils import get_ray_trainer, get_swanlab_callback + + +if is_ray_available(): + from ray.train.huggingface.transformers import RayTrainReportCallback if TYPE_CHECKING: @@ -43,10 +47,8 @@ logger = logging.get_logger(__name__) def training_function(config: Dict[str, Any]) -> None: - args = config.get("args", None) - callbacks = config.get("callbacks", []) - - callbacks.append(LogCallback()) + args = config.get("args") + callbacks: List[Any] = config.get("callbacks") model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args) if finetuning_args.pissa_convert: @@ -73,31 +75,22 @@ def training_function(config: Dict[str, Any]) -> None: raise ValueError(f"Unknown task: {finetuning_args.stage}.") -def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None: - args_dict = _read_args(args) - ray_args = _parse_ray_args(args_dict) +def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None) -> None: + callbacks = callbacks or [] + callbacks.append(LogCallback()) + args = read_args(args) + ray_args = get_ray_args(args) if ray_args.use_ray: - # Import lazily to avoid ray not installed error - from ..integrations.ray.ray_train import get_ray_trainer - - # Initialize ray trainer + callbacks.append(RayTrainReportCallback()) trainer = get_ray_trainer( training_function=training_function, - train_loop_config={ - "args": args_dict, - "callbacks": callbacks, - }, + train_loop_config={"args": args, "callbacks": callbacks}, ray_args=ray_args, ) trainer.fit() else: - training_function( - config={ - "args": args_dict, - "callbacks": callbacks, - } - ) + training_function(config={"args": args, "callbacks": callbacks}) def export_model(args: Optional[Dict[str, Any]] = None) -> None: From 0c1ad5f3fb496fabafc2954e4e4f42da06b22daf Mon Sep 17 00:00:00 2001 From: hiyouga Date: Tue, 7 Jan 2025 09:59:24 +0000 Subject: [PATCH 4/4] fix llamaboard with ray Former-commit-id: c46675d5e56d175c27d705ef0068fb47dc89a872 --- src/llamafactory/train/callbacks.py | 6 +++--- src/llamafactory/train/tuner.py | 11 +++++------ src/llamafactory/webui/runner.py | 6 +++--- 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/src/llamafactory/train/callbacks.py b/src/llamafactory/train/callbacks.py index 189c7533..4da4ec18 100644 --- a/src/llamafactory/train/callbacks.py +++ b/src/llamafactory/train/callbacks.py @@ -35,7 +35,7 @@ from typing_extensions import override from ..extras import logging from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME -from ..extras.misc import get_peak_memory +from ..extras.misc import get_peak_memory, use_ray if is_safetensors_available(): @@ -194,7 +194,7 @@ class LogCallback(TrainerCallback): self.do_train = False # Web UI self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"] - if self.webui_mode: + if self.webui_mode and not use_ray(): signal.signal(signal.SIGABRT, self._set_abort) self.logger_handler = logging.LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR")) logging.add_handler(self.logger_handler) @@ -383,7 +383,7 @@ class ReporterCallback(TrainerCallback): ) if self.finetuning_args.use_swanlab: - import swanlab + import swanlab # type: ignore swanlab.config.update( { diff --git a/src/llamafactory/train/tuner.py b/src/llamafactory/train/tuner.py index 24620c87..bbbef1cf 100644 --- a/src/llamafactory/train/tuner.py +++ b/src/llamafactory/train/tuner.py @@ -46,11 +46,12 @@ if TYPE_CHECKING: logger = logging.get_logger(__name__) -def training_function(config: Dict[str, Any]) -> None: +def _training_function(config: Dict[str, Any]) -> None: args = config.get("args") callbacks: List[Any] = config.get("callbacks") model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args) + callbacks.append(LogCallback()) if finetuning_args.pissa_convert: callbacks.append(PissaConvertCallback()) @@ -76,21 +77,19 @@ def training_function(config: Dict[str, Any]) -> None: def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None) -> None: - callbacks = callbacks or [] - callbacks.append(LogCallback()) - args = read_args(args) ray_args = get_ray_args(args) + callbacks = callbacks or [] if ray_args.use_ray: callbacks.append(RayTrainReportCallback()) trainer = get_ray_trainer( - training_function=training_function, + training_function=_training_function, train_loop_config={"args": args, "callbacks": callbacks}, ray_args=ray_args, ) trainer.fit() else: - training_function(config={"args": args, "callbacks": callbacks}) + _training_function(config={"args": args, "callbacks": callbacks}) def export_model(args: Optional[Dict[str, Any]] = None) -> None: diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py index f5aecaeb..dc91ad50 100644 --- a/src/llamafactory/webui/runner.py +++ b/src/llamafactory/webui/runner.py @@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, Optional from transformers.trainer import TRAINING_ARGS_NAME from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES -from ..extras.misc import is_gpu_or_npu_available, torch_gc +from ..extras.misc import is_gpu_or_npu_available, torch_gc, use_ray from ..extras.packages import is_gradio_available, is_transformers_version_equal_to_4_46 from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, QUANTIZATION_BITS, get_save_dir, load_config from .locales import ALERTS, LOCALES @@ -394,12 +394,12 @@ class Runner: continue if self.do_train: - if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)): + if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)) or use_ray(): finish_info = ALERTS["info_finished"][lang] else: finish_info = ALERTS["err_failed"][lang] else: - if os.path.exists(os.path.join(output_path, "all_results.json")): + if os.path.exists(os.path.join(output_path, "all_results.json")) or use_ray(): finish_info = get_eval_results(os.path.join(output_path, "all_results.json")) else: finish_info = ALERTS["err_failed"][lang]