mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-07-31 10:42:50 +08:00
[misc] upgrade format to py39 (#7256)
This commit is contained in:
parent
cdafa8a15e
commit
7c1640ed5f
@ -10,7 +10,7 @@ _DESCRIPTION = "BELLE multiturn chat dataset."
|
||||
|
||||
_CITATION = """\
|
||||
@article{belle2023exploring,
|
||||
title={Exploring the Impact of Instruction Data Scaling on Large Language Models: An Empirical Study on Real-World Use Cases},
|
||||
title={Exploring the Impact of Instruction Data Scaling on Large Language Models},
|
||||
author={Yunjie Ji, Yong Deng, Yan Gong, Yiping Peng, Qiang Niu, Lei Zhang, Baochang Ma, Xiangang Li},
|
||||
journal={arXiv preprint arXiv:2303.14742},
|
||||
year={2023}
|
||||
|
@ -1,6 +1,5 @@
|
||||
import json
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import datasets
|
||||
|
||||
@ -50,7 +49,7 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
|
||||
datasets.SplitGenerator(name=datasets.Split.TEST, gen_kwargs={"filepaths": file_path["test"]}),
|
||||
]
|
||||
|
||||
def _generate_examples(self, filepaths: List[str]):
|
||||
def _generate_examples(self, filepaths: list[str]):
|
||||
key = 0
|
||||
for filepath in filepaths:
|
||||
with open(filepath, encoding="utf-8") as f:
|
||||
|
@ -1,6 +1,5 @@
|
||||
import json
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import datasets
|
||||
|
||||
@ -11,7 +10,7 @@ _DESCRIPTION = "UltraChat: Large-scale, Informative, and Diverse Multi-round Dia
|
||||
|
||||
_CITATION = """\
|
||||
@misc{UltraChat,
|
||||
author = {Ding, Ning and Chen, Yulin and Xu, Bokai and Hu, Shengding and Qin, Yujia and Liu, Zhiyuan and Sun, Maosong and Zhou, Bowen},
|
||||
author = {Ding, Ning and Chen, Yulin and Xu, Bokai and Hu, Shengding and others},
|
||||
title = {UltraChat: A Large-scale Auto-generated Multi-round Dialogue Data},
|
||||
year = {2023},
|
||||
publisher = {GitHub},
|
||||
@ -40,7 +39,7 @@ class UltraChat(datasets.GeneratorBasedBuilder):
|
||||
file_paths = [dl_manager.download(_BASE_DATA_URL.format(idx=idx)) for idx in range(10)] # multiple shards
|
||||
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepaths": file_paths})]
|
||||
|
||||
def _generate_examples(self, filepaths: List[str]):
|
||||
def _generate_examples(self, filepaths: list[str]):
|
||||
for filepath in filepaths:
|
||||
with open(filepath, encoding="utf-8") as f:
|
||||
for row in f:
|
||||
@ -49,7 +48,7 @@ class UltraChat(datasets.GeneratorBasedBuilder):
|
||||
except Exception:
|
||||
continue
|
||||
key: int = data["id"]
|
||||
content: List[str] = data["data"]
|
||||
content: list[str] = data["data"]
|
||||
if len(content) % 2 == 1:
|
||||
content.pop(-1)
|
||||
if len(content) < 2:
|
||||
|
@ -21,14 +21,15 @@ import pandas as pd
|
||||
_CITATION = """\
|
||||
@article{huang2023ceval,
|
||||
title={C-Eval: A Multi-Level Multi-Discipline Chinese Evaluation Suite for Foundation Models},
|
||||
author={Huang, Yuzhen and Bai, Yuzhuo and Zhu, Zhihao and Zhang, Junlei and Zhang, Jinghan and Su, Tangjun and Liu, Junteng and Lv, Chuancheng and Zhang, Yikai and Lei, Jiayi and Fu, Yao and Sun, Maosong and He, Junxian},
|
||||
author={Huang, Yuzhen and Bai, Yuzhuo and Zhu, Zhihao and others},
|
||||
journal={arXiv preprint arXiv:2305.08322},
|
||||
year={2023}
|
||||
}
|
||||
"""
|
||||
|
||||
_DESCRIPTION = """\
|
||||
C-Eval is a comprehensive Chinese evaluation suite for foundation models. It consists of 13948 multi-choice questions spanning 52 diverse disciplines and four difficulty levels.
|
||||
C-Eval is a comprehensive Chinese evaluation suite for foundation models.
|
||||
It consists of 13948 multi-choice questions spanning 52 diverse disciplines and four difficulty levels.
|
||||
"""
|
||||
|
||||
_HOMEPAGE = "https://cevalbenchmark.com"
|
||||
|
@ -21,14 +21,15 @@ import pandas as pd
|
||||
_CITATION = """\
|
||||
@article{li2023cmmlu,
|
||||
title={CMMLU: Measuring massive multitask language understanding in Chinese},
|
||||
author={Haonan Li and Yixuan Zhang and Fajri Koto and Yifei Yang and Hai Zhao and Yeyun Gong and Nan Duan and Timothy Baldwin},
|
||||
author={Haonan Li and Yixuan Zhang and Fajri Koto and Yifei Yang and others,
|
||||
journal={arXiv preprint arXiv:2306.09212},
|
||||
year={2023}
|
||||
}
|
||||
"""
|
||||
|
||||
_DESCRIPTION = """\
|
||||
CMMLU is a comprehensive Chinese assessment suite specifically designed to evaluate the advanced knowledge and reasoning abilities of LLMs within the Chinese language and cultural context.
|
||||
CMMLU is a comprehensive Chinese assessment suite specifically designed to evaluate the advanced knowledge
|
||||
and reasoning abilities of LLMs within the Chinese language and cultural context.
|
||||
"""
|
||||
|
||||
_HOMEPAGE = "https://github.com/haonan-li/CMMLU"
|
||||
|
@ -21,14 +21,15 @@ import pandas as pd
|
||||
_CITATION = """\
|
||||
@article{hendryckstest2021,
|
||||
title={Measuring Massive Multitask Language Understanding},
|
||||
author={Dan Hendrycks and Collin Burns and Steven Basart and Andy Zou and Mantas Mazeika and Dawn Song and Jacob Steinhardt},
|
||||
author={Dan Hendrycks and Collin Burns and others},
|
||||
journal={Proceedings of the International Conference on Learning Representations (ICLR)},
|
||||
year={2021}
|
||||
}
|
||||
"""
|
||||
|
||||
_DESCRIPTION = """\
|
||||
Measuring Massive Multitask Language Understanding by Dan Hendrycks, Collin Burns, Steven Basart, Andy Zou, Mantas Mazeika, Dawn Song, and Jacob Steinhardt (ICLR 2021).
|
||||
Measuring Massive Multitask Language Understanding by Dan Hendrycks, Collin Burns, Steven Basart,
|
||||
Andy Zou, Mantas Mazeika, Dawn Song, and Jacob Steinhardt (ICLR 2021).
|
||||
"""
|
||||
|
||||
_HOMEPAGE = "https://github.com/hendrycks/test"
|
||||
|
@ -19,13 +19,35 @@ dynamic = [
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py38"
|
||||
target-version = "py39"
|
||||
line-length = 119
|
||||
indent-width = 4
|
||||
|
||||
[tool.ruff.lint]
|
||||
ignore = ["C408", "C901", "E501", "E731", "E741", "W605"]
|
||||
select = ["C", "E", "F", "I", "W"]
|
||||
ignore = [
|
||||
"C408", # collection
|
||||
"C901", # complex
|
||||
"E731", # lambda function
|
||||
"E741", # ambiguous var name
|
||||
"D100", # no doc public module
|
||||
"D101", # no doc public class
|
||||
"D102", # no doc public method
|
||||
"D103", # no doc public function
|
||||
"D104", # no doc public package
|
||||
"D105", # no doc magic method
|
||||
"D107", # no doc __init__
|
||||
]
|
||||
extend-select = [
|
||||
"C", # complexity
|
||||
"E", # error
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
"W", # warning
|
||||
"UP", # pyupgrade
|
||||
"D", # pydocstyle
|
||||
"PT009", # pytest assert
|
||||
"RUF022", # sort __all__
|
||||
]
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
lines-after-imports = 2
|
||||
@ -41,6 +63,9 @@ known-third-party = [
|
||||
"trl"
|
||||
]
|
||||
|
||||
[tool.ruff.lint.pydocstyle]
|
||||
convention = "google"
|
||||
|
||||
[tool.ruff.format]
|
||||
quote-style = "double"
|
||||
indent-style = "space"
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Sequence
|
||||
from collections.abc import Sequence
|
||||
|
||||
from openai import OpenAI
|
||||
from transformers.utils.versions import require_version
|
||||
|
@ -15,7 +15,7 @@
|
||||
import json
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
import fire
|
||||
import torch
|
||||
@ -29,13 +29,13 @@ CONFIG_NAME = "config.json"
|
||||
|
||||
|
||||
def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetensors: bool):
|
||||
baichuan2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
|
||||
baichuan2_state_dict: dict[str, torch.Tensor] = OrderedDict()
|
||||
for filepath in tqdm(os.listdir(input_dir), desc="Load weights"):
|
||||
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".bin"):
|
||||
shard_weight = torch.load(os.path.join(input_dir, filepath), map_location="cpu")
|
||||
baichuan2_state_dict.update(shard_weight)
|
||||
|
||||
llama_state_dict: Dict[str, torch.Tensor] = OrderedDict()
|
||||
llama_state_dict: dict[str, torch.Tensor] = OrderedDict()
|
||||
for key, value in tqdm(baichuan2_state_dict.items(), desc="Convert format"):
|
||||
if "W_pack" in key:
|
||||
proj_size = value.size(0) // 3
|
||||
@ -75,7 +75,7 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
|
||||
|
||||
def save_config(input_dir: str, output_dir: str):
|
||||
with open(os.path.join(input_dir, CONFIG_NAME), encoding="utf-8") as f:
|
||||
llama2_config_dict: Dict[str, Any] = json.load(f)
|
||||
llama2_config_dict: dict[str, Any] = json.load(f)
|
||||
|
||||
llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
|
||||
llama2_config_dict.pop("auto_map", None)
|
||||
@ -94,8 +94,8 @@ def llamafy_baichuan2(
|
||||
shard_size: str = "2GB",
|
||||
save_safetensors: bool = True,
|
||||
):
|
||||
r"""
|
||||
Converts the Baichuan2-7B model in the same format as LLaMA2-7B.
|
||||
r"""Convert the Baichuan2-7B model in the same format as LLaMA2-7B.
|
||||
|
||||
Usage: python llamafy_baichuan2.py --input_dir input --output_dir output
|
||||
Converted model: https://huggingface.co/hiyouga/Baichuan2-7B-Base-LLaMAfied
|
||||
"""
|
||||
|
@ -15,7 +15,7 @@
|
||||
import json
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
import fire
|
||||
import torch
|
||||
@ -37,14 +37,14 @@ CONFIG_NAME = "config.json"
|
||||
|
||||
|
||||
def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetensors: bool) -> str:
|
||||
qwen_state_dict: Dict[str, torch.Tensor] = OrderedDict()
|
||||
qwen_state_dict: dict[str, torch.Tensor] = OrderedDict()
|
||||
for filepath in tqdm(os.listdir(input_dir), desc="Load weights"):
|
||||
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".safetensors"):
|
||||
with safe_open(os.path.join(input_dir, filepath), framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
qwen_state_dict[key] = f.get_tensor(key)
|
||||
|
||||
llama_state_dict: Dict[str, torch.Tensor] = OrderedDict()
|
||||
llama_state_dict: dict[str, torch.Tensor] = OrderedDict()
|
||||
torch_dtype = None
|
||||
for key, value in tqdm(qwen_state_dict.items(), desc="Convert format"):
|
||||
if torch_dtype is None:
|
||||
@ -112,9 +112,9 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
|
||||
|
||||
def save_config(input_dir: str, output_dir: str, torch_dtype: str):
|
||||
with open(os.path.join(input_dir, CONFIG_NAME), encoding="utf-8") as f:
|
||||
qwen_config_dict: Dict[str, Any] = json.load(f)
|
||||
qwen_config_dict: dict[str, Any] = json.load(f)
|
||||
|
||||
llama2_config_dict: Dict[str, Any] = OrderedDict()
|
||||
llama2_config_dict: dict[str, Any] = OrderedDict()
|
||||
llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
|
||||
llama2_config_dict["hidden_act"] = "silu"
|
||||
llama2_config_dict["hidden_size"] = qwen_config_dict["hidden_size"]
|
||||
@ -147,8 +147,8 @@ def llamafy_qwen(
|
||||
shard_size: str = "2GB",
|
||||
save_safetensors: bool = False,
|
||||
):
|
||||
r"""
|
||||
Converts the Qwen models in the same format as LLaMA2.
|
||||
r"""Convert the Qwen models in the same format as LLaMA2.
|
||||
|
||||
Usage: python llamafy_qwen.py --input_dir input --output_dir output
|
||||
Converted model: https://huggingface.co/hiyouga/Qwen-14B-Chat-LLaMAfied
|
||||
"""
|
||||
|
@ -18,7 +18,7 @@
|
||||
import json
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import fire
|
||||
import torch
|
||||
@ -44,11 +44,11 @@ def block_expansion(
|
||||
shard_size: str = "5GB",
|
||||
save_safetensors: bool = True,
|
||||
):
|
||||
r"""
|
||||
Performs block expansion for LLaMA, Mistral, Qwen2 or Yi models.
|
||||
r"""Perform block expansion for LLaMA, Mistral, Qwen2 or Yi models.
|
||||
|
||||
Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8
|
||||
"""
|
||||
config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
|
||||
config: PretrainedConfig = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
|
||||
num_layers = getattr(config, "num_hidden_layers")
|
||||
if num_layers % num_expand != 0:
|
||||
raise ValueError(f"`num_layers` {num_layers} should be divisible by `num_expand` {num_expand}.")
|
||||
@ -70,7 +70,7 @@ def block_expansion(
|
||||
split = num_layers // num_expand
|
||||
layer_cnt = 0
|
||||
state_dict = model.state_dict()
|
||||
output_state_dict: Dict[str, "torch.Tensor"] = OrderedDict()
|
||||
output_state_dict: dict[str, torch.Tensor] = OrderedDict()
|
||||
for i in range(num_layers):
|
||||
for key, value in state_dict.items():
|
||||
if f".{i:d}." in key:
|
||||
|
@ -38,8 +38,8 @@ def quantize_loftq(
|
||||
lora_target: tuple = ("q_proj", "v_proj"),
|
||||
save_safetensors: bool = True,
|
||||
):
|
||||
r"""
|
||||
Initializes LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ)
|
||||
r"""Initialize LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ).
|
||||
|
||||
Usage: python loftq_init.py --model_name_or_path path_to_model --output_dir output_dir
|
||||
"""
|
||||
if isinstance(lora_target, str):
|
||||
@ -72,7 +72,7 @@ def quantize_loftq(
|
||||
print(f"Adapter weights saved in {loftq_dir}")
|
||||
|
||||
# Save base model
|
||||
base_model: "PreTrainedModel" = peft_model.unload()
|
||||
base_model: PreTrainedModel = peft_model.unload()
|
||||
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
print(f"Model weights saved in {output_dir}")
|
||||
|
@ -37,8 +37,8 @@ def quantize_pissa(
|
||||
lora_target: tuple = ("q_proj", "v_proj"),
|
||||
save_safetensors: bool = True,
|
||||
):
|
||||
r"""
|
||||
Initializes LoRA weights with Principal Singular values and Singular vectors Adaptation (PiSSA)
|
||||
r"""Initialize LoRA weights with Principal Singular values and Singular vectors Adaptation (PiSSA).
|
||||
|
||||
Usage: python pissa_init.py --model_name_or_path path_to_model --output_dir output_dir
|
||||
"""
|
||||
if isinstance(lora_target, str):
|
||||
@ -67,7 +67,7 @@ def quantize_pissa(
|
||||
print(f"Adapter weights saved in {pissa_dir}")
|
||||
|
||||
# Save base model
|
||||
base_model: "PreTrainedModel" = peft_model.unload()
|
||||
base_model: PreTrainedModel = peft_model.unload()
|
||||
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
print(f"Model weights saved in {output_dir}")
|
||||
|
@ -29,8 +29,8 @@ def calculate_flops(
|
||||
seq_length: int = 512,
|
||||
flash_attn: str = "auto",
|
||||
):
|
||||
r"""
|
||||
Calculates the flops of pre-trained models.
|
||||
r"""Calculate the flops of pre-trained models.
|
||||
|
||||
Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512
|
||||
"""
|
||||
with get_accelerator().device(0):
|
||||
|
@ -45,8 +45,8 @@ def calculate_lr(
|
||||
is_mistral_or_gemma: bool = False, # mistral and gemma models opt for a smaller learning rate,
|
||||
packing: bool = False,
|
||||
):
|
||||
r"""
|
||||
Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
|
||||
r"""Calculate the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
|
||||
|
||||
Usage:
|
||||
python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en_demo --cutoff_len 1024 --batch_size 16
|
||||
"""
|
||||
@ -89,9 +89,8 @@ def calculate_lr(
|
||||
lr = BASE_LR * math.sqrt(token_batch_size / BASE_BS) # lr ~ sqrt(batch_size)
|
||||
lr = lr / 6.0 if is_mistral_or_gemma else lr
|
||||
print(
|
||||
"Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective token batch size {:.2f}".format(
|
||||
lr, valid_ratio * 100, token_batch_size
|
||||
)
|
||||
f"Optimal learning rate is {lr:.2e} for valid ratio% {valid_ratio * 100:.2f} "
|
||||
f"and effective token batch size {token_batch_size:.2f}"
|
||||
)
|
||||
|
||||
|
||||
|
@ -34,9 +34,7 @@ def compute_model_flops(
|
||||
include_recompute: bool = False,
|
||||
include_flashattn: bool = False,
|
||||
) -> int:
|
||||
r"""
|
||||
Calculates the FLOPs of model per forward/backward pass.
|
||||
"""
|
||||
r"""Calculate the FLOPs of model per forward/backward pass."""
|
||||
config = AutoConfig.from_pretrained(model_name_or_path)
|
||||
hidden_size = getattr(config, "hidden_size", None)
|
||||
vocab_size = getattr(config, "vocab_size", None)
|
||||
@ -86,9 +84,7 @@ def compute_model_flops(
|
||||
|
||||
|
||||
def compute_device_flops(world_size: int) -> float:
|
||||
r"""
|
||||
Calculates the FLOPs of the device capability per second.
|
||||
"""
|
||||
r"""Calculate the FLOPs of the device capability per second."""
|
||||
device_name = torch.cuda.get_device_name()
|
||||
if "H100" in device_name or "H800" in device_name:
|
||||
return 989 * 1e12 * world_size
|
||||
@ -114,8 +110,8 @@ def calculate_mfu(
|
||||
liger_kernel: bool = False,
|
||||
unsloth_gc: bool = False,
|
||||
) -> float:
|
||||
r"""
|
||||
Calculates MFU for given model and hyper-params.
|
||||
r"""Calculate MFU for given model and hyper-params.
|
||||
|
||||
Usage: python cal_mfu.py --model_name_or_path path_to_model --batch_size 1 --seq_length 1024
|
||||
"""
|
||||
args = {
|
||||
|
@ -13,8 +13,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Literal, Optional, Sequence
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
import fire
|
||||
import torch
|
||||
@ -30,16 +31,12 @@ from llamafactory.model import load_model, load_tokenizer
|
||||
|
||||
@dataclass
|
||||
class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator for pairwise data.
|
||||
"""
|
||||
r"""Data collator for pairwise data."""
|
||||
|
||||
train_on_prompt: bool = False
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
||||
r"""
|
||||
Pads batched data to the longest sequence in the batch.
|
||||
"""
|
||||
def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, torch.Tensor]:
|
||||
r"""Pad batched data to the longest sequence in the batch."""
|
||||
chosen_features = []
|
||||
for feature in features:
|
||||
chosen_features.append(
|
||||
@ -68,8 +65,8 @@ def calculate_ppl(
|
||||
max_samples: Optional[int] = None,
|
||||
train_on_prompt: bool = False,
|
||||
):
|
||||
r"""
|
||||
Calculates the ppl on the dataset of the pre-trained models.
|
||||
r"""Calculate the ppl on the dataset of the pre-trained models.
|
||||
|
||||
Usage: export CUDA_VISIBLE_DEVICES=0
|
||||
python cal_ppl.py --model_name_or_path path_to_model --dataset alpaca_en_demo --save_name ppl.json
|
||||
"""
|
||||
@ -111,17 +108,17 @@ def calculate_ppl(
|
||||
criterion = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
total_ppl = 0
|
||||
perplexities = []
|
||||
batch: Dict[str, "torch.Tensor"]
|
||||
batch: dict[str, torch.Tensor]
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(dataloader, desc="Computing perplexities"):
|
||||
batch = batch.to(model.device)
|
||||
outputs = model(**batch)
|
||||
shift_logits: "torch.Tensor" = outputs["logits"][..., :-1, :]
|
||||
shift_labels: "torch.Tensor" = batch["labels"][..., 1:]
|
||||
shift_logits: torch.Tensor = outputs["logits"][..., :-1, :]
|
||||
shift_labels: torch.Tensor = batch["labels"][..., 1:]
|
||||
loss_mask = shift_labels != IGNORE_INDEX
|
||||
flatten_logits = shift_logits.contiguous().view(shift_labels.size(0) * shift_labels.size(1), -1)
|
||||
flatten_labels = shift_labels.contiguous().view(-1)
|
||||
token_logps: "torch.Tensor" = criterion(flatten_logits, flatten_labels)
|
||||
token_logps: torch.Tensor = criterion(flatten_logits, flatten_labels)
|
||||
token_logps = token_logps.contiguous().view(shift_logits.size(0), -1)
|
||||
sentence_logps = (token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
||||
total_ppl += sentence_logps.exp().sum().item()
|
||||
|
@ -29,8 +29,8 @@ def length_cdf(
|
||||
template: str = "default",
|
||||
interval: int = 1000,
|
||||
):
|
||||
r"""
|
||||
Calculates the distribution of the input lengths in the dataset.
|
||||
r"""Calculate the distribution of the input lengths in the dataset.
|
||||
|
||||
Usage: export CUDA_VISIBLE_DEVICES=0
|
||||
python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en_demo --template default
|
||||
"""
|
||||
|
@ -52,8 +52,8 @@ def vllm_infer(
|
||||
image_max_pixels: int = 768 * 768,
|
||||
image_min_pixels: int = 32 * 32,
|
||||
):
|
||||
r"""
|
||||
Performs batch generation using vLLM engine, which supports tensor parallelism.
|
||||
r"""Perform batch generation using vLLM engine, which supports tensor parallelism.
|
||||
|
||||
Usage: python vllm_infer.py --model_name_or_path meta-llama/Llama-2-7b-hf --template llama --dataset alpaca_en_demo
|
||||
"""
|
||||
check_version("vllm>=0.4.3,<=0.7.3")
|
||||
|
5
setup.py
5
setup.py
@ -14,7 +14,6 @@
|
||||
|
||||
import os
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
@ -27,14 +26,14 @@ def get_version() -> str:
|
||||
return version
|
||||
|
||||
|
||||
def get_requires() -> List[str]:
|
||||
def get_requires() -> list[str]:
|
||||
with open("requirements.txt", encoding="utf-8") as f:
|
||||
file_content = f.read()
|
||||
lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
|
||||
return lines
|
||||
|
||||
|
||||
def get_console_scripts() -> List[str]:
|
||||
def get_console_scripts() -> list[str]:
|
||||
console_scripts = ["llamafactory-cli = llamafactory.cli:main"]
|
||||
if os.getenv("ENABLE_SHORT_CONSOLE", "1").lower() in ["true", "y", "1"]:
|
||||
console_scripts.append("lmf = llamafactory.cli:main")
|
||||
|
@ -12,8 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
r"""
|
||||
Efficient fine-tuning of large language models.
|
||||
r"""Efficient fine-tuning of large language models.
|
||||
|
||||
Level:
|
||||
api, webui > chat, eval, train > data, model > hparams > extras
|
||||
|
@ -16,9 +16,7 @@ import asyncio
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
from typing_extensions import Annotated
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from ..chat import ChatModel
|
||||
from ..extras.constants import EngineName
|
||||
|
@ -18,7 +18,8 @@ import json
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from ..data import Role as DataRole
|
||||
from ..extras import logging
|
||||
@ -71,7 +72,7 @@ ROLE_MAPPING = {
|
||||
|
||||
def _process_request(
|
||||
request: "ChatCompletionRequest",
|
||||
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional[List["ImageInput"]]]:
|
||||
) -> tuple[list[dict[str, str]], Optional[str], Optional[str], Optional[list["ImageInput"]]]:
|
||||
if is_env_enabled("API_VERBOSE", "1"):
|
||||
logger.info_rank0(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}")
|
||||
|
||||
|
@ -13,14 +13,14 @@
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def dictify(data: "BaseModel") -> Dict[str, Any]:
|
||||
def dictify(data: "BaseModel") -> dict[str, Any]:
|
||||
try: # pydantic v2
|
||||
return data.model_dump(exclude_unset=True)
|
||||
except AttributeError: # pydantic v1
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
import time
|
||||
from enum import Enum, unique
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Literal
|
||||
@ -45,7 +45,7 @@ class ModelCard(BaseModel):
|
||||
|
||||
class ModelList(BaseModel):
|
||||
object: Literal["list"] = "list"
|
||||
data: List[ModelCard] = []
|
||||
data: list[ModelCard] = []
|
||||
|
||||
|
||||
class Function(BaseModel):
|
||||
@ -56,7 +56,7 @@ class Function(BaseModel):
|
||||
class FunctionDefinition(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
parameters: Dict[str, Any]
|
||||
parameters: dict[str, Any]
|
||||
|
||||
|
||||
class FunctionAvailable(BaseModel):
|
||||
@ -82,26 +82,26 @@ class MultimodalInputItem(BaseModel):
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: Role
|
||||
content: Optional[Union[str, List[MultimodalInputItem]]] = None
|
||||
tool_calls: Optional[List[FunctionCall]] = None
|
||||
content: Optional[Union[str, list[MultimodalInputItem]]] = None
|
||||
tool_calls: Optional[list[FunctionCall]] = None
|
||||
|
||||
|
||||
class ChatCompletionMessage(BaseModel):
|
||||
role: Optional[Role] = None
|
||||
content: Optional[str] = None
|
||||
tool_calls: Optional[List[FunctionCall]] = None
|
||||
tool_calls: Optional[list[FunctionCall]] = None
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[ChatMessage]
|
||||
tools: Optional[List[FunctionAvailable]] = None
|
||||
messages: list[ChatMessage]
|
||||
tools: Optional[list[FunctionAvailable]] = None
|
||||
do_sample: Optional[bool] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
n: int = 1
|
||||
max_tokens: Optional[int] = None
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
stop: Optional[Union[str, list[str]]] = None
|
||||
stream: bool = False
|
||||
|
||||
|
||||
@ -128,7 +128,7 @@ class ChatCompletionResponse(BaseModel):
|
||||
object: Literal["chat.completion"] = "chat.completion"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseChoice]
|
||||
choices: list[ChatCompletionResponseChoice]
|
||||
usage: ChatCompletionResponseUsage
|
||||
|
||||
|
||||
@ -137,12 +137,12 @@ class ChatCompletionStreamResponse(BaseModel):
|
||||
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[ChatCompletionStreamResponseChoice]
|
||||
choices: list[ChatCompletionStreamResponseChoice]
|
||||
|
||||
|
||||
class ScoreEvaluationRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[str]
|
||||
messages: list[str]
|
||||
max_length: Optional[int] = None
|
||||
|
||||
|
||||
@ -150,4 +150,4 @@ class ScoreEvaluationResponse(BaseModel):
|
||||
id: str
|
||||
object: Literal["score.evaluation"] = "score.evaluation"
|
||||
model: str
|
||||
scores: List[float]
|
||||
scores: list[float]
|
||||
|
@ -13,8 +13,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncGenerator, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Union
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -36,8 +37,7 @@ class Response:
|
||||
|
||||
|
||||
class BaseEngine(ABC):
|
||||
r"""
|
||||
Base class for inference engine of chat models.
|
||||
r"""Base class for inference engine of chat models.
|
||||
|
||||
Must implements async methods: chat(), stream_chat() and get_scores().
|
||||
"""
|
||||
@ -47,7 +47,7 @@ class BaseEngine(ABC):
|
||||
tokenizer: "PreTrainedTokenizer"
|
||||
can_generate: bool
|
||||
template: "Template"
|
||||
generating_args: Dict[str, Any]
|
||||
generating_args: dict[str, Any]
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
@ -57,31 +57,27 @@ class BaseEngine(ABC):
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
) -> None:
|
||||
r"""
|
||||
Initializes an inference engine.
|
||||
"""
|
||||
r"""Initialize an inference engine."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
audios: Optional[Sequence["AudioInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
r"""
|
||||
Gets a list of responses of the chat model.
|
||||
"""
|
||||
) -> list["Response"]:
|
||||
r"""Get a list of responses of the chat model."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def stream_chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
@ -89,18 +85,14 @@ class BaseEngine(ABC):
|
||||
audios: Optional[Sequence["AudioInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
r"""
|
||||
Gets the response token-by-token of the chat model.
|
||||
"""
|
||||
r"""Get the response token-by-token of the chat model."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_scores(
|
||||
self,
|
||||
batch_input: List[str],
|
||||
batch_input: list[str],
|
||||
**input_kwargs,
|
||||
) -> List[float]:
|
||||
r"""
|
||||
Gets a list of scores of the reward model.
|
||||
"""
|
||||
) -> list[float]:
|
||||
r"""Get a list of scores of the reward model."""
|
||||
...
|
||||
|
@ -17,8 +17,9 @@
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from collections.abc import AsyncGenerator, Generator, Sequence
|
||||
from threading import Thread
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from ..extras.constants import EngineName
|
||||
from ..extras.misc import torch_gc
|
||||
@ -38,20 +39,19 @@ def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
|
||||
|
||||
|
||||
class ChatModel:
|
||||
r"""
|
||||
General class for chat models. Backed by huggingface or vllm engines.
|
||||
r"""General class for chat models. Backed by huggingface or vllm engines.
|
||||
|
||||
Supports both sync and async methods.
|
||||
Sync methods: chat(), stream_chat() and get_scores().
|
||||
Async methods: achat(), astream_chat() and aget_scores().
|
||||
"""
|
||||
|
||||
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
||||
def __init__(self, args: Optional[dict[str, Any]] = None) -> None:
|
||||
model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
|
||||
if model_args.infer_backend == EngineName.HF:
|
||||
self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
|
||||
self.engine: BaseEngine = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
|
||||
elif model_args.infer_backend == EngineName.VLLM:
|
||||
self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
|
||||
self.engine: BaseEngine = VllmEngine(model_args, data_args, finetuning_args, generating_args)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")
|
||||
|
||||
@ -61,17 +61,15 @@ class ChatModel:
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
audios: Optional[Sequence["AudioInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
r"""
|
||||
Gets a list of responses of the chat model.
|
||||
"""
|
||||
) -> list["Response"]:
|
||||
r"""Get a list of responses of the chat model."""
|
||||
task = asyncio.run_coroutine_threadsafe(
|
||||
self.achat(messages, system, tools, images, videos, audios, **input_kwargs), self._loop
|
||||
)
|
||||
@ -79,22 +77,20 @@ class ChatModel:
|
||||
|
||||
async def achat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
audios: Optional[Sequence["AudioInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
r"""
|
||||
Asynchronously gets a list of responses of the chat model.
|
||||
"""
|
||||
) -> list["Response"]:
|
||||
r"""Asynchronously get a list of responses of the chat model."""
|
||||
return await self.engine.chat(messages, system, tools, images, videos, audios, **input_kwargs)
|
||||
|
||||
def stream_chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
@ -102,9 +98,7 @@ class ChatModel:
|
||||
audios: Optional[Sequence["AudioInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> Generator[str, None, None]:
|
||||
r"""
|
||||
Gets the response token-by-token of the chat model.
|
||||
"""
|
||||
r"""Get the response token-by-token of the chat model."""
|
||||
generator = self.astream_chat(messages, system, tools, images, videos, audios, **input_kwargs)
|
||||
while True:
|
||||
try:
|
||||
@ -115,7 +109,7 @@ class ChatModel:
|
||||
|
||||
async def astream_chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
@ -123,9 +117,7 @@ class ChatModel:
|
||||
audios: Optional[Sequence["AudioInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
r"""
|
||||
Asynchronously gets the response token-by-token of the chat model.
|
||||
"""
|
||||
r"""Asynchronously get the response token-by-token of the chat model."""
|
||||
async for new_token in self.engine.stream_chat(
|
||||
messages, system, tools, images, videos, audios, **input_kwargs
|
||||
):
|
||||
@ -133,23 +125,19 @@ class ChatModel:
|
||||
|
||||
def get_scores(
|
||||
self,
|
||||
batch_input: List[str],
|
||||
batch_input: list[str],
|
||||
**input_kwargs,
|
||||
) -> List[float]:
|
||||
r"""
|
||||
Gets a list of scores of the reward model.
|
||||
"""
|
||||
) -> list[float]:
|
||||
r"""Get a list of scores of the reward model."""
|
||||
task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop)
|
||||
return task.result()
|
||||
|
||||
async def aget_scores(
|
||||
self,
|
||||
batch_input: List[str],
|
||||
batch_input: list[str],
|
||||
**input_kwargs,
|
||||
) -> List[float]:
|
||||
r"""
|
||||
Asynchronously gets a list of scores of the reward model.
|
||||
"""
|
||||
) -> list[float]:
|
||||
r"""Asynchronously get a list of scores of the reward model."""
|
||||
return await self.engine.get_scores(batch_input, **input_kwargs)
|
||||
|
||||
|
||||
|
@ -15,8 +15,9 @@
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import os
|
||||
from collections.abc import AsyncGenerator, Sequence
|
||||
from threading import Thread
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import GenerationConfig, TextIteratorStreamer
|
||||
@ -76,15 +77,15 @@ class HuggingfaceEngine(BaseEngine):
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
template: "Template",
|
||||
generating_args: Dict[str, Any],
|
||||
messages: Sequence[Dict[str, str]],
|
||||
generating_args: dict[str, Any],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
audios: Optional[Sequence["AudioInput"]] = None,
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> Tuple[Dict[str, Any], int]:
|
||||
input_kwargs: Optional[dict[str, Any]] = {},
|
||||
) -> tuple[dict[str, Any], int]:
|
||||
mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]}
|
||||
if images is not None:
|
||||
mm_input_dict.update({"images": images, "imglens": [len(images)]})
|
||||
@ -130,7 +131,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
|
||||
max_length: Optional[int] = input_kwargs.pop("max_length", None)
|
||||
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
|
||||
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
|
||||
stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None)
|
||||
|
||||
if stop is not None:
|
||||
logger.warning_rank0("Stop parameter is not supported by the huggingface engine yet.")
|
||||
@ -217,15 +218,15 @@ class HuggingfaceEngine(BaseEngine):
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
template: "Template",
|
||||
generating_args: Dict[str, Any],
|
||||
messages: Sequence[Dict[str, str]],
|
||||
generating_args: dict[str, Any],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
audios: Optional[Sequence["AudioInput"]] = None,
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> List["Response"]:
|
||||
input_kwargs: Optional[dict[str, Any]] = {},
|
||||
) -> list["Response"]:
|
||||
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
|
||||
model,
|
||||
tokenizer,
|
||||
@ -272,14 +273,14 @@ class HuggingfaceEngine(BaseEngine):
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
template: "Template",
|
||||
generating_args: Dict[str, Any],
|
||||
messages: Sequence[Dict[str, str]],
|
||||
generating_args: dict[str, Any],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
audios: Optional[Sequence["AudioInput"]] = None,
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
input_kwargs: Optional[dict[str, Any]] = {},
|
||||
) -> Callable[[], str]:
|
||||
gen_kwargs, _ = HuggingfaceEngine._process_args(
|
||||
model,
|
||||
@ -317,12 +318,12 @@ class HuggingfaceEngine(BaseEngine):
|
||||
def _get_scores(
|
||||
model: "PreTrainedModelWrapper",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
batch_input: List[str],
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> List[float]:
|
||||
batch_input: list[str],
|
||||
input_kwargs: Optional[dict[str, Any]] = {},
|
||||
) -> list[float]:
|
||||
max_length: Optional[int] = input_kwargs.pop("max_length", None)
|
||||
device = getattr(model.pretrained_model, "device", "cuda")
|
||||
inputs: Dict[str, "torch.Tensor"] = tokenizer(
|
||||
inputs: dict[str, torch.Tensor] = tokenizer(
|
||||
batch_input,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
@ -330,21 +331,21 @@ class HuggingfaceEngine(BaseEngine):
|
||||
return_tensors="pt",
|
||||
add_special_tokens=False,
|
||||
).to(device)
|
||||
values: "torch.Tensor" = model(**inputs, return_dict=True, use_cache=False)[-1]
|
||||
values: torch.Tensor = model(**inputs, return_dict=True, use_cache=False)[-1]
|
||||
scores = values.gather(dim=-1, index=(inputs["attention_mask"].sum(dim=-1, keepdim=True) - 1))
|
||||
return scores
|
||||
|
||||
@override
|
||||
async def chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
audios: Optional[Sequence["AudioInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
) -> list["Response"]:
|
||||
if not self.can_generate:
|
||||
raise ValueError("The current model does not support `chat`.")
|
||||
|
||||
@ -370,7 +371,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
@override
|
||||
async def stream_chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
@ -408,9 +409,9 @@ class HuggingfaceEngine(BaseEngine):
|
||||
@override
|
||||
async def get_scores(
|
||||
self,
|
||||
batch_input: List[str],
|
||||
batch_input: list[str],
|
||||
**input_kwargs,
|
||||
) -> List[float]:
|
||||
) -> list[float]:
|
||||
if self.can_generate:
|
||||
raise ValueError("Cannot get scores using an auto-regressive model.")
|
||||
|
||||
|
@ -13,7 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
|
||||
from collections.abc import AsyncGenerator, AsyncIterator, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
@ -53,7 +54,7 @@ class VllmEngine(BaseEngine):
|
||||
self.model_args = model_args
|
||||
config = load_config(model_args) # may download model from ms hub
|
||||
if getattr(config, "quantization_config", None): # gptq models should use float16
|
||||
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
||||
quantization_config: dict[str, Any] = getattr(config, "quantization_config", None)
|
||||
quant_method = quantization_config.get("quant_method", "")
|
||||
if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto":
|
||||
model_args.infer_dtype = "float16"
|
||||
@ -101,7 +102,7 @@ class VllmEngine(BaseEngine):
|
||||
|
||||
async def _generate(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
@ -143,7 +144,7 @@ class VllmEngine(BaseEngine):
|
||||
skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
|
||||
max_length: Optional[int] = input_kwargs.pop("max_length", None)
|
||||
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
|
||||
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
|
||||
stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None)
|
||||
|
||||
if length_penalty is not None:
|
||||
logger.warning_rank0("Length penalty is not supported by the vllm engine yet.")
|
||||
@ -201,14 +202,14 @@ class VllmEngine(BaseEngine):
|
||||
@override
|
||||
async def chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
audios: Optional[Sequence["AudioInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
) -> list["Response"]:
|
||||
final_output = None
|
||||
generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
|
||||
async for request_output in generator:
|
||||
@ -230,7 +231,7 @@ class VllmEngine(BaseEngine):
|
||||
@override
|
||||
async def stream_chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
@ -248,7 +249,7 @@ class VllmEngine(BaseEngine):
|
||||
@override
|
||||
async def get_scores(
|
||||
self,
|
||||
batch_input: List[str],
|
||||
batch_input: list[str],
|
||||
**input_kwargs,
|
||||
) -> List[float]:
|
||||
) -> list[float]:
|
||||
raise NotImplementedError("vLLM engine does not support get_scores.")
|
||||
|
@ -24,14 +24,14 @@ from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
|
||||
|
||||
|
||||
__all__ = [
|
||||
"TEMPLATES",
|
||||
"KTODataCollatorWithPadding",
|
||||
"MultiModalDataCollatorForSeq2Seq",
|
||||
"PairwiseDataCollatorWithPadding",
|
||||
"SFTDataCollatorWith4DAttentionMask",
|
||||
"Role",
|
||||
"split_dataset",
|
||||
"get_dataset",
|
||||
"TEMPLATES",
|
||||
"SFTDataCollatorWith4DAttentionMask",
|
||||
"Template",
|
||||
"get_dataset",
|
||||
"get_template_and_fix_tokenizer",
|
||||
"split_dataset",
|
||||
]
|
||||
|
@ -15,8 +15,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -38,9 +39,10 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
|
||||
r"""
|
||||
Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
|
||||
while handles packed sequences and transforms the mask to lower triangular form to prevent future peeking.
|
||||
r"""Expand 2d attention mask to 4d attention mask.
|
||||
|
||||
Expand the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
|
||||
handle packed sequences and transforms the mask to lower triangular form to prevent future peeking.
|
||||
|
||||
e.g.
|
||||
```python
|
||||
@ -78,8 +80,7 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
|
||||
|
||||
@dataclass
|
||||
class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator that supports VLMs.
|
||||
r"""Data collator that supports VLMs.
|
||||
|
||||
Features should contain input_ids, attention_mask, labels, and optionally contain images, videos and audios.
|
||||
"""
|
||||
@ -91,7 +92,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
if self.template is None:
|
||||
raise ValueError("Template is required for MultiModalDataCollator.")
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
||||
def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
|
||||
batch_images, batch_videos, batch_audios = [], [], []
|
||||
batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], []
|
||||
for feature in features:
|
||||
@ -166,7 +167,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
for i, feature in enumerate(features):
|
||||
feature["token_type_ids"] = token_type_ids[i]
|
||||
|
||||
features: Dict[str, "torch.Tensor"] = super().__call__(features)
|
||||
features: dict[str, torch.Tensor] = super().__call__(features)
|
||||
|
||||
if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope
|
||||
rope_index_kwargs = {
|
||||
@ -198,15 +199,13 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
|
||||
@dataclass
|
||||
class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator for 4d attention mask.
|
||||
"""
|
||||
r"""Data collator for 4d attention mask."""
|
||||
|
||||
block_diag_attn: bool = False
|
||||
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
|
||||
compute_dtype: "torch.dtype" = torch.float32
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
||||
def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
|
||||
features = super().__call__(features)
|
||||
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
|
||||
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
|
||||
@ -220,13 +219,10 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
|
||||
|
||||
@dataclass
|
||||
class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator for pairwise data.
|
||||
"""
|
||||
r"""Data collator for pairwise data."""
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
||||
r"""
|
||||
Pads batched data to the longest sequence in the batch.
|
||||
def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
|
||||
r"""Pad batched data to the longest sequence in the batch.
|
||||
|
||||
We generate 2 * n examples where the first n examples represent chosen examples and
|
||||
the last n examples represent rejected examples.
|
||||
@ -249,11 +245,9 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
||||
|
||||
@dataclass
|
||||
class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator for KTO data.
|
||||
"""
|
||||
r"""Data collator for KTO data."""
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
||||
def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
|
||||
target_features = []
|
||||
kl_features = []
|
||||
kto_tags = []
|
||||
|
@ -14,8 +14,9 @@
|
||||
|
||||
import os
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Type, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from ..extras import logging
|
||||
from .data_utils import Role
|
||||
@ -36,10 +37,8 @@ class DatasetConverter:
|
||||
dataset_attr: "DatasetAttr"
|
||||
data_args: "DataArguments"
|
||||
|
||||
def _find_medias(self, medias: Union[Any, Sequence[Any]]) -> Optional[List[Any]]:
|
||||
r"""
|
||||
Optionally concatenates media path to media dir when loading from local disk.
|
||||
"""
|
||||
def _find_medias(self, medias: Union[Any, Sequence[Any]]) -> Optional[list[Any]]:
|
||||
r"""Optionally concatenate media path to media dir when loading from local disk."""
|
||||
if not isinstance(medias, list):
|
||||
medias = [medias] if medias is not None else []
|
||||
elif len(medias) == 0:
|
||||
@ -57,16 +56,14 @@ class DatasetConverter:
|
||||
return medias
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]:
|
||||
r"""
|
||||
Converts a single example in the dataset to the standard format.
|
||||
"""
|
||||
def __call__(self, example: dict[str, Any]) -> dict[str, Any]:
|
||||
r"""Convert a single example in the dataset to the standard format."""
|
||||
...
|
||||
|
||||
|
||||
@dataclass
|
||||
class AlpacaDatasetConverter(DatasetConverter):
|
||||
def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def __call__(self, example: dict[str, Any]) -> dict[str, Any]:
|
||||
prompt = []
|
||||
if self.dataset_attr.history and isinstance(example[self.dataset_attr.history], list):
|
||||
for old_prompt, old_response in example[self.dataset_attr.history]:
|
||||
@ -116,7 +113,7 @@ class AlpacaDatasetConverter(DatasetConverter):
|
||||
|
||||
@dataclass
|
||||
class SharegptDatasetConverter(DatasetConverter):
|
||||
def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def __call__(self, example: dict[str, Any]) -> dict[str, Any]:
|
||||
tag_mapping = {
|
||||
self.dataset_attr.user_tag: Role.USER.value,
|
||||
self.dataset_attr.assistant_tag: Role.ASSISTANT.value,
|
||||
@ -216,10 +213,8 @@ DATASET_CONVERTERS = {
|
||||
}
|
||||
|
||||
|
||||
def register_dataset_converter(name: str, dataset_converter: Type["DatasetConverter"]) -> None:
|
||||
r"""
|
||||
Register a new dataset converter.
|
||||
"""
|
||||
def register_dataset_converter(name: str, dataset_converter: type["DatasetConverter"]) -> None:
|
||||
r"""Register a new dataset converter."""
|
||||
if name in DATASET_CONVERTERS:
|
||||
raise ValueError(f"Dataset converter {name} already exists.")
|
||||
|
||||
@ -227,9 +222,7 @@ def register_dataset_converter(name: str, dataset_converter: Type["DatasetConver
|
||||
|
||||
|
||||
def get_dataset_converter(name: str, dataset_attr: "DatasetAttr", data_args: "DataArguments") -> "DatasetConverter":
|
||||
r"""
|
||||
Gets a dataset converter.
|
||||
"""
|
||||
r"""Get a dataset converter."""
|
||||
if name not in DATASET_CONVERTERS:
|
||||
raise ValueError(f"Dataset converter {name} not found.")
|
||||
|
||||
@ -242,17 +235,17 @@ def align_dataset(
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
r"""
|
||||
Aligned dataset:
|
||||
_prompt: [{"role": "user", "content": "..."}] * (2T - 1)
|
||||
_response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
|
||||
_system: "..."
|
||||
_tools: "...",
|
||||
_images: [],
|
||||
_videos: [],
|
||||
_audios: [],
|
||||
"""
|
||||
r"""Align the dataset to a specific format.
|
||||
|
||||
Aligned dataset:
|
||||
_prompt: [{"role": "user", "content": "..."}] * (2T - 1)
|
||||
_response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
|
||||
_system: "..."
|
||||
_tools: "..."
|
||||
_images: []
|
||||
_videos: []
|
||||
_audios: []
|
||||
"""
|
||||
column_names = list(next(iter(dataset)).keys())
|
||||
kwargs = {}
|
||||
if not data_args.streaming:
|
||||
|
@ -12,8 +12,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum, unique
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, TypedDict, Union
|
||||
from typing import TYPE_CHECKING, Optional, TypedDict, Union
|
||||
|
||||
from datasets import DatasetDict, concatenate_datasets, interleave_datasets
|
||||
|
||||
@ -29,7 +30,7 @@ if TYPE_CHECKING:
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
|
||||
SLOTS = Sequence[Union[str, set[str], dict[str, str]]]
|
||||
|
||||
|
||||
@unique
|
||||
@ -43,15 +44,13 @@ class Role(str, Enum):
|
||||
|
||||
class DatasetModule(TypedDict):
|
||||
train_dataset: Optional[Union["Dataset", "IterableDataset"]]
|
||||
eval_dataset: Optional[Union["Dataset", "IterableDataset", Dict[str, "Dataset"]]]
|
||||
eval_dataset: Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]
|
||||
|
||||
|
||||
def merge_dataset(
|
||||
all_datasets: List[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int
|
||||
all_datasets: list[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
r"""
|
||||
Merges multiple datasets to a unified dataset.
|
||||
"""
|
||||
r"""Merge multiple datasets to a unified dataset."""
|
||||
if len(all_datasets) == 1:
|
||||
return all_datasets[0]
|
||||
|
||||
@ -78,14 +77,13 @@ def merge_dataset(
|
||||
|
||||
def split_dataset(
|
||||
dataset: Optional[Union["Dataset", "IterableDataset"]],
|
||||
eval_dataset: Optional[Union["Dataset", "IterableDataset", Dict[str, "Dataset"]]],
|
||||
eval_dataset: Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]],
|
||||
data_args: "DataArguments",
|
||||
seed: int,
|
||||
) -> "DatasetDict":
|
||||
r"""
|
||||
Splits the dataset and returns a dataset dict containing train set and validation set.
|
||||
r"""Split the dataset and returns a dataset dict containing train set and validation set.
|
||||
|
||||
Supports both map dataset and iterable dataset.
|
||||
Support both map dataset and iterable dataset.
|
||||
"""
|
||||
if eval_dataset is not None and data_args.val_size > 1e-6:
|
||||
raise ValueError("Cannot specify `val_size` if `eval_dataset` is not None.")
|
||||
@ -120,10 +118,8 @@ def split_dataset(
|
||||
|
||||
|
||||
def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModule":
|
||||
r"""
|
||||
Converts dataset or dataset dict to dataset module.
|
||||
"""
|
||||
dataset_module: "DatasetModule" = {}
|
||||
r"""Convert dataset or dataset dict to dataset module."""
|
||||
dataset_module: DatasetModule = {}
|
||||
if isinstance(dataset, DatasetDict): # dataset dict
|
||||
if "train" in dataset:
|
||||
dataset_module["train_dataset"] = dataset["train"]
|
||||
|
@ -16,7 +16,7 @@ import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
@ -31,14 +31,11 @@ class Formatter(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
r"""
|
||||
Forms a list of slots according to the inputs to encode.
|
||||
"""
|
||||
r"""Forms a list of slots according to the inputs to encode."""
|
||||
...
|
||||
|
||||
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
|
||||
r"""
|
||||
Extract a list of tuples from the response message if using tools.
|
||||
def extract(self, content: str) -> Union[str, list["FunctionCall"]]:
|
||||
r"""Extract a list of tuples from the response message if using tools.
|
||||
|
||||
Each tuple consists of function name and function arguments.
|
||||
"""
|
||||
@ -105,7 +102,7 @@ class FunctionFormatter(StringFormatter):
|
||||
if thought:
|
||||
content = content.replace(thought.group(0), "")
|
||||
|
||||
functions: List["FunctionCall"] = []
|
||||
functions: list[FunctionCall] = []
|
||||
try:
|
||||
tool_calls = json.loads(content)
|
||||
if not isinstance(tool_calls, list): # parallel function call
|
||||
@ -141,5 +138,5 @@ class ToolFormatter(Formatter):
|
||||
raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") # flat string
|
||||
|
||||
@override
|
||||
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
|
||||
def extract(self, content: str) -> Union[str, list["FunctionCall"]]:
|
||||
return self.tool_utils.tool_extractor(content)
|
||||
|
@ -13,7 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset, load_from_disk
|
||||
@ -54,9 +55,7 @@ def _load_single_dataset(
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
r"""
|
||||
Loads a single dataset and aligns it to the standard format.
|
||||
"""
|
||||
r"""Load a single dataset and aligns it to the standard format."""
|
||||
logger.info_rank0(f"Loading dataset {dataset_attr}...")
|
||||
data_path, data_name, data_dir, data_files = None, None, None, None
|
||||
if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]:
|
||||
@ -164,10 +163,8 @@ def _get_merged_dataset(
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||
merge: bool = True,
|
||||
) -> Optional[Union["Dataset", "IterableDataset", Dict[str, "Dataset"]]]:
|
||||
r"""
|
||||
Returns the merged datasets in the standard format.
|
||||
"""
|
||||
) -> Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]:
|
||||
r"""Return the merged datasets in the standard format."""
|
||||
if dataset_names is None:
|
||||
return None
|
||||
|
||||
@ -192,9 +189,7 @@ def _get_dataset_processor(
|
||||
processor: Optional["ProcessorMixin"],
|
||||
do_generate: bool = False,
|
||||
) -> "DatasetProcessor":
|
||||
r"""
|
||||
Returns the corresponding dataset processor.
|
||||
"""
|
||||
r"""Return the corresponding dataset processor."""
|
||||
if stage == "pt":
|
||||
dataset_processor_class = PretrainDatasetProcessor
|
||||
elif stage == "sft" and not do_generate:
|
||||
@ -236,9 +231,7 @@ def _get_preprocessed_dataset(
|
||||
processor: Optional["ProcessorMixin"] = None,
|
||||
is_eval: bool = False,
|
||||
) -> Optional[Union["Dataset", "IterableDataset"]]:
|
||||
r"""
|
||||
Preprocesses the dataset, including format checking and tokenization.
|
||||
"""
|
||||
r"""Preprocesses the dataset, including format checking and tokenization."""
|
||||
if dataset is None:
|
||||
return None
|
||||
|
||||
@ -284,9 +277,7 @@ def get_dataset(
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"] = None,
|
||||
) -> "DatasetModule":
|
||||
r"""
|
||||
Gets the train dataset and optionally gets the evaluation dataset.
|
||||
"""
|
||||
r"""Get the train dataset and optionally gets the evaluation dataset."""
|
||||
# Load tokenized dataset if path exists
|
||||
if data_args.tokenized_path is not None:
|
||||
if has_tokenized_data(data_args.tokenized_path):
|
||||
|
@ -1,10 +1,11 @@
|
||||
import inspect
|
||||
import math
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from io import BytesIO
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Type, TypedDict, Union
|
||||
from typing import TYPE_CHECKING, Optional, TypedDict, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -58,12 +59,12 @@ if TYPE_CHECKING:
|
||||
|
||||
def _get_paligemma_token_type_ids(
|
||||
imglens: Sequence[int], seqlens: Sequence[int], processor: "ProcessorMixin"
|
||||
) -> List[List[int]]:
|
||||
r"""
|
||||
Gets paligemma token type ids for computing loss.
|
||||
) -> list[list[int]]:
|
||||
r"""Get paligemma token type ids for computing loss.
|
||||
|
||||
Returns:
|
||||
batch_token_type_ids: shape (batch_size, sequence_length)
|
||||
|
||||
"""
|
||||
batch_token_type_ids = []
|
||||
for imglen, seqlen in zip(imglens, seqlens):
|
||||
@ -87,11 +88,9 @@ class MMPluginMixin:
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
) -> None:
|
||||
r"""
|
||||
Validates if this model accepts the input modalities.
|
||||
"""
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor", None)
|
||||
feature_extractor: "SequenceFeatureExtractor" = getattr(processor, "feature_extractor", None)
|
||||
r"""Validate if this model accepts the input modalities."""
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
||||
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
||||
if len(images) != 0 and self.image_token is None:
|
||||
raise ValueError(
|
||||
"This model does not support image input. Please check whether the correct `template` is used."
|
||||
@ -119,9 +118,7 @@ class MMPluginMixin:
|
||||
def _preprocess_image(
|
||||
self, image: "ImageObject", image_max_pixels: int, image_min_pixels: int, **kwargs
|
||||
) -> "ImageObject":
|
||||
r"""
|
||||
Pre-processes a single image.
|
||||
"""
|
||||
r"""Pre-process a single image."""
|
||||
if (image.width * image.height) > image_max_pixels:
|
||||
resize_factor = math.sqrt(image_max_pixels / (image.width * image.height))
|
||||
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
|
||||
@ -139,10 +136,8 @@ class MMPluginMixin:
|
||||
|
||||
def _get_video_sample_indices(
|
||||
self, video_stream: "Stream", video_fps: float, video_maxlen: int, **kwargs
|
||||
) -> List[int]:
|
||||
r"""
|
||||
Computes video sample indices according to fps.
|
||||
"""
|
||||
) -> list[int]:
|
||||
r"""Compute video sample indices according to fps."""
|
||||
total_frames = video_stream.frames
|
||||
if total_frames == 0: # infinite video
|
||||
return np.linspace(0, video_maxlen - 1, video_maxlen).astype(np.int32)
|
||||
@ -151,10 +146,8 @@ class MMPluginMixin:
|
||||
sample_frames = min(total_frames, video_maxlen, sample_frames)
|
||||
return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
|
||||
|
||||
def _regularize_images(self, images: Sequence["ImageInput"], **kwargs) -> List["ImageObject"]:
|
||||
r"""
|
||||
Regularizes images to avoid error. Including reading and pre-processing.
|
||||
"""
|
||||
def _regularize_images(self, images: Sequence["ImageInput"], **kwargs) -> list["ImageObject"]:
|
||||
r"""Regularize images to avoid error. Including reading and pre-processing."""
|
||||
results = []
|
||||
for image in images:
|
||||
if isinstance(image, str):
|
||||
@ -174,16 +167,14 @@ class MMPluginMixin:
|
||||
|
||||
return results
|
||||
|
||||
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]:
|
||||
r"""
|
||||
Regularizes videos to avoid error. Including reading, resizing and converting.
|
||||
"""
|
||||
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> list[list["ImageObject"]]:
|
||||
r"""Regularizes videos to avoid error. Including reading, resizing and converting."""
|
||||
results = []
|
||||
for video in videos:
|
||||
container = av.open(video, "r")
|
||||
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
|
||||
frames: List["ImageObject"] = []
|
||||
frames: list[ImageObject] = []
|
||||
container.seek(0)
|
||||
for frame_idx, frame in enumerate(container.decode(video_stream)):
|
||||
if frame_idx in sample_indices:
|
||||
@ -194,10 +185,8 @@ class MMPluginMixin:
|
||||
|
||||
return results
|
||||
|
||||
def _regularize_audios(self, audios: Sequence["AudioInput"], sampling_rate: float, **kwargs) -> List["NDArray"]:
|
||||
r"""
|
||||
Regularizes audios to avoid error. Including reading and resampling.
|
||||
"""
|
||||
def _regularize_audios(self, audios: Sequence["AudioInput"], sampling_rate: float, **kwargs) -> list["NDArray"]:
|
||||
r"""Regularizes audios to avoid error. Including reading and resampling."""
|
||||
results = []
|
||||
for audio in audios:
|
||||
if isinstance(audio, str):
|
||||
@ -216,9 +205,8 @@ class MMPluginMixin:
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: "ProcessorMixin",
|
||||
) -> Dict[str, "torch.Tensor"]:
|
||||
r"""
|
||||
Processes visual inputs.
|
||||
) -> dict[str, "torch.Tensor"]:
|
||||
r"""Process visual inputs.
|
||||
|
||||
Returns: (llava and paligemma)
|
||||
pixel_values: tensor with shape (B, C, H, W)
|
||||
@ -229,9 +217,9 @@ class MMPluginMixin:
|
||||
|
||||
It holds num_patches == torch.prod(image_grid_thw)
|
||||
"""
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor", None)
|
||||
video_processor: "BaseImageProcessor" = getattr(processor, "video_processor", image_processor)
|
||||
feature_extractor: "SequenceFeatureExtractor" = getattr(processor, "feature_extractor", None)
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
||||
video_processor: BaseImageProcessor = getattr(processor, "video_processor", image_processor)
|
||||
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
||||
mm_inputs = {}
|
||||
|
||||
if len(images) != 0:
|
||||
@ -278,31 +266,27 @@ class MMPluginMixin:
|
||||
class BasePlugin(MMPluginMixin):
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
r"""
|
||||
Pre-processes input messages before tokenization for VLMs.
|
||||
"""
|
||||
) -> list[dict[str, str]]:
|
||||
r"""Pre-processes input messages before tokenization for VLMs."""
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
return messages
|
||||
|
||||
def process_token_ids(
|
||||
self,
|
||||
input_ids: List[int],
|
||||
labels: Optional[List[int]],
|
||||
input_ids: list[int],
|
||||
labels: Optional[list[int]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Tuple[List[int], Optional[List[int]]]:
|
||||
r"""
|
||||
Pre-processes token ids after tokenization for VLMs.
|
||||
"""
|
||||
) -> tuple[list[int], Optional[list[int]]]:
|
||||
r"""Pre-processes token ids after tokenization for VLMs."""
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
return input_ids, labels
|
||||
|
||||
@ -314,20 +298,21 @@ class BasePlugin(MMPluginMixin):
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
audlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
batch_ids: Sequence[list[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
r"""
|
||||
Builds batched multimodal inputs for VLMs.
|
||||
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||
r"""Build batched multimodal inputs for VLMs.
|
||||
|
||||
Arguments:
|
||||
images: a list of image inputs, shape (num_images,)
|
||||
videos: a list of video inputs, shape (num_videos,)
|
||||
audios: a list of audio inputs, shape (num_audios,)
|
||||
imglens: number of images in each sample, shape (batch_size,)
|
||||
vidlens: number of videos in each sample, shape (batch_size,)
|
||||
audlens: number of audios in each sample, shape (batch_size,)
|
||||
batch_ids: token ids of input samples, shape (batch_size, seq_len)
|
||||
processor: a processor for pre-processing images and videos
|
||||
|
||||
"""
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
return {}
|
||||
@ -338,12 +323,12 @@ class LlavaPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
num_image_tokens = 0
|
||||
image_seqlen = getattr(processor, "image_seqlen") if self.expand_mm_tokens else 1
|
||||
@ -370,9 +355,9 @@ class LlavaPlugin(BasePlugin):
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
audlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
batch_ids: Sequence[list[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
return self._get_mm_inputs(images, videos, audios, processor)
|
||||
|
||||
@ -382,12 +367,12 @@ class LlavaNextPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
num_image_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
@ -426,9 +411,9 @@ class LlavaNextPlugin(BasePlugin):
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
audlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
batch_ids: Sequence[list[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
return self._get_mm_inputs(images, videos, audios, processor)
|
||||
|
||||
@ -438,12 +423,12 @@ class LlavaNextVideoPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
num_image_tokens, num_video_tokens = 0, 0
|
||||
messages = deepcopy(messages)
|
||||
@ -502,9 +487,9 @@ class LlavaNextVideoPlugin(BasePlugin):
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
audlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
batch_ids: Sequence[list[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
return self._get_mm_inputs(images, videos, audios, processor)
|
||||
|
||||
@ -514,16 +499,16 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
|
||||
messages = deepcopy(messages)
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
||||
mm_inputs = {}
|
||||
audio_inputs = {}
|
||||
if len(images) != 0 and len(videos) != 0:
|
||||
@ -619,9 +604,9 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: "ProcessorMixin",
|
||||
**kwargs,
|
||||
) -> Dict[str, "torch.Tensor"]:
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
feature_extractor: "SequenceFeatureExtractor" = getattr(processor, "feature_extractor", None)
|
||||
) -> dict[str, "torch.Tensor"]:
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
||||
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
||||
mm_inputs = {}
|
||||
if len(images) != 0:
|
||||
images = self._regularize_images(
|
||||
@ -691,9 +676,9 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
audlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
batch_ids: Sequence[list[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
# image bound
|
||||
image_bounds_list = []
|
||||
@ -756,12 +741,12 @@ class MllamaPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
num_image_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
@ -782,10 +767,9 @@ class MllamaPlugin(BasePlugin):
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: "ProcessorMixin",
|
||||
imglens: List[int],
|
||||
) -> Dict[str, "torch.Tensor"]:
|
||||
r"""
|
||||
Processes visual inputs for mllama because its image processor only accepts List[List[ImageInput]].
|
||||
imglens: list[int],
|
||||
) -> dict[str, "torch.Tensor"]:
|
||||
r"""Process visual inputs for mllama because its image processor only accepts List[List[ImageInput]].
|
||||
|
||||
Returns:
|
||||
pixel_values: tensor with shape
|
||||
@ -794,8 +778,9 @@ class MllamaPlugin(BasePlugin):
|
||||
aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1).
|
||||
aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4).
|
||||
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
|
||||
|
||||
"""
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
||||
mm_inputs = {}
|
||||
if len(images) > 0:
|
||||
images = self._regularize_images(
|
||||
@ -821,9 +806,9 @@ class MllamaPlugin(BasePlugin):
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
audlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
batch_ids: Sequence[list[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor, imglens)
|
||||
if mm_inputs:
|
||||
@ -850,12 +835,12 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
num_image_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
@ -875,14 +860,14 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
@override
|
||||
def process_token_ids(
|
||||
self,
|
||||
input_ids: List[int],
|
||||
labels: Optional[List[int]],
|
||||
input_ids: list[int],
|
||||
labels: Optional[list[int]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Tuple[List[int], Optional[List[int]]]:
|
||||
) -> tuple[list[int], Optional[list[int]]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
num_images = len(images)
|
||||
image_seqlen = num_images * getattr(processor, "image_seqlen") if self.expand_mm_tokens else 0 # skip mm token
|
||||
@ -902,9 +887,9 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
audlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
batch_ids: Sequence[list[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
seqlens = [len(input_ids) for input_ids in batch_ids]
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
@ -917,12 +902,12 @@ class PixtralPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
patch_size = getattr(processor, "patch_size")
|
||||
image_token = getattr(processor, "image_token")
|
||||
@ -968,9 +953,9 @@ class PixtralPlugin(BasePlugin):
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
audlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
batch_ids: Sequence[list[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
mm_inputs.pop("image_sizes", None)
|
||||
@ -982,12 +967,12 @@ class Qwen2AudioPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
bos_token: str = getattr(processor, "audio_bos_token")
|
||||
eos_token: str = getattr(processor, "audio_eos_token")
|
||||
@ -1028,9 +1013,9 @@ class Qwen2AudioPlugin(BasePlugin):
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
audlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
batch_ids: Sequence[list[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
return self._get_mm_inputs(images, videos, audios, processor)
|
||||
|
||||
@ -1057,13 +1042,13 @@ class Qwen2VLPlugin(BasePlugin):
|
||||
@override
|
||||
def _regularize_videos(
|
||||
self, videos: Sequence["VideoInput"], **kwargs
|
||||
) -> Tuple[List[List["ImageObject"]], List[float]]:
|
||||
) -> tuple[list[list["ImageObject"]], list[float]]:
|
||||
results, fps_per_video = [], []
|
||||
for video in videos:
|
||||
container = av.open(video, "r")
|
||||
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
|
||||
frames: List["ImageObject"] = []
|
||||
frames: list[ImageObject] = []
|
||||
container.seek(0)
|
||||
for frame_idx, frame in enumerate(container.decode(video_stream)):
|
||||
if frame_idx in sample_indices:
|
||||
@ -1088,8 +1073,8 @@ class Qwen2VLPlugin(BasePlugin):
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: "ProcessorMixin",
|
||||
) -> Dict[str, "torch.Tensor"]:
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor", None)
|
||||
) -> dict[str, "torch.Tensor"]:
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
||||
mm_inputs = {}
|
||||
if len(images) != 0:
|
||||
images = self._regularize_images(
|
||||
@ -1115,16 +1100,16 @@ class Qwen2VLPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
num_image_tokens, num_video_tokens = 0, 0
|
||||
messages = deepcopy(messages)
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
||||
|
||||
merge_length: int = getattr(image_processor, "merge_size") ** 2
|
||||
if self.expand_mm_tokens:
|
||||
@ -1176,13 +1161,13 @@ class Qwen2VLPlugin(BasePlugin):
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
audlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
batch_ids: Sequence[list[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
fps_per_video = mm_inputs.pop("fps_per_video", [])
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
||||
if "second_per_grid_ts" in processor.model_input_names and fps_per_video:
|
||||
mm_inputs["second_per_grid_ts"] = [image_processor.temporal_patch_size / fps for fps in fps_per_video]
|
||||
|
||||
@ -1194,12 +1179,12 @@ class VideoLlavaPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
num_image_tokens, num_video_tokens = 0, 0
|
||||
messages = deepcopy(messages)
|
||||
@ -1255,9 +1240,9 @@ class VideoLlavaPlugin(BasePlugin):
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
audlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
batch_ids: Sequence[list[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
return self._get_mm_inputs(images, videos, audios, processor)
|
||||
|
||||
@ -1277,10 +1262,8 @@ PLUGINS = {
|
||||
}
|
||||
|
||||
|
||||
def register_mm_plugin(name: str, plugin_class: Type["BasePlugin"]) -> None:
|
||||
r"""
|
||||
Registers a multimodal plugin.
|
||||
"""
|
||||
def register_mm_plugin(name: str, plugin_class: type["BasePlugin"]) -> None:
|
||||
r"""Register a multimodal plugin."""
|
||||
if name in PLUGINS:
|
||||
raise ValueError(f"Multimodal plugin {name} already exists.")
|
||||
|
||||
@ -1293,9 +1276,7 @@ def get_mm_plugin(
|
||||
video_token: Optional[str] = None,
|
||||
audio_token: Optional[str] = None,
|
||||
) -> "BasePlugin":
|
||||
r"""
|
||||
Gets plugin for multimodal inputs.
|
||||
"""
|
||||
r"""Get plugin for multimodal inputs."""
|
||||
if name not in PLUGINS:
|
||||
raise ValueError(f"Multimodal plugin `{name}` not found.")
|
||||
|
||||
|
@ -14,8 +14,9 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Literal, Optional, Sequence
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from transformers.utils import cached_file
|
||||
|
||||
@ -25,9 +26,7 @@ from ..extras.misc import use_modelscope, use_openmind
|
||||
|
||||
@dataclass
|
||||
class DatasetAttr:
|
||||
r"""
|
||||
Dataset attributes.
|
||||
"""
|
||||
r"""Dataset attributes."""
|
||||
|
||||
# basic configs
|
||||
load_from: Literal["hf_hub", "ms_hub", "om_hub", "script", "file"]
|
||||
@ -68,10 +67,10 @@ class DatasetAttr:
|
||||
def __repr__(self) -> str:
|
||||
return self.dataset_name
|
||||
|
||||
def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None:
|
||||
def set_attr(self, key: str, obj: dict[str, Any], default: Optional[Any] = None) -> None:
|
||||
setattr(self, key, obj.get(key, default))
|
||||
|
||||
def join(self, attr: Dict[str, Any]) -> None:
|
||||
def join(self, attr: dict[str, Any]) -> None:
|
||||
self.set_attr("formatting", attr, default="alpaca")
|
||||
self.set_attr("ranking", attr, default=False)
|
||||
self.set_attr("subset", attr)
|
||||
@ -92,10 +91,8 @@ class DatasetAttr:
|
||||
self.set_attr(tag, attr["tags"])
|
||||
|
||||
|
||||
def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -> List["DatasetAttr"]:
|
||||
r"""
|
||||
Gets the attributes of the datasets.
|
||||
"""
|
||||
def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -> list["DatasetAttr"]:
|
||||
r"""Get the attributes of the datasets."""
|
||||
if dataset_names is None:
|
||||
dataset_names = []
|
||||
|
||||
@ -116,7 +113,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
|
||||
|
||||
dataset_info = None
|
||||
|
||||
dataset_list: List["DatasetAttr"] = []
|
||||
dataset_list: list[DatasetAttr] = []
|
||||
for name in dataset_names:
|
||||
if dataset_info is None: # dataset_dir is ONLINE
|
||||
if use_modelscope():
|
||||
|
@ -9,9 +9,9 @@ from .unsupervised import UnsupervisedDatasetProcessor
|
||||
__all__ = [
|
||||
"DatasetProcessor",
|
||||
"FeedbackDatasetProcessor",
|
||||
"PackedSupervisedDatasetProcessor",
|
||||
"PairwiseDatasetProcessor",
|
||||
"PretrainDatasetProcessor",
|
||||
"PackedSupervisedDatasetProcessor",
|
||||
"SupervisedDatasetProcessor",
|
||||
"UnsupervisedDatasetProcessor",
|
||||
]
|
||||
|
@ -13,7 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
@ -30,15 +31,15 @@ logger = logging.get_logger(__name__)
|
||||
class FeedbackDatasetProcessor(DatasetProcessor):
|
||||
def _encode_data_example(
|
||||
self,
|
||||
prompt: Sequence[Dict[str, str]],
|
||||
response: Sequence[Dict[str, str]],
|
||||
kl_response: Sequence[Dict[str, str]],
|
||||
prompt: Sequence[dict[str, str]],
|
||||
response: Sequence[dict[str, str]],
|
||||
kl_response: Sequence[dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
) -> Tuple[List[int], List[int], List[int], List[int], bool]:
|
||||
) -> tuple[list[int], list[int], list[int], list[int], bool]:
|
||||
if response[0]["content"]: # desired example
|
||||
kto_tag = True
|
||||
messages = prompt + [response[0]]
|
||||
@ -82,7 +83,7 @@ class FeedbackDatasetProcessor(DatasetProcessor):
|
||||
kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids
|
||||
return input_ids, labels, kl_input_ids, kl_labels, kto_tag
|
||||
|
||||
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
|
||||
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
|
||||
kl_response = examples["_response"][::-1]
|
||||
model_inputs = defaultdict(list)
|
||||
@ -121,7 +122,7 @@ class FeedbackDatasetProcessor(DatasetProcessor):
|
||||
|
||||
return model_inputs
|
||||
|
||||
def print_data_example(self, example: Dict[str, List[int]]) -> None:
|
||||
def print_data_example(self, example: dict[str, list[int]]) -> None:
|
||||
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
|
@ -13,7 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
@ -30,14 +31,14 @@ logger = logging.get_logger(__name__)
|
||||
class PairwiseDatasetProcessor(DatasetProcessor):
|
||||
def _encode_data_example(
|
||||
self,
|
||||
prompt: Sequence[Dict[str, str]],
|
||||
response: Sequence[Dict[str, str]],
|
||||
prompt: Sequence[dict[str, str]],
|
||||
response: Sequence[dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
) -> Tuple[List[int], List[int], List[int], List[int]]:
|
||||
) -> tuple[list[int], list[int], list[int], list[int]]:
|
||||
chosen_messages = self.template.mm_plugin.process_messages(
|
||||
prompt + [response[0]], images, videos, audios, self.processor
|
||||
)
|
||||
@ -68,7 +69,7 @@ class PairwiseDatasetProcessor(DatasetProcessor):
|
||||
rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids
|
||||
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels
|
||||
|
||||
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
|
||||
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
|
||||
model_inputs = defaultdict(list)
|
||||
for i in range(len(examples["_prompt"])):
|
||||
@ -99,7 +100,7 @@ class PairwiseDatasetProcessor(DatasetProcessor):
|
||||
|
||||
return model_inputs
|
||||
|
||||
def print_data_example(self, example: Dict[str, List[int]]) -> None:
|
||||
def print_data_example(self, example: dict[str, list[int]]) -> None:
|
||||
valid_chosen_labels = list(filter(lambda x: x != IGNORE_INDEX, example["chosen_labels"]))
|
||||
valid_rejected_labels = list(filter(lambda x: x != IGNORE_INDEX, example["rejected_labels"]))
|
||||
print("chosen_input_ids:\n{}".format(example["chosen_input_ids"]))
|
||||
|
@ -17,14 +17,14 @@
|
||||
|
||||
from dataclasses import dataclass
|
||||
from itertools import chain
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from .processor_utils import DatasetProcessor
|
||||
|
||||
|
||||
@dataclass
|
||||
class PretrainDatasetProcessor(DatasetProcessor):
|
||||
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
|
||||
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
|
||||
eos_token = "<|end_of_text|>" if self.data_args.template == "llama3" else self.tokenizer.eos_token
|
||||
text_examples = [messages[0]["content"] + eos_token for messages in examples["_prompt"]]
|
||||
@ -52,6 +52,6 @@ class PretrainDatasetProcessor(DatasetProcessor):
|
||||
|
||||
return result
|
||||
|
||||
def print_data_example(self, example: Dict[str, List[int]]) -> None:
|
||||
def print_data_example(self, example: dict[str, list[int]]) -> None:
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
|
@ -14,8 +14,9 @@
|
||||
|
||||
import bisect
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -27,9 +28,7 @@ if TYPE_CHECKING:
|
||||
|
||||
@dataclass
|
||||
class DatasetProcessor(ABC):
|
||||
r"""
|
||||
A class for data processors.
|
||||
"""
|
||||
r"""A class for data processors."""
|
||||
|
||||
template: "Template"
|
||||
tokenizer: "PreTrainedTokenizer"
|
||||
@ -37,32 +36,24 @@ class DatasetProcessor(ABC):
|
||||
data_args: "DataArguments"
|
||||
|
||||
@abstractmethod
|
||||
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||
r"""
|
||||
Builds model inputs from the examples.
|
||||
"""
|
||||
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
|
||||
r"""Build model inputs from the examples."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def print_data_example(self, example: Dict[str, List[int]]) -> None:
|
||||
r"""
|
||||
Print a data example to stdout.
|
||||
"""
|
||||
def print_data_example(self, example: dict[str, list[int]]) -> None:
|
||||
r"""Print a data example to stdout."""
|
||||
...
|
||||
|
||||
|
||||
def search_for_fit(numbers: Sequence[int], capacity: int) -> int:
|
||||
r"""
|
||||
Finds the index of largest number that fits into the knapsack with the given capacity.
|
||||
"""
|
||||
r"""Find the index of largest number that fits into the knapsack with the given capacity."""
|
||||
index = bisect.bisect(numbers, capacity)
|
||||
return -1 if index == 0 else (index - 1)
|
||||
|
||||
|
||||
def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
|
||||
r"""
|
||||
An efficient greedy algorithm with binary search for the knapsack problem.
|
||||
"""
|
||||
def greedy_knapsack(numbers: list[int], capacity: int) -> list[list[int]]:
|
||||
r"""Implement efficient greedy algorithm with binary search for the knapsack problem."""
|
||||
numbers.sort() # sort numbers in ascending order for binary search
|
||||
knapsacks = []
|
||||
|
||||
@ -83,10 +74,8 @@ def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
|
||||
return knapsacks
|
||||
|
||||
|
||||
def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]:
|
||||
r"""
|
||||
Computes the real sequence length after truncation by the cutoff_len.
|
||||
"""
|
||||
def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> tuple[int, int]:
|
||||
r"""Compute the real sequence length after truncation by the cutoff_len."""
|
||||
if target_len * 2 < cutoff_len: # truncate source
|
||||
max_target_len = cutoff_len
|
||||
elif source_len * 2 < cutoff_len: # truncate target
|
||||
|
@ -13,8 +13,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
@ -32,14 +33,14 @@ logger = logging.get_logger(__name__)
|
||||
class SupervisedDatasetProcessor(DatasetProcessor):
|
||||
def _encode_data_example(
|
||||
self,
|
||||
prompt: Sequence[Dict[str, str]],
|
||||
response: Sequence[Dict[str, str]],
|
||||
prompt: Sequence[dict[str, str]],
|
||||
response: Sequence[dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
) -> tuple[list[int], list[int]]:
|
||||
messages = self.template.mm_plugin.process_messages(prompt + response, images, videos, audios, self.processor)
|
||||
input_ids, labels = self.template.mm_plugin.process_token_ids(
|
||||
[], [], images, videos, audios, self.tokenizer, self.processor
|
||||
@ -85,7 +86,7 @@ class SupervisedDatasetProcessor(DatasetProcessor):
|
||||
|
||||
return input_ids, labels
|
||||
|
||||
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
|
||||
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
||||
model_inputs = defaultdict(list)
|
||||
@ -114,7 +115,7 @@ class SupervisedDatasetProcessor(DatasetProcessor):
|
||||
|
||||
return model_inputs
|
||||
|
||||
def print_data_example(self, example: Dict[str, List[int]]) -> None:
|
||||
def print_data_example(self, example: dict[str, list[int]]) -> None:
|
||||
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
@ -124,7 +125,7 @@ class SupervisedDatasetProcessor(DatasetProcessor):
|
||||
|
||||
@dataclass
|
||||
class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
|
||||
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
|
||||
# TODO: use `position_ids` to achieve packing
|
||||
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
|
||||
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
|
||||
|
@ -13,7 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from ...extras import logging
|
||||
from ..data_utils import Role
|
||||
@ -30,14 +31,14 @@ logger = logging.get_logger(__name__)
|
||||
class UnsupervisedDatasetProcessor(DatasetProcessor):
|
||||
def _encode_data_example(
|
||||
self,
|
||||
prompt: Sequence[Dict[str, str]],
|
||||
response: Sequence[Dict[str, str]],
|
||||
prompt: Sequence[dict[str, str]],
|
||||
response: Sequence[dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
) -> tuple[list[int], list[int]]:
|
||||
if len(response) == 1:
|
||||
messages = prompt + response
|
||||
else:
|
||||
@ -56,7 +57,7 @@ class UnsupervisedDatasetProcessor(DatasetProcessor):
|
||||
labels = labels[:target_len]
|
||||
return input_ids, labels
|
||||
|
||||
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
|
||||
# build inputs with format `<bos> X` and labels with format `Y <eos>`
|
||||
model_inputs = defaultdict(list)
|
||||
for i in range(len(examples["_prompt"])):
|
||||
@ -84,7 +85,7 @@ class UnsupervisedDatasetProcessor(DatasetProcessor):
|
||||
|
||||
return model_inputs
|
||||
|
||||
def print_data_example(self, example: Dict[str, List[int]]) -> None:
|
||||
def print_data_example(self, example: dict[str, list[int]]) -> None:
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
print("label_ids:\n{}".format(example["labels"]))
|
||||
|
@ -12,8 +12,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Type, Union
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
@ -46,8 +47,8 @@ class Template:
|
||||
format_tools: "Formatter"
|
||||
format_prefix: "Formatter"
|
||||
default_system: str
|
||||
stop_words: List[str]
|
||||
thought_words: Tuple[str, str]
|
||||
stop_words: list[str]
|
||||
thought_words: tuple[str, str]
|
||||
efficient_eos: bool
|
||||
replace_eos: bool
|
||||
replace_jinja_template: bool
|
||||
@ -56,13 +57,11 @@ class Template:
|
||||
def encode_oneturn(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
r"""
|
||||
Returns a single pair of token ids representing prompt and response respectively.
|
||||
"""
|
||||
) -> tuple[list[int], list[int]]:
|
||||
r"""Return a single pair of token ids representing prompt and response respectively."""
|
||||
encoded_messages = self._encode(tokenizer, messages, system, tools)
|
||||
prompt_ids = []
|
||||
for encoded_ids in encoded_messages[:-1]:
|
||||
@ -74,36 +73,28 @@ class Template:
|
||||
def encode_multiturn(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
r"""
|
||||
Returns multiple pairs of token ids representing prompts and responses respectively.
|
||||
"""
|
||||
) -> list[tuple[list[int], list[int]]]:
|
||||
r"""Return multiple pairs of token ids representing prompts and responses respectively."""
|
||||
encoded_messages = self._encode(tokenizer, messages, system, tools)
|
||||
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
|
||||
|
||||
def extract_tool(self, content: str) -> Union[str, List["FunctionCall"]]:
|
||||
r"""
|
||||
Extracts tool message.
|
||||
"""
|
||||
def extract_tool(self, content: str) -> Union[str, list["FunctionCall"]]:
|
||||
r"""Extract tool message."""
|
||||
return self.format_tools.extract(content)
|
||||
|
||||
def get_stop_token_ids(self, tokenizer: "PreTrainedTokenizer") -> List[int]:
|
||||
r"""
|
||||
Returns stop token ids.
|
||||
"""
|
||||
def get_stop_token_ids(self, tokenizer: "PreTrainedTokenizer") -> list[int]:
|
||||
r"""Return stop token ids."""
|
||||
stop_token_ids = {tokenizer.eos_token_id}
|
||||
for token in self.stop_words:
|
||||
stop_token_ids.add(tokenizer.convert_tokens_to_ids(token))
|
||||
|
||||
return list(stop_token_ids)
|
||||
|
||||
def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> List[int]:
|
||||
r"""
|
||||
Converts elements to token ids.
|
||||
"""
|
||||
def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> list[int]:
|
||||
r"""Convert elements to token ids."""
|
||||
token_ids = []
|
||||
for elem in elements:
|
||||
if isinstance(elem, str):
|
||||
@ -124,14 +115,14 @@ class Template:
|
||||
def _encode(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
) -> List[List[int]]:
|
||||
r"""
|
||||
Encodes formatted inputs to pairs of token ids.
|
||||
) -> list[list[int]]:
|
||||
r"""Encode formatted inputs to pairs of token ids.
|
||||
|
||||
Turn 0: prefix + system + query resp
|
||||
Turn t: query resp
|
||||
Turn t: query resp.
|
||||
"""
|
||||
system = system or self.default_system
|
||||
encoded_messages = []
|
||||
@ -161,9 +152,7 @@ class Template:
|
||||
|
||||
@staticmethod
|
||||
def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None:
|
||||
r"""
|
||||
Adds or replaces eos token to the tokenizer.
|
||||
"""
|
||||
r"""Add or replace eos token to the tokenizer."""
|
||||
is_added = tokenizer.eos_token_id is None
|
||||
num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
|
||||
|
||||
@ -176,9 +165,7 @@ class Template:
|
||||
logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
|
||||
|
||||
def fix_special_tokens(self, tokenizer: "PreTrainedTokenizer") -> None:
|
||||
r"""
|
||||
Adds eos token and pad token to the tokenizer.
|
||||
"""
|
||||
r"""Add eos token and pad token to the tokenizer."""
|
||||
stop_words = self.stop_words
|
||||
if self.replace_eos:
|
||||
if not stop_words:
|
||||
@ -204,16 +191,12 @@ class Template:
|
||||
|
||||
@staticmethod
|
||||
def _jinja_escape(content: str) -> str:
|
||||
r"""
|
||||
Escape single quotes in content.
|
||||
"""
|
||||
r"""Escape single quotes in content."""
|
||||
return content.replace("'", r"\'")
|
||||
|
||||
@staticmethod
|
||||
def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str:
|
||||
r"""
|
||||
Converts slots to jinja template.
|
||||
"""
|
||||
r"""Convert slots to jinja template."""
|
||||
slot_items = []
|
||||
for slot in slots:
|
||||
if isinstance(slot, str):
|
||||
@ -235,9 +218,7 @@ class Template:
|
||||
return " + ".join(slot_items)
|
||||
|
||||
def _get_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> str:
|
||||
r"""
|
||||
Returns the jinja template.
|
||||
"""
|
||||
r"""Return the jinja template."""
|
||||
prefix = self._convert_slots_to_jinja(self.format_prefix.apply(), tokenizer)
|
||||
system = self._convert_slots_to_jinja(self.format_system.apply(), tokenizer, placeholder="system_message")
|
||||
user = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer)
|
||||
@ -265,9 +246,7 @@ class Template:
|
||||
return jinja_template
|
||||
|
||||
def fix_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> None:
|
||||
r"""
|
||||
Replaces the jinja template in the tokenizer.
|
||||
"""
|
||||
r"""Replace the jinja template in the tokenizer."""
|
||||
if tokenizer.chat_template is None or self.replace_jinja_template:
|
||||
try:
|
||||
tokenizer.chat_template = self._get_jinja_template(tokenizer)
|
||||
@ -278,9 +257,7 @@ class Template:
|
||||
def _convert_slots_to_ollama(
|
||||
slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content"
|
||||
) -> str:
|
||||
r"""
|
||||
Converts slots to ollama template.
|
||||
"""
|
||||
r"""Convert slots to ollama template."""
|
||||
slot_items = []
|
||||
for slot in slots:
|
||||
if isinstance(slot, str):
|
||||
@ -302,9 +279,7 @@ class Template:
|
||||
return "".join(slot_items)
|
||||
|
||||
def _get_ollama_template(self, tokenizer: "PreTrainedTokenizer") -> str:
|
||||
r"""
|
||||
Returns the ollama template.
|
||||
"""
|
||||
r"""Return the ollama template."""
|
||||
prefix = self._convert_slots_to_ollama(self.format_prefix.apply(), tokenizer)
|
||||
system = self._convert_slots_to_ollama(self.format_system.apply(), tokenizer, placeholder=".System")
|
||||
user = self._convert_slots_to_ollama(self.format_user.apply(), tokenizer, placeholder=".Content")
|
||||
@ -316,8 +291,7 @@ class Template:
|
||||
)
|
||||
|
||||
def get_ollama_modelfile(self, tokenizer: "PreTrainedTokenizer") -> str:
|
||||
r"""
|
||||
Returns the ollama modelfile.
|
||||
r"""Return the ollama modelfile.
|
||||
|
||||
TODO: support function calling.
|
||||
"""
|
||||
@ -340,10 +314,10 @@ class Llama2Template(Template):
|
||||
def _encode(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: str,
|
||||
tools: str,
|
||||
) -> List[List[int]]:
|
||||
) -> list[list[int]]:
|
||||
system = system or self.default_system
|
||||
encoded_messages = []
|
||||
for i, message in enumerate(messages):
|
||||
@ -402,7 +376,7 @@ class Llama2Template(Template):
|
||||
return jinja_template
|
||||
|
||||
|
||||
TEMPLATES: Dict[str, "Template"] = {}
|
||||
TEMPLATES: dict[str, "Template"] = {}
|
||||
|
||||
|
||||
def register_template(
|
||||
@ -416,15 +390,14 @@ def register_template(
|
||||
format_prefix: Optional["Formatter"] = None,
|
||||
default_system: str = "",
|
||||
stop_words: Optional[Sequence[str]] = None,
|
||||
thought_words: Optional[Tuple[str, str]] = None,
|
||||
thought_words: Optional[tuple[str, str]] = None,
|
||||
efficient_eos: bool = False,
|
||||
replace_eos: bool = False,
|
||||
replace_jinja_template: bool = False,
|
||||
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
|
||||
template_class: Type["Template"] = Template,
|
||||
template_class: type["Template"] = Template,
|
||||
) -> None:
|
||||
r"""
|
||||
Registers a chat template.
|
||||
r"""Register a chat template.
|
||||
|
||||
To add the following chat template:
|
||||
```
|
||||
@ -472,9 +445,7 @@ def register_template(
|
||||
|
||||
|
||||
def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
|
||||
r"""
|
||||
Extracts a chat template from the tokenizer.
|
||||
"""
|
||||
r"""Extract a chat template from the tokenizer."""
|
||||
|
||||
def find_diff(short_str: str, long_str: str) -> str:
|
||||
i, j = 0, 0
|
||||
@ -532,9 +503,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
|
||||
|
||||
|
||||
def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template":
|
||||
r"""
|
||||
Gets chat template and fixes the tokenizer.
|
||||
"""
|
||||
r"""Get chat template and fixes the tokenizer."""
|
||||
if data_args.template is None:
|
||||
if isinstance(tokenizer.chat_template, str):
|
||||
logger.warning_rank0("`template` was not specified, try parsing the chat template from the tokenizer.")
|
||||
@ -1149,7 +1118,8 @@ register_template(
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
default_system=(
|
||||
"你是一个经过良好训练的AI助手,你的名字是Marco-o1.由阿里国际数字商业集团的AI Business创造.\n## 重要!!!!!\n"
|
||||
"你是一个经过良好训练的AI助手,你的名字是Marco-o1."
|
||||
"由阿里国际数字商业集团的AI Business创造.\n## 重要!!!!!\n"
|
||||
"当你回答问题时,你的思考应该在<Thought>内完成,<Output>内输出你的结果。\n"
|
||||
"<Thought>应该尽可能是英文,但是有2个特例,一个是对原文中的引用,另一个是是数学应该使用markdown格式,<Output>内的输出需要遵循用户输入的语言。\n"
|
||||
),
|
||||
|
@ -17,7 +17,7 @@ import re
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, NamedTuple, Tuple, Union
|
||||
from typing import Any, NamedTuple, Union
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
@ -60,31 +60,24 @@ QWEN_TOOL_PROMPT = (
|
||||
|
||||
@dataclass
|
||||
class ToolUtils(ABC):
|
||||
"""
|
||||
Base class for tool utilities.
|
||||
"""
|
||||
"""Base class for tool utilities."""
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
r"""
|
||||
Generates the system message describing all the available tools.
|
||||
"""
|
||||
def tool_formatter(tools: list[dict[str, Any]]) -> str:
|
||||
r"""Generate the system message describing all the available tools."""
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def function_formatter(functions: List["FunctionCall"]) -> str:
|
||||
r"""
|
||||
Generates the assistant message including all the tool calls.
|
||||
"""
|
||||
def function_formatter(functions: list["FunctionCall"]) -> str:
|
||||
r"""Generate the assistant message including all the tool calls."""
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
||||
r"""
|
||||
Extracts all the function calls from the assistant message.
|
||||
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
|
||||
r"""Extract all the function calls from the assistant message.
|
||||
|
||||
It should be an inverse function of `function_formatter`.
|
||||
"""
|
||||
@ -92,13 +85,11 @@ class ToolUtils(ABC):
|
||||
|
||||
|
||||
class DefaultToolUtils(ToolUtils):
|
||||
r"""
|
||||
Default tool using template.
|
||||
"""
|
||||
r"""Default tool using template."""
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
def tool_formatter(tools: list[dict[str, Any]]) -> str:
|
||||
tool_text = ""
|
||||
tool_names = []
|
||||
for tool in tools:
|
||||
@ -132,7 +123,7 @@ class DefaultToolUtils(ToolUtils):
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: List["FunctionCall"]) -> str:
|
||||
def function_formatter(functions: list["FunctionCall"]) -> str:
|
||||
function_text = ""
|
||||
for name, arguments in functions:
|
||||
function_text += f"Action: {name}\nAction Input: {arguments}\n"
|
||||
@ -141,9 +132,9 @@ class DefaultToolUtils(ToolUtils):
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
||||
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
|
||||
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL)
|
||||
action_match: List[Tuple[str, str]] = re.findall(regex, content)
|
||||
action_match: list[tuple[str, str]] = re.findall(regex, content)
|
||||
if not action_match:
|
||||
return content
|
||||
|
||||
@ -161,13 +152,11 @@ class DefaultToolUtils(ToolUtils):
|
||||
|
||||
|
||||
class GLM4ToolUtils(ToolUtils):
|
||||
r"""
|
||||
GLM-4 tool using template.
|
||||
"""
|
||||
r"""GLM-4 tool using template."""
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
def tool_formatter(tools: list[dict[str, Any]]) -> str:
|
||||
tool_text = ""
|
||||
for tool in tools:
|
||||
tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(
|
||||
@ -178,7 +167,7 @@ class GLM4ToolUtils(ToolUtils):
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: List["FunctionCall"]) -> str:
|
||||
def function_formatter(functions: list["FunctionCall"]) -> str:
|
||||
if len(functions) > 1:
|
||||
raise ValueError("GLM-4 does not support parallel functions.")
|
||||
|
||||
@ -186,7 +175,7 @@ class GLM4ToolUtils(ToolUtils):
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
||||
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
|
||||
if "\n" not in content:
|
||||
return content
|
||||
|
||||
@ -200,15 +189,14 @@ class GLM4ToolUtils(ToolUtils):
|
||||
|
||||
|
||||
class Llama3ToolUtils(ToolUtils):
|
||||
r"""
|
||||
Llama 3.x tool using template with `tools_in_user_message=False`.
|
||||
r"""Llama 3.x tool using template with `tools_in_user_message=False`.
|
||||
|
||||
Reference: https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling
|
||||
"""
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
def tool_formatter(tools: list[dict[str, Any]]) -> str:
|
||||
date = datetime.now().strftime("%d %b %Y")
|
||||
tool_text = ""
|
||||
for tool in tools:
|
||||
@ -219,7 +207,7 @@ class Llama3ToolUtils(ToolUtils):
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: List["FunctionCall"]) -> str:
|
||||
def function_formatter(functions: list["FunctionCall"]) -> str:
|
||||
if len(functions) > 1:
|
||||
raise ValueError("Llama-3 does not support parallel functions.")
|
||||
|
||||
@ -227,7 +215,7 @@ class Llama3ToolUtils(ToolUtils):
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
||||
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
|
||||
try:
|
||||
tool = json.loads(content.strip())
|
||||
except json.JSONDecodeError:
|
||||
@ -240,13 +228,11 @@ class Llama3ToolUtils(ToolUtils):
|
||||
|
||||
|
||||
class MistralToolUtils(ToolUtils):
|
||||
r"""
|
||||
Mistral v0.3 tool using template.
|
||||
"""
|
||||
r"""Mistral v0.3 tool using template."""
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
def tool_formatter(tools: list[dict[str, Any]]) -> str:
|
||||
wrapped_tools = []
|
||||
for tool in tools:
|
||||
wrapped_tools.append({"type": "function", "function": tool})
|
||||
@ -255,7 +241,7 @@ class MistralToolUtils(ToolUtils):
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: List["FunctionCall"]) -> str:
|
||||
def function_formatter(functions: list["FunctionCall"]) -> str:
|
||||
function_texts = []
|
||||
for name, arguments in functions:
|
||||
function_texts.append(f'{{"name": "{name}", "arguments": {arguments}}}')
|
||||
@ -264,7 +250,7 @@ class MistralToolUtils(ToolUtils):
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
||||
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
|
||||
try:
|
||||
tools = json.loads(content.strip())
|
||||
except json.JSONDecodeError:
|
||||
@ -284,13 +270,11 @@ class MistralToolUtils(ToolUtils):
|
||||
|
||||
|
||||
class QwenToolUtils(ToolUtils):
|
||||
r"""
|
||||
Qwen 2.5 tool using template.
|
||||
"""
|
||||
r"""Qwen 2.5 tool using template."""
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
def tool_formatter(tools: list[dict[str, Any]]) -> str:
|
||||
tool_text = ""
|
||||
for tool in tools:
|
||||
wrapped_tool = {"type": "function", "function": tool}
|
||||
@ -300,7 +284,7 @@ class QwenToolUtils(ToolUtils):
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: List["FunctionCall"]) -> str:
|
||||
def function_formatter(functions: list["FunctionCall"]) -> str:
|
||||
function_texts = []
|
||||
for name, arguments in functions:
|
||||
function_texts.append(
|
||||
@ -311,9 +295,9 @@ class QwenToolUtils(ToolUtils):
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
||||
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
|
||||
regex = re.compile(r"<tool_call>(.+?)</tool_call>(?=\s*<tool_call>|\s*$)", re.DOTALL)
|
||||
tool_match: List[str] = re.findall(regex, content)
|
||||
tool_match: list[str] = re.findall(regex, content)
|
||||
if not tool_match:
|
||||
return content
|
||||
|
||||
|
@ -39,7 +39,7 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -59,7 +59,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class Evaluator:
|
||||
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
||||
def __init__(self, args: Optional[dict[str, Any]] = None) -> None:
|
||||
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
|
||||
self.tokenizer = load_tokenizer(self.model_args)["tokenizer"]
|
||||
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
|
||||
@ -69,7 +69,7 @@ class Evaluator:
|
||||
self.choice_inputs = [self.tokenizer.encode(ch, add_special_tokens=False)[-1] for ch in CHOICES]
|
||||
|
||||
@torch.inference_mode()
|
||||
def batch_inference(self, batch_input: Dict[str, "torch.Tensor"]) -> List[str]:
|
||||
def batch_inference(self, batch_input: dict[str, "torch.Tensor"]) -> list[str]:
|
||||
logits = self.model(**batch_input).logits
|
||||
lengths = torch.sum(batch_input["attention_mask"], dim=-1)
|
||||
word_probs = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0)
|
||||
@ -88,7 +88,7 @@ class Evaluator:
|
||||
)
|
||||
|
||||
with open(mapping, encoding="utf-8") as f:
|
||||
categorys: Dict[str, Dict[str, str]] = json.load(f)
|
||||
categorys: dict[str, dict[str, str]] = json.load(f)
|
||||
|
||||
category_corrects = {subj: np.array([], dtype="bool") for subj in SUBJECTS}
|
||||
pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
|
||||
@ -136,7 +136,7 @@ class Evaluator:
|
||||
pbar.close()
|
||||
self._save_results(category_corrects, results)
|
||||
|
||||
def _save_results(self, category_corrects: Dict[str, "NDArray"], results: Dict[str, Dict[int, str]]) -> None:
|
||||
def _save_results(self, category_corrects: dict[str, "NDArray"], results: dict[str, dict[int, str]]) -> None:
|
||||
score_info = "\n".join(
|
||||
[
|
||||
f"{category_name:>15}: {100 * np.mean(category_correct):.2f}"
|
||||
|
@ -12,8 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Sequence, Tuple
|
||||
|
||||
from ..data import Role
|
||||
from ..extras.constants import CHOICES
|
||||
@ -25,20 +25,19 @@ class EvalTemplate:
|
||||
choice: str
|
||||
answer: str
|
||||
|
||||
def _parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
|
||||
r"""
|
||||
def _parse_example(self, example: dict[str, str]) -> tuple[str, str]:
|
||||
r"""Parse eval example.
|
||||
|
||||
input: a dict with keys {"question", "A", "B", "C", "D", "answer"}
|
||||
output: a tuple of (prompt, response)
|
||||
output: a tuple of (prompt, response).
|
||||
"""
|
||||
candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example]
|
||||
return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
|
||||
|
||||
def format_example(
|
||||
self, target_data: Dict[str, str], support_set: Sequence[Dict[str, str]], subject_name: str
|
||||
) -> List[Dict[str, str]]:
|
||||
r"""
|
||||
Converts dataset examples to messages.
|
||||
"""
|
||||
self, target_data: dict[str, str], support_set: Sequence[dict[str, str]], subject_name: str
|
||||
) -> list[dict[str, str]]:
|
||||
r"""Convert dataset examples to messages."""
|
||||
messages = []
|
||||
for k in range(len(support_set)):
|
||||
prompt, response = self._parse_example(support_set[k])
|
||||
@ -52,7 +51,7 @@ class EvalTemplate:
|
||||
return messages
|
||||
|
||||
|
||||
eval_templates: Dict[str, "EvalTemplate"] = {}
|
||||
eval_templates: dict[str, "EvalTemplate"] = {}
|
||||
|
||||
|
||||
def _register_eval_template(name: str, system: str, choice: str, answer: str) -> None:
|
||||
|
@ -15,7 +15,7 @@
|
||||
import os
|
||||
from collections import OrderedDict, defaultdict
|
||||
from enum import Enum
|
||||
from typing import Dict, Optional
|
||||
from typing import Optional
|
||||
|
||||
from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME
|
||||
from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME
|
||||
@ -122,7 +122,7 @@ class RopeScaling(str, Enum):
|
||||
|
||||
|
||||
def register_model_group(
|
||||
models: Dict[str, Dict[DownloadSource, str]],
|
||||
models: dict[str, dict[DownloadSource, str]],
|
||||
template: Optional[str] = None,
|
||||
multimodal: bool = False,
|
||||
) -> None:
|
||||
|
@ -32,9 +32,7 @@ _default_log_level: "logging._Level" = logging.INFO
|
||||
|
||||
|
||||
class LoggerHandler(logging.Handler):
|
||||
r"""
|
||||
Redirects the logging output to the logging file for LLaMA Board.
|
||||
"""
|
||||
r"""Redirect the logging output to the logging file for LLaMA Board."""
|
||||
|
||||
def __init__(self, output_dir: str) -> None:
|
||||
super().__init__()
|
||||
@ -67,9 +65,7 @@ class LoggerHandler(logging.Handler):
|
||||
|
||||
|
||||
class _Logger(logging.Logger):
|
||||
r"""
|
||||
A logger that supports rank0 logging.
|
||||
"""
|
||||
r"""A logger that supports rank0 logging."""
|
||||
|
||||
def info_rank0(self, *args, **kwargs) -> None:
|
||||
self.info(*args, **kwargs)
|
||||
@ -82,9 +78,7 @@ class _Logger(logging.Logger):
|
||||
|
||||
|
||||
def _get_default_logging_level() -> "logging._Level":
|
||||
r"""
|
||||
Returns the default logging level.
|
||||
"""
|
||||
r"""Return the default logging level."""
|
||||
env_level_str = os.environ.get("LLAMAFACTORY_VERBOSITY", None)
|
||||
if env_level_str:
|
||||
if env_level_str.upper() in logging._nameToLevel:
|
||||
@ -104,9 +98,7 @@ def _get_library_root_logger() -> "_Logger":
|
||||
|
||||
|
||||
def _configure_library_root_logger() -> None:
|
||||
r"""
|
||||
Configures root logger using a stdout stream handler with an explicit format.
|
||||
"""
|
||||
r"""Configure root logger using a stdout stream handler with an explicit format."""
|
||||
global _default_handler
|
||||
|
||||
with _thread_lock:
|
||||
@ -126,9 +118,7 @@ def _configure_library_root_logger() -> None:
|
||||
|
||||
|
||||
def get_logger(name: Optional[str] = None) -> "_Logger":
|
||||
r"""
|
||||
Returns a logger with the specified name. It it not supposed to be accessed externally.
|
||||
"""
|
||||
r"""Return a logger with the specified name. It it not supposed to be accessed externally."""
|
||||
if name is None:
|
||||
name = _get_library_name()
|
||||
|
||||
@ -137,17 +127,13 @@ def get_logger(name: Optional[str] = None) -> "_Logger":
|
||||
|
||||
|
||||
def add_handler(handler: "logging.Handler") -> None:
|
||||
r"""
|
||||
Adds a handler to the root logger.
|
||||
"""
|
||||
r"""Add a handler to the root logger."""
|
||||
_configure_library_root_logger()
|
||||
_get_library_root_logger().addHandler(handler)
|
||||
|
||||
|
||||
def remove_handler(handler: logging.Handler) -> None:
|
||||
r"""
|
||||
Removes a handler to the root logger.
|
||||
"""
|
||||
r"""Remove a handler to the root logger."""
|
||||
_configure_library_root_logger()
|
||||
_get_library_root_logger().removeHandler(handler)
|
||||
|
||||
|
@ -17,7 +17,8 @@
|
||||
|
||||
import gc
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, Literal, Sequence, Tuple, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Any, Literal, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -54,9 +55,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class AverageMeter:
|
||||
r"""
|
||||
Computes and stores the average and current value.
|
||||
"""
|
||||
r"""Compute and store the average and current value."""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
@ -75,9 +74,7 @@ class AverageMeter:
|
||||
|
||||
|
||||
def check_version(requirement: str, mandatory: bool = False) -> None:
|
||||
r"""
|
||||
Optionally checks the package version.
|
||||
"""
|
||||
r"""Optionally check the package version."""
|
||||
if is_env_enabled("DISABLE_VERSION_CHECK") and not mandatory:
|
||||
logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.")
|
||||
return
|
||||
@ -91,9 +88,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
|
||||
|
||||
|
||||
def check_dependencies() -> None:
|
||||
r"""
|
||||
Checks the version of the required packages.
|
||||
"""
|
||||
r"""Check the version of the required packages."""
|
||||
check_version("transformers>=4.41.2,<=4.49.0,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
|
||||
check_version("datasets>=2.16.0,<=3.2.0")
|
||||
check_version("accelerate>=0.34.0,<=1.2.1")
|
||||
@ -103,10 +98,8 @@ def check_dependencies() -> None:
|
||||
logger.warning_rank0_once("There are known bugs in transformers v4.46.0-v4.48.0, please use other versions.")
|
||||
|
||||
|
||||
def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float], stage: Literal["sft", "rm"]) -> float:
|
||||
r"""
|
||||
Calculates effective tokens per second.
|
||||
"""
|
||||
def calculate_tps(dataset: Sequence[dict[str, Any]], metrics: dict[str, float], stage: Literal["sft", "rm"]) -> float:
|
||||
r"""Calculate effective tokens per second."""
|
||||
effective_token_num = 0
|
||||
for data in dataset:
|
||||
if stage == "sft":
|
||||
@ -118,10 +111,8 @@ def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float],
|
||||
return result / dist.get_world_size() if dist.is_initialized() else result
|
||||
|
||||
|
||||
def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
|
||||
r"""
|
||||
Returns the number of trainable parameters and number of all parameters in the model.
|
||||
"""
|
||||
def count_parameters(model: "torch.nn.Module") -> tuple[int, int]:
|
||||
r"""Return the number of trainable parameters and number of all parameters in the model."""
|
||||
trainable_params, all_param = 0, 0
|
||||
for param in model.parameters():
|
||||
num_params = param.numel()
|
||||
@ -148,9 +139,7 @@ def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
|
||||
|
||||
|
||||
def get_current_device() -> "torch.device":
|
||||
r"""
|
||||
Gets the current available device.
|
||||
"""
|
||||
r"""Get the current available device."""
|
||||
if is_torch_xpu_available():
|
||||
device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||
elif is_torch_npu_available():
|
||||
@ -166,9 +155,7 @@ def get_current_device() -> "torch.device":
|
||||
|
||||
|
||||
def get_device_count() -> int:
|
||||
r"""
|
||||
Gets the number of available GPU or NPU devices.
|
||||
"""
|
||||
r"""Get the number of available GPU or NPU devices."""
|
||||
if is_torch_xpu_available():
|
||||
return torch.xpu.device_count()
|
||||
elif is_torch_npu_available():
|
||||
@ -180,18 +167,14 @@ def get_device_count() -> int:
|
||||
|
||||
|
||||
def get_logits_processor() -> "LogitsProcessorList":
|
||||
r"""
|
||||
Gets logits processor that removes NaN and Inf logits.
|
||||
"""
|
||||
r"""Get logits processor that removes NaN and Inf logits."""
|
||||
logits_processor = LogitsProcessorList()
|
||||
logits_processor.append(InfNanRemoveLogitsProcessor())
|
||||
return logits_processor
|
||||
|
||||
|
||||
def get_peak_memory() -> Tuple[int, int]:
|
||||
r"""
|
||||
Gets the peak memory usage for the current device (in Bytes).
|
||||
"""
|
||||
def get_peak_memory() -> tuple[int, int]:
|
||||
r"""Get the peak memory usage for the current device (in Bytes)."""
|
||||
if is_torch_npu_available():
|
||||
return torch.npu.max_memory_allocated(), torch.npu.max_memory_reserved()
|
||||
elif is_torch_cuda_available():
|
||||
@ -201,16 +184,12 @@ def get_peak_memory() -> Tuple[int, int]:
|
||||
|
||||
|
||||
def has_tokenized_data(path: "os.PathLike") -> bool:
|
||||
r"""
|
||||
Checks if the path has a tokenized dataset.
|
||||
"""
|
||||
r"""Check if the path has a tokenized dataset."""
|
||||
return os.path.isdir(path) and len(os.listdir(path)) > 0
|
||||
|
||||
|
||||
def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
|
||||
r"""
|
||||
Infers the optimal dtype according to the model_dtype and device compatibility.
|
||||
"""
|
||||
r"""Infer the optimal dtype according to the model_dtype and device compatibility."""
|
||||
if _is_bf16_available and model_dtype == torch.bfloat16:
|
||||
return torch.bfloat16
|
||||
elif _is_fp16_available:
|
||||
@ -220,23 +199,17 @@ def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
|
||||
|
||||
|
||||
def is_gpu_or_npu_available() -> bool:
|
||||
r"""
|
||||
Checks if the GPU or NPU is available.
|
||||
"""
|
||||
r"""Check if the GPU or NPU is available."""
|
||||
return is_torch_npu_available() or is_torch_cuda_available()
|
||||
|
||||
|
||||
def is_env_enabled(env_var: str, default: str = "0") -> bool:
|
||||
r"""
|
||||
Checks if the environment variable is enabled.
|
||||
"""
|
||||
r"""Check if the environment variable is enabled."""
|
||||
return os.getenv(env_var, default).lower() in ["true", "y", "1"]
|
||||
|
||||
|
||||
def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
|
||||
r"""
|
||||
Casts a torch tensor or a numpy array to a numpy array.
|
||||
"""
|
||||
r"""Cast a torch tensor or a numpy array to a numpy array."""
|
||||
if isinstance(inputs, torch.Tensor):
|
||||
inputs = inputs.cpu()
|
||||
if inputs.dtype == torch.bfloat16: # numpy does not support bfloat16 until 1.21.4
|
||||
@ -248,17 +221,13 @@ def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
|
||||
|
||||
|
||||
def skip_check_imports() -> None:
|
||||
r"""
|
||||
Avoids flash attention import error in custom model files.
|
||||
"""
|
||||
r"""Avoid flash attention import error in custom model files."""
|
||||
if not is_env_enabled("FORCE_CHECK_IMPORTS"):
|
||||
transformers.dynamic_module_utils.check_imports = get_relative_imports
|
||||
|
||||
|
||||
def torch_gc() -> None:
|
||||
r"""
|
||||
Collects GPU or NPU memory.
|
||||
"""
|
||||
r"""Collect GPU or NPU memory."""
|
||||
gc.collect()
|
||||
if is_torch_xpu_available():
|
||||
torch.xpu.empty_cache()
|
||||
|
@ -15,7 +15,7 @@
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from transformers.trainer import TRAINER_STATE_NAME
|
||||
|
||||
@ -31,10 +31,8 @@ if is_matplotlib_available():
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def smooth(scalars: List[float]) -> List[float]:
|
||||
r"""
|
||||
EMA implementation according to TensorBoard.
|
||||
"""
|
||||
def smooth(scalars: list[float]) -> list[float]:
|
||||
r"""EMA implementation according to TensorBoard."""
|
||||
if len(scalars) == 0:
|
||||
return []
|
||||
|
||||
@ -48,10 +46,8 @@ def smooth(scalars: List[float]) -> List[float]:
|
||||
return smoothed
|
||||
|
||||
|
||||
def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figure":
|
||||
r"""
|
||||
Plots loss curves in LlamaBoard.
|
||||
"""
|
||||
def gen_loss_plot(trainer_log: list[dict[str, Any]]) -> "matplotlib.figure.Figure":
|
||||
r"""Plot loss curves in LlamaBoard."""
|
||||
plt.close("all")
|
||||
plt.switch_backend("agg")
|
||||
fig = plt.figure()
|
||||
@ -70,10 +66,8 @@ def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figur
|
||||
return fig
|
||||
|
||||
|
||||
def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None:
|
||||
r"""
|
||||
Plots loss curves and saves the image.
|
||||
"""
|
||||
def plot_loss(save_dictionary: str, keys: list[str] = ["loss"]) -> None:
|
||||
r"""Plot loss curves and saves the image."""
|
||||
plt.switch_backend("agg")
|
||||
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
@ -16,14 +16,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments:
|
||||
r"""
|
||||
Arguments pertaining to what data we are going to input our model for training and evaluation.
|
||||
"""
|
||||
r"""Arguments pertaining to what data we are going to input our model for training and evaluation."""
|
||||
|
||||
template: Optional[str] = field(
|
||||
default=None,
|
||||
@ -162,5 +160,5 @@ class DataArguments:
|
||||
if self.mask_history and self.train_on_prompt:
|
||||
raise ValueError("`mask_history` is incompatible with `train_on_prompt`.")
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
@ -21,9 +21,7 @@ from datasets import DownloadMode
|
||||
|
||||
@dataclass
|
||||
class EvaluationArguments:
|
||||
r"""
|
||||
Arguments pertaining to specify the evaluation parameters.
|
||||
"""
|
||||
r"""Arguments pertaining to specify the evaluation parameters."""
|
||||
|
||||
task: str = field(
|
||||
metadata={"help": "Name of the evaluation task."},
|
||||
|
@ -13,14 +13,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class FreezeArguments:
|
||||
r"""
|
||||
Arguments pertaining to the freeze (partial-parameter) training.
|
||||
"""
|
||||
r"""Arguments pertaining to the freeze (partial-parameter) training."""
|
||||
|
||||
freeze_trainable_layers: int = field(
|
||||
default=2,
|
||||
@ -56,9 +54,7 @@ class FreezeArguments:
|
||||
|
||||
@dataclass
|
||||
class LoraArguments:
|
||||
r"""
|
||||
Arguments pertaining to the LoRA training.
|
||||
"""
|
||||
r"""Arguments pertaining to the LoRA training."""
|
||||
|
||||
additional_target: Optional[str] = field(
|
||||
default=None,
|
||||
@ -128,9 +124,7 @@ class LoraArguments:
|
||||
|
||||
@dataclass
|
||||
class RLHFArguments:
|
||||
r"""
|
||||
Arguments pertaining to the PPO, DPO and KTO training.
|
||||
"""
|
||||
r"""Arguments pertaining to the PPO, DPO and KTO training."""
|
||||
|
||||
pref_beta: float = field(
|
||||
default=0.1,
|
||||
@ -212,9 +206,7 @@ class RLHFArguments:
|
||||
|
||||
@dataclass
|
||||
class GaloreArguments:
|
||||
r"""
|
||||
Arguments pertaining to the GaLore algorithm.
|
||||
"""
|
||||
r"""Arguments pertaining to the GaLore algorithm."""
|
||||
|
||||
use_galore: bool = field(
|
||||
default=False,
|
||||
@ -253,9 +245,7 @@ class GaloreArguments:
|
||||
|
||||
@dataclass
|
||||
class ApolloArguments:
|
||||
r"""
|
||||
Arguments pertaining to the APOLLO algorithm.
|
||||
"""
|
||||
r"""Arguments pertaining to the APOLLO algorithm."""
|
||||
|
||||
use_apollo: bool = field(
|
||||
default=False,
|
||||
@ -306,9 +296,7 @@ class ApolloArguments:
|
||||
|
||||
@dataclass
|
||||
class BAdamArgument:
|
||||
r"""
|
||||
Arguments pertaining to the BAdam optimizer.
|
||||
"""
|
||||
r"""Arguments pertaining to the BAdam optimizer."""
|
||||
|
||||
use_badam: bool = field(
|
||||
default=False,
|
||||
@ -393,9 +381,7 @@ class SwanLabArguments:
|
||||
class FinetuningArguments(
|
||||
SwanLabArguments, BAdamArgument, ApolloArguments, GaloreArguments, RLHFArguments, LoraArguments, FreezeArguments
|
||||
):
|
||||
r"""
|
||||
Arguments pertaining to which techniques we are going to fine-tuning with.
|
||||
"""
|
||||
r"""Arguments pertaining to which techniques we are going to fine-tuning with."""
|
||||
|
||||
pure_bf16: bool = field(
|
||||
default=False,
|
||||
@ -452,13 +438,13 @@ class FinetuningArguments(
|
||||
return [item.strip() for item in arg.split(",")]
|
||||
return arg
|
||||
|
||||
self.freeze_trainable_modules: List[str] = split_arg(self.freeze_trainable_modules)
|
||||
self.freeze_extra_modules: Optional[List[str]] = split_arg(self.freeze_extra_modules)
|
||||
self.freeze_trainable_modules: list[str] = split_arg(self.freeze_trainable_modules)
|
||||
self.freeze_extra_modules: Optional[list[str]] = split_arg(self.freeze_extra_modules)
|
||||
self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2
|
||||
self.lora_target: List[str] = split_arg(self.lora_target)
|
||||
self.additional_target: Optional[List[str]] = split_arg(self.additional_target)
|
||||
self.galore_target: List[str] = split_arg(self.galore_target)
|
||||
self.apollo_target: List[str] = split_arg(self.apollo_target)
|
||||
self.lora_target: list[str] = split_arg(self.lora_target)
|
||||
self.additional_target: Optional[list[str]] = split_arg(self.additional_target)
|
||||
self.galore_target: list[str] = split_arg(self.galore_target)
|
||||
self.apollo_target: list[str] = split_arg(self.apollo_target)
|
||||
self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
|
||||
|
||||
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
|
||||
@ -499,7 +485,7 @@ class FinetuningArguments(
|
||||
if self.pissa_init:
|
||||
raise ValueError("`pissa_init` is only valid for LoRA training.")
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
args = asdict(self)
|
||||
args = {k: f"<{k.upper()}>" if k.endswith("api_key") else v for k, v in args.items()}
|
||||
return args
|
||||
|
@ -13,16 +13,14 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from transformers import GenerationConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneratingArguments:
|
||||
r"""
|
||||
Arguments pertaining to specify the decoding parameters.
|
||||
"""
|
||||
r"""Arguments pertaining to specify the decoding parameters."""
|
||||
|
||||
do_sample: bool = field(
|
||||
default=True,
|
||||
@ -35,7 +33,9 @@ class GeneratingArguments:
|
||||
top_p: float = field(
|
||||
default=0.7,
|
||||
metadata={
|
||||
"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."
|
||||
"help": (
|
||||
"The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."
|
||||
)
|
||||
},
|
||||
)
|
||||
top_k: int = field(
|
||||
@ -71,7 +71,7 @@ class GeneratingArguments:
|
||||
metadata={"help": "Whether or not to remove special tokens in the decoding."},
|
||||
)
|
||||
|
||||
def to_dict(self, obey_generation_config: bool = False) -> Dict[str, Any]:
|
||||
def to_dict(self, obey_generation_config: bool = False) -> dict[str, Any]:
|
||||
args = asdict(self)
|
||||
if args.get("max_new_tokens", -1) > 0:
|
||||
args.pop("max_length", None)
|
||||
|
@ -17,7 +17,7 @@
|
||||
|
||||
import json
|
||||
from dataclasses import asdict, dataclass, field, fields
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers.training_args import _convert_str_dict
|
||||
@ -28,9 +28,7 @@ from ..extras.constants import AttentionFunction, EngineName, RopeScaling
|
||||
|
||||
@dataclass
|
||||
class BaseModelArguments:
|
||||
r"""
|
||||
Arguments pertaining to the model.
|
||||
"""
|
||||
r"""Arguments pertaining to the model."""
|
||||
|
||||
model_name_or_path: Optional[str] = field(
|
||||
default=None,
|
||||
@ -184,9 +182,7 @@ class BaseModelArguments:
|
||||
|
||||
@dataclass
|
||||
class QuantizationArguments:
|
||||
r"""
|
||||
Arguments pertaining to the quantization method.
|
||||
"""
|
||||
r"""Arguments pertaining to the quantization method."""
|
||||
|
||||
quantization_method: Literal["bitsandbytes", "hqq", "eetq"] = field(
|
||||
default="bitsandbytes",
|
||||
@ -212,9 +208,7 @@ class QuantizationArguments:
|
||||
|
||||
@dataclass
|
||||
class ProcessorArguments:
|
||||
r"""
|
||||
Arguments pertaining to the image processor.
|
||||
"""
|
||||
r"""Arguments pertaining to the image processor."""
|
||||
|
||||
image_max_pixels: int = field(
|
||||
default=768 * 768,
|
||||
@ -244,9 +238,7 @@ class ProcessorArguments:
|
||||
|
||||
@dataclass
|
||||
class ExportArguments:
|
||||
r"""
|
||||
Arguments pertaining to the model export.
|
||||
"""
|
||||
r"""Arguments pertaining to the model export."""
|
||||
|
||||
export_dir: Optional[str] = field(
|
||||
default=None,
|
||||
@ -292,9 +284,7 @@ class ExportArguments:
|
||||
|
||||
@dataclass
|
||||
class VllmArguments:
|
||||
r"""
|
||||
Arguments pertaining to the vLLM worker.
|
||||
"""
|
||||
r"""Arguments pertaining to the vLLM worker."""
|
||||
|
||||
vllm_maxlen: int = field(
|
||||
default=4096,
|
||||
@ -324,8 +314,7 @@ class VllmArguments:
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, QuantizationArguments, BaseModelArguments):
|
||||
r"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
|
||||
r"""Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
|
||||
|
||||
The class on the most right will be displayed first.
|
||||
"""
|
||||
@ -335,7 +324,7 @@ class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, Quantiz
|
||||
init=False,
|
||||
metadata={"help": "Torch data type for computing model outputs, derived from `fp/bf16`. Do not specify it."},
|
||||
)
|
||||
device_map: Optional[Union[str, Dict[str, Any]]] = field(
|
||||
device_map: Optional[Union[str, dict[str, Any]]] = field(
|
||||
default=None,
|
||||
init=False,
|
||||
metadata={"help": "Device map for model placement, derived from training stage. Do not specify it."},
|
||||
@ -372,7 +361,7 @@ class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, Quantiz
|
||||
|
||||
return result
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
args = asdict(self)
|
||||
args = {k: f"<{k.upper()}>" if k.endswith("token") else v for k, v in args.items()}
|
||||
return args
|
||||
|
@ -19,7 +19,7 @@ import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
@ -47,17 +47,15 @@ check_dependencies()
|
||||
|
||||
|
||||
_TRAIN_ARGS = [ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||
_TRAIN_CLS = tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
||||
_INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
||||
_INFER_CLS = tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
||||
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
|
||||
_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
|
||||
_EVAL_CLS = tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
|
||||
|
||||
|
||||
def read_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> Union[Dict[str, Any], List[str]]:
|
||||
r"""
|
||||
Gets arguments from the command line or a config file.
|
||||
"""
|
||||
def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[dict[str, Any], list[str]]:
|
||||
r"""Get arguments from the command line or a config file."""
|
||||
if args is not None:
|
||||
return args
|
||||
|
||||
@ -70,8 +68,8 @@ def read_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> Union[
|
||||
|
||||
|
||||
def _parse_args(
|
||||
parser: "HfArgumentParser", args: Optional[Union[Dict[str, Any], List[str]]] = None, allow_extra_keys: bool = False
|
||||
) -> Tuple[Any]:
|
||||
parser: "HfArgumentParser", args: Optional[Union[dict[str, Any], list[str]]] = None, allow_extra_keys: bool = False
|
||||
) -> tuple[Any]:
|
||||
args = read_args(args)
|
||||
if isinstance(args, dict):
|
||||
return parser.parse_dict(args, allow_extra_keys=allow_extra_keys)
|
||||
@ -161,31 +159,31 @@ def _check_extra_dependencies(
|
||||
check_version("rouge_chinese", mandatory=True)
|
||||
|
||||
|
||||
def _parse_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS:
|
||||
def _parse_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS:
|
||||
parser = HfArgumentParser(_TRAIN_ARGS)
|
||||
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
|
||||
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
|
||||
|
||||
|
||||
def _parse_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS:
|
||||
def _parse_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS:
|
||||
parser = HfArgumentParser(_INFER_ARGS)
|
||||
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
|
||||
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
|
||||
|
||||
|
||||
def _parse_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS:
|
||||
def _parse_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS:
|
||||
parser = HfArgumentParser(_EVAL_ARGS)
|
||||
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
|
||||
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
|
||||
|
||||
|
||||
def get_ray_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> RayArguments:
|
||||
def get_ray_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> RayArguments:
|
||||
parser = HfArgumentParser(RayArguments)
|
||||
(ray_args,) = _parse_args(parser, args, allow_extra_keys=True)
|
||||
return ray_args
|
||||
|
||||
|
||||
def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS:
|
||||
def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS:
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
|
||||
|
||||
# Setup logging
|
||||
@ -364,9 +362,7 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
|
||||
and training_args.resume_from_checkpoint is not None
|
||||
):
|
||||
logger.warning_rank0(
|
||||
"Add {} to `adapter_name_or_path` to resume training from checkpoint.".format(
|
||||
training_args.resume_from_checkpoint
|
||||
)
|
||||
f"Add {training_args.resume_from_checkpoint} to `adapter_name_or_path` to resume training from checkpoint."
|
||||
)
|
||||
|
||||
# Post-process model arguments
|
||||
@ -382,20 +378,17 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
|
||||
|
||||
# Log on each process the small summary
|
||||
logger.info(
|
||||
"Process rank: {}, world size: {}, device: {}, distributed training: {}, compute dtype: {}".format(
|
||||
training_args.process_index,
|
||||
training_args.world_size,
|
||||
training_args.device,
|
||||
training_args.parallel_mode == ParallelMode.DISTRIBUTED,
|
||||
str(model_args.compute_dtype),
|
||||
)
|
||||
f"Process rank: {training_args.process_index}, "
|
||||
f"world size: {training_args.world_size}, device: {training_args.device}, "
|
||||
f"distributed training: {training_args.parallel_mode == ParallelMode.DISTRIBUTED}, "
|
||||
f"compute dtype: {str(model_args.compute_dtype)}"
|
||||
)
|
||||
transformers.set_seed(training_args.seed)
|
||||
|
||||
return model_args, data_args, training_args, finetuning_args, generating_args
|
||||
|
||||
|
||||
def get_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS:
|
||||
def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS:
|
||||
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
|
||||
|
||||
_set_transformers_logging()
|
||||
@ -426,7 +419,7 @@ def get_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
|
||||
return model_args, data_args, finetuning_args, generating_args
|
||||
|
||||
|
||||
def get_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS:
|
||||
def get_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS:
|
||||
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
|
||||
|
||||
_set_transformers_logging()
|
||||
|
@ -10,9 +10,7 @@ from ..extras.misc import use_ray
|
||||
|
||||
@dataclass
|
||||
class RayArguments:
|
||||
r"""
|
||||
Arguments pertaining to the Ray training.
|
||||
"""
|
||||
r"""Arguments pertaining to the Ray training."""
|
||||
|
||||
ray_run_name: Optional[str] = field(
|
||||
default=None,
|
||||
@ -43,9 +41,7 @@ class RayArguments:
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(RayArguments, Seq2SeqTrainingArguments):
|
||||
r"""
|
||||
Arguments pertaining to the trainer.
|
||||
"""
|
||||
r"""Arguments pertaining to the trainer."""
|
||||
|
||||
def __post_init__(self):
|
||||
Seq2SeqTrainingArguments.__post_init__(self)
|
||||
|
@ -20,9 +20,9 @@ from .model_utils.valuehead import load_valuehead_params
|
||||
|
||||
__all__ = [
|
||||
"QuantizationMethod",
|
||||
"find_all_linear_modules",
|
||||
"load_config",
|
||||
"load_model",
|
||||
"load_tokenizer",
|
||||
"find_all_linear_modules",
|
||||
"load_valuehead_params",
|
||||
]
|
||||
|
@ -81,9 +81,8 @@ def _setup_freeze_tuning(
|
||||
if finetuning_args.use_llama_pro:
|
||||
if num_layers % finetuning_args.freeze_trainable_layers != 0:
|
||||
raise ValueError(
|
||||
"`num_layers` {} should be divisible by `num_layer_trainable` {}.".format(
|
||||
num_layers, finetuning_args.freeze_trainable_layers
|
||||
)
|
||||
f"`num_layers` {num_layers} should be "
|
||||
f"divisible by `num_layer_trainable` {finetuning_args.freeze_trainable_layers}."
|
||||
)
|
||||
|
||||
stride = num_layers // finetuning_args.freeze_trainable_layers
|
||||
@ -178,7 +177,7 @@ def _setup_lora_tuning(
|
||||
}
|
||||
|
||||
for adapter in adapter_to_merge:
|
||||
model: "LoraModel" = PeftModel.from_pretrained(model, adapter, **init_kwargs)
|
||||
model: LoraModel = PeftModel.from_pretrained(model, adapter, **init_kwargs)
|
||||
model = model.merge_and_unload()
|
||||
|
||||
if len(adapter_to_merge) > 0:
|
||||
@ -263,8 +262,7 @@ def init_adapter(
|
||||
finetuning_args: "FinetuningArguments",
|
||||
is_trainable: bool,
|
||||
) -> "PreTrainedModel":
|
||||
r"""
|
||||
Initializes the adapters.
|
||||
r"""Initialize the adapters.
|
||||
|
||||
Support full-parameter, freeze and LoRA training.
|
||||
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypedDict
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
@ -51,9 +51,8 @@ class TokenizerModule(TypedDict):
|
||||
processor: Optional["ProcessorMixin"]
|
||||
|
||||
|
||||
def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
|
||||
r"""
|
||||
Gets arguments to load config/tokenizer/model.
|
||||
def _get_init_kwargs(model_args: "ModelArguments") -> dict[str, Any]:
|
||||
r"""Get arguments to load config/tokenizer/model.
|
||||
|
||||
Note: including inplace operation of model_args.
|
||||
"""
|
||||
@ -68,8 +67,7 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
|
||||
|
||||
|
||||
def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
||||
r"""
|
||||
Loads pretrained tokenizer and optionally loads processor.
|
||||
r"""Load pretrained tokenizer and optionally loads processor.
|
||||
|
||||
Note: including inplace operation of model_args.
|
||||
"""
|
||||
@ -110,9 +108,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
||||
|
||||
|
||||
def load_config(model_args: "ModelArguments") -> "PretrainedConfig":
|
||||
r"""
|
||||
Loads model config.
|
||||
"""
|
||||
r"""Load model config."""
|
||||
init_kwargs = _get_init_kwargs(model_args)
|
||||
return AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
||||
|
||||
@ -124,9 +120,7 @@ def load_model(
|
||||
is_trainable: bool = False,
|
||||
add_valuehead: bool = False,
|
||||
) -> "PreTrainedModel":
|
||||
r"""
|
||||
Loads pretrained model.
|
||||
"""
|
||||
r"""Load pretrained model."""
|
||||
init_kwargs = _get_init_kwargs(model_args)
|
||||
config = load_config(model_args)
|
||||
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
|
||||
@ -194,8 +188,9 @@ def load_model(
|
||||
|
||||
trainable_params, all_param = count_parameters(model)
|
||||
if is_trainable:
|
||||
param_stats = "trainable params: {:,} || all params: {:,} || trainable%: {:.4f}".format(
|
||||
trainable_params, all_param, 100 * trainable_params / all_param
|
||||
param_stats = (
|
||||
f"trainable params: {trainable_params:,} || "
|
||||
f"all params: {all_param:,} || trainable%: {100 * trainable_params / all_param:.4f}"
|
||||
)
|
||||
else:
|
||||
param_stats = f"all params: {all_param:,}"
|
||||
|
@ -21,7 +21,7 @@
|
||||
import inspect
|
||||
from functools import WRAPPER_ASSIGNMENTS, partial, wraps
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -40,9 +40,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
def get_unsloth_gradient_checkpointing_func() -> Callable:
|
||||
class UnslothGradientCheckpointing(torch.autograd.Function):
|
||||
r"""
|
||||
Saves VRAM by smartly offloading to RAM.
|
||||
"""
|
||||
r"""Saves VRAM by smartly offloading to RAM."""
|
||||
|
||||
@staticmethod
|
||||
@torch.cuda.amp.custom_fwd
|
||||
@ -77,13 +75,11 @@ def get_unsloth_gradient_checkpointing_func() -> Callable:
|
||||
|
||||
|
||||
def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable) -> Callable:
|
||||
r"""
|
||||
Only applies gradient checkpointing to trainable layers.
|
||||
"""
|
||||
r"""Only applies gradient checkpointing to trainable layers."""
|
||||
|
||||
@wraps(gradient_checkpointing_func, assigned=WRAPPER_ASSIGNMENTS + ("__self__",))
|
||||
def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
|
||||
module: "torch.nn.Module" = func.__self__
|
||||
module: torch.nn.Module = func.__self__
|
||||
|
||||
has_grad = False
|
||||
if any(param.requires_grad for param in module.parameters()):
|
||||
@ -103,11 +99,10 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
|
||||
|
||||
def _gradient_checkpointing_enable(
|
||||
self: "PreTrainedModel",
|
||||
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None,
|
||||
gradient_checkpointing_kwargs: Optional[dict[str, Any]] = None,
|
||||
use_unsloth_gc: bool = False,
|
||||
) -> None:
|
||||
r"""
|
||||
Activates gradient checkpointing for the current model.
|
||||
r"""Activates gradient checkpointing for the current model.
|
||||
|
||||
Modification of the original method to enable gradient checkpointing for block-wise optimizer.
|
||||
"""
|
||||
@ -134,17 +129,18 @@ def _gradient_checkpointing_enable(
|
||||
|
||||
|
||||
def _fp32_forward_post_hook(
|
||||
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
|
||||
module: "torch.nn.Module", args: tuple["torch.Tensor"], output: "torch.Tensor"
|
||||
) -> "torch.Tensor":
|
||||
return output.to(torch.float32)
|
||||
|
||||
|
||||
def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArguments") -> None:
|
||||
r"""
|
||||
Includes:
|
||||
(1) cast the layernorm in fp32
|
||||
(2) make output embedding layer require grads
|
||||
(3) add the upcasting of the lm_head in fp32
|
||||
r"""Prepare the model before training.
|
||||
|
||||
Include:
|
||||
(1) cast the layernorm in fp32
|
||||
(2) make output embedding layer require grads
|
||||
(3) add the upcasting of the lm_head in fp32.
|
||||
"""
|
||||
if model_args.upcast_layernorm:
|
||||
logger.info_rank0("Upcasting layernorm weights in float32.")
|
||||
|
@ -38,9 +38,7 @@ def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int
|
||||
|
||||
|
||||
def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None:
|
||||
r"""
|
||||
Resize token embeddings.
|
||||
"""
|
||||
r"""Resize token embeddings."""
|
||||
if is_deepspeed_zero3_enabled():
|
||||
import deepspeed # type: ignore
|
||||
|
||||
|
@ -18,7 +18,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -54,14 +54,14 @@ def llama_attention_forward(
|
||||
past_key_value: Optional["Cache"] = None,
|
||||
output_attentions: bool = False,
|
||||
cache_position: Optional["torch.LongTensor"] = None,
|
||||
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None,
|
||||
position_embeddings: Optional[tuple["torch.Tensor", "torch.Tensor"]] = None,
|
||||
**kwargs,
|
||||
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]:
|
||||
) -> tuple["torch.Tensor", Optional["torch.Tensor"], Optional[tuple["torch.Tensor"]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states: "torch.Tensor" = self.q_proj(hidden_states)
|
||||
key_states: "torch.Tensor" = self.k_proj(hidden_states)
|
||||
value_states: "torch.Tensor" = self.v_proj(hidden_states)
|
||||
query_states: torch.Tensor = self.q_proj(hidden_states)
|
||||
key_states: torch.Tensor = self.k_proj(hidden_states)
|
||||
value_states: torch.Tensor = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
@ -139,17 +139,17 @@ def llama_flash_attention_2_forward(
|
||||
past_key_value: Optional["Cache"] = None,
|
||||
output_attentions: bool = False,
|
||||
cache_position: Optional["torch.LongTensor"] = None,
|
||||
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None,
|
||||
position_embeddings: Optional[tuple["torch.Tensor", "torch.Tensor"]] = None,
|
||||
**kwargs,
|
||||
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]:
|
||||
) -> tuple["torch.Tensor", Optional["torch.Tensor"], Optional[tuple["torch.Tensor"]]]:
|
||||
# LlamaFlashAttention2 attention does not support output_attentions
|
||||
output_attentions = False
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states: "torch.Tensor" = self.q_proj(hidden_states)
|
||||
key_states: "torch.Tensor" = self.k_proj(hidden_states)
|
||||
value_states: "torch.Tensor" = self.v_proj(hidden_states)
|
||||
query_states: torch.Tensor = self.q_proj(hidden_states)
|
||||
key_states: torch.Tensor = self.k_proj(hidden_states)
|
||||
value_states: torch.Tensor = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
@ -209,7 +209,7 @@ def llama_flash_attention_2_forward(
|
||||
if is_transformers_version_greater_than("4.43.0"):
|
||||
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
||||
|
||||
attn_output: "torch.Tensor" = _flash_attention_forward(
|
||||
attn_output: torch.Tensor = _flash_attention_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
@ -221,7 +221,7 @@ def llama_flash_attention_2_forward(
|
||||
is_causal=self.is_causal,
|
||||
)
|
||||
else:
|
||||
attn_output: "torch.Tensor" = self._flash_attention_forward(
|
||||
attn_output: torch.Tensor = self._flash_attention_forward(
|
||||
query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate
|
||||
)
|
||||
|
||||
@ -254,9 +254,9 @@ def llama_sdpa_attention_forward(
|
||||
past_key_value: Optional["Cache"] = None,
|
||||
output_attentions: bool = False,
|
||||
cache_position: Optional["torch.LongTensor"] = None,
|
||||
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None,
|
||||
position_embeddings: Optional[tuple["torch.Tensor", "torch.Tensor"]] = None,
|
||||
**kwargs,
|
||||
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]:
|
||||
) -> tuple["torch.Tensor", Optional["torch.Tensor"], Optional[tuple["torch.Tensor"]]]:
|
||||
if output_attentions:
|
||||
transformers_logger.warning_once(
|
||||
"SDPA does not support `output_attentions=True`. Falling back to the vanilla attention"
|
||||
@ -274,9 +274,9 @@ def llama_sdpa_attention_forward(
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states: "torch.Tensor" = self.q_proj(hidden_states)
|
||||
key_states: "torch.Tensor" = self.k_proj(hidden_states)
|
||||
value_states: "torch.Tensor" = self.v_proj(hidden_states)
|
||||
query_states: torch.Tensor = self.q_proj(hidden_states)
|
||||
key_states: torch.Tensor = self.k_proj(hidden_states)
|
||||
value_states: torch.Tensor = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, List
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...extras import logging
|
||||
from .visual import COMPOSITE_MODELS
|
||||
@ -25,10 +25,8 @@ if TYPE_CHECKING:
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]:
|
||||
r"""
|
||||
Finds all available modules to apply LoRA, GaLore or APOLLO.
|
||||
"""
|
||||
def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> list[str]:
|
||||
r"""Find all available modules to apply LoRA, GaLore or APOLLO."""
|
||||
model_type = getattr(model.config, "model_type", None)
|
||||
forbidden_modules = {"lm_head"}
|
||||
if model_type == "chatglm":
|
||||
@ -54,10 +52,8 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
|
||||
return list(module_names)
|
||||
|
||||
|
||||
def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], num_layer_trainable: int) -> List[str]:
|
||||
r"""
|
||||
Finds the modules in the expanded blocks to apply lora.
|
||||
"""
|
||||
def find_expanded_modules(model: "PreTrainedModel", target_modules: list[str], num_layer_trainable: int) -> list[str]:
|
||||
r"""Find the modules in the expanded blocks to apply lora."""
|
||||
num_layers = getattr(model.config, "num_hidden_layers", None)
|
||||
if not num_layers:
|
||||
raise ValueError("Model was not supported.")
|
||||
|
@ -12,7 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Sequence
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
@ -34,9 +35,7 @@ def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: Sequence["torch
|
||||
|
||||
|
||||
def add_z3_leaf_module(model: "PreTrainedModel") -> None:
|
||||
r"""
|
||||
Sets module as a leaf module to skip partitioning in deepspeed zero3.
|
||||
"""
|
||||
r"""Set module as a leaf module to skip partitioning in deepspeed zero3."""
|
||||
if not is_deepspeed_zero3_enabled():
|
||||
return
|
||||
|
||||
|
@ -37,7 +37,7 @@
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -59,8 +59,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
|
||||
r"""
|
||||
Gets the sequnce lengths in the current batch.
|
||||
r"""Get the sequnce lengths in the current batch.
|
||||
|
||||
e.g.
|
||||
```python
|
||||
@ -76,7 +75,7 @@ def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
|
||||
bsz = attention_mask.size(0)
|
||||
dtype, device = attention_mask.dtype, attention_mask.device
|
||||
max_num = torch.max(attention_mask).item()
|
||||
counts: "torch.Tensor" = torch.zeros((bsz, max_num), dtype=dtype, device=device)
|
||||
counts: torch.Tensor = torch.zeros((bsz, max_num), dtype=dtype, device=device)
|
||||
for i in range(max_num):
|
||||
counts[:, i] = torch.sum(attention_mask == (i + 1), dim=-1)
|
||||
|
||||
@ -85,9 +84,8 @@ def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
|
||||
return seqlens
|
||||
|
||||
|
||||
def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "torch.Tensor", int]:
|
||||
r"""
|
||||
Prepares the indices and seqlens for flash attn varlen function.
|
||||
def get_unpad_data(attention_mask: "torch.Tensor") -> tuple["torch.Tensor", "torch.Tensor", int]:
|
||||
r"""Prepare the indices and seqlens for flash attn varlen function.
|
||||
|
||||
Returns:
|
||||
indices: indices of non-masked tokens from the flattened sequence.
|
||||
@ -106,6 +104,7 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
|
||||
[0, 2, 5, 6, 8, 11]
|
||||
3
|
||||
```
|
||||
|
||||
"""
|
||||
seqlens_in_batch = get_seqlens_in_batch(attention_mask)
|
||||
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
||||
|
@ -19,7 +19,7 @@
|
||||
import os
|
||||
import random
|
||||
from enum import Enum, unique
|
||||
from typing import TYPE_CHECKING, Any, Dict, List
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
@ -43,9 +43,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
@unique
|
||||
class QuantizationMethod(str, Enum):
|
||||
r"""
|
||||
Borrowed from `transformers.utils.quantization_config.QuantizationMethod`.
|
||||
"""
|
||||
r"""Borrowed from `transformers.utils.quantization_config.QuantizationMethod`."""
|
||||
|
||||
BITS_AND_BYTES = "bitsandbytes"
|
||||
GPTQ = "gptq"
|
||||
@ -56,10 +54,8 @@ class QuantizationMethod(str, Enum):
|
||||
HQQ = "hqq"
|
||||
|
||||
|
||||
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[Dict[str, Any]]:
|
||||
r"""
|
||||
Prepares the tokenized dataset to perform AutoGPTQ. Do not use tensor output for JSON serialization.
|
||||
"""
|
||||
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> list[dict[str, Any]]:
|
||||
r"""Prepare the tokenized dataset to perform AutoGPTQ. Do not use tensor output for JSON serialization."""
|
||||
if os.path.isfile(model_args.export_quantization_dataset):
|
||||
data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None)
|
||||
data_files = model_args.export_quantization_dataset
|
||||
@ -84,7 +80,7 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod
|
||||
raise ValueError("Cannot find satisfying example, considering decrease `export_quantization_maxlen`.")
|
||||
|
||||
sample_idx = random.randint(0, len(dataset) - 1)
|
||||
sample: Dict[str, "torch.Tensor"] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
|
||||
sample: dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
|
||||
n_try += 1
|
||||
if sample["input_ids"].size(1) > maxlen:
|
||||
break # TODO: fix large maxlen
|
||||
@ -101,11 +97,9 @@ def configure_quantization(
|
||||
config: "PretrainedConfig",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
model_args: "ModelArguments",
|
||||
init_kwargs: Dict[str, Any],
|
||||
init_kwargs: dict[str, Any],
|
||||
) -> None:
|
||||
r"""
|
||||
Priority: PTQ-quantized (train/infer) > AutoGPTQ (export) > On-the-fly quantization (train/infer)
|
||||
"""
|
||||
r"""Priority: PTQ-quantized (train/infer) > AutoGPTQ (export) > On-the-fly quantization (train/infer)."""
|
||||
if getattr(config, "quantization_config", None): # ptq
|
||||
if model_args.quantization_bit is not None:
|
||||
logger.warning_rank0("`quantization_bit` will not affect on the PTQ-quantized models.")
|
||||
@ -113,7 +107,7 @@ def configure_quantization(
|
||||
if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
|
||||
raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.")
|
||||
|
||||
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
||||
quantization_config: dict[str, Any] = getattr(config, "quantization_config", None)
|
||||
quant_method = quantization_config.get("quant_method", "")
|
||||
|
||||
if quant_method == QuantizationMethod.GPTQ:
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.misc import get_current_device
|
||||
@ -29,7 +29,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
def _get_unsloth_kwargs(
|
||||
config: "PretrainedConfig", model_name_or_path: str, model_args: "ModelArguments"
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"model_name": model_name_or_path,
|
||||
"max_seq_length": model_args.model_max_length or 4096,
|
||||
@ -47,10 +47,8 @@ def _get_unsloth_kwargs(
|
||||
def load_unsloth_pretrained_model(
|
||||
config: "PretrainedConfig", model_args: "ModelArguments"
|
||||
) -> Optional["PreTrainedModel"]:
|
||||
r"""
|
||||
Optionally loads pretrained model with unsloth. Used in training.
|
||||
"""
|
||||
from unsloth import FastLanguageModel
|
||||
r"""Optionally load pretrained model with unsloth. Used in training."""
|
||||
from unsloth import FastLanguageModel # type: ignore
|
||||
|
||||
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.model_name_or_path, model_args)
|
||||
try:
|
||||
@ -64,12 +62,10 @@ def load_unsloth_pretrained_model(
|
||||
|
||||
|
||||
def get_unsloth_peft_model(
|
||||
model: "PreTrainedModel", model_args: "ModelArguments", peft_kwargs: Dict[str, Any]
|
||||
model: "PreTrainedModel", model_args: "ModelArguments", peft_kwargs: dict[str, Any]
|
||||
) -> "PreTrainedModel":
|
||||
r"""
|
||||
Gets the peft model for the pretrained model with unsloth. Used in training.
|
||||
"""
|
||||
from unsloth import FastLanguageModel
|
||||
r"""Get the peft model for the pretrained model with unsloth. Used in training."""
|
||||
from unsloth import FastLanguageModel # type: ignore
|
||||
|
||||
unsloth_peft_kwargs = {
|
||||
"model": model,
|
||||
@ -82,10 +78,8 @@ def get_unsloth_peft_model(
|
||||
def load_unsloth_peft_model(
|
||||
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
|
||||
) -> "PreTrainedModel":
|
||||
r"""
|
||||
Loads peft model with unsloth. Used in both training and inference.
|
||||
"""
|
||||
from unsloth import FastLanguageModel
|
||||
r"""Load peft model with unsloth. Used in both training and inference."""
|
||||
from unsloth import FastLanguageModel # type: ignore
|
||||
|
||||
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args)
|
||||
try:
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from transformers.utils import cached_file
|
||||
@ -30,9 +30,8 @@ if TYPE_CHECKING:
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
|
||||
r"""
|
||||
Loads value head parameters from Hugging Face Hub or local disk.
|
||||
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> dict[str, torch.Tensor]:
|
||||
r"""Load value head parameters from Hugging Face Hub or local disk.
|
||||
|
||||
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
|
||||
"""
|
||||
|
@ -15,8 +15,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
@ -40,9 +41,9 @@ transformers_logger = transformers.utils.logging.get_logger(__name__)
|
||||
class CompositeModel:
|
||||
model_type: str
|
||||
projector_key: str
|
||||
vision_model_keys: List[str]
|
||||
language_model_keys: List[str]
|
||||
lora_conflict_keys: List[str]
|
||||
vision_model_keys: list[str]
|
||||
language_model_keys: list[str]
|
||||
lora_conflict_keys: list[str]
|
||||
|
||||
def get_projector(self, module: "torch.nn.Module") -> "torch.nn.Module":
|
||||
for key in self.projector_key.split("."):
|
||||
@ -51,15 +52,15 @@ class CompositeModel:
|
||||
return module
|
||||
|
||||
|
||||
COMPOSITE_MODELS: Dict[str, "CompositeModel"] = {}
|
||||
COMPOSITE_MODELS: dict[str, "CompositeModel"] = {}
|
||||
|
||||
|
||||
def _register_composite_model(
|
||||
model_type: str,
|
||||
projector_key: Optional[str] = None,
|
||||
vision_model_keys: Optional[List[str]] = None,
|
||||
language_model_keys: Optional[List[str]] = None,
|
||||
lora_conflict_keys: Optional[List[str]] = None,
|
||||
vision_model_keys: Optional[list[str]] = None,
|
||||
language_model_keys: Optional[list[str]] = None,
|
||||
lora_conflict_keys: Optional[list[str]] = None,
|
||||
):
|
||||
COMPOSITE_MODELS[model_type] = CompositeModel(
|
||||
model_type=model_type,
|
||||
@ -116,12 +117,10 @@ class LlavaMultiModalProjectorForYiVLForVLLM(LlavaMultiModalProjectorForYiVL):
|
||||
|
||||
|
||||
def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArguments") -> None:
|
||||
r"""
|
||||
Casts projector output to half precision for fine-tuning quantized VLMs.
|
||||
"""
|
||||
r"""Cast projector output to half precision for fine-tuning quantized VLMs."""
|
||||
|
||||
def _mm_projector_forward_post_hook(
|
||||
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
|
||||
module: "torch.nn.Module", args: tuple["torch.Tensor"], output: "torch.Tensor"
|
||||
) -> "torch.Tensor":
|
||||
return output.to(model_args.compute_dtype)
|
||||
|
||||
@ -137,9 +136,7 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen
|
||||
|
||||
|
||||
def configure_visual_model(config: "PretrainedConfig") -> None:
|
||||
r"""
|
||||
Patches VLMs before loading them.
|
||||
"""
|
||||
r"""Patch VLMs before loading them."""
|
||||
if getattr(config, "text_config", None) and not getattr(config, "hidden_size", None):
|
||||
# required for ds zero3 and valuehead models
|
||||
setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))
|
||||
@ -149,10 +146,8 @@ def configure_visual_model(config: "PretrainedConfig") -> None:
|
||||
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL
|
||||
|
||||
|
||||
def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "FinetuningArguments") -> Set[str]:
|
||||
r"""
|
||||
Freezes vision tower and language model for VLM full/freeze tuning.
|
||||
"""
|
||||
def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "FinetuningArguments") -> set[str]:
|
||||
r"""Freeze vision tower and language model for VLM full/freeze tuning."""
|
||||
model_type = getattr(config, "model_type", None)
|
||||
forbidden_modules = set()
|
||||
if model_type in COMPOSITE_MODELS:
|
||||
@ -175,9 +170,7 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
|
||||
|
||||
|
||||
def get_image_seqlen(config: "PretrainedConfig") -> int:
|
||||
r"""
|
||||
Computes the number of special tokens per image.
|
||||
"""
|
||||
r"""Compute the number of special tokens per image."""
|
||||
model_type = getattr(config, "model_type", None)
|
||||
if model_type == "llava":
|
||||
image_seqlen = (config.vision_config.image_size // config.vision_config.patch_size) ** 2
|
||||
@ -192,17 +185,13 @@ def get_image_seqlen(config: "PretrainedConfig") -> int:
|
||||
|
||||
|
||||
def get_patch_size(config: "PretrainedConfig", processor: "ProcessorMixin") -> int:
|
||||
r"""
|
||||
Computes the patch size of the vit.
|
||||
"""
|
||||
r"""Compute the patch size of the vit."""
|
||||
patch_size = getattr(config.vision_config, "patch_size", getattr(processor, "patch_size", -1))
|
||||
return patch_size
|
||||
|
||||
|
||||
def get_vision_feature_select_strategy(config: "PretrainedConfig", processor: "ProcessorMixin") -> int:
|
||||
r"""
|
||||
Get the vision_feature_select_strategy.
|
||||
"""
|
||||
r"""Get the vision_feature_select_strategy."""
|
||||
vision_feature_select_strategy = getattr(
|
||||
config, "vision_feature_select_strategy", getattr(processor, "vision_feature_select_strategy", "default")
|
||||
)
|
||||
@ -211,10 +200,8 @@ def get_vision_feature_select_strategy(config: "PretrainedConfig", processor: "P
|
||||
|
||||
def patch_target_modules(
|
||||
model: "PreTrainedModel", finetuning_args: "FinetuningArguments", target_modules: Sequence[str]
|
||||
) -> List[str]:
|
||||
r"""
|
||||
Freezes vision tower for VLM LoRA tuning.
|
||||
"""
|
||||
) -> list[str]:
|
||||
r"""Freezes vision tower for VLM LoRA tuning."""
|
||||
model_type = getattr(model.config, "model_type", None)
|
||||
if model_type in COMPOSITE_MODELS:
|
||||
forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
from peft import PeftModel
|
||||
@ -93,7 +93,7 @@ def patch_config(
|
||||
config: "PretrainedConfig",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
model_args: "ModelArguments",
|
||||
init_kwargs: Dict[str, Any],
|
||||
init_kwargs: dict[str, Any],
|
||||
is_trainable: bool,
|
||||
) -> None:
|
||||
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
|
||||
|
@ -19,7 +19,7 @@ import sys
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
@ -56,7 +56,8 @@ logger = logging.get_logger(__name__)
|
||||
def fix_valuehead_checkpoint(
|
||||
model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool
|
||||
) -> None:
|
||||
r"""
|
||||
r"""Fix the valuehead checkpoint files.
|
||||
|
||||
The model is already unwrapped.
|
||||
|
||||
There are three cases:
|
||||
@ -72,10 +73,10 @@ def fix_valuehead_checkpoint(
|
||||
if safe_serialization:
|
||||
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
|
||||
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
|
||||
state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
|
||||
state_dict: dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
|
||||
else:
|
||||
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
|
||||
state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
|
||||
state_dict: dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
|
||||
|
||||
os.remove(path_to_checkpoint)
|
||||
decoder_state_dict, v_head_state_dict = {}, {}
|
||||
@ -98,9 +99,7 @@ def fix_valuehead_checkpoint(
|
||||
|
||||
|
||||
class FixValueHeadModelCallback(TrainerCallback):
|
||||
r"""
|
||||
A callback for fixing the checkpoint for valuehead models.
|
||||
"""
|
||||
r"""A callback for fixing the checkpoint for valuehead models."""
|
||||
|
||||
@override
|
||||
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
@ -112,9 +111,7 @@ class FixValueHeadModelCallback(TrainerCallback):
|
||||
|
||||
|
||||
class SaveProcessorCallback(TrainerCallback):
|
||||
r"""
|
||||
A callback for saving the processor.
|
||||
"""
|
||||
r"""A callback for saving the processor."""
|
||||
|
||||
def __init__(self, processor: "ProcessorMixin") -> None:
|
||||
self.processor = processor
|
||||
@ -132,9 +129,7 @@ class SaveProcessorCallback(TrainerCallback):
|
||||
|
||||
|
||||
class PissaConvertCallback(TrainerCallback):
|
||||
r"""
|
||||
A callback for converting the PiSSA adapter to a normal one.
|
||||
"""
|
||||
r"""A callback for converting the PiSSA adapter to a normal one."""
|
||||
|
||||
@override
|
||||
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
@ -177,9 +172,7 @@ class PissaConvertCallback(TrainerCallback):
|
||||
|
||||
|
||||
class LogCallback(TrainerCallback):
|
||||
r"""
|
||||
A callback for logging training and evaluation status.
|
||||
"""
|
||||
r"""A callback for logging training and evaluation status."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Progress
|
||||
@ -188,7 +181,7 @@ class LogCallback(TrainerCallback):
|
||||
self.max_steps = 0
|
||||
self.elapsed_time = ""
|
||||
self.remaining_time = ""
|
||||
self.thread_pool: Optional["ThreadPoolExecutor"] = None
|
||||
self.thread_pool: Optional[ThreadPoolExecutor] = None
|
||||
# Status
|
||||
self.aborted = False
|
||||
self.do_train = False
|
||||
@ -219,7 +212,7 @@ class LogCallback(TrainerCallback):
|
||||
self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
|
||||
self.remaining_time = str(timedelta(seconds=int(remaining_time)))
|
||||
|
||||
def _write_log(self, output_dir: str, logs: Dict[str, Any]) -> None:
|
||||
def _write_log(self, output_dir: str, logs: dict[str, Any]) -> None:
|
||||
with open(os.path.join(output_dir, TRAINER_LOG), "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(logs) + "\n")
|
||||
|
||||
@ -348,9 +341,7 @@ class LogCallback(TrainerCallback):
|
||||
|
||||
|
||||
class ReporterCallback(TrainerCallback):
|
||||
r"""
|
||||
A callback for reporting training status to external logger.
|
||||
"""
|
||||
r"""A callback for reporting training status to external logger."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -19,7 +19,7 @@ import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -129,15 +129,11 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
|
||||
@override
|
||||
def get_batch_samples(self, epoch_iterator, num_batches):
|
||||
r"""
|
||||
Replaces the method of KTO Trainer with the one of the standard Trainer.
|
||||
"""
|
||||
r"""Replace the method of DPO Trainer with the one of the standard Trainer."""
|
||||
return Trainer.get_batch_samples(self, epoch_iterator, num_batches)
|
||||
|
||||
def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
|
||||
r"""
|
||||
Computes ORPO's odds ratio (OR) loss for batched log probabilities of the policy model.
|
||||
"""
|
||||
r"""Compute ORPO's odds ratio (OR) loss for batched log probabilities of the policy model."""
|
||||
log_odds = (chosen_logps - rejected_logps) - (
|
||||
torch.log1p(-torch.exp(chosen_logps)) - torch.log1p(-torch.exp(rejected_logps))
|
||||
)
|
||||
@ -147,9 +143,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
return orpo_loss
|
||||
|
||||
def simpo_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
|
||||
r"""
|
||||
Computes SimPO loss for batched log probabilities of the policy model.
|
||||
"""
|
||||
r"""Compute SimPO loss for batched log probabilities of the policy model."""
|
||||
pi_logratios = chosen_logps - rejected_logps
|
||||
gamma_logratios = self.simpo_gamma / self.beta
|
||||
logits = pi_logratios - gamma_logratios
|
||||
@ -162,10 +156,8 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
policy_rejected_logps: "torch.Tensor",
|
||||
reference_chosen_logps: Optional["torch.Tensor"],
|
||||
reference_rejected_logps: Optional["torch.Tensor"],
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
r"""
|
||||
Computes loss for preference learning.
|
||||
"""
|
||||
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
r"""Compute loss for preference learning."""
|
||||
if not self.finetuning_args.use_ref_model:
|
||||
if self.loss_type == "orpo":
|
||||
losses = self.odds_ratio_loss(policy_chosen_logps, policy_rejected_logps)
|
||||
@ -185,17 +177,16 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
|
||||
@override
|
||||
def concatenated_forward(
|
||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
r"""
|
||||
Computes the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO.
|
||||
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"]
|
||||
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
r"""Compute the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO.
|
||||
|
||||
Otherwise the average log probabilities.
|
||||
"""
|
||||
if self.finetuning_args.use_ref_model:
|
||||
batch = nested_detach(batch, clone=True) # avoid error
|
||||
|
||||
all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
|
||||
all_logits: torch.Tensor = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
|
||||
all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"])
|
||||
if self.loss_type in ["ipo", "orpo", "simpo"]:
|
||||
all_logps = all_logps / valid_length
|
||||
@ -212,11 +203,9 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
|
||||
@override
|
||||
def compute_reference_log_probs(
|
||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
||||
) -> Tuple[Optional["torch.Tensor"], Optional["torch.Tensor"]]:
|
||||
r"""
|
||||
Computes log probabilities of the reference model.
|
||||
"""
|
||||
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"]
|
||||
) -> tuple[Optional["torch.Tensor"], Optional["torch.Tensor"]]:
|
||||
r"""Compute log probabilities of the reference model."""
|
||||
if not self.finetuning_args.use_ref_model:
|
||||
return None, None
|
||||
|
||||
@ -236,12 +225,10 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
def get_batch_loss_metrics(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
batch: Dict[str, "torch.Tensor"],
|
||||
batch: dict[str, "torch.Tensor"],
|
||||
train_eval: Literal["train", "eval"] = "train",
|
||||
) -> Tuple["torch.Tensor", Dict[str, "torch.Tensor"]]:
|
||||
r"""
|
||||
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
|
||||
"""
|
||||
) -> tuple["torch.Tensor", dict[str, "torch.Tensor"]]:
|
||||
r"""Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
|
||||
metrics = {}
|
||||
(
|
||||
policy_chosen_logps,
|
||||
@ -279,18 +266,14 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
|
||||
@override
|
||||
def compute_loss(
|
||||
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
|
||||
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
|
||||
r"""
|
||||
Subclass and override to accept extra kwargs.
|
||||
"""
|
||||
self, model: "PreTrainedModel", inputs: dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
|
||||
) -> Union["torch.Tensor", tuple["torch.Tensor", list["torch.Tensor"]]]:
|
||||
r"""Subclass and override to accept extra kwargs."""
|
||||
return super().compute_loss(model, inputs, return_outputs)
|
||||
|
||||
@override
|
||||
def log(self, logs: Dict[str, float], *args, **kwargs) -> None:
|
||||
r"""
|
||||
Log `logs` on the various objects watching training, including stored metrics.
|
||||
"""
|
||||
def log(self, logs: dict[str, float], *args, **kwargs) -> None:
|
||||
r"""Log `logs` on the various objects watching training, including stored metrics."""
|
||||
# logs either has "loss" or "eval_loss"
|
||||
train_eval = "train" if "loss" in logs else "eval"
|
||||
# Add averaged stored metrics to logs
|
||||
|
@ -15,7 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from ...data import PairwiseDataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
@ -38,7 +38,7 @@ def run_dpo(
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||
callbacks: Optional[list["TrainerCallback"]] = None,
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
|
@ -19,7 +19,7 @@ import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import Trainer
|
||||
@ -120,9 +120,7 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
|
||||
@override
|
||||
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
|
||||
r"""
|
||||
Replaces the sequential sampler of KTO Trainer created by trl with the random sampler.
|
||||
"""
|
||||
r"""Replace the sequential sampler of KTO Trainer created by trl with the random sampler."""
|
||||
if self.finetuning_args.disable_shuffling:
|
||||
return torch.utils.data.SequentialSampler(self.train_dataset)
|
||||
|
||||
@ -130,18 +128,14 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
|
||||
@override
|
||||
def get_batch_samples(self, epoch_iterator, num_batches):
|
||||
r"""
|
||||
Replaces the method of KTO Trainer with the one of the standard Trainer.
|
||||
"""
|
||||
r"""Replace the method of KTO Trainer with the one of the standard Trainer."""
|
||||
return Trainer.get_batch_samples(self, epoch_iterator, num_batches)
|
||||
|
||||
@override
|
||||
def forward(
|
||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
r"""
|
||||
Runs forward pass and computes the log probabilities.
|
||||
"""
|
||||
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
|
||||
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
r"""Run forward pass and computes the log probabilities."""
|
||||
batch = nested_detach(batch, clone=True) # avoid error
|
||||
model_inputs = {
|
||||
"input_ids": batch[f"{prefix}input_ids"],
|
||||
@ -171,8 +165,8 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
|
||||
@override
|
||||
def concatenated_forward(
|
||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"]
|
||||
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
target_logits, target_logps, target_logps_avg = self.forward(model, batch)
|
||||
with torch.no_grad():
|
||||
_, kl_logps, _ = self.forward(model, batch, prefix="kl_")
|
||||
@ -189,11 +183,9 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
|
||||
@override
|
||||
def compute_reference_log_probs(
|
||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
r"""
|
||||
Computes log probabilities of the reference model.
|
||||
"""
|
||||
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"]
|
||||
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
r"""Compute log probabilities of the reference model."""
|
||||
if self.ref_model is None:
|
||||
ref_model = model
|
||||
ref_context = self.accelerator.unwrap_model(model).disable_adapter()
|
||||
@ -212,11 +204,9 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
def get_batch_loss_metrics(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
batch: Dict[str, "torch.Tensor"],
|
||||
) -> Tuple["torch.Tensor", Dict[str, "torch.Tensor"]]:
|
||||
r"""
|
||||
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
|
||||
"""
|
||||
batch: dict[str, "torch.Tensor"],
|
||||
) -> tuple["torch.Tensor", dict[str, "torch.Tensor"]]:
|
||||
r"""Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
|
||||
metrics = {}
|
||||
(
|
||||
policy_chosen_logps,
|
||||
@ -262,18 +252,14 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
|
||||
@override
|
||||
def compute_loss(
|
||||
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
|
||||
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
|
||||
r"""
|
||||
Subclass and override to accept extra kwargs.
|
||||
"""
|
||||
self, model: "PreTrainedModel", inputs: dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
|
||||
) -> Union["torch.Tensor", tuple["torch.Tensor", list["torch.Tensor"]]]:
|
||||
r"""Subclass and override to accept extra kwargs."""
|
||||
return super().compute_loss(model, inputs, return_outputs)
|
||||
|
||||
@override
|
||||
def log(self, logs: Dict[str, float], *args, **kwargs) -> None:
|
||||
r"""
|
||||
Log `logs` on the various objects watching training, including stored metrics.
|
||||
"""
|
||||
def log(self, logs: dict[str, float], *args, **kwargs) -> None:
|
||||
r"""Log `logs` on the various objects watching training, including stored metrics."""
|
||||
# logs either has "loss" or "eval_loss"
|
||||
train_eval = "train" if "loss" in logs else "eval"
|
||||
prefix = "eval_" if train_eval == "eval" else ""
|
||||
@ -291,7 +277,7 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
|
||||
metric_list = torch.tensor(metric_list, dtype=torch.float).to(self.accelerator.device)
|
||||
metric_list = self.accelerator.reduce(metric_list, "sum").tolist()
|
||||
metric_dict: Dict[str, float] = dict(zip(key_list, metric_list))
|
||||
metric_dict: dict[str, float] = dict(zip(key_list, metric_list))
|
||||
for split in ["chosen", "rejected"]: # accumulate average metrics from sums and lengths
|
||||
if f"count/{split}" in metric_dict:
|
||||
for key in ("rewards", "logps", "logits"):
|
||||
|
@ -15,7 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from ...data import KTODataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
@ -37,7 +37,7 @@ def run_kto(
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||
callbacks: Optional[list["TrainerCallback"]] = None,
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
import json
|
||||
from contextlib import nullcontext
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional
|
||||
from typing import TYPE_CHECKING, Literal, Optional
|
||||
|
||||
import torch
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
@ -31,10 +31,8 @@ if TYPE_CHECKING:
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
|
||||
def get_rewards_from_server(server_url: str, messages: List[str]) -> List["torch.Tensor"]:
|
||||
r"""
|
||||
Gets reward scores from the API server.
|
||||
"""
|
||||
def get_rewards_from_server(server_url: str, messages: list[str]) -> list["torch.Tensor"]:
|
||||
r"""Get reward scores from the API server."""
|
||||
headers = {"Content-Type": "application/json"}
|
||||
payload = {"model": "model", "messages": messages}
|
||||
response = requests.post(server_url, json=payload, headers=headers)
|
||||
@ -43,9 +41,7 @@ def get_rewards_from_server(server_url: str, messages: List[str]) -> List["torch
|
||||
|
||||
|
||||
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
|
||||
r"""
|
||||
Replaces the default/reward modules in the model. The model is already unwrapped.
|
||||
"""
|
||||
r"""Replace the default/reward modules in the model. The model is already unwrapped."""
|
||||
v_head_layer = model.v_head.summary
|
||||
if is_deepspeed_zero3_enabled():
|
||||
import deepspeed # type: ignore
|
||||
@ -66,10 +62,8 @@ def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["d
|
||||
v_head_layer.bias.data = model.get_buffer(f"{target}_head_bias").detach().clone().to(device)
|
||||
|
||||
|
||||
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, "torch.Tensor"]:
|
||||
r"""
|
||||
Dumps the layernorm parameters in the model. The model is already unwrapped (and gathered).
|
||||
"""
|
||||
def dump_layernorm(model: "PreTrainedModel") -> dict[str, "torch.Tensor"]:
|
||||
r"""Dump the layernorm parameters in the model. The model is already unwrapped (and gathered)."""
|
||||
layer_norm_params = {}
|
||||
for name, param in model.named_parameters():
|
||||
if param.data.dtype == torch.float32:
|
||||
@ -79,10 +73,8 @@ def dump_layernorm(model: "PreTrainedModel") -> Dict[str, "torch.Tensor"]:
|
||||
return layer_norm_params
|
||||
|
||||
|
||||
def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
|
||||
r"""
|
||||
Restores the layernorm parameters in the model. The model is already unwrapped (and gathered).
|
||||
"""
|
||||
def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[dict[str, "torch.Tensor"]] = None) -> None:
|
||||
r"""Restore the layernorm parameters in the model. The model is already unwrapped (and gathered)."""
|
||||
for name, param in model.named_parameters():
|
||||
if name in layernorm_params:
|
||||
param.data = layernorm_params[name]
|
||||
|
@ -20,7 +20,7 @@ import os
|
||||
import sys
|
||||
import warnings
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
@ -62,9 +62,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
r"""
|
||||
Inherits PPOTrainer.
|
||||
"""
|
||||
r"""Inherit PPOTrainer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -72,7 +70,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]],
|
||||
callbacks: Optional[list["TrainerCallback"]],
|
||||
model: "AutoModelForCausalLMWithValueHead",
|
||||
reward_model: Optional["AutoModelForCausalLMWithValueHead"],
|
||||
ref_model: Optional["AutoModelForCausalLMWithValueHead"],
|
||||
@ -187,9 +185,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
self.add_callback(BAdamCallback)
|
||||
|
||||
def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
|
||||
r"""
|
||||
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
|
||||
"""
|
||||
r"""Implement training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer."""
|
||||
if resume_from_checkpoint is not None:
|
||||
raise ValueError("`resume_from_checkpoint` will be supported in the future version.")
|
||||
|
||||
@ -221,9 +217,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
logger.info_rank0(f" Num Epochs = {num_train_epochs:,}")
|
||||
logger.info_rank0(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}")
|
||||
logger.info_rank0(
|
||||
" Total train batch size (w. parallel, buffer, distributed & accumulation) = {:,}".format(
|
||||
total_train_batch_size
|
||||
)
|
||||
f" Total train batch size (w. parallel, buffer, distributed & accumulation) = {total_train_batch_size:,}"
|
||||
)
|
||||
logger.info_rank0(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps:,}")
|
||||
logger.info_rank0(f" Num optimization epochs per batch = {self.finetuning_args.ppo_epochs:,}")
|
||||
@ -339,21 +333,19 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
return lr_scheduler
|
||||
|
||||
@torch.no_grad()
|
||||
def get_inputs(self, batch: Dict[str, "torch.Tensor"]) -> Tuple[List["torch.Tensor"], List["torch.Tensor"]]:
|
||||
r"""
|
||||
Generates model's responses given queries.
|
||||
"""
|
||||
def get_inputs(self, batch: dict[str, "torch.Tensor"]) -> tuple[list["torch.Tensor"], list["torch.Tensor"]]:
|
||||
r"""Generate model's responses given queries."""
|
||||
if batch["input_ids"].size(0) == 1: # handle llama2 ppo with gradient accumulation > 1
|
||||
start_index = (batch["input_ids"][0] != self.tokenizer.pad_token_id).nonzero()[0].item()
|
||||
for k, v in batch.items():
|
||||
batch[k] = v[:, start_index:]
|
||||
|
||||
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
unwrapped_model: AutoModelForCausalLMWithValueHead = self.accelerator.unwrap_model(self.model)
|
||||
if self.model_args.upcast_layernorm:
|
||||
layernorm_params = dump_layernorm(unwrapped_model)
|
||||
|
||||
generate_output: "torch.Tensor" = unwrapped_model.generate(
|
||||
generate_output: torch.Tensor = unwrapped_model.generate(
|
||||
generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch
|
||||
)
|
||||
if self.model_args.upcast_layernorm:
|
||||
@ -381,11 +373,10 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
@torch.no_grad()
|
||||
def get_rewards(
|
||||
self,
|
||||
queries: List["torch.Tensor"],
|
||||
responses: List["torch.Tensor"],
|
||||
) -> List["torch.Tensor"]:
|
||||
r"""
|
||||
Computes scores using given reward model.
|
||||
queries: list["torch.Tensor"],
|
||||
responses: list["torch.Tensor"],
|
||||
) -> list["torch.Tensor"]:
|
||||
r"""Compute scores using given reward model.
|
||||
|
||||
Both inputs and outputs are put on CPU.
|
||||
"""
|
||||
@ -394,8 +385,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
messages = self.tokenizer.batch_decode(token_ids, skip_special_tokens=False)
|
||||
return get_rewards_from_server(self.reward_model, messages)
|
||||
|
||||
batch: Dict[str, "torch.Tensor"] = self.prepare_model_inputs(queries, responses)
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
batch: dict[str, torch.Tensor] = self.prepare_model_inputs(queries, responses)
|
||||
unwrapped_model: AutoModelForCausalLMWithValueHead = self.accelerator.unwrap_model(self.model)
|
||||
|
||||
if self.finetuning_args.reward_model_type == "lora":
|
||||
replace_model(unwrapped_model, target="reward")
|
||||
@ -404,7 +395,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
reward_model = self.reward_model
|
||||
|
||||
with unwrap_model_for_generation(reward_model, self.accelerator), self.amp_context: # support bf16
|
||||
values: "torch.Tensor" = reward_model(**batch, return_dict=True, use_cache=False)[-1]
|
||||
values: torch.Tensor = reward_model(**batch, return_dict=True, use_cache=False)[-1]
|
||||
|
||||
if self.finetuning_args.reward_model_type == "lora":
|
||||
replace_model(unwrapped_model, target="default")
|
||||
@ -419,12 +410,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
model: "AutoModelForCausalLMWithValueHead",
|
||||
queries: "torch.Tensor",
|
||||
responses: "torch.Tensor",
|
||||
model_inputs: Dict[str, Any],
|
||||
model_inputs: dict[str, Any],
|
||||
return_logits: bool = False,
|
||||
response_masks: Optional["torch.Tensor"] = None,
|
||||
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], "torch.Tensor", "torch.Tensor"]:
|
||||
r"""
|
||||
Calculates model outputs in multiple batches.
|
||||
) -> tuple["torch.Tensor", Optional["torch.Tensor"], "torch.Tensor", "torch.Tensor"]:
|
||||
r"""Calculate model outputs in multiple batches.
|
||||
|
||||
Subclass and override to inject custom behavior.
|
||||
"""
|
||||
@ -483,8 +473,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
|
||||
@override
|
||||
def save_model(self, output_dir: Optional[str] = None) -> None:
|
||||
r"""
|
||||
Saves model checkpoint.
|
||||
r"""Save model checkpoint.
|
||||
|
||||
Subclass and override to inject custom behavior.
|
||||
"""
|
||||
@ -508,5 +497,5 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
self.model.save_checkpoint(output_dir)
|
||||
|
||||
elif self.args.should_save:
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
unwrapped_model: AutoModelForCausalLMWithValueHead = self.accelerator.unwrap_model(self.model)
|
||||
self._save(output_dir, state_dict=unwrapped_model.state_dict())
|
||||
|
@ -15,7 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from ...data import MultiModalDataCollatorForSeq2Seq, get_dataset, get_template_and_fix_tokenizer
|
||||
from ...extras.ploting import plot_loss
|
||||
@ -37,7 +37,7 @@ def run_ppo(
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||
callbacks: Optional[list["TrainerCallback"]] = None,
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
@ -53,7 +53,7 @@ def run_ppo(
|
||||
reward_model = create_reward_model(model, model_args, finetuning_args)
|
||||
|
||||
# Initialize our Trainer
|
||||
ppo_trainer: "CustomPPOTrainer" = CustomPPOTrainer(
|
||||
ppo_trainer: CustomPPOTrainer = CustomPPOTrainer(
|
||||
model_args=model_args,
|
||||
training_args=training_args,
|
||||
finetuning_args=finetuning_args,
|
||||
|
@ -31,9 +31,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class CustomTrainer(Trainer):
|
||||
r"""
|
||||
Inherits Trainer for custom optimizer.
|
||||
"""
|
||||
r"""Inherit Trainer for custom optimizer."""
|
||||
|
||||
def __init__(
|
||||
self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs
|
||||
|
@ -16,7 +16,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from transformers import DataCollatorForLanguageModeling
|
||||
|
||||
@ -38,7 +38,7 @@ def run_pt(
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||
callbacks: Optional[list["TrainerCallback"]] = None,
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -26,11 +26,9 @@ if TYPE_CHECKING:
|
||||
|
||||
@dataclass
|
||||
class ComputeAccuracy:
|
||||
r"""
|
||||
Computes reward accuracy and supports `batch_eval_metrics`.
|
||||
"""
|
||||
r"""Compute reward accuracy and support `batch_eval_metrics`."""
|
||||
|
||||
def _dump(self) -> Optional[Dict[str, float]]:
|
||||
def _dump(self) -> Optional[dict[str, float]]:
|
||||
result = None
|
||||
if hasattr(self, "score_dict"):
|
||||
result = {k: float(np.mean(v)) for k, v in self.score_dict.items()}
|
||||
@ -41,7 +39,7 @@ class ComputeAccuracy:
|
||||
def __post_init__(self):
|
||||
self._dump()
|
||||
|
||||
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]:
|
||||
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[dict[str, float]]:
|
||||
chosen_scores, rejected_scores = numpify(eval_preds.predictions[0]), numpify(eval_preds.predictions[1])
|
||||
if not chosen_scores.shape:
|
||||
self.score_dict["accuracy"].append(chosen_scores > rejected_scores)
|
||||
|
@ -18,7 +18,7 @@
|
||||
import json
|
||||
import os
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import Trainer
|
||||
@ -41,9 +41,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class PairwiseTrainer(Trainer):
|
||||
r"""
|
||||
Inherits Trainer to compute pairwise loss.
|
||||
"""
|
||||
r"""Inherits Trainer to compute pairwise loss."""
|
||||
|
||||
def __init__(
|
||||
self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs
|
||||
@ -88,10 +86,9 @@ class PairwiseTrainer(Trainer):
|
||||
|
||||
@override
|
||||
def compute_loss(
|
||||
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
|
||||
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
|
||||
r"""
|
||||
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
|
||||
self, model: "PreTrainedModel", inputs: dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
|
||||
) -> Union["torch.Tensor", tuple["torch.Tensor", list["torch.Tensor"]]]:
|
||||
r"""Compute pairwise loss. The first n examples are chosen and the last n examples are rejected.
|
||||
|
||||
Subclass and override to inject custom behavior.
|
||||
|
||||
@ -113,8 +110,7 @@ class PairwiseTrainer(Trainer):
|
||||
return loss
|
||||
|
||||
def save_predictions(self, predict_results: "PredictionOutput") -> None:
|
||||
r"""
|
||||
Saves model predictions to `output_dir`.
|
||||
r"""Save model predictions to `output_dir`.
|
||||
|
||||
A custom behavior that not contained in Seq2SeqTrainer.
|
||||
"""
|
||||
@ -126,7 +122,7 @@ class PairwiseTrainer(Trainer):
|
||||
chosen_scores, rejected_scores = predict_results.predictions
|
||||
|
||||
with open(output_prediction_file, "w", encoding="utf-8") as writer:
|
||||
res: List[str] = []
|
||||
res: list[str] = []
|
||||
for c_score, r_score in zip(chosen_scores, rejected_scores):
|
||||
res.append(json.dumps({"chosen": round(float(c_score), 2), "rejected": round(float(r_score), 2)}))
|
||||
|
||||
|
@ -15,7 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from ...data import PairwiseDataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer
|
||||
from ...extras.ploting import plot_loss
|
||||
@ -37,7 +37,7 @@ def run_rm(
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||
callbacks: Optional[list["TrainerCallback"]] = None,
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
|
@ -17,7 +17,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -45,9 +45,7 @@ if is_rouge_available():
|
||||
|
||||
|
||||
def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "torch.Tensor":
|
||||
r"""
|
||||
Computes the token with the largest likelihood to reduce memory footprint.
|
||||
"""
|
||||
r"""Compute the token with the largest likelihood to reduce memory footprint."""
|
||||
if isinstance(logits, (list, tuple)):
|
||||
if logits[0].dim() == 3: # (batch_size, seq_len, vocab_size)
|
||||
logits = logits[0]
|
||||
@ -62,11 +60,9 @@ def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "tor
|
||||
|
||||
@dataclass
|
||||
class ComputeAccuracy:
|
||||
r"""
|
||||
Computes accuracy and supports `batch_eval_metrics`.
|
||||
"""
|
||||
r"""Compute accuracy and support `batch_eval_metrics`."""
|
||||
|
||||
def _dump(self) -> Optional[Dict[str, float]]:
|
||||
def _dump(self) -> Optional[dict[str, float]]:
|
||||
result = None
|
||||
if hasattr(self, "score_dict"):
|
||||
result = {k: float(np.mean(v)) for k, v in self.score_dict.items()}
|
||||
@ -77,7 +73,7 @@ class ComputeAccuracy:
|
||||
def __post_init__(self):
|
||||
self._dump()
|
||||
|
||||
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]:
|
||||
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[dict[str, float]]:
|
||||
preds, labels = numpify(eval_preds.predictions), numpify(eval_preds.label_ids)
|
||||
for i in range(len(preds)):
|
||||
pred, label = preds[i, :-1], labels[i, 1:]
|
||||
@ -90,15 +86,14 @@ class ComputeAccuracy:
|
||||
|
||||
@dataclass
|
||||
class ComputeSimilarity:
|
||||
r"""
|
||||
Computes text similarity scores and supports `batch_eval_metrics`.
|
||||
r"""Compute text similarity scores and support `batch_eval_metrics`.
|
||||
|
||||
Wraps the tokenizer into metric functions, used in CustomSeq2SeqTrainer.
|
||||
"""
|
||||
|
||||
tokenizer: "PreTrainedTokenizer"
|
||||
|
||||
def _dump(self) -> Optional[Dict[str, float]]:
|
||||
def _dump(self) -> Optional[dict[str, float]]:
|
||||
result = None
|
||||
if hasattr(self, "score_dict"):
|
||||
result = {k: float(np.mean(v)) for k, v in self.score_dict.items()}
|
||||
@ -109,7 +104,7 @@ class ComputeSimilarity:
|
||||
def __post_init__(self):
|
||||
self._dump()
|
||||
|
||||
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]:
|
||||
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[dict[str, float]]:
|
||||
preds, labels = numpify(eval_preds.predictions), numpify(eval_preds.label_ids)
|
||||
|
||||
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
|
||||
|
@ -18,7 +18,7 @@
|
||||
import json
|
||||
import os
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -44,21 +44,19 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
r"""
|
||||
Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE.
|
||||
"""
|
||||
r"""Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
finetuning_args: "FinetuningArguments",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
gen_kwargs: Optional[Dict[str, Any]] = None,
|
||||
gen_kwargs: Optional[dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if is_transformers_version_greater_than("4.46"):
|
||||
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
||||
else:
|
||||
self.processing_class: "PreTrainedTokenizer" = kwargs.get("tokenizer")
|
||||
self.processing_class: PreTrainedTokenizer = kwargs.get("tokenizer")
|
||||
|
||||
super().__init__(**kwargs)
|
||||
self.finetuning_args = finetuning_args
|
||||
@ -99,13 +97,12 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
def prediction_step(
|
||||
self,
|
||||
model: "torch.nn.Module",
|
||||
inputs: Dict[str, Union["torch.Tensor", Any]],
|
||||
inputs: dict[str, Union["torch.Tensor", Any]],
|
||||
prediction_loss_only: bool,
|
||||
ignore_keys: Optional[List[str]] = None,
|
||||
ignore_keys: Optional[list[str]] = None,
|
||||
**gen_kwargs,
|
||||
) -> Tuple[Optional[float], Optional["torch.Tensor"], Optional["torch.Tensor"]]:
|
||||
r"""
|
||||
Removes the prompt part in the generated tokens.
|
||||
) -> tuple[Optional[float], Optional["torch.Tensor"], Optional["torch.Tensor"]]:
|
||||
r"""Remove the prompt part in the generated tokens.
|
||||
|
||||
Subclass and override to inject custom behavior.
|
||||
"""
|
||||
@ -126,8 +123,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
def save_predictions(
|
||||
self, dataset: "Dataset", predict_results: "PredictionOutput", skip_special_tokens: bool = True
|
||||
) -> None:
|
||||
r"""
|
||||
Saves model predictions to `output_dir`.
|
||||
r"""Save model predictions to `output_dir`.
|
||||
|
||||
A custom behavior that not contained in Seq2SeqTrainer.
|
||||
"""
|
||||
|
@ -15,7 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
@ -43,7 +43,7 @@ def run_sft(
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||
callbacks: Optional[list["TrainerCallback"]] = None,
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
|
@ -12,7 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Sequence, Set, Tuple, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import torch
|
||||
from peft import PeftModel
|
||||
@ -43,7 +44,7 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_k
|
||||
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) is True
|
||||
|
||||
|
||||
def check_lora_model(model: "LoraModel") -> Tuple[Set[str], Set[str]]:
|
||||
def check_lora_model(model: "LoraModel") -> tuple[set[str], set[str]]:
|
||||
linear_modules, extra_modules = set(), set()
|
||||
for name, param in model.named_parameters():
|
||||
if any(module in name for module in ["lora_A", "lora_B"]):
|
||||
@ -83,7 +84,7 @@ def load_reference_model(
|
||||
) -> Union["PreTrainedModel", "LoraModel"]:
|
||||
current_device = get_current_device()
|
||||
if add_valuehead:
|
||||
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
model: AutoModelForCausalLMWithValueHead = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
model_path, torch_dtype=torch.float16, device_map=current_device
|
||||
)
|
||||
if not is_trainable:
|
||||
@ -111,7 +112,7 @@ def load_dataset_module(**kwargs) -> "DatasetModule":
|
||||
|
||||
|
||||
def patch_valuehead_model() -> None:
|
||||
def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]) -> None:
|
||||
def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: dict[str, "torch.Tensor"]) -> None:
|
||||
state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")}
|
||||
self.v_head.load_state_dict(state_dict, strict=False)
|
||||
del state_dict
|
||||
|
@ -21,7 +21,7 @@ import json
|
||||
import os
|
||||
from collections.abc import Mapping
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import Trainer
|
||||
@ -63,12 +63,10 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class DummyOptimizer(torch.optim.Optimizer):
|
||||
r"""
|
||||
A dummy optimizer used for the GaLore or APOLLO algorithm.
|
||||
"""
|
||||
r"""A dummy optimizer used for the GaLore or APOLLO algorithm."""
|
||||
|
||||
def __init__(
|
||||
self, lr: float = 1e-3, optimizer_dict: Optional[Dict["torch.nn.Parameter", "torch.optim.Optimizer"]] = None
|
||||
self, lr: float = 1e-3, optimizer_dict: Optional[dict["torch.nn.Parameter", "torch.optim.Optimizer"]] = None
|
||||
) -> None:
|
||||
dummy_tensor = torch.randn(1, 1)
|
||||
self.optimizer_dict = optimizer_dict
|
||||
@ -112,8 +110,7 @@ def create_modelcard_and_push(
|
||||
def create_ref_model(
|
||||
model_args: "ModelArguments", finetuning_args: "FinetuningArguments", add_valuehead: bool = False
|
||||
) -> Optional[Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]]:
|
||||
r"""
|
||||
Creates reference model for PPO/DPO training. Evaluation mode is not supported.
|
||||
r"""Create reference model for PPO/DPO training. Evaluation mode is not supported.
|
||||
|
||||
The valuehead parameter is randomly initialized since it is useless for PPO training.
|
||||
"""
|
||||
@ -148,9 +145,7 @@ def create_ref_model(
|
||||
def create_reward_model(
|
||||
model: "AutoModelForCausalLMWithValueHead", model_args: "ModelArguments", finetuning_args: "FinetuningArguments"
|
||||
) -> Optional["AutoModelForCausalLMWithValueHead"]:
|
||||
r"""
|
||||
Creates reward model for PPO training.
|
||||
"""
|
||||
r"""Create reward model for PPO training."""
|
||||
if finetuning_args.reward_model_type == "api":
|
||||
assert finetuning_args.reward_model.startswith("http"), "Please provide full url."
|
||||
logger.info_rank0(f"Use reward server {finetuning_args.reward_model}")
|
||||
@ -189,10 +184,8 @@ def create_reward_model(
|
||||
return reward_model
|
||||
|
||||
|
||||
def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
|
||||
r"""
|
||||
Returns a list of names of parameters with weight decay. (weights in non-layernorm layers)
|
||||
"""
|
||||
def _get_decay_parameter_names(model: "PreTrainedModel") -> list[str]:
|
||||
r"""Return a list of names of parameters with weight decay. (weights in non-layernorm layers)."""
|
||||
decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS)
|
||||
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
||||
return decay_parameters
|
||||
@ -208,7 +201,7 @@ def _create_galore_optimizer(
|
||||
else:
|
||||
galore_targets = finetuning_args.galore_target
|
||||
|
||||
galore_params: List["torch.nn.Parameter"] = []
|
||||
galore_params: list[torch.nn.Parameter] = []
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, torch.nn.Linear) and any(target in name for target in galore_targets):
|
||||
for param in module.parameters():
|
||||
@ -224,7 +217,7 @@ def _create_galore_optimizer(
|
||||
|
||||
id_galore_params = {id(param) for param in galore_params}
|
||||
decay_params, nodecay_params = [], [] # they are non-galore parameters
|
||||
trainable_params: List["torch.nn.Parameter"] = [] # galore_params + decay_params + nodecay_params
|
||||
trainable_params: list[torch.nn.Parameter] = [] # galore_params + decay_params + nodecay_params
|
||||
decay_param_names = _get_decay_parameter_names(model)
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
@ -251,7 +244,7 @@ def _create_galore_optimizer(
|
||||
if training_args.gradient_accumulation_steps != 1:
|
||||
raise ValueError("Per-layer GaLore does not support gradient accumulation.")
|
||||
|
||||
optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {}
|
||||
optimizer_dict: dict[torch.Tensor, torch.optim.Optimizer] = {}
|
||||
for param in nodecay_params:
|
||||
param_groups = [dict(params=[param], weight_decay=0.0)]
|
||||
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
|
||||
@ -296,7 +289,7 @@ def _create_apollo_optimizer(
|
||||
else:
|
||||
apollo_targets = finetuning_args.apollo_target
|
||||
|
||||
apollo_params: List["torch.nn.Parameter"] = []
|
||||
apollo_params: list[torch.nn.Parameter] = []
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, torch.nn.Linear) and any(target in name for target in apollo_targets):
|
||||
for param in module.parameters():
|
||||
@ -315,7 +308,7 @@ def _create_apollo_optimizer(
|
||||
|
||||
id_apollo_params = {id(param) for param in apollo_params}
|
||||
decay_params, nodecay_params = [], [] # they are non-apollo parameters
|
||||
trainable_params: List["torch.nn.Parameter"] = [] # apollo_params + decay_params + nodecay_params
|
||||
trainable_params: list[torch.nn.Parameter] = [] # apollo_params + decay_params + nodecay_params
|
||||
decay_param_names = _get_decay_parameter_names(model)
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
@ -338,7 +331,7 @@ def _create_apollo_optimizer(
|
||||
if training_args.gradient_accumulation_steps != 1:
|
||||
raise ValueError("Per-layer APOLLO does not support gradient accumulation.")
|
||||
|
||||
optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {}
|
||||
optimizer_dict: dict[torch.Tensor, torch.optim.Optimizer] = {}
|
||||
for param in nodecay_params:
|
||||
param_groups = [dict(params=[param], weight_decay=0.0)]
|
||||
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
|
||||
@ -380,7 +373,7 @@ def _create_loraplus_optimizer(
|
||||
embedding_lr = finetuning_args.loraplus_lr_embedding
|
||||
|
||||
decay_param_names = _get_decay_parameter_names(model)
|
||||
param_dict: Dict[str, List["torch.nn.Parameter"]] = {
|
||||
param_dict: dict[str, list[torch.nn.Parameter]] = {
|
||||
"lora_a": [],
|
||||
"lora_b": [],
|
||||
"lora_b_nodecay": [],
|
||||
@ -524,7 +517,7 @@ def create_custom_scheduler(
|
||||
) -> None:
|
||||
if optimizer is not None and isinstance(optimizer, DummyOptimizer):
|
||||
optimizer_dict = optimizer.optimizer_dict
|
||||
scheduler_dict: Dict["torch.nn.Parameter", "torch.optim.lr_scheduler.LRScheduler"] = {}
|
||||
scheduler_dict: dict[torch.nn.Parameter, torch.optim.lr_scheduler.LRScheduler] = {}
|
||||
|
||||
for param in optimizer_dict.keys():
|
||||
scheduler_dict[param] = get_scheduler(
|
||||
@ -544,13 +537,13 @@ def create_custom_scheduler(
|
||||
|
||||
def get_batch_logps(
|
||||
logits: "torch.Tensor", labels: "torch.Tensor", label_pad_token_id: int = IGNORE_INDEX
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor"]:
|
||||
r"""
|
||||
Computes the log probabilities of the given labels under the given logits.
|
||||
) -> tuple["torch.Tensor", "torch.Tensor"]:
|
||||
r"""Compute the log probabilities of the given labels under the given logits.
|
||||
|
||||
Returns:
|
||||
logps: A tensor of shape (batch_size,) containing the sum of log probabilities.
|
||||
valid_length: A tensor of shape (batch_size,) containing the number of non-masked tokens.
|
||||
|
||||
"""
|
||||
if logits.shape[:-1] != labels.shape:
|
||||
raise ValueError("Logits (batchsize x seqlen) and labels must have the same shape.")
|
||||
@ -564,12 +557,10 @@ def get_batch_logps(
|
||||
|
||||
|
||||
def nested_detach(
|
||||
tensors: Union["torch.Tensor", List["torch.Tensor"], Tuple["torch.Tensor"], Dict[str, "torch.Tensor"]],
|
||||
tensors: Union["torch.Tensor", list["torch.Tensor"], tuple["torch.Tensor"], dict[str, "torch.Tensor"]],
|
||||
clone: bool = False,
|
||||
):
|
||||
r"""
|
||||
Detach `tensors` (even if it's a nested list/tuple/dict of tensors).
|
||||
"""
|
||||
r"""Detach `tensors` (even if it's a nested list/tuple/dict of tensors)."""
|
||||
if isinstance(tensors, (list, tuple)):
|
||||
return type(tensors)(nested_detach(t, clone=clone) for t in tensors)
|
||||
elif isinstance(tensors, Mapping):
|
||||
@ -585,9 +576,7 @@ def nested_detach(
|
||||
|
||||
|
||||
def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCallback":
|
||||
r"""
|
||||
Gets the callback for logging to SwanLab.
|
||||
"""
|
||||
r"""Get the callback for logging to SwanLab."""
|
||||
import swanlab # type: ignore
|
||||
from swanlab.integration.transformers import SwanLabCallback # type: ignore
|
||||
|
||||
@ -624,7 +613,7 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
|
||||
|
||||
def get_ray_trainer(
|
||||
training_function: Callable,
|
||||
train_loop_config: Dict[str, Any],
|
||||
train_loop_config: dict[str, Any],
|
||||
ray_args: "RayArguments",
|
||||
) -> "TorchTrainer":
|
||||
if not ray_args.use_ray:
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -48,9 +48,9 @@ if TYPE_CHECKING:
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _training_function(config: Dict[str, Any]) -> None:
|
||||
def _training_function(config: dict[str, Any]) -> None:
|
||||
args = config.get("args")
|
||||
callbacks: List[Any] = config.get("callbacks")
|
||||
callbacks: list[Any] = config.get("callbacks")
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
|
||||
|
||||
callbacks.append(LogCallback())
|
||||
@ -84,7 +84,7 @@ def _training_function(config: Dict[str, Any]) -> None:
|
||||
logger.warning(f"Failed to destroy process group: {e}.")
|
||||
|
||||
|
||||
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None) -> None:
|
||||
def run_exp(args: Optional[dict[str, Any]] = None, callbacks: Optional[list["TrainerCallback"]] = None) -> None:
|
||||
args = read_args(args)
|
||||
if "-h" in args or "--help" in args:
|
||||
get_train_args(args)
|
||||
@ -103,7 +103,7 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["Tra
|
||||
_training_function(config={"args": args, "callbacks": callbacks})
|
||||
|
||||
|
||||
def export_model(args: Optional[Dict[str, Any]] = None) -> None:
|
||||
def export_model(args: Optional[dict[str, Any]] = None) -> None:
|
||||
model_args, data_args, finetuning_args, _ = get_infer_args(args)
|
||||
|
||||
if model_args.export_dir is None:
|
||||
|
@ -14,7 +14,8 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from transformers.utils import is_torch_npu_available
|
||||
|
||||
@ -37,15 +38,12 @@ if is_gradio_available():
|
||||
|
||||
|
||||
def _escape_html(text: str) -> str:
|
||||
r"""
|
||||
Escapes HTML characters.
|
||||
"""
|
||||
r"""Escape HTML characters."""
|
||||
return text.replace("<", "<").replace(">", ">")
|
||||
|
||||
|
||||
def _format_response(text: str, lang: str, escape_html: bool, thought_words: Tuple[str, str]) -> str:
|
||||
r"""
|
||||
Post-processes the response text.
|
||||
def _format_response(text: str, lang: str, escape_html: bool, thought_words: tuple[str, str]) -> str:
|
||||
r"""Post-process the response text.
|
||||
|
||||
Based on: https://huggingface.co/spaces/Lyte/DeepSeek-R1-Distill-Qwen-1.5B-Demo-GGUF/blob/main/app.py
|
||||
"""
|
||||
@ -74,7 +72,7 @@ class WebChatModel(ChatModel):
|
||||
def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None:
|
||||
self.manager = manager
|
||||
self.demo_mode = demo_mode
|
||||
self.engine: Optional["BaseEngine"] = None
|
||||
self.engine: Optional[BaseEngine] = None
|
||||
|
||||
if not lazy_init: # read arguments from command line
|
||||
super().__init__()
|
||||
@ -160,14 +158,13 @@ class WebChatModel(ChatModel):
|
||||
|
||||
@staticmethod
|
||||
def append(
|
||||
chatbot: List[Dict[str, str]],
|
||||
messages: List[Dict[str, str]],
|
||||
chatbot: list[dict[str, str]],
|
||||
messages: list[dict[str, str]],
|
||||
role: str,
|
||||
query: str,
|
||||
escape_html: bool,
|
||||
) -> Tuple[List[Dict[str, str]], List[Dict[str, str]], str]:
|
||||
r"""
|
||||
Adds the user input to chatbot.
|
||||
) -> tuple[list[dict[str, str]], list[dict[str, str]], str]:
|
||||
r"""Add the user input to chatbot.
|
||||
|
||||
Inputs: infer.chatbot, infer.messages, infer.role, infer.query, infer.escape_html
|
||||
Output: infer.chatbot, infer.messages, infer.query
|
||||
@ -180,8 +177,8 @@ class WebChatModel(ChatModel):
|
||||
|
||||
def stream(
|
||||
self,
|
||||
chatbot: List[Dict[str, str]],
|
||||
messages: List[Dict[str, str]],
|
||||
chatbot: list[dict[str, str]],
|
||||
messages: list[dict[str, str]],
|
||||
lang: str,
|
||||
system: str,
|
||||
tools: str,
|
||||
@ -193,9 +190,8 @@ class WebChatModel(ChatModel):
|
||||
temperature: float,
|
||||
skip_special_tokens: bool,
|
||||
escape_html: bool,
|
||||
) -> Generator[Tuple[List[Dict[str, str]], List[Dict[str, str]]], None, None]:
|
||||
r"""
|
||||
Generates output text in stream.
|
||||
) -> Generator[tuple[list[dict[str, str]], list[dict[str, str]]], None, None]:
|
||||
r"""Generate output text in stream.
|
||||
|
||||
Inputs: infer.chatbot, infer.messages, infer.system, infer.tools, infer.image, infer.video, ...
|
||||
Output: infer.chatbot, infer.messages
|
||||
|
@ -17,7 +17,7 @@ import os
|
||||
import signal
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from psutil import Process
|
||||
from yaml import safe_dump, safe_load
|
||||
@ -44,9 +44,7 @@ USER_CONFIG = "user_config.yaml"
|
||||
|
||||
|
||||
def abort_process(pid: int) -> None:
|
||||
r"""
|
||||
Aborts the processes recursively in a bottom-up way.
|
||||
"""
|
||||
r"""Abort the processes recursively in a bottom-up way."""
|
||||
try:
|
||||
children = Process(pid).children()
|
||||
if children:
|
||||
@ -59,9 +57,7 @@ def abort_process(pid: int) -> None:
|
||||
|
||||
|
||||
def get_save_dir(*paths: str) -> os.PathLike:
|
||||
r"""
|
||||
Gets the path to saved model checkpoints.
|
||||
"""
|
||||
r"""Get the path to saved model checkpoints."""
|
||||
if os.path.sep in paths[-1]:
|
||||
logger.warning_rank0("Found complex path, some features may be not available.")
|
||||
return paths[-1]
|
||||
@ -71,16 +67,12 @@ def get_save_dir(*paths: str) -> os.PathLike:
|
||||
|
||||
|
||||
def _get_config_path() -> os.PathLike:
|
||||
r"""
|
||||
Gets the path to user config.
|
||||
"""
|
||||
r"""Get the path to user config."""
|
||||
return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)
|
||||
|
||||
|
||||
def load_config() -> Dict[str, Union[str, Dict[str, Any]]]:
|
||||
r"""
|
||||
Loads user config if exists.
|
||||
"""
|
||||
def load_config() -> dict[str, Union[str, dict[str, Any]]]:
|
||||
r"""Load user config if exists."""
|
||||
try:
|
||||
with open(_get_config_path(), encoding="utf-8") as f:
|
||||
return safe_load(f)
|
||||
@ -89,9 +81,7 @@ def load_config() -> Dict[str, Union[str, Dict[str, Any]]]:
|
||||
|
||||
|
||||
def save_config(lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None) -> None:
|
||||
r"""
|
||||
Saves user config.
|
||||
"""
|
||||
r"""Save user config."""
|
||||
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
|
||||
user_config = load_config()
|
||||
user_config["lang"] = lang or user_config["lang"]
|
||||
@ -106,11 +96,9 @@ def save_config(lang: str, model_name: Optional[str] = None, model_path: Optiona
|
||||
|
||||
|
||||
def get_model_path(model_name: str) -> str:
|
||||
r"""
|
||||
Gets the model path according to the model name.
|
||||
"""
|
||||
r"""Get the model path according to the model name."""
|
||||
user_config = load_config()
|
||||
path_dict: Dict["DownloadSource", str] = SUPPORTED_MODELS.get(model_name, defaultdict(str))
|
||||
path_dict: dict[DownloadSource, str] = SUPPORTED_MODELS.get(model_name, defaultdict(str))
|
||||
model_path = user_config["path_dict"].get(model_name, "") or path_dict.get(DownloadSource.DEFAULT, "")
|
||||
if (
|
||||
use_modelscope()
|
||||
@ -130,30 +118,22 @@ def get_model_path(model_name: str) -> str:
|
||||
|
||||
|
||||
def get_template(model_name: str) -> str:
|
||||
r"""
|
||||
Gets the template name if the model is a chat/distill/instruct model.
|
||||
"""
|
||||
r"""Get the template name if the model is a chat/distill/instruct model."""
|
||||
return DEFAULT_TEMPLATE.get(model_name, "default")
|
||||
|
||||
|
||||
def get_time() -> str:
|
||||
r"""
|
||||
Gets current date and time.
|
||||
"""
|
||||
r"""Get current date and time."""
|
||||
return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S")
|
||||
|
||||
|
||||
def is_multimodal(model_name: str) -> bool:
|
||||
r"""
|
||||
Judges if the model is a vision language model.
|
||||
"""
|
||||
r"""Judge if the model is a vision language model."""
|
||||
return model_name in MULTIMODAL_SUPPORTED_MODELS
|
||||
|
||||
|
||||
def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
|
||||
r"""
|
||||
Loads dataset_info.json.
|
||||
"""
|
||||
def load_dataset_info(dataset_dir: str) -> dict[str, dict[str, Any]]:
|
||||
r"""Load dataset_info.json."""
|
||||
if dataset_dir == "ONLINE" or dataset_dir.startswith("REMOTE:"):
|
||||
logger.info_rank0(f"dataset_dir is {dataset_dir}, using online dataset.")
|
||||
return {}
|
||||
@ -166,10 +146,8 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
|
||||
return {}
|
||||
|
||||
|
||||
def load_args(config_path: str) -> Optional[Dict[str, Any]]:
|
||||
r"""
|
||||
Loads the training configuration from config path.
|
||||
"""
|
||||
def load_args(config_path: str) -> Optional[dict[str, Any]]:
|
||||
r"""Load the training configuration from config path."""
|
||||
try:
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
return safe_load(f)
|
||||
@ -177,26 +155,20 @@ def load_args(config_path: str) -> Optional[Dict[str, Any]]:
|
||||
return None
|
||||
|
||||
|
||||
def save_args(config_path: str, config_dict: Dict[str, Any]) -> None:
|
||||
r"""
|
||||
Saves the training configuration to config path.
|
||||
"""
|
||||
def save_args(config_path: str, config_dict: dict[str, Any]) -> None:
|
||||
r"""Save the training configuration to config path."""
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
safe_dump(config_dict, f)
|
||||
|
||||
|
||||
def _clean_cmd(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
r"""
|
||||
Removes args with NoneType or False or empty string value.
|
||||
"""
|
||||
def _clean_cmd(args: dict[str, Any]) -> dict[str, Any]:
|
||||
r"""Remove args with NoneType or False or empty string value."""
|
||||
no_skip_keys = ["packing"]
|
||||
return {k: v for k, v in args.items() if (k in no_skip_keys) or (v is not None and v is not False and v != "")}
|
||||
|
||||
|
||||
def gen_cmd(args: Dict[str, Any]) -> str:
|
||||
r"""
|
||||
Generates CLI commands for previewing.
|
||||
"""
|
||||
def gen_cmd(args: dict[str, Any]) -> str:
|
||||
r"""Generate CLI commands for previewing."""
|
||||
cmd_lines = ["llamafactory-cli train "]
|
||||
for k, v in _clean_cmd(args).items():
|
||||
if isinstance(v, dict):
|
||||
@ -215,10 +187,8 @@ def gen_cmd(args: Dict[str, Any]) -> str:
|
||||
return cmd_text
|
||||
|
||||
|
||||
def save_cmd(args: Dict[str, Any]) -> str:
|
||||
r"""
|
||||
Saves CLI commands to launch training.
|
||||
"""
|
||||
def save_cmd(args: dict[str, Any]) -> str:
|
||||
r"""Save CLI commands to launch training."""
|
||||
output_dir = args["output_dir"]
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
with open(os.path.join(output_dir, TRAINING_ARGS), "w", encoding="utf-8") as f:
|
||||
@ -228,9 +198,7 @@ def save_cmd(args: Dict[str, Any]) -> str:
|
||||
|
||||
|
||||
def load_eval_results(path: os.PathLike) -> str:
|
||||
r"""
|
||||
Gets scores after evaluation.
|
||||
"""
|
||||
r"""Get scores after evaluation."""
|
||||
with open(path, encoding="utf-8") as f:
|
||||
result = json.dumps(json.load(f), indent=4)
|
||||
|
||||
@ -238,9 +206,7 @@ def load_eval_results(path: os.PathLike) -> str:
|
||||
|
||||
|
||||
def create_ds_config() -> None:
|
||||
r"""
|
||||
Creates deepspeed config in the current directory.
|
||||
"""
|
||||
r"""Create deepspeed config in the current directory."""
|
||||
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
|
||||
ds_config = {
|
||||
"train_batch_size": "auto",
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Dict, Tuple
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...data import Role
|
||||
from ...extras.packages import is_gradio_available
|
||||
@ -31,9 +31,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
def check_json_schema(text: str, lang: str) -> None:
|
||||
r"""
|
||||
Checks if the json schema is valid.
|
||||
"""
|
||||
r"""Check if the json schema is valid."""
|
||||
try:
|
||||
tools = json.loads(text)
|
||||
if tools:
|
||||
@ -49,7 +47,7 @@ def check_json_schema(text: str, lang: str) -> None:
|
||||
|
||||
def create_chat_box(
|
||||
engine: "Engine", visible: bool = False
|
||||
) -> Tuple["Component", "Component", Dict[str, "Component"]]:
|
||||
) -> tuple["Component", "Component", dict[str, "Component"]]:
|
||||
lang = engine.manager.get_elem_by_id("top.lang")
|
||||
with gr.Column(visible=visible) as chat_box:
|
||||
chatbot = gr.Chatbot(type="messages", show_copy_button=True)
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from ...extras.constants import DATA_CONFIG
|
||||
from ...extras.packages import is_gradio_available
|
||||
@ -40,9 +40,7 @@ def next_page(page_index: int, total_num: int) -> int:
|
||||
|
||||
|
||||
def can_preview(dataset_dir: str, dataset: list) -> "gr.Button":
|
||||
r"""
|
||||
Checks if the dataset is a local dataset.
|
||||
"""
|
||||
r"""Check if the dataset is a local dataset."""
|
||||
try:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f:
|
||||
dataset_info = json.load(f)
|
||||
@ -59,7 +57,7 @@ def can_preview(dataset_dir: str, dataset: list) -> "gr.Button":
|
||||
return gr.Button(interactive=False)
|
||||
|
||||
|
||||
def _load_data_file(file_path: str) -> List[Any]:
|
||||
def _load_data_file(file_path: str) -> list[Any]:
|
||||
with open(file_path, encoding="utf-8") as f:
|
||||
if file_path.endswith(".json"):
|
||||
return json.load(f)
|
||||
@ -69,10 +67,8 @@ def _load_data_file(file_path: str) -> List[Any]:
|
||||
return list(f)
|
||||
|
||||
|
||||
def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, "gr.Column"]:
|
||||
r"""
|
||||
Gets the preview samples from the dataset.
|
||||
"""
|
||||
def get_preview(dataset_dir: str, dataset: list, page_index: int) -> tuple[int, list, "gr.Column"]:
|
||||
r"""Get the preview samples from the dataset."""
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f:
|
||||
dataset_info = json.load(f)
|
||||
|
||||
@ -87,7 +83,7 @@ def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int,
|
||||
return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.Column(visible=True)
|
||||
|
||||
|
||||
def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dict[str, "Component"]:
|
||||
def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> dict[str, "Component"]:
|
||||
data_preview_btn = gr.Button(interactive=False, scale=1)
|
||||
with gr.Column(visible=False, elem_classes="modal-box") as preview_box:
|
||||
with gr.Row():
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...extras.packages import is_gradio_available
|
||||
from ..common import DEFAULT_DATA_DIR
|
||||
@ -30,7 +30,7 @@ if TYPE_CHECKING:
|
||||
from ..engine import Engine
|
||||
|
||||
|
||||
def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
def create_eval_tab(engine: "Engine") -> dict[str, "Component"]:
|
||||
input_elems = engine.manager.get_base_elems()
|
||||
elem_dict = dict()
|
||||
|
||||
|
@ -12,7 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Dict, Generator, List, Union
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
from ...extras.constants import PEFT_METHODS
|
||||
from ...extras.misc import torch_gc
|
||||
@ -35,7 +36,7 @@ if TYPE_CHECKING:
|
||||
GPTQ_BITS = ["8", "4", "3", "2"]
|
||||
|
||||
|
||||
def can_quantize(checkpoint_path: Union[str, List[str]]) -> "gr.Dropdown":
|
||||
def can_quantize(checkpoint_path: Union[str, list[str]]) -> "gr.Dropdown":
|
||||
if isinstance(checkpoint_path, list) and len(checkpoint_path) != 0:
|
||||
return gr.Dropdown(value="none", interactive=False)
|
||||
else:
|
||||
@ -47,7 +48,7 @@ def save_model(
|
||||
model_name: str,
|
||||
model_path: str,
|
||||
finetuning_type: str,
|
||||
checkpoint_path: Union[str, List[str]],
|
||||
checkpoint_path: Union[str, list[str]],
|
||||
template: str,
|
||||
export_size: int,
|
||||
export_quantization_bit: str,
|
||||
@ -106,7 +107,7 @@ def save_model(
|
||||
yield ALERTS["info_exported"][lang]
|
||||
|
||||
|
||||
def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
def create_export_tab(engine: "Engine") -> dict[str, "Component"]:
|
||||
with gr.Row():
|
||||
export_size = gr.Slider(minimum=1, maximum=100, value=5, step=1)
|
||||
export_quantization_bit = gr.Dropdown(choices=["none"] + GPTQ_BITS, value="none")
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...extras.packages import is_gradio_available
|
||||
from ..common import is_multimodal
|
||||
@ -29,7 +29,7 @@ if TYPE_CHECKING:
|
||||
from ..engine import Engine
|
||||
|
||||
|
||||
def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
def create_infer_tab(engine: "Engine") -> dict[str, "Component"]:
|
||||
input_elems = engine.manager.get_base_elems()
|
||||
elem_dict = dict()
|
||||
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...data import TEMPLATES
|
||||
from ...extras.constants import METHODS, SUPPORTED_MODELS
|
||||
@ -29,7 +29,7 @@ if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
|
||||
def create_top() -> Dict[str, "Component"]:
|
||||
def create_top() -> dict[str, "Component"]:
|
||||
with gr.Row():
|
||||
lang = gr.Dropdown(choices=["en", "ru", "zh", "ko", "ja"], value=None, scale=1)
|
||||
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user