From bba52e258e7474a64b8080b159a3ce762e1ee67c Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Mon, 6 Jan 2025 23:55:56 +0000 Subject: [PATCH] run style check Former-commit-id: 1e8e7be0a535e55888f58bbe2c38bc1c382e9012 --- examples/train_lora/llama3_lora_sft.yaml | 2 +- src/llamafactory/cli.py | 3 +- src/llamafactory/hparams/parser.py | 41 +++++++++++++------ .../integrations/ray/ray_train.py | 10 ++--- .../integrations/ray/ray_train_args.py | 16 ++++++-- .../integrations/ray/ray_utils.py | 4 -- src/llamafactory/train/tuner.py | 13 +++--- 7 files changed, 54 insertions(+), 35 deletions(-) diff --git a/examples/train_lora/llama3_lora_sft.yaml b/examples/train_lora/llama3_lora_sft.yaml index 558ac3e9..dc4f5add 100644 --- a/examples/train_lora/llama3_lora_sft.yaml +++ b/examples/train_lora/llama3_lora_sft.yaml @@ -9,7 +9,7 @@ finetuning_type: lora lora_target: all ### dataset -dataset_dir: /home/ray/default/lf/data/ +dataset_dir: /home/ray/default/LLaMA-Factory/data/ dataset: identity,alpaca_en_demo template: llama3 cutoff_len: 2048 diff --git a/src/llamafactory/cli.py b/src/llamafactory/cli.py index 26d7a3df..a2a8e29d 100644 --- a/src/llamafactory/cli.py +++ b/src/llamafactory/cli.py @@ -25,9 +25,10 @@ from .eval.evaluator import run_eval from .extras import logging from .extras.env import VERSION, print_env from .extras.misc import get_device_count +from .integrations.ray.ray_utils import should_use_ray from .train.tuner import export_model, run_exp from .webui.interface import run_web_demo, run_web_ui -from .integrations.ray.ray_utils import should_use_ray + USAGE = ( "-" * 70 diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 06853ae8..8cdfa7cb 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -15,16 +15,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import os import sys -from typing import Any, Dict, Optional, Tuple - -import json -import yaml from pathlib import Path +from typing import Any, Dict, Optional, Tuple import torch import transformers +import yaml from transformers import HfArgumentParser, Seq2SeqTrainingArguments from transformers.integrations import is_deepspeed_zero3_enabled from transformers.trainer_utils import get_last_checkpoint @@ -35,21 +34,35 @@ from transformers.utils.versions import require_version from ..extras import logging from ..extras.constants import CHECKPOINT_NAMES from ..extras.misc import check_dependencies, get_current_device +from ..integrations.ray.ray_train_args import RayTrainArguments from .data_args import DataArguments from .evaluation_args import EvaluationArguments from .finetuning_args import FinetuningArguments from .generating_args import GeneratingArguments from .model_args import ModelArguments -from ..integrations.ray.ray_train_args import RayTrainArguments logger = logging.get_logger(__name__) check_dependencies() -_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments, RayTrainArguments] -_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments, RayTrainArguments] +_TRAIN_ARGS = [ + ModelArguments, + DataArguments, + Seq2SeqTrainingArguments, + FinetuningArguments, + GeneratingArguments, + RayTrainArguments, +] +_TRAIN_CLS = Tuple[ + ModelArguments, + DataArguments, + Seq2SeqTrainingArguments, + FinetuningArguments, + GeneratingArguments, + RayTrainArguments, +] _INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] _INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] _EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] @@ -70,14 +83,17 @@ def _read_args(args: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: return {} -def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None, allow_extra_keys: bool = False) -> Tuple[Any]: - +def _parse_args( + parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None, allow_extra_keys: bool = False +) -> Tuple[Any]: args_dict = _read_args(args) - + if args_dict: return parser.parse_dict(args_dict, allow_extra_keys=allow_extra_keys) else: - (*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(args=args_dict, return_remaining_strings=True) + (*parsed_args, unknown_args) = parser.parse_args_into_dataclasses( + args=args_dict, return_remaining_strings=True + ) if unknown_args: print(parser.format_help()) @@ -85,7 +101,6 @@ def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = Non raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}") return (*parsed_args,) - def _set_transformers_logging() -> None: @@ -187,7 +202,7 @@ def _parse_ray_args(args: Optional[Dict[str, Any]] = None) -> RayTrainArguments: def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: model_args, data_args, training_args, finetuning_args, generating_args, _ = _parse_train_args(args) - + # Setup logging if training_args.should_log: _set_transformers_logging() diff --git a/src/llamafactory/integrations/ray/ray_train.py b/src/llamafactory/integrations/ray/ray_train.py index 4620bb5a..50a2927a 100644 --- a/src/llamafactory/integrations/ray/ray_train.py +++ b/src/llamafactory/integrations/ray/ray_train.py @@ -1,21 +1,19 @@ - from typing import Any, Callable, Dict -from ray.train.torch import TorchTrainer from ray.train import ScalingConfig +from ray.train.torch import TorchTrainer from .ray_train_args import RayTrainArguments - + def get_ray_trainer( training_function: Callable, train_loop_config: Dict[str, Any], ray_args: RayTrainArguments, ) -> TorchTrainer: - if not ray_args.use_ray: raise ValueError("Ray is not enabled. Please set USE_RAY=1 in your environment.") - + trainer = TorchTrainer( training_function, train_loop_config=train_loop_config, @@ -25,4 +23,4 @@ def get_ray_trainer( use_gpu=True, ), ) - return trainer \ No newline at end of file + return trainer diff --git a/src/llamafactory/integrations/ray/ray_train_args.py b/src/llamafactory/integrations/ray/ray_train_args.py index 9ee9dc8e..afacbee1 100644 --- a/src/llamafactory/integrations/ray/ray_train_args.py +++ b/src/llamafactory/integrations/ray/ray_train_args.py @@ -3,14 +3,23 @@ from typing import Any, Dict, Literal, Optional from .ray_utils import should_use_ray + @dataclass class RayTrainArguments: r""" Arguments pertaining to the Ray training. """ - resources_per_worker: Optional[Dict[str, Any]] = field(default_factory=lambda: {"GPU": 1}, metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."}) - num_workers: Optional[int] = field(default=1, metadata={"help": "The number of workers for Ray training. Default is 1 worker."}) - placement_strategy: Optional[Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"]] = field(default="PACK", metadata={"help": "The placement strategy for Ray training. Default is PACK."}) + + resources_per_worker: Optional[Dict[str, Any]] = field( + default_factory=lambda: {"GPU": 1}, + metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."}, + ) + num_workers: Optional[int] = field( + default=1, metadata={"help": "The number of workers for Ray training. Default is 1 worker."} + ) + placement_strategy: Optional[Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"]] = field( + default="PACK", metadata={"help": "The placement strategy for Ray training. Default is PACK."} + ) @property def use_ray(self) -> bool: @@ -19,4 +28,3 @@ class RayTrainArguments: This prevents manual setting of use_ray. """ return should_use_ray() - diff --git a/src/llamafactory/integrations/ray/ray_utils.py b/src/llamafactory/integrations/ray/ray_utils.py index 67ce2ed9..8b8b045e 100644 --- a/src/llamafactory/integrations/ray/ray_utils.py +++ b/src/llamafactory/integrations/ray/ray_utils.py @@ -1,9 +1,5 @@ - import os def should_use_ray(): return os.getenv("USE_RAY", "0").lower() in ["true", "1"] - - - diff --git a/src/llamafactory/train/tuner.py b/src/llamafactory/train/tuner.py index 507a1f14..8c461890 100644 --- a/src/llamafactory/train/tuner.py +++ b/src/llamafactory/train/tuner.py @@ -37,14 +37,15 @@ from .trainer_utils import get_swanlab_callback if TYPE_CHECKING: from transformers import TrainerCallback - + logger = logging.get_logger(__name__) + def training_function(config: Dict[str, Any]) -> None: args = config.get("args", None) callbacks = config.get("callbacks", []) - + callbacks.append(LogCallback()) model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args) @@ -71,15 +72,15 @@ def training_function(config: Dict[str, Any]) -> None: else: raise ValueError(f"Unknown task: {finetuning_args.stage}.") + def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None: - args_dict = _read_args(args) ray_args = _parse_ray_args(args_dict) - - if ray_args.use_ray: + + if ray_args.use_ray: # Import lazily to avoid ray not installed error from ..integrations.ray.ray_train import get_ray_trainer - + # Initialize ray trainer trainer = get_ray_trainer( training_function=training_function,