From 7c1640ed5f47599b68f17ef3bab1f2a6effb13ce Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Wed, 12 Mar 2025 00:08:41 +0800 Subject: [PATCH] [misc] upgrade format to py39 (#7256) --- data/belle_multiturn/belle_multiturn.py | 2 +- data/hh_rlhf_en/hh_rlhf_en.py | 3 +- data/ultra_chat/ultra_chat.py | 7 +- evaluation/ceval/ceval.py | 5 +- evaluation/cmmlu/cmmlu.py | 5 +- evaluation/mmlu/mmlu.py | 5 +- pyproject.toml | 31 ++- scripts/api_example/test_toolcall.py | 2 +- scripts/convert_ckpt/llamafy_baichuan2.py | 12 +- scripts/convert_ckpt/llamafy_qwen.py | 14 +- scripts/llama_pro.py | 10 +- scripts/loftq_init.py | 6 +- scripts/pissa_init.py | 6 +- scripts/stat_utils/cal_flops.py | 4 +- scripts/stat_utils/cal_lr.py | 9 +- scripts/stat_utils/cal_mfu.py | 12 +- scripts/stat_utils/cal_ppl.py | 25 +-- scripts/stat_utils/length_cdf.py | 4 +- scripts/vllm_infer.py | 4 +- setup.py | 5 +- src/llamafactory/__init__.py | 3 +- src/llamafactory/api/app.py | 4 +- src/llamafactory/api/chat.py | 5 +- src/llamafactory/api/common.py | 4 +- src/llamafactory/api/protocol.py | 26 +-- src/llamafactory/chat/base_engine.py | 34 ++- src/llamafactory/chat/chat_model.py | 56 ++--- src/llamafactory/chat/hf_engine.py | 47 ++-- src/llamafactory/chat/vllm_engine.py | 19 +- src/llamafactory/data/__init__.py | 8 +- src/llamafactory/data/collator.py | 38 ++-- src/llamafactory/data/converter.py | 49 ++-- src/llamafactory/data/data_utils.py | 26 +-- src/llamafactory/data/formatter.py | 15 +- src/llamafactory/data/loader.py | 25 +-- src/llamafactory/data/mm_plugin.py | 211 ++++++++---------- src/llamafactory/data/parser.py | 19 +- src/llamafactory/data/processor/__init__.py | 2 +- src/llamafactory/data/processor/feedback.py | 15 +- src/llamafactory/data/processor/pairwise.py | 13 +- src/llamafactory/data/processor/pretrain.py | 6 +- .../data/processor/processor_utils.py | 35 +-- src/llamafactory/data/processor/supervised.py | 15 +- .../data/processor/unsupervised.py | 13 +- src/llamafactory/data/template.py | 110 ++++----- src/llamafactory/data/tool_utils.py | 76 +++---- src/llamafactory/eval/evaluator.py | 10 +- src/llamafactory/eval/template.py | 19 +- src/llamafactory/extras/constants.py | 4 +- src/llamafactory/extras/logging.py | 28 +-- src/llamafactory/extras/misc.py | 73 ++---- src/llamafactory/extras/ploting.py | 20 +- src/llamafactory/hparams/data_args.py | 8 +- src/llamafactory/hparams/evaluation_args.py | 4 +- src/llamafactory/hparams/finetuning_args.py | 44 ++-- src/llamafactory/hparams/generating_args.py | 12 +- src/llamafactory/hparams/model_args.py | 29 +-- src/llamafactory/hparams/parser.py | 47 ++-- src/llamafactory/hparams/training_args.py | 8 +- src/llamafactory/model/__init__.py | 2 +- src/llamafactory/model/adapter.py | 10 +- src/llamafactory/model/loader.py | 23 +- .../model/model_utils/checkpointing.py | 30 ++- .../model/model_utils/embedding.py | 4 +- .../model/model_utils/longlora.py | 36 +-- src/llamafactory/model/model_utils/misc.py | 14 +- src/llamafactory/model/model_utils/moe.py | 7 +- src/llamafactory/model/model_utils/packing.py | 13 +- .../model/model_utils/quantization.py | 22 +- src/llamafactory/model/model_utils/unsloth.py | 24 +- .../model/model_utils/valuehead.py | 7 +- src/llamafactory/model/model_utils/visual.py | 51 ++--- src/llamafactory/model/patcher.py | 4 +- src/llamafactory/train/callbacks.py | 33 +-- src/llamafactory/train/dpo/trainer.py | 59 ++--- src/llamafactory/train/dpo/workflow.py | 4 +- src/llamafactory/train/kto/trainer.py | 54 ++--- src/llamafactory/train/kto/workflow.py | 4 +- src/llamafactory/train/ppo/ppo_utils.py | 24 +- src/llamafactory/train/ppo/trainer.py | 53 ++--- src/llamafactory/train/ppo/workflow.py | 6 +- src/llamafactory/train/pt/trainer.py | 4 +- src/llamafactory/train/pt/workflow.py | 4 +- src/llamafactory/train/rm/metric.py | 10 +- src/llamafactory/train/rm/trainer.py | 18 +- src/llamafactory/train/rm/workflow.py | 4 +- src/llamafactory/train/sft/metric.py | 21 +- src/llamafactory/train/sft/trainer.py | 22 +- src/llamafactory/train/sft/workflow.py | 4 +- src/llamafactory/train/test_utils.py | 9 +- src/llamafactory/train/trainer_utils.py | 55 ++--- src/llamafactory/train/tuner.py | 10 +- src/llamafactory/webui/chatter.py | 32 ++- src/llamafactory/webui/common.py | 86 +++---- src/llamafactory/webui/components/chatbot.py | 8 +- src/llamafactory/webui/components/data.py | 16 +- src/llamafactory/webui/components/eval.py | 4 +- src/llamafactory/webui/components/export.py | 9 +- src/llamafactory/webui/components/infer.py | 4 +- src/llamafactory/webui/components/top.py | 4 +- src/llamafactory/webui/components/train.py | 8 +- src/llamafactory/webui/control.py | 37 ++- src/llamafactory/webui/engine.py | 22 +- src/llamafactory/webui/interface.py | 2 +- src/llamafactory/webui/manager.py | 42 ++-- src/llamafactory/webui/runner.py | 72 +++--- tests/data/processor/test_pairwise.py | 3 +- tests/data/processor/test_processor_utils.py | 3 +- tests/data/test_formatter.py | 6 +- tests/data/test_mm_plugin.py | 19 +- tests/data/test_template.py | 10 +- tests/model/model_utils/test_checkpointing.py | 2 +- tests/train/test_sft_trainer.py | 6 +- 113 files changed, 984 insertions(+), 1407 deletions(-) diff --git a/data/belle_multiturn/belle_multiturn.py b/data/belle_multiturn/belle_multiturn.py index 7b2f5449..2267c7ce 100644 --- a/data/belle_multiturn/belle_multiturn.py +++ b/data/belle_multiturn/belle_multiturn.py @@ -10,7 +10,7 @@ _DESCRIPTION = "BELLE multiturn chat dataset." _CITATION = """\ @article{belle2023exploring, - title={Exploring the Impact of Instruction Data Scaling on Large Language Models: An Empirical Study on Real-World Use Cases}, + title={Exploring the Impact of Instruction Data Scaling on Large Language Models}, author={Yunjie Ji, Yong Deng, Yan Gong, Yiping Peng, Qiang Niu, Lei Zhang, Baochang Ma, Xiangang Li}, journal={arXiv preprint arXiv:2303.14742}, year={2023} diff --git a/data/hh_rlhf_en/hh_rlhf_en.py b/data/hh_rlhf_en/hh_rlhf_en.py index b94fbfd1..083130f1 100644 --- a/data/hh_rlhf_en/hh_rlhf_en.py +++ b/data/hh_rlhf_en/hh_rlhf_en.py @@ -1,6 +1,5 @@ import json import os -from typing import List import datasets @@ -50,7 +49,7 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder): datasets.SplitGenerator(name=datasets.Split.TEST, gen_kwargs={"filepaths": file_path["test"]}), ] - def _generate_examples(self, filepaths: List[str]): + def _generate_examples(self, filepaths: list[str]): key = 0 for filepath in filepaths: with open(filepath, encoding="utf-8") as f: diff --git a/data/ultra_chat/ultra_chat.py b/data/ultra_chat/ultra_chat.py index c7e12a03..9eafa2ef 100644 --- a/data/ultra_chat/ultra_chat.py +++ b/data/ultra_chat/ultra_chat.py @@ -1,6 +1,5 @@ import json import os -from typing import List import datasets @@ -11,7 +10,7 @@ _DESCRIPTION = "UltraChat: Large-scale, Informative, and Diverse Multi-round Dia _CITATION = """\ @misc{UltraChat, - author = {Ding, Ning and Chen, Yulin and Xu, Bokai and Hu, Shengding and Qin, Yujia and Liu, Zhiyuan and Sun, Maosong and Zhou, Bowen}, + author = {Ding, Ning and Chen, Yulin and Xu, Bokai and Hu, Shengding and others}, title = {UltraChat: A Large-scale Auto-generated Multi-round Dialogue Data}, year = {2023}, publisher = {GitHub}, @@ -40,7 +39,7 @@ class UltraChat(datasets.GeneratorBasedBuilder): file_paths = [dl_manager.download(_BASE_DATA_URL.format(idx=idx)) for idx in range(10)] # multiple shards return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepaths": file_paths})] - def _generate_examples(self, filepaths: List[str]): + def _generate_examples(self, filepaths: list[str]): for filepath in filepaths: with open(filepath, encoding="utf-8") as f: for row in f: @@ -49,7 +48,7 @@ class UltraChat(datasets.GeneratorBasedBuilder): except Exception: continue key: int = data["id"] - content: List[str] = data["data"] + content: list[str] = data["data"] if len(content) % 2 == 1: content.pop(-1) if len(content) < 2: diff --git a/evaluation/ceval/ceval.py b/evaluation/ceval/ceval.py index 48442d50..e18be8ee 100644 --- a/evaluation/ceval/ceval.py +++ b/evaluation/ceval/ceval.py @@ -21,14 +21,15 @@ import pandas as pd _CITATION = """\ @article{huang2023ceval, title={C-Eval: A Multi-Level Multi-Discipline Chinese Evaluation Suite for Foundation Models}, - author={Huang, Yuzhen and Bai, Yuzhuo and Zhu, Zhihao and Zhang, Junlei and Zhang, Jinghan and Su, Tangjun and Liu, Junteng and Lv, Chuancheng and Zhang, Yikai and Lei, Jiayi and Fu, Yao and Sun, Maosong and He, Junxian}, + author={Huang, Yuzhen and Bai, Yuzhuo and Zhu, Zhihao and others}, journal={arXiv preprint arXiv:2305.08322}, year={2023} } """ _DESCRIPTION = """\ -C-Eval is a comprehensive Chinese evaluation suite for foundation models. It consists of 13948 multi-choice questions spanning 52 diverse disciplines and four difficulty levels. +C-Eval is a comprehensive Chinese evaluation suite for foundation models. +It consists of 13948 multi-choice questions spanning 52 diverse disciplines and four difficulty levels. """ _HOMEPAGE = "https://cevalbenchmark.com" diff --git a/evaluation/cmmlu/cmmlu.py b/evaluation/cmmlu/cmmlu.py index 5ff548a4..517d63f8 100644 --- a/evaluation/cmmlu/cmmlu.py +++ b/evaluation/cmmlu/cmmlu.py @@ -21,14 +21,15 @@ import pandas as pd _CITATION = """\ @article{li2023cmmlu, title={CMMLU: Measuring massive multitask language understanding in Chinese}, - author={Haonan Li and Yixuan Zhang and Fajri Koto and Yifei Yang and Hai Zhao and Yeyun Gong and Nan Duan and Timothy Baldwin}, + author={Haonan Li and Yixuan Zhang and Fajri Koto and Yifei Yang and others, journal={arXiv preprint arXiv:2306.09212}, year={2023} } """ _DESCRIPTION = """\ -CMMLU is a comprehensive Chinese assessment suite specifically designed to evaluate the advanced knowledge and reasoning abilities of LLMs within the Chinese language and cultural context. +CMMLU is a comprehensive Chinese assessment suite specifically designed to evaluate the advanced knowledge +and reasoning abilities of LLMs within the Chinese language and cultural context. """ _HOMEPAGE = "https://github.com/haonan-li/CMMLU" diff --git a/evaluation/mmlu/mmlu.py b/evaluation/mmlu/mmlu.py index e83fdab5..63547757 100644 --- a/evaluation/mmlu/mmlu.py +++ b/evaluation/mmlu/mmlu.py @@ -21,14 +21,15 @@ import pandas as pd _CITATION = """\ @article{hendryckstest2021, title={Measuring Massive Multitask Language Understanding}, - author={Dan Hendrycks and Collin Burns and Steven Basart and Andy Zou and Mantas Mazeika and Dawn Song and Jacob Steinhardt}, + author={Dan Hendrycks and Collin Burns and others}, journal={Proceedings of the International Conference on Learning Representations (ICLR)}, year={2021} } """ _DESCRIPTION = """\ -Measuring Massive Multitask Language Understanding by Dan Hendrycks, Collin Burns, Steven Basart, Andy Zou, Mantas Mazeika, Dawn Song, and Jacob Steinhardt (ICLR 2021). +Measuring Massive Multitask Language Understanding by Dan Hendrycks, Collin Burns, Steven Basart, +Andy Zou, Mantas Mazeika, Dawn Song, and Jacob Steinhardt (ICLR 2021). """ _HOMEPAGE = "https://github.com/hendrycks/test" diff --git a/pyproject.toml b/pyproject.toml index 97084dc0..cf011762 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,13 +19,35 @@ dynamic = [ ] [tool.ruff] -target-version = "py38" +target-version = "py39" line-length = 119 indent-width = 4 [tool.ruff.lint] -ignore = ["C408", "C901", "E501", "E731", "E741", "W605"] -select = ["C", "E", "F", "I", "W"] +ignore = [ + "C408", # collection + "C901", # complex + "E731", # lambda function + "E741", # ambiguous var name + "D100", # no doc public module + "D101", # no doc public class + "D102", # no doc public method + "D103", # no doc public function + "D104", # no doc public package + "D105", # no doc magic method + "D107", # no doc __init__ +] +extend-select = [ + "C", # complexity + "E", # error + "F", # pyflakes + "I", # isort + "W", # warning + "UP", # pyupgrade + "D", # pydocstyle + "PT009", # pytest assert + "RUF022", # sort __all__ +] [tool.ruff.lint.isort] lines-after-imports = 2 @@ -41,6 +63,9 @@ known-third-party = [ "trl" ] +[tool.ruff.lint.pydocstyle] +convention = "google" + [tool.ruff.format] quote-style = "double" indent-style = "space" diff --git a/scripts/api_example/test_toolcall.py b/scripts/api_example/test_toolcall.py index 6a0a6f38..2dff4aab 100644 --- a/scripts/api_example/test_toolcall.py +++ b/scripts/api_example/test_toolcall.py @@ -14,7 +14,7 @@ import json import os -from typing import Sequence +from collections.abc import Sequence from openai import OpenAI from transformers.utils.versions import require_version diff --git a/scripts/convert_ckpt/llamafy_baichuan2.py b/scripts/convert_ckpt/llamafy_baichuan2.py index 75e849b2..3dbeff49 100644 --- a/scripts/convert_ckpt/llamafy_baichuan2.py +++ b/scripts/convert_ckpt/llamafy_baichuan2.py @@ -15,7 +15,7 @@ import json import os from collections import OrderedDict -from typing import Any, Dict +from typing import Any import fire import torch @@ -29,13 +29,13 @@ CONFIG_NAME = "config.json" def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetensors: bool): - baichuan2_state_dict: Dict[str, torch.Tensor] = OrderedDict() + baichuan2_state_dict: dict[str, torch.Tensor] = OrderedDict() for filepath in tqdm(os.listdir(input_dir), desc="Load weights"): if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".bin"): shard_weight = torch.load(os.path.join(input_dir, filepath), map_location="cpu") baichuan2_state_dict.update(shard_weight) - llama_state_dict: Dict[str, torch.Tensor] = OrderedDict() + llama_state_dict: dict[str, torch.Tensor] = OrderedDict() for key, value in tqdm(baichuan2_state_dict.items(), desc="Convert format"): if "W_pack" in key: proj_size = value.size(0) // 3 @@ -75,7 +75,7 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso def save_config(input_dir: str, output_dir: str): with open(os.path.join(input_dir, CONFIG_NAME), encoding="utf-8") as f: - llama2_config_dict: Dict[str, Any] = json.load(f) + llama2_config_dict: dict[str, Any] = json.load(f) llama2_config_dict["architectures"] = ["LlamaForCausalLM"] llama2_config_dict.pop("auto_map", None) @@ -94,8 +94,8 @@ def llamafy_baichuan2( shard_size: str = "2GB", save_safetensors: bool = True, ): - r""" - Converts the Baichuan2-7B model in the same format as LLaMA2-7B. + r"""Convert the Baichuan2-7B model in the same format as LLaMA2-7B. + Usage: python llamafy_baichuan2.py --input_dir input --output_dir output Converted model: https://huggingface.co/hiyouga/Baichuan2-7B-Base-LLaMAfied """ diff --git a/scripts/convert_ckpt/llamafy_qwen.py b/scripts/convert_ckpt/llamafy_qwen.py index bb3fe519..599b0f12 100644 --- a/scripts/convert_ckpt/llamafy_qwen.py +++ b/scripts/convert_ckpt/llamafy_qwen.py @@ -15,7 +15,7 @@ import json import os from collections import OrderedDict -from typing import Any, Dict +from typing import Any import fire import torch @@ -37,14 +37,14 @@ CONFIG_NAME = "config.json" def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetensors: bool) -> str: - qwen_state_dict: Dict[str, torch.Tensor] = OrderedDict() + qwen_state_dict: dict[str, torch.Tensor] = OrderedDict() for filepath in tqdm(os.listdir(input_dir), desc="Load weights"): if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".safetensors"): with safe_open(os.path.join(input_dir, filepath), framework="pt", device="cpu") as f: for key in f.keys(): qwen_state_dict[key] = f.get_tensor(key) - llama_state_dict: Dict[str, torch.Tensor] = OrderedDict() + llama_state_dict: dict[str, torch.Tensor] = OrderedDict() torch_dtype = None for key, value in tqdm(qwen_state_dict.items(), desc="Convert format"): if torch_dtype is None: @@ -112,9 +112,9 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso def save_config(input_dir: str, output_dir: str, torch_dtype: str): with open(os.path.join(input_dir, CONFIG_NAME), encoding="utf-8") as f: - qwen_config_dict: Dict[str, Any] = json.load(f) + qwen_config_dict: dict[str, Any] = json.load(f) - llama2_config_dict: Dict[str, Any] = OrderedDict() + llama2_config_dict: dict[str, Any] = OrderedDict() llama2_config_dict["architectures"] = ["LlamaForCausalLM"] llama2_config_dict["hidden_act"] = "silu" llama2_config_dict["hidden_size"] = qwen_config_dict["hidden_size"] @@ -147,8 +147,8 @@ def llamafy_qwen( shard_size: str = "2GB", save_safetensors: bool = False, ): - r""" - Converts the Qwen models in the same format as LLaMA2. + r"""Convert the Qwen models in the same format as LLaMA2. + Usage: python llamafy_qwen.py --input_dir input --output_dir output Converted model: https://huggingface.co/hiyouga/Qwen-14B-Chat-LLaMAfied """ diff --git a/scripts/llama_pro.py b/scripts/llama_pro.py index dd10b525..7e4b9448 100644 --- a/scripts/llama_pro.py +++ b/scripts/llama_pro.py @@ -18,7 +18,7 @@ import json import os from collections import OrderedDict -from typing import TYPE_CHECKING, Dict +from typing import TYPE_CHECKING import fire import torch @@ -44,11 +44,11 @@ def block_expansion( shard_size: str = "5GB", save_safetensors: bool = True, ): - r""" - Performs block expansion for LLaMA, Mistral, Qwen2 or Yi models. + r"""Perform block expansion for LLaMA, Mistral, Qwen2 or Yi models. + Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8 """ - config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) + config: PretrainedConfig = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) num_layers = getattr(config, "num_hidden_layers") if num_layers % num_expand != 0: raise ValueError(f"`num_layers` {num_layers} should be divisible by `num_expand` {num_expand}.") @@ -70,7 +70,7 @@ def block_expansion( split = num_layers // num_expand layer_cnt = 0 state_dict = model.state_dict() - output_state_dict: Dict[str, "torch.Tensor"] = OrderedDict() + output_state_dict: dict[str, torch.Tensor] = OrderedDict() for i in range(num_layers): for key, value in state_dict.items(): if f".{i:d}." in key: diff --git a/scripts/loftq_init.py b/scripts/loftq_init.py index 83e38e88..3a793388 100644 --- a/scripts/loftq_init.py +++ b/scripts/loftq_init.py @@ -38,8 +38,8 @@ def quantize_loftq( lora_target: tuple = ("q_proj", "v_proj"), save_safetensors: bool = True, ): - r""" - Initializes LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ) + r"""Initialize LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ). + Usage: python loftq_init.py --model_name_or_path path_to_model --output_dir output_dir """ if isinstance(lora_target, str): @@ -72,7 +72,7 @@ def quantize_loftq( print(f"Adapter weights saved in {loftq_dir}") # Save base model - base_model: "PreTrainedModel" = peft_model.unload() + base_model: PreTrainedModel = peft_model.unload() base_model.save_pretrained(output_dir, safe_serialization=save_safetensors) tokenizer.save_pretrained(output_dir) print(f"Model weights saved in {output_dir}") diff --git a/scripts/pissa_init.py b/scripts/pissa_init.py index 3be11fbf..405a1472 100644 --- a/scripts/pissa_init.py +++ b/scripts/pissa_init.py @@ -37,8 +37,8 @@ def quantize_pissa( lora_target: tuple = ("q_proj", "v_proj"), save_safetensors: bool = True, ): - r""" - Initializes LoRA weights with Principal Singular values and Singular vectors Adaptation (PiSSA) + r"""Initialize LoRA weights with Principal Singular values and Singular vectors Adaptation (PiSSA). + Usage: python pissa_init.py --model_name_or_path path_to_model --output_dir output_dir """ if isinstance(lora_target, str): @@ -67,7 +67,7 @@ def quantize_pissa( print(f"Adapter weights saved in {pissa_dir}") # Save base model - base_model: "PreTrainedModel" = peft_model.unload() + base_model: PreTrainedModel = peft_model.unload() base_model.save_pretrained(output_dir, safe_serialization=save_safetensors) tokenizer.save_pretrained(output_dir) print(f"Model weights saved in {output_dir}") diff --git a/scripts/stat_utils/cal_flops.py b/scripts/stat_utils/cal_flops.py index a9eb033f..3dc04995 100644 --- a/scripts/stat_utils/cal_flops.py +++ b/scripts/stat_utils/cal_flops.py @@ -29,8 +29,8 @@ def calculate_flops( seq_length: int = 512, flash_attn: str = "auto", ): - r""" - Calculates the flops of pre-trained models. + r"""Calculate the flops of pre-trained models. + Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512 """ with get_accelerator().device(0): diff --git a/scripts/stat_utils/cal_lr.py b/scripts/stat_utils/cal_lr.py index 85921d90..eb35c47e 100644 --- a/scripts/stat_utils/cal_lr.py +++ b/scripts/stat_utils/cal_lr.py @@ -45,8 +45,8 @@ def calculate_lr( is_mistral_or_gemma: bool = False, # mistral and gemma models opt for a smaller learning rate, packing: bool = False, ): - r""" - Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters. + r"""Calculate the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters. + Usage: python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en_demo --cutoff_len 1024 --batch_size 16 """ @@ -89,9 +89,8 @@ def calculate_lr( lr = BASE_LR * math.sqrt(token_batch_size / BASE_BS) # lr ~ sqrt(batch_size) lr = lr / 6.0 if is_mistral_or_gemma else lr print( - "Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective token batch size {:.2f}".format( - lr, valid_ratio * 100, token_batch_size - ) + f"Optimal learning rate is {lr:.2e} for valid ratio% {valid_ratio * 100:.2f} " + f"and effective token batch size {token_batch_size:.2f}" ) diff --git a/scripts/stat_utils/cal_mfu.py b/scripts/stat_utils/cal_mfu.py index b1aea710..f1d4446e 100644 --- a/scripts/stat_utils/cal_mfu.py +++ b/scripts/stat_utils/cal_mfu.py @@ -34,9 +34,7 @@ def compute_model_flops( include_recompute: bool = False, include_flashattn: bool = False, ) -> int: - r""" - Calculates the FLOPs of model per forward/backward pass. - """ + r"""Calculate the FLOPs of model per forward/backward pass.""" config = AutoConfig.from_pretrained(model_name_or_path) hidden_size = getattr(config, "hidden_size", None) vocab_size = getattr(config, "vocab_size", None) @@ -86,9 +84,7 @@ def compute_model_flops( def compute_device_flops(world_size: int) -> float: - r""" - Calculates the FLOPs of the device capability per second. - """ + r"""Calculate the FLOPs of the device capability per second.""" device_name = torch.cuda.get_device_name() if "H100" in device_name or "H800" in device_name: return 989 * 1e12 * world_size @@ -114,8 +110,8 @@ def calculate_mfu( liger_kernel: bool = False, unsloth_gc: bool = False, ) -> float: - r""" - Calculates MFU for given model and hyper-params. + r"""Calculate MFU for given model and hyper-params. + Usage: python cal_mfu.py --model_name_or_path path_to_model --batch_size 1 --seq_length 1024 """ args = { diff --git a/scripts/stat_utils/cal_ppl.py b/scripts/stat_utils/cal_ppl.py index f4bce61f..a318ee46 100644 --- a/scripts/stat_utils/cal_ppl.py +++ b/scripts/stat_utils/cal_ppl.py @@ -13,8 +13,9 @@ # limitations under the License. import json +from collections.abc import Sequence from dataclasses import dataclass -from typing import Any, Dict, Literal, Optional, Sequence +from typing import Any, Literal, Optional import fire import torch @@ -30,16 +31,12 @@ from llamafactory.model import load_model, load_tokenizer @dataclass class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq): - r""" - Data collator for pairwise data. - """ + r"""Data collator for pairwise data.""" train_on_prompt: bool = False - def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]: - r""" - Pads batched data to the longest sequence in the batch. - """ + def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, torch.Tensor]: + r"""Pad batched data to the longest sequence in the batch.""" chosen_features = [] for feature in features: chosen_features.append( @@ -68,8 +65,8 @@ def calculate_ppl( max_samples: Optional[int] = None, train_on_prompt: bool = False, ): - r""" - Calculates the ppl on the dataset of the pre-trained models. + r"""Calculate the ppl on the dataset of the pre-trained models. + Usage: export CUDA_VISIBLE_DEVICES=0 python cal_ppl.py --model_name_or_path path_to_model --dataset alpaca_en_demo --save_name ppl.json """ @@ -111,17 +108,17 @@ def calculate_ppl( criterion = torch.nn.CrossEntropyLoss(reduction="none") total_ppl = 0 perplexities = [] - batch: Dict[str, "torch.Tensor"] + batch: dict[str, torch.Tensor] with torch.no_grad(): for batch in tqdm(dataloader, desc="Computing perplexities"): batch = batch.to(model.device) outputs = model(**batch) - shift_logits: "torch.Tensor" = outputs["logits"][..., :-1, :] - shift_labels: "torch.Tensor" = batch["labels"][..., 1:] + shift_logits: torch.Tensor = outputs["logits"][..., :-1, :] + shift_labels: torch.Tensor = batch["labels"][..., 1:] loss_mask = shift_labels != IGNORE_INDEX flatten_logits = shift_logits.contiguous().view(shift_labels.size(0) * shift_labels.size(1), -1) flatten_labels = shift_labels.contiguous().view(-1) - token_logps: "torch.Tensor" = criterion(flatten_logits, flatten_labels) + token_logps: torch.Tensor = criterion(flatten_logits, flatten_labels) token_logps = token_logps.contiguous().view(shift_logits.size(0), -1) sentence_logps = (token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) total_ppl += sentence_logps.exp().sum().item() diff --git a/scripts/stat_utils/length_cdf.py b/scripts/stat_utils/length_cdf.py index 275549ba..c459c8fa 100644 --- a/scripts/stat_utils/length_cdf.py +++ b/scripts/stat_utils/length_cdf.py @@ -29,8 +29,8 @@ def length_cdf( template: str = "default", interval: int = 1000, ): - r""" - Calculates the distribution of the input lengths in the dataset. + r"""Calculate the distribution of the input lengths in the dataset. + Usage: export CUDA_VISIBLE_DEVICES=0 python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en_demo --template default """ diff --git a/scripts/vllm_infer.py b/scripts/vllm_infer.py index 02d20ee5..24334911 100644 --- a/scripts/vllm_infer.py +++ b/scripts/vllm_infer.py @@ -52,8 +52,8 @@ def vllm_infer( image_max_pixels: int = 768 * 768, image_min_pixels: int = 32 * 32, ): - r""" - Performs batch generation using vLLM engine, which supports tensor parallelism. + r"""Perform batch generation using vLLM engine, which supports tensor parallelism. + Usage: python vllm_infer.py --model_name_or_path meta-llama/Llama-2-7b-hf --template llama --dataset alpaca_en_demo """ check_version("vllm>=0.4.3,<=0.7.3") diff --git a/setup.py b/setup.py index 0fa66dc7..d9b4a905 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,6 @@ import os import re -from typing import List from setuptools import find_packages, setup @@ -27,14 +26,14 @@ def get_version() -> str: return version -def get_requires() -> List[str]: +def get_requires() -> list[str]: with open("requirements.txt", encoding="utf-8") as f: file_content = f.read() lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")] return lines -def get_console_scripts() -> List[str]: +def get_console_scripts() -> list[str]: console_scripts = ["llamafactory-cli = llamafactory.cli:main"] if os.getenv("ENABLE_SHORT_CONSOLE", "1").lower() in ["true", "y", "1"]: console_scripts.append("lmf = llamafactory.cli:main") diff --git a/src/llamafactory/__init__.py b/src/llamafactory/__init__.py index 2cfc5a38..093de3db 100644 --- a/src/llamafactory/__init__.py +++ b/src/llamafactory/__init__.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -r""" -Efficient fine-tuning of large language models. +r"""Efficient fine-tuning of large language models. Level: api, webui > chat, eval, train > data, model > hparams > extras diff --git a/src/llamafactory/api/app.py b/src/llamafactory/api/app.py index e86e647f..e0621d80 100644 --- a/src/llamafactory/api/app.py +++ b/src/llamafactory/api/app.py @@ -16,9 +16,7 @@ import asyncio import os from contextlib import asynccontextmanager from functools import partial -from typing import Optional - -from typing_extensions import Annotated +from typing import Annotated, Optional from ..chat import ChatModel from ..extras.constants import EngineName diff --git a/src/llamafactory/api/chat.py b/src/llamafactory/api/chat.py index a0edd8e0..ed40e8f8 100644 --- a/src/llamafactory/api/chat.py +++ b/src/llamafactory/api/chat.py @@ -18,7 +18,8 @@ import json import os import re import uuid -from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple +from collections.abc import AsyncGenerator +from typing import TYPE_CHECKING, Optional from ..data import Role as DataRole from ..extras import logging @@ -71,7 +72,7 @@ ROLE_MAPPING = { def _process_request( request: "ChatCompletionRequest", -) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional[List["ImageInput"]]]: +) -> tuple[list[dict[str, str]], Optional[str], Optional[str], Optional[list["ImageInput"]]]: if is_env_enabled("API_VERBOSE", "1"): logger.info_rank0(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}") diff --git a/src/llamafactory/api/common.py b/src/llamafactory/api/common.py index 59c84de6..f4d0c2fb 100644 --- a/src/llamafactory/api/common.py +++ b/src/llamafactory/api/common.py @@ -13,14 +13,14 @@ # limitations under the License. import json -from typing import TYPE_CHECKING, Any, Dict +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from pydantic import BaseModel -def dictify(data: "BaseModel") -> Dict[str, Any]: +def dictify(data: "BaseModel") -> dict[str, Any]: try: # pydantic v2 return data.model_dump(exclude_unset=True) except AttributeError: # pydantic v1 diff --git a/src/llamafactory/api/protocol.py b/src/llamafactory/api/protocol.py index 310e743e..bb3029d5 100644 --- a/src/llamafactory/api/protocol.py +++ b/src/llamafactory/api/protocol.py @@ -14,7 +14,7 @@ import time from enum import Enum, unique -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union from pydantic import BaseModel, Field from typing_extensions import Literal @@ -45,7 +45,7 @@ class ModelCard(BaseModel): class ModelList(BaseModel): object: Literal["list"] = "list" - data: List[ModelCard] = [] + data: list[ModelCard] = [] class Function(BaseModel): @@ -56,7 +56,7 @@ class Function(BaseModel): class FunctionDefinition(BaseModel): name: str description: str - parameters: Dict[str, Any] + parameters: dict[str, Any] class FunctionAvailable(BaseModel): @@ -82,26 +82,26 @@ class MultimodalInputItem(BaseModel): class ChatMessage(BaseModel): role: Role - content: Optional[Union[str, List[MultimodalInputItem]]] = None - tool_calls: Optional[List[FunctionCall]] = None + content: Optional[Union[str, list[MultimodalInputItem]]] = None + tool_calls: Optional[list[FunctionCall]] = None class ChatCompletionMessage(BaseModel): role: Optional[Role] = None content: Optional[str] = None - tool_calls: Optional[List[FunctionCall]] = None + tool_calls: Optional[list[FunctionCall]] = None class ChatCompletionRequest(BaseModel): model: str - messages: List[ChatMessage] - tools: Optional[List[FunctionAvailable]] = None + messages: list[ChatMessage] + tools: Optional[list[FunctionAvailable]] = None do_sample: Optional[bool] = None temperature: Optional[float] = None top_p: Optional[float] = None n: int = 1 max_tokens: Optional[int] = None - stop: Optional[Union[str, List[str]]] = None + stop: Optional[Union[str, list[str]]] = None stream: bool = False @@ -128,7 +128,7 @@ class ChatCompletionResponse(BaseModel): object: Literal["chat.completion"] = "chat.completion" created: int = Field(default_factory=lambda: int(time.time())) model: str - choices: List[ChatCompletionResponseChoice] + choices: list[ChatCompletionResponseChoice] usage: ChatCompletionResponseUsage @@ -137,12 +137,12 @@ class ChatCompletionStreamResponse(BaseModel): object: Literal["chat.completion.chunk"] = "chat.completion.chunk" created: int = Field(default_factory=lambda: int(time.time())) model: str - choices: List[ChatCompletionStreamResponseChoice] + choices: list[ChatCompletionStreamResponseChoice] class ScoreEvaluationRequest(BaseModel): model: str - messages: List[str] + messages: list[str] max_length: Optional[int] = None @@ -150,4 +150,4 @@ class ScoreEvaluationResponse(BaseModel): id: str object: Literal["score.evaluation"] = "score.evaluation" model: str - scores: List[float] + scores: list[float] diff --git a/src/llamafactory/chat/base_engine.py b/src/llamafactory/chat/base_engine.py index 1ebc0437..3b9bf5f4 100644 --- a/src/llamafactory/chat/base_engine.py +++ b/src/llamafactory/chat/base_engine.py @@ -13,8 +13,9 @@ # limitations under the License. from abc import ABC, abstractmethod +from collections.abc import AsyncGenerator, Sequence from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, Literal, Optional, Union if TYPE_CHECKING: @@ -36,8 +37,7 @@ class Response: class BaseEngine(ABC): - r""" - Base class for inference engine of chat models. + r"""Base class for inference engine of chat models. Must implements async methods: chat(), stream_chat() and get_scores(). """ @@ -47,7 +47,7 @@ class BaseEngine(ABC): tokenizer: "PreTrainedTokenizer" can_generate: bool template: "Template" - generating_args: Dict[str, Any] + generating_args: dict[str, Any] @abstractmethod def __init__( @@ -57,31 +57,27 @@ class BaseEngine(ABC): finetuning_args: "FinetuningArguments", generating_args: "GeneratingArguments", ) -> None: - r""" - Initializes an inference engine. - """ + r"""Initialize an inference engine.""" ... @abstractmethod async def chat( self, - messages: Sequence[Dict[str, str]], + messages: Sequence[dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, images: Optional[Sequence["ImageInput"]] = None, videos: Optional[Sequence["VideoInput"]] = None, audios: Optional[Sequence["AudioInput"]] = None, **input_kwargs, - ) -> List["Response"]: - r""" - Gets a list of responses of the chat model. - """ + ) -> list["Response"]: + r"""Get a list of responses of the chat model.""" ... @abstractmethod async def stream_chat( self, - messages: Sequence[Dict[str, str]], + messages: Sequence[dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, images: Optional[Sequence["ImageInput"]] = None, @@ -89,18 +85,14 @@ class BaseEngine(ABC): audios: Optional[Sequence["AudioInput"]] = None, **input_kwargs, ) -> AsyncGenerator[str, None]: - r""" - Gets the response token-by-token of the chat model. - """ + r"""Get the response token-by-token of the chat model.""" ... @abstractmethod async def get_scores( self, - batch_input: List[str], + batch_input: list[str], **input_kwargs, - ) -> List[float]: - r""" - Gets a list of scores of the reward model. - """ + ) -> list[float]: + r"""Get a list of scores of the reward model.""" ... diff --git a/src/llamafactory/chat/chat_model.py b/src/llamafactory/chat/chat_model.py index ef273947..63651184 100644 --- a/src/llamafactory/chat/chat_model.py +++ b/src/llamafactory/chat/chat_model.py @@ -17,8 +17,9 @@ import asyncio import os +from collections.abc import AsyncGenerator, Generator, Sequence from threading import Thread -from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence +from typing import TYPE_CHECKING, Any, Optional from ..extras.constants import EngineName from ..extras.misc import torch_gc @@ -38,20 +39,19 @@ def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None: class ChatModel: - r""" - General class for chat models. Backed by huggingface or vllm engines. + r"""General class for chat models. Backed by huggingface or vllm engines. Supports both sync and async methods. Sync methods: chat(), stream_chat() and get_scores(). Async methods: achat(), astream_chat() and aget_scores(). """ - def __init__(self, args: Optional[Dict[str, Any]] = None) -> None: + def __init__(self, args: Optional[dict[str, Any]] = None) -> None: model_args, data_args, finetuning_args, generating_args = get_infer_args(args) if model_args.infer_backend == EngineName.HF: - self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args) + self.engine: BaseEngine = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args) elif model_args.infer_backend == EngineName.VLLM: - self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args) + self.engine: BaseEngine = VllmEngine(model_args, data_args, finetuning_args, generating_args) else: raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}") @@ -61,17 +61,15 @@ class ChatModel: def chat( self, - messages: Sequence[Dict[str, str]], + messages: Sequence[dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, images: Optional[Sequence["ImageInput"]] = None, videos: Optional[Sequence["VideoInput"]] = None, audios: Optional[Sequence["AudioInput"]] = None, **input_kwargs, - ) -> List["Response"]: - r""" - Gets a list of responses of the chat model. - """ + ) -> list["Response"]: + r"""Get a list of responses of the chat model.""" task = asyncio.run_coroutine_threadsafe( self.achat(messages, system, tools, images, videos, audios, **input_kwargs), self._loop ) @@ -79,22 +77,20 @@ class ChatModel: async def achat( self, - messages: Sequence[Dict[str, str]], + messages: Sequence[dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, images: Optional[Sequence["ImageInput"]] = None, videos: Optional[Sequence["VideoInput"]] = None, audios: Optional[Sequence["AudioInput"]] = None, **input_kwargs, - ) -> List["Response"]: - r""" - Asynchronously gets a list of responses of the chat model. - """ + ) -> list["Response"]: + r"""Asynchronously get a list of responses of the chat model.""" return await self.engine.chat(messages, system, tools, images, videos, audios, **input_kwargs) def stream_chat( self, - messages: Sequence[Dict[str, str]], + messages: Sequence[dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, images: Optional[Sequence["ImageInput"]] = None, @@ -102,9 +98,7 @@ class ChatModel: audios: Optional[Sequence["AudioInput"]] = None, **input_kwargs, ) -> Generator[str, None, None]: - r""" - Gets the response token-by-token of the chat model. - """ + r"""Get the response token-by-token of the chat model.""" generator = self.astream_chat(messages, system, tools, images, videos, audios, **input_kwargs) while True: try: @@ -115,7 +109,7 @@ class ChatModel: async def astream_chat( self, - messages: Sequence[Dict[str, str]], + messages: Sequence[dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, images: Optional[Sequence["ImageInput"]] = None, @@ -123,9 +117,7 @@ class ChatModel: audios: Optional[Sequence["AudioInput"]] = None, **input_kwargs, ) -> AsyncGenerator[str, None]: - r""" - Asynchronously gets the response token-by-token of the chat model. - """ + r"""Asynchronously get the response token-by-token of the chat model.""" async for new_token in self.engine.stream_chat( messages, system, tools, images, videos, audios, **input_kwargs ): @@ -133,23 +125,19 @@ class ChatModel: def get_scores( self, - batch_input: List[str], + batch_input: list[str], **input_kwargs, - ) -> List[float]: - r""" - Gets a list of scores of the reward model. - """ + ) -> list[float]: + r"""Get a list of scores of the reward model.""" task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop) return task.result() async def aget_scores( self, - batch_input: List[str], + batch_input: list[str], **input_kwargs, - ) -> List[float]: - r""" - Asynchronously gets a list of scores of the reward model. - """ + ) -> list[float]: + r"""Asynchronously get a list of scores of the reward model.""" return await self.engine.get_scores(batch_input, **input_kwargs) diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index 4b829881..510cf0c6 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -15,8 +15,9 @@ import asyncio import concurrent.futures import os +from collections.abc import AsyncGenerator, Sequence from threading import Thread -from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch from transformers import GenerationConfig, TextIteratorStreamer @@ -76,15 +77,15 @@ class HuggingfaceEngine(BaseEngine): tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"], template: "Template", - generating_args: Dict[str, Any], - messages: Sequence[Dict[str, str]], + generating_args: dict[str, Any], + messages: Sequence[dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, images: Optional[Sequence["ImageInput"]] = None, videos: Optional[Sequence["VideoInput"]] = None, audios: Optional[Sequence["AudioInput"]] = None, - input_kwargs: Optional[Dict[str, Any]] = {}, - ) -> Tuple[Dict[str, Any], int]: + input_kwargs: Optional[dict[str, Any]] = {}, + ) -> tuple[dict[str, Any], int]: mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]} if images is not None: mm_input_dict.update({"images": images, "imglens": [len(images)]}) @@ -130,7 +131,7 @@ class HuggingfaceEngine(BaseEngine): skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None) max_length: Optional[int] = input_kwargs.pop("max_length", None) max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None) - stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None) + stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None) if stop is not None: logger.warning_rank0("Stop parameter is not supported by the huggingface engine yet.") @@ -217,15 +218,15 @@ class HuggingfaceEngine(BaseEngine): tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"], template: "Template", - generating_args: Dict[str, Any], - messages: Sequence[Dict[str, str]], + generating_args: dict[str, Any], + messages: Sequence[dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, images: Optional[Sequence["ImageInput"]] = None, videos: Optional[Sequence["VideoInput"]] = None, audios: Optional[Sequence["AudioInput"]] = None, - input_kwargs: Optional[Dict[str, Any]] = {}, - ) -> List["Response"]: + input_kwargs: Optional[dict[str, Any]] = {}, + ) -> list["Response"]: gen_kwargs, prompt_length = HuggingfaceEngine._process_args( model, tokenizer, @@ -272,14 +273,14 @@ class HuggingfaceEngine(BaseEngine): tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"], template: "Template", - generating_args: Dict[str, Any], - messages: Sequence[Dict[str, str]], + generating_args: dict[str, Any], + messages: Sequence[dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, images: Optional[Sequence["ImageInput"]] = None, videos: Optional[Sequence["VideoInput"]] = None, audios: Optional[Sequence["AudioInput"]] = None, - input_kwargs: Optional[Dict[str, Any]] = {}, + input_kwargs: Optional[dict[str, Any]] = {}, ) -> Callable[[], str]: gen_kwargs, _ = HuggingfaceEngine._process_args( model, @@ -317,12 +318,12 @@ class HuggingfaceEngine(BaseEngine): def _get_scores( model: "PreTrainedModelWrapper", tokenizer: "PreTrainedTokenizer", - batch_input: List[str], - input_kwargs: Optional[Dict[str, Any]] = {}, - ) -> List[float]: + batch_input: list[str], + input_kwargs: Optional[dict[str, Any]] = {}, + ) -> list[float]: max_length: Optional[int] = input_kwargs.pop("max_length", None) device = getattr(model.pretrained_model, "device", "cuda") - inputs: Dict[str, "torch.Tensor"] = tokenizer( + inputs: dict[str, torch.Tensor] = tokenizer( batch_input, padding=True, truncation=True, @@ -330,21 +331,21 @@ class HuggingfaceEngine(BaseEngine): return_tensors="pt", add_special_tokens=False, ).to(device) - values: "torch.Tensor" = model(**inputs, return_dict=True, use_cache=False)[-1] + values: torch.Tensor = model(**inputs, return_dict=True, use_cache=False)[-1] scores = values.gather(dim=-1, index=(inputs["attention_mask"].sum(dim=-1, keepdim=True) - 1)) return scores @override async def chat( self, - messages: Sequence[Dict[str, str]], + messages: Sequence[dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, images: Optional[Sequence["ImageInput"]] = None, videos: Optional[Sequence["VideoInput"]] = None, audios: Optional[Sequence["AudioInput"]] = None, **input_kwargs, - ) -> List["Response"]: + ) -> list["Response"]: if not self.can_generate: raise ValueError("The current model does not support `chat`.") @@ -370,7 +371,7 @@ class HuggingfaceEngine(BaseEngine): @override async def stream_chat( self, - messages: Sequence[Dict[str, str]], + messages: Sequence[dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, images: Optional[Sequence["ImageInput"]] = None, @@ -408,9 +409,9 @@ class HuggingfaceEngine(BaseEngine): @override async def get_scores( self, - batch_input: List[str], + batch_input: list[str], **input_kwargs, - ) -> List[float]: + ) -> list[float]: if self.can_generate: raise ValueError("Cannot get scores using an auto-regressive model.") diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index 0acfc370..ab478728 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -13,7 +13,8 @@ # limitations under the License. import uuid -from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union +from collections.abc import AsyncGenerator, AsyncIterator, Sequence +from typing import TYPE_CHECKING, Any, Optional, Union from typing_extensions import override @@ -53,7 +54,7 @@ class VllmEngine(BaseEngine): self.model_args = model_args config = load_config(model_args) # may download model from ms hub if getattr(config, "quantization_config", None): # gptq models should use float16 - quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None) + quantization_config: dict[str, Any] = getattr(config, "quantization_config", None) quant_method = quantization_config.get("quant_method", "") if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto": model_args.infer_dtype = "float16" @@ -101,7 +102,7 @@ class VllmEngine(BaseEngine): async def _generate( self, - messages: Sequence[Dict[str, str]], + messages: Sequence[dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, images: Optional[Sequence["ImageInput"]] = None, @@ -143,7 +144,7 @@ class VllmEngine(BaseEngine): skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None) max_length: Optional[int] = input_kwargs.pop("max_length", None) max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None) - stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None) + stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None) if length_penalty is not None: logger.warning_rank0("Length penalty is not supported by the vllm engine yet.") @@ -201,14 +202,14 @@ class VllmEngine(BaseEngine): @override async def chat( self, - messages: Sequence[Dict[str, str]], + messages: Sequence[dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, images: Optional[Sequence["ImageInput"]] = None, videos: Optional[Sequence["VideoInput"]] = None, audios: Optional[Sequence["AudioInput"]] = None, **input_kwargs, - ) -> List["Response"]: + ) -> list["Response"]: final_output = None generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs) async for request_output in generator: @@ -230,7 +231,7 @@ class VllmEngine(BaseEngine): @override async def stream_chat( self, - messages: Sequence[Dict[str, str]], + messages: Sequence[dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, images: Optional[Sequence["ImageInput"]] = None, @@ -248,7 +249,7 @@ class VllmEngine(BaseEngine): @override async def get_scores( self, - batch_input: List[str], + batch_input: list[str], **input_kwargs, - ) -> List[float]: + ) -> list[float]: raise NotImplementedError("vLLM engine does not support get_scores.") diff --git a/src/llamafactory/data/__init__.py b/src/llamafactory/data/__init__.py index 247d8cf0..11c8c9fc 100644 --- a/src/llamafactory/data/__init__.py +++ b/src/llamafactory/data/__init__.py @@ -24,14 +24,14 @@ from .template import TEMPLATES, Template, get_template_and_fix_tokenizer __all__ = [ + "TEMPLATES", "KTODataCollatorWithPadding", "MultiModalDataCollatorForSeq2Seq", "PairwiseDataCollatorWithPadding", - "SFTDataCollatorWith4DAttentionMask", "Role", - "split_dataset", - "get_dataset", - "TEMPLATES", + "SFTDataCollatorWith4DAttentionMask", "Template", + "get_dataset", "get_template_and_fix_tokenizer", + "split_dataset", ] diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index 68d2978c..be9d9eb0 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -15,8 +15,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Sequence from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence +from typing import TYPE_CHECKING, Any, Literal, Optional import numpy as np import torch @@ -38,9 +39,10 @@ if TYPE_CHECKING: def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor": - r""" - Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len), - while handles packed sequences and transforms the mask to lower triangular form to prevent future peeking. + r"""Expand 2d attention mask to 4d attention mask. + + Expand the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len), + handle packed sequences and transforms the mask to lower triangular form to prevent future peeking. e.g. ```python @@ -78,8 +80,7 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype @dataclass class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): - r""" - Data collator that supports VLMs. + r"""Data collator that supports VLMs. Features should contain input_ids, attention_mask, labels, and optionally contain images, videos and audios. """ @@ -91,7 +92,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): if self.template is None: raise ValueError("Template is required for MultiModalDataCollator.") - def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: + def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]: batch_images, batch_videos, batch_audios = [], [], [] batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], [] for feature in features: @@ -166,7 +167,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): for i, feature in enumerate(features): feature["token_type_ids"] = token_type_ids[i] - features: Dict[str, "torch.Tensor"] = super().__call__(features) + features: dict[str, torch.Tensor] = super().__call__(features) if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope rope_index_kwargs = { @@ -198,15 +199,13 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): @dataclass class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq): - r""" - Data collator for 4d attention mask. - """ + r"""Data collator for 4d attention mask.""" block_diag_attn: bool = False attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager" compute_dtype: "torch.dtype" = torch.float32 - def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: + def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]: features = super().__call__(features) if self.block_diag_attn and self.attn_implementation != "flash_attention_2": features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype) @@ -220,13 +219,10 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq): @dataclass class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq): - r""" - Data collator for pairwise data. - """ + r"""Data collator for pairwise data.""" - def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: - r""" - Pads batched data to the longest sequence in the batch. + def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]: + r"""Pad batched data to the longest sequence in the batch. We generate 2 * n examples where the first n examples represent chosen examples and the last n examples represent rejected examples. @@ -249,11 +245,9 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq): @dataclass class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq): - r""" - Data collator for KTO data. - """ + r"""Data collator for KTO data.""" - def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: + def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]: target_features = [] kl_features = [] kto_tags = [] diff --git a/src/llamafactory/data/converter.py b/src/llamafactory/data/converter.py index ec456cd1..8449d7c5 100644 --- a/src/llamafactory/data/converter.py +++ b/src/llamafactory/data/converter.py @@ -14,8 +14,9 @@ import os from abc import abstractmethod +from collections.abc import Sequence from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Type, Union +from typing import TYPE_CHECKING, Any, Optional, Union from ..extras import logging from .data_utils import Role @@ -36,10 +37,8 @@ class DatasetConverter: dataset_attr: "DatasetAttr" data_args: "DataArguments" - def _find_medias(self, medias: Union[Any, Sequence[Any]]) -> Optional[List[Any]]: - r""" - Optionally concatenates media path to media dir when loading from local disk. - """ + def _find_medias(self, medias: Union[Any, Sequence[Any]]) -> Optional[list[Any]]: + r"""Optionally concatenate media path to media dir when loading from local disk.""" if not isinstance(medias, list): medias = [medias] if medias is not None else [] elif len(medias) == 0: @@ -57,16 +56,14 @@ class DatasetConverter: return medias @abstractmethod - def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]: - r""" - Converts a single example in the dataset to the standard format. - """ + def __call__(self, example: dict[str, Any]) -> dict[str, Any]: + r"""Convert a single example in the dataset to the standard format.""" ... @dataclass class AlpacaDatasetConverter(DatasetConverter): - def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]: + def __call__(self, example: dict[str, Any]) -> dict[str, Any]: prompt = [] if self.dataset_attr.history and isinstance(example[self.dataset_attr.history], list): for old_prompt, old_response in example[self.dataset_attr.history]: @@ -116,7 +113,7 @@ class AlpacaDatasetConverter(DatasetConverter): @dataclass class SharegptDatasetConverter(DatasetConverter): - def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]: + def __call__(self, example: dict[str, Any]) -> dict[str, Any]: tag_mapping = { self.dataset_attr.user_tag: Role.USER.value, self.dataset_attr.assistant_tag: Role.ASSISTANT.value, @@ -216,10 +213,8 @@ DATASET_CONVERTERS = { } -def register_dataset_converter(name: str, dataset_converter: Type["DatasetConverter"]) -> None: - r""" - Register a new dataset converter. - """ +def register_dataset_converter(name: str, dataset_converter: type["DatasetConverter"]) -> None: + r"""Register a new dataset converter.""" if name in DATASET_CONVERTERS: raise ValueError(f"Dataset converter {name} already exists.") @@ -227,9 +222,7 @@ def register_dataset_converter(name: str, dataset_converter: Type["DatasetConver def get_dataset_converter(name: str, dataset_attr: "DatasetAttr", data_args: "DataArguments") -> "DatasetConverter": - r""" - Gets a dataset converter. - """ + r"""Get a dataset converter.""" if name not in DATASET_CONVERTERS: raise ValueError(f"Dataset converter {name} not found.") @@ -242,17 +235,17 @@ def align_dataset( data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments", ) -> Union["Dataset", "IterableDataset"]: - r""" - Aligned dataset: - _prompt: [{"role": "user", "content": "..."}] * (2T - 1) - _response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset) - _system: "..." - _tools: "...", - _images: [], - _videos: [], - _audios: [], - """ + r"""Align the dataset to a specific format. + Aligned dataset: + _prompt: [{"role": "user", "content": "..."}] * (2T - 1) + _response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset) + _system: "..." + _tools: "..." + _images: [] + _videos: [] + _audios: [] + """ column_names = list(next(iter(dataset)).keys()) kwargs = {} if not data_args.streaming: diff --git a/src/llamafactory/data/data_utils.py b/src/llamafactory/data/data_utils.py index fd050f91..13a5b4cb 100644 --- a/src/llamafactory/data/data_utils.py +++ b/src/llamafactory/data/data_utils.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Sequence from enum import Enum, unique -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, TypedDict, Union +from typing import TYPE_CHECKING, Optional, TypedDict, Union from datasets import DatasetDict, concatenate_datasets, interleave_datasets @@ -29,7 +30,7 @@ if TYPE_CHECKING: logger = logging.get_logger(__name__) -SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]] +SLOTS = Sequence[Union[str, set[str], dict[str, str]]] @unique @@ -43,15 +44,13 @@ class Role(str, Enum): class DatasetModule(TypedDict): train_dataset: Optional[Union["Dataset", "IterableDataset"]] - eval_dataset: Optional[Union["Dataset", "IterableDataset", Dict[str, "Dataset"]]] + eval_dataset: Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]] def merge_dataset( - all_datasets: List[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int + all_datasets: list[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int ) -> Union["Dataset", "IterableDataset"]: - r""" - Merges multiple datasets to a unified dataset. - """ + r"""Merge multiple datasets to a unified dataset.""" if len(all_datasets) == 1: return all_datasets[0] @@ -78,14 +77,13 @@ def merge_dataset( def split_dataset( dataset: Optional[Union["Dataset", "IterableDataset"]], - eval_dataset: Optional[Union["Dataset", "IterableDataset", Dict[str, "Dataset"]]], + eval_dataset: Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]], data_args: "DataArguments", seed: int, ) -> "DatasetDict": - r""" - Splits the dataset and returns a dataset dict containing train set and validation set. + r"""Split the dataset and returns a dataset dict containing train set and validation set. - Supports both map dataset and iterable dataset. + Support both map dataset and iterable dataset. """ if eval_dataset is not None and data_args.val_size > 1e-6: raise ValueError("Cannot specify `val_size` if `eval_dataset` is not None.") @@ -120,10 +118,8 @@ def split_dataset( def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModule": - r""" - Converts dataset or dataset dict to dataset module. - """ - dataset_module: "DatasetModule" = {} + r"""Convert dataset or dataset dict to dataset module.""" + dataset_module: DatasetModule = {} if isinstance(dataset, DatasetDict): # dataset dict if "train" in dataset: dataset_module["train_dataset"] = dataset["train"] diff --git a/src/llamafactory/data/formatter.py b/src/llamafactory/data/formatter.py index 754fc54b..0a101ac6 100644 --- a/src/llamafactory/data/formatter.py +++ b/src/llamafactory/data/formatter.py @@ -16,7 +16,7 @@ import json import re from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import List, Optional, Union +from typing import Optional, Union from typing_extensions import override @@ -31,14 +31,11 @@ class Formatter(ABC): @abstractmethod def apply(self, **kwargs) -> SLOTS: - r""" - Forms a list of slots according to the inputs to encode. - """ + r"""Forms a list of slots according to the inputs to encode.""" ... - def extract(self, content: str) -> Union[str, List["FunctionCall"]]: - r""" - Extract a list of tuples from the response message if using tools. + def extract(self, content: str) -> Union[str, list["FunctionCall"]]: + r"""Extract a list of tuples from the response message if using tools. Each tuple consists of function name and function arguments. """ @@ -105,7 +102,7 @@ class FunctionFormatter(StringFormatter): if thought: content = content.replace(thought.group(0), "") - functions: List["FunctionCall"] = [] + functions: list[FunctionCall] = [] try: tool_calls = json.loads(content) if not isinstance(tool_calls, list): # parallel function call @@ -141,5 +138,5 @@ class ToolFormatter(Formatter): raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") # flat string @override - def extract(self, content: str) -> Union[str, List["FunctionCall"]]: + def extract(self, content: str) -> Union[str, list["FunctionCall"]]: return self.tool_utils.tool_extractor(content) diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index 2055c7ff..5f243329 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -13,7 +13,8 @@ # limitations under the License. import os -from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union +from collections.abc import Sequence +from typing import TYPE_CHECKING, Literal, Optional, Union import numpy as np from datasets import load_dataset, load_from_disk @@ -54,9 +55,7 @@ def _load_single_dataset( data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments", ) -> Union["Dataset", "IterableDataset"]: - r""" - Loads a single dataset and aligns it to the standard format. - """ + r"""Load a single dataset and aligns it to the standard format.""" logger.info_rank0(f"Loading dataset {dataset_attr}...") data_path, data_name, data_dir, data_files = None, None, None, None if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]: @@ -164,10 +163,8 @@ def _get_merged_dataset( training_args: "Seq2SeqTrainingArguments", stage: Literal["pt", "sft", "rm", "ppo", "kto"], merge: bool = True, -) -> Optional[Union["Dataset", "IterableDataset", Dict[str, "Dataset"]]]: - r""" - Returns the merged datasets in the standard format. - """ +) -> Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]: + r"""Return the merged datasets in the standard format.""" if dataset_names is None: return None @@ -192,9 +189,7 @@ def _get_dataset_processor( processor: Optional["ProcessorMixin"], do_generate: bool = False, ) -> "DatasetProcessor": - r""" - Returns the corresponding dataset processor. - """ + r"""Return the corresponding dataset processor.""" if stage == "pt": dataset_processor_class = PretrainDatasetProcessor elif stage == "sft" and not do_generate: @@ -236,9 +231,7 @@ def _get_preprocessed_dataset( processor: Optional["ProcessorMixin"] = None, is_eval: bool = False, ) -> Optional[Union["Dataset", "IterableDataset"]]: - r""" - Preprocesses the dataset, including format checking and tokenization. - """ + r"""Preprocesses the dataset, including format checking and tokenization.""" if dataset is None: return None @@ -284,9 +277,7 @@ def get_dataset( tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"] = None, ) -> "DatasetModule": - r""" - Gets the train dataset and optionally gets the evaluation dataset. - """ + r"""Get the train dataset and optionally gets the evaluation dataset.""" # Load tokenized dataset if path exists if data_args.tokenized_path is not None: if has_tokenized_data(data_args.tokenized_path): diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 0a63800a..422b10aa 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -1,10 +1,11 @@ import inspect import math import re +from collections.abc import Sequence from copy import deepcopy from dataclasses import dataclass from io import BytesIO -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Type, TypedDict, Union +from typing import TYPE_CHECKING, Optional, TypedDict, Union import numpy as np import torch @@ -58,12 +59,12 @@ if TYPE_CHECKING: def _get_paligemma_token_type_ids( imglens: Sequence[int], seqlens: Sequence[int], processor: "ProcessorMixin" -) -> List[List[int]]: - r""" - Gets paligemma token type ids for computing loss. +) -> list[list[int]]: + r"""Get paligemma token type ids for computing loss. Returns: batch_token_type_ids: shape (batch_size, sequence_length) + """ batch_token_type_ids = [] for imglen, seqlen in zip(imglens, seqlens): @@ -87,11 +88,9 @@ class MMPluginMixin: videos: Sequence["VideoInput"], audios: Sequence["AudioInput"], ) -> None: - r""" - Validates if this model accepts the input modalities. - """ - image_processor: "BaseImageProcessor" = getattr(processor, "image_processor", None) - feature_extractor: "SequenceFeatureExtractor" = getattr(processor, "feature_extractor", None) + r"""Validate if this model accepts the input modalities.""" + image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) + feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) if len(images) != 0 and self.image_token is None: raise ValueError( "This model does not support image input. Please check whether the correct `template` is used." @@ -119,9 +118,7 @@ class MMPluginMixin: def _preprocess_image( self, image: "ImageObject", image_max_pixels: int, image_min_pixels: int, **kwargs ) -> "ImageObject": - r""" - Pre-processes a single image. - """ + r"""Pre-process a single image.""" if (image.width * image.height) > image_max_pixels: resize_factor = math.sqrt(image_max_pixels / (image.width * image.height)) width, height = int(image.width * resize_factor), int(image.height * resize_factor) @@ -139,10 +136,8 @@ class MMPluginMixin: def _get_video_sample_indices( self, video_stream: "Stream", video_fps: float, video_maxlen: int, **kwargs - ) -> List[int]: - r""" - Computes video sample indices according to fps. - """ + ) -> list[int]: + r"""Compute video sample indices according to fps.""" total_frames = video_stream.frames if total_frames == 0: # infinite video return np.linspace(0, video_maxlen - 1, video_maxlen).astype(np.int32) @@ -151,10 +146,8 @@ class MMPluginMixin: sample_frames = min(total_frames, video_maxlen, sample_frames) return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32) - def _regularize_images(self, images: Sequence["ImageInput"], **kwargs) -> List["ImageObject"]: - r""" - Regularizes images to avoid error. Including reading and pre-processing. - """ + def _regularize_images(self, images: Sequence["ImageInput"], **kwargs) -> list["ImageObject"]: + r"""Regularize images to avoid error. Including reading and pre-processing.""" results = [] for image in images: if isinstance(image, str): @@ -174,16 +167,14 @@ class MMPluginMixin: return results - def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]: - r""" - Regularizes videos to avoid error. Including reading, resizing and converting. - """ + def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> list[list["ImageObject"]]: + r"""Regularizes videos to avoid error. Including reading, resizing and converting.""" results = [] for video in videos: container = av.open(video, "r") video_stream = next(stream for stream in container.streams if stream.type == "video") sample_indices = self._get_video_sample_indices(video_stream, **kwargs) - frames: List["ImageObject"] = [] + frames: list[ImageObject] = [] container.seek(0) for frame_idx, frame in enumerate(container.decode(video_stream)): if frame_idx in sample_indices: @@ -194,10 +185,8 @@ class MMPluginMixin: return results - def _regularize_audios(self, audios: Sequence["AudioInput"], sampling_rate: float, **kwargs) -> List["NDArray"]: - r""" - Regularizes audios to avoid error. Including reading and resampling. - """ + def _regularize_audios(self, audios: Sequence["AudioInput"], sampling_rate: float, **kwargs) -> list["NDArray"]: + r"""Regularizes audios to avoid error. Including reading and resampling.""" results = [] for audio in audios: if isinstance(audio, str): @@ -216,9 +205,8 @@ class MMPluginMixin: videos: Sequence["VideoInput"], audios: Sequence["AudioInput"], processor: "ProcessorMixin", - ) -> Dict[str, "torch.Tensor"]: - r""" - Processes visual inputs. + ) -> dict[str, "torch.Tensor"]: + r"""Process visual inputs. Returns: (llava and paligemma) pixel_values: tensor with shape (B, C, H, W) @@ -229,9 +217,9 @@ class MMPluginMixin: It holds num_patches == torch.prod(image_grid_thw) """ - image_processor: "BaseImageProcessor" = getattr(processor, "image_processor", None) - video_processor: "BaseImageProcessor" = getattr(processor, "video_processor", image_processor) - feature_extractor: "SequenceFeatureExtractor" = getattr(processor, "feature_extractor", None) + image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) + video_processor: BaseImageProcessor = getattr(processor, "video_processor", image_processor) + feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) mm_inputs = {} if len(images) != 0: @@ -278,31 +266,27 @@ class MMPluginMixin: class BasePlugin(MMPluginMixin): def process_messages( self, - messages: Sequence[Dict[str, str]], + messages: Sequence[dict[str, str]], images: Sequence["ImageInput"], videos: Sequence["VideoInput"], audios: Sequence["AudioInput"], processor: Optional["ProcessorMixin"], - ) -> List[Dict[str, str]]: - r""" - Pre-processes input messages before tokenization for VLMs. - """ + ) -> list[dict[str, str]]: + r"""Pre-processes input messages before tokenization for VLMs.""" self._validate_input(processor, images, videos, audios) return messages def process_token_ids( self, - input_ids: List[int], - labels: Optional[List[int]], + input_ids: list[int], + labels: Optional[list[int]], images: Sequence["ImageInput"], videos: Sequence["VideoInput"], audios: Sequence["AudioInput"], tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"], - ) -> Tuple[List[int], Optional[List[int]]]: - r""" - Pre-processes token ids after tokenization for VLMs. - """ + ) -> tuple[list[int], Optional[list[int]]]: + r"""Pre-processes token ids after tokenization for VLMs.""" self._validate_input(processor, images, videos, audios) return input_ids, labels @@ -314,20 +298,21 @@ class BasePlugin(MMPluginMixin): imglens: Sequence[int], vidlens: Sequence[int], audlens: Sequence[int], - batch_ids: Sequence[List[int]], + batch_ids: Sequence[list[int]], processor: Optional["ProcessorMixin"], - ) -> Dict[str, Union[List[int], "torch.Tensor"]]: - r""" - Builds batched multimodal inputs for VLMs. + ) -> dict[str, Union[list[int], "torch.Tensor"]]: + r"""Build batched multimodal inputs for VLMs. Arguments: images: a list of image inputs, shape (num_images,) videos: a list of video inputs, shape (num_videos,) + audios: a list of audio inputs, shape (num_audios,) imglens: number of images in each sample, shape (batch_size,) vidlens: number of videos in each sample, shape (batch_size,) audlens: number of audios in each sample, shape (batch_size,) batch_ids: token ids of input samples, shape (batch_size, seq_len) processor: a processor for pre-processing images and videos + """ self._validate_input(processor, images, videos, audios) return {} @@ -338,12 +323,12 @@ class LlavaPlugin(BasePlugin): @override def process_messages( self, - messages: Sequence[Dict[str, str]], + messages: Sequence[dict[str, str]], images: Sequence["ImageInput"], videos: Sequence["VideoInput"], audios: Sequence["AudioInput"], processor: Optional["ProcessorMixin"], - ) -> List[Dict[str, str]]: + ) -> list[dict[str, str]]: self._validate_input(processor, images, videos, audios) num_image_tokens = 0 image_seqlen = getattr(processor, "image_seqlen") if self.expand_mm_tokens else 1 @@ -370,9 +355,9 @@ class LlavaPlugin(BasePlugin): imglens: Sequence[int], vidlens: Sequence[int], audlens: Sequence[int], - batch_ids: Sequence[List[int]], + batch_ids: Sequence[list[int]], processor: Optional["ProcessorMixin"], - ) -> Dict[str, Union[List[int], "torch.Tensor"]]: + ) -> dict[str, Union[list[int], "torch.Tensor"]]: self._validate_input(processor, images, videos, audios) return self._get_mm_inputs(images, videos, audios, processor) @@ -382,12 +367,12 @@ class LlavaNextPlugin(BasePlugin): @override def process_messages( self, - messages: Sequence[Dict[str, str]], + messages: Sequence[dict[str, str]], images: Sequence["ImageInput"], videos: Sequence["VideoInput"], audios: Sequence["AudioInput"], processor: Optional["ProcessorMixin"], - ) -> List[Dict[str, str]]: + ) -> list[dict[str, str]]: self._validate_input(processor, images, videos, audios) num_image_tokens = 0 messages = deepcopy(messages) @@ -426,9 +411,9 @@ class LlavaNextPlugin(BasePlugin): imglens: Sequence[int], vidlens: Sequence[int], audlens: Sequence[int], - batch_ids: Sequence[List[int]], + batch_ids: Sequence[list[int]], processor: Optional["ProcessorMixin"], - ) -> Dict[str, Union[List[int], "torch.Tensor"]]: + ) -> dict[str, Union[list[int], "torch.Tensor"]]: self._validate_input(processor, images, videos, audios) return self._get_mm_inputs(images, videos, audios, processor) @@ -438,12 +423,12 @@ class LlavaNextVideoPlugin(BasePlugin): @override def process_messages( self, - messages: Sequence[Dict[str, str]], + messages: Sequence[dict[str, str]], images: Sequence["ImageInput"], videos: Sequence["VideoInput"], audios: Sequence["AudioInput"], processor: Optional["ProcessorMixin"], - ) -> List[Dict[str, str]]: + ) -> list[dict[str, str]]: self._validate_input(processor, images, videos, audios) num_image_tokens, num_video_tokens = 0, 0 messages = deepcopy(messages) @@ -502,9 +487,9 @@ class LlavaNextVideoPlugin(BasePlugin): imglens: Sequence[int], vidlens: Sequence[int], audlens: Sequence[int], - batch_ids: Sequence[List[int]], + batch_ids: Sequence[list[int]], processor: Optional["ProcessorMixin"], - ) -> Dict[str, Union[List[int], "torch.Tensor"]]: + ) -> dict[str, Union[list[int], "torch.Tensor"]]: self._validate_input(processor, images, videos, audios) return self._get_mm_inputs(images, videos, audios, processor) @@ -514,16 +499,16 @@ class MiniCPMVPlugin(BasePlugin): @override def process_messages( self, - messages: Sequence[Dict[str, str]], + messages: Sequence[dict[str, str]], images: Sequence["ImageInput"], videos: Sequence["VideoInput"], audios: Sequence["AudioInput"], processor: Optional["ProcessorMixin"], - ) -> List[Dict[str, str]]: + ) -> list[dict[str, str]]: self._validate_input(processor, images, videos, audios) num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0 messages = deepcopy(messages) - image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") + image_processor: BaseImageProcessor = getattr(processor, "image_processor") mm_inputs = {} audio_inputs = {} if len(images) != 0 and len(videos) != 0: @@ -619,9 +604,9 @@ class MiniCPMVPlugin(BasePlugin): audios: Sequence["AudioInput"], processor: "ProcessorMixin", **kwargs, - ) -> Dict[str, "torch.Tensor"]: - image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") - feature_extractor: "SequenceFeatureExtractor" = getattr(processor, "feature_extractor", None) + ) -> dict[str, "torch.Tensor"]: + image_processor: BaseImageProcessor = getattr(processor, "image_processor") + feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) mm_inputs = {} if len(images) != 0: images = self._regularize_images( @@ -691,9 +676,9 @@ class MiniCPMVPlugin(BasePlugin): imglens: Sequence[int], vidlens: Sequence[int], audlens: Sequence[int], - batch_ids: Sequence[List[int]], + batch_ids: Sequence[list[int]], processor: Optional["ProcessorMixin"], - ) -> Dict[str, Union[List[int], "torch.Tensor"]]: + ) -> dict[str, Union[list[int], "torch.Tensor"]]: self._validate_input(processor, images, videos, audios) # image bound image_bounds_list = [] @@ -756,12 +741,12 @@ class MllamaPlugin(BasePlugin): @override def process_messages( self, - messages: Sequence[Dict[str, str]], + messages: Sequence[dict[str, str]], images: Sequence["ImageInput"], videos: Sequence["VideoInput"], audios: Sequence["AudioInput"], processor: Optional["ProcessorMixin"], - ) -> List[Dict[str, str]]: + ) -> list[dict[str, str]]: self._validate_input(processor, images, videos, audios) num_image_tokens = 0 messages = deepcopy(messages) @@ -782,10 +767,9 @@ class MllamaPlugin(BasePlugin): videos: Sequence["VideoInput"], audios: Sequence["AudioInput"], processor: "ProcessorMixin", - imglens: List[int], - ) -> Dict[str, "torch.Tensor"]: - r""" - Processes visual inputs for mllama because its image processor only accepts List[List[ImageInput]]. + imglens: list[int], + ) -> dict[str, "torch.Tensor"]: + r"""Process visual inputs for mllama because its image processor only accepts List[List[ImageInput]]. Returns: pixel_values: tensor with shape @@ -794,8 +778,9 @@ class MllamaPlugin(BasePlugin): aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1). aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4). num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1). + """ - image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") + image_processor: BaseImageProcessor = getattr(processor, "image_processor") mm_inputs = {} if len(images) > 0: images = self._regularize_images( @@ -821,9 +806,9 @@ class MllamaPlugin(BasePlugin): imglens: Sequence[int], vidlens: Sequence[int], audlens: Sequence[int], - batch_ids: Sequence[List[int]], + batch_ids: Sequence[list[int]], processor: Optional["ProcessorMixin"], - ) -> Dict[str, Union[List[int], "torch.Tensor"]]: + ) -> dict[str, Union[list[int], "torch.Tensor"]]: self._validate_input(processor, images, videos, audios) mm_inputs = self._get_mm_inputs(images, videos, audios, processor, imglens) if mm_inputs: @@ -850,12 +835,12 @@ class PaliGemmaPlugin(BasePlugin): @override def process_messages( self, - messages: Sequence[Dict[str, str]], + messages: Sequence[dict[str, str]], images: Sequence["ImageInput"], videos: Sequence["VideoInput"], audios: Sequence["AudioInput"], processor: Optional["ProcessorMixin"], - ) -> List[Dict[str, str]]: + ) -> list[dict[str, str]]: self._validate_input(processor, images, videos, audios) num_image_tokens = 0 messages = deepcopy(messages) @@ -875,14 +860,14 @@ class PaliGemmaPlugin(BasePlugin): @override def process_token_ids( self, - input_ids: List[int], - labels: Optional[List[int]], + input_ids: list[int], + labels: Optional[list[int]], images: Sequence["ImageInput"], videos: Sequence["VideoInput"], audios: Sequence["AudioInput"], tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"], - ) -> Tuple[List[int], Optional[List[int]]]: + ) -> tuple[list[int], Optional[list[int]]]: self._validate_input(processor, images, videos, audios) num_images = len(images) image_seqlen = num_images * getattr(processor, "image_seqlen") if self.expand_mm_tokens else 0 # skip mm token @@ -902,9 +887,9 @@ class PaliGemmaPlugin(BasePlugin): imglens: Sequence[int], vidlens: Sequence[int], audlens: Sequence[int], - batch_ids: Sequence[List[int]], + batch_ids: Sequence[list[int]], processor: Optional["ProcessorMixin"], - ) -> Dict[str, Union[List[int], "torch.Tensor"]]: + ) -> dict[str, Union[list[int], "torch.Tensor"]]: self._validate_input(processor, images, videos, audios) seqlens = [len(input_ids) for input_ids in batch_ids] mm_inputs = self._get_mm_inputs(images, videos, audios, processor) @@ -917,12 +902,12 @@ class PixtralPlugin(BasePlugin): @override def process_messages( self, - messages: Sequence[Dict[str, str]], + messages: Sequence[dict[str, str]], images: Sequence["ImageInput"], videos: Sequence["VideoInput"], audios: Sequence["AudioInput"], processor: Optional["ProcessorMixin"], - ) -> List[Dict[str, str]]: + ) -> list[dict[str, str]]: self._validate_input(processor, images, videos, audios) patch_size = getattr(processor, "patch_size") image_token = getattr(processor, "image_token") @@ -968,9 +953,9 @@ class PixtralPlugin(BasePlugin): imglens: Sequence[int], vidlens: Sequence[int], audlens: Sequence[int], - batch_ids: Sequence[List[int]], + batch_ids: Sequence[list[int]], processor: Optional["ProcessorMixin"], - ) -> Dict[str, Union[List[int], "torch.Tensor"]]: + ) -> dict[str, Union[list[int], "torch.Tensor"]]: self._validate_input(processor, images, videos, audios) mm_inputs = self._get_mm_inputs(images, videos, audios, processor) mm_inputs.pop("image_sizes", None) @@ -982,12 +967,12 @@ class Qwen2AudioPlugin(BasePlugin): @override def process_messages( self, - messages: Sequence[Dict[str, str]], + messages: Sequence[dict[str, str]], images: Sequence["ImageInput"], videos: Sequence["VideoInput"], audios: Sequence["AudioInput"], processor: Optional["ProcessorMixin"], - ) -> List[Dict[str, str]]: + ) -> list[dict[str, str]]: self._validate_input(processor, images, videos, audios) bos_token: str = getattr(processor, "audio_bos_token") eos_token: str = getattr(processor, "audio_eos_token") @@ -1028,9 +1013,9 @@ class Qwen2AudioPlugin(BasePlugin): imglens: Sequence[int], vidlens: Sequence[int], audlens: Sequence[int], - batch_ids: Sequence[List[int]], + batch_ids: Sequence[list[int]], processor: Optional["ProcessorMixin"], - ) -> Dict[str, Union[List[int], "torch.Tensor"]]: + ) -> dict[str, Union[list[int], "torch.Tensor"]]: self._validate_input(processor, images, videos, audios) return self._get_mm_inputs(images, videos, audios, processor) @@ -1057,13 +1042,13 @@ class Qwen2VLPlugin(BasePlugin): @override def _regularize_videos( self, videos: Sequence["VideoInput"], **kwargs - ) -> Tuple[List[List["ImageObject"]], List[float]]: + ) -> tuple[list[list["ImageObject"]], list[float]]: results, fps_per_video = [], [] for video in videos: container = av.open(video, "r") video_stream = next(stream for stream in container.streams if stream.type == "video") sample_indices = self._get_video_sample_indices(video_stream, **kwargs) - frames: List["ImageObject"] = [] + frames: list[ImageObject] = [] container.seek(0) for frame_idx, frame in enumerate(container.decode(video_stream)): if frame_idx in sample_indices: @@ -1088,8 +1073,8 @@ class Qwen2VLPlugin(BasePlugin): videos: Sequence["VideoInput"], audios: Sequence["AudioInput"], processor: "ProcessorMixin", - ) -> Dict[str, "torch.Tensor"]: - image_processor: "BaseImageProcessor" = getattr(processor, "image_processor", None) + ) -> dict[str, "torch.Tensor"]: + image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) mm_inputs = {} if len(images) != 0: images = self._regularize_images( @@ -1115,16 +1100,16 @@ class Qwen2VLPlugin(BasePlugin): @override def process_messages( self, - messages: Sequence[Dict[str, str]], + messages: Sequence[dict[str, str]], images: Sequence["ImageInput"], videos: Sequence["VideoInput"], audios: Sequence["AudioInput"], processor: Optional["ProcessorMixin"], - ) -> List[Dict[str, str]]: + ) -> list[dict[str, str]]: self._validate_input(processor, images, videos, audios) num_image_tokens, num_video_tokens = 0, 0 messages = deepcopy(messages) - image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") + image_processor: BaseImageProcessor = getattr(processor, "image_processor") merge_length: int = getattr(image_processor, "merge_size") ** 2 if self.expand_mm_tokens: @@ -1176,13 +1161,13 @@ class Qwen2VLPlugin(BasePlugin): imglens: Sequence[int], vidlens: Sequence[int], audlens: Sequence[int], - batch_ids: Sequence[List[int]], + batch_ids: Sequence[list[int]], processor: Optional["ProcessorMixin"], - ) -> Dict[str, Union[List[int], "torch.Tensor"]]: + ) -> dict[str, Union[list[int], "torch.Tensor"]]: self._validate_input(processor, images, videos, audios) mm_inputs = self._get_mm_inputs(images, videos, audios, processor) fps_per_video = mm_inputs.pop("fps_per_video", []) - image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") + image_processor: BaseImageProcessor = getattr(processor, "image_processor") if "second_per_grid_ts" in processor.model_input_names and fps_per_video: mm_inputs["second_per_grid_ts"] = [image_processor.temporal_patch_size / fps for fps in fps_per_video] @@ -1194,12 +1179,12 @@ class VideoLlavaPlugin(BasePlugin): @override def process_messages( self, - messages: Sequence[Dict[str, str]], + messages: Sequence[dict[str, str]], images: Sequence["ImageInput"], videos: Sequence["VideoInput"], audios: Sequence["AudioInput"], processor: Optional["ProcessorMixin"], - ) -> List[Dict[str, str]]: + ) -> list[dict[str, str]]: self._validate_input(processor, images, videos, audios) num_image_tokens, num_video_tokens = 0, 0 messages = deepcopy(messages) @@ -1255,9 +1240,9 @@ class VideoLlavaPlugin(BasePlugin): imglens: Sequence[int], vidlens: Sequence[int], audlens: Sequence[int], - batch_ids: Sequence[List[int]], + batch_ids: Sequence[list[int]], processor: Optional["ProcessorMixin"], - ) -> Dict[str, Union[List[int], "torch.Tensor"]]: + ) -> dict[str, Union[list[int], "torch.Tensor"]]: self._validate_input(processor, images, videos, audios) return self._get_mm_inputs(images, videos, audios, processor) @@ -1277,10 +1262,8 @@ PLUGINS = { } -def register_mm_plugin(name: str, plugin_class: Type["BasePlugin"]) -> None: - r""" - Registers a multimodal plugin. - """ +def register_mm_plugin(name: str, plugin_class: type["BasePlugin"]) -> None: + r"""Register a multimodal plugin.""" if name in PLUGINS: raise ValueError(f"Multimodal plugin {name} already exists.") @@ -1293,9 +1276,7 @@ def get_mm_plugin( video_token: Optional[str] = None, audio_token: Optional[str] = None, ) -> "BasePlugin": - r""" - Gets plugin for multimodal inputs. - """ + r"""Get plugin for multimodal inputs.""" if name not in PLUGINS: raise ValueError(f"Multimodal plugin `{name}` not found.") diff --git a/src/llamafactory/data/parser.py b/src/llamafactory/data/parser.py index ac6bc932..4e1c7aff 100644 --- a/src/llamafactory/data/parser.py +++ b/src/llamafactory/data/parser.py @@ -14,8 +14,9 @@ import json import os +from collections.abc import Sequence from dataclasses import dataclass -from typing import Any, Dict, List, Literal, Optional, Sequence +from typing import Any, Literal, Optional from transformers.utils import cached_file @@ -25,9 +26,7 @@ from ..extras.misc import use_modelscope, use_openmind @dataclass class DatasetAttr: - r""" - Dataset attributes. - """ + r"""Dataset attributes.""" # basic configs load_from: Literal["hf_hub", "ms_hub", "om_hub", "script", "file"] @@ -68,10 +67,10 @@ class DatasetAttr: def __repr__(self) -> str: return self.dataset_name - def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None: + def set_attr(self, key: str, obj: dict[str, Any], default: Optional[Any] = None) -> None: setattr(self, key, obj.get(key, default)) - def join(self, attr: Dict[str, Any]) -> None: + def join(self, attr: dict[str, Any]) -> None: self.set_attr("formatting", attr, default="alpaca") self.set_attr("ranking", attr, default=False) self.set_attr("subset", attr) @@ -92,10 +91,8 @@ class DatasetAttr: self.set_attr(tag, attr["tags"]) -def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -> List["DatasetAttr"]: - r""" - Gets the attributes of the datasets. - """ +def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -> list["DatasetAttr"]: + r"""Get the attributes of the datasets.""" if dataset_names is None: dataset_names = [] @@ -116,7 +113,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) - dataset_info = None - dataset_list: List["DatasetAttr"] = [] + dataset_list: list[DatasetAttr] = [] for name in dataset_names: if dataset_info is None: # dataset_dir is ONLINE if use_modelscope(): diff --git a/src/llamafactory/data/processor/__init__.py b/src/llamafactory/data/processor/__init__.py index fa82a88e..a827d005 100644 --- a/src/llamafactory/data/processor/__init__.py +++ b/src/llamafactory/data/processor/__init__.py @@ -9,9 +9,9 @@ from .unsupervised import UnsupervisedDatasetProcessor __all__ = [ "DatasetProcessor", "FeedbackDatasetProcessor", + "PackedSupervisedDatasetProcessor", "PairwiseDatasetProcessor", "PretrainDatasetProcessor", - "PackedSupervisedDatasetProcessor", "SupervisedDatasetProcessor", "UnsupervisedDatasetProcessor", ] diff --git a/src/llamafactory/data/processor/feedback.py b/src/llamafactory/data/processor/feedback.py index fb3c4803..89233e10 100644 --- a/src/llamafactory/data/processor/feedback.py +++ b/src/llamafactory/data/processor/feedback.py @@ -13,7 +13,8 @@ # limitations under the License. from collections import defaultdict -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any, Optional from ...extras import logging from ...extras.constants import IGNORE_INDEX @@ -30,15 +31,15 @@ logger = logging.get_logger(__name__) class FeedbackDatasetProcessor(DatasetProcessor): def _encode_data_example( self, - prompt: Sequence[Dict[str, str]], - response: Sequence[Dict[str, str]], - kl_response: Sequence[Dict[str, str]], + prompt: Sequence[dict[str, str]], + response: Sequence[dict[str, str]], + kl_response: Sequence[dict[str, str]], system: Optional[str], tools: Optional[str], images: Sequence["ImageInput"], videos: Sequence["VideoInput"], audios: Sequence["AudioInput"], - ) -> Tuple[List[int], List[int], List[int], List[int], bool]: + ) -> tuple[list[int], list[int], list[int], list[int], bool]: if response[0]["content"]: # desired example kto_tag = True messages = prompt + [response[0]] @@ -82,7 +83,7 @@ class FeedbackDatasetProcessor(DatasetProcessor): kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids return input_ids, labels, kl_input_ids, kl_labels, kto_tag - def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]: # create unrelated input-output pairs for estimating the KL term by flipping the matched pairs kl_response = examples["_response"][::-1] model_inputs = defaultdict(list) @@ -121,7 +122,7 @@ class FeedbackDatasetProcessor(DatasetProcessor): return model_inputs - def print_data_example(self, example: Dict[str, List[int]]) -> None: + def print_data_example(self, example: dict[str, list[int]]) -> None: valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"])) print("input_ids:\n{}".format(example["input_ids"])) print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False))) diff --git a/src/llamafactory/data/processor/pairwise.py b/src/llamafactory/data/processor/pairwise.py index f30ebbf8..e0a81f0b 100644 --- a/src/llamafactory/data/processor/pairwise.py +++ b/src/llamafactory/data/processor/pairwise.py @@ -13,7 +13,8 @@ # limitations under the License. from collections import defaultdict -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any, Optional from ...extras import logging from ...extras.constants import IGNORE_INDEX @@ -30,14 +31,14 @@ logger = logging.get_logger(__name__) class PairwiseDatasetProcessor(DatasetProcessor): def _encode_data_example( self, - prompt: Sequence[Dict[str, str]], - response: Sequence[Dict[str, str]], + prompt: Sequence[dict[str, str]], + response: Sequence[dict[str, str]], system: Optional[str], tools: Optional[str], images: Sequence["ImageInput"], videos: Sequence["VideoInput"], audios: Sequence["AudioInput"], - ) -> Tuple[List[int], List[int], List[int], List[int]]: + ) -> tuple[list[int], list[int], list[int], list[int]]: chosen_messages = self.template.mm_plugin.process_messages( prompt + [response[0]], images, videos, audios, self.processor ) @@ -68,7 +69,7 @@ class PairwiseDatasetProcessor(DatasetProcessor): rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels - def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]: # build input pairs with format ` X`, `Y1 ` and `Y2 ` model_inputs = defaultdict(list) for i in range(len(examples["_prompt"])): @@ -99,7 +100,7 @@ class PairwiseDatasetProcessor(DatasetProcessor): return model_inputs - def print_data_example(self, example: Dict[str, List[int]]) -> None: + def print_data_example(self, example: dict[str, list[int]]) -> None: valid_chosen_labels = list(filter(lambda x: x != IGNORE_INDEX, example["chosen_labels"])) valid_rejected_labels = list(filter(lambda x: x != IGNORE_INDEX, example["rejected_labels"])) print("chosen_input_ids:\n{}".format(example["chosen_input_ids"])) diff --git a/src/llamafactory/data/processor/pretrain.py b/src/llamafactory/data/processor/pretrain.py index 87e35ad1..385b3914 100644 --- a/src/llamafactory/data/processor/pretrain.py +++ b/src/llamafactory/data/processor/pretrain.py @@ -17,14 +17,14 @@ from dataclasses import dataclass from itertools import chain -from typing import Any, Dict, List +from typing import Any from .processor_utils import DatasetProcessor @dataclass class PretrainDatasetProcessor(DatasetProcessor): - def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]: # build grouped texts with format `X1 X2 X3 ...` if packing is enabled eos_token = "<|end_of_text|>" if self.data_args.template == "llama3" else self.tokenizer.eos_token text_examples = [messages[0]["content"] + eos_token for messages in examples["_prompt"]] @@ -52,6 +52,6 @@ class PretrainDatasetProcessor(DatasetProcessor): return result - def print_data_example(self, example: Dict[str, List[int]]) -> None: + def print_data_example(self, example: dict[str, list[int]]) -> None: print("input_ids:\n{}".format(example["input_ids"])) print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False))) diff --git a/src/llamafactory/data/processor/processor_utils.py b/src/llamafactory/data/processor/processor_utils.py index 9e5cb086..528ff52b 100644 --- a/src/llamafactory/data/processor/processor_utils.py +++ b/src/llamafactory/data/processor/processor_utils.py @@ -14,8 +14,9 @@ import bisect from abc import ABC, abstractmethod +from collections.abc import Sequence from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple +from typing import TYPE_CHECKING, Any, Optional if TYPE_CHECKING: @@ -27,9 +28,7 @@ if TYPE_CHECKING: @dataclass class DatasetProcessor(ABC): - r""" - A class for data processors. - """ + r"""A class for data processors.""" template: "Template" tokenizer: "PreTrainedTokenizer" @@ -37,32 +36,24 @@ class DatasetProcessor(ABC): data_args: "DataArguments" @abstractmethod - def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: - r""" - Builds model inputs from the examples. - """ + def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]: + r"""Build model inputs from the examples.""" ... @abstractmethod - def print_data_example(self, example: Dict[str, List[int]]) -> None: - r""" - Print a data example to stdout. - """ + def print_data_example(self, example: dict[str, list[int]]) -> None: + r"""Print a data example to stdout.""" ... def search_for_fit(numbers: Sequence[int], capacity: int) -> int: - r""" - Finds the index of largest number that fits into the knapsack with the given capacity. - """ + r"""Find the index of largest number that fits into the knapsack with the given capacity.""" index = bisect.bisect(numbers, capacity) return -1 if index == 0 else (index - 1) -def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]: - r""" - An efficient greedy algorithm with binary search for the knapsack problem. - """ +def greedy_knapsack(numbers: list[int], capacity: int) -> list[list[int]]: + r"""Implement efficient greedy algorithm with binary search for the knapsack problem.""" numbers.sort() # sort numbers in ascending order for binary search knapsacks = [] @@ -83,10 +74,8 @@ def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]: return knapsacks -def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]: - r""" - Computes the real sequence length after truncation by the cutoff_len. - """ +def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> tuple[int, int]: + r"""Compute the real sequence length after truncation by the cutoff_len.""" if target_len * 2 < cutoff_len: # truncate source max_target_len = cutoff_len elif source_len * 2 < cutoff_len: # truncate target diff --git a/src/llamafactory/data/processor/supervised.py b/src/llamafactory/data/processor/supervised.py index e83de97b..1e62e9a3 100644 --- a/src/llamafactory/data/processor/supervised.py +++ b/src/llamafactory/data/processor/supervised.py @@ -13,8 +13,9 @@ # limitations under the License. from collections import defaultdict +from collections.abc import Sequence from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple +from typing import TYPE_CHECKING, Any, Optional from ...extras import logging from ...extras.constants import IGNORE_INDEX @@ -32,14 +33,14 @@ logger = logging.get_logger(__name__) class SupervisedDatasetProcessor(DatasetProcessor): def _encode_data_example( self, - prompt: Sequence[Dict[str, str]], - response: Sequence[Dict[str, str]], + prompt: Sequence[dict[str, str]], + response: Sequence[dict[str, str]], system: Optional[str], tools: Optional[str], images: Sequence["ImageInput"], videos: Sequence["VideoInput"], audios: Sequence["AudioInput"], - ) -> Tuple[List[int], List[int]]: + ) -> tuple[list[int], list[int]]: messages = self.template.mm_plugin.process_messages(prompt + response, images, videos, audios, self.processor) input_ids, labels = self.template.mm_plugin.process_token_ids( [], [], images, videos, audios, self.tokenizer, self.processor @@ -85,7 +86,7 @@ class SupervisedDatasetProcessor(DatasetProcessor): return input_ids, labels - def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]: # build inputs with format ` X Y ` and labels with format ` ... Y ` # for multiturn examples, we only mask the prompt part in each prompt-response pair. model_inputs = defaultdict(list) @@ -114,7 +115,7 @@ class SupervisedDatasetProcessor(DatasetProcessor): return model_inputs - def print_data_example(self, example: Dict[str, List[int]]) -> None: + def print_data_example(self, example: dict[str, list[int]]) -> None: valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"])) print("input_ids:\n{}".format(example["input_ids"])) print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False))) @@ -124,7 +125,7 @@ class SupervisedDatasetProcessor(DatasetProcessor): @dataclass class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor): - def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]: # TODO: use `position_ids` to achieve packing # build inputs with format ` X1 Y1 X2 Y2 ` # and labels with format ` ... Y1 ... Y2 ` diff --git a/src/llamafactory/data/processor/unsupervised.py b/src/llamafactory/data/processor/unsupervised.py index 38a0b442..2ce628d9 100644 --- a/src/llamafactory/data/processor/unsupervised.py +++ b/src/llamafactory/data/processor/unsupervised.py @@ -13,7 +13,8 @@ # limitations under the License. from collections import defaultdict -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any, Optional from ...extras import logging from ..data_utils import Role @@ -30,14 +31,14 @@ logger = logging.get_logger(__name__) class UnsupervisedDatasetProcessor(DatasetProcessor): def _encode_data_example( self, - prompt: Sequence[Dict[str, str]], - response: Sequence[Dict[str, str]], + prompt: Sequence[dict[str, str]], + response: Sequence[dict[str, str]], system: Optional[str], tools: Optional[str], images: Sequence["ImageInput"], videos: Sequence["VideoInput"], audios: Sequence["AudioInput"], - ) -> Tuple[List[int], List[int]]: + ) -> tuple[list[int], list[int]]: if len(response) == 1: messages = prompt + response else: @@ -56,7 +57,7 @@ class UnsupervisedDatasetProcessor(DatasetProcessor): labels = labels[:target_len] return input_ids, labels - def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]: # build inputs with format ` X` and labels with format `Y ` model_inputs = defaultdict(list) for i in range(len(examples["_prompt"])): @@ -84,7 +85,7 @@ class UnsupervisedDatasetProcessor(DatasetProcessor): return model_inputs - def print_data_example(self, example: Dict[str, List[int]]) -> None: + def print_data_example(self, example: dict[str, list[int]]) -> None: print("input_ids:\n{}".format(example["input_ids"])) print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False))) print("label_ids:\n{}".format(example["labels"])) diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index a789cf7a..50d0da24 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Sequence from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Type, Union +from typing import TYPE_CHECKING, Optional, Union from typing_extensions import override @@ -46,8 +47,8 @@ class Template: format_tools: "Formatter" format_prefix: "Formatter" default_system: str - stop_words: List[str] - thought_words: Tuple[str, str] + stop_words: list[str] + thought_words: tuple[str, str] efficient_eos: bool replace_eos: bool replace_jinja_template: bool @@ -56,13 +57,11 @@ class Template: def encode_oneturn( self, tokenizer: "PreTrainedTokenizer", - messages: Sequence[Dict[str, str]], + messages: Sequence[dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - ) -> Tuple[List[int], List[int]]: - r""" - Returns a single pair of token ids representing prompt and response respectively. - """ + ) -> tuple[list[int], list[int]]: + r"""Return a single pair of token ids representing prompt and response respectively.""" encoded_messages = self._encode(tokenizer, messages, system, tools) prompt_ids = [] for encoded_ids in encoded_messages[:-1]: @@ -74,36 +73,28 @@ class Template: def encode_multiturn( self, tokenizer: "PreTrainedTokenizer", - messages: Sequence[Dict[str, str]], + messages: Sequence[dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - ) -> List[Tuple[List[int], List[int]]]: - r""" - Returns multiple pairs of token ids representing prompts and responses respectively. - """ + ) -> list[tuple[list[int], list[int]]]: + r"""Return multiple pairs of token ids representing prompts and responses respectively.""" encoded_messages = self._encode(tokenizer, messages, system, tools) return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)] - def extract_tool(self, content: str) -> Union[str, List["FunctionCall"]]: - r""" - Extracts tool message. - """ + def extract_tool(self, content: str) -> Union[str, list["FunctionCall"]]: + r"""Extract tool message.""" return self.format_tools.extract(content) - def get_stop_token_ids(self, tokenizer: "PreTrainedTokenizer") -> List[int]: - r""" - Returns stop token ids. - """ + def get_stop_token_ids(self, tokenizer: "PreTrainedTokenizer") -> list[int]: + r"""Return stop token ids.""" stop_token_ids = {tokenizer.eos_token_id} for token in self.stop_words: stop_token_ids.add(tokenizer.convert_tokens_to_ids(token)) return list(stop_token_ids) - def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> List[int]: - r""" - Converts elements to token ids. - """ + def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> list[int]: + r"""Convert elements to token ids.""" token_ids = [] for elem in elements: if isinstance(elem, str): @@ -124,14 +115,14 @@ class Template: def _encode( self, tokenizer: "PreTrainedTokenizer", - messages: Sequence[Dict[str, str]], + messages: Sequence[dict[str, str]], system: Optional[str], tools: Optional[str], - ) -> List[List[int]]: - r""" - Encodes formatted inputs to pairs of token ids. + ) -> list[list[int]]: + r"""Encode formatted inputs to pairs of token ids. + Turn 0: prefix + system + query resp - Turn t: query resp + Turn t: query resp. """ system = system or self.default_system encoded_messages = [] @@ -161,9 +152,7 @@ class Template: @staticmethod def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None: - r""" - Adds or replaces eos token to the tokenizer. - """ + r"""Add or replace eos token to the tokenizer.""" is_added = tokenizer.eos_token_id is None num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token}) @@ -176,9 +165,7 @@ class Template: logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.") def fix_special_tokens(self, tokenizer: "PreTrainedTokenizer") -> None: - r""" - Adds eos token and pad token to the tokenizer. - """ + r"""Add eos token and pad token to the tokenizer.""" stop_words = self.stop_words if self.replace_eos: if not stop_words: @@ -204,16 +191,12 @@ class Template: @staticmethod def _jinja_escape(content: str) -> str: - r""" - Escape single quotes in content. - """ + r"""Escape single quotes in content.""" return content.replace("'", r"\'") @staticmethod def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str: - r""" - Converts slots to jinja template. - """ + r"""Convert slots to jinja template.""" slot_items = [] for slot in slots: if isinstance(slot, str): @@ -235,9 +218,7 @@ class Template: return " + ".join(slot_items) def _get_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> str: - r""" - Returns the jinja template. - """ + r"""Return the jinja template.""" prefix = self._convert_slots_to_jinja(self.format_prefix.apply(), tokenizer) system = self._convert_slots_to_jinja(self.format_system.apply(), tokenizer, placeholder="system_message") user = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer) @@ -265,9 +246,7 @@ class Template: return jinja_template def fix_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> None: - r""" - Replaces the jinja template in the tokenizer. - """ + r"""Replace the jinja template in the tokenizer.""" if tokenizer.chat_template is None or self.replace_jinja_template: try: tokenizer.chat_template = self._get_jinja_template(tokenizer) @@ -278,9 +257,7 @@ class Template: def _convert_slots_to_ollama( slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content" ) -> str: - r""" - Converts slots to ollama template. - """ + r"""Convert slots to ollama template.""" slot_items = [] for slot in slots: if isinstance(slot, str): @@ -302,9 +279,7 @@ class Template: return "".join(slot_items) def _get_ollama_template(self, tokenizer: "PreTrainedTokenizer") -> str: - r""" - Returns the ollama template. - """ + r"""Return the ollama template.""" prefix = self._convert_slots_to_ollama(self.format_prefix.apply(), tokenizer) system = self._convert_slots_to_ollama(self.format_system.apply(), tokenizer, placeholder=".System") user = self._convert_slots_to_ollama(self.format_user.apply(), tokenizer, placeholder=".Content") @@ -316,8 +291,7 @@ class Template: ) def get_ollama_modelfile(self, tokenizer: "PreTrainedTokenizer") -> str: - r""" - Returns the ollama modelfile. + r"""Return the ollama modelfile. TODO: support function calling. """ @@ -340,10 +314,10 @@ class Llama2Template(Template): def _encode( self, tokenizer: "PreTrainedTokenizer", - messages: Sequence[Dict[str, str]], + messages: Sequence[dict[str, str]], system: str, tools: str, - ) -> List[List[int]]: + ) -> list[list[int]]: system = system or self.default_system encoded_messages = [] for i, message in enumerate(messages): @@ -402,7 +376,7 @@ class Llama2Template(Template): return jinja_template -TEMPLATES: Dict[str, "Template"] = {} +TEMPLATES: dict[str, "Template"] = {} def register_template( @@ -416,15 +390,14 @@ def register_template( format_prefix: Optional["Formatter"] = None, default_system: str = "", stop_words: Optional[Sequence[str]] = None, - thought_words: Optional[Tuple[str, str]] = None, + thought_words: Optional[tuple[str, str]] = None, efficient_eos: bool = False, replace_eos: bool = False, replace_jinja_template: bool = False, mm_plugin: "BasePlugin" = get_mm_plugin(name="base"), - template_class: Type["Template"] = Template, + template_class: type["Template"] = Template, ) -> None: - r""" - Registers a chat template. + r"""Register a chat template. To add the following chat template: ``` @@ -472,9 +445,7 @@ def register_template( def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template": - r""" - Extracts a chat template from the tokenizer. - """ + r"""Extract a chat template from the tokenizer.""" def find_diff(short_str: str, long_str: str) -> str: i, j = 0, 0 @@ -532,9 +503,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template": def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template": - r""" - Gets chat template and fixes the tokenizer. - """ + r"""Get chat template and fixes the tokenizer.""" if data_args.template is None: if isinstance(tokenizer.chat_template, str): logger.warning_rank0("`template` was not specified, try parsing the chat template from the tokenizer.") @@ -1149,7 +1118,8 @@ register_template( format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), default_system=( - "你是一个经过良好训练的AI助手,你的名字是Marco-o1.由阿里国际数字商业集团的AI Business创造.\n## 重要!!!!!\n" + "你是一个经过良好训练的AI助手,你的名字是Marco-o1." + "由阿里国际数字商业集团的AI Business创造.\n## 重要!!!!!\n" "当你回答问题时,你的思考应该在内完成,内输出你的结果。\n" "应该尽可能是英文,但是有2个特例,一个是对原文中的引用,另一个是是数学应该使用markdown格式,内的输出需要遵循用户输入的语言。\n" ), diff --git a/src/llamafactory/data/tool_utils.py b/src/llamafactory/data/tool_utils.py index f3269413..8ad1466a 100644 --- a/src/llamafactory/data/tool_utils.py +++ b/src/llamafactory/data/tool_utils.py @@ -17,7 +17,7 @@ import re from abc import ABC, abstractmethod from dataclasses import dataclass from datetime import datetime -from typing import Any, Dict, List, NamedTuple, Tuple, Union +from typing import Any, NamedTuple, Union from typing_extensions import override @@ -60,31 +60,24 @@ QWEN_TOOL_PROMPT = ( @dataclass class ToolUtils(ABC): - """ - Base class for tool utilities. - """ + """Base class for tool utilities.""" @staticmethod @abstractmethod - def tool_formatter(tools: List[Dict[str, Any]]) -> str: - r""" - Generates the system message describing all the available tools. - """ + def tool_formatter(tools: list[dict[str, Any]]) -> str: + r"""Generate the system message describing all the available tools.""" ... @staticmethod @abstractmethod - def function_formatter(functions: List["FunctionCall"]) -> str: - r""" - Generates the assistant message including all the tool calls. - """ + def function_formatter(functions: list["FunctionCall"]) -> str: + r"""Generate the assistant message including all the tool calls.""" ... @staticmethod @abstractmethod - def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]: - r""" - Extracts all the function calls from the assistant message. + def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]: + r"""Extract all the function calls from the assistant message. It should be an inverse function of `function_formatter`. """ @@ -92,13 +85,11 @@ class ToolUtils(ABC): class DefaultToolUtils(ToolUtils): - r""" - Default tool using template. - """ + r"""Default tool using template.""" @override @staticmethod - def tool_formatter(tools: List[Dict[str, Any]]) -> str: + def tool_formatter(tools: list[dict[str, Any]]) -> str: tool_text = "" tool_names = [] for tool in tools: @@ -132,7 +123,7 @@ class DefaultToolUtils(ToolUtils): @override @staticmethod - def function_formatter(functions: List["FunctionCall"]) -> str: + def function_formatter(functions: list["FunctionCall"]) -> str: function_text = "" for name, arguments in functions: function_text += f"Action: {name}\nAction Input: {arguments}\n" @@ -141,9 +132,9 @@ class DefaultToolUtils(ToolUtils): @override @staticmethod - def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]: + def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]: regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL) - action_match: List[Tuple[str, str]] = re.findall(regex, content) + action_match: list[tuple[str, str]] = re.findall(regex, content) if not action_match: return content @@ -161,13 +152,11 @@ class DefaultToolUtils(ToolUtils): class GLM4ToolUtils(ToolUtils): - r""" - GLM-4 tool using template. - """ + r"""GLM-4 tool using template.""" @override @staticmethod - def tool_formatter(tools: List[Dict[str, Any]]) -> str: + def tool_formatter(tools: list[dict[str, Any]]) -> str: tool_text = "" for tool in tools: tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format( @@ -178,7 +167,7 @@ class GLM4ToolUtils(ToolUtils): @override @staticmethod - def function_formatter(functions: List["FunctionCall"]) -> str: + def function_formatter(functions: list["FunctionCall"]) -> str: if len(functions) > 1: raise ValueError("GLM-4 does not support parallel functions.") @@ -186,7 +175,7 @@ class GLM4ToolUtils(ToolUtils): @override @staticmethod - def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]: + def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]: if "\n" not in content: return content @@ -200,15 +189,14 @@ class GLM4ToolUtils(ToolUtils): class Llama3ToolUtils(ToolUtils): - r""" - Llama 3.x tool using template with `tools_in_user_message=False`. + r"""Llama 3.x tool using template with `tools_in_user_message=False`. Reference: https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling """ @override @staticmethod - def tool_formatter(tools: List[Dict[str, Any]]) -> str: + def tool_formatter(tools: list[dict[str, Any]]) -> str: date = datetime.now().strftime("%d %b %Y") tool_text = "" for tool in tools: @@ -219,7 +207,7 @@ class Llama3ToolUtils(ToolUtils): @override @staticmethod - def function_formatter(functions: List["FunctionCall"]) -> str: + def function_formatter(functions: list["FunctionCall"]) -> str: if len(functions) > 1: raise ValueError("Llama-3 does not support parallel functions.") @@ -227,7 +215,7 @@ class Llama3ToolUtils(ToolUtils): @override @staticmethod - def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]: + def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]: try: tool = json.loads(content.strip()) except json.JSONDecodeError: @@ -240,13 +228,11 @@ class Llama3ToolUtils(ToolUtils): class MistralToolUtils(ToolUtils): - r""" - Mistral v0.3 tool using template. - """ + r"""Mistral v0.3 tool using template.""" @override @staticmethod - def tool_formatter(tools: List[Dict[str, Any]]) -> str: + def tool_formatter(tools: list[dict[str, Any]]) -> str: wrapped_tools = [] for tool in tools: wrapped_tools.append({"type": "function", "function": tool}) @@ -255,7 +241,7 @@ class MistralToolUtils(ToolUtils): @override @staticmethod - def function_formatter(functions: List["FunctionCall"]) -> str: + def function_formatter(functions: list["FunctionCall"]) -> str: function_texts = [] for name, arguments in functions: function_texts.append(f'{{"name": "{name}", "arguments": {arguments}}}') @@ -264,7 +250,7 @@ class MistralToolUtils(ToolUtils): @override @staticmethod - def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]: + def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]: try: tools = json.loads(content.strip()) except json.JSONDecodeError: @@ -284,13 +270,11 @@ class MistralToolUtils(ToolUtils): class QwenToolUtils(ToolUtils): - r""" - Qwen 2.5 tool using template. - """ + r"""Qwen 2.5 tool using template.""" @override @staticmethod - def tool_formatter(tools: List[Dict[str, Any]]) -> str: + def tool_formatter(tools: list[dict[str, Any]]) -> str: tool_text = "" for tool in tools: wrapped_tool = {"type": "function", "function": tool} @@ -300,7 +284,7 @@ class QwenToolUtils(ToolUtils): @override @staticmethod - def function_formatter(functions: List["FunctionCall"]) -> str: + def function_formatter(functions: list["FunctionCall"]) -> str: function_texts = [] for name, arguments in functions: function_texts.append( @@ -311,9 +295,9 @@ class QwenToolUtils(ToolUtils): @override @staticmethod - def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]: + def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]: regex = re.compile(r"(.+?)(?=\s*|\s*$)", re.DOTALL) - tool_match: List[str] = re.findall(regex, content) + tool_match: list[str] = re.findall(regex, content) if not tool_match: return content diff --git a/src/llamafactory/eval/evaluator.py b/src/llamafactory/eval/evaluator.py index 99758dd2..7729c59b 100644 --- a/src/llamafactory/eval/evaluator.py +++ b/src/llamafactory/eval/evaluator.py @@ -39,7 +39,7 @@ import json import os -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Optional import numpy as np import torch @@ -59,7 +59,7 @@ if TYPE_CHECKING: class Evaluator: - def __init__(self, args: Optional[Dict[str, Any]] = None) -> None: + def __init__(self, args: Optional[dict[str, Any]] = None) -> None: self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args) self.tokenizer = load_tokenizer(self.model_args)["tokenizer"] self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2 @@ -69,7 +69,7 @@ class Evaluator: self.choice_inputs = [self.tokenizer.encode(ch, add_special_tokens=False)[-1] for ch in CHOICES] @torch.inference_mode() - def batch_inference(self, batch_input: Dict[str, "torch.Tensor"]) -> List[str]: + def batch_inference(self, batch_input: dict[str, "torch.Tensor"]) -> list[str]: logits = self.model(**batch_input).logits lengths = torch.sum(batch_input["attention_mask"], dim=-1) word_probs = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0) @@ -88,7 +88,7 @@ class Evaluator: ) with open(mapping, encoding="utf-8") as f: - categorys: Dict[str, Dict[str, str]] = json.load(f) + categorys: dict[str, dict[str, str]] = json.load(f) category_corrects = {subj: np.array([], dtype="bool") for subj in SUBJECTS} pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0) @@ -136,7 +136,7 @@ class Evaluator: pbar.close() self._save_results(category_corrects, results) - def _save_results(self, category_corrects: Dict[str, "NDArray"], results: Dict[str, Dict[int, str]]) -> None: + def _save_results(self, category_corrects: dict[str, "NDArray"], results: dict[str, dict[int, str]]) -> None: score_info = "\n".join( [ f"{category_name:>15}: {100 * np.mean(category_correct):.2f}" diff --git a/src/llamafactory/eval/template.py b/src/llamafactory/eval/template.py index e1454097..83f70171 100644 --- a/src/llamafactory/eval/template.py +++ b/src/llamafactory/eval/template.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Sequence from dataclasses import dataclass -from typing import Dict, List, Sequence, Tuple from ..data import Role from ..extras.constants import CHOICES @@ -25,20 +25,19 @@ class EvalTemplate: choice: str answer: str - def _parse_example(self, example: Dict[str, str]) -> Tuple[str, str]: - r""" + def _parse_example(self, example: dict[str, str]) -> tuple[str, str]: + r"""Parse eval example. + input: a dict with keys {"question", "A", "B", "C", "D", "answer"} - output: a tuple of (prompt, response) + output: a tuple of (prompt, response). """ candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example] return "".join([example["question"]] + candidates + [self.answer]), example["answer"] def format_example( - self, target_data: Dict[str, str], support_set: Sequence[Dict[str, str]], subject_name: str - ) -> List[Dict[str, str]]: - r""" - Converts dataset examples to messages. - """ + self, target_data: dict[str, str], support_set: Sequence[dict[str, str]], subject_name: str + ) -> list[dict[str, str]]: + r"""Convert dataset examples to messages.""" messages = [] for k in range(len(support_set)): prompt, response = self._parse_example(support_set[k]) @@ -52,7 +51,7 @@ class EvalTemplate: return messages -eval_templates: Dict[str, "EvalTemplate"] = {} +eval_templates: dict[str, "EvalTemplate"] = {} def _register_eval_template(name: str, system: str, choice: str, answer: str) -> None: diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index dce6d83b..a3f222e9 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -15,7 +15,7 @@ import os from collections import OrderedDict, defaultdict from enum import Enum -from typing import Dict, Optional +from typing import Optional from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME @@ -122,7 +122,7 @@ class RopeScaling(str, Enum): def register_model_group( - models: Dict[str, Dict[DownloadSource, str]], + models: dict[str, dict[DownloadSource, str]], template: Optional[str] = None, multimodal: bool = False, ) -> None: diff --git a/src/llamafactory/extras/logging.py b/src/llamafactory/extras/logging.py index 8f98b055..8fc030a8 100644 --- a/src/llamafactory/extras/logging.py +++ b/src/llamafactory/extras/logging.py @@ -32,9 +32,7 @@ _default_log_level: "logging._Level" = logging.INFO class LoggerHandler(logging.Handler): - r""" - Redirects the logging output to the logging file for LLaMA Board. - """ + r"""Redirect the logging output to the logging file for LLaMA Board.""" def __init__(self, output_dir: str) -> None: super().__init__() @@ -67,9 +65,7 @@ class LoggerHandler(logging.Handler): class _Logger(logging.Logger): - r""" - A logger that supports rank0 logging. - """ + r"""A logger that supports rank0 logging.""" def info_rank0(self, *args, **kwargs) -> None: self.info(*args, **kwargs) @@ -82,9 +78,7 @@ class _Logger(logging.Logger): def _get_default_logging_level() -> "logging._Level": - r""" - Returns the default logging level. - """ + r"""Return the default logging level.""" env_level_str = os.environ.get("LLAMAFACTORY_VERBOSITY", None) if env_level_str: if env_level_str.upper() in logging._nameToLevel: @@ -104,9 +98,7 @@ def _get_library_root_logger() -> "_Logger": def _configure_library_root_logger() -> None: - r""" - Configures root logger using a stdout stream handler with an explicit format. - """ + r"""Configure root logger using a stdout stream handler with an explicit format.""" global _default_handler with _thread_lock: @@ -126,9 +118,7 @@ def _configure_library_root_logger() -> None: def get_logger(name: Optional[str] = None) -> "_Logger": - r""" - Returns a logger with the specified name. It it not supposed to be accessed externally. - """ + r"""Return a logger with the specified name. It it not supposed to be accessed externally.""" if name is None: name = _get_library_name() @@ -137,17 +127,13 @@ def get_logger(name: Optional[str] = None) -> "_Logger": def add_handler(handler: "logging.Handler") -> None: - r""" - Adds a handler to the root logger. - """ + r"""Add a handler to the root logger.""" _configure_library_root_logger() _get_library_root_logger().addHandler(handler) def remove_handler(handler: logging.Handler) -> None: - r""" - Removes a handler to the root logger. - """ + r"""Remove a handler to the root logger.""" _configure_library_root_logger() _get_library_root_logger().removeHandler(handler) diff --git a/src/llamafactory/extras/misc.py b/src/llamafactory/extras/misc.py index f637d728..c37a8be6 100644 --- a/src/llamafactory/extras/misc.py +++ b/src/llamafactory/extras/misc.py @@ -17,7 +17,8 @@ import gc import os -from typing import TYPE_CHECKING, Any, Dict, Literal, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any, Literal, Union import torch import torch.distributed as dist @@ -54,9 +55,7 @@ logger = logging.get_logger(__name__) class AverageMeter: - r""" - Computes and stores the average and current value. - """ + r"""Compute and store the average and current value.""" def __init__(self): self.reset() @@ -75,9 +74,7 @@ class AverageMeter: def check_version(requirement: str, mandatory: bool = False) -> None: - r""" - Optionally checks the package version. - """ + r"""Optionally check the package version.""" if is_env_enabled("DISABLE_VERSION_CHECK") and not mandatory: logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.") return @@ -91,9 +88,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None: def check_dependencies() -> None: - r""" - Checks the version of the required packages. - """ + r"""Check the version of the required packages.""" check_version("transformers>=4.41.2,<=4.49.0,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0") check_version("datasets>=2.16.0,<=3.2.0") check_version("accelerate>=0.34.0,<=1.2.1") @@ -103,10 +98,8 @@ def check_dependencies() -> None: logger.warning_rank0_once("There are known bugs in transformers v4.46.0-v4.48.0, please use other versions.") -def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float], stage: Literal["sft", "rm"]) -> float: - r""" - Calculates effective tokens per second. - """ +def calculate_tps(dataset: Sequence[dict[str, Any]], metrics: dict[str, float], stage: Literal["sft", "rm"]) -> float: + r"""Calculate effective tokens per second.""" effective_token_num = 0 for data in dataset: if stage == "sft": @@ -118,10 +111,8 @@ def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float], return result / dist.get_world_size() if dist.is_initialized() else result -def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]: - r""" - Returns the number of trainable parameters and number of all parameters in the model. - """ +def count_parameters(model: "torch.nn.Module") -> tuple[int, int]: + r"""Return the number of trainable parameters and number of all parameters in the model.""" trainable_params, all_param = 0, 0 for param in model.parameters(): num_params = param.numel() @@ -148,9 +139,7 @@ def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]: def get_current_device() -> "torch.device": - r""" - Gets the current available device. - """ + r"""Get the current available device.""" if is_torch_xpu_available(): device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0")) elif is_torch_npu_available(): @@ -166,9 +155,7 @@ def get_current_device() -> "torch.device": def get_device_count() -> int: - r""" - Gets the number of available GPU or NPU devices. - """ + r"""Get the number of available GPU or NPU devices.""" if is_torch_xpu_available(): return torch.xpu.device_count() elif is_torch_npu_available(): @@ -180,18 +167,14 @@ def get_device_count() -> int: def get_logits_processor() -> "LogitsProcessorList": - r""" - Gets logits processor that removes NaN and Inf logits. - """ + r"""Get logits processor that removes NaN and Inf logits.""" logits_processor = LogitsProcessorList() logits_processor.append(InfNanRemoveLogitsProcessor()) return logits_processor -def get_peak_memory() -> Tuple[int, int]: - r""" - Gets the peak memory usage for the current device (in Bytes). - """ +def get_peak_memory() -> tuple[int, int]: + r"""Get the peak memory usage for the current device (in Bytes).""" if is_torch_npu_available(): return torch.npu.max_memory_allocated(), torch.npu.max_memory_reserved() elif is_torch_cuda_available(): @@ -201,16 +184,12 @@ def get_peak_memory() -> Tuple[int, int]: def has_tokenized_data(path: "os.PathLike") -> bool: - r""" - Checks if the path has a tokenized dataset. - """ + r"""Check if the path has a tokenized dataset.""" return os.path.isdir(path) and len(os.listdir(path)) > 0 def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype": - r""" - Infers the optimal dtype according to the model_dtype and device compatibility. - """ + r"""Infer the optimal dtype according to the model_dtype and device compatibility.""" if _is_bf16_available and model_dtype == torch.bfloat16: return torch.bfloat16 elif _is_fp16_available: @@ -220,23 +199,17 @@ def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype": def is_gpu_or_npu_available() -> bool: - r""" - Checks if the GPU or NPU is available. - """ + r"""Check if the GPU or NPU is available.""" return is_torch_npu_available() or is_torch_cuda_available() def is_env_enabled(env_var: str, default: str = "0") -> bool: - r""" - Checks if the environment variable is enabled. - """ + r"""Check if the environment variable is enabled.""" return os.getenv(env_var, default).lower() in ["true", "y", "1"] def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray": - r""" - Casts a torch tensor or a numpy array to a numpy array. - """ + r"""Cast a torch tensor or a numpy array to a numpy array.""" if isinstance(inputs, torch.Tensor): inputs = inputs.cpu() if inputs.dtype == torch.bfloat16: # numpy does not support bfloat16 until 1.21.4 @@ -248,17 +221,13 @@ def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray": def skip_check_imports() -> None: - r""" - Avoids flash attention import error in custom model files. - """ + r"""Avoid flash attention import error in custom model files.""" if not is_env_enabled("FORCE_CHECK_IMPORTS"): transformers.dynamic_module_utils.check_imports = get_relative_imports def torch_gc() -> None: - r""" - Collects GPU or NPU memory. - """ + r"""Collect GPU or NPU memory.""" gc.collect() if is_torch_xpu_available(): torch.xpu.empty_cache() diff --git a/src/llamafactory/extras/ploting.py b/src/llamafactory/extras/ploting.py index d05970d2..be89bcc5 100644 --- a/src/llamafactory/extras/ploting.py +++ b/src/llamafactory/extras/ploting.py @@ -15,7 +15,7 @@ import json import math import os -from typing import Any, Dict, List +from typing import Any from transformers.trainer import TRAINER_STATE_NAME @@ -31,10 +31,8 @@ if is_matplotlib_available(): logger = logging.get_logger(__name__) -def smooth(scalars: List[float]) -> List[float]: - r""" - EMA implementation according to TensorBoard. - """ +def smooth(scalars: list[float]) -> list[float]: + r"""EMA implementation according to TensorBoard.""" if len(scalars) == 0: return [] @@ -48,10 +46,8 @@ def smooth(scalars: List[float]) -> List[float]: return smoothed -def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figure": - r""" - Plots loss curves in LlamaBoard. - """ +def gen_loss_plot(trainer_log: list[dict[str, Any]]) -> "matplotlib.figure.Figure": + r"""Plot loss curves in LlamaBoard.""" plt.close("all") plt.switch_backend("agg") fig = plt.figure() @@ -70,10 +66,8 @@ def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figur return fig -def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None: - r""" - Plots loss curves and saves the image. - """ +def plot_loss(save_dictionary: str, keys: list[str] = ["loss"]) -> None: + r"""Plot loss curves and saves the image.""" plt.switch_backend("agg") with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), encoding="utf-8") as f: data = json.load(f) diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index cd28f580..80b49248 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -16,14 +16,12 @@ # limitations under the License. from dataclasses import asdict, dataclass, field -from typing import Any, Dict, Literal, Optional +from typing import Any, Literal, Optional @dataclass class DataArguments: - r""" - Arguments pertaining to what data we are going to input our model for training and evaluation. - """ + r"""Arguments pertaining to what data we are going to input our model for training and evaluation.""" template: Optional[str] = field( default=None, @@ -162,5 +160,5 @@ class DataArguments: if self.mask_history and self.train_on_prompt: raise ValueError("`mask_history` is incompatible with `train_on_prompt`.") - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return asdict(self) diff --git a/src/llamafactory/hparams/evaluation_args.py b/src/llamafactory/hparams/evaluation_args.py index ec1867e8..d92e8b1e 100644 --- a/src/llamafactory/hparams/evaluation_args.py +++ b/src/llamafactory/hparams/evaluation_args.py @@ -21,9 +21,7 @@ from datasets import DownloadMode @dataclass class EvaluationArguments: - r""" - Arguments pertaining to specify the evaluation parameters. - """ + r"""Arguments pertaining to specify the evaluation parameters.""" task: str = field( metadata={"help": "Name of the evaluation task."}, diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index afcb171c..69c4ef77 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -13,14 +13,12 @@ # limitations under the License. from dataclasses import asdict, dataclass, field -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Literal, Optional @dataclass class FreezeArguments: - r""" - Arguments pertaining to the freeze (partial-parameter) training. - """ + r"""Arguments pertaining to the freeze (partial-parameter) training.""" freeze_trainable_layers: int = field( default=2, @@ -56,9 +54,7 @@ class FreezeArguments: @dataclass class LoraArguments: - r""" - Arguments pertaining to the LoRA training. - """ + r"""Arguments pertaining to the LoRA training.""" additional_target: Optional[str] = field( default=None, @@ -128,9 +124,7 @@ class LoraArguments: @dataclass class RLHFArguments: - r""" - Arguments pertaining to the PPO, DPO and KTO training. - """ + r"""Arguments pertaining to the PPO, DPO and KTO training.""" pref_beta: float = field( default=0.1, @@ -212,9 +206,7 @@ class RLHFArguments: @dataclass class GaloreArguments: - r""" - Arguments pertaining to the GaLore algorithm. - """ + r"""Arguments pertaining to the GaLore algorithm.""" use_galore: bool = field( default=False, @@ -253,9 +245,7 @@ class GaloreArguments: @dataclass class ApolloArguments: - r""" - Arguments pertaining to the APOLLO algorithm. - """ + r"""Arguments pertaining to the APOLLO algorithm.""" use_apollo: bool = field( default=False, @@ -306,9 +296,7 @@ class ApolloArguments: @dataclass class BAdamArgument: - r""" - Arguments pertaining to the BAdam optimizer. - """ + r"""Arguments pertaining to the BAdam optimizer.""" use_badam: bool = field( default=False, @@ -393,9 +381,7 @@ class SwanLabArguments: class FinetuningArguments( SwanLabArguments, BAdamArgument, ApolloArguments, GaloreArguments, RLHFArguments, LoraArguments, FreezeArguments ): - r""" - Arguments pertaining to which techniques we are going to fine-tuning with. - """ + r"""Arguments pertaining to which techniques we are going to fine-tuning with.""" pure_bf16: bool = field( default=False, @@ -452,13 +438,13 @@ class FinetuningArguments( return [item.strip() for item in arg.split(",")] return arg - self.freeze_trainable_modules: List[str] = split_arg(self.freeze_trainable_modules) - self.freeze_extra_modules: Optional[List[str]] = split_arg(self.freeze_extra_modules) + self.freeze_trainable_modules: list[str] = split_arg(self.freeze_trainable_modules) + self.freeze_extra_modules: Optional[list[str]] = split_arg(self.freeze_extra_modules) self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2 - self.lora_target: List[str] = split_arg(self.lora_target) - self.additional_target: Optional[List[str]] = split_arg(self.additional_target) - self.galore_target: List[str] = split_arg(self.galore_target) - self.apollo_target: List[str] = split_arg(self.apollo_target) + self.lora_target: list[str] = split_arg(self.lora_target) + self.additional_target: Optional[list[str]] = split_arg(self.additional_target) + self.galore_target: list[str] = split_arg(self.galore_target) + self.apollo_target: list[str] = split_arg(self.apollo_target) self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"] assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method." @@ -499,7 +485,7 @@ class FinetuningArguments( if self.pissa_init: raise ValueError("`pissa_init` is only valid for LoRA training.") - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: args = asdict(self) args = {k: f"<{k.upper()}>" if k.endswith("api_key") else v for k, v in args.items()} return args diff --git a/src/llamafactory/hparams/generating_args.py b/src/llamafactory/hparams/generating_args.py index db3306d6..251822b1 100644 --- a/src/llamafactory/hparams/generating_args.py +++ b/src/llamafactory/hparams/generating_args.py @@ -13,16 +13,14 @@ # limitations under the License. from dataclasses import asdict, dataclass, field -from typing import Any, Dict, Optional +from typing import Any, Optional from transformers import GenerationConfig @dataclass class GeneratingArguments: - r""" - Arguments pertaining to specify the decoding parameters. - """ + r"""Arguments pertaining to specify the decoding parameters.""" do_sample: bool = field( default=True, @@ -35,7 +33,9 @@ class GeneratingArguments: top_p: float = field( default=0.7, metadata={ - "help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept." + "help": ( + "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept." + ) }, ) top_k: int = field( @@ -71,7 +71,7 @@ class GeneratingArguments: metadata={"help": "Whether or not to remove special tokens in the decoding."}, ) - def to_dict(self, obey_generation_config: bool = False) -> Dict[str, Any]: + def to_dict(self, obey_generation_config: bool = False) -> dict[str, Any]: args = asdict(self) if args.get("max_new_tokens", -1) > 0: args.pop("max_length", None) diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index 7b5fc93b..2c7c00f8 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -17,7 +17,7 @@ import json from dataclasses import asdict, dataclass, field, fields -from typing import Any, Dict, Literal, Optional, Union +from typing import Any, Literal, Optional, Union import torch from transformers.training_args import _convert_str_dict @@ -28,9 +28,7 @@ from ..extras.constants import AttentionFunction, EngineName, RopeScaling @dataclass class BaseModelArguments: - r""" - Arguments pertaining to the model. - """ + r"""Arguments pertaining to the model.""" model_name_or_path: Optional[str] = field( default=None, @@ -184,9 +182,7 @@ class BaseModelArguments: @dataclass class QuantizationArguments: - r""" - Arguments pertaining to the quantization method. - """ + r"""Arguments pertaining to the quantization method.""" quantization_method: Literal["bitsandbytes", "hqq", "eetq"] = field( default="bitsandbytes", @@ -212,9 +208,7 @@ class QuantizationArguments: @dataclass class ProcessorArguments: - r""" - Arguments pertaining to the image processor. - """ + r"""Arguments pertaining to the image processor.""" image_max_pixels: int = field( default=768 * 768, @@ -244,9 +238,7 @@ class ProcessorArguments: @dataclass class ExportArguments: - r""" - Arguments pertaining to the model export. - """ + r"""Arguments pertaining to the model export.""" export_dir: Optional[str] = field( default=None, @@ -292,9 +284,7 @@ class ExportArguments: @dataclass class VllmArguments: - r""" - Arguments pertaining to the vLLM worker. - """ + r"""Arguments pertaining to the vLLM worker.""" vllm_maxlen: int = field( default=4096, @@ -324,8 +314,7 @@ class VllmArguments: @dataclass class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, QuantizationArguments, BaseModelArguments): - r""" - Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer. + r"""Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer. The class on the most right will be displayed first. """ @@ -335,7 +324,7 @@ class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, Quantiz init=False, metadata={"help": "Torch data type for computing model outputs, derived from `fp/bf16`. Do not specify it."}, ) - device_map: Optional[Union[str, Dict[str, Any]]] = field( + device_map: Optional[Union[str, dict[str, Any]]] = field( default=None, init=False, metadata={"help": "Device map for model placement, derived from training stage. Do not specify it."}, @@ -372,7 +361,7 @@ class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, Quantiz return result - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: args = asdict(self) args = {k: f"<{k.upper()}>" if k.endswith("token") else v for k, v in args.items()} return args diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index dc7f6c1b..464b8472 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -19,7 +19,7 @@ import json import os import sys from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Optional, Union import torch import transformers @@ -47,17 +47,15 @@ check_dependencies() _TRAIN_ARGS = [ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments] -_TRAIN_CLS = Tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments] +_TRAIN_CLS = tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments] _INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] -_INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] +_INFER_CLS = tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] _EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] -_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] +_EVAL_CLS = tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] -def read_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> Union[Dict[str, Any], List[str]]: - r""" - Gets arguments from the command line or a config file. - """ +def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[dict[str, Any], list[str]]: + r"""Get arguments from the command line or a config file.""" if args is not None: return args @@ -70,8 +68,8 @@ def read_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> Union[ def _parse_args( - parser: "HfArgumentParser", args: Optional[Union[Dict[str, Any], List[str]]] = None, allow_extra_keys: bool = False -) -> Tuple[Any]: + parser: "HfArgumentParser", args: Optional[Union[dict[str, Any], list[str]]] = None, allow_extra_keys: bool = False +) -> tuple[Any]: args = read_args(args) if isinstance(args, dict): return parser.parse_dict(args, allow_extra_keys=allow_extra_keys) @@ -161,31 +159,31 @@ def _check_extra_dependencies( check_version("rouge_chinese", mandatory=True) -def _parse_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS: +def _parse_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS: parser = HfArgumentParser(_TRAIN_ARGS) allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS") return _parse_args(parser, args, allow_extra_keys=allow_extra_keys) -def _parse_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS: +def _parse_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS: parser = HfArgumentParser(_INFER_ARGS) allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS") return _parse_args(parser, args, allow_extra_keys=allow_extra_keys) -def _parse_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS: +def _parse_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS: parser = HfArgumentParser(_EVAL_ARGS) allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS") return _parse_args(parser, args, allow_extra_keys=allow_extra_keys) -def get_ray_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> RayArguments: +def get_ray_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> RayArguments: parser = HfArgumentParser(RayArguments) (ray_args,) = _parse_args(parser, args, allow_extra_keys=True) return ray_args -def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS: +def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS: model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args) # Setup logging @@ -364,9 +362,7 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _ and training_args.resume_from_checkpoint is not None ): logger.warning_rank0( - "Add {} to `adapter_name_or_path` to resume training from checkpoint.".format( - training_args.resume_from_checkpoint - ) + f"Add {training_args.resume_from_checkpoint} to `adapter_name_or_path` to resume training from checkpoint." ) # Post-process model arguments @@ -382,20 +378,17 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _ # Log on each process the small summary logger.info( - "Process rank: {}, world size: {}, device: {}, distributed training: {}, compute dtype: {}".format( - training_args.process_index, - training_args.world_size, - training_args.device, - training_args.parallel_mode == ParallelMode.DISTRIBUTED, - str(model_args.compute_dtype), - ) + f"Process rank: {training_args.process_index}, " + f"world size: {training_args.world_size}, device: {training_args.device}, " + f"distributed training: {training_args.parallel_mode == ParallelMode.DISTRIBUTED}, " + f"compute dtype: {str(model_args.compute_dtype)}" ) transformers.set_seed(training_args.seed) return model_args, data_args, training_args, finetuning_args, generating_args -def get_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS: +def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS: model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args) _set_transformers_logging() @@ -426,7 +419,7 @@ def get_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _ return model_args, data_args, finetuning_args, generating_args -def get_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS: +def get_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS: model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args) _set_transformers_logging() diff --git a/src/llamafactory/hparams/training_args.py b/src/llamafactory/hparams/training_args.py index 341c5c5e..38a93650 100644 --- a/src/llamafactory/hparams/training_args.py +++ b/src/llamafactory/hparams/training_args.py @@ -10,9 +10,7 @@ from ..extras.misc import use_ray @dataclass class RayArguments: - r""" - Arguments pertaining to the Ray training. - """ + r"""Arguments pertaining to the Ray training.""" ray_run_name: Optional[str] = field( default=None, @@ -43,9 +41,7 @@ class RayArguments: @dataclass class TrainingArguments(RayArguments, Seq2SeqTrainingArguments): - r""" - Arguments pertaining to the trainer. - """ + r"""Arguments pertaining to the trainer.""" def __post_init__(self): Seq2SeqTrainingArguments.__post_init__(self) diff --git a/src/llamafactory/model/__init__.py b/src/llamafactory/model/__init__.py index 1957ff8d..71d4f47f 100644 --- a/src/llamafactory/model/__init__.py +++ b/src/llamafactory/model/__init__.py @@ -20,9 +20,9 @@ from .model_utils.valuehead import load_valuehead_params __all__ = [ "QuantizationMethod", + "find_all_linear_modules", "load_config", "load_model", "load_tokenizer", - "find_all_linear_modules", "load_valuehead_params", ] diff --git a/src/llamafactory/model/adapter.py b/src/llamafactory/model/adapter.py index 399500e0..7d3ed389 100644 --- a/src/llamafactory/model/adapter.py +++ b/src/llamafactory/model/adapter.py @@ -81,9 +81,8 @@ def _setup_freeze_tuning( if finetuning_args.use_llama_pro: if num_layers % finetuning_args.freeze_trainable_layers != 0: raise ValueError( - "`num_layers` {} should be divisible by `num_layer_trainable` {}.".format( - num_layers, finetuning_args.freeze_trainable_layers - ) + f"`num_layers` {num_layers} should be " + f"divisible by `num_layer_trainable` {finetuning_args.freeze_trainable_layers}." ) stride = num_layers // finetuning_args.freeze_trainable_layers @@ -178,7 +177,7 @@ def _setup_lora_tuning( } for adapter in adapter_to_merge: - model: "LoraModel" = PeftModel.from_pretrained(model, adapter, **init_kwargs) + model: LoraModel = PeftModel.from_pretrained(model, adapter, **init_kwargs) model = model.merge_and_unload() if len(adapter_to_merge) > 0: @@ -263,8 +262,7 @@ def init_adapter( finetuning_args: "FinetuningArguments", is_trainable: bool, ) -> "PreTrainedModel": - r""" - Initializes the adapters. + r"""Initialize the adapters. Support full-parameter, freeze and LoRA training. diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index 45cc1fa5..fb7846fe 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -13,7 +13,7 @@ # limitations under the License. import os -from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict +from typing import TYPE_CHECKING, Any, Optional, TypedDict import torch from transformers import ( @@ -51,9 +51,8 @@ class TokenizerModule(TypedDict): processor: Optional["ProcessorMixin"] -def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]: - r""" - Gets arguments to load config/tokenizer/model. +def _get_init_kwargs(model_args: "ModelArguments") -> dict[str, Any]: + r"""Get arguments to load config/tokenizer/model. Note: including inplace operation of model_args. """ @@ -68,8 +67,7 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]: def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule": - r""" - Loads pretrained tokenizer and optionally loads processor. + r"""Load pretrained tokenizer and optionally loads processor. Note: including inplace operation of model_args. """ @@ -110,9 +108,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule": def load_config(model_args: "ModelArguments") -> "PretrainedConfig": - r""" - Loads model config. - """ + r"""Load model config.""" init_kwargs = _get_init_kwargs(model_args) return AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs) @@ -124,9 +120,7 @@ def load_model( is_trainable: bool = False, add_valuehead: bool = False, ) -> "PreTrainedModel": - r""" - Loads pretrained model. - """ + r"""Load pretrained model.""" init_kwargs = _get_init_kwargs(model_args) config = load_config(model_args) patch_config(config, tokenizer, model_args, init_kwargs, is_trainable) @@ -194,8 +188,9 @@ def load_model( trainable_params, all_param = count_parameters(model) if is_trainable: - param_stats = "trainable params: {:,} || all params: {:,} || trainable%: {:.4f}".format( - trainable_params, all_param, 100 * trainable_params / all_param + param_stats = ( + f"trainable params: {trainable_params:,} || " + f"all params: {all_param:,} || trainable%: {100 * trainable_params / all_param:.4f}" ) else: param_stats = f"all params: {all_param:,}" diff --git a/src/llamafactory/model/model_utils/checkpointing.py b/src/llamafactory/model/model_utils/checkpointing.py index 916f1934..eb498310 100644 --- a/src/llamafactory/model/model_utils/checkpointing.py +++ b/src/llamafactory/model/model_utils/checkpointing.py @@ -21,7 +21,7 @@ import inspect from functools import WRAPPER_ASSIGNMENTS, partial, wraps from types import MethodType -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch @@ -40,9 +40,7 @@ logger = logging.get_logger(__name__) def get_unsloth_gradient_checkpointing_func() -> Callable: class UnslothGradientCheckpointing(torch.autograd.Function): - r""" - Saves VRAM by smartly offloading to RAM. - """ + r"""Saves VRAM by smartly offloading to RAM.""" @staticmethod @torch.cuda.amp.custom_fwd @@ -77,13 +75,11 @@ def get_unsloth_gradient_checkpointing_func() -> Callable: def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable) -> Callable: - r""" - Only applies gradient checkpointing to trainable layers. - """ + r"""Only applies gradient checkpointing to trainable layers.""" @wraps(gradient_checkpointing_func, assigned=WRAPPER_ASSIGNMENTS + ("__self__",)) def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs): - module: "torch.nn.Module" = func.__self__ + module: torch.nn.Module = func.__self__ has_grad = False if any(param.requires_grad for param in module.parameters()): @@ -103,11 +99,10 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable def _gradient_checkpointing_enable( self: "PreTrainedModel", - gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None, + gradient_checkpointing_kwargs: Optional[dict[str, Any]] = None, use_unsloth_gc: bool = False, ) -> None: - r""" - Activates gradient checkpointing for the current model. + r"""Activates gradient checkpointing for the current model. Modification of the original method to enable gradient checkpointing for block-wise optimizer. """ @@ -134,17 +129,18 @@ def _gradient_checkpointing_enable( def _fp32_forward_post_hook( - module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor" + module: "torch.nn.Module", args: tuple["torch.Tensor"], output: "torch.Tensor" ) -> "torch.Tensor": return output.to(torch.float32) def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArguments") -> None: - r""" - Includes: - (1) cast the layernorm in fp32 - (2) make output embedding layer require grads - (3) add the upcasting of the lm_head in fp32 + r"""Prepare the model before training. + + Include: + (1) cast the layernorm in fp32 + (2) make output embedding layer require grads + (3) add the upcasting of the lm_head in fp32. """ if model_args.upcast_layernorm: logger.info_rank0("Upcasting layernorm weights in float32.") diff --git a/src/llamafactory/model/model_utils/embedding.py b/src/llamafactory/model/model_utils/embedding.py index 199b53c3..c10e34f4 100644 --- a/src/llamafactory/model/model_utils/embedding.py +++ b/src/llamafactory/model/model_utils/embedding.py @@ -38,9 +38,7 @@ def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None: - r""" - Resize token embeddings. - """ + r"""Resize token embeddings.""" if is_deepspeed_zero3_enabled(): import deepspeed # type: ignore diff --git a/src/llamafactory/model/model_utils/longlora.py b/src/llamafactory/model/model_utils/longlora.py index 798b3906..12ea91e5 100644 --- a/src/llamafactory/model/model_utils/longlora.py +++ b/src/llamafactory/model/model_utils/longlora.py @@ -18,7 +18,7 @@ # limitations under the License. import math -from typing import TYPE_CHECKING, Optional, Tuple +from typing import TYPE_CHECKING, Optional import torch import torch.nn as nn @@ -54,14 +54,14 @@ def llama_attention_forward( past_key_value: Optional["Cache"] = None, output_attentions: bool = False, cache_position: Optional["torch.LongTensor"] = None, - position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None, + position_embeddings: Optional[tuple["torch.Tensor", "torch.Tensor"]] = None, **kwargs, -) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]: +) -> tuple["torch.Tensor", Optional["torch.Tensor"], Optional[tuple["torch.Tensor"]]]: bsz, q_len, _ = hidden_states.size() - query_states: "torch.Tensor" = self.q_proj(hidden_states) - key_states: "torch.Tensor" = self.k_proj(hidden_states) - value_states: "torch.Tensor" = self.v_proj(hidden_states) + query_states: torch.Tensor = self.q_proj(hidden_states) + key_states: torch.Tensor = self.k_proj(hidden_states) + value_states: torch.Tensor = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) @@ -139,17 +139,17 @@ def llama_flash_attention_2_forward( past_key_value: Optional["Cache"] = None, output_attentions: bool = False, cache_position: Optional["torch.LongTensor"] = None, - position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None, + position_embeddings: Optional[tuple["torch.Tensor", "torch.Tensor"]] = None, **kwargs, -) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]: +) -> tuple["torch.Tensor", Optional["torch.Tensor"], Optional[tuple["torch.Tensor"]]]: # LlamaFlashAttention2 attention does not support output_attentions output_attentions = False bsz, q_len, _ = hidden_states.size() - query_states: "torch.Tensor" = self.q_proj(hidden_states) - key_states: "torch.Tensor" = self.k_proj(hidden_states) - value_states: "torch.Tensor" = self.v_proj(hidden_states) + query_states: torch.Tensor = self.q_proj(hidden_states) + key_states: torch.Tensor = self.k_proj(hidden_states) + value_states: torch.Tensor = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) @@ -209,7 +209,7 @@ def llama_flash_attention_2_forward( if is_transformers_version_greater_than("4.43.0"): from transformers.modeling_flash_attention_utils import _flash_attention_forward - attn_output: "torch.Tensor" = _flash_attention_forward( + attn_output: torch.Tensor = _flash_attention_forward( query_states, key_states, value_states, @@ -221,7 +221,7 @@ def llama_flash_attention_2_forward( is_causal=self.is_causal, ) else: - attn_output: "torch.Tensor" = self._flash_attention_forward( + attn_output: torch.Tensor = self._flash_attention_forward( query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate ) @@ -254,9 +254,9 @@ def llama_sdpa_attention_forward( past_key_value: Optional["Cache"] = None, output_attentions: bool = False, cache_position: Optional["torch.LongTensor"] = None, - position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None, + position_embeddings: Optional[tuple["torch.Tensor", "torch.Tensor"]] = None, **kwargs, -) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]: +) -> tuple["torch.Tensor", Optional["torch.Tensor"], Optional[tuple["torch.Tensor"]]]: if output_attentions: transformers_logger.warning_once( "SDPA does not support `output_attentions=True`. Falling back to the vanilla attention" @@ -274,9 +274,9 @@ def llama_sdpa_attention_forward( bsz, q_len, _ = hidden_states.size() - query_states: "torch.Tensor" = self.q_proj(hidden_states) - key_states: "torch.Tensor" = self.k_proj(hidden_states) - value_states: "torch.Tensor" = self.v_proj(hidden_states) + query_states: torch.Tensor = self.q_proj(hidden_states) + key_states: torch.Tensor = self.k_proj(hidden_states) + value_states: torch.Tensor = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) diff --git a/src/llamafactory/model/model_utils/misc.py b/src/llamafactory/model/model_utils/misc.py index fc777ecb..b0249b47 100644 --- a/src/llamafactory/model/model_utils/misc.py +++ b/src/llamafactory/model/model_utils/misc.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING from ...extras import logging from .visual import COMPOSITE_MODELS @@ -25,10 +25,8 @@ if TYPE_CHECKING: logger = logging.get_logger(__name__) -def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]: - r""" - Finds all available modules to apply LoRA, GaLore or APOLLO. - """ +def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> list[str]: + r"""Find all available modules to apply LoRA, GaLore or APOLLO.""" model_type = getattr(model.config, "model_type", None) forbidden_modules = {"lm_head"} if model_type == "chatglm": @@ -54,10 +52,8 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) return list(module_names) -def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], num_layer_trainable: int) -> List[str]: - r""" - Finds the modules in the expanded blocks to apply lora. - """ +def find_expanded_modules(model: "PreTrainedModel", target_modules: list[str], num_layer_trainable: int) -> list[str]: + r"""Find the modules in the expanded blocks to apply lora.""" num_layers = getattr(model.config, "num_hidden_layers", None) if not num_layers: raise ValueError("Model was not supported.") diff --git a/src/llamafactory/model/model_utils/moe.py b/src/llamafactory/model/model_utils/moe.py index 4e520d5c..9e225ad5 100644 --- a/src/llamafactory/model/model_utils/moe.py +++ b/src/llamafactory/model/model_utils/moe.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Sequence +from collections.abc import Sequence +from typing import TYPE_CHECKING import torch from transformers.integrations import is_deepspeed_zero3_enabled @@ -34,9 +35,7 @@ def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: Sequence["torch def add_z3_leaf_module(model: "PreTrainedModel") -> None: - r""" - Sets module as a leaf module to skip partitioning in deepspeed zero3. - """ + r"""Set module as a leaf module to skip partitioning in deepspeed zero3.""" if not is_deepspeed_zero3_enabled(): return diff --git a/src/llamafactory/model/model_utils/packing.py b/src/llamafactory/model/model_utils/packing.py index 275d7895..475d7bc3 100644 --- a/src/llamafactory/model/model_utils/packing.py +++ b/src/llamafactory/model/model_utils/packing.py @@ -37,7 +37,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING import torch import torch.nn.functional as F @@ -59,8 +59,7 @@ logger = logging.get_logger(__name__) def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor": - r""" - Gets the sequnce lengths in the current batch. + r"""Get the sequnce lengths in the current batch. e.g. ```python @@ -76,7 +75,7 @@ def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor": bsz = attention_mask.size(0) dtype, device = attention_mask.dtype, attention_mask.device max_num = torch.max(attention_mask).item() - counts: "torch.Tensor" = torch.zeros((bsz, max_num), dtype=dtype, device=device) + counts: torch.Tensor = torch.zeros((bsz, max_num), dtype=dtype, device=device) for i in range(max_num): counts[:, i] = torch.sum(attention_mask == (i + 1), dim=-1) @@ -85,9 +84,8 @@ def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor": return seqlens -def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "torch.Tensor", int]: - r""" - Prepares the indices and seqlens for flash attn varlen function. +def get_unpad_data(attention_mask: "torch.Tensor") -> tuple["torch.Tensor", "torch.Tensor", int]: + r"""Prepare the indices and seqlens for flash attn varlen function. Returns: indices: indices of non-masked tokens from the flattened sequence. @@ -106,6 +104,7 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor [0, 2, 5, 6, 8, 11] 3 ``` + """ seqlens_in_batch = get_seqlens_in_batch(attention_mask) indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() diff --git a/src/llamafactory/model/model_utils/quantization.py b/src/llamafactory/model/model_utils/quantization.py index e000ee23..860e2c2a 100644 --- a/src/llamafactory/model/model_utils/quantization.py +++ b/src/llamafactory/model/model_utils/quantization.py @@ -19,7 +19,7 @@ import os import random from enum import Enum, unique -from typing import TYPE_CHECKING, Any, Dict, List +from typing import TYPE_CHECKING, Any import torch from datasets import load_dataset @@ -43,9 +43,7 @@ logger = logging.get_logger(__name__) @unique class QuantizationMethod(str, Enum): - r""" - Borrowed from `transformers.utils.quantization_config.QuantizationMethod`. - """ + r"""Borrowed from `transformers.utils.quantization_config.QuantizationMethod`.""" BITS_AND_BYTES = "bitsandbytes" GPTQ = "gptq" @@ -56,10 +54,8 @@ class QuantizationMethod(str, Enum): HQQ = "hqq" -def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[Dict[str, Any]]: - r""" - Prepares the tokenized dataset to perform AutoGPTQ. Do not use tensor output for JSON serialization. - """ +def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> list[dict[str, Any]]: + r"""Prepare the tokenized dataset to perform AutoGPTQ. Do not use tensor output for JSON serialization.""" if os.path.isfile(model_args.export_quantization_dataset): data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None) data_files = model_args.export_quantization_dataset @@ -84,7 +80,7 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod raise ValueError("Cannot find satisfying example, considering decrease `export_quantization_maxlen`.") sample_idx = random.randint(0, len(dataset) - 1) - sample: Dict[str, "torch.Tensor"] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt") + sample: dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt") n_try += 1 if sample["input_ids"].size(1) > maxlen: break # TODO: fix large maxlen @@ -101,11 +97,9 @@ def configure_quantization( config: "PretrainedConfig", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", - init_kwargs: Dict[str, Any], + init_kwargs: dict[str, Any], ) -> None: - r""" - Priority: PTQ-quantized (train/infer) > AutoGPTQ (export) > On-the-fly quantization (train/infer) - """ + r"""Priority: PTQ-quantized (train/infer) > AutoGPTQ (export) > On-the-fly quantization (train/infer).""" if getattr(config, "quantization_config", None): # ptq if model_args.quantization_bit is not None: logger.warning_rank0("`quantization_bit` will not affect on the PTQ-quantized models.") @@ -113,7 +107,7 @@ def configure_quantization( if is_deepspeed_zero3_enabled() or is_fsdp_enabled(): raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.") - quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None) + quantization_config: dict[str, Any] = getattr(config, "quantization_config", None) quant_method = quantization_config.get("quant_method", "") if quant_method == QuantizationMethod.GPTQ: diff --git a/src/llamafactory/model/model_utils/unsloth.py b/src/llamafactory/model/model_utils/unsloth.py index 899cc971..8bb6aa64 100644 --- a/src/llamafactory/model/model_utils/unsloth.py +++ b/src/llamafactory/model/model_utils/unsloth.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Optional from ...extras import logging from ...extras.misc import get_current_device @@ -29,7 +29,7 @@ logger = logging.get_logger(__name__) def _get_unsloth_kwargs( config: "PretrainedConfig", model_name_or_path: str, model_args: "ModelArguments" -) -> Dict[str, Any]: +) -> dict[str, Any]: return { "model_name": model_name_or_path, "max_seq_length": model_args.model_max_length or 4096, @@ -47,10 +47,8 @@ def _get_unsloth_kwargs( def load_unsloth_pretrained_model( config: "PretrainedConfig", model_args: "ModelArguments" ) -> Optional["PreTrainedModel"]: - r""" - Optionally loads pretrained model with unsloth. Used in training. - """ - from unsloth import FastLanguageModel + r"""Optionally load pretrained model with unsloth. Used in training.""" + from unsloth import FastLanguageModel # type: ignore unsloth_kwargs = _get_unsloth_kwargs(config, model_args.model_name_or_path, model_args) try: @@ -64,12 +62,10 @@ def load_unsloth_pretrained_model( def get_unsloth_peft_model( - model: "PreTrainedModel", model_args: "ModelArguments", peft_kwargs: Dict[str, Any] + model: "PreTrainedModel", model_args: "ModelArguments", peft_kwargs: dict[str, Any] ) -> "PreTrainedModel": - r""" - Gets the peft model for the pretrained model with unsloth. Used in training. - """ - from unsloth import FastLanguageModel + r"""Get the peft model for the pretrained model with unsloth. Used in training.""" + from unsloth import FastLanguageModel # type: ignore unsloth_peft_kwargs = { "model": model, @@ -82,10 +78,8 @@ def get_unsloth_peft_model( def load_unsloth_peft_model( config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool ) -> "PreTrainedModel": - r""" - Loads peft model with unsloth. Used in both training and inference. - """ - from unsloth import FastLanguageModel + r"""Load peft model with unsloth. Used in both training and inference.""" + from unsloth import FastLanguageModel # type: ignore unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args) try: diff --git a/src/llamafactory/model/model_utils/valuehead.py b/src/llamafactory/model/model_utils/valuehead.py index ace90f75..137c6b7d 100644 --- a/src/llamafactory/model/model_utils/valuehead.py +++ b/src/llamafactory/model/model_utils/valuehead.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Dict +from typing import TYPE_CHECKING import torch from transformers.utils import cached_file @@ -30,9 +30,8 @@ if TYPE_CHECKING: logger = logging.get_logger(__name__) -def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]: - r""" - Loads value head parameters from Hugging Face Hub or local disk. +def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> dict[str, torch.Tensor]: + r"""Load value head parameters from Hugging Face Hub or local disk. Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`. """ diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index 316740f0..8c1c3df9 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -15,8 +15,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Sequence from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple +from typing import TYPE_CHECKING, Optional import torch import transformers @@ -40,9 +41,9 @@ transformers_logger = transformers.utils.logging.get_logger(__name__) class CompositeModel: model_type: str projector_key: str - vision_model_keys: List[str] - language_model_keys: List[str] - lora_conflict_keys: List[str] + vision_model_keys: list[str] + language_model_keys: list[str] + lora_conflict_keys: list[str] def get_projector(self, module: "torch.nn.Module") -> "torch.nn.Module": for key in self.projector_key.split("."): @@ -51,15 +52,15 @@ class CompositeModel: return module -COMPOSITE_MODELS: Dict[str, "CompositeModel"] = {} +COMPOSITE_MODELS: dict[str, "CompositeModel"] = {} def _register_composite_model( model_type: str, projector_key: Optional[str] = None, - vision_model_keys: Optional[List[str]] = None, - language_model_keys: Optional[List[str]] = None, - lora_conflict_keys: Optional[List[str]] = None, + vision_model_keys: Optional[list[str]] = None, + language_model_keys: Optional[list[str]] = None, + lora_conflict_keys: Optional[list[str]] = None, ): COMPOSITE_MODELS[model_type] = CompositeModel( model_type=model_type, @@ -116,12 +117,10 @@ class LlavaMultiModalProjectorForYiVLForVLLM(LlavaMultiModalProjectorForYiVL): def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArguments") -> None: - r""" - Casts projector output to half precision for fine-tuning quantized VLMs. - """ + r"""Cast projector output to half precision for fine-tuning quantized VLMs.""" def _mm_projector_forward_post_hook( - module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor" + module: "torch.nn.Module", args: tuple["torch.Tensor"], output: "torch.Tensor" ) -> "torch.Tensor": return output.to(model_args.compute_dtype) @@ -137,9 +136,7 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen def configure_visual_model(config: "PretrainedConfig") -> None: - r""" - Patches VLMs before loading them. - """ + r"""Patch VLMs before loading them.""" if getattr(config, "text_config", None) and not getattr(config, "hidden_size", None): # required for ds zero3 and valuehead models setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None)) @@ -149,10 +146,8 @@ def configure_visual_model(config: "PretrainedConfig") -> None: transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL -def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "FinetuningArguments") -> Set[str]: - r""" - Freezes vision tower and language model for VLM full/freeze tuning. - """ +def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "FinetuningArguments") -> set[str]: + r"""Freeze vision tower and language model for VLM full/freeze tuning.""" model_type = getattr(config, "model_type", None) forbidden_modules = set() if model_type in COMPOSITE_MODELS: @@ -175,9 +170,7 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni def get_image_seqlen(config: "PretrainedConfig") -> int: - r""" - Computes the number of special tokens per image. - """ + r"""Compute the number of special tokens per image.""" model_type = getattr(config, "model_type", None) if model_type == "llava": image_seqlen = (config.vision_config.image_size // config.vision_config.patch_size) ** 2 @@ -192,17 +185,13 @@ def get_image_seqlen(config: "PretrainedConfig") -> int: def get_patch_size(config: "PretrainedConfig", processor: "ProcessorMixin") -> int: - r""" - Computes the patch size of the vit. - """ + r"""Compute the patch size of the vit.""" patch_size = getattr(config.vision_config, "patch_size", getattr(processor, "patch_size", -1)) return patch_size def get_vision_feature_select_strategy(config: "PretrainedConfig", processor: "ProcessorMixin") -> int: - r""" - Get the vision_feature_select_strategy. - """ + r"""Get the vision_feature_select_strategy.""" vision_feature_select_strategy = getattr( config, "vision_feature_select_strategy", getattr(processor, "vision_feature_select_strategy", "default") ) @@ -211,10 +200,8 @@ def get_vision_feature_select_strategy(config: "PretrainedConfig", processor: "P def patch_target_modules( model: "PreTrainedModel", finetuning_args: "FinetuningArguments", target_modules: Sequence[str] -) -> List[str]: - r""" - Freezes vision tower for VLM LoRA tuning. - """ +) -> list[str]: + r"""Freezes vision tower for VLM LoRA tuning.""" model_type = getattr(model.config, "model_type", None) if model_type in COMPOSITE_MODELS: forbidden_modules = get_forbidden_modules(model.config, finetuning_args) diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index 126abe5d..f732c792 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -13,7 +13,7 @@ # limitations under the License. from types import MethodType -from typing import TYPE_CHECKING, Any, Dict +from typing import TYPE_CHECKING, Any import torch from peft import PeftModel @@ -93,7 +93,7 @@ def patch_config( config: "PretrainedConfig", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", - init_kwargs: Dict[str, Any], + init_kwargs: dict[str, Any], is_trainable: bool, ) -> None: if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32 diff --git a/src/llamafactory/train/callbacks.py b/src/llamafactory/train/callbacks.py index 45191d9e..d6d8ecf7 100644 --- a/src/llamafactory/train/callbacks.py +++ b/src/llamafactory/train/callbacks.py @@ -19,7 +19,7 @@ import sys import time from concurrent.futures import ThreadPoolExecutor from datetime import timedelta -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Optional import torch import transformers @@ -56,7 +56,8 @@ logger = logging.get_logger(__name__) def fix_valuehead_checkpoint( model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool ) -> None: - r""" + r"""Fix the valuehead checkpoint files. + The model is already unwrapped. There are three cases: @@ -72,10 +73,10 @@ def fix_valuehead_checkpoint( if safe_serialization: path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME) with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f: - state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()} + state_dict: dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()} else: path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME) - state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu") + state_dict: dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu") os.remove(path_to_checkpoint) decoder_state_dict, v_head_state_dict = {}, {} @@ -98,9 +99,7 @@ def fix_valuehead_checkpoint( class FixValueHeadModelCallback(TrainerCallback): - r""" - A callback for fixing the checkpoint for valuehead models. - """ + r"""A callback for fixing the checkpoint for valuehead models.""" @override def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): @@ -112,9 +111,7 @@ class FixValueHeadModelCallback(TrainerCallback): class SaveProcessorCallback(TrainerCallback): - r""" - A callback for saving the processor. - """ + r"""A callback for saving the processor.""" def __init__(self, processor: "ProcessorMixin") -> None: self.processor = processor @@ -132,9 +129,7 @@ class SaveProcessorCallback(TrainerCallback): class PissaConvertCallback(TrainerCallback): - r""" - A callback for converting the PiSSA adapter to a normal one. - """ + r"""A callback for converting the PiSSA adapter to a normal one.""" @override def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): @@ -177,9 +172,7 @@ class PissaConvertCallback(TrainerCallback): class LogCallback(TrainerCallback): - r""" - A callback for logging training and evaluation status. - """ + r"""A callback for logging training and evaluation status.""" def __init__(self) -> None: # Progress @@ -188,7 +181,7 @@ class LogCallback(TrainerCallback): self.max_steps = 0 self.elapsed_time = "" self.remaining_time = "" - self.thread_pool: Optional["ThreadPoolExecutor"] = None + self.thread_pool: Optional[ThreadPoolExecutor] = None # Status self.aborted = False self.do_train = False @@ -219,7 +212,7 @@ class LogCallback(TrainerCallback): self.elapsed_time = str(timedelta(seconds=int(elapsed_time))) self.remaining_time = str(timedelta(seconds=int(remaining_time))) - def _write_log(self, output_dir: str, logs: Dict[str, Any]) -> None: + def _write_log(self, output_dir: str, logs: dict[str, Any]) -> None: with open(os.path.join(output_dir, TRAINER_LOG), "a", encoding="utf-8") as f: f.write(json.dumps(logs) + "\n") @@ -348,9 +341,7 @@ class LogCallback(TrainerCallback): class ReporterCallback(TrainerCallback): - r""" - A callback for reporting training status to external logger. - """ + r"""A callback for reporting training status to external logger.""" def __init__( self, diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index 68cb2c3c..e8499416 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -19,7 +19,7 @@ import warnings from collections import defaultdict from contextlib import nullcontext from types import MethodType -from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union +from typing import TYPE_CHECKING, Literal, Optional, Union import torch import torch.nn.functional as F @@ -129,15 +129,11 @@ class CustomDPOTrainer(DPOTrainer): @override def get_batch_samples(self, epoch_iterator, num_batches): - r""" - Replaces the method of KTO Trainer with the one of the standard Trainer. - """ + r"""Replace the method of DPO Trainer with the one of the standard Trainer.""" return Trainer.get_batch_samples(self, epoch_iterator, num_batches) def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor": - r""" - Computes ORPO's odds ratio (OR) loss for batched log probabilities of the policy model. - """ + r"""Compute ORPO's odds ratio (OR) loss for batched log probabilities of the policy model.""" log_odds = (chosen_logps - rejected_logps) - ( torch.log1p(-torch.exp(chosen_logps)) - torch.log1p(-torch.exp(rejected_logps)) ) @@ -147,9 +143,7 @@ class CustomDPOTrainer(DPOTrainer): return orpo_loss def simpo_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor": - r""" - Computes SimPO loss for batched log probabilities of the policy model. - """ + r"""Compute SimPO loss for batched log probabilities of the policy model.""" pi_logratios = chosen_logps - rejected_logps gamma_logratios = self.simpo_gamma / self.beta logits = pi_logratios - gamma_logratios @@ -162,10 +156,8 @@ class CustomDPOTrainer(DPOTrainer): policy_rejected_logps: "torch.Tensor", reference_chosen_logps: Optional["torch.Tensor"], reference_rejected_logps: Optional["torch.Tensor"], - ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]: - r""" - Computes loss for preference learning. - """ + ) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]: + r"""Compute loss for preference learning.""" if not self.finetuning_args.use_ref_model: if self.loss_type == "orpo": losses = self.odds_ratio_loss(policy_chosen_logps, policy_rejected_logps) @@ -185,17 +177,16 @@ class CustomDPOTrainer(DPOTrainer): @override def concatenated_forward( - self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] - ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: - r""" - Computes the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO. + self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"] + ) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: + r"""Compute the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO. Otherwise the average log probabilities. """ if self.finetuning_args.use_ref_model: batch = nested_detach(batch, clone=True) # avoid error - all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32) + all_logits: torch.Tensor = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32) all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"]) if self.loss_type in ["ipo", "orpo", "simpo"]: all_logps = all_logps / valid_length @@ -212,11 +203,9 @@ class CustomDPOTrainer(DPOTrainer): @override def compute_reference_log_probs( - self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] - ) -> Tuple[Optional["torch.Tensor"], Optional["torch.Tensor"]]: - r""" - Computes log probabilities of the reference model. - """ + self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"] + ) -> tuple[Optional["torch.Tensor"], Optional["torch.Tensor"]]: + r"""Compute log probabilities of the reference model.""" if not self.finetuning_args.use_ref_model: return None, None @@ -236,12 +225,10 @@ class CustomDPOTrainer(DPOTrainer): def get_batch_loss_metrics( self, model: "PreTrainedModel", - batch: Dict[str, "torch.Tensor"], + batch: dict[str, "torch.Tensor"], train_eval: Literal["train", "eval"] = "train", - ) -> Tuple["torch.Tensor", Dict[str, "torch.Tensor"]]: - r""" - Computes the DPO loss and other metrics for the given batch of inputs for train or test. - """ + ) -> tuple["torch.Tensor", dict[str, "torch.Tensor"]]: + r"""Compute the DPO loss and other metrics for the given batch of inputs for train or test.""" metrics = {} ( policy_chosen_logps, @@ -279,18 +266,14 @@ class CustomDPOTrainer(DPOTrainer): @override def compute_loss( - self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs - ) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]: - r""" - Subclass and override to accept extra kwargs. - """ + self, model: "PreTrainedModel", inputs: dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs + ) -> Union["torch.Tensor", tuple["torch.Tensor", list["torch.Tensor"]]]: + r"""Subclass and override to accept extra kwargs.""" return super().compute_loss(model, inputs, return_outputs) @override - def log(self, logs: Dict[str, float], *args, **kwargs) -> None: - r""" - Log `logs` on the various objects watching training, including stored metrics. - """ + def log(self, logs: dict[str, float], *args, **kwargs) -> None: + r"""Log `logs` on the various objects watching training, including stored metrics.""" # logs either has "loss" or "eval_loss" train_eval = "train" if "loss" in logs else "eval" # Add averaged stored metrics to logs diff --git a/src/llamafactory/train/dpo/workflow.py b/src/llamafactory/train/dpo/workflow.py index e513f5c4..7a9ff517 100644 --- a/src/llamafactory/train/dpo/workflow.py +++ b/src/llamafactory/train/dpo/workflow.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Optional from ...data import PairwiseDataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer from ...extras.constants import IGNORE_INDEX @@ -38,7 +38,7 @@ def run_dpo( data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments", finetuning_args: "FinetuningArguments", - callbacks: Optional[List["TrainerCallback"]] = None, + callbacks: Optional[list["TrainerCallback"]] = None, ): tokenizer_module = load_tokenizer(model_args) tokenizer = tokenizer_module["tokenizer"] diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py index 65c72449..d4c07092 100644 --- a/src/llamafactory/train/kto/trainer.py +++ b/src/llamafactory/train/kto/trainer.py @@ -19,7 +19,7 @@ import warnings from collections import defaultdict from contextlib import nullcontext from types import MethodType -from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union +from typing import TYPE_CHECKING, Literal, Optional, Union import torch from transformers import Trainer @@ -120,9 +120,7 @@ class CustomKTOTrainer(KTOTrainer): @override def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]: - r""" - Replaces the sequential sampler of KTO Trainer created by trl with the random sampler. - """ + r"""Replace the sequential sampler of KTO Trainer created by trl with the random sampler.""" if self.finetuning_args.disable_shuffling: return torch.utils.data.SequentialSampler(self.train_dataset) @@ -130,18 +128,14 @@ class CustomKTOTrainer(KTOTrainer): @override def get_batch_samples(self, epoch_iterator, num_batches): - r""" - Replaces the method of KTO Trainer with the one of the standard Trainer. - """ + r"""Replace the method of KTO Trainer with the one of the standard Trainer.""" return Trainer.get_batch_samples(self, epoch_iterator, num_batches) @override def forward( - self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = "" - ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]: - r""" - Runs forward pass and computes the log probabilities. - """ + self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = "" + ) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]: + r"""Run forward pass and computes the log probabilities.""" batch = nested_detach(batch, clone=True) # avoid error model_inputs = { "input_ids": batch[f"{prefix}input_ids"], @@ -171,8 +165,8 @@ class CustomKTOTrainer(KTOTrainer): @override def concatenated_forward( - self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] - ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: + self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"] + ) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: target_logits, target_logps, target_logps_avg = self.forward(model, batch) with torch.no_grad(): _, kl_logps, _ = self.forward(model, batch, prefix="kl_") @@ -189,11 +183,9 @@ class CustomKTOTrainer(KTOTrainer): @override def compute_reference_log_probs( - self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] - ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]: - r""" - Computes log probabilities of the reference model. - """ + self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"] + ) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]: + r"""Compute log probabilities of the reference model.""" if self.ref_model is None: ref_model = model ref_context = self.accelerator.unwrap_model(model).disable_adapter() @@ -212,11 +204,9 @@ class CustomKTOTrainer(KTOTrainer): def get_batch_loss_metrics( self, model: "PreTrainedModel", - batch: Dict[str, "torch.Tensor"], - ) -> Tuple["torch.Tensor", Dict[str, "torch.Tensor"]]: - r""" - Computes the DPO loss and other metrics for the given batch of inputs for train or test. - """ + batch: dict[str, "torch.Tensor"], + ) -> tuple["torch.Tensor", dict[str, "torch.Tensor"]]: + r"""Compute the DPO loss and other metrics for the given batch of inputs for train or test.""" metrics = {} ( policy_chosen_logps, @@ -262,18 +252,14 @@ class CustomKTOTrainer(KTOTrainer): @override def compute_loss( - self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs - ) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]: - r""" - Subclass and override to accept extra kwargs. - """ + self, model: "PreTrainedModel", inputs: dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs + ) -> Union["torch.Tensor", tuple["torch.Tensor", list["torch.Tensor"]]]: + r"""Subclass and override to accept extra kwargs.""" return super().compute_loss(model, inputs, return_outputs) @override - def log(self, logs: Dict[str, float], *args, **kwargs) -> None: - r""" - Log `logs` on the various objects watching training, including stored metrics. - """ + def log(self, logs: dict[str, float], *args, **kwargs) -> None: + r"""Log `logs` on the various objects watching training, including stored metrics.""" # logs either has "loss" or "eval_loss" train_eval = "train" if "loss" in logs else "eval" prefix = "eval_" if train_eval == "eval" else "" @@ -291,7 +277,7 @@ class CustomKTOTrainer(KTOTrainer): metric_list = torch.tensor(metric_list, dtype=torch.float).to(self.accelerator.device) metric_list = self.accelerator.reduce(metric_list, "sum").tolist() - metric_dict: Dict[str, float] = dict(zip(key_list, metric_list)) + metric_dict: dict[str, float] = dict(zip(key_list, metric_list)) for split in ["chosen", "rejected"]: # accumulate average metrics from sums and lengths if f"count/{split}" in metric_dict: for key in ("rewards", "logps", "logits"): diff --git a/src/llamafactory/train/kto/workflow.py b/src/llamafactory/train/kto/workflow.py index e98510d5..74668720 100644 --- a/src/llamafactory/train/kto/workflow.py +++ b/src/llamafactory/train/kto/workflow.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Optional from ...data import KTODataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer from ...extras.constants import IGNORE_INDEX @@ -37,7 +37,7 @@ def run_kto( data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments", finetuning_args: "FinetuningArguments", - callbacks: Optional[List["TrainerCallback"]] = None, + callbacks: Optional[list["TrainerCallback"]] = None, ): tokenizer_module = load_tokenizer(model_args) tokenizer = tokenizer_module["tokenizer"] diff --git a/src/llamafactory/train/ppo/ppo_utils.py b/src/llamafactory/train/ppo/ppo_utils.py index 55b79b4e..9d462e77 100644 --- a/src/llamafactory/train/ppo/ppo_utils.py +++ b/src/llamafactory/train/ppo/ppo_utils.py @@ -14,7 +14,7 @@ import json from contextlib import nullcontext -from typing import TYPE_CHECKING, Dict, List, Literal, Optional +from typing import TYPE_CHECKING, Literal, Optional import torch from transformers.integrations import is_deepspeed_zero3_enabled @@ -31,10 +31,8 @@ if TYPE_CHECKING: from trl import AutoModelForCausalLMWithValueHead -def get_rewards_from_server(server_url: str, messages: List[str]) -> List["torch.Tensor"]: - r""" - Gets reward scores from the API server. - """ +def get_rewards_from_server(server_url: str, messages: list[str]) -> list["torch.Tensor"]: + r"""Get reward scores from the API server.""" headers = {"Content-Type": "application/json"} payload = {"model": "model", "messages": messages} response = requests.post(server_url, json=payload, headers=headers) @@ -43,9 +41,7 @@ def get_rewards_from_server(server_url: str, messages: List[str]) -> List["torch def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None: - r""" - Replaces the default/reward modules in the model. The model is already unwrapped. - """ + r"""Replace the default/reward modules in the model. The model is already unwrapped.""" v_head_layer = model.v_head.summary if is_deepspeed_zero3_enabled(): import deepspeed # type: ignore @@ -66,10 +62,8 @@ def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["d v_head_layer.bias.data = model.get_buffer(f"{target}_head_bias").detach().clone().to(device) -def dump_layernorm(model: "PreTrainedModel") -> Dict[str, "torch.Tensor"]: - r""" - Dumps the layernorm parameters in the model. The model is already unwrapped (and gathered). - """ +def dump_layernorm(model: "PreTrainedModel") -> dict[str, "torch.Tensor"]: + r"""Dump the layernorm parameters in the model. The model is already unwrapped (and gathered).""" layer_norm_params = {} for name, param in model.named_parameters(): if param.data.dtype == torch.float32: @@ -79,10 +73,8 @@ def dump_layernorm(model: "PreTrainedModel") -> Dict[str, "torch.Tensor"]: return layer_norm_params -def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[Dict[str, "torch.Tensor"]] = None) -> None: - r""" - Restores the layernorm parameters in the model. The model is already unwrapped (and gathered). - """ +def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[dict[str, "torch.Tensor"]] = None) -> None: + r"""Restore the layernorm parameters in the model. The model is already unwrapped (and gathered).""" for name, param in model.named_parameters(): if name in layernorm_params: param.data = layernorm_params[name] diff --git a/src/llamafactory/train/ppo/trainer.py b/src/llamafactory/train/ppo/trainer.py index 4ab7a118..00258acd 100644 --- a/src/llamafactory/train/ppo/trainer.py +++ b/src/llamafactory/train/ppo/trainer.py @@ -20,7 +20,7 @@ import os import sys import warnings from types import MethodType -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Optional import torch from accelerate.utils import DistributedDataParallelKwargs @@ -62,9 +62,7 @@ logger = logging.get_logger(__name__) class CustomPPOTrainer(PPOTrainer, Trainer): - r""" - Inherits PPOTrainer. - """ + r"""Inherit PPOTrainer.""" def __init__( self, @@ -72,7 +70,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): training_args: "Seq2SeqTrainingArguments", finetuning_args: "FinetuningArguments", generating_args: "GeneratingArguments", - callbacks: Optional[List["TrainerCallback"]], + callbacks: Optional[list["TrainerCallback"]], model: "AutoModelForCausalLMWithValueHead", reward_model: Optional["AutoModelForCausalLMWithValueHead"], ref_model: Optional["AutoModelForCausalLMWithValueHead"], @@ -187,9 +185,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): self.add_callback(BAdamCallback) def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None: - r""" - Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer. - """ + r"""Implement training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.""" if resume_from_checkpoint is not None: raise ValueError("`resume_from_checkpoint` will be supported in the future version.") @@ -221,9 +217,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): logger.info_rank0(f" Num Epochs = {num_train_epochs:,}") logger.info_rank0(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}") logger.info_rank0( - " Total train batch size (w. parallel, buffer, distributed & accumulation) = {:,}".format( - total_train_batch_size - ) + f" Total train batch size (w. parallel, buffer, distributed & accumulation) = {total_train_batch_size:,}" ) logger.info_rank0(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps:,}") logger.info_rank0(f" Num optimization epochs per batch = {self.finetuning_args.ppo_epochs:,}") @@ -339,21 +333,19 @@ class CustomPPOTrainer(PPOTrainer, Trainer): return lr_scheduler @torch.no_grad() - def get_inputs(self, batch: Dict[str, "torch.Tensor"]) -> Tuple[List["torch.Tensor"], List["torch.Tensor"]]: - r""" - Generates model's responses given queries. - """ + def get_inputs(self, batch: dict[str, "torch.Tensor"]) -> tuple[list["torch.Tensor"], list["torch.Tensor"]]: + r"""Generate model's responses given queries.""" if batch["input_ids"].size(0) == 1: # handle llama2 ppo with gradient accumulation > 1 start_index = (batch["input_ids"][0] != self.tokenizer.pad_token_id).nonzero()[0].item() for k, v in batch.items(): batch[k] = v[:, start_index:] with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: - unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) + unwrapped_model: AutoModelForCausalLMWithValueHead = self.accelerator.unwrap_model(self.model) if self.model_args.upcast_layernorm: layernorm_params = dump_layernorm(unwrapped_model) - generate_output: "torch.Tensor" = unwrapped_model.generate( + generate_output: torch.Tensor = unwrapped_model.generate( generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch ) if self.model_args.upcast_layernorm: @@ -381,11 +373,10 @@ class CustomPPOTrainer(PPOTrainer, Trainer): @torch.no_grad() def get_rewards( self, - queries: List["torch.Tensor"], - responses: List["torch.Tensor"], - ) -> List["torch.Tensor"]: - r""" - Computes scores using given reward model. + queries: list["torch.Tensor"], + responses: list["torch.Tensor"], + ) -> list["torch.Tensor"]: + r"""Compute scores using given reward model. Both inputs and outputs are put on CPU. """ @@ -394,8 +385,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer): messages = self.tokenizer.batch_decode(token_ids, skip_special_tokens=False) return get_rewards_from_server(self.reward_model, messages) - batch: Dict[str, "torch.Tensor"] = self.prepare_model_inputs(queries, responses) - unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) + batch: dict[str, torch.Tensor] = self.prepare_model_inputs(queries, responses) + unwrapped_model: AutoModelForCausalLMWithValueHead = self.accelerator.unwrap_model(self.model) if self.finetuning_args.reward_model_type == "lora": replace_model(unwrapped_model, target="reward") @@ -404,7 +395,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): reward_model = self.reward_model with unwrap_model_for_generation(reward_model, self.accelerator), self.amp_context: # support bf16 - values: "torch.Tensor" = reward_model(**batch, return_dict=True, use_cache=False)[-1] + values: torch.Tensor = reward_model(**batch, return_dict=True, use_cache=False)[-1] if self.finetuning_args.reward_model_type == "lora": replace_model(unwrapped_model, target="default") @@ -419,12 +410,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer): model: "AutoModelForCausalLMWithValueHead", queries: "torch.Tensor", responses: "torch.Tensor", - model_inputs: Dict[str, Any], + model_inputs: dict[str, Any], return_logits: bool = False, response_masks: Optional["torch.Tensor"] = None, - ) -> Tuple["torch.Tensor", Optional["torch.Tensor"], "torch.Tensor", "torch.Tensor"]: - r""" - Calculates model outputs in multiple batches. + ) -> tuple["torch.Tensor", Optional["torch.Tensor"], "torch.Tensor", "torch.Tensor"]: + r"""Calculate model outputs in multiple batches. Subclass and override to inject custom behavior. """ @@ -483,8 +473,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): @override def save_model(self, output_dir: Optional[str] = None) -> None: - r""" - Saves model checkpoint. + r"""Save model checkpoint. Subclass and override to inject custom behavior. """ @@ -508,5 +497,5 @@ class CustomPPOTrainer(PPOTrainer, Trainer): self.model.save_checkpoint(output_dir) elif self.args.should_save: - unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) + unwrapped_model: AutoModelForCausalLMWithValueHead = self.accelerator.unwrap_model(self.model) self._save(output_dir, state_dict=unwrapped_model.state_dict()) diff --git a/src/llamafactory/train/ppo/workflow.py b/src/llamafactory/train/ppo/workflow.py index 50210583..4e64a256 100644 --- a/src/llamafactory/train/ppo/workflow.py +++ b/src/llamafactory/train/ppo/workflow.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Optional from ...data import MultiModalDataCollatorForSeq2Seq, get_dataset, get_template_and_fix_tokenizer from ...extras.ploting import plot_loss @@ -37,7 +37,7 @@ def run_ppo( training_args: "Seq2SeqTrainingArguments", finetuning_args: "FinetuningArguments", generating_args: "GeneratingArguments", - callbacks: Optional[List["TrainerCallback"]] = None, + callbacks: Optional[list["TrainerCallback"]] = None, ): tokenizer_module = load_tokenizer(model_args) tokenizer = tokenizer_module["tokenizer"] @@ -53,7 +53,7 @@ def run_ppo( reward_model = create_reward_model(model, model_args, finetuning_args) # Initialize our Trainer - ppo_trainer: "CustomPPOTrainer" = CustomPPOTrainer( + ppo_trainer: CustomPPOTrainer = CustomPPOTrainer( model_args=model_args, training_args=training_args, finetuning_args=finetuning_args, diff --git a/src/llamafactory/train/pt/trainer.py b/src/llamafactory/train/pt/trainer.py index 3024004d..346611f9 100644 --- a/src/llamafactory/train/pt/trainer.py +++ b/src/llamafactory/train/pt/trainer.py @@ -31,9 +31,7 @@ if TYPE_CHECKING: class CustomTrainer(Trainer): - r""" - Inherits Trainer for custom optimizer. - """ + r"""Inherit Trainer for custom optimizer.""" def __init__( self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs diff --git a/src/llamafactory/train/pt/workflow.py b/src/llamafactory/train/pt/workflow.py index 06afdc12..ecbbe00d 100644 --- a/src/llamafactory/train/pt/workflow.py +++ b/src/llamafactory/train/pt/workflow.py @@ -16,7 +16,7 @@ # limitations under the License. import math -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Optional from transformers import DataCollatorForLanguageModeling @@ -38,7 +38,7 @@ def run_pt( data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments", finetuning_args: "FinetuningArguments", - callbacks: Optional[List["TrainerCallback"]] = None, + callbacks: Optional[list["TrainerCallback"]] = None, ): tokenizer_module = load_tokenizer(model_args) tokenizer = tokenizer_module["tokenizer"] diff --git a/src/llamafactory/train/rm/metric.py b/src/llamafactory/train/rm/metric.py index 6f08b107..a7c3c43f 100644 --- a/src/llamafactory/train/rm/metric.py +++ b/src/llamafactory/train/rm/metric.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING, Optional import numpy as np @@ -26,11 +26,9 @@ if TYPE_CHECKING: @dataclass class ComputeAccuracy: - r""" - Computes reward accuracy and supports `batch_eval_metrics`. - """ + r"""Compute reward accuracy and support `batch_eval_metrics`.""" - def _dump(self) -> Optional[Dict[str, float]]: + def _dump(self) -> Optional[dict[str, float]]: result = None if hasattr(self, "score_dict"): result = {k: float(np.mean(v)) for k, v in self.score_dict.items()} @@ -41,7 +39,7 @@ class ComputeAccuracy: def __post_init__(self): self._dump() - def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]: + def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[dict[str, float]]: chosen_scores, rejected_scores = numpify(eval_preds.predictions[0]), numpify(eval_preds.predictions[1]) if not chosen_scores.shape: self.score_dict["accuracy"].append(chosen_scores > rejected_scores) diff --git a/src/llamafactory/train/rm/trainer.py b/src/llamafactory/train/rm/trainer.py index 11c0cbc4..508f7516 100644 --- a/src/llamafactory/train/rm/trainer.py +++ b/src/llamafactory/train/rm/trainer.py @@ -18,7 +18,7 @@ import json import os from types import MethodType -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Optional, Union import torch from transformers import Trainer @@ -41,9 +41,7 @@ logger = logging.get_logger(__name__) class PairwiseTrainer(Trainer): - r""" - Inherits Trainer to compute pairwise loss. - """ + r"""Inherits Trainer to compute pairwise loss.""" def __init__( self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs @@ -88,10 +86,9 @@ class PairwiseTrainer(Trainer): @override def compute_loss( - self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs - ) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]: - r""" - Computes pairwise loss. The first n examples are chosen and the last n examples are rejected. + self, model: "PreTrainedModel", inputs: dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs + ) -> Union["torch.Tensor", tuple["torch.Tensor", list["torch.Tensor"]]]: + r"""Compute pairwise loss. The first n examples are chosen and the last n examples are rejected. Subclass and override to inject custom behavior. @@ -113,8 +110,7 @@ class PairwiseTrainer(Trainer): return loss def save_predictions(self, predict_results: "PredictionOutput") -> None: - r""" - Saves model predictions to `output_dir`. + r"""Save model predictions to `output_dir`. A custom behavior that not contained in Seq2SeqTrainer. """ @@ -126,7 +122,7 @@ class PairwiseTrainer(Trainer): chosen_scores, rejected_scores = predict_results.predictions with open(output_prediction_file, "w", encoding="utf-8") as writer: - res: List[str] = [] + res: list[str] = [] for c_score, r_score in zip(chosen_scores, rejected_scores): res.append(json.dumps({"chosen": round(float(c_score), 2), "rejected": round(float(r_score), 2)})) diff --git a/src/llamafactory/train/rm/workflow.py b/src/llamafactory/train/rm/workflow.py index bced1018..14607cca 100644 --- a/src/llamafactory/train/rm/workflow.py +++ b/src/llamafactory/train/rm/workflow.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Optional from ...data import PairwiseDataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer from ...extras.ploting import plot_loss @@ -37,7 +37,7 @@ def run_rm( data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments", finetuning_args: "FinetuningArguments", - callbacks: Optional[List["TrainerCallback"]] = None, + callbacks: Optional[list["TrainerCallback"]] = None, ): tokenizer_module = load_tokenizer(model_args) tokenizer = tokenizer_module["tokenizer"] diff --git a/src/llamafactory/train/sft/metric.py b/src/llamafactory/train/sft/metric.py index b64df0c5..323d6241 100644 --- a/src/llamafactory/train/sft/metric.py +++ b/src/llamafactory/train/sft/metric.py @@ -17,7 +17,7 @@ # limitations under the License. from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING, Optional import numpy as np import torch @@ -45,9 +45,7 @@ if is_rouge_available(): def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "torch.Tensor": - r""" - Computes the token with the largest likelihood to reduce memory footprint. - """ + r"""Compute the token with the largest likelihood to reduce memory footprint.""" if isinstance(logits, (list, tuple)): if logits[0].dim() == 3: # (batch_size, seq_len, vocab_size) logits = logits[0] @@ -62,11 +60,9 @@ def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "tor @dataclass class ComputeAccuracy: - r""" - Computes accuracy and supports `batch_eval_metrics`. - """ + r"""Compute accuracy and support `batch_eval_metrics`.""" - def _dump(self) -> Optional[Dict[str, float]]: + def _dump(self) -> Optional[dict[str, float]]: result = None if hasattr(self, "score_dict"): result = {k: float(np.mean(v)) for k, v in self.score_dict.items()} @@ -77,7 +73,7 @@ class ComputeAccuracy: def __post_init__(self): self._dump() - def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]: + def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[dict[str, float]]: preds, labels = numpify(eval_preds.predictions), numpify(eval_preds.label_ids) for i in range(len(preds)): pred, label = preds[i, :-1], labels[i, 1:] @@ -90,15 +86,14 @@ class ComputeAccuracy: @dataclass class ComputeSimilarity: - r""" - Computes text similarity scores and supports `batch_eval_metrics`. + r"""Compute text similarity scores and support `batch_eval_metrics`. Wraps the tokenizer into metric functions, used in CustomSeq2SeqTrainer. """ tokenizer: "PreTrainedTokenizer" - def _dump(self) -> Optional[Dict[str, float]]: + def _dump(self) -> Optional[dict[str, float]]: result = None if hasattr(self, "score_dict"): result = {k: float(np.mean(v)) for k, v in self.score_dict.items()} @@ -109,7 +104,7 @@ class ComputeSimilarity: def __post_init__(self): self._dump() - def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]: + def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[dict[str, float]]: preds, labels = numpify(eval_preds.predictions), numpify(eval_preds.label_ids) preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id) diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index 9050a9ba..ea22c6bb 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -18,7 +18,7 @@ import json import os from types import MethodType -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Optional, Union import numpy as np import torch @@ -44,21 +44,19 @@ logger = logging.get_logger(__name__) class CustomSeq2SeqTrainer(Seq2SeqTrainer): - r""" - Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE. - """ + r"""Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE.""" def __init__( self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], - gen_kwargs: Optional[Dict[str, Any]] = None, + gen_kwargs: Optional[dict[str, Any]] = None, **kwargs, ) -> None: if is_transformers_version_greater_than("4.46"): kwargs["processing_class"] = kwargs.pop("tokenizer") else: - self.processing_class: "PreTrainedTokenizer" = kwargs.get("tokenizer") + self.processing_class: PreTrainedTokenizer = kwargs.get("tokenizer") super().__init__(**kwargs) self.finetuning_args = finetuning_args @@ -99,13 +97,12 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): def prediction_step( self, model: "torch.nn.Module", - inputs: Dict[str, Union["torch.Tensor", Any]], + inputs: dict[str, Union["torch.Tensor", Any]], prediction_loss_only: bool, - ignore_keys: Optional[List[str]] = None, + ignore_keys: Optional[list[str]] = None, **gen_kwargs, - ) -> Tuple[Optional[float], Optional["torch.Tensor"], Optional["torch.Tensor"]]: - r""" - Removes the prompt part in the generated tokens. + ) -> tuple[Optional[float], Optional["torch.Tensor"], Optional["torch.Tensor"]]: + r"""Remove the prompt part in the generated tokens. Subclass and override to inject custom behavior. """ @@ -126,8 +123,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): def save_predictions( self, dataset: "Dataset", predict_results: "PredictionOutput", skip_special_tokens: bool = True ) -> None: - r""" - Saves model predictions to `output_dir`. + r"""Save model predictions to `output_dir`. A custom behavior that not contained in Seq2SeqTrainer. """ diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index 5b904244..1474a74a 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Optional from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer from ...extras.constants import IGNORE_INDEX @@ -43,7 +43,7 @@ def run_sft( training_args: "Seq2SeqTrainingArguments", finetuning_args: "FinetuningArguments", generating_args: "GeneratingArguments", - callbacks: Optional[List["TrainerCallback"]] = None, + callbacks: Optional[list["TrainerCallback"]] = None, ): tokenizer_module = load_tokenizer(model_args) tokenizer = tokenizer_module["tokenizer"] diff --git a/src/llamafactory/train/test_utils.py b/src/llamafactory/train/test_utils.py index f8ba3510..ceffe1e0 100644 --- a/src/llamafactory/train/test_utils.py +++ b/src/llamafactory/train/test_utils.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Dict, Optional, Sequence, Set, Tuple, Union +from collections.abc import Sequence +from typing import TYPE_CHECKING, Optional, Union import torch from peft import PeftModel @@ -43,7 +44,7 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_k assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) is True -def check_lora_model(model: "LoraModel") -> Tuple[Set[str], Set[str]]: +def check_lora_model(model: "LoraModel") -> tuple[set[str], set[str]]: linear_modules, extra_modules = set(), set() for name, param in model.named_parameters(): if any(module in name for module in ["lora_A", "lora_B"]): @@ -83,7 +84,7 @@ def load_reference_model( ) -> Union["PreTrainedModel", "LoraModel"]: current_device = get_current_device() if add_valuehead: - model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained( + model: AutoModelForCausalLMWithValueHead = AutoModelForCausalLMWithValueHead.from_pretrained( model_path, torch_dtype=torch.float16, device_map=current_device ) if not is_trainable: @@ -111,7 +112,7 @@ def load_dataset_module(**kwargs) -> "DatasetModule": def patch_valuehead_model() -> None: - def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]) -> None: + def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: dict[str, "torch.Tensor"]) -> None: state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")} self.v_head.load_state_dict(state_dict, strict=False) del state_dict diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index 161e2860..5f6aeb8f 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -21,7 +21,7 @@ import json import os from collections.abc import Mapping from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch from transformers import Trainer @@ -63,12 +63,10 @@ logger = logging.get_logger(__name__) class DummyOptimizer(torch.optim.Optimizer): - r""" - A dummy optimizer used for the GaLore or APOLLO algorithm. - """ + r"""A dummy optimizer used for the GaLore or APOLLO algorithm.""" def __init__( - self, lr: float = 1e-3, optimizer_dict: Optional[Dict["torch.nn.Parameter", "torch.optim.Optimizer"]] = None + self, lr: float = 1e-3, optimizer_dict: Optional[dict["torch.nn.Parameter", "torch.optim.Optimizer"]] = None ) -> None: dummy_tensor = torch.randn(1, 1) self.optimizer_dict = optimizer_dict @@ -112,8 +110,7 @@ def create_modelcard_and_push( def create_ref_model( model_args: "ModelArguments", finetuning_args: "FinetuningArguments", add_valuehead: bool = False ) -> Optional[Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]]: - r""" - Creates reference model for PPO/DPO training. Evaluation mode is not supported. + r"""Create reference model for PPO/DPO training. Evaluation mode is not supported. The valuehead parameter is randomly initialized since it is useless for PPO training. """ @@ -148,9 +145,7 @@ def create_ref_model( def create_reward_model( model: "AutoModelForCausalLMWithValueHead", model_args: "ModelArguments", finetuning_args: "FinetuningArguments" ) -> Optional["AutoModelForCausalLMWithValueHead"]: - r""" - Creates reward model for PPO training. - """ + r"""Create reward model for PPO training.""" if finetuning_args.reward_model_type == "api": assert finetuning_args.reward_model.startswith("http"), "Please provide full url." logger.info_rank0(f"Use reward server {finetuning_args.reward_model}") @@ -189,10 +184,8 @@ def create_reward_model( return reward_model -def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]: - r""" - Returns a list of names of parameters with weight decay. (weights in non-layernorm layers) - """ +def _get_decay_parameter_names(model: "PreTrainedModel") -> list[str]: + r"""Return a list of names of parameters with weight decay. (weights in non-layernorm layers).""" decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS) decay_parameters = [name for name in decay_parameters if "bias" not in name] return decay_parameters @@ -208,7 +201,7 @@ def _create_galore_optimizer( else: galore_targets = finetuning_args.galore_target - galore_params: List["torch.nn.Parameter"] = [] + galore_params: list[torch.nn.Parameter] = [] for name, module in model.named_modules(): if isinstance(module, torch.nn.Linear) and any(target in name for target in galore_targets): for param in module.parameters(): @@ -224,7 +217,7 @@ def _create_galore_optimizer( id_galore_params = {id(param) for param in galore_params} decay_params, nodecay_params = [], [] # they are non-galore parameters - trainable_params: List["torch.nn.Parameter"] = [] # galore_params + decay_params + nodecay_params + trainable_params: list[torch.nn.Parameter] = [] # galore_params + decay_params + nodecay_params decay_param_names = _get_decay_parameter_names(model) for name, param in model.named_parameters(): if param.requires_grad: @@ -251,7 +244,7 @@ def _create_galore_optimizer( if training_args.gradient_accumulation_steps != 1: raise ValueError("Per-layer GaLore does not support gradient accumulation.") - optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {} + optimizer_dict: dict[torch.Tensor, torch.optim.Optimizer] = {} for param in nodecay_params: param_groups = [dict(params=[param], weight_decay=0.0)] optimizer_dict[param] = optim_class(param_groups, **optim_kwargs) @@ -296,7 +289,7 @@ def _create_apollo_optimizer( else: apollo_targets = finetuning_args.apollo_target - apollo_params: List["torch.nn.Parameter"] = [] + apollo_params: list[torch.nn.Parameter] = [] for name, module in model.named_modules(): if isinstance(module, torch.nn.Linear) and any(target in name for target in apollo_targets): for param in module.parameters(): @@ -315,7 +308,7 @@ def _create_apollo_optimizer( id_apollo_params = {id(param) for param in apollo_params} decay_params, nodecay_params = [], [] # they are non-apollo parameters - trainable_params: List["torch.nn.Parameter"] = [] # apollo_params + decay_params + nodecay_params + trainable_params: list[torch.nn.Parameter] = [] # apollo_params + decay_params + nodecay_params decay_param_names = _get_decay_parameter_names(model) for name, param in model.named_parameters(): if param.requires_grad: @@ -338,7 +331,7 @@ def _create_apollo_optimizer( if training_args.gradient_accumulation_steps != 1: raise ValueError("Per-layer APOLLO does not support gradient accumulation.") - optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {} + optimizer_dict: dict[torch.Tensor, torch.optim.Optimizer] = {} for param in nodecay_params: param_groups = [dict(params=[param], weight_decay=0.0)] optimizer_dict[param] = optim_class(param_groups, **optim_kwargs) @@ -380,7 +373,7 @@ def _create_loraplus_optimizer( embedding_lr = finetuning_args.loraplus_lr_embedding decay_param_names = _get_decay_parameter_names(model) - param_dict: Dict[str, List["torch.nn.Parameter"]] = { + param_dict: dict[str, list[torch.nn.Parameter]] = { "lora_a": [], "lora_b": [], "lora_b_nodecay": [], @@ -524,7 +517,7 @@ def create_custom_scheduler( ) -> None: if optimizer is not None and isinstance(optimizer, DummyOptimizer): optimizer_dict = optimizer.optimizer_dict - scheduler_dict: Dict["torch.nn.Parameter", "torch.optim.lr_scheduler.LRScheduler"] = {} + scheduler_dict: dict[torch.nn.Parameter, torch.optim.lr_scheduler.LRScheduler] = {} for param in optimizer_dict.keys(): scheduler_dict[param] = get_scheduler( @@ -544,13 +537,13 @@ def create_custom_scheduler( def get_batch_logps( logits: "torch.Tensor", labels: "torch.Tensor", label_pad_token_id: int = IGNORE_INDEX -) -> Tuple["torch.Tensor", "torch.Tensor"]: - r""" - Computes the log probabilities of the given labels under the given logits. +) -> tuple["torch.Tensor", "torch.Tensor"]: + r"""Compute the log probabilities of the given labels under the given logits. Returns: logps: A tensor of shape (batch_size,) containing the sum of log probabilities. valid_length: A tensor of shape (batch_size,) containing the number of non-masked tokens. + """ if logits.shape[:-1] != labels.shape: raise ValueError("Logits (batchsize x seqlen) and labels must have the same shape.") @@ -564,12 +557,10 @@ def get_batch_logps( def nested_detach( - tensors: Union["torch.Tensor", List["torch.Tensor"], Tuple["torch.Tensor"], Dict[str, "torch.Tensor"]], + tensors: Union["torch.Tensor", list["torch.Tensor"], tuple["torch.Tensor"], dict[str, "torch.Tensor"]], clone: bool = False, ): - r""" - Detach `tensors` (even if it's a nested list/tuple/dict of tensors). - """ + r"""Detach `tensors` (even if it's a nested list/tuple/dict of tensors).""" if isinstance(tensors, (list, tuple)): return type(tensors)(nested_detach(t, clone=clone) for t in tensors) elif isinstance(tensors, Mapping): @@ -585,9 +576,7 @@ def nested_detach( def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCallback": - r""" - Gets the callback for logging to SwanLab. - """ + r"""Get the callback for logging to SwanLab.""" import swanlab # type: ignore from swanlab.integration.transformers import SwanLabCallback # type: ignore @@ -624,7 +613,7 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall def get_ray_trainer( training_function: Callable, - train_loop_config: Dict[str, Any], + train_loop_config: dict[str, Any], ray_args: "RayArguments", ) -> "TorchTrainer": if not ray_args.use_ray: diff --git a/src/llamafactory/train/tuner.py b/src/llamafactory/train/tuner.py index 1b9ad082..c5d926ac 100644 --- a/src/llamafactory/train/tuner.py +++ b/src/llamafactory/train/tuner.py @@ -14,7 +14,7 @@ import os import shutil -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Optional import torch import torch.distributed as dist @@ -48,9 +48,9 @@ if TYPE_CHECKING: logger = logging.get_logger(__name__) -def _training_function(config: Dict[str, Any]) -> None: +def _training_function(config: dict[str, Any]) -> None: args = config.get("args") - callbacks: List[Any] = config.get("callbacks") + callbacks: list[Any] = config.get("callbacks") model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args) callbacks.append(LogCallback()) @@ -84,7 +84,7 @@ def _training_function(config: Dict[str, Any]) -> None: logger.warning(f"Failed to destroy process group: {e}.") -def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None) -> None: +def run_exp(args: Optional[dict[str, Any]] = None, callbacks: Optional[list["TrainerCallback"]] = None) -> None: args = read_args(args) if "-h" in args or "--help" in args: get_train_args(args) @@ -103,7 +103,7 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["Tra _training_function(config={"args": args, "callbacks": callbacks}) -def export_model(args: Optional[Dict[str, Any]] = None) -> None: +def export_model(args: Optional[dict[str, Any]] = None) -> None: model_args, data_args, finetuning_args, _ = get_infer_args(args) if model_args.export_dir is None: diff --git a/src/llamafactory/webui/chatter.py b/src/llamafactory/webui/chatter.py index 944a3d06..575cd584 100644 --- a/src/llamafactory/webui/chatter.py +++ b/src/llamafactory/webui/chatter.py @@ -14,7 +14,8 @@ import json import os -from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple +from collections.abc import Generator +from typing import TYPE_CHECKING, Any, Optional from transformers.utils import is_torch_npu_available @@ -37,15 +38,12 @@ if is_gradio_available(): def _escape_html(text: str) -> str: - r""" - Escapes HTML characters. - """ + r"""Escape HTML characters.""" return text.replace("<", "<").replace(">", ">") -def _format_response(text: str, lang: str, escape_html: bool, thought_words: Tuple[str, str]) -> str: - r""" - Post-processes the response text. +def _format_response(text: str, lang: str, escape_html: bool, thought_words: tuple[str, str]) -> str: + r"""Post-process the response text. Based on: https://huggingface.co/spaces/Lyte/DeepSeek-R1-Distill-Qwen-1.5B-Demo-GGUF/blob/main/app.py """ @@ -74,7 +72,7 @@ class WebChatModel(ChatModel): def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None: self.manager = manager self.demo_mode = demo_mode - self.engine: Optional["BaseEngine"] = None + self.engine: Optional[BaseEngine] = None if not lazy_init: # read arguments from command line super().__init__() @@ -160,14 +158,13 @@ class WebChatModel(ChatModel): @staticmethod def append( - chatbot: List[Dict[str, str]], - messages: List[Dict[str, str]], + chatbot: list[dict[str, str]], + messages: list[dict[str, str]], role: str, query: str, escape_html: bool, - ) -> Tuple[List[Dict[str, str]], List[Dict[str, str]], str]: - r""" - Adds the user input to chatbot. + ) -> tuple[list[dict[str, str]], list[dict[str, str]], str]: + r"""Add the user input to chatbot. Inputs: infer.chatbot, infer.messages, infer.role, infer.query, infer.escape_html Output: infer.chatbot, infer.messages, infer.query @@ -180,8 +177,8 @@ class WebChatModel(ChatModel): def stream( self, - chatbot: List[Dict[str, str]], - messages: List[Dict[str, str]], + chatbot: list[dict[str, str]], + messages: list[dict[str, str]], lang: str, system: str, tools: str, @@ -193,9 +190,8 @@ class WebChatModel(ChatModel): temperature: float, skip_special_tokens: bool, escape_html: bool, - ) -> Generator[Tuple[List[Dict[str, str]], List[Dict[str, str]]], None, None]: - r""" - Generates output text in stream. + ) -> Generator[tuple[list[dict[str, str]], list[dict[str, str]]], None, None]: + r"""Generate output text in stream. Inputs: infer.chatbot, infer.messages, infer.system, infer.tools, infer.image, infer.video, ... Output: infer.chatbot, infer.messages diff --git a/src/llamafactory/webui/common.py b/src/llamafactory/webui/common.py index f8349714..2387174a 100644 --- a/src/llamafactory/webui/common.py +++ b/src/llamafactory/webui/common.py @@ -17,7 +17,7 @@ import os import signal from collections import defaultdict from datetime import datetime -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union from psutil import Process from yaml import safe_dump, safe_load @@ -44,9 +44,7 @@ USER_CONFIG = "user_config.yaml" def abort_process(pid: int) -> None: - r""" - Aborts the processes recursively in a bottom-up way. - """ + r"""Abort the processes recursively in a bottom-up way.""" try: children = Process(pid).children() if children: @@ -59,9 +57,7 @@ def abort_process(pid: int) -> None: def get_save_dir(*paths: str) -> os.PathLike: - r""" - Gets the path to saved model checkpoints. - """ + r"""Get the path to saved model checkpoints.""" if os.path.sep in paths[-1]: logger.warning_rank0("Found complex path, some features may be not available.") return paths[-1] @@ -71,16 +67,12 @@ def get_save_dir(*paths: str) -> os.PathLike: def _get_config_path() -> os.PathLike: - r""" - Gets the path to user config. - """ + r"""Get the path to user config.""" return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG) -def load_config() -> Dict[str, Union[str, Dict[str, Any]]]: - r""" - Loads user config if exists. - """ +def load_config() -> dict[str, Union[str, dict[str, Any]]]: + r"""Load user config if exists.""" try: with open(_get_config_path(), encoding="utf-8") as f: return safe_load(f) @@ -89,9 +81,7 @@ def load_config() -> Dict[str, Union[str, Dict[str, Any]]]: def save_config(lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None) -> None: - r""" - Saves user config. - """ + r"""Save user config.""" os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True) user_config = load_config() user_config["lang"] = lang or user_config["lang"] @@ -106,11 +96,9 @@ def save_config(lang: str, model_name: Optional[str] = None, model_path: Optiona def get_model_path(model_name: str) -> str: - r""" - Gets the model path according to the model name. - """ + r"""Get the model path according to the model name.""" user_config = load_config() - path_dict: Dict["DownloadSource", str] = SUPPORTED_MODELS.get(model_name, defaultdict(str)) + path_dict: dict[DownloadSource, str] = SUPPORTED_MODELS.get(model_name, defaultdict(str)) model_path = user_config["path_dict"].get(model_name, "") or path_dict.get(DownloadSource.DEFAULT, "") if ( use_modelscope() @@ -130,30 +118,22 @@ def get_model_path(model_name: str) -> str: def get_template(model_name: str) -> str: - r""" - Gets the template name if the model is a chat/distill/instruct model. - """ + r"""Get the template name if the model is a chat/distill/instruct model.""" return DEFAULT_TEMPLATE.get(model_name, "default") def get_time() -> str: - r""" - Gets current date and time. - """ + r"""Get current date and time.""" return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S") def is_multimodal(model_name: str) -> bool: - r""" - Judges if the model is a vision language model. - """ + r"""Judge if the model is a vision language model.""" return model_name in MULTIMODAL_SUPPORTED_MODELS -def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]: - r""" - Loads dataset_info.json. - """ +def load_dataset_info(dataset_dir: str) -> dict[str, dict[str, Any]]: + r"""Load dataset_info.json.""" if dataset_dir == "ONLINE" or dataset_dir.startswith("REMOTE:"): logger.info_rank0(f"dataset_dir is {dataset_dir}, using online dataset.") return {} @@ -166,10 +146,8 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]: return {} -def load_args(config_path: str) -> Optional[Dict[str, Any]]: - r""" - Loads the training configuration from config path. - """ +def load_args(config_path: str) -> Optional[dict[str, Any]]: + r"""Load the training configuration from config path.""" try: with open(config_path, encoding="utf-8") as f: return safe_load(f) @@ -177,26 +155,20 @@ def load_args(config_path: str) -> Optional[Dict[str, Any]]: return None -def save_args(config_path: str, config_dict: Dict[str, Any]) -> None: - r""" - Saves the training configuration to config path. - """ +def save_args(config_path: str, config_dict: dict[str, Any]) -> None: + r"""Save the training configuration to config path.""" with open(config_path, "w", encoding="utf-8") as f: safe_dump(config_dict, f) -def _clean_cmd(args: Dict[str, Any]) -> Dict[str, Any]: - r""" - Removes args with NoneType or False or empty string value. - """ +def _clean_cmd(args: dict[str, Any]) -> dict[str, Any]: + r"""Remove args with NoneType or False or empty string value.""" no_skip_keys = ["packing"] return {k: v for k, v in args.items() if (k in no_skip_keys) or (v is not None and v is not False and v != "")} -def gen_cmd(args: Dict[str, Any]) -> str: - r""" - Generates CLI commands for previewing. - """ +def gen_cmd(args: dict[str, Any]) -> str: + r"""Generate CLI commands for previewing.""" cmd_lines = ["llamafactory-cli train "] for k, v in _clean_cmd(args).items(): if isinstance(v, dict): @@ -215,10 +187,8 @@ def gen_cmd(args: Dict[str, Any]) -> str: return cmd_text -def save_cmd(args: Dict[str, Any]) -> str: - r""" - Saves CLI commands to launch training. - """ +def save_cmd(args: dict[str, Any]) -> str: + r"""Save CLI commands to launch training.""" output_dir = args["output_dir"] os.makedirs(output_dir, exist_ok=True) with open(os.path.join(output_dir, TRAINING_ARGS), "w", encoding="utf-8") as f: @@ -228,9 +198,7 @@ def save_cmd(args: Dict[str, Any]) -> str: def load_eval_results(path: os.PathLike) -> str: - r""" - Gets scores after evaluation. - """ + r"""Get scores after evaluation.""" with open(path, encoding="utf-8") as f: result = json.dumps(json.load(f), indent=4) @@ -238,9 +206,7 @@ def load_eval_results(path: os.PathLike) -> str: def create_ds_config() -> None: - r""" - Creates deepspeed config in the current directory. - """ + r"""Create deepspeed config in the current directory.""" os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True) ds_config = { "train_batch_size": "auto", diff --git a/src/llamafactory/webui/components/chatbot.py b/src/llamafactory/webui/components/chatbot.py index 51e9691d..52217e16 100644 --- a/src/llamafactory/webui/components/chatbot.py +++ b/src/llamafactory/webui/components/chatbot.py @@ -13,7 +13,7 @@ # limitations under the License. import json -from typing import TYPE_CHECKING, Dict, Tuple +from typing import TYPE_CHECKING from ...data import Role from ...extras.packages import is_gradio_available @@ -31,9 +31,7 @@ if TYPE_CHECKING: def check_json_schema(text: str, lang: str) -> None: - r""" - Checks if the json schema is valid. - """ + r"""Check if the json schema is valid.""" try: tools = json.loads(text) if tools: @@ -49,7 +47,7 @@ def check_json_schema(text: str, lang: str) -> None: def create_chat_box( engine: "Engine", visible: bool = False -) -> Tuple["Component", "Component", Dict[str, "Component"]]: +) -> tuple["Component", "Component", dict[str, "Component"]]: lang = engine.manager.get_elem_by_id("top.lang") with gr.Column(visible=visible) as chat_box: chatbot = gr.Chatbot(type="messages", show_copy_button=True) diff --git a/src/llamafactory/webui/components/data.py b/src/llamafactory/webui/components/data.py index 1dbc68d5..8f27bd19 100644 --- a/src/llamafactory/webui/components/data.py +++ b/src/llamafactory/webui/components/data.py @@ -14,7 +14,7 @@ import json import os -from typing import TYPE_CHECKING, Any, Dict, List, Tuple +from typing import TYPE_CHECKING, Any from ...extras.constants import DATA_CONFIG from ...extras.packages import is_gradio_available @@ -40,9 +40,7 @@ def next_page(page_index: int, total_num: int) -> int: def can_preview(dataset_dir: str, dataset: list) -> "gr.Button": - r""" - Checks if the dataset is a local dataset. - """ + r"""Check if the dataset is a local dataset.""" try: with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f: dataset_info = json.load(f) @@ -59,7 +57,7 @@ def can_preview(dataset_dir: str, dataset: list) -> "gr.Button": return gr.Button(interactive=False) -def _load_data_file(file_path: str) -> List[Any]: +def _load_data_file(file_path: str) -> list[Any]: with open(file_path, encoding="utf-8") as f: if file_path.endswith(".json"): return json.load(f) @@ -69,10 +67,8 @@ def _load_data_file(file_path: str) -> List[Any]: return list(f) -def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, "gr.Column"]: - r""" - Gets the preview samples from the dataset. - """ +def get_preview(dataset_dir: str, dataset: list, page_index: int) -> tuple[int, list, "gr.Column"]: + r"""Get the preview samples from the dataset.""" with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f: dataset_info = json.load(f) @@ -87,7 +83,7 @@ def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.Column(visible=True) -def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dict[str, "Component"]: +def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> dict[str, "Component"]: data_preview_btn = gr.Button(interactive=False, scale=1) with gr.Column(visible=False, elem_classes="modal-box") as preview_box: with gr.Row(): diff --git a/src/llamafactory/webui/components/eval.py b/src/llamafactory/webui/components/eval.py index 7be0a5b4..3804a77d 100644 --- a/src/llamafactory/webui/components/eval.py +++ b/src/llamafactory/webui/components/eval.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Dict +from typing import TYPE_CHECKING from ...extras.packages import is_gradio_available from ..common import DEFAULT_DATA_DIR @@ -30,7 +30,7 @@ if TYPE_CHECKING: from ..engine import Engine -def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]: +def create_eval_tab(engine: "Engine") -> dict[str, "Component"]: input_elems = engine.manager.get_base_elems() elem_dict = dict() diff --git a/src/llamafactory/webui/components/export.py b/src/llamafactory/webui/components/export.py index c5034222..d292ee4f 100644 --- a/src/llamafactory/webui/components/export.py +++ b/src/llamafactory/webui/components/export.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Dict, Generator, List, Union +from collections.abc import Generator +from typing import TYPE_CHECKING, Union from ...extras.constants import PEFT_METHODS from ...extras.misc import torch_gc @@ -35,7 +36,7 @@ if TYPE_CHECKING: GPTQ_BITS = ["8", "4", "3", "2"] -def can_quantize(checkpoint_path: Union[str, List[str]]) -> "gr.Dropdown": +def can_quantize(checkpoint_path: Union[str, list[str]]) -> "gr.Dropdown": if isinstance(checkpoint_path, list) and len(checkpoint_path) != 0: return gr.Dropdown(value="none", interactive=False) else: @@ -47,7 +48,7 @@ def save_model( model_name: str, model_path: str, finetuning_type: str, - checkpoint_path: Union[str, List[str]], + checkpoint_path: Union[str, list[str]], template: str, export_size: int, export_quantization_bit: str, @@ -106,7 +107,7 @@ def save_model( yield ALERTS["info_exported"][lang] -def create_export_tab(engine: "Engine") -> Dict[str, "Component"]: +def create_export_tab(engine: "Engine") -> dict[str, "Component"]: with gr.Row(): export_size = gr.Slider(minimum=1, maximum=100, value=5, step=1) export_quantization_bit = gr.Dropdown(choices=["none"] + GPTQ_BITS, value="none") diff --git a/src/llamafactory/webui/components/infer.py b/src/llamafactory/webui/components/infer.py index 48536a36..5d0f4e88 100644 --- a/src/llamafactory/webui/components/infer.py +++ b/src/llamafactory/webui/components/infer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Dict +from typing import TYPE_CHECKING from ...extras.packages import is_gradio_available from ..common import is_multimodal @@ -29,7 +29,7 @@ if TYPE_CHECKING: from ..engine import Engine -def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]: +def create_infer_tab(engine: "Engine") -> dict[str, "Component"]: input_elems = engine.manager.get_base_elems() elem_dict = dict() diff --git a/src/llamafactory/webui/components/top.py b/src/llamafactory/webui/components/top.py index 978f93cd..f3616455 100644 --- a/src/llamafactory/webui/components/top.py +++ b/src/llamafactory/webui/components/top.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Dict +from typing import TYPE_CHECKING from ...data import TEMPLATES from ...extras.constants import METHODS, SUPPORTED_MODELS @@ -29,7 +29,7 @@ if TYPE_CHECKING: from gradio.components import Component -def create_top() -> Dict[str, "Component"]: +def create_top() -> dict[str, "Component"]: with gr.Row(): lang = gr.Dropdown(choices=["en", "ru", "zh", "ko", "ja"], value=None, scale=1) available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"] diff --git a/src/llamafactory/webui/components/train.py b/src/llamafactory/webui/components/train.py index bb40027a..7ca99647 100644 --- a/src/llamafactory/webui/components/train.py +++ b/src/llamafactory/webui/components/train.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Dict +from typing import TYPE_CHECKING from transformers.trainer_utils import SchedulerType @@ -34,7 +34,7 @@ if TYPE_CHECKING: from ..engine import Engine -def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: +def create_train_tab(engine: "Engine") -> dict[str, "Component"]: input_elems = engine.manager.get_base_elems() elem_dict = dict() @@ -382,8 +382,8 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None) lang = engine.manager.get_elem_by_id("top.lang") - model_name: "gr.Dropdown" = engine.manager.get_elem_by_id("top.model_name") - finetuning_type: "gr.Dropdown" = engine.manager.get_elem_by_id("top.finetuning_type") + model_name: gr.Dropdown = engine.manager.get_elem_by_id("top.model_name") + finetuning_type: gr.Dropdown = engine.manager.get_elem_by_id("top.finetuning_type") arg_save_btn.click(engine.runner.save_args, input_elems, output_elems, concurrency_limit=None) arg_load_btn.click( diff --git a/src/llamafactory/webui/control.py b/src/llamafactory/webui/control.py index 21997964..08aed40d 100644 --- a/src/llamafactory/webui/control.py +++ b/src/llamafactory/webui/control.py @@ -14,7 +14,7 @@ import json import os -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Optional from transformers.trainer_utils import get_last_checkpoint @@ -39,8 +39,7 @@ if is_gradio_available(): def can_quantize(finetuning_type: str) -> "gr.Dropdown": - r""" - Judges if the quantization is available in this finetuning type. + r"""Judge if the quantization is available in this finetuning type. Inputs: top.finetuning_type Outputs: top.quantization_bit @@ -52,8 +51,7 @@ def can_quantize(finetuning_type: str) -> "gr.Dropdown": def can_quantize_to(quantization_method: str) -> "gr.Dropdown": - r""" - Gets the available quantization bits. + r"""Get the available quantization bits. Inputs: top.quantization_method Outputs: top.quantization_bit @@ -68,9 +66,8 @@ def can_quantize_to(quantization_method: str) -> "gr.Dropdown": return gr.Dropdown(choices=available_bits) -def change_stage(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Tuple[List[str], bool]: - r""" - Modifys states after changing the training stage. +def change_stage(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> tuple[list[str], bool]: + r"""Modify states after changing the training stage. Inputs: train.training_stage Outputs: train.dataset, train.packing @@ -78,9 +75,8 @@ def change_stage(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Tuple return [], TRAINING_STAGES[training_stage] == "pt" -def get_model_info(model_name: str) -> Tuple[str, str]: - r""" - Gets the necessary information of this model. +def get_model_info(model_name: str) -> tuple[str, str]: + r"""Get the necessary information of this model. Inputs: top.model_name Outputs: top.model_path, top.template @@ -88,9 +84,8 @@ def get_model_info(model_name: str) -> Tuple[str, str]: return get_model_path(model_name), get_template(model_name) -def get_trainer_info(lang: str, output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr.Slider", Dict[str, Any]]: - r""" - Gets training infomation for monitor. +def get_trainer_info(lang: str, output_path: os.PathLike, do_train: bool) -> tuple[str, "gr.Slider", dict[str, Any]]: + r"""Get training infomation for monitor. If do_train is True: Inputs: top.lang, train.output_path @@ -110,7 +105,7 @@ def get_trainer_info(lang: str, output_path: os.PathLike, do_train: bool) -> Tup trainer_log_path = os.path.join(output_path, TRAINER_LOG) if os.path.isfile(trainer_log_path): - trainer_log: List[Dict[str, Any]] = [] + trainer_log: list[dict[str, Any]] = [] with open(trainer_log_path, encoding="utf-8") as f: for line in f: trainer_log.append(json.loads(line)) @@ -143,8 +138,7 @@ def get_trainer_info(lang: str, output_path: os.PathLike, do_train: bool) -> Tup def list_checkpoints(model_name: str, finetuning_type: str) -> "gr.Dropdown": - r""" - Lists all available checkpoints. + r"""List all available checkpoints. Inputs: top.model_name, top.finetuning_type Outputs: top.checkpoint_path @@ -166,8 +160,7 @@ def list_checkpoints(model_name: str, finetuning_type: str) -> "gr.Dropdown": def list_config_paths(current_time: str) -> "gr.Dropdown": - r""" - Lists all the saved configuration files. + r"""List all the saved configuration files. Inputs: train.current_time Outputs: train.config_path @@ -182,8 +175,7 @@ def list_config_paths(current_time: str) -> "gr.Dropdown": def list_datasets(dataset_dir: str = None, training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Dropdown": - r""" - Lists all available datasets in the dataset dir for the training stage. + r"""List all available datasets in the dataset dir for the training stage. Inputs: *.dataset_dir, *.training_stage Outputs: *.dataset @@ -195,8 +187,7 @@ def list_datasets(dataset_dir: str = None, training_stage: str = list(TRAINING_S def list_output_dirs(model_name: Optional[str], finetuning_type: str, current_time: str) -> "gr.Dropdown": - r""" - Lists all the directories that can resume from. + r"""List all the directories that can resume from. Inputs: top.model_name, top.finetuning_type, train.current_time Outputs: train.output_dir diff --git a/src/llamafactory/webui/engine.py b/src/llamafactory/webui/engine.py index 2708139d..8844cacf 100644 --- a/src/llamafactory/webui/engine.py +++ b/src/llamafactory/webui/engine.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Dict +from typing import TYPE_CHECKING, Any from .chatter import WebChatModel from .common import create_ds_config, get_time, load_config @@ -26,9 +26,7 @@ if TYPE_CHECKING: class Engine: - r""" - A general engine to control the behaviors of Web UI. - """ + r"""A general engine to control the behaviors of Web UI.""" def __init__(self, demo_mode: bool = False, pure_chat: bool = False) -> None: self.demo_mode = demo_mode @@ -39,11 +37,9 @@ class Engine: if not demo_mode: create_ds_config() - def _update_component(self, input_dict: Dict[str, Dict[str, Any]]) -> Dict["Component", "Component"]: - r""" - Updates gradio components according to the (elem_id, properties) mapping. - """ - output_dict: Dict["Component", "Component"] = {} + def _update_component(self, input_dict: dict[str, dict[str, Any]]) -> dict["Component", "Component"]: + r"""Update gradio components according to the (elem_id, properties) mapping.""" + output_dict: dict[Component, Component] = {} for elem_id, elem_attr in input_dict.items(): elem = self.manager.get_elem_by_id(elem_id) output_dict[elem] = elem.__class__(**elem_attr) @@ -51,9 +47,7 @@ class Engine: return output_dict def resume(self): - r""" - Gets the initial value of gradio components and restores training status if necessary. - """ + r"""Get the initial value of gradio components and restores training status if necessary.""" user_config = load_config() if not self.demo_mode else {} # do not use config in demo mode lang = user_config.get("lang", None) or "en" init_dict = {"top.lang": {"value": lang}, "infer.chat_box": {"visible": self.chatter.loaded}} @@ -79,9 +73,7 @@ class Engine: yield self._update_component({"eval.resume_btn": {"value": True}}) def change_lang(self, lang: str): - r""" - Updates the displayed language of gradio components. - """ + r"""Update the displayed language of gradio components.""" return { elem: elem.__class__(**LOCALES[elem_name][lang]) for elem_name, elem in self.manager.get_elem_iter() diff --git a/src/llamafactory/webui/interface.py b/src/llamafactory/webui/interface.py index c6722fb8..325ee623 100644 --- a/src/llamafactory/webui/interface.py +++ b/src/llamafactory/webui/interface.py @@ -48,7 +48,7 @@ def create_ui(demo_mode: bool = False) -> "gr.Blocks": gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button") engine.manager.add_elems("top", create_top()) - lang: "gr.Dropdown" = engine.manager.get_elem_by_id("top.lang") + lang: gr.Dropdown = engine.manager.get_elem_by_id("top.lang") with gr.Tab("Train"): engine.manager.add_elems("train", create_train_tab(engine)) diff --git a/src/llamafactory/webui/manager.py b/src/llamafactory/webui/manager.py index 3b6f5a9a..e762fa6b 100644 --- a/src/llamafactory/webui/manager.py +++ b/src/llamafactory/webui/manager.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Dict, Generator, List, Set, Tuple +from collections.abc import Generator +from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -20,54 +21,41 @@ if TYPE_CHECKING: class Manager: - r""" - A class to manage all the gradio components in Web UI. - """ + r"""A class to manage all the gradio components in Web UI.""" def __init__(self) -> None: - self._id_to_elem: Dict[str, "Component"] = {} - self._elem_to_id: Dict["Component", str] = {} + self._id_to_elem: dict[str, Component] = {} + self._elem_to_id: dict[Component, str] = {} - def add_elems(self, tab_name: str, elem_dict: Dict[str, "Component"]) -> None: - r""" - Adds elements to manager. - """ + def add_elems(self, tab_name: str, elem_dict: dict[str, "Component"]) -> None: + r"""Add elements to manager.""" for elem_name, elem in elem_dict.items(): elem_id = f"{tab_name}.{elem_name}" self._id_to_elem[elem_id] = elem self._elem_to_id[elem] = elem_id - def get_elem_list(self) -> List["Component"]: - r""" - Returns the list of all elements. - """ + def get_elem_list(self) -> list["Component"]: + r"""Return the list of all elements.""" return list(self._id_to_elem.values()) - def get_elem_iter(self) -> Generator[Tuple[str, "Component"], None, None]: - r""" - Returns an iterator over all elements with their names. - """ + def get_elem_iter(self) -> Generator[tuple[str, "Component"], None, None]: + r"""Return an iterator over all elements with their names.""" for elem_id, elem in self._id_to_elem.items(): yield elem_id.split(".")[-1], elem def get_elem_by_id(self, elem_id: str) -> "Component": - r""" - Gets element by id. + r"""Get element by id. Example: top.lang, train.dataset """ return self._id_to_elem[elem_id] def get_id_by_elem(self, elem: "Component") -> str: - r""" - Gets id by element. - """ + r"""Get id by element.""" return self._elem_to_id[elem] - def get_base_elems(self) -> Set["Component"]: - r""" - Gets the base elements that are commonly used. - """ + def get_base_elems(self) -> set["Component"]: + r"""Get the base elements that are commonly used.""" return { self._id_to_elem["top.lang"], self._id_to_elem["top.model_name"], diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py index 9fe92f27..8f9176d2 100644 --- a/src/llamafactory/webui/runner.py +++ b/src/llamafactory/webui/runner.py @@ -14,9 +14,10 @@ import json import os +from collections.abc import Generator from copy import deepcopy from subprocess import Popen, TimeoutExpired -from typing import TYPE_CHECKING, Any, Dict, Generator, Optional +from typing import TYPE_CHECKING, Any, Optional from transformers.trainer import TRAINING_ARGS_NAME from transformers.utils import is_torch_npu_available @@ -51,17 +52,16 @@ if TYPE_CHECKING: class Runner: - r""" - A class to manage the running status of the trainers. - """ + r"""A class to manage the running status of the trainers.""" def __init__(self, manager: "Manager", demo_mode: bool = False) -> None: + r"""Init a runner.""" self.manager = manager self.demo_mode = demo_mode """ Resume """ - self.trainer: Optional["Popen"] = None + self.trainer: Optional[Popen] = None self.do_train = True - self.running_data: Dict["Component", Any] = None + self.running_data: dict[Component, Any] = None """ State """ self.aborted = False self.running = False @@ -71,10 +71,8 @@ class Runner: if self.trainer is not None: abort_process(self.trainer.pid) - def _initialize(self, data: Dict["Component", Any], do_train: bool, from_preview: bool) -> str: - r""" - Validates the configuration. - """ + def _initialize(self, data: dict["Component", Any], do_train: bool, from_preview: bool) -> str: + r"""Validate the configuration.""" get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)] lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path") dataset = get("train.dataset") if do_train else get("eval.dataset") @@ -116,9 +114,7 @@ class Runner: return "" def _finalize(self, lang: str, finish_info: str) -> str: - r""" - Cleans the cached memory and resets the runner. - """ + r"""Clean the cached memory and resets the runner.""" finish_info = ALERTS["info_aborted"][lang] if self.aborted else finish_info gr.Info(finish_info) self.trainer = None @@ -128,10 +124,8 @@ class Runner: torch_gc() return finish_info - def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]: - r""" - Builds and validates the training arguments. - """ + def _parse_train_args(self, data: dict["Component", Any]) -> dict[str, Any]: + r"""Build and validate the training arguments.""" get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)] model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type") user_config = load_config() @@ -291,10 +285,8 @@ class Runner: return args - def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]: - r""" - Builds and validates the evaluation arguments. - """ + def _parse_eval_args(self, data: dict["Component", Any]) -> dict[str, Any]: + r"""Build and validate the evaluation arguments.""" get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)] model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type") user_config = load_config() @@ -345,10 +337,8 @@ class Runner: return args - def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", str], None, None]: - r""" - Previews the training commands. - """ + def _preview(self, data: dict["Component", Any], do_train: bool) -> Generator[dict["Component", str], None, None]: + r"""Preview the training commands.""" output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval")) error = self._initialize(data, do_train, from_preview=True) if error: @@ -358,10 +348,8 @@ class Runner: args = self._parse_train_args(data) if do_train else self._parse_eval_args(data) yield {output_box: gen_cmd(args)} - def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", Any], None, None]: - r""" - Starts the training process. - """ + def _launch(self, data: dict["Component", Any], do_train: bool) -> Generator[dict["Component", Any], None, None]: + r"""Start the training process.""" output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval")) error = self._initialize(data, do_train, from_preview=False) if error: @@ -383,10 +371,8 @@ class Runner: self.trainer = Popen(["llamafactory-cli", "train", save_cmd(args)], env=env) yield from self.monitor() - def _build_config_dict(self, data: Dict["Component", Any]) -> Dict[str, Any]: - r""" - Builds a dictionary containing the current training configuration. - """ + def _build_config_dict(self, data: dict["Component", Any]) -> dict[str, Any]: + r"""Build a dictionary containing the current training configuration.""" config_dict = {} skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path"] for elem, value in data.items(): @@ -409,9 +395,7 @@ class Runner: yield from self._launch(data, do_train=False) def monitor(self): - r""" - Monitors the training progress and logs. - """ + r"""Monitorgit the training progress and logs.""" self.aborted = False self.running = True @@ -469,9 +453,7 @@ class Runner: yield return_dict def save_args(self, data): - r""" - Saves the training configuration to config path. - """ + r"""Save the training configuration to config path.""" output_box = self.manager.get_elem_by_id("train.output_box") error = self._initialize(data, do_train=True, from_preview=True) if error: @@ -487,27 +469,23 @@ class Runner: return {output_box: ALERTS["info_config_saved"][lang] + save_path} def load_args(self, lang: str, config_path: str): - r""" - Loads the training configuration from config path. - """ + r"""Load the training configuration from config path.""" output_box = self.manager.get_elem_by_id("train.output_box") config_dict = load_args(os.path.join(DEFAULT_CONFIG_DIR, config_path)) if config_dict is None: gr.Warning(ALERTS["err_config_not_found"][lang]) return {output_box: ALERTS["err_config_not_found"][lang]} - output_dict: Dict["Component", Any] = {output_box: ALERTS["info_config_loaded"][lang]} + output_dict: dict[Component, Any] = {output_box: ALERTS["info_config_loaded"][lang]} for elem_id, value in config_dict.items(): output_dict[self.manager.get_elem_by_id(elem_id)] = value return output_dict def check_output_dir(self, lang: str, model_name: str, finetuning_type: str, output_dir: str): - r""" - Restore the training status if output_dir exists. - """ + r"""Restore the training status if output_dir exists.""" output_box = self.manager.get_elem_by_id("train.output_box") - output_dict: Dict["Component", Any] = {output_box: LOCALES["output_box"][lang]["value"]} + output_dict: dict[Component, Any] = {output_box: LOCALES["output_box"][lang]["value"]} if model_name and output_dir and os.path.isdir(get_save_dir(model_name, finetuning_type, output_dir)): gr.Warning(ALERTS["warn_output_dir_exists"][lang]) output_dict[output_box] = ALERTS["warn_output_dir_exists"][lang] diff --git a/tests/data/processor/test_pairwise.py b/tests/data/processor/test_pairwise.py index 3faac9a7..569a55ab 100644 --- a/tests/data/processor/test_pairwise.py +++ b/tests/data/processor/test_pairwise.py @@ -14,7 +14,6 @@ import os import random -from typing import Dict, List import pytest from datasets import load_dataset @@ -43,7 +42,7 @@ TRAIN_ARGS = { } -def _convert_sharegpt_to_openai(messages: List[Dict[str, str]]) -> List[Dict[str, str]]: +def _convert_sharegpt_to_openai(messages: list[dict[str, str]]) -> list[dict[str, str]]: role_mapping = {"human": "user", "gpt": "assistant", "system": "system"} new_messages = [] for message in messages: diff --git a/tests/data/processor/test_processor_utils.py b/tests/data/processor/test_processor_utils.py index 64d2ab91..e004cb06 100644 --- a/tests/data/processor/test_processor_utils.py +++ b/tests/data/processor/test_processor_utils.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import pytest @@ -31,5 +30,5 @@ from llamafactory.data.processor.processor_utils import infer_seqlen ((10, 10, 1000), (10, 10)), ], ) -def test_infer_seqlen(test_input: Tuple[int, int, int], test_output: Tuple[int, int]): +def test_infer_seqlen(test_input: tuple[int, int, int], test_output: tuple[int, int]): assert test_output == infer_seqlen(*test_input) diff --git a/tests/data/test_formatter.py b/tests/data/test_formatter.py index 542bafb9..7e08465a 100644 --- a/tests/data/test_formatter.py +++ b/tests/data/test_formatter.py @@ -112,7 +112,8 @@ def test_glm4_tool_formatter(): assert formatter.apply(content=json.dumps(TOOLS)) == [ "你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的," "你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具\n\n" - f"## test_tool\n\n{json.dumps(TOOLS[0], indent=4, ensure_ascii=False)}\n在调用上述函数时,请使用 Json 格式表示调用的参数。" + f"## test_tool\n\n{json.dumps(TOOLS[0], indent=4, ensure_ascii=False)}\n" + "在调用上述函数时,请使用 Json 格式表示调用的参数。" ] @@ -136,7 +137,8 @@ def test_llama3_tool_formatter(): wrapped_tool = {"type": "function", "function": TOOLS[0]} assert formatter.apply(content=json.dumps(TOOLS)) == [ f"Cutting Knowledge Date: December 2023\nToday Date: {date}\n\n" - "You have access to the following functions. To call a function, please respond with JSON for a function call. " + "You have access to the following functions. " + "To call a function, please respond with JSON for a function call. " """Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. """ f"Do not use variables.\n\n{json.dumps(wrapped_tool, indent=4, ensure_ascii=False)}\n\n" ] diff --git a/tests/data/test_mm_plugin.py b/tests/data/test_mm_plugin.py index 0d26acd1..c47842bc 100644 --- a/tests/data/test_mm_plugin.py +++ b/tests/data/test_mm_plugin.py @@ -13,7 +13,8 @@ # limitations under the License. import os -from typing import TYPE_CHECKING, Any, Dict, List, Sequence +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any import pytest import torch @@ -69,12 +70,12 @@ LABELS = [0, 1, 2, 3, 4] BATCH_IDS = [[1] * 1024] -def _get_mm_inputs(processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]: - image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") +def _get_mm_inputs(processor: "ProcessorMixin") -> dict[str, "torch.Tensor"]: + image_processor: BaseImageProcessor = getattr(processor, "image_processor") return image_processor(images=IMAGES, return_tensors="pt") -def _is_close(batch_a: Dict[str, Any], batch_b: Dict[str, Any]) -> None: +def _is_close(batch_a: dict[str, Any], batch_b: dict[str, Any]) -> None: assert batch_a.keys() == batch_b.keys() for key in batch_a.keys(): if isinstance(batch_a[key], torch.Tensor): @@ -96,11 +97,11 @@ def _check_plugin( plugin: "BasePlugin", tokenizer: "PreTrainedTokenizer", processor: "ProcessorMixin", - expected_mm_messages: Sequence[Dict[str, str]] = MM_MESSAGES, - expected_input_ids: List[int] = INPUT_IDS, - expected_labels: List[int] = LABELS, - expected_mm_inputs: Dict[str, Any] = {}, - expected_no_mm_inputs: Dict[str, Any] = {}, + expected_mm_messages: Sequence[dict[str, str]] = MM_MESSAGES, + expected_input_ids: list[int] = INPUT_IDS, + expected_labels: list[int] = LABELS, + expected_mm_inputs: dict[str, Any] = {}, + expected_no_mm_inputs: dict[str, Any] = {}, ) -> None: # test mm_messages if plugin.__class__.__name__ != "BasePlugin": diff --git a/tests/data/test_template.py b/tests/data/test_template.py index e2e6b942..b3f2052e 100644 --- a/tests/data/test_template.py +++ b/tests/data/test_template.py @@ -13,7 +13,8 @@ # limitations under the License. import os -from typing import TYPE_CHECKING, Sequence +from collections.abc import Sequence +from typing import TYPE_CHECKING import pytest from transformers import AutoTokenizer @@ -42,8 +43,7 @@ MESSAGES = [ def _check_tokenization( tokenizer: "PreTrainedTokenizer", batch_input_ids: Sequence[Sequence[int]], batch_text: Sequence[str] ) -> None: - r""" - Checks token ids and texts. + r"""Check token ids and texts. encode(text) == token_ids decode(token_ids) == text @@ -54,8 +54,7 @@ def _check_tokenization( def _check_template(model_id: str, template_name: str, prompt_str: str, answer_str: str, use_fast: bool) -> None: - r""" - Checks template. + r"""Check template. Args: model_id: the model id on hugging face hub. @@ -63,6 +62,7 @@ def _check_template(model_id: str, template_name: str, prompt_str: str, answer_s prompt_str: the string corresponding to the prompt part. answer_str: the string corresponding to the answer part. use_fast: whether to use fast tokenizer. + """ tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, token=HF_TOKEN) content_str = tokenizer.apply_chat_template(MESSAGES, tokenize=False) diff --git a/tests/model/model_utils/test_checkpointing.py b/tests/model/model_utils/test_checkpointing.py index f0246016..36c20434 100644 --- a/tests/model/model_utils/test_checkpointing.py +++ b/tests/model/model_utils/test_checkpointing.py @@ -62,5 +62,5 @@ def test_upcast_layernorm(): def test_upcast_lmhead_output(): model = load_train_model(upcast_lmhead_output=True, **TRAIN_ARGS) inputs = torch.randn((1, 16), dtype=torch.float16, device=get_current_device()) - outputs: "torch.Tensor" = model.get_output_embeddings()(inputs) + outputs: torch.Tensor = model.get_output_embeddings()(inputs) assert outputs.dtype == torch.float32 diff --git a/tests/train/test_sft_trainer.py b/tests/train/test_sft_trainer.py index 1f84071e..c520bb3a 100644 --- a/tests/train/test_sft_trainer.py +++ b/tests/train/test_sft_trainer.py @@ -14,7 +14,7 @@ import os from dataclasses import dataclass, field -from typing import Any, Dict, List +from typing import Any import pytest from transformers import DataCollatorWithPadding @@ -46,9 +46,9 @@ TRAIN_ARGS = { @dataclass class DataCollatorWithVerbose(DataCollatorWithPadding): - verbose_list: List[Dict[str, Any]] = field(default_factory=list) + verbose_list: list[dict[str, Any]] = field(default_factory=list) - def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: + def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]: self.verbose_list.extend(features) batch = super().__call__(features) return {k: v[:, :1] for k, v in batch.items()} # truncate input length