From 0eeae9061c6ff002a5ab4d9ad2f00f6f35599a43 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Sat, 4 Jan 2025 07:25:19 +0000 Subject: [PATCH 01/10] update wechat Former-commit-id: 11a9d96a042e8afd972e0bf2fa3e51f95e4799ec --- .../{bug-report.yml => 1-bug-report.yml} | 30 +++++++------------ .github/ISSUE_TEMPLATE/2-feature-request.yml | 23 ++++++++++++++ .github/ISSUE_TEMPLATE/config.yml | 1 + docker/docker-cuda/Dockerfile | 28 ++++++++--------- docker/docker-npu/Dockerfile | 24 +++++++-------- docker/docker-rocm/Dockerfile | 28 ++++++++--------- 6 files changed, 74 insertions(+), 60 deletions(-) rename .github/ISSUE_TEMPLATE/{bug-report.yml => 1-bug-report.yml} (57%) create mode 100644 .github/ISSUE_TEMPLATE/2-feature-request.yml create mode 100644 .github/ISSUE_TEMPLATE/config.yml diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/1-bug-report.yml similarity index 57% rename from .github/ISSUE_TEMPLATE/bug-report.yml rename to .github/ISSUE_TEMPLATE/1-bug-report.yml index 58561329..4e5ffba0 100644 --- a/.github/ISSUE_TEMPLATE/bug-report.yml +++ b/.github/ISSUE_TEMPLATE/1-bug-report.yml @@ -1,11 +1,11 @@ -name: "\U0001F41B Bug / Help" +name: "\U0001F41B Bug / help" description: Create a report to help us improve the LLaMA Factory body: - type: markdown attributes: value: | - Issues included in **FAQs** or those with **insufficient** information may be closed without a response. - 包含在**常见问题**内或提供信息**不完整**的 issues 可能不会被回复。 + Issues included in **[FAQs](https://github.com/hiyouga/LLaMA-Factory/issues/4614)** or those with **insufficient** information may be closed without a response. + 包含在 **[常见问题](https://github.com/hiyouga/LLaMA-Factory/issues/4614)** 内或提供信息**不完整**的 issues 可能不会被回复。 - type: checkboxes id: reminder @@ -38,26 +38,16 @@ body: attributes: label: Reproduction description: | - Please provide code snippets, error messages and stack traces that reproduces the problem. - 请提供运行参数,错误信息以及异常堆栈以便于我们复现该问题。 - Remember to use Markdown tags to correctly format your code. - 请合理使用 Markdown 标签来格式化您的文本。 + Please provide entry arguments, error messages and stack traces that reproduces the problem. + 请提供入口参数,错误日志以及异常堆栈以便于我们复现问题。 + Remember to wrap your log messages with \`\`\`. + 请务必使用 Markdown 标签 \`\`\` 来包裹您的日志信息。 - placeholder: | - ```bash - llamafactory-cli train ... + value: | + ```text + Put your message here. ``` - - type: textarea - id: expected-behavior - validations: - required: false - attributes: - label: Expected behavior - description: | - Please provide a clear and concise description of what you would expect to happen. - 请提供您原本的目的,即这段代码的期望行为。 - - type: textarea id: others validations: diff --git a/.github/ISSUE_TEMPLATE/2-feature-request.yml b/.github/ISSUE_TEMPLATE/2-feature-request.yml new file mode 100644 index 00000000..73676a7e --- /dev/null +++ b/.github/ISSUE_TEMPLATE/2-feature-request.yml @@ -0,0 +1,23 @@ +name: "\U0001F680 Feature request" +description: Submit a request for a new feature +labels: ["enhancement"] +body: + - type: textarea + id: description + validations: + required: true + attributes: + label: Description + description: | + A clear and concise description of the feature proposal. + 请详细描述您希望加入的新功能特性。 + + - type: textarea + id: contribution + validations: + required: false + attributes: + label: Pull Request + description: | + Have you already created the relevant PR and submitted the code? + 您是否已经创建了相关 PR 并提交了代码? diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 00000000..3ba13e0c --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1 @@ +blank_issues_enabled: false diff --git a/docker/docker-cuda/Dockerfile b/docker/docker-cuda/Dockerfile index b1914ed5..15ac6ec2 100644 --- a/docker/docker-cuda/Dockerfile +++ b/docker/docker-cuda/Dockerfile @@ -23,10 +23,10 @@ ARG HTTP_PROXY= WORKDIR /app # Set http proxy -RUN if [ -n "$HTTP_PROXY" ]; then \ - echo "Configuring proxy..."; \ - export http_proxy=$HTTP_PROXY; \ - export https_proxy=$HTTP_PROXY; \ +RUN if [ -n "$HTTP_PROXY" ]; then \ + echo "Configuring proxy..."; \ + export http_proxy=$HTTP_PROXY; \ + export https_proxy=$HTTP_PROXY; \ fi # Install the requirements @@ -34,10 +34,10 @@ COPY requirements.txt /app RUN pip config set global.index-url "$PIP_INDEX" && \ pip config set global.extra-index-url "$PIP_INDEX" && \ python -m pip install --upgrade pip && \ - if [ -n "$HTTP_PROXY" ]; then \ - python -m pip install --proxy=$HTTP_PROXY -r requirements.txt; \ - else \ - python -m pip install -r requirements.txt; \ + if [ -n "$HTTP_PROXY" ]; then \ + python -m pip install --proxy=$HTTP_PROXY -r requirements.txt; \ + else \ + python -m pip install -r requirements.txt; \ fi # Copy the rest of the application into the image @@ -63,10 +63,10 @@ RUN EXTRA_PACKAGES="metrics"; \ if [ "$INSTALL_EETQ" == "true" ]; then \ EXTRA_PACKAGES="${EXTRA_PACKAGES},eetq"; \ fi; \ - if [ -n "$HTTP_PROXY" ]; then \ - pip install --proxy=$HTTP_PROXY -e ".[$EXTRA_PACKAGES]"; \ - else \ - pip install -e ".[$EXTRA_PACKAGES]"; \ + if [ -n "$HTTP_PROXY" ]; then \ + pip install --proxy=$HTTP_PROXY -e ".[$EXTRA_PACKAGES]"; \ + else \ + pip install -e ".[$EXTRA_PACKAGES]"; \ fi # Rebuild flash attention @@ -76,8 +76,8 @@ RUN pip uninstall -y transformer-engine flash-attn && \ if [ -n "$HTTP_PROXY" ]; then \ pip install --proxy=$HTTP_PROXY ninja && \ pip install --proxy=$HTTP_PROXY --no-cache-dir flash-attn --no-build-isolation; \ - else \ - pip install ninja && \ + else \ + pip install ninja && \ pip install --no-cache-dir flash-attn --no-build-isolation; \ fi; \ fi diff --git a/docker/docker-npu/Dockerfile b/docker/docker-npu/Dockerfile index 15d4eee4..aa315806 100644 --- a/docker/docker-npu/Dockerfile +++ b/docker/docker-npu/Dockerfile @@ -18,10 +18,10 @@ ARG HTTP_PROXY= WORKDIR /app # Set http proxy -RUN if [ -n "$HTTP_PROXY" ]; then \ - echo "Configuring proxy..."; \ - export http_proxy=$HTTP_PROXY; \ - export https_proxy=$HTTP_PROXY; \ +RUN if [ -n "$HTTP_PROXY" ]; then \ + echo "Configuring proxy..."; \ + export http_proxy=$HTTP_PROXY; \ + export https_proxy=$HTTP_PROXY; \ fi # Install the requirements @@ -29,10 +29,10 @@ COPY requirements.txt /app RUN pip config set global.index-url "$PIP_INDEX" && \ pip config set global.extra-index-url "$TORCH_INDEX" && \ python -m pip install --upgrade pip && \ - if [ -n "$HTTP_PROXY" ]; then \ - python -m pip install --proxy=$HTTP_PROXY -r requirements.txt; \ - else \ - python -m pip install -r requirements.txt; \ + if [ -n "$HTTP_PROXY" ]; then \ + python -m pip install --proxy=$HTTP_PROXY -r requirements.txt; \ + else \ + python -m pip install -r requirements.txt; \ fi # Copy the rest of the application into the image @@ -43,10 +43,10 @@ RUN EXTRA_PACKAGES="torch-npu,metrics"; \ if [ "$INSTALL_DEEPSPEED" == "true" ]; then \ EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \ fi; \ - if [ -n "$HTTP_PROXY" ]; then \ - pip install --proxy=$HTTP_PROXY -e ".[$EXTRA_PACKAGES]"; \ - else \ - pip install -e ".[$EXTRA_PACKAGES]"; \ + if [ -n "$HTTP_PROXY" ]; then \ + pip install --proxy=$HTTP_PROXY -e ".[$EXTRA_PACKAGES]"; \ + else \ + pip install -e ".[$EXTRA_PACKAGES]"; \ fi # Unset http proxy diff --git a/docker/docker-rocm/Dockerfile b/docker/docker-rocm/Dockerfile index 86e96a37..61eb68e5 100644 --- a/docker/docker-rocm/Dockerfile +++ b/docker/docker-rocm/Dockerfile @@ -19,10 +19,10 @@ ARG HTTP_PROXY= WORKDIR /app # Set http proxy -RUN if [ -n "$HTTP_PROXY" ]; then \ - echo "Configuring proxy..."; \ - export http_proxy=$HTTP_PROXY; \ - export https_proxy=$HTTP_PROXY; \ +RUN if [ -n "$HTTP_PROXY" ]; then \ + echo "Configuring proxy..."; \ + export http_proxy=$HTTP_PROXY; \ + export https_proxy=$HTTP_PROXY; \ fi # Install the requirements @@ -30,10 +30,10 @@ COPY requirements.txt /app RUN pip config set global.index-url "$PIP_INDEX" && \ pip config set global.extra-index-url "$PIP_INDEX" && \ python -m pip install --upgrade pip && \ - if [ -n "$HTTP_PROXY" ]; then \ - python -m pip install --proxy=$HTTP_PROXY -r requirements.txt; \ - else \ - python -m pip install -r requirements.txt; \ + if [ -n "$HTTP_PROXY" ]; then \ + python -m pip install --proxy=$HTTP_PROXY -r requirements.txt; \ + else \ + python -m pip install -r requirements.txt; \ fi # Copy the rest of the application into the image @@ -56,10 +56,10 @@ RUN EXTRA_PACKAGES="metrics"; \ if [ "$INSTALL_HQQ" == "true" ]; then \ EXTRA_PACKAGES="${EXTRA_PACKAGES},hqq"; \ fi; \ - if [ -n "$HTTP_PROXY" ]; then \ - pip install --proxy=$HTTP_PROXY -e ".[$EXTRA_PACKAGES]"; \ - else \ - pip install -e ".[$EXTRA_PACKAGES]"; \ + if [ -n "$HTTP_PROXY" ]; then \ + pip install --proxy=$HTTP_PROXY -e ".[$EXTRA_PACKAGES]"; \ + else \ + pip install -e ".[$EXTRA_PACKAGES]"; \ fi # Rebuild flash attention @@ -69,8 +69,8 @@ RUN pip uninstall -y transformer-engine flash-attn && \ if [ -n "$HTTP_PROXY" ]; then \ pip install --proxy=$HTTP_PROXY ninja && \ pip install --proxy=$HTTP_PROXY --no-cache-dir flash-attn --no-build-isolation; \ - else \ - pip install ninja && \ + else \ + pip install ninja && \ pip install --no-cache-dir flash-attn --no-build-isolation; \ fi; \ fi From d8bd46f1bf1a1942e46a9b5e4df5caed5e09e843 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Tue, 7 Jan 2025 06:30:44 +0000 Subject: [PATCH 02/10] fix #6546 Former-commit-id: 6fcf2f10faf3b1614896b091591eeef96d717e64 --- src/llamafactory/train/dpo/trainer.py | 4 +-- src/llamafactory/train/kto/trainer.py | 4 +-- src/llamafactory/train/trainer_utils.py | 34 ++++++++++++++++++++----- 3 files changed, 32 insertions(+), 10 deletions(-) diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index ad670385..770b32e5 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -31,7 +31,7 @@ from typing_extensions import override from ...extras.constants import IGNORE_INDEX from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than from ..callbacks import SaveProcessorCallback -from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, nested_detach if TYPE_CHECKING: @@ -193,7 +193,7 @@ class CustomDPOTrainer(DPOTrainer): Otherwise the average log probabilities. """ if self.finetuning_args.use_ref_model: - batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error + batch = nested_detach(batch, clone=True) # avoid error all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32) all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"]) diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py index 3c6d1089..419de579 100644 --- a/src/llamafactory/train/kto/trainer.py +++ b/src/llamafactory/train/kto/trainer.py @@ -30,7 +30,7 @@ from typing_extensions import override from ...extras.constants import IGNORE_INDEX from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than from ..callbacks import SaveProcessorCallback -from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, nested_detach if TYPE_CHECKING: @@ -142,7 +142,7 @@ class CustomKTOTrainer(KTOTrainer): r""" Runs forward pass and computes the log probabilities. """ - batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error + batch = nested_detach(batch, clone=True) # avoid error model_inputs = { "input_ids": batch[f"{prefix}input_ids"], "attention_mask": batch[f"{prefix}attention_mask"], diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index eb2421ce..4cd2337e 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -17,6 +17,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Mapping from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union import torch @@ -36,7 +37,7 @@ from ..model import find_all_linear_modules, load_model, load_tokenizer, load_va if is_galore_available(): - from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit + from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit # type: ignore if TYPE_CHECKING: @@ -330,7 +331,7 @@ def _create_badam_optimizer( ] if finetuning_args.badam_mode == "layer": - from badam import BlockOptimizer + from badam import BlockOptimizer # type: ignore base_optimizer = optim_class(param_groups, **optim_kwargs) optimizer = BlockOptimizer( @@ -350,7 +351,7 @@ def _create_badam_optimizer( ) elif finetuning_args.badam_mode == "ratio": - from badam import BlockOptimizerRatio + from badam import BlockOptimizerRatio # type: ignore assert finetuning_args.badam_update_ratio > 1e-6 optimizer = BlockOptimizerRatio( @@ -374,7 +375,7 @@ def _create_adam_mini_optimizer( model: "PreTrainedModel", training_args: "Seq2SeqTrainingArguments", ) -> "torch.optim.Optimizer": - from adam_mini import Adam_mini + from adam_mini import Adam_mini # type: ignore hidden_size = getattr(model.config, "hidden_size", None) num_q_head = getattr(model.config, "num_attention_heads", None) @@ -459,12 +460,33 @@ def get_batch_logps( return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1) +def nested_detach( + tensors: Union["torch.Tensor", List["torch.Tensor"], Tuple["torch.Tensor"], Dict[str, "torch.Tensor"]], + clone: bool = False, +): + r""" + Detach `tensors` (even if it's a nested list/tuple/dict of tensors). + """ + if isinstance(tensors, (list, tuple)): + return type(tensors)(nested_detach(t, clone=clone) for t in tensors) + elif isinstance(tensors, Mapping): + return type(tensors)({k: nested_detach(t, clone=clone) for k, t in tensors.items()}) + + if isinstance(tensors, torch.Tensor): + if clone: + return tensors.detach().clone() + else: + return tensors.detach() + else: + return tensors + + def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCallback": r""" Gets the callback for logging to SwanLab. """ - import swanlab - from swanlab.integration.transformers import SwanLabCallback + import swanlab # type: ignore + from swanlab.integration.transformers import SwanLabCallback # type: ignore if finetuning_args.swanlab_api_key is not None: swanlab.login(api_key=finetuning_args.swanlab_api_key) From 8683582300f24a296a3006932587ba4e864f9353 Mon Sep 17 00:00:00 2001 From: Kourosh Hakhamaneshi Date: Mon, 30 Dec 2024 16:48:52 -0800 Subject: [PATCH 03/10] drafting ray integration Signed-off-by: Kourosh Hakhamaneshi Former-commit-id: 19c12ddae9350f6e25a270fe3372f5b9094cf960 --- examples/train_lora/llama3_lora_sft.yaml | 8 +++ src/llamafactory/cli.py | 5 +- src/llamafactory/hparams/parser.py | 56 +++++++++++++------ src/llamafactory/integrations/__init__.py | 0 src/llamafactory/integrations/ray/__init__.py | 0 .../integrations/ray/ray_train.py | 28 ++++++++++ .../integrations/ray/ray_train_args.py | 22 ++++++++ .../integrations/ray/ray_utils.py | 9 +++ src/llamafactory/train/tuner.py | 36 +++++++++++- 9 files changed, 143 insertions(+), 21 deletions(-) create mode 100644 src/llamafactory/integrations/__init__.py create mode 100644 src/llamafactory/integrations/ray/__init__.py create mode 100644 src/llamafactory/integrations/ray/ray_train.py create mode 100644 src/llamafactory/integrations/ray/ray_train_args.py create mode 100644 src/llamafactory/integrations/ray/ray_utils.py diff --git a/examples/train_lora/llama3_lora_sft.yaml b/examples/train_lora/llama3_lora_sft.yaml index 243f2445..558ac3e9 100644 --- a/examples/train_lora/llama3_lora_sft.yaml +++ b/examples/train_lora/llama3_lora_sft.yaml @@ -9,6 +9,7 @@ finetuning_type: lora lora_target: all ### dataset +dataset_dir: /home/ray/default/lf/data/ dataset: identity,alpaca_en_demo template: llama3 cutoff_len: 2048 @@ -38,3 +39,10 @@ val_size: 0.1 per_device_eval_batch_size: 1 eval_strategy: steps eval_steps: 500 + + +### ray setup +resources_per_worker: + GPU: 1 +num_workers: 4 +# placement_strategy: ... diff --git a/src/llamafactory/cli.py b/src/llamafactory/cli.py index ef5bd1dc..26d7a3df 100644 --- a/src/llamafactory/cli.py +++ b/src/llamafactory/cli.py @@ -27,7 +27,7 @@ from .extras.env import VERSION, print_env from .extras.misc import get_device_count from .train.tuner import export_model, run_exp from .webui.interface import run_web_demo, run_web_ui - +from .integrations.ray.ray_utils import should_use_ray USAGE = ( "-" * 70 @@ -87,7 +87,8 @@ def main(): export_model() elif command == Command.TRAIN: force_torchrun = os.getenv("FORCE_TORCHRUN", "0").lower() in ["true", "1"] - if force_torchrun or get_device_count() > 1: + use_ray = should_use_ray() + if force_torchrun or (get_device_count() > 1 and not use_ray): master_addr = os.getenv("MASTER_ADDR", "127.0.0.1") master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999))) logger.info_rank0(f"Initializing distributed tasks at: {master_addr}:{master_port}") diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 4a254367..06853ae8 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -19,6 +19,10 @@ import os import sys from typing import Any, Dict, Optional, Tuple +import json +import yaml +from pathlib import Path + import torch import transformers from transformers import HfArgumentParser, Seq2SeqTrainingArguments @@ -37,39 +41,51 @@ from .finetuning_args import FinetuningArguments from .generating_args import GeneratingArguments from .model_args import ModelArguments +from ..integrations.ray.ray_train_args import RayTrainArguments logger = logging.get_logger(__name__) - check_dependencies() -_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments] -_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments] +_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments, RayTrainArguments] +_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments, RayTrainArguments] _INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] _INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] _EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] _EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] -def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]: +def _read_args(args: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: if args is not None: - return parser.parse_dict(args) + return args if len(sys.argv) == 2 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")): - return parser.parse_yaml_file(os.path.abspath(sys.argv[1])) + # read yaml file + return yaml.safe_load(Path(sys.argv[1]).absolute().read_text()) + elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # read json file + return json.loads(Path(sys.argv[1]).absolute().read_text()) + else: + return {} - if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): - return parser.parse_json_file(os.path.abspath(sys.argv[1])) - (*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(return_remaining_strings=True) +def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None, allow_extra_keys: bool = False) -> Tuple[Any]: - if unknown_args: - print(parser.format_help()) - print(f"Got unknown args, potentially deprecated arguments: {unknown_args}") - raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}") + args_dict = _read_args(args) + + if args_dict: + return parser.parse_dict(args_dict, allow_extra_keys=allow_extra_keys) + else: + (*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(args=args_dict, return_remaining_strings=True) - return (*parsed_args,) + if unknown_args: + print(parser.format_help()) + print(f"Got unknown args, potentially deprecated arguments: {unknown_args}") + raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}") + + return (*parsed_args,) + def _set_transformers_logging() -> None: @@ -161,9 +177,17 @@ def _parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS: return _parse_args(parser, args) -def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: - model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args) +def _parse_ray_args(args: Optional[Dict[str, Any]] = None) -> RayTrainArguments: + parser = HfArgumentParser(RayTrainArguments) + ray_args = _parse_args(parser, args, allow_extra_keys=True)[0] + if ray_args.use_ray: + require_version("ray", "To fix: pip install ray") + return ray_args + +def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: + model_args, data_args, training_args, finetuning_args, generating_args, _ = _parse_train_args(args) + # Setup logging if training_args.should_log: _set_transformers_logging() diff --git a/src/llamafactory/integrations/__init__.py b/src/llamafactory/integrations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/llamafactory/integrations/ray/__init__.py b/src/llamafactory/integrations/ray/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/llamafactory/integrations/ray/ray_train.py b/src/llamafactory/integrations/ray/ray_train.py new file mode 100644 index 00000000..4620bb5a --- /dev/null +++ b/src/llamafactory/integrations/ray/ray_train.py @@ -0,0 +1,28 @@ + +from typing import Any, Callable, Dict + +from ray.train.torch import TorchTrainer +from ray.train import ScalingConfig + +from .ray_train_args import RayTrainArguments + + +def get_ray_trainer( + training_function: Callable, + train_loop_config: Dict[str, Any], + ray_args: RayTrainArguments, +) -> TorchTrainer: + + if not ray_args.use_ray: + raise ValueError("Ray is not enabled. Please set USE_RAY=1 in your environment.") + + trainer = TorchTrainer( + training_function, + train_loop_config=train_loop_config, + scaling_config=ScalingConfig( + num_workers=ray_args.num_workers, + resources_per_worker=ray_args.resources_per_worker, + use_gpu=True, + ), + ) + return trainer \ No newline at end of file diff --git a/src/llamafactory/integrations/ray/ray_train_args.py b/src/llamafactory/integrations/ray/ray_train_args.py new file mode 100644 index 00000000..9ee9dc8e --- /dev/null +++ b/src/llamafactory/integrations/ray/ray_train_args.py @@ -0,0 +1,22 @@ +from dataclasses import dataclass, field +from typing import Any, Dict, Literal, Optional + +from .ray_utils import should_use_ray + +@dataclass +class RayTrainArguments: + r""" + Arguments pertaining to the Ray training. + """ + resources_per_worker: Optional[Dict[str, Any]] = field(default_factory=lambda: {"GPU": 1}, metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."}) + num_workers: Optional[int] = field(default=1, metadata={"help": "The number of workers for Ray training. Default is 1 worker."}) + placement_strategy: Optional[Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"]] = field(default="PACK", metadata={"help": "The placement strategy for Ray training. Default is PACK."}) + + @property + def use_ray(self) -> bool: + """ + Always returns the value from the environment variable check. + This prevents manual setting of use_ray. + """ + return should_use_ray() + diff --git a/src/llamafactory/integrations/ray/ray_utils.py b/src/llamafactory/integrations/ray/ray_utils.py new file mode 100644 index 00000000..67ce2ed9 --- /dev/null +++ b/src/llamafactory/integrations/ray/ray_utils.py @@ -0,0 +1,9 @@ + +import os + + +def should_use_ray(): + return os.getenv("USE_RAY", "0").lower() in ["true", "1"] + + + diff --git a/src/llamafactory/train/tuner.py b/src/llamafactory/train/tuner.py index 6c79320e..507a1f14 100644 --- a/src/llamafactory/train/tuner.py +++ b/src/llamafactory/train/tuner.py @@ -23,6 +23,7 @@ from ..data import get_template_and_fix_tokenizer from ..extras import logging from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from ..hparams import get_infer_args, get_train_args +from ..hparams.parser import _parse_ray_args, _read_args from ..model import load_model, load_tokenizer from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback from .dpo import run_dpo @@ -36,12 +37,14 @@ from .trainer_utils import get_swanlab_callback if TYPE_CHECKING: from transformers import TrainerCallback - + logger = logging.get_logger(__name__) - -def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None: +def training_function(config: Dict[str, Any]) -> None: + args = config.get("args", None) + callbacks = config.get("callbacks", []) + callbacks.append(LogCallback()) model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args) @@ -68,6 +71,33 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallb else: raise ValueError(f"Unknown task: {finetuning_args.stage}.") +def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None: + + args_dict = _read_args(args) + ray_args = _parse_ray_args(args_dict) + + if ray_args.use_ray: + # Import lazily to avoid ray not installed error + from ..integrations.ray.ray_train import get_ray_trainer + + # Initialize ray trainer + trainer = get_ray_trainer( + training_function=training_function, + train_loop_config={ + "args": args_dict, + "callbacks": callbacks, + }, + ray_args=ray_args, + ) + trainer.fit() + else: + training_function( + config={ + "args": args_dict, + "callbacks": callbacks, + } + ) + def export_model(args: Optional[Dict[str, Any]] = None) -> None: model_args, data_args, finetuning_args, _ = get_infer_args(args) From 4f31ad997c2fcba1c20a2298659a8604f59cf336 Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Mon, 6 Jan 2025 23:55:56 +0000 Subject: [PATCH 04/10] run style check Former-commit-id: 5ec33baf5f95df9fa2afe5523c825d3eda8a076b --- examples/train_lora/llama3_lora_sft.yaml | 2 +- src/llamafactory/cli.py | 3 +- src/llamafactory/hparams/parser.py | 41 +++++++++++++------ .../integrations/ray/ray_train.py | 10 ++--- .../integrations/ray/ray_train_args.py | 16 ++++++-- .../integrations/ray/ray_utils.py | 4 -- src/llamafactory/train/tuner.py | 13 +++--- 7 files changed, 54 insertions(+), 35 deletions(-) diff --git a/examples/train_lora/llama3_lora_sft.yaml b/examples/train_lora/llama3_lora_sft.yaml index 558ac3e9..dc4f5add 100644 --- a/examples/train_lora/llama3_lora_sft.yaml +++ b/examples/train_lora/llama3_lora_sft.yaml @@ -9,7 +9,7 @@ finetuning_type: lora lora_target: all ### dataset -dataset_dir: /home/ray/default/lf/data/ +dataset_dir: /home/ray/default/LLaMA-Factory/data/ dataset: identity,alpaca_en_demo template: llama3 cutoff_len: 2048 diff --git a/src/llamafactory/cli.py b/src/llamafactory/cli.py index 26d7a3df..a2a8e29d 100644 --- a/src/llamafactory/cli.py +++ b/src/llamafactory/cli.py @@ -25,9 +25,10 @@ from .eval.evaluator import run_eval from .extras import logging from .extras.env import VERSION, print_env from .extras.misc import get_device_count +from .integrations.ray.ray_utils import should_use_ray from .train.tuner import export_model, run_exp from .webui.interface import run_web_demo, run_web_ui -from .integrations.ray.ray_utils import should_use_ray + USAGE = ( "-" * 70 diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 06853ae8..8cdfa7cb 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -15,16 +15,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import os import sys -from typing import Any, Dict, Optional, Tuple - -import json -import yaml from pathlib import Path +from typing import Any, Dict, Optional, Tuple import torch import transformers +import yaml from transformers import HfArgumentParser, Seq2SeqTrainingArguments from transformers.integrations import is_deepspeed_zero3_enabled from transformers.trainer_utils import get_last_checkpoint @@ -35,21 +34,35 @@ from transformers.utils.versions import require_version from ..extras import logging from ..extras.constants import CHECKPOINT_NAMES from ..extras.misc import check_dependencies, get_current_device +from ..integrations.ray.ray_train_args import RayTrainArguments from .data_args import DataArguments from .evaluation_args import EvaluationArguments from .finetuning_args import FinetuningArguments from .generating_args import GeneratingArguments from .model_args import ModelArguments -from ..integrations.ray.ray_train_args import RayTrainArguments logger = logging.get_logger(__name__) check_dependencies() -_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments, RayTrainArguments] -_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments, RayTrainArguments] +_TRAIN_ARGS = [ + ModelArguments, + DataArguments, + Seq2SeqTrainingArguments, + FinetuningArguments, + GeneratingArguments, + RayTrainArguments, +] +_TRAIN_CLS = Tuple[ + ModelArguments, + DataArguments, + Seq2SeqTrainingArguments, + FinetuningArguments, + GeneratingArguments, + RayTrainArguments, +] _INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] _INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] _EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] @@ -70,14 +83,17 @@ def _read_args(args: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: return {} -def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None, allow_extra_keys: bool = False) -> Tuple[Any]: - +def _parse_args( + parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None, allow_extra_keys: bool = False +) -> Tuple[Any]: args_dict = _read_args(args) - + if args_dict: return parser.parse_dict(args_dict, allow_extra_keys=allow_extra_keys) else: - (*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(args=args_dict, return_remaining_strings=True) + (*parsed_args, unknown_args) = parser.parse_args_into_dataclasses( + args=args_dict, return_remaining_strings=True + ) if unknown_args: print(parser.format_help()) @@ -85,7 +101,6 @@ def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = Non raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}") return (*parsed_args,) - def _set_transformers_logging() -> None: @@ -187,7 +202,7 @@ def _parse_ray_args(args: Optional[Dict[str, Any]] = None) -> RayTrainArguments: def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: model_args, data_args, training_args, finetuning_args, generating_args, _ = _parse_train_args(args) - + # Setup logging if training_args.should_log: _set_transformers_logging() diff --git a/src/llamafactory/integrations/ray/ray_train.py b/src/llamafactory/integrations/ray/ray_train.py index 4620bb5a..50a2927a 100644 --- a/src/llamafactory/integrations/ray/ray_train.py +++ b/src/llamafactory/integrations/ray/ray_train.py @@ -1,21 +1,19 @@ - from typing import Any, Callable, Dict -from ray.train.torch import TorchTrainer from ray.train import ScalingConfig +from ray.train.torch import TorchTrainer from .ray_train_args import RayTrainArguments - + def get_ray_trainer( training_function: Callable, train_loop_config: Dict[str, Any], ray_args: RayTrainArguments, ) -> TorchTrainer: - if not ray_args.use_ray: raise ValueError("Ray is not enabled. Please set USE_RAY=1 in your environment.") - + trainer = TorchTrainer( training_function, train_loop_config=train_loop_config, @@ -25,4 +23,4 @@ def get_ray_trainer( use_gpu=True, ), ) - return trainer \ No newline at end of file + return trainer diff --git a/src/llamafactory/integrations/ray/ray_train_args.py b/src/llamafactory/integrations/ray/ray_train_args.py index 9ee9dc8e..afacbee1 100644 --- a/src/llamafactory/integrations/ray/ray_train_args.py +++ b/src/llamafactory/integrations/ray/ray_train_args.py @@ -3,14 +3,23 @@ from typing import Any, Dict, Literal, Optional from .ray_utils import should_use_ray + @dataclass class RayTrainArguments: r""" Arguments pertaining to the Ray training. """ - resources_per_worker: Optional[Dict[str, Any]] = field(default_factory=lambda: {"GPU": 1}, metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."}) - num_workers: Optional[int] = field(default=1, metadata={"help": "The number of workers for Ray training. Default is 1 worker."}) - placement_strategy: Optional[Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"]] = field(default="PACK", metadata={"help": "The placement strategy for Ray training. Default is PACK."}) + + resources_per_worker: Optional[Dict[str, Any]] = field( + default_factory=lambda: {"GPU": 1}, + metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."}, + ) + num_workers: Optional[int] = field( + default=1, metadata={"help": "The number of workers for Ray training. Default is 1 worker."} + ) + placement_strategy: Optional[Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"]] = field( + default="PACK", metadata={"help": "The placement strategy for Ray training. Default is PACK."} + ) @property def use_ray(self) -> bool: @@ -19,4 +28,3 @@ class RayTrainArguments: This prevents manual setting of use_ray. """ return should_use_ray() - diff --git a/src/llamafactory/integrations/ray/ray_utils.py b/src/llamafactory/integrations/ray/ray_utils.py index 67ce2ed9..8b8b045e 100644 --- a/src/llamafactory/integrations/ray/ray_utils.py +++ b/src/llamafactory/integrations/ray/ray_utils.py @@ -1,9 +1,5 @@ - import os def should_use_ray(): return os.getenv("USE_RAY", "0").lower() in ["true", "1"] - - - diff --git a/src/llamafactory/train/tuner.py b/src/llamafactory/train/tuner.py index 507a1f14..8c461890 100644 --- a/src/llamafactory/train/tuner.py +++ b/src/llamafactory/train/tuner.py @@ -37,14 +37,15 @@ from .trainer_utils import get_swanlab_callback if TYPE_CHECKING: from transformers import TrainerCallback - + logger = logging.get_logger(__name__) + def training_function(config: Dict[str, Any]) -> None: args = config.get("args", None) callbacks = config.get("callbacks", []) - + callbacks.append(LogCallback()) model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args) @@ -71,15 +72,15 @@ def training_function(config: Dict[str, Any]) -> None: else: raise ValueError(f"Unknown task: {finetuning_args.stage}.") + def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None: - args_dict = _read_args(args) ray_args = _parse_ray_args(args_dict) - - if ray_args.use_ray: + + if ray_args.use_ray: # Import lazily to avoid ray not installed error from ..integrations.ray.ray_train import get_ray_trainer - + # Initialize ray trainer trainer = get_ray_trainer( training_function=training_function, From 944a2aec4d9205dd7cd177f92006f4345c5f609c Mon Sep 17 00:00:00 2001 From: hiyouga Date: Tue, 7 Jan 2025 08:54:41 +0000 Subject: [PATCH 05/10] refactor ray integration, support save ckpt Former-commit-id: 2f50b27e608b2092bfceab6c6e84e6631e973ee2 --- .env.local | 1 + examples/README.md | 6 ++ examples/README_zh.md | 6 ++ examples/train_lora/llama3_lora_sft.yaml | 8 -- examples/train_lora/llama3_lora_sft_ray.yaml | 48 ++++++++++++ src/llamafactory/cli.py | 6 +- src/llamafactory/extras/misc.py | 10 ++- src/llamafactory/extras/packages.py | 4 + src/llamafactory/hparams/__init__.py | 7 +- src/llamafactory/hparams/parser.py | 77 +++++++------------ src/llamafactory/hparams/training_args.py | 48 ++++++++++++ src/llamafactory/integrations/__init__.py | 0 src/llamafactory/integrations/ray/__init__.py | 0 .../integrations/ray/ray_train.py | 26 ------- .../integrations/ray/ray_train_args.py | 30 -------- .../integrations/ray/ray_utils.py | 5 -- src/llamafactory/train/trainer_utils.py | 53 ++++++++++--- src/llamafactory/train/tuner.py | 41 ++++------ 18 files changed, 215 insertions(+), 161 deletions(-) create mode 100644 examples/train_lora/llama3_lora_sft_ray.yaml create mode 100644 src/llamafactory/hparams/training_args.py delete mode 100644 src/llamafactory/integrations/__init__.py delete mode 100644 src/llamafactory/integrations/ray/__init__.py delete mode 100644 src/llamafactory/integrations/ray/ray_train.py delete mode 100644 src/llamafactory/integrations/ray/ray_train_args.py delete mode 100644 src/llamafactory/integrations/ray/ray_utils.py diff --git a/.env.local b/.env.local index 203aebaf..1d8f2a00 100644 --- a/.env.local +++ b/.env.local @@ -12,6 +12,7 @@ FORCE_CHECK_IMPORTS= LLAMAFACTORY_VERBOSITY= USE_MODELSCOPE_HUB= USE_OPENMIND_HUB= +USE_RAY= RECORD_VRAM= # torchrun FORCE_TORCHRUN= diff --git a/examples/README.md b/examples/README.md index cc2afc46..89f7d174 100644 --- a/examples/README.md +++ b/examples/README.md @@ -95,6 +95,12 @@ FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml ``` +#### Supervised Fine-Tuning with Ray on 4 GPUs + +```bash +USE_RAY=1 llamafactory-cli train examples/train_full/llama3_lora_sft_ray.yaml +``` + ### QLoRA Fine-Tuning #### Supervised Fine-Tuning with 4/8-bit Bitsandbytes/HQQ/EETQ Quantization (Recommended) diff --git a/examples/README_zh.md b/examples/README_zh.md index b41d7ab8..2c108e56 100644 --- a/examples/README_zh.md +++ b/examples/README_zh.md @@ -95,6 +95,12 @@ FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml ``` +#### 使用 Ray 在 4 张 GPU 上微调 + +```bash +USE_RAY=1 llamafactory-cli train examples/train_full/llama3_lora_sft_ray.yaml +``` + ### QLoRA 微调 #### 基于 4/8 比特 Bitsandbytes/HQQ/EETQ 量化进行指令监督微调(推荐) diff --git a/examples/train_lora/llama3_lora_sft.yaml b/examples/train_lora/llama3_lora_sft.yaml index dc4f5add..243f2445 100644 --- a/examples/train_lora/llama3_lora_sft.yaml +++ b/examples/train_lora/llama3_lora_sft.yaml @@ -9,7 +9,6 @@ finetuning_type: lora lora_target: all ### dataset -dataset_dir: /home/ray/default/LLaMA-Factory/data/ dataset: identity,alpaca_en_demo template: llama3 cutoff_len: 2048 @@ -39,10 +38,3 @@ val_size: 0.1 per_device_eval_batch_size: 1 eval_strategy: steps eval_steps: 500 - - -### ray setup -resources_per_worker: - GPU: 1 -num_workers: 4 -# placement_strategy: ... diff --git a/examples/train_lora/llama3_lora_sft_ray.yaml b/examples/train_lora/llama3_lora_sft_ray.yaml new file mode 100644 index 00000000..4aac08bc --- /dev/null +++ b/examples/train_lora/llama3_lora_sft_ray.yaml @@ -0,0 +1,48 @@ +### model +model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct # or use local absolute path +trust_remote_code: true + +### method +stage: sft +do_train: true +finetuning_type: lora +lora_target: all + +### dataset +dataset: identity,alpaca_en_demo +dataset_dir: REMOTE:llamafactory/demo_data # or use local absolute path +template: llama3 +cutoff_len: 2048 +max_samples: 1000 +overwrite_cache: true +preprocessing_num_workers: 16 + +### output +output_dir: tmp_dir +logging_steps: 10 +save_steps: 500 +plot_loss: true +overwrite_output_dir: true + +### train +per_device_train_batch_size: 1 +gradient_accumulation_steps: 8 +learning_rate: 1.0e-4 +num_train_epochs: 3.0 +lr_scheduler_type: cosine +warmup_ratio: 0.1 +bf16: true +ddp_timeout: 180000000 + +### eval +val_size: 0.1 +per_device_eval_batch_size: 1 +eval_strategy: steps +eval_steps: 500 + +### ray +ray_run_name: llama3_8b_sft_lora +ray_num_workers: 4 # number of GPUs to use +resources_per_worker: + GPU: 1 +placement_strategy: PACK diff --git a/src/llamafactory/cli.py b/src/llamafactory/cli.py index a2a8e29d..72085e2d 100644 --- a/src/llamafactory/cli.py +++ b/src/llamafactory/cli.py @@ -24,8 +24,7 @@ from .chat.chat_model import run_chat from .eval.evaluator import run_eval from .extras import logging from .extras.env import VERSION, print_env -from .extras.misc import get_device_count -from .integrations.ray.ray_utils import should_use_ray +from .extras.misc import get_device_count, use_ray from .train.tuner import export_model, run_exp from .webui.interface import run_web_demo, run_web_ui @@ -88,8 +87,7 @@ def main(): export_model() elif command == Command.TRAIN: force_torchrun = os.getenv("FORCE_TORCHRUN", "0").lower() in ["true", "1"] - use_ray = should_use_ray() - if force_torchrun or (get_device_count() > 1 and not use_ray): + if force_torchrun or (get_device_count() > 1 and not use_ray()): master_addr = os.getenv("MASTER_ADDR", "127.0.0.1") master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999))) logger.info_rank0(f"Initializing distributed tasks at: {master_addr}:{master_port}") diff --git a/src/llamafactory/extras/misc.py b/src/llamafactory/extras/misc.py index 735c5d63..11797f9f 100644 --- a/src/llamafactory/extras/misc.py +++ b/src/llamafactory/extras/misc.py @@ -229,7 +229,7 @@ def skip_check_imports() -> None: r""" Avoids flash attention import error in custom model files. """ - if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]: + if os.getenv("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]: transformers.dynamic_module_utils.check_imports = get_relative_imports @@ -275,8 +275,12 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str: def use_modelscope() -> bool: - return os.environ.get("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"] + return os.getenv("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"] def use_openmind() -> bool: - return os.environ.get("USE_OPENMIND_HUB", "0").lower() in ["true", "1"] + return os.getenv("USE_OPENMIND_HUB", "0").lower() in ["true", "1"] + + +def use_ray() -> bool: + return os.getenv("USE_RAY", "0").lower() in ["true", "1"] diff --git a/src/llamafactory/extras/packages.py b/src/llamafactory/extras/packages.py index 44b9bb8a..6b2bc3f3 100644 --- a/src/llamafactory/extras/packages.py +++ b/src/llamafactory/extras/packages.py @@ -62,6 +62,10 @@ def is_pillow_available(): return _is_package_available("PIL") +def is_ray_available(): + return _is_package_available("ray") + + def is_requests_available(): return _is_package_available("requests") diff --git a/src/llamafactory/hparams/__init__.py b/src/llamafactory/hparams/__init__.py index cfe448c1..254a845e 100644 --- a/src/llamafactory/hparams/__init__.py +++ b/src/llamafactory/hparams/__init__.py @@ -17,7 +17,8 @@ from .evaluation_args import EvaluationArguments from .finetuning_args import FinetuningArguments from .generating_args import GeneratingArguments from .model_args import ModelArguments -from .parser import get_eval_args, get_infer_args, get_train_args +from .parser import get_eval_args, get_infer_args, get_ray_args, get_train_args, read_args +from .training_args import RayArguments, TrainingArguments __all__ = [ @@ -26,7 +27,11 @@ __all__ = [ "FinetuningArguments", "GeneratingArguments", "ModelArguments", + "RayArguments", + "TrainingArguments", "get_eval_args", "get_infer_args", + "get_ray_args", "get_train_args", + "read_args", ] diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 8cdfa7cb..62edbf78 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -19,12 +19,12 @@ import json import os import sys from pathlib import Path -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union import torch import transformers import yaml -from transformers import HfArgumentParser, Seq2SeqTrainingArguments +from transformers import HfArgumentParser from transformers.integrations import is_deepspeed_zero3_enabled from transformers.trainer_utils import get_last_checkpoint from transformers.training_args import ParallelMode @@ -34,12 +34,12 @@ from transformers.utils.versions import require_version from ..extras import logging from ..extras.constants import CHECKPOINT_NAMES from ..extras.misc import check_dependencies, get_current_device -from ..integrations.ray.ray_train_args import RayTrainArguments from .data_args import DataArguments from .evaluation_args import EvaluationArguments from .finetuning_args import FinetuningArguments from .generating_args import GeneratingArguments from .model_args import ModelArguments +from .training_args import RayArguments, TrainingArguments logger = logging.get_logger(__name__) @@ -47,60 +47,41 @@ logger = logging.get_logger(__name__) check_dependencies() -_TRAIN_ARGS = [ - ModelArguments, - DataArguments, - Seq2SeqTrainingArguments, - FinetuningArguments, - GeneratingArguments, - RayTrainArguments, -] -_TRAIN_CLS = Tuple[ - ModelArguments, - DataArguments, - Seq2SeqTrainingArguments, - FinetuningArguments, - GeneratingArguments, - RayTrainArguments, -] +_TRAIN_ARGS = [ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments] +_TRAIN_CLS = Tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments] _INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] _INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] _EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] _EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] -def _read_args(args: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: +def read_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> Union[Dict[str, Any], List[str]]: if args is not None: return args if len(sys.argv) == 2 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")): - # read yaml file return yaml.safe_load(Path(sys.argv[1]).absolute().read_text()) elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"): - # read json file return json.loads(Path(sys.argv[1]).absolute().read_text()) else: - return {} + return sys.argv[1:] def _parse_args( - parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None, allow_extra_keys: bool = False + parser: "HfArgumentParser", args: Optional[Union[Dict[str, Any], List[str]]] = None, allow_extra_keys: bool = False ) -> Tuple[Any]: - args_dict = _read_args(args) + args = read_args(args) + if isinstance(args, dict): + return parser.parse_dict(args, allow_extra_keys=allow_extra_keys) - if args_dict: - return parser.parse_dict(args_dict, allow_extra_keys=allow_extra_keys) - else: - (*parsed_args, unknown_args) = parser.parse_args_into_dataclasses( - args=args_dict, return_remaining_strings=True - ) + (*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(args=args, return_remaining_strings=True) - if unknown_args: - print(parser.format_help()) - print(f"Got unknown args, potentially deprecated arguments: {unknown_args}") - raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}") + if unknown_args: + print(parser.format_help()) + print(f"Got unknown args, potentially deprecated arguments: {unknown_args}") + raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}") - return (*parsed_args,) + return (*parsed_args,) def _set_transformers_logging() -> None: @@ -141,7 +122,7 @@ def _verify_model_args( def _check_extra_dependencies( model_args: "ModelArguments", finetuning_args: "FinetuningArguments", - training_args: Optional["Seq2SeqTrainingArguments"] = None, + training_args: Optional["TrainingArguments"] = None, ) -> None: if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]: logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.") @@ -177,31 +158,29 @@ def _check_extra_dependencies( require_version("rouge_chinese", "To fix: pip install rouge-chinese") -def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: +def _parse_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS: parser = HfArgumentParser(_TRAIN_ARGS) return _parse_args(parser, args) -def _parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: +def _parse_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS: parser = HfArgumentParser(_INFER_ARGS) return _parse_args(parser, args) -def _parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS: +def _parse_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS: parser = HfArgumentParser(_EVAL_ARGS) return _parse_args(parser, args) -def _parse_ray_args(args: Optional[Dict[str, Any]] = None) -> RayTrainArguments: - parser = HfArgumentParser(RayTrainArguments) - ray_args = _parse_args(parser, args, allow_extra_keys=True)[0] - if ray_args.use_ray: - require_version("ray", "To fix: pip install ray") +def get_ray_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> RayArguments: + parser = HfArgumentParser(RayArguments) + (ray_args,) = _parse_args(parser, args, allow_extra_keys=True) return ray_args -def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: - model_args, data_args, training_args, finetuning_args, generating_args, _ = _parse_train_args(args) +def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS: + model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args) # Setup logging if training_args.should_log: @@ -410,7 +389,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: return model_args, data_args, training_args, finetuning_args, generating_args -def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: +def get_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS: model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args) _set_transformers_logging() @@ -443,7 +422,7 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: return model_args, data_args, finetuning_args, generating_args -def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS: +def get_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS: model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args) _set_transformers_logging() diff --git a/src/llamafactory/hparams/training_args.py b/src/llamafactory/hparams/training_args.py new file mode 100644 index 00000000..9df24ace --- /dev/null +++ b/src/llamafactory/hparams/training_args.py @@ -0,0 +1,48 @@ +import json +from dataclasses import dataclass, field +from typing import Literal, Optional, Union + +from transformers import Seq2SeqTrainingArguments +from transformers.training_args import _convert_str_dict + +from ..extras.misc import use_ray + + +@dataclass +class RayArguments: + r""" + Arguments pertaining to the Ray training. + """ + + ray_run_name: Optional[str] = field( + default=None, + metadata={"help": "The training results will be saved at `saves/ray_run_name`."}, + ) + ray_num_workers: int = field( + default=1, + metadata={"help": "The number of workers for Ray training. Default is 1 worker."}, + ) + resources_per_worker: Union[dict, str] = field( + default_factory=lambda: {"GPU": 1}, + metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."}, + ) + placement_strategy: Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"] = field( + default="PACK", + metadata={"help": "The placement strategy for Ray training. Default is PACK."}, + ) + + def __post_init__(self): + self.use_ray = use_ray() + if isinstance(self.resources_per_worker, str) and self.resources_per_worker.startswith("{"): + self.resources_per_worker = _convert_str_dict(json.loads(self.resources_per_worker)) + + +@dataclass +class TrainingArguments(RayArguments, Seq2SeqTrainingArguments): + r""" + Arguments pertaining to the trainer. + """ + + def __post_init__(self): + Seq2SeqTrainingArguments.__post_init__(self) + RayArguments.__post_init__(self) diff --git a/src/llamafactory/integrations/__init__.py b/src/llamafactory/integrations/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/llamafactory/integrations/ray/__init__.py b/src/llamafactory/integrations/ray/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/llamafactory/integrations/ray/ray_train.py b/src/llamafactory/integrations/ray/ray_train.py deleted file mode 100644 index 50a2927a..00000000 --- a/src/llamafactory/integrations/ray/ray_train.py +++ /dev/null @@ -1,26 +0,0 @@ -from typing import Any, Callable, Dict - -from ray.train import ScalingConfig -from ray.train.torch import TorchTrainer - -from .ray_train_args import RayTrainArguments - - -def get_ray_trainer( - training_function: Callable, - train_loop_config: Dict[str, Any], - ray_args: RayTrainArguments, -) -> TorchTrainer: - if not ray_args.use_ray: - raise ValueError("Ray is not enabled. Please set USE_RAY=1 in your environment.") - - trainer = TorchTrainer( - training_function, - train_loop_config=train_loop_config, - scaling_config=ScalingConfig( - num_workers=ray_args.num_workers, - resources_per_worker=ray_args.resources_per_worker, - use_gpu=True, - ), - ) - return trainer diff --git a/src/llamafactory/integrations/ray/ray_train_args.py b/src/llamafactory/integrations/ray/ray_train_args.py deleted file mode 100644 index afacbee1..00000000 --- a/src/llamafactory/integrations/ray/ray_train_args.py +++ /dev/null @@ -1,30 +0,0 @@ -from dataclasses import dataclass, field -from typing import Any, Dict, Literal, Optional - -from .ray_utils import should_use_ray - - -@dataclass -class RayTrainArguments: - r""" - Arguments pertaining to the Ray training. - """ - - resources_per_worker: Optional[Dict[str, Any]] = field( - default_factory=lambda: {"GPU": 1}, - metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."}, - ) - num_workers: Optional[int] = field( - default=1, metadata={"help": "The number of workers for Ray training. Default is 1 worker."} - ) - placement_strategy: Optional[Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"]] = field( - default="PACK", metadata={"help": "The placement strategy for Ray training. Default is PACK."} - ) - - @property - def use_ray(self) -> bool: - """ - Always returns the value from the environment variable check. - This prevents manual setting of use_ray. - """ - return should_use_ray() diff --git a/src/llamafactory/integrations/ray/ray_utils.py b/src/llamafactory/integrations/ray/ray_utils.py deleted file mode 100644 index 8b8b045e..00000000 --- a/src/llamafactory/integrations/ray/ray_utils.py +++ /dev/null @@ -1,5 +0,0 @@ -import os - - -def should_use_ray(): - return os.getenv("USE_RAY", "0").lower() in ["true", "1"] diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index 4cd2337e..6aca53cf 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -18,7 +18,8 @@ # limitations under the License. from collections.abc import Mapping -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import torch from transformers import Trainer @@ -31,7 +32,7 @@ from typing_extensions import override from ..extras import logging from ..extras.constants import IGNORE_INDEX -from ..extras.packages import is_galore_available +from ..extras.packages import is_galore_available, is_ray_available from ..hparams import FinetuningArguments, ModelArguments from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params @@ -40,11 +41,16 @@ if is_galore_available(): from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit # type: ignore +if is_ray_available(): + from ray.train import RunConfig, ScalingConfig + from ray.train.torch import TorchTrainer + + if TYPE_CHECKING: - from transformers import PreTrainedModel, Seq2SeqTrainingArguments, TrainerCallback + from transformers import PreTrainedModel, TrainerCallback from trl import AutoModelForCausalLMWithValueHead - from ..hparams import DataArguments + from ..hparams import DataArguments, RayArguments, TrainingArguments logger = logging.get_logger(__name__) @@ -75,7 +81,7 @@ def create_modelcard_and_push( trainer: "Trainer", model_args: "ModelArguments", data_args: "DataArguments", - training_args: "Seq2SeqTrainingArguments", + training_args: "TrainingArguments", finetuning_args: "FinetuningArguments", ) -> None: kwargs = { @@ -188,7 +194,7 @@ def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]: def _create_galore_optimizer( model: "PreTrainedModel", - training_args: "Seq2SeqTrainingArguments", + training_args: "TrainingArguments", finetuning_args: "FinetuningArguments", ) -> "torch.optim.Optimizer": if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all": @@ -272,7 +278,7 @@ def _create_galore_optimizer( def _create_loraplus_optimizer( model: "PreTrainedModel", - training_args: "Seq2SeqTrainingArguments", + training_args: "TrainingArguments", finetuning_args: "FinetuningArguments", ) -> "torch.optim.Optimizer": default_lr = training_args.learning_rate @@ -312,7 +318,7 @@ def _create_loraplus_optimizer( def _create_badam_optimizer( model: "PreTrainedModel", - training_args: "Seq2SeqTrainingArguments", + training_args: "TrainingArguments", finetuning_args: "FinetuningArguments", ) -> "torch.optim.Optimizer": decay_params, nodecay_params = [], [] @@ -373,7 +379,7 @@ def _create_badam_optimizer( def _create_adam_mini_optimizer( model: "PreTrainedModel", - training_args: "Seq2SeqTrainingArguments", + training_args: "TrainingArguments", ) -> "torch.optim.Optimizer": from adam_mini import Adam_mini # type: ignore @@ -398,7 +404,7 @@ def _create_adam_mini_optimizer( def create_custom_optimizer( model: "PreTrainedModel", - training_args: "Seq2SeqTrainingArguments", + training_args: "TrainingArguments", finetuning_args: "FinetuningArguments", ) -> Optional["torch.optim.Optimizer"]: if finetuning_args.use_galore: @@ -415,7 +421,7 @@ def create_custom_optimizer( def create_custom_scheduler( - training_args: "Seq2SeqTrainingArguments", + training_args: "TrainingArguments", num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None, ) -> None: @@ -499,3 +505,28 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall config={"Framework": "🦙LlamaFactory"}, ) return swanlab_callback + + +def get_ray_trainer( + training_function: Callable, + train_loop_config: Dict[str, Any], + ray_args: "RayArguments", +) -> "TorchTrainer": + if not ray_args.use_ray: + raise ValueError("Ray was not enabled. Please set `USE_RAY=1` to enable ray.") + + trainer = TorchTrainer( + training_function, + train_loop_config=train_loop_config, + scaling_config=ScalingConfig( + num_workers=ray_args.ray_num_workers, + resources_per_worker=ray_args.resources_per_worker, + placement_strategy=ray_args.placement_strategy, + use_gpu=True, + ), + run_config=RunConfig( + name=ray_args.ray_run_name, + storage_path=Path("./saves").absolute().as_posix(), + ), + ) + return trainer diff --git a/src/llamafactory/train/tuner.py b/src/llamafactory/train/tuner.py index 8c461890..24620c87 100644 --- a/src/llamafactory/train/tuner.py +++ b/src/llamafactory/train/tuner.py @@ -22,8 +22,8 @@ from transformers import PreTrainedModel from ..data import get_template_and_fix_tokenizer from ..extras import logging from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME -from ..hparams import get_infer_args, get_train_args -from ..hparams.parser import _parse_ray_args, _read_args +from ..extras.packages import is_ray_available +from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args from ..model import load_model, load_tokenizer from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback from .dpo import run_dpo @@ -32,7 +32,11 @@ from .ppo import run_ppo from .pt import run_pt from .rm import run_rm from .sft import run_sft -from .trainer_utils import get_swanlab_callback +from .trainer_utils import get_ray_trainer, get_swanlab_callback + + +if is_ray_available(): + from ray.train.huggingface.transformers import RayTrainReportCallback if TYPE_CHECKING: @@ -43,10 +47,8 @@ logger = logging.get_logger(__name__) def training_function(config: Dict[str, Any]) -> None: - args = config.get("args", None) - callbacks = config.get("callbacks", []) - - callbacks.append(LogCallback()) + args = config.get("args") + callbacks: List[Any] = config.get("callbacks") model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args) if finetuning_args.pissa_convert: @@ -73,31 +75,22 @@ def training_function(config: Dict[str, Any]) -> None: raise ValueError(f"Unknown task: {finetuning_args.stage}.") -def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None: - args_dict = _read_args(args) - ray_args = _parse_ray_args(args_dict) +def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None) -> None: + callbacks = callbacks or [] + callbacks.append(LogCallback()) + args = read_args(args) + ray_args = get_ray_args(args) if ray_args.use_ray: - # Import lazily to avoid ray not installed error - from ..integrations.ray.ray_train import get_ray_trainer - - # Initialize ray trainer + callbacks.append(RayTrainReportCallback()) trainer = get_ray_trainer( training_function=training_function, - train_loop_config={ - "args": args_dict, - "callbacks": callbacks, - }, + train_loop_config={"args": args, "callbacks": callbacks}, ray_args=ray_args, ) trainer.fit() else: - training_function( - config={ - "args": args_dict, - "callbacks": callbacks, - } - ) + training_function(config={"args": args, "callbacks": callbacks}) def export_model(args: Optional[Dict[str, Any]] = None) -> None: From 0ef1f981da84e164f445fb0355244c0683699ff8 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Tue, 7 Jan 2025 09:59:24 +0000 Subject: [PATCH 06/10] fix llamaboard with ray Former-commit-id: bd8a432d6a980b1b24a551626304fe3d394b1baf --- src/llamafactory/train/callbacks.py | 6 +++--- src/llamafactory/train/tuner.py | 11 +++++------ src/llamafactory/webui/runner.py | 6 +++--- 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/src/llamafactory/train/callbacks.py b/src/llamafactory/train/callbacks.py index 189c7533..4da4ec18 100644 --- a/src/llamafactory/train/callbacks.py +++ b/src/llamafactory/train/callbacks.py @@ -35,7 +35,7 @@ from typing_extensions import override from ..extras import logging from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME -from ..extras.misc import get_peak_memory +from ..extras.misc import get_peak_memory, use_ray if is_safetensors_available(): @@ -194,7 +194,7 @@ class LogCallback(TrainerCallback): self.do_train = False # Web UI self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"] - if self.webui_mode: + if self.webui_mode and not use_ray(): signal.signal(signal.SIGABRT, self._set_abort) self.logger_handler = logging.LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR")) logging.add_handler(self.logger_handler) @@ -383,7 +383,7 @@ class ReporterCallback(TrainerCallback): ) if self.finetuning_args.use_swanlab: - import swanlab + import swanlab # type: ignore swanlab.config.update( { diff --git a/src/llamafactory/train/tuner.py b/src/llamafactory/train/tuner.py index 24620c87..bbbef1cf 100644 --- a/src/llamafactory/train/tuner.py +++ b/src/llamafactory/train/tuner.py @@ -46,11 +46,12 @@ if TYPE_CHECKING: logger = logging.get_logger(__name__) -def training_function(config: Dict[str, Any]) -> None: +def _training_function(config: Dict[str, Any]) -> None: args = config.get("args") callbacks: List[Any] = config.get("callbacks") model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args) + callbacks.append(LogCallback()) if finetuning_args.pissa_convert: callbacks.append(PissaConvertCallback()) @@ -76,21 +77,19 @@ def training_function(config: Dict[str, Any]) -> None: def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None) -> None: - callbacks = callbacks or [] - callbacks.append(LogCallback()) - args = read_args(args) ray_args = get_ray_args(args) + callbacks = callbacks or [] if ray_args.use_ray: callbacks.append(RayTrainReportCallback()) trainer = get_ray_trainer( - training_function=training_function, + training_function=_training_function, train_loop_config={"args": args, "callbacks": callbacks}, ray_args=ray_args, ) trainer.fit() else: - training_function(config={"args": args, "callbacks": callbacks}) + _training_function(config={"args": args, "callbacks": callbacks}) def export_model(args: Optional[Dict[str, Any]] = None) -> None: diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py index f5aecaeb..dc91ad50 100644 --- a/src/llamafactory/webui/runner.py +++ b/src/llamafactory/webui/runner.py @@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, Optional from transformers.trainer import TRAINING_ARGS_NAME from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES -from ..extras.misc import is_gpu_or_npu_available, torch_gc +from ..extras.misc import is_gpu_or_npu_available, torch_gc, use_ray from ..extras.packages import is_gradio_available, is_transformers_version_equal_to_4_46 from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, QUANTIZATION_BITS, get_save_dir, load_config from .locales import ALERTS, LOCALES @@ -394,12 +394,12 @@ class Runner: continue if self.do_train: - if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)): + if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)) or use_ray(): finish_info = ALERTS["info_finished"][lang] else: finish_info = ALERTS["err_failed"][lang] else: - if os.path.exists(os.path.join(output_path, "all_results.json")): + if os.path.exists(os.path.join(output_path, "all_results.json")) or use_ray(): finish_info = get_eval_results(os.path.join(output_path, "all_results.json")) else: finish_info = ALERTS["err_failed"][lang] From 647c51a7727c06c8e9f6894f30e0b8010d94e82a Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 8 Jan 2025 09:56:10 +0000 Subject: [PATCH 07/10] imporve log Former-commit-id: a6abf375975ffea3d51e1b944c9855b5f62ffac8 --- src/llamafactory/chat/hf_engine.py | 2 +- src/llamafactory/data/data_utils.py | 4 +-- src/llamafactory/data/loader.py | 7 ++-- src/llamafactory/data/mm_plugin.py | 8 +++-- src/llamafactory/data/template.py | 4 +-- src/llamafactory/extras/logging.py | 8 ++--- src/llamafactory/extras/misc.py | 34 +++++++++++++------ src/llamafactory/hparams/parser.py | 30 +++++++--------- .../model/model_utils/attention.py | 6 ++-- .../model/model_utils/checkpointing.py | 2 +- .../model/model_utils/longlora.py | 4 +-- src/llamafactory/model/model_utils/moe.py | 5 +-- src/llamafactory/model/model_utils/packing.py | 4 +-- .../model/model_utils/quantization.py | 23 ++++++------- src/llamafactory/train/callbacks.py | 2 +- src/llamafactory/train/sft/workflow.py | 2 +- 16 files changed, 78 insertions(+), 67 deletions(-) diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index 1004544d..f63b6434 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -63,7 +63,7 @@ class HuggingfaceEngine(BaseEngine): try: asyncio.get_event_loop() except RuntimeError: - logger.warning_once("There is no current event loop, creating a new one.") + logger.warning_rank0_once("There is no current event loop, creating a new one.") loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) diff --git a/src/llamafactory/data/data_utils.py b/src/llamafactory/data/data_utils.py index cbce026c..bd5d3587 100644 --- a/src/llamafactory/data/data_utils.py +++ b/src/llamafactory/data/data_utils.py @@ -56,12 +56,12 @@ def merge_dataset( return all_datasets[0] elif data_args.mix_strategy == "concat": if data_args.streaming: - logger.warning_once("The samples between different datasets will not be mixed in streaming mode.") + logger.warning_rank0_once("The samples between different datasets will not be mixed in streaming mode.") return concatenate_datasets(all_datasets) elif data_args.mix_strategy.startswith("interleave"): if not data_args.streaming: - logger.warning_once("We recommend using `mix_strategy=concat` in non-streaming mode.") + logger.warning_rank0_once("We recommend using `mix_strategy=concat` in non-streaming mode.") return interleave_datasets( datasets=all_datasets, diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index 863b6492..3c7e34a4 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -18,11 +18,10 @@ from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union import numpy as np from datasets import DatasetDict, load_dataset, load_from_disk -from transformers.utils.versions import require_version from ..extras import logging from ..extras.constants import FILEEXT2TYPE -from ..extras.misc import has_tokenized_data +from ..extras.misc import check_version, has_tokenized_data from .aligner import align_dataset from .data_utils import merge_dataset, split_dataset from .parser import get_dataset_list @@ -84,7 +83,7 @@ def _load_single_dataset( raise NotImplementedError(f"Unknown load type: {dataset_attr.load_from}.") if dataset_attr.load_from == "ms_hub": - require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0") + check_version("modelscope>=1.11.0", mandatory=True) from modelscope import MsDataset # type: ignore from modelscope.utils.config_ds import MS_DATASETS_CACHE # type: ignore @@ -103,7 +102,7 @@ def _load_single_dataset( dataset = dataset.to_hf_dataset() elif dataset_attr.load_from == "om_hub": - require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0") + check_version("openmind>=0.8.0", mandatory=True) from openmind import OmDataset # type: ignore from openmind.utils.hub import OM_DATASETS_CACHE # type: ignore diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 4e1f418a..a8c46d11 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -73,10 +73,14 @@ class BasePlugin: Validates if this model accepts the input modalities. """ if len(images) != 0 and self.image_token is None: - raise ValueError("This model does not support image input.") + raise ValueError( + "This model does not support image input. Please check whether the correct `template` is used." + ) if len(videos) != 0 and self.video_token is None: - raise ValueError("This model does not support video input.") + raise ValueError( + "This model does not support video input. Please check whether the correct `template` is used." + ) def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject": r""" diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 5768cf7b..ebe31553 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -15,10 +15,10 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union -from transformers.utils.versions import require_version from typing_extensions import override from ..extras import logging +from ..extras.misc import check_version from .data_utils import Role from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter from .mm_plugin import get_mm_plugin @@ -365,7 +365,7 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: raise ValueError(f"Template {data_args.template} does not exist.") if template.mm_plugin.__class__.__name__ != "BasePlugin": - require_version("transformers>=4.45.0", "To fix: pip install transformers>=4.45.0") + check_version("transformers>=4.45.0") if data_args.train_on_prompt and template.efficient_eos: raise ValueError("Current template does not support `train_on_prompt`.") diff --git a/src/llamafactory/extras/logging.py b/src/llamafactory/extras/logging.py index 40889a88..8f98b055 100644 --- a/src/llamafactory/extras/logging.py +++ b/src/llamafactory/extras/logging.py @@ -68,7 +68,7 @@ class LoggerHandler(logging.Handler): class _Logger(logging.Logger): r""" - A logger that supports info_rank0 and warning_once. + A logger that supports rank0 logging. """ def info_rank0(self, *args, **kwargs) -> None: @@ -77,7 +77,7 @@ class _Logger(logging.Logger): def warning_rank0(self, *args, **kwargs) -> None: self.warning(*args, **kwargs) - def warning_once(self, *args, **kwargs) -> None: + def warning_rank0_once(self, *args, **kwargs) -> None: self.warning(*args, **kwargs) @@ -163,11 +163,11 @@ def warning_rank0(self: "logging.Logger", *args, **kwargs) -> None: @lru_cache(None) -def warning_once(self: "logging.Logger", *args, **kwargs) -> None: +def warning_rank0_once(self: "logging.Logger", *args, **kwargs) -> None: if int(os.getenv("LOCAL_RANK", "0")) == 0: self.warning(*args, **kwargs) logging.Logger.info_rank0 = info_rank0 logging.Logger.warning_rank0 = warning_rank0 -logging.Logger.warning_once = warning_once +logging.Logger.warning_rank0_once = warning_rank0_once diff --git a/src/llamafactory/extras/misc.py b/src/llamafactory/extras/misc.py index 11797f9f..beaed725 100644 --- a/src/llamafactory/extras/misc.py +++ b/src/llamafactory/extras/misc.py @@ -73,19 +73,31 @@ class AverageMeter: self.avg = self.sum / self.count +def check_version(requirement: str, mandatory: bool = False) -> None: + r""" + Optionally checks the package version. + """ + if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"] and not mandatory: + logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.") + return + + if mandatory: + hint = f"To fix: run `pip install {requirement}`." + else: + hint = f"To fix: run `pip install {requirement}` or set `DISABLE_VERSION_CHECK=1` to skip this check." + + require_version(requirement, hint) + + def check_dependencies() -> None: r""" Checks the version of the required packages. """ - if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]: - logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.") - return - - require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1") - require_version("datasets>=2.16.0,<=3.1.0", "To fix: pip install datasets>=2.16.0,<=3.1.0") - require_version("accelerate>=0.34.0,<=1.0.1", "To fix: pip install accelerate>=0.34.0,<=1.0.1") - require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0") - require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6") + check_version("transformers>=4.41.2,<=4.46.1") + check_version("datasets>=2.16.0,<=3.1.0") + check_version("accelerate>=0.34.0,<=1.0.1") + check_version("peft>=0.11.1,<=0.12.0") + check_version("trl>=0.8.6,<=0.9.6") def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float], stage: Literal["sft", "rm"]) -> float: @@ -253,7 +265,7 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str: return model_args.model_name_or_path if use_modelscope(): - require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0") + check_version("modelscope>=1.11.0", mandatory=True) from modelscope import snapshot_download # type: ignore revision = "master" if model_args.model_revision == "main" else model_args.model_revision @@ -264,7 +276,7 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str: ) if use_openmind(): - require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0") + check_version("openmind>=0.8.0", mandatory=True) from openmind.utils.hub import snapshot_download # type: ignore return snapshot_download( diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 62edbf78..456e34a2 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -29,11 +29,10 @@ from transformers.integrations import is_deepspeed_zero3_enabled from transformers.trainer_utils import get_last_checkpoint from transformers.training_args import ParallelMode from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available -from transformers.utils.versions import require_version from ..extras import logging from ..extras.constants import CHECKPOINT_NAMES -from ..extras.misc import check_dependencies, get_current_device +from ..extras.misc import check_dependencies, check_version, get_current_device from .data_args import DataArguments from .evaluation_args import EvaluationArguments from .finetuning_args import FinetuningArguments @@ -124,38 +123,35 @@ def _check_extra_dependencies( finetuning_args: "FinetuningArguments", training_args: Optional["TrainingArguments"] = None, ) -> None: - if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]: - logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.") - return - if model_args.use_unsloth: - require_version("unsloth", "Please install unsloth: https://github.com/unslothai/unsloth") + check_version("unsloth", mandatory=True) if model_args.enable_liger_kernel: - require_version("liger-kernel", "To fix: pip install liger-kernel") + check_version("liger-kernel", mandatory=True) if model_args.mixture_of_depths is not None: - require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6") + check_version("mixture-of-depth>=1.1.6", mandatory=True) if model_args.infer_backend == "vllm": - require_version("vllm>=0.4.3,<0.6.7", "To fix: pip install vllm>=0.4.3,<0.6.7") + check_version("vllm>=0.4.3,<0.6.7") + check_version("vllm", mandatory=True) if finetuning_args.use_galore: - require_version("galore_torch", "To fix: pip install galore_torch") + check_version("galore_torch", mandatory=True) if finetuning_args.use_badam: - require_version("badam>=1.2.1", "To fix: pip install badam>=1.2.1") + check_version("badam>=1.2.1", mandatory=True) if finetuning_args.use_adam_mini: - require_version("adam-mini", "To fix: pip install adam-mini") + check_version("adam-mini", mandatory=True) if finetuning_args.plot_loss: - require_version("matplotlib", "To fix: pip install matplotlib") + check_version("matplotlib", mandatory=True) if training_args is not None and training_args.predict_with_generate: - require_version("jieba", "To fix: pip install jieba") - require_version("nltk", "To fix: pip install nltk") - require_version("rouge_chinese", "To fix: pip install rouge-chinese") + check_version("jieba", mandatory=True) + check_version("nltk", mandatory=True) + check_version("rouge_chinese", mandatory=True) def _parse_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS: diff --git a/src/llamafactory/model/model_utils/attention.py b/src/llamafactory/model/model_utils/attention.py index bf243aaa..8ec74351 100644 --- a/src/llamafactory/model/model_utils/attention.py +++ b/src/llamafactory/model/model_utils/attention.py @@ -15,9 +15,9 @@ from typing import TYPE_CHECKING from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available -from transformers.utils.versions import require_version from ...extras import logging +from ...extras.misc import check_version if TYPE_CHECKING: @@ -35,8 +35,8 @@ def configure_attn_implementation( if getattr(config, "model_type", None) == "gemma2" and is_trainable: if model_args.flash_attn == "auto" or model_args.flash_attn == "fa2": if is_flash_attn_2_available(): - require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4") - require_version("flash_attn>=2.6.3", "To fix: pip install flash_attn>=2.6.3") + check_version("transformers>=4.42.4") + check_version("flash_attn>=2.6.3") if model_args.flash_attn != "fa2": logger.warning_rank0("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.") model_args.flash_attn = "fa2" diff --git a/src/llamafactory/model/model_utils/checkpointing.py b/src/llamafactory/model/model_utils/checkpointing.py index 0fad48cf..80a50f3e 100644 --- a/src/llamafactory/model/model_utils/checkpointing.py +++ b/src/llamafactory/model/model_utils/checkpointing.py @@ -122,7 +122,7 @@ def _gradient_checkpointing_enable( if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format self.apply(partial(self._set_gradient_checkpointing, value=True)) self.enable_input_require_grads() - logger.warning_once("You are using the old GC format, some features (e.g. BAdam) will be invalid.") + logger.warning_rank0_once("You are using the old GC format, some features (e.g. BAdam) will be invalid.") else: # have already enabled input require gradients self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func) diff --git a/src/llamafactory/model/model_utils/longlora.py b/src/llamafactory/model/model_utils/longlora.py index 96a7b40e..89457846 100644 --- a/src/llamafactory/model/model_utils/longlora.py +++ b/src/llamafactory/model/model_utils/longlora.py @@ -31,10 +31,10 @@ from transformers.models.llama.modeling_llama import ( apply_rotary_pos_emb, repeat_kv, ) -from transformers.utils.versions import require_version from ...extras import logging from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN +from ...extras.misc import check_version from ...extras.packages import is_transformers_version_greater_than @@ -353,7 +353,7 @@ def llama_sdpa_attention_forward( def _apply_llama_patch() -> None: - require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1") + check_version("transformers>=4.41.2,<=4.46.1") LlamaAttention.forward = llama_attention_forward LlamaFlashAttention2.forward = llama_flash_attention_2_forward LlamaSdpaAttention.forward = llama_sdpa_attention_forward diff --git a/src/llamafactory/model/model_utils/moe.py b/src/llamafactory/model/model_utils/moe.py index 642d164a..58039e2a 100644 --- a/src/llamafactory/model/model_utils/moe.py +++ b/src/llamafactory/model/model_utils/moe.py @@ -16,7 +16,8 @@ from typing import TYPE_CHECKING, Sequence import torch from transformers.integrations import is_deepspeed_zero3_enabled -from transformers.utils.versions import require_version + +from ...extras.misc import check_version if TYPE_CHECKING: @@ -26,7 +27,7 @@ if TYPE_CHECKING: def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: Sequence["torch.nn.Module"]) -> None: - require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0") + check_version("deepspeed>=0.13.0") from deepspeed.utils import set_z3_leaf_modules # type: ignore set_z3_leaf_modules(model, leaf_modules) diff --git a/src/llamafactory/model/model_utils/packing.py b/src/llamafactory/model/model_utils/packing.py index 014c8e87..34c3c55b 100644 --- a/src/llamafactory/model/model_utils/packing.py +++ b/src/llamafactory/model/model_utils/packing.py @@ -41,9 +41,9 @@ from typing import TYPE_CHECKING, Tuple import torch import torch.nn.functional as F -from transformers.utils.versions import require_version from ...extras import logging +from ...extras.misc import check_version from ...extras.packages import is_transformers_version_greater_than @@ -118,6 +118,6 @@ def configure_packing(model_args: "ModelArguments", is_trainable: bool) -> None: if not is_trainable or not model_args.block_diag_attn: return - require_version("transformers>=4.43.0,<=4.46.1", "To fix: pip install transformers>=4.43.0,<=4.46.1") + check_version("transformers>=4.43.0,<=4.46.1") transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.") diff --git a/src/llamafactory/model/model_utils/quantization.py b/src/llamafactory/model/model_utils/quantization.py index 0739c566..e000ee23 100644 --- a/src/llamafactory/model/model_utils/quantization.py +++ b/src/llamafactory/model/model_utils/quantization.py @@ -26,11 +26,10 @@ from datasets import load_dataset from transformers import BitsAndBytesConfig, EetqConfig, GPTQConfig, HqqConfig from transformers.integrations import is_deepspeed_zero3_enabled from transformers.modeling_utils import is_fsdp_enabled -from transformers.utils.versions import require_version from ...extras import logging from ...extras.constants import FILEEXT2TYPE -from ...extras.misc import get_current_device +from ...extras.misc import check_version, get_current_device if TYPE_CHECKING: @@ -118,15 +117,15 @@ def configure_quantization( quant_method = quantization_config.get("quant_method", "") if quant_method == QuantizationMethod.GPTQ: - require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0") + check_version("auto_gptq>=0.5.0", mandatory=True) quantization_config.pop("disable_exllama", None) # remove deprecated args quantization_config["use_exllama"] = False # disable exllama if quant_method == QuantizationMethod.AWQ: - require_version("autoawq", "To fix: pip install autoawq") + check_version("autoawq", mandatory=True) if quant_method == QuantizationMethod.AQLM: - require_version("aqlm>=1.1.0", "To fix: pip install aqlm[gpu]>=1.1.0") + check_version("aqlm>=1.1.0", mandatory=True) quantization_config["bits"] = 2 quant_bits = quantization_config.get("bits", "?") @@ -136,8 +135,8 @@ def configure_quantization( if model_args.export_quantization_bit not in [8, 4, 3, 2]: raise ValueError("AutoGPTQ only accepts 2/3/4/8-bit quantization.") - require_version("optimum>=1.17.0", "To fix: pip install optimum>=1.17.0") - require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0") + check_version("optimum>=1.17.0", mandatory=True) + check_version("auto_gptq>=0.5.0", mandatory=True) from accelerate.utils import get_max_memory if getattr(config, "model_type", None) == "chatglm": @@ -154,10 +153,10 @@ def configure_quantization( elif model_args.quantization_bit is not None: # on-the-fly if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value: if model_args.quantization_bit == 8: - require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") + check_version("bitsandbytes>=0.37.0", mandatory=True) init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) elif model_args.quantization_bit == 4: - require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") + check_version("bitsandbytes>=0.39.0", mandatory=True) init_kwargs["quantization_config"] = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=model_args.compute_dtype, @@ -175,7 +174,7 @@ def configure_quantization( if model_args.quantization_bit != 4: raise ValueError("Only 4-bit quantized model can use fsdp+qlora or auto device map.") - require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0") + check_version("bitsandbytes>=0.43.0", mandatory=True) else: init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference @@ -187,7 +186,7 @@ def configure_quantization( if is_deepspeed_zero3_enabled() or is_fsdp_enabled(): raise ValueError("HQQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.") - require_version("hqq", "To fix: pip install hqq") + check_version("hqq", mandatory=True) init_kwargs["quantization_config"] = HqqConfig( nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0 ) # use ATEN kernel (axis=0) for performance @@ -199,6 +198,6 @@ def configure_quantization( if is_deepspeed_zero3_enabled() or is_fsdp_enabled(): raise ValueError("EETQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.") - require_version("eetq", "To fix: pip install eetq") + check_version("eetq", mandatory=True) init_kwargs["quantization_config"] = EetqConfig() logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with EETQ.") diff --git a/src/llamafactory/train/callbacks.py b/src/llamafactory/train/callbacks.py index 4da4ec18..5906a4a6 100644 --- a/src/llamafactory/train/callbacks.py +++ b/src/llamafactory/train/callbacks.py @@ -239,7 +239,7 @@ class LogCallback(TrainerCallback): and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG)) and args.overwrite_output_dir ): - logger.warning_once("Previous trainer log in this folder will be deleted.") + logger.warning_rank0_once("Previous trainer log in this folder will be deleted.") os.remove(os.path.join(args.output_dir, TRAINER_LOG)) @override diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index 1ccfa9ef..f41f24cb 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -122,7 +122,7 @@ def run_sft( # Predict if training_args.do_predict: - logger.warning_once("Batch generation can be very slow. Consider using `scripts/vllm_infer.py` instead.") + logger.warning_rank0_once("Batch generation can be very slow. Consider using `scripts/vllm_infer.py` instead.") predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs) trainer.log_metrics("predict", predict_results.metrics) trainer.save_metrics("predict", predict_results.metrics) From b6b53b61f738345fbe3afe6be2cf71b4afdb3ec2 Mon Sep 17 00:00:00 2001 From: zhubin Date: Wed, 8 Jan 2025 17:18:41 +0800 Subject: [PATCH 08/10] =?UTF-8?q?fix=20=C2=96get=20ray=20args=20when=20arg?= =?UTF-8?q?s=20not=20a=20dict?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Former-commit-id: 5e5398cd5b117b2378107172d3f91cfb0321e842 --- src/llamafactory/hparams/parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 62edbf78..f3c40f68 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -76,7 +76,7 @@ def _parse_args( (*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(args=args, return_remaining_strings=True) - if unknown_args: + if unknown_args and not allow_extra_keys: print(parser.format_help()) print(f"Got unknown args, potentially deprecated arguments: {unknown_args}") raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}") From 867980196e9d230b876316dbb390cf6c60b2f28f Mon Sep 17 00:00:00 2001 From: hiyouga Date: Thu, 9 Jan 2025 18:27:20 +0000 Subject: [PATCH 09/10] improve template, add phi4 model Former-commit-id: a785b6796e445a3adba45c5b6947166a2ff99871 --- README.md | 11 +- README_zh.md | 11 +- src/llamafactory/data/template.py | 146 ++++++++++++--------------- src/llamafactory/extras/constants.py | 19 ++++ tests/data/test_template.py | 85 +++++++++------- 5 files changed, 147 insertions(+), 125 deletions(-) diff --git a/README.md b/README.md index 1eec5b96..739a13af 100644 --- a/README.md +++ b/README.md @@ -88,14 +88,16 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ ## Changelog +[25/01/10] We supported fine-tuning the **[Phi-4](https://huggingface.co/microsoft/phi-4)** model. + [24/12/21] We supported using **[SwanLab](https://github.com/SwanHubX/SwanLab)** for experiment tracking and visualization. See [this section](#use-swanlab-logger) for details. [24/11/27] We supported fine-tuning the **[Skywork-o1](https://huggingface.co/Skywork/Skywork-o1-Open-Llama-3.1-8B)** model and the **[OpenO1](https://huggingface.co/datasets/O1-OPEN/OpenO1-SFT)** dataset. -[24/10/09] We supported downloading pre-trained models and datasets from the **[Modelers Hub](https://modelers.cn/models)**. See [this tutorial](#download-from-modelers-hub) for usage. -
Full Changelog +[24/10/09] We supported downloading pre-trained models and datasets from the **[Modelers Hub](https://modelers.cn/models)**. See [this tutorial](#download-from-modelers-hub) for usage. + [24/09/19] We supported fine-tuning the **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** models. [24/08/30] We supported fine-tuning the **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** models. Thank [@simonJJJ](https://github.com/simonJJJ)'s PR. @@ -211,8 +213,9 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ | [OLMo](https://huggingface.co/allenai) | 1B/7B | - | | [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma | | [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - | -| [Phi-3](https://huggingface.co/microsoft) | 4B/14B | phi | +| [Phi-3/Phi-3.5](https://huggingface.co/microsoft) | 4B/14B | phi | | [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small | +| [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 | | [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral | | [Qwen/QwQ (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen | | [Qwen2-VL/QVQ](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl | @@ -762,7 +765,7 @@ If you have a project that should be incorporated, please contact via email or c This repository is licensed under the [Apache-2.0 License](LICENSE). -Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) +Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3/Phi-4](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) ## Citation diff --git a/README_zh.md b/README_zh.md index 352a483d..e21560d2 100644 --- a/README_zh.md +++ b/README_zh.md @@ -89,14 +89,16 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272 ## 更新日志 +[25/01/10] 我们支持了 **[Phi-4](https://huggingface.co/microsoft/phi-4)** 模型的微调。 + [24/12/21] 我们支持了使用 **[SwanLab](https://github.com/SwanHubX/SwanLab)** 跟踪与可视化实验。详细用法请参考 [此部分](#使用-swanlab-面板)。 [24/11/27] 我们支持了 **[Skywork-o1](https://huggingface.co/Skywork/Skywork-o1-Open-Llama-3.1-8B)** 模型的微调和 **[OpenO1](https://huggingface.co/datasets/O1-OPEN/OpenO1-SFT)** 数据集。 -[24/10/09] 我们支持了从 **[魔乐社区](https://modelers.cn/models)** 下载预训练模型和数据集。详细用法请参照 [此教程](#从魔乐社区下载)。 -
展开日志 +[24/10/09] 我们支持了从 **[魔乐社区](https://modelers.cn/models)** 下载预训练模型和数据集。详细用法请参照 [此教程](#从魔乐社区下载)。 + [24/09/19] 我们支持了 **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** 模型的微调。 [24/08/30] 我们支持了 **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** 模型的微调。感谢 [@simonJJJ](https://github.com/simonJJJ) 的 PR。 @@ -212,8 +214,9 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272 | [OLMo](https://huggingface.co/allenai) | 1B/7B | - | | [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma | | [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - | -| [Phi-3](https://huggingface.co/microsoft) | 4B/14B | phi | +| [Phi-3/Phi-3.5](https://huggingface.co/microsoft) | 4B/14B | phi | | [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small | +| [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 | | [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral | | [Qwen/QwQ (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen | | [Qwen2-VL/QVQ](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl | @@ -763,7 +766,7 @@ swanlab_run_name: test_run # 可选 本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。 -使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) +使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3/Phi-4](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) ## 引用 diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index ebe31553..490b571d 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -44,7 +44,6 @@ class Template: format_function: "Formatter" format_observation: "Formatter" format_tools: "Formatter" - format_separator: "Formatter" format_prefix: "Formatter" default_system: str stop_words: List[str] @@ -113,9 +112,6 @@ class Template: tool_text = self.format_tools.apply(content=tools)[0] if tools else "" elements += self.format_system.apply(content=(system + tool_text)) - if i > 0 and i % 2 == 0: - elements += self.format_separator.apply() - if message["role"] == Role.USER.value: elements += self.format_user.apply(content=message["content"], idx=str(i // 2)) elif message["role"] == Role.ASSISTANT.value: @@ -180,9 +176,6 @@ class Llama2Template(Template): tool_text = self.format_tools.apply(content=tools)[0] if tools else "" system_text = self.format_system.apply(content=(system + tool_text))[0] - if i > 0 and i % 2 == 0: - elements += self.format_separator.apply() - if message["role"] == Role.USER.value: elements += self.format_user.apply(content=system_text + message["content"]) elif message["role"] == Role.ASSISTANT.value: @@ -210,7 +203,6 @@ def _register_template( format_function: Optional["Formatter"] = None, format_observation: Optional["Formatter"] = None, format_tools: Optional["Formatter"] = None, - format_separator: Optional["Formatter"] = None, format_prefix: Optional["Formatter"] = None, default_system: str = "", stop_words: Sequence[str] = [], @@ -224,34 +216,28 @@ def _register_template( To add the following chat template: ``` - [HUMAN]: - user prompt here - [AI]: - model response here - - [HUMAN]: - user prompt here - [AI]: - model response here + user prompt here + model response here + user prompt here + model response here ``` The corresponding code should be: ``` _register_template( name="custom", - format_user=StringFormatter(slots=["[HUMAN]:\n{{content}}\n[AI]:\n"]), - format_separator=EmptyFormatter(slots=["\n\n"]), - efficient_eos=True, + format_user=StringFormatter(slots=["{{content}}\n"]), + format_assistant=StringFormatter(slots=["{{content}}\n"]), + format_prefix=EmptyFormatter(""), ) ``` """ - template_class = Llama2Template if any(k in name for k in ("llama2", "mistral")) else Template + template_class = Llama2Template if any(k in name for k in ("llama2", "mistral", "pixtral")) else Template default_slots = ["{{content}}"] if efficient_eos else ["{{content}}", {"eos_token"}] default_user_formatter = StringFormatter(slots=["{{content}}"]) default_assistant_formatter = StringFormatter(slots=default_slots) default_function_formatter = FunctionFormatter(slots=default_slots, tool_format="default") default_tool_formatter = ToolFormatter(tool_format="default") - default_separator_formatter = EmptyFormatter() default_prefix_formatter = EmptyFormatter() TEMPLATES[name] = template_class( format_user=format_user or default_user_formatter, @@ -260,7 +246,6 @@ def _register_template( format_function=format_function or default_function_formatter, format_observation=format_observation or format_user or default_user_formatter, format_tools=format_tools or default_tool_formatter, - format_separator=format_separator or default_separator_formatter, format_prefix=format_prefix or default_prefix_formatter, default_system=default_system, stop_words=stop_words, @@ -344,9 +329,7 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") jinja_template += "{{ " + user_message + " }}" jinja_template += "{% elif message['role'] == 'assistant' %}" - assistant_message = _convert_slots_to_jinja( - template.format_assistant.apply() + template.format_separator.apply(), tokenizer - ) + assistant_message = _convert_slots_to_jinja(template.format_assistant.apply(), tokenizer) jinja_template += "{{ " + assistant_message + " }}" jinja_template += "{% endif %}" jinja_template += "{% endfor %}" @@ -411,7 +394,7 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: _register_template( name="alpaca", format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]), - format_separator=EmptyFormatter(slots=["\n\n"]), + format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n\n"]), default_system=( "Below is an instruction that describes a task. " "Write a response that appropriately completes the request.\n\n" @@ -423,13 +406,13 @@ _register_template( _register_template( name="aquila", format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]), - format_separator=EmptyFormatter(slots=["###"]), + format_assistant=StringFormatter(slots=["{{content}}###"]), + format_system=StringFormatter(slots=["System: {{content}}###"]), default_system=( "A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions." ), stop_words=[""], - efficient_eos=True, ) @@ -459,7 +442,7 @@ _register_template( _register_template( name="belle", format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]), - format_separator=EmptyFormatter(slots=["\n\n"]), + format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n\n"]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]), ) @@ -481,7 +464,6 @@ _register_template( _register_template( name="chatglm2", format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]), - format_separator=EmptyFormatter(slots=["\n\n"]), format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]), efficient_eos=True, ) @@ -506,9 +488,9 @@ _register_template( _register_template( name="chatml", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), - format_separator=EmptyFormatter(slots=["\n"]), stop_words=["<|im_end|>", "<|im_start|>"], replace_eos=True, replace_jinja_template=True, @@ -519,9 +501,9 @@ _register_template( _register_template( name="chatml_de", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), - format_separator=EmptyFormatter(slots=["\n"]), default_system="Du bist ein freundlicher und hilfsbereiter KI-Assistent.", stop_words=["<|im_end|>", "<|im_start|>"], replace_eos=True, @@ -574,9 +556,11 @@ _register_template( ) +# copied from chatml template _register_template( name="cpm3", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]), stop_words=["<|im_end|>"], @@ -587,9 +571,9 @@ _register_template( _register_template( name="dbrx", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), - format_separator=EmptyFormatter(slots=["\n"]), default_system=( "You are DBRX, created by Databricks. You were last updated in December 2023. " "You answer questions based on information available up to that point.\n" @@ -606,7 +590,6 @@ _register_template( "ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY." ), stop_words=["<|im_end|>"], - replace_eos=True, ) @@ -628,8 +611,7 @@ _register_template( _register_template( name="deepseekcoder", format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]), - format_assistant=StringFormatter(slots=["\n{{content}}\n<|EOT|>"]), - format_separator=EmptyFormatter(slots=["\n"]), + format_assistant=StringFormatter(slots=["\n{{content}}\n<|EOT|>\n"]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]), default_system=( "You are an AI programming assistant, utilizing the DeepSeek Coder model, " @@ -643,8 +625,8 @@ _register_template( _register_template( name="default", format_user=StringFormatter(slots=["Human: {{content}}\nAssistant:"]), - format_system=StringFormatter(slots=["{{content}}\n"]), - format_separator=EmptyFormatter(slots=["\n"]), + format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n"]), + format_system=StringFormatter(slots=["System: {{content}}\n"]), ) @@ -657,22 +639,22 @@ _register_template( _register_template( name="exaone", format_user=StringFormatter(slots=["[|user|]{{content}}\n[|assistant|]"]), + format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n"]), format_system=StringFormatter(slots=["[|system|]{{content}}[|endofturn|]\n"]), - format_separator=EmptyFormatter(slots=["\n"]), ) _register_template( name="falcon", format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]), - format_separator=EmptyFormatter(slots=["\n"]), + format_assistant=StringFormatter(slots=["{{content}}\n"]), efficient_eos=True, ) _register_template( name="fewshot", - format_separator=EmptyFormatter(slots=["\n\n"]), + format_assistant=StringFormatter(slots=["{{content}}\n\n"]), efficient_eos=True, ) @@ -680,12 +662,11 @@ _register_template( _register_template( name="gemma", format_user=StringFormatter(slots=["user\n{{content}}\nmodel\n"]), + format_assistant=StringFormatter(slots=["{{content}}\n"]), format_observation=StringFormatter( slots=["tool\n{{content}}\nmodel\n"] ), - format_separator=EmptyFormatter(slots=["\n"]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]), - efficient_eos=True, ) @@ -710,8 +691,8 @@ _register_template( "<|start_of_role|>user<|end_of_role|>{{content}}<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>" ] ), + format_assistant=StringFormatter(slots=["{{content}}<|end_of_text|>\n"]), format_system=StringFormatter(slots=["<|start_of_role|>system<|end_of_role|>{{content}}<|end_of_text|>\n"]), - format_separator=EmptyFormatter(slots=["\n"]), ) @@ -726,22 +707,20 @@ _register_template( _register_template( name="intern", format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]), + format_assistant=StringFormatter(slots=["{{content}}\n"]), format_system=StringFormatter(slots=["<|System|>:{{content}}\n"]), - format_separator=EmptyFormatter(slots=["\n"]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]), stop_words=[""], - efficient_eos=True, # internlm tokenizer cannot set eos_token_id ) _register_template( name="intern2", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), - format_separator=EmptyFormatter(slots=["<|im_end|>\n"]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]), stop_words=["<|im_end|>"], - efficient_eos=True, # internlm2 tokenizer cannot set eos_token_id ) @@ -872,6 +851,7 @@ _register_template( name="llava_next_mistral", format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]), format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]), + format_system=StringFormatter(slots=["{{content}}\n\n"]), format_function=FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", {"eos_token"}], tool_format="mistral"), format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]), format_tools=ToolFormatter(tool_format="mistral"), @@ -884,16 +864,15 @@ _register_template( _register_template( name="llava_next_qwen", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), - format_function=FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen"), + format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"), format_observation=StringFormatter( slots=["<|im_start|>user\n\n{{content}}\n<|im_end|>\n<|im_start|>assistant\n"] ), format_tools=ToolFormatter(tool_format="qwen"), - format_separator=EmptyFormatter(slots=["\n"]), default_system="You are a helpful assistant.", stop_words=["<|im_end|>"], - replace_eos=True, mm_plugin=get_mm_plugin(name="llava_next", image_token=""), ) @@ -902,10 +881,9 @@ _register_template( _register_template( name="llava_next_yi", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), - format_separator=EmptyFormatter(slots=["\n"]), stop_words=["<|im_end|>"], - replace_eos=True, mm_plugin=get_mm_plugin(name="llava_next", image_token=""), ) @@ -927,6 +905,7 @@ _register_template( name="llava_next_video_mistral", format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]), format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]), + format_system=StringFormatter(slots=["{{content}}\n\n"]), format_function=FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", {"eos_token"}], tool_format="mistral"), format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]), format_tools=ToolFormatter(tool_format="mistral"), @@ -939,10 +918,9 @@ _register_template( _register_template( name="llava_next_video_yi", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), - format_separator=EmptyFormatter(slots=["\n"]), stop_words=["<|im_end|>"], - replace_eos=True, mm_plugin=get_mm_plugin(name="llava_next_video", image_token="", video_token="