diff --git a/.env.local b/.env.local index 203aebaf..1d8f2a00 100644 --- a/.env.local +++ b/.env.local @@ -12,6 +12,7 @@ FORCE_CHECK_IMPORTS= LLAMAFACTORY_VERBOSITY= USE_MODELSCOPE_HUB= USE_OPENMIND_HUB= +USE_RAY= RECORD_VRAM= # torchrun FORCE_TORCHRUN= diff --git a/examples/README.md b/examples/README.md index cc2afc46..89f7d174 100644 --- a/examples/README.md +++ b/examples/README.md @@ -95,6 +95,12 @@ FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml ``` +#### Supervised Fine-Tuning with Ray on 4 GPUs + +```bash +USE_RAY=1 llamafactory-cli train examples/train_full/llama3_lora_sft_ray.yaml +``` + ### QLoRA Fine-Tuning #### Supervised Fine-Tuning with 4/8-bit Bitsandbytes/HQQ/EETQ Quantization (Recommended) diff --git a/examples/README_zh.md b/examples/README_zh.md index b41d7ab8..2c108e56 100644 --- a/examples/README_zh.md +++ b/examples/README_zh.md @@ -95,6 +95,12 @@ FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml ``` +#### 使用 Ray 在 4 张 GPU 上微调 + +```bash +USE_RAY=1 llamafactory-cli train examples/train_full/llama3_lora_sft_ray.yaml +``` + ### QLoRA 微调 #### 基于 4/8 比特 Bitsandbytes/HQQ/EETQ 量化进行指令监督微调(推荐) diff --git a/examples/train_lora/llama3_lora_sft.yaml b/examples/train_lora/llama3_lora_sft.yaml index dc4f5add..243f2445 100644 --- a/examples/train_lora/llama3_lora_sft.yaml +++ b/examples/train_lora/llama3_lora_sft.yaml @@ -9,7 +9,6 @@ finetuning_type: lora lora_target: all ### dataset -dataset_dir: /home/ray/default/LLaMA-Factory/data/ dataset: identity,alpaca_en_demo template: llama3 cutoff_len: 2048 @@ -39,10 +38,3 @@ val_size: 0.1 per_device_eval_batch_size: 1 eval_strategy: steps eval_steps: 500 - - -### ray setup -resources_per_worker: - GPU: 1 -num_workers: 4 -# placement_strategy: ... diff --git a/examples/train_lora/llama3_lora_sft_ray.yaml b/examples/train_lora/llama3_lora_sft_ray.yaml new file mode 100644 index 00000000..4aac08bc --- /dev/null +++ b/examples/train_lora/llama3_lora_sft_ray.yaml @@ -0,0 +1,48 @@ +### model +model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct # or use local absolute path +trust_remote_code: true + +### method +stage: sft +do_train: true +finetuning_type: lora +lora_target: all + +### dataset +dataset: identity,alpaca_en_demo +dataset_dir: REMOTE:llamafactory/demo_data # or use local absolute path +template: llama3 +cutoff_len: 2048 +max_samples: 1000 +overwrite_cache: true +preprocessing_num_workers: 16 + +### output +output_dir: tmp_dir +logging_steps: 10 +save_steps: 500 +plot_loss: true +overwrite_output_dir: true + +### train +per_device_train_batch_size: 1 +gradient_accumulation_steps: 8 +learning_rate: 1.0e-4 +num_train_epochs: 3.0 +lr_scheduler_type: cosine +warmup_ratio: 0.1 +bf16: true +ddp_timeout: 180000000 + +### eval +val_size: 0.1 +per_device_eval_batch_size: 1 +eval_strategy: steps +eval_steps: 500 + +### ray +ray_run_name: llama3_8b_sft_lora +ray_num_workers: 4 # number of GPUs to use +resources_per_worker: + GPU: 1 +placement_strategy: PACK diff --git a/src/llamafactory/cli.py b/src/llamafactory/cli.py index a2a8e29d..72085e2d 100644 --- a/src/llamafactory/cli.py +++ b/src/llamafactory/cli.py @@ -24,8 +24,7 @@ from .chat.chat_model import run_chat from .eval.evaluator import run_eval from .extras import logging from .extras.env import VERSION, print_env -from .extras.misc import get_device_count -from .integrations.ray.ray_utils import should_use_ray +from .extras.misc import get_device_count, use_ray from .train.tuner import export_model, run_exp from .webui.interface import run_web_demo, run_web_ui @@ -88,8 +87,7 @@ def main(): export_model() elif command == Command.TRAIN: force_torchrun = os.getenv("FORCE_TORCHRUN", "0").lower() in ["true", "1"] - use_ray = should_use_ray() - if force_torchrun or (get_device_count() > 1 and not use_ray): + if force_torchrun or (get_device_count() > 1 and not use_ray()): master_addr = os.getenv("MASTER_ADDR", "127.0.0.1") master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999))) logger.info_rank0(f"Initializing distributed tasks at: {master_addr}:{master_port}") diff --git a/src/llamafactory/extras/misc.py b/src/llamafactory/extras/misc.py index 735c5d63..11797f9f 100644 --- a/src/llamafactory/extras/misc.py +++ b/src/llamafactory/extras/misc.py @@ -229,7 +229,7 @@ def skip_check_imports() -> None: r""" Avoids flash attention import error in custom model files. """ - if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]: + if os.getenv("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]: transformers.dynamic_module_utils.check_imports = get_relative_imports @@ -275,8 +275,12 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str: def use_modelscope() -> bool: - return os.environ.get("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"] + return os.getenv("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"] def use_openmind() -> bool: - return os.environ.get("USE_OPENMIND_HUB", "0").lower() in ["true", "1"] + return os.getenv("USE_OPENMIND_HUB", "0").lower() in ["true", "1"] + + +def use_ray() -> bool: + return os.getenv("USE_RAY", "0").lower() in ["true", "1"] diff --git a/src/llamafactory/extras/packages.py b/src/llamafactory/extras/packages.py index 44b9bb8a..6b2bc3f3 100644 --- a/src/llamafactory/extras/packages.py +++ b/src/llamafactory/extras/packages.py @@ -62,6 +62,10 @@ def is_pillow_available(): return _is_package_available("PIL") +def is_ray_available(): + return _is_package_available("ray") + + def is_requests_available(): return _is_package_available("requests") diff --git a/src/llamafactory/hparams/__init__.py b/src/llamafactory/hparams/__init__.py index cfe448c1..254a845e 100644 --- a/src/llamafactory/hparams/__init__.py +++ b/src/llamafactory/hparams/__init__.py @@ -17,7 +17,8 @@ from .evaluation_args import EvaluationArguments from .finetuning_args import FinetuningArguments from .generating_args import GeneratingArguments from .model_args import ModelArguments -from .parser import get_eval_args, get_infer_args, get_train_args +from .parser import get_eval_args, get_infer_args, get_ray_args, get_train_args, read_args +from .training_args import RayArguments, TrainingArguments __all__ = [ @@ -26,7 +27,11 @@ __all__ = [ "FinetuningArguments", "GeneratingArguments", "ModelArguments", + "RayArguments", + "TrainingArguments", "get_eval_args", "get_infer_args", + "get_ray_args", "get_train_args", + "read_args", ] diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 8cdfa7cb..62edbf78 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -19,12 +19,12 @@ import json import os import sys from pathlib import Path -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union import torch import transformers import yaml -from transformers import HfArgumentParser, Seq2SeqTrainingArguments +from transformers import HfArgumentParser from transformers.integrations import is_deepspeed_zero3_enabled from transformers.trainer_utils import get_last_checkpoint from transformers.training_args import ParallelMode @@ -34,12 +34,12 @@ from transformers.utils.versions import require_version from ..extras import logging from ..extras.constants import CHECKPOINT_NAMES from ..extras.misc import check_dependencies, get_current_device -from ..integrations.ray.ray_train_args import RayTrainArguments from .data_args import DataArguments from .evaluation_args import EvaluationArguments from .finetuning_args import FinetuningArguments from .generating_args import GeneratingArguments from .model_args import ModelArguments +from .training_args import RayArguments, TrainingArguments logger = logging.get_logger(__name__) @@ -47,60 +47,41 @@ logger = logging.get_logger(__name__) check_dependencies() -_TRAIN_ARGS = [ - ModelArguments, - DataArguments, - Seq2SeqTrainingArguments, - FinetuningArguments, - GeneratingArguments, - RayTrainArguments, -] -_TRAIN_CLS = Tuple[ - ModelArguments, - DataArguments, - Seq2SeqTrainingArguments, - FinetuningArguments, - GeneratingArguments, - RayTrainArguments, -] +_TRAIN_ARGS = [ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments] +_TRAIN_CLS = Tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments] _INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] _INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] _EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] _EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] -def _read_args(args: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: +def read_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> Union[Dict[str, Any], List[str]]: if args is not None: return args if len(sys.argv) == 2 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")): - # read yaml file return yaml.safe_load(Path(sys.argv[1]).absolute().read_text()) elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"): - # read json file return json.loads(Path(sys.argv[1]).absolute().read_text()) else: - return {} + return sys.argv[1:] def _parse_args( - parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None, allow_extra_keys: bool = False + parser: "HfArgumentParser", args: Optional[Union[Dict[str, Any], List[str]]] = None, allow_extra_keys: bool = False ) -> Tuple[Any]: - args_dict = _read_args(args) + args = read_args(args) + if isinstance(args, dict): + return parser.parse_dict(args, allow_extra_keys=allow_extra_keys) - if args_dict: - return parser.parse_dict(args_dict, allow_extra_keys=allow_extra_keys) - else: - (*parsed_args, unknown_args) = parser.parse_args_into_dataclasses( - args=args_dict, return_remaining_strings=True - ) + (*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(args=args, return_remaining_strings=True) - if unknown_args: - print(parser.format_help()) - print(f"Got unknown args, potentially deprecated arguments: {unknown_args}") - raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}") + if unknown_args: + print(parser.format_help()) + print(f"Got unknown args, potentially deprecated arguments: {unknown_args}") + raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}") - return (*parsed_args,) + return (*parsed_args,) def _set_transformers_logging() -> None: @@ -141,7 +122,7 @@ def _verify_model_args( def _check_extra_dependencies( model_args: "ModelArguments", finetuning_args: "FinetuningArguments", - training_args: Optional["Seq2SeqTrainingArguments"] = None, + training_args: Optional["TrainingArguments"] = None, ) -> None: if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]: logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.") @@ -177,31 +158,29 @@ def _check_extra_dependencies( require_version("rouge_chinese", "To fix: pip install rouge-chinese") -def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: +def _parse_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS: parser = HfArgumentParser(_TRAIN_ARGS) return _parse_args(parser, args) -def _parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: +def _parse_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS: parser = HfArgumentParser(_INFER_ARGS) return _parse_args(parser, args) -def _parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS: +def _parse_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS: parser = HfArgumentParser(_EVAL_ARGS) return _parse_args(parser, args) -def _parse_ray_args(args: Optional[Dict[str, Any]] = None) -> RayTrainArguments: - parser = HfArgumentParser(RayTrainArguments) - ray_args = _parse_args(parser, args, allow_extra_keys=True)[0] - if ray_args.use_ray: - require_version("ray", "To fix: pip install ray") +def get_ray_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> RayArguments: + parser = HfArgumentParser(RayArguments) + (ray_args,) = _parse_args(parser, args, allow_extra_keys=True) return ray_args -def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: - model_args, data_args, training_args, finetuning_args, generating_args, _ = _parse_train_args(args) +def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS: + model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args) # Setup logging if training_args.should_log: @@ -410,7 +389,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: return model_args, data_args, training_args, finetuning_args, generating_args -def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: +def get_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS: model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args) _set_transformers_logging() @@ -443,7 +422,7 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: return model_args, data_args, finetuning_args, generating_args -def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS: +def get_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS: model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args) _set_transformers_logging() diff --git a/src/llamafactory/hparams/training_args.py b/src/llamafactory/hparams/training_args.py new file mode 100644 index 00000000..9df24ace --- /dev/null +++ b/src/llamafactory/hparams/training_args.py @@ -0,0 +1,48 @@ +import json +from dataclasses import dataclass, field +from typing import Literal, Optional, Union + +from transformers import Seq2SeqTrainingArguments +from transformers.training_args import _convert_str_dict + +from ..extras.misc import use_ray + + +@dataclass +class RayArguments: + r""" + Arguments pertaining to the Ray training. + """ + + ray_run_name: Optional[str] = field( + default=None, + metadata={"help": "The training results will be saved at `saves/ray_run_name`."}, + ) + ray_num_workers: int = field( + default=1, + metadata={"help": "The number of workers for Ray training. Default is 1 worker."}, + ) + resources_per_worker: Union[dict, str] = field( + default_factory=lambda: {"GPU": 1}, + metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."}, + ) + placement_strategy: Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"] = field( + default="PACK", + metadata={"help": "The placement strategy for Ray training. Default is PACK."}, + ) + + def __post_init__(self): + self.use_ray = use_ray() + if isinstance(self.resources_per_worker, str) and self.resources_per_worker.startswith("{"): + self.resources_per_worker = _convert_str_dict(json.loads(self.resources_per_worker)) + + +@dataclass +class TrainingArguments(RayArguments, Seq2SeqTrainingArguments): + r""" + Arguments pertaining to the trainer. + """ + + def __post_init__(self): + Seq2SeqTrainingArguments.__post_init__(self) + RayArguments.__post_init__(self) diff --git a/src/llamafactory/integrations/__init__.py b/src/llamafactory/integrations/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/llamafactory/integrations/ray/__init__.py b/src/llamafactory/integrations/ray/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/llamafactory/integrations/ray/ray_train.py b/src/llamafactory/integrations/ray/ray_train.py deleted file mode 100644 index 50a2927a..00000000 --- a/src/llamafactory/integrations/ray/ray_train.py +++ /dev/null @@ -1,26 +0,0 @@ -from typing import Any, Callable, Dict - -from ray.train import ScalingConfig -from ray.train.torch import TorchTrainer - -from .ray_train_args import RayTrainArguments - - -def get_ray_trainer( - training_function: Callable, - train_loop_config: Dict[str, Any], - ray_args: RayTrainArguments, -) -> TorchTrainer: - if not ray_args.use_ray: - raise ValueError("Ray is not enabled. Please set USE_RAY=1 in your environment.") - - trainer = TorchTrainer( - training_function, - train_loop_config=train_loop_config, - scaling_config=ScalingConfig( - num_workers=ray_args.num_workers, - resources_per_worker=ray_args.resources_per_worker, - use_gpu=True, - ), - ) - return trainer diff --git a/src/llamafactory/integrations/ray/ray_train_args.py b/src/llamafactory/integrations/ray/ray_train_args.py deleted file mode 100644 index afacbee1..00000000 --- a/src/llamafactory/integrations/ray/ray_train_args.py +++ /dev/null @@ -1,30 +0,0 @@ -from dataclasses import dataclass, field -from typing import Any, Dict, Literal, Optional - -from .ray_utils import should_use_ray - - -@dataclass -class RayTrainArguments: - r""" - Arguments pertaining to the Ray training. - """ - - resources_per_worker: Optional[Dict[str, Any]] = field( - default_factory=lambda: {"GPU": 1}, - metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."}, - ) - num_workers: Optional[int] = field( - default=1, metadata={"help": "The number of workers for Ray training. Default is 1 worker."} - ) - placement_strategy: Optional[Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"]] = field( - default="PACK", metadata={"help": "The placement strategy for Ray training. Default is PACK."} - ) - - @property - def use_ray(self) -> bool: - """ - Always returns the value from the environment variable check. - This prevents manual setting of use_ray. - """ - return should_use_ray() diff --git a/src/llamafactory/integrations/ray/ray_utils.py b/src/llamafactory/integrations/ray/ray_utils.py deleted file mode 100644 index 8b8b045e..00000000 --- a/src/llamafactory/integrations/ray/ray_utils.py +++ /dev/null @@ -1,5 +0,0 @@ -import os - - -def should_use_ray(): - return os.getenv("USE_RAY", "0").lower() in ["true", "1"] diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index 4cd2337e..6aca53cf 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -18,7 +18,8 @@ # limitations under the License. from collections.abc import Mapping -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import torch from transformers import Trainer @@ -31,7 +32,7 @@ from typing_extensions import override from ..extras import logging from ..extras.constants import IGNORE_INDEX -from ..extras.packages import is_galore_available +from ..extras.packages import is_galore_available, is_ray_available from ..hparams import FinetuningArguments, ModelArguments from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params @@ -40,11 +41,16 @@ if is_galore_available(): from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit # type: ignore +if is_ray_available(): + from ray.train import RunConfig, ScalingConfig + from ray.train.torch import TorchTrainer + + if TYPE_CHECKING: - from transformers import PreTrainedModel, Seq2SeqTrainingArguments, TrainerCallback + from transformers import PreTrainedModel, TrainerCallback from trl import AutoModelForCausalLMWithValueHead - from ..hparams import DataArguments + from ..hparams import DataArguments, RayArguments, TrainingArguments logger = logging.get_logger(__name__) @@ -75,7 +81,7 @@ def create_modelcard_and_push( trainer: "Trainer", model_args: "ModelArguments", data_args: "DataArguments", - training_args: "Seq2SeqTrainingArguments", + training_args: "TrainingArguments", finetuning_args: "FinetuningArguments", ) -> None: kwargs = { @@ -188,7 +194,7 @@ def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]: def _create_galore_optimizer( model: "PreTrainedModel", - training_args: "Seq2SeqTrainingArguments", + training_args: "TrainingArguments", finetuning_args: "FinetuningArguments", ) -> "torch.optim.Optimizer": if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all": @@ -272,7 +278,7 @@ def _create_galore_optimizer( def _create_loraplus_optimizer( model: "PreTrainedModel", - training_args: "Seq2SeqTrainingArguments", + training_args: "TrainingArguments", finetuning_args: "FinetuningArguments", ) -> "torch.optim.Optimizer": default_lr = training_args.learning_rate @@ -312,7 +318,7 @@ def _create_loraplus_optimizer( def _create_badam_optimizer( model: "PreTrainedModel", - training_args: "Seq2SeqTrainingArguments", + training_args: "TrainingArguments", finetuning_args: "FinetuningArguments", ) -> "torch.optim.Optimizer": decay_params, nodecay_params = [], [] @@ -373,7 +379,7 @@ def _create_badam_optimizer( def _create_adam_mini_optimizer( model: "PreTrainedModel", - training_args: "Seq2SeqTrainingArguments", + training_args: "TrainingArguments", ) -> "torch.optim.Optimizer": from adam_mini import Adam_mini # type: ignore @@ -398,7 +404,7 @@ def _create_adam_mini_optimizer( def create_custom_optimizer( model: "PreTrainedModel", - training_args: "Seq2SeqTrainingArguments", + training_args: "TrainingArguments", finetuning_args: "FinetuningArguments", ) -> Optional["torch.optim.Optimizer"]: if finetuning_args.use_galore: @@ -415,7 +421,7 @@ def create_custom_optimizer( def create_custom_scheduler( - training_args: "Seq2SeqTrainingArguments", + training_args: "TrainingArguments", num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None, ) -> None: @@ -499,3 +505,28 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall config={"Framework": "🦙LlamaFactory"}, ) return swanlab_callback + + +def get_ray_trainer( + training_function: Callable, + train_loop_config: Dict[str, Any], + ray_args: "RayArguments", +) -> "TorchTrainer": + if not ray_args.use_ray: + raise ValueError("Ray was not enabled. Please set `USE_RAY=1` to enable ray.") + + trainer = TorchTrainer( + training_function, + train_loop_config=train_loop_config, + scaling_config=ScalingConfig( + num_workers=ray_args.ray_num_workers, + resources_per_worker=ray_args.resources_per_worker, + placement_strategy=ray_args.placement_strategy, + use_gpu=True, + ), + run_config=RunConfig( + name=ray_args.ray_run_name, + storage_path=Path("./saves").absolute().as_posix(), + ), + ) + return trainer diff --git a/src/llamafactory/train/tuner.py b/src/llamafactory/train/tuner.py index 8c461890..24620c87 100644 --- a/src/llamafactory/train/tuner.py +++ b/src/llamafactory/train/tuner.py @@ -22,8 +22,8 @@ from transformers import PreTrainedModel from ..data import get_template_and_fix_tokenizer from ..extras import logging from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME -from ..hparams import get_infer_args, get_train_args -from ..hparams.parser import _parse_ray_args, _read_args +from ..extras.packages import is_ray_available +from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args from ..model import load_model, load_tokenizer from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback from .dpo import run_dpo @@ -32,7 +32,11 @@ from .ppo import run_ppo from .pt import run_pt from .rm import run_rm from .sft import run_sft -from .trainer_utils import get_swanlab_callback +from .trainer_utils import get_ray_trainer, get_swanlab_callback + + +if is_ray_available(): + from ray.train.huggingface.transformers import RayTrainReportCallback if TYPE_CHECKING: @@ -43,10 +47,8 @@ logger = logging.get_logger(__name__) def training_function(config: Dict[str, Any]) -> None: - args = config.get("args", None) - callbacks = config.get("callbacks", []) - - callbacks.append(LogCallback()) + args = config.get("args") + callbacks: List[Any] = config.get("callbacks") model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args) if finetuning_args.pissa_convert: @@ -73,31 +75,22 @@ def training_function(config: Dict[str, Any]) -> None: raise ValueError(f"Unknown task: {finetuning_args.stage}.") -def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None: - args_dict = _read_args(args) - ray_args = _parse_ray_args(args_dict) +def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None) -> None: + callbacks = callbacks or [] + callbacks.append(LogCallback()) + args = read_args(args) + ray_args = get_ray_args(args) if ray_args.use_ray: - # Import lazily to avoid ray not installed error - from ..integrations.ray.ray_train import get_ray_trainer - - # Initialize ray trainer + callbacks.append(RayTrainReportCallback()) trainer = get_ray_trainer( training_function=training_function, - train_loop_config={ - "args": args_dict, - "callbacks": callbacks, - }, + train_loop_config={"args": args, "callbacks": callbacks}, ray_args=ray_args, ) trainer.fit() else: - training_function( - config={ - "args": args_dict, - "callbacks": callbacks, - } - ) + training_function(config={"args": args, "callbacks": callbacks}) def export_model(args: Optional[Dict[str, Any]] = None) -> None: