[misc] upgrade format to py39 (#7256)

This commit is contained in:
hoshi-hiyouga 2025-03-12 00:08:41 +08:00 committed by GitHub
parent cdafa8a15e
commit 7c1640ed5f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
113 changed files with 984 additions and 1407 deletions

View File

@ -10,7 +10,7 @@ _DESCRIPTION = "BELLE multiturn chat dataset."
_CITATION = """\ _CITATION = """\
@article{belle2023exploring, @article{belle2023exploring,
title={Exploring the Impact of Instruction Data Scaling on Large Language Models: An Empirical Study on Real-World Use Cases}, title={Exploring the Impact of Instruction Data Scaling on Large Language Models},
author={Yunjie Ji, Yong Deng, Yan Gong, Yiping Peng, Qiang Niu, Lei Zhang, Baochang Ma, Xiangang Li}, author={Yunjie Ji, Yong Deng, Yan Gong, Yiping Peng, Qiang Niu, Lei Zhang, Baochang Ma, Xiangang Li},
journal={arXiv preprint arXiv:2303.14742}, journal={arXiv preprint arXiv:2303.14742},
year={2023} year={2023}

View File

@ -1,6 +1,5 @@
import json import json
import os import os
from typing import List
import datasets import datasets
@ -50,7 +49,7 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
datasets.SplitGenerator(name=datasets.Split.TEST, gen_kwargs={"filepaths": file_path["test"]}), datasets.SplitGenerator(name=datasets.Split.TEST, gen_kwargs={"filepaths": file_path["test"]}),
] ]
def _generate_examples(self, filepaths: List[str]): def _generate_examples(self, filepaths: list[str]):
key = 0 key = 0
for filepath in filepaths: for filepath in filepaths:
with open(filepath, encoding="utf-8") as f: with open(filepath, encoding="utf-8") as f:

View File

@ -1,6 +1,5 @@
import json import json
import os import os
from typing import List
import datasets import datasets
@ -11,7 +10,7 @@ _DESCRIPTION = "UltraChat: Large-scale, Informative, and Diverse Multi-round Dia
_CITATION = """\ _CITATION = """\
@misc{UltraChat, @misc{UltraChat,
author = {Ding, Ning and Chen, Yulin and Xu, Bokai and Hu, Shengding and Qin, Yujia and Liu, Zhiyuan and Sun, Maosong and Zhou, Bowen}, author = {Ding, Ning and Chen, Yulin and Xu, Bokai and Hu, Shengding and others},
title = {UltraChat: A Large-scale Auto-generated Multi-round Dialogue Data}, title = {UltraChat: A Large-scale Auto-generated Multi-round Dialogue Data},
year = {2023}, year = {2023},
publisher = {GitHub}, publisher = {GitHub},
@ -40,7 +39,7 @@ class UltraChat(datasets.GeneratorBasedBuilder):
file_paths = [dl_manager.download(_BASE_DATA_URL.format(idx=idx)) for idx in range(10)] # multiple shards file_paths = [dl_manager.download(_BASE_DATA_URL.format(idx=idx)) for idx in range(10)] # multiple shards
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepaths": file_paths})] return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepaths": file_paths})]
def _generate_examples(self, filepaths: List[str]): def _generate_examples(self, filepaths: list[str]):
for filepath in filepaths: for filepath in filepaths:
with open(filepath, encoding="utf-8") as f: with open(filepath, encoding="utf-8") as f:
for row in f: for row in f:
@ -49,7 +48,7 @@ class UltraChat(datasets.GeneratorBasedBuilder):
except Exception: except Exception:
continue continue
key: int = data["id"] key: int = data["id"]
content: List[str] = data["data"] content: list[str] = data["data"]
if len(content) % 2 == 1: if len(content) % 2 == 1:
content.pop(-1) content.pop(-1)
if len(content) < 2: if len(content) < 2:

View File

@ -21,14 +21,15 @@ import pandas as pd
_CITATION = """\ _CITATION = """\
@article{huang2023ceval, @article{huang2023ceval,
title={C-Eval: A Multi-Level Multi-Discipline Chinese Evaluation Suite for Foundation Models}, title={C-Eval: A Multi-Level Multi-Discipline Chinese Evaluation Suite for Foundation Models},
author={Huang, Yuzhen and Bai, Yuzhuo and Zhu, Zhihao and Zhang, Junlei and Zhang, Jinghan and Su, Tangjun and Liu, Junteng and Lv, Chuancheng and Zhang, Yikai and Lei, Jiayi and Fu, Yao and Sun, Maosong and He, Junxian}, author={Huang, Yuzhen and Bai, Yuzhuo and Zhu, Zhihao and others},
journal={arXiv preprint arXiv:2305.08322}, journal={arXiv preprint arXiv:2305.08322},
year={2023} year={2023}
} }
""" """
_DESCRIPTION = """\ _DESCRIPTION = """\
C-Eval is a comprehensive Chinese evaluation suite for foundation models. It consists of 13948 multi-choice questions spanning 52 diverse disciplines and four difficulty levels. C-Eval is a comprehensive Chinese evaluation suite for foundation models.
It consists of 13948 multi-choice questions spanning 52 diverse disciplines and four difficulty levels.
""" """
_HOMEPAGE = "https://cevalbenchmark.com" _HOMEPAGE = "https://cevalbenchmark.com"

View File

@ -21,14 +21,15 @@ import pandas as pd
_CITATION = """\ _CITATION = """\
@article{li2023cmmlu, @article{li2023cmmlu,
title={CMMLU: Measuring massive multitask language understanding in Chinese}, title={CMMLU: Measuring massive multitask language understanding in Chinese},
author={Haonan Li and Yixuan Zhang and Fajri Koto and Yifei Yang and Hai Zhao and Yeyun Gong and Nan Duan and Timothy Baldwin}, author={Haonan Li and Yixuan Zhang and Fajri Koto and Yifei Yang and others,
journal={arXiv preprint arXiv:2306.09212}, journal={arXiv preprint arXiv:2306.09212},
year={2023} year={2023}
} }
""" """
_DESCRIPTION = """\ _DESCRIPTION = """\
CMMLU is a comprehensive Chinese assessment suite specifically designed to evaluate the advanced knowledge and reasoning abilities of LLMs within the Chinese language and cultural context. CMMLU is a comprehensive Chinese assessment suite specifically designed to evaluate the advanced knowledge
and reasoning abilities of LLMs within the Chinese language and cultural context.
""" """
_HOMEPAGE = "https://github.com/haonan-li/CMMLU" _HOMEPAGE = "https://github.com/haonan-li/CMMLU"

View File

@ -21,14 +21,15 @@ import pandas as pd
_CITATION = """\ _CITATION = """\
@article{hendryckstest2021, @article{hendryckstest2021,
title={Measuring Massive Multitask Language Understanding}, title={Measuring Massive Multitask Language Understanding},
author={Dan Hendrycks and Collin Burns and Steven Basart and Andy Zou and Mantas Mazeika and Dawn Song and Jacob Steinhardt}, author={Dan Hendrycks and Collin Burns and others},
journal={Proceedings of the International Conference on Learning Representations (ICLR)}, journal={Proceedings of the International Conference on Learning Representations (ICLR)},
year={2021} year={2021}
} }
""" """
_DESCRIPTION = """\ _DESCRIPTION = """\
Measuring Massive Multitask Language Understanding by Dan Hendrycks, Collin Burns, Steven Basart, Andy Zou, Mantas Mazeika, Dawn Song, and Jacob Steinhardt (ICLR 2021). Measuring Massive Multitask Language Understanding by Dan Hendrycks, Collin Burns, Steven Basart,
Andy Zou, Mantas Mazeika, Dawn Song, and Jacob Steinhardt (ICLR 2021).
""" """
_HOMEPAGE = "https://github.com/hendrycks/test" _HOMEPAGE = "https://github.com/hendrycks/test"

View File

@ -19,13 +19,35 @@ dynamic = [
] ]
[tool.ruff] [tool.ruff]
target-version = "py38" target-version = "py39"
line-length = 119 line-length = 119
indent-width = 4 indent-width = 4
[tool.ruff.lint] [tool.ruff.lint]
ignore = ["C408", "C901", "E501", "E731", "E741", "W605"] ignore = [
select = ["C", "E", "F", "I", "W"] "C408", # collection
"C901", # complex
"E731", # lambda function
"E741", # ambiguous var name
"D100", # no doc public module
"D101", # no doc public class
"D102", # no doc public method
"D103", # no doc public function
"D104", # no doc public package
"D105", # no doc magic method
"D107", # no doc __init__
]
extend-select = [
"C", # complexity
"E", # error
"F", # pyflakes
"I", # isort
"W", # warning
"UP", # pyupgrade
"D", # pydocstyle
"PT009", # pytest assert
"RUF022", # sort __all__
]
[tool.ruff.lint.isort] [tool.ruff.lint.isort]
lines-after-imports = 2 lines-after-imports = 2
@ -41,6 +63,9 @@ known-third-party = [
"trl" "trl"
] ]
[tool.ruff.lint.pydocstyle]
convention = "google"
[tool.ruff.format] [tool.ruff.format]
quote-style = "double" quote-style = "double"
indent-style = "space" indent-style = "space"

View File

@ -14,7 +14,7 @@
import json import json
import os import os
from typing import Sequence from collections.abc import Sequence
from openai import OpenAI from openai import OpenAI
from transformers.utils.versions import require_version from transformers.utils.versions import require_version

View File

@ -15,7 +15,7 @@
import json import json
import os import os
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Dict from typing import Any
import fire import fire
import torch import torch
@ -29,13 +29,13 @@ CONFIG_NAME = "config.json"
def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetensors: bool): def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetensors: bool):
baichuan2_state_dict: Dict[str, torch.Tensor] = OrderedDict() baichuan2_state_dict: dict[str, torch.Tensor] = OrderedDict()
for filepath in tqdm(os.listdir(input_dir), desc="Load weights"): for filepath in tqdm(os.listdir(input_dir), desc="Load weights"):
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".bin"): if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".bin"):
shard_weight = torch.load(os.path.join(input_dir, filepath), map_location="cpu") shard_weight = torch.load(os.path.join(input_dir, filepath), map_location="cpu")
baichuan2_state_dict.update(shard_weight) baichuan2_state_dict.update(shard_weight)
llama_state_dict: Dict[str, torch.Tensor] = OrderedDict() llama_state_dict: dict[str, torch.Tensor] = OrderedDict()
for key, value in tqdm(baichuan2_state_dict.items(), desc="Convert format"): for key, value in tqdm(baichuan2_state_dict.items(), desc="Convert format"):
if "W_pack" in key: if "W_pack" in key:
proj_size = value.size(0) // 3 proj_size = value.size(0) // 3
@ -75,7 +75,7 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
def save_config(input_dir: str, output_dir: str): def save_config(input_dir: str, output_dir: str):
with open(os.path.join(input_dir, CONFIG_NAME), encoding="utf-8") as f: with open(os.path.join(input_dir, CONFIG_NAME), encoding="utf-8") as f:
llama2_config_dict: Dict[str, Any] = json.load(f) llama2_config_dict: dict[str, Any] = json.load(f)
llama2_config_dict["architectures"] = ["LlamaForCausalLM"] llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
llama2_config_dict.pop("auto_map", None) llama2_config_dict.pop("auto_map", None)
@ -94,8 +94,8 @@ def llamafy_baichuan2(
shard_size: str = "2GB", shard_size: str = "2GB",
save_safetensors: bool = True, save_safetensors: bool = True,
): ):
r""" r"""Convert the Baichuan2-7B model in the same format as LLaMA2-7B.
Converts the Baichuan2-7B model in the same format as LLaMA2-7B.
Usage: python llamafy_baichuan2.py --input_dir input --output_dir output Usage: python llamafy_baichuan2.py --input_dir input --output_dir output
Converted model: https://huggingface.co/hiyouga/Baichuan2-7B-Base-LLaMAfied Converted model: https://huggingface.co/hiyouga/Baichuan2-7B-Base-LLaMAfied
""" """

View File

@ -15,7 +15,7 @@
import json import json
import os import os
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Dict from typing import Any
import fire import fire
import torch import torch
@ -37,14 +37,14 @@ CONFIG_NAME = "config.json"
def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetensors: bool) -> str: def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetensors: bool) -> str:
qwen_state_dict: Dict[str, torch.Tensor] = OrderedDict() qwen_state_dict: dict[str, torch.Tensor] = OrderedDict()
for filepath in tqdm(os.listdir(input_dir), desc="Load weights"): for filepath in tqdm(os.listdir(input_dir), desc="Load weights"):
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".safetensors"): if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".safetensors"):
with safe_open(os.path.join(input_dir, filepath), framework="pt", device="cpu") as f: with safe_open(os.path.join(input_dir, filepath), framework="pt", device="cpu") as f:
for key in f.keys(): for key in f.keys():
qwen_state_dict[key] = f.get_tensor(key) qwen_state_dict[key] = f.get_tensor(key)
llama_state_dict: Dict[str, torch.Tensor] = OrderedDict() llama_state_dict: dict[str, torch.Tensor] = OrderedDict()
torch_dtype = None torch_dtype = None
for key, value in tqdm(qwen_state_dict.items(), desc="Convert format"): for key, value in tqdm(qwen_state_dict.items(), desc="Convert format"):
if torch_dtype is None: if torch_dtype is None:
@ -112,9 +112,9 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
def save_config(input_dir: str, output_dir: str, torch_dtype: str): def save_config(input_dir: str, output_dir: str, torch_dtype: str):
with open(os.path.join(input_dir, CONFIG_NAME), encoding="utf-8") as f: with open(os.path.join(input_dir, CONFIG_NAME), encoding="utf-8") as f:
qwen_config_dict: Dict[str, Any] = json.load(f) qwen_config_dict: dict[str, Any] = json.load(f)
llama2_config_dict: Dict[str, Any] = OrderedDict() llama2_config_dict: dict[str, Any] = OrderedDict()
llama2_config_dict["architectures"] = ["LlamaForCausalLM"] llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
llama2_config_dict["hidden_act"] = "silu" llama2_config_dict["hidden_act"] = "silu"
llama2_config_dict["hidden_size"] = qwen_config_dict["hidden_size"] llama2_config_dict["hidden_size"] = qwen_config_dict["hidden_size"]
@ -147,8 +147,8 @@ def llamafy_qwen(
shard_size: str = "2GB", shard_size: str = "2GB",
save_safetensors: bool = False, save_safetensors: bool = False,
): ):
r""" r"""Convert the Qwen models in the same format as LLaMA2.
Converts the Qwen models in the same format as LLaMA2.
Usage: python llamafy_qwen.py --input_dir input --output_dir output Usage: python llamafy_qwen.py --input_dir input --output_dir output
Converted model: https://huggingface.co/hiyouga/Qwen-14B-Chat-LLaMAfied Converted model: https://huggingface.co/hiyouga/Qwen-14B-Chat-LLaMAfied
""" """

View File

@ -18,7 +18,7 @@
import json import json
import os import os
from collections import OrderedDict from collections import OrderedDict
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING
import fire import fire
import torch import torch
@ -44,11 +44,11 @@ def block_expansion(
shard_size: str = "5GB", shard_size: str = "5GB",
save_safetensors: bool = True, save_safetensors: bool = True,
): ):
r""" r"""Perform block expansion for LLaMA, Mistral, Qwen2 or Yi models.
Performs block expansion for LLaMA, Mistral, Qwen2 or Yi models.
Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8 Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8
""" """
config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) config: PretrainedConfig = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
num_layers = getattr(config, "num_hidden_layers") num_layers = getattr(config, "num_hidden_layers")
if num_layers % num_expand != 0: if num_layers % num_expand != 0:
raise ValueError(f"`num_layers` {num_layers} should be divisible by `num_expand` {num_expand}.") raise ValueError(f"`num_layers` {num_layers} should be divisible by `num_expand` {num_expand}.")
@ -70,7 +70,7 @@ def block_expansion(
split = num_layers // num_expand split = num_layers // num_expand
layer_cnt = 0 layer_cnt = 0
state_dict = model.state_dict() state_dict = model.state_dict()
output_state_dict: Dict[str, "torch.Tensor"] = OrderedDict() output_state_dict: dict[str, torch.Tensor] = OrderedDict()
for i in range(num_layers): for i in range(num_layers):
for key, value in state_dict.items(): for key, value in state_dict.items():
if f".{i:d}." in key: if f".{i:d}." in key:

View File

@ -38,8 +38,8 @@ def quantize_loftq(
lora_target: tuple = ("q_proj", "v_proj"), lora_target: tuple = ("q_proj", "v_proj"),
save_safetensors: bool = True, save_safetensors: bool = True,
): ):
r""" r"""Initialize LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ).
Initializes LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ)
Usage: python loftq_init.py --model_name_or_path path_to_model --output_dir output_dir Usage: python loftq_init.py --model_name_or_path path_to_model --output_dir output_dir
""" """
if isinstance(lora_target, str): if isinstance(lora_target, str):
@ -72,7 +72,7 @@ def quantize_loftq(
print(f"Adapter weights saved in {loftq_dir}") print(f"Adapter weights saved in {loftq_dir}")
# Save base model # Save base model
base_model: "PreTrainedModel" = peft_model.unload() base_model: PreTrainedModel = peft_model.unload()
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors) base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
tokenizer.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir)
print(f"Model weights saved in {output_dir}") print(f"Model weights saved in {output_dir}")

View File

@ -37,8 +37,8 @@ def quantize_pissa(
lora_target: tuple = ("q_proj", "v_proj"), lora_target: tuple = ("q_proj", "v_proj"),
save_safetensors: bool = True, save_safetensors: bool = True,
): ):
r""" r"""Initialize LoRA weights with Principal Singular values and Singular vectors Adaptation (PiSSA).
Initializes LoRA weights with Principal Singular values and Singular vectors Adaptation (PiSSA)
Usage: python pissa_init.py --model_name_or_path path_to_model --output_dir output_dir Usage: python pissa_init.py --model_name_or_path path_to_model --output_dir output_dir
""" """
if isinstance(lora_target, str): if isinstance(lora_target, str):
@ -67,7 +67,7 @@ def quantize_pissa(
print(f"Adapter weights saved in {pissa_dir}") print(f"Adapter weights saved in {pissa_dir}")
# Save base model # Save base model
base_model: "PreTrainedModel" = peft_model.unload() base_model: PreTrainedModel = peft_model.unload()
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors) base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
tokenizer.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir)
print(f"Model weights saved in {output_dir}") print(f"Model weights saved in {output_dir}")

View File

@ -29,8 +29,8 @@ def calculate_flops(
seq_length: int = 512, seq_length: int = 512,
flash_attn: str = "auto", flash_attn: str = "auto",
): ):
r""" r"""Calculate the flops of pre-trained models.
Calculates the flops of pre-trained models.
Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512 Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512
""" """
with get_accelerator().device(0): with get_accelerator().device(0):

View File

@ -45,8 +45,8 @@ def calculate_lr(
is_mistral_or_gemma: bool = False, # mistral and gemma models opt for a smaller learning rate, is_mistral_or_gemma: bool = False, # mistral and gemma models opt for a smaller learning rate,
packing: bool = False, packing: bool = False,
): ):
r""" r"""Calculate the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
Usage: Usage:
python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en_demo --cutoff_len 1024 --batch_size 16 python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en_demo --cutoff_len 1024 --batch_size 16
""" """
@ -89,9 +89,8 @@ def calculate_lr(
lr = BASE_LR * math.sqrt(token_batch_size / BASE_BS) # lr ~ sqrt(batch_size) lr = BASE_LR * math.sqrt(token_batch_size / BASE_BS) # lr ~ sqrt(batch_size)
lr = lr / 6.0 if is_mistral_or_gemma else lr lr = lr / 6.0 if is_mistral_or_gemma else lr
print( print(
"Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective token batch size {:.2f}".format( f"Optimal learning rate is {lr:.2e} for valid ratio% {valid_ratio * 100:.2f} "
lr, valid_ratio * 100, token_batch_size f"and effective token batch size {token_batch_size:.2f}"
)
) )

View File

@ -34,9 +34,7 @@ def compute_model_flops(
include_recompute: bool = False, include_recompute: bool = False,
include_flashattn: bool = False, include_flashattn: bool = False,
) -> int: ) -> int:
r""" r"""Calculate the FLOPs of model per forward/backward pass."""
Calculates the FLOPs of model per forward/backward pass.
"""
config = AutoConfig.from_pretrained(model_name_or_path) config = AutoConfig.from_pretrained(model_name_or_path)
hidden_size = getattr(config, "hidden_size", None) hidden_size = getattr(config, "hidden_size", None)
vocab_size = getattr(config, "vocab_size", None) vocab_size = getattr(config, "vocab_size", None)
@ -86,9 +84,7 @@ def compute_model_flops(
def compute_device_flops(world_size: int) -> float: def compute_device_flops(world_size: int) -> float:
r""" r"""Calculate the FLOPs of the device capability per second."""
Calculates the FLOPs of the device capability per second.
"""
device_name = torch.cuda.get_device_name() device_name = torch.cuda.get_device_name()
if "H100" in device_name or "H800" in device_name: if "H100" in device_name or "H800" in device_name:
return 989 * 1e12 * world_size return 989 * 1e12 * world_size
@ -114,8 +110,8 @@ def calculate_mfu(
liger_kernel: bool = False, liger_kernel: bool = False,
unsloth_gc: bool = False, unsloth_gc: bool = False,
) -> float: ) -> float:
r""" r"""Calculate MFU for given model and hyper-params.
Calculates MFU for given model and hyper-params.
Usage: python cal_mfu.py --model_name_or_path path_to_model --batch_size 1 --seq_length 1024 Usage: python cal_mfu.py --model_name_or_path path_to_model --batch_size 1 --seq_length 1024
""" """
args = { args = {

View File

@ -13,8 +13,9 @@
# limitations under the License. # limitations under the License.
import json import json
from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Literal, Optional, Sequence from typing import Any, Literal, Optional
import fire import fire
import torch import torch
@ -30,16 +31,12 @@ from llamafactory.model import load_model, load_tokenizer
@dataclass @dataclass
class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq): class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
r""" r"""Data collator for pairwise data."""
Data collator for pairwise data.
"""
train_on_prompt: bool = False train_on_prompt: bool = False
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]: def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, torch.Tensor]:
r""" r"""Pad batched data to the longest sequence in the batch."""
Pads batched data to the longest sequence in the batch.
"""
chosen_features = [] chosen_features = []
for feature in features: for feature in features:
chosen_features.append( chosen_features.append(
@ -68,8 +65,8 @@ def calculate_ppl(
max_samples: Optional[int] = None, max_samples: Optional[int] = None,
train_on_prompt: bool = False, train_on_prompt: bool = False,
): ):
r""" r"""Calculate the ppl on the dataset of the pre-trained models.
Calculates the ppl on the dataset of the pre-trained models.
Usage: export CUDA_VISIBLE_DEVICES=0 Usage: export CUDA_VISIBLE_DEVICES=0
python cal_ppl.py --model_name_or_path path_to_model --dataset alpaca_en_demo --save_name ppl.json python cal_ppl.py --model_name_or_path path_to_model --dataset alpaca_en_demo --save_name ppl.json
""" """
@ -111,17 +108,17 @@ def calculate_ppl(
criterion = torch.nn.CrossEntropyLoss(reduction="none") criterion = torch.nn.CrossEntropyLoss(reduction="none")
total_ppl = 0 total_ppl = 0
perplexities = [] perplexities = []
batch: Dict[str, "torch.Tensor"] batch: dict[str, torch.Tensor]
with torch.no_grad(): with torch.no_grad():
for batch in tqdm(dataloader, desc="Computing perplexities"): for batch in tqdm(dataloader, desc="Computing perplexities"):
batch = batch.to(model.device) batch = batch.to(model.device)
outputs = model(**batch) outputs = model(**batch)
shift_logits: "torch.Tensor" = outputs["logits"][..., :-1, :] shift_logits: torch.Tensor = outputs["logits"][..., :-1, :]
shift_labels: "torch.Tensor" = batch["labels"][..., 1:] shift_labels: torch.Tensor = batch["labels"][..., 1:]
loss_mask = shift_labels != IGNORE_INDEX loss_mask = shift_labels != IGNORE_INDEX
flatten_logits = shift_logits.contiguous().view(shift_labels.size(0) * shift_labels.size(1), -1) flatten_logits = shift_logits.contiguous().view(shift_labels.size(0) * shift_labels.size(1), -1)
flatten_labels = shift_labels.contiguous().view(-1) flatten_labels = shift_labels.contiguous().view(-1)
token_logps: "torch.Tensor" = criterion(flatten_logits, flatten_labels) token_logps: torch.Tensor = criterion(flatten_logits, flatten_labels)
token_logps = token_logps.contiguous().view(shift_logits.size(0), -1) token_logps = token_logps.contiguous().view(shift_logits.size(0), -1)
sentence_logps = (token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) sentence_logps = (token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
total_ppl += sentence_logps.exp().sum().item() total_ppl += sentence_logps.exp().sum().item()

View File

@ -29,8 +29,8 @@ def length_cdf(
template: str = "default", template: str = "default",
interval: int = 1000, interval: int = 1000,
): ):
r""" r"""Calculate the distribution of the input lengths in the dataset.
Calculates the distribution of the input lengths in the dataset.
Usage: export CUDA_VISIBLE_DEVICES=0 Usage: export CUDA_VISIBLE_DEVICES=0
python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en_demo --template default python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en_demo --template default
""" """

View File

@ -52,8 +52,8 @@ def vllm_infer(
image_max_pixels: int = 768 * 768, image_max_pixels: int = 768 * 768,
image_min_pixels: int = 32 * 32, image_min_pixels: int = 32 * 32,
): ):
r""" r"""Perform batch generation using vLLM engine, which supports tensor parallelism.
Performs batch generation using vLLM engine, which supports tensor parallelism.
Usage: python vllm_infer.py --model_name_or_path meta-llama/Llama-2-7b-hf --template llama --dataset alpaca_en_demo Usage: python vllm_infer.py --model_name_or_path meta-llama/Llama-2-7b-hf --template llama --dataset alpaca_en_demo
""" """
check_version("vllm>=0.4.3,<=0.7.3") check_version("vllm>=0.4.3,<=0.7.3")

View File

@ -14,7 +14,6 @@
import os import os
import re import re
from typing import List
from setuptools import find_packages, setup from setuptools import find_packages, setup
@ -27,14 +26,14 @@ def get_version() -> str:
return version return version
def get_requires() -> List[str]: def get_requires() -> list[str]:
with open("requirements.txt", encoding="utf-8") as f: with open("requirements.txt", encoding="utf-8") as f:
file_content = f.read() file_content = f.read()
lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")] lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
return lines return lines
def get_console_scripts() -> List[str]: def get_console_scripts() -> list[str]:
console_scripts = ["llamafactory-cli = llamafactory.cli:main"] console_scripts = ["llamafactory-cli = llamafactory.cli:main"]
if os.getenv("ENABLE_SHORT_CONSOLE", "1").lower() in ["true", "y", "1"]: if os.getenv("ENABLE_SHORT_CONSOLE", "1").lower() in ["true", "y", "1"]:
console_scripts.append("lmf = llamafactory.cli:main") console_scripts.append("lmf = llamafactory.cli:main")

View File

@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
r""" r"""Efficient fine-tuning of large language models.
Efficient fine-tuning of large language models.
Level: Level:
api, webui > chat, eval, train > data, model > hparams > extras api, webui > chat, eval, train > data, model > hparams > extras

View File

@ -16,9 +16,7 @@ import asyncio
import os import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from functools import partial from functools import partial
from typing import Optional from typing import Annotated, Optional
from typing_extensions import Annotated
from ..chat import ChatModel from ..chat import ChatModel
from ..extras.constants import EngineName from ..extras.constants import EngineName

View File

@ -18,7 +18,8 @@ import json
import os import os
import re import re
import uuid import uuid
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING, Optional
from ..data import Role as DataRole from ..data import Role as DataRole
from ..extras import logging from ..extras import logging
@ -71,7 +72,7 @@ ROLE_MAPPING = {
def _process_request( def _process_request(
request: "ChatCompletionRequest", request: "ChatCompletionRequest",
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional[List["ImageInput"]]]: ) -> tuple[list[dict[str, str]], Optional[str], Optional[str], Optional[list["ImageInput"]]]:
if is_env_enabled("API_VERBOSE", "1"): if is_env_enabled("API_VERBOSE", "1"):
logger.info_rank0(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}") logger.info_rank0(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}")

View File

@ -13,14 +13,14 @@
# limitations under the License. # limitations under the License.
import json import json
from typing import TYPE_CHECKING, Any, Dict from typing import TYPE_CHECKING, Any
if TYPE_CHECKING: if TYPE_CHECKING:
from pydantic import BaseModel from pydantic import BaseModel
def dictify(data: "BaseModel") -> Dict[str, Any]: def dictify(data: "BaseModel") -> dict[str, Any]:
try: # pydantic v2 try: # pydantic v2
return data.model_dump(exclude_unset=True) return data.model_dump(exclude_unset=True)
except AttributeError: # pydantic v1 except AttributeError: # pydantic v1

View File

@ -14,7 +14,7 @@
import time import time
from enum import Enum, unique from enum import Enum, unique
from typing import Any, Dict, List, Optional, Union from typing import Any, Optional, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Literal from typing_extensions import Literal
@ -45,7 +45,7 @@ class ModelCard(BaseModel):
class ModelList(BaseModel): class ModelList(BaseModel):
object: Literal["list"] = "list" object: Literal["list"] = "list"
data: List[ModelCard] = [] data: list[ModelCard] = []
class Function(BaseModel): class Function(BaseModel):
@ -56,7 +56,7 @@ class Function(BaseModel):
class FunctionDefinition(BaseModel): class FunctionDefinition(BaseModel):
name: str name: str
description: str description: str
parameters: Dict[str, Any] parameters: dict[str, Any]
class FunctionAvailable(BaseModel): class FunctionAvailable(BaseModel):
@ -82,26 +82,26 @@ class MultimodalInputItem(BaseModel):
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
role: Role role: Role
content: Optional[Union[str, List[MultimodalInputItem]]] = None content: Optional[Union[str, list[MultimodalInputItem]]] = None
tool_calls: Optional[List[FunctionCall]] = None tool_calls: Optional[list[FunctionCall]] = None
class ChatCompletionMessage(BaseModel): class ChatCompletionMessage(BaseModel):
role: Optional[Role] = None role: Optional[Role] = None
content: Optional[str] = None content: Optional[str] = None
tool_calls: Optional[List[FunctionCall]] = None tool_calls: Optional[list[FunctionCall]] = None
class ChatCompletionRequest(BaseModel): class ChatCompletionRequest(BaseModel):
model: str model: str
messages: List[ChatMessage] messages: list[ChatMessage]
tools: Optional[List[FunctionAvailable]] = None tools: Optional[list[FunctionAvailable]] = None
do_sample: Optional[bool] = None do_sample: Optional[bool] = None
temperature: Optional[float] = None temperature: Optional[float] = None
top_p: Optional[float] = None top_p: Optional[float] = None
n: int = 1 n: int = 1
max_tokens: Optional[int] = None max_tokens: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None stop: Optional[Union[str, list[str]]] = None
stream: bool = False stream: bool = False
@ -128,7 +128,7 @@ class ChatCompletionResponse(BaseModel):
object: Literal["chat.completion"] = "chat.completion" object: Literal["chat.completion"] = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
model: str model: str
choices: List[ChatCompletionResponseChoice] choices: list[ChatCompletionResponseChoice]
usage: ChatCompletionResponseUsage usage: ChatCompletionResponseUsage
@ -137,12 +137,12 @@ class ChatCompletionStreamResponse(BaseModel):
object: Literal["chat.completion.chunk"] = "chat.completion.chunk" object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
model: str model: str
choices: List[ChatCompletionStreamResponseChoice] choices: list[ChatCompletionStreamResponseChoice]
class ScoreEvaluationRequest(BaseModel): class ScoreEvaluationRequest(BaseModel):
model: str model: str
messages: List[str] messages: list[str]
max_length: Optional[int] = None max_length: Optional[int] = None
@ -150,4 +150,4 @@ class ScoreEvaluationResponse(BaseModel):
id: str id: str
object: Literal["score.evaluation"] = "score.evaluation" object: Literal["score.evaluation"] = "score.evaluation"
model: str model: str
scores: List[float] scores: list[float]

View File

@ -13,8 +13,9 @@
# limitations under the License. # limitations under the License.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Union from typing import TYPE_CHECKING, Any, Literal, Optional, Union
if TYPE_CHECKING: if TYPE_CHECKING:
@ -36,8 +37,7 @@ class Response:
class BaseEngine(ABC): class BaseEngine(ABC):
r""" r"""Base class for inference engine of chat models.
Base class for inference engine of chat models.
Must implements async methods: chat(), stream_chat() and get_scores(). Must implements async methods: chat(), stream_chat() and get_scores().
""" """
@ -47,7 +47,7 @@ class BaseEngine(ABC):
tokenizer: "PreTrainedTokenizer" tokenizer: "PreTrainedTokenizer"
can_generate: bool can_generate: bool
template: "Template" template: "Template"
generating_args: Dict[str, Any] generating_args: dict[str, Any]
@abstractmethod @abstractmethod
def __init__( def __init__(
@ -57,31 +57,27 @@ class BaseEngine(ABC):
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments", generating_args: "GeneratingArguments",
) -> None: ) -> None:
r""" r"""Initialize an inference engine."""
Initializes an inference engine.
"""
... ...
@abstractmethod @abstractmethod
async def chat( async def chat(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> list["Response"]:
r""" r"""Get a list of responses of the chat model."""
Gets a list of responses of the chat model.
"""
... ...
@abstractmethod @abstractmethod
async def stream_chat( async def stream_chat(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[Sequence["ImageInput"]] = None,
@ -89,18 +85,14 @@ class BaseEngine(ABC):
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
r""" r"""Get the response token-by-token of the chat model."""
Gets the response token-by-token of the chat model.
"""
... ...
@abstractmethod @abstractmethod
async def get_scores( async def get_scores(
self, self,
batch_input: List[str], batch_input: list[str],
**input_kwargs, **input_kwargs,
) -> List[float]: ) -> list[float]:
r""" r"""Get a list of scores of the reward model."""
Gets a list of scores of the reward model.
"""
... ...

View File

@ -17,8 +17,9 @@
import asyncio import asyncio
import os import os
from collections.abc import AsyncGenerator, Generator, Sequence
from threading import Thread from threading import Thread
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence from typing import TYPE_CHECKING, Any, Optional
from ..extras.constants import EngineName from ..extras.constants import EngineName
from ..extras.misc import torch_gc from ..extras.misc import torch_gc
@ -38,20 +39,19 @@ def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
class ChatModel: class ChatModel:
r""" r"""General class for chat models. Backed by huggingface or vllm engines.
General class for chat models. Backed by huggingface or vllm engines.
Supports both sync and async methods. Supports both sync and async methods.
Sync methods: chat(), stream_chat() and get_scores(). Sync methods: chat(), stream_chat() and get_scores().
Async methods: achat(), astream_chat() and aget_scores(). Async methods: achat(), astream_chat() and aget_scores().
""" """
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None: def __init__(self, args: Optional[dict[str, Any]] = None) -> None:
model_args, data_args, finetuning_args, generating_args = get_infer_args(args) model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
if model_args.infer_backend == EngineName.HF: if model_args.infer_backend == EngineName.HF:
self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args) self.engine: BaseEngine = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
elif model_args.infer_backend == EngineName.VLLM: elif model_args.infer_backend == EngineName.VLLM:
self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args) self.engine: BaseEngine = VllmEngine(model_args, data_args, finetuning_args, generating_args)
else: else:
raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}") raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")
@ -61,17 +61,15 @@ class ChatModel:
def chat( def chat(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> list["Response"]:
r""" r"""Get a list of responses of the chat model."""
Gets a list of responses of the chat model.
"""
task = asyncio.run_coroutine_threadsafe( task = asyncio.run_coroutine_threadsafe(
self.achat(messages, system, tools, images, videos, audios, **input_kwargs), self._loop self.achat(messages, system, tools, images, videos, audios, **input_kwargs), self._loop
) )
@ -79,22 +77,20 @@ class ChatModel:
async def achat( async def achat(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> list["Response"]:
r""" r"""Asynchronously get a list of responses of the chat model."""
Asynchronously gets a list of responses of the chat model.
"""
return await self.engine.chat(messages, system, tools, images, videos, audios, **input_kwargs) return await self.engine.chat(messages, system, tools, images, videos, audios, **input_kwargs)
def stream_chat( def stream_chat(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[Sequence["ImageInput"]] = None,
@ -102,9 +98,7 @@ class ChatModel:
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
r""" r"""Get the response token-by-token of the chat model."""
Gets the response token-by-token of the chat model.
"""
generator = self.astream_chat(messages, system, tools, images, videos, audios, **input_kwargs) generator = self.astream_chat(messages, system, tools, images, videos, audios, **input_kwargs)
while True: while True:
try: try:
@ -115,7 +109,7 @@ class ChatModel:
async def astream_chat( async def astream_chat(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[Sequence["ImageInput"]] = None,
@ -123,9 +117,7 @@ class ChatModel:
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
r""" r"""Asynchronously get the response token-by-token of the chat model."""
Asynchronously gets the response token-by-token of the chat model.
"""
async for new_token in self.engine.stream_chat( async for new_token in self.engine.stream_chat(
messages, system, tools, images, videos, audios, **input_kwargs messages, system, tools, images, videos, audios, **input_kwargs
): ):
@ -133,23 +125,19 @@ class ChatModel:
def get_scores( def get_scores(
self, self,
batch_input: List[str], batch_input: list[str],
**input_kwargs, **input_kwargs,
) -> List[float]: ) -> list[float]:
r""" r"""Get a list of scores of the reward model."""
Gets a list of scores of the reward model.
"""
task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop) task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop)
return task.result() return task.result()
async def aget_scores( async def aget_scores(
self, self,
batch_input: List[str], batch_input: list[str],
**input_kwargs, **input_kwargs,
) -> List[float]: ) -> list[float]:
r""" r"""Asynchronously get a list of scores of the reward model."""
Asynchronously gets a list of scores of the reward model.
"""
return await self.engine.get_scores(batch_input, **input_kwargs) return await self.engine.get_scores(batch_input, **input_kwargs)

View File

@ -15,8 +15,9 @@
import asyncio import asyncio
import concurrent.futures import concurrent.futures
import os import os
from collections.abc import AsyncGenerator, Sequence
from threading import Thread from threading import Thread
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple, Union from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import torch import torch
from transformers import GenerationConfig, TextIteratorStreamer from transformers import GenerationConfig, TextIteratorStreamer
@ -76,15 +77,15 @@ class HuggingfaceEngine(BaseEngine):
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
template: "Template", template: "Template",
generating_args: Dict[str, Any], generating_args: dict[str, Any],
messages: Sequence[Dict[str, str]], messages: Sequence[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[Sequence["AudioInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[dict[str, Any]] = {},
) -> Tuple[Dict[str, Any], int]: ) -> tuple[dict[str, Any], int]:
mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]} mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]}
if images is not None: if images is not None:
mm_input_dict.update({"images": images, "imglens": [len(images)]}) mm_input_dict.update({"images": images, "imglens": [len(images)]})
@ -130,7 +131,7 @@ class HuggingfaceEngine(BaseEngine):
skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None) skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
max_length: Optional[int] = input_kwargs.pop("max_length", None) max_length: Optional[int] = input_kwargs.pop("max_length", None)
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None) max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None) stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None)
if stop is not None: if stop is not None:
logger.warning_rank0("Stop parameter is not supported by the huggingface engine yet.") logger.warning_rank0("Stop parameter is not supported by the huggingface engine yet.")
@ -217,15 +218,15 @@ class HuggingfaceEngine(BaseEngine):
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
template: "Template", template: "Template",
generating_args: Dict[str, Any], generating_args: dict[str, Any],
messages: Sequence[Dict[str, str]], messages: Sequence[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[Sequence["AudioInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[dict[str, Any]] = {},
) -> List["Response"]: ) -> list["Response"]:
gen_kwargs, prompt_length = HuggingfaceEngine._process_args( gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
model, model,
tokenizer, tokenizer,
@ -272,14 +273,14 @@ class HuggingfaceEngine(BaseEngine):
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
template: "Template", template: "Template",
generating_args: Dict[str, Any], generating_args: dict[str, Any],
messages: Sequence[Dict[str, str]], messages: Sequence[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[Sequence["AudioInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[dict[str, Any]] = {},
) -> Callable[[], str]: ) -> Callable[[], str]:
gen_kwargs, _ = HuggingfaceEngine._process_args( gen_kwargs, _ = HuggingfaceEngine._process_args(
model, model,
@ -317,12 +318,12 @@ class HuggingfaceEngine(BaseEngine):
def _get_scores( def _get_scores(
model: "PreTrainedModelWrapper", model: "PreTrainedModelWrapper",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
batch_input: List[str], batch_input: list[str],
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[dict[str, Any]] = {},
) -> List[float]: ) -> list[float]:
max_length: Optional[int] = input_kwargs.pop("max_length", None) max_length: Optional[int] = input_kwargs.pop("max_length", None)
device = getattr(model.pretrained_model, "device", "cuda") device = getattr(model.pretrained_model, "device", "cuda")
inputs: Dict[str, "torch.Tensor"] = tokenizer( inputs: dict[str, torch.Tensor] = tokenizer(
batch_input, batch_input,
padding=True, padding=True,
truncation=True, truncation=True,
@ -330,21 +331,21 @@ class HuggingfaceEngine(BaseEngine):
return_tensors="pt", return_tensors="pt",
add_special_tokens=False, add_special_tokens=False,
).to(device) ).to(device)
values: "torch.Tensor" = model(**inputs, return_dict=True, use_cache=False)[-1] values: torch.Tensor = model(**inputs, return_dict=True, use_cache=False)[-1]
scores = values.gather(dim=-1, index=(inputs["attention_mask"].sum(dim=-1, keepdim=True) - 1)) scores = values.gather(dim=-1, index=(inputs["attention_mask"].sum(dim=-1, keepdim=True) - 1))
return scores return scores
@override @override
async def chat( async def chat(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> list["Response"]:
if not self.can_generate: if not self.can_generate:
raise ValueError("The current model does not support `chat`.") raise ValueError("The current model does not support `chat`.")
@ -370,7 +371,7 @@ class HuggingfaceEngine(BaseEngine):
@override @override
async def stream_chat( async def stream_chat(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[Sequence["ImageInput"]] = None,
@ -408,9 +409,9 @@ class HuggingfaceEngine(BaseEngine):
@override @override
async def get_scores( async def get_scores(
self, self,
batch_input: List[str], batch_input: list[str],
**input_kwargs, **input_kwargs,
) -> List[float]: ) -> list[float]:
if self.can_generate: if self.can_generate:
raise ValueError("Cannot get scores using an auto-regressive model.") raise ValueError("Cannot get scores using an auto-regressive model.")

View File

@ -13,7 +13,8 @@
# limitations under the License. # limitations under the License.
import uuid import uuid
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union from collections.abc import AsyncGenerator, AsyncIterator, Sequence
from typing import TYPE_CHECKING, Any, Optional, Union
from typing_extensions import override from typing_extensions import override
@ -53,7 +54,7 @@ class VllmEngine(BaseEngine):
self.model_args = model_args self.model_args = model_args
config = load_config(model_args) # may download model from ms hub config = load_config(model_args) # may download model from ms hub
if getattr(config, "quantization_config", None): # gptq models should use float16 if getattr(config, "quantization_config", None): # gptq models should use float16
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None) quantization_config: dict[str, Any] = getattr(config, "quantization_config", None)
quant_method = quantization_config.get("quant_method", "") quant_method = quantization_config.get("quant_method", "")
if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto": if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto":
model_args.infer_dtype = "float16" model_args.infer_dtype = "float16"
@ -101,7 +102,7 @@ class VllmEngine(BaseEngine):
async def _generate( async def _generate(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[Sequence["ImageInput"]] = None,
@ -143,7 +144,7 @@ class VllmEngine(BaseEngine):
skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None) skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
max_length: Optional[int] = input_kwargs.pop("max_length", None) max_length: Optional[int] = input_kwargs.pop("max_length", None)
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None) max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None) stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None)
if length_penalty is not None: if length_penalty is not None:
logger.warning_rank0("Length penalty is not supported by the vllm engine yet.") logger.warning_rank0("Length penalty is not supported by the vllm engine yet.")
@ -201,14 +202,14 @@ class VllmEngine(BaseEngine):
@override @override
async def chat( async def chat(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> list["Response"]:
final_output = None final_output = None
generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs) generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
async for request_output in generator: async for request_output in generator:
@ -230,7 +231,7 @@ class VllmEngine(BaseEngine):
@override @override
async def stream_chat( async def stream_chat(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[Sequence["ImageInput"]] = None,
@ -248,7 +249,7 @@ class VllmEngine(BaseEngine):
@override @override
async def get_scores( async def get_scores(
self, self,
batch_input: List[str], batch_input: list[str],
**input_kwargs, **input_kwargs,
) -> List[float]: ) -> list[float]:
raise NotImplementedError("vLLM engine does not support get_scores.") raise NotImplementedError("vLLM engine does not support get_scores.")

View File

@ -24,14 +24,14 @@ from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
__all__ = [ __all__ = [
"TEMPLATES",
"KTODataCollatorWithPadding", "KTODataCollatorWithPadding",
"MultiModalDataCollatorForSeq2Seq", "MultiModalDataCollatorForSeq2Seq",
"PairwiseDataCollatorWithPadding", "PairwiseDataCollatorWithPadding",
"SFTDataCollatorWith4DAttentionMask",
"Role", "Role",
"split_dataset", "SFTDataCollatorWith4DAttentionMask",
"get_dataset",
"TEMPLATES",
"Template", "Template",
"get_dataset",
"get_template_and_fix_tokenizer", "get_template_and_fix_tokenizer",
"split_dataset",
] ]

View File

@ -15,8 +15,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence from typing import TYPE_CHECKING, Any, Literal, Optional
import numpy as np import numpy as np
import torch import torch
@ -38,9 +39,10 @@ if TYPE_CHECKING:
def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor": def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
r""" r"""Expand 2d attention mask to 4d attention mask.
Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
while handles packed sequences and transforms the mask to lower triangular form to prevent future peeking. Expand the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
handle packed sequences and transforms the mask to lower triangular form to prevent future peeking.
e.g. e.g.
```python ```python
@ -78,8 +80,7 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
@dataclass @dataclass
class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
r""" r"""Data collator that supports VLMs.
Data collator that supports VLMs.
Features should contain input_ids, attention_mask, labels, and optionally contain images, videos and audios. Features should contain input_ids, attention_mask, labels, and optionally contain images, videos and audios.
""" """
@ -91,7 +92,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
if self.template is None: if self.template is None:
raise ValueError("Template is required for MultiModalDataCollator.") raise ValueError("Template is required for MultiModalDataCollator.")
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
batch_images, batch_videos, batch_audios = [], [], [] batch_images, batch_videos, batch_audios = [], [], []
batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], [] batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], []
for feature in features: for feature in features:
@ -166,7 +167,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
for i, feature in enumerate(features): for i, feature in enumerate(features):
feature["token_type_ids"] = token_type_ids[i] feature["token_type_ids"] = token_type_ids[i]
features: Dict[str, "torch.Tensor"] = super().__call__(features) features: dict[str, torch.Tensor] = super().__call__(features)
if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope
rope_index_kwargs = { rope_index_kwargs = {
@ -198,15 +199,13 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
@dataclass @dataclass
class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq): class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
r""" r"""Data collator for 4d attention mask."""
Data collator for 4d attention mask.
"""
block_diag_attn: bool = False block_diag_attn: bool = False
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager" attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
compute_dtype: "torch.dtype" = torch.float32 compute_dtype: "torch.dtype" = torch.float32
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
features = super().__call__(features) features = super().__call__(features)
if self.block_diag_attn and self.attn_implementation != "flash_attention_2": if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype) features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
@ -220,13 +219,10 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
@dataclass @dataclass
class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq): class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
r""" r"""Data collator for pairwise data."""
Data collator for pairwise data.
"""
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
r""" r"""Pad batched data to the longest sequence in the batch.
Pads batched data to the longest sequence in the batch.
We generate 2 * n examples where the first n examples represent chosen examples and We generate 2 * n examples where the first n examples represent chosen examples and
the last n examples represent rejected examples. the last n examples represent rejected examples.
@ -249,11 +245,9 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
@dataclass @dataclass
class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq): class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
r""" r"""Data collator for KTO data."""
Data collator for KTO data.
"""
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
target_features = [] target_features = []
kl_features = [] kl_features = []
kto_tags = [] kto_tags = []

View File

@ -14,8 +14,9 @@
import os import os
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Type, Union from typing import TYPE_CHECKING, Any, Optional, Union
from ..extras import logging from ..extras import logging
from .data_utils import Role from .data_utils import Role
@ -36,10 +37,8 @@ class DatasetConverter:
dataset_attr: "DatasetAttr" dataset_attr: "DatasetAttr"
data_args: "DataArguments" data_args: "DataArguments"
def _find_medias(self, medias: Union[Any, Sequence[Any]]) -> Optional[List[Any]]: def _find_medias(self, medias: Union[Any, Sequence[Any]]) -> Optional[list[Any]]:
r""" r"""Optionally concatenate media path to media dir when loading from local disk."""
Optionally concatenates media path to media dir when loading from local disk.
"""
if not isinstance(medias, list): if not isinstance(medias, list):
medias = [medias] if medias is not None else [] medias = [medias] if medias is not None else []
elif len(medias) == 0: elif len(medias) == 0:
@ -57,16 +56,14 @@ class DatasetConverter:
return medias return medias
@abstractmethod @abstractmethod
def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]: def __call__(self, example: dict[str, Any]) -> dict[str, Any]:
r""" r"""Convert a single example in the dataset to the standard format."""
Converts a single example in the dataset to the standard format.
"""
... ...
@dataclass @dataclass
class AlpacaDatasetConverter(DatasetConverter): class AlpacaDatasetConverter(DatasetConverter):
def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]: def __call__(self, example: dict[str, Any]) -> dict[str, Any]:
prompt = [] prompt = []
if self.dataset_attr.history and isinstance(example[self.dataset_attr.history], list): if self.dataset_attr.history and isinstance(example[self.dataset_attr.history], list):
for old_prompt, old_response in example[self.dataset_attr.history]: for old_prompt, old_response in example[self.dataset_attr.history]:
@ -116,7 +113,7 @@ class AlpacaDatasetConverter(DatasetConverter):
@dataclass @dataclass
class SharegptDatasetConverter(DatasetConverter): class SharegptDatasetConverter(DatasetConverter):
def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]: def __call__(self, example: dict[str, Any]) -> dict[str, Any]:
tag_mapping = { tag_mapping = {
self.dataset_attr.user_tag: Role.USER.value, self.dataset_attr.user_tag: Role.USER.value,
self.dataset_attr.assistant_tag: Role.ASSISTANT.value, self.dataset_attr.assistant_tag: Role.ASSISTANT.value,
@ -216,10 +213,8 @@ DATASET_CONVERTERS = {
} }
def register_dataset_converter(name: str, dataset_converter: Type["DatasetConverter"]) -> None: def register_dataset_converter(name: str, dataset_converter: type["DatasetConverter"]) -> None:
r""" r"""Register a new dataset converter."""
Register a new dataset converter.
"""
if name in DATASET_CONVERTERS: if name in DATASET_CONVERTERS:
raise ValueError(f"Dataset converter {name} already exists.") raise ValueError(f"Dataset converter {name} already exists.")
@ -227,9 +222,7 @@ def register_dataset_converter(name: str, dataset_converter: Type["DatasetConver
def get_dataset_converter(name: str, dataset_attr: "DatasetAttr", data_args: "DataArguments") -> "DatasetConverter": def get_dataset_converter(name: str, dataset_attr: "DatasetAttr", data_args: "DataArguments") -> "DatasetConverter":
r""" r"""Get a dataset converter."""
Gets a dataset converter.
"""
if name not in DATASET_CONVERTERS: if name not in DATASET_CONVERTERS:
raise ValueError(f"Dataset converter {name} not found.") raise ValueError(f"Dataset converter {name} not found.")
@ -242,17 +235,17 @@ def align_dataset(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]: ) -> Union["Dataset", "IterableDataset"]:
r""" r"""Align the dataset to a specific format.
Aligned dataset: Aligned dataset:
_prompt: [{"role": "user", "content": "..."}] * (2T - 1) _prompt: [{"role": "user", "content": "..."}] * (2T - 1)
_response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset) _response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
_system: "..." _system: "..."
_tools: "...", _tools: "..."
_images: [], _images: []
_videos: [], _videos: []
_audios: [], _audios: []
""" """
column_names = list(next(iter(dataset)).keys()) column_names = list(next(iter(dataset)).keys())
kwargs = {} kwargs = {}
if not data_args.streaming: if not data_args.streaming:

View File

@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections.abc import Sequence
from enum import Enum, unique from enum import Enum, unique
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, TypedDict, Union from typing import TYPE_CHECKING, Optional, TypedDict, Union
from datasets import DatasetDict, concatenate_datasets, interleave_datasets from datasets import DatasetDict, concatenate_datasets, interleave_datasets
@ -29,7 +30,7 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]] SLOTS = Sequence[Union[str, set[str], dict[str, str]]]
@unique @unique
@ -43,15 +44,13 @@ class Role(str, Enum):
class DatasetModule(TypedDict): class DatasetModule(TypedDict):
train_dataset: Optional[Union["Dataset", "IterableDataset"]] train_dataset: Optional[Union["Dataset", "IterableDataset"]]
eval_dataset: Optional[Union["Dataset", "IterableDataset", Dict[str, "Dataset"]]] eval_dataset: Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]
def merge_dataset( def merge_dataset(
all_datasets: List[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int all_datasets: list[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int
) -> Union["Dataset", "IterableDataset"]: ) -> Union["Dataset", "IterableDataset"]:
r""" r"""Merge multiple datasets to a unified dataset."""
Merges multiple datasets to a unified dataset.
"""
if len(all_datasets) == 1: if len(all_datasets) == 1:
return all_datasets[0] return all_datasets[0]
@ -78,14 +77,13 @@ def merge_dataset(
def split_dataset( def split_dataset(
dataset: Optional[Union["Dataset", "IterableDataset"]], dataset: Optional[Union["Dataset", "IterableDataset"]],
eval_dataset: Optional[Union["Dataset", "IterableDataset", Dict[str, "Dataset"]]], eval_dataset: Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]],
data_args: "DataArguments", data_args: "DataArguments",
seed: int, seed: int,
) -> "DatasetDict": ) -> "DatasetDict":
r""" r"""Split the dataset and returns a dataset dict containing train set and validation set.
Splits the dataset and returns a dataset dict containing train set and validation set.
Supports both map dataset and iterable dataset. Support both map dataset and iterable dataset.
""" """
if eval_dataset is not None and data_args.val_size > 1e-6: if eval_dataset is not None and data_args.val_size > 1e-6:
raise ValueError("Cannot specify `val_size` if `eval_dataset` is not None.") raise ValueError("Cannot specify `val_size` if `eval_dataset` is not None.")
@ -120,10 +118,8 @@ def split_dataset(
def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModule": def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModule":
r""" r"""Convert dataset or dataset dict to dataset module."""
Converts dataset or dataset dict to dataset module. dataset_module: DatasetModule = {}
"""
dataset_module: "DatasetModule" = {}
if isinstance(dataset, DatasetDict): # dataset dict if isinstance(dataset, DatasetDict): # dataset dict
if "train" in dataset: if "train" in dataset:
dataset_module["train_dataset"] = dataset["train"] dataset_module["train_dataset"] = dataset["train"]

View File

@ -16,7 +16,7 @@ import json
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List, Optional, Union from typing import Optional, Union
from typing_extensions import override from typing_extensions import override
@ -31,14 +31,11 @@ class Formatter(ABC):
@abstractmethod @abstractmethod
def apply(self, **kwargs) -> SLOTS: def apply(self, **kwargs) -> SLOTS:
r""" r"""Forms a list of slots according to the inputs to encode."""
Forms a list of slots according to the inputs to encode.
"""
... ...
def extract(self, content: str) -> Union[str, List["FunctionCall"]]: def extract(self, content: str) -> Union[str, list["FunctionCall"]]:
r""" r"""Extract a list of tuples from the response message if using tools.
Extract a list of tuples from the response message if using tools.
Each tuple consists of function name and function arguments. Each tuple consists of function name and function arguments.
""" """
@ -105,7 +102,7 @@ class FunctionFormatter(StringFormatter):
if thought: if thought:
content = content.replace(thought.group(0), "") content = content.replace(thought.group(0), "")
functions: List["FunctionCall"] = [] functions: list[FunctionCall] = []
try: try:
tool_calls = json.loads(content) tool_calls = json.loads(content)
if not isinstance(tool_calls, list): # parallel function call if not isinstance(tool_calls, list): # parallel function call
@ -141,5 +138,5 @@ class ToolFormatter(Formatter):
raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") # flat string raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") # flat string
@override @override
def extract(self, content: str) -> Union[str, List["FunctionCall"]]: def extract(self, content: str) -> Union[str, list["FunctionCall"]]:
return self.tool_utils.tool_extractor(content) return self.tool_utils.tool_extractor(content)

View File

@ -13,7 +13,8 @@
# limitations under the License. # limitations under the License.
import os import os
from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union from collections.abc import Sequence
from typing import TYPE_CHECKING, Literal, Optional, Union
import numpy as np import numpy as np
from datasets import load_dataset, load_from_disk from datasets import load_dataset, load_from_disk
@ -54,9 +55,7 @@ def _load_single_dataset(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]: ) -> Union["Dataset", "IterableDataset"]:
r""" r"""Load a single dataset and aligns it to the standard format."""
Loads a single dataset and aligns it to the standard format.
"""
logger.info_rank0(f"Loading dataset {dataset_attr}...") logger.info_rank0(f"Loading dataset {dataset_attr}...")
data_path, data_name, data_dir, data_files = None, None, None, None data_path, data_name, data_dir, data_files = None, None, None, None
if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]: if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]:
@ -164,10 +163,8 @@ def _get_merged_dataset(
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"], stage: Literal["pt", "sft", "rm", "ppo", "kto"],
merge: bool = True, merge: bool = True,
) -> Optional[Union["Dataset", "IterableDataset", Dict[str, "Dataset"]]]: ) -> Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]:
r""" r"""Return the merged datasets in the standard format."""
Returns the merged datasets in the standard format.
"""
if dataset_names is None: if dataset_names is None:
return None return None
@ -192,9 +189,7 @@ def _get_dataset_processor(
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
do_generate: bool = False, do_generate: bool = False,
) -> "DatasetProcessor": ) -> "DatasetProcessor":
r""" r"""Return the corresponding dataset processor."""
Returns the corresponding dataset processor.
"""
if stage == "pt": if stage == "pt":
dataset_processor_class = PretrainDatasetProcessor dataset_processor_class = PretrainDatasetProcessor
elif stage == "sft" and not do_generate: elif stage == "sft" and not do_generate:
@ -236,9 +231,7 @@ def _get_preprocessed_dataset(
processor: Optional["ProcessorMixin"] = None, processor: Optional["ProcessorMixin"] = None,
is_eval: bool = False, is_eval: bool = False,
) -> Optional[Union["Dataset", "IterableDataset"]]: ) -> Optional[Union["Dataset", "IterableDataset"]]:
r""" r"""Preprocesses the dataset, including format checking and tokenization."""
Preprocesses the dataset, including format checking and tokenization.
"""
if dataset is None: if dataset is None:
return None return None
@ -284,9 +277,7 @@ def get_dataset(
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"] = None, processor: Optional["ProcessorMixin"] = None,
) -> "DatasetModule": ) -> "DatasetModule":
r""" r"""Get the train dataset and optionally gets the evaluation dataset."""
Gets the train dataset and optionally gets the evaluation dataset.
"""
# Load tokenized dataset if path exists # Load tokenized dataset if path exists
if data_args.tokenized_path is not None: if data_args.tokenized_path is not None:
if has_tokenized_data(data_args.tokenized_path): if has_tokenized_data(data_args.tokenized_path):

View File

@ -1,10 +1,11 @@
import inspect import inspect
import math import math
import re import re
from collections.abc import Sequence
from copy import deepcopy from copy import deepcopy
from dataclasses import dataclass from dataclasses import dataclass
from io import BytesIO from io import BytesIO
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Type, TypedDict, Union from typing import TYPE_CHECKING, Optional, TypedDict, Union
import numpy as np import numpy as np
import torch import torch
@ -58,12 +59,12 @@ if TYPE_CHECKING:
def _get_paligemma_token_type_ids( def _get_paligemma_token_type_ids(
imglens: Sequence[int], seqlens: Sequence[int], processor: "ProcessorMixin" imglens: Sequence[int], seqlens: Sequence[int], processor: "ProcessorMixin"
) -> List[List[int]]: ) -> list[list[int]]:
r""" r"""Get paligemma token type ids for computing loss.
Gets paligemma token type ids for computing loss.
Returns: Returns:
batch_token_type_ids: shape (batch_size, sequence_length) batch_token_type_ids: shape (batch_size, sequence_length)
""" """
batch_token_type_ids = [] batch_token_type_ids = []
for imglen, seqlen in zip(imglens, seqlens): for imglen, seqlen in zip(imglens, seqlens):
@ -87,11 +88,9 @@ class MMPluginMixin:
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
) -> None: ) -> None:
r""" r"""Validate if this model accepts the input modalities."""
Validates if this model accepts the input modalities. image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
""" feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor", None)
feature_extractor: "SequenceFeatureExtractor" = getattr(processor, "feature_extractor", None)
if len(images) != 0 and self.image_token is None: if len(images) != 0 and self.image_token is None:
raise ValueError( raise ValueError(
"This model does not support image input. Please check whether the correct `template` is used." "This model does not support image input. Please check whether the correct `template` is used."
@ -119,9 +118,7 @@ class MMPluginMixin:
def _preprocess_image( def _preprocess_image(
self, image: "ImageObject", image_max_pixels: int, image_min_pixels: int, **kwargs self, image: "ImageObject", image_max_pixels: int, image_min_pixels: int, **kwargs
) -> "ImageObject": ) -> "ImageObject":
r""" r"""Pre-process a single image."""
Pre-processes a single image.
"""
if (image.width * image.height) > image_max_pixels: if (image.width * image.height) > image_max_pixels:
resize_factor = math.sqrt(image_max_pixels / (image.width * image.height)) resize_factor = math.sqrt(image_max_pixels / (image.width * image.height))
width, height = int(image.width * resize_factor), int(image.height * resize_factor) width, height = int(image.width * resize_factor), int(image.height * resize_factor)
@ -139,10 +136,8 @@ class MMPluginMixin:
def _get_video_sample_indices( def _get_video_sample_indices(
self, video_stream: "Stream", video_fps: float, video_maxlen: int, **kwargs self, video_stream: "Stream", video_fps: float, video_maxlen: int, **kwargs
) -> List[int]: ) -> list[int]:
r""" r"""Compute video sample indices according to fps."""
Computes video sample indices according to fps.
"""
total_frames = video_stream.frames total_frames = video_stream.frames
if total_frames == 0: # infinite video if total_frames == 0: # infinite video
return np.linspace(0, video_maxlen - 1, video_maxlen).astype(np.int32) return np.linspace(0, video_maxlen - 1, video_maxlen).astype(np.int32)
@ -151,10 +146,8 @@ class MMPluginMixin:
sample_frames = min(total_frames, video_maxlen, sample_frames) sample_frames = min(total_frames, video_maxlen, sample_frames)
return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32) return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
def _regularize_images(self, images: Sequence["ImageInput"], **kwargs) -> List["ImageObject"]: def _regularize_images(self, images: Sequence["ImageInput"], **kwargs) -> list["ImageObject"]:
r""" r"""Regularize images to avoid error. Including reading and pre-processing."""
Regularizes images to avoid error. Including reading and pre-processing.
"""
results = [] results = []
for image in images: for image in images:
if isinstance(image, str): if isinstance(image, str):
@ -174,16 +167,14 @@ class MMPluginMixin:
return results return results
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]: def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> list[list["ImageObject"]]:
r""" r"""Regularizes videos to avoid error. Including reading, resizing and converting."""
Regularizes videos to avoid error. Including reading, resizing and converting.
"""
results = [] results = []
for video in videos: for video in videos:
container = av.open(video, "r") container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video") video_stream = next(stream for stream in container.streams if stream.type == "video")
sample_indices = self._get_video_sample_indices(video_stream, **kwargs) sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
frames: List["ImageObject"] = [] frames: list[ImageObject] = []
container.seek(0) container.seek(0)
for frame_idx, frame in enumerate(container.decode(video_stream)): for frame_idx, frame in enumerate(container.decode(video_stream)):
if frame_idx in sample_indices: if frame_idx in sample_indices:
@ -194,10 +185,8 @@ class MMPluginMixin:
return results return results
def _regularize_audios(self, audios: Sequence["AudioInput"], sampling_rate: float, **kwargs) -> List["NDArray"]: def _regularize_audios(self, audios: Sequence["AudioInput"], sampling_rate: float, **kwargs) -> list["NDArray"]:
r""" r"""Regularizes audios to avoid error. Including reading and resampling."""
Regularizes audios to avoid error. Including reading and resampling.
"""
results = [] results = []
for audio in audios: for audio in audios:
if isinstance(audio, str): if isinstance(audio, str):
@ -216,9 +205,8 @@ class MMPluginMixin:
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
processor: "ProcessorMixin", processor: "ProcessorMixin",
) -> Dict[str, "torch.Tensor"]: ) -> dict[str, "torch.Tensor"]:
r""" r"""Process visual inputs.
Processes visual inputs.
Returns: (llava and paligemma) Returns: (llava and paligemma)
pixel_values: tensor with shape (B, C, H, W) pixel_values: tensor with shape (B, C, H, W)
@ -229,9 +217,9 @@ class MMPluginMixin:
It holds num_patches == torch.prod(image_grid_thw) It holds num_patches == torch.prod(image_grid_thw)
""" """
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor", None) image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
video_processor: "BaseImageProcessor" = getattr(processor, "video_processor", image_processor) video_processor: BaseImageProcessor = getattr(processor, "video_processor", image_processor)
feature_extractor: "SequenceFeatureExtractor" = getattr(processor, "feature_extractor", None) feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
mm_inputs = {} mm_inputs = {}
if len(images) != 0: if len(images) != 0:
@ -278,31 +266,27 @@ class MMPluginMixin:
class BasePlugin(MMPluginMixin): class BasePlugin(MMPluginMixin):
def process_messages( def process_messages(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[dict[str, str]],
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]: ) -> list[dict[str, str]]:
r""" r"""Pre-processes input messages before tokenization for VLMs."""
Pre-processes input messages before tokenization for VLMs.
"""
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
return messages return messages
def process_token_ids( def process_token_ids(
self, self,
input_ids: List[int], input_ids: list[int],
labels: Optional[List[int]], labels: Optional[list[int]],
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Tuple[List[int], Optional[List[int]]]: ) -> tuple[list[int], Optional[list[int]]]:
r""" r"""Pre-processes token ids after tokenization for VLMs."""
Pre-processes token ids after tokenization for VLMs.
"""
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
return input_ids, labels return input_ids, labels
@ -314,20 +298,21 @@ class BasePlugin(MMPluginMixin):
imglens: Sequence[int], imglens: Sequence[int],
vidlens: Sequence[int], vidlens: Sequence[int],
audlens: Sequence[int], audlens: Sequence[int],
batch_ids: Sequence[List[int]], batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> dict[str, Union[list[int], "torch.Tensor"]]:
r""" r"""Build batched multimodal inputs for VLMs.
Builds batched multimodal inputs for VLMs.
Arguments: Arguments:
images: a list of image inputs, shape (num_images,) images: a list of image inputs, shape (num_images,)
videos: a list of video inputs, shape (num_videos,) videos: a list of video inputs, shape (num_videos,)
audios: a list of audio inputs, shape (num_audios,)
imglens: number of images in each sample, shape (batch_size,) imglens: number of images in each sample, shape (batch_size,)
vidlens: number of videos in each sample, shape (batch_size,) vidlens: number of videos in each sample, shape (batch_size,)
audlens: number of audios in each sample, shape (batch_size,) audlens: number of audios in each sample, shape (batch_size,)
batch_ids: token ids of input samples, shape (batch_size, seq_len) batch_ids: token ids of input samples, shape (batch_size, seq_len)
processor: a processor for pre-processing images and videos processor: a processor for pre-processing images and videos
""" """
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
return {} return {}
@ -338,12 +323,12 @@ class LlavaPlugin(BasePlugin):
@override @override
def process_messages( def process_messages(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[dict[str, str]],
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
num_image_tokens = 0 num_image_tokens = 0
image_seqlen = getattr(processor, "image_seqlen") if self.expand_mm_tokens else 1 image_seqlen = getattr(processor, "image_seqlen") if self.expand_mm_tokens else 1
@ -370,9 +355,9 @@ class LlavaPlugin(BasePlugin):
imglens: Sequence[int], imglens: Sequence[int],
vidlens: Sequence[int], vidlens: Sequence[int],
audlens: Sequence[int], audlens: Sequence[int],
batch_ids: Sequence[List[int]], batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor) return self._get_mm_inputs(images, videos, audios, processor)
@ -382,12 +367,12 @@ class LlavaNextPlugin(BasePlugin):
@override @override
def process_messages( def process_messages(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[dict[str, str]],
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
num_image_tokens = 0 num_image_tokens = 0
messages = deepcopy(messages) messages = deepcopy(messages)
@ -426,9 +411,9 @@ class LlavaNextPlugin(BasePlugin):
imglens: Sequence[int], imglens: Sequence[int],
vidlens: Sequence[int], vidlens: Sequence[int],
audlens: Sequence[int], audlens: Sequence[int],
batch_ids: Sequence[List[int]], batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor) return self._get_mm_inputs(images, videos, audios, processor)
@ -438,12 +423,12 @@ class LlavaNextVideoPlugin(BasePlugin):
@override @override
def process_messages( def process_messages(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[dict[str, str]],
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
num_image_tokens, num_video_tokens = 0, 0 num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages) messages = deepcopy(messages)
@ -502,9 +487,9 @@ class LlavaNextVideoPlugin(BasePlugin):
imglens: Sequence[int], imglens: Sequence[int],
vidlens: Sequence[int], vidlens: Sequence[int],
audlens: Sequence[int], audlens: Sequence[int],
batch_ids: Sequence[List[int]], batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor) return self._get_mm_inputs(images, videos, audios, processor)
@ -514,16 +499,16 @@ class MiniCPMVPlugin(BasePlugin):
@override @override
def process_messages( def process_messages(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[dict[str, str]],
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0 num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
messages = deepcopy(messages) messages = deepcopy(messages)
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") image_processor: BaseImageProcessor = getattr(processor, "image_processor")
mm_inputs = {} mm_inputs = {}
audio_inputs = {} audio_inputs = {}
if len(images) != 0 and len(videos) != 0: if len(images) != 0 and len(videos) != 0:
@ -619,9 +604,9 @@ class MiniCPMVPlugin(BasePlugin):
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
processor: "ProcessorMixin", processor: "ProcessorMixin",
**kwargs, **kwargs,
) -> Dict[str, "torch.Tensor"]: ) -> dict[str, "torch.Tensor"]:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") image_processor: BaseImageProcessor = getattr(processor, "image_processor")
feature_extractor: "SequenceFeatureExtractor" = getattr(processor, "feature_extractor", None) feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
mm_inputs = {} mm_inputs = {}
if len(images) != 0: if len(images) != 0:
images = self._regularize_images( images = self._regularize_images(
@ -691,9 +676,9 @@ class MiniCPMVPlugin(BasePlugin):
imglens: Sequence[int], imglens: Sequence[int],
vidlens: Sequence[int], vidlens: Sequence[int],
audlens: Sequence[int], audlens: Sequence[int],
batch_ids: Sequence[List[int]], batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
# image bound # image bound
image_bounds_list = [] image_bounds_list = []
@ -756,12 +741,12 @@ class MllamaPlugin(BasePlugin):
@override @override
def process_messages( def process_messages(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[dict[str, str]],
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
num_image_tokens = 0 num_image_tokens = 0
messages = deepcopy(messages) messages = deepcopy(messages)
@ -782,10 +767,9 @@ class MllamaPlugin(BasePlugin):
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
processor: "ProcessorMixin", processor: "ProcessorMixin",
imglens: List[int], imglens: list[int],
) -> Dict[str, "torch.Tensor"]: ) -> dict[str, "torch.Tensor"]:
r""" r"""Process visual inputs for mllama because its image processor only accepts List[List[ImageInput]].
Processes visual inputs for mllama because its image processor only accepts List[List[ImageInput]].
Returns: Returns:
pixel_values: tensor with shape pixel_values: tensor with shape
@ -794,8 +778,9 @@ class MllamaPlugin(BasePlugin):
aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1). aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1).
aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4). aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4).
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1). num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
""" """
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") image_processor: BaseImageProcessor = getattr(processor, "image_processor")
mm_inputs = {} mm_inputs = {}
if len(images) > 0: if len(images) > 0:
images = self._regularize_images( images = self._regularize_images(
@ -821,9 +806,9 @@ class MllamaPlugin(BasePlugin):
imglens: Sequence[int], imglens: Sequence[int],
vidlens: Sequence[int], vidlens: Sequence[int],
audlens: Sequence[int], audlens: Sequence[int],
batch_ids: Sequence[List[int]], batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor, imglens) mm_inputs = self._get_mm_inputs(images, videos, audios, processor, imglens)
if mm_inputs: if mm_inputs:
@ -850,12 +835,12 @@ class PaliGemmaPlugin(BasePlugin):
@override @override
def process_messages( def process_messages(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[dict[str, str]],
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
num_image_tokens = 0 num_image_tokens = 0
messages = deepcopy(messages) messages = deepcopy(messages)
@ -875,14 +860,14 @@ class PaliGemmaPlugin(BasePlugin):
@override @override
def process_token_ids( def process_token_ids(
self, self,
input_ids: List[int], input_ids: list[int],
labels: Optional[List[int]], labels: Optional[list[int]],
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Tuple[List[int], Optional[List[int]]]: ) -> tuple[list[int], Optional[list[int]]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
num_images = len(images) num_images = len(images)
image_seqlen = num_images * getattr(processor, "image_seqlen") if self.expand_mm_tokens else 0 # skip mm token image_seqlen = num_images * getattr(processor, "image_seqlen") if self.expand_mm_tokens else 0 # skip mm token
@ -902,9 +887,9 @@ class PaliGemmaPlugin(BasePlugin):
imglens: Sequence[int], imglens: Sequence[int],
vidlens: Sequence[int], vidlens: Sequence[int],
audlens: Sequence[int], audlens: Sequence[int],
batch_ids: Sequence[List[int]], batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
seqlens = [len(input_ids) for input_ids in batch_ids] seqlens = [len(input_ids) for input_ids in batch_ids]
mm_inputs = self._get_mm_inputs(images, videos, audios, processor) mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
@ -917,12 +902,12 @@ class PixtralPlugin(BasePlugin):
@override @override
def process_messages( def process_messages(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[dict[str, str]],
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
patch_size = getattr(processor, "patch_size") patch_size = getattr(processor, "patch_size")
image_token = getattr(processor, "image_token") image_token = getattr(processor, "image_token")
@ -968,9 +953,9 @@ class PixtralPlugin(BasePlugin):
imglens: Sequence[int], imglens: Sequence[int],
vidlens: Sequence[int], vidlens: Sequence[int],
audlens: Sequence[int], audlens: Sequence[int],
batch_ids: Sequence[List[int]], batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor) mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
mm_inputs.pop("image_sizes", None) mm_inputs.pop("image_sizes", None)
@ -982,12 +967,12 @@ class Qwen2AudioPlugin(BasePlugin):
@override @override
def process_messages( def process_messages(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[dict[str, str]],
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
bos_token: str = getattr(processor, "audio_bos_token") bos_token: str = getattr(processor, "audio_bos_token")
eos_token: str = getattr(processor, "audio_eos_token") eos_token: str = getattr(processor, "audio_eos_token")
@ -1028,9 +1013,9 @@ class Qwen2AudioPlugin(BasePlugin):
imglens: Sequence[int], imglens: Sequence[int],
vidlens: Sequence[int], vidlens: Sequence[int],
audlens: Sequence[int], audlens: Sequence[int],
batch_ids: Sequence[List[int]], batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor) return self._get_mm_inputs(images, videos, audios, processor)
@ -1057,13 +1042,13 @@ class Qwen2VLPlugin(BasePlugin):
@override @override
def _regularize_videos( def _regularize_videos(
self, videos: Sequence["VideoInput"], **kwargs self, videos: Sequence["VideoInput"], **kwargs
) -> Tuple[List[List["ImageObject"]], List[float]]: ) -> tuple[list[list["ImageObject"]], list[float]]:
results, fps_per_video = [], [] results, fps_per_video = [], []
for video in videos: for video in videos:
container = av.open(video, "r") container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video") video_stream = next(stream for stream in container.streams if stream.type == "video")
sample_indices = self._get_video_sample_indices(video_stream, **kwargs) sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
frames: List["ImageObject"] = [] frames: list[ImageObject] = []
container.seek(0) container.seek(0)
for frame_idx, frame in enumerate(container.decode(video_stream)): for frame_idx, frame in enumerate(container.decode(video_stream)):
if frame_idx in sample_indices: if frame_idx in sample_indices:
@ -1088,8 +1073,8 @@ class Qwen2VLPlugin(BasePlugin):
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
processor: "ProcessorMixin", processor: "ProcessorMixin",
) -> Dict[str, "torch.Tensor"]: ) -> dict[str, "torch.Tensor"]:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor", None) image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
mm_inputs = {} mm_inputs = {}
if len(images) != 0: if len(images) != 0:
images = self._regularize_images( images = self._regularize_images(
@ -1115,16 +1100,16 @@ class Qwen2VLPlugin(BasePlugin):
@override @override
def process_messages( def process_messages(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[dict[str, str]],
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
num_image_tokens, num_video_tokens = 0, 0 num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages) messages = deepcopy(messages)
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") image_processor: BaseImageProcessor = getattr(processor, "image_processor")
merge_length: int = getattr(image_processor, "merge_size") ** 2 merge_length: int = getattr(image_processor, "merge_size") ** 2
if self.expand_mm_tokens: if self.expand_mm_tokens:
@ -1176,13 +1161,13 @@ class Qwen2VLPlugin(BasePlugin):
imglens: Sequence[int], imglens: Sequence[int],
vidlens: Sequence[int], vidlens: Sequence[int],
audlens: Sequence[int], audlens: Sequence[int],
batch_ids: Sequence[List[int]], batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor) mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
fps_per_video = mm_inputs.pop("fps_per_video", []) fps_per_video = mm_inputs.pop("fps_per_video", [])
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") image_processor: BaseImageProcessor = getattr(processor, "image_processor")
if "second_per_grid_ts" in processor.model_input_names and fps_per_video: if "second_per_grid_ts" in processor.model_input_names and fps_per_video:
mm_inputs["second_per_grid_ts"] = [image_processor.temporal_patch_size / fps for fps in fps_per_video] mm_inputs["second_per_grid_ts"] = [image_processor.temporal_patch_size / fps for fps in fps_per_video]
@ -1194,12 +1179,12 @@ class VideoLlavaPlugin(BasePlugin):
@override @override
def process_messages( def process_messages(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[dict[str, str]],
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
num_image_tokens, num_video_tokens = 0, 0 num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages) messages = deepcopy(messages)
@ -1255,9 +1240,9 @@ class VideoLlavaPlugin(BasePlugin):
imglens: Sequence[int], imglens: Sequence[int],
vidlens: Sequence[int], vidlens: Sequence[int],
audlens: Sequence[int], audlens: Sequence[int],
batch_ids: Sequence[List[int]], batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor) return self._get_mm_inputs(images, videos, audios, processor)
@ -1277,10 +1262,8 @@ PLUGINS = {
} }
def register_mm_plugin(name: str, plugin_class: Type["BasePlugin"]) -> None: def register_mm_plugin(name: str, plugin_class: type["BasePlugin"]) -> None:
r""" r"""Register a multimodal plugin."""
Registers a multimodal plugin.
"""
if name in PLUGINS: if name in PLUGINS:
raise ValueError(f"Multimodal plugin {name} already exists.") raise ValueError(f"Multimodal plugin {name} already exists.")
@ -1293,9 +1276,7 @@ def get_mm_plugin(
video_token: Optional[str] = None, video_token: Optional[str] = None,
audio_token: Optional[str] = None, audio_token: Optional[str] = None,
) -> "BasePlugin": ) -> "BasePlugin":
r""" r"""Get plugin for multimodal inputs."""
Gets plugin for multimodal inputs.
"""
if name not in PLUGINS: if name not in PLUGINS:
raise ValueError(f"Multimodal plugin `{name}` not found.") raise ValueError(f"Multimodal plugin `{name}` not found.")

View File

@ -14,8 +14,9 @@
import json import json
import os import os
from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Literal, Optional, Sequence from typing import Any, Literal, Optional
from transformers.utils import cached_file from transformers.utils import cached_file
@ -25,9 +26,7 @@ from ..extras.misc import use_modelscope, use_openmind
@dataclass @dataclass
class DatasetAttr: class DatasetAttr:
r""" r"""Dataset attributes."""
Dataset attributes.
"""
# basic configs # basic configs
load_from: Literal["hf_hub", "ms_hub", "om_hub", "script", "file"] load_from: Literal["hf_hub", "ms_hub", "om_hub", "script", "file"]
@ -68,10 +67,10 @@ class DatasetAttr:
def __repr__(self) -> str: def __repr__(self) -> str:
return self.dataset_name return self.dataset_name
def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None: def set_attr(self, key: str, obj: dict[str, Any], default: Optional[Any] = None) -> None:
setattr(self, key, obj.get(key, default)) setattr(self, key, obj.get(key, default))
def join(self, attr: Dict[str, Any]) -> None: def join(self, attr: dict[str, Any]) -> None:
self.set_attr("formatting", attr, default="alpaca") self.set_attr("formatting", attr, default="alpaca")
self.set_attr("ranking", attr, default=False) self.set_attr("ranking", attr, default=False)
self.set_attr("subset", attr) self.set_attr("subset", attr)
@ -92,10 +91,8 @@ class DatasetAttr:
self.set_attr(tag, attr["tags"]) self.set_attr(tag, attr["tags"])
def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -> List["DatasetAttr"]: def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -> list["DatasetAttr"]:
r""" r"""Get the attributes of the datasets."""
Gets the attributes of the datasets.
"""
if dataset_names is None: if dataset_names is None:
dataset_names = [] dataset_names = []
@ -116,7 +113,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
dataset_info = None dataset_info = None
dataset_list: List["DatasetAttr"] = [] dataset_list: list[DatasetAttr] = []
for name in dataset_names: for name in dataset_names:
if dataset_info is None: # dataset_dir is ONLINE if dataset_info is None: # dataset_dir is ONLINE
if use_modelscope(): if use_modelscope():

View File

@ -9,9 +9,9 @@ from .unsupervised import UnsupervisedDatasetProcessor
__all__ = [ __all__ = [
"DatasetProcessor", "DatasetProcessor",
"FeedbackDatasetProcessor", "FeedbackDatasetProcessor",
"PackedSupervisedDatasetProcessor",
"PairwiseDatasetProcessor", "PairwiseDatasetProcessor",
"PretrainDatasetProcessor", "PretrainDatasetProcessor",
"PackedSupervisedDatasetProcessor",
"SupervisedDatasetProcessor", "SupervisedDatasetProcessor",
"UnsupervisedDatasetProcessor", "UnsupervisedDatasetProcessor",
] ]

View File

@ -13,7 +13,8 @@
# limitations under the License. # limitations under the License.
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Optional
from ...extras import logging from ...extras import logging
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
@ -30,15 +31,15 @@ logger = logging.get_logger(__name__)
class FeedbackDatasetProcessor(DatasetProcessor): class FeedbackDatasetProcessor(DatasetProcessor):
def _encode_data_example( def _encode_data_example(
self, self,
prompt: Sequence[Dict[str, str]], prompt: Sequence[dict[str, str]],
response: Sequence[Dict[str, str]], response: Sequence[dict[str, str]],
kl_response: Sequence[Dict[str, str]], kl_response: Sequence[dict[str, str]],
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
) -> Tuple[List[int], List[int], List[int], List[int], bool]: ) -> tuple[list[int], list[int], list[int], list[int], bool]:
if response[0]["content"]: # desired example if response[0]["content"]: # desired example
kto_tag = True kto_tag = True
messages = prompt + [response[0]] messages = prompt + [response[0]]
@ -82,7 +83,7 @@ class FeedbackDatasetProcessor(DatasetProcessor):
kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids
return input_ids, labels, kl_input_ids, kl_labels, kto_tag return input_ids, labels, kl_input_ids, kl_labels, kto_tag
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs # create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
kl_response = examples["_response"][::-1] kl_response = examples["_response"][::-1]
model_inputs = defaultdict(list) model_inputs = defaultdict(list)
@ -121,7 +122,7 @@ class FeedbackDatasetProcessor(DatasetProcessor):
return model_inputs return model_inputs
def print_data_example(self, example: Dict[str, List[int]]) -> None: def print_data_example(self, example: dict[str, list[int]]) -> None:
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"])) valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
print("input_ids:\n{}".format(example["input_ids"])) print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False))) print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))

View File

@ -13,7 +13,8 @@
# limitations under the License. # limitations under the License.
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Optional
from ...extras import logging from ...extras import logging
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
@ -30,14 +31,14 @@ logger = logging.get_logger(__name__)
class PairwiseDatasetProcessor(DatasetProcessor): class PairwiseDatasetProcessor(DatasetProcessor):
def _encode_data_example( def _encode_data_example(
self, self,
prompt: Sequence[Dict[str, str]], prompt: Sequence[dict[str, str]],
response: Sequence[Dict[str, str]], response: Sequence[dict[str, str]],
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
) -> Tuple[List[int], List[int], List[int], List[int]]: ) -> tuple[list[int], list[int], list[int], list[int]]:
chosen_messages = self.template.mm_plugin.process_messages( chosen_messages = self.template.mm_plugin.process_messages(
prompt + [response[0]], images, videos, audios, self.processor prompt + [response[0]], images, videos, audios, self.processor
) )
@ -68,7 +69,7 @@ class PairwiseDatasetProcessor(DatasetProcessor):
rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>` # build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = defaultdict(list) model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])): for i in range(len(examples["_prompt"])):
@ -99,7 +100,7 @@ class PairwiseDatasetProcessor(DatasetProcessor):
return model_inputs return model_inputs
def print_data_example(self, example: Dict[str, List[int]]) -> None: def print_data_example(self, example: dict[str, list[int]]) -> None:
valid_chosen_labels = list(filter(lambda x: x != IGNORE_INDEX, example["chosen_labels"])) valid_chosen_labels = list(filter(lambda x: x != IGNORE_INDEX, example["chosen_labels"]))
valid_rejected_labels = list(filter(lambda x: x != IGNORE_INDEX, example["rejected_labels"])) valid_rejected_labels = list(filter(lambda x: x != IGNORE_INDEX, example["rejected_labels"]))
print("chosen_input_ids:\n{}".format(example["chosen_input_ids"])) print("chosen_input_ids:\n{}".format(example["chosen_input_ids"]))

View File

@ -17,14 +17,14 @@
from dataclasses import dataclass from dataclasses import dataclass
from itertools import chain from itertools import chain
from typing import Any, Dict, List from typing import Any
from .processor_utils import DatasetProcessor from .processor_utils import DatasetProcessor
@dataclass @dataclass
class PretrainDatasetProcessor(DatasetProcessor): class PretrainDatasetProcessor(DatasetProcessor):
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled # build grouped texts with format `X1 X2 X3 ...` if packing is enabled
eos_token = "<|end_of_text|>" if self.data_args.template == "llama3" else self.tokenizer.eos_token eos_token = "<|end_of_text|>" if self.data_args.template == "llama3" else self.tokenizer.eos_token
text_examples = [messages[0]["content"] + eos_token for messages in examples["_prompt"]] text_examples = [messages[0]["content"] + eos_token for messages in examples["_prompt"]]
@ -52,6 +52,6 @@ class PretrainDatasetProcessor(DatasetProcessor):
return result return result
def print_data_example(self, example: Dict[str, List[int]]) -> None: def print_data_example(self, example: dict[str, list[int]]) -> None:
print("input_ids:\n{}".format(example["input_ids"])) print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False))) print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))

View File

@ -14,8 +14,9 @@
import bisect import bisect
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING: if TYPE_CHECKING:
@ -27,9 +28,7 @@ if TYPE_CHECKING:
@dataclass @dataclass
class DatasetProcessor(ABC): class DatasetProcessor(ABC):
r""" r"""A class for data processors."""
A class for data processors.
"""
template: "Template" template: "Template"
tokenizer: "PreTrainedTokenizer" tokenizer: "PreTrainedTokenizer"
@ -37,32 +36,24 @@ class DatasetProcessor(ABC):
data_args: "DataArguments" data_args: "DataArguments"
@abstractmethod @abstractmethod
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
r""" r"""Build model inputs from the examples."""
Builds model inputs from the examples.
"""
... ...
@abstractmethod @abstractmethod
def print_data_example(self, example: Dict[str, List[int]]) -> None: def print_data_example(self, example: dict[str, list[int]]) -> None:
r""" r"""Print a data example to stdout."""
Print a data example to stdout.
"""
... ...
def search_for_fit(numbers: Sequence[int], capacity: int) -> int: def search_for_fit(numbers: Sequence[int], capacity: int) -> int:
r""" r"""Find the index of largest number that fits into the knapsack with the given capacity."""
Finds the index of largest number that fits into the knapsack with the given capacity.
"""
index = bisect.bisect(numbers, capacity) index = bisect.bisect(numbers, capacity)
return -1 if index == 0 else (index - 1) return -1 if index == 0 else (index - 1)
def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]: def greedy_knapsack(numbers: list[int], capacity: int) -> list[list[int]]:
r""" r"""Implement efficient greedy algorithm with binary search for the knapsack problem."""
An efficient greedy algorithm with binary search for the knapsack problem.
"""
numbers.sort() # sort numbers in ascending order for binary search numbers.sort() # sort numbers in ascending order for binary search
knapsacks = [] knapsacks = []
@ -83,10 +74,8 @@ def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
return knapsacks return knapsacks
def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]: def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> tuple[int, int]:
r""" r"""Compute the real sequence length after truncation by the cutoff_len."""
Computes the real sequence length after truncation by the cutoff_len.
"""
if target_len * 2 < cutoff_len: # truncate source if target_len * 2 < cutoff_len: # truncate source
max_target_len = cutoff_len max_target_len = cutoff_len
elif source_len * 2 < cutoff_len: # truncate target elif source_len * 2 < cutoff_len: # truncate target

View File

@ -13,8 +13,9 @@
# limitations under the License. # limitations under the License.
from collections import defaultdict from collections import defaultdict
from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, Optional
from ...extras import logging from ...extras import logging
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
@ -32,14 +33,14 @@ logger = logging.get_logger(__name__)
class SupervisedDatasetProcessor(DatasetProcessor): class SupervisedDatasetProcessor(DatasetProcessor):
def _encode_data_example( def _encode_data_example(
self, self,
prompt: Sequence[Dict[str, str]], prompt: Sequence[dict[str, str]],
response: Sequence[Dict[str, str]], response: Sequence[dict[str, str]],
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
) -> Tuple[List[int], List[int]]: ) -> tuple[list[int], list[int]]:
messages = self.template.mm_plugin.process_messages(prompt + response, images, videos, audios, self.processor) messages = self.template.mm_plugin.process_messages(prompt + response, images, videos, audios, self.processor)
input_ids, labels = self.template.mm_plugin.process_token_ids( input_ids, labels = self.template.mm_plugin.process_token_ids(
[], [], images, videos, audios, self.tokenizer, self.processor [], [], images, videos, audios, self.tokenizer, self.processor
@ -85,7 +86,7 @@ class SupervisedDatasetProcessor(DatasetProcessor):
return input_ids, labels return input_ids, labels
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>` # build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair. # for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = defaultdict(list) model_inputs = defaultdict(list)
@ -114,7 +115,7 @@ class SupervisedDatasetProcessor(DatasetProcessor):
return model_inputs return model_inputs
def print_data_example(self, example: Dict[str, List[int]]) -> None: def print_data_example(self, example: dict[str, list[int]]) -> None:
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"])) valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
print("input_ids:\n{}".format(example["input_ids"])) print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False))) print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
@ -124,7 +125,7 @@ class SupervisedDatasetProcessor(DatasetProcessor):
@dataclass @dataclass
class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor): class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
# TODO: use `position_ids` to achieve packing # TODO: use `position_ids` to achieve packing
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>` # build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>` # and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`

View File

@ -13,7 +13,8 @@
# limitations under the License. # limitations under the License.
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Optional
from ...extras import logging from ...extras import logging
from ..data_utils import Role from ..data_utils import Role
@ -30,14 +31,14 @@ logger = logging.get_logger(__name__)
class UnsupervisedDatasetProcessor(DatasetProcessor): class UnsupervisedDatasetProcessor(DatasetProcessor):
def _encode_data_example( def _encode_data_example(
self, self,
prompt: Sequence[Dict[str, str]], prompt: Sequence[dict[str, str]],
response: Sequence[Dict[str, str]], response: Sequence[dict[str, str]],
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
) -> Tuple[List[int], List[int]]: ) -> tuple[list[int], list[int]]:
if len(response) == 1: if len(response) == 1:
messages = prompt + response messages = prompt + response
else: else:
@ -56,7 +57,7 @@ class UnsupervisedDatasetProcessor(DatasetProcessor):
labels = labels[:target_len] labels = labels[:target_len]
return input_ids, labels return input_ids, labels
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>` # build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = defaultdict(list) model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])): for i in range(len(examples["_prompt"])):
@ -84,7 +85,7 @@ class UnsupervisedDatasetProcessor(DatasetProcessor):
return model_inputs return model_inputs
def print_data_example(self, example: Dict[str, List[int]]) -> None: def print_data_example(self, example: dict[str, list[int]]) -> None:
print("input_ids:\n{}".format(example["input_ids"])) print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False))) print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"])) print("label_ids:\n{}".format(example["labels"]))

View File

@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Type, Union from typing import TYPE_CHECKING, Optional, Union
from typing_extensions import override from typing_extensions import override
@ -46,8 +47,8 @@ class Template:
format_tools: "Formatter" format_tools: "Formatter"
format_prefix: "Formatter" format_prefix: "Formatter"
default_system: str default_system: str
stop_words: List[str] stop_words: list[str]
thought_words: Tuple[str, str] thought_words: tuple[str, str]
efficient_eos: bool efficient_eos: bool
replace_eos: bool replace_eos: bool
replace_jinja_template: bool replace_jinja_template: bool
@ -56,13 +57,11 @@ class Template:
def encode_oneturn( def encode_oneturn(
self, self,
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
messages: Sequence[Dict[str, str]], messages: Sequence[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
) -> Tuple[List[int], List[int]]: ) -> tuple[list[int], list[int]]:
r""" r"""Return a single pair of token ids representing prompt and response respectively."""
Returns a single pair of token ids representing prompt and response respectively.
"""
encoded_messages = self._encode(tokenizer, messages, system, tools) encoded_messages = self._encode(tokenizer, messages, system, tools)
prompt_ids = [] prompt_ids = []
for encoded_ids in encoded_messages[:-1]: for encoded_ids in encoded_messages[:-1]:
@ -74,36 +73,28 @@ class Template:
def encode_multiturn( def encode_multiturn(
self, self,
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
messages: Sequence[Dict[str, str]], messages: Sequence[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
) -> List[Tuple[List[int], List[int]]]: ) -> list[tuple[list[int], list[int]]]:
r""" r"""Return multiple pairs of token ids representing prompts and responses respectively."""
Returns multiple pairs of token ids representing prompts and responses respectively.
"""
encoded_messages = self._encode(tokenizer, messages, system, tools) encoded_messages = self._encode(tokenizer, messages, system, tools)
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)] return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
def extract_tool(self, content: str) -> Union[str, List["FunctionCall"]]: def extract_tool(self, content: str) -> Union[str, list["FunctionCall"]]:
r""" r"""Extract tool message."""
Extracts tool message.
"""
return self.format_tools.extract(content) return self.format_tools.extract(content)
def get_stop_token_ids(self, tokenizer: "PreTrainedTokenizer") -> List[int]: def get_stop_token_ids(self, tokenizer: "PreTrainedTokenizer") -> list[int]:
r""" r"""Return stop token ids."""
Returns stop token ids.
"""
stop_token_ids = {tokenizer.eos_token_id} stop_token_ids = {tokenizer.eos_token_id}
for token in self.stop_words: for token in self.stop_words:
stop_token_ids.add(tokenizer.convert_tokens_to_ids(token)) stop_token_ids.add(tokenizer.convert_tokens_to_ids(token))
return list(stop_token_ids) return list(stop_token_ids)
def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> List[int]: def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> list[int]:
r""" r"""Convert elements to token ids."""
Converts elements to token ids.
"""
token_ids = [] token_ids = []
for elem in elements: for elem in elements:
if isinstance(elem, str): if isinstance(elem, str):
@ -124,14 +115,14 @@ class Template:
def _encode( def _encode(
self, self,
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
messages: Sequence[Dict[str, str]], messages: Sequence[dict[str, str]],
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
) -> List[List[int]]: ) -> list[list[int]]:
r""" r"""Encode formatted inputs to pairs of token ids.
Encodes formatted inputs to pairs of token ids.
Turn 0: prefix + system + query resp Turn 0: prefix + system + query resp
Turn t: query resp Turn t: query resp.
""" """
system = system or self.default_system system = system or self.default_system
encoded_messages = [] encoded_messages = []
@ -161,9 +152,7 @@ class Template:
@staticmethod @staticmethod
def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None: def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None:
r""" r"""Add or replace eos token to the tokenizer."""
Adds or replaces eos token to the tokenizer.
"""
is_added = tokenizer.eos_token_id is None is_added = tokenizer.eos_token_id is None
num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token}) num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
@ -176,9 +165,7 @@ class Template:
logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.") logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
def fix_special_tokens(self, tokenizer: "PreTrainedTokenizer") -> None: def fix_special_tokens(self, tokenizer: "PreTrainedTokenizer") -> None:
r""" r"""Add eos token and pad token to the tokenizer."""
Adds eos token and pad token to the tokenizer.
"""
stop_words = self.stop_words stop_words = self.stop_words
if self.replace_eos: if self.replace_eos:
if not stop_words: if not stop_words:
@ -204,16 +191,12 @@ class Template:
@staticmethod @staticmethod
def _jinja_escape(content: str) -> str: def _jinja_escape(content: str) -> str:
r""" r"""Escape single quotes in content."""
Escape single quotes in content.
"""
return content.replace("'", r"\'") return content.replace("'", r"\'")
@staticmethod @staticmethod
def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str: def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str:
r""" r"""Convert slots to jinja template."""
Converts slots to jinja template.
"""
slot_items = [] slot_items = []
for slot in slots: for slot in slots:
if isinstance(slot, str): if isinstance(slot, str):
@ -235,9 +218,7 @@ class Template:
return " + ".join(slot_items) return " + ".join(slot_items)
def _get_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> str: def _get_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> str:
r""" r"""Return the jinja template."""
Returns the jinja template.
"""
prefix = self._convert_slots_to_jinja(self.format_prefix.apply(), tokenizer) prefix = self._convert_slots_to_jinja(self.format_prefix.apply(), tokenizer)
system = self._convert_slots_to_jinja(self.format_system.apply(), tokenizer, placeholder="system_message") system = self._convert_slots_to_jinja(self.format_system.apply(), tokenizer, placeholder="system_message")
user = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer) user = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer)
@ -265,9 +246,7 @@ class Template:
return jinja_template return jinja_template
def fix_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> None: def fix_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> None:
r""" r"""Replace the jinja template in the tokenizer."""
Replaces the jinja template in the tokenizer.
"""
if tokenizer.chat_template is None or self.replace_jinja_template: if tokenizer.chat_template is None or self.replace_jinja_template:
try: try:
tokenizer.chat_template = self._get_jinja_template(tokenizer) tokenizer.chat_template = self._get_jinja_template(tokenizer)
@ -278,9 +257,7 @@ class Template:
def _convert_slots_to_ollama( def _convert_slots_to_ollama(
slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content" slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content"
) -> str: ) -> str:
r""" r"""Convert slots to ollama template."""
Converts slots to ollama template.
"""
slot_items = [] slot_items = []
for slot in slots: for slot in slots:
if isinstance(slot, str): if isinstance(slot, str):
@ -302,9 +279,7 @@ class Template:
return "".join(slot_items) return "".join(slot_items)
def _get_ollama_template(self, tokenizer: "PreTrainedTokenizer") -> str: def _get_ollama_template(self, tokenizer: "PreTrainedTokenizer") -> str:
r""" r"""Return the ollama template."""
Returns the ollama template.
"""
prefix = self._convert_slots_to_ollama(self.format_prefix.apply(), tokenizer) prefix = self._convert_slots_to_ollama(self.format_prefix.apply(), tokenizer)
system = self._convert_slots_to_ollama(self.format_system.apply(), tokenizer, placeholder=".System") system = self._convert_slots_to_ollama(self.format_system.apply(), tokenizer, placeholder=".System")
user = self._convert_slots_to_ollama(self.format_user.apply(), tokenizer, placeholder=".Content") user = self._convert_slots_to_ollama(self.format_user.apply(), tokenizer, placeholder=".Content")
@ -316,8 +291,7 @@ class Template:
) )
def get_ollama_modelfile(self, tokenizer: "PreTrainedTokenizer") -> str: def get_ollama_modelfile(self, tokenizer: "PreTrainedTokenizer") -> str:
r""" r"""Return the ollama modelfile.
Returns the ollama modelfile.
TODO: support function calling. TODO: support function calling.
""" """
@ -340,10 +314,10 @@ class Llama2Template(Template):
def _encode( def _encode(
self, self,
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
messages: Sequence[Dict[str, str]], messages: Sequence[dict[str, str]],
system: str, system: str,
tools: str, tools: str,
) -> List[List[int]]: ) -> list[list[int]]:
system = system or self.default_system system = system or self.default_system
encoded_messages = [] encoded_messages = []
for i, message in enumerate(messages): for i, message in enumerate(messages):
@ -402,7 +376,7 @@ class Llama2Template(Template):
return jinja_template return jinja_template
TEMPLATES: Dict[str, "Template"] = {} TEMPLATES: dict[str, "Template"] = {}
def register_template( def register_template(
@ -416,15 +390,14 @@ def register_template(
format_prefix: Optional["Formatter"] = None, format_prefix: Optional["Formatter"] = None,
default_system: str = "", default_system: str = "",
stop_words: Optional[Sequence[str]] = None, stop_words: Optional[Sequence[str]] = None,
thought_words: Optional[Tuple[str, str]] = None, thought_words: Optional[tuple[str, str]] = None,
efficient_eos: bool = False, efficient_eos: bool = False,
replace_eos: bool = False, replace_eos: bool = False,
replace_jinja_template: bool = False, replace_jinja_template: bool = False,
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"), mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
template_class: Type["Template"] = Template, template_class: type["Template"] = Template,
) -> None: ) -> None:
r""" r"""Register a chat template.
Registers a chat template.
To add the following chat template: To add the following chat template:
``` ```
@ -472,9 +445,7 @@ def register_template(
def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template": def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
r""" r"""Extract a chat template from the tokenizer."""
Extracts a chat template from the tokenizer.
"""
def find_diff(short_str: str, long_str: str) -> str: def find_diff(short_str: str, long_str: str) -> str:
i, j = 0, 0 i, j = 0, 0
@ -532,9 +503,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template": def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template":
r""" r"""Get chat template and fixes the tokenizer."""
Gets chat template and fixes the tokenizer.
"""
if data_args.template is None: if data_args.template is None:
if isinstance(tokenizer.chat_template, str): if isinstance(tokenizer.chat_template, str):
logger.warning_rank0("`template` was not specified, try parsing the chat template from the tokenizer.") logger.warning_rank0("`template` was not specified, try parsing the chat template from the tokenizer.")
@ -1149,7 +1118,8 @@ register_template(
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
default_system=( default_system=(
"你是一个经过良好训练的AI助手你的名字是Marco-o1.由阿里国际数字商业集团的AI Business创造.\n## 重要!!!!!\n" "你是一个经过良好训练的AI助手你的名字是Marco-o1."
"由阿里国际数字商业集团的AI Business创造.\n## 重要!!!!!\n"
"当你回答问题时,你的思考应该在<Thought>内完成,<Output>内输出你的结果。\n" "当你回答问题时,你的思考应该在<Thought>内完成,<Output>内输出你的结果。\n"
"<Thought>应该尽可能是英文但是有2个特例一个是对原文中的引用另一个是是数学应该使用markdown格式<Output>内的输出需要遵循用户输入的语言。\n" "<Thought>应该尽可能是英文但是有2个特例一个是对原文中的引用另一个是是数学应该使用markdown格式<Output>内的输出需要遵循用户输入的语言。\n"
), ),

View File

@ -17,7 +17,7 @@ import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, NamedTuple, Tuple, Union from typing import Any, NamedTuple, Union
from typing_extensions import override from typing_extensions import override
@ -60,31 +60,24 @@ QWEN_TOOL_PROMPT = (
@dataclass @dataclass
class ToolUtils(ABC): class ToolUtils(ABC):
""" """Base class for tool utilities."""
Base class for tool utilities.
"""
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str: def tool_formatter(tools: list[dict[str, Any]]) -> str:
r""" r"""Generate the system message describing all the available tools."""
Generates the system message describing all the available tools.
"""
... ...
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def function_formatter(functions: List["FunctionCall"]) -> str: def function_formatter(functions: list["FunctionCall"]) -> str:
r""" r"""Generate the assistant message including all the tool calls."""
Generates the assistant message including all the tool calls.
"""
... ...
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]: def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
r""" r"""Extract all the function calls from the assistant message.
Extracts all the function calls from the assistant message.
It should be an inverse function of `function_formatter`. It should be an inverse function of `function_formatter`.
""" """
@ -92,13 +85,11 @@ class ToolUtils(ABC):
class DefaultToolUtils(ToolUtils): class DefaultToolUtils(ToolUtils):
r""" r"""Default tool using template."""
Default tool using template.
"""
@override @override
@staticmethod @staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str: def tool_formatter(tools: list[dict[str, Any]]) -> str:
tool_text = "" tool_text = ""
tool_names = [] tool_names = []
for tool in tools: for tool in tools:
@ -132,7 +123,7 @@ class DefaultToolUtils(ToolUtils):
@override @override
@staticmethod @staticmethod
def function_formatter(functions: List["FunctionCall"]) -> str: def function_formatter(functions: list["FunctionCall"]) -> str:
function_text = "" function_text = ""
for name, arguments in functions: for name, arguments in functions:
function_text += f"Action: {name}\nAction Input: {arguments}\n" function_text += f"Action: {name}\nAction Input: {arguments}\n"
@ -141,9 +132,9 @@ class DefaultToolUtils(ToolUtils):
@override @override
@staticmethod @staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]: def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL) regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL)
action_match: List[Tuple[str, str]] = re.findall(regex, content) action_match: list[tuple[str, str]] = re.findall(regex, content)
if not action_match: if not action_match:
return content return content
@ -161,13 +152,11 @@ class DefaultToolUtils(ToolUtils):
class GLM4ToolUtils(ToolUtils): class GLM4ToolUtils(ToolUtils):
r""" r"""GLM-4 tool using template."""
GLM-4 tool using template.
"""
@override @override
@staticmethod @staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str: def tool_formatter(tools: list[dict[str, Any]]) -> str:
tool_text = "" tool_text = ""
for tool in tools: for tool in tools:
tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format( tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(
@ -178,7 +167,7 @@ class GLM4ToolUtils(ToolUtils):
@override @override
@staticmethod @staticmethod
def function_formatter(functions: List["FunctionCall"]) -> str: def function_formatter(functions: list["FunctionCall"]) -> str:
if len(functions) > 1: if len(functions) > 1:
raise ValueError("GLM-4 does not support parallel functions.") raise ValueError("GLM-4 does not support parallel functions.")
@ -186,7 +175,7 @@ class GLM4ToolUtils(ToolUtils):
@override @override
@staticmethod @staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]: def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
if "\n" not in content: if "\n" not in content:
return content return content
@ -200,15 +189,14 @@ class GLM4ToolUtils(ToolUtils):
class Llama3ToolUtils(ToolUtils): class Llama3ToolUtils(ToolUtils):
r""" r"""Llama 3.x tool using template with `tools_in_user_message=False`.
Llama 3.x tool using template with `tools_in_user_message=False`.
Reference: https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling Reference: https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling
""" """
@override @override
@staticmethod @staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str: def tool_formatter(tools: list[dict[str, Any]]) -> str:
date = datetime.now().strftime("%d %b %Y") date = datetime.now().strftime("%d %b %Y")
tool_text = "" tool_text = ""
for tool in tools: for tool in tools:
@ -219,7 +207,7 @@ class Llama3ToolUtils(ToolUtils):
@override @override
@staticmethod @staticmethod
def function_formatter(functions: List["FunctionCall"]) -> str: def function_formatter(functions: list["FunctionCall"]) -> str:
if len(functions) > 1: if len(functions) > 1:
raise ValueError("Llama-3 does not support parallel functions.") raise ValueError("Llama-3 does not support parallel functions.")
@ -227,7 +215,7 @@ class Llama3ToolUtils(ToolUtils):
@override @override
@staticmethod @staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]: def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
try: try:
tool = json.loads(content.strip()) tool = json.loads(content.strip())
except json.JSONDecodeError: except json.JSONDecodeError:
@ -240,13 +228,11 @@ class Llama3ToolUtils(ToolUtils):
class MistralToolUtils(ToolUtils): class MistralToolUtils(ToolUtils):
r""" r"""Mistral v0.3 tool using template."""
Mistral v0.3 tool using template.
"""
@override @override
@staticmethod @staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str: def tool_formatter(tools: list[dict[str, Any]]) -> str:
wrapped_tools = [] wrapped_tools = []
for tool in tools: for tool in tools:
wrapped_tools.append({"type": "function", "function": tool}) wrapped_tools.append({"type": "function", "function": tool})
@ -255,7 +241,7 @@ class MistralToolUtils(ToolUtils):
@override @override
@staticmethod @staticmethod
def function_formatter(functions: List["FunctionCall"]) -> str: def function_formatter(functions: list["FunctionCall"]) -> str:
function_texts = [] function_texts = []
for name, arguments in functions: for name, arguments in functions:
function_texts.append(f'{{"name": "{name}", "arguments": {arguments}}}') function_texts.append(f'{{"name": "{name}", "arguments": {arguments}}}')
@ -264,7 +250,7 @@ class MistralToolUtils(ToolUtils):
@override @override
@staticmethod @staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]: def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
try: try:
tools = json.loads(content.strip()) tools = json.loads(content.strip())
except json.JSONDecodeError: except json.JSONDecodeError:
@ -284,13 +270,11 @@ class MistralToolUtils(ToolUtils):
class QwenToolUtils(ToolUtils): class QwenToolUtils(ToolUtils):
r""" r"""Qwen 2.5 tool using template."""
Qwen 2.5 tool using template.
"""
@override @override
@staticmethod @staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str: def tool_formatter(tools: list[dict[str, Any]]) -> str:
tool_text = "" tool_text = ""
for tool in tools: for tool in tools:
wrapped_tool = {"type": "function", "function": tool} wrapped_tool = {"type": "function", "function": tool}
@ -300,7 +284,7 @@ class QwenToolUtils(ToolUtils):
@override @override
@staticmethod @staticmethod
def function_formatter(functions: List["FunctionCall"]) -> str: def function_formatter(functions: list["FunctionCall"]) -> str:
function_texts = [] function_texts = []
for name, arguments in functions: for name, arguments in functions:
function_texts.append( function_texts.append(
@ -311,9 +295,9 @@ class QwenToolUtils(ToolUtils):
@override @override
@staticmethod @staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]: def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
regex = re.compile(r"<tool_call>(.+?)</tool_call>(?=\s*<tool_call>|\s*$)", re.DOTALL) regex = re.compile(r"<tool_call>(.+?)</tool_call>(?=\s*<tool_call>|\s*$)", re.DOTALL)
tool_match: List[str] = re.findall(regex, content) tool_match: list[str] = re.findall(regex, content)
if not tool_match: if not tool_match:
return content return content

View File

@ -39,7 +39,7 @@
import json import json
import os import os
from typing import TYPE_CHECKING, Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Optional
import numpy as np import numpy as np
import torch import torch
@ -59,7 +59,7 @@ if TYPE_CHECKING:
class Evaluator: class Evaluator:
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None: def __init__(self, args: Optional[dict[str, Any]] = None) -> None:
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args) self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
self.tokenizer = load_tokenizer(self.model_args)["tokenizer"] self.tokenizer = load_tokenizer(self.model_args)["tokenizer"]
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2 self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
@ -69,7 +69,7 @@ class Evaluator:
self.choice_inputs = [self.tokenizer.encode(ch, add_special_tokens=False)[-1] for ch in CHOICES] self.choice_inputs = [self.tokenizer.encode(ch, add_special_tokens=False)[-1] for ch in CHOICES]
@torch.inference_mode() @torch.inference_mode()
def batch_inference(self, batch_input: Dict[str, "torch.Tensor"]) -> List[str]: def batch_inference(self, batch_input: dict[str, "torch.Tensor"]) -> list[str]:
logits = self.model(**batch_input).logits logits = self.model(**batch_input).logits
lengths = torch.sum(batch_input["attention_mask"], dim=-1) lengths = torch.sum(batch_input["attention_mask"], dim=-1)
word_probs = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0) word_probs = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0)
@ -88,7 +88,7 @@ class Evaluator:
) )
with open(mapping, encoding="utf-8") as f: with open(mapping, encoding="utf-8") as f:
categorys: Dict[str, Dict[str, str]] = json.load(f) categorys: dict[str, dict[str, str]] = json.load(f)
category_corrects = {subj: np.array([], dtype="bool") for subj in SUBJECTS} category_corrects = {subj: np.array([], dtype="bool") for subj in SUBJECTS}
pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0) pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
@ -136,7 +136,7 @@ class Evaluator:
pbar.close() pbar.close()
self._save_results(category_corrects, results) self._save_results(category_corrects, results)
def _save_results(self, category_corrects: Dict[str, "NDArray"], results: Dict[str, Dict[int, str]]) -> None: def _save_results(self, category_corrects: dict[str, "NDArray"], results: dict[str, dict[int, str]]) -> None:
score_info = "\n".join( score_info = "\n".join(
[ [
f"{category_name:>15}: {100 * np.mean(category_correct):.2f}" f"{category_name:>15}: {100 * np.mean(category_correct):.2f}"

View File

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Sequence, Tuple
from ..data import Role from ..data import Role
from ..extras.constants import CHOICES from ..extras.constants import CHOICES
@ -25,20 +25,19 @@ class EvalTemplate:
choice: str choice: str
answer: str answer: str
def _parse_example(self, example: Dict[str, str]) -> Tuple[str, str]: def _parse_example(self, example: dict[str, str]) -> tuple[str, str]:
r""" r"""Parse eval example.
input: a dict with keys {"question", "A", "B", "C", "D", "answer"} input: a dict with keys {"question", "A", "B", "C", "D", "answer"}
output: a tuple of (prompt, response) output: a tuple of (prompt, response).
""" """
candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example] candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example]
return "".join([example["question"]] + candidates + [self.answer]), example["answer"] return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
def format_example( def format_example(
self, target_data: Dict[str, str], support_set: Sequence[Dict[str, str]], subject_name: str self, target_data: dict[str, str], support_set: Sequence[dict[str, str]], subject_name: str
) -> List[Dict[str, str]]: ) -> list[dict[str, str]]:
r""" r"""Convert dataset examples to messages."""
Converts dataset examples to messages.
"""
messages = [] messages = []
for k in range(len(support_set)): for k in range(len(support_set)):
prompt, response = self._parse_example(support_set[k]) prompt, response = self._parse_example(support_set[k])
@ -52,7 +51,7 @@ class EvalTemplate:
return messages return messages
eval_templates: Dict[str, "EvalTemplate"] = {} eval_templates: dict[str, "EvalTemplate"] = {}
def _register_eval_template(name: str, system: str, choice: str, answer: str) -> None: def _register_eval_template(name: str, system: str, choice: str, answer: str) -> None:

View File

@ -15,7 +15,7 @@
import os import os
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
from enum import Enum from enum import Enum
from typing import Dict, Optional from typing import Optional
from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME
from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME
@ -122,7 +122,7 @@ class RopeScaling(str, Enum):
def register_model_group( def register_model_group(
models: Dict[str, Dict[DownloadSource, str]], models: dict[str, dict[DownloadSource, str]],
template: Optional[str] = None, template: Optional[str] = None,
multimodal: bool = False, multimodal: bool = False,
) -> None: ) -> None:

View File

@ -32,9 +32,7 @@ _default_log_level: "logging._Level" = logging.INFO
class LoggerHandler(logging.Handler): class LoggerHandler(logging.Handler):
r""" r"""Redirect the logging output to the logging file for LLaMA Board."""
Redirects the logging output to the logging file for LLaMA Board.
"""
def __init__(self, output_dir: str) -> None: def __init__(self, output_dir: str) -> None:
super().__init__() super().__init__()
@ -67,9 +65,7 @@ class LoggerHandler(logging.Handler):
class _Logger(logging.Logger): class _Logger(logging.Logger):
r""" r"""A logger that supports rank0 logging."""
A logger that supports rank0 logging.
"""
def info_rank0(self, *args, **kwargs) -> None: def info_rank0(self, *args, **kwargs) -> None:
self.info(*args, **kwargs) self.info(*args, **kwargs)
@ -82,9 +78,7 @@ class _Logger(logging.Logger):
def _get_default_logging_level() -> "logging._Level": def _get_default_logging_level() -> "logging._Level":
r""" r"""Return the default logging level."""
Returns the default logging level.
"""
env_level_str = os.environ.get("LLAMAFACTORY_VERBOSITY", None) env_level_str = os.environ.get("LLAMAFACTORY_VERBOSITY", None)
if env_level_str: if env_level_str:
if env_level_str.upper() in logging._nameToLevel: if env_level_str.upper() in logging._nameToLevel:
@ -104,9 +98,7 @@ def _get_library_root_logger() -> "_Logger":
def _configure_library_root_logger() -> None: def _configure_library_root_logger() -> None:
r""" r"""Configure root logger using a stdout stream handler with an explicit format."""
Configures root logger using a stdout stream handler with an explicit format.
"""
global _default_handler global _default_handler
with _thread_lock: with _thread_lock:
@ -126,9 +118,7 @@ def _configure_library_root_logger() -> None:
def get_logger(name: Optional[str] = None) -> "_Logger": def get_logger(name: Optional[str] = None) -> "_Logger":
r""" r"""Return a logger with the specified name. It it not supposed to be accessed externally."""
Returns a logger with the specified name. It it not supposed to be accessed externally.
"""
if name is None: if name is None:
name = _get_library_name() name = _get_library_name()
@ -137,17 +127,13 @@ def get_logger(name: Optional[str] = None) -> "_Logger":
def add_handler(handler: "logging.Handler") -> None: def add_handler(handler: "logging.Handler") -> None:
r""" r"""Add a handler to the root logger."""
Adds a handler to the root logger.
"""
_configure_library_root_logger() _configure_library_root_logger()
_get_library_root_logger().addHandler(handler) _get_library_root_logger().addHandler(handler)
def remove_handler(handler: logging.Handler) -> None: def remove_handler(handler: logging.Handler) -> None:
r""" r"""Remove a handler to the root logger."""
Removes a handler to the root logger.
"""
_configure_library_root_logger() _configure_library_root_logger()
_get_library_root_logger().removeHandler(handler) _get_library_root_logger().removeHandler(handler)

View File

@ -17,7 +17,8 @@
import gc import gc
import os import os
from typing import TYPE_CHECKING, Any, Dict, Literal, Sequence, Tuple, Union from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Literal, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -54,9 +55,7 @@ logger = logging.get_logger(__name__)
class AverageMeter: class AverageMeter:
r""" r"""Compute and store the average and current value."""
Computes and stores the average and current value.
"""
def __init__(self): def __init__(self):
self.reset() self.reset()
@ -75,9 +74,7 @@ class AverageMeter:
def check_version(requirement: str, mandatory: bool = False) -> None: def check_version(requirement: str, mandatory: bool = False) -> None:
r""" r"""Optionally check the package version."""
Optionally checks the package version.
"""
if is_env_enabled("DISABLE_VERSION_CHECK") and not mandatory: if is_env_enabled("DISABLE_VERSION_CHECK") and not mandatory:
logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.") logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.")
return return
@ -91,9 +88,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
def check_dependencies() -> None: def check_dependencies() -> None:
r""" r"""Check the version of the required packages."""
Checks the version of the required packages.
"""
check_version("transformers>=4.41.2,<=4.49.0,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0") check_version("transformers>=4.41.2,<=4.49.0,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
check_version("datasets>=2.16.0,<=3.2.0") check_version("datasets>=2.16.0,<=3.2.0")
check_version("accelerate>=0.34.0,<=1.2.1") check_version("accelerate>=0.34.0,<=1.2.1")
@ -103,10 +98,8 @@ def check_dependencies() -> None:
logger.warning_rank0_once("There are known bugs in transformers v4.46.0-v4.48.0, please use other versions.") logger.warning_rank0_once("There are known bugs in transformers v4.46.0-v4.48.0, please use other versions.")
def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float], stage: Literal["sft", "rm"]) -> float: def calculate_tps(dataset: Sequence[dict[str, Any]], metrics: dict[str, float], stage: Literal["sft", "rm"]) -> float:
r""" r"""Calculate effective tokens per second."""
Calculates effective tokens per second.
"""
effective_token_num = 0 effective_token_num = 0
for data in dataset: for data in dataset:
if stage == "sft": if stage == "sft":
@ -118,10 +111,8 @@ def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float],
return result / dist.get_world_size() if dist.is_initialized() else result return result / dist.get_world_size() if dist.is_initialized() else result
def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]: def count_parameters(model: "torch.nn.Module") -> tuple[int, int]:
r""" r"""Return the number of trainable parameters and number of all parameters in the model."""
Returns the number of trainable parameters and number of all parameters in the model.
"""
trainable_params, all_param = 0, 0 trainable_params, all_param = 0, 0
for param in model.parameters(): for param in model.parameters():
num_params = param.numel() num_params = param.numel()
@ -148,9 +139,7 @@ def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
def get_current_device() -> "torch.device": def get_current_device() -> "torch.device":
r""" r"""Get the current available device."""
Gets the current available device.
"""
if is_torch_xpu_available(): if is_torch_xpu_available():
device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0")) device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
elif is_torch_npu_available(): elif is_torch_npu_available():
@ -166,9 +155,7 @@ def get_current_device() -> "torch.device":
def get_device_count() -> int: def get_device_count() -> int:
r""" r"""Get the number of available GPU or NPU devices."""
Gets the number of available GPU or NPU devices.
"""
if is_torch_xpu_available(): if is_torch_xpu_available():
return torch.xpu.device_count() return torch.xpu.device_count()
elif is_torch_npu_available(): elif is_torch_npu_available():
@ -180,18 +167,14 @@ def get_device_count() -> int:
def get_logits_processor() -> "LogitsProcessorList": def get_logits_processor() -> "LogitsProcessorList":
r""" r"""Get logits processor that removes NaN and Inf logits."""
Gets logits processor that removes NaN and Inf logits.
"""
logits_processor = LogitsProcessorList() logits_processor = LogitsProcessorList()
logits_processor.append(InfNanRemoveLogitsProcessor()) logits_processor.append(InfNanRemoveLogitsProcessor())
return logits_processor return logits_processor
def get_peak_memory() -> Tuple[int, int]: def get_peak_memory() -> tuple[int, int]:
r""" r"""Get the peak memory usage for the current device (in Bytes)."""
Gets the peak memory usage for the current device (in Bytes).
"""
if is_torch_npu_available(): if is_torch_npu_available():
return torch.npu.max_memory_allocated(), torch.npu.max_memory_reserved() return torch.npu.max_memory_allocated(), torch.npu.max_memory_reserved()
elif is_torch_cuda_available(): elif is_torch_cuda_available():
@ -201,16 +184,12 @@ def get_peak_memory() -> Tuple[int, int]:
def has_tokenized_data(path: "os.PathLike") -> bool: def has_tokenized_data(path: "os.PathLike") -> bool:
r""" r"""Check if the path has a tokenized dataset."""
Checks if the path has a tokenized dataset.
"""
return os.path.isdir(path) and len(os.listdir(path)) > 0 return os.path.isdir(path) and len(os.listdir(path)) > 0
def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype": def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
r""" r"""Infer the optimal dtype according to the model_dtype and device compatibility."""
Infers the optimal dtype according to the model_dtype and device compatibility.
"""
if _is_bf16_available and model_dtype == torch.bfloat16: if _is_bf16_available and model_dtype == torch.bfloat16:
return torch.bfloat16 return torch.bfloat16
elif _is_fp16_available: elif _is_fp16_available:
@ -220,23 +199,17 @@ def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
def is_gpu_or_npu_available() -> bool: def is_gpu_or_npu_available() -> bool:
r""" r"""Check if the GPU or NPU is available."""
Checks if the GPU or NPU is available.
"""
return is_torch_npu_available() or is_torch_cuda_available() return is_torch_npu_available() or is_torch_cuda_available()
def is_env_enabled(env_var: str, default: str = "0") -> bool: def is_env_enabled(env_var: str, default: str = "0") -> bool:
r""" r"""Check if the environment variable is enabled."""
Checks if the environment variable is enabled.
"""
return os.getenv(env_var, default).lower() in ["true", "y", "1"] return os.getenv(env_var, default).lower() in ["true", "y", "1"]
def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray": def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
r""" r"""Cast a torch tensor or a numpy array to a numpy array."""
Casts a torch tensor or a numpy array to a numpy array.
"""
if isinstance(inputs, torch.Tensor): if isinstance(inputs, torch.Tensor):
inputs = inputs.cpu() inputs = inputs.cpu()
if inputs.dtype == torch.bfloat16: # numpy does not support bfloat16 until 1.21.4 if inputs.dtype == torch.bfloat16: # numpy does not support bfloat16 until 1.21.4
@ -248,17 +221,13 @@ def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
def skip_check_imports() -> None: def skip_check_imports() -> None:
r""" r"""Avoid flash attention import error in custom model files."""
Avoids flash attention import error in custom model files.
"""
if not is_env_enabled("FORCE_CHECK_IMPORTS"): if not is_env_enabled("FORCE_CHECK_IMPORTS"):
transformers.dynamic_module_utils.check_imports = get_relative_imports transformers.dynamic_module_utils.check_imports = get_relative_imports
def torch_gc() -> None: def torch_gc() -> None:
r""" r"""Collect GPU or NPU memory."""
Collects GPU or NPU memory.
"""
gc.collect() gc.collect()
if is_torch_xpu_available(): if is_torch_xpu_available():
torch.xpu.empty_cache() torch.xpu.empty_cache()

View File

@ -15,7 +15,7 @@
import json import json
import math import math
import os import os
from typing import Any, Dict, List from typing import Any
from transformers.trainer import TRAINER_STATE_NAME from transformers.trainer import TRAINER_STATE_NAME
@ -31,10 +31,8 @@ if is_matplotlib_available():
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def smooth(scalars: List[float]) -> List[float]: def smooth(scalars: list[float]) -> list[float]:
r""" r"""EMA implementation according to TensorBoard."""
EMA implementation according to TensorBoard.
"""
if len(scalars) == 0: if len(scalars) == 0:
return [] return []
@ -48,10 +46,8 @@ def smooth(scalars: List[float]) -> List[float]:
return smoothed return smoothed
def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figure": def gen_loss_plot(trainer_log: list[dict[str, Any]]) -> "matplotlib.figure.Figure":
r""" r"""Plot loss curves in LlamaBoard."""
Plots loss curves in LlamaBoard.
"""
plt.close("all") plt.close("all")
plt.switch_backend("agg") plt.switch_backend("agg")
fig = plt.figure() fig = plt.figure()
@ -70,10 +66,8 @@ def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figur
return fig return fig
def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None: def plot_loss(save_dictionary: str, keys: list[str] = ["loss"]) -> None:
r""" r"""Plot loss curves and saves the image."""
Plots loss curves and saves the image.
"""
plt.switch_backend("agg") plt.switch_backend("agg")
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), encoding="utf-8") as f: with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), encoding="utf-8") as f:
data = json.load(f) data = json.load(f)

View File

@ -16,14 +16,12 @@
# limitations under the License. # limitations under the License.
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import Any, Dict, Literal, Optional from typing import Any, Literal, Optional
@dataclass @dataclass
class DataArguments: class DataArguments:
r""" r"""Arguments pertaining to what data we are going to input our model for training and evaluation."""
Arguments pertaining to what data we are going to input our model for training and evaluation.
"""
template: Optional[str] = field( template: Optional[str] = field(
default=None, default=None,
@ -162,5 +160,5 @@ class DataArguments:
if self.mask_history and self.train_on_prompt: if self.mask_history and self.train_on_prompt:
raise ValueError("`mask_history` is incompatible with `train_on_prompt`.") raise ValueError("`mask_history` is incompatible with `train_on_prompt`.")
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> dict[str, Any]:
return asdict(self) return asdict(self)

View File

@ -21,9 +21,7 @@ from datasets import DownloadMode
@dataclass @dataclass
class EvaluationArguments: class EvaluationArguments:
r""" r"""Arguments pertaining to specify the evaluation parameters."""
Arguments pertaining to specify the evaluation parameters.
"""
task: str = field( task: str = field(
metadata={"help": "Name of the evaluation task."}, metadata={"help": "Name of the evaluation task."},

View File

@ -13,14 +13,12 @@
# limitations under the License. # limitations under the License.
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import Any, Dict, List, Literal, Optional from typing import Any, Literal, Optional
@dataclass @dataclass
class FreezeArguments: class FreezeArguments:
r""" r"""Arguments pertaining to the freeze (partial-parameter) training."""
Arguments pertaining to the freeze (partial-parameter) training.
"""
freeze_trainable_layers: int = field( freeze_trainable_layers: int = field(
default=2, default=2,
@ -56,9 +54,7 @@ class FreezeArguments:
@dataclass @dataclass
class LoraArguments: class LoraArguments:
r""" r"""Arguments pertaining to the LoRA training."""
Arguments pertaining to the LoRA training.
"""
additional_target: Optional[str] = field( additional_target: Optional[str] = field(
default=None, default=None,
@ -128,9 +124,7 @@ class LoraArguments:
@dataclass @dataclass
class RLHFArguments: class RLHFArguments:
r""" r"""Arguments pertaining to the PPO, DPO and KTO training."""
Arguments pertaining to the PPO, DPO and KTO training.
"""
pref_beta: float = field( pref_beta: float = field(
default=0.1, default=0.1,
@ -212,9 +206,7 @@ class RLHFArguments:
@dataclass @dataclass
class GaloreArguments: class GaloreArguments:
r""" r"""Arguments pertaining to the GaLore algorithm."""
Arguments pertaining to the GaLore algorithm.
"""
use_galore: bool = field( use_galore: bool = field(
default=False, default=False,
@ -253,9 +245,7 @@ class GaloreArguments:
@dataclass @dataclass
class ApolloArguments: class ApolloArguments:
r""" r"""Arguments pertaining to the APOLLO algorithm."""
Arguments pertaining to the APOLLO algorithm.
"""
use_apollo: bool = field( use_apollo: bool = field(
default=False, default=False,
@ -306,9 +296,7 @@ class ApolloArguments:
@dataclass @dataclass
class BAdamArgument: class BAdamArgument:
r""" r"""Arguments pertaining to the BAdam optimizer."""
Arguments pertaining to the BAdam optimizer.
"""
use_badam: bool = field( use_badam: bool = field(
default=False, default=False,
@ -393,9 +381,7 @@ class SwanLabArguments:
class FinetuningArguments( class FinetuningArguments(
SwanLabArguments, BAdamArgument, ApolloArguments, GaloreArguments, RLHFArguments, LoraArguments, FreezeArguments SwanLabArguments, BAdamArgument, ApolloArguments, GaloreArguments, RLHFArguments, LoraArguments, FreezeArguments
): ):
r""" r"""Arguments pertaining to which techniques we are going to fine-tuning with."""
Arguments pertaining to which techniques we are going to fine-tuning with.
"""
pure_bf16: bool = field( pure_bf16: bool = field(
default=False, default=False,
@ -452,13 +438,13 @@ class FinetuningArguments(
return [item.strip() for item in arg.split(",")] return [item.strip() for item in arg.split(",")]
return arg return arg
self.freeze_trainable_modules: List[str] = split_arg(self.freeze_trainable_modules) self.freeze_trainable_modules: list[str] = split_arg(self.freeze_trainable_modules)
self.freeze_extra_modules: Optional[List[str]] = split_arg(self.freeze_extra_modules) self.freeze_extra_modules: Optional[list[str]] = split_arg(self.freeze_extra_modules)
self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2 self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2
self.lora_target: List[str] = split_arg(self.lora_target) self.lora_target: list[str] = split_arg(self.lora_target)
self.additional_target: Optional[List[str]] = split_arg(self.additional_target) self.additional_target: Optional[list[str]] = split_arg(self.additional_target)
self.galore_target: List[str] = split_arg(self.galore_target) self.galore_target: list[str] = split_arg(self.galore_target)
self.apollo_target: List[str] = split_arg(self.apollo_target) self.apollo_target: list[str] = split_arg(self.apollo_target)
self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"] self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method." assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
@ -499,7 +485,7 @@ class FinetuningArguments(
if self.pissa_init: if self.pissa_init:
raise ValueError("`pissa_init` is only valid for LoRA training.") raise ValueError("`pissa_init` is only valid for LoRA training.")
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> dict[str, Any]:
args = asdict(self) args = asdict(self)
args = {k: f"<{k.upper()}>" if k.endswith("api_key") else v for k, v in args.items()} args = {k: f"<{k.upper()}>" if k.endswith("api_key") else v for k, v in args.items()}
return args return args

View File

@ -13,16 +13,14 @@
# limitations under the License. # limitations under the License.
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import Any, Dict, Optional from typing import Any, Optional
from transformers import GenerationConfig from transformers import GenerationConfig
@dataclass @dataclass
class GeneratingArguments: class GeneratingArguments:
r""" r"""Arguments pertaining to specify the decoding parameters."""
Arguments pertaining to specify the decoding parameters.
"""
do_sample: bool = field( do_sample: bool = field(
default=True, default=True,
@ -35,7 +33,9 @@ class GeneratingArguments:
top_p: float = field( top_p: float = field(
default=0.7, default=0.7,
metadata={ metadata={
"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept." "help": (
"The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."
)
}, },
) )
top_k: int = field( top_k: int = field(
@ -71,7 +71,7 @@ class GeneratingArguments:
metadata={"help": "Whether or not to remove special tokens in the decoding."}, metadata={"help": "Whether or not to remove special tokens in the decoding."},
) )
def to_dict(self, obey_generation_config: bool = False) -> Dict[str, Any]: def to_dict(self, obey_generation_config: bool = False) -> dict[str, Any]:
args = asdict(self) args = asdict(self)
if args.get("max_new_tokens", -1) > 0: if args.get("max_new_tokens", -1) > 0:
args.pop("max_length", None) args.pop("max_length", None)

View File

@ -17,7 +17,7 @@
import json import json
from dataclasses import asdict, dataclass, field, fields from dataclasses import asdict, dataclass, field, fields
from typing import Any, Dict, Literal, Optional, Union from typing import Any, Literal, Optional, Union
import torch import torch
from transformers.training_args import _convert_str_dict from transformers.training_args import _convert_str_dict
@ -28,9 +28,7 @@ from ..extras.constants import AttentionFunction, EngineName, RopeScaling
@dataclass @dataclass
class BaseModelArguments: class BaseModelArguments:
r""" r"""Arguments pertaining to the model."""
Arguments pertaining to the model.
"""
model_name_or_path: Optional[str] = field( model_name_or_path: Optional[str] = field(
default=None, default=None,
@ -184,9 +182,7 @@ class BaseModelArguments:
@dataclass @dataclass
class QuantizationArguments: class QuantizationArguments:
r""" r"""Arguments pertaining to the quantization method."""
Arguments pertaining to the quantization method.
"""
quantization_method: Literal["bitsandbytes", "hqq", "eetq"] = field( quantization_method: Literal["bitsandbytes", "hqq", "eetq"] = field(
default="bitsandbytes", default="bitsandbytes",
@ -212,9 +208,7 @@ class QuantizationArguments:
@dataclass @dataclass
class ProcessorArguments: class ProcessorArguments:
r""" r"""Arguments pertaining to the image processor."""
Arguments pertaining to the image processor.
"""
image_max_pixels: int = field( image_max_pixels: int = field(
default=768 * 768, default=768 * 768,
@ -244,9 +238,7 @@ class ProcessorArguments:
@dataclass @dataclass
class ExportArguments: class ExportArguments:
r""" r"""Arguments pertaining to the model export."""
Arguments pertaining to the model export.
"""
export_dir: Optional[str] = field( export_dir: Optional[str] = field(
default=None, default=None,
@ -292,9 +284,7 @@ class ExportArguments:
@dataclass @dataclass
class VllmArguments: class VllmArguments:
r""" r"""Arguments pertaining to the vLLM worker."""
Arguments pertaining to the vLLM worker.
"""
vllm_maxlen: int = field( vllm_maxlen: int = field(
default=4096, default=4096,
@ -324,8 +314,7 @@ class VllmArguments:
@dataclass @dataclass
class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, QuantizationArguments, BaseModelArguments): class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, QuantizationArguments, BaseModelArguments):
r""" r"""Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
The class on the most right will be displayed first. The class on the most right will be displayed first.
""" """
@ -335,7 +324,7 @@ class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, Quantiz
init=False, init=False,
metadata={"help": "Torch data type for computing model outputs, derived from `fp/bf16`. Do not specify it."}, metadata={"help": "Torch data type for computing model outputs, derived from `fp/bf16`. Do not specify it."},
) )
device_map: Optional[Union[str, Dict[str, Any]]] = field( device_map: Optional[Union[str, dict[str, Any]]] = field(
default=None, default=None,
init=False, init=False,
metadata={"help": "Device map for model placement, derived from training stage. Do not specify it."}, metadata={"help": "Device map for model placement, derived from training stage. Do not specify it."},
@ -372,7 +361,7 @@ class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, Quantiz
return result return result
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> dict[str, Any]:
args = asdict(self) args = asdict(self)
args = {k: f"<{k.upper()}>" if k.endswith("token") else v for k, v in args.items()} args = {k: f"<{k.upper()}>" if k.endswith("token") else v for k, v in args.items()}
return args return args

View File

@ -19,7 +19,7 @@ import json
import os import os
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Optional, Union
import torch import torch
import transformers import transformers
@ -47,17 +47,15 @@ check_dependencies()
_TRAIN_ARGS = [ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments] _TRAIN_ARGS = [ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments] _TRAIN_CLS = tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] _INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] _INFER_CLS = tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] _EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] _EVAL_CLS = tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
def read_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> Union[Dict[str, Any], List[str]]: def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[dict[str, Any], list[str]]:
r""" r"""Get arguments from the command line or a config file."""
Gets arguments from the command line or a config file.
"""
if args is not None: if args is not None:
return args return args
@ -70,8 +68,8 @@ def read_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> Union[
def _parse_args( def _parse_args(
parser: "HfArgumentParser", args: Optional[Union[Dict[str, Any], List[str]]] = None, allow_extra_keys: bool = False parser: "HfArgumentParser", args: Optional[Union[dict[str, Any], list[str]]] = None, allow_extra_keys: bool = False
) -> Tuple[Any]: ) -> tuple[Any]:
args = read_args(args) args = read_args(args)
if isinstance(args, dict): if isinstance(args, dict):
return parser.parse_dict(args, allow_extra_keys=allow_extra_keys) return parser.parse_dict(args, allow_extra_keys=allow_extra_keys)
@ -161,31 +159,31 @@ def _check_extra_dependencies(
check_version("rouge_chinese", mandatory=True) check_version("rouge_chinese", mandatory=True)
def _parse_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS: def _parse_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS:
parser = HfArgumentParser(_TRAIN_ARGS) parser = HfArgumentParser(_TRAIN_ARGS)
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS") allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys) return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
def _parse_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS: def _parse_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS:
parser = HfArgumentParser(_INFER_ARGS) parser = HfArgumentParser(_INFER_ARGS)
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS") allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys) return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
def _parse_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS: def _parse_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS:
parser = HfArgumentParser(_EVAL_ARGS) parser = HfArgumentParser(_EVAL_ARGS)
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS") allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys) return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
def get_ray_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> RayArguments: def get_ray_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> RayArguments:
parser = HfArgumentParser(RayArguments) parser = HfArgumentParser(RayArguments)
(ray_args,) = _parse_args(parser, args, allow_extra_keys=True) (ray_args,) = _parse_args(parser, args, allow_extra_keys=True)
return ray_args return ray_args
def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS: def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS:
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args) model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
# Setup logging # Setup logging
@ -364,9 +362,7 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
and training_args.resume_from_checkpoint is not None and training_args.resume_from_checkpoint is not None
): ):
logger.warning_rank0( logger.warning_rank0(
"Add {} to `adapter_name_or_path` to resume training from checkpoint.".format( f"Add {training_args.resume_from_checkpoint} to `adapter_name_or_path` to resume training from checkpoint."
training_args.resume_from_checkpoint
)
) )
# Post-process model arguments # Post-process model arguments
@ -382,20 +378,17 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
# Log on each process the small summary # Log on each process the small summary
logger.info( logger.info(
"Process rank: {}, world size: {}, device: {}, distributed training: {}, compute dtype: {}".format( f"Process rank: {training_args.process_index}, "
training_args.process_index, f"world size: {training_args.world_size}, device: {training_args.device}, "
training_args.world_size, f"distributed training: {training_args.parallel_mode == ParallelMode.DISTRIBUTED}, "
training_args.device, f"compute dtype: {str(model_args.compute_dtype)}"
training_args.parallel_mode == ParallelMode.DISTRIBUTED,
str(model_args.compute_dtype),
)
) )
transformers.set_seed(training_args.seed) transformers.set_seed(training_args.seed)
return model_args, data_args, training_args, finetuning_args, generating_args return model_args, data_args, training_args, finetuning_args, generating_args
def get_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS: def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS:
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args) model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
_set_transformers_logging() _set_transformers_logging()
@ -426,7 +419,7 @@ def get_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
return model_args, data_args, finetuning_args, generating_args return model_args, data_args, finetuning_args, generating_args
def get_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS: def get_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS:
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args) model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
_set_transformers_logging() _set_transformers_logging()

View File

@ -10,9 +10,7 @@ from ..extras.misc import use_ray
@dataclass @dataclass
class RayArguments: class RayArguments:
r""" r"""Arguments pertaining to the Ray training."""
Arguments pertaining to the Ray training.
"""
ray_run_name: Optional[str] = field( ray_run_name: Optional[str] = field(
default=None, default=None,
@ -43,9 +41,7 @@ class RayArguments:
@dataclass @dataclass
class TrainingArguments(RayArguments, Seq2SeqTrainingArguments): class TrainingArguments(RayArguments, Seq2SeqTrainingArguments):
r""" r"""Arguments pertaining to the trainer."""
Arguments pertaining to the trainer.
"""
def __post_init__(self): def __post_init__(self):
Seq2SeqTrainingArguments.__post_init__(self) Seq2SeqTrainingArguments.__post_init__(self)

View File

@ -20,9 +20,9 @@ from .model_utils.valuehead import load_valuehead_params
__all__ = [ __all__ = [
"QuantizationMethod", "QuantizationMethod",
"find_all_linear_modules",
"load_config", "load_config",
"load_model", "load_model",
"load_tokenizer", "load_tokenizer",
"find_all_linear_modules",
"load_valuehead_params", "load_valuehead_params",
] ]

View File

@ -81,9 +81,8 @@ def _setup_freeze_tuning(
if finetuning_args.use_llama_pro: if finetuning_args.use_llama_pro:
if num_layers % finetuning_args.freeze_trainable_layers != 0: if num_layers % finetuning_args.freeze_trainable_layers != 0:
raise ValueError( raise ValueError(
"`num_layers` {} should be divisible by `num_layer_trainable` {}.".format( f"`num_layers` {num_layers} should be "
num_layers, finetuning_args.freeze_trainable_layers f"divisible by `num_layer_trainable` {finetuning_args.freeze_trainable_layers}."
)
) )
stride = num_layers // finetuning_args.freeze_trainable_layers stride = num_layers // finetuning_args.freeze_trainable_layers
@ -178,7 +177,7 @@ def _setup_lora_tuning(
} }
for adapter in adapter_to_merge: for adapter in adapter_to_merge:
model: "LoraModel" = PeftModel.from_pretrained(model, adapter, **init_kwargs) model: LoraModel = PeftModel.from_pretrained(model, adapter, **init_kwargs)
model = model.merge_and_unload() model = model.merge_and_unload()
if len(adapter_to_merge) > 0: if len(adapter_to_merge) > 0:
@ -263,8 +262,7 @@ def init_adapter(
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
is_trainable: bool, is_trainable: bool,
) -> "PreTrainedModel": ) -> "PreTrainedModel":
r""" r"""Initialize the adapters.
Initializes the adapters.
Support full-parameter, freeze and LoRA training. Support full-parameter, freeze and LoRA training.

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import os import os
from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict from typing import TYPE_CHECKING, Any, Optional, TypedDict
import torch import torch
from transformers import ( from transformers import (
@ -51,9 +51,8 @@ class TokenizerModule(TypedDict):
processor: Optional["ProcessorMixin"] processor: Optional["ProcessorMixin"]
def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]: def _get_init_kwargs(model_args: "ModelArguments") -> dict[str, Any]:
r""" r"""Get arguments to load config/tokenizer/model.
Gets arguments to load config/tokenizer/model.
Note: including inplace operation of model_args. Note: including inplace operation of model_args.
""" """
@ -68,8 +67,7 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule": def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
r""" r"""Load pretrained tokenizer and optionally loads processor.
Loads pretrained tokenizer and optionally loads processor.
Note: including inplace operation of model_args. Note: including inplace operation of model_args.
""" """
@ -110,9 +108,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
def load_config(model_args: "ModelArguments") -> "PretrainedConfig": def load_config(model_args: "ModelArguments") -> "PretrainedConfig":
r""" r"""Load model config."""
Loads model config.
"""
init_kwargs = _get_init_kwargs(model_args) init_kwargs = _get_init_kwargs(model_args)
return AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs) return AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs)
@ -124,9 +120,7 @@ def load_model(
is_trainable: bool = False, is_trainable: bool = False,
add_valuehead: bool = False, add_valuehead: bool = False,
) -> "PreTrainedModel": ) -> "PreTrainedModel":
r""" r"""Load pretrained model."""
Loads pretrained model.
"""
init_kwargs = _get_init_kwargs(model_args) init_kwargs = _get_init_kwargs(model_args)
config = load_config(model_args) config = load_config(model_args)
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable) patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
@ -194,8 +188,9 @@ def load_model(
trainable_params, all_param = count_parameters(model) trainable_params, all_param = count_parameters(model)
if is_trainable: if is_trainable:
param_stats = "trainable params: {:,} || all params: {:,} || trainable%: {:.4f}".format( param_stats = (
trainable_params, all_param, 100 * trainable_params / all_param f"trainable params: {trainable_params:,} || "
f"all params: {all_param:,} || trainable%: {100 * trainable_params / all_param:.4f}"
) )
else: else:
param_stats = f"all params: {all_param:,}" param_stats = f"all params: {all_param:,}"

View File

@ -21,7 +21,7 @@
import inspect import inspect
from functools import WRAPPER_ASSIGNMENTS, partial, wraps from functools import WRAPPER_ASSIGNMENTS, partial, wraps
from types import MethodType from types import MethodType
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import torch import torch
@ -40,9 +40,7 @@ logger = logging.get_logger(__name__)
def get_unsloth_gradient_checkpointing_func() -> Callable: def get_unsloth_gradient_checkpointing_func() -> Callable:
class UnslothGradientCheckpointing(torch.autograd.Function): class UnslothGradientCheckpointing(torch.autograd.Function):
r""" r"""Saves VRAM by smartly offloading to RAM."""
Saves VRAM by smartly offloading to RAM.
"""
@staticmethod @staticmethod
@torch.cuda.amp.custom_fwd @torch.cuda.amp.custom_fwd
@ -77,13 +75,11 @@ def get_unsloth_gradient_checkpointing_func() -> Callable:
def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable) -> Callable: def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable) -> Callable:
r""" r"""Only applies gradient checkpointing to trainable layers."""
Only applies gradient checkpointing to trainable layers.
"""
@wraps(gradient_checkpointing_func, assigned=WRAPPER_ASSIGNMENTS + ("__self__",)) @wraps(gradient_checkpointing_func, assigned=WRAPPER_ASSIGNMENTS + ("__self__",))
def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs): def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
module: "torch.nn.Module" = func.__self__ module: torch.nn.Module = func.__self__
has_grad = False has_grad = False
if any(param.requires_grad for param in module.parameters()): if any(param.requires_grad for param in module.parameters()):
@ -103,11 +99,10 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
def _gradient_checkpointing_enable( def _gradient_checkpointing_enable(
self: "PreTrainedModel", self: "PreTrainedModel",
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None, gradient_checkpointing_kwargs: Optional[dict[str, Any]] = None,
use_unsloth_gc: bool = False, use_unsloth_gc: bool = False,
) -> None: ) -> None:
r""" r"""Activates gradient checkpointing for the current model.
Activates gradient checkpointing for the current model.
Modification of the original method to enable gradient checkpointing for block-wise optimizer. Modification of the original method to enable gradient checkpointing for block-wise optimizer.
""" """
@ -134,17 +129,18 @@ def _gradient_checkpointing_enable(
def _fp32_forward_post_hook( def _fp32_forward_post_hook(
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor" module: "torch.nn.Module", args: tuple["torch.Tensor"], output: "torch.Tensor"
) -> "torch.Tensor": ) -> "torch.Tensor":
return output.to(torch.float32) return output.to(torch.float32)
def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArguments") -> None: def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArguments") -> None:
r""" r"""Prepare the model before training.
Includes:
Include:
(1) cast the layernorm in fp32 (1) cast the layernorm in fp32
(2) make output embedding layer require grads (2) make output embedding layer require grads
(3) add the upcasting of the lm_head in fp32 (3) add the upcasting of the lm_head in fp32.
""" """
if model_args.upcast_layernorm: if model_args.upcast_layernorm:
logger.info_rank0("Upcasting layernorm weights in float32.") logger.info_rank0("Upcasting layernorm weights in float32.")

View File

@ -38,9 +38,7 @@ def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int
def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None: def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None:
r""" r"""Resize token embeddings."""
Resize token embeddings.
"""
if is_deepspeed_zero3_enabled(): if is_deepspeed_zero3_enabled():
import deepspeed # type: ignore import deepspeed # type: ignore

View File

@ -18,7 +18,7 @@
# limitations under the License. # limitations under the License.
import math import math
from typing import TYPE_CHECKING, Optional, Tuple from typing import TYPE_CHECKING, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -54,14 +54,14 @@ def llama_attention_forward(
past_key_value: Optional["Cache"] = None, past_key_value: Optional["Cache"] = None,
output_attentions: bool = False, output_attentions: bool = False,
cache_position: Optional["torch.LongTensor"] = None, cache_position: Optional["torch.LongTensor"] = None,
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None, position_embeddings: Optional[tuple["torch.Tensor", "torch.Tensor"]] = None,
**kwargs, **kwargs,
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]: ) -> tuple["torch.Tensor", Optional["torch.Tensor"], Optional[tuple["torch.Tensor"]]]:
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
query_states: "torch.Tensor" = self.q_proj(hidden_states) query_states: torch.Tensor = self.q_proj(hidden_states)
key_states: "torch.Tensor" = self.k_proj(hidden_states) key_states: torch.Tensor = self.k_proj(hidden_states)
value_states: "torch.Tensor" = self.v_proj(hidden_states) value_states: torch.Tensor = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
@ -139,17 +139,17 @@ def llama_flash_attention_2_forward(
past_key_value: Optional["Cache"] = None, past_key_value: Optional["Cache"] = None,
output_attentions: bool = False, output_attentions: bool = False,
cache_position: Optional["torch.LongTensor"] = None, cache_position: Optional["torch.LongTensor"] = None,
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None, position_embeddings: Optional[tuple["torch.Tensor", "torch.Tensor"]] = None,
**kwargs, **kwargs,
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]: ) -> tuple["torch.Tensor", Optional["torch.Tensor"], Optional[tuple["torch.Tensor"]]]:
# LlamaFlashAttention2 attention does not support output_attentions # LlamaFlashAttention2 attention does not support output_attentions
output_attentions = False output_attentions = False
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
query_states: "torch.Tensor" = self.q_proj(hidden_states) query_states: torch.Tensor = self.q_proj(hidden_states)
key_states: "torch.Tensor" = self.k_proj(hidden_states) key_states: torch.Tensor = self.k_proj(hidden_states)
value_states: "torch.Tensor" = self.v_proj(hidden_states) value_states: torch.Tensor = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
@ -209,7 +209,7 @@ def llama_flash_attention_2_forward(
if is_transformers_version_greater_than("4.43.0"): if is_transformers_version_greater_than("4.43.0"):
from transformers.modeling_flash_attention_utils import _flash_attention_forward from transformers.modeling_flash_attention_utils import _flash_attention_forward
attn_output: "torch.Tensor" = _flash_attention_forward( attn_output: torch.Tensor = _flash_attention_forward(
query_states, query_states,
key_states, key_states,
value_states, value_states,
@ -221,7 +221,7 @@ def llama_flash_attention_2_forward(
is_causal=self.is_causal, is_causal=self.is_causal,
) )
else: else:
attn_output: "torch.Tensor" = self._flash_attention_forward( attn_output: torch.Tensor = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate
) )
@ -254,9 +254,9 @@ def llama_sdpa_attention_forward(
past_key_value: Optional["Cache"] = None, past_key_value: Optional["Cache"] = None,
output_attentions: bool = False, output_attentions: bool = False,
cache_position: Optional["torch.LongTensor"] = None, cache_position: Optional["torch.LongTensor"] = None,
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None, position_embeddings: Optional[tuple["torch.Tensor", "torch.Tensor"]] = None,
**kwargs, **kwargs,
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]: ) -> tuple["torch.Tensor", Optional["torch.Tensor"], Optional[tuple["torch.Tensor"]]]:
if output_attentions: if output_attentions:
transformers_logger.warning_once( transformers_logger.warning_once(
"SDPA does not support `output_attentions=True`. Falling back to the vanilla attention" "SDPA does not support `output_attentions=True`. Falling back to the vanilla attention"
@ -274,9 +274,9 @@ def llama_sdpa_attention_forward(
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
query_states: "torch.Tensor" = self.q_proj(hidden_states) query_states: torch.Tensor = self.q_proj(hidden_states)
key_states: "torch.Tensor" = self.k_proj(hidden_states) key_states: torch.Tensor = self.k_proj(hidden_states)
value_states: "torch.Tensor" = self.v_proj(hidden_states) value_states: torch.Tensor = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, List from typing import TYPE_CHECKING
from ...extras import logging from ...extras import logging
from .visual import COMPOSITE_MODELS from .visual import COMPOSITE_MODELS
@ -25,10 +25,8 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]: def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> list[str]:
r""" r"""Find all available modules to apply LoRA, GaLore or APOLLO."""
Finds all available modules to apply LoRA, GaLore or APOLLO.
"""
model_type = getattr(model.config, "model_type", None) model_type = getattr(model.config, "model_type", None)
forbidden_modules = {"lm_head"} forbidden_modules = {"lm_head"}
if model_type == "chatglm": if model_type == "chatglm":
@ -54,10 +52,8 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
return list(module_names) return list(module_names)
def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], num_layer_trainable: int) -> List[str]: def find_expanded_modules(model: "PreTrainedModel", target_modules: list[str], num_layer_trainable: int) -> list[str]:
r""" r"""Find the modules in the expanded blocks to apply lora."""
Finds the modules in the expanded blocks to apply lora.
"""
num_layers = getattr(model.config, "num_hidden_layers", None) num_layers = getattr(model.config, "num_hidden_layers", None)
if not num_layers: if not num_layers:
raise ValueError("Model was not supported.") raise ValueError("Model was not supported.")

View File

@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, Sequence from collections.abc import Sequence
from typing import TYPE_CHECKING
import torch import torch
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
@ -34,9 +35,7 @@ def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: Sequence["torch
def add_z3_leaf_module(model: "PreTrainedModel") -> None: def add_z3_leaf_module(model: "PreTrainedModel") -> None:
r""" r"""Set module as a leaf module to skip partitioning in deepspeed zero3."""
Sets module as a leaf module to skip partitioning in deepspeed zero3.
"""
if not is_deepspeed_zero3_enabled(): if not is_deepspeed_zero3_enabled():
return return

View File

@ -37,7 +37,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE. # SOFTWARE.
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -59,8 +59,7 @@ logger = logging.get_logger(__name__)
def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor": def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
r""" r"""Get the sequnce lengths in the current batch.
Gets the sequnce lengths in the current batch.
e.g. e.g.
```python ```python
@ -76,7 +75,7 @@ def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
bsz = attention_mask.size(0) bsz = attention_mask.size(0)
dtype, device = attention_mask.dtype, attention_mask.device dtype, device = attention_mask.dtype, attention_mask.device
max_num = torch.max(attention_mask).item() max_num = torch.max(attention_mask).item()
counts: "torch.Tensor" = torch.zeros((bsz, max_num), dtype=dtype, device=device) counts: torch.Tensor = torch.zeros((bsz, max_num), dtype=dtype, device=device)
for i in range(max_num): for i in range(max_num):
counts[:, i] = torch.sum(attention_mask == (i + 1), dim=-1) counts[:, i] = torch.sum(attention_mask == (i + 1), dim=-1)
@ -85,9 +84,8 @@ def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
return seqlens return seqlens
def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "torch.Tensor", int]: def get_unpad_data(attention_mask: "torch.Tensor") -> tuple["torch.Tensor", "torch.Tensor", int]:
r""" r"""Prepare the indices and seqlens for flash attn varlen function.
Prepares the indices and seqlens for flash attn varlen function.
Returns: Returns:
indices: indices of non-masked tokens from the flattened sequence. indices: indices of non-masked tokens from the flattened sequence.
@ -106,6 +104,7 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
[0, 2, 5, 6, 8, 11] [0, 2, 5, 6, 8, 11]
3 3
``` ```
""" """
seqlens_in_batch = get_seqlens_in_batch(attention_mask) seqlens_in_batch = get_seqlens_in_batch(attention_mask)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()

View File

@ -19,7 +19,7 @@
import os import os
import random import random
from enum import Enum, unique from enum import Enum, unique
from typing import TYPE_CHECKING, Any, Dict, List from typing import TYPE_CHECKING, Any
import torch import torch
from datasets import load_dataset from datasets import load_dataset
@ -43,9 +43,7 @@ logger = logging.get_logger(__name__)
@unique @unique
class QuantizationMethod(str, Enum): class QuantizationMethod(str, Enum):
r""" r"""Borrowed from `transformers.utils.quantization_config.QuantizationMethod`."""
Borrowed from `transformers.utils.quantization_config.QuantizationMethod`.
"""
BITS_AND_BYTES = "bitsandbytes" BITS_AND_BYTES = "bitsandbytes"
GPTQ = "gptq" GPTQ = "gptq"
@ -56,10 +54,8 @@ class QuantizationMethod(str, Enum):
HQQ = "hqq" HQQ = "hqq"
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[Dict[str, Any]]: def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> list[dict[str, Any]]:
r""" r"""Prepare the tokenized dataset to perform AutoGPTQ. Do not use tensor output for JSON serialization."""
Prepares the tokenized dataset to perform AutoGPTQ. Do not use tensor output for JSON serialization.
"""
if os.path.isfile(model_args.export_quantization_dataset): if os.path.isfile(model_args.export_quantization_dataset):
data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None) data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None)
data_files = model_args.export_quantization_dataset data_files = model_args.export_quantization_dataset
@ -84,7 +80,7 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod
raise ValueError("Cannot find satisfying example, considering decrease `export_quantization_maxlen`.") raise ValueError("Cannot find satisfying example, considering decrease `export_quantization_maxlen`.")
sample_idx = random.randint(0, len(dataset) - 1) sample_idx = random.randint(0, len(dataset) - 1)
sample: Dict[str, "torch.Tensor"] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt") sample: dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
n_try += 1 n_try += 1
if sample["input_ids"].size(1) > maxlen: if sample["input_ids"].size(1) > maxlen:
break # TODO: fix large maxlen break # TODO: fix large maxlen
@ -101,11 +97,9 @@ def configure_quantization(
config: "PretrainedConfig", config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments", model_args: "ModelArguments",
init_kwargs: Dict[str, Any], init_kwargs: dict[str, Any],
) -> None: ) -> None:
r""" r"""Priority: PTQ-quantized (train/infer) > AutoGPTQ (export) > On-the-fly quantization (train/infer)."""
Priority: PTQ-quantized (train/infer) > AutoGPTQ (export) > On-the-fly quantization (train/infer)
"""
if getattr(config, "quantization_config", None): # ptq if getattr(config, "quantization_config", None): # ptq
if model_args.quantization_bit is not None: if model_args.quantization_bit is not None:
logger.warning_rank0("`quantization_bit` will not affect on the PTQ-quantized models.") logger.warning_rank0("`quantization_bit` will not affect on the PTQ-quantized models.")
@ -113,7 +107,7 @@ def configure_quantization(
if is_deepspeed_zero3_enabled() or is_fsdp_enabled(): if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.") raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.")
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None) quantization_config: dict[str, Any] = getattr(config, "quantization_config", None)
quant_method = quantization_config.get("quant_method", "") quant_method = quantization_config.get("quant_method", "")
if quant_method == QuantizationMethod.GPTQ: if quant_method == QuantizationMethod.GPTQ:

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, Optional from typing import TYPE_CHECKING, Any, Optional
from ...extras import logging from ...extras import logging
from ...extras.misc import get_current_device from ...extras.misc import get_current_device
@ -29,7 +29,7 @@ logger = logging.get_logger(__name__)
def _get_unsloth_kwargs( def _get_unsloth_kwargs(
config: "PretrainedConfig", model_name_or_path: str, model_args: "ModelArguments" config: "PretrainedConfig", model_name_or_path: str, model_args: "ModelArguments"
) -> Dict[str, Any]: ) -> dict[str, Any]:
return { return {
"model_name": model_name_or_path, "model_name": model_name_or_path,
"max_seq_length": model_args.model_max_length or 4096, "max_seq_length": model_args.model_max_length or 4096,
@ -47,10 +47,8 @@ def _get_unsloth_kwargs(
def load_unsloth_pretrained_model( def load_unsloth_pretrained_model(
config: "PretrainedConfig", model_args: "ModelArguments" config: "PretrainedConfig", model_args: "ModelArguments"
) -> Optional["PreTrainedModel"]: ) -> Optional["PreTrainedModel"]:
r""" r"""Optionally load pretrained model with unsloth. Used in training."""
Optionally loads pretrained model with unsloth. Used in training. from unsloth import FastLanguageModel # type: ignore
"""
from unsloth import FastLanguageModel
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.model_name_or_path, model_args) unsloth_kwargs = _get_unsloth_kwargs(config, model_args.model_name_or_path, model_args)
try: try:
@ -64,12 +62,10 @@ def load_unsloth_pretrained_model(
def get_unsloth_peft_model( def get_unsloth_peft_model(
model: "PreTrainedModel", model_args: "ModelArguments", peft_kwargs: Dict[str, Any] model: "PreTrainedModel", model_args: "ModelArguments", peft_kwargs: dict[str, Any]
) -> "PreTrainedModel": ) -> "PreTrainedModel":
r""" r"""Get the peft model for the pretrained model with unsloth. Used in training."""
Gets the peft model for the pretrained model with unsloth. Used in training. from unsloth import FastLanguageModel # type: ignore
"""
from unsloth import FastLanguageModel
unsloth_peft_kwargs = { unsloth_peft_kwargs = {
"model": model, "model": model,
@ -82,10 +78,8 @@ def get_unsloth_peft_model(
def load_unsloth_peft_model( def load_unsloth_peft_model(
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
) -> "PreTrainedModel": ) -> "PreTrainedModel":
r""" r"""Load peft model with unsloth. Used in both training and inference."""
Loads peft model with unsloth. Used in both training and inference. from unsloth import FastLanguageModel # type: ignore
"""
from unsloth import FastLanguageModel
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args) unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args)
try: try:

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING
import torch import torch
from transformers.utils import cached_file from transformers.utils import cached_file
@ -30,9 +30,8 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]: def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> dict[str, torch.Tensor]:
r""" r"""Load value head parameters from Hugging Face Hub or local disk.
Loads value head parameters from Hugging Face Hub or local disk.
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`. Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
""" """

View File

@ -15,8 +15,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple from typing import TYPE_CHECKING, Optional
import torch import torch
import transformers import transformers
@ -40,9 +41,9 @@ transformers_logger = transformers.utils.logging.get_logger(__name__)
class CompositeModel: class CompositeModel:
model_type: str model_type: str
projector_key: str projector_key: str
vision_model_keys: List[str] vision_model_keys: list[str]
language_model_keys: List[str] language_model_keys: list[str]
lora_conflict_keys: List[str] lora_conflict_keys: list[str]
def get_projector(self, module: "torch.nn.Module") -> "torch.nn.Module": def get_projector(self, module: "torch.nn.Module") -> "torch.nn.Module":
for key in self.projector_key.split("."): for key in self.projector_key.split("."):
@ -51,15 +52,15 @@ class CompositeModel:
return module return module
COMPOSITE_MODELS: Dict[str, "CompositeModel"] = {} COMPOSITE_MODELS: dict[str, "CompositeModel"] = {}
def _register_composite_model( def _register_composite_model(
model_type: str, model_type: str,
projector_key: Optional[str] = None, projector_key: Optional[str] = None,
vision_model_keys: Optional[List[str]] = None, vision_model_keys: Optional[list[str]] = None,
language_model_keys: Optional[List[str]] = None, language_model_keys: Optional[list[str]] = None,
lora_conflict_keys: Optional[List[str]] = None, lora_conflict_keys: Optional[list[str]] = None,
): ):
COMPOSITE_MODELS[model_type] = CompositeModel( COMPOSITE_MODELS[model_type] = CompositeModel(
model_type=model_type, model_type=model_type,
@ -116,12 +117,10 @@ class LlavaMultiModalProjectorForYiVLForVLLM(LlavaMultiModalProjectorForYiVL):
def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArguments") -> None: def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArguments") -> None:
r""" r"""Cast projector output to half precision for fine-tuning quantized VLMs."""
Casts projector output to half precision for fine-tuning quantized VLMs.
"""
def _mm_projector_forward_post_hook( def _mm_projector_forward_post_hook(
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor" module: "torch.nn.Module", args: tuple["torch.Tensor"], output: "torch.Tensor"
) -> "torch.Tensor": ) -> "torch.Tensor":
return output.to(model_args.compute_dtype) return output.to(model_args.compute_dtype)
@ -137,9 +136,7 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen
def configure_visual_model(config: "PretrainedConfig") -> None: def configure_visual_model(config: "PretrainedConfig") -> None:
r""" r"""Patch VLMs before loading them."""
Patches VLMs before loading them.
"""
if getattr(config, "text_config", None) and not getattr(config, "hidden_size", None): if getattr(config, "text_config", None) and not getattr(config, "hidden_size", None):
# required for ds zero3 and valuehead models # required for ds zero3 and valuehead models
setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None)) setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))
@ -149,10 +146,8 @@ def configure_visual_model(config: "PretrainedConfig") -> None:
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL
def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "FinetuningArguments") -> Set[str]: def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "FinetuningArguments") -> set[str]:
r""" r"""Freeze vision tower and language model for VLM full/freeze tuning."""
Freezes vision tower and language model for VLM full/freeze tuning.
"""
model_type = getattr(config, "model_type", None) model_type = getattr(config, "model_type", None)
forbidden_modules = set() forbidden_modules = set()
if model_type in COMPOSITE_MODELS: if model_type in COMPOSITE_MODELS:
@ -175,9 +170,7 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
def get_image_seqlen(config: "PretrainedConfig") -> int: def get_image_seqlen(config: "PretrainedConfig") -> int:
r""" r"""Compute the number of special tokens per image."""
Computes the number of special tokens per image.
"""
model_type = getattr(config, "model_type", None) model_type = getattr(config, "model_type", None)
if model_type == "llava": if model_type == "llava":
image_seqlen = (config.vision_config.image_size // config.vision_config.patch_size) ** 2 image_seqlen = (config.vision_config.image_size // config.vision_config.patch_size) ** 2
@ -192,17 +185,13 @@ def get_image_seqlen(config: "PretrainedConfig") -> int:
def get_patch_size(config: "PretrainedConfig", processor: "ProcessorMixin") -> int: def get_patch_size(config: "PretrainedConfig", processor: "ProcessorMixin") -> int:
r""" r"""Compute the patch size of the vit."""
Computes the patch size of the vit.
"""
patch_size = getattr(config.vision_config, "patch_size", getattr(processor, "patch_size", -1)) patch_size = getattr(config.vision_config, "patch_size", getattr(processor, "patch_size", -1))
return patch_size return patch_size
def get_vision_feature_select_strategy(config: "PretrainedConfig", processor: "ProcessorMixin") -> int: def get_vision_feature_select_strategy(config: "PretrainedConfig", processor: "ProcessorMixin") -> int:
r""" r"""Get the vision_feature_select_strategy."""
Get the vision_feature_select_strategy.
"""
vision_feature_select_strategy = getattr( vision_feature_select_strategy = getattr(
config, "vision_feature_select_strategy", getattr(processor, "vision_feature_select_strategy", "default") config, "vision_feature_select_strategy", getattr(processor, "vision_feature_select_strategy", "default")
) )
@ -211,10 +200,8 @@ def get_vision_feature_select_strategy(config: "PretrainedConfig", processor: "P
def patch_target_modules( def patch_target_modules(
model: "PreTrainedModel", finetuning_args: "FinetuningArguments", target_modules: Sequence[str] model: "PreTrainedModel", finetuning_args: "FinetuningArguments", target_modules: Sequence[str]
) -> List[str]: ) -> list[str]:
r""" r"""Freezes vision tower for VLM LoRA tuning."""
Freezes vision tower for VLM LoRA tuning.
"""
model_type = getattr(model.config, "model_type", None) model_type = getattr(model.config, "model_type", None)
if model_type in COMPOSITE_MODELS: if model_type in COMPOSITE_MODELS:
forbidden_modules = get_forbidden_modules(model.config, finetuning_args) forbidden_modules = get_forbidden_modules(model.config, finetuning_args)

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from types import MethodType from types import MethodType
from typing import TYPE_CHECKING, Any, Dict from typing import TYPE_CHECKING, Any
import torch import torch
from peft import PeftModel from peft import PeftModel
@ -93,7 +93,7 @@ def patch_config(
config: "PretrainedConfig", config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments", model_args: "ModelArguments",
init_kwargs: Dict[str, Any], init_kwargs: dict[str, Any],
is_trainable: bool, is_trainable: bool,
) -> None: ) -> None:
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32 if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32

View File

@ -19,7 +19,7 @@ import sys
import time import time
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta from datetime import timedelta
from typing import TYPE_CHECKING, Any, Dict, Optional from typing import TYPE_CHECKING, Any, Optional
import torch import torch
import transformers import transformers
@ -56,7 +56,8 @@ logger = logging.get_logger(__name__)
def fix_valuehead_checkpoint( def fix_valuehead_checkpoint(
model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool
) -> None: ) -> None:
r""" r"""Fix the valuehead checkpoint files.
The model is already unwrapped. The model is already unwrapped.
There are three cases: There are three cases:
@ -72,10 +73,10 @@ def fix_valuehead_checkpoint(
if safe_serialization: if safe_serialization:
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME) path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f: with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()} state_dict: dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
else: else:
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME) path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu") state_dict: dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
os.remove(path_to_checkpoint) os.remove(path_to_checkpoint)
decoder_state_dict, v_head_state_dict = {}, {} decoder_state_dict, v_head_state_dict = {}, {}
@ -98,9 +99,7 @@ def fix_valuehead_checkpoint(
class FixValueHeadModelCallback(TrainerCallback): class FixValueHeadModelCallback(TrainerCallback):
r""" r"""A callback for fixing the checkpoint for valuehead models."""
A callback for fixing the checkpoint for valuehead models.
"""
@override @override
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
@ -112,9 +111,7 @@ class FixValueHeadModelCallback(TrainerCallback):
class SaveProcessorCallback(TrainerCallback): class SaveProcessorCallback(TrainerCallback):
r""" r"""A callback for saving the processor."""
A callback for saving the processor.
"""
def __init__(self, processor: "ProcessorMixin") -> None: def __init__(self, processor: "ProcessorMixin") -> None:
self.processor = processor self.processor = processor
@ -132,9 +129,7 @@ class SaveProcessorCallback(TrainerCallback):
class PissaConvertCallback(TrainerCallback): class PissaConvertCallback(TrainerCallback):
r""" r"""A callback for converting the PiSSA adapter to a normal one."""
A callback for converting the PiSSA adapter to a normal one.
"""
@override @override
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
@ -177,9 +172,7 @@ class PissaConvertCallback(TrainerCallback):
class LogCallback(TrainerCallback): class LogCallback(TrainerCallback):
r""" r"""A callback for logging training and evaluation status."""
A callback for logging training and evaluation status.
"""
def __init__(self) -> None: def __init__(self) -> None:
# Progress # Progress
@ -188,7 +181,7 @@ class LogCallback(TrainerCallback):
self.max_steps = 0 self.max_steps = 0
self.elapsed_time = "" self.elapsed_time = ""
self.remaining_time = "" self.remaining_time = ""
self.thread_pool: Optional["ThreadPoolExecutor"] = None self.thread_pool: Optional[ThreadPoolExecutor] = None
# Status # Status
self.aborted = False self.aborted = False
self.do_train = False self.do_train = False
@ -219,7 +212,7 @@ class LogCallback(TrainerCallback):
self.elapsed_time = str(timedelta(seconds=int(elapsed_time))) self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
self.remaining_time = str(timedelta(seconds=int(remaining_time))) self.remaining_time = str(timedelta(seconds=int(remaining_time)))
def _write_log(self, output_dir: str, logs: Dict[str, Any]) -> None: def _write_log(self, output_dir: str, logs: dict[str, Any]) -> None:
with open(os.path.join(output_dir, TRAINER_LOG), "a", encoding="utf-8") as f: with open(os.path.join(output_dir, TRAINER_LOG), "a", encoding="utf-8") as f:
f.write(json.dumps(logs) + "\n") f.write(json.dumps(logs) + "\n")
@ -348,9 +341,7 @@ class LogCallback(TrainerCallback):
class ReporterCallback(TrainerCallback): class ReporterCallback(TrainerCallback):
r""" r"""A callback for reporting training status to external logger."""
A callback for reporting training status to external logger.
"""
def __init__( def __init__(
self, self,

View File

@ -19,7 +19,7 @@ import warnings
from collections import defaultdict from collections import defaultdict
from contextlib import nullcontext from contextlib import nullcontext
from types import MethodType from types import MethodType
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union from typing import TYPE_CHECKING, Literal, Optional, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -129,15 +129,11 @@ class CustomDPOTrainer(DPOTrainer):
@override @override
def get_batch_samples(self, epoch_iterator, num_batches): def get_batch_samples(self, epoch_iterator, num_batches):
r""" r"""Replace the method of DPO Trainer with the one of the standard Trainer."""
Replaces the method of KTO Trainer with the one of the standard Trainer.
"""
return Trainer.get_batch_samples(self, epoch_iterator, num_batches) return Trainer.get_batch_samples(self, epoch_iterator, num_batches)
def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor": def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
r""" r"""Compute ORPO's odds ratio (OR) loss for batched log probabilities of the policy model."""
Computes ORPO's odds ratio (OR) loss for batched log probabilities of the policy model.
"""
log_odds = (chosen_logps - rejected_logps) - ( log_odds = (chosen_logps - rejected_logps) - (
torch.log1p(-torch.exp(chosen_logps)) - torch.log1p(-torch.exp(rejected_logps)) torch.log1p(-torch.exp(chosen_logps)) - torch.log1p(-torch.exp(rejected_logps))
) )
@ -147,9 +143,7 @@ class CustomDPOTrainer(DPOTrainer):
return orpo_loss return orpo_loss
def simpo_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor": def simpo_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
r""" r"""Compute SimPO loss for batched log probabilities of the policy model."""
Computes SimPO loss for batched log probabilities of the policy model.
"""
pi_logratios = chosen_logps - rejected_logps pi_logratios = chosen_logps - rejected_logps
gamma_logratios = self.simpo_gamma / self.beta gamma_logratios = self.simpo_gamma / self.beta
logits = pi_logratios - gamma_logratios logits = pi_logratios - gamma_logratios
@ -162,10 +156,8 @@ class CustomDPOTrainer(DPOTrainer):
policy_rejected_logps: "torch.Tensor", policy_rejected_logps: "torch.Tensor",
reference_chosen_logps: Optional["torch.Tensor"], reference_chosen_logps: Optional["torch.Tensor"],
reference_rejected_logps: Optional["torch.Tensor"], reference_rejected_logps: Optional["torch.Tensor"],
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]: ) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r""" r"""Compute loss for preference learning."""
Computes loss for preference learning.
"""
if not self.finetuning_args.use_ref_model: if not self.finetuning_args.use_ref_model:
if self.loss_type == "orpo": if self.loss_type == "orpo":
losses = self.odds_ratio_loss(policy_chosen_logps, policy_rejected_logps) losses = self.odds_ratio_loss(policy_chosen_logps, policy_rejected_logps)
@ -185,17 +177,16 @@ class CustomDPOTrainer(DPOTrainer):
@override @override
def concatenated_forward( def concatenated_forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: ) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r""" r"""Compute the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO.
Computes the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO.
Otherwise the average log probabilities. Otherwise the average log probabilities.
""" """
if self.finetuning_args.use_ref_model: if self.finetuning_args.use_ref_model:
batch = nested_detach(batch, clone=True) # avoid error batch = nested_detach(batch, clone=True) # avoid error
all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32) all_logits: torch.Tensor = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"]) all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"])
if self.loss_type in ["ipo", "orpo", "simpo"]: if self.loss_type in ["ipo", "orpo", "simpo"]:
all_logps = all_logps / valid_length all_logps = all_logps / valid_length
@ -212,11 +203,9 @@ class CustomDPOTrainer(DPOTrainer):
@override @override
def compute_reference_log_probs( def compute_reference_log_probs(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"]
) -> Tuple[Optional["torch.Tensor"], Optional["torch.Tensor"]]: ) -> tuple[Optional["torch.Tensor"], Optional["torch.Tensor"]]:
r""" r"""Compute log probabilities of the reference model."""
Computes log probabilities of the reference model.
"""
if not self.finetuning_args.use_ref_model: if not self.finetuning_args.use_ref_model:
return None, None return None, None
@ -236,12 +225,10 @@ class CustomDPOTrainer(DPOTrainer):
def get_batch_loss_metrics( def get_batch_loss_metrics(
self, self,
model: "PreTrainedModel", model: "PreTrainedModel",
batch: Dict[str, "torch.Tensor"], batch: dict[str, "torch.Tensor"],
train_eval: Literal["train", "eval"] = "train", train_eval: Literal["train", "eval"] = "train",
) -> Tuple["torch.Tensor", Dict[str, "torch.Tensor"]]: ) -> tuple["torch.Tensor", dict[str, "torch.Tensor"]]:
r""" r"""Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
"""
metrics = {} metrics = {}
( (
policy_chosen_logps, policy_chosen_logps,
@ -279,18 +266,14 @@ class CustomDPOTrainer(DPOTrainer):
@override @override
def compute_loss( def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs self, model: "PreTrainedModel", inputs: dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]: ) -> Union["torch.Tensor", tuple["torch.Tensor", list["torch.Tensor"]]]:
r""" r"""Subclass and override to accept extra kwargs."""
Subclass and override to accept extra kwargs.
"""
return super().compute_loss(model, inputs, return_outputs) return super().compute_loss(model, inputs, return_outputs)
@override @override
def log(self, logs: Dict[str, float], *args, **kwargs) -> None: def log(self, logs: dict[str, float], *args, **kwargs) -> None:
r""" r"""Log `logs` on the various objects watching training, including stored metrics."""
Log `logs` on the various objects watching training, including stored metrics.
"""
# logs either has "loss" or "eval_loss" # logs either has "loss" or "eval_loss"
train_eval = "train" if "loss" in logs else "eval" train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs # Add averaged stored metrics to logs

View File

@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, Optional
from ...data import PairwiseDataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer from ...data import PairwiseDataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
@ -38,7 +38,7 @@ def run_dpo(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None, callbacks: Optional[list["TrainerCallback"]] = None,
): ):
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"] tokenizer = tokenizer_module["tokenizer"]

View File

@ -19,7 +19,7 @@ import warnings
from collections import defaultdict from collections import defaultdict
from contextlib import nullcontext from contextlib import nullcontext
from types import MethodType from types import MethodType
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union from typing import TYPE_CHECKING, Literal, Optional, Union
import torch import torch
from transformers import Trainer from transformers import Trainer
@ -120,9 +120,7 @@ class CustomKTOTrainer(KTOTrainer):
@override @override
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]: def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
r""" r"""Replace the sequential sampler of KTO Trainer created by trl with the random sampler."""
Replaces the sequential sampler of KTO Trainer created by trl with the random sampler.
"""
if self.finetuning_args.disable_shuffling: if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset) return torch.utils.data.SequentialSampler(self.train_dataset)
@ -130,18 +128,14 @@ class CustomKTOTrainer(KTOTrainer):
@override @override
def get_batch_samples(self, epoch_iterator, num_batches): def get_batch_samples(self, epoch_iterator, num_batches):
r""" r"""Replace the method of KTO Trainer with the one of the standard Trainer."""
Replaces the method of KTO Trainer with the one of the standard Trainer.
"""
return Trainer.get_batch_samples(self, epoch_iterator, num_batches) return Trainer.get_batch_samples(self, epoch_iterator, num_batches)
@override @override
def forward( def forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = "" self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]: ) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r""" r"""Run forward pass and computes the log probabilities."""
Runs forward pass and computes the log probabilities.
"""
batch = nested_detach(batch, clone=True) # avoid error batch = nested_detach(batch, clone=True) # avoid error
model_inputs = { model_inputs = {
"input_ids": batch[f"{prefix}input_ids"], "input_ids": batch[f"{prefix}input_ids"],
@ -171,8 +165,8 @@ class CustomKTOTrainer(KTOTrainer):
@override @override
def concatenated_forward( def concatenated_forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: ) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
target_logits, target_logps, target_logps_avg = self.forward(model, batch) target_logits, target_logps, target_logps_avg = self.forward(model, batch)
with torch.no_grad(): with torch.no_grad():
_, kl_logps, _ = self.forward(model, batch, prefix="kl_") _, kl_logps, _ = self.forward(model, batch, prefix="kl_")
@ -189,11 +183,9 @@ class CustomKTOTrainer(KTOTrainer):
@override @override
def compute_reference_log_probs( def compute_reference_log_probs(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]: ) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r""" r"""Compute log probabilities of the reference model."""
Computes log probabilities of the reference model.
"""
if self.ref_model is None: if self.ref_model is None:
ref_model = model ref_model = model
ref_context = self.accelerator.unwrap_model(model).disable_adapter() ref_context = self.accelerator.unwrap_model(model).disable_adapter()
@ -212,11 +204,9 @@ class CustomKTOTrainer(KTOTrainer):
def get_batch_loss_metrics( def get_batch_loss_metrics(
self, self,
model: "PreTrainedModel", model: "PreTrainedModel",
batch: Dict[str, "torch.Tensor"], batch: dict[str, "torch.Tensor"],
) -> Tuple["torch.Tensor", Dict[str, "torch.Tensor"]]: ) -> tuple["torch.Tensor", dict[str, "torch.Tensor"]]:
r""" r"""Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
"""
metrics = {} metrics = {}
( (
policy_chosen_logps, policy_chosen_logps,
@ -262,18 +252,14 @@ class CustomKTOTrainer(KTOTrainer):
@override @override
def compute_loss( def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs self, model: "PreTrainedModel", inputs: dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]: ) -> Union["torch.Tensor", tuple["torch.Tensor", list["torch.Tensor"]]]:
r""" r"""Subclass and override to accept extra kwargs."""
Subclass and override to accept extra kwargs.
"""
return super().compute_loss(model, inputs, return_outputs) return super().compute_loss(model, inputs, return_outputs)
@override @override
def log(self, logs: Dict[str, float], *args, **kwargs) -> None: def log(self, logs: dict[str, float], *args, **kwargs) -> None:
r""" r"""Log `logs` on the various objects watching training, including stored metrics."""
Log `logs` on the various objects watching training, including stored metrics.
"""
# logs either has "loss" or "eval_loss" # logs either has "loss" or "eval_loss"
train_eval = "train" if "loss" in logs else "eval" train_eval = "train" if "loss" in logs else "eval"
prefix = "eval_" if train_eval == "eval" else "" prefix = "eval_" if train_eval == "eval" else ""
@ -291,7 +277,7 @@ class CustomKTOTrainer(KTOTrainer):
metric_list = torch.tensor(metric_list, dtype=torch.float).to(self.accelerator.device) metric_list = torch.tensor(metric_list, dtype=torch.float).to(self.accelerator.device)
metric_list = self.accelerator.reduce(metric_list, "sum").tolist() metric_list = self.accelerator.reduce(metric_list, "sum").tolist()
metric_dict: Dict[str, float] = dict(zip(key_list, metric_list)) metric_dict: dict[str, float] = dict(zip(key_list, metric_list))
for split in ["chosen", "rejected"]: # accumulate average metrics from sums and lengths for split in ["chosen", "rejected"]: # accumulate average metrics from sums and lengths
if f"count/{split}" in metric_dict: if f"count/{split}" in metric_dict:
for key in ("rewards", "logps", "logits"): for key in ("rewards", "logps", "logits"):

View File

@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, Optional
from ...data import KTODataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer from ...data import KTODataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
@ -37,7 +37,7 @@ def run_kto(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None, callbacks: Optional[list["TrainerCallback"]] = None,
): ):
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"] tokenizer = tokenizer_module["tokenizer"]

View File

@ -14,7 +14,7 @@
import json import json
from contextlib import nullcontext from contextlib import nullcontext
from typing import TYPE_CHECKING, Dict, List, Literal, Optional from typing import TYPE_CHECKING, Literal, Optional
import torch import torch
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
@ -31,10 +31,8 @@ if TYPE_CHECKING:
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
def get_rewards_from_server(server_url: str, messages: List[str]) -> List["torch.Tensor"]: def get_rewards_from_server(server_url: str, messages: list[str]) -> list["torch.Tensor"]:
r""" r"""Get reward scores from the API server."""
Gets reward scores from the API server.
"""
headers = {"Content-Type": "application/json"} headers = {"Content-Type": "application/json"}
payload = {"model": "model", "messages": messages} payload = {"model": "model", "messages": messages}
response = requests.post(server_url, json=payload, headers=headers) response = requests.post(server_url, json=payload, headers=headers)
@ -43,9 +41,7 @@ def get_rewards_from_server(server_url: str, messages: List[str]) -> List["torch
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None: def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
r""" r"""Replace the default/reward modules in the model. The model is already unwrapped."""
Replaces the default/reward modules in the model. The model is already unwrapped.
"""
v_head_layer = model.v_head.summary v_head_layer = model.v_head.summary
if is_deepspeed_zero3_enabled(): if is_deepspeed_zero3_enabled():
import deepspeed # type: ignore import deepspeed # type: ignore
@ -66,10 +62,8 @@ def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["d
v_head_layer.bias.data = model.get_buffer(f"{target}_head_bias").detach().clone().to(device) v_head_layer.bias.data = model.get_buffer(f"{target}_head_bias").detach().clone().to(device)
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, "torch.Tensor"]: def dump_layernorm(model: "PreTrainedModel") -> dict[str, "torch.Tensor"]:
r""" r"""Dump the layernorm parameters in the model. The model is already unwrapped (and gathered)."""
Dumps the layernorm parameters in the model. The model is already unwrapped (and gathered).
"""
layer_norm_params = {} layer_norm_params = {}
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if param.data.dtype == torch.float32: if param.data.dtype == torch.float32:
@ -79,10 +73,8 @@ def dump_layernorm(model: "PreTrainedModel") -> Dict[str, "torch.Tensor"]:
return layer_norm_params return layer_norm_params
def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[Dict[str, "torch.Tensor"]] = None) -> None: def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[dict[str, "torch.Tensor"]] = None) -> None:
r""" r"""Restore the layernorm parameters in the model. The model is already unwrapped (and gathered)."""
Restores the layernorm parameters in the model. The model is already unwrapped (and gathered).
"""
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if name in layernorm_params: if name in layernorm_params:
param.data = layernorm_params[name] param.data = layernorm_params[name]

View File

@ -20,7 +20,7 @@ import os
import sys import sys
import warnings import warnings
from types import MethodType from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Optional
import torch import torch
from accelerate.utils import DistributedDataParallelKwargs from accelerate.utils import DistributedDataParallelKwargs
@ -62,9 +62,7 @@ logger = logging.get_logger(__name__)
class CustomPPOTrainer(PPOTrainer, Trainer): class CustomPPOTrainer(PPOTrainer, Trainer):
r""" r"""Inherit PPOTrainer."""
Inherits PPOTrainer.
"""
def __init__( def __init__(
self, self,
@ -72,7 +70,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments", generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]], callbacks: Optional[list["TrainerCallback"]],
model: "AutoModelForCausalLMWithValueHead", model: "AutoModelForCausalLMWithValueHead",
reward_model: Optional["AutoModelForCausalLMWithValueHead"], reward_model: Optional["AutoModelForCausalLMWithValueHead"],
ref_model: Optional["AutoModelForCausalLMWithValueHead"], ref_model: Optional["AutoModelForCausalLMWithValueHead"],
@ -187,9 +185,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None: def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
r""" r"""Implement training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer."""
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
"""
if resume_from_checkpoint is not None: if resume_from_checkpoint is not None:
raise ValueError("`resume_from_checkpoint` will be supported in the future version.") raise ValueError("`resume_from_checkpoint` will be supported in the future version.")
@ -221,9 +217,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
logger.info_rank0(f" Num Epochs = {num_train_epochs:,}") logger.info_rank0(f" Num Epochs = {num_train_epochs:,}")
logger.info_rank0(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}") logger.info_rank0(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}")
logger.info_rank0( logger.info_rank0(
" Total train batch size (w. parallel, buffer, distributed & accumulation) = {:,}".format( f" Total train batch size (w. parallel, buffer, distributed & accumulation) = {total_train_batch_size:,}"
total_train_batch_size
)
) )
logger.info_rank0(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps:,}") logger.info_rank0(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps:,}")
logger.info_rank0(f" Num optimization epochs per batch = {self.finetuning_args.ppo_epochs:,}") logger.info_rank0(f" Num optimization epochs per batch = {self.finetuning_args.ppo_epochs:,}")
@ -339,21 +333,19 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
return lr_scheduler return lr_scheduler
@torch.no_grad() @torch.no_grad()
def get_inputs(self, batch: Dict[str, "torch.Tensor"]) -> Tuple[List["torch.Tensor"], List["torch.Tensor"]]: def get_inputs(self, batch: dict[str, "torch.Tensor"]) -> tuple[list["torch.Tensor"], list["torch.Tensor"]]:
r""" r"""Generate model's responses given queries."""
Generates model's responses given queries.
"""
if batch["input_ids"].size(0) == 1: # handle llama2 ppo with gradient accumulation > 1 if batch["input_ids"].size(0) == 1: # handle llama2 ppo with gradient accumulation > 1
start_index = (batch["input_ids"][0] != self.tokenizer.pad_token_id).nonzero()[0].item() start_index = (batch["input_ids"][0] != self.tokenizer.pad_token_id).nonzero()[0].item()
for k, v in batch.items(): for k, v in batch.items():
batch[k] = v[:, start_index:] batch[k] = v[:, start_index:]
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) unwrapped_model: AutoModelForCausalLMWithValueHead = self.accelerator.unwrap_model(self.model)
if self.model_args.upcast_layernorm: if self.model_args.upcast_layernorm:
layernorm_params = dump_layernorm(unwrapped_model) layernorm_params = dump_layernorm(unwrapped_model)
generate_output: "torch.Tensor" = unwrapped_model.generate( generate_output: torch.Tensor = unwrapped_model.generate(
generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch
) )
if self.model_args.upcast_layernorm: if self.model_args.upcast_layernorm:
@ -381,11 +373,10 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
@torch.no_grad() @torch.no_grad()
def get_rewards( def get_rewards(
self, self,
queries: List["torch.Tensor"], queries: list["torch.Tensor"],
responses: List["torch.Tensor"], responses: list["torch.Tensor"],
) -> List["torch.Tensor"]: ) -> list["torch.Tensor"]:
r""" r"""Compute scores using given reward model.
Computes scores using given reward model.
Both inputs and outputs are put on CPU. Both inputs and outputs are put on CPU.
""" """
@ -394,8 +385,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
messages = self.tokenizer.batch_decode(token_ids, skip_special_tokens=False) messages = self.tokenizer.batch_decode(token_ids, skip_special_tokens=False)
return get_rewards_from_server(self.reward_model, messages) return get_rewards_from_server(self.reward_model, messages)
batch: Dict[str, "torch.Tensor"] = self.prepare_model_inputs(queries, responses) batch: dict[str, torch.Tensor] = self.prepare_model_inputs(queries, responses)
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) unwrapped_model: AutoModelForCausalLMWithValueHead = self.accelerator.unwrap_model(self.model)
if self.finetuning_args.reward_model_type == "lora": if self.finetuning_args.reward_model_type == "lora":
replace_model(unwrapped_model, target="reward") replace_model(unwrapped_model, target="reward")
@ -404,7 +395,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
reward_model = self.reward_model reward_model = self.reward_model
with unwrap_model_for_generation(reward_model, self.accelerator), self.amp_context: # support bf16 with unwrap_model_for_generation(reward_model, self.accelerator), self.amp_context: # support bf16
values: "torch.Tensor" = reward_model(**batch, return_dict=True, use_cache=False)[-1] values: torch.Tensor = reward_model(**batch, return_dict=True, use_cache=False)[-1]
if self.finetuning_args.reward_model_type == "lora": if self.finetuning_args.reward_model_type == "lora":
replace_model(unwrapped_model, target="default") replace_model(unwrapped_model, target="default")
@ -419,12 +410,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
model: "AutoModelForCausalLMWithValueHead", model: "AutoModelForCausalLMWithValueHead",
queries: "torch.Tensor", queries: "torch.Tensor",
responses: "torch.Tensor", responses: "torch.Tensor",
model_inputs: Dict[str, Any], model_inputs: dict[str, Any],
return_logits: bool = False, return_logits: bool = False,
response_masks: Optional["torch.Tensor"] = None, response_masks: Optional["torch.Tensor"] = None,
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], "torch.Tensor", "torch.Tensor"]: ) -> tuple["torch.Tensor", Optional["torch.Tensor"], "torch.Tensor", "torch.Tensor"]:
r""" r"""Calculate model outputs in multiple batches.
Calculates model outputs in multiple batches.
Subclass and override to inject custom behavior. Subclass and override to inject custom behavior.
""" """
@ -483,8 +473,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
@override @override
def save_model(self, output_dir: Optional[str] = None) -> None: def save_model(self, output_dir: Optional[str] = None) -> None:
r""" r"""Save model checkpoint.
Saves model checkpoint.
Subclass and override to inject custom behavior. Subclass and override to inject custom behavior.
""" """
@ -508,5 +497,5 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.model.save_checkpoint(output_dir) self.model.save_checkpoint(output_dir)
elif self.args.should_save: elif self.args.should_save:
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) unwrapped_model: AutoModelForCausalLMWithValueHead = self.accelerator.unwrap_model(self.model)
self._save(output_dir, state_dict=unwrapped_model.state_dict()) self._save(output_dir, state_dict=unwrapped_model.state_dict())

View File

@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, Optional
from ...data import MultiModalDataCollatorForSeq2Seq, get_dataset, get_template_and_fix_tokenizer from ...data import MultiModalDataCollatorForSeq2Seq, get_dataset, get_template_and_fix_tokenizer
from ...extras.ploting import plot_loss from ...extras.ploting import plot_loss
@ -37,7 +37,7 @@ def run_ppo(
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments", generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None, callbacks: Optional[list["TrainerCallback"]] = None,
): ):
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"] tokenizer = tokenizer_module["tokenizer"]
@ -53,7 +53,7 @@ def run_ppo(
reward_model = create_reward_model(model, model_args, finetuning_args) reward_model = create_reward_model(model, model_args, finetuning_args)
# Initialize our Trainer # Initialize our Trainer
ppo_trainer: "CustomPPOTrainer" = CustomPPOTrainer( ppo_trainer: CustomPPOTrainer = CustomPPOTrainer(
model_args=model_args, model_args=model_args,
training_args=training_args, training_args=training_args,
finetuning_args=finetuning_args, finetuning_args=finetuning_args,

View File

@ -31,9 +31,7 @@ if TYPE_CHECKING:
class CustomTrainer(Trainer): class CustomTrainer(Trainer):
r""" r"""Inherit Trainer for custom optimizer."""
Inherits Trainer for custom optimizer.
"""
def __init__( def __init__(
self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs

View File

@ -16,7 +16,7 @@
# limitations under the License. # limitations under the License.
import math import math
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, Optional
from transformers import DataCollatorForLanguageModeling from transformers import DataCollatorForLanguageModeling
@ -38,7 +38,7 @@ def run_pt(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None, callbacks: Optional[list["TrainerCallback"]] = None,
): ):
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"] tokenizer = tokenizer_module["tokenizer"]

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Optional from typing import TYPE_CHECKING, Optional
import numpy as np import numpy as np
@ -26,11 +26,9 @@ if TYPE_CHECKING:
@dataclass @dataclass
class ComputeAccuracy: class ComputeAccuracy:
r""" r"""Compute reward accuracy and support `batch_eval_metrics`."""
Computes reward accuracy and supports `batch_eval_metrics`.
"""
def _dump(self) -> Optional[Dict[str, float]]: def _dump(self) -> Optional[dict[str, float]]:
result = None result = None
if hasattr(self, "score_dict"): if hasattr(self, "score_dict"):
result = {k: float(np.mean(v)) for k, v in self.score_dict.items()} result = {k: float(np.mean(v)) for k, v in self.score_dict.items()}
@ -41,7 +39,7 @@ class ComputeAccuracy:
def __post_init__(self): def __post_init__(self):
self._dump() self._dump()
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]: def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[dict[str, float]]:
chosen_scores, rejected_scores = numpify(eval_preds.predictions[0]), numpify(eval_preds.predictions[1]) chosen_scores, rejected_scores = numpify(eval_preds.predictions[0]), numpify(eval_preds.predictions[1])
if not chosen_scores.shape: if not chosen_scores.shape:
self.score_dict["accuracy"].append(chosen_scores > rejected_scores) self.score_dict["accuracy"].append(chosen_scores > rejected_scores)

View File

@ -18,7 +18,7 @@
import json import json
import os import os
from types import MethodType from types import MethodType
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Optional, Union
import torch import torch
from transformers import Trainer from transformers import Trainer
@ -41,9 +41,7 @@ logger = logging.get_logger(__name__)
class PairwiseTrainer(Trainer): class PairwiseTrainer(Trainer):
r""" r"""Inherits Trainer to compute pairwise loss."""
Inherits Trainer to compute pairwise loss.
"""
def __init__( def __init__(
self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs
@ -88,10 +86,9 @@ class PairwiseTrainer(Trainer):
@override @override
def compute_loss( def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs self, model: "PreTrainedModel", inputs: dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]: ) -> Union["torch.Tensor", tuple["torch.Tensor", list["torch.Tensor"]]]:
r""" r"""Compute pairwise loss. The first n examples are chosen and the last n examples are rejected.
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
Subclass and override to inject custom behavior. Subclass and override to inject custom behavior.
@ -113,8 +110,7 @@ class PairwiseTrainer(Trainer):
return loss return loss
def save_predictions(self, predict_results: "PredictionOutput") -> None: def save_predictions(self, predict_results: "PredictionOutput") -> None:
r""" r"""Save model predictions to `output_dir`.
Saves model predictions to `output_dir`.
A custom behavior that not contained in Seq2SeqTrainer. A custom behavior that not contained in Seq2SeqTrainer.
""" """
@ -126,7 +122,7 @@ class PairwiseTrainer(Trainer):
chosen_scores, rejected_scores = predict_results.predictions chosen_scores, rejected_scores = predict_results.predictions
with open(output_prediction_file, "w", encoding="utf-8") as writer: with open(output_prediction_file, "w", encoding="utf-8") as writer:
res: List[str] = [] res: list[str] = []
for c_score, r_score in zip(chosen_scores, rejected_scores): for c_score, r_score in zip(chosen_scores, rejected_scores):
res.append(json.dumps({"chosen": round(float(c_score), 2), "rejected": round(float(r_score), 2)})) res.append(json.dumps({"chosen": round(float(c_score), 2), "rejected": round(float(r_score), 2)}))

View File

@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, Optional
from ...data import PairwiseDataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer from ...data import PairwiseDataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer
from ...extras.ploting import plot_loss from ...extras.ploting import plot_loss
@ -37,7 +37,7 @@ def run_rm(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None, callbacks: Optional[list["TrainerCallback"]] = None,
): ):
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"] tokenizer = tokenizer_module["tokenizer"]

View File

@ -17,7 +17,7 @@
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Optional from typing import TYPE_CHECKING, Optional
import numpy as np import numpy as np
import torch import torch
@ -45,9 +45,7 @@ if is_rouge_available():
def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "torch.Tensor": def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "torch.Tensor":
r""" r"""Compute the token with the largest likelihood to reduce memory footprint."""
Computes the token with the largest likelihood to reduce memory footprint.
"""
if isinstance(logits, (list, tuple)): if isinstance(logits, (list, tuple)):
if logits[0].dim() == 3: # (batch_size, seq_len, vocab_size) if logits[0].dim() == 3: # (batch_size, seq_len, vocab_size)
logits = logits[0] logits = logits[0]
@ -62,11 +60,9 @@ def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "tor
@dataclass @dataclass
class ComputeAccuracy: class ComputeAccuracy:
r""" r"""Compute accuracy and support `batch_eval_metrics`."""
Computes accuracy and supports `batch_eval_metrics`.
"""
def _dump(self) -> Optional[Dict[str, float]]: def _dump(self) -> Optional[dict[str, float]]:
result = None result = None
if hasattr(self, "score_dict"): if hasattr(self, "score_dict"):
result = {k: float(np.mean(v)) for k, v in self.score_dict.items()} result = {k: float(np.mean(v)) for k, v in self.score_dict.items()}
@ -77,7 +73,7 @@ class ComputeAccuracy:
def __post_init__(self): def __post_init__(self):
self._dump() self._dump()
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]: def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[dict[str, float]]:
preds, labels = numpify(eval_preds.predictions), numpify(eval_preds.label_ids) preds, labels = numpify(eval_preds.predictions), numpify(eval_preds.label_ids)
for i in range(len(preds)): for i in range(len(preds)):
pred, label = preds[i, :-1], labels[i, 1:] pred, label = preds[i, :-1], labels[i, 1:]
@ -90,15 +86,14 @@ class ComputeAccuracy:
@dataclass @dataclass
class ComputeSimilarity: class ComputeSimilarity:
r""" r"""Compute text similarity scores and support `batch_eval_metrics`.
Computes text similarity scores and supports `batch_eval_metrics`.
Wraps the tokenizer into metric functions, used in CustomSeq2SeqTrainer. Wraps the tokenizer into metric functions, used in CustomSeq2SeqTrainer.
""" """
tokenizer: "PreTrainedTokenizer" tokenizer: "PreTrainedTokenizer"
def _dump(self) -> Optional[Dict[str, float]]: def _dump(self) -> Optional[dict[str, float]]:
result = None result = None
if hasattr(self, "score_dict"): if hasattr(self, "score_dict"):
result = {k: float(np.mean(v)) for k, v in self.score_dict.items()} result = {k: float(np.mean(v)) for k, v in self.score_dict.items()}
@ -109,7 +104,7 @@ class ComputeSimilarity:
def __post_init__(self): def __post_init__(self):
self._dump() self._dump()
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]: def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[dict[str, float]]:
preds, labels = numpify(eval_preds.predictions), numpify(eval_preds.label_ids) preds, labels = numpify(eval_preds.predictions), numpify(eval_preds.label_ids)
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id) preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)

View File

@ -18,7 +18,7 @@
import json import json
import os import os
from types import MethodType from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Optional, Union
import numpy as np import numpy as np
import torch import torch
@ -44,21 +44,19 @@ logger = logging.get_logger(__name__)
class CustomSeq2SeqTrainer(Seq2SeqTrainer): class CustomSeq2SeqTrainer(Seq2SeqTrainer):
r""" r"""Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE."""
Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE.
"""
def __init__( def __init__(
self, self,
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
gen_kwargs: Optional[Dict[str, Any]] = None, gen_kwargs: Optional[dict[str, Any]] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
if is_transformers_version_greater_than("4.46"): if is_transformers_version_greater_than("4.46"):
kwargs["processing_class"] = kwargs.pop("tokenizer") kwargs["processing_class"] = kwargs.pop("tokenizer")
else: else:
self.processing_class: "PreTrainedTokenizer" = kwargs.get("tokenizer") self.processing_class: PreTrainedTokenizer = kwargs.get("tokenizer")
super().__init__(**kwargs) super().__init__(**kwargs)
self.finetuning_args = finetuning_args self.finetuning_args = finetuning_args
@ -99,13 +97,12 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
def prediction_step( def prediction_step(
self, self,
model: "torch.nn.Module", model: "torch.nn.Module",
inputs: Dict[str, Union["torch.Tensor", Any]], inputs: dict[str, Union["torch.Tensor", Any]],
prediction_loss_only: bool, prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None, ignore_keys: Optional[list[str]] = None,
**gen_kwargs, **gen_kwargs,
) -> Tuple[Optional[float], Optional["torch.Tensor"], Optional["torch.Tensor"]]: ) -> tuple[Optional[float], Optional["torch.Tensor"], Optional["torch.Tensor"]]:
r""" r"""Remove the prompt part in the generated tokens.
Removes the prompt part in the generated tokens.
Subclass and override to inject custom behavior. Subclass and override to inject custom behavior.
""" """
@ -126,8 +123,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
def save_predictions( def save_predictions(
self, dataset: "Dataset", predict_results: "PredictionOutput", skip_special_tokens: bool = True self, dataset: "Dataset", predict_results: "PredictionOutput", skip_special_tokens: bool = True
) -> None: ) -> None:
r""" r"""Save model predictions to `output_dir`.
Saves model predictions to `output_dir`.
A custom behavior that not contained in Seq2SeqTrainer. A custom behavior that not contained in Seq2SeqTrainer.
""" """

View File

@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, Optional
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
@ -43,7 +43,7 @@ def run_sft(
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments", generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None, callbacks: Optional[list["TrainerCallback"]] = None,
): ):
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"] tokenizer = tokenizer_module["tokenizer"]

View File

@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, Dict, Optional, Sequence, Set, Tuple, Union from collections.abc import Sequence
from typing import TYPE_CHECKING, Optional, Union
import torch import torch
from peft import PeftModel from peft import PeftModel
@ -43,7 +44,7 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_k
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) is True assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) is True
def check_lora_model(model: "LoraModel") -> Tuple[Set[str], Set[str]]: def check_lora_model(model: "LoraModel") -> tuple[set[str], set[str]]:
linear_modules, extra_modules = set(), set() linear_modules, extra_modules = set(), set()
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if any(module in name for module in ["lora_A", "lora_B"]): if any(module in name for module in ["lora_A", "lora_B"]):
@ -83,7 +84,7 @@ def load_reference_model(
) -> Union["PreTrainedModel", "LoraModel"]: ) -> Union["PreTrainedModel", "LoraModel"]:
current_device = get_current_device() current_device = get_current_device()
if add_valuehead: if add_valuehead:
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained( model: AutoModelForCausalLMWithValueHead = AutoModelForCausalLMWithValueHead.from_pretrained(
model_path, torch_dtype=torch.float16, device_map=current_device model_path, torch_dtype=torch.float16, device_map=current_device
) )
if not is_trainable: if not is_trainable:
@ -111,7 +112,7 @@ def load_dataset_module(**kwargs) -> "DatasetModule":
def patch_valuehead_model() -> None: def patch_valuehead_model() -> None:
def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]) -> None: def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: dict[str, "torch.Tensor"]) -> None:
state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")} state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")}
self.v_head.load_state_dict(state_dict, strict=False) self.v_head.load_state_dict(state_dict, strict=False)
del state_dict del state_dict

View File

@ -21,7 +21,7 @@ import json
import os import os
from collections.abc import Mapping from collections.abc import Mapping
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import torch import torch
from transformers import Trainer from transformers import Trainer
@ -63,12 +63,10 @@ logger = logging.get_logger(__name__)
class DummyOptimizer(torch.optim.Optimizer): class DummyOptimizer(torch.optim.Optimizer):
r""" r"""A dummy optimizer used for the GaLore or APOLLO algorithm."""
A dummy optimizer used for the GaLore or APOLLO algorithm.
"""
def __init__( def __init__(
self, lr: float = 1e-3, optimizer_dict: Optional[Dict["torch.nn.Parameter", "torch.optim.Optimizer"]] = None self, lr: float = 1e-3, optimizer_dict: Optional[dict["torch.nn.Parameter", "torch.optim.Optimizer"]] = None
) -> None: ) -> None:
dummy_tensor = torch.randn(1, 1) dummy_tensor = torch.randn(1, 1)
self.optimizer_dict = optimizer_dict self.optimizer_dict = optimizer_dict
@ -112,8 +110,7 @@ def create_modelcard_and_push(
def create_ref_model( def create_ref_model(
model_args: "ModelArguments", finetuning_args: "FinetuningArguments", add_valuehead: bool = False model_args: "ModelArguments", finetuning_args: "FinetuningArguments", add_valuehead: bool = False
) -> Optional[Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]]: ) -> Optional[Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]]:
r""" r"""Create reference model for PPO/DPO training. Evaluation mode is not supported.
Creates reference model for PPO/DPO training. Evaluation mode is not supported.
The valuehead parameter is randomly initialized since it is useless for PPO training. The valuehead parameter is randomly initialized since it is useless for PPO training.
""" """
@ -148,9 +145,7 @@ def create_ref_model(
def create_reward_model( def create_reward_model(
model: "AutoModelForCausalLMWithValueHead", model_args: "ModelArguments", finetuning_args: "FinetuningArguments" model: "AutoModelForCausalLMWithValueHead", model_args: "ModelArguments", finetuning_args: "FinetuningArguments"
) -> Optional["AutoModelForCausalLMWithValueHead"]: ) -> Optional["AutoModelForCausalLMWithValueHead"]:
r""" r"""Create reward model for PPO training."""
Creates reward model for PPO training.
"""
if finetuning_args.reward_model_type == "api": if finetuning_args.reward_model_type == "api":
assert finetuning_args.reward_model.startswith("http"), "Please provide full url." assert finetuning_args.reward_model.startswith("http"), "Please provide full url."
logger.info_rank0(f"Use reward server {finetuning_args.reward_model}") logger.info_rank0(f"Use reward server {finetuning_args.reward_model}")
@ -189,10 +184,8 @@ def create_reward_model(
return reward_model return reward_model
def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]: def _get_decay_parameter_names(model: "PreTrainedModel") -> list[str]:
r""" r"""Return a list of names of parameters with weight decay. (weights in non-layernorm layers)."""
Returns a list of names of parameters with weight decay. (weights in non-layernorm layers)
"""
decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS) decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name] decay_parameters = [name for name in decay_parameters if "bias" not in name]
return decay_parameters return decay_parameters
@ -208,7 +201,7 @@ def _create_galore_optimizer(
else: else:
galore_targets = finetuning_args.galore_target galore_targets = finetuning_args.galore_target
galore_params: List["torch.nn.Parameter"] = [] galore_params: list[torch.nn.Parameter] = []
for name, module in model.named_modules(): for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) and any(target in name for target in galore_targets): if isinstance(module, torch.nn.Linear) and any(target in name for target in galore_targets):
for param in module.parameters(): for param in module.parameters():
@ -224,7 +217,7 @@ def _create_galore_optimizer(
id_galore_params = {id(param) for param in galore_params} id_galore_params = {id(param) for param in galore_params}
decay_params, nodecay_params = [], [] # they are non-galore parameters decay_params, nodecay_params = [], [] # they are non-galore parameters
trainable_params: List["torch.nn.Parameter"] = [] # galore_params + decay_params + nodecay_params trainable_params: list[torch.nn.Parameter] = [] # galore_params + decay_params + nodecay_params
decay_param_names = _get_decay_parameter_names(model) decay_param_names = _get_decay_parameter_names(model)
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if param.requires_grad: if param.requires_grad:
@ -251,7 +244,7 @@ def _create_galore_optimizer(
if training_args.gradient_accumulation_steps != 1: if training_args.gradient_accumulation_steps != 1:
raise ValueError("Per-layer GaLore does not support gradient accumulation.") raise ValueError("Per-layer GaLore does not support gradient accumulation.")
optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {} optimizer_dict: dict[torch.Tensor, torch.optim.Optimizer] = {}
for param in nodecay_params: for param in nodecay_params:
param_groups = [dict(params=[param], weight_decay=0.0)] param_groups = [dict(params=[param], weight_decay=0.0)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs) optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
@ -296,7 +289,7 @@ def _create_apollo_optimizer(
else: else:
apollo_targets = finetuning_args.apollo_target apollo_targets = finetuning_args.apollo_target
apollo_params: List["torch.nn.Parameter"] = [] apollo_params: list[torch.nn.Parameter] = []
for name, module in model.named_modules(): for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) and any(target in name for target in apollo_targets): if isinstance(module, torch.nn.Linear) and any(target in name for target in apollo_targets):
for param in module.parameters(): for param in module.parameters():
@ -315,7 +308,7 @@ def _create_apollo_optimizer(
id_apollo_params = {id(param) for param in apollo_params} id_apollo_params = {id(param) for param in apollo_params}
decay_params, nodecay_params = [], [] # they are non-apollo parameters decay_params, nodecay_params = [], [] # they are non-apollo parameters
trainable_params: List["torch.nn.Parameter"] = [] # apollo_params + decay_params + nodecay_params trainable_params: list[torch.nn.Parameter] = [] # apollo_params + decay_params + nodecay_params
decay_param_names = _get_decay_parameter_names(model) decay_param_names = _get_decay_parameter_names(model)
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if param.requires_grad: if param.requires_grad:
@ -338,7 +331,7 @@ def _create_apollo_optimizer(
if training_args.gradient_accumulation_steps != 1: if training_args.gradient_accumulation_steps != 1:
raise ValueError("Per-layer APOLLO does not support gradient accumulation.") raise ValueError("Per-layer APOLLO does not support gradient accumulation.")
optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {} optimizer_dict: dict[torch.Tensor, torch.optim.Optimizer] = {}
for param in nodecay_params: for param in nodecay_params:
param_groups = [dict(params=[param], weight_decay=0.0)] param_groups = [dict(params=[param], weight_decay=0.0)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs) optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
@ -380,7 +373,7 @@ def _create_loraplus_optimizer(
embedding_lr = finetuning_args.loraplus_lr_embedding embedding_lr = finetuning_args.loraplus_lr_embedding
decay_param_names = _get_decay_parameter_names(model) decay_param_names = _get_decay_parameter_names(model)
param_dict: Dict[str, List["torch.nn.Parameter"]] = { param_dict: dict[str, list[torch.nn.Parameter]] = {
"lora_a": [], "lora_a": [],
"lora_b": [], "lora_b": [],
"lora_b_nodecay": [], "lora_b_nodecay": [],
@ -524,7 +517,7 @@ def create_custom_scheduler(
) -> None: ) -> None:
if optimizer is not None and isinstance(optimizer, DummyOptimizer): if optimizer is not None and isinstance(optimizer, DummyOptimizer):
optimizer_dict = optimizer.optimizer_dict optimizer_dict = optimizer.optimizer_dict
scheduler_dict: Dict["torch.nn.Parameter", "torch.optim.lr_scheduler.LRScheduler"] = {} scheduler_dict: dict[torch.nn.Parameter, torch.optim.lr_scheduler.LRScheduler] = {}
for param in optimizer_dict.keys(): for param in optimizer_dict.keys():
scheduler_dict[param] = get_scheduler( scheduler_dict[param] = get_scheduler(
@ -544,13 +537,13 @@ def create_custom_scheduler(
def get_batch_logps( def get_batch_logps(
logits: "torch.Tensor", labels: "torch.Tensor", label_pad_token_id: int = IGNORE_INDEX logits: "torch.Tensor", labels: "torch.Tensor", label_pad_token_id: int = IGNORE_INDEX
) -> Tuple["torch.Tensor", "torch.Tensor"]: ) -> tuple["torch.Tensor", "torch.Tensor"]:
r""" r"""Compute the log probabilities of the given labels under the given logits.
Computes the log probabilities of the given labels under the given logits.
Returns: Returns:
logps: A tensor of shape (batch_size,) containing the sum of log probabilities. logps: A tensor of shape (batch_size,) containing the sum of log probabilities.
valid_length: A tensor of shape (batch_size,) containing the number of non-masked tokens. valid_length: A tensor of shape (batch_size,) containing the number of non-masked tokens.
""" """
if logits.shape[:-1] != labels.shape: if logits.shape[:-1] != labels.shape:
raise ValueError("Logits (batchsize x seqlen) and labels must have the same shape.") raise ValueError("Logits (batchsize x seqlen) and labels must have the same shape.")
@ -564,12 +557,10 @@ def get_batch_logps(
def nested_detach( def nested_detach(
tensors: Union["torch.Tensor", List["torch.Tensor"], Tuple["torch.Tensor"], Dict[str, "torch.Tensor"]], tensors: Union["torch.Tensor", list["torch.Tensor"], tuple["torch.Tensor"], dict[str, "torch.Tensor"]],
clone: bool = False, clone: bool = False,
): ):
r""" r"""Detach `tensors` (even if it's a nested list/tuple/dict of tensors)."""
Detach `tensors` (even if it's a nested list/tuple/dict of tensors).
"""
if isinstance(tensors, (list, tuple)): if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_detach(t, clone=clone) for t in tensors) return type(tensors)(nested_detach(t, clone=clone) for t in tensors)
elif isinstance(tensors, Mapping): elif isinstance(tensors, Mapping):
@ -585,9 +576,7 @@ def nested_detach(
def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCallback": def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCallback":
r""" r"""Get the callback for logging to SwanLab."""
Gets the callback for logging to SwanLab.
"""
import swanlab # type: ignore import swanlab # type: ignore
from swanlab.integration.transformers import SwanLabCallback # type: ignore from swanlab.integration.transformers import SwanLabCallback # type: ignore
@ -624,7 +613,7 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
def get_ray_trainer( def get_ray_trainer(
training_function: Callable, training_function: Callable,
train_loop_config: Dict[str, Any], train_loop_config: dict[str, Any],
ray_args: "RayArguments", ray_args: "RayArguments",
) -> "TorchTrainer": ) -> "TorchTrainer":
if not ray_args.use_ray: if not ray_args.use_ray:

View File

@ -14,7 +14,7 @@
import os import os
import shutil import shutil
from typing import TYPE_CHECKING, Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -48,9 +48,9 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def _training_function(config: Dict[str, Any]) -> None: def _training_function(config: dict[str, Any]) -> None:
args = config.get("args") args = config.get("args")
callbacks: List[Any] = config.get("callbacks") callbacks: list[Any] = config.get("callbacks")
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args) model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
callbacks.append(LogCallback()) callbacks.append(LogCallback())
@ -84,7 +84,7 @@ def _training_function(config: Dict[str, Any]) -> None:
logger.warning(f"Failed to destroy process group: {e}.") logger.warning(f"Failed to destroy process group: {e}.")
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None) -> None: def run_exp(args: Optional[dict[str, Any]] = None, callbacks: Optional[list["TrainerCallback"]] = None) -> None:
args = read_args(args) args = read_args(args)
if "-h" in args or "--help" in args: if "-h" in args or "--help" in args:
get_train_args(args) get_train_args(args)
@ -103,7 +103,7 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["Tra
_training_function(config={"args": args, "callbacks": callbacks}) _training_function(config={"args": args, "callbacks": callbacks})
def export_model(args: Optional[Dict[str, Any]] = None) -> None: def export_model(args: Optional[dict[str, Any]] = None) -> None:
model_args, data_args, finetuning_args, _ = get_infer_args(args) model_args, data_args, finetuning_args, _ = get_infer_args(args)
if model_args.export_dir is None: if model_args.export_dir is None:

View File

@ -14,7 +14,8 @@
import json import json
import os import os
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple from collections.abc import Generator
from typing import TYPE_CHECKING, Any, Optional
from transformers.utils import is_torch_npu_available from transformers.utils import is_torch_npu_available
@ -37,15 +38,12 @@ if is_gradio_available():
def _escape_html(text: str) -> str: def _escape_html(text: str) -> str:
r""" r"""Escape HTML characters."""
Escapes HTML characters.
"""
return text.replace("<", "&lt;").replace(">", "&gt;") return text.replace("<", "&lt;").replace(">", "&gt;")
def _format_response(text: str, lang: str, escape_html: bool, thought_words: Tuple[str, str]) -> str: def _format_response(text: str, lang: str, escape_html: bool, thought_words: tuple[str, str]) -> str:
r""" r"""Post-process the response text.
Post-processes the response text.
Based on: https://huggingface.co/spaces/Lyte/DeepSeek-R1-Distill-Qwen-1.5B-Demo-GGUF/blob/main/app.py Based on: https://huggingface.co/spaces/Lyte/DeepSeek-R1-Distill-Qwen-1.5B-Demo-GGUF/blob/main/app.py
""" """
@ -74,7 +72,7 @@ class WebChatModel(ChatModel):
def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None: def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None:
self.manager = manager self.manager = manager
self.demo_mode = demo_mode self.demo_mode = demo_mode
self.engine: Optional["BaseEngine"] = None self.engine: Optional[BaseEngine] = None
if not lazy_init: # read arguments from command line if not lazy_init: # read arguments from command line
super().__init__() super().__init__()
@ -160,14 +158,13 @@ class WebChatModel(ChatModel):
@staticmethod @staticmethod
def append( def append(
chatbot: List[Dict[str, str]], chatbot: list[dict[str, str]],
messages: List[Dict[str, str]], messages: list[dict[str, str]],
role: str, role: str,
query: str, query: str,
escape_html: bool, escape_html: bool,
) -> Tuple[List[Dict[str, str]], List[Dict[str, str]], str]: ) -> tuple[list[dict[str, str]], list[dict[str, str]], str]:
r""" r"""Add the user input to chatbot.
Adds the user input to chatbot.
Inputs: infer.chatbot, infer.messages, infer.role, infer.query, infer.escape_html Inputs: infer.chatbot, infer.messages, infer.role, infer.query, infer.escape_html
Output: infer.chatbot, infer.messages, infer.query Output: infer.chatbot, infer.messages, infer.query
@ -180,8 +177,8 @@ class WebChatModel(ChatModel):
def stream( def stream(
self, self,
chatbot: List[Dict[str, str]], chatbot: list[dict[str, str]],
messages: List[Dict[str, str]], messages: list[dict[str, str]],
lang: str, lang: str,
system: str, system: str,
tools: str, tools: str,
@ -193,9 +190,8 @@ class WebChatModel(ChatModel):
temperature: float, temperature: float,
skip_special_tokens: bool, skip_special_tokens: bool,
escape_html: bool, escape_html: bool,
) -> Generator[Tuple[List[Dict[str, str]], List[Dict[str, str]]], None, None]: ) -> Generator[tuple[list[dict[str, str]], list[dict[str, str]]], None, None]:
r""" r"""Generate output text in stream.
Generates output text in stream.
Inputs: infer.chatbot, infer.messages, infer.system, infer.tools, infer.image, infer.video, ... Inputs: infer.chatbot, infer.messages, infer.system, infer.tools, infer.image, infer.video, ...
Output: infer.chatbot, infer.messages Output: infer.chatbot, infer.messages

View File

@ -17,7 +17,7 @@ import os
import signal import signal
from collections import defaultdict from collections import defaultdict
from datetime import datetime from datetime import datetime
from typing import Any, Dict, Optional, Union from typing import Any, Optional, Union
from psutil import Process from psutil import Process
from yaml import safe_dump, safe_load from yaml import safe_dump, safe_load
@ -44,9 +44,7 @@ USER_CONFIG = "user_config.yaml"
def abort_process(pid: int) -> None: def abort_process(pid: int) -> None:
r""" r"""Abort the processes recursively in a bottom-up way."""
Aborts the processes recursively in a bottom-up way.
"""
try: try:
children = Process(pid).children() children = Process(pid).children()
if children: if children:
@ -59,9 +57,7 @@ def abort_process(pid: int) -> None:
def get_save_dir(*paths: str) -> os.PathLike: def get_save_dir(*paths: str) -> os.PathLike:
r""" r"""Get the path to saved model checkpoints."""
Gets the path to saved model checkpoints.
"""
if os.path.sep in paths[-1]: if os.path.sep in paths[-1]:
logger.warning_rank0("Found complex path, some features may be not available.") logger.warning_rank0("Found complex path, some features may be not available.")
return paths[-1] return paths[-1]
@ -71,16 +67,12 @@ def get_save_dir(*paths: str) -> os.PathLike:
def _get_config_path() -> os.PathLike: def _get_config_path() -> os.PathLike:
r""" r"""Get the path to user config."""
Gets the path to user config.
"""
return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG) return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)
def load_config() -> Dict[str, Union[str, Dict[str, Any]]]: def load_config() -> dict[str, Union[str, dict[str, Any]]]:
r""" r"""Load user config if exists."""
Loads user config if exists.
"""
try: try:
with open(_get_config_path(), encoding="utf-8") as f: with open(_get_config_path(), encoding="utf-8") as f:
return safe_load(f) return safe_load(f)
@ -89,9 +81,7 @@ def load_config() -> Dict[str, Union[str, Dict[str, Any]]]:
def save_config(lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None) -> None: def save_config(lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None) -> None:
r""" r"""Save user config."""
Saves user config.
"""
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True) os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
user_config = load_config() user_config = load_config()
user_config["lang"] = lang or user_config["lang"] user_config["lang"] = lang or user_config["lang"]
@ -106,11 +96,9 @@ def save_config(lang: str, model_name: Optional[str] = None, model_path: Optiona
def get_model_path(model_name: str) -> str: def get_model_path(model_name: str) -> str:
r""" r"""Get the model path according to the model name."""
Gets the model path according to the model name.
"""
user_config = load_config() user_config = load_config()
path_dict: Dict["DownloadSource", str] = SUPPORTED_MODELS.get(model_name, defaultdict(str)) path_dict: dict[DownloadSource, str] = SUPPORTED_MODELS.get(model_name, defaultdict(str))
model_path = user_config["path_dict"].get(model_name, "") or path_dict.get(DownloadSource.DEFAULT, "") model_path = user_config["path_dict"].get(model_name, "") or path_dict.get(DownloadSource.DEFAULT, "")
if ( if (
use_modelscope() use_modelscope()
@ -130,30 +118,22 @@ def get_model_path(model_name: str) -> str:
def get_template(model_name: str) -> str: def get_template(model_name: str) -> str:
r""" r"""Get the template name if the model is a chat/distill/instruct model."""
Gets the template name if the model is a chat/distill/instruct model.
"""
return DEFAULT_TEMPLATE.get(model_name, "default") return DEFAULT_TEMPLATE.get(model_name, "default")
def get_time() -> str: def get_time() -> str:
r""" r"""Get current date and time."""
Gets current date and time.
"""
return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S") return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S")
def is_multimodal(model_name: str) -> bool: def is_multimodal(model_name: str) -> bool:
r""" r"""Judge if the model is a vision language model."""
Judges if the model is a vision language model.
"""
return model_name in MULTIMODAL_SUPPORTED_MODELS return model_name in MULTIMODAL_SUPPORTED_MODELS
def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]: def load_dataset_info(dataset_dir: str) -> dict[str, dict[str, Any]]:
r""" r"""Load dataset_info.json."""
Loads dataset_info.json.
"""
if dataset_dir == "ONLINE" or dataset_dir.startswith("REMOTE:"): if dataset_dir == "ONLINE" or dataset_dir.startswith("REMOTE:"):
logger.info_rank0(f"dataset_dir is {dataset_dir}, using online dataset.") logger.info_rank0(f"dataset_dir is {dataset_dir}, using online dataset.")
return {} return {}
@ -166,10 +146,8 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
return {} return {}
def load_args(config_path: str) -> Optional[Dict[str, Any]]: def load_args(config_path: str) -> Optional[dict[str, Any]]:
r""" r"""Load the training configuration from config path."""
Loads the training configuration from config path.
"""
try: try:
with open(config_path, encoding="utf-8") as f: with open(config_path, encoding="utf-8") as f:
return safe_load(f) return safe_load(f)
@ -177,26 +155,20 @@ def load_args(config_path: str) -> Optional[Dict[str, Any]]:
return None return None
def save_args(config_path: str, config_dict: Dict[str, Any]) -> None: def save_args(config_path: str, config_dict: dict[str, Any]) -> None:
r""" r"""Save the training configuration to config path."""
Saves the training configuration to config path.
"""
with open(config_path, "w", encoding="utf-8") as f: with open(config_path, "w", encoding="utf-8") as f:
safe_dump(config_dict, f) safe_dump(config_dict, f)
def _clean_cmd(args: Dict[str, Any]) -> Dict[str, Any]: def _clean_cmd(args: dict[str, Any]) -> dict[str, Any]:
r""" r"""Remove args with NoneType or False or empty string value."""
Removes args with NoneType or False or empty string value.
"""
no_skip_keys = ["packing"] no_skip_keys = ["packing"]
return {k: v for k, v in args.items() if (k in no_skip_keys) or (v is not None and v is not False and v != "")} return {k: v for k, v in args.items() if (k in no_skip_keys) or (v is not None and v is not False and v != "")}
def gen_cmd(args: Dict[str, Any]) -> str: def gen_cmd(args: dict[str, Any]) -> str:
r""" r"""Generate CLI commands for previewing."""
Generates CLI commands for previewing.
"""
cmd_lines = ["llamafactory-cli train "] cmd_lines = ["llamafactory-cli train "]
for k, v in _clean_cmd(args).items(): for k, v in _clean_cmd(args).items():
if isinstance(v, dict): if isinstance(v, dict):
@ -215,10 +187,8 @@ def gen_cmd(args: Dict[str, Any]) -> str:
return cmd_text return cmd_text
def save_cmd(args: Dict[str, Any]) -> str: def save_cmd(args: dict[str, Any]) -> str:
r""" r"""Save CLI commands to launch training."""
Saves CLI commands to launch training.
"""
output_dir = args["output_dir"] output_dir = args["output_dir"]
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, TRAINING_ARGS), "w", encoding="utf-8") as f: with open(os.path.join(output_dir, TRAINING_ARGS), "w", encoding="utf-8") as f:
@ -228,9 +198,7 @@ def save_cmd(args: Dict[str, Any]) -> str:
def load_eval_results(path: os.PathLike) -> str: def load_eval_results(path: os.PathLike) -> str:
r""" r"""Get scores after evaluation."""
Gets scores after evaluation.
"""
with open(path, encoding="utf-8") as f: with open(path, encoding="utf-8") as f:
result = json.dumps(json.load(f), indent=4) result = json.dumps(json.load(f), indent=4)
@ -238,9 +206,7 @@ def load_eval_results(path: os.PathLike) -> str:
def create_ds_config() -> None: def create_ds_config() -> None:
r""" r"""Create deepspeed config in the current directory."""
Creates deepspeed config in the current directory.
"""
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True) os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
ds_config = { ds_config = {
"train_batch_size": "auto", "train_batch_size": "auto",

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import json import json
from typing import TYPE_CHECKING, Dict, Tuple from typing import TYPE_CHECKING
from ...data import Role from ...data import Role
from ...extras.packages import is_gradio_available from ...extras.packages import is_gradio_available
@ -31,9 +31,7 @@ if TYPE_CHECKING:
def check_json_schema(text: str, lang: str) -> None: def check_json_schema(text: str, lang: str) -> None:
r""" r"""Check if the json schema is valid."""
Checks if the json schema is valid.
"""
try: try:
tools = json.loads(text) tools = json.loads(text)
if tools: if tools:
@ -49,7 +47,7 @@ def check_json_schema(text: str, lang: str) -> None:
def create_chat_box( def create_chat_box(
engine: "Engine", visible: bool = False engine: "Engine", visible: bool = False
) -> Tuple["Component", "Component", Dict[str, "Component"]]: ) -> tuple["Component", "Component", dict[str, "Component"]]:
lang = engine.manager.get_elem_by_id("top.lang") lang = engine.manager.get_elem_by_id("top.lang")
with gr.Column(visible=visible) as chat_box: with gr.Column(visible=visible) as chat_box:
chatbot = gr.Chatbot(type="messages", show_copy_button=True) chatbot = gr.Chatbot(type="messages", show_copy_button=True)

View File

@ -14,7 +14,7 @@
import json import json
import os import os
from typing import TYPE_CHECKING, Any, Dict, List, Tuple from typing import TYPE_CHECKING, Any
from ...extras.constants import DATA_CONFIG from ...extras.constants import DATA_CONFIG
from ...extras.packages import is_gradio_available from ...extras.packages import is_gradio_available
@ -40,9 +40,7 @@ def next_page(page_index: int, total_num: int) -> int:
def can_preview(dataset_dir: str, dataset: list) -> "gr.Button": def can_preview(dataset_dir: str, dataset: list) -> "gr.Button":
r""" r"""Check if the dataset is a local dataset."""
Checks if the dataset is a local dataset.
"""
try: try:
with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f: with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f:
dataset_info = json.load(f) dataset_info = json.load(f)
@ -59,7 +57,7 @@ def can_preview(dataset_dir: str, dataset: list) -> "gr.Button":
return gr.Button(interactive=False) return gr.Button(interactive=False)
def _load_data_file(file_path: str) -> List[Any]: def _load_data_file(file_path: str) -> list[Any]:
with open(file_path, encoding="utf-8") as f: with open(file_path, encoding="utf-8") as f:
if file_path.endswith(".json"): if file_path.endswith(".json"):
return json.load(f) return json.load(f)
@ -69,10 +67,8 @@ def _load_data_file(file_path: str) -> List[Any]:
return list(f) return list(f)
def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, "gr.Column"]: def get_preview(dataset_dir: str, dataset: list, page_index: int) -> tuple[int, list, "gr.Column"]:
r""" r"""Get the preview samples from the dataset."""
Gets the preview samples from the dataset.
"""
with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f: with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f:
dataset_info = json.load(f) dataset_info = json.load(f)
@ -87,7 +83,7 @@ def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int,
return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.Column(visible=True) return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.Column(visible=True)
def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dict[str, "Component"]: def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> dict[str, "Component"]:
data_preview_btn = gr.Button(interactive=False, scale=1) data_preview_btn = gr.Button(interactive=False, scale=1)
with gr.Column(visible=False, elem_classes="modal-box") as preview_box: with gr.Column(visible=False, elem_classes="modal-box") as preview_box:
with gr.Row(): with gr.Row():

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING
from ...extras.packages import is_gradio_available from ...extras.packages import is_gradio_available
from ..common import DEFAULT_DATA_DIR from ..common import DEFAULT_DATA_DIR
@ -30,7 +30,7 @@ if TYPE_CHECKING:
from ..engine import Engine from ..engine import Engine
def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]: def create_eval_tab(engine: "Engine") -> dict[str, "Component"]:
input_elems = engine.manager.get_base_elems() input_elems = engine.manager.get_base_elems()
elem_dict = dict() elem_dict = dict()

View File

@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, Dict, Generator, List, Union from collections.abc import Generator
from typing import TYPE_CHECKING, Union
from ...extras.constants import PEFT_METHODS from ...extras.constants import PEFT_METHODS
from ...extras.misc import torch_gc from ...extras.misc import torch_gc
@ -35,7 +36,7 @@ if TYPE_CHECKING:
GPTQ_BITS = ["8", "4", "3", "2"] GPTQ_BITS = ["8", "4", "3", "2"]
def can_quantize(checkpoint_path: Union[str, List[str]]) -> "gr.Dropdown": def can_quantize(checkpoint_path: Union[str, list[str]]) -> "gr.Dropdown":
if isinstance(checkpoint_path, list) and len(checkpoint_path) != 0: if isinstance(checkpoint_path, list) and len(checkpoint_path) != 0:
return gr.Dropdown(value="none", interactive=False) return gr.Dropdown(value="none", interactive=False)
else: else:
@ -47,7 +48,7 @@ def save_model(
model_name: str, model_name: str,
model_path: str, model_path: str,
finetuning_type: str, finetuning_type: str,
checkpoint_path: Union[str, List[str]], checkpoint_path: Union[str, list[str]],
template: str, template: str,
export_size: int, export_size: int,
export_quantization_bit: str, export_quantization_bit: str,
@ -106,7 +107,7 @@ def save_model(
yield ALERTS["info_exported"][lang] yield ALERTS["info_exported"][lang]
def create_export_tab(engine: "Engine") -> Dict[str, "Component"]: def create_export_tab(engine: "Engine") -> dict[str, "Component"]:
with gr.Row(): with gr.Row():
export_size = gr.Slider(minimum=1, maximum=100, value=5, step=1) export_size = gr.Slider(minimum=1, maximum=100, value=5, step=1)
export_quantization_bit = gr.Dropdown(choices=["none"] + GPTQ_BITS, value="none") export_quantization_bit = gr.Dropdown(choices=["none"] + GPTQ_BITS, value="none")

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING
from ...extras.packages import is_gradio_available from ...extras.packages import is_gradio_available
from ..common import is_multimodal from ..common import is_multimodal
@ -29,7 +29,7 @@ if TYPE_CHECKING:
from ..engine import Engine from ..engine import Engine
def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]: def create_infer_tab(engine: "Engine") -> dict[str, "Component"]:
input_elems = engine.manager.get_base_elems() input_elems = engine.manager.get_base_elems()
elem_dict = dict() elem_dict = dict()

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING
from ...data import TEMPLATES from ...data import TEMPLATES
from ...extras.constants import METHODS, SUPPORTED_MODELS from ...extras.constants import METHODS, SUPPORTED_MODELS
@ -29,7 +29,7 @@ if TYPE_CHECKING:
from gradio.components import Component from gradio.components import Component
def create_top() -> Dict[str, "Component"]: def create_top() -> dict[str, "Component"]:
with gr.Row(): with gr.Row():
lang = gr.Dropdown(choices=["en", "ru", "zh", "ko", "ja"], value=None, scale=1) lang = gr.Dropdown(choices=["en", "ru", "zh", "ko", "ja"], value=None, scale=1)
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"] available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]

Some files were not shown because too many files have changed in this diff Show More