mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 06:12:50 +08:00
modify style
Former-commit-id: 1dcabafe72fe21c7f9122a6bc1a1ccc4f5d08fdd
This commit is contained in:
parent
f42c0b26d1
commit
e6cf251fb8
@ -3,6 +3,7 @@ from .loader import get_dataset
|
|||||||
from .template import Template, get_template_and_fix_tokenizer, templates
|
from .template import Template, get_template_and_fix_tokenizer, templates
|
||||||
from .utils import Role, split_dataset
|
from .utils import Role, split_dataset
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"PairwiseDataCollatorWithPadding",
|
"PairwiseDataCollatorWithPadding",
|
||||||
"get_dataset",
|
"get_dataset",
|
||||||
|
@ -13,9 +13,7 @@ if TYPE_CHECKING:
|
|||||||
from .parser import DatasetAttr
|
from .parser import DatasetAttr
|
||||||
|
|
||||||
|
|
||||||
def convert_alpaca(
|
def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
|
||||||
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr"
|
|
||||||
) -> Dict[str, List[Any]]:
|
|
||||||
outputs = {"prompt": [], "response": [], "system": [], "tools": []}
|
outputs = {"prompt": [], "response": [], "system": [], "tools": []}
|
||||||
for i in range(len(examples[dataset_attr.prompt])):
|
for i in range(len(examples[dataset_attr.prompt])):
|
||||||
prompt = []
|
prompt = []
|
||||||
@ -33,16 +31,11 @@ def convert_alpaca(
|
|||||||
|
|
||||||
prompt.append({"role": Role.USER.value, "content": "\n".join(content)})
|
prompt.append({"role": Role.USER.value, "content": "\n".join(content)})
|
||||||
|
|
||||||
if dataset_attr.response and isinstance(
|
if dataset_attr.response and isinstance(examples[dataset_attr.response][i], list):
|
||||||
examples[dataset_attr.response][i], list
|
|
||||||
):
|
|
||||||
response = [
|
response = [
|
||||||
{"role": Role.ASSISTANT.value, "content": content}
|
{"role": Role.ASSISTANT.value, "content": content} for content in examples[dataset_attr.response][i]
|
||||||
for content in examples[dataset_attr.response][i]
|
|
||||||
]
|
]
|
||||||
elif dataset_attr.response and isinstance(
|
elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str):
|
||||||
examples[dataset_attr.response][i], str
|
|
||||||
):
|
|
||||||
response = [
|
response = [
|
||||||
{
|
{
|
||||||
"role": Role.ASSISTANT.value,
|
"role": Role.ASSISTANT.value,
|
||||||
@ -54,17 +47,13 @@ def convert_alpaca(
|
|||||||
|
|
||||||
outputs["prompt"].append(prompt)
|
outputs["prompt"].append(prompt)
|
||||||
outputs["response"].append(response)
|
outputs["response"].append(response)
|
||||||
outputs["system"].append(
|
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
|
||||||
examples[dataset_attr.system][i] if dataset_attr.system else ""
|
|
||||||
)
|
|
||||||
outputs["tools"].append("")
|
outputs["tools"].append("")
|
||||||
outputs["images"].append([])
|
outputs["images"].append([])
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
def convert_sharegpt(
|
def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
|
||||||
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr"
|
|
||||||
) -> Dict[str, List[Any]]:
|
|
||||||
outputs = {"prompt": [], "response": [], "system": [], "tools": []}
|
outputs = {"prompt": [], "response": [], "system": [], "tools": []}
|
||||||
tag_mapping = {
|
tag_mapping = {
|
||||||
dataset_attr.user_tag: Role.USER.value,
|
dataset_attr.user_tag: Role.USER.value,
|
||||||
@ -77,10 +66,7 @@ def convert_sharegpt(
|
|||||||
even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
|
even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
|
||||||
accept_tags = (odd_tags, even_tags)
|
accept_tags = (odd_tags, even_tags)
|
||||||
for i, messages in enumerate(examples[dataset_attr.messages]):
|
for i, messages in enumerate(examples[dataset_attr.messages]):
|
||||||
if (
|
if dataset_attr.system_tag and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag:
|
||||||
dataset_attr.system_tag
|
|
||||||
and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag
|
|
||||||
):
|
|
||||||
system = messages[0][dataset_attr.content_tag]
|
system = messages[0][dataset_attr.content_tag]
|
||||||
messages = messages[1:]
|
messages = messages[1:]
|
||||||
else:
|
else:
|
||||||
@ -105,17 +91,13 @@ def convert_sharegpt(
|
|||||||
outputs["prompt"].append(aligned_messages[:-1])
|
outputs["prompt"].append(aligned_messages[:-1])
|
||||||
outputs["response"].append(aligned_messages[-1:])
|
outputs["response"].append(aligned_messages[-1:])
|
||||||
outputs["system"].append(system)
|
outputs["system"].append(system)
|
||||||
outputs["tools"].append(
|
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
|
||||||
examples[dataset_attr.tools][i] if dataset_attr.tools else ""
|
|
||||||
)
|
|
||||||
outputs["images"].append([])
|
outputs["images"].append([])
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
def convert_llava(
|
def convert_llava(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
|
||||||
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr"
|
|
||||||
) -> Dict[str, List[Any]]:
|
|
||||||
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
|
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
|
||||||
tag_mapping = {
|
tag_mapping = {
|
||||||
dataset_attr.user_tag: Role.USER.value,
|
dataset_attr.user_tag: Role.USER.value,
|
||||||
@ -128,10 +110,7 @@ def convert_llava(
|
|||||||
even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
|
even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
|
||||||
accept_tags = (odd_tags, even_tags)
|
accept_tags = (odd_tags, even_tags)
|
||||||
for i, messages in enumerate(examples[dataset_attr.messages]):
|
for i, messages in enumerate(examples[dataset_attr.messages]):
|
||||||
if (
|
if dataset_attr.system_tag and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag:
|
||||||
dataset_attr.system_tag
|
|
||||||
and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag
|
|
||||||
):
|
|
||||||
system = messages[0][dataset_attr.content_tag]
|
system = messages[0][dataset_attr.content_tag]
|
||||||
messages = messages[1:]
|
messages = messages[1:]
|
||||||
else:
|
else:
|
||||||
@ -156,13 +135,9 @@ def convert_llava(
|
|||||||
outputs["prompt"].append(aligned_messages[:-1])
|
outputs["prompt"].append(aligned_messages[:-1])
|
||||||
outputs["response"].append(aligned_messages[-1:])
|
outputs["response"].append(aligned_messages[-1:])
|
||||||
outputs["system"].append(system)
|
outputs["system"].append(system)
|
||||||
outputs["tools"].append(
|
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
|
||||||
examples[dataset_attr.tools][i] if dataset_attr.tools else ""
|
|
||||||
)
|
|
||||||
print(examples[dataset_attr.images][i])
|
print(examples[dataset_attr.images][i])
|
||||||
outputs["images"].append(
|
outputs["images"].append(examples[dataset_attr.images][i] if dataset_attr.images else [])
|
||||||
examples[dataset_attr.images][i] if dataset_attr.images else []
|
|
||||||
)
|
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING, Literal, Union, Optional
|
from typing import TYPE_CHECKING, Literal, Optional, Union
|
||||||
|
|
||||||
from datasets import load_dataset, load_from_disk
|
from datasets import load_dataset, load_from_disk
|
||||||
|
|
||||||
@ -13,9 +13,10 @@ from .preprocess import get_preprocess_and_print_func
|
|||||||
from .template import get_template_and_fix_tokenizer
|
from .template import get_template_and_fix_tokenizer
|
||||||
from .utils import checksum, merge_dataset
|
from .utils import checksum, merge_dataset
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from datasets import Dataset, IterableDataset
|
from datasets import Dataset, IterableDataset
|
||||||
from transformers import Seq2SeqTrainingArguments, AutoProcessor
|
from transformers import AutoProcessor, Seq2SeqTrainingArguments
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
|
|
||||||
from ..hparams import DataArguments, ModelArguments
|
from ..hparams import DataArguments, ModelArguments
|
||||||
@ -78,20 +79,14 @@ def load_single_dataset(
|
|||||||
split=data_args.split,
|
split=data_args.split,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
token=model_args.ms_hub_token,
|
token=model_args.ms_hub_token,
|
||||||
use_streaming=(
|
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
||||||
data_args.streaming and (dataset_attr.load_from != "file")
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
if isinstance(dataset, MsDataset):
|
if isinstance(dataset, MsDataset):
|
||||||
dataset = dataset.to_hf_dataset()
|
dataset = dataset.to_hf_dataset()
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError("Please install modelscope via `pip install modelscope -U`")
|
||||||
"Please install modelscope via `pip install modelscope -U`"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
if (
|
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
|
||||||
"trust_remote_code" in inspect.signature(load_dataset).parameters
|
|
||||||
): # for datasets==2.16.0
|
|
||||||
kwargs = {"trust_remote_code": True}
|
kwargs = {"trust_remote_code": True}
|
||||||
else:
|
else:
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
@ -108,9 +103,7 @@ def load_single_dataset(
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if data_args.streaming and (
|
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
|
||||||
dataset_attr.load_from == "file"
|
|
||||||
): # faster than specifying streaming=True
|
|
||||||
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
|
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
|
||||||
|
|
||||||
if data_args.max_samples is not None: # truncate dataset
|
if data_args.max_samples is not None: # truncate dataset
|
||||||
@ -135,13 +128,9 @@ def get_dataset(
|
|||||||
# Load tokenized dataset
|
# Load tokenized dataset
|
||||||
if data_args.tokenized_path is not None:
|
if data_args.tokenized_path is not None:
|
||||||
if has_tokenized_data(data_args.tokenized_path):
|
if has_tokenized_data(data_args.tokenized_path):
|
||||||
logger.warning(
|
logger.warning("Loading dataset from disk will ignore other data arguments.")
|
||||||
"Loading dataset from disk will ignore other data arguments."
|
|
||||||
)
|
|
||||||
dataset = load_from_disk(data_args.tokenized_path)
|
dataset = load_from_disk(data_args.tokenized_path)
|
||||||
logger.info(
|
logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path))
|
||||||
"Loaded tokenized dataset from {}.".format(data_args.tokenized_path)
|
|
||||||
)
|
|
||||||
if data_args.streaming:
|
if data_args.streaming:
|
||||||
dataset = dataset.to_iterable_dataset()
|
dataset = dataset.to_iterable_dataset()
|
||||||
return dataset
|
return dataset
|
||||||
@ -152,16 +141,10 @@ def get_dataset(
|
|||||||
with training_args.main_process_first(desc="load dataset"):
|
with training_args.main_process_first(desc="load dataset"):
|
||||||
all_datasets = []
|
all_datasets = []
|
||||||
for dataset_attr in get_dataset_list(data_args):
|
for dataset_attr in get_dataset_list(data_args):
|
||||||
if (stage == "rm" and dataset_attr.ranking is False) or (
|
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
|
||||||
stage != "rm" and dataset_attr.ranking is True
|
raise ValueError("The dataset is not applicable in the current training stage.")
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"The dataset is not applicable in the current training stage."
|
|
||||||
)
|
|
||||||
|
|
||||||
all_datasets.append(
|
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args))
|
||||||
load_single_dataset(dataset_attr, model_args, data_args)
|
|
||||||
)
|
|
||||||
dataset = merge_dataset(all_datasets, data_args, training_args)
|
dataset = merge_dataset(all_datasets, data_args, training_args)
|
||||||
|
|
||||||
with training_args.main_process_first(desc="pre-process dataset"):
|
with training_args.main_process_first(desc="pre-process dataset"):
|
||||||
@ -177,21 +160,13 @@ def get_dataset(
|
|||||||
desc="Running tokenizer on dataset",
|
desc="Running tokenizer on dataset",
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset = dataset.map(
|
dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs)
|
||||||
preprocess_func, batched=True, remove_columns=column_names, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
if data_args.tokenized_path is not None:
|
if data_args.tokenized_path is not None:
|
||||||
if training_args.should_save:
|
if training_args.should_save:
|
||||||
dataset.save_to_disk(data_args.tokenized_path)
|
dataset.save_to_disk(data_args.tokenized_path)
|
||||||
logger.info(
|
logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path))
|
||||||
"Tokenized dataset saved at {}.".format(data_args.tokenized_path)
|
logger.info("Please restart the training with `--tokenized_path {}`.".format(data_args.tokenized_path))
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
"Please restart the training with `--tokenized_path {}`.".format(
|
|
||||||
data_args.tokenized_path
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
@ -199,8 +174,6 @@ def get_dataset(
|
|||||||
try:
|
try:
|
||||||
print_function(next(iter(dataset)))
|
print_function(next(iter(dataset)))
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
raise RuntimeError(
|
raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
|
||||||
"Cannot find valid samples, check `data/README.md` for the data format."
|
|
||||||
)
|
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
@ -50,9 +50,7 @@ class DatasetAttr:
|
|||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return self.dataset_name
|
return self.dataset_name
|
||||||
|
|
||||||
def set_attr(
|
def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None:
|
||||||
self, key: str, obj: Dict[str, Any], default: Optional[Any] = None
|
|
||||||
) -> None:
|
|
||||||
setattr(self, key, obj.get(key, default))
|
setattr(self, key, obj.get(key, default))
|
||||||
|
|
||||||
|
|
||||||
@ -71,16 +69,12 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
|
|||||||
except Exception as err:
|
except Exception as err:
|
||||||
if len(dataset_names) != 0:
|
if len(dataset_names) != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Cannot open {} due to {}.".format(
|
"Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err))
|
||||||
os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err)
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
dataset_info = None
|
dataset_info = None
|
||||||
|
|
||||||
if data_args.interleave_probs is not None:
|
if data_args.interleave_probs is not None:
|
||||||
data_args.interleave_probs = [
|
data_args.interleave_probs = [float(prob.strip()) for prob in data_args.interleave_probs.split(",")]
|
||||||
float(prob.strip()) for prob in data_args.interleave_probs.split(",")
|
|
||||||
]
|
|
||||||
|
|
||||||
dataset_list: List[DatasetAttr] = []
|
dataset_list: List[DatasetAttr] = []
|
||||||
for name in dataset_names:
|
for name in dataset_names:
|
||||||
@ -98,21 +92,13 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
|
|||||||
|
|
||||||
if has_hf_url or has_ms_url:
|
if has_hf_url or has_ms_url:
|
||||||
if (use_modelscope() and has_ms_url) or (not has_hf_url):
|
if (use_modelscope() and has_ms_url) or (not has_hf_url):
|
||||||
dataset_attr = DatasetAttr(
|
dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"])
|
||||||
"ms_hub", dataset_name=dataset_info[name]["ms_hub_url"]
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
dataset_attr = DatasetAttr(
|
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
|
||||||
"hf_hub", dataset_name=dataset_info[name]["hf_hub_url"]
|
|
||||||
)
|
|
||||||
elif "script_url" in dataset_info[name]:
|
elif "script_url" in dataset_info[name]:
|
||||||
dataset_attr = DatasetAttr(
|
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
|
||||||
"script", dataset_name=dataset_info[name]["script_url"]
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
dataset_attr = DatasetAttr(
|
dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])
|
||||||
"file", dataset_name=dataset_info[name]["file_name"]
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset_attr.set_attr("file_sha1", dataset_info[name])
|
dataset_attr.set_attr("file_sha1", dataset_info[name])
|
||||||
dataset_attr.set_attr("subset", dataset_info[name])
|
dataset_attr.set_attr("subset", dataset_info[name])
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Tuple, Optional
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple
|
||||||
|
|
||||||
from ..extras.constants import IGNORE_INDEX
|
from ..extras.constants import IGNORE_INDEX
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
@ -9,7 +9,7 @@ from .utils import Role
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import Seq2SeqTrainingArguments
|
from transformers import Seq2SeqTrainingArguments
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer, AutoProcessor
|
from transformers.tokenization_utils import AutoProcessor, PreTrainedTokenizer
|
||||||
|
|
||||||
from ..hparams import DataArguments
|
from ..hparams import DataArguments
|
||||||
from .template import Template
|
from .template import Template
|
||||||
@ -24,22 +24,16 @@ def preprocess_pretrain_dataset(
|
|||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
) -> Dict[str, List[List[int]]]:
|
) -> Dict[str, List[List[int]]]:
|
||||||
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
|
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
|
||||||
text_examples = [
|
text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
|
||||||
messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]
|
|
||||||
]
|
|
||||||
|
|
||||||
if not data_args.packing:
|
if not data_args.packing:
|
||||||
if data_args.template == "gemma":
|
if data_args.template == "gemma":
|
||||||
text_examples = [tokenizer.bos_token + example for example in text_examples]
|
text_examples = [tokenizer.bos_token + example for example in text_examples]
|
||||||
|
|
||||||
result = tokenizer(
|
result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len)
|
||||||
text_examples, add_special_tokens=False, max_length=data_args.cutoff_len
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
|
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
|
||||||
concatenated_examples = {
|
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
|
||||||
k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()
|
|
||||||
}
|
|
||||||
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
|
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
|
||||||
block_size = data_args.cutoff_len
|
block_size = data_args.cutoff_len
|
||||||
total_length = (total_length // block_size) * block_size
|
total_length = (total_length // block_size) * block_size
|
||||||
@ -87,9 +81,7 @@ def preprocess_supervised_dataset(
|
|||||||
if data_args.train_on_prompt:
|
if data_args.train_on_prompt:
|
||||||
source_mask = source_ids
|
source_mask = source_ids
|
||||||
elif turn_idx != 0 and template.efficient_eos:
|
elif turn_idx != 0 and template.efficient_eos:
|
||||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (
|
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||||
len(source_ids) - 1
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||||
|
|
||||||
@ -128,9 +120,7 @@ def preprocess_packed_supervised_dataset(
|
|||||||
if data_args.train_on_prompt:
|
if data_args.train_on_prompt:
|
||||||
source_mask = source_ids
|
source_mask = source_ids
|
||||||
elif len(input_ids) != 0 and template.efficient_eos:
|
elif len(input_ids) != 0 and template.efficient_eos:
|
||||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (
|
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||||
len(source_ids) - 1
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||||
|
|
||||||
@ -190,9 +180,7 @@ def preprocess_multimodal_supervised_dataset(
|
|||||||
if data_args.train_on_prompt:
|
if data_args.train_on_prompt:
|
||||||
source_mask = source_ids
|
source_mask = source_ids
|
||||||
elif turn_idx != 0 and template.efficient_eos:
|
elif turn_idx != 0 and template.efficient_eos:
|
||||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (
|
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||||
len(source_ids) - 1
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||||
|
|
||||||
@ -206,9 +194,7 @@ def preprocess_multimodal_supervised_dataset(
|
|||||||
model_inputs["input_ids"].append(input_ids)
|
model_inputs["input_ids"].append(input_ids)
|
||||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||||
model_inputs["labels"].append(labels)
|
model_inputs["labels"].append(labels)
|
||||||
pixel_values = processor.image_processor(
|
pixel_values = processor.image_processor(examples["images"][0], return_tensors="pt")["pixel_values"][0]
|
||||||
examples["images"][0], return_tensors="pt"
|
|
||||||
)["pixel_values"][0]
|
|
||||||
model_inputs["pixel_values"].append(pixel_values)
|
model_inputs["pixel_values"].append(pixel_values)
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
@ -229,9 +215,7 @@ def preprocess_unsupervised_dataset(
|
|||||||
if len(examples["response"][i]) == 1:
|
if len(examples["response"][i]) == 1:
|
||||||
messages = examples["prompt"][i] + examples["response"][i]
|
messages = examples["prompt"][i] + examples["response"][i]
|
||||||
else:
|
else:
|
||||||
messages = examples["prompt"][i] + [
|
messages = examples["prompt"][i] + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||||
{"role": Role.ASSISTANT.value, "content": ""}
|
|
||||||
]
|
|
||||||
|
|
||||||
input_ids, labels = template.encode_oneturn(
|
input_ids, labels = template.encode_oneturn(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
@ -294,15 +278,9 @@ def preprocess_pairwise_dataset(
|
|||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
|
|
||||||
def print_supervised_dataset_example(
|
def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
||||||
example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer"
|
|
||||||
) -> None:
|
|
||||||
print("input_ids:\n{}".format(example["input_ids"]))
|
print("input_ids:\n{}".format(example["input_ids"]))
|
||||||
print(
|
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||||
"inputs:\n{}".format(
|
|
||||||
tokenizer.decode(example["input_ids"], skip_special_tokens=False)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
print("label_ids:\n{}".format(example["labels"]))
|
print("label_ids:\n{}".format(example["labels"]))
|
||||||
print(
|
print(
|
||||||
"labels:\n{}".format(
|
"labels:\n{}".format(
|
||||||
@ -314,38 +292,18 @@ def print_supervised_dataset_example(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def print_pairwise_dataset_example(
|
def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
||||||
example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer"
|
|
||||||
) -> None:
|
|
||||||
print("prompt_ids:\n{}".format(example["prompt_ids"]))
|
print("prompt_ids:\n{}".format(example["prompt_ids"]))
|
||||||
print(
|
print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)))
|
||||||
"prompt:\n{}".format(
|
|
||||||
tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
print("chosen_ids:\n{}".format(example["chosen_ids"]))
|
print("chosen_ids:\n{}".format(example["chosen_ids"]))
|
||||||
print(
|
print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)))
|
||||||
"chosen:\n{}".format(
|
|
||||||
tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
print("rejected_ids:\n{}".format(example["rejected_ids"]))
|
print("rejected_ids:\n{}".format(example["rejected_ids"]))
|
||||||
print(
|
print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)))
|
||||||
"rejected:\n{}".format(
|
|
||||||
tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def print_unsupervised_dataset_example(
|
def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
||||||
example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer"
|
|
||||||
) -> None:
|
|
||||||
print("input_ids:\n{}".format(example["input_ids"]))
|
print("input_ids:\n{}".format(example["input_ids"]))
|
||||||
print(
|
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||||
"inputs:\n{}".format(
|
|
||||||
tokenizer.decode(example["input_ids"], skip_special_tokens=False)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_preprocess_and_print_func(
|
def get_preprocess_and_print_func(
|
||||||
@ -357,12 +315,8 @@ def get_preprocess_and_print_func(
|
|||||||
processor: Optional["AutoProcessor"] = None,
|
processor: Optional["AutoProcessor"] = None,
|
||||||
) -> Tuple[Callable, Callable]:
|
) -> Tuple[Callable, Callable]:
|
||||||
if stage == "pt":
|
if stage == "pt":
|
||||||
preprocess_func = partial(
|
preprocess_func = partial(preprocess_pretrain_dataset, tokenizer=tokenizer, data_args=data_args)
|
||||||
preprocess_pretrain_dataset, tokenizer=tokenizer, data_args=data_args
|
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
||||||
)
|
|
||||||
print_function = partial(
|
|
||||||
print_unsupervised_dataset_example, tokenizer=tokenizer
|
|
||||||
)
|
|
||||||
elif stage == "sft" and not training_args.predict_with_generate:
|
elif stage == "sft" and not training_args.predict_with_generate:
|
||||||
if data_args.packing:
|
if data_args.packing:
|
||||||
preprocess_func = partial(
|
preprocess_func = partial(
|
||||||
@ -402,8 +356,6 @@ def get_preprocess_and_print_func(
|
|||||||
template=template,
|
template=template,
|
||||||
data_args=data_args,
|
data_args=data_args,
|
||||||
)
|
)
|
||||||
print_function = partial(
|
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
||||||
print_unsupervised_dataset_example, tokenizer=tokenizer
|
|
||||||
)
|
|
||||||
|
|
||||||
return preprocess_func, print_function
|
return preprocess_func, print_function
|
||||||
|
@ -42,9 +42,7 @@ class Template:
|
|||||||
r"""
|
r"""
|
||||||
Returns a single pair of token ids representing prompt and response respectively.
|
Returns a single pair of token ids representing prompt and response respectively.
|
||||||
"""
|
"""
|
||||||
encoded_pairs = self._encode(
|
encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
|
||||||
tokenizer, messages, system, tools, cutoff_len, reserved_label_len
|
|
||||||
)
|
|
||||||
prompt_ids = []
|
prompt_ids = []
|
||||||
for query_ids, resp_ids in encoded_pairs[:-1]:
|
for query_ids, resp_ids in encoded_pairs[:-1]:
|
||||||
prompt_ids += query_ids + resp_ids
|
prompt_ids += query_ids + resp_ids
|
||||||
@ -64,9 +62,7 @@ class Template:
|
|||||||
r"""
|
r"""
|
||||||
Returns multiple pairs of token ids representing prompts and responses respectively.
|
Returns multiple pairs of token ids representing prompts and responses respectively.
|
||||||
"""
|
"""
|
||||||
return self._encode(
|
return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
|
||||||
tokenizer, messages, system, tools, cutoff_len, reserved_label_len
|
|
||||||
)
|
|
||||||
|
|
||||||
def _encode(
|
def _encode(
|
||||||
self,
|
self,
|
||||||
@ -93,9 +89,7 @@ class Template:
|
|||||||
elements += self.format_separator.apply()
|
elements += self.format_separator.apply()
|
||||||
|
|
||||||
if message["role"] == Role.USER.value:
|
if message["role"] == Role.USER.value:
|
||||||
elements += self.format_user.apply(
|
elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
|
||||||
content=message["content"], idx=str(i // 2)
|
|
||||||
)
|
|
||||||
elif message["role"] == Role.ASSISTANT.value:
|
elif message["role"] == Role.ASSISTANT.value:
|
||||||
elements += self.format_assistant.apply(content=message["content"])
|
elements += self.format_assistant.apply(content=message["content"])
|
||||||
elif message["role"] == Role.OBSERVATION.value:
|
elif message["role"] == Role.OBSERVATION.value:
|
||||||
@ -130,11 +124,7 @@ class Template:
|
|||||||
elif "eos_token" in elem and tokenizer.eos_token_id is not None:
|
elif "eos_token" in elem and tokenizer.eos_token_id is not None:
|
||||||
token_ids += [tokenizer.eos_token_id]
|
token_ids += [tokenizer.eos_token_id]
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem)))
|
||||||
"Input must be string, set[str] or dict[str, str], got {}".format(
|
|
||||||
type(elem)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return token_ids
|
return token_ids
|
||||||
|
|
||||||
@ -192,9 +182,7 @@ class Llama2Template(Template):
|
|||||||
elements += self.format_separator.apply()
|
elements += self.format_separator.apply()
|
||||||
|
|
||||||
if message["role"] == Role.USER.value:
|
if message["role"] == Role.USER.value:
|
||||||
elements += self.format_user.apply(
|
elements += self.format_user.apply(content=system_text + message["content"])
|
||||||
content=system_text + message["content"]
|
|
||||||
)
|
|
||||||
elif message["role"] == Role.ASSISTANT.value:
|
elif message["role"] == Role.ASSISTANT.value:
|
||||||
elements += self.format_assistant.apply(content=message["content"])
|
elements += self.format_assistant.apply(content=message["content"])
|
||||||
elif message["role"] == Role.OBSERVATION.value:
|
elif message["role"] == Role.OBSERVATION.value:
|
||||||
@ -257,9 +245,7 @@ def _register_template(
|
|||||||
template_class = Llama2Template if name.startswith("llama2") else Template
|
template_class = Llama2Template if name.startswith("llama2") else Template
|
||||||
default_user_formatter = StringFormatter(slots=["{{content}}"])
|
default_user_formatter = StringFormatter(slots=["{{content}}"])
|
||||||
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
|
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
|
||||||
default_function_formatter = FunctionFormatter(
|
default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots)
|
||||||
slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots
|
|
||||||
)
|
|
||||||
default_tool_formatter = ToolFormatter(tool_format="default")
|
default_tool_formatter = ToolFormatter(tool_format="default")
|
||||||
default_separator_formatter = EmptyFormatter()
|
default_separator_formatter = EmptyFormatter()
|
||||||
templates[name] = template_class(
|
templates[name] = template_class(
|
||||||
@ -295,9 +281,7 @@ def _jinja_escape(content: str) -> str:
|
|||||||
return content.replace("\n", r"\n").replace("'", r"\'")
|
return content.replace("\n", r"\n").replace("'", r"\'")
|
||||||
|
|
||||||
|
|
||||||
def _convert_slots_to_jinja(
|
def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str:
|
||||||
slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content"
|
|
||||||
) -> str:
|
|
||||||
slot_items = []
|
slot_items = []
|
||||||
for slot in slots:
|
for slot in slots:
|
||||||
if isinstance(slot, str):
|
if isinstance(slot, str):
|
||||||
@ -311,9 +295,7 @@ def _convert_slots_to_jinja(
|
|||||||
elif isinstance(slot, set):
|
elif isinstance(slot, set):
|
||||||
if "bos_token" in slot:
|
if "bos_token" in slot:
|
||||||
slot_items.append("'" + tokenizer.bos_token + "'")
|
slot_items.append("'" + tokenizer.bos_token + "'")
|
||||||
elif (
|
elif "eos_token" in slot: # do not use {{ eos_token }} since it may be replaced
|
||||||
"eos_token" in slot
|
|
||||||
): # do not use {{ eos_token }} since it may be replaced
|
|
||||||
slot_items.append("'" + tokenizer.eos_token + "'")
|
slot_items.append("'" + tokenizer.eos_token + "'")
|
||||||
elif isinstance(slot, dict):
|
elif isinstance(slot, dict):
|
||||||
raise ValueError("Dict is not supported.")
|
raise ValueError("Dict is not supported.")
|
||||||
@ -325,37 +307,25 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
|
|||||||
jinja_template = ""
|
jinja_template = ""
|
||||||
|
|
||||||
if template.default_system:
|
if template.default_system:
|
||||||
jinja_template += (
|
jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}"
|
||||||
"{% set system_message = '"
|
|
||||||
+ _jinja_escape(template.default_system)
|
|
||||||
+ "' %}"
|
|
||||||
)
|
|
||||||
|
|
||||||
jinja_template += (
|
jinja_template += (
|
||||||
"{% if messages[0]['role'] == 'system' %}"
|
"{% if messages[0]['role'] == 'system' %}" "{% set system_message = messages[0]['content'] %}" "{% endif %}"
|
||||||
"{% set system_message = messages[0]['content'] %}"
|
|
||||||
"{% endif %}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
system_message = _convert_slots_to_jinja(
|
system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message")
|
||||||
template.format_system.apply(), tokenizer, placeholder="system_message"
|
|
||||||
)
|
|
||||||
if isinstance(template, Llama2Template):
|
if isinstance(template, Llama2Template):
|
||||||
pass
|
pass
|
||||||
elif template.force_system:
|
elif template.force_system:
|
||||||
jinja_template += "{{ " + system_message + " }}"
|
jinja_template += "{{ " + system_message + " }}"
|
||||||
else:
|
else:
|
||||||
jinja_template += (
|
jinja_template += "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}"
|
||||||
"{% if system_message is defined %}{{ " + system_message + " }}{% endif %}"
|
|
||||||
)
|
|
||||||
|
|
||||||
jinja_template += "{% for message in messages %}"
|
jinja_template += "{% for message in messages %}"
|
||||||
jinja_template += "{% set content = message['content'] %}"
|
jinja_template += "{% set content = message['content'] %}"
|
||||||
if isinstance(template, Llama2Template):
|
if isinstance(template, Llama2Template):
|
||||||
jinja_template += "{% if loop.index0 == 0 and system_message is defined %}"
|
jinja_template += "{% if loop.index0 == 0 and system_message is defined %}"
|
||||||
jinja_template += (
|
jinja_template += "{% set content = " + system_message + " + message['content'] %}"
|
||||||
"{% set content = " + system_message + " + message['content'] %}"
|
|
||||||
)
|
|
||||||
jinja_template += "{% endif %}"
|
jinja_template += "{% endif %}"
|
||||||
jinja_template += "{% if message['role'] == 'user' %}"
|
jinja_template += "{% if message['role'] == 'user' %}"
|
||||||
user_message = _convert_slots_to_jinja(template.format_user.apply(), tokenizer)
|
user_message = _convert_slots_to_jinja(template.format_user.apply(), tokenizer)
|
||||||
@ -403,9 +373,7 @@ def get_template_and_fix_tokenizer(
|
|||||||
)
|
)
|
||||||
logger.info("Add {} to stop words.".format(",".join(stop_words)))
|
logger.info("Add {} to stop words.".format(",".join(stop_words)))
|
||||||
if num_added_tokens > 0:
|
if num_added_tokens > 0:
|
||||||
logger.warning(
|
logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
|
||||||
"New tokens have been added, make sure `resize_vocab` is True."
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tokenizer.chat_template = _get_jinja_template(template, tokenizer)
|
tokenizer.chat_template = _get_jinja_template(template, tokenizer)
|
||||||
@ -417,9 +385,7 @@ def get_template_and_fix_tokenizer(
|
|||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="alpaca",
|
name="alpaca",
|
||||||
format_user=StringFormatter(
|
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]),
|
||||||
slots=["### Instruction:\n{{content}}\n\n### Response:\n"]
|
|
||||||
),
|
|
||||||
format_separator=EmptyFormatter(slots=["\n\n"]),
|
format_separator=EmptyFormatter(slots=["\n\n"]),
|
||||||
default_system=(
|
default_system=(
|
||||||
"Below is an instruction that describes a task. "
|
"Below is an instruction that describes a task. "
|
||||||
@ -458,9 +424,7 @@ _register_template(
|
|||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="baichuan",
|
name="baichuan",
|
||||||
format_user=StringFormatter(
|
format_user=StringFormatter(slots=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]),
|
||||||
slots=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]
|
|
||||||
),
|
|
||||||
efficient_eos=True,
|
efficient_eos=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -483,9 +447,7 @@ _register_template(
|
|||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="bluelm",
|
name="bluelm",
|
||||||
format_user=StringFormatter(
|
format_user=StringFormatter(slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]),
|
||||||
slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -504,9 +466,7 @@ _register_template(
|
|||||||
_register_template(
|
_register_template(
|
||||||
name="chatglm2",
|
name="chatglm2",
|
||||||
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
|
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
|
||||||
format_system=StringFormatter(
|
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
|
||||||
slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]
|
|
||||||
),
|
|
||||||
format_separator=EmptyFormatter(slots=["\n\n"]),
|
format_separator=EmptyFormatter(slots=["\n\n"]),
|
||||||
efficient_eos=True,
|
efficient_eos=True,
|
||||||
force_system=True,
|
force_system=True,
|
||||||
@ -515,13 +475,9 @@ _register_template(
|
|||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="chatglm3",
|
name="chatglm3",
|
||||||
format_user=StringFormatter(
|
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
|
||||||
slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
|
|
||||||
),
|
|
||||||
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
||||||
format_system=StringFormatter(
|
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
|
||||||
slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]
|
|
||||||
),
|
|
||||||
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
|
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
|
||||||
format_observation=StringFormatter(
|
format_observation=StringFormatter(
|
||||||
slots=[
|
slots=[
|
||||||
@ -539,9 +495,7 @@ _register_template(
|
|||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="chatglm3_system",
|
name="chatglm3_system",
|
||||||
format_user=StringFormatter(
|
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
|
||||||
slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
|
|
||||||
),
|
|
||||||
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
||||||
format_system=StringFormatter(
|
format_system=StringFormatter(
|
||||||
slots=[
|
slots=[
|
||||||
@ -572,15 +526,9 @@ _register_template(
|
|||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="chatml",
|
name="chatml",
|
||||||
format_user=StringFormatter(
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
|
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||||
),
|
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
format_system=StringFormatter(
|
|
||||||
slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]
|
|
||||||
),
|
|
||||||
format_observation=StringFormatter(
|
|
||||||
slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
|
|
||||||
),
|
|
||||||
format_separator=EmptyFormatter(slots=["\n"]),
|
format_separator=EmptyFormatter(slots=["\n"]),
|
||||||
stop_words=["<|im_end|>", "<|im_start|>"],
|
stop_words=["<|im_end|>", "<|im_start|>"],
|
||||||
replace_eos=True,
|
replace_eos=True,
|
||||||
@ -589,15 +537,9 @@ _register_template(
|
|||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="chatml_de",
|
name="chatml_de",
|
||||||
format_user=StringFormatter(
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
|
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||||
),
|
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
format_system=StringFormatter(
|
|
||||||
slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]
|
|
||||||
),
|
|
||||||
format_observation=StringFormatter(
|
|
||||||
slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
|
|
||||||
),
|
|
||||||
format_separator=EmptyFormatter(slots=["\n"]),
|
format_separator=EmptyFormatter(slots=["\n"]),
|
||||||
default_system="Du bist ein freundlicher und hilfsbereiter KI-Assistent.",
|
default_system="Du bist ein freundlicher und hilfsbereiter KI-Assistent.",
|
||||||
stop_words=["<|im_end|>", "<|im_start|>"],
|
stop_words=["<|im_end|>", "<|im_start|>"],
|
||||||
@ -607,9 +549,7 @@ _register_template(
|
|||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="codegeex2",
|
name="codegeex2",
|
||||||
format_system=StringFormatter(
|
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
|
||||||
slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]
|
|
||||||
),
|
|
||||||
force_system=True,
|
force_system=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -639,15 +579,9 @@ _register_template(
|
|||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="dbrx",
|
name="dbrx",
|
||||||
format_user=StringFormatter(
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
|
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||||
),
|
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
format_system=StringFormatter(
|
|
||||||
slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]
|
|
||||||
),
|
|
||||||
format_observation=StringFormatter(
|
|
||||||
slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
|
|
||||||
),
|
|
||||||
format_separator=EmptyFormatter(slots=["\n"]),
|
format_separator=EmptyFormatter(slots=["\n"]),
|
||||||
default_system=(
|
default_system=(
|
||||||
"You are DBRX, created by Databricks. You were last updated in December 2023. "
|
"You are DBRX, created by Databricks. You were last updated in December 2023. "
|
||||||
@ -725,9 +659,7 @@ _register_template(
|
|||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="gemma",
|
name="gemma",
|
||||||
format_user=StringFormatter(
|
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
|
||||||
slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
|
|
||||||
),
|
|
||||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||||
format_observation=StringFormatter(
|
format_observation=StringFormatter(
|
||||||
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
|
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
|
||||||
@ -740,9 +672,7 @@ _register_template(
|
|||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="intern",
|
name="intern",
|
||||||
format_user=StringFormatter(
|
format_user=StringFormatter(slots=["<|User|>:{{content}}", {"token": "<eoh>"}, "\n<|Bot|>:"]),
|
||||||
slots=["<|User|>:{{content}}", {"token": "<eoh>"}, "\n<|Bot|>:"]
|
|
||||||
),
|
|
||||||
format_separator=EmptyFormatter(slots=[{"token": "<eoa>"}, "\n"]),
|
format_separator=EmptyFormatter(slots=[{"token": "<eoa>"}, "\n"]),
|
||||||
stop_words=["<eoa>"],
|
stop_words=["<eoa>"],
|
||||||
efficient_eos=True,
|
efficient_eos=True,
|
||||||
@ -751,12 +681,8 @@ _register_template(
|
|||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="intern2",
|
name="intern2",
|
||||||
format_user=StringFormatter(
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
|
format_system=StringFormatter(slots=[{"bos_token"}, "<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||||
),
|
|
||||||
format_system=StringFormatter(
|
|
||||||
slots=[{"bos_token"}, "<|im_start|>system\n{{content}}<|im_end|>\n"]
|
|
||||||
),
|
|
||||||
format_separator=EmptyFormatter(slots=["\n"]),
|
format_separator=EmptyFormatter(slots=["\n"]),
|
||||||
default_system=(
|
default_system=(
|
||||||
"You are an AI assistant whose name is InternLM (书生·浦语).\n"
|
"You are an AI assistant whose name is InternLM (书生·浦语).\n"
|
||||||
@ -859,9 +785,7 @@ _register_template(
|
|||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="orion",
|
name="orion",
|
||||||
format_user=StringFormatter(
|
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
|
||||||
slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]
|
|
||||||
),
|
|
||||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||||
force_system=True,
|
force_system=True,
|
||||||
)
|
)
|
||||||
@ -869,15 +793,9 @@ _register_template(
|
|||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="phi",
|
name="phi",
|
||||||
format_user=StringFormatter(
|
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
|
||||||
slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]
|
format_system=StringFormatter(slots=[{"bos_token"}, "<|system|>\n{{content}}<|end|>\n"]),
|
||||||
),
|
format_observation=StringFormatter(slots=["<|function_output|>\n{{content}}<|end|>\n<|assistant|>\n"]),
|
||||||
format_system=StringFormatter(
|
|
||||||
slots=[{"bos_token"}, "<|system|>\n{{content}}<|end|>\n"]
|
|
||||||
),
|
|
||||||
format_observation=StringFormatter(
|
|
||||||
slots=["<|function_output|>\n{{content}}<|end|>\n<|assistant|>\n"]
|
|
||||||
),
|
|
||||||
format_separator=EmptyFormatter(slots=["\n"]),
|
format_separator=EmptyFormatter(slots=["\n"]),
|
||||||
default_system="You are a helpful AI assistant.",
|
default_system="You are a helpful AI assistant.",
|
||||||
stop_words=["<|end|>"],
|
stop_words=["<|end|>"],
|
||||||
@ -887,15 +805,9 @@ _register_template(
|
|||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="qwen",
|
name="qwen",
|
||||||
format_user=StringFormatter(
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
|
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||||
),
|
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
format_system=StringFormatter(
|
|
||||||
slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]
|
|
||||||
),
|
|
||||||
format_observation=StringFormatter(
|
|
||||||
slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
|
|
||||||
),
|
|
||||||
format_separator=EmptyFormatter(slots=["\n"]),
|
format_separator=EmptyFormatter(slots=["\n"]),
|
||||||
default_system="You are a helpful assistant.",
|
default_system="You are a helpful assistant.",
|
||||||
stop_words=["<|im_end|>"],
|
stop_words=["<|im_end|>"],
|
||||||
@ -951,12 +863,8 @@ _register_template(
|
|||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="yayi",
|
name="yayi",
|
||||||
format_user=StringFormatter(
|
format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]),
|
||||||
slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]
|
format_system=StringFormatter(slots=[{"token": "<|System|>"}, ":\n{{content}}\n\n"]),
|
||||||
),
|
|
||||||
format_system=StringFormatter(
|
|
||||||
slots=[{"token": "<|System|>"}, ":\n{{content}}\n\n"]
|
|
||||||
),
|
|
||||||
format_separator=EmptyFormatter(slots=["\n\n"]),
|
format_separator=EmptyFormatter(slots=["\n\n"]),
|
||||||
default_system=(
|
default_system=(
|
||||||
"You are a helpful, respectful and honest assistant named YaYi "
|
"You are a helpful, respectful and honest assistant named YaYi "
|
||||||
@ -975,9 +883,7 @@ _register_template(
|
|||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="yi",
|
name="yi",
|
||||||
format_user=StringFormatter(
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
|
|
||||||
),
|
|
||||||
format_separator=EmptyFormatter(slots=["\n"]),
|
format_separator=EmptyFormatter(slots=["\n"]),
|
||||||
stop_words=["<|im_end|>"],
|
stop_words=["<|im_end|>"],
|
||||||
replace_eos=True,
|
replace_eos=True,
|
||||||
@ -995,9 +901,7 @@ _register_template(
|
|||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="zephyr",
|
name="zephyr",
|
||||||
format_user=StringFormatter(
|
format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>"]),
|
||||||
slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>"]
|
|
||||||
),
|
|
||||||
format_assistant=StringFormatter(slots=["\n{{content}}", {"eos_token"}]),
|
format_assistant=StringFormatter(slots=["\n{{content}}", {"eos_token"}]),
|
||||||
format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
|
format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
|
||||||
default_system="You are a friendly chatbot who always responds in the style of a pirate",
|
default_system="You are a friendly chatbot who always responds in the style of a pirate",
|
||||||
|
@ -15,33 +15,23 @@ class ModelArguments:
|
|||||||
)
|
)
|
||||||
adapter_name_or_path: Optional[str] = field(
|
adapter_name_or_path: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={"help": "Path to the adapter weight or identifier from huggingface.co/models."},
|
||||||
"help": "Path to the adapter weight or identifier from huggingface.co/models."
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
cache_dir: Optional[str] = field(
|
cache_dir: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."},
|
||||||
"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
use_fast_tokenizer: bool = field(
|
use_fast_tokenizer: bool = field(
|
||||||
default=True,
|
default=True,
|
||||||
metadata={
|
metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."},
|
||||||
"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
resize_vocab: bool = field(
|
resize_vocab: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={"help": "Whether or not to resize the tokenizer vocab and the embedding layers."},
|
||||||
"help": "Whether or not to resize the tokenizer vocab and the embedding layers."
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
split_special_tokens: bool = field(
|
split_special_tokens: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."},
|
||||||
"help": "Whether or not the special tokens should be split during the tokenization process."
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
new_special_tokens: Optional[str] = field(
|
new_special_tokens: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
@ -49,9 +39,7 @@ class ModelArguments:
|
|||||||
)
|
)
|
||||||
model_revision: str = field(
|
model_revision: str = field(
|
||||||
default="main",
|
default="main",
|
||||||
metadata={
|
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||||
"help": "The specific model version to use (can be a branch name, tag name or commit id)."
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
low_cpu_mem_usage: bool = field(
|
low_cpu_mem_usage: bool = field(
|
||||||
default=True,
|
default=True,
|
||||||
@ -59,9 +47,7 @@ class ModelArguments:
|
|||||||
)
|
)
|
||||||
quantization_bit: Optional[int] = field(
|
quantization_bit: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={"help": "The number of bits to quantize the model using bitsandbytes."},
|
||||||
"help": "The number of bits to quantize the model using bitsandbytes."
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
quantization_type: Literal["fp4", "nf4"] = field(
|
quantization_type: Literal["fp4", "nf4"] = field(
|
||||||
default="nf4",
|
default="nf4",
|
||||||
@ -69,21 +55,15 @@ class ModelArguments:
|
|||||||
)
|
)
|
||||||
double_quantization: bool = field(
|
double_quantization: bool = field(
|
||||||
default=True,
|
default=True,
|
||||||
metadata={
|
metadata={"help": "Whether or not to use double quantization in int4 training."},
|
||||||
"help": "Whether or not to use double quantization in int4 training."
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
quantization_device_map: Optional[Literal["auto"]] = field(
|
quantization_device_map: Optional[Literal["auto"]] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
|
||||||
"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
|
rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
|
||||||
"help": "Which scaling strategy should be adopted for the RoPE embeddings."
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
flash_attn: Literal["off", "sdpa", "fa2", "auto"] = field(
|
flash_attn: Literal["off", "sdpa", "fa2", "auto"] = field(
|
||||||
default="auto",
|
default="auto",
|
||||||
@ -91,27 +71,19 @@ class ModelArguments:
|
|||||||
)
|
)
|
||||||
shift_attn: bool = field(
|
shift_attn: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
|
||||||
"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
mixture_of_depths: Optional[Literal["convert", "load"]] = field(
|
mixture_of_depths: Optional[Literal["convert", "load"]] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."},
|
||||||
"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
use_unsloth: bool = field(
|
use_unsloth: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."},
|
||||||
"help": "Whether or not to use unsloth's optimization for the LoRA training."
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
moe_aux_loss_coef: Optional[float] = field(
|
moe_aux_loss_coef: Optional[float] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."},
|
||||||
"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
disable_gradient_checkpointing: bool = field(
|
disable_gradient_checkpointing: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
@ -135,9 +107,7 @@ class ModelArguments:
|
|||||||
)
|
)
|
||||||
vllm_gpu_util: float = field(
|
vllm_gpu_util: float = field(
|
||||||
default=0.9,
|
default=0.9,
|
||||||
metadata={
|
metadata={"help": "The fraction of GPU memory in (0,1) to be used for the vLLM engine."},
|
||||||
"help": "The fraction of GPU memory in (0,1) to be used for the vLLM engine."
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
vllm_enforce_eager: bool = field(
|
vllm_enforce_eager: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
@ -177,9 +147,7 @@ class ModelArguments:
|
|||||||
)
|
)
|
||||||
export_quantization_dataset: Optional[str] = field(
|
export_quantization_dataset: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."},
|
||||||
"help": "Path to the dataset or dataset name to use in quantizing the exported model."
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
export_quantization_nsamples: int = field(
|
export_quantization_nsamples: int = field(
|
||||||
default=128,
|
default=128,
|
||||||
@ -187,27 +155,19 @@ class ModelArguments:
|
|||||||
)
|
)
|
||||||
export_quantization_maxlen: int = field(
|
export_quantization_maxlen: int = field(
|
||||||
default=1024,
|
default=1024,
|
||||||
metadata={
|
metadata={"help": "The maximum length of the model inputs used for quantization."},
|
||||||
"help": "The maximum length of the model inputs used for quantization."
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
export_legacy_format: bool = field(
|
export_legacy_format: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."},
|
||||||
"help": "Whether or not to save the `.bin` files instead of `.safetensors`."
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
export_hub_model_id: Optional[str] = field(
|
export_hub_model_id: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={"help": "The name of the repository if push the model to the Hugging Face hub."},
|
||||||
"help": "The name of the repository if push the model to the Hugging Face hub."
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
print_param_status: bool = field(
|
print_param_status: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={"help": "For debugging purposes, print the status of the parameters in the model."},
|
||||||
"help": "For debugging purposes, print the status of the parameters in the model."
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
use_mllm: bool = field(
|
use_mllm: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
@ -220,21 +180,13 @@ class ModelArguments:
|
|||||||
self.model_max_length = None
|
self.model_max_length = None
|
||||||
|
|
||||||
if self.split_special_tokens and self.use_fast_tokenizer:
|
if self.split_special_tokens and self.use_fast_tokenizer:
|
||||||
raise ValueError(
|
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
|
||||||
"`split_special_tokens` is only supported for slow tokenizers."
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
if self.adapter_name_or_path is not None: # support merging multiple lora weights
|
||||||
self.adapter_name_or_path is not None
|
self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]
|
||||||
): # support merging multiple lora weights
|
|
||||||
self.adapter_name_or_path = [
|
|
||||||
path.strip() for path in self.adapter_name_or_path.split(",")
|
|
||||||
]
|
|
||||||
|
|
||||||
if self.new_special_tokens is not None: # support multiple special tokens
|
if self.new_special_tokens is not None: # support multiple special tokens
|
||||||
self.new_special_tokens = [
|
self.new_special_tokens = [token.strip() for token in self.new_special_tokens.split(",")]
|
||||||
token.strip() for token in self.new_special_tokens.split(",")
|
|
||||||
]
|
|
||||||
|
|
||||||
assert self.quantization_bit in [
|
assert self.quantization_bit in [
|
||||||
None,
|
None,
|
||||||
@ -249,10 +201,7 @@ class ModelArguments:
|
|||||||
2,
|
2,
|
||||||
], "We only accept 2/3/4/8-bit quantization."
|
], "We only accept 2/3/4/8-bit quantization."
|
||||||
|
|
||||||
if (
|
if self.export_quantization_bit is not None and self.export_quantization_dataset is None:
|
||||||
self.export_quantization_bit is not None
|
|
||||||
and self.export_quantization_dataset is None
|
|
||||||
):
|
|
||||||
raise ValueError("Quantization dataset is necessary for exporting.")
|
raise ValueError("Quantization dataset is necessary for exporting.")
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from .loader import load_config, load_model, load_tokenizer
|
from .loader import load_config, load_model, load_tokenizer
|
||||||
from .utils.misc import find_all_linear_modules, load_valuehead_params
|
from .utils.misc import find_all_linear_modules, load_valuehead_params
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"load_config",
|
"load_config",
|
||||||
"load_model",
|
"load_model",
|
||||||
|
@ -38,9 +38,7 @@ def init_adapter(
|
|||||||
logger.info("Adapter is not found at evaluation, load the base model.")
|
logger.info("Adapter is not found at evaluation, load the base model.")
|
||||||
return model
|
return model
|
||||||
|
|
||||||
if finetuning_args.finetuning_type != "lora" and getattr(
|
if finetuning_args.finetuning_type != "lora" and getattr(model, "quantization_method", None):
|
||||||
model, "quantization_method", None
|
|
||||||
):
|
|
||||||
raise ValueError("You can only use lora for quantized models.")
|
raise ValueError("You can only use lora for quantized models.")
|
||||||
|
|
||||||
if finetuning_args.finetuning_type == "full" and is_trainable:
|
if finetuning_args.finetuning_type == "full" and is_trainable:
|
||||||
@ -68,12 +66,8 @@ def init_adapter(
|
|||||||
|
|
||||||
stride = num_layers // finetuning_args.num_layer_trainable
|
stride = num_layers // finetuning_args.num_layer_trainable
|
||||||
trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride)
|
trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride)
|
||||||
elif (
|
elif finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
|
||||||
finetuning_args.num_layer_trainable > 0
|
trainable_layer_ids = range(num_layers - finetuning_args.num_layer_trainable, num_layers)
|
||||||
): # fine-tuning the last n layers if num_layer_trainable > 0
|
|
||||||
trainable_layer_ids = range(
|
|
||||||
num_layers - finetuning_args.num_layer_trainable, num_layers
|
|
||||||
)
|
|
||||||
else: # fine-tuning the first n layers if num_layer_trainable < 0
|
else: # fine-tuning the first n layers if num_layer_trainable < 0
|
||||||
trainable_layer_ids = range(-finetuning_args.num_layer_trainable)
|
trainable_layer_ids = range(-finetuning_args.num_layer_trainable)
|
||||||
|
|
||||||
@ -88,15 +82,11 @@ def init_adapter(
|
|||||||
for module_name in finetuning_args.name_module_trainable:
|
for module_name in finetuning_args.name_module_trainable:
|
||||||
if module_name not in freeze_modules:
|
if module_name not in freeze_modules:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Module {} is not found, please choose from {}".format(
|
"Module {} is not found, please choose from {}".format(module_name, ", ".join(freeze_modules))
|
||||||
module_name, ", ".join(freeze_modules)
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for idx in trainable_layer_ids:
|
for idx in trainable_layer_ids:
|
||||||
trainable_layers.append(
|
trainable_layers.append(".{:d}.{}".format(idx, module_name if module_name != "all" else ""))
|
||||||
".{:d}.{}".format(idx, module_name if module_name != "all" else "")
|
|
||||||
)
|
|
||||||
|
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
if any(trainable_layer in name for trainable_layer in trainable_layers):
|
if any(trainable_layer in name for trainable_layer in trainable_layers):
|
||||||
@ -105,43 +95,27 @@ def init_adapter(
|
|||||||
else:
|
else:
|
||||||
param.requires_grad_(False)
|
param.requires_grad_(False)
|
||||||
|
|
||||||
logger.info(
|
logger.info("Set trainable layers: {}".format(",".join(map(str, trainable_layer_ids))))
|
||||||
"Set trainable layers: {}".format(",".join(map(str, trainable_layer_ids)))
|
|
||||||
)
|
|
||||||
|
|
||||||
if finetuning_args.finetuning_type == "lora":
|
if finetuning_args.finetuning_type == "lora":
|
||||||
logger.info(
|
logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
|
||||||
"Fine-tuning method: {}".format(
|
|
||||||
"DoRA" if finetuning_args.use_dora else "LoRA"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
adapter_to_resume = None
|
adapter_to_resume = None
|
||||||
|
|
||||||
if model_args.adapter_name_or_path is not None:
|
if model_args.adapter_name_or_path is not None:
|
||||||
is_mergeable = True
|
is_mergeable = True
|
||||||
if getattr(
|
if getattr(model, "quantization_method", None): # merge lora in quantized model is unstable
|
||||||
model, "quantization_method", None
|
assert len(model_args.adapter_name_or_path) == 1, "Quantized model only accepts a single adapter."
|
||||||
): # merge lora in quantized model is unstable
|
|
||||||
assert (
|
|
||||||
len(model_args.adapter_name_or_path) == 1
|
|
||||||
), "Quantized model only accepts a single adapter."
|
|
||||||
is_mergeable = False
|
is_mergeable = False
|
||||||
|
|
||||||
if is_deepspeed_zero3_enabled():
|
if is_deepspeed_zero3_enabled():
|
||||||
assert (
|
assert len(model_args.adapter_name_or_path) == 1, "Cannot use multiple adapters in DeepSpeed ZeRO-3."
|
||||||
len(model_args.adapter_name_or_path) == 1
|
|
||||||
), "Cannot use multiple adapters in DeepSpeed ZeRO-3."
|
|
||||||
is_mergeable = False
|
is_mergeable = False
|
||||||
|
|
||||||
if model_args.use_unsloth:
|
if model_args.use_unsloth:
|
||||||
assert (
|
assert len(model_args.adapter_name_or_path) == 1, "Unsloth model only accepts a single adapter."
|
||||||
len(model_args.adapter_name_or_path) == 1
|
|
||||||
), "Unsloth model only accepts a single adapter."
|
|
||||||
is_mergeable = False
|
is_mergeable = False
|
||||||
|
|
||||||
if (is_trainable and not finetuning_args.create_new_adapter) or (
|
if (is_trainable and not finetuning_args.create_new_adapter) or (not is_mergeable):
|
||||||
not is_mergeable
|
|
||||||
):
|
|
||||||
adapter_to_merge = model_args.adapter_name_or_path[:-1]
|
adapter_to_merge = model_args.adapter_name_or_path[:-1]
|
||||||
adapter_to_resume = model_args.adapter_name_or_path[-1]
|
adapter_to_resume = model_args.adapter_name_or_path[-1]
|
||||||
else:
|
else:
|
||||||
@ -158,9 +132,7 @@ def init_adapter(
|
|||||||
|
|
||||||
if adapter_to_resume is not None: # resume lora training
|
if adapter_to_resume is not None: # resume lora training
|
||||||
if model_args.use_unsloth:
|
if model_args.use_unsloth:
|
||||||
model = load_unsloth_peft_model(
|
model = load_unsloth_peft_model(config, model_args, is_trainable=is_trainable)
|
||||||
config, model_args, is_trainable=is_trainable
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
model = PeftModel.from_pretrained(
|
model = PeftModel.from_pretrained(
|
||||||
model,
|
model,
|
||||||
@ -169,27 +141,19 @@ def init_adapter(
|
|||||||
offload_folder=model_args.offload_folder,
|
offload_folder=model_args.offload_folder,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if is_trainable and adapter_to_resume is None: # create new lora weights while training
|
||||||
is_trainable and adapter_to_resume is None
|
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
|
||||||
): # create new lora weights while training
|
|
||||||
if (
|
|
||||||
len(finetuning_args.lora_target) == 1
|
|
||||||
and finetuning_args.lora_target[0] == "all"
|
|
||||||
):
|
|
||||||
target_modules = find_all_linear_modules(model)
|
target_modules = find_all_linear_modules(model)
|
||||||
else:
|
else:
|
||||||
target_modules = finetuning_args.lora_target
|
target_modules = finetuning_args.lora_target
|
||||||
|
|
||||||
if finetuning_args.use_llama_pro:
|
if finetuning_args.use_llama_pro:
|
||||||
target_modules = find_expanded_modules(
|
target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable)
|
||||||
model, target_modules, finetuning_args.num_layer_trainable
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
finetuning_args.use_dora
|
finetuning_args.use_dora
|
||||||
and getattr(model, "quantization_method", None) is not None
|
and getattr(model, "quantization_method", None) is not None
|
||||||
and getattr(model, "quantization_method", None)
|
and getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES
|
||||||
!= QuantizationMethod.BITS_AND_BYTES
|
|
||||||
):
|
):
|
||||||
raise ValueError("DoRA is not compatible with PTQ-quantized models.")
|
raise ValueError("DoRA is not compatible with PTQ-quantized models.")
|
||||||
|
|
||||||
@ -202,11 +166,7 @@ def init_adapter(
|
|||||||
module_names.add(name.split(".")[-1])
|
module_names.add(name.split(".")[-1])
|
||||||
|
|
||||||
finetuning_args.additional_target = module_names
|
finetuning_args.additional_target = module_names
|
||||||
logger.warning(
|
logger.warning("Vocab has been resized, add {} to trainable params.".format(",".join(module_names)))
|
||||||
"Vocab has been resized, add {} to trainable params.".format(
|
|
||||||
",".join(module_names)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
peft_kwargs = {
|
peft_kwargs = {
|
||||||
"r": finetuning_args.lora_rank,
|
"r": finetuning_args.lora_rank,
|
||||||
@ -233,10 +193,6 @@ def init_adapter(
|
|||||||
param.data = param.data.to(torch.float32)
|
param.data = param.data.to(torch.float32)
|
||||||
|
|
||||||
if model_args.adapter_name_or_path is not None:
|
if model_args.adapter_name_or_path is not None:
|
||||||
logger.info(
|
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
|
||||||
"Loaded adapter(s): {}".format(
|
|
||||||
",".join(model_args.adapter_name_or_path)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
@ -3,9 +3,9 @@ from typing import TYPE_CHECKING, Any, Dict, Union
|
|||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoTokenizer,
|
|
||||||
AutoProcessor,
|
|
||||||
AutoModelForVision2Seq,
|
AutoModelForVision2Seq,
|
||||||
|
AutoProcessor,
|
||||||
|
AutoTokenizer,
|
||||||
)
|
)
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
@ -17,6 +17,7 @@ from .utils.misc import load_valuehead_params, register_autoclass
|
|||||||
from .utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
|
from .utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
|
||||||
from .utils.unsloth import load_unsloth_pretrained_model
|
from .utils.unsloth import load_unsloth_pretrained_model
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
|
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
|
||||||
|
|
||||||
@ -42,7 +43,7 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
|
|||||||
|
|
||||||
def load_tokenizer(
|
def load_tokenizer(
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
) -> Dict[str, Union["PreTrainedTokenizer", "AutoProcesser"]]:
|
) -> Dict[str, Union["PreTrainedTokenizer", "AutoProcessor"]]:
|
||||||
r"""
|
r"""
|
||||||
Loads pretrained tokenizer.
|
Loads pretrained tokenizer.
|
||||||
|
|
||||||
@ -70,14 +71,10 @@ def load_tokenizer(
|
|||||||
dict(additional_special_tokens=model_args.new_special_tokens),
|
dict(additional_special_tokens=model_args.new_special_tokens),
|
||||||
replace_additional_special_tokens=False,
|
replace_additional_special_tokens=False,
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info("Add {} to special tokens.".format(",".join(model_args.new_special_tokens)))
|
||||||
"Add {} to special tokens.".format(",".join(model_args.new_special_tokens))
|
|
||||||
)
|
|
||||||
if num_added_tokens > 0 and not model_args.resize_vocab:
|
if num_added_tokens > 0 and not model_args.resize_vocab:
|
||||||
model_args.resize_vocab = True
|
model_args.resize_vocab = True
|
||||||
logger.warning(
|
logger.warning("New tokens have been added, changed `resize_vocab` to True.")
|
||||||
"New tokens have been added, changed `resize_vocab` to True."
|
|
||||||
)
|
|
||||||
|
|
||||||
patch_tokenizer(tokenizer)
|
patch_tokenizer(tokenizer)
|
||||||
tokenizer_modules = {"tokenizer": tokenizer, "processor": None}
|
tokenizer_modules = {"tokenizer": tokenizer, "processor": None}
|
||||||
@ -174,10 +171,8 @@ def load_model(
|
|||||||
|
|
||||||
trainable_params, all_param = count_parameters(model)
|
trainable_params, all_param = count_parameters(model)
|
||||||
if is_trainable:
|
if is_trainable:
|
||||||
param_stats = (
|
param_stats = "trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
|
||||||
"trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
|
trainable_params, all_param, 100 * trainable_params / all_param
|
||||||
trainable_params, all_param, 100 * trainable_params / all_param
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
param_stats = "all params: {:d}".format(all_param)
|
param_stats = "all params: {:d}".format(all_param)
|
||||||
|
@ -50,29 +50,17 @@ def run_sft(
|
|||||||
tokenizer.padding_side = "left" # use left-padding in generation
|
tokenizer.padding_side = "left" # use left-padding in generation
|
||||||
|
|
||||||
if getattr(model, "is_quantized", False) and not training_args.do_train:
|
if getattr(model, "is_quantized", False) and not training_args.do_train:
|
||||||
setattr(
|
setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction
|
||||||
model, "_hf_peft_config_loaded", True
|
|
||||||
) # hack here: make model compatible with prediction
|
|
||||||
|
|
||||||
data_collator = DataCollatorForSeq2Seq(
|
data_collator = DataCollatorForSeq2Seq(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
pad_to_multiple_of=(
|
pad_to_multiple_of=(8 if tokenizer.padding_side == "right" else None), # for shift short attention
|
||||||
8 if tokenizer.padding_side == "right" else None
|
label_pad_token_id=(IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id),
|
||||||
), # for shift short attention
|
|
||||||
label_pad_token_id=(
|
|
||||||
IGNORE_INDEX
|
|
||||||
if data_args.ignore_pad_token_for_loss
|
|
||||||
else tokenizer.pad_token_id
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Override the decoding parameters of Seq2SeqTrainer
|
# Override the decoding parameters of Seq2SeqTrainer
|
||||||
training_args.generation_max_length = (
|
training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len
|
||||||
training_args.generation_max_length or data_args.cutoff_len
|
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
|
||||||
)
|
|
||||||
training_args.generation_num_beams = (
|
|
||||||
data_args.eval_num_beams or training_args.generation_num_beams
|
|
||||||
)
|
|
||||||
if model_args.use_mllm:
|
if model_args.use_mllm:
|
||||||
training_args.remove_unused_columns = False
|
training_args.remove_unused_columns = False
|
||||||
|
|
||||||
@ -84,25 +72,19 @@ def run_sft(
|
|||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
compute_metrics=(
|
compute_metrics=(ComputeMetrics(tokenizer) if training_args.predict_with_generate else None),
|
||||||
ComputeMetrics(tokenizer) if training_args.predict_with_generate else None
|
|
||||||
),
|
|
||||||
**split_dataset(dataset, data_args, training_args),
|
**split_dataset(dataset, data_args, training_args),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Keyword arguments for `model.generate`
|
# Keyword arguments for `model.generate`
|
||||||
gen_kwargs = generating_args.to_dict()
|
gen_kwargs = generating_args.to_dict()
|
||||||
gen_kwargs["eos_token_id"] = [
|
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
|
||||||
tokenizer.eos_token_id
|
|
||||||
] + tokenizer.additional_special_tokens_ids
|
|
||||||
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
|
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
|
||||||
gen_kwargs["logits_processor"] = get_logits_processor()
|
gen_kwargs["logits_processor"] = get_logits_processor()
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
train_result = trainer.train(
|
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||||
resume_from_checkpoint=training_args.resume_from_checkpoint
|
|
||||||
)
|
|
||||||
trainer.save_model()
|
trainer.save_model()
|
||||||
trainer.log_metrics("train", train_result.metrics)
|
trainer.log_metrics("train", train_result.metrics)
|
||||||
trainer.save_metrics("train", train_result.metrics)
|
trainer.save_metrics("train", train_result.metrics)
|
||||||
@ -113,27 +95,19 @@ def run_sft(
|
|||||||
# Evaluation
|
# Evaluation
|
||||||
if training_args.do_eval:
|
if training_args.do_eval:
|
||||||
metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
|
metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
|
||||||
if (
|
if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled
|
||||||
training_args.predict_with_generate
|
|
||||||
): # eval_loss will be wrong if predict_with_generate is enabled
|
|
||||||
metrics.pop("eval_loss", None)
|
metrics.pop("eval_loss", None)
|
||||||
trainer.log_metrics("eval", metrics)
|
trainer.log_metrics("eval", metrics)
|
||||||
trainer.save_metrics("eval", metrics)
|
trainer.save_metrics("eval", metrics)
|
||||||
|
|
||||||
# Predict
|
# Predict
|
||||||
if training_args.do_predict:
|
if training_args.do_predict:
|
||||||
predict_results = trainer.predict(
|
predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs)
|
||||||
dataset, metric_key_prefix="predict", **gen_kwargs
|
if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
|
||||||
)
|
|
||||||
if (
|
|
||||||
training_args.predict_with_generate
|
|
||||||
): # predict_loss will be wrong if predict_with_generate is enabled
|
|
||||||
predict_results.metrics.pop("predict_loss", None)
|
predict_results.metrics.pop("predict_loss", None)
|
||||||
trainer.log_metrics("predict", predict_results.metrics)
|
trainer.log_metrics("predict", predict_results.metrics)
|
||||||
trainer.save_metrics("predict", predict_results.metrics)
|
trainer.save_metrics("predict", predict_results.metrics)
|
||||||
trainer.save_predictions(predict_results)
|
trainer.save_predictions(predict_results)
|
||||||
|
|
||||||
# Create model card
|
# Create model card
|
||||||
create_modelcard_and_push(
|
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
|
||||||
trainer, model_args, data_args, training_args, finetuning_args
|
|
||||||
)
|
|
||||||
|
4
src/llmtuner/train/sftmm/__init__.py
Normal file
4
src/llmtuner/train/sftmm/__init__.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
from .workflow import run_sft_mm
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["run_sft_mm"]
|
61
src/llmtuner/train/sftmm/metric.py
Normal file
61
src/llmtuner/train/sftmm/metric.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from ...extras.constants import IGNORE_INDEX
|
||||||
|
from ...extras.packages import is_jieba_available, is_nltk_available, is_rouge_available
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
|
|
||||||
|
if is_jieba_available():
|
||||||
|
import jieba # type: ignore
|
||||||
|
|
||||||
|
if is_nltk_available():
|
||||||
|
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
|
||||||
|
|
||||||
|
if is_rouge_available():
|
||||||
|
from rouge_chinese import Rouge
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ComputeMetrics:
|
||||||
|
r"""
|
||||||
|
Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tokenizer: "PreTrainedTokenizer"
|
||||||
|
|
||||||
|
def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
|
||||||
|
r"""
|
||||||
|
Uses the model predictions to compute metrics.
|
||||||
|
"""
|
||||||
|
preds, labels = eval_preds
|
||||||
|
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
|
||||||
|
|
||||||
|
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
|
||||||
|
labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
|
||||||
|
|
||||||
|
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||||
|
decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||||
|
|
||||||
|
for pred, label in zip(decoded_preds, decoded_labels):
|
||||||
|
hypothesis = list(jieba.cut(pred))
|
||||||
|
reference = list(jieba.cut(label))
|
||||||
|
|
||||||
|
if len(" ".join(hypothesis).split()) == 0 or len(" ".join(reference).split()) == 0:
|
||||||
|
result = {"rouge-1": {"f": 0.0}, "rouge-2": {"f": 0.0}, "rouge-l": {"f": 0.0}}
|
||||||
|
else:
|
||||||
|
rouge = Rouge()
|
||||||
|
scores = rouge.get_scores(" ".join(hypothesis), " ".join(reference))
|
||||||
|
result = scores[0]
|
||||||
|
|
||||||
|
for k, v in result.items():
|
||||||
|
score_dict[k].append(round(v["f"] * 100, 4))
|
||||||
|
|
||||||
|
bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
|
||||||
|
score_dict["bleu-4"].append(round(bleu_score * 100, 4))
|
||||||
|
|
||||||
|
return {k: float(np.mean(v)) for k, v in score_dict.items()}
|
39
src/llmtuner/train/sftmm/trainer.py
Normal file
39
src/llmtuner/train/sftmm/trainer.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
from types import MethodType
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import Seq2SeqTrainer
|
||||||
|
|
||||||
|
from ...extras.logging import get_logger
|
||||||
|
from ..utils import create_custom_optimzer, create_custom_scheduler
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ...hparams import FinetuningArguments
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||||
|
r"""
|
||||||
|
Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.finetuning_args = finetuning_args
|
||||||
|
if finetuning_args.use_badam:
|
||||||
|
from badam import clip_grad_norm_for_sparse_tensor
|
||||||
|
|
||||||
|
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
|
||||||
|
|
||||||
|
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||||
|
if self.optimizer is None:
|
||||||
|
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args)
|
||||||
|
return super().create_optimizer()
|
||||||
|
|
||||||
|
def create_scheduler(
|
||||||
|
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
|
||||||
|
) -> "torch.optim.lr_scheduler.LRScheduler":
|
||||||
|
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
||||||
|
return super().create_scheduler(num_training_steps, optimizer)
|
101
src/llmtuner/train/sftmm/workflow.py
Normal file
101
src/llmtuner/train/sftmm/workflow.py
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
# Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/summarization/run_summarization.py
|
||||||
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
|
from transformers import DataCollatorForSeq2Seq
|
||||||
|
|
||||||
|
from ...data import get_dataset
|
||||||
|
from ...extras.constants import IGNORE_INDEX
|
||||||
|
from ...extras.misc import get_logits_processor
|
||||||
|
from ...extras.ploting import plot_loss
|
||||||
|
from ...model import load_model, load_processor
|
||||||
|
from ..sft.metric import ComputeMetrics
|
||||||
|
from ..utils import create_modelcard_and_push
|
||||||
|
from .trainer import CustomSeq2SeqTrainer
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||||
|
|
||||||
|
from ...hparams import (
|
||||||
|
DataArguments,
|
||||||
|
FinetuningArguments,
|
||||||
|
GeneratingArguments,
|
||||||
|
ModelArguments,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run_sft_mm(
|
||||||
|
model_args: "ModelArguments",
|
||||||
|
data_args: "DataArguments",
|
||||||
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
|
finetuning_args: "FinetuningArguments",
|
||||||
|
generating_args: "GeneratingArguments",
|
||||||
|
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||||
|
):
|
||||||
|
processor = load_processor(model_args)
|
||||||
|
tokenizer = processor.tokenizer
|
||||||
|
dataset = get_dataset(tokenizer, model_args, data_args, training_args, "sft", processor)
|
||||||
|
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||||
|
if getattr(model, "is_quantized", False) and not training_args.do_train:
|
||||||
|
setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction
|
||||||
|
train_dataset = dataset
|
||||||
|
eval_dataset = dataset
|
||||||
|
data_collator = DataCollatorForSeq2Seq(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
pad_to_multiple_of=(8 if tokenizer.padding_side == "right" else None), # for shift short attention
|
||||||
|
label_pad_token_id=(IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Override the decoding parameters of Seq2SeqTrainer
|
||||||
|
training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len
|
||||||
|
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
|
||||||
|
training_args.remove_unused_columns = False
|
||||||
|
|
||||||
|
# Initialize our Trainer
|
||||||
|
trainer = CustomSeq2SeqTrainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
finetuning_args=finetuning_args,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
data_collator=data_collator,
|
||||||
|
callbacks=callbacks,
|
||||||
|
compute_metrics=(ComputeMetrics(tokenizer) if training_args.predict_with_generate else None),
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
eval_dataset=eval_dataset,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Keyword arguments for `model.generate`
|
||||||
|
gen_kwargs = generating_args.to_dict()
|
||||||
|
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
|
||||||
|
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
|
||||||
|
gen_kwargs["logits_processor"] = get_logits_processor()
|
||||||
|
|
||||||
|
# Training
|
||||||
|
if training_args.do_train:
|
||||||
|
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||||
|
trainer.save_model()
|
||||||
|
trainer.log_metrics("train", train_result.metrics)
|
||||||
|
trainer.save_metrics("train", train_result.metrics)
|
||||||
|
trainer.save_state()
|
||||||
|
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||||
|
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
||||||
|
|
||||||
|
# Evaluation
|
||||||
|
if training_args.do_eval:
|
||||||
|
metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
|
||||||
|
if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled
|
||||||
|
metrics.pop("eval_loss", None)
|
||||||
|
trainer.log_metrics("eval", metrics)
|
||||||
|
trainer.save_metrics("eval", metrics)
|
||||||
|
|
||||||
|
# Predict
|
||||||
|
if training_args.do_predict:
|
||||||
|
predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs)
|
||||||
|
if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
|
||||||
|
predict_results.metrics.pop("predict_loss", None)
|
||||||
|
trainer.log_metrics("predict", predict_results.metrics)
|
||||||
|
trainer.save_metrics("predict", predict_results.metrics)
|
||||||
|
trainer.save_predictions(predict_results)
|
||||||
|
|
||||||
|
# Create model card
|
||||||
|
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
|
@ -15,6 +15,7 @@ from .pt import run_pt
|
|||||||
from .rm import run_rm
|
from .rm import run_rm
|
||||||
from .sft import run_sft
|
from .sft import run_sft
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import TrainerCallback
|
from transformers import TrainerCallback
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user