[misc] upgrade format to py39 (#7256)

This commit is contained in:
hoshi-hiyouga 2025-03-12 00:08:41 +08:00 committed by GitHub
parent cdafa8a15e
commit 7c1640ed5f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
113 changed files with 984 additions and 1407 deletions

View File

@ -10,7 +10,7 @@ _DESCRIPTION = "BELLE multiturn chat dataset."
_CITATION = """\
@article{belle2023exploring,
title={Exploring the Impact of Instruction Data Scaling on Large Language Models: An Empirical Study on Real-World Use Cases},
title={Exploring the Impact of Instruction Data Scaling on Large Language Models},
author={Yunjie Ji, Yong Deng, Yan Gong, Yiping Peng, Qiang Niu, Lei Zhang, Baochang Ma, Xiangang Li},
journal={arXiv preprint arXiv:2303.14742},
year={2023}

View File

@ -1,6 +1,5 @@
import json
import os
from typing import List
import datasets
@ -50,7 +49,7 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
datasets.SplitGenerator(name=datasets.Split.TEST, gen_kwargs={"filepaths": file_path["test"]}),
]
def _generate_examples(self, filepaths: List[str]):
def _generate_examples(self, filepaths: list[str]):
key = 0
for filepath in filepaths:
with open(filepath, encoding="utf-8") as f:

View File

@ -1,6 +1,5 @@
import json
import os
from typing import List
import datasets
@ -11,7 +10,7 @@ _DESCRIPTION = "UltraChat: Large-scale, Informative, and Diverse Multi-round Dia
_CITATION = """\
@misc{UltraChat,
author = {Ding, Ning and Chen, Yulin and Xu, Bokai and Hu, Shengding and Qin, Yujia and Liu, Zhiyuan and Sun, Maosong and Zhou, Bowen},
author = {Ding, Ning and Chen, Yulin and Xu, Bokai and Hu, Shengding and others},
title = {UltraChat: A Large-scale Auto-generated Multi-round Dialogue Data},
year = {2023},
publisher = {GitHub},
@ -40,7 +39,7 @@ class UltraChat(datasets.GeneratorBasedBuilder):
file_paths = [dl_manager.download(_BASE_DATA_URL.format(idx=idx)) for idx in range(10)] # multiple shards
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepaths": file_paths})]
def _generate_examples(self, filepaths: List[str]):
def _generate_examples(self, filepaths: list[str]):
for filepath in filepaths:
with open(filepath, encoding="utf-8") as f:
for row in f:
@ -49,7 +48,7 @@ class UltraChat(datasets.GeneratorBasedBuilder):
except Exception:
continue
key: int = data["id"]
content: List[str] = data["data"]
content: list[str] = data["data"]
if len(content) % 2 == 1:
content.pop(-1)
if len(content) < 2:

View File

@ -21,14 +21,15 @@ import pandas as pd
_CITATION = """\
@article{huang2023ceval,
title={C-Eval: A Multi-Level Multi-Discipline Chinese Evaluation Suite for Foundation Models},
author={Huang, Yuzhen and Bai, Yuzhuo and Zhu, Zhihao and Zhang, Junlei and Zhang, Jinghan and Su, Tangjun and Liu, Junteng and Lv, Chuancheng and Zhang, Yikai and Lei, Jiayi and Fu, Yao and Sun, Maosong and He, Junxian},
author={Huang, Yuzhen and Bai, Yuzhuo and Zhu, Zhihao and others},
journal={arXiv preprint arXiv:2305.08322},
year={2023}
}
"""
_DESCRIPTION = """\
C-Eval is a comprehensive Chinese evaluation suite for foundation models. It consists of 13948 multi-choice questions spanning 52 diverse disciplines and four difficulty levels.
C-Eval is a comprehensive Chinese evaluation suite for foundation models.
It consists of 13948 multi-choice questions spanning 52 diverse disciplines and four difficulty levels.
"""
_HOMEPAGE = "https://cevalbenchmark.com"

View File

@ -21,14 +21,15 @@ import pandas as pd
_CITATION = """\
@article{li2023cmmlu,
title={CMMLU: Measuring massive multitask language understanding in Chinese},
author={Haonan Li and Yixuan Zhang and Fajri Koto and Yifei Yang and Hai Zhao and Yeyun Gong and Nan Duan and Timothy Baldwin},
author={Haonan Li and Yixuan Zhang and Fajri Koto and Yifei Yang and others,
journal={arXiv preprint arXiv:2306.09212},
year={2023}
}
"""
_DESCRIPTION = """\
CMMLU is a comprehensive Chinese assessment suite specifically designed to evaluate the advanced knowledge and reasoning abilities of LLMs within the Chinese language and cultural context.
CMMLU is a comprehensive Chinese assessment suite specifically designed to evaluate the advanced knowledge
and reasoning abilities of LLMs within the Chinese language and cultural context.
"""
_HOMEPAGE = "https://github.com/haonan-li/CMMLU"

View File

@ -21,14 +21,15 @@ import pandas as pd
_CITATION = """\
@article{hendryckstest2021,
title={Measuring Massive Multitask Language Understanding},
author={Dan Hendrycks and Collin Burns and Steven Basart and Andy Zou and Mantas Mazeika and Dawn Song and Jacob Steinhardt},
author={Dan Hendrycks and Collin Burns and others},
journal={Proceedings of the International Conference on Learning Representations (ICLR)},
year={2021}
}
"""
_DESCRIPTION = """\
Measuring Massive Multitask Language Understanding by Dan Hendrycks, Collin Burns, Steven Basart, Andy Zou, Mantas Mazeika, Dawn Song, and Jacob Steinhardt (ICLR 2021).
Measuring Massive Multitask Language Understanding by Dan Hendrycks, Collin Burns, Steven Basart,
Andy Zou, Mantas Mazeika, Dawn Song, and Jacob Steinhardt (ICLR 2021).
"""
_HOMEPAGE = "https://github.com/hendrycks/test"

View File

@ -19,13 +19,35 @@ dynamic = [
]
[tool.ruff]
target-version = "py38"
target-version = "py39"
line-length = 119
indent-width = 4
[tool.ruff.lint]
ignore = ["C408", "C901", "E501", "E731", "E741", "W605"]
select = ["C", "E", "F", "I", "W"]
ignore = [
"C408", # collection
"C901", # complex
"E731", # lambda function
"E741", # ambiguous var name
"D100", # no doc public module
"D101", # no doc public class
"D102", # no doc public method
"D103", # no doc public function
"D104", # no doc public package
"D105", # no doc magic method
"D107", # no doc __init__
]
extend-select = [
"C", # complexity
"E", # error
"F", # pyflakes
"I", # isort
"W", # warning
"UP", # pyupgrade
"D", # pydocstyle
"PT009", # pytest assert
"RUF022", # sort __all__
]
[tool.ruff.lint.isort]
lines-after-imports = 2
@ -41,6 +63,9 @@ known-third-party = [
"trl"
]
[tool.ruff.lint.pydocstyle]
convention = "google"
[tool.ruff.format]
quote-style = "double"
indent-style = "space"

View File

@ -14,7 +14,7 @@
import json
import os
from typing import Sequence
from collections.abc import Sequence
from openai import OpenAI
from transformers.utils.versions import require_version

View File

@ -15,7 +15,7 @@
import json
import os
from collections import OrderedDict
from typing import Any, Dict
from typing import Any
import fire
import torch
@ -29,13 +29,13 @@ CONFIG_NAME = "config.json"
def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetensors: bool):
baichuan2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
baichuan2_state_dict: dict[str, torch.Tensor] = OrderedDict()
for filepath in tqdm(os.listdir(input_dir), desc="Load weights"):
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".bin"):
shard_weight = torch.load(os.path.join(input_dir, filepath), map_location="cpu")
baichuan2_state_dict.update(shard_weight)
llama_state_dict: Dict[str, torch.Tensor] = OrderedDict()
llama_state_dict: dict[str, torch.Tensor] = OrderedDict()
for key, value in tqdm(baichuan2_state_dict.items(), desc="Convert format"):
if "W_pack" in key:
proj_size = value.size(0) // 3
@ -75,7 +75,7 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
def save_config(input_dir: str, output_dir: str):
with open(os.path.join(input_dir, CONFIG_NAME), encoding="utf-8") as f:
llama2_config_dict: Dict[str, Any] = json.load(f)
llama2_config_dict: dict[str, Any] = json.load(f)
llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
llama2_config_dict.pop("auto_map", None)
@ -94,8 +94,8 @@ def llamafy_baichuan2(
shard_size: str = "2GB",
save_safetensors: bool = True,
):
r"""
Converts the Baichuan2-7B model in the same format as LLaMA2-7B.
r"""Convert the Baichuan2-7B model in the same format as LLaMA2-7B.
Usage: python llamafy_baichuan2.py --input_dir input --output_dir output
Converted model: https://huggingface.co/hiyouga/Baichuan2-7B-Base-LLaMAfied
"""

View File

@ -15,7 +15,7 @@
import json
import os
from collections import OrderedDict
from typing import Any, Dict
from typing import Any
import fire
import torch
@ -37,14 +37,14 @@ CONFIG_NAME = "config.json"
def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetensors: bool) -> str:
qwen_state_dict: Dict[str, torch.Tensor] = OrderedDict()
qwen_state_dict: dict[str, torch.Tensor] = OrderedDict()
for filepath in tqdm(os.listdir(input_dir), desc="Load weights"):
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".safetensors"):
with safe_open(os.path.join(input_dir, filepath), framework="pt", device="cpu") as f:
for key in f.keys():
qwen_state_dict[key] = f.get_tensor(key)
llama_state_dict: Dict[str, torch.Tensor] = OrderedDict()
llama_state_dict: dict[str, torch.Tensor] = OrderedDict()
torch_dtype = None
for key, value in tqdm(qwen_state_dict.items(), desc="Convert format"):
if torch_dtype is None:
@ -112,9 +112,9 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
def save_config(input_dir: str, output_dir: str, torch_dtype: str):
with open(os.path.join(input_dir, CONFIG_NAME), encoding="utf-8") as f:
qwen_config_dict: Dict[str, Any] = json.load(f)
qwen_config_dict: dict[str, Any] = json.load(f)
llama2_config_dict: Dict[str, Any] = OrderedDict()
llama2_config_dict: dict[str, Any] = OrderedDict()
llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
llama2_config_dict["hidden_act"] = "silu"
llama2_config_dict["hidden_size"] = qwen_config_dict["hidden_size"]
@ -147,8 +147,8 @@ def llamafy_qwen(
shard_size: str = "2GB",
save_safetensors: bool = False,
):
r"""
Converts the Qwen models in the same format as LLaMA2.
r"""Convert the Qwen models in the same format as LLaMA2.
Usage: python llamafy_qwen.py --input_dir input --output_dir output
Converted model: https://huggingface.co/hiyouga/Qwen-14B-Chat-LLaMAfied
"""

View File

@ -18,7 +18,7 @@
import json
import os
from collections import OrderedDict
from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING
import fire
import torch
@ -44,11 +44,11 @@ def block_expansion(
shard_size: str = "5GB",
save_safetensors: bool = True,
):
r"""
Performs block expansion for LLaMA, Mistral, Qwen2 or Yi models.
r"""Perform block expansion for LLaMA, Mistral, Qwen2 or Yi models.
Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8
"""
config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
config: PretrainedConfig = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
num_layers = getattr(config, "num_hidden_layers")
if num_layers % num_expand != 0:
raise ValueError(f"`num_layers` {num_layers} should be divisible by `num_expand` {num_expand}.")
@ -70,7 +70,7 @@ def block_expansion(
split = num_layers // num_expand
layer_cnt = 0
state_dict = model.state_dict()
output_state_dict: Dict[str, "torch.Tensor"] = OrderedDict()
output_state_dict: dict[str, torch.Tensor] = OrderedDict()
for i in range(num_layers):
for key, value in state_dict.items():
if f".{i:d}." in key:

View File

@ -38,8 +38,8 @@ def quantize_loftq(
lora_target: tuple = ("q_proj", "v_proj"),
save_safetensors: bool = True,
):
r"""
Initializes LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ)
r"""Initialize LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ).
Usage: python loftq_init.py --model_name_or_path path_to_model --output_dir output_dir
"""
if isinstance(lora_target, str):
@ -72,7 +72,7 @@ def quantize_loftq(
print(f"Adapter weights saved in {loftq_dir}")
# Save base model
base_model: "PreTrainedModel" = peft_model.unload()
base_model: PreTrainedModel = peft_model.unload()
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
tokenizer.save_pretrained(output_dir)
print(f"Model weights saved in {output_dir}")

View File

@ -37,8 +37,8 @@ def quantize_pissa(
lora_target: tuple = ("q_proj", "v_proj"),
save_safetensors: bool = True,
):
r"""
Initializes LoRA weights with Principal Singular values and Singular vectors Adaptation (PiSSA)
r"""Initialize LoRA weights with Principal Singular values and Singular vectors Adaptation (PiSSA).
Usage: python pissa_init.py --model_name_or_path path_to_model --output_dir output_dir
"""
if isinstance(lora_target, str):
@ -67,7 +67,7 @@ def quantize_pissa(
print(f"Adapter weights saved in {pissa_dir}")
# Save base model
base_model: "PreTrainedModel" = peft_model.unload()
base_model: PreTrainedModel = peft_model.unload()
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
tokenizer.save_pretrained(output_dir)
print(f"Model weights saved in {output_dir}")

View File

@ -29,8 +29,8 @@ def calculate_flops(
seq_length: int = 512,
flash_attn: str = "auto",
):
r"""
Calculates the flops of pre-trained models.
r"""Calculate the flops of pre-trained models.
Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512
"""
with get_accelerator().device(0):

View File

@ -45,8 +45,8 @@ def calculate_lr(
is_mistral_or_gemma: bool = False, # mistral and gemma models opt for a smaller learning rate,
packing: bool = False,
):
r"""
Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
r"""Calculate the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
Usage:
python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en_demo --cutoff_len 1024 --batch_size 16
"""
@ -89,9 +89,8 @@ def calculate_lr(
lr = BASE_LR * math.sqrt(token_batch_size / BASE_BS) # lr ~ sqrt(batch_size)
lr = lr / 6.0 if is_mistral_or_gemma else lr
print(
"Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective token batch size {:.2f}".format(
lr, valid_ratio * 100, token_batch_size
)
f"Optimal learning rate is {lr:.2e} for valid ratio% {valid_ratio * 100:.2f} "
f"and effective token batch size {token_batch_size:.2f}"
)

View File

@ -34,9 +34,7 @@ def compute_model_flops(
include_recompute: bool = False,
include_flashattn: bool = False,
) -> int:
r"""
Calculates the FLOPs of model per forward/backward pass.
"""
r"""Calculate the FLOPs of model per forward/backward pass."""
config = AutoConfig.from_pretrained(model_name_or_path)
hidden_size = getattr(config, "hidden_size", None)
vocab_size = getattr(config, "vocab_size", None)
@ -86,9 +84,7 @@ def compute_model_flops(
def compute_device_flops(world_size: int) -> float:
r"""
Calculates the FLOPs of the device capability per second.
"""
r"""Calculate the FLOPs of the device capability per second."""
device_name = torch.cuda.get_device_name()
if "H100" in device_name or "H800" in device_name:
return 989 * 1e12 * world_size
@ -114,8 +110,8 @@ def calculate_mfu(
liger_kernel: bool = False,
unsloth_gc: bool = False,
) -> float:
r"""
Calculates MFU for given model and hyper-params.
r"""Calculate MFU for given model and hyper-params.
Usage: python cal_mfu.py --model_name_or_path path_to_model --batch_size 1 --seq_length 1024
"""
args = {

View File

@ -13,8 +13,9 @@
# limitations under the License.
import json
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, Dict, Literal, Optional, Sequence
from typing import Any, Literal, Optional
import fire
import torch
@ -30,16 +31,12 @@ from llamafactory.model import load_model, load_tokenizer
@dataclass
class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
r"""
Data collator for pairwise data.
"""
r"""Data collator for pairwise data."""
train_on_prompt: bool = False
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
r"""
Pads batched data to the longest sequence in the batch.
"""
def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, torch.Tensor]:
r"""Pad batched data to the longest sequence in the batch."""
chosen_features = []
for feature in features:
chosen_features.append(
@ -68,8 +65,8 @@ def calculate_ppl(
max_samples: Optional[int] = None,
train_on_prompt: bool = False,
):
r"""
Calculates the ppl on the dataset of the pre-trained models.
r"""Calculate the ppl on the dataset of the pre-trained models.
Usage: export CUDA_VISIBLE_DEVICES=0
python cal_ppl.py --model_name_or_path path_to_model --dataset alpaca_en_demo --save_name ppl.json
"""
@ -111,17 +108,17 @@ def calculate_ppl(
criterion = torch.nn.CrossEntropyLoss(reduction="none")
total_ppl = 0
perplexities = []
batch: Dict[str, "torch.Tensor"]
batch: dict[str, torch.Tensor]
with torch.no_grad():
for batch in tqdm(dataloader, desc="Computing perplexities"):
batch = batch.to(model.device)
outputs = model(**batch)
shift_logits: "torch.Tensor" = outputs["logits"][..., :-1, :]
shift_labels: "torch.Tensor" = batch["labels"][..., 1:]
shift_logits: torch.Tensor = outputs["logits"][..., :-1, :]
shift_labels: torch.Tensor = batch["labels"][..., 1:]
loss_mask = shift_labels != IGNORE_INDEX
flatten_logits = shift_logits.contiguous().view(shift_labels.size(0) * shift_labels.size(1), -1)
flatten_labels = shift_labels.contiguous().view(-1)
token_logps: "torch.Tensor" = criterion(flatten_logits, flatten_labels)
token_logps: torch.Tensor = criterion(flatten_logits, flatten_labels)
token_logps = token_logps.contiguous().view(shift_logits.size(0), -1)
sentence_logps = (token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
total_ppl += sentence_logps.exp().sum().item()

View File

@ -29,8 +29,8 @@ def length_cdf(
template: str = "default",
interval: int = 1000,
):
r"""
Calculates the distribution of the input lengths in the dataset.
r"""Calculate the distribution of the input lengths in the dataset.
Usage: export CUDA_VISIBLE_DEVICES=0
python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en_demo --template default
"""

View File

@ -52,8 +52,8 @@ def vllm_infer(
image_max_pixels: int = 768 * 768,
image_min_pixels: int = 32 * 32,
):
r"""
Performs batch generation using vLLM engine, which supports tensor parallelism.
r"""Perform batch generation using vLLM engine, which supports tensor parallelism.
Usage: python vllm_infer.py --model_name_or_path meta-llama/Llama-2-7b-hf --template llama --dataset alpaca_en_demo
"""
check_version("vllm>=0.4.3,<=0.7.3")

View File

@ -14,7 +14,6 @@
import os
import re
from typing import List
from setuptools import find_packages, setup
@ -27,14 +26,14 @@ def get_version() -> str:
return version
def get_requires() -> List[str]:
def get_requires() -> list[str]:
with open("requirements.txt", encoding="utf-8") as f:
file_content = f.read()
lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
return lines
def get_console_scripts() -> List[str]:
def get_console_scripts() -> list[str]:
console_scripts = ["llamafactory-cli = llamafactory.cli:main"]
if os.getenv("ENABLE_SHORT_CONSOLE", "1").lower() in ["true", "y", "1"]:
console_scripts.append("lmf = llamafactory.cli:main")

View File

@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
r"""
Efficient fine-tuning of large language models.
r"""Efficient fine-tuning of large language models.
Level:
api, webui > chat, eval, train > data, model > hparams > extras

View File

@ -16,9 +16,7 @@ import asyncio
import os
from contextlib import asynccontextmanager
from functools import partial
from typing import Optional
from typing_extensions import Annotated
from typing import Annotated, Optional
from ..chat import ChatModel
from ..extras.constants import EngineName

View File

@ -18,7 +18,8 @@ import json
import os
import re
import uuid
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING, Optional
from ..data import Role as DataRole
from ..extras import logging
@ -71,7 +72,7 @@ ROLE_MAPPING = {
def _process_request(
request: "ChatCompletionRequest",
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional[List["ImageInput"]]]:
) -> tuple[list[dict[str, str]], Optional[str], Optional[str], Optional[list["ImageInput"]]]:
if is_env_enabled("API_VERBOSE", "1"):
logger.info_rank0(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}")

View File

@ -13,14 +13,14 @@
# limitations under the License.
import json
from typing import TYPE_CHECKING, Any, Dict
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from pydantic import BaseModel
def dictify(data: "BaseModel") -> Dict[str, Any]:
def dictify(data: "BaseModel") -> dict[str, Any]:
try: # pydantic v2
return data.model_dump(exclude_unset=True)
except AttributeError: # pydantic v1

View File

@ -14,7 +14,7 @@
import time
from enum import Enum, unique
from typing import Any, Dict, List, Optional, Union
from typing import Any, Optional, Union
from pydantic import BaseModel, Field
from typing_extensions import Literal
@ -45,7 +45,7 @@ class ModelCard(BaseModel):
class ModelList(BaseModel):
object: Literal["list"] = "list"
data: List[ModelCard] = []
data: list[ModelCard] = []
class Function(BaseModel):
@ -56,7 +56,7 @@ class Function(BaseModel):
class FunctionDefinition(BaseModel):
name: str
description: str
parameters: Dict[str, Any]
parameters: dict[str, Any]
class FunctionAvailable(BaseModel):
@ -82,26 +82,26 @@ class MultimodalInputItem(BaseModel):
class ChatMessage(BaseModel):
role: Role
content: Optional[Union[str, List[MultimodalInputItem]]] = None
tool_calls: Optional[List[FunctionCall]] = None
content: Optional[Union[str, list[MultimodalInputItem]]] = None
tool_calls: Optional[list[FunctionCall]] = None
class ChatCompletionMessage(BaseModel):
role: Optional[Role] = None
content: Optional[str] = None
tool_calls: Optional[List[FunctionCall]] = None
tool_calls: Optional[list[FunctionCall]] = None
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
tools: Optional[List[FunctionAvailable]] = None
messages: list[ChatMessage]
tools: Optional[list[FunctionAvailable]] = None
do_sample: Optional[bool] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
n: int = 1
max_tokens: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
stop: Optional[Union[str, list[str]]] = None
stream: bool = False
@ -128,7 +128,7 @@ class ChatCompletionResponse(BaseModel):
object: Literal["chat.completion"] = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseChoice]
choices: list[ChatCompletionResponseChoice]
usage: ChatCompletionResponseUsage
@ -137,12 +137,12 @@ class ChatCompletionStreamResponse(BaseModel):
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionStreamResponseChoice]
choices: list[ChatCompletionStreamResponseChoice]
class ScoreEvaluationRequest(BaseModel):
model: str
messages: List[str]
messages: list[str]
max_length: Optional[int] = None
@ -150,4 +150,4 @@ class ScoreEvaluationResponse(BaseModel):
id: str
object: Literal["score.evaluation"] = "score.evaluation"
model: str
scores: List[float]
scores: list[float]

View File

@ -13,8 +13,9 @@
# limitations under the License.
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Union
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
if TYPE_CHECKING:
@ -36,8 +37,7 @@ class Response:
class BaseEngine(ABC):
r"""
Base class for inference engine of chat models.
r"""Base class for inference engine of chat models.
Must implements async methods: chat(), stream_chat() and get_scores().
"""
@ -47,7 +47,7 @@ class BaseEngine(ABC):
tokenizer: "PreTrainedTokenizer"
can_generate: bool
template: "Template"
generating_args: Dict[str, Any]
generating_args: dict[str, Any]
@abstractmethod
def __init__(
@ -57,31 +57,27 @@ class BaseEngine(ABC):
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
) -> None:
r"""
Initializes an inference engine.
"""
r"""Initialize an inference engine."""
...
@abstractmethod
async def chat(
self,
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs,
) -> List["Response"]:
r"""
Gets a list of responses of the chat model.
"""
) -> list["Response"]:
r"""Get a list of responses of the chat model."""
...
@abstractmethod
async def stream_chat(
self,
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
@ -89,18 +85,14 @@ class BaseEngine(ABC):
audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
r"""
Gets the response token-by-token of the chat model.
"""
r"""Get the response token-by-token of the chat model."""
...
@abstractmethod
async def get_scores(
self,
batch_input: List[str],
batch_input: list[str],
**input_kwargs,
) -> List[float]:
r"""
Gets a list of scores of the reward model.
"""
) -> list[float]:
r"""Get a list of scores of the reward model."""
...

View File

@ -17,8 +17,9 @@
import asyncio
import os
from collections.abc import AsyncGenerator, Generator, Sequence
from threading import Thread
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
from typing import TYPE_CHECKING, Any, Optional
from ..extras.constants import EngineName
from ..extras.misc import torch_gc
@ -38,20 +39,19 @@ def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
class ChatModel:
r"""
General class for chat models. Backed by huggingface or vllm engines.
r"""General class for chat models. Backed by huggingface or vllm engines.
Supports both sync and async methods.
Sync methods: chat(), stream_chat() and get_scores().
Async methods: achat(), astream_chat() and aget_scores().
"""
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
def __init__(self, args: Optional[dict[str, Any]] = None) -> None:
model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
if model_args.infer_backend == EngineName.HF:
self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
self.engine: BaseEngine = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
elif model_args.infer_backend == EngineName.VLLM:
self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
self.engine: BaseEngine = VllmEngine(model_args, data_args, finetuning_args, generating_args)
else:
raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")
@ -61,17 +61,15 @@ class ChatModel:
def chat(
self,
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs,
) -> List["Response"]:
r"""
Gets a list of responses of the chat model.
"""
) -> list["Response"]:
r"""Get a list of responses of the chat model."""
task = asyncio.run_coroutine_threadsafe(
self.achat(messages, system, tools, images, videos, audios, **input_kwargs), self._loop
)
@ -79,22 +77,20 @@ class ChatModel:
async def achat(
self,
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs,
) -> List["Response"]:
r"""
Asynchronously gets a list of responses of the chat model.
"""
) -> list["Response"]:
r"""Asynchronously get a list of responses of the chat model."""
return await self.engine.chat(messages, system, tools, images, videos, audios, **input_kwargs)
def stream_chat(
self,
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
@ -102,9 +98,7 @@ class ChatModel:
audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs,
) -> Generator[str, None, None]:
r"""
Gets the response token-by-token of the chat model.
"""
r"""Get the response token-by-token of the chat model."""
generator = self.astream_chat(messages, system, tools, images, videos, audios, **input_kwargs)
while True:
try:
@ -115,7 +109,7 @@ class ChatModel:
async def astream_chat(
self,
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
@ -123,9 +117,7 @@ class ChatModel:
audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
r"""
Asynchronously gets the response token-by-token of the chat model.
"""
r"""Asynchronously get the response token-by-token of the chat model."""
async for new_token in self.engine.stream_chat(
messages, system, tools, images, videos, audios, **input_kwargs
):
@ -133,23 +125,19 @@ class ChatModel:
def get_scores(
self,
batch_input: List[str],
batch_input: list[str],
**input_kwargs,
) -> List[float]:
r"""
Gets a list of scores of the reward model.
"""
) -> list[float]:
r"""Get a list of scores of the reward model."""
task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop)
return task.result()
async def aget_scores(
self,
batch_input: List[str],
batch_input: list[str],
**input_kwargs,
) -> List[float]:
r"""
Asynchronously gets a list of scores of the reward model.
"""
) -> list[float]:
r"""Asynchronously get a list of scores of the reward model."""
return await self.engine.get_scores(batch_input, **input_kwargs)

View File

@ -15,8 +15,9 @@
import asyncio
import concurrent.futures
import os
from collections.abc import AsyncGenerator, Sequence
from threading import Thread
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import torch
from transformers import GenerationConfig, TextIteratorStreamer
@ -76,15 +77,15 @@ class HuggingfaceEngine(BaseEngine):
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
template: "Template",
generating_args: Dict[str, Any],
messages: Sequence[Dict[str, str]],
generating_args: dict[str, Any],
messages: Sequence[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> Tuple[Dict[str, Any], int]:
input_kwargs: Optional[dict[str, Any]] = {},
) -> tuple[dict[str, Any], int]:
mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]}
if images is not None:
mm_input_dict.update({"images": images, "imglens": [len(images)]})
@ -130,7 +131,7 @@ class HuggingfaceEngine(BaseEngine):
skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
max_length: Optional[int] = input_kwargs.pop("max_length", None)
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None)
if stop is not None:
logger.warning_rank0("Stop parameter is not supported by the huggingface engine yet.")
@ -217,15 +218,15 @@ class HuggingfaceEngine(BaseEngine):
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
template: "Template",
generating_args: Dict[str, Any],
messages: Sequence[Dict[str, str]],
generating_args: dict[str, Any],
messages: Sequence[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> List["Response"]:
input_kwargs: Optional[dict[str, Any]] = {},
) -> list["Response"]:
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
model,
tokenizer,
@ -272,14 +273,14 @@ class HuggingfaceEngine(BaseEngine):
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
template: "Template",
generating_args: Dict[str, Any],
messages: Sequence[Dict[str, str]],
generating_args: dict[str, Any],
messages: Sequence[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
input_kwargs: Optional[dict[str, Any]] = {},
) -> Callable[[], str]:
gen_kwargs, _ = HuggingfaceEngine._process_args(
model,
@ -317,12 +318,12 @@ class HuggingfaceEngine(BaseEngine):
def _get_scores(
model: "PreTrainedModelWrapper",
tokenizer: "PreTrainedTokenizer",
batch_input: List[str],
input_kwargs: Optional[Dict[str, Any]] = {},
) -> List[float]:
batch_input: list[str],
input_kwargs: Optional[dict[str, Any]] = {},
) -> list[float]:
max_length: Optional[int] = input_kwargs.pop("max_length", None)
device = getattr(model.pretrained_model, "device", "cuda")
inputs: Dict[str, "torch.Tensor"] = tokenizer(
inputs: dict[str, torch.Tensor] = tokenizer(
batch_input,
padding=True,
truncation=True,
@ -330,21 +331,21 @@ class HuggingfaceEngine(BaseEngine):
return_tensors="pt",
add_special_tokens=False,
).to(device)
values: "torch.Tensor" = model(**inputs, return_dict=True, use_cache=False)[-1]
values: torch.Tensor = model(**inputs, return_dict=True, use_cache=False)[-1]
scores = values.gather(dim=-1, index=(inputs["attention_mask"].sum(dim=-1, keepdim=True) - 1))
return scores
@override
async def chat(
self,
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs,
) -> List["Response"]:
) -> list["Response"]:
if not self.can_generate:
raise ValueError("The current model does not support `chat`.")
@ -370,7 +371,7 @@ class HuggingfaceEngine(BaseEngine):
@override
async def stream_chat(
self,
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
@ -408,9 +409,9 @@ class HuggingfaceEngine(BaseEngine):
@override
async def get_scores(
self,
batch_input: List[str],
batch_input: list[str],
**input_kwargs,
) -> List[float]:
) -> list[float]:
if self.can_generate:
raise ValueError("Cannot get scores using an auto-regressive model.")

View File

@ -13,7 +13,8 @@
# limitations under the License.
import uuid
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
from collections.abc import AsyncGenerator, AsyncIterator, Sequence
from typing import TYPE_CHECKING, Any, Optional, Union
from typing_extensions import override
@ -53,7 +54,7 @@ class VllmEngine(BaseEngine):
self.model_args = model_args
config = load_config(model_args) # may download model from ms hub
if getattr(config, "quantization_config", None): # gptq models should use float16
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
quantization_config: dict[str, Any] = getattr(config, "quantization_config", None)
quant_method = quantization_config.get("quant_method", "")
if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto":
model_args.infer_dtype = "float16"
@ -101,7 +102,7 @@ class VllmEngine(BaseEngine):
async def _generate(
self,
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
@ -143,7 +144,7 @@ class VllmEngine(BaseEngine):
skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
max_length: Optional[int] = input_kwargs.pop("max_length", None)
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None)
if length_penalty is not None:
logger.warning_rank0("Length penalty is not supported by the vllm engine yet.")
@ -201,14 +202,14 @@ class VllmEngine(BaseEngine):
@override
async def chat(
self,
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs,
) -> List["Response"]:
) -> list["Response"]:
final_output = None
generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
async for request_output in generator:
@ -230,7 +231,7 @@ class VllmEngine(BaseEngine):
@override
async def stream_chat(
self,
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
@ -248,7 +249,7 @@ class VllmEngine(BaseEngine):
@override
async def get_scores(
self,
batch_input: List[str],
batch_input: list[str],
**input_kwargs,
) -> List[float]:
) -> list[float]:
raise NotImplementedError("vLLM engine does not support get_scores.")

View File

@ -24,14 +24,14 @@ from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
__all__ = [
"TEMPLATES",
"KTODataCollatorWithPadding",
"MultiModalDataCollatorForSeq2Seq",
"PairwiseDataCollatorWithPadding",
"SFTDataCollatorWith4DAttentionMask",
"Role",
"split_dataset",
"get_dataset",
"TEMPLATES",
"SFTDataCollatorWith4DAttentionMask",
"Template",
"get_dataset",
"get_template_and_fix_tokenizer",
"split_dataset",
]

View File

@ -15,8 +15,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence
from typing import TYPE_CHECKING, Any, Literal, Optional
import numpy as np
import torch
@ -38,9 +39,10 @@ if TYPE_CHECKING:
def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
r"""
Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
while handles packed sequences and transforms the mask to lower triangular form to prevent future peeking.
r"""Expand 2d attention mask to 4d attention mask.
Expand the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
handle packed sequences and transforms the mask to lower triangular form to prevent future peeking.
e.g.
```python
@ -78,8 +80,7 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
@dataclass
class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
r"""
Data collator that supports VLMs.
r"""Data collator that supports VLMs.
Features should contain input_ids, attention_mask, labels, and optionally contain images, videos and audios.
"""
@ -91,7 +92,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
if self.template is None:
raise ValueError("Template is required for MultiModalDataCollator.")
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
batch_images, batch_videos, batch_audios = [], [], []
batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], []
for feature in features:
@ -166,7 +167,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
for i, feature in enumerate(features):
feature["token_type_ids"] = token_type_ids[i]
features: Dict[str, "torch.Tensor"] = super().__call__(features)
features: dict[str, torch.Tensor] = super().__call__(features)
if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope
rope_index_kwargs = {
@ -198,15 +199,13 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
@dataclass
class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
r"""
Data collator for 4d attention mask.
"""
r"""Data collator for 4d attention mask."""
block_diag_attn: bool = False
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
compute_dtype: "torch.dtype" = torch.float32
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
features = super().__call__(features)
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
@ -220,13 +219,10 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
@dataclass
class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
r"""
Data collator for pairwise data.
"""
r"""Data collator for pairwise data."""
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
r"""
Pads batched data to the longest sequence in the batch.
def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
r"""Pad batched data to the longest sequence in the batch.
We generate 2 * n examples where the first n examples represent chosen examples and
the last n examples represent rejected examples.
@ -249,11 +245,9 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
@dataclass
class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
r"""
Data collator for KTO data.
"""
r"""Data collator for KTO data."""
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
target_features = []
kl_features = []
kto_tags = []

View File

@ -14,8 +14,9 @@
import os
from abc import abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Type, Union
from typing import TYPE_CHECKING, Any, Optional, Union
from ..extras import logging
from .data_utils import Role
@ -36,10 +37,8 @@ class DatasetConverter:
dataset_attr: "DatasetAttr"
data_args: "DataArguments"
def _find_medias(self, medias: Union[Any, Sequence[Any]]) -> Optional[List[Any]]:
r"""
Optionally concatenates media path to media dir when loading from local disk.
"""
def _find_medias(self, medias: Union[Any, Sequence[Any]]) -> Optional[list[Any]]:
r"""Optionally concatenate media path to media dir when loading from local disk."""
if not isinstance(medias, list):
medias = [medias] if medias is not None else []
elif len(medias) == 0:
@ -57,16 +56,14 @@ class DatasetConverter:
return medias
@abstractmethod
def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]:
r"""
Converts a single example in the dataset to the standard format.
"""
def __call__(self, example: dict[str, Any]) -> dict[str, Any]:
r"""Convert a single example in the dataset to the standard format."""
...
@dataclass
class AlpacaDatasetConverter(DatasetConverter):
def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]:
def __call__(self, example: dict[str, Any]) -> dict[str, Any]:
prompt = []
if self.dataset_attr.history and isinstance(example[self.dataset_attr.history], list):
for old_prompt, old_response in example[self.dataset_attr.history]:
@ -116,7 +113,7 @@ class AlpacaDatasetConverter(DatasetConverter):
@dataclass
class SharegptDatasetConverter(DatasetConverter):
def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]:
def __call__(self, example: dict[str, Any]) -> dict[str, Any]:
tag_mapping = {
self.dataset_attr.user_tag: Role.USER.value,
self.dataset_attr.assistant_tag: Role.ASSISTANT.value,
@ -216,10 +213,8 @@ DATASET_CONVERTERS = {
}
def register_dataset_converter(name: str, dataset_converter: Type["DatasetConverter"]) -> None:
r"""
Register a new dataset converter.
"""
def register_dataset_converter(name: str, dataset_converter: type["DatasetConverter"]) -> None:
r"""Register a new dataset converter."""
if name in DATASET_CONVERTERS:
raise ValueError(f"Dataset converter {name} already exists.")
@ -227,9 +222,7 @@ def register_dataset_converter(name: str, dataset_converter: Type["DatasetConver
def get_dataset_converter(name: str, dataset_attr: "DatasetAttr", data_args: "DataArguments") -> "DatasetConverter":
r"""
Gets a dataset converter.
"""
r"""Get a dataset converter."""
if name not in DATASET_CONVERTERS:
raise ValueError(f"Dataset converter {name} not found.")
@ -242,17 +235,17 @@ def align_dataset(
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]:
r"""
Aligned dataset:
_prompt: [{"role": "user", "content": "..."}] * (2T - 1)
_response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
_system: "..."
_tools: "...",
_images: [],
_videos: [],
_audios: [],
"""
r"""Align the dataset to a specific format.
Aligned dataset:
_prompt: [{"role": "user", "content": "..."}] * (2T - 1)
_response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
_system: "..."
_tools: "..."
_images: []
_videos: []
_audios: []
"""
column_names = list(next(iter(dataset)).keys())
kwargs = {}
if not data_args.streaming:

View File

@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
from enum import Enum, unique
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, TypedDict, Union
from typing import TYPE_CHECKING, Optional, TypedDict, Union
from datasets import DatasetDict, concatenate_datasets, interleave_datasets
@ -29,7 +30,7 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__)
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
SLOTS = Sequence[Union[str, set[str], dict[str, str]]]
@unique
@ -43,15 +44,13 @@ class Role(str, Enum):
class DatasetModule(TypedDict):
train_dataset: Optional[Union["Dataset", "IterableDataset"]]
eval_dataset: Optional[Union["Dataset", "IterableDataset", Dict[str, "Dataset"]]]
eval_dataset: Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]
def merge_dataset(
all_datasets: List[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int
all_datasets: list[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int
) -> Union["Dataset", "IterableDataset"]:
r"""
Merges multiple datasets to a unified dataset.
"""
r"""Merge multiple datasets to a unified dataset."""
if len(all_datasets) == 1:
return all_datasets[0]
@ -78,14 +77,13 @@ def merge_dataset(
def split_dataset(
dataset: Optional[Union["Dataset", "IterableDataset"]],
eval_dataset: Optional[Union["Dataset", "IterableDataset", Dict[str, "Dataset"]]],
eval_dataset: Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]],
data_args: "DataArguments",
seed: int,
) -> "DatasetDict":
r"""
Splits the dataset and returns a dataset dict containing train set and validation set.
r"""Split the dataset and returns a dataset dict containing train set and validation set.
Supports both map dataset and iterable dataset.
Support both map dataset and iterable dataset.
"""
if eval_dataset is not None and data_args.val_size > 1e-6:
raise ValueError("Cannot specify `val_size` if `eval_dataset` is not None.")
@ -120,10 +118,8 @@ def split_dataset(
def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModule":
r"""
Converts dataset or dataset dict to dataset module.
"""
dataset_module: "DatasetModule" = {}
r"""Convert dataset or dataset dict to dataset module."""
dataset_module: DatasetModule = {}
if isinstance(dataset, DatasetDict): # dataset dict
if "train" in dataset:
dataset_module["train_dataset"] = dataset["train"]

View File

@ -16,7 +16,7 @@ import json
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import List, Optional, Union
from typing import Optional, Union
from typing_extensions import override
@ -31,14 +31,11 @@ class Formatter(ABC):
@abstractmethod
def apply(self, **kwargs) -> SLOTS:
r"""
Forms a list of slots according to the inputs to encode.
"""
r"""Forms a list of slots according to the inputs to encode."""
...
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
r"""
Extract a list of tuples from the response message if using tools.
def extract(self, content: str) -> Union[str, list["FunctionCall"]]:
r"""Extract a list of tuples from the response message if using tools.
Each tuple consists of function name and function arguments.
"""
@ -105,7 +102,7 @@ class FunctionFormatter(StringFormatter):
if thought:
content = content.replace(thought.group(0), "")
functions: List["FunctionCall"] = []
functions: list[FunctionCall] = []
try:
tool_calls = json.loads(content)
if not isinstance(tool_calls, list): # parallel function call
@ -141,5 +138,5 @@ class ToolFormatter(Formatter):
raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") # flat string
@override
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
def extract(self, content: str) -> Union[str, list["FunctionCall"]]:
return self.tool_utils.tool_extractor(content)

View File

@ -13,7 +13,8 @@
# limitations under the License.
import os
from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union
from collections.abc import Sequence
from typing import TYPE_CHECKING, Literal, Optional, Union
import numpy as np
from datasets import load_dataset, load_from_disk
@ -54,9 +55,7 @@ def _load_single_dataset(
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]:
r"""
Loads a single dataset and aligns it to the standard format.
"""
r"""Load a single dataset and aligns it to the standard format."""
logger.info_rank0(f"Loading dataset {dataset_attr}...")
data_path, data_name, data_dir, data_files = None, None, None, None
if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]:
@ -164,10 +163,8 @@ def _get_merged_dataset(
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
merge: bool = True,
) -> Optional[Union["Dataset", "IterableDataset", Dict[str, "Dataset"]]]:
r"""
Returns the merged datasets in the standard format.
"""
) -> Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]:
r"""Return the merged datasets in the standard format."""
if dataset_names is None:
return None
@ -192,9 +189,7 @@ def _get_dataset_processor(
processor: Optional["ProcessorMixin"],
do_generate: bool = False,
) -> "DatasetProcessor":
r"""
Returns the corresponding dataset processor.
"""
r"""Return the corresponding dataset processor."""
if stage == "pt":
dataset_processor_class = PretrainDatasetProcessor
elif stage == "sft" and not do_generate:
@ -236,9 +231,7 @@ def _get_preprocessed_dataset(
processor: Optional["ProcessorMixin"] = None,
is_eval: bool = False,
) -> Optional[Union["Dataset", "IterableDataset"]]:
r"""
Preprocesses the dataset, including format checking and tokenization.
"""
r"""Preprocesses the dataset, including format checking and tokenization."""
if dataset is None:
return None
@ -284,9 +277,7 @@ def get_dataset(
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"] = None,
) -> "DatasetModule":
r"""
Gets the train dataset and optionally gets the evaluation dataset.
"""
r"""Get the train dataset and optionally gets the evaluation dataset."""
# Load tokenized dataset if path exists
if data_args.tokenized_path is not None:
if has_tokenized_data(data_args.tokenized_path):

View File

@ -1,10 +1,11 @@
import inspect
import math
import re
from collections.abc import Sequence
from copy import deepcopy
from dataclasses import dataclass
from io import BytesIO
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Type, TypedDict, Union
from typing import TYPE_CHECKING, Optional, TypedDict, Union
import numpy as np
import torch
@ -58,12 +59,12 @@ if TYPE_CHECKING:
def _get_paligemma_token_type_ids(
imglens: Sequence[int], seqlens: Sequence[int], processor: "ProcessorMixin"
) -> List[List[int]]:
r"""
Gets paligemma token type ids for computing loss.
) -> list[list[int]]:
r"""Get paligemma token type ids for computing loss.
Returns:
batch_token_type_ids: shape (batch_size, sequence_length)
"""
batch_token_type_ids = []
for imglen, seqlen in zip(imglens, seqlens):
@ -87,11 +88,9 @@ class MMPluginMixin:
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
) -> None:
r"""
Validates if this model accepts the input modalities.
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor", None)
feature_extractor: "SequenceFeatureExtractor" = getattr(processor, "feature_extractor", None)
r"""Validate if this model accepts the input modalities."""
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
if len(images) != 0 and self.image_token is None:
raise ValueError(
"This model does not support image input. Please check whether the correct `template` is used."
@ -119,9 +118,7 @@ class MMPluginMixin:
def _preprocess_image(
self, image: "ImageObject", image_max_pixels: int, image_min_pixels: int, **kwargs
) -> "ImageObject":
r"""
Pre-processes a single image.
"""
r"""Pre-process a single image."""
if (image.width * image.height) > image_max_pixels:
resize_factor = math.sqrt(image_max_pixels / (image.width * image.height))
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
@ -139,10 +136,8 @@ class MMPluginMixin:
def _get_video_sample_indices(
self, video_stream: "Stream", video_fps: float, video_maxlen: int, **kwargs
) -> List[int]:
r"""
Computes video sample indices according to fps.
"""
) -> list[int]:
r"""Compute video sample indices according to fps."""
total_frames = video_stream.frames
if total_frames == 0: # infinite video
return np.linspace(0, video_maxlen - 1, video_maxlen).astype(np.int32)
@ -151,10 +146,8 @@ class MMPluginMixin:
sample_frames = min(total_frames, video_maxlen, sample_frames)
return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
def _regularize_images(self, images: Sequence["ImageInput"], **kwargs) -> List["ImageObject"]:
r"""
Regularizes images to avoid error. Including reading and pre-processing.
"""
def _regularize_images(self, images: Sequence["ImageInput"], **kwargs) -> list["ImageObject"]:
r"""Regularize images to avoid error. Including reading and pre-processing."""
results = []
for image in images:
if isinstance(image, str):
@ -174,16 +167,14 @@ class MMPluginMixin:
return results
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]:
r"""
Regularizes videos to avoid error. Including reading, resizing and converting.
"""
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> list[list["ImageObject"]]:
r"""Regularizes videos to avoid error. Including reading, resizing and converting."""
results = []
for video in videos:
container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video")
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
frames: List["ImageObject"] = []
frames: list[ImageObject] = []
container.seek(0)
for frame_idx, frame in enumerate(container.decode(video_stream)):
if frame_idx in sample_indices:
@ -194,10 +185,8 @@ class MMPluginMixin:
return results
def _regularize_audios(self, audios: Sequence["AudioInput"], sampling_rate: float, **kwargs) -> List["NDArray"]:
r"""
Regularizes audios to avoid error. Including reading and resampling.
"""
def _regularize_audios(self, audios: Sequence["AudioInput"], sampling_rate: float, **kwargs) -> list["NDArray"]:
r"""Regularizes audios to avoid error. Including reading and resampling."""
results = []
for audio in audios:
if isinstance(audio, str):
@ -216,9 +205,8 @@ class MMPluginMixin:
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: "ProcessorMixin",
) -> Dict[str, "torch.Tensor"]:
r"""
Processes visual inputs.
) -> dict[str, "torch.Tensor"]:
r"""Process visual inputs.
Returns: (llava and paligemma)
pixel_values: tensor with shape (B, C, H, W)
@ -229,9 +217,9 @@ class MMPluginMixin:
It holds num_patches == torch.prod(image_grid_thw)
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor", None)
video_processor: "BaseImageProcessor" = getattr(processor, "video_processor", image_processor)
feature_extractor: "SequenceFeatureExtractor" = getattr(processor, "feature_extractor", None)
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
video_processor: BaseImageProcessor = getattr(processor, "video_processor", image_processor)
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
mm_inputs = {}
if len(images) != 0:
@ -278,31 +266,27 @@ class MMPluginMixin:
class BasePlugin(MMPluginMixin):
def process_messages(
self,
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
r"""
Pre-processes input messages before tokenization for VLMs.
"""
) -> list[dict[str, str]]:
r"""Pre-processes input messages before tokenization for VLMs."""
self._validate_input(processor, images, videos, audios)
return messages
def process_token_ids(
self,
input_ids: List[int],
labels: Optional[List[int]],
input_ids: list[int],
labels: Optional[list[int]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
) -> Tuple[List[int], Optional[List[int]]]:
r"""
Pre-processes token ids after tokenization for VLMs.
"""
) -> tuple[list[int], Optional[list[int]]]:
r"""Pre-processes token ids after tokenization for VLMs."""
self._validate_input(processor, images, videos, audios)
return input_ids, labels
@ -314,20 +298,21 @@ class BasePlugin(MMPluginMixin):
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]],
batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
r"""
Builds batched multimodal inputs for VLMs.
) -> dict[str, Union[list[int], "torch.Tensor"]]:
r"""Build batched multimodal inputs for VLMs.
Arguments:
images: a list of image inputs, shape (num_images,)
videos: a list of video inputs, shape (num_videos,)
audios: a list of audio inputs, shape (num_audios,)
imglens: number of images in each sample, shape (batch_size,)
vidlens: number of videos in each sample, shape (batch_size,)
audlens: number of audios in each sample, shape (batch_size,)
batch_ids: token ids of input samples, shape (batch_size, seq_len)
processor: a processor for pre-processing images and videos
"""
self._validate_input(processor, images, videos, audios)
return {}
@ -338,12 +323,12 @@ class LlavaPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
num_image_tokens = 0
image_seqlen = getattr(processor, "image_seqlen") if self.expand_mm_tokens else 1
@ -370,9 +355,9 @@ class LlavaPlugin(BasePlugin):
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]],
batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor)
@ -382,12 +367,12 @@ class LlavaNextPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
num_image_tokens = 0
messages = deepcopy(messages)
@ -426,9 +411,9 @@ class LlavaNextPlugin(BasePlugin):
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]],
batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor)
@ -438,12 +423,12 @@ class LlavaNextVideoPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages)
@ -502,9 +487,9 @@ class LlavaNextVideoPlugin(BasePlugin):
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]],
batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor)
@ -514,16 +499,16 @@ class MiniCPMVPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
messages = deepcopy(messages)
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
mm_inputs = {}
audio_inputs = {}
if len(images) != 0 and len(videos) != 0:
@ -619,9 +604,9 @@ class MiniCPMVPlugin(BasePlugin):
audios: Sequence["AudioInput"],
processor: "ProcessorMixin",
**kwargs,
) -> Dict[str, "torch.Tensor"]:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
feature_extractor: "SequenceFeatureExtractor" = getattr(processor, "feature_extractor", None)
) -> dict[str, "torch.Tensor"]:
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
mm_inputs = {}
if len(images) != 0:
images = self._regularize_images(
@ -691,9 +676,9 @@ class MiniCPMVPlugin(BasePlugin):
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]],
batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
# image bound
image_bounds_list = []
@ -756,12 +741,12 @@ class MllamaPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
num_image_tokens = 0
messages = deepcopy(messages)
@ -782,10 +767,9 @@ class MllamaPlugin(BasePlugin):
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: "ProcessorMixin",
imglens: List[int],
) -> Dict[str, "torch.Tensor"]:
r"""
Processes visual inputs for mllama because its image processor only accepts List[List[ImageInput]].
imglens: list[int],
) -> dict[str, "torch.Tensor"]:
r"""Process visual inputs for mllama because its image processor only accepts List[List[ImageInput]].
Returns:
pixel_values: tensor with shape
@ -794,8 +778,9 @@ class MllamaPlugin(BasePlugin):
aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1).
aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4).
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
mm_inputs = {}
if len(images) > 0:
images = self._regularize_images(
@ -821,9 +806,9 @@ class MllamaPlugin(BasePlugin):
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]],
batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor, imglens)
if mm_inputs:
@ -850,12 +835,12 @@ class PaliGemmaPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
num_image_tokens = 0
messages = deepcopy(messages)
@ -875,14 +860,14 @@ class PaliGemmaPlugin(BasePlugin):
@override
def process_token_ids(
self,
input_ids: List[int],
labels: Optional[List[int]],
input_ids: list[int],
labels: Optional[list[int]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
) -> Tuple[List[int], Optional[List[int]]]:
) -> tuple[list[int], Optional[list[int]]]:
self._validate_input(processor, images, videos, audios)
num_images = len(images)
image_seqlen = num_images * getattr(processor, "image_seqlen") if self.expand_mm_tokens else 0 # skip mm token
@ -902,9 +887,9 @@ class PaliGemmaPlugin(BasePlugin):
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]],
batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
seqlens = [len(input_ids) for input_ids in batch_ids]
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
@ -917,12 +902,12 @@ class PixtralPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
patch_size = getattr(processor, "patch_size")
image_token = getattr(processor, "image_token")
@ -968,9 +953,9 @@ class PixtralPlugin(BasePlugin):
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]],
batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
mm_inputs.pop("image_sizes", None)
@ -982,12 +967,12 @@ class Qwen2AudioPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
bos_token: str = getattr(processor, "audio_bos_token")
eos_token: str = getattr(processor, "audio_eos_token")
@ -1028,9 +1013,9 @@ class Qwen2AudioPlugin(BasePlugin):
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]],
batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor)
@ -1057,13 +1042,13 @@ class Qwen2VLPlugin(BasePlugin):
@override
def _regularize_videos(
self, videos: Sequence["VideoInput"], **kwargs
) -> Tuple[List[List["ImageObject"]], List[float]]:
) -> tuple[list[list["ImageObject"]], list[float]]:
results, fps_per_video = [], []
for video in videos:
container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video")
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
frames: List["ImageObject"] = []
frames: list[ImageObject] = []
container.seek(0)
for frame_idx, frame in enumerate(container.decode(video_stream)):
if frame_idx in sample_indices:
@ -1088,8 +1073,8 @@ class Qwen2VLPlugin(BasePlugin):
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: "ProcessorMixin",
) -> Dict[str, "torch.Tensor"]:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor", None)
) -> dict[str, "torch.Tensor"]:
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
mm_inputs = {}
if len(images) != 0:
images = self._regularize_images(
@ -1115,16 +1100,16 @@ class Qwen2VLPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages)
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
merge_length: int = getattr(image_processor, "merge_size") ** 2
if self.expand_mm_tokens:
@ -1176,13 +1161,13 @@ class Qwen2VLPlugin(BasePlugin):
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]],
batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
fps_per_video = mm_inputs.pop("fps_per_video", [])
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
if "second_per_grid_ts" in processor.model_input_names and fps_per_video:
mm_inputs["second_per_grid_ts"] = [image_processor.temporal_patch_size / fps for fps in fps_per_video]
@ -1194,12 +1179,12 @@ class VideoLlavaPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages)
@ -1255,9 +1240,9 @@ class VideoLlavaPlugin(BasePlugin):
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]],
batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor)
@ -1277,10 +1262,8 @@ PLUGINS = {
}
def register_mm_plugin(name: str, plugin_class: Type["BasePlugin"]) -> None:
r"""
Registers a multimodal plugin.
"""
def register_mm_plugin(name: str, plugin_class: type["BasePlugin"]) -> None:
r"""Register a multimodal plugin."""
if name in PLUGINS:
raise ValueError(f"Multimodal plugin {name} already exists.")
@ -1293,9 +1276,7 @@ def get_mm_plugin(
video_token: Optional[str] = None,
audio_token: Optional[str] = None,
) -> "BasePlugin":
r"""
Gets plugin for multimodal inputs.
"""
r"""Get plugin for multimodal inputs."""
if name not in PLUGINS:
raise ValueError(f"Multimodal plugin `{name}` not found.")

View File

@ -14,8 +14,9 @@
import json
import os
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, Dict, List, Literal, Optional, Sequence
from typing import Any, Literal, Optional
from transformers.utils import cached_file
@ -25,9 +26,7 @@ from ..extras.misc import use_modelscope, use_openmind
@dataclass
class DatasetAttr:
r"""
Dataset attributes.
"""
r"""Dataset attributes."""
# basic configs
load_from: Literal["hf_hub", "ms_hub", "om_hub", "script", "file"]
@ -68,10 +67,10 @@ class DatasetAttr:
def __repr__(self) -> str:
return self.dataset_name
def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None:
def set_attr(self, key: str, obj: dict[str, Any], default: Optional[Any] = None) -> None:
setattr(self, key, obj.get(key, default))
def join(self, attr: Dict[str, Any]) -> None:
def join(self, attr: dict[str, Any]) -> None:
self.set_attr("formatting", attr, default="alpaca")
self.set_attr("ranking", attr, default=False)
self.set_attr("subset", attr)
@ -92,10 +91,8 @@ class DatasetAttr:
self.set_attr(tag, attr["tags"])
def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -> List["DatasetAttr"]:
r"""
Gets the attributes of the datasets.
"""
def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -> list["DatasetAttr"]:
r"""Get the attributes of the datasets."""
if dataset_names is None:
dataset_names = []
@ -116,7 +113,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
dataset_info = None
dataset_list: List["DatasetAttr"] = []
dataset_list: list[DatasetAttr] = []
for name in dataset_names:
if dataset_info is None: # dataset_dir is ONLINE
if use_modelscope():

View File

@ -9,9 +9,9 @@ from .unsupervised import UnsupervisedDatasetProcessor
__all__ = [
"DatasetProcessor",
"FeedbackDatasetProcessor",
"PackedSupervisedDatasetProcessor",
"PairwiseDatasetProcessor",
"PretrainDatasetProcessor",
"PackedSupervisedDatasetProcessor",
"SupervisedDatasetProcessor",
"UnsupervisedDatasetProcessor",
]

View File

@ -13,7 +13,8 @@
# limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Optional
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
@ -30,15 +31,15 @@ logger = logging.get_logger(__name__)
class FeedbackDatasetProcessor(DatasetProcessor):
def _encode_data_example(
self,
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
kl_response: Sequence[Dict[str, str]],
prompt: Sequence[dict[str, str]],
response: Sequence[dict[str, str]],
kl_response: Sequence[dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
) -> Tuple[List[int], List[int], List[int], List[int], bool]:
) -> tuple[list[int], list[int], list[int], list[int], bool]:
if response[0]["content"]: # desired example
kto_tag = True
messages = prompt + [response[0]]
@ -82,7 +83,7 @@ class FeedbackDatasetProcessor(DatasetProcessor):
kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids
return input_ids, labels, kl_input_ids, kl_labels, kto_tag
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
kl_response = examples["_response"][::-1]
model_inputs = defaultdict(list)
@ -121,7 +122,7 @@ class FeedbackDatasetProcessor(DatasetProcessor):
return model_inputs
def print_data_example(self, example: Dict[str, List[int]]) -> None:
def print_data_example(self, example: dict[str, list[int]]) -> None:
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))

View File

@ -13,7 +13,8 @@
# limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Optional
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
@ -30,14 +31,14 @@ logger = logging.get_logger(__name__)
class PairwiseDatasetProcessor(DatasetProcessor):
def _encode_data_example(
self,
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
prompt: Sequence[dict[str, str]],
response: Sequence[dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
) -> Tuple[List[int], List[int], List[int], List[int]]:
) -> tuple[list[int], list[int], list[int], list[int]]:
chosen_messages = self.template.mm_plugin.process_messages(
prompt + [response[0]], images, videos, audios, self.processor
)
@ -68,7 +69,7 @@ class PairwiseDatasetProcessor(DatasetProcessor):
rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
@ -99,7 +100,7 @@ class PairwiseDatasetProcessor(DatasetProcessor):
return model_inputs
def print_data_example(self, example: Dict[str, List[int]]) -> None:
def print_data_example(self, example: dict[str, list[int]]) -> None:
valid_chosen_labels = list(filter(lambda x: x != IGNORE_INDEX, example["chosen_labels"]))
valid_rejected_labels = list(filter(lambda x: x != IGNORE_INDEX, example["rejected_labels"]))
print("chosen_input_ids:\n{}".format(example["chosen_input_ids"]))

View File

@ -17,14 +17,14 @@
from dataclasses import dataclass
from itertools import chain
from typing import Any, Dict, List
from typing import Any
from .processor_utils import DatasetProcessor
@dataclass
class PretrainDatasetProcessor(DatasetProcessor):
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
eos_token = "<|end_of_text|>" if self.data_args.template == "llama3" else self.tokenizer.eos_token
text_examples = [messages[0]["content"] + eos_token for messages in examples["_prompt"]]
@ -52,6 +52,6 @@ class PretrainDatasetProcessor(DatasetProcessor):
return result
def print_data_example(self, example: Dict[str, List[int]]) -> None:
def print_data_example(self, example: dict[str, list[int]]) -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))

View File

@ -14,8 +14,9 @@
import bisect
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING:
@ -27,9 +28,7 @@ if TYPE_CHECKING:
@dataclass
class DatasetProcessor(ABC):
r"""
A class for data processors.
"""
r"""A class for data processors."""
template: "Template"
tokenizer: "PreTrainedTokenizer"
@ -37,32 +36,24 @@ class DatasetProcessor(ABC):
data_args: "DataArguments"
@abstractmethod
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
r"""
Builds model inputs from the examples.
"""
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
r"""Build model inputs from the examples."""
...
@abstractmethod
def print_data_example(self, example: Dict[str, List[int]]) -> None:
r"""
Print a data example to stdout.
"""
def print_data_example(self, example: dict[str, list[int]]) -> None:
r"""Print a data example to stdout."""
...
def search_for_fit(numbers: Sequence[int], capacity: int) -> int:
r"""
Finds the index of largest number that fits into the knapsack with the given capacity.
"""
r"""Find the index of largest number that fits into the knapsack with the given capacity."""
index = bisect.bisect(numbers, capacity)
return -1 if index == 0 else (index - 1)
def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
r"""
An efficient greedy algorithm with binary search for the knapsack problem.
"""
def greedy_knapsack(numbers: list[int], capacity: int) -> list[list[int]]:
r"""Implement efficient greedy algorithm with binary search for the knapsack problem."""
numbers.sort() # sort numbers in ascending order for binary search
knapsacks = []
@ -83,10 +74,8 @@ def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
return knapsacks
def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]:
r"""
Computes the real sequence length after truncation by the cutoff_len.
"""
def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> tuple[int, int]:
r"""Compute the real sequence length after truncation by the cutoff_len."""
if target_len * 2 < cutoff_len: # truncate source
max_target_len = cutoff_len
elif source_len * 2 < cutoff_len: # truncate target

View File

@ -13,8 +13,9 @@
# limitations under the License.
from collections import defaultdict
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from typing import TYPE_CHECKING, Any, Optional
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
@ -32,14 +33,14 @@ logger = logging.get_logger(__name__)
class SupervisedDatasetProcessor(DatasetProcessor):
def _encode_data_example(
self,
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
prompt: Sequence[dict[str, str]],
response: Sequence[dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
) -> Tuple[List[int], List[int]]:
) -> tuple[list[int], list[int]]:
messages = self.template.mm_plugin.process_messages(prompt + response, images, videos, audios, self.processor)
input_ids, labels = self.template.mm_plugin.process_token_ids(
[], [], images, videos, audios, self.tokenizer, self.processor
@ -85,7 +86,7 @@ class SupervisedDatasetProcessor(DatasetProcessor):
return input_ids, labels
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = defaultdict(list)
@ -114,7 +115,7 @@ class SupervisedDatasetProcessor(DatasetProcessor):
return model_inputs
def print_data_example(self, example: Dict[str, List[int]]) -> None:
def print_data_example(self, example: dict[str, list[int]]) -> None:
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
@ -124,7 +125,7 @@ class SupervisedDatasetProcessor(DatasetProcessor):
@dataclass
class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
# TODO: use `position_ids` to achieve packing
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`

View File

@ -13,7 +13,8 @@
# limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Optional
from ...extras import logging
from ..data_utils import Role
@ -30,14 +31,14 @@ logger = logging.get_logger(__name__)
class UnsupervisedDatasetProcessor(DatasetProcessor):
def _encode_data_example(
self,
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
prompt: Sequence[dict[str, str]],
response: Sequence[dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
) -> Tuple[List[int], List[int]]:
) -> tuple[list[int], list[int]]:
if len(response) == 1:
messages = prompt + response
else:
@ -56,7 +57,7 @@ class UnsupervisedDatasetProcessor(DatasetProcessor):
labels = labels[:target_len]
return input_ids, labels
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
@ -84,7 +85,7 @@ class UnsupervisedDatasetProcessor(DatasetProcessor):
return model_inputs
def print_data_example(self, example: Dict[str, List[int]]) -> None:
def print_data_example(self, example: dict[str, list[int]]) -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))

View File

@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Type, Union
from typing import TYPE_CHECKING, Optional, Union
from typing_extensions import override
@ -46,8 +47,8 @@ class Template:
format_tools: "Formatter"
format_prefix: "Formatter"
default_system: str
stop_words: List[str]
thought_words: Tuple[str, str]
stop_words: list[str]
thought_words: tuple[str, str]
efficient_eos: bool
replace_eos: bool
replace_jinja_template: bool
@ -56,13 +57,11 @@ class Template:
def encode_oneturn(
self,
tokenizer: "PreTrainedTokenizer",
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
) -> Tuple[List[int], List[int]]:
r"""
Returns a single pair of token ids representing prompt and response respectively.
"""
) -> tuple[list[int], list[int]]:
r"""Return a single pair of token ids representing prompt and response respectively."""
encoded_messages = self._encode(tokenizer, messages, system, tools)
prompt_ids = []
for encoded_ids in encoded_messages[:-1]:
@ -74,36 +73,28 @@ class Template:
def encode_multiturn(
self,
tokenizer: "PreTrainedTokenizer",
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
) -> List[Tuple[List[int], List[int]]]:
r"""
Returns multiple pairs of token ids representing prompts and responses respectively.
"""
) -> list[tuple[list[int], list[int]]]:
r"""Return multiple pairs of token ids representing prompts and responses respectively."""
encoded_messages = self._encode(tokenizer, messages, system, tools)
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
def extract_tool(self, content: str) -> Union[str, List["FunctionCall"]]:
r"""
Extracts tool message.
"""
def extract_tool(self, content: str) -> Union[str, list["FunctionCall"]]:
r"""Extract tool message."""
return self.format_tools.extract(content)
def get_stop_token_ids(self, tokenizer: "PreTrainedTokenizer") -> List[int]:
r"""
Returns stop token ids.
"""
def get_stop_token_ids(self, tokenizer: "PreTrainedTokenizer") -> list[int]:
r"""Return stop token ids."""
stop_token_ids = {tokenizer.eos_token_id}
for token in self.stop_words:
stop_token_ids.add(tokenizer.convert_tokens_to_ids(token))
return list(stop_token_ids)
def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> List[int]:
r"""
Converts elements to token ids.
"""
def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> list[int]:
r"""Convert elements to token ids."""
token_ids = []
for elem in elements:
if isinstance(elem, str):
@ -124,14 +115,14 @@ class Template:
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
system: Optional[str],
tools: Optional[str],
) -> List[List[int]]:
r"""
Encodes formatted inputs to pairs of token ids.
) -> list[list[int]]:
r"""Encode formatted inputs to pairs of token ids.
Turn 0: prefix + system + query resp
Turn t: query resp
Turn t: query resp.
"""
system = system or self.default_system
encoded_messages = []
@ -161,9 +152,7 @@ class Template:
@staticmethod
def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None:
r"""
Adds or replaces eos token to the tokenizer.
"""
r"""Add or replace eos token to the tokenizer."""
is_added = tokenizer.eos_token_id is None
num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
@ -176,9 +165,7 @@ class Template:
logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
def fix_special_tokens(self, tokenizer: "PreTrainedTokenizer") -> None:
r"""
Adds eos token and pad token to the tokenizer.
"""
r"""Add eos token and pad token to the tokenizer."""
stop_words = self.stop_words
if self.replace_eos:
if not stop_words:
@ -204,16 +191,12 @@ class Template:
@staticmethod
def _jinja_escape(content: str) -> str:
r"""
Escape single quotes in content.
"""
r"""Escape single quotes in content."""
return content.replace("'", r"\'")
@staticmethod
def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str:
r"""
Converts slots to jinja template.
"""
r"""Convert slots to jinja template."""
slot_items = []
for slot in slots:
if isinstance(slot, str):
@ -235,9 +218,7 @@ class Template:
return " + ".join(slot_items)
def _get_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> str:
r"""
Returns the jinja template.
"""
r"""Return the jinja template."""
prefix = self._convert_slots_to_jinja(self.format_prefix.apply(), tokenizer)
system = self._convert_slots_to_jinja(self.format_system.apply(), tokenizer, placeholder="system_message")
user = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer)
@ -265,9 +246,7 @@ class Template:
return jinja_template
def fix_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> None:
r"""
Replaces the jinja template in the tokenizer.
"""
r"""Replace the jinja template in the tokenizer."""
if tokenizer.chat_template is None or self.replace_jinja_template:
try:
tokenizer.chat_template = self._get_jinja_template(tokenizer)
@ -278,9 +257,7 @@ class Template:
def _convert_slots_to_ollama(
slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content"
) -> str:
r"""
Converts slots to ollama template.
"""
r"""Convert slots to ollama template."""
slot_items = []
for slot in slots:
if isinstance(slot, str):
@ -302,9 +279,7 @@ class Template:
return "".join(slot_items)
def _get_ollama_template(self, tokenizer: "PreTrainedTokenizer") -> str:
r"""
Returns the ollama template.
"""
r"""Return the ollama template."""
prefix = self._convert_slots_to_ollama(self.format_prefix.apply(), tokenizer)
system = self._convert_slots_to_ollama(self.format_system.apply(), tokenizer, placeholder=".System")
user = self._convert_slots_to_ollama(self.format_user.apply(), tokenizer, placeholder=".Content")
@ -316,8 +291,7 @@ class Template:
)
def get_ollama_modelfile(self, tokenizer: "PreTrainedTokenizer") -> str:
r"""
Returns the ollama modelfile.
r"""Return the ollama modelfile.
TODO: support function calling.
"""
@ -340,10 +314,10 @@ class Llama2Template(Template):
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
system: str,
tools: str,
) -> List[List[int]]:
) -> list[list[int]]:
system = system or self.default_system
encoded_messages = []
for i, message in enumerate(messages):
@ -402,7 +376,7 @@ class Llama2Template(Template):
return jinja_template
TEMPLATES: Dict[str, "Template"] = {}
TEMPLATES: dict[str, "Template"] = {}
def register_template(
@ -416,15 +390,14 @@ def register_template(
format_prefix: Optional["Formatter"] = None,
default_system: str = "",
stop_words: Optional[Sequence[str]] = None,
thought_words: Optional[Tuple[str, str]] = None,
thought_words: Optional[tuple[str, str]] = None,
efficient_eos: bool = False,
replace_eos: bool = False,
replace_jinja_template: bool = False,
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
template_class: Type["Template"] = Template,
template_class: type["Template"] = Template,
) -> None:
r"""
Registers a chat template.
r"""Register a chat template.
To add the following chat template:
```
@ -472,9 +445,7 @@ def register_template(
def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
r"""
Extracts a chat template from the tokenizer.
"""
r"""Extract a chat template from the tokenizer."""
def find_diff(short_str: str, long_str: str) -> str:
i, j = 0, 0
@ -532,9 +503,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template":
r"""
Gets chat template and fixes the tokenizer.
"""
r"""Get chat template and fixes the tokenizer."""
if data_args.template is None:
if isinstance(tokenizer.chat_template, str):
logger.warning_rank0("`template` was not specified, try parsing the chat template from the tokenizer.")
@ -1149,7 +1118,8 @@ register_template(
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
default_system=(
"你是一个经过良好训练的AI助手你的名字是Marco-o1.由阿里国际数字商业集团的AI Business创造.\n## 重要!!!!!\n"
"你是一个经过良好训练的AI助手你的名字是Marco-o1."
"由阿里国际数字商业集团的AI Business创造.\n## 重要!!!!!\n"
"当你回答问题时,你的思考应该在<Thought>内完成,<Output>内输出你的结果。\n"
"<Thought>应该尽可能是英文但是有2个特例一个是对原文中的引用另一个是是数学应该使用markdown格式<Output>内的输出需要遵循用户输入的语言。\n"
),

View File

@ -17,7 +17,7 @@ import re
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Dict, List, NamedTuple, Tuple, Union
from typing import Any, NamedTuple, Union
from typing_extensions import override
@ -60,31 +60,24 @@ QWEN_TOOL_PROMPT = (
@dataclass
class ToolUtils(ABC):
"""
Base class for tool utilities.
"""
"""Base class for tool utilities."""
@staticmethod
@abstractmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
r"""
Generates the system message describing all the available tools.
"""
def tool_formatter(tools: list[dict[str, Any]]) -> str:
r"""Generate the system message describing all the available tools."""
...
@staticmethod
@abstractmethod
def function_formatter(functions: List["FunctionCall"]) -> str:
r"""
Generates the assistant message including all the tool calls.
"""
def function_formatter(functions: list["FunctionCall"]) -> str:
r"""Generate the assistant message including all the tool calls."""
...
@staticmethod
@abstractmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
r"""
Extracts all the function calls from the assistant message.
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
r"""Extract all the function calls from the assistant message.
It should be an inverse function of `function_formatter`.
"""
@ -92,13 +85,11 @@ class ToolUtils(ABC):
class DefaultToolUtils(ToolUtils):
r"""
Default tool using template.
"""
r"""Default tool using template."""
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
def tool_formatter(tools: list[dict[str, Any]]) -> str:
tool_text = ""
tool_names = []
for tool in tools:
@ -132,7 +123,7 @@ class DefaultToolUtils(ToolUtils):
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> str:
def function_formatter(functions: list["FunctionCall"]) -> str:
function_text = ""
for name, arguments in functions:
function_text += f"Action: {name}\nAction Input: {arguments}\n"
@ -141,9 +132,9 @@ class DefaultToolUtils(ToolUtils):
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL)
action_match: List[Tuple[str, str]] = re.findall(regex, content)
action_match: list[tuple[str, str]] = re.findall(regex, content)
if not action_match:
return content
@ -161,13 +152,11 @@ class DefaultToolUtils(ToolUtils):
class GLM4ToolUtils(ToolUtils):
r"""
GLM-4 tool using template.
"""
r"""GLM-4 tool using template."""
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
def tool_formatter(tools: list[dict[str, Any]]) -> str:
tool_text = ""
for tool in tools:
tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(
@ -178,7 +167,7 @@ class GLM4ToolUtils(ToolUtils):
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> str:
def function_formatter(functions: list["FunctionCall"]) -> str:
if len(functions) > 1:
raise ValueError("GLM-4 does not support parallel functions.")
@ -186,7 +175,7 @@ class GLM4ToolUtils(ToolUtils):
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
if "\n" not in content:
return content
@ -200,15 +189,14 @@ class GLM4ToolUtils(ToolUtils):
class Llama3ToolUtils(ToolUtils):
r"""
Llama 3.x tool using template with `tools_in_user_message=False`.
r"""Llama 3.x tool using template with `tools_in_user_message=False`.
Reference: https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling
"""
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
def tool_formatter(tools: list[dict[str, Any]]) -> str:
date = datetime.now().strftime("%d %b %Y")
tool_text = ""
for tool in tools:
@ -219,7 +207,7 @@ class Llama3ToolUtils(ToolUtils):
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> str:
def function_formatter(functions: list["FunctionCall"]) -> str:
if len(functions) > 1:
raise ValueError("Llama-3 does not support parallel functions.")
@ -227,7 +215,7 @@ class Llama3ToolUtils(ToolUtils):
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
try:
tool = json.loads(content.strip())
except json.JSONDecodeError:
@ -240,13 +228,11 @@ class Llama3ToolUtils(ToolUtils):
class MistralToolUtils(ToolUtils):
r"""
Mistral v0.3 tool using template.
"""
r"""Mistral v0.3 tool using template."""
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
def tool_formatter(tools: list[dict[str, Any]]) -> str:
wrapped_tools = []
for tool in tools:
wrapped_tools.append({"type": "function", "function": tool})
@ -255,7 +241,7 @@ class MistralToolUtils(ToolUtils):
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> str:
def function_formatter(functions: list["FunctionCall"]) -> str:
function_texts = []
for name, arguments in functions:
function_texts.append(f'{{"name": "{name}", "arguments": {arguments}}}')
@ -264,7 +250,7 @@ class MistralToolUtils(ToolUtils):
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
try:
tools = json.loads(content.strip())
except json.JSONDecodeError:
@ -284,13 +270,11 @@ class MistralToolUtils(ToolUtils):
class QwenToolUtils(ToolUtils):
r"""
Qwen 2.5 tool using template.
"""
r"""Qwen 2.5 tool using template."""
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
def tool_formatter(tools: list[dict[str, Any]]) -> str:
tool_text = ""
for tool in tools:
wrapped_tool = {"type": "function", "function": tool}
@ -300,7 +284,7 @@ class QwenToolUtils(ToolUtils):
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> str:
def function_formatter(functions: list["FunctionCall"]) -> str:
function_texts = []
for name, arguments in functions:
function_texts.append(
@ -311,9 +295,9 @@ class QwenToolUtils(ToolUtils):
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
regex = re.compile(r"<tool_call>(.+?)</tool_call>(?=\s*<tool_call>|\s*$)", re.DOTALL)
tool_match: List[str] = re.findall(regex, content)
tool_match: list[str] = re.findall(regex, content)
if not tool_match:
return content

View File

@ -39,7 +39,7 @@
import json
import os
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Optional
import numpy as np
import torch
@ -59,7 +59,7 @@ if TYPE_CHECKING:
class Evaluator:
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
def __init__(self, args: Optional[dict[str, Any]] = None) -> None:
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
self.tokenizer = load_tokenizer(self.model_args)["tokenizer"]
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
@ -69,7 +69,7 @@ class Evaluator:
self.choice_inputs = [self.tokenizer.encode(ch, add_special_tokens=False)[-1] for ch in CHOICES]
@torch.inference_mode()
def batch_inference(self, batch_input: Dict[str, "torch.Tensor"]) -> List[str]:
def batch_inference(self, batch_input: dict[str, "torch.Tensor"]) -> list[str]:
logits = self.model(**batch_input).logits
lengths = torch.sum(batch_input["attention_mask"], dim=-1)
word_probs = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0)
@ -88,7 +88,7 @@ class Evaluator:
)
with open(mapping, encoding="utf-8") as f:
categorys: Dict[str, Dict[str, str]] = json.load(f)
categorys: dict[str, dict[str, str]] = json.load(f)
category_corrects = {subj: np.array([], dtype="bool") for subj in SUBJECTS}
pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
@ -136,7 +136,7 @@ class Evaluator:
pbar.close()
self._save_results(category_corrects, results)
def _save_results(self, category_corrects: Dict[str, "NDArray"], results: Dict[str, Dict[int, str]]) -> None:
def _save_results(self, category_corrects: dict[str, "NDArray"], results: dict[str, dict[int, str]]) -> None:
score_info = "\n".join(
[
f"{category_name:>15}: {100 * np.mean(category_correct):.2f}"

View File

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Dict, List, Sequence, Tuple
from ..data import Role
from ..extras.constants import CHOICES
@ -25,20 +25,19 @@ class EvalTemplate:
choice: str
answer: str
def _parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
r"""
def _parse_example(self, example: dict[str, str]) -> tuple[str, str]:
r"""Parse eval example.
input: a dict with keys {"question", "A", "B", "C", "D", "answer"}
output: a tuple of (prompt, response)
output: a tuple of (prompt, response).
"""
candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example]
return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
def format_example(
self, target_data: Dict[str, str], support_set: Sequence[Dict[str, str]], subject_name: str
) -> List[Dict[str, str]]:
r"""
Converts dataset examples to messages.
"""
self, target_data: dict[str, str], support_set: Sequence[dict[str, str]], subject_name: str
) -> list[dict[str, str]]:
r"""Convert dataset examples to messages."""
messages = []
for k in range(len(support_set)):
prompt, response = self._parse_example(support_set[k])
@ -52,7 +51,7 @@ class EvalTemplate:
return messages
eval_templates: Dict[str, "EvalTemplate"] = {}
eval_templates: dict[str, "EvalTemplate"] = {}
def _register_eval_template(name: str, system: str, choice: str, answer: str) -> None:

View File

@ -15,7 +15,7 @@
import os
from collections import OrderedDict, defaultdict
from enum import Enum
from typing import Dict, Optional
from typing import Optional
from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME
from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME
@ -122,7 +122,7 @@ class RopeScaling(str, Enum):
def register_model_group(
models: Dict[str, Dict[DownloadSource, str]],
models: dict[str, dict[DownloadSource, str]],
template: Optional[str] = None,
multimodal: bool = False,
) -> None:

View File

@ -32,9 +32,7 @@ _default_log_level: "logging._Level" = logging.INFO
class LoggerHandler(logging.Handler):
r"""
Redirects the logging output to the logging file for LLaMA Board.
"""
r"""Redirect the logging output to the logging file for LLaMA Board."""
def __init__(self, output_dir: str) -> None:
super().__init__()
@ -67,9 +65,7 @@ class LoggerHandler(logging.Handler):
class _Logger(logging.Logger):
r"""
A logger that supports rank0 logging.
"""
r"""A logger that supports rank0 logging."""
def info_rank0(self, *args, **kwargs) -> None:
self.info(*args, **kwargs)
@ -82,9 +78,7 @@ class _Logger(logging.Logger):
def _get_default_logging_level() -> "logging._Level":
r"""
Returns the default logging level.
"""
r"""Return the default logging level."""
env_level_str = os.environ.get("LLAMAFACTORY_VERBOSITY", None)
if env_level_str:
if env_level_str.upper() in logging._nameToLevel:
@ -104,9 +98,7 @@ def _get_library_root_logger() -> "_Logger":
def _configure_library_root_logger() -> None:
r"""
Configures root logger using a stdout stream handler with an explicit format.
"""
r"""Configure root logger using a stdout stream handler with an explicit format."""
global _default_handler
with _thread_lock:
@ -126,9 +118,7 @@ def _configure_library_root_logger() -> None:
def get_logger(name: Optional[str] = None) -> "_Logger":
r"""
Returns a logger with the specified name. It it not supposed to be accessed externally.
"""
r"""Return a logger with the specified name. It it not supposed to be accessed externally."""
if name is None:
name = _get_library_name()
@ -137,17 +127,13 @@ def get_logger(name: Optional[str] = None) -> "_Logger":
def add_handler(handler: "logging.Handler") -> None:
r"""
Adds a handler to the root logger.
"""
r"""Add a handler to the root logger."""
_configure_library_root_logger()
_get_library_root_logger().addHandler(handler)
def remove_handler(handler: logging.Handler) -> None:
r"""
Removes a handler to the root logger.
"""
r"""Remove a handler to the root logger."""
_configure_library_root_logger()
_get_library_root_logger().removeHandler(handler)

View File

@ -17,7 +17,8 @@
import gc
import os
from typing import TYPE_CHECKING, Any, Dict, Literal, Sequence, Tuple, Union
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Literal, Union
import torch
import torch.distributed as dist
@ -54,9 +55,7 @@ logger = logging.get_logger(__name__)
class AverageMeter:
r"""
Computes and stores the average and current value.
"""
r"""Compute and store the average and current value."""
def __init__(self):
self.reset()
@ -75,9 +74,7 @@ class AverageMeter:
def check_version(requirement: str, mandatory: bool = False) -> None:
r"""
Optionally checks the package version.
"""
r"""Optionally check the package version."""
if is_env_enabled("DISABLE_VERSION_CHECK") and not mandatory:
logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.")
return
@ -91,9 +88,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
def check_dependencies() -> None:
r"""
Checks the version of the required packages.
"""
r"""Check the version of the required packages."""
check_version("transformers>=4.41.2,<=4.49.0,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
check_version("datasets>=2.16.0,<=3.2.0")
check_version("accelerate>=0.34.0,<=1.2.1")
@ -103,10 +98,8 @@ def check_dependencies() -> None:
logger.warning_rank0_once("There are known bugs in transformers v4.46.0-v4.48.0, please use other versions.")
def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float], stage: Literal["sft", "rm"]) -> float:
r"""
Calculates effective tokens per second.
"""
def calculate_tps(dataset: Sequence[dict[str, Any]], metrics: dict[str, float], stage: Literal["sft", "rm"]) -> float:
r"""Calculate effective tokens per second."""
effective_token_num = 0
for data in dataset:
if stage == "sft":
@ -118,10 +111,8 @@ def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float],
return result / dist.get_world_size() if dist.is_initialized() else result
def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
r"""
Returns the number of trainable parameters and number of all parameters in the model.
"""
def count_parameters(model: "torch.nn.Module") -> tuple[int, int]:
r"""Return the number of trainable parameters and number of all parameters in the model."""
trainable_params, all_param = 0, 0
for param in model.parameters():
num_params = param.numel()
@ -148,9 +139,7 @@ def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
def get_current_device() -> "torch.device":
r"""
Gets the current available device.
"""
r"""Get the current available device."""
if is_torch_xpu_available():
device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
elif is_torch_npu_available():
@ -166,9 +155,7 @@ def get_current_device() -> "torch.device":
def get_device_count() -> int:
r"""
Gets the number of available GPU or NPU devices.
"""
r"""Get the number of available GPU or NPU devices."""
if is_torch_xpu_available():
return torch.xpu.device_count()
elif is_torch_npu_available():
@ -180,18 +167,14 @@ def get_device_count() -> int:
def get_logits_processor() -> "LogitsProcessorList":
r"""
Gets logits processor that removes NaN and Inf logits.
"""
r"""Get logits processor that removes NaN and Inf logits."""
logits_processor = LogitsProcessorList()
logits_processor.append(InfNanRemoveLogitsProcessor())
return logits_processor
def get_peak_memory() -> Tuple[int, int]:
r"""
Gets the peak memory usage for the current device (in Bytes).
"""
def get_peak_memory() -> tuple[int, int]:
r"""Get the peak memory usage for the current device (in Bytes)."""
if is_torch_npu_available():
return torch.npu.max_memory_allocated(), torch.npu.max_memory_reserved()
elif is_torch_cuda_available():
@ -201,16 +184,12 @@ def get_peak_memory() -> Tuple[int, int]:
def has_tokenized_data(path: "os.PathLike") -> bool:
r"""
Checks if the path has a tokenized dataset.
"""
r"""Check if the path has a tokenized dataset."""
return os.path.isdir(path) and len(os.listdir(path)) > 0
def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
r"""
Infers the optimal dtype according to the model_dtype and device compatibility.
"""
r"""Infer the optimal dtype according to the model_dtype and device compatibility."""
if _is_bf16_available and model_dtype == torch.bfloat16:
return torch.bfloat16
elif _is_fp16_available:
@ -220,23 +199,17 @@ def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
def is_gpu_or_npu_available() -> bool:
r"""
Checks if the GPU or NPU is available.
"""
r"""Check if the GPU or NPU is available."""
return is_torch_npu_available() or is_torch_cuda_available()
def is_env_enabled(env_var: str, default: str = "0") -> bool:
r"""
Checks if the environment variable is enabled.
"""
r"""Check if the environment variable is enabled."""
return os.getenv(env_var, default).lower() in ["true", "y", "1"]
def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
r"""
Casts a torch tensor or a numpy array to a numpy array.
"""
r"""Cast a torch tensor or a numpy array to a numpy array."""
if isinstance(inputs, torch.Tensor):
inputs = inputs.cpu()
if inputs.dtype == torch.bfloat16: # numpy does not support bfloat16 until 1.21.4
@ -248,17 +221,13 @@ def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
def skip_check_imports() -> None:
r"""
Avoids flash attention import error in custom model files.
"""
r"""Avoid flash attention import error in custom model files."""
if not is_env_enabled("FORCE_CHECK_IMPORTS"):
transformers.dynamic_module_utils.check_imports = get_relative_imports
def torch_gc() -> None:
r"""
Collects GPU or NPU memory.
"""
r"""Collect GPU or NPU memory."""
gc.collect()
if is_torch_xpu_available():
torch.xpu.empty_cache()

View File

@ -15,7 +15,7 @@
import json
import math
import os
from typing import Any, Dict, List
from typing import Any
from transformers.trainer import TRAINER_STATE_NAME
@ -31,10 +31,8 @@ if is_matplotlib_available():
logger = logging.get_logger(__name__)
def smooth(scalars: List[float]) -> List[float]:
r"""
EMA implementation according to TensorBoard.
"""
def smooth(scalars: list[float]) -> list[float]:
r"""EMA implementation according to TensorBoard."""
if len(scalars) == 0:
return []
@ -48,10 +46,8 @@ def smooth(scalars: List[float]) -> List[float]:
return smoothed
def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figure":
r"""
Plots loss curves in LlamaBoard.
"""
def gen_loss_plot(trainer_log: list[dict[str, Any]]) -> "matplotlib.figure.Figure":
r"""Plot loss curves in LlamaBoard."""
plt.close("all")
plt.switch_backend("agg")
fig = plt.figure()
@ -70,10 +66,8 @@ def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figur
return fig
def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None:
r"""
Plots loss curves and saves the image.
"""
def plot_loss(save_dictionary: str, keys: list[str] = ["loss"]) -> None:
r"""Plot loss curves and saves the image."""
plt.switch_backend("agg")
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), encoding="utf-8") as f:
data = json.load(f)

View File

@ -16,14 +16,12 @@
# limitations under the License.
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, Literal, Optional
from typing import Any, Literal, Optional
@dataclass
class DataArguments:
r"""
Arguments pertaining to what data we are going to input our model for training and evaluation.
"""
r"""Arguments pertaining to what data we are going to input our model for training and evaluation."""
template: Optional[str] = field(
default=None,
@ -162,5 +160,5 @@ class DataArguments:
if self.mask_history and self.train_on_prompt:
raise ValueError("`mask_history` is incompatible with `train_on_prompt`.")
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
return asdict(self)

View File

@ -21,9 +21,7 @@ from datasets import DownloadMode
@dataclass
class EvaluationArguments:
r"""
Arguments pertaining to specify the evaluation parameters.
"""
r"""Arguments pertaining to specify the evaluation parameters."""
task: str = field(
metadata={"help": "Name of the evaluation task."},

View File

@ -13,14 +13,12 @@
# limitations under the License.
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, List, Literal, Optional
from typing import Any, Literal, Optional
@dataclass
class FreezeArguments:
r"""
Arguments pertaining to the freeze (partial-parameter) training.
"""
r"""Arguments pertaining to the freeze (partial-parameter) training."""
freeze_trainable_layers: int = field(
default=2,
@ -56,9 +54,7 @@ class FreezeArguments:
@dataclass
class LoraArguments:
r"""
Arguments pertaining to the LoRA training.
"""
r"""Arguments pertaining to the LoRA training."""
additional_target: Optional[str] = field(
default=None,
@ -128,9 +124,7 @@ class LoraArguments:
@dataclass
class RLHFArguments:
r"""
Arguments pertaining to the PPO, DPO and KTO training.
"""
r"""Arguments pertaining to the PPO, DPO and KTO training."""
pref_beta: float = field(
default=0.1,
@ -212,9 +206,7 @@ class RLHFArguments:
@dataclass
class GaloreArguments:
r"""
Arguments pertaining to the GaLore algorithm.
"""
r"""Arguments pertaining to the GaLore algorithm."""
use_galore: bool = field(
default=False,
@ -253,9 +245,7 @@ class GaloreArguments:
@dataclass
class ApolloArguments:
r"""
Arguments pertaining to the APOLLO algorithm.
"""
r"""Arguments pertaining to the APOLLO algorithm."""
use_apollo: bool = field(
default=False,
@ -306,9 +296,7 @@ class ApolloArguments:
@dataclass
class BAdamArgument:
r"""
Arguments pertaining to the BAdam optimizer.
"""
r"""Arguments pertaining to the BAdam optimizer."""
use_badam: bool = field(
default=False,
@ -393,9 +381,7 @@ class SwanLabArguments:
class FinetuningArguments(
SwanLabArguments, BAdamArgument, ApolloArguments, GaloreArguments, RLHFArguments, LoraArguments, FreezeArguments
):
r"""
Arguments pertaining to which techniques we are going to fine-tuning with.
"""
r"""Arguments pertaining to which techniques we are going to fine-tuning with."""
pure_bf16: bool = field(
default=False,
@ -452,13 +438,13 @@ class FinetuningArguments(
return [item.strip() for item in arg.split(",")]
return arg
self.freeze_trainable_modules: List[str] = split_arg(self.freeze_trainable_modules)
self.freeze_extra_modules: Optional[List[str]] = split_arg(self.freeze_extra_modules)
self.freeze_trainable_modules: list[str] = split_arg(self.freeze_trainable_modules)
self.freeze_extra_modules: Optional[list[str]] = split_arg(self.freeze_extra_modules)
self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2
self.lora_target: List[str] = split_arg(self.lora_target)
self.additional_target: Optional[List[str]] = split_arg(self.additional_target)
self.galore_target: List[str] = split_arg(self.galore_target)
self.apollo_target: List[str] = split_arg(self.apollo_target)
self.lora_target: list[str] = split_arg(self.lora_target)
self.additional_target: Optional[list[str]] = split_arg(self.additional_target)
self.galore_target: list[str] = split_arg(self.galore_target)
self.apollo_target: list[str] = split_arg(self.apollo_target)
self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
@ -499,7 +485,7 @@ class FinetuningArguments(
if self.pissa_init:
raise ValueError("`pissa_init` is only valid for LoRA training.")
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
args = asdict(self)
args = {k: f"<{k.upper()}>" if k.endswith("api_key") else v for k, v in args.items()}
return args

View File

@ -13,16 +13,14 @@
# limitations under the License.
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, Optional
from typing import Any, Optional
from transformers import GenerationConfig
@dataclass
class GeneratingArguments:
r"""
Arguments pertaining to specify the decoding parameters.
"""
r"""Arguments pertaining to specify the decoding parameters."""
do_sample: bool = field(
default=True,
@ -35,7 +33,9 @@ class GeneratingArguments:
top_p: float = field(
default=0.7,
metadata={
"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."
"help": (
"The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."
)
},
)
top_k: int = field(
@ -71,7 +71,7 @@ class GeneratingArguments:
metadata={"help": "Whether or not to remove special tokens in the decoding."},
)
def to_dict(self, obey_generation_config: bool = False) -> Dict[str, Any]:
def to_dict(self, obey_generation_config: bool = False) -> dict[str, Any]:
args = asdict(self)
if args.get("max_new_tokens", -1) > 0:
args.pop("max_length", None)

View File

@ -17,7 +17,7 @@
import json
from dataclasses import asdict, dataclass, field, fields
from typing import Any, Dict, Literal, Optional, Union
from typing import Any, Literal, Optional, Union
import torch
from transformers.training_args import _convert_str_dict
@ -28,9 +28,7 @@ from ..extras.constants import AttentionFunction, EngineName, RopeScaling
@dataclass
class BaseModelArguments:
r"""
Arguments pertaining to the model.
"""
r"""Arguments pertaining to the model."""
model_name_or_path: Optional[str] = field(
default=None,
@ -184,9 +182,7 @@ class BaseModelArguments:
@dataclass
class QuantizationArguments:
r"""
Arguments pertaining to the quantization method.
"""
r"""Arguments pertaining to the quantization method."""
quantization_method: Literal["bitsandbytes", "hqq", "eetq"] = field(
default="bitsandbytes",
@ -212,9 +208,7 @@ class QuantizationArguments:
@dataclass
class ProcessorArguments:
r"""
Arguments pertaining to the image processor.
"""
r"""Arguments pertaining to the image processor."""
image_max_pixels: int = field(
default=768 * 768,
@ -244,9 +238,7 @@ class ProcessorArguments:
@dataclass
class ExportArguments:
r"""
Arguments pertaining to the model export.
"""
r"""Arguments pertaining to the model export."""
export_dir: Optional[str] = field(
default=None,
@ -292,9 +284,7 @@ class ExportArguments:
@dataclass
class VllmArguments:
r"""
Arguments pertaining to the vLLM worker.
"""
r"""Arguments pertaining to the vLLM worker."""
vllm_maxlen: int = field(
default=4096,
@ -324,8 +314,7 @@ class VllmArguments:
@dataclass
class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, QuantizationArguments, BaseModelArguments):
r"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
r"""Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
The class on the most right will be displayed first.
"""
@ -335,7 +324,7 @@ class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, Quantiz
init=False,
metadata={"help": "Torch data type for computing model outputs, derived from `fp/bf16`. Do not specify it."},
)
device_map: Optional[Union[str, Dict[str, Any]]] = field(
device_map: Optional[Union[str, dict[str, Any]]] = field(
default=None,
init=False,
metadata={"help": "Device map for model placement, derived from training stage. Do not specify it."},
@ -372,7 +361,7 @@ class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, Quantiz
return result
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
args = asdict(self)
args = {k: f"<{k.upper()}>" if k.endswith("token") else v for k, v in args.items()}
return args

View File

@ -19,7 +19,7 @@ import json
import os
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Optional, Union
import torch
import transformers
@ -47,17 +47,15 @@ check_dependencies()
_TRAIN_ARGS = [ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
_TRAIN_CLS = tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_INFER_CLS = tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
_EVAL_CLS = tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
def read_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> Union[Dict[str, Any], List[str]]:
r"""
Gets arguments from the command line or a config file.
"""
def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[dict[str, Any], list[str]]:
r"""Get arguments from the command line or a config file."""
if args is not None:
return args
@ -70,8 +68,8 @@ def read_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> Union[
def _parse_args(
parser: "HfArgumentParser", args: Optional[Union[Dict[str, Any], List[str]]] = None, allow_extra_keys: bool = False
) -> Tuple[Any]:
parser: "HfArgumentParser", args: Optional[Union[dict[str, Any], list[str]]] = None, allow_extra_keys: bool = False
) -> tuple[Any]:
args = read_args(args)
if isinstance(args, dict):
return parser.parse_dict(args, allow_extra_keys=allow_extra_keys)
@ -161,31 +159,31 @@ def _check_extra_dependencies(
check_version("rouge_chinese", mandatory=True)
def _parse_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS:
def _parse_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS:
parser = HfArgumentParser(_TRAIN_ARGS)
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
def _parse_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS:
def _parse_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS:
parser = HfArgumentParser(_INFER_ARGS)
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
def _parse_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS:
def _parse_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS:
parser = HfArgumentParser(_EVAL_ARGS)
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
def get_ray_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> RayArguments:
def get_ray_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> RayArguments:
parser = HfArgumentParser(RayArguments)
(ray_args,) = _parse_args(parser, args, allow_extra_keys=True)
return ray_args
def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS:
def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS:
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
# Setup logging
@ -364,9 +362,7 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
and training_args.resume_from_checkpoint is not None
):
logger.warning_rank0(
"Add {} to `adapter_name_or_path` to resume training from checkpoint.".format(
training_args.resume_from_checkpoint
)
f"Add {training_args.resume_from_checkpoint} to `adapter_name_or_path` to resume training from checkpoint."
)
# Post-process model arguments
@ -382,20 +378,17 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
# Log on each process the small summary
logger.info(
"Process rank: {}, world size: {}, device: {}, distributed training: {}, compute dtype: {}".format(
training_args.process_index,
training_args.world_size,
training_args.device,
training_args.parallel_mode == ParallelMode.DISTRIBUTED,
str(model_args.compute_dtype),
)
f"Process rank: {training_args.process_index}, "
f"world size: {training_args.world_size}, device: {training_args.device}, "
f"distributed training: {training_args.parallel_mode == ParallelMode.DISTRIBUTED}, "
f"compute dtype: {str(model_args.compute_dtype)}"
)
transformers.set_seed(training_args.seed)
return model_args, data_args, training_args, finetuning_args, generating_args
def get_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS:
def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS:
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
_set_transformers_logging()
@ -426,7 +419,7 @@ def get_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
return model_args, data_args, finetuning_args, generating_args
def get_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS:
def get_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS:
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
_set_transformers_logging()

View File

@ -10,9 +10,7 @@ from ..extras.misc import use_ray
@dataclass
class RayArguments:
r"""
Arguments pertaining to the Ray training.
"""
r"""Arguments pertaining to the Ray training."""
ray_run_name: Optional[str] = field(
default=None,
@ -43,9 +41,7 @@ class RayArguments:
@dataclass
class TrainingArguments(RayArguments, Seq2SeqTrainingArguments):
r"""
Arguments pertaining to the trainer.
"""
r"""Arguments pertaining to the trainer."""
def __post_init__(self):
Seq2SeqTrainingArguments.__post_init__(self)

View File

@ -20,9 +20,9 @@ from .model_utils.valuehead import load_valuehead_params
__all__ = [
"QuantizationMethod",
"find_all_linear_modules",
"load_config",
"load_model",
"load_tokenizer",
"find_all_linear_modules",
"load_valuehead_params",
]

View File

@ -81,9 +81,8 @@ def _setup_freeze_tuning(
if finetuning_args.use_llama_pro:
if num_layers % finetuning_args.freeze_trainable_layers != 0:
raise ValueError(
"`num_layers` {} should be divisible by `num_layer_trainable` {}.".format(
num_layers, finetuning_args.freeze_trainable_layers
)
f"`num_layers` {num_layers} should be "
f"divisible by `num_layer_trainable` {finetuning_args.freeze_trainable_layers}."
)
stride = num_layers // finetuning_args.freeze_trainable_layers
@ -178,7 +177,7 @@ def _setup_lora_tuning(
}
for adapter in adapter_to_merge:
model: "LoraModel" = PeftModel.from_pretrained(model, adapter, **init_kwargs)
model: LoraModel = PeftModel.from_pretrained(model, adapter, **init_kwargs)
model = model.merge_and_unload()
if len(adapter_to_merge) > 0:
@ -263,8 +262,7 @@ def init_adapter(
finetuning_args: "FinetuningArguments",
is_trainable: bool,
) -> "PreTrainedModel":
r"""
Initializes the adapters.
r"""Initialize the adapters.
Support full-parameter, freeze and LoRA training.

View File

@ -13,7 +13,7 @@
# limitations under the License.
import os
from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict
from typing import TYPE_CHECKING, Any, Optional, TypedDict
import torch
from transformers import (
@ -51,9 +51,8 @@ class TokenizerModule(TypedDict):
processor: Optional["ProcessorMixin"]
def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
r"""
Gets arguments to load config/tokenizer/model.
def _get_init_kwargs(model_args: "ModelArguments") -> dict[str, Any]:
r"""Get arguments to load config/tokenizer/model.
Note: including inplace operation of model_args.
"""
@ -68,8 +67,7 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
r"""
Loads pretrained tokenizer and optionally loads processor.
r"""Load pretrained tokenizer and optionally loads processor.
Note: including inplace operation of model_args.
"""
@ -110,9 +108,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
def load_config(model_args: "ModelArguments") -> "PretrainedConfig":
r"""
Loads model config.
"""
r"""Load model config."""
init_kwargs = _get_init_kwargs(model_args)
return AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs)
@ -124,9 +120,7 @@ def load_model(
is_trainable: bool = False,
add_valuehead: bool = False,
) -> "PreTrainedModel":
r"""
Loads pretrained model.
"""
r"""Load pretrained model."""
init_kwargs = _get_init_kwargs(model_args)
config = load_config(model_args)
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
@ -194,8 +188,9 @@ def load_model(
trainable_params, all_param = count_parameters(model)
if is_trainable:
param_stats = "trainable params: {:,} || all params: {:,} || trainable%: {:.4f}".format(
trainable_params, all_param, 100 * trainable_params / all_param
param_stats = (
f"trainable params: {trainable_params:,} || "
f"all params: {all_param:,} || trainable%: {100 * trainable_params / all_param:.4f}"
)
else:
param_stats = f"all params: {all_param:,}"

View File

@ -21,7 +21,7 @@
import inspect
from functools import WRAPPER_ASSIGNMENTS, partial, wraps
from types import MethodType
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import torch
@ -40,9 +40,7 @@ logger = logging.get_logger(__name__)
def get_unsloth_gradient_checkpointing_func() -> Callable:
class UnslothGradientCheckpointing(torch.autograd.Function):
r"""
Saves VRAM by smartly offloading to RAM.
"""
r"""Saves VRAM by smartly offloading to RAM."""
@staticmethod
@torch.cuda.amp.custom_fwd
@ -77,13 +75,11 @@ def get_unsloth_gradient_checkpointing_func() -> Callable:
def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable) -> Callable:
r"""
Only applies gradient checkpointing to trainable layers.
"""
r"""Only applies gradient checkpointing to trainable layers."""
@wraps(gradient_checkpointing_func, assigned=WRAPPER_ASSIGNMENTS + ("__self__",))
def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
module: "torch.nn.Module" = func.__self__
module: torch.nn.Module = func.__self__
has_grad = False
if any(param.requires_grad for param in module.parameters()):
@ -103,11 +99,10 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
def _gradient_checkpointing_enable(
self: "PreTrainedModel",
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None,
gradient_checkpointing_kwargs: Optional[dict[str, Any]] = None,
use_unsloth_gc: bool = False,
) -> None:
r"""
Activates gradient checkpointing for the current model.
r"""Activates gradient checkpointing for the current model.
Modification of the original method to enable gradient checkpointing for block-wise optimizer.
"""
@ -134,17 +129,18 @@ def _gradient_checkpointing_enable(
def _fp32_forward_post_hook(
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
module: "torch.nn.Module", args: tuple["torch.Tensor"], output: "torch.Tensor"
) -> "torch.Tensor":
return output.to(torch.float32)
def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArguments") -> None:
r"""
Includes:
(1) cast the layernorm in fp32
(2) make output embedding layer require grads
(3) add the upcasting of the lm_head in fp32
r"""Prepare the model before training.
Include:
(1) cast the layernorm in fp32
(2) make output embedding layer require grads
(3) add the upcasting of the lm_head in fp32.
"""
if model_args.upcast_layernorm:
logger.info_rank0("Upcasting layernorm weights in float32.")

View File

@ -38,9 +38,7 @@ def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int
def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None:
r"""
Resize token embeddings.
"""
r"""Resize token embeddings."""
if is_deepspeed_zero3_enabled():
import deepspeed # type: ignore

View File

@ -18,7 +18,7 @@
# limitations under the License.
import math
from typing import TYPE_CHECKING, Optional, Tuple
from typing import TYPE_CHECKING, Optional
import torch
import torch.nn as nn
@ -54,14 +54,14 @@ def llama_attention_forward(
past_key_value: Optional["Cache"] = None,
output_attentions: bool = False,
cache_position: Optional["torch.LongTensor"] = None,
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None,
position_embeddings: Optional[tuple["torch.Tensor", "torch.Tensor"]] = None,
**kwargs,
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]:
) -> tuple["torch.Tensor", Optional["torch.Tensor"], Optional[tuple["torch.Tensor"]]]:
bsz, q_len, _ = hidden_states.size()
query_states: "torch.Tensor" = self.q_proj(hidden_states)
key_states: "torch.Tensor" = self.k_proj(hidden_states)
value_states: "torch.Tensor" = self.v_proj(hidden_states)
query_states: torch.Tensor = self.q_proj(hidden_states)
key_states: torch.Tensor = self.k_proj(hidden_states)
value_states: torch.Tensor = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
@ -139,17 +139,17 @@ def llama_flash_attention_2_forward(
past_key_value: Optional["Cache"] = None,
output_attentions: bool = False,
cache_position: Optional["torch.LongTensor"] = None,
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None,
position_embeddings: Optional[tuple["torch.Tensor", "torch.Tensor"]] = None,
**kwargs,
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]:
) -> tuple["torch.Tensor", Optional["torch.Tensor"], Optional[tuple["torch.Tensor"]]]:
# LlamaFlashAttention2 attention does not support output_attentions
output_attentions = False
bsz, q_len, _ = hidden_states.size()
query_states: "torch.Tensor" = self.q_proj(hidden_states)
key_states: "torch.Tensor" = self.k_proj(hidden_states)
value_states: "torch.Tensor" = self.v_proj(hidden_states)
query_states: torch.Tensor = self.q_proj(hidden_states)
key_states: torch.Tensor = self.k_proj(hidden_states)
value_states: torch.Tensor = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
@ -209,7 +209,7 @@ def llama_flash_attention_2_forward(
if is_transformers_version_greater_than("4.43.0"):
from transformers.modeling_flash_attention_utils import _flash_attention_forward
attn_output: "torch.Tensor" = _flash_attention_forward(
attn_output: torch.Tensor = _flash_attention_forward(
query_states,
key_states,
value_states,
@ -221,7 +221,7 @@ def llama_flash_attention_2_forward(
is_causal=self.is_causal,
)
else:
attn_output: "torch.Tensor" = self._flash_attention_forward(
attn_output: torch.Tensor = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate
)
@ -254,9 +254,9 @@ def llama_sdpa_attention_forward(
past_key_value: Optional["Cache"] = None,
output_attentions: bool = False,
cache_position: Optional["torch.LongTensor"] = None,
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None,
position_embeddings: Optional[tuple["torch.Tensor", "torch.Tensor"]] = None,
**kwargs,
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]:
) -> tuple["torch.Tensor", Optional["torch.Tensor"], Optional[tuple["torch.Tensor"]]]:
if output_attentions:
transformers_logger.warning_once(
"SDPA does not support `output_attentions=True`. Falling back to the vanilla attention"
@ -274,9 +274,9 @@ def llama_sdpa_attention_forward(
bsz, q_len, _ = hidden_states.size()
query_states: "torch.Tensor" = self.q_proj(hidden_states)
key_states: "torch.Tensor" = self.k_proj(hidden_states)
value_states: "torch.Tensor" = self.v_proj(hidden_states)
query_states: torch.Tensor = self.q_proj(hidden_states)
key_states: torch.Tensor = self.k_proj(hidden_states)
value_states: torch.Tensor = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, List
from typing import TYPE_CHECKING
from ...extras import logging
from .visual import COMPOSITE_MODELS
@ -25,10 +25,8 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__)
def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]:
r"""
Finds all available modules to apply LoRA, GaLore or APOLLO.
"""
def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> list[str]:
r"""Find all available modules to apply LoRA, GaLore or APOLLO."""
model_type = getattr(model.config, "model_type", None)
forbidden_modules = {"lm_head"}
if model_type == "chatglm":
@ -54,10 +52,8 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
return list(module_names)
def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], num_layer_trainable: int) -> List[str]:
r"""
Finds the modules in the expanded blocks to apply lora.
"""
def find_expanded_modules(model: "PreTrainedModel", target_modules: list[str], num_layer_trainable: int) -> list[str]:
r"""Find the modules in the expanded blocks to apply lora."""
num_layers = getattr(model.config, "num_hidden_layers", None)
if not num_layers:
raise ValueError("Model was not supported.")

View File

@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Sequence
from collections.abc import Sequence
from typing import TYPE_CHECKING
import torch
from transformers.integrations import is_deepspeed_zero3_enabled
@ -34,9 +35,7 @@ def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: Sequence["torch
def add_z3_leaf_module(model: "PreTrainedModel") -> None:
r"""
Sets module as a leaf module to skip partitioning in deepspeed zero3.
"""
r"""Set module as a leaf module to skip partitioning in deepspeed zero3."""
if not is_deepspeed_zero3_enabled():
return

View File

@ -37,7 +37,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from typing import TYPE_CHECKING, Tuple
from typing import TYPE_CHECKING
import torch
import torch.nn.functional as F
@ -59,8 +59,7 @@ logger = logging.get_logger(__name__)
def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
r"""
Gets the sequnce lengths in the current batch.
r"""Get the sequnce lengths in the current batch.
e.g.
```python
@ -76,7 +75,7 @@ def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
bsz = attention_mask.size(0)
dtype, device = attention_mask.dtype, attention_mask.device
max_num = torch.max(attention_mask).item()
counts: "torch.Tensor" = torch.zeros((bsz, max_num), dtype=dtype, device=device)
counts: torch.Tensor = torch.zeros((bsz, max_num), dtype=dtype, device=device)
for i in range(max_num):
counts[:, i] = torch.sum(attention_mask == (i + 1), dim=-1)
@ -85,9 +84,8 @@ def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
return seqlens
def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "torch.Tensor", int]:
r"""
Prepares the indices and seqlens for flash attn varlen function.
def get_unpad_data(attention_mask: "torch.Tensor") -> tuple["torch.Tensor", "torch.Tensor", int]:
r"""Prepare the indices and seqlens for flash attn varlen function.
Returns:
indices: indices of non-masked tokens from the flattened sequence.
@ -106,6 +104,7 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
[0, 2, 5, 6, 8, 11]
3
```
"""
seqlens_in_batch = get_seqlens_in_batch(attention_mask)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()

View File

@ -19,7 +19,7 @@
import os
import random
from enum import Enum, unique
from typing import TYPE_CHECKING, Any, Dict, List
from typing import TYPE_CHECKING, Any
import torch
from datasets import load_dataset
@ -43,9 +43,7 @@ logger = logging.get_logger(__name__)
@unique
class QuantizationMethod(str, Enum):
r"""
Borrowed from `transformers.utils.quantization_config.QuantizationMethod`.
"""
r"""Borrowed from `transformers.utils.quantization_config.QuantizationMethod`."""
BITS_AND_BYTES = "bitsandbytes"
GPTQ = "gptq"
@ -56,10 +54,8 @@ class QuantizationMethod(str, Enum):
HQQ = "hqq"
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[Dict[str, Any]]:
r"""
Prepares the tokenized dataset to perform AutoGPTQ. Do not use tensor output for JSON serialization.
"""
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> list[dict[str, Any]]:
r"""Prepare the tokenized dataset to perform AutoGPTQ. Do not use tensor output for JSON serialization."""
if os.path.isfile(model_args.export_quantization_dataset):
data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None)
data_files = model_args.export_quantization_dataset
@ -84,7 +80,7 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod
raise ValueError("Cannot find satisfying example, considering decrease `export_quantization_maxlen`.")
sample_idx = random.randint(0, len(dataset) - 1)
sample: Dict[str, "torch.Tensor"] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
sample: dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
n_try += 1
if sample["input_ids"].size(1) > maxlen:
break # TODO: fix large maxlen
@ -101,11 +97,9 @@ def configure_quantization(
config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
init_kwargs: Dict[str, Any],
init_kwargs: dict[str, Any],
) -> None:
r"""
Priority: PTQ-quantized (train/infer) > AutoGPTQ (export) > On-the-fly quantization (train/infer)
"""
r"""Priority: PTQ-quantized (train/infer) > AutoGPTQ (export) > On-the-fly quantization (train/infer)."""
if getattr(config, "quantization_config", None): # ptq
if model_args.quantization_bit is not None:
logger.warning_rank0("`quantization_bit` will not affect on the PTQ-quantized models.")
@ -113,7 +107,7 @@ def configure_quantization(
if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.")
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
quantization_config: dict[str, Any] = getattr(config, "quantization_config", None)
quant_method = quantization_config.get("quant_method", "")
if quant_method == QuantizationMethod.GPTQ:

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Optional
from ...extras import logging
from ...extras.misc import get_current_device
@ -29,7 +29,7 @@ logger = logging.get_logger(__name__)
def _get_unsloth_kwargs(
config: "PretrainedConfig", model_name_or_path: str, model_args: "ModelArguments"
) -> Dict[str, Any]:
) -> dict[str, Any]:
return {
"model_name": model_name_or_path,
"max_seq_length": model_args.model_max_length or 4096,
@ -47,10 +47,8 @@ def _get_unsloth_kwargs(
def load_unsloth_pretrained_model(
config: "PretrainedConfig", model_args: "ModelArguments"
) -> Optional["PreTrainedModel"]:
r"""
Optionally loads pretrained model with unsloth. Used in training.
"""
from unsloth import FastLanguageModel
r"""Optionally load pretrained model with unsloth. Used in training."""
from unsloth import FastLanguageModel # type: ignore
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.model_name_or_path, model_args)
try:
@ -64,12 +62,10 @@ def load_unsloth_pretrained_model(
def get_unsloth_peft_model(
model: "PreTrainedModel", model_args: "ModelArguments", peft_kwargs: Dict[str, Any]
model: "PreTrainedModel", model_args: "ModelArguments", peft_kwargs: dict[str, Any]
) -> "PreTrainedModel":
r"""
Gets the peft model for the pretrained model with unsloth. Used in training.
"""
from unsloth import FastLanguageModel
r"""Get the peft model for the pretrained model with unsloth. Used in training."""
from unsloth import FastLanguageModel # type: ignore
unsloth_peft_kwargs = {
"model": model,
@ -82,10 +78,8 @@ def get_unsloth_peft_model(
def load_unsloth_peft_model(
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
) -> "PreTrainedModel":
r"""
Loads peft model with unsloth. Used in both training and inference.
"""
from unsloth import FastLanguageModel
r"""Load peft model with unsloth. Used in both training and inference."""
from unsloth import FastLanguageModel # type: ignore
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args)
try:

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING
import torch
from transformers.utils import cached_file
@ -30,9 +30,8 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__)
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
r"""
Loads value head parameters from Hugging Face Hub or local disk.
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> dict[str, torch.Tensor]:
r"""Load value head parameters from Hugging Face Hub or local disk.
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
"""

View File

@ -15,8 +15,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple
from typing import TYPE_CHECKING, Optional
import torch
import transformers
@ -40,9 +41,9 @@ transformers_logger = transformers.utils.logging.get_logger(__name__)
class CompositeModel:
model_type: str
projector_key: str
vision_model_keys: List[str]
language_model_keys: List[str]
lora_conflict_keys: List[str]
vision_model_keys: list[str]
language_model_keys: list[str]
lora_conflict_keys: list[str]
def get_projector(self, module: "torch.nn.Module") -> "torch.nn.Module":
for key in self.projector_key.split("."):
@ -51,15 +52,15 @@ class CompositeModel:
return module
COMPOSITE_MODELS: Dict[str, "CompositeModel"] = {}
COMPOSITE_MODELS: dict[str, "CompositeModel"] = {}
def _register_composite_model(
model_type: str,
projector_key: Optional[str] = None,
vision_model_keys: Optional[List[str]] = None,
language_model_keys: Optional[List[str]] = None,
lora_conflict_keys: Optional[List[str]] = None,
vision_model_keys: Optional[list[str]] = None,
language_model_keys: Optional[list[str]] = None,
lora_conflict_keys: Optional[list[str]] = None,
):
COMPOSITE_MODELS[model_type] = CompositeModel(
model_type=model_type,
@ -116,12 +117,10 @@ class LlavaMultiModalProjectorForYiVLForVLLM(LlavaMultiModalProjectorForYiVL):
def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArguments") -> None:
r"""
Casts projector output to half precision for fine-tuning quantized VLMs.
"""
r"""Cast projector output to half precision for fine-tuning quantized VLMs."""
def _mm_projector_forward_post_hook(
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
module: "torch.nn.Module", args: tuple["torch.Tensor"], output: "torch.Tensor"
) -> "torch.Tensor":
return output.to(model_args.compute_dtype)
@ -137,9 +136,7 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen
def configure_visual_model(config: "PretrainedConfig") -> None:
r"""
Patches VLMs before loading them.
"""
r"""Patch VLMs before loading them."""
if getattr(config, "text_config", None) and not getattr(config, "hidden_size", None):
# required for ds zero3 and valuehead models
setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))
@ -149,10 +146,8 @@ def configure_visual_model(config: "PretrainedConfig") -> None:
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL
def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "FinetuningArguments") -> Set[str]:
r"""
Freezes vision tower and language model for VLM full/freeze tuning.
"""
def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "FinetuningArguments") -> set[str]:
r"""Freeze vision tower and language model for VLM full/freeze tuning."""
model_type = getattr(config, "model_type", None)
forbidden_modules = set()
if model_type in COMPOSITE_MODELS:
@ -175,9 +170,7 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
def get_image_seqlen(config: "PretrainedConfig") -> int:
r"""
Computes the number of special tokens per image.
"""
r"""Compute the number of special tokens per image."""
model_type = getattr(config, "model_type", None)
if model_type == "llava":
image_seqlen = (config.vision_config.image_size // config.vision_config.patch_size) ** 2
@ -192,17 +185,13 @@ def get_image_seqlen(config: "PretrainedConfig") -> int:
def get_patch_size(config: "PretrainedConfig", processor: "ProcessorMixin") -> int:
r"""
Computes the patch size of the vit.
"""
r"""Compute the patch size of the vit."""
patch_size = getattr(config.vision_config, "patch_size", getattr(processor, "patch_size", -1))
return patch_size
def get_vision_feature_select_strategy(config: "PretrainedConfig", processor: "ProcessorMixin") -> int:
r"""
Get the vision_feature_select_strategy.
"""
r"""Get the vision_feature_select_strategy."""
vision_feature_select_strategy = getattr(
config, "vision_feature_select_strategy", getattr(processor, "vision_feature_select_strategy", "default")
)
@ -211,10 +200,8 @@ def get_vision_feature_select_strategy(config: "PretrainedConfig", processor: "P
def patch_target_modules(
model: "PreTrainedModel", finetuning_args: "FinetuningArguments", target_modules: Sequence[str]
) -> List[str]:
r"""
Freezes vision tower for VLM LoRA tuning.
"""
) -> list[str]:
r"""Freezes vision tower for VLM LoRA tuning."""
model_type = getattr(model.config, "model_type", None)
if model_type in COMPOSITE_MODELS:
forbidden_modules = get_forbidden_modules(model.config, finetuning_args)

View File

@ -13,7 +13,7 @@
# limitations under the License.
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict
from typing import TYPE_CHECKING, Any
import torch
from peft import PeftModel
@ -93,7 +93,7 @@ def patch_config(
config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
init_kwargs: Dict[str, Any],
init_kwargs: dict[str, Any],
is_trainable: bool,
) -> None:
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32

View File

@ -19,7 +19,7 @@ import sys
import time
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Optional
import torch
import transformers
@ -56,7 +56,8 @@ logger = logging.get_logger(__name__)
def fix_valuehead_checkpoint(
model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool
) -> None:
r"""
r"""Fix the valuehead checkpoint files.
The model is already unwrapped.
There are three cases:
@ -72,10 +73,10 @@ def fix_valuehead_checkpoint(
if safe_serialization:
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
state_dict: dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
else:
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
state_dict: dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
os.remove(path_to_checkpoint)
decoder_state_dict, v_head_state_dict = {}, {}
@ -98,9 +99,7 @@ def fix_valuehead_checkpoint(
class FixValueHeadModelCallback(TrainerCallback):
r"""
A callback for fixing the checkpoint for valuehead models.
"""
r"""A callback for fixing the checkpoint for valuehead models."""
@override
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
@ -112,9 +111,7 @@ class FixValueHeadModelCallback(TrainerCallback):
class SaveProcessorCallback(TrainerCallback):
r"""
A callback for saving the processor.
"""
r"""A callback for saving the processor."""
def __init__(self, processor: "ProcessorMixin") -> None:
self.processor = processor
@ -132,9 +129,7 @@ class SaveProcessorCallback(TrainerCallback):
class PissaConvertCallback(TrainerCallback):
r"""
A callback for converting the PiSSA adapter to a normal one.
"""
r"""A callback for converting the PiSSA adapter to a normal one."""
@override
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
@ -177,9 +172,7 @@ class PissaConvertCallback(TrainerCallback):
class LogCallback(TrainerCallback):
r"""
A callback for logging training and evaluation status.
"""
r"""A callback for logging training and evaluation status."""
def __init__(self) -> None:
# Progress
@ -188,7 +181,7 @@ class LogCallback(TrainerCallback):
self.max_steps = 0
self.elapsed_time = ""
self.remaining_time = ""
self.thread_pool: Optional["ThreadPoolExecutor"] = None
self.thread_pool: Optional[ThreadPoolExecutor] = None
# Status
self.aborted = False
self.do_train = False
@ -219,7 +212,7 @@ class LogCallback(TrainerCallback):
self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
self.remaining_time = str(timedelta(seconds=int(remaining_time)))
def _write_log(self, output_dir: str, logs: Dict[str, Any]) -> None:
def _write_log(self, output_dir: str, logs: dict[str, Any]) -> None:
with open(os.path.join(output_dir, TRAINER_LOG), "a", encoding="utf-8") as f:
f.write(json.dumps(logs) + "\n")
@ -348,9 +341,7 @@ class LogCallback(TrainerCallback):
class ReporterCallback(TrainerCallback):
r"""
A callback for reporting training status to external logger.
"""
r"""A callback for reporting training status to external logger."""
def __init__(
self,

View File

@ -19,7 +19,7 @@ import warnings
from collections import defaultdict
from contextlib import nullcontext
from types import MethodType
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
from typing import TYPE_CHECKING, Literal, Optional, Union
import torch
import torch.nn.functional as F
@ -129,15 +129,11 @@ class CustomDPOTrainer(DPOTrainer):
@override
def get_batch_samples(self, epoch_iterator, num_batches):
r"""
Replaces the method of KTO Trainer with the one of the standard Trainer.
"""
r"""Replace the method of DPO Trainer with the one of the standard Trainer."""
return Trainer.get_batch_samples(self, epoch_iterator, num_batches)
def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
r"""
Computes ORPO's odds ratio (OR) loss for batched log probabilities of the policy model.
"""
r"""Compute ORPO's odds ratio (OR) loss for batched log probabilities of the policy model."""
log_odds = (chosen_logps - rejected_logps) - (
torch.log1p(-torch.exp(chosen_logps)) - torch.log1p(-torch.exp(rejected_logps))
)
@ -147,9 +143,7 @@ class CustomDPOTrainer(DPOTrainer):
return orpo_loss
def simpo_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
r"""
Computes SimPO loss for batched log probabilities of the policy model.
"""
r"""Compute SimPO loss for batched log probabilities of the policy model."""
pi_logratios = chosen_logps - rejected_logps
gamma_logratios = self.simpo_gamma / self.beta
logits = pi_logratios - gamma_logratios
@ -162,10 +156,8 @@ class CustomDPOTrainer(DPOTrainer):
policy_rejected_logps: "torch.Tensor",
reference_chosen_logps: Optional["torch.Tensor"],
reference_rejected_logps: Optional["torch.Tensor"],
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r"""
Computes loss for preference learning.
"""
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r"""Compute loss for preference learning."""
if not self.finetuning_args.use_ref_model:
if self.loss_type == "orpo":
losses = self.odds_ratio_loss(policy_chosen_logps, policy_rejected_logps)
@ -185,17 +177,16 @@ class CustomDPOTrainer(DPOTrainer):
@override
def concatenated_forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r"""
Computes the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO.
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"]
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r"""Compute the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO.
Otherwise the average log probabilities.
"""
if self.finetuning_args.use_ref_model:
batch = nested_detach(batch, clone=True) # avoid error
all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
all_logits: torch.Tensor = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"])
if self.loss_type in ["ipo", "orpo", "simpo"]:
all_logps = all_logps / valid_length
@ -212,11 +203,9 @@ class CustomDPOTrainer(DPOTrainer):
@override
def compute_reference_log_probs(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple[Optional["torch.Tensor"], Optional["torch.Tensor"]]:
r"""
Computes log probabilities of the reference model.
"""
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"]
) -> tuple[Optional["torch.Tensor"], Optional["torch.Tensor"]]:
r"""Compute log probabilities of the reference model."""
if not self.finetuning_args.use_ref_model:
return None, None
@ -236,12 +225,10 @@ class CustomDPOTrainer(DPOTrainer):
def get_batch_loss_metrics(
self,
model: "PreTrainedModel",
batch: Dict[str, "torch.Tensor"],
batch: dict[str, "torch.Tensor"],
train_eval: Literal["train", "eval"] = "train",
) -> Tuple["torch.Tensor", Dict[str, "torch.Tensor"]]:
r"""
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
"""
) -> tuple["torch.Tensor", dict[str, "torch.Tensor"]]:
r"""Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
metrics = {}
(
policy_chosen_logps,
@ -279,18 +266,14 @@ class CustomDPOTrainer(DPOTrainer):
@override
def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
r"""
Subclass and override to accept extra kwargs.
"""
self, model: "PreTrainedModel", inputs: dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", tuple["torch.Tensor", list["torch.Tensor"]]]:
r"""Subclass and override to accept extra kwargs."""
return super().compute_loss(model, inputs, return_outputs)
@override
def log(self, logs: Dict[str, float], *args, **kwargs) -> None:
r"""
Log `logs` on the various objects watching training, including stored metrics.
"""
def log(self, logs: dict[str, float], *args, **kwargs) -> None:
r"""Log `logs` on the various objects watching training, including stored metrics."""
# logs either has "loss" or "eval_loss"
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs

View File

@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Optional
from ...data import PairwiseDataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX
@ -38,7 +38,7 @@ def run_dpo(
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None,
callbacks: Optional[list["TrainerCallback"]] = None,
):
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]

View File

@ -19,7 +19,7 @@ import warnings
from collections import defaultdict
from contextlib import nullcontext
from types import MethodType
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
from typing import TYPE_CHECKING, Literal, Optional, Union
import torch
from transformers import Trainer
@ -120,9 +120,7 @@ class CustomKTOTrainer(KTOTrainer):
@override
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
r"""
Replaces the sequential sampler of KTO Trainer created by trl with the random sampler.
"""
r"""Replace the sequential sampler of KTO Trainer created by trl with the random sampler."""
if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset)
@ -130,18 +128,14 @@ class CustomKTOTrainer(KTOTrainer):
@override
def get_batch_samples(self, epoch_iterator, num_batches):
r"""
Replaces the method of KTO Trainer with the one of the standard Trainer.
"""
r"""Replace the method of KTO Trainer with the one of the standard Trainer."""
return Trainer.get_batch_samples(self, epoch_iterator, num_batches)
@override
def forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r"""
Runs forward pass and computes the log probabilities.
"""
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r"""Run forward pass and computes the log probabilities."""
batch = nested_detach(batch, clone=True) # avoid error
model_inputs = {
"input_ids": batch[f"{prefix}input_ids"],
@ -171,8 +165,8 @@ class CustomKTOTrainer(KTOTrainer):
@override
def concatenated_forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"]
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
target_logits, target_logps, target_logps_avg = self.forward(model, batch)
with torch.no_grad():
_, kl_logps, _ = self.forward(model, batch, prefix="kl_")
@ -189,11 +183,9 @@ class CustomKTOTrainer(KTOTrainer):
@override
def compute_reference_log_probs(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r"""
Computes log probabilities of the reference model.
"""
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"]
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r"""Compute log probabilities of the reference model."""
if self.ref_model is None:
ref_model = model
ref_context = self.accelerator.unwrap_model(model).disable_adapter()
@ -212,11 +204,9 @@ class CustomKTOTrainer(KTOTrainer):
def get_batch_loss_metrics(
self,
model: "PreTrainedModel",
batch: Dict[str, "torch.Tensor"],
) -> Tuple["torch.Tensor", Dict[str, "torch.Tensor"]]:
r"""
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
"""
batch: dict[str, "torch.Tensor"],
) -> tuple["torch.Tensor", dict[str, "torch.Tensor"]]:
r"""Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
metrics = {}
(
policy_chosen_logps,
@ -262,18 +252,14 @@ class CustomKTOTrainer(KTOTrainer):
@override
def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
r"""
Subclass and override to accept extra kwargs.
"""
self, model: "PreTrainedModel", inputs: dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", tuple["torch.Tensor", list["torch.Tensor"]]]:
r"""Subclass and override to accept extra kwargs."""
return super().compute_loss(model, inputs, return_outputs)
@override
def log(self, logs: Dict[str, float], *args, **kwargs) -> None:
r"""
Log `logs` on the various objects watching training, including stored metrics.
"""
def log(self, logs: dict[str, float], *args, **kwargs) -> None:
r"""Log `logs` on the various objects watching training, including stored metrics."""
# logs either has "loss" or "eval_loss"
train_eval = "train" if "loss" in logs else "eval"
prefix = "eval_" if train_eval == "eval" else ""
@ -291,7 +277,7 @@ class CustomKTOTrainer(KTOTrainer):
metric_list = torch.tensor(metric_list, dtype=torch.float).to(self.accelerator.device)
metric_list = self.accelerator.reduce(metric_list, "sum").tolist()
metric_dict: Dict[str, float] = dict(zip(key_list, metric_list))
metric_dict: dict[str, float] = dict(zip(key_list, metric_list))
for split in ["chosen", "rejected"]: # accumulate average metrics from sums and lengths
if f"count/{split}" in metric_dict:
for key in ("rewards", "logps", "logits"):

View File

@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Optional
from ...data import KTODataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX
@ -37,7 +37,7 @@ def run_kto(
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None,
callbacks: Optional[list["TrainerCallback"]] = None,
):
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]

View File

@ -14,7 +14,7 @@
import json
from contextlib import nullcontext
from typing import TYPE_CHECKING, Dict, List, Literal, Optional
from typing import TYPE_CHECKING, Literal, Optional
import torch
from transformers.integrations import is_deepspeed_zero3_enabled
@ -31,10 +31,8 @@ if TYPE_CHECKING:
from trl import AutoModelForCausalLMWithValueHead
def get_rewards_from_server(server_url: str, messages: List[str]) -> List["torch.Tensor"]:
r"""
Gets reward scores from the API server.
"""
def get_rewards_from_server(server_url: str, messages: list[str]) -> list["torch.Tensor"]:
r"""Get reward scores from the API server."""
headers = {"Content-Type": "application/json"}
payload = {"model": "model", "messages": messages}
response = requests.post(server_url, json=payload, headers=headers)
@ -43,9 +41,7 @@ def get_rewards_from_server(server_url: str, messages: List[str]) -> List["torch
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
r"""
Replaces the default/reward modules in the model. The model is already unwrapped.
"""
r"""Replace the default/reward modules in the model. The model is already unwrapped."""
v_head_layer = model.v_head.summary
if is_deepspeed_zero3_enabled():
import deepspeed # type: ignore
@ -66,10 +62,8 @@ def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["d
v_head_layer.bias.data = model.get_buffer(f"{target}_head_bias").detach().clone().to(device)
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, "torch.Tensor"]:
r"""
Dumps the layernorm parameters in the model. The model is already unwrapped (and gathered).
"""
def dump_layernorm(model: "PreTrainedModel") -> dict[str, "torch.Tensor"]:
r"""Dump the layernorm parameters in the model. The model is already unwrapped (and gathered)."""
layer_norm_params = {}
for name, param in model.named_parameters():
if param.data.dtype == torch.float32:
@ -79,10 +73,8 @@ def dump_layernorm(model: "PreTrainedModel") -> Dict[str, "torch.Tensor"]:
return layer_norm_params
def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
r"""
Restores the layernorm parameters in the model. The model is already unwrapped (and gathered).
"""
def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[dict[str, "torch.Tensor"]] = None) -> None:
r"""Restore the layernorm parameters in the model. The model is already unwrapped (and gathered)."""
for name, param in model.named_parameters():
if name in layernorm_params:
param.data = layernorm_params[name]

View File

@ -20,7 +20,7 @@ import os
import sys
import warnings
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Optional
import torch
from accelerate.utils import DistributedDataParallelKwargs
@ -62,9 +62,7 @@ logger = logging.get_logger(__name__)
class CustomPPOTrainer(PPOTrainer, Trainer):
r"""
Inherits PPOTrainer.
"""
r"""Inherit PPOTrainer."""
def __init__(
self,
@ -72,7 +70,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]],
callbacks: Optional[list["TrainerCallback"]],
model: "AutoModelForCausalLMWithValueHead",
reward_model: Optional["AutoModelForCausalLMWithValueHead"],
ref_model: Optional["AutoModelForCausalLMWithValueHead"],
@ -187,9 +185,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.add_callback(BAdamCallback)
def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
r"""
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
"""
r"""Implement training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer."""
if resume_from_checkpoint is not None:
raise ValueError("`resume_from_checkpoint` will be supported in the future version.")
@ -221,9 +217,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
logger.info_rank0(f" Num Epochs = {num_train_epochs:,}")
logger.info_rank0(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}")
logger.info_rank0(
" Total train batch size (w. parallel, buffer, distributed & accumulation) = {:,}".format(
total_train_batch_size
)
f" Total train batch size (w. parallel, buffer, distributed & accumulation) = {total_train_batch_size:,}"
)
logger.info_rank0(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps:,}")
logger.info_rank0(f" Num optimization epochs per batch = {self.finetuning_args.ppo_epochs:,}")
@ -339,21 +333,19 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
return lr_scheduler
@torch.no_grad()
def get_inputs(self, batch: Dict[str, "torch.Tensor"]) -> Tuple[List["torch.Tensor"], List["torch.Tensor"]]:
r"""
Generates model's responses given queries.
"""
def get_inputs(self, batch: dict[str, "torch.Tensor"]) -> tuple[list["torch.Tensor"], list["torch.Tensor"]]:
r"""Generate model's responses given queries."""
if batch["input_ids"].size(0) == 1: # handle llama2 ppo with gradient accumulation > 1
start_index = (batch["input_ids"][0] != self.tokenizer.pad_token_id).nonzero()[0].item()
for k, v in batch.items():
batch[k] = v[:, start_index:]
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
unwrapped_model: AutoModelForCausalLMWithValueHead = self.accelerator.unwrap_model(self.model)
if self.model_args.upcast_layernorm:
layernorm_params = dump_layernorm(unwrapped_model)
generate_output: "torch.Tensor" = unwrapped_model.generate(
generate_output: torch.Tensor = unwrapped_model.generate(
generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch
)
if self.model_args.upcast_layernorm:
@ -381,11 +373,10 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
@torch.no_grad()
def get_rewards(
self,
queries: List["torch.Tensor"],
responses: List["torch.Tensor"],
) -> List["torch.Tensor"]:
r"""
Computes scores using given reward model.
queries: list["torch.Tensor"],
responses: list["torch.Tensor"],
) -> list["torch.Tensor"]:
r"""Compute scores using given reward model.
Both inputs and outputs are put on CPU.
"""
@ -394,8 +385,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
messages = self.tokenizer.batch_decode(token_ids, skip_special_tokens=False)
return get_rewards_from_server(self.reward_model, messages)
batch: Dict[str, "torch.Tensor"] = self.prepare_model_inputs(queries, responses)
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
batch: dict[str, torch.Tensor] = self.prepare_model_inputs(queries, responses)
unwrapped_model: AutoModelForCausalLMWithValueHead = self.accelerator.unwrap_model(self.model)
if self.finetuning_args.reward_model_type == "lora":
replace_model(unwrapped_model, target="reward")
@ -404,7 +395,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
reward_model = self.reward_model
with unwrap_model_for_generation(reward_model, self.accelerator), self.amp_context: # support bf16
values: "torch.Tensor" = reward_model(**batch, return_dict=True, use_cache=False)[-1]
values: torch.Tensor = reward_model(**batch, return_dict=True, use_cache=False)[-1]
if self.finetuning_args.reward_model_type == "lora":
replace_model(unwrapped_model, target="default")
@ -419,12 +410,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
model: "AutoModelForCausalLMWithValueHead",
queries: "torch.Tensor",
responses: "torch.Tensor",
model_inputs: Dict[str, Any],
model_inputs: dict[str, Any],
return_logits: bool = False,
response_masks: Optional["torch.Tensor"] = None,
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], "torch.Tensor", "torch.Tensor"]:
r"""
Calculates model outputs in multiple batches.
) -> tuple["torch.Tensor", Optional["torch.Tensor"], "torch.Tensor", "torch.Tensor"]:
r"""Calculate model outputs in multiple batches.
Subclass and override to inject custom behavior.
"""
@ -483,8 +473,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
@override
def save_model(self, output_dir: Optional[str] = None) -> None:
r"""
Saves model checkpoint.
r"""Save model checkpoint.
Subclass and override to inject custom behavior.
"""
@ -508,5 +497,5 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.model.save_checkpoint(output_dir)
elif self.args.should_save:
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
unwrapped_model: AutoModelForCausalLMWithValueHead = self.accelerator.unwrap_model(self.model)
self._save(output_dir, state_dict=unwrapped_model.state_dict())

View File

@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Optional
from ...data import MultiModalDataCollatorForSeq2Seq, get_dataset, get_template_and_fix_tokenizer
from ...extras.ploting import plot_loss
@ -37,7 +37,7 @@ def run_ppo(
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None,
callbacks: Optional[list["TrainerCallback"]] = None,
):
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
@ -53,7 +53,7 @@ def run_ppo(
reward_model = create_reward_model(model, model_args, finetuning_args)
# Initialize our Trainer
ppo_trainer: "CustomPPOTrainer" = CustomPPOTrainer(
ppo_trainer: CustomPPOTrainer = CustomPPOTrainer(
model_args=model_args,
training_args=training_args,
finetuning_args=finetuning_args,

View File

@ -31,9 +31,7 @@ if TYPE_CHECKING:
class CustomTrainer(Trainer):
r"""
Inherits Trainer for custom optimizer.
"""
r"""Inherit Trainer for custom optimizer."""
def __init__(
self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs

View File

@ -16,7 +16,7 @@
# limitations under the License.
import math
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Optional
from transformers import DataCollatorForLanguageModeling
@ -38,7 +38,7 @@ def run_pt(
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None,
callbacks: Optional[list["TrainerCallback"]] = None,
):
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]

View File

@ -13,7 +13,7 @@
# limitations under the License.
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Optional
from typing import TYPE_CHECKING, Optional
import numpy as np
@ -26,11 +26,9 @@ if TYPE_CHECKING:
@dataclass
class ComputeAccuracy:
r"""
Computes reward accuracy and supports `batch_eval_metrics`.
"""
r"""Compute reward accuracy and support `batch_eval_metrics`."""
def _dump(self) -> Optional[Dict[str, float]]:
def _dump(self) -> Optional[dict[str, float]]:
result = None
if hasattr(self, "score_dict"):
result = {k: float(np.mean(v)) for k, v in self.score_dict.items()}
@ -41,7 +39,7 @@ class ComputeAccuracy:
def __post_init__(self):
self._dump()
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]:
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[dict[str, float]]:
chosen_scores, rejected_scores = numpify(eval_preds.predictions[0]), numpify(eval_preds.predictions[1])
if not chosen_scores.shape:
self.score_dict["accuracy"].append(chosen_scores > rejected_scores)

View File

@ -18,7 +18,7 @@
import json
import os
from types import MethodType
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Optional, Union
import torch
from transformers import Trainer
@ -41,9 +41,7 @@ logger = logging.get_logger(__name__)
class PairwiseTrainer(Trainer):
r"""
Inherits Trainer to compute pairwise loss.
"""
r"""Inherits Trainer to compute pairwise loss."""
def __init__(
self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs
@ -88,10 +86,9 @@ class PairwiseTrainer(Trainer):
@override
def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
r"""
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
self, model: "PreTrainedModel", inputs: dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", tuple["torch.Tensor", list["torch.Tensor"]]]:
r"""Compute pairwise loss. The first n examples are chosen and the last n examples are rejected.
Subclass and override to inject custom behavior.
@ -113,8 +110,7 @@ class PairwiseTrainer(Trainer):
return loss
def save_predictions(self, predict_results: "PredictionOutput") -> None:
r"""
Saves model predictions to `output_dir`.
r"""Save model predictions to `output_dir`.
A custom behavior that not contained in Seq2SeqTrainer.
"""
@ -126,7 +122,7 @@ class PairwiseTrainer(Trainer):
chosen_scores, rejected_scores = predict_results.predictions
with open(output_prediction_file, "w", encoding="utf-8") as writer:
res: List[str] = []
res: list[str] = []
for c_score, r_score in zip(chosen_scores, rejected_scores):
res.append(json.dumps({"chosen": round(float(c_score), 2), "rejected": round(float(r_score), 2)}))

View File

@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Optional
from ...data import PairwiseDataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer
from ...extras.ploting import plot_loss
@ -37,7 +37,7 @@ def run_rm(
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None,
callbacks: Optional[list["TrainerCallback"]] = None,
):
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]

View File

@ -17,7 +17,7 @@
# limitations under the License.
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Optional
from typing import TYPE_CHECKING, Optional
import numpy as np
import torch
@ -45,9 +45,7 @@ if is_rouge_available():
def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "torch.Tensor":
r"""
Computes the token with the largest likelihood to reduce memory footprint.
"""
r"""Compute the token with the largest likelihood to reduce memory footprint."""
if isinstance(logits, (list, tuple)):
if logits[0].dim() == 3: # (batch_size, seq_len, vocab_size)
logits = logits[0]
@ -62,11 +60,9 @@ def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "tor
@dataclass
class ComputeAccuracy:
r"""
Computes accuracy and supports `batch_eval_metrics`.
"""
r"""Compute accuracy and support `batch_eval_metrics`."""
def _dump(self) -> Optional[Dict[str, float]]:
def _dump(self) -> Optional[dict[str, float]]:
result = None
if hasattr(self, "score_dict"):
result = {k: float(np.mean(v)) for k, v in self.score_dict.items()}
@ -77,7 +73,7 @@ class ComputeAccuracy:
def __post_init__(self):
self._dump()
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]:
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[dict[str, float]]:
preds, labels = numpify(eval_preds.predictions), numpify(eval_preds.label_ids)
for i in range(len(preds)):
pred, label = preds[i, :-1], labels[i, 1:]
@ -90,15 +86,14 @@ class ComputeAccuracy:
@dataclass
class ComputeSimilarity:
r"""
Computes text similarity scores and supports `batch_eval_metrics`.
r"""Compute text similarity scores and support `batch_eval_metrics`.
Wraps the tokenizer into metric functions, used in CustomSeq2SeqTrainer.
"""
tokenizer: "PreTrainedTokenizer"
def _dump(self) -> Optional[Dict[str, float]]:
def _dump(self) -> Optional[dict[str, float]]:
result = None
if hasattr(self, "score_dict"):
result = {k: float(np.mean(v)) for k, v in self.score_dict.items()}
@ -109,7 +104,7 @@ class ComputeSimilarity:
def __post_init__(self):
self._dump()
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]:
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[dict[str, float]]:
preds, labels = numpify(eval_preds.predictions), numpify(eval_preds.label_ids)
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)

View File

@ -18,7 +18,7 @@
import json
import os
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Optional, Union
import numpy as np
import torch
@ -44,21 +44,19 @@ logger = logging.get_logger(__name__)
class CustomSeq2SeqTrainer(Seq2SeqTrainer):
r"""
Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE.
"""
r"""Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE."""
def __init__(
self,
finetuning_args: "FinetuningArguments",
processor: Optional["ProcessorMixin"],
gen_kwargs: Optional[Dict[str, Any]] = None,
gen_kwargs: Optional[dict[str, Any]] = None,
**kwargs,
) -> None:
if is_transformers_version_greater_than("4.46"):
kwargs["processing_class"] = kwargs.pop("tokenizer")
else:
self.processing_class: "PreTrainedTokenizer" = kwargs.get("tokenizer")
self.processing_class: PreTrainedTokenizer = kwargs.get("tokenizer")
super().__init__(**kwargs)
self.finetuning_args = finetuning_args
@ -99,13 +97,12 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
def prediction_step(
self,
model: "torch.nn.Module",
inputs: Dict[str, Union["torch.Tensor", Any]],
inputs: dict[str, Union["torch.Tensor", Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
ignore_keys: Optional[list[str]] = None,
**gen_kwargs,
) -> Tuple[Optional[float], Optional["torch.Tensor"], Optional["torch.Tensor"]]:
r"""
Removes the prompt part in the generated tokens.
) -> tuple[Optional[float], Optional["torch.Tensor"], Optional["torch.Tensor"]]:
r"""Remove the prompt part in the generated tokens.
Subclass and override to inject custom behavior.
"""
@ -126,8 +123,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
def save_predictions(
self, dataset: "Dataset", predict_results: "PredictionOutput", skip_special_tokens: bool = True
) -> None:
r"""
Saves model predictions to `output_dir`.
r"""Save model predictions to `output_dir`.
A custom behavior that not contained in Seq2SeqTrainer.
"""

View File

@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Optional
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX
@ -43,7 +43,7 @@ def run_sft(
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None,
callbacks: Optional[list["TrainerCallback"]] = None,
):
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]

View File

@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Dict, Optional, Sequence, Set, Tuple, Union
from collections.abc import Sequence
from typing import TYPE_CHECKING, Optional, Union
import torch
from peft import PeftModel
@ -43,7 +44,7 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_k
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) is True
def check_lora_model(model: "LoraModel") -> Tuple[Set[str], Set[str]]:
def check_lora_model(model: "LoraModel") -> tuple[set[str], set[str]]:
linear_modules, extra_modules = set(), set()
for name, param in model.named_parameters():
if any(module in name for module in ["lora_A", "lora_B"]):
@ -83,7 +84,7 @@ def load_reference_model(
) -> Union["PreTrainedModel", "LoraModel"]:
current_device = get_current_device()
if add_valuehead:
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(
model: AutoModelForCausalLMWithValueHead = AutoModelForCausalLMWithValueHead.from_pretrained(
model_path, torch_dtype=torch.float16, device_map=current_device
)
if not is_trainable:
@ -111,7 +112,7 @@ def load_dataset_module(**kwargs) -> "DatasetModule":
def patch_valuehead_model() -> None:
def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]) -> None:
def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: dict[str, "torch.Tensor"]) -> None:
state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")}
self.v_head.load_state_dict(state_dict, strict=False)
del state_dict

View File

@ -21,7 +21,7 @@ import json
import os
from collections.abc import Mapping
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import torch
from transformers import Trainer
@ -63,12 +63,10 @@ logger = logging.get_logger(__name__)
class DummyOptimizer(torch.optim.Optimizer):
r"""
A dummy optimizer used for the GaLore or APOLLO algorithm.
"""
r"""A dummy optimizer used for the GaLore or APOLLO algorithm."""
def __init__(
self, lr: float = 1e-3, optimizer_dict: Optional[Dict["torch.nn.Parameter", "torch.optim.Optimizer"]] = None
self, lr: float = 1e-3, optimizer_dict: Optional[dict["torch.nn.Parameter", "torch.optim.Optimizer"]] = None
) -> None:
dummy_tensor = torch.randn(1, 1)
self.optimizer_dict = optimizer_dict
@ -112,8 +110,7 @@ def create_modelcard_and_push(
def create_ref_model(
model_args: "ModelArguments", finetuning_args: "FinetuningArguments", add_valuehead: bool = False
) -> Optional[Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]]:
r"""
Creates reference model for PPO/DPO training. Evaluation mode is not supported.
r"""Create reference model for PPO/DPO training. Evaluation mode is not supported.
The valuehead parameter is randomly initialized since it is useless for PPO training.
"""
@ -148,9 +145,7 @@ def create_ref_model(
def create_reward_model(
model: "AutoModelForCausalLMWithValueHead", model_args: "ModelArguments", finetuning_args: "FinetuningArguments"
) -> Optional["AutoModelForCausalLMWithValueHead"]:
r"""
Creates reward model for PPO training.
"""
r"""Create reward model for PPO training."""
if finetuning_args.reward_model_type == "api":
assert finetuning_args.reward_model.startswith("http"), "Please provide full url."
logger.info_rank0(f"Use reward server {finetuning_args.reward_model}")
@ -189,10 +184,8 @@ def create_reward_model(
return reward_model
def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
r"""
Returns a list of names of parameters with weight decay. (weights in non-layernorm layers)
"""
def _get_decay_parameter_names(model: "PreTrainedModel") -> list[str]:
r"""Return a list of names of parameters with weight decay. (weights in non-layernorm layers)."""
decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
return decay_parameters
@ -208,7 +201,7 @@ def _create_galore_optimizer(
else:
galore_targets = finetuning_args.galore_target
galore_params: List["torch.nn.Parameter"] = []
galore_params: list[torch.nn.Parameter] = []
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) and any(target in name for target in galore_targets):
for param in module.parameters():
@ -224,7 +217,7 @@ def _create_galore_optimizer(
id_galore_params = {id(param) for param in galore_params}
decay_params, nodecay_params = [], [] # they are non-galore parameters
trainable_params: List["torch.nn.Parameter"] = [] # galore_params + decay_params + nodecay_params
trainable_params: list[torch.nn.Parameter] = [] # galore_params + decay_params + nodecay_params
decay_param_names = _get_decay_parameter_names(model)
for name, param in model.named_parameters():
if param.requires_grad:
@ -251,7 +244,7 @@ def _create_galore_optimizer(
if training_args.gradient_accumulation_steps != 1:
raise ValueError("Per-layer GaLore does not support gradient accumulation.")
optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {}
optimizer_dict: dict[torch.Tensor, torch.optim.Optimizer] = {}
for param in nodecay_params:
param_groups = [dict(params=[param], weight_decay=0.0)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
@ -296,7 +289,7 @@ def _create_apollo_optimizer(
else:
apollo_targets = finetuning_args.apollo_target
apollo_params: List["torch.nn.Parameter"] = []
apollo_params: list[torch.nn.Parameter] = []
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) and any(target in name for target in apollo_targets):
for param in module.parameters():
@ -315,7 +308,7 @@ def _create_apollo_optimizer(
id_apollo_params = {id(param) for param in apollo_params}
decay_params, nodecay_params = [], [] # they are non-apollo parameters
trainable_params: List["torch.nn.Parameter"] = [] # apollo_params + decay_params + nodecay_params
trainable_params: list[torch.nn.Parameter] = [] # apollo_params + decay_params + nodecay_params
decay_param_names = _get_decay_parameter_names(model)
for name, param in model.named_parameters():
if param.requires_grad:
@ -338,7 +331,7 @@ def _create_apollo_optimizer(
if training_args.gradient_accumulation_steps != 1:
raise ValueError("Per-layer APOLLO does not support gradient accumulation.")
optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {}
optimizer_dict: dict[torch.Tensor, torch.optim.Optimizer] = {}
for param in nodecay_params:
param_groups = [dict(params=[param], weight_decay=0.0)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
@ -380,7 +373,7 @@ def _create_loraplus_optimizer(
embedding_lr = finetuning_args.loraplus_lr_embedding
decay_param_names = _get_decay_parameter_names(model)
param_dict: Dict[str, List["torch.nn.Parameter"]] = {
param_dict: dict[str, list[torch.nn.Parameter]] = {
"lora_a": [],
"lora_b": [],
"lora_b_nodecay": [],
@ -524,7 +517,7 @@ def create_custom_scheduler(
) -> None:
if optimizer is not None and isinstance(optimizer, DummyOptimizer):
optimizer_dict = optimizer.optimizer_dict
scheduler_dict: Dict["torch.nn.Parameter", "torch.optim.lr_scheduler.LRScheduler"] = {}
scheduler_dict: dict[torch.nn.Parameter, torch.optim.lr_scheduler.LRScheduler] = {}
for param in optimizer_dict.keys():
scheduler_dict[param] = get_scheduler(
@ -544,13 +537,13 @@ def create_custom_scheduler(
def get_batch_logps(
logits: "torch.Tensor", labels: "torch.Tensor", label_pad_token_id: int = IGNORE_INDEX
) -> Tuple["torch.Tensor", "torch.Tensor"]:
r"""
Computes the log probabilities of the given labels under the given logits.
) -> tuple["torch.Tensor", "torch.Tensor"]:
r"""Compute the log probabilities of the given labels under the given logits.
Returns:
logps: A tensor of shape (batch_size,) containing the sum of log probabilities.
valid_length: A tensor of shape (batch_size,) containing the number of non-masked tokens.
"""
if logits.shape[:-1] != labels.shape:
raise ValueError("Logits (batchsize x seqlen) and labels must have the same shape.")
@ -564,12 +557,10 @@ def get_batch_logps(
def nested_detach(
tensors: Union["torch.Tensor", List["torch.Tensor"], Tuple["torch.Tensor"], Dict[str, "torch.Tensor"]],
tensors: Union["torch.Tensor", list["torch.Tensor"], tuple["torch.Tensor"], dict[str, "torch.Tensor"]],
clone: bool = False,
):
r"""
Detach `tensors` (even if it's a nested list/tuple/dict of tensors).
"""
r"""Detach `tensors` (even if it's a nested list/tuple/dict of tensors)."""
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_detach(t, clone=clone) for t in tensors)
elif isinstance(tensors, Mapping):
@ -585,9 +576,7 @@ def nested_detach(
def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCallback":
r"""
Gets the callback for logging to SwanLab.
"""
r"""Get the callback for logging to SwanLab."""
import swanlab # type: ignore
from swanlab.integration.transformers import SwanLabCallback # type: ignore
@ -624,7 +613,7 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
def get_ray_trainer(
training_function: Callable,
train_loop_config: Dict[str, Any],
train_loop_config: dict[str, Any],
ray_args: "RayArguments",
) -> "TorchTrainer":
if not ray_args.use_ray:

View File

@ -14,7 +14,7 @@
import os
import shutil
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Optional
import torch
import torch.distributed as dist
@ -48,9 +48,9 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__)
def _training_function(config: Dict[str, Any]) -> None:
def _training_function(config: dict[str, Any]) -> None:
args = config.get("args")
callbacks: List[Any] = config.get("callbacks")
callbacks: list[Any] = config.get("callbacks")
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
callbacks.append(LogCallback())
@ -84,7 +84,7 @@ def _training_function(config: Dict[str, Any]) -> None:
logger.warning(f"Failed to destroy process group: {e}.")
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None) -> None:
def run_exp(args: Optional[dict[str, Any]] = None, callbacks: Optional[list["TrainerCallback"]] = None) -> None:
args = read_args(args)
if "-h" in args or "--help" in args:
get_train_args(args)
@ -103,7 +103,7 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["Tra
_training_function(config={"args": args, "callbacks": callbacks})
def export_model(args: Optional[Dict[str, Any]] = None) -> None:
def export_model(args: Optional[dict[str, Any]] = None) -> None:
model_args, data_args, finetuning_args, _ = get_infer_args(args)
if model_args.export_dir is None:

View File

@ -14,7 +14,8 @@
import json
import os
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
from collections.abc import Generator
from typing import TYPE_CHECKING, Any, Optional
from transformers.utils import is_torch_npu_available
@ -37,15 +38,12 @@ if is_gradio_available():
def _escape_html(text: str) -> str:
r"""
Escapes HTML characters.
"""
r"""Escape HTML characters."""
return text.replace("<", "&lt;").replace(">", "&gt;")
def _format_response(text: str, lang: str, escape_html: bool, thought_words: Tuple[str, str]) -> str:
r"""
Post-processes the response text.
def _format_response(text: str, lang: str, escape_html: bool, thought_words: tuple[str, str]) -> str:
r"""Post-process the response text.
Based on: https://huggingface.co/spaces/Lyte/DeepSeek-R1-Distill-Qwen-1.5B-Demo-GGUF/blob/main/app.py
"""
@ -74,7 +72,7 @@ class WebChatModel(ChatModel):
def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None:
self.manager = manager
self.demo_mode = demo_mode
self.engine: Optional["BaseEngine"] = None
self.engine: Optional[BaseEngine] = None
if not lazy_init: # read arguments from command line
super().__init__()
@ -160,14 +158,13 @@ class WebChatModel(ChatModel):
@staticmethod
def append(
chatbot: List[Dict[str, str]],
messages: List[Dict[str, str]],
chatbot: list[dict[str, str]],
messages: list[dict[str, str]],
role: str,
query: str,
escape_html: bool,
) -> Tuple[List[Dict[str, str]], List[Dict[str, str]], str]:
r"""
Adds the user input to chatbot.
) -> tuple[list[dict[str, str]], list[dict[str, str]], str]:
r"""Add the user input to chatbot.
Inputs: infer.chatbot, infer.messages, infer.role, infer.query, infer.escape_html
Output: infer.chatbot, infer.messages, infer.query
@ -180,8 +177,8 @@ class WebChatModel(ChatModel):
def stream(
self,
chatbot: List[Dict[str, str]],
messages: List[Dict[str, str]],
chatbot: list[dict[str, str]],
messages: list[dict[str, str]],
lang: str,
system: str,
tools: str,
@ -193,9 +190,8 @@ class WebChatModel(ChatModel):
temperature: float,
skip_special_tokens: bool,
escape_html: bool,
) -> Generator[Tuple[List[Dict[str, str]], List[Dict[str, str]]], None, None]:
r"""
Generates output text in stream.
) -> Generator[tuple[list[dict[str, str]], list[dict[str, str]]], None, None]:
r"""Generate output text in stream.
Inputs: infer.chatbot, infer.messages, infer.system, infer.tools, infer.image, infer.video, ...
Output: infer.chatbot, infer.messages

View File

@ -17,7 +17,7 @@ import os
import signal
from collections import defaultdict
from datetime import datetime
from typing import Any, Dict, Optional, Union
from typing import Any, Optional, Union
from psutil import Process
from yaml import safe_dump, safe_load
@ -44,9 +44,7 @@ USER_CONFIG = "user_config.yaml"
def abort_process(pid: int) -> None:
r"""
Aborts the processes recursively in a bottom-up way.
"""
r"""Abort the processes recursively in a bottom-up way."""
try:
children = Process(pid).children()
if children:
@ -59,9 +57,7 @@ def abort_process(pid: int) -> None:
def get_save_dir(*paths: str) -> os.PathLike:
r"""
Gets the path to saved model checkpoints.
"""
r"""Get the path to saved model checkpoints."""
if os.path.sep in paths[-1]:
logger.warning_rank0("Found complex path, some features may be not available.")
return paths[-1]
@ -71,16 +67,12 @@ def get_save_dir(*paths: str) -> os.PathLike:
def _get_config_path() -> os.PathLike:
r"""
Gets the path to user config.
"""
r"""Get the path to user config."""
return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)
def load_config() -> Dict[str, Union[str, Dict[str, Any]]]:
r"""
Loads user config if exists.
"""
def load_config() -> dict[str, Union[str, dict[str, Any]]]:
r"""Load user config if exists."""
try:
with open(_get_config_path(), encoding="utf-8") as f:
return safe_load(f)
@ -89,9 +81,7 @@ def load_config() -> Dict[str, Union[str, Dict[str, Any]]]:
def save_config(lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None) -> None:
r"""
Saves user config.
"""
r"""Save user config."""
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
user_config = load_config()
user_config["lang"] = lang or user_config["lang"]
@ -106,11 +96,9 @@ def save_config(lang: str, model_name: Optional[str] = None, model_path: Optiona
def get_model_path(model_name: str) -> str:
r"""
Gets the model path according to the model name.
"""
r"""Get the model path according to the model name."""
user_config = load_config()
path_dict: Dict["DownloadSource", str] = SUPPORTED_MODELS.get(model_name, defaultdict(str))
path_dict: dict[DownloadSource, str] = SUPPORTED_MODELS.get(model_name, defaultdict(str))
model_path = user_config["path_dict"].get(model_name, "") or path_dict.get(DownloadSource.DEFAULT, "")
if (
use_modelscope()
@ -130,30 +118,22 @@ def get_model_path(model_name: str) -> str:
def get_template(model_name: str) -> str:
r"""
Gets the template name if the model is a chat/distill/instruct model.
"""
r"""Get the template name if the model is a chat/distill/instruct model."""
return DEFAULT_TEMPLATE.get(model_name, "default")
def get_time() -> str:
r"""
Gets current date and time.
"""
r"""Get current date and time."""
return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S")
def is_multimodal(model_name: str) -> bool:
r"""
Judges if the model is a vision language model.
"""
r"""Judge if the model is a vision language model."""
return model_name in MULTIMODAL_SUPPORTED_MODELS
def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
r"""
Loads dataset_info.json.
"""
def load_dataset_info(dataset_dir: str) -> dict[str, dict[str, Any]]:
r"""Load dataset_info.json."""
if dataset_dir == "ONLINE" or dataset_dir.startswith("REMOTE:"):
logger.info_rank0(f"dataset_dir is {dataset_dir}, using online dataset.")
return {}
@ -166,10 +146,8 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
return {}
def load_args(config_path: str) -> Optional[Dict[str, Any]]:
r"""
Loads the training configuration from config path.
"""
def load_args(config_path: str) -> Optional[dict[str, Any]]:
r"""Load the training configuration from config path."""
try:
with open(config_path, encoding="utf-8") as f:
return safe_load(f)
@ -177,26 +155,20 @@ def load_args(config_path: str) -> Optional[Dict[str, Any]]:
return None
def save_args(config_path: str, config_dict: Dict[str, Any]) -> None:
r"""
Saves the training configuration to config path.
"""
def save_args(config_path: str, config_dict: dict[str, Any]) -> None:
r"""Save the training configuration to config path."""
with open(config_path, "w", encoding="utf-8") as f:
safe_dump(config_dict, f)
def _clean_cmd(args: Dict[str, Any]) -> Dict[str, Any]:
r"""
Removes args with NoneType or False or empty string value.
"""
def _clean_cmd(args: dict[str, Any]) -> dict[str, Any]:
r"""Remove args with NoneType or False or empty string value."""
no_skip_keys = ["packing"]
return {k: v for k, v in args.items() if (k in no_skip_keys) or (v is not None and v is not False and v != "")}
def gen_cmd(args: Dict[str, Any]) -> str:
r"""
Generates CLI commands for previewing.
"""
def gen_cmd(args: dict[str, Any]) -> str:
r"""Generate CLI commands for previewing."""
cmd_lines = ["llamafactory-cli train "]
for k, v in _clean_cmd(args).items():
if isinstance(v, dict):
@ -215,10 +187,8 @@ def gen_cmd(args: Dict[str, Any]) -> str:
return cmd_text
def save_cmd(args: Dict[str, Any]) -> str:
r"""
Saves CLI commands to launch training.
"""
def save_cmd(args: dict[str, Any]) -> str:
r"""Save CLI commands to launch training."""
output_dir = args["output_dir"]
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, TRAINING_ARGS), "w", encoding="utf-8") as f:
@ -228,9 +198,7 @@ def save_cmd(args: Dict[str, Any]) -> str:
def load_eval_results(path: os.PathLike) -> str:
r"""
Gets scores after evaluation.
"""
r"""Get scores after evaluation."""
with open(path, encoding="utf-8") as f:
result = json.dumps(json.load(f), indent=4)
@ -238,9 +206,7 @@ def load_eval_results(path: os.PathLike) -> str:
def create_ds_config() -> None:
r"""
Creates deepspeed config in the current directory.
"""
r"""Create deepspeed config in the current directory."""
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
ds_config = {
"train_batch_size": "auto",

View File

@ -13,7 +13,7 @@
# limitations under the License.
import json
from typing import TYPE_CHECKING, Dict, Tuple
from typing import TYPE_CHECKING
from ...data import Role
from ...extras.packages import is_gradio_available
@ -31,9 +31,7 @@ if TYPE_CHECKING:
def check_json_schema(text: str, lang: str) -> None:
r"""
Checks if the json schema is valid.
"""
r"""Check if the json schema is valid."""
try:
tools = json.loads(text)
if tools:
@ -49,7 +47,7 @@ def check_json_schema(text: str, lang: str) -> None:
def create_chat_box(
engine: "Engine", visible: bool = False
) -> Tuple["Component", "Component", Dict[str, "Component"]]:
) -> tuple["Component", "Component", dict[str, "Component"]]:
lang = engine.manager.get_elem_by_id("top.lang")
with gr.Column(visible=visible) as chat_box:
chatbot = gr.Chatbot(type="messages", show_copy_button=True)

View File

@ -14,7 +14,7 @@
import json
import os
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
from typing import TYPE_CHECKING, Any
from ...extras.constants import DATA_CONFIG
from ...extras.packages import is_gradio_available
@ -40,9 +40,7 @@ def next_page(page_index: int, total_num: int) -> int:
def can_preview(dataset_dir: str, dataset: list) -> "gr.Button":
r"""
Checks if the dataset is a local dataset.
"""
r"""Check if the dataset is a local dataset."""
try:
with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f:
dataset_info = json.load(f)
@ -59,7 +57,7 @@ def can_preview(dataset_dir: str, dataset: list) -> "gr.Button":
return gr.Button(interactive=False)
def _load_data_file(file_path: str) -> List[Any]:
def _load_data_file(file_path: str) -> list[Any]:
with open(file_path, encoding="utf-8") as f:
if file_path.endswith(".json"):
return json.load(f)
@ -69,10 +67,8 @@ def _load_data_file(file_path: str) -> List[Any]:
return list(f)
def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, "gr.Column"]:
r"""
Gets the preview samples from the dataset.
"""
def get_preview(dataset_dir: str, dataset: list, page_index: int) -> tuple[int, list, "gr.Column"]:
r"""Get the preview samples from the dataset."""
with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f:
dataset_info = json.load(f)
@ -87,7 +83,7 @@ def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int,
return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.Column(visible=True)
def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dict[str, "Component"]:
def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> dict[str, "Component"]:
data_preview_btn = gr.Button(interactive=False, scale=1)
with gr.Column(visible=False, elem_classes="modal-box") as preview_box:
with gr.Row():

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING
from ...extras.packages import is_gradio_available
from ..common import DEFAULT_DATA_DIR
@ -30,7 +30,7 @@ if TYPE_CHECKING:
from ..engine import Engine
def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
def create_eval_tab(engine: "Engine") -> dict[str, "Component"]:
input_elems = engine.manager.get_base_elems()
elem_dict = dict()

View File

@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Dict, Generator, List, Union
from collections.abc import Generator
from typing import TYPE_CHECKING, Union
from ...extras.constants import PEFT_METHODS
from ...extras.misc import torch_gc
@ -35,7 +36,7 @@ if TYPE_CHECKING:
GPTQ_BITS = ["8", "4", "3", "2"]
def can_quantize(checkpoint_path: Union[str, List[str]]) -> "gr.Dropdown":
def can_quantize(checkpoint_path: Union[str, list[str]]) -> "gr.Dropdown":
if isinstance(checkpoint_path, list) and len(checkpoint_path) != 0:
return gr.Dropdown(value="none", interactive=False)
else:
@ -47,7 +48,7 @@ def save_model(
model_name: str,
model_path: str,
finetuning_type: str,
checkpoint_path: Union[str, List[str]],
checkpoint_path: Union[str, list[str]],
template: str,
export_size: int,
export_quantization_bit: str,
@ -106,7 +107,7 @@ def save_model(
yield ALERTS["info_exported"][lang]
def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
def create_export_tab(engine: "Engine") -> dict[str, "Component"]:
with gr.Row():
export_size = gr.Slider(minimum=1, maximum=100, value=5, step=1)
export_quantization_bit = gr.Dropdown(choices=["none"] + GPTQ_BITS, value="none")

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING
from ...extras.packages import is_gradio_available
from ..common import is_multimodal
@ -29,7 +29,7 @@ if TYPE_CHECKING:
from ..engine import Engine
def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
def create_infer_tab(engine: "Engine") -> dict[str, "Component"]:
input_elems = engine.manager.get_base_elems()
elem_dict = dict()

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING
from ...data import TEMPLATES
from ...extras.constants import METHODS, SUPPORTED_MODELS
@ -29,7 +29,7 @@ if TYPE_CHECKING:
from gradio.components import Component
def create_top() -> Dict[str, "Component"]:
def create_top() -> dict[str, "Component"]:
with gr.Row():
lang = gr.Dropdown(choices=["en", "ru", "zh", "ko", "ja"], value=None, scale=1)
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]

Some files were not shown because too many files have changed in this diff Show More