merge data part to the text stream

Former-commit-id: c6dd89918feb25fe8c07857162421ad1706f791f
This commit is contained in:
BUAADreamer 2024-04-25 19:19:59 +08:00
parent 4e032ff95e
commit b6d78b2a64
15 changed files with 828 additions and 293 deletions

View File

@ -418,6 +418,17 @@
"hf_hub_url": "HuggingFaceH4/llava-instruct-mix-vsft" "hf_hub_url": "HuggingFaceH4/llava-instruct-mix-vsft"
}, },
"mllm_instruct_example": { "mllm_instruct_example": {
"hf_hub_url": "data/mllm_example_dataset" "file_name": "llava_instruct_example.json",
"formatting": "llava",
"columns": {
"messages": "messages",
"images": "images"
},
"tags": {
"role_tag": "role",
"content_tag": "content",
"user_tag": "user",
"assistant_tag": "assistant"
}
} }
} }

View File

@ -2,7 +2,7 @@
{ {
"messages": [ "messages": [
{ {
"content": "Who are they?", "content": "Who are they?<image>",
"role": "user" "role": "user"
}, },
{ {
@ -18,12 +18,14 @@
"role": "assistant" "role": "assistant"
} }
], ],
"image": "1.jpg" "images": [
"data/images/1.jpg"
]
}, },
{ {
"messages": [ "messages": [
{ {
"content": "Who is he?", "content": "Who is he?<image>",
"role": "user" "role": "user"
}, },
{ {
@ -39,12 +41,14 @@
"role": "assistant" "role": "assistant"
} }
], ],
"image": "2.jpg" "images": [
"data/images/2.jpg"
]
}, },
{ {
"messages": [ "messages": [
{ {
"content": "Please describe this image", "content": "Please describe this image<image>",
"role": "user" "role": "user"
}, },
{ {
@ -60,6 +64,8 @@
"role": "assistant" "role": "assistant"
} }
], ],
"image": "3.jpg" "images": [
"data/images/3.jpg"
]
} }
] ]

View File

@ -1,32 +0,0 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage sft_mm \
--do_train \
--model_name_or_path Salesforce/instructblip-vicuna-7b \
--dataset mllm_instruct_example \
--dataset_dir data \
--template default \
--finetuning_type lora \
--lora_target all \
--output_dir saves/instructblip-vicuna-7b/lora/sft \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--preprocessing_num_workers 16 \
--per_device_train_batch_size 3 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \
--lr_scheduler_type cosine \
--logging_steps 1 \
--warmup_steps 20 \
--save_steps 100 \
--eval_steps 100 \
--evaluation_strategy steps \
--load_best_model_at_end \
--learning_rate 1e-5 \
--num_train_epochs 50 \
--max_samples 3000 \
--val_size 0.1 \
--plot_loss \
--bf16

View File

@ -29,7 +29,10 @@ def get_processor(model_path):
def apply_lora(base_model_path, model_path, lora_path): def apply_lora(base_model_path, model_path, lora_path):
print(f"Loading the base model from {base_model_path}") print(f"Loading the base model from {base_model_path}")
base_model = AutoModelForVision2Seq.from_pretrained( base_model = AutoModelForVision2Seq.from_pretrained(
base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="cuda", base_model_path,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map="cuda",
) )
processor = get_processor(base_model_path) processor = get_processor(base_model_path)
tokenizer = processor.tokenizer tokenizer = processor.tokenizer
@ -60,11 +63,14 @@ def main(
if not os.path.exists(model_path) or do_merge: if not os.path.exists(model_path) or do_merge:
apply_lora(base_model_path, model_path, lora_model_path) apply_lora(base_model_path, model_path, lora_model_path)
model = AutoModelForVision2Seq.from_pretrained( model = AutoModelForVision2Seq.from_pretrained(
model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map="cuda" model_path,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
device_map="cuda",
) )
processor = get_processor(model_path) processor = get_processor(model_path)
raw_datasets = load_dataset(dataset_name) raw_datasets = load_dataset(dataset_name)
train_dataset = raw_datasets['train'] train_dataset = raw_datasets["train"]
examples = train_dataset.select(range(3)) examples = train_dataset.select(range(3))
texts = [] texts = []
images = [] images = []
@ -81,5 +87,5 @@ def main(
print(res) print(res)
if __name__ == '__main__': if __name__ == "__main__":
fire.Fire(main) fire.Fire(main)

View File

@ -1,12 +1,11 @@
from .collator import PairwiseDataCollatorWithPadding from .collator import PairwiseDataCollatorWithPadding
from .loader import get_dataset, get_mm_dataset from .loader import get_dataset
from .template import Template, get_template_and_fix_tokenizer, templates from .template import Template, get_template_and_fix_tokenizer, templates
from .utils import Role, split_dataset from .utils import Role, split_dataset
__all__ = [ __all__ = [
"PairwiseDataCollatorWithPadding", "PairwiseDataCollatorWithPadding",
"get_dataset", "get_dataset",
"get_mm_dataset",
"Template", "Template",
"get_template_and_fix_tokenizer", "get_template_and_fix_tokenizer",
"templates", "templates",

View File

@ -13,7 +13,9 @@ if TYPE_CHECKING:
from .parser import DatasetAttr from .parser import DatasetAttr
def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]: def convert_alpaca(
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr"
) -> Dict[str, List[Any]]:
outputs = {"prompt": [], "response": [], "system": [], "tools": []} outputs = {"prompt": [], "response": [], "system": [], "tools": []}
for i in range(len(examples[dataset_attr.prompt])): for i in range(len(examples[dataset_attr.prompt])):
prompt = [] prompt = []
@ -31,24 +33,38 @@ def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr")
prompt.append({"role": Role.USER.value, "content": "\n".join(content)}) prompt.append({"role": Role.USER.value, "content": "\n".join(content)})
if dataset_attr.response and isinstance(examples[dataset_attr.response][i], list): if dataset_attr.response and isinstance(
examples[dataset_attr.response][i], list
):
response = [ response = [
{"role": Role.ASSISTANT.value, "content": content} for content in examples[dataset_attr.response][i] {"role": Role.ASSISTANT.value, "content": content}
for content in examples[dataset_attr.response][i]
]
elif dataset_attr.response and isinstance(
examples[dataset_attr.response][i], str
):
response = [
{
"role": Role.ASSISTANT.value,
"content": examples[dataset_attr.response][i],
}
] ]
elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str):
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
else: else:
response = [] response = []
outputs["prompt"].append(prompt) outputs["prompt"].append(prompt)
outputs["response"].append(response) outputs["response"].append(response)
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "") outputs["system"].append(
examples[dataset_attr.system][i] if dataset_attr.system else ""
)
outputs["tools"].append("") outputs["tools"].append("")
outputs["images"].append([])
return outputs return outputs
def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]: def convert_sharegpt(
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr"
) -> Dict[str, List[Any]]:
outputs = {"prompt": [], "response": [], "system": [], "tools": []} outputs = {"prompt": [], "response": [], "system": [], "tools": []}
tag_mapping = { tag_mapping = {
dataset_attr.user_tag: Role.USER.value, dataset_attr.user_tag: Role.USER.value,
@ -61,7 +77,10 @@ def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr"
even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag) even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
accept_tags = (odd_tags, even_tags) accept_tags = (odd_tags, even_tags)
for i, messages in enumerate(examples[dataset_attr.messages]): for i, messages in enumerate(examples[dataset_attr.messages]):
if dataset_attr.system_tag and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag: if (
dataset_attr.system_tag
and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag
):
system = messages[0][dataset_attr.content_tag] system = messages[0][dataset_attr.content_tag]
messages = messages[1:] messages = messages[1:]
else: else:
@ -77,19 +96,81 @@ def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr"
raise ValueError("Invalid role tag in {}.".format(messages)) raise ValueError("Invalid role tag in {}.".format(messages))
aligned_messages.append( aligned_messages.append(
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]} {
"role": tag_mapping[message[dataset_attr.role_tag]],
"content": message[dataset_attr.content_tag],
}
) )
outputs["prompt"].append(aligned_messages[:-1]) outputs["prompt"].append(aligned_messages[:-1])
outputs["response"].append(aligned_messages[-1:]) outputs["response"].append(aligned_messages[-1:])
outputs["system"].append(system) outputs["system"].append(system)
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "") outputs["tools"].append(
examples[dataset_attr.tools][i] if dataset_attr.tools else ""
)
outputs["images"].append([])
return outputs
def convert_llava(
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr"
) -> Dict[str, List[Any]]:
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
tag_mapping = {
dataset_attr.user_tag: Role.USER.value,
dataset_attr.assistant_tag: Role.ASSISTANT.value,
dataset_attr.observation_tag: Role.OBSERVATION.value,
dataset_attr.function_tag: Role.FUNCTION.value,
dataset_attr.system_tag: Role.SYSTEM.value,
}
odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag)
even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
accept_tags = (odd_tags, even_tags)
for i, messages in enumerate(examples[dataset_attr.messages]):
if (
dataset_attr.system_tag
and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag
):
system = messages[0][dataset_attr.content_tag]
messages = messages[1:]
else:
system = examples[dataset_attr.system][i] if dataset_attr.system else ""
messages = messages[: len(messages) // 2 * 2] # should be multiples of 2
if len(messages) == 0:
continue
aligned_messages = []
for turn_idx, message in enumerate(messages):
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
raise ValueError("Invalid role tag in {}.".format(messages))
aligned_messages.append(
{
"role": tag_mapping[message[dataset_attr.role_tag]],
"content": message[dataset_attr.content_tag],
}
)
outputs["prompt"].append(aligned_messages[:-1])
outputs["response"].append(aligned_messages[-1:])
outputs["system"].append(system)
outputs["tools"].append(
examples[dataset_attr.tools][i] if dataset_attr.tools else ""
)
print(examples[dataset_attr.images][i])
outputs["images"].append(
examples[dataset_attr.images][i] if dataset_attr.images else []
)
return outputs return outputs
def align_dataset( def align_dataset(
dataset: Union["Dataset", "IterableDataset"], dataset_attr: "DatasetAttr", data_args: "DataArguments" dataset: Union["Dataset", "IterableDataset"],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
) -> Union["Dataset", "IterableDataset"]: ) -> Union["Dataset", "IterableDataset"]:
r""" r"""
Aligned dataset: Aligned dataset:
@ -100,6 +181,8 @@ def align_dataset(
""" """
if dataset_attr.formatting == "alpaca": if dataset_attr.formatting == "alpaca":
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr) convert_func = partial(convert_alpaca, dataset_attr=dataset_attr)
elif dataset_attr.formatting == "llava":
convert_func = partial(convert_llava, dataset_attr=dataset_attr)
else: else:
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr) convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr)
@ -107,13 +190,20 @@ def align_dataset(
features = Features.from_dict( features = Features.from_dict(
{ {
"prompt": [ "prompt": [
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}} {
"role": {"dtype": "string", "_type": "Value"},
"content": {"dtype": "string", "_type": "Value"},
}
], ],
"response": [ "response": [
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}} {
"role": {"dtype": "string", "_type": "Value"},
"content": {"dtype": "string", "_type": "Value"},
}
], ],
"system": {"dtype": "string", "_type": "Value"}, "system": {"dtype": "string", "_type": "Value"},
"tools": {"dtype": "string", "_type": "Value"}, "tools": {"dtype": "string", "_type": "Value"},
"images": {"feature": {"_type": "Image"}, "_type": "Sequence"},
} }
) )
kwargs = {} kwargs = {}

View File

@ -1,6 +1,6 @@
import inspect import inspect
import os import os
from typing import TYPE_CHECKING, Literal, Union from typing import TYPE_CHECKING, Literal, Union, Optional
from datasets import load_dataset, load_from_disk from datasets import load_dataset, load_from_disk
@ -25,9 +25,9 @@ logger = get_logger(__name__)
def load_single_dataset( def load_single_dataset(
dataset_attr: "DatasetAttr", dataset_attr: "DatasetAttr",
model_args: "ModelArguments", model_args: "ModelArguments",
data_args: "DataArguments", data_args: "DataArguments",
) -> Union["Dataset", "IterableDataset"]: ) -> Union["Dataset", "IterableDataset"]:
logger.info("Loading dataset {}...".format(dataset_attr)) logger.info("Loading dataset {}...".format(dataset_attr))
data_path, data_name, data_dir, data_files = None, None, None, None data_path, data_name, data_dir, data_files = None, None, None, None
@ -78,14 +78,20 @@ def load_single_dataset(
split=data_args.split, split=data_args.split,
cache_dir=cache_dir, cache_dir=cache_dir,
token=model_args.ms_hub_token, token=model_args.ms_hub_token,
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")), use_streaming=(
data_args.streaming and (dataset_attr.load_from != "file")
),
) )
if isinstance(dataset, MsDataset): if isinstance(dataset, MsDataset):
dataset = dataset.to_hf_dataset() dataset = dataset.to_hf_dataset()
except ImportError: except ImportError:
raise ImportError("Please install modelscope via `pip install modelscope -U`") raise ImportError(
"Please install modelscope via `pip install modelscope -U`"
)
else: else:
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0 if (
"trust_remote_code" in inspect.signature(load_dataset).parameters
): # for datasets==2.16.0
kwargs = {"trust_remote_code": True} kwargs = {"trust_remote_code": True}
else: else:
kwargs = {} kwargs = {}
@ -102,7 +108,9 @@ def load_single_dataset(
**kwargs, **kwargs,
) )
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True if data_args.streaming and (
dataset_attr.load_from == "file"
): # faster than specifying streaming=True
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
if data_args.max_samples is not None: # truncate dataset if data_args.max_samples is not None: # truncate dataset
@ -113,11 +121,12 @@ def load_single_dataset(
def get_dataset( def get_dataset(
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments", model_args: "ModelArguments",
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"], stage: Literal["pt", "sft", "rm", "ppo"],
processor: Optional["AutoProcessor"] = None,
) -> Union["Dataset", "IterableDataset"]: ) -> Union["Dataset", "IterableDataset"]:
template = get_template_and_fix_tokenizer(tokenizer, data_args.template) template = get_template_and_fix_tokenizer(tokenizer, data_args.template)
if data_args.train_on_prompt and template.efficient_eos: if data_args.train_on_prompt and template.efficient_eos:
@ -126,9 +135,13 @@ def get_dataset(
# Load tokenized dataset # Load tokenized dataset
if data_args.tokenized_path is not None: if data_args.tokenized_path is not None:
if has_tokenized_data(data_args.tokenized_path): if has_tokenized_data(data_args.tokenized_path):
logger.warning("Loading dataset from disk will ignore other data arguments.") logger.warning(
"Loading dataset from disk will ignore other data arguments."
)
dataset = load_from_disk(data_args.tokenized_path) dataset = load_from_disk(data_args.tokenized_path)
logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path)) logger.info(
"Loaded tokenized dataset from {}.".format(data_args.tokenized_path)
)
if data_args.streaming: if data_args.streaming:
dataset = dataset.to_iterable_dataset() dataset = dataset.to_iterable_dataset()
return dataset return dataset
@ -139,15 +152,21 @@ def get_dataset(
with training_args.main_process_first(desc="load dataset"): with training_args.main_process_first(desc="load dataset"):
all_datasets = [] all_datasets = []
for dataset_attr in get_dataset_list(data_args): for dataset_attr in get_dataset_list(data_args):
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True): if (stage == "rm" and dataset_attr.ranking is False) or (
raise ValueError("The dataset is not applicable in the current training stage.") stage != "rm" and dataset_attr.ranking is True
):
raise ValueError(
"The dataset is not applicable in the current training stage."
)
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args)) all_datasets.append(
load_single_dataset(dataset_attr, model_args, data_args)
)
dataset = merge_dataset(all_datasets, data_args, training_args) dataset = merge_dataset(all_datasets, data_args, training_args)
with training_args.main_process_first(desc="pre-process dataset"): with training_args.main_process_first(desc="pre-process dataset"):
preprocess_func, print_function = get_preprocess_and_print_func( preprocess_func, print_function = get_preprocess_and_print_func(
tokenizer, template, data_args, training_args, stage tokenizer, template, data_args, training_args, stage, processor
) )
column_names = list(next(iter(dataset)).keys()) column_names = list(next(iter(dataset)).keys())
kwargs = {} kwargs = {}
@ -158,13 +177,21 @@ def get_dataset(
desc="Running tokenizer on dataset", desc="Running tokenizer on dataset",
) )
dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs) dataset = dataset.map(
preprocess_func, batched=True, remove_columns=column_names, **kwargs
)
if data_args.tokenized_path is not None: if data_args.tokenized_path is not None:
if training_args.should_save: if training_args.should_save:
dataset.save_to_disk(data_args.tokenized_path) dataset.save_to_disk(data_args.tokenized_path)
logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path)) logger.info(
logger.info("Please restart the training with `--tokenized_path {}`.".format(data_args.tokenized_path)) "Tokenized dataset saved at {}.".format(data_args.tokenized_path)
)
logger.info(
"Please restart the training with `--tokenized_path {}`.".format(
data_args.tokenized_path
)
)
exit(0) exit(0)
@ -172,34 +199,8 @@ def get_dataset(
try: try:
print_function(next(iter(dataset))) print_function(next(iter(dataset)))
except StopIteration: except StopIteration:
raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.") raise RuntimeError(
"Cannot find valid samples, check `data/README.md` for the data format."
)
return dataset return dataset
def get_mm_dataset(
processor: "AutoProcessor",
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"],
) -> Union["Dataset", "IterableDataset"]:
if data_args.tokenized_path is not None:
if has_tokenized_data(data_args.tokenized_path):
logger.warning("Loading dataset from disk will ignore other data arguments.")
dataset = load_from_disk(data_args.tokenized_path)
logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path))
if data_args.streaming:
dataset = dataset.to_iterable_dataset()
return dataset
if data_args.streaming:
raise ValueError("Turn off `streaming` when saving dataset to disk.")
with training_args.main_process_first(desc="load dataset"):
all_datasets = []
for dataset_attr in get_dataset_list(data_args):
all_datasets.append(load_dataset(dataset_attr.dataset_name)['train'])
dataset = merge_dataset(all_datasets, data_args, training_args)
return dataset

View File

@ -25,7 +25,7 @@ class DatasetAttr:
subset: Optional[str] = None subset: Optional[str] = None
folder: Optional[str] = None folder: Optional[str] = None
ranking: bool = False ranking: bool = False
formatting: Literal["alpaca", "sharegpt"] = "alpaca" formatting: Literal["alpaca", "sharegpt", "llava"] = "alpaca"
""" columns """ """ columns """
system: Optional[str] = None system: Optional[str] = None
""" columns for the alpaca format """ """ columns for the alpaca format """
@ -44,11 +44,15 @@ class DatasetAttr:
observation_tag: Optional[str] = "observation" observation_tag: Optional[str] = "observation"
function_tag: Optional[str] = "function_call" function_tag: Optional[str] = "function_call"
system_tag: Optional[str] = "system" system_tag: Optional[str] = "system"
""" columns for the mllm format """
images: Optional[str] = None
def __repr__(self) -> str: def __repr__(self) -> str:
return self.dataset_name return self.dataset_name
def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None: def set_attr(
self, key: str, obj: Dict[str, Any], default: Optional[Any] = None
) -> None:
setattr(self, key, obj.get(key, default)) setattr(self, key, obj.get(key, default))
@ -67,12 +71,16 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
except Exception as err: except Exception as err:
if len(dataset_names) != 0: if len(dataset_names) != 0:
raise ValueError( raise ValueError(
"Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err)) "Cannot open {} due to {}.".format(
os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err)
)
) )
dataset_info = None dataset_info = None
if data_args.interleave_probs is not None: if data_args.interleave_probs is not None:
data_args.interleave_probs = [float(prob.strip()) for prob in data_args.interleave_probs.split(",")] data_args.interleave_probs = [
float(prob.strip()) for prob in data_args.interleave_probs.split(",")
]
dataset_list: List[DatasetAttr] = [] dataset_list: List[DatasetAttr] = []
for name in dataset_names: for name in dataset_names:
@ -90,31 +98,42 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
if has_hf_url or has_ms_url: if has_hf_url or has_ms_url:
if (use_modelscope() and has_ms_url) or (not has_hf_url): if (use_modelscope() and has_ms_url) or (not has_hf_url):
dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"]) dataset_attr = DatasetAttr(
"ms_hub", dataset_name=dataset_info[name]["ms_hub_url"]
)
else: else:
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"]) dataset_attr = DatasetAttr(
"hf_hub", dataset_name=dataset_info[name]["hf_hub_url"]
)
elif "script_url" in dataset_info[name]: elif "script_url" in dataset_info[name]:
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"]) dataset_attr = DatasetAttr(
"script", dataset_name=dataset_info[name]["script_url"]
)
else: else:
dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"]) dataset_attr = DatasetAttr(
"file", dataset_name=dataset_info[name]["file_name"]
)
dataset_attr.set_attr("file_sha1", dataset_info[name]) dataset_attr.set_attr("file_sha1", dataset_info[name])
dataset_attr.set_attr("subset", dataset_info[name]) dataset_attr.set_attr("subset", dataset_info[name])
dataset_attr.set_attr("folder", dataset_info[name]) dataset_attr.set_attr("folder", dataset_info[name])
dataset_attr.set_attr("ranking", dataset_info[name], default=False) dataset_attr.set_attr("ranking", dataset_info[name], default=False)
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca") dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
dataset_attr.set_attr("images", dataset_info[name], default="")
if "columns" in dataset_info[name]: if "columns" in dataset_info[name]:
column_names = ["system"] column_names = ["system"]
if dataset_attr.formatting == "alpaca": if dataset_attr.formatting == "alpaca":
column_names.extend(["prompt", "query", "response", "history"]) column_names.extend(["prompt", "query", "response", "history"])
elif dataset_attr.formatting == "llava":
column_names.extend(["messages", "images"])
else: else:
column_names.extend(["messages", "tools"]) column_names.extend(["messages", "tools"])
for column_name in column_names: for column_name in column_names:
dataset_attr.set_attr(column_name, dataset_info[name]["columns"]) dataset_attr.set_attr(column_name, dataset_info[name]["columns"])
if dataset_attr.formatting == "sharegpt" and "tags" in dataset_info[name]: if dataset_attr.formatting != "alpaca" and "tags" in dataset_info[name]:
tag_names = ( tag_names = (
"role_tag", "role_tag",
"content_tag", "content_tag",

View File

@ -1,6 +1,6 @@
from functools import partial from functools import partial
from itertools import chain from itertools import chain
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Tuple from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Tuple, Optional
from ..extras.constants import IGNORE_INDEX from ..extras.constants import IGNORE_INDEX
from ..extras.logging import get_logger from ..extras.logging import get_logger
@ -9,7 +9,7 @@ from .utils import Role
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer from transformers.tokenization_utils import PreTrainedTokenizer, AutoProcessor
from ..hparams import DataArguments from ..hparams import DataArguments
from .template import Template from .template import Template
@ -19,19 +19,27 @@ logger = get_logger(__name__)
def preprocess_pretrain_dataset( def preprocess_pretrain_dataset(
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments" examples: Dict[str, List[Any]],
tokenizer: "PreTrainedTokenizer",
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[List[int]]]:
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled # build grouped texts with format `X1 X2 X3 ...` if packing is enabled
text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]] text_examples = [
messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]
]
if not data_args.packing: if not data_args.packing:
if data_args.template == "gemma": if data_args.template == "gemma":
text_examples = [tokenizer.bos_token + example for example in text_examples] text_examples = [tokenizer.bos_token + example for example in text_examples]
result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len) result = tokenizer(
text_examples, add_special_tokens=False, max_length=data_args.cutoff_len
)
else: else:
tokenized_examples = tokenizer(text_examples, add_special_tokens=False) tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()} concatenated_examples = {
k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()
}
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]]) total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
block_size = data_args.cutoff_len block_size = data_args.cutoff_len
total_length = (total_length // block_size) * block_size total_length = (total_length // block_size) * block_size
@ -54,7 +62,11 @@ def preprocess_supervised_dataset(
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>` # build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair. # for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} model_inputs = {
"input_ids": [],
"attention_mask": [],
"labels": [],
}
for i in range(len(examples["prompt"])): for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1: if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
@ -75,7 +87,9 @@ def preprocess_supervised_dataset(
if data_args.train_on_prompt: if data_args.train_on_prompt:
source_mask = source_ids source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos: elif turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1) source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (
len(source_ids) - 1
)
else: else:
source_mask = [IGNORE_INDEX] * len(source_ids) source_mask = [IGNORE_INDEX] * len(source_ids)
@ -114,7 +128,9 @@ def preprocess_packed_supervised_dataset(
if data_args.train_on_prompt: if data_args.train_on_prompt:
source_mask = source_ids source_mask = source_ids
elif len(input_ids) != 0 and template.efficient_eos: elif len(input_ids) != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1) source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (
len(source_ids) - 1
)
else: else:
source_mask = [IGNORE_INDEX] * len(source_ids) source_mask = [IGNORE_INDEX] * len(source_ids)
@ -139,6 +155,64 @@ def preprocess_packed_supervised_dataset(
return model_inputs return model_inputs
def preprocess_multimodal_supervised_dataset(
examples: Dict[str, List[Any]],
processor: "AutoProcessor",
template: "Template",
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
tokenizer = processor.tokenizer
model_inputs = {
"input_ids": [],
"attention_mask": [],
"labels": [],
"pixel_values": [],
}
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
continue
messages = examples["prompt"][i] + examples["response"][i]
input_ids, labels = [], []
for turn_idx, (source_ids, target_ids) in enumerate(
template.encode_multiturn(
tokenizer,
messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
):
if data_args.train_on_prompt:
source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (
len(source_ids) - 1
)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
input_ids += source_ids + target_ids
labels += source_mask + target_ids
if template.efficient_eos:
input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id]
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
pixel_values = processor.image_processor(
examples["images"][0], return_tensors="pt"
)["pixel_values"][0]
model_inputs["pixel_values"].append(pixel_values)
return model_inputs
def preprocess_unsupervised_dataset( def preprocess_unsupervised_dataset(
examples: Dict[str, List[Any]], examples: Dict[str, List[Any]],
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
@ -155,7 +229,9 @@ def preprocess_unsupervised_dataset(
if len(examples["response"][i]) == 1: if len(examples["response"][i]) == 1:
messages = examples["prompt"][i] + examples["response"][i] messages = examples["prompt"][i] + examples["response"][i]
else: else:
messages = examples["prompt"][i] + [{"role": Role.ASSISTANT.value, "content": ""}] messages = examples["prompt"][i] + [
{"role": Role.ASSISTANT.value, "content": ""}
]
input_ids, labels = template.encode_oneturn( input_ids, labels = template.encode_oneturn(
tokenizer, tokenizer,
@ -218,29 +294,58 @@ def preprocess_pairwise_dataset(
return model_inputs return model_inputs
def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None: def print_supervised_dataset_example(
example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer"
) -> None:
print("input_ids:\n{}".format(example["input_ids"])) print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) print(
"inputs:\n{}".format(
tokenizer.decode(example["input_ids"], skip_special_tokens=False)
)
)
print("label_ids:\n{}".format(example["labels"])) print("label_ids:\n{}".format(example["labels"]))
print( print(
"labels:\n{}".format( "labels:\n{}".format(
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False) tokenizer.decode(
list(filter(lambda x: x != IGNORE_INDEX, example["labels"])),
skip_special_tokens=False,
)
) )
) )
def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None: def print_pairwise_dataset_example(
example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer"
) -> None:
print("prompt_ids:\n{}".format(example["prompt_ids"])) print("prompt_ids:\n{}".format(example["prompt_ids"]))
print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False))) print(
"prompt:\n{}".format(
tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)
)
)
print("chosen_ids:\n{}".format(example["chosen_ids"])) print("chosen_ids:\n{}".format(example["chosen_ids"]))
print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False))) print(
"chosen:\n{}".format(
tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)
)
)
print("rejected_ids:\n{}".format(example["rejected_ids"])) print("rejected_ids:\n{}".format(example["rejected_ids"]))
print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False))) print(
"rejected:\n{}".format(
tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)
)
)
def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None: def print_unsupervised_dataset_example(
example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer"
) -> None:
print("input_ids:\n{}".format(example["input_ids"])) print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) print(
"inputs:\n{}".format(
tokenizer.decode(example["input_ids"], skip_special_tokens=False)
)
)
def get_preprocess_and_print_func( def get_preprocess_and_print_func(
@ -249,30 +354,56 @@ def get_preprocess_and_print_func(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"], stage: Literal["pt", "sft", "rm", "ppo"],
processor: Optional["AutoProcessor"] = None,
) -> Tuple[Callable, Callable]: ) -> Tuple[Callable, Callable]:
if stage == "pt": if stage == "pt":
preprocess_func = partial(preprocess_pretrain_dataset, tokenizer=tokenizer, data_args=data_args) preprocess_func = partial(
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) preprocess_pretrain_dataset, tokenizer=tokenizer, data_args=data_args
)
print_function = partial(
print_unsupervised_dataset_example, tokenizer=tokenizer
)
elif stage == "sft" and not training_args.predict_with_generate: elif stage == "sft" and not training_args.predict_with_generate:
if data_args.packing: if data_args.packing:
preprocess_func = partial( preprocess_func = partial(
preprocess_packed_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args preprocess_packed_supervised_dataset,
tokenizer=tokenizer,
template=template,
data_args=data_args,
)
elif processor is not None:
preprocess_func = partial(
preprocess_multimodal_supervised_dataset,
processor=processor,
template=template,
data_args=data_args,
) )
else: else:
preprocess_func = partial( preprocess_func = partial(
preprocess_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args preprocess_supervised_dataset,
tokenizer=tokenizer,
template=template,
data_args=data_args,
) )
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer) print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
elif stage == "rm": elif stage == "rm":
preprocess_func = partial( preprocess_func = partial(
preprocess_pairwise_dataset, tokenizer=tokenizer, template=template, data_args=data_args preprocess_pairwise_dataset,
tokenizer=tokenizer,
template=template,
data_args=data_args,
) )
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer) print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
else: else:
preprocess_func = partial( preprocess_func = partial(
preprocess_unsupervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args preprocess_unsupervised_dataset,
tokenizer=tokenizer,
template=template,
data_args=data_args,
)
print_function = partial(
print_unsupervised_dataset_example, tokenizer=tokenizer
) )
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
return preprocess_func, print_function return preprocess_func, print_function

View File

@ -42,7 +42,9 @@ class Template:
r""" r"""
Returns a single pair of token ids representing prompt and response respectively. Returns a single pair of token ids representing prompt and response respectively.
""" """
encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len) encoded_pairs = self._encode(
tokenizer, messages, system, tools, cutoff_len, reserved_label_len
)
prompt_ids = [] prompt_ids = []
for query_ids, resp_ids in encoded_pairs[:-1]: for query_ids, resp_ids in encoded_pairs[:-1]:
prompt_ids += query_ids + resp_ids prompt_ids += query_ids + resp_ids
@ -62,7 +64,9 @@ class Template:
r""" r"""
Returns multiple pairs of token ids representing prompts and responses respectively. Returns multiple pairs of token ids representing prompts and responses respectively.
""" """
return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len) return self._encode(
tokenizer, messages, system, tools, cutoff_len, reserved_label_len
)
def _encode( def _encode(
self, self,
@ -89,7 +93,9 @@ class Template:
elements += self.format_separator.apply() elements += self.format_separator.apply()
if message["role"] == Role.USER.value: if message["role"] == Role.USER.value:
elements += self.format_user.apply(content=message["content"], idx=str(i // 2)) elements += self.format_user.apply(
content=message["content"], idx=str(i // 2)
)
elif message["role"] == Role.ASSISTANT.value: elif message["role"] == Role.ASSISTANT.value:
elements += self.format_assistant.apply(content=message["content"]) elements += self.format_assistant.apply(content=message["content"])
elif message["role"] == Role.OBSERVATION.value: elif message["role"] == Role.OBSERVATION.value:
@ -104,7 +110,9 @@ class Template:
return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len) return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
def _convert_elements_to_ids( def _convert_elements_to_ids(
self, tokenizer: "PreTrainedTokenizer", elements: List[Union[str, Dict[str, str]]] self,
tokenizer: "PreTrainedTokenizer",
elements: List[Union[str, Dict[str, str]]],
) -> List[int]: ) -> List[int]:
r""" r"""
Converts elements to token ids. Converts elements to token ids.
@ -122,7 +130,11 @@ class Template:
elif "eos_token" in elem and tokenizer.eos_token_id is not None: elif "eos_token" in elem and tokenizer.eos_token_id is not None:
token_ids += [tokenizer.eos_token_id] token_ids += [tokenizer.eos_token_id]
else: else:
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem))) raise ValueError(
"Input must be string, set[str] or dict[str, str], got {}".format(
type(elem)
)
)
return token_ids return token_ids
@ -180,7 +192,9 @@ class Llama2Template(Template):
elements += self.format_separator.apply() elements += self.format_separator.apply()
if message["role"] == Role.USER.value: if message["role"] == Role.USER.value:
elements += self.format_user.apply(content=system_text + message["content"]) elements += self.format_user.apply(
content=system_text + message["content"]
)
elif message["role"] == Role.ASSISTANT.value: elif message["role"] == Role.ASSISTANT.value:
elements += self.format_assistant.apply(content=message["content"]) elements += self.format_assistant.apply(content=message["content"])
elif message["role"] == Role.OBSERVATION.value: elif message["role"] == Role.OBSERVATION.value:
@ -243,7 +257,9 @@ def _register_template(
template_class = Llama2Template if name.startswith("llama2") else Template template_class = Llama2Template if name.startswith("llama2") else Template
default_user_formatter = StringFormatter(slots=["{{content}}"]) default_user_formatter = StringFormatter(slots=["{{content}}"])
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots) default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots) default_function_formatter = FunctionFormatter(
slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots
)
default_tool_formatter = ToolFormatter(tool_format="default") default_tool_formatter = ToolFormatter(tool_format="default")
default_separator_formatter = EmptyFormatter() default_separator_formatter = EmptyFormatter()
templates[name] = template_class( templates[name] = template_class(
@ -279,7 +295,9 @@ def _jinja_escape(content: str) -> str:
return content.replace("\n", r"\n").replace("'", r"\'") return content.replace("\n", r"\n").replace("'", r"\'")
def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str: def _convert_slots_to_jinja(
slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content"
) -> str:
slot_items = [] slot_items = []
for slot in slots: for slot in slots:
if isinstance(slot, str): if isinstance(slot, str):
@ -293,7 +311,9 @@ def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", pl
elif isinstance(slot, set): elif isinstance(slot, set):
if "bos_token" in slot: if "bos_token" in slot:
slot_items.append("'" + tokenizer.bos_token + "'") slot_items.append("'" + tokenizer.bos_token + "'")
elif "eos_token" in slot: # do not use {{ eos_token }} since it may be replaced elif (
"eos_token" in slot
): # do not use {{ eos_token }} since it may be replaced
slot_items.append("'" + tokenizer.eos_token + "'") slot_items.append("'" + tokenizer.eos_token + "'")
elif isinstance(slot, dict): elif isinstance(slot, dict):
raise ValueError("Dict is not supported.") raise ValueError("Dict is not supported.")
@ -305,25 +325,37 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
jinja_template = "" jinja_template = ""
if template.default_system: if template.default_system:
jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}" jinja_template += (
"{% set system_message = '"
+ _jinja_escape(template.default_system)
+ "' %}"
)
jinja_template += ( jinja_template += (
"{% if messages[0]['role'] == 'system' %}" "{% set system_message = messages[0]['content'] %}" "{% endif %}" "{% if messages[0]['role'] == 'system' %}"
"{% set system_message = messages[0]['content'] %}"
"{% endif %}"
) )
system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message") system_message = _convert_slots_to_jinja(
template.format_system.apply(), tokenizer, placeholder="system_message"
)
if isinstance(template, Llama2Template): if isinstance(template, Llama2Template):
pass pass
elif template.force_system: elif template.force_system:
jinja_template += "{{ " + system_message + " }}" jinja_template += "{{ " + system_message + " }}"
else: else:
jinja_template += "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}" jinja_template += (
"{% if system_message is defined %}{{ " + system_message + " }}{% endif %}"
)
jinja_template += "{% for message in messages %}" jinja_template += "{% for message in messages %}"
jinja_template += "{% set content = message['content'] %}" jinja_template += "{% set content = message['content'] %}"
if isinstance(template, Llama2Template): if isinstance(template, Llama2Template):
jinja_template += "{% if loop.index0 == 0 and system_message is defined %}" jinja_template += "{% if loop.index0 == 0 and system_message is defined %}"
jinja_template += "{% set content = " + system_message + " + message['content'] %}" jinja_template += (
"{% set content = " + system_message + " + message['content'] %}"
)
jinja_template += "{% endif %}" jinja_template += "{% endif %}"
jinja_template += "{% if message['role'] == 'user' %}" jinja_template += "{% if message['role'] == 'user' %}"
user_message = _convert_slots_to_jinja(template.format_user.apply(), tokenizer) user_message = _convert_slots_to_jinja(template.format_user.apply(), tokenizer)
@ -366,11 +398,14 @@ def get_template_and_fix_tokenizer(
if stop_words: if stop_words:
num_added_tokens = tokenizer.add_special_tokens( num_added_tokens = tokenizer.add_special_tokens(
dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False dict(additional_special_tokens=stop_words),
replace_additional_special_tokens=False,
) )
logger.info("Add {} to stop words.".format(",".join(stop_words))) logger.info("Add {} to stop words.".format(",".join(stop_words)))
if num_added_tokens > 0: if num_added_tokens > 0:
logger.warning("New tokens have been added, make sure `resize_vocab` is True.") logger.warning(
"New tokens have been added, make sure `resize_vocab` is True."
)
try: try:
tokenizer.chat_template = _get_jinja_template(template, tokenizer) tokenizer.chat_template = _get_jinja_template(template, tokenizer)
@ -382,7 +417,9 @@ def get_template_and_fix_tokenizer(
_register_template( _register_template(
name="alpaca", name="alpaca",
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]), format_user=StringFormatter(
slots=["### Instruction:\n{{content}}\n\n### Response:\n"]
),
format_separator=EmptyFormatter(slots=["\n\n"]), format_separator=EmptyFormatter(slots=["\n\n"]),
default_system=( default_system=(
"Below is an instruction that describes a task. " "Below is an instruction that describes a task. "
@ -407,7 +444,13 @@ _register_template(
_register_template( _register_template(
name="atom", name="atom",
format_user=StringFormatter( format_user=StringFormatter(
slots=[{"bos_token"}, "Human: {{content}}\n", {"eos_token"}, {"bos_token"}, "Assistant:"] slots=[
{"bos_token"},
"Human: {{content}}\n",
{"eos_token"},
{"bos_token"},
"Assistant:",
]
), ),
format_assistant=StringFormatter(slots=["{{content}}\n", {"eos_token"}]), format_assistant=StringFormatter(slots=["{{content}}\n", {"eos_token"}]),
) )
@ -415,7 +458,9 @@ _register_template(
_register_template( _register_template(
name="baichuan", name="baichuan",
format_user=StringFormatter(slots=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]), format_user=StringFormatter(
slots=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]
),
efficient_eos=True, efficient_eos=True,
) )
@ -438,7 +483,9 @@ _register_template(
_register_template( _register_template(
name="bluelm", name="bluelm",
format_user=StringFormatter(slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]), format_user=StringFormatter(
slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]
),
) )
@ -457,7 +504,9 @@ _register_template(
_register_template( _register_template(
name="chatglm2", name="chatglm2",
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]), format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]), format_system=StringFormatter(
slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]
),
format_separator=EmptyFormatter(slots=["\n\n"]), format_separator=EmptyFormatter(slots=["\n\n"]),
efficient_eos=True, efficient_eos=True,
force_system=True, force_system=True,
@ -466,12 +515,21 @@ _register_template(
_register_template( _register_template(
name="chatglm3", name="chatglm3",
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]), format_user=StringFormatter(
slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
),
format_assistant=StringFormatter(slots=["\n", "{{content}}"]), format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]), format_system=StringFormatter(
slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]
),
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]), format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
format_observation=StringFormatter( format_observation=StringFormatter(
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}] slots=[
{"token": "<|observation|>"},
"\n",
"{{content}}",
{"token": "<|assistant|>"},
]
), ),
stop_words=["<|user|>", "<|observation|>"], stop_words=["<|user|>", "<|observation|>"],
efficient_eos=True, efficient_eos=True,
@ -481,14 +539,27 @@ _register_template(
_register_template( _register_template(
name="chatglm3_system", name="chatglm3_system",
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]), format_user=StringFormatter(
slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
),
format_assistant=StringFormatter(slots=["\n", "{{content}}"]), format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
format_system=StringFormatter( format_system=StringFormatter(
slots=[{"token": "[gMASK]"}, {"token": "sop"}, {"token": "<|system|>"}, "\n", "{{content}}"] slots=[
{"token": "[gMASK]"},
{"token": "sop"},
{"token": "<|system|>"},
"\n",
"{{content}}",
]
), ),
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]), format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
format_observation=StringFormatter( format_observation=StringFormatter(
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}] slots=[
{"token": "<|observation|>"},
"\n",
"{{content}}",
{"token": "<|assistant|>"},
]
), ),
default_system=( default_system=(
"You are ChatGLM3, a large language model trained by Zhipu.AI. " "You are ChatGLM3, a large language model trained by Zhipu.AI. "
@ -501,9 +572,15 @@ _register_template(
_register_template( _register_template(
name="chatml", name="chatml",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), ),
format_system=StringFormatter(
slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]
),
format_observation=StringFormatter(
slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
),
format_separator=EmptyFormatter(slots=["\n"]), format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|im_end|>", "<|im_start|>"], stop_words=["<|im_end|>", "<|im_start|>"],
replace_eos=True, replace_eos=True,
@ -512,9 +589,15 @@ _register_template(
_register_template( _register_template(
name="chatml_de", name="chatml_de",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), ),
format_system=StringFormatter(
slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]
),
format_observation=StringFormatter(
slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
),
format_separator=EmptyFormatter(slots=["\n"]), format_separator=EmptyFormatter(slots=["\n"]),
default_system="Du bist ein freundlicher und hilfsbereiter KI-Assistent.", default_system="Du bist ein freundlicher und hilfsbereiter KI-Assistent.",
stop_words=["<|im_end|>", "<|im_start|>"], stop_words=["<|im_end|>", "<|im_start|>"],
@ -524,7 +607,9 @@ _register_template(
_register_template( _register_template(
name="codegeex2", name="codegeex2",
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]), format_system=StringFormatter(
slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]
),
force_system=True, force_system=True,
) )
@ -554,9 +639,15 @@ _register_template(
_register_template( _register_template(
name="dbrx", name="dbrx",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), ),
format_system=StringFormatter(
slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]
),
format_observation=StringFormatter(
slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
),
format_separator=EmptyFormatter(slots=["\n"]), format_separator=EmptyFormatter(slots=["\n"]),
default_system=( default_system=(
"You are DBRX, created by Databricks. You were last updated in December 2023. " "You are DBRX, created by Databricks. You were last updated in December 2023. "
@ -634,7 +725,9 @@ _register_template(
_register_template( _register_template(
name="gemma", name="gemma",
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]), format_user=StringFormatter(
slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
format_observation=StringFormatter( format_observation=StringFormatter(
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"] slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
@ -647,7 +740,9 @@ _register_template(
_register_template( _register_template(
name="intern", name="intern",
format_user=StringFormatter(slots=["<|User|>:{{content}}", {"token": "<eoh>"}, "\n<|Bot|>:"]), format_user=StringFormatter(
slots=["<|User|>:{{content}}", {"token": "<eoh>"}, "\n<|Bot|>:"]
),
format_separator=EmptyFormatter(slots=[{"token": "<eoa>"}, "\n"]), format_separator=EmptyFormatter(slots=[{"token": "<eoa>"}, "\n"]),
stop_words=["<eoa>"], stop_words=["<eoa>"],
efficient_eos=True, efficient_eos=True,
@ -656,8 +751,12 @@ _register_template(
_register_template( _register_template(
name="intern2", name="intern2",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(
format_system=StringFormatter(slots=[{"bos_token"}, "<|im_start|>system\n{{content}}<|im_end|>\n"]), slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
),
format_system=StringFormatter(
slots=[{"bos_token"}, "<|im_start|>system\n{{content}}<|im_end|>\n"]
),
format_separator=EmptyFormatter(slots=["\n"]), format_separator=EmptyFormatter(slots=["\n"]),
default_system=( default_system=(
"You are an AI assistant whose name is InternLM (书生·浦语).\n" "You are an AI assistant whose name is InternLM (书生·浦语).\n"
@ -707,7 +806,10 @@ _register_template(
] ]
), ),
format_system=StringFormatter( format_system=StringFormatter(
slots=[{"bos_token"}, "<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"] slots=[
{"bos_token"},
"<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>",
]
), ),
format_observation=StringFormatter( format_observation=StringFormatter(
slots=[ slots=[
@ -742,7 +844,13 @@ _register_template(
_register_template( _register_template(
name="openchat", name="openchat",
format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]), format_user=StringFormatter(
slots=[
"GPT4 Correct User: {{content}}",
{"eos_token"},
"GPT4 Correct Assistant:",
]
),
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}]), format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
force_system=True, force_system=True,
@ -751,7 +859,9 @@ _register_template(
_register_template( _register_template(
name="orion", name="orion",
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]), format_user=StringFormatter(
slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]
),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
force_system=True, force_system=True,
) )
@ -759,9 +869,15 @@ _register_template(
_register_template( _register_template(
name="phi", name="phi",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]), format_user=StringFormatter(
format_system=StringFormatter(slots=[{"bos_token"}, "<|system|>\n{{content}}<|end|>\n"]), slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]
format_observation=StringFormatter(slots=["<|function_output|>\n{{content}}<|end|>\n<|assistant|>\n"]), ),
format_system=StringFormatter(
slots=[{"bos_token"}, "<|system|>\n{{content}}<|end|>\n"]
),
format_observation=StringFormatter(
slots=["<|function_output|>\n{{content}}<|end|>\n<|assistant|>\n"]
),
format_separator=EmptyFormatter(slots=["\n"]), format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful AI assistant.", default_system="You are a helpful AI assistant.",
stop_words=["<|end|>"], stop_words=["<|end|>"],
@ -771,9 +887,15 @@ _register_template(
_register_template( _register_template(
name="qwen", name="qwen",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), ),
format_system=StringFormatter(
slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]
),
format_observation=StringFormatter(
slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
),
format_separator=EmptyFormatter(slots=["\n"]), format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful assistant.", default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
@ -829,8 +951,12 @@ _register_template(
_register_template( _register_template(
name="yayi", name="yayi",
format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]), format_user=StringFormatter(
format_system=StringFormatter(slots=[{"token": "<|System|>"}, ":\n{{content}}\n\n"]), slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]
),
format_system=StringFormatter(
slots=[{"token": "<|System|>"}, ":\n{{content}}\n\n"]
),
format_separator=EmptyFormatter(slots=["\n\n"]), format_separator=EmptyFormatter(slots=["\n\n"]),
default_system=( default_system=(
"You are a helpful, respectful and honest assistant named YaYi " "You are a helpful, respectful and honest assistant named YaYi "
@ -849,7 +975,9 @@ _register_template(
_register_template( _register_template(
name="yi", name="yi",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(
slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
),
format_separator=EmptyFormatter(slots=["\n"]), format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
replace_eos=True, replace_eos=True,
@ -867,7 +995,9 @@ _register_template(
_register_template( _register_template(
name="zephyr", name="zephyr",
format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>"]), format_user=StringFormatter(
slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>"]
),
format_assistant=StringFormatter(slots=["\n{{content}}", {"eos_token"}]), format_assistant=StringFormatter(slots=["\n{{content}}", {"eos_token"}]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]), format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
default_system="You are a friendly chatbot who always responds in the style of a pirate", default_system="You are a friendly chatbot who always responds in the style of a pirate",
@ -879,3 +1009,13 @@ _register_template(
format_user=StringFormatter(slots=["<human>:{{content}}\n<bot>:"]), format_user=StringFormatter(slots=["<human>:{{content}}\n<bot>:"]),
format_separator=EmptyFormatter(slots=["\n"]), format_separator=EmptyFormatter(slots=["\n"]),
) )
_register_template(
name="llava",
format_user=StringFormatter(slots=["USER: {{content}} "]),
format_assistant=StringFormatter(slots=["ASSISTANT: {{content}}"]),
default_system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
)

View File

@ -15,23 +15,33 @@ class ModelArguments:
) )
adapter_name_or_path: Optional[str] = field( adapter_name_or_path: Optional[str] = field(
default=None, default=None,
metadata={"help": "Path to the adapter weight or identifier from huggingface.co/models."}, metadata={
"help": "Path to the adapter weight or identifier from huggingface.co/models."
},
) )
cache_dir: Optional[str] = field( cache_dir: Optional[str] = field(
default=None, default=None,
metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."}, metadata={
"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."
},
) )
use_fast_tokenizer: bool = field( use_fast_tokenizer: bool = field(
default=True, default=True,
metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."}, metadata={
"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."
},
) )
resize_vocab: bool = field( resize_vocab: bool = field(
default=False, default=False,
metadata={"help": "Whether or not to resize the tokenizer vocab and the embedding layers."}, metadata={
"help": "Whether or not to resize the tokenizer vocab and the embedding layers."
},
) )
split_special_tokens: bool = field( split_special_tokens: bool = field(
default=False, default=False,
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."}, metadata={
"help": "Whether or not the special tokens should be split during the tokenization process."
},
) )
new_special_tokens: Optional[str] = field( new_special_tokens: Optional[str] = field(
default=None, default=None,
@ -39,7 +49,9 @@ class ModelArguments:
) )
model_revision: str = field( model_revision: str = field(
default="main", default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, metadata={
"help": "The specific model version to use (can be a branch name, tag name or commit id)."
},
) )
low_cpu_mem_usage: bool = field( low_cpu_mem_usage: bool = field(
default=True, default=True,
@ -47,7 +59,9 @@ class ModelArguments:
) )
quantization_bit: Optional[int] = field( quantization_bit: Optional[int] = field(
default=None, default=None,
metadata={"help": "The number of bits to quantize the model using bitsandbytes."}, metadata={
"help": "The number of bits to quantize the model using bitsandbytes."
},
) )
quantization_type: Literal["fp4", "nf4"] = field( quantization_type: Literal["fp4", "nf4"] = field(
default="nf4", default="nf4",
@ -55,15 +69,21 @@ class ModelArguments:
) )
double_quantization: bool = field( double_quantization: bool = field(
default=True, default=True,
metadata={"help": "Whether or not to use double quantization in int4 training."}, metadata={
"help": "Whether or not to use double quantization in int4 training."
},
) )
quantization_device_map: Optional[Literal["auto"]] = field( quantization_device_map: Optional[Literal["auto"]] = field(
default=None, default=None,
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."}, metadata={
"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."
},
) )
rope_scaling: Optional[Literal["linear", "dynamic"]] = field( rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
default=None, default=None,
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."}, metadata={
"help": "Which scaling strategy should be adopted for the RoPE embeddings."
},
) )
flash_attn: Literal["off", "sdpa", "fa2", "auto"] = field( flash_attn: Literal["off", "sdpa", "fa2", "auto"] = field(
default="auto", default="auto",
@ -71,19 +91,27 @@ class ModelArguments:
) )
shift_attn: bool = field( shift_attn: bool = field(
default=False, default=False,
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}, metadata={
"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."
},
) )
mixture_of_depths: Optional[Literal["convert", "load"]] = field( mixture_of_depths: Optional[Literal["convert", "load"]] = field(
default=None, default=None,
metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."}, metadata={
"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."
},
) )
use_unsloth: bool = field( use_unsloth: bool = field(
default=False, default=False,
metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."}, metadata={
"help": "Whether or not to use unsloth's optimization for the LoRA training."
},
) )
moe_aux_loss_coef: Optional[float] = field( moe_aux_loss_coef: Optional[float] = field(
default=None, default=None,
metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."}, metadata={
"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."
},
) )
disable_gradient_checkpointing: bool = field( disable_gradient_checkpointing: bool = field(
default=False, default=False,
@ -107,7 +135,9 @@ class ModelArguments:
) )
vllm_gpu_util: float = field( vllm_gpu_util: float = field(
default=0.9, default=0.9,
metadata={"help": "The fraction of GPU memory in (0,1) to be used for the vLLM engine."}, metadata={
"help": "The fraction of GPU memory in (0,1) to be used for the vLLM engine."
},
) )
vllm_enforce_eager: bool = field( vllm_enforce_eager: bool = field(
default=False, default=False,
@ -147,7 +177,9 @@ class ModelArguments:
) )
export_quantization_dataset: Optional[str] = field( export_quantization_dataset: Optional[str] = field(
default=None, default=None,
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."}, metadata={
"help": "Path to the dataset or dataset name to use in quantizing the exported model."
},
) )
export_quantization_nsamples: int = field( export_quantization_nsamples: int = field(
default=128, default=128,
@ -155,19 +187,27 @@ class ModelArguments:
) )
export_quantization_maxlen: int = field( export_quantization_maxlen: int = field(
default=1024, default=1024,
metadata={"help": "The maximum length of the model inputs used for quantization."}, metadata={
"help": "The maximum length of the model inputs used for quantization."
},
) )
export_legacy_format: bool = field( export_legacy_format: bool = field(
default=False, default=False,
metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."}, metadata={
"help": "Whether or not to save the `.bin` files instead of `.safetensors`."
},
) )
export_hub_model_id: Optional[str] = field( export_hub_model_id: Optional[str] = field(
default=None, default=None,
metadata={"help": "The name of the repository if push the model to the Hugging Face hub."}, metadata={
"help": "The name of the repository if push the model to the Hugging Face hub."
},
) )
print_param_status: bool = field( print_param_status: bool = field(
default=False, default=False,
metadata={"help": "For debugging purposes, print the status of the parameters in the model."}, metadata={
"help": "For debugging purposes, print the status of the parameters in the model."
},
) )
use_mllm: bool = field( use_mllm: bool = field(
default=False, default=False,
@ -180,18 +220,39 @@ class ModelArguments:
self.model_max_length = None self.model_max_length = None
if self.split_special_tokens and self.use_fast_tokenizer: if self.split_special_tokens and self.use_fast_tokenizer:
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.") raise ValueError(
"`split_special_tokens` is only supported for slow tokenizers."
)
if self.adapter_name_or_path is not None: # support merging multiple lora weights if (
self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")] self.adapter_name_or_path is not None
): # support merging multiple lora weights
self.adapter_name_or_path = [
path.strip() for path in self.adapter_name_or_path.split(",")
]
if self.new_special_tokens is not None: # support multiple special tokens if self.new_special_tokens is not None: # support multiple special tokens
self.new_special_tokens = [token.strip() for token in self.new_special_tokens.split(",")] self.new_special_tokens = [
token.strip() for token in self.new_special_tokens.split(",")
]
assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." assert self.quantization_bit in [
assert self.export_quantization_bit in [None, 8, 4, 3, 2], "We only accept 2/3/4/8-bit quantization." None,
8,
4,
], "We only accept 4-bit or 8-bit quantization."
assert self.export_quantization_bit in [
None,
8,
4,
3,
2,
], "We only accept 2/3/4/8-bit quantization."
if self.export_quantization_bit is not None and self.export_quantization_dataset is None: if (
self.export_quantization_bit is not None
and self.export_quantization_dataset is None
):
raise ValueError("Quantization dataset is necessary for exporting.") raise ValueError("Quantization dataset is necessary for exporting.")
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:

View File

@ -11,7 +11,7 @@ from .utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel, AutoModelForVision2Seq from transformers import PretrainedConfig, PreTrainedModel
from ..hparams import FinetuningArguments, ModelArguments from ..hparams import FinetuningArguments, ModelArguments
@ -21,11 +21,11 @@ logger = get_logger(__name__)
def init_adapter( def init_adapter(
config: "PretrainedConfig", config: "PretrainedConfig",
model: Union["PreTrainedModel","AutoModelForVision2Seq"], model: Union["PreTrainedModel"],
model_args: "ModelArguments", model_args: "ModelArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
is_trainable: bool, is_trainable: bool,
) -> Union["PreTrainedModel","AutoModelForVision2Seq"]: ) -> Union["PreTrainedModel"]:
r""" r"""
Initializes the adapters. Initializes the adapters.
@ -38,7 +38,9 @@ def init_adapter(
logger.info("Adapter is not found at evaluation, load the base model.") logger.info("Adapter is not found at evaluation, load the base model.")
return model return model
if finetuning_args.finetuning_type != "lora" and getattr(model, "quantization_method", None): if finetuning_args.finetuning_type != "lora" and getattr(
model, "quantization_method", None
):
raise ValueError("You can only use lora for quantized models.") raise ValueError("You can only use lora for quantized models.")
if finetuning_args.finetuning_type == "full" and is_trainable: if finetuning_args.finetuning_type == "full" and is_trainable:
@ -49,9 +51,9 @@ def init_adapter(
if finetuning_args.finetuning_type == "freeze" and is_trainable: if finetuning_args.finetuning_type == "freeze" and is_trainable:
logger.info("Fine-tuning method: Freeze") logger.info("Fine-tuning method: Freeze")
num_layers = ( num_layers = (
getattr(model.config, "num_hidden_layers", None) getattr(model.config, "num_hidden_layers", None)
or getattr(model.config, "num_layers", None) or getattr(model.config, "num_layers", None)
or getattr(model.config, "n_layer", None) or getattr(model.config, "n_layer", None)
) )
if not num_layers: if not num_layers:
raise ValueError("Current model does not support freeze tuning.") raise ValueError("Current model does not support freeze tuning.")
@ -66,8 +68,12 @@ def init_adapter(
stride = num_layers // finetuning_args.num_layer_trainable stride = num_layers // finetuning_args.num_layer_trainable
trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride) trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride)
elif finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0 elif (
trainable_layer_ids = range(num_layers - finetuning_args.num_layer_trainable, num_layers) finetuning_args.num_layer_trainable > 0
): # fine-tuning the last n layers if num_layer_trainable > 0
trainable_layer_ids = range(
num_layers - finetuning_args.num_layer_trainable, num_layers
)
else: # fine-tuning the first n layers if num_layer_trainable < 0 else: # fine-tuning the first n layers if num_layer_trainable < 0
trainable_layer_ids = range(-finetuning_args.num_layer_trainable) trainable_layer_ids = range(-finetuning_args.num_layer_trainable)
@ -82,11 +88,15 @@ def init_adapter(
for module_name in finetuning_args.name_module_trainable: for module_name in finetuning_args.name_module_trainable:
if module_name not in freeze_modules: if module_name not in freeze_modules:
raise ValueError( raise ValueError(
"Module {} is not found, please choose from {}".format(module_name, ", ".join(freeze_modules)) "Module {} is not found, please choose from {}".format(
module_name, ", ".join(freeze_modules)
)
) )
for idx in trainable_layer_ids: for idx in trainable_layer_ids:
trainable_layers.append(".{:d}.{}".format(idx, module_name if module_name != "all" else "")) trainable_layers.append(
".{:d}.{}".format(idx, module_name if module_name != "all" else "")
)
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if any(trainable_layer in name for trainable_layer in trainable_layers): if any(trainable_layer in name for trainable_layer in trainable_layers):
@ -95,27 +105,43 @@ def init_adapter(
else: else:
param.requires_grad_(False) param.requires_grad_(False)
logger.info("Set trainable layers: {}".format(",".join(map(str, trainable_layer_ids)))) logger.info(
"Set trainable layers: {}".format(",".join(map(str, trainable_layer_ids)))
)
if finetuning_args.finetuning_type == "lora": if finetuning_args.finetuning_type == "lora":
logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA")) logger.info(
"Fine-tuning method: {}".format(
"DoRA" if finetuning_args.use_dora else "LoRA"
)
)
adapter_to_resume = None adapter_to_resume = None
if model_args.adapter_name_or_path is not None: if model_args.adapter_name_or_path is not None:
is_mergeable = True is_mergeable = True
if getattr(model, "quantization_method", None): # merge lora in quantized model is unstable if getattr(
assert len(model_args.adapter_name_or_path) == 1, "Quantized model only accepts a single adapter." model, "quantization_method", None
): # merge lora in quantized model is unstable
assert (
len(model_args.adapter_name_or_path) == 1
), "Quantized model only accepts a single adapter."
is_mergeable = False is_mergeable = False
if is_deepspeed_zero3_enabled(): if is_deepspeed_zero3_enabled():
assert len(model_args.adapter_name_or_path) == 1, "Cannot use multiple adapters in DeepSpeed ZeRO-3." assert (
len(model_args.adapter_name_or_path) == 1
), "Cannot use multiple adapters in DeepSpeed ZeRO-3."
is_mergeable = False is_mergeable = False
if model_args.use_unsloth: if model_args.use_unsloth:
assert len(model_args.adapter_name_or_path) == 1, "Unsloth model only accepts a single adapter." assert (
len(model_args.adapter_name_or_path) == 1
), "Unsloth model only accepts a single adapter."
is_mergeable = False is_mergeable = False
if (is_trainable and not finetuning_args.create_new_adapter) or (not is_mergeable): if (is_trainable and not finetuning_args.create_new_adapter) or (
not is_mergeable
):
adapter_to_merge = model_args.adapter_name_or_path[:-1] adapter_to_merge = model_args.adapter_name_or_path[:-1]
adapter_to_resume = model_args.adapter_name_or_path[-1] adapter_to_resume = model_args.adapter_name_or_path[-1]
else: else:
@ -132,7 +158,9 @@ def init_adapter(
if adapter_to_resume is not None: # resume lora training if adapter_to_resume is not None: # resume lora training
if model_args.use_unsloth: if model_args.use_unsloth:
model = load_unsloth_peft_model(config, model_args, is_trainable=is_trainable) model = load_unsloth_peft_model(
config, model_args, is_trainable=is_trainable
)
else: else:
model = PeftModel.from_pretrained( model = PeftModel.from_pretrained(
model, model,
@ -141,19 +169,27 @@ def init_adapter(
offload_folder=model_args.offload_folder, offload_folder=model_args.offload_folder,
) )
if is_trainable and adapter_to_resume is None: # create new lora weights while training if (
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all": is_trainable and adapter_to_resume is None
): # create new lora weights while training
if (
len(finetuning_args.lora_target) == 1
and finetuning_args.lora_target[0] == "all"
):
target_modules = find_all_linear_modules(model) target_modules = find_all_linear_modules(model)
else: else:
target_modules = finetuning_args.lora_target target_modules = finetuning_args.lora_target
if finetuning_args.use_llama_pro: if finetuning_args.use_llama_pro:
target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable) target_modules = find_expanded_modules(
model, target_modules, finetuning_args.num_layer_trainable
)
if ( if (
finetuning_args.use_dora finetuning_args.use_dora
and getattr(model, "quantization_method", None) is not None and getattr(model, "quantization_method", None) is not None
and getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES and getattr(model, "quantization_method", None)
!= QuantizationMethod.BITS_AND_BYTES
): ):
raise ValueError("DoRA is not compatible with PTQ-quantized models.") raise ValueError("DoRA is not compatible with PTQ-quantized models.")
@ -166,7 +202,11 @@ def init_adapter(
module_names.add(name.split(".")[-1]) module_names.add(name.split(".")[-1])
finetuning_args.additional_target = module_names finetuning_args.additional_target = module_names
logger.warning("Vocab has been resized, add {} to trainable params.".format(",".join(module_names))) logger.warning(
"Vocab has been resized, add {} to trainable params.".format(
",".join(module_names)
)
)
peft_kwargs = { peft_kwargs = {
"r": finetuning_args.lora_rank, "r": finetuning_args.lora_rank,
@ -193,6 +233,10 @@ def init_adapter(
param.data = param.data.to(torch.float32) param.data = param.data.to(torch.float32)
if model_args.adapter_name_or_path is not None: if model_args.adapter_name_or_path is not None:
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path))) logger.info(
"Loaded adapter(s): {}".format(
",".join(model_args.adapter_name_or_path)
)
)
return model return model

View File

@ -1,6 +1,12 @@
from typing import TYPE_CHECKING, Any, Dict, Union from typing import TYPE_CHECKING, Any, Dict, Union
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AutoProcessor, AutoModelForVision2Seq from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
AutoProcessor,
AutoModelForVision2Seq,
)
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from ..extras.logging import get_logger from ..extras.logging import get_logger
@ -62,10 +68,14 @@ def load_tokenizer(model_args: "ModelArguments") -> "PreTrainedTokenizer":
dict(additional_special_tokens=model_args.new_special_tokens), dict(additional_special_tokens=model_args.new_special_tokens),
replace_additional_special_tokens=False, replace_additional_special_tokens=False,
) )
logger.info("Add {} to special tokens.".format(",".join(model_args.new_special_tokens))) logger.info(
"Add {} to special tokens.".format(",".join(model_args.new_special_tokens))
)
if num_added_tokens > 0 and not model_args.resize_vocab: if num_added_tokens > 0 and not model_args.resize_vocab:
model_args.resize_vocab = True model_args.resize_vocab = True
logger.warning("New tokens have been added, changed `resize_vocab` to True.") logger.warning(
"New tokens have been added, changed `resize_vocab` to True."
)
patch_tokenizer(tokenizer) patch_tokenizer(tokenizer)
return tokenizer return tokenizer
@ -111,7 +121,7 @@ def load_model(
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
is_trainable: bool = False, is_trainable: bool = False,
add_valuehead: bool = False, add_valuehead: bool = False,
) -> Union["PreTrainedModel", "AutoModelForVision2Seq"]: ) -> Union["PreTrainedModel"]:
r""" r"""
Loads pretrained model. Loads pretrained model.
""" """
@ -170,8 +180,10 @@ def load_model(
trainable_params, all_param = count_parameters(model) trainable_params, all_param = count_parameters(model)
if is_trainable: if is_trainable:
param_stats = "trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format( param_stats = (
trainable_params, all_param, 100 * trainable_params / all_param "trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
trainable_params, all_param, 100 * trainable_params / all_param
)
) )
else: else:
param_stats = "all params: {:d}".format(all_param) param_stats = "all params: {:d}".format(all_param)
@ -185,4 +197,4 @@ def load_model(
) )
) )
return model return model

View File

@ -19,7 +19,9 @@ class DataCollatorForVis2Seq:
texts.append(text) texts.append(text)
images.append(example["images"][0]) images.append(example["images"][0])
batch = self.processor(text=texts, images=images, return_tensors="pt", padding=True) batch = self.processor(
text=texts, images=images, return_tensors="pt", padding=True
)
labels = batch["input_ids"].clone() labels = batch["input_ids"].clone()
if self.processor.tokenizer.pad_token_id is not None: if self.processor.tokenizer.pad_token_id is not None:
@ -27,3 +29,14 @@ class DataCollatorForVis2Seq:
batch["labels"] = labels batch["labels"] = labels
return batch return batch
@dataclass
class DataCollatorForMLLM:
processor: AutoProcessor
def __call__(self, examples):
print(examples[0].keys())
print(examples[0]["input_ids"])
batch = {}
return batch

View File

@ -1,47 +1,66 @@
# Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/summarization/run_summarization.py # Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/summarization/run_summarization.py
import os import os
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional
from ...data import split_dataset, get_mm_dataset from ...data import get_dataset
from ...extras.misc import get_logits_processor from ...extras.misc import get_logits_processor
from ...extras.ploting import plot_loss from ...extras.ploting import plot_loss
from ...model import load_tokenizer, load_processor, load_model from ...model import load_processor, load_model
from ..utils import create_modelcard_and_push from ..utils import create_modelcard_and_push
from .metric import ComputeMetrics from .metric import ComputeMetrics
from .trainer import CustomSeq2SeqTrainer from .trainer import CustomSeq2SeqTrainer
from .collator import DataCollatorForVis2Seq from transformers import DataCollatorForSeq2Seq
from ...extras.constants import IGNORE_INDEX
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback from transformers import Seq2SeqTrainingArguments, TrainerCallback
from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments from ...hparams import (
DataArguments,
FinetuningArguments,
GeneratingArguments,
ModelArguments,
)
def run_sft_mm( def run_sft_mm(
model_args: "ModelArguments", model_args: "ModelArguments",
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments", generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None, callbacks: Optional[List["TrainerCallback"]] = None,
): ):
processor = load_processor(model_args) processor = load_processor(model_args)
tokenizer = load_tokenizer(model_args) tokenizer = processor.tokenizer
CHAT_TEMPLATE = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. {% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<image>{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}{% if add_generation_prompt %}ASSISTANT: {% endif %}""" dataset = get_dataset(
tokenizer.chat_template = CHAT_TEMPLATE tokenizer, model_args, data_args, training_args, "sft", processor
processor.tokenizer = tokenizer )
model = load_model(processor.tokenizer, model_args, finetuning_args, training_args.do_train) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
dataset = get_mm_dataset(processor, model_args, data_args, training_args, stage="sft")
if getattr(model, "is_quantized", False) and not training_args.do_train: if getattr(model, "is_quantized", False) and not training_args.do_train:
setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction setattr(
model, "_hf_peft_config_loaded", True
) # hack here: make model compatible with prediction
train_dataset = dataset train_dataset = dataset
eval_dataset = dataset eval_dataset = dataset
data_collator = DataCollatorForVis2Seq( data_collator = DataCollatorForSeq2Seq(
processor=processor, tokenizer=tokenizer,
pad_to_multiple_of=(
8 if tokenizer.padding_side == "right" else None
), # for shift short attention
label_pad_token_id=(
IGNORE_INDEX
if data_args.ignore_pad_token_for_loss
else tokenizer.pad_token_id
),
) )
# Override the decoding parameters of Seq2SeqTrainer # Override the decoding parameters of Seq2SeqTrainer
training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len training_args.generation_max_length = (
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams training_args.generation_max_length or data_args.cutoff_len
)
training_args.generation_num_beams = (
data_args.eval_num_beams or training_args.generation_num_beams
)
training_args.remove_unused_columns = False training_args.remove_unused_columns = False
# Initialize our Trainer # Initialize our Trainer
@ -52,19 +71,26 @@ def run_sft_mm(
tokenizer=tokenizer, tokenizer=tokenizer,
data_collator=data_collator, data_collator=data_collator,
callbacks=callbacks, callbacks=callbacks,
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None, compute_metrics=(
ComputeMetrics(tokenizer) if training_args.predict_with_generate else None
),
train_dataset=train_dataset, train_dataset=train_dataset,
eval_dataset=eval_dataset, eval_dataset=eval_dataset,
) )
# Keyword arguments for `model.generate` # Keyword arguments for `model.generate`
gen_kwargs = generating_args.to_dict() gen_kwargs = generating_args.to_dict()
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids gen_kwargs["eos_token_id"] = [
tokenizer.eos_token_id
] + tokenizer.additional_special_tokens_ids
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
gen_kwargs["logits_processor"] = get_logits_processor() gen_kwargs["logits_processor"] = get_logits_processor()
# Training # Training
if training_args.do_train: if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) train_result = trainer.train(
resume_from_checkpoint=training_args.resume_from_checkpoint
)
trainer.save_model() trainer.save_model()
trainer.log_metrics("train", train_result.metrics) trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics)
@ -75,19 +101,27 @@ def run_sft_mm(
# Evaluation # Evaluation
if training_args.do_eval: if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs) metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled if (
training_args.predict_with_generate
): # eval_loss will be wrong if predict_with_generate is enabled
metrics.pop("eval_loss", None) metrics.pop("eval_loss", None)
trainer.log_metrics("eval", metrics) trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics) trainer.save_metrics("eval", metrics)
# Predict # Predict
if training_args.do_predict: if training_args.do_predict:
predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs) predict_results = trainer.predict(
if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled dataset, metric_key_prefix="predict", **gen_kwargs
)
if (
training_args.predict_with_generate
): # predict_loss will be wrong if predict_with_generate is enabled
predict_results.metrics.pop("predict_loss", None) predict_results.metrics.pop("predict_loss", None)
trainer.log_metrics("predict", predict_results.metrics) trainer.log_metrics("predict", predict_results.metrics)
trainer.save_metrics("predict", predict_results.metrics) trainer.save_metrics("predict", predict_results.metrics)
trainer.save_predictions(predict_results) trainer.save_predictions(predict_results)
# Create model card # Create model card
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args) create_modelcard_and_push(
trainer, model_args, data_args, training_args, finetuning_args
)