mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-14 23:58:11 +08:00
[breaking change] refactor data pipeline (#6901)
* refactor data * rename file Former-commit-id: 7a1a4ce6451cb782573d0bd9dd27a5e443e3a18b
This commit is contained in:
parent
80b89978d9
commit
46203856fc
@ -1,241 +0,0 @@
|
|||||||
# Copyright 2025 the LlamaFactory team.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import os
|
|
||||||
from functools import partial
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
|
|
||||||
|
|
||||||
from ..extras import logging
|
|
||||||
from .data_utils import Role
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from datasets import Dataset, IterableDataset
|
|
||||||
from transformers import Seq2SeqTrainingArguments
|
|
||||||
|
|
||||||
from ..hparams import DataArguments
|
|
||||||
from .parser import DatasetAttr
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def _regularize_medias(
|
|
||||||
inputs: Union[Any, Sequence[Any]],
|
|
||||||
dataset_attr: "DatasetAttr",
|
|
||||||
data_args: "DataArguments",
|
|
||||||
) -> Optional[List[Any]]:
|
|
||||||
r"""
|
|
||||||
Optionally concatenates media path to media dir when loading from local disk.
|
|
||||||
"""
|
|
||||||
if not isinstance(inputs, list):
|
|
||||||
inputs = [inputs]
|
|
||||||
elif len(inputs) == 0:
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
inputs = inputs[:]
|
|
||||||
|
|
||||||
if dataset_attr.load_from in ["script", "file"]:
|
|
||||||
for i in range(len(inputs)):
|
|
||||||
if isinstance(inputs[i], str) and os.path.isfile(os.path.join(data_args.media_dir, inputs[i])):
|
|
||||||
inputs[i] = os.path.join(data_args.media_dir, inputs[i])
|
|
||||||
|
|
||||||
return inputs
|
|
||||||
|
|
||||||
|
|
||||||
def convert_alpaca(
|
|
||||||
example: Dict[str, Any],
|
|
||||||
dataset_attr: "DatasetAttr",
|
|
||||||
data_args: "DataArguments",
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
r"""
|
|
||||||
Converts alpaca format dataset to the standard format.
|
|
||||||
"""
|
|
||||||
prompt = []
|
|
||||||
if dataset_attr.history and isinstance(example[dataset_attr.history], list):
|
|
||||||
for old_prompt, old_response in example[dataset_attr.history]:
|
|
||||||
prompt.append({"role": Role.USER.value, "content": old_prompt})
|
|
||||||
prompt.append({"role": Role.ASSISTANT.value, "content": old_response})
|
|
||||||
|
|
||||||
query = []
|
|
||||||
if dataset_attr.prompt and example[dataset_attr.prompt]:
|
|
||||||
query.append(example[dataset_attr.prompt])
|
|
||||||
|
|
||||||
if dataset_attr.query and example[dataset_attr.query]:
|
|
||||||
query.append(example[dataset_attr.query])
|
|
||||||
|
|
||||||
prompt.append({"role": Role.USER.value, "content": "\n".join(query)}) # "prompt\nquery"
|
|
||||||
|
|
||||||
if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
|
|
||||||
response = [{"role": Role.ASSISTANT.value, "content": example[dataset_attr.response]}]
|
|
||||||
if example[dataset_attr.kto_tag]:
|
|
||||||
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
|
|
||||||
else:
|
|
||||||
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
|
|
||||||
elif (
|
|
||||||
dataset_attr.ranking
|
|
||||||
and isinstance(example[dataset_attr.chosen], str)
|
|
||||||
and isinstance(example[dataset_attr.rejected], str)
|
|
||||||
): # pairwise example
|
|
||||||
response = [
|
|
||||||
{"role": Role.ASSISTANT.value, "content": example[dataset_attr.chosen]},
|
|
||||||
{"role": Role.ASSISTANT.value, "content": example[dataset_attr.rejected]},
|
|
||||||
]
|
|
||||||
elif dataset_attr.response and isinstance(example[dataset_attr.response], str): # normal example
|
|
||||||
response = [{"role": Role.ASSISTANT.value, "content": example[dataset_attr.response]}]
|
|
||||||
else: # unsupervised
|
|
||||||
response = []
|
|
||||||
|
|
||||||
regularize_medias = partial(_regularize_medias, dataset_attr=dataset_attr, data_args=data_args)
|
|
||||||
output = {
|
|
||||||
"_prompt": prompt,
|
|
||||||
"_response": response,
|
|
||||||
"_system": example[dataset_attr.system] if dataset_attr.system else "",
|
|
||||||
"_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
|
|
||||||
"_images": regularize_medias(example[dataset_attr.images]) if dataset_attr.images else None,
|
|
||||||
"_videos": regularize_medias(example[dataset_attr.videos]) if dataset_attr.videos else None,
|
|
||||||
"_audios": regularize_medias(example[dataset_attr.audios]) if dataset_attr.audios else None,
|
|
||||||
}
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def convert_sharegpt(
|
|
||||||
example: Dict[str, Any],
|
|
||||||
dataset_attr: "DatasetAttr",
|
|
||||||
data_args: "DataArguments",
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
r"""
|
|
||||||
Converts sharegpt format dataset to the standard format.
|
|
||||||
"""
|
|
||||||
tag_mapping = {
|
|
||||||
dataset_attr.user_tag: Role.USER.value,
|
|
||||||
dataset_attr.assistant_tag: Role.ASSISTANT.value,
|
|
||||||
dataset_attr.observation_tag: Role.OBSERVATION.value,
|
|
||||||
dataset_attr.function_tag: Role.FUNCTION.value,
|
|
||||||
dataset_attr.system_tag: Role.SYSTEM.value,
|
|
||||||
}
|
|
||||||
odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag)
|
|
||||||
even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
|
|
||||||
accept_tags = (odd_tags, even_tags)
|
|
||||||
messages = example[dataset_attr.messages]
|
|
||||||
if (
|
|
||||||
dataset_attr.system_tag
|
|
||||||
and len(messages) != 0
|
|
||||||
and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag
|
|
||||||
):
|
|
||||||
system = messages[0][dataset_attr.content_tag]
|
|
||||||
messages = messages[1:]
|
|
||||||
else:
|
|
||||||
system = example[dataset_attr.system] if dataset_attr.system else ""
|
|
||||||
|
|
||||||
aligned_messages = []
|
|
||||||
broken_data = False
|
|
||||||
for turn_idx, message in enumerate(messages):
|
|
||||||
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
|
|
||||||
logger.warning_rank0(f"Invalid role tag in {messages}.")
|
|
||||||
broken_data = True
|
|
||||||
break
|
|
||||||
|
|
||||||
aligned_messages.append(
|
|
||||||
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
|
|
||||||
)
|
|
||||||
|
|
||||||
if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
|
|
||||||
dataset_attr.ranking and len(aligned_messages) % 2 == 0
|
|
||||||
):
|
|
||||||
logger.warning_rank0(f"Invalid message count in {messages}.")
|
|
||||||
broken_data = True
|
|
||||||
|
|
||||||
if broken_data:
|
|
||||||
logger.warning_rank0("Skipping this abnormal example.")
|
|
||||||
prompt, response = [], []
|
|
||||||
elif dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
|
|
||||||
prompt = aligned_messages[:-1]
|
|
||||||
response = aligned_messages[-1:]
|
|
||||||
if example[dataset_attr.kto_tag]:
|
|
||||||
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
|
|
||||||
else:
|
|
||||||
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
|
|
||||||
elif (
|
|
||||||
dataset_attr.ranking
|
|
||||||
and isinstance(example[dataset_attr.chosen], dict)
|
|
||||||
and isinstance(example[dataset_attr.rejected], dict)
|
|
||||||
): # pairwise example
|
|
||||||
chosen = example[dataset_attr.chosen]
|
|
||||||
rejected = example[dataset_attr.rejected]
|
|
||||||
if (
|
|
||||||
chosen[dataset_attr.role_tag] not in accept_tags[-1]
|
|
||||||
or rejected[dataset_attr.role_tag] not in accept_tags[-1]
|
|
||||||
):
|
|
||||||
logger.warning_rank0(f"Invalid role tag in {[chosen, rejected]}.")
|
|
||||||
broken_data = True
|
|
||||||
|
|
||||||
prompt = aligned_messages
|
|
||||||
response = [
|
|
||||||
{"role": tag_mapping[chosen[dataset_attr.role_tag]], "content": chosen[dataset_attr.content_tag]},
|
|
||||||
{"role": tag_mapping[rejected[dataset_attr.role_tag]], "content": rejected[dataset_attr.content_tag]},
|
|
||||||
]
|
|
||||||
else: # normal example
|
|
||||||
prompt = aligned_messages[:-1]
|
|
||||||
response = aligned_messages[-1:]
|
|
||||||
|
|
||||||
regularize_medias = partial(_regularize_medias, dataset_attr=dataset_attr, data_args=data_args)
|
|
||||||
output = {
|
|
||||||
"_prompt": prompt,
|
|
||||||
"_response": response,
|
|
||||||
"_system": system,
|
|
||||||
"_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
|
|
||||||
"_images": regularize_medias(example[dataset_attr.images]) if dataset_attr.images else None,
|
|
||||||
"_videos": regularize_medias(example[dataset_attr.videos]) if dataset_attr.videos else None,
|
|
||||||
"_audios": regularize_medias(example[dataset_attr.audios]) if dataset_attr.audios else None,
|
|
||||||
}
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def align_dataset(
|
|
||||||
dataset: Union["Dataset", "IterableDataset"],
|
|
||||||
dataset_attr: "DatasetAttr",
|
|
||||||
data_args: "DataArguments",
|
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
|
||||||
) -> Union["Dataset", "IterableDataset"]:
|
|
||||||
r"""
|
|
||||||
Aligned dataset:
|
|
||||||
_prompt: [{"role": "user", "content": "..."}] * (2T - 1)
|
|
||||||
_response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
|
|
||||||
_system: "..."
|
|
||||||
_tools: "...",
|
|
||||||
_images: [],
|
|
||||||
_videos: [],
|
|
||||||
_audios: [],
|
|
||||||
"""
|
|
||||||
if dataset_attr.formatting == "alpaca":
|
|
||||||
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args)
|
|
||||||
else:
|
|
||||||
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr, data_args=data_args)
|
|
||||||
|
|
||||||
column_names = list(next(iter(dataset)).keys())
|
|
||||||
kwargs = {}
|
|
||||||
if not data_args.streaming:
|
|
||||||
kwargs = dict(
|
|
||||||
num_proc=data_args.preprocessing_num_workers,
|
|
||||||
load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
|
|
||||||
desc="Converting format of dataset",
|
|
||||||
)
|
|
||||||
|
|
||||||
return dataset.map(
|
|
||||||
convert_func,
|
|
||||||
batched=False,
|
|
||||||
remove_columns=column_names,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
269
src/llamafactory/data/converter.py
Normal file
269
src/llamafactory/data/converter.py
Normal file
@ -0,0 +1,269 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
from abc import abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Type, Union
|
||||||
|
|
||||||
|
from ..extras import logging
|
||||||
|
from .data_utils import Role
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from datasets import Dataset, IterableDataset
|
||||||
|
from transformers import Seq2SeqTrainingArguments
|
||||||
|
|
||||||
|
from ..hparams import DataArguments
|
||||||
|
from .parser import DatasetAttr
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DatasetConverter:
|
||||||
|
dataset_attr: "DatasetAttr"
|
||||||
|
data_args: "DataArguments"
|
||||||
|
|
||||||
|
def _find_medias(self, inputs: Union[Any, Sequence[Any]]) -> Optional[List[Any]]:
|
||||||
|
r"""
|
||||||
|
Optionally concatenates media path to media dir when loading from local disk.
|
||||||
|
"""
|
||||||
|
if not isinstance(inputs, list):
|
||||||
|
inputs = [inputs]
|
||||||
|
elif len(inputs) == 0:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
inputs = inputs[:]
|
||||||
|
|
||||||
|
if self.dataset_attr.load_from in ["script", "file"]:
|
||||||
|
for i in range(len(inputs)):
|
||||||
|
if isinstance(inputs[i], str) and os.path.isfile(os.path.join(self.data_args.media_dir, inputs[i])):
|
||||||
|
inputs[i] = os.path.join(self.data_args.media_dir, inputs[i])
|
||||||
|
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
r"""
|
||||||
|
Converts a single example in the dataset to the standard format.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AlpacaDatasetConverter(DatasetConverter):
|
||||||
|
def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
prompt = []
|
||||||
|
if self.dataset_attr.history and isinstance(example[self.dataset_attr.history], list):
|
||||||
|
for old_prompt, old_response in example[self.dataset_attr.history]:
|
||||||
|
prompt.append({"role": Role.USER.value, "content": old_prompt})
|
||||||
|
prompt.append({"role": Role.ASSISTANT.value, "content": old_response})
|
||||||
|
|
||||||
|
query = []
|
||||||
|
if self.dataset_attr.prompt and example[self.dataset_attr.prompt]:
|
||||||
|
query.append(example[self.dataset_attr.prompt])
|
||||||
|
|
||||||
|
if self.dataset_attr.query and example[self.dataset_attr.query]:
|
||||||
|
query.append(example[self.dataset_attr.query])
|
||||||
|
|
||||||
|
prompt.append({"role": Role.USER.value, "content": "\n".join(query)}) # "prompt\nquery"
|
||||||
|
|
||||||
|
if self.dataset_attr.kto_tag and isinstance(example[self.dataset_attr.kto_tag], bool): # kto example
|
||||||
|
response = [{"role": Role.ASSISTANT.value, "content": example[self.dataset_attr.response]}]
|
||||||
|
if example[self.dataset_attr.kto_tag]:
|
||||||
|
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||||
|
else:
|
||||||
|
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
|
||||||
|
elif (
|
||||||
|
self.dataset_attr.ranking
|
||||||
|
and isinstance(example[self.dataset_attr.chosen], str)
|
||||||
|
and isinstance(example[self.dataset_attr.rejected], str)
|
||||||
|
): # pairwise example
|
||||||
|
response = [
|
||||||
|
{"role": Role.ASSISTANT.value, "content": example[self.dataset_attr.chosen]},
|
||||||
|
{"role": Role.ASSISTANT.value, "content": example[self.dataset_attr.rejected]},
|
||||||
|
]
|
||||||
|
elif self.dataset_attr.response and isinstance(example[self.dataset_attr.response], str): # normal example
|
||||||
|
response = [{"role": Role.ASSISTANT.value, "content": example[self.dataset_attr.response]}]
|
||||||
|
else: # unsupervised
|
||||||
|
response = []
|
||||||
|
|
||||||
|
output = {
|
||||||
|
"_prompt": prompt,
|
||||||
|
"_response": response,
|
||||||
|
"_system": example[self.dataset_attr.system] if self.dataset_attr.system else "",
|
||||||
|
"_tools": example[self.dataset_attr.tools] if self.dataset_attr.tools else "",
|
||||||
|
"_images": self._find_medias(example[self.dataset_attr.images]) if self.dataset_attr.images else None,
|
||||||
|
"_videos": self._find_medias(example[self.dataset_attr.videos]) if self.dataset_attr.videos else None,
|
||||||
|
"_audios": self._find_medias(example[self.dataset_attr.audios]) if self.dataset_attr.audios else None,
|
||||||
|
}
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SharegptDatasetConverter(DatasetConverter):
|
||||||
|
def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
tag_mapping = {
|
||||||
|
self.dataset_attr.user_tag: Role.USER.value,
|
||||||
|
self.dataset_attr.assistant_tag: Role.ASSISTANT.value,
|
||||||
|
self.dataset_attr.observation_tag: Role.OBSERVATION.value,
|
||||||
|
self.dataset_attr.function_tag: Role.FUNCTION.value,
|
||||||
|
self.dataset_attr.system_tag: Role.SYSTEM.value,
|
||||||
|
}
|
||||||
|
odd_tags = (self.dataset_attr.user_tag, self.dataset_attr.observation_tag)
|
||||||
|
even_tags = (self.dataset_attr.assistant_tag, self.dataset_attr.function_tag)
|
||||||
|
accept_tags = (odd_tags, even_tags)
|
||||||
|
messages = example[self.dataset_attr.messages]
|
||||||
|
if (
|
||||||
|
self.dataset_attr.system_tag
|
||||||
|
and len(messages) != 0
|
||||||
|
and messages[0][self.dataset_attr.role_tag] == self.dataset_attr.system_tag
|
||||||
|
):
|
||||||
|
system = messages[0][self.dataset_attr.content_tag]
|
||||||
|
messages = messages[1:]
|
||||||
|
else:
|
||||||
|
system = example[self.dataset_attr.system] if self.dataset_attr.system else ""
|
||||||
|
|
||||||
|
aligned_messages = []
|
||||||
|
broken_data = False
|
||||||
|
for turn_idx, message in enumerate(messages):
|
||||||
|
if message[self.dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
|
||||||
|
logger.warning_rank0(f"Invalid role tag in {messages}.")
|
||||||
|
broken_data = True
|
||||||
|
break
|
||||||
|
|
||||||
|
aligned_messages.append(
|
||||||
|
{
|
||||||
|
"role": tag_mapping[message[self.dataset_attr.role_tag]],
|
||||||
|
"content": message[self.dataset_attr.content_tag],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if (not self.dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
|
||||||
|
self.dataset_attr.ranking and len(aligned_messages) % 2 == 0
|
||||||
|
):
|
||||||
|
logger.warning_rank0(f"Invalid message count in {messages}.")
|
||||||
|
broken_data = True
|
||||||
|
|
||||||
|
if broken_data:
|
||||||
|
logger.warning_rank0("Skipping this abnormal example.")
|
||||||
|
prompt, response = [], []
|
||||||
|
elif self.dataset_attr.kto_tag and isinstance(example[self.dataset_attr.kto_tag], bool): # kto example
|
||||||
|
prompt = aligned_messages[:-1]
|
||||||
|
response = aligned_messages[-1:]
|
||||||
|
if example[self.dataset_attr.kto_tag]:
|
||||||
|
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||||
|
else:
|
||||||
|
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
|
||||||
|
elif (
|
||||||
|
self.dataset_attr.ranking
|
||||||
|
and isinstance(example[self.dataset_attr.chosen], dict)
|
||||||
|
and isinstance(example[self.dataset_attr.rejected], dict)
|
||||||
|
): # pairwise example
|
||||||
|
chosen = example[self.dataset_attr.chosen]
|
||||||
|
rejected = example[self.dataset_attr.rejected]
|
||||||
|
if (
|
||||||
|
chosen[self.dataset_attr.role_tag] not in accept_tags[-1]
|
||||||
|
or rejected[self.dataset_attr.role_tag] not in accept_tags[-1]
|
||||||
|
):
|
||||||
|
logger.warning_rank0(f"Invalid role tag in {[chosen, rejected]}.")
|
||||||
|
broken_data = True
|
||||||
|
|
||||||
|
prompt = aligned_messages
|
||||||
|
response = [
|
||||||
|
{
|
||||||
|
"role": tag_mapping[chosen[self.dataset_attr.role_tag]],
|
||||||
|
"content": chosen[self.dataset_attr.content_tag],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": tag_mapping[rejected[self.dataset_attr.role_tag]],
|
||||||
|
"content": rejected[self.dataset_attr.content_tag],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
else: # normal example
|
||||||
|
prompt = aligned_messages[:-1]
|
||||||
|
response = aligned_messages[-1:]
|
||||||
|
|
||||||
|
output = {
|
||||||
|
"_prompt": prompt,
|
||||||
|
"_response": response,
|
||||||
|
"_system": system,
|
||||||
|
"_tools": example[self.dataset_attr.tools] if self.dataset_attr.tools else "",
|
||||||
|
"_images": self._find_medias(example[self.dataset_attr.images]) if self.dataset_attr.images else None,
|
||||||
|
"_videos": self._find_medias(example[self.dataset_attr.videos]) if self.dataset_attr.videos else None,
|
||||||
|
"_audios": self._find_medias(example[self.dataset_attr.audios]) if self.dataset_attr.audios else None,
|
||||||
|
}
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
DATASET_CONVERTERS = {
|
||||||
|
"alpaca": AlpacaDatasetConverter,
|
||||||
|
"sharegpt": SharegptDatasetConverter,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def register_dataset_converter(name: str, dataset_converter: Type["DatasetConverter"]) -> None:
|
||||||
|
r"""
|
||||||
|
Register a new dataset converter.
|
||||||
|
"""
|
||||||
|
if name in DATASET_CONVERTERS:
|
||||||
|
raise ValueError(f"Dataset converter {name} already exists.")
|
||||||
|
|
||||||
|
DATASET_CONVERTERS[name] = dataset_converter
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataset_converter(name: str, dataset_attr: "DatasetAttr", data_args: "DataArguments") -> "DatasetConverter":
|
||||||
|
r"""
|
||||||
|
Gets a dataset converter.
|
||||||
|
"""
|
||||||
|
if name not in DATASET_CONVERTERS:
|
||||||
|
raise ValueError(f"Dataset converter {name} not found.")
|
||||||
|
|
||||||
|
return DATASET_CONVERTERS[name](dataset_attr, data_args)
|
||||||
|
|
||||||
|
|
||||||
|
def align_dataset(
|
||||||
|
dataset: Union["Dataset", "IterableDataset"],
|
||||||
|
dataset_attr: "DatasetAttr",
|
||||||
|
data_args: "DataArguments",
|
||||||
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
|
) -> Union["Dataset", "IterableDataset"]:
|
||||||
|
r"""
|
||||||
|
Aligned dataset:
|
||||||
|
_prompt: [{"role": "user", "content": "..."}] * (2T - 1)
|
||||||
|
_response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
|
||||||
|
_system: "..."
|
||||||
|
_tools: "...",
|
||||||
|
_images: [],
|
||||||
|
_videos: [],
|
||||||
|
_audios: [],
|
||||||
|
"""
|
||||||
|
|
||||||
|
column_names = list(next(iter(dataset)).keys())
|
||||||
|
kwargs = {}
|
||||||
|
if not data_args.streaming:
|
||||||
|
kwargs = dict(
|
||||||
|
num_proc=data_args.preprocessing_num_workers,
|
||||||
|
load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
|
||||||
|
desc="Converting format of dataset",
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset_converter = get_dataset_converter(dataset_attr.formatting, dataset_attr, data_args)
|
||||||
|
return dataset.map(
|
||||||
|
dataset_converter,
|
||||||
|
batched=False,
|
||||||
|
remove_columns=column_names,
|
||||||
|
**kwargs,
|
||||||
|
)
|
@ -22,10 +22,17 @@ from datasets import DatasetDict, load_dataset, load_from_disk
|
|||||||
from ..extras import logging
|
from ..extras import logging
|
||||||
from ..extras.constants import FILEEXT2TYPE
|
from ..extras.constants import FILEEXT2TYPE
|
||||||
from ..extras.misc import check_version, has_tokenized_data
|
from ..extras.misc import check_version, has_tokenized_data
|
||||||
from .aligner import align_dataset
|
from .converter import align_dataset
|
||||||
from .data_utils import merge_dataset, split_dataset
|
from .data_utils import merge_dataset, split_dataset
|
||||||
from .parser import get_dataset_list
|
from .parser import get_dataset_list
|
||||||
from .preprocess import get_preprocess_and_print_func
|
from .processor import (
|
||||||
|
FeedbackDatasetProcessor,
|
||||||
|
PackedSupervisedDatasetProcessor,
|
||||||
|
PairwiseDatasetProcessor,
|
||||||
|
PretrainDatasetProcessor,
|
||||||
|
SupervisedDatasetProcessor,
|
||||||
|
UnsupervisedDatasetProcessor,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -35,6 +42,7 @@ if TYPE_CHECKING:
|
|||||||
from ..hparams import DataArguments, ModelArguments
|
from ..hparams import DataArguments, ModelArguments
|
||||||
from .data_utils import DatasetModule
|
from .data_utils import DatasetModule
|
||||||
from .parser import DatasetAttr
|
from .parser import DatasetAttr
|
||||||
|
from .processor import DatasetProcessor
|
||||||
from .template import Template
|
from .template import Template
|
||||||
|
|
||||||
|
|
||||||
@ -158,7 +166,7 @@ def _get_merged_dataset(
|
|||||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||||
) -> Optional[Union["Dataset", "IterableDataset"]]:
|
) -> Optional[Union["Dataset", "IterableDataset"]]:
|
||||||
r"""
|
r"""
|
||||||
Gets the merged datasets in the standard format.
|
Returns the merged datasets in the standard format.
|
||||||
"""
|
"""
|
||||||
if dataset_names is None:
|
if dataset_names is None:
|
||||||
return None
|
return None
|
||||||
@ -173,6 +181,48 @@ def _get_merged_dataset(
|
|||||||
return merge_dataset(datasets, data_args, seed=training_args.seed)
|
return merge_dataset(datasets, data_args, seed=training_args.seed)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_dataset_processor(
|
||||||
|
data_args: "DataArguments",
|
||||||
|
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||||
|
template: "Template",
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
do_generate: bool = False,
|
||||||
|
) -> "DatasetProcessor":
|
||||||
|
r"""
|
||||||
|
Returns the corresponding dataset processor.
|
||||||
|
"""
|
||||||
|
if stage == "pt":
|
||||||
|
dataset_processor_class = PretrainDatasetProcessor
|
||||||
|
elif stage == "sft" and not do_generate:
|
||||||
|
if data_args.packing:
|
||||||
|
if data_args.neat_packing: # hack datasets to have int32 attention mask
|
||||||
|
from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence
|
||||||
|
|
||||||
|
def __init__(self, data, **kwargs):
|
||||||
|
return TypedSequence.__init__(
|
||||||
|
self,
|
||||||
|
data,
|
||||||
|
type=kwargs.pop("type", None),
|
||||||
|
try_type=kwargs.pop("try_type", None),
|
||||||
|
optimized_int_type=kwargs.pop("optimized_int_type", None),
|
||||||
|
)
|
||||||
|
|
||||||
|
OptimizedTypedSequence.__init__ = __init__
|
||||||
|
dataset_processor_class = PackedSupervisedDatasetProcessor
|
||||||
|
else:
|
||||||
|
dataset_processor_class = SupervisedDatasetProcessor
|
||||||
|
|
||||||
|
elif stage == "rm":
|
||||||
|
dataset_processor_class = PairwiseDatasetProcessor
|
||||||
|
elif stage == "kto":
|
||||||
|
dataset_processor_class = FeedbackDatasetProcessor
|
||||||
|
else:
|
||||||
|
dataset_processor_class = UnsupervisedDatasetProcessor
|
||||||
|
|
||||||
|
return dataset_processor_class(template=template, tokenizer=tokenizer, processor=processor, data_args=data_args)
|
||||||
|
|
||||||
|
|
||||||
def _get_preprocessed_dataset(
|
def _get_preprocessed_dataset(
|
||||||
dataset: Optional[Union["Dataset", "IterableDataset"]],
|
dataset: Optional[Union["Dataset", "IterableDataset"]],
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
@ -189,7 +239,7 @@ def _get_preprocessed_dataset(
|
|||||||
if dataset is None:
|
if dataset is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
preprocess_func, print_function = get_preprocess_and_print_func(
|
dataset_processor = _get_dataset_processor(
|
||||||
data_args, stage, template, tokenizer, processor, do_generate=(training_args.predict_with_generate and is_eval)
|
data_args, stage, template, tokenizer, processor, do_generate=(training_args.predict_with_generate and is_eval)
|
||||||
)
|
)
|
||||||
column_names = list(next(iter(dataset)).keys())
|
column_names = list(next(iter(dataset)).keys())
|
||||||
@ -202,7 +252,7 @@ def _get_preprocessed_dataset(
|
|||||||
)
|
)
|
||||||
|
|
||||||
dataset = dataset.map(
|
dataset = dataset.map(
|
||||||
preprocess_func,
|
dataset_processor.preprocess_dataset,
|
||||||
batched=True,
|
batched=True,
|
||||||
batch_size=data_args.preprocessing_batch_size,
|
batch_size=data_args.preprocessing_batch_size,
|
||||||
remove_columns=column_names,
|
remove_columns=column_names,
|
||||||
@ -212,7 +262,7 @@ def _get_preprocessed_dataset(
|
|||||||
if training_args.should_log:
|
if training_args.should_log:
|
||||||
try:
|
try:
|
||||||
print("eval example:" if is_eval else "training example:")
|
print("eval example:" if is_eval else "training example:")
|
||||||
print_function(next(iter(dataset)))
|
dataset_processor.print_data_example(next(iter(dataset)))
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
if stage == "pt":
|
if stage == "pt":
|
||||||
raise RuntimeError("Cannot find sufficient samples, consider increasing dataset size.")
|
raise RuntimeError("Cannot find sufficient samples, consider increasing dataset size.")
|
||||||
|
@ -4,7 +4,7 @@ import re
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Type, TypedDict, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -1241,14 +1241,26 @@ PLUGINS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def register_mm_plugin(name: str, plugin_class: Type["BasePlugin"]) -> None:
|
||||||
|
r"""
|
||||||
|
Registers a multimodal plugin.
|
||||||
|
"""
|
||||||
|
if name in PLUGINS:
|
||||||
|
raise ValueError(f"Multimodal plugin {name} already exists.")
|
||||||
|
|
||||||
|
PLUGINS[name] = plugin_class
|
||||||
|
|
||||||
|
|
||||||
def get_mm_plugin(
|
def get_mm_plugin(
|
||||||
name: str,
|
name: str,
|
||||||
image_token: Optional[str] = None,
|
image_token: Optional[str] = None,
|
||||||
video_token: Optional[str] = None,
|
video_token: Optional[str] = None,
|
||||||
audio_token: Optional[str] = None,
|
audio_token: Optional[str] = None,
|
||||||
) -> "BasePlugin":
|
) -> "BasePlugin":
|
||||||
plugin_class = PLUGINS.get(name, None)
|
r"""
|
||||||
if plugin_class is None:
|
Gets plugin for multimodal inputs.
|
||||||
|
"""
|
||||||
|
if name not in PLUGINS:
|
||||||
raise ValueError(f"Multimodal plugin `{name}` not found.")
|
raise ValueError(f"Multimodal plugin `{name}` not found.")
|
||||||
|
|
||||||
return plugin_class(image_token, video_token, audio_token)
|
return PLUGINS[name](image_token, video_token, audio_token)
|
||||||
|
@ -45,7 +45,7 @@ class DatasetAttr:
|
|||||||
images: Optional[str] = None
|
images: Optional[str] = None
|
||||||
videos: Optional[str] = None
|
videos: Optional[str] = None
|
||||||
audios: Optional[str] = None
|
audios: Optional[str] = None
|
||||||
# rlhf columns
|
# dpo columns
|
||||||
chosen: Optional[str] = None
|
chosen: Optional[str] = None
|
||||||
rejected: Optional[str] = None
|
rejected: Optional[str] = None
|
||||||
kto_tag: Optional[str] = None
|
kto_tag: Optional[str] = None
|
||||||
@ -71,6 +71,26 @@ class DatasetAttr:
|
|||||||
def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None:
|
def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None:
|
||||||
setattr(self, key, obj.get(key, default))
|
setattr(self, key, obj.get(key, default))
|
||||||
|
|
||||||
|
def join(self, attr: Dict[str, Any]) -> None:
|
||||||
|
self.set_attr("formatting", attr, default="alpaca")
|
||||||
|
self.set_attr("ranking", attr, default=False)
|
||||||
|
self.set_attr("subset", attr)
|
||||||
|
self.set_attr("split", attr, default="train")
|
||||||
|
self.set_attr("folder", attr)
|
||||||
|
self.set_attr("num_samples", attr)
|
||||||
|
|
||||||
|
if "columns" in attr:
|
||||||
|
column_names = ["prompt", "query", "response", "history", "messages", "system", "tools"]
|
||||||
|
column_names += ["images", "videos", "audios", "chosen", "rejected", "kto_tag"]
|
||||||
|
for column_name in column_names:
|
||||||
|
self.set_attr(column_name, attr["columns"])
|
||||||
|
|
||||||
|
if "tags" in attr:
|
||||||
|
tag_names = ["role_tag", "content_tag"]
|
||||||
|
tag_names += ["user_tag", "assistant_tag", "observation_tag", "function_tag", "system_tag"]
|
||||||
|
for tag in tag_names:
|
||||||
|
self.set_attr(tag, attr["tags"])
|
||||||
|
|
||||||
|
|
||||||
def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -> List["DatasetAttr"]:
|
def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -> List["DatasetAttr"]:
|
||||||
r"""
|
r"""
|
||||||
@ -128,36 +148,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
|
|||||||
else:
|
else:
|
||||||
dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])
|
dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])
|
||||||
|
|
||||||
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
|
dataset_attr.join(dataset_info[name])
|
||||||
dataset_attr.set_attr("ranking", dataset_info[name], default=False)
|
|
||||||
dataset_attr.set_attr("subset", dataset_info[name])
|
|
||||||
dataset_attr.set_attr("split", dataset_info[name], default="train")
|
|
||||||
dataset_attr.set_attr("folder", dataset_info[name])
|
|
||||||
dataset_attr.set_attr("num_samples", dataset_info[name])
|
|
||||||
|
|
||||||
if "columns" in dataset_info[name]:
|
|
||||||
column_names = ["system", "tools", "images", "videos", "audios", "chosen", "rejected", "kto_tag"]
|
|
||||||
if dataset_attr.formatting == "alpaca":
|
|
||||||
column_names.extend(["prompt", "query", "response", "history"])
|
|
||||||
else:
|
|
||||||
column_names.extend(["messages"])
|
|
||||||
|
|
||||||
for column_name in column_names:
|
|
||||||
dataset_attr.set_attr(column_name, dataset_info[name]["columns"])
|
|
||||||
|
|
||||||
if dataset_attr.formatting == "sharegpt" and "tags" in dataset_info[name]:
|
|
||||||
tag_names = (
|
|
||||||
"role_tag",
|
|
||||||
"content_tag",
|
|
||||||
"user_tag",
|
|
||||||
"assistant_tag",
|
|
||||||
"observation_tag",
|
|
||||||
"function_tag",
|
|
||||||
"system_tag",
|
|
||||||
)
|
|
||||||
for tag in tag_names:
|
|
||||||
dataset_attr.set_attr(tag, dataset_info[name]["tags"])
|
|
||||||
|
|
||||||
dataset_list.append(dataset_attr)
|
dataset_list.append(dataset_attr)
|
||||||
|
|
||||||
return dataset_list
|
return dataset_list
|
||||||
|
@ -1,111 +0,0 @@
|
|||||||
# Copyright 2025 the LlamaFactory team.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from functools import partial
|
|
||||||
from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple
|
|
||||||
|
|
||||||
from .processors.feedback import preprocess_feedback_dataset
|
|
||||||
from .processors.pairwise import preprocess_pairwise_dataset, print_pairwise_dataset_example
|
|
||||||
from .processors.pretrain import preprocess_pretrain_dataset, print_pretrain_dataset_example
|
|
||||||
from .processors.supervised import (
|
|
||||||
preprocess_packed_supervised_dataset,
|
|
||||||
preprocess_supervised_dataset,
|
|
||||||
print_supervised_dataset_example,
|
|
||||||
)
|
|
||||||
from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsupervised_dataset_example
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
|
||||||
|
|
||||||
from ..hparams import DataArguments
|
|
||||||
from .template import Template
|
|
||||||
|
|
||||||
|
|
||||||
def get_preprocess_and_print_func(
|
|
||||||
data_args: "DataArguments",
|
|
||||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
|
||||||
template: "Template",
|
|
||||||
tokenizer: "PreTrainedTokenizer",
|
|
||||||
processor: Optional["ProcessorMixin"],
|
|
||||||
do_generate: bool = False,
|
|
||||||
) -> Tuple[Callable, Callable]:
|
|
||||||
if stage == "pt":
|
|
||||||
preprocess_func = partial(
|
|
||||||
preprocess_pretrain_dataset,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
data_args=data_args,
|
|
||||||
)
|
|
||||||
print_function = partial(print_pretrain_dataset_example, tokenizer=tokenizer)
|
|
||||||
elif stage == "sft" and not do_generate:
|
|
||||||
if data_args.packing:
|
|
||||||
if data_args.neat_packing: # hack datasets to have int32 attention mask
|
|
||||||
from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence
|
|
||||||
|
|
||||||
def __init__(self, data, **kwargs):
|
|
||||||
return TypedSequence.__init__(
|
|
||||||
self,
|
|
||||||
data,
|
|
||||||
type=kwargs.pop("type", None),
|
|
||||||
try_type=kwargs.pop("try_type", None),
|
|
||||||
optimized_int_type=kwargs.pop("optimized_int_type", None),
|
|
||||||
)
|
|
||||||
|
|
||||||
OptimizedTypedSequence.__init__ = __init__
|
|
||||||
preprocess_func = partial(
|
|
||||||
preprocess_packed_supervised_dataset,
|
|
||||||
template=template,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
processor=processor,
|
|
||||||
data_args=data_args,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
preprocess_func = partial(
|
|
||||||
preprocess_supervised_dataset,
|
|
||||||
template=template,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
processor=processor,
|
|
||||||
data_args=data_args,
|
|
||||||
)
|
|
||||||
|
|
||||||
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
|
|
||||||
elif stage == "rm":
|
|
||||||
preprocess_func = partial(
|
|
||||||
preprocess_pairwise_dataset,
|
|
||||||
template=template,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
processor=processor,
|
|
||||||
data_args=data_args,
|
|
||||||
)
|
|
||||||
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
|
|
||||||
elif stage == "kto":
|
|
||||||
preprocess_func = partial(
|
|
||||||
preprocess_feedback_dataset,
|
|
||||||
template=template,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
processor=processor,
|
|
||||||
data_args=data_args,
|
|
||||||
)
|
|
||||||
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
|
|
||||||
else:
|
|
||||||
preprocess_func = partial(
|
|
||||||
preprocess_unsupervised_dataset,
|
|
||||||
template=template,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
processor=processor,
|
|
||||||
data_args=data_args,
|
|
||||||
)
|
|
||||||
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
|
||||||
|
|
||||||
return preprocess_func, print_function
|
|
17
src/llamafactory/data/processor/__init__.py
Normal file
17
src/llamafactory/data/processor/__init__.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
from .feedback import FeedbackDatasetProcessor
|
||||||
|
from .pairwise import PairwiseDatasetProcessor
|
||||||
|
from .pretrain import PretrainDatasetProcessor
|
||||||
|
from .processor_utils import DatasetProcessor
|
||||||
|
from .supervised import PackedSupervisedDatasetProcessor, SupervisedDatasetProcessor
|
||||||
|
from .unsupervised import UnsupervisedDatasetProcessor
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DatasetProcessor",
|
||||||
|
"FeedbackDatasetProcessor",
|
||||||
|
"PairwiseDatasetProcessor",
|
||||||
|
"PretrainDatasetProcessor",
|
||||||
|
"PackedSupervisedDatasetProcessor",
|
||||||
|
"SupervisedDatasetProcessor",
|
||||||
|
"UnsupervisedDatasetProcessor",
|
||||||
|
]
|
129
src/llamafactory/data/processor/feedback.py
Normal file
129
src/llamafactory/data/processor/feedback.py
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
from ...extras import logging
|
||||||
|
from ...extras.constants import IGNORE_INDEX
|
||||||
|
from .processor_utils import DatasetProcessor, infer_seqlen
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..mm_plugin import AudioInput, ImageInput, VideoInput
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FeedbackDatasetProcessor(DatasetProcessor):
|
||||||
|
def _encode_data_example(
|
||||||
|
self,
|
||||||
|
prompt: Sequence[Dict[str, str]],
|
||||||
|
response: Sequence[Dict[str, str]],
|
||||||
|
kl_response: Sequence[Dict[str, str]],
|
||||||
|
system: Optional[str],
|
||||||
|
tools: Optional[str],
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
audios: Sequence["AudioInput"],
|
||||||
|
) -> Tuple[List[int], List[int], List[int], List[int], bool]:
|
||||||
|
if response[0]["content"]: # desired example
|
||||||
|
kto_tag = True
|
||||||
|
messages = prompt + [response[0]]
|
||||||
|
else: # undesired example
|
||||||
|
kto_tag = False
|
||||||
|
messages = prompt + [response[1]]
|
||||||
|
|
||||||
|
if kl_response[0]["content"]:
|
||||||
|
kl_messages = prompt + [kl_response[0]]
|
||||||
|
else:
|
||||||
|
kl_messages = prompt + [kl_response[1]]
|
||||||
|
|
||||||
|
messages = self.template.mm_plugin.process_messages(messages, images, videos, audios, self.processor)
|
||||||
|
kl_messages = self.template.mm_plugin.process_messages(kl_messages, images, videos, audios, self.processor)
|
||||||
|
prompt_ids, response_ids = self.template.encode_oneturn(self.tokenizer, messages, system, tools)
|
||||||
|
kl_prompt_ids, kl_response_ids = self.template.encode_oneturn(self.tokenizer, kl_messages, system, tools)
|
||||||
|
|
||||||
|
if self.template.efficient_eos:
|
||||||
|
response_ids += [self.tokenizer.eos_token_id]
|
||||||
|
kl_response_ids += [self.tokenizer.eos_token_id]
|
||||||
|
|
||||||
|
prompt_ids, _ = self.template.mm_plugin.process_token_ids(
|
||||||
|
prompt_ids, None, images, videos, audios, self.tokenizer, self.processor
|
||||||
|
)
|
||||||
|
kl_prompt_ids, _ = self.template.mm_plugin.process_token_ids(
|
||||||
|
kl_prompt_ids, None, images, videos, audios, self.tokenizer, self.processor
|
||||||
|
)
|
||||||
|
|
||||||
|
source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), self.data_args.cutoff_len)
|
||||||
|
prompt_ids = prompt_ids[:source_len]
|
||||||
|
response_ids = response_ids[:target_len]
|
||||||
|
kl_source_len, kl_target_len = infer_seqlen(
|
||||||
|
len(kl_prompt_ids), len(kl_response_ids), self.data_args.cutoff_len
|
||||||
|
)
|
||||||
|
kl_prompt_ids = kl_prompt_ids[:kl_source_len]
|
||||||
|
kl_response_ids = kl_response_ids[:kl_target_len]
|
||||||
|
|
||||||
|
input_ids = prompt_ids + response_ids
|
||||||
|
labels = [IGNORE_INDEX] * source_len + response_ids
|
||||||
|
kl_input_ids = kl_prompt_ids + kl_response_ids
|
||||||
|
kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids
|
||||||
|
return input_ids, labels, kl_input_ids, kl_labels, kto_tag
|
||||||
|
|
||||||
|
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||||
|
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
|
||||||
|
kl_response = examples["_response"][::-1]
|
||||||
|
model_inputs = defaultdict(list)
|
||||||
|
for i in range(len(examples["_prompt"])):
|
||||||
|
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
|
||||||
|
logger.warning_rank0(
|
||||||
|
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
input_ids, labels, kl_input_ids, kl_labels, kto_tag = self._encode_data_example(
|
||||||
|
prompt=examples["_prompt"][i],
|
||||||
|
response=examples["_response"][i],
|
||||||
|
kl_response=kl_response[i],
|
||||||
|
system=examples["_system"][i],
|
||||||
|
tools=examples["_tools"][i],
|
||||||
|
images=examples["_images"][i] or [],
|
||||||
|
videos=examples["_videos"][i] or [],
|
||||||
|
audios=examples["_audios"][i] or [],
|
||||||
|
)
|
||||||
|
model_inputs["input_ids"].append(input_ids)
|
||||||
|
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||||
|
model_inputs["labels"].append(labels)
|
||||||
|
model_inputs["kl_input_ids"].append(kl_input_ids)
|
||||||
|
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
|
||||||
|
model_inputs["kl_labels"].append(kl_labels)
|
||||||
|
model_inputs["kto_tags"].append(kto_tag)
|
||||||
|
model_inputs["images"].append(examples["_images"][i])
|
||||||
|
model_inputs["videos"].append(examples["_videos"][i])
|
||||||
|
model_inputs["audios"].append(examples["_audios"][i])
|
||||||
|
|
||||||
|
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
|
||||||
|
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
|
||||||
|
if desirable_num == 0 or undesirable_num == 0:
|
||||||
|
logger.warning_rank0("Your dataset only has one preference type.")
|
||||||
|
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
|
def print_data_example(self, example: Dict[str, List[int]]) -> None:
|
||||||
|
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
|
||||||
|
print("input_ids:\n{}".format(example["input_ids"]))
|
||||||
|
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||||
|
print("label_ids:\n{}".format(example["labels"]))
|
||||||
|
print(f"labels:\n{self.tokenizer.decode(valid_labels, skip_special_tokens=False)}")
|
118
src/llamafactory/data/processor/pairwise.py
Normal file
118
src/llamafactory/data/processor/pairwise.py
Normal file
@ -0,0 +1,118 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
from ...extras import logging
|
||||||
|
from ...extras.constants import IGNORE_INDEX
|
||||||
|
from .processor_utils import DatasetProcessor, infer_seqlen
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..mm_plugin import AudioInput, ImageInput, VideoInput
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PairwiseDatasetProcessor(DatasetProcessor):
|
||||||
|
def _encode_data_example(
|
||||||
|
self,
|
||||||
|
prompt: Sequence[Dict[str, str]],
|
||||||
|
response: Sequence[Dict[str, str]],
|
||||||
|
system: Optional[str],
|
||||||
|
tools: Optional[str],
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
audios: Sequence["AudioInput"],
|
||||||
|
) -> Tuple[List[int], List[int], List[int], List[int]]:
|
||||||
|
chosen_messages = self.template.mm_plugin.process_messages(
|
||||||
|
prompt + [response[0]], images, videos, audios, self.processor
|
||||||
|
)
|
||||||
|
rejected_messages = self.template.mm_plugin.process_messages(
|
||||||
|
prompt + [response[1]], images, videos, audios, self.processor
|
||||||
|
)
|
||||||
|
prompt_ids, chosen_ids = self.template.encode_oneturn(self.tokenizer, chosen_messages, system, tools)
|
||||||
|
_, rejected_ids = self.template.encode_oneturn(self.tokenizer, rejected_messages, system, tools)
|
||||||
|
|
||||||
|
if self.template.efficient_eos:
|
||||||
|
chosen_ids += [self.tokenizer.eos_token_id]
|
||||||
|
rejected_ids += [self.tokenizer.eos_token_id]
|
||||||
|
|
||||||
|
prompt_ids, _ = self.template.mm_plugin.process_token_ids(
|
||||||
|
prompt_ids, None, images, videos, audios, self.tokenizer, self.processor
|
||||||
|
)
|
||||||
|
# consider the response is more important
|
||||||
|
source_len, target_len = infer_seqlen(
|
||||||
|
len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), self.data_args.cutoff_len
|
||||||
|
)
|
||||||
|
prompt_ids = prompt_ids[:source_len]
|
||||||
|
chosen_ids = chosen_ids[:target_len]
|
||||||
|
rejected_ids = rejected_ids[:target_len]
|
||||||
|
|
||||||
|
chosen_input_ids = prompt_ids + chosen_ids
|
||||||
|
chosen_labels = [IGNORE_INDEX] * source_len + chosen_ids
|
||||||
|
rejected_input_ids = prompt_ids + rejected_ids
|
||||||
|
rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids
|
||||||
|
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels
|
||||||
|
|
||||||
|
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||||
|
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
|
||||||
|
model_inputs = defaultdict(list)
|
||||||
|
for i in range(len(examples["_prompt"])):
|
||||||
|
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
|
||||||
|
logger.warning_rank0(
|
||||||
|
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = self._encode_data_example(
|
||||||
|
prompt=examples["_prompt"][i],
|
||||||
|
response=examples["_response"][i],
|
||||||
|
system=examples["_system"][i],
|
||||||
|
tools=examples["_tools"][i],
|
||||||
|
images=examples["_images"][i] or [],
|
||||||
|
videos=examples["_videos"][i] or [],
|
||||||
|
audios=examples["_audios"][i] or [],
|
||||||
|
)
|
||||||
|
model_inputs["chosen_input_ids"].append(chosen_input_ids)
|
||||||
|
model_inputs["chosen_attention_mask"].append([1] * len(chosen_input_ids))
|
||||||
|
model_inputs["chosen_labels"].append(chosen_labels)
|
||||||
|
model_inputs["rejected_input_ids"].append(rejected_input_ids)
|
||||||
|
model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids))
|
||||||
|
model_inputs["rejected_labels"].append(rejected_labels)
|
||||||
|
model_inputs["images"].append(examples["_images"][i])
|
||||||
|
model_inputs["videos"].append(examples["_videos"][i])
|
||||||
|
model_inputs["audios"].append(examples["_audios"][i])
|
||||||
|
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
|
def print_data_example(self, example: Dict[str, List[int]]) -> None:
|
||||||
|
valid_chosen_labels = list(filter(lambda x: x != IGNORE_INDEX, example["chosen_labels"]))
|
||||||
|
valid_rejected_labels = list(filter(lambda x: x != IGNORE_INDEX, example["rejected_labels"]))
|
||||||
|
print("chosen_input_ids:\n{}".format(example["chosen_input_ids"]))
|
||||||
|
print(
|
||||||
|
"chosen_inputs:\n{}".format(self.tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False))
|
||||||
|
)
|
||||||
|
print("chosen_label_ids:\n{}".format(example["chosen_labels"]))
|
||||||
|
print(f"chosen_labels:\n{self.tokenizer.decode(valid_chosen_labels, skip_special_tokens=False)}")
|
||||||
|
print("rejected_input_ids:\n{}".format(example["rejected_input_ids"]))
|
||||||
|
print(
|
||||||
|
"rejected_inputs:\n{}".format(
|
||||||
|
self.tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print("rejected_label_ids:\n{}".format(example["rejected_labels"]))
|
||||||
|
print(f"rejected_labels:\n{self.tokenizer.decode(valid_rejected_labels, skip_special_tokens=False)}")
|
57
src/llamafactory/data/processor/pretrain.py
Normal file
57
src/llamafactory/data/processor/pretrain.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# This code is inspired by the HuggingFace's transformers library.
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from itertools import chain
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from .processor_utils import DatasetProcessor
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PretrainDatasetProcessor(DatasetProcessor):
|
||||||
|
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||||
|
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
|
||||||
|
eos_token = "<|end_of_text|>" if self.data_args.template == "llama3" else self.tokenizer.eos_token
|
||||||
|
text_examples = [messages[0]["content"] + eos_token for messages in examples["_prompt"]]
|
||||||
|
|
||||||
|
if not self.data_args.packing:
|
||||||
|
if getattr(self.tokenizer, "add_bos_token", False):
|
||||||
|
text_examples = [self.tokenizer.bos_token + example for example in text_examples]
|
||||||
|
|
||||||
|
result = self.tokenizer(
|
||||||
|
text_examples, add_special_tokens=False, truncation=True, max_length=self.data_args.cutoff_len
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
tokenized_examples = self.tokenizer(text_examples, add_special_tokens=False)
|
||||||
|
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
|
||||||
|
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
|
||||||
|
block_size = self.data_args.cutoff_len
|
||||||
|
total_length = (total_length // block_size) * block_size
|
||||||
|
result = {
|
||||||
|
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
||||||
|
for k, t in concatenated_examples.items()
|
||||||
|
}
|
||||||
|
if getattr(self.tokenizer, "add_bos_token", False):
|
||||||
|
for i in range(len(result["input_ids"])):
|
||||||
|
result["input_ids"][i][0] = self.tokenizer.bos_token_id
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def print_data_example(self, example: Dict[str, List[int]]) -> None:
|
||||||
|
print("input_ids:\n{}".format(example["input_ids"]))
|
||||||
|
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
@ -13,7 +13,42 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import bisect
|
import bisect
|
||||||
from typing import List, Sequence, Tuple
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||||
|
|
||||||
|
from ...hparams import DataArguments
|
||||||
|
from ..template import Template
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DatasetProcessor(ABC):
|
||||||
|
r"""
|
||||||
|
A class for data processors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
template: "Template"
|
||||||
|
tokenizer: "PreTrainedTokenizer"
|
||||||
|
processor: Optional["ProcessorMixin"]
|
||||||
|
data_args: "DataArguments"
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||||
|
r"""
|
||||||
|
Builds model inputs from the examples.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def print_data_example(self, example: Dict[str, List[int]]) -> None:
|
||||||
|
r"""
|
||||||
|
Print a data example to stdout.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
def search_for_fit(numbers: Sequence[int], capacity: int) -> int:
|
def search_for_fit(numbers: Sequence[int], capacity: int) -> int:
|
200
src/llamafactory/data/processor/supervised.py
Normal file
200
src/llamafactory/data/processor/supervised.py
Normal file
@ -0,0 +1,200 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
from ...extras import logging
|
||||||
|
from ...extras.constants import IGNORE_INDEX
|
||||||
|
from .processor_utils import DatasetProcessor, greedy_knapsack, infer_seqlen
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..mm_plugin import AudioInput, ImageInput, VideoInput
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SupervisedDatasetProcessor(DatasetProcessor):
|
||||||
|
def _encode_data_example(
|
||||||
|
self,
|
||||||
|
prompt: Sequence[Dict[str, str]],
|
||||||
|
response: Sequence[Dict[str, str]],
|
||||||
|
system: Optional[str],
|
||||||
|
tools: Optional[str],
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
audios: Sequence["AudioInput"],
|
||||||
|
) -> Tuple[List[int], List[int]]:
|
||||||
|
messages = self.template.mm_plugin.process_messages(prompt + response, images, videos, audios, self.processor)
|
||||||
|
input_ids, labels = self.template.mm_plugin.process_token_ids(
|
||||||
|
[], [], images, videos, audios, self.tokenizer, self.processor
|
||||||
|
)
|
||||||
|
encoded_pairs = self.template.encode_multiturn(self.tokenizer, messages, system, tools)
|
||||||
|
total_length = len(input_ids) + (1 if self.template.efficient_eos else 0)
|
||||||
|
if self.data_args.mask_history:
|
||||||
|
encoded_pairs = encoded_pairs[::-1] # high priority for last turns
|
||||||
|
|
||||||
|
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
|
||||||
|
if total_length >= self.data_args.cutoff_len:
|
||||||
|
break
|
||||||
|
|
||||||
|
source_len, target_len = infer_seqlen(
|
||||||
|
len(source_ids), len(target_ids), self.data_args.cutoff_len - total_length
|
||||||
|
)
|
||||||
|
source_ids = source_ids[:source_len]
|
||||||
|
target_ids = target_ids[:target_len]
|
||||||
|
total_length += source_len + target_len
|
||||||
|
|
||||||
|
if self.data_args.train_on_prompt:
|
||||||
|
source_label = source_ids
|
||||||
|
elif self.template.efficient_eos:
|
||||||
|
source_label = [self.tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1)
|
||||||
|
else:
|
||||||
|
source_label = [IGNORE_INDEX] * source_len
|
||||||
|
|
||||||
|
if self.data_args.mask_history and turn_idx != 0: # train on the last turn only
|
||||||
|
target_label = [IGNORE_INDEX] * target_len
|
||||||
|
else:
|
||||||
|
target_label = target_ids
|
||||||
|
|
||||||
|
if self.data_args.mask_history: # reversed sequences
|
||||||
|
input_ids = source_ids + target_ids + input_ids
|
||||||
|
labels = source_label + target_label + labels
|
||||||
|
else:
|
||||||
|
input_ids += source_ids + target_ids
|
||||||
|
labels += source_label + target_label
|
||||||
|
|
||||||
|
if self.template.efficient_eos:
|
||||||
|
input_ids += [self.tokenizer.eos_token_id]
|
||||||
|
labels += [self.tokenizer.eos_token_id]
|
||||||
|
|
||||||
|
return input_ids, labels
|
||||||
|
|
||||||
|
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||||
|
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||||
|
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
||||||
|
model_inputs = defaultdict(list)
|
||||||
|
for i in range(len(examples["_prompt"])):
|
||||||
|
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
|
||||||
|
logger.warning_rank0(
|
||||||
|
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
input_ids, labels = self._encode_data_example(
|
||||||
|
prompt=examples["_prompt"][i],
|
||||||
|
response=examples["_response"][i],
|
||||||
|
system=examples["_system"][i],
|
||||||
|
tools=examples["_tools"][i],
|
||||||
|
images=examples["_images"][i] or [],
|
||||||
|
videos=examples["_videos"][i] or [],
|
||||||
|
audios=examples["_audios"][i] or [],
|
||||||
|
)
|
||||||
|
model_inputs["input_ids"].append(input_ids)
|
||||||
|
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||||
|
model_inputs["labels"].append(labels)
|
||||||
|
model_inputs["images"].append(examples["_images"][i])
|
||||||
|
model_inputs["videos"].append(examples["_videos"][i])
|
||||||
|
model_inputs["audios"].append(examples["_audios"][i])
|
||||||
|
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
|
def print_data_example(self, example: Dict[str, List[int]]) -> None:
|
||||||
|
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
|
||||||
|
print("input_ids:\n{}".format(example["input_ids"]))
|
||||||
|
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||||
|
print("label_ids:\n{}".format(example["labels"]))
|
||||||
|
print(f"labels:\n{self.tokenizer.decode(valid_labels, skip_special_tokens=False)}")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
|
||||||
|
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||||
|
# TODO: use `position_ids` to achieve packing
|
||||||
|
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
|
||||||
|
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
|
||||||
|
valid_num = 0
|
||||||
|
batch_input_ids, batch_labels, batch_images, batch_videos, batch_audios = [], [], [], [], []
|
||||||
|
lengths = []
|
||||||
|
length2indexes = defaultdict(list)
|
||||||
|
for i in range(len(examples["_prompt"])):
|
||||||
|
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
|
||||||
|
logger.warning_rank0(
|
||||||
|
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
input_ids, labels = self._encode_data_example(
|
||||||
|
prompt=examples["_prompt"][i],
|
||||||
|
response=examples["_response"][i],
|
||||||
|
system=examples["_system"][i],
|
||||||
|
tools=examples["_tools"][i],
|
||||||
|
images=examples["_images"][i] or [],
|
||||||
|
videos=examples["_videos"][i] or [],
|
||||||
|
audios=examples["_audios"][i] or [],
|
||||||
|
)
|
||||||
|
length = len(input_ids)
|
||||||
|
if length > self.data_args.cutoff_len:
|
||||||
|
logger.warning_rank0(f"Dropped lengthy example with length {length} > {self.data_args.cutoff_len}.")
|
||||||
|
else:
|
||||||
|
lengths.append(length)
|
||||||
|
length2indexes[length].append(valid_num)
|
||||||
|
batch_input_ids.append(input_ids)
|
||||||
|
batch_labels.append(labels)
|
||||||
|
batch_images.append(examples["_images"][i] or [])
|
||||||
|
batch_videos.append(examples["_videos"][i] or [])
|
||||||
|
batch_audios.append(examples["_audios"][i] or [])
|
||||||
|
valid_num += 1
|
||||||
|
|
||||||
|
model_inputs = defaultdict(list)
|
||||||
|
knapsacks = greedy_knapsack(lengths, self.data_args.cutoff_len)
|
||||||
|
for knapsack in knapsacks:
|
||||||
|
packed_input_ids, packed_attention_masks, packed_labels = [], [], []
|
||||||
|
packed_images, packed_videos, packed_audios = [], [], []
|
||||||
|
for i, length in enumerate(knapsack):
|
||||||
|
index = length2indexes[length].pop()
|
||||||
|
packed_input_ids += batch_input_ids[index]
|
||||||
|
packed_labels += batch_labels[index]
|
||||||
|
packed_images += batch_images[index]
|
||||||
|
packed_videos += batch_videos[index]
|
||||||
|
packed_audios += batch_audios[index]
|
||||||
|
if self.data_args.neat_packing:
|
||||||
|
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
|
||||||
|
else:
|
||||||
|
packed_attention_masks += [1] * len(batch_input_ids[index])
|
||||||
|
|
||||||
|
if len(packed_input_ids) < self.data_args.cutoff_len + 1: # avoid flash_attn drops attn mask
|
||||||
|
pad_length = self.data_args.cutoff_len - len(packed_input_ids) + 1
|
||||||
|
packed_input_ids += [self.tokenizer.pad_token_id] * pad_length
|
||||||
|
packed_labels += [IGNORE_INDEX] * pad_length
|
||||||
|
if self.data_args.neat_packing:
|
||||||
|
packed_attention_masks += [0] * pad_length
|
||||||
|
else:
|
||||||
|
packed_attention_masks += [1] * pad_length # more efficient flash_attn
|
||||||
|
|
||||||
|
if len(packed_input_ids) != self.data_args.cutoff_len + 1:
|
||||||
|
raise ValueError("The length of packed example should be identical to the cutoff length.")
|
||||||
|
|
||||||
|
model_inputs["input_ids"].append(packed_input_ids)
|
||||||
|
model_inputs["attention_mask"].append(packed_attention_masks)
|
||||||
|
model_inputs["labels"].append(packed_labels)
|
||||||
|
model_inputs["images"].append(packed_images or None)
|
||||||
|
model_inputs["videos"].append(packed_videos or None)
|
||||||
|
model_inputs["audios"].append(packed_audios or None)
|
||||||
|
|
||||||
|
return model_inputs
|
91
src/llamafactory/data/processor/unsupervised.py
Normal file
91
src/llamafactory/data/processor/unsupervised.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
from ...extras import logging
|
||||||
|
from ..data_utils import Role
|
||||||
|
from .processor_utils import DatasetProcessor, infer_seqlen
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..mm_plugin import AudioInput, ImageInput, VideoInput
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class UnsupervisedDatasetProcessor(DatasetProcessor):
|
||||||
|
def _encode_data_example(
|
||||||
|
self,
|
||||||
|
prompt: Sequence[Dict[str, str]],
|
||||||
|
response: Sequence[Dict[str, str]],
|
||||||
|
system: Optional[str],
|
||||||
|
tools: Optional[str],
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
audios: Sequence["AudioInput"],
|
||||||
|
) -> Tuple[List[int], List[int]]:
|
||||||
|
if len(response) == 1:
|
||||||
|
messages = prompt + response
|
||||||
|
else:
|
||||||
|
messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||||
|
|
||||||
|
messages = self.template.mm_plugin.process_messages(messages, images, videos, audios, self.processor)
|
||||||
|
input_ids, labels = self.template.encode_oneturn(self.tokenizer, messages, system, tools)
|
||||||
|
if self.template.efficient_eos:
|
||||||
|
labels += [self.tokenizer.eos_token_id]
|
||||||
|
|
||||||
|
input_ids, _ = self.template.mm_plugin.process_token_ids(
|
||||||
|
input_ids, None, images, videos, audios, self.tokenizer, self.processor
|
||||||
|
)
|
||||||
|
source_len, target_len = infer_seqlen(len(input_ids), len(labels), self.data_args.cutoff_len)
|
||||||
|
input_ids = input_ids[:source_len]
|
||||||
|
labels = labels[:target_len]
|
||||||
|
return input_ids, labels
|
||||||
|
|
||||||
|
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||||
|
# build inputs with format `<bos> X` and labels with format `Y <eos>`
|
||||||
|
model_inputs = defaultdict(list)
|
||||||
|
for i in range(len(examples["_prompt"])):
|
||||||
|
if len(examples["_prompt"][i]) % 2 != 1:
|
||||||
|
logger.warning_rank0(
|
||||||
|
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
input_ids, labels = self._encode_data_example(
|
||||||
|
prompt=examples["_prompt"][i],
|
||||||
|
response=examples["_response"][i],
|
||||||
|
system=examples["_system"][i],
|
||||||
|
tools=examples["_tools"][i],
|
||||||
|
images=examples["_images"][i] or [],
|
||||||
|
videos=examples["_videos"][i] or [],
|
||||||
|
audios=examples["_audios"][i] or [],
|
||||||
|
)
|
||||||
|
model_inputs["input_ids"].append(input_ids)
|
||||||
|
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||||
|
model_inputs["labels"].append(labels)
|
||||||
|
model_inputs["images"].append(examples["_images"][i])
|
||||||
|
model_inputs["videos"].append(examples["_videos"][i])
|
||||||
|
model_inputs["audios"].append(examples["_audios"][i])
|
||||||
|
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
|
def print_data_example(self, example: Dict[str, List[int]]) -> None:
|
||||||
|
print("input_ids:\n{}".format(example["input_ids"]))
|
||||||
|
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||||
|
print("label_ids:\n{}".format(example["labels"]))
|
||||||
|
print("labels:\n{}".format(self.tokenizer.decode(example["labels"], skip_special_tokens=False)))
|
@ -1,137 +0,0 @@
|
|||||||
# Copyright 2025 the LlamaFactory team.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from collections import defaultdict
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
|
||||||
|
|
||||||
from ...extras import logging
|
|
||||||
from ...extras.constants import IGNORE_INDEX
|
|
||||||
from .processor_utils import infer_seqlen
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
|
||||||
|
|
||||||
from ...hparams import DataArguments
|
|
||||||
from ..mm_plugin import AudioInput, ImageInput, VideoInput
|
|
||||||
from ..template import Template
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def _encode_feedback_example(
|
|
||||||
prompt: Sequence[Dict[str, str]],
|
|
||||||
response: Sequence[Dict[str, str]],
|
|
||||||
kl_response: Sequence[Dict[str, str]],
|
|
||||||
system: Optional[str],
|
|
||||||
tools: Optional[str],
|
|
||||||
images: Sequence["ImageInput"],
|
|
||||||
videos: Sequence["VideoInput"],
|
|
||||||
audios: Sequence["AudioInput"],
|
|
||||||
template: "Template",
|
|
||||||
tokenizer: "PreTrainedTokenizer",
|
|
||||||
processor: Optional["ProcessorMixin"],
|
|
||||||
cutoff_len: int,
|
|
||||||
) -> Tuple[List[int], List[int], List[int], List[int], bool]:
|
|
||||||
if response[0]["content"]: # desired example
|
|
||||||
kto_tag = True
|
|
||||||
messages = prompt + [response[0]]
|
|
||||||
else: # undesired example
|
|
||||||
kto_tag = False
|
|
||||||
messages = prompt + [response[1]]
|
|
||||||
|
|
||||||
if kl_response[0]["content"]:
|
|
||||||
kl_messages = prompt + [kl_response[0]]
|
|
||||||
else:
|
|
||||||
kl_messages = prompt + [kl_response[1]]
|
|
||||||
|
|
||||||
messages = template.mm_plugin.process_messages(messages, images, videos, audios, processor)
|
|
||||||
kl_messages = template.mm_plugin.process_messages(kl_messages, images, videos, audios, processor)
|
|
||||||
prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools)
|
|
||||||
kl_prompt_ids, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools)
|
|
||||||
|
|
||||||
if template.efficient_eos:
|
|
||||||
response_ids += [tokenizer.eos_token_id]
|
|
||||||
kl_response_ids += [tokenizer.eos_token_id]
|
|
||||||
|
|
||||||
prompt_ids, _ = template.mm_plugin.process_token_ids(
|
|
||||||
prompt_ids, None, images, videos, audios, tokenizer, processor
|
|
||||||
)
|
|
||||||
kl_prompt_ids, _ = template.mm_plugin.process_token_ids(
|
|
||||||
kl_prompt_ids, None, images, videos, audios, tokenizer, processor
|
|
||||||
)
|
|
||||||
|
|
||||||
source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), cutoff_len)
|
|
||||||
prompt_ids = prompt_ids[:source_len]
|
|
||||||
response_ids = response_ids[:target_len]
|
|
||||||
kl_source_len, kl_target_len = infer_seqlen(len(kl_prompt_ids), len(kl_response_ids), cutoff_len)
|
|
||||||
kl_prompt_ids = kl_prompt_ids[:kl_source_len]
|
|
||||||
kl_response_ids = kl_response_ids[:kl_target_len]
|
|
||||||
|
|
||||||
input_ids = prompt_ids + response_ids
|
|
||||||
labels = [IGNORE_INDEX] * source_len + response_ids
|
|
||||||
kl_input_ids = kl_prompt_ids + kl_response_ids
|
|
||||||
kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids
|
|
||||||
return input_ids, labels, kl_input_ids, kl_labels, kto_tag
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_feedback_dataset(
|
|
||||||
examples: Dict[str, List[Any]],
|
|
||||||
template: "Template",
|
|
||||||
tokenizer: "PreTrainedTokenizer",
|
|
||||||
processor: Optional["ProcessorMixin"],
|
|
||||||
data_args: "DataArguments",
|
|
||||||
) -> Dict[str, List[Any]]:
|
|
||||||
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
|
|
||||||
kl_response = examples["_response"][::-1]
|
|
||||||
model_inputs = defaultdict(list)
|
|
||||||
for i in range(len(examples["_prompt"])):
|
|
||||||
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
|
|
||||||
logger.warning_rank0(
|
|
||||||
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example(
|
|
||||||
prompt=examples["_prompt"][i],
|
|
||||||
response=examples["_response"][i],
|
|
||||||
kl_response=kl_response[i],
|
|
||||||
system=examples["_system"][i],
|
|
||||||
tools=examples["_tools"][i],
|
|
||||||
images=examples["_images"][i] or [],
|
|
||||||
videos=examples["_videos"][i] or [],
|
|
||||||
audios=examples["_audios"][i] or [],
|
|
||||||
template=template,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
processor=processor,
|
|
||||||
cutoff_len=data_args.cutoff_len,
|
|
||||||
)
|
|
||||||
model_inputs["input_ids"].append(input_ids)
|
|
||||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
|
||||||
model_inputs["labels"].append(labels)
|
|
||||||
model_inputs["kl_input_ids"].append(kl_input_ids)
|
|
||||||
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
|
|
||||||
model_inputs["kl_labels"].append(kl_labels)
|
|
||||||
model_inputs["kto_tags"].append(kto_tag)
|
|
||||||
model_inputs["images"].append(examples["_images"][i])
|
|
||||||
model_inputs["videos"].append(examples["_videos"][i])
|
|
||||||
model_inputs["audios"].append(examples["_audios"][i])
|
|
||||||
|
|
||||||
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
|
|
||||||
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
|
|
||||||
if desirable_num == 0 or undesirable_num == 0:
|
|
||||||
logger.warning_rank0("Your dataset only has one preference type.")
|
|
||||||
|
|
||||||
return model_inputs
|
|
@ -1,124 +0,0 @@
|
|||||||
# Copyright 2025 the LlamaFactory team.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from collections import defaultdict
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
|
||||||
|
|
||||||
from ...extras import logging
|
|
||||||
from ...extras.constants import IGNORE_INDEX
|
|
||||||
from .processor_utils import infer_seqlen
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
|
||||||
|
|
||||||
from ...hparams import DataArguments
|
|
||||||
from ..mm_plugin import AudioInput, ImageInput, VideoInput
|
|
||||||
from ..template import Template
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def _encode_pairwise_example(
|
|
||||||
prompt: Sequence[Dict[str, str]],
|
|
||||||
response: Sequence[Dict[str, str]],
|
|
||||||
system: Optional[str],
|
|
||||||
tools: Optional[str],
|
|
||||||
images: Sequence["ImageInput"],
|
|
||||||
videos: Sequence["VideoInput"],
|
|
||||||
audios: Sequence["AudioInput"],
|
|
||||||
template: "Template",
|
|
||||||
tokenizer: "PreTrainedTokenizer",
|
|
||||||
processor: Optional["ProcessorMixin"],
|
|
||||||
cutoff_len: int,
|
|
||||||
) -> Tuple[List[int], List[int], List[int], List[int]]:
|
|
||||||
chosen_messages = template.mm_plugin.process_messages(prompt + [response[0]], images, videos, audios, processor)
|
|
||||||
rejected_messages = template.mm_plugin.process_messages(prompt + [response[1]], images, videos, audios, processor)
|
|
||||||
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools)
|
|
||||||
_, rejected_ids = template.encode_oneturn(tokenizer, rejected_messages, system, tools)
|
|
||||||
|
|
||||||
if template.efficient_eos:
|
|
||||||
chosen_ids += [tokenizer.eos_token_id]
|
|
||||||
rejected_ids += [tokenizer.eos_token_id]
|
|
||||||
|
|
||||||
prompt_ids, _ = template.mm_plugin.process_token_ids(
|
|
||||||
prompt_ids, None, images, videos, audios, tokenizer, processor
|
|
||||||
)
|
|
||||||
# consider the response is more important
|
|
||||||
source_len, target_len = infer_seqlen(len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), cutoff_len)
|
|
||||||
prompt_ids = prompt_ids[:source_len]
|
|
||||||
chosen_ids = chosen_ids[:target_len]
|
|
||||||
rejected_ids = rejected_ids[:target_len]
|
|
||||||
|
|
||||||
chosen_input_ids = prompt_ids + chosen_ids
|
|
||||||
chosen_labels = [IGNORE_INDEX] * source_len + chosen_ids
|
|
||||||
rejected_input_ids = prompt_ids + rejected_ids
|
|
||||||
rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids
|
|
||||||
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_pairwise_dataset(
|
|
||||||
examples: Dict[str, List[Any]],
|
|
||||||
template: "Template",
|
|
||||||
tokenizer: "PreTrainedTokenizer",
|
|
||||||
processor: Optional["ProcessorMixin"],
|
|
||||||
data_args: "DataArguments",
|
|
||||||
) -> Dict[str, List[Any]]:
|
|
||||||
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
|
|
||||||
model_inputs = defaultdict(list)
|
|
||||||
for i in range(len(examples["_prompt"])):
|
|
||||||
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
|
|
||||||
logger.warning_rank0(
|
|
||||||
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example(
|
|
||||||
prompt=examples["_prompt"][i],
|
|
||||||
response=examples["_response"][i],
|
|
||||||
system=examples["_system"][i],
|
|
||||||
tools=examples["_tools"][i],
|
|
||||||
images=examples["_images"][i] or [],
|
|
||||||
videos=examples["_videos"][i] or [],
|
|
||||||
audios=examples["_audios"][i] or [],
|
|
||||||
template=template,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
processor=processor,
|
|
||||||
cutoff_len=data_args.cutoff_len,
|
|
||||||
)
|
|
||||||
model_inputs["chosen_input_ids"].append(chosen_input_ids)
|
|
||||||
model_inputs["chosen_attention_mask"].append([1] * len(chosen_input_ids))
|
|
||||||
model_inputs["chosen_labels"].append(chosen_labels)
|
|
||||||
model_inputs["rejected_input_ids"].append(rejected_input_ids)
|
|
||||||
model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids))
|
|
||||||
model_inputs["rejected_labels"].append(rejected_labels)
|
|
||||||
model_inputs["images"].append(examples["_images"][i])
|
|
||||||
model_inputs["videos"].append(examples["_videos"][i])
|
|
||||||
model_inputs["audios"].append(examples["_audios"][i])
|
|
||||||
|
|
||||||
return model_inputs
|
|
||||||
|
|
||||||
|
|
||||||
def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
|
||||||
valid_chosen_labels = list(filter(lambda x: x != IGNORE_INDEX, example["chosen_labels"]))
|
|
||||||
valid_rejected_labels = list(filter(lambda x: x != IGNORE_INDEX, example["rejected_labels"]))
|
|
||||||
print("chosen_input_ids:\n{}".format(example["chosen_input_ids"]))
|
|
||||||
print("chosen_inputs:\n{}".format(tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False)))
|
|
||||||
print("chosen_label_ids:\n{}".format(example["chosen_labels"]))
|
|
||||||
print(f"chosen_labels:\n{tokenizer.decode(valid_chosen_labels, skip_special_tokens=False)}")
|
|
||||||
print("rejected_input_ids:\n{}".format(example["rejected_input_ids"]))
|
|
||||||
print("rejected_inputs:\n{}".format(tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False)))
|
|
||||||
print("rejected_label_ids:\n{}".format(example["rejected_labels"]))
|
|
||||||
print(f"rejected_labels:\n{tokenizer.decode(valid_rejected_labels, skip_special_tokens=False)}")
|
|
@ -1,59 +0,0 @@
|
|||||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
|
||||||
#
|
|
||||||
# This code is inspired by the HuggingFace's transformers library.
|
|
||||||
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from itertools import chain
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from transformers import PreTrainedTokenizer
|
|
||||||
|
|
||||||
from ...hparams import DataArguments
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_pretrain_dataset(
|
|
||||||
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
|
|
||||||
) -> Dict[str, List[Any]]:
|
|
||||||
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
|
|
||||||
eos_token = "<|end_of_text|>" if data_args.template == "llama3" else tokenizer.eos_token
|
|
||||||
text_examples = [messages[0]["content"] + eos_token for messages in examples["_prompt"]]
|
|
||||||
|
|
||||||
if not data_args.packing:
|
|
||||||
if getattr(tokenizer, "add_bos_token", False):
|
|
||||||
text_examples = [tokenizer.bos_token + example for example in text_examples]
|
|
||||||
|
|
||||||
result = tokenizer(text_examples, add_special_tokens=False, truncation=True, max_length=data_args.cutoff_len)
|
|
||||||
else:
|
|
||||||
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
|
|
||||||
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
|
|
||||||
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
|
|
||||||
block_size = data_args.cutoff_len
|
|
||||||
total_length = (total_length // block_size) * block_size
|
|
||||||
result = {
|
|
||||||
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
|
||||||
for k, t in concatenated_examples.items()
|
|
||||||
}
|
|
||||||
if getattr(tokenizer, "add_bos_token", False):
|
|
||||||
for i in range(len(result["input_ids"])):
|
|
||||||
result["input_ids"][i][0] = tokenizer.bos_token_id
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def print_pretrain_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
|
||||||
print("input_ids:\n{}".format(example["input_ids"]))
|
|
||||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
|
@ -1,226 +0,0 @@
|
|||||||
# Copyright 2025 the LlamaFactory team.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from collections import defaultdict
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
|
||||||
|
|
||||||
from ...extras import logging
|
|
||||||
from ...extras.constants import IGNORE_INDEX
|
|
||||||
from .processor_utils import greedy_knapsack, infer_seqlen
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
|
||||||
|
|
||||||
from ...hparams import DataArguments
|
|
||||||
from ..mm_plugin import AudioInput, ImageInput, VideoInput
|
|
||||||
from ..template import Template
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def _encode_supervised_example(
|
|
||||||
prompt: Sequence[Dict[str, str]],
|
|
||||||
response: Sequence[Dict[str, str]],
|
|
||||||
system: Optional[str],
|
|
||||||
tools: Optional[str],
|
|
||||||
images: Sequence["ImageInput"],
|
|
||||||
videos: Sequence["VideoInput"],
|
|
||||||
audios: Sequence["AudioInput"],
|
|
||||||
template: "Template",
|
|
||||||
tokenizer: "PreTrainedTokenizer",
|
|
||||||
processor: Optional["ProcessorMixin"],
|
|
||||||
cutoff_len: int,
|
|
||||||
train_on_prompt: bool,
|
|
||||||
mask_history: bool,
|
|
||||||
) -> Tuple[List[int], List[int]]:
|
|
||||||
messages = template.mm_plugin.process_messages(prompt + response, images, videos, audios, processor)
|
|
||||||
input_ids, labels = template.mm_plugin.process_token_ids([], [], images, videos, audios, tokenizer, processor)
|
|
||||||
encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
|
|
||||||
total_length = len(input_ids) + (1 if template.efficient_eos else 0)
|
|
||||||
if mask_history:
|
|
||||||
encoded_pairs = encoded_pairs[::-1] # high priority for last turns
|
|
||||||
|
|
||||||
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
|
|
||||||
if total_length >= cutoff_len:
|
|
||||||
break
|
|
||||||
|
|
||||||
source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), cutoff_len - total_length)
|
|
||||||
source_ids = source_ids[:source_len]
|
|
||||||
target_ids = target_ids[:target_len]
|
|
||||||
total_length += source_len + target_len
|
|
||||||
|
|
||||||
if train_on_prompt:
|
|
||||||
source_label = source_ids
|
|
||||||
elif template.efficient_eos:
|
|
||||||
source_label = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1)
|
|
||||||
else:
|
|
||||||
source_label = [IGNORE_INDEX] * source_len
|
|
||||||
|
|
||||||
if mask_history and turn_idx != 0: # train on the last turn only
|
|
||||||
target_label = [IGNORE_INDEX] * target_len
|
|
||||||
else:
|
|
||||||
target_label = target_ids
|
|
||||||
|
|
||||||
if mask_history: # reversed sequences
|
|
||||||
input_ids = source_ids + target_ids + input_ids
|
|
||||||
labels = source_label + target_label + labels
|
|
||||||
else:
|
|
||||||
input_ids += source_ids + target_ids
|
|
||||||
labels += source_label + target_label
|
|
||||||
|
|
||||||
if template.efficient_eos:
|
|
||||||
input_ids += [tokenizer.eos_token_id]
|
|
||||||
labels += [tokenizer.eos_token_id]
|
|
||||||
|
|
||||||
return input_ids, labels
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_supervised_dataset(
|
|
||||||
examples: Dict[str, List[Any]],
|
|
||||||
template: "Template",
|
|
||||||
tokenizer: "PreTrainedTokenizer",
|
|
||||||
processor: Optional["ProcessorMixin"],
|
|
||||||
data_args: "DataArguments",
|
|
||||||
) -> Dict[str, List[Any]]:
|
|
||||||
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
|
||||||
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
|
||||||
model_inputs = defaultdict(list)
|
|
||||||
for i in range(len(examples["_prompt"])):
|
|
||||||
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
|
|
||||||
logger.warning_rank0(
|
|
||||||
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
input_ids, labels = _encode_supervised_example(
|
|
||||||
prompt=examples["_prompt"][i],
|
|
||||||
response=examples["_response"][i],
|
|
||||||
system=examples["_system"][i],
|
|
||||||
tools=examples["_tools"][i],
|
|
||||||
images=examples["_images"][i] or [],
|
|
||||||
videos=examples["_videos"][i] or [],
|
|
||||||
audios=examples["_audios"][i] or [],
|
|
||||||
template=template,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
processor=processor,
|
|
||||||
cutoff_len=data_args.cutoff_len,
|
|
||||||
train_on_prompt=data_args.train_on_prompt,
|
|
||||||
mask_history=data_args.mask_history,
|
|
||||||
)
|
|
||||||
model_inputs["input_ids"].append(input_ids)
|
|
||||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
|
||||||
model_inputs["labels"].append(labels)
|
|
||||||
model_inputs["images"].append(examples["_images"][i])
|
|
||||||
model_inputs["videos"].append(examples["_videos"][i])
|
|
||||||
model_inputs["audios"].append(examples["_audios"][i])
|
|
||||||
|
|
||||||
return model_inputs
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_packed_supervised_dataset(
|
|
||||||
examples: Dict[str, List[Any]],
|
|
||||||
template: "Template",
|
|
||||||
tokenizer: "PreTrainedTokenizer",
|
|
||||||
processor: Optional["ProcessorMixin"],
|
|
||||||
data_args: "DataArguments",
|
|
||||||
) -> Dict[str, List[Any]]:
|
|
||||||
# TODO: use `position_ids` to achieve packing
|
|
||||||
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
|
|
||||||
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
|
|
||||||
valid_num = 0
|
|
||||||
batch_input_ids, batch_labels, batch_images, batch_videos, batch_audios = [], [], [], [], []
|
|
||||||
lengths = []
|
|
||||||
length2indexes = defaultdict(list)
|
|
||||||
for i in range(len(examples["_prompt"])):
|
|
||||||
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
|
|
||||||
logger.warning_rank0(
|
|
||||||
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
input_ids, labels = _encode_supervised_example(
|
|
||||||
prompt=examples["_prompt"][i],
|
|
||||||
response=examples["_response"][i],
|
|
||||||
system=examples["_system"][i],
|
|
||||||
tools=examples["_tools"][i],
|
|
||||||
images=examples["_images"][i] or [],
|
|
||||||
videos=examples["_videos"][i] or [],
|
|
||||||
audios=examples["_audios"][i] or [],
|
|
||||||
template=template,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
processor=processor,
|
|
||||||
cutoff_len=data_args.cutoff_len - 1, # reserved for the padding token
|
|
||||||
train_on_prompt=data_args.train_on_prompt,
|
|
||||||
mask_history=data_args.mask_history,
|
|
||||||
)
|
|
||||||
length = len(input_ids)
|
|
||||||
if length > data_args.cutoff_len:
|
|
||||||
logger.warning_rank0(f"Dropped lengthy example with length {length} > {data_args.cutoff_len}.")
|
|
||||||
else:
|
|
||||||
lengths.append(length)
|
|
||||||
length2indexes[length].append(valid_num)
|
|
||||||
batch_input_ids.append(input_ids)
|
|
||||||
batch_labels.append(labels)
|
|
||||||
batch_images.append(examples["_images"][i] or [])
|
|
||||||
batch_videos.append(examples["_videos"][i] or [])
|
|
||||||
batch_audios.append(examples["_audios"][i] or [])
|
|
||||||
valid_num += 1
|
|
||||||
|
|
||||||
model_inputs = defaultdict(list)
|
|
||||||
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len - 1) # reserved for the padding token
|
|
||||||
for knapsack in knapsacks:
|
|
||||||
packed_input_ids, packed_attention_masks, packed_labels = [], [], []
|
|
||||||
packed_images, packed_videos, packed_audios = [], [], []
|
|
||||||
for i, length in enumerate(knapsack):
|
|
||||||
index = length2indexes[length].pop()
|
|
||||||
packed_input_ids += batch_input_ids[index]
|
|
||||||
packed_labels += batch_labels[index]
|
|
||||||
packed_images += batch_images[index]
|
|
||||||
packed_videos += batch_videos[index]
|
|
||||||
packed_audios += batch_audios[index]
|
|
||||||
if data_args.neat_packing:
|
|
||||||
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
|
|
||||||
else:
|
|
||||||
packed_attention_masks += [1] * len(batch_input_ids[index])
|
|
||||||
|
|
||||||
if len(packed_input_ids) < data_args.cutoff_len:
|
|
||||||
pad_length = data_args.cutoff_len - len(packed_input_ids)
|
|
||||||
packed_input_ids += [tokenizer.pad_token_id] * pad_length
|
|
||||||
packed_labels += [IGNORE_INDEX] * pad_length
|
|
||||||
if data_args.neat_packing:
|
|
||||||
packed_attention_masks += [0] * pad_length
|
|
||||||
else:
|
|
||||||
packed_attention_masks += [1] * pad_length # more efficient flash_attn
|
|
||||||
|
|
||||||
if len(packed_input_ids) != data_args.cutoff_len:
|
|
||||||
raise ValueError("The length of packed example should be identical to the cutoff length.")
|
|
||||||
|
|
||||||
model_inputs["input_ids"].append(packed_input_ids)
|
|
||||||
model_inputs["attention_mask"].append(packed_attention_masks)
|
|
||||||
model_inputs["labels"].append(packed_labels)
|
|
||||||
model_inputs["images"].append(packed_images or None)
|
|
||||||
model_inputs["videos"].append(packed_videos or None)
|
|
||||||
model_inputs["audios"].append(packed_audios or None)
|
|
||||||
|
|
||||||
return model_inputs
|
|
||||||
|
|
||||||
|
|
||||||
def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
|
||||||
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
|
|
||||||
print("input_ids:\n{}".format(example["input_ids"]))
|
|
||||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
|
||||||
print("label_ids:\n{}".format(example["labels"]))
|
|
||||||
print(f"labels:\n{tokenizer.decode(valid_labels, skip_special_tokens=False)}")
|
|
@ -1,107 +0,0 @@
|
|||||||
# Copyright 2025 the LlamaFactory team.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from collections import defaultdict
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
|
||||||
|
|
||||||
from ...extras import logging
|
|
||||||
from ..data_utils import Role
|
|
||||||
from .processor_utils import infer_seqlen
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
|
||||||
|
|
||||||
from ...hparams import DataArguments
|
|
||||||
from ..mm_plugin import AudioInput, ImageInput, VideoInput
|
|
||||||
from ..template import Template
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def _encode_unsupervised_example(
|
|
||||||
prompt: Sequence[Dict[str, str]],
|
|
||||||
response: Sequence[Dict[str, str]],
|
|
||||||
system: Optional[str],
|
|
||||||
tools: Optional[str],
|
|
||||||
images: Sequence["ImageInput"],
|
|
||||||
videos: Sequence["VideoInput"],
|
|
||||||
audios: Sequence["AudioInput"],
|
|
||||||
template: "Template",
|
|
||||||
tokenizer: "PreTrainedTokenizer",
|
|
||||||
processor: Optional["ProcessorMixin"],
|
|
||||||
cutoff_len: int,
|
|
||||||
) -> Tuple[List[int], List[int]]:
|
|
||||||
if len(response) == 1:
|
|
||||||
messages = prompt + response
|
|
||||||
else:
|
|
||||||
messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}]
|
|
||||||
|
|
||||||
messages = template.mm_plugin.process_messages(messages, images, videos, audios, processor)
|
|
||||||
input_ids, labels = template.encode_oneturn(tokenizer, messages, system, tools)
|
|
||||||
if template.efficient_eos:
|
|
||||||
labels += [tokenizer.eos_token_id]
|
|
||||||
|
|
||||||
input_ids, _ = template.mm_plugin.process_token_ids(input_ids, None, images, videos, audios, tokenizer, processor)
|
|
||||||
source_len, target_len = infer_seqlen(len(input_ids), len(labels), cutoff_len)
|
|
||||||
input_ids = input_ids[:source_len]
|
|
||||||
labels = labels[:target_len]
|
|
||||||
return input_ids, labels
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_unsupervised_dataset(
|
|
||||||
examples: Dict[str, List[Any]],
|
|
||||||
template: "Template",
|
|
||||||
tokenizer: "PreTrainedTokenizer",
|
|
||||||
processor: Optional["ProcessorMixin"],
|
|
||||||
data_args: "DataArguments",
|
|
||||||
) -> Dict[str, List[Any]]:
|
|
||||||
# build inputs with format `<bos> X` and labels with format `Y <eos>`
|
|
||||||
model_inputs = defaultdict(list)
|
|
||||||
for i in range(len(examples["_prompt"])):
|
|
||||||
if len(examples["_prompt"][i]) % 2 != 1:
|
|
||||||
logger.warning_rank0(
|
|
||||||
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
input_ids, labels = _encode_unsupervised_example(
|
|
||||||
prompt=examples["_prompt"][i],
|
|
||||||
response=examples["_response"][i],
|
|
||||||
system=examples["_system"][i],
|
|
||||||
tools=examples["_tools"][i],
|
|
||||||
images=examples["_images"][i] or [],
|
|
||||||
videos=examples["_videos"][i] or [],
|
|
||||||
audios=examples["_audios"][i] or [],
|
|
||||||
template=template,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
processor=processor,
|
|
||||||
cutoff_len=data_args.cutoff_len,
|
|
||||||
)
|
|
||||||
model_inputs["input_ids"].append(input_ids)
|
|
||||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
|
||||||
model_inputs["labels"].append(labels)
|
|
||||||
model_inputs["images"].append(examples["_images"][i])
|
|
||||||
model_inputs["videos"].append(examples["_videos"][i])
|
|
||||||
model_inputs["audios"].append(examples["_audios"][i])
|
|
||||||
|
|
||||||
return model_inputs
|
|
||||||
|
|
||||||
|
|
||||||
def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
|
||||||
print("input_ids:\n{}".format(example["input_ids"]))
|
|
||||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
|
||||||
print("label_ids:\n{}".format(example["labels"]))
|
|
||||||
print("labels:\n{}".format(tokenizer.decode(example["labels"], skip_special_tokens=False)))
|
|
@ -405,7 +405,7 @@ class Llama2Template(Template):
|
|||||||
TEMPLATES: Dict[str, "Template"] = {}
|
TEMPLATES: Dict[str, "Template"] = {}
|
||||||
|
|
||||||
|
|
||||||
def _register_template(
|
def register_template(
|
||||||
name: str,
|
name: str,
|
||||||
format_user: Optional["Formatter"] = None,
|
format_user: Optional["Formatter"] = None,
|
||||||
format_assistant: Optional["Formatter"] = None,
|
format_assistant: Optional["Formatter"] = None,
|
||||||
@ -421,7 +421,7 @@ def _register_template(
|
|||||||
replace_eos: bool = False,
|
replace_eos: bool = False,
|
||||||
replace_jinja_template: bool = False,
|
replace_jinja_template: bool = False,
|
||||||
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
|
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
|
||||||
template_class: Type[Template] = Template,
|
template_class: Type["Template"] = Template,
|
||||||
) -> None:
|
) -> None:
|
||||||
r"""
|
r"""
|
||||||
Registers a chat template.
|
Registers a chat template.
|
||||||
@ -436,7 +436,7 @@ def _register_template(
|
|||||||
|
|
||||||
The corresponding code should be:
|
The corresponding code should be:
|
||||||
```
|
```
|
||||||
_register_template(
|
register_template(
|
||||||
name="custom",
|
name="custom",
|
||||||
format_user=StringFormatter(slots=["<user>{{content}}\n<model>"]),
|
format_user=StringFormatter(slots=["<user>{{content}}\n<model>"]),
|
||||||
format_assistant=StringFormatter(slots=["{{content}}</s>\n"]),
|
format_assistant=StringFormatter(slots=["{{content}}</s>\n"]),
|
||||||
@ -444,6 +444,9 @@ def _register_template(
|
|||||||
)
|
)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
if name in TEMPLATES:
|
||||||
|
raise ValueError(f"Template {name} already exists.")
|
||||||
|
|
||||||
default_slots = ["{{content}}"] if efficient_eos else ["{{content}}", {"eos_token"}]
|
default_slots = ["{{content}}"] if efficient_eos else ["{{content}}", {"eos_token"}]
|
||||||
default_user_formatter = StringFormatter(slots=["{{content}}"])
|
default_user_formatter = StringFormatter(slots=["{{content}}"])
|
||||||
default_assistant_formatter = StringFormatter(slots=default_slots)
|
default_assistant_formatter = StringFormatter(slots=default_slots)
|
||||||
@ -562,7 +565,7 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
|
|||||||
return template
|
return template
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="alpaca",
|
name="alpaca",
|
||||||
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]),
|
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]),
|
||||||
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n\n"]),
|
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n\n"]),
|
||||||
@ -573,7 +576,7 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="aquila",
|
name="aquila",
|
||||||
format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]),
|
format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]),
|
||||||
format_assistant=StringFormatter(slots=["{{content}}###"]),
|
format_assistant=StringFormatter(slots=["{{content}}###"]),
|
||||||
@ -586,7 +589,7 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="atom",
|
name="atom",
|
||||||
format_user=StringFormatter(
|
format_user=StringFormatter(
|
||||||
slots=[{"bos_token"}, "Human: {{content}}\n", {"eos_token"}, {"bos_token"}, "Assistant:"]
|
slots=[{"bos_token"}, "Human: {{content}}\n", {"eos_token"}, {"bos_token"}, "Assistant:"]
|
||||||
@ -595,21 +598,21 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="baichuan",
|
name="baichuan",
|
||||||
format_user=StringFormatter(slots=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]),
|
format_user=StringFormatter(slots=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]),
|
||||||
efficient_eos=True,
|
efficient_eos=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="baichuan2",
|
name="baichuan2",
|
||||||
format_user=StringFormatter(slots=["<reserved_106>{{content}}<reserved_107>"]),
|
format_user=StringFormatter(slots=["<reserved_106>{{content}}<reserved_107>"]),
|
||||||
efficient_eos=True,
|
efficient_eos=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="belle",
|
name="belle",
|
||||||
format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
|
format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
|
||||||
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n\n"]),
|
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n\n"]),
|
||||||
@ -617,13 +620,13 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="bluelm",
|
name="bluelm",
|
||||||
format_user=StringFormatter(slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]),
|
format_user=StringFormatter(slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="breeze",
|
name="breeze",
|
||||||
format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]),
|
format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]),
|
||||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
@ -631,7 +634,7 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="chatglm2",
|
name="chatglm2",
|
||||||
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
|
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
|
||||||
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
|
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
|
||||||
@ -639,7 +642,7 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="chatglm3",
|
name="chatglm3",
|
||||||
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
|
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
|
||||||
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
||||||
@ -655,7 +658,7 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="chatml",
|
name="chatml",
|
||||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||||
@ -668,7 +671,7 @@ _register_template(
|
|||||||
|
|
||||||
|
|
||||||
# copied from chatml template
|
# copied from chatml template
|
||||||
_register_template(
|
register_template(
|
||||||
name="chatml_de",
|
name="chatml_de",
|
||||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||||
@ -681,13 +684,13 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="codegeex2",
|
name="codegeex2",
|
||||||
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
|
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="codegeex4",
|
name="codegeex4",
|
||||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
|
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
|
||||||
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
|
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
|
||||||
@ -704,7 +707,7 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="cohere",
|
name="cohere",
|
||||||
format_user=StringFormatter(
|
format_user=StringFormatter(
|
||||||
slots=[
|
slots=[
|
||||||
@ -719,7 +722,7 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="cpm",
|
name="cpm",
|
||||||
format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]),
|
format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]),
|
||||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
@ -727,7 +730,7 @@ _register_template(
|
|||||||
|
|
||||||
|
|
||||||
# copied from chatml template
|
# copied from chatml template
|
||||||
_register_template(
|
register_template(
|
||||||
name="cpm3",
|
name="cpm3",
|
||||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||||
@ -738,7 +741,7 @@ _register_template(
|
|||||||
|
|
||||||
|
|
||||||
# copied from chatml template
|
# copied from chatml template
|
||||||
_register_template(
|
register_template(
|
||||||
name="dbrx",
|
name="dbrx",
|
||||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||||
@ -763,7 +766,7 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="deepseek",
|
name="deepseek",
|
||||||
format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
|
format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
|
||||||
format_system=StringFormatter(slots=["{{content}}\n\n"]),
|
format_system=StringFormatter(slots=["{{content}}\n\n"]),
|
||||||
@ -771,14 +774,14 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="deepseek3",
|
name="deepseek3",
|
||||||
format_user=StringFormatter(slots=["<|User|>{{content}}<|Assistant|>"]),
|
format_user=StringFormatter(slots=["<|User|>{{content}}<|Assistant|>"]),
|
||||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="deepseekcoder",
|
name="deepseekcoder",
|
||||||
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
|
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
|
||||||
format_assistant=StringFormatter(slots=["\n{{content}}\n<|EOT|>\n"]),
|
format_assistant=StringFormatter(slots=["\n{{content}}\n<|EOT|>\n"]),
|
||||||
@ -792,7 +795,7 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="default",
|
name="default",
|
||||||
format_user=StringFormatter(slots=["Human: {{content}}\nAssistant:"]),
|
format_user=StringFormatter(slots=["Human: {{content}}\nAssistant:"]),
|
||||||
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n"]),
|
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n"]),
|
||||||
@ -800,13 +803,13 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="empty",
|
name="empty",
|
||||||
format_assistant=StringFormatter(slots=["{{content}}"]),
|
format_assistant=StringFormatter(slots=["{{content}}"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="exaone",
|
name="exaone",
|
||||||
format_user=StringFormatter(slots=["[|user|]{{content}}\n[|assistant|]"]),
|
format_user=StringFormatter(slots=["[|user|]{{content}}\n[|assistant|]"]),
|
||||||
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n"]),
|
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n"]),
|
||||||
@ -814,7 +817,7 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="falcon",
|
name="falcon",
|
||||||
format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
|
format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
|
||||||
format_assistant=StringFormatter(slots=["{{content}}\n"]),
|
format_assistant=StringFormatter(slots=["{{content}}\n"]),
|
||||||
@ -822,14 +825,14 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="fewshot",
|
name="fewshot",
|
||||||
format_assistant=StringFormatter(slots=["{{content}}\n\n"]),
|
format_assistant=StringFormatter(slots=["{{content}}\n\n"]),
|
||||||
efficient_eos=True,
|
efficient_eos=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="gemma",
|
name="gemma",
|
||||||
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
|
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
|
||||||
format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
|
format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
|
||||||
@ -840,7 +843,7 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="glm4",
|
name="glm4",
|
||||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
|
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
|
||||||
format_assistant=StringFormatter(slots=["\n{{content}}"]),
|
format_assistant=StringFormatter(slots=["\n{{content}}"]),
|
||||||
@ -854,7 +857,7 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="granite3",
|
name="granite3",
|
||||||
format_user=StringFormatter(
|
format_user=StringFormatter(
|
||||||
slots=[
|
slots=[
|
||||||
@ -866,7 +869,7 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="index",
|
name="index",
|
||||||
format_user=StringFormatter(slots=["reserved_0{{content}}reserved_1"]),
|
format_user=StringFormatter(slots=["reserved_0{{content}}reserved_1"]),
|
||||||
format_system=StringFormatter(slots=["<unk>{{content}}"]),
|
format_system=StringFormatter(slots=["<unk>{{content}}"]),
|
||||||
@ -874,7 +877,7 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="intern",
|
name="intern",
|
||||||
format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]),
|
format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]),
|
||||||
format_assistant=StringFormatter(slots=["{{content}}<eoa>\n"]),
|
format_assistant=StringFormatter(slots=["{{content}}<eoa>\n"]),
|
||||||
@ -891,7 +894,7 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="intern2",
|
name="intern2",
|
||||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||||
@ -908,7 +911,7 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="llama2",
|
name="llama2",
|
||||||
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
|
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
|
||||||
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
|
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
|
||||||
@ -917,7 +920,7 @@ _register_template(
|
|||||||
|
|
||||||
|
|
||||||
# copied from llama2 template
|
# copied from llama2 template
|
||||||
_register_template(
|
register_template(
|
||||||
name="llama2_zh",
|
name="llama2_zh",
|
||||||
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
|
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
|
||||||
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
|
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
|
||||||
@ -926,7 +929,7 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="llama3",
|
name="llama3",
|
||||||
format_user=StringFormatter(
|
format_user=StringFormatter(
|
||||||
slots=[
|
slots=[
|
||||||
@ -954,7 +957,7 @@ _register_template(
|
|||||||
|
|
||||||
|
|
||||||
# copied from llama3 template
|
# copied from llama3 template
|
||||||
_register_template(
|
register_template(
|
||||||
name="mllama",
|
name="mllama",
|
||||||
format_user=StringFormatter(
|
format_user=StringFormatter(
|
||||||
slots=[
|
slots=[
|
||||||
@ -983,7 +986,7 @@ _register_template(
|
|||||||
|
|
||||||
|
|
||||||
# copied from vicuna template
|
# copied from vicuna template
|
||||||
_register_template(
|
register_template(
|
||||||
name="llava",
|
name="llava",
|
||||||
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
|
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
|
||||||
default_system=(
|
default_system=(
|
||||||
@ -995,7 +998,7 @@ _register_template(
|
|||||||
|
|
||||||
|
|
||||||
# copied from vicuna template
|
# copied from vicuna template
|
||||||
_register_template(
|
register_template(
|
||||||
name="llava_next",
|
name="llava_next",
|
||||||
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
|
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
|
||||||
default_system=(
|
default_system=(
|
||||||
@ -1007,7 +1010,7 @@ _register_template(
|
|||||||
|
|
||||||
|
|
||||||
# copied from llama3 template
|
# copied from llama3 template
|
||||||
_register_template(
|
register_template(
|
||||||
name="llava_next_llama3",
|
name="llava_next_llama3",
|
||||||
format_user=StringFormatter(
|
format_user=StringFormatter(
|
||||||
slots=[
|
slots=[
|
||||||
@ -1036,7 +1039,7 @@ _register_template(
|
|||||||
|
|
||||||
|
|
||||||
# copied from mistral template
|
# copied from mistral template
|
||||||
_register_template(
|
register_template(
|
||||||
name="llava_next_mistral",
|
name="llava_next_mistral",
|
||||||
format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]),
|
format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]),
|
||||||
format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]),
|
format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]),
|
||||||
@ -1051,7 +1054,7 @@ _register_template(
|
|||||||
|
|
||||||
|
|
||||||
# copied from qwen template
|
# copied from qwen template
|
||||||
_register_template(
|
register_template(
|
||||||
name="llava_next_qwen",
|
name="llava_next_qwen",
|
||||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||||
@ -1068,7 +1071,7 @@ _register_template(
|
|||||||
|
|
||||||
|
|
||||||
# copied from chatml template
|
# copied from chatml template
|
||||||
_register_template(
|
register_template(
|
||||||
name="llava_next_yi",
|
name="llava_next_yi",
|
||||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||||
@ -1079,7 +1082,7 @@ _register_template(
|
|||||||
|
|
||||||
|
|
||||||
# copied from vicuna template
|
# copied from vicuna template
|
||||||
_register_template(
|
register_template(
|
||||||
name="llava_next_video",
|
name="llava_next_video",
|
||||||
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
|
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
|
||||||
default_system=(
|
default_system=(
|
||||||
@ -1091,7 +1094,7 @@ _register_template(
|
|||||||
|
|
||||||
|
|
||||||
# copied from mistral template
|
# copied from mistral template
|
||||||
_register_template(
|
register_template(
|
||||||
name="llava_next_video_mistral",
|
name="llava_next_video_mistral",
|
||||||
format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]),
|
format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]),
|
||||||
format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]),
|
format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]),
|
||||||
@ -1106,7 +1109,7 @@ _register_template(
|
|||||||
|
|
||||||
|
|
||||||
# copied from chatml template
|
# copied from chatml template
|
||||||
_register_template(
|
register_template(
|
||||||
name="llava_next_video_yi",
|
name="llava_next_video_yi",
|
||||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||||
@ -1117,7 +1120,7 @@ _register_template(
|
|||||||
|
|
||||||
|
|
||||||
# copied from chatml template
|
# copied from chatml template
|
||||||
_register_template(
|
register_template(
|
||||||
name="marco",
|
name="marco",
|
||||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||||
@ -1133,7 +1136,7 @@ _register_template(
|
|||||||
|
|
||||||
|
|
||||||
# copied from chatml template
|
# copied from chatml template
|
||||||
_register_template(
|
register_template(
|
||||||
name="minicpm_v",
|
name="minicpm_v",
|
||||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||||
@ -1144,7 +1147,7 @@ _register_template(
|
|||||||
|
|
||||||
|
|
||||||
# copied from minicpm_v template
|
# copied from minicpm_v template
|
||||||
_register_template(
|
register_template(
|
||||||
name="minicpm_o",
|
name="minicpm_o",
|
||||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||||
@ -1155,7 +1158,7 @@ _register_template(
|
|||||||
|
|
||||||
|
|
||||||
# mistral tokenizer v3 tekken
|
# mistral tokenizer v3 tekken
|
||||||
_register_template(
|
register_template(
|
||||||
name="ministral",
|
name="ministral",
|
||||||
format_user=StringFormatter(slots=["[INST]{{content}}[/INST]"]),
|
format_user=StringFormatter(slots=["[INST]{{content}}[/INST]"]),
|
||||||
format_system=StringFormatter(slots=["{{content}}\n\n"]),
|
format_system=StringFormatter(slots=["{{content}}\n\n"]),
|
||||||
@ -1168,7 +1171,7 @@ _register_template(
|
|||||||
|
|
||||||
|
|
||||||
# mistral tokenizer v3
|
# mistral tokenizer v3
|
||||||
_register_template(
|
register_template(
|
||||||
name="mistral",
|
name="mistral",
|
||||||
format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]),
|
format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]),
|
||||||
format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]),
|
format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]),
|
||||||
@ -1182,7 +1185,7 @@ _register_template(
|
|||||||
|
|
||||||
|
|
||||||
# mistral tokenizer v7 tekken (copied from ministral)
|
# mistral tokenizer v7 tekken (copied from ministral)
|
||||||
_register_template(
|
register_template(
|
||||||
name="mistral_small",
|
name="mistral_small",
|
||||||
format_user=StringFormatter(slots=["[INST]{{content}}[/INST]"]),
|
format_user=StringFormatter(slots=["[INST]{{content}}[/INST]"]),
|
||||||
format_system=StringFormatter(slots=["[SYSTEM_PROMPT]{{content}}[/SYSTEM_PROMPT]"]),
|
format_system=StringFormatter(slots=["[SYSTEM_PROMPT]{{content}}[/SYSTEM_PROMPT]"]),
|
||||||
@ -1193,21 +1196,21 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="olmo",
|
name="olmo",
|
||||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
|
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
|
||||||
format_prefix=EmptyFormatter(slots=[{"eos_token"}]),
|
format_prefix=EmptyFormatter(slots=[{"eos_token"}]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="openchat",
|
name="openchat",
|
||||||
format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]),
|
format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]),
|
||||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="openchat-3.6",
|
name="openchat-3.6",
|
||||||
format_user=StringFormatter(
|
format_user=StringFormatter(
|
||||||
slots=[
|
slots=[
|
||||||
@ -1223,7 +1226,7 @@ _register_template(
|
|||||||
|
|
||||||
|
|
||||||
# copied from chatml template
|
# copied from chatml template
|
||||||
_register_template(
|
register_template(
|
||||||
name="opencoder",
|
name="opencoder",
|
||||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||||
@ -1234,7 +1237,7 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="orion",
|
name="orion",
|
||||||
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
|
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
|
||||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
@ -1242,7 +1245,7 @@ _register_template(
|
|||||||
|
|
||||||
|
|
||||||
# copied from gemma template
|
# copied from gemma template
|
||||||
_register_template(
|
register_template(
|
||||||
name="paligemma",
|
name="paligemma",
|
||||||
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
|
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
|
||||||
format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
|
format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
|
||||||
@ -1254,7 +1257,7 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="phi",
|
name="phi",
|
||||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
|
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
|
||||||
format_assistant=StringFormatter(slots=["{{content}}<|end|>\n"]),
|
format_assistant=StringFormatter(slots=["{{content}}<|end|>\n"]),
|
||||||
@ -1263,7 +1266,7 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="phi_small",
|
name="phi_small",
|
||||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
|
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
|
||||||
format_assistant=StringFormatter(slots=["{{content}}<|end|>\n"]),
|
format_assistant=StringFormatter(slots=["{{content}}<|end|>\n"]),
|
||||||
@ -1273,7 +1276,7 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="phi4",
|
name="phi4",
|
||||||
format_user=StringFormatter(
|
format_user=StringFormatter(
|
||||||
slots=["<|im_start|>user<|im_sep|>{{content}}<|im_end|><|im_start|>assistant<|im_sep|>"]
|
slots=["<|im_start|>user<|im_sep|>{{content}}<|im_end|><|im_start|>assistant<|im_sep|>"]
|
||||||
@ -1285,7 +1288,7 @@ _register_template(
|
|||||||
|
|
||||||
|
|
||||||
# copied from ministral template
|
# copied from ministral template
|
||||||
_register_template(
|
register_template(
|
||||||
name="pixtral",
|
name="pixtral",
|
||||||
format_user=StringFormatter(slots=["[INST]{{content}}[/INST]"]),
|
format_user=StringFormatter(slots=["[INST]{{content}}[/INST]"]),
|
||||||
format_system=StringFormatter(slots=["{{content}}\n\n"]),
|
format_system=StringFormatter(slots=["{{content}}\n\n"]),
|
||||||
@ -1299,7 +1302,7 @@ _register_template(
|
|||||||
|
|
||||||
|
|
||||||
# copied from chatml template
|
# copied from chatml template
|
||||||
_register_template(
|
register_template(
|
||||||
name="qwen",
|
name="qwen",
|
||||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||||
@ -1315,7 +1318,7 @@ _register_template(
|
|||||||
|
|
||||||
|
|
||||||
# copied from chatml template
|
# copied from chatml template
|
||||||
_register_template(
|
register_template(
|
||||||
name="qwen2_audio",
|
name="qwen2_audio",
|
||||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||||
@ -1327,7 +1330,7 @@ _register_template(
|
|||||||
|
|
||||||
|
|
||||||
# copied from qwen template
|
# copied from qwen template
|
||||||
_register_template(
|
register_template(
|
||||||
name="qwen2_vl",
|
name="qwen2_vl",
|
||||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||||
@ -1343,7 +1346,7 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="sailor",
|
name="sailor",
|
||||||
format_user=StringFormatter(slots=["<|im_start|>question\n{{content}}<|im_end|>\n<|im_start|>answer\n"]),
|
format_user=StringFormatter(slots=["<|im_start|>question\n{{content}}<|im_end|>\n<|im_start|>answer\n"]),
|
||||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||||
@ -1357,7 +1360,7 @@ _register_template(
|
|||||||
|
|
||||||
|
|
||||||
# copied from llama3 template
|
# copied from llama3 template
|
||||||
_register_template(
|
register_template(
|
||||||
name="skywork_o1",
|
name="skywork_o1",
|
||||||
format_user=StringFormatter(
|
format_user=StringFormatter(
|
||||||
slots=[
|
slots=[
|
||||||
@ -1391,7 +1394,7 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="solar",
|
name="solar",
|
||||||
format_user=StringFormatter(slots=["### User:\n{{content}}\n\n### Assistant:\n"]),
|
format_user=StringFormatter(slots=["### User:\n{{content}}\n\n### Assistant:\n"]),
|
||||||
format_system=StringFormatter(slots=["### System:\n{{content}}\n\n"]),
|
format_system=StringFormatter(slots=["### System:\n{{content}}\n\n"]),
|
||||||
@ -1399,7 +1402,7 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="starchat",
|
name="starchat",
|
||||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>"]),
|
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>"]),
|
||||||
format_assistant=StringFormatter(slots=["{{content}}<|end|>\n"]),
|
format_assistant=StringFormatter(slots=["{{content}}<|end|>\n"]),
|
||||||
@ -1408,14 +1411,14 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="telechat",
|
name="telechat",
|
||||||
format_user=StringFormatter(slots=["<_user>{{content}}<_bot>"]),
|
format_user=StringFormatter(slots=["<_user>{{content}}<_bot>"]),
|
||||||
format_system=StringFormatter(slots=["<_system>{{content}}<_end>"]),
|
format_system=StringFormatter(slots=["<_system>{{content}}<_end>"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="telechat2",
|
name="telechat2",
|
||||||
format_user=StringFormatter(slots=["<_user>{{content}}<_bot>"]),
|
format_user=StringFormatter(slots=["<_user>{{content}}<_bot>"]),
|
||||||
format_system=StringFormatter(slots=["<_system>{{content}}"]),
|
format_system=StringFormatter(slots=["<_system>{{content}}"]),
|
||||||
@ -1425,7 +1428,7 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="vicuna",
|
name="vicuna",
|
||||||
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
|
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
|
||||||
default_system=(
|
default_system=(
|
||||||
@ -1436,7 +1439,7 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="video_llava",
|
name="video_llava",
|
||||||
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
|
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
|
||||||
default_system=(
|
default_system=(
|
||||||
@ -1447,7 +1450,7 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="xuanyuan",
|
name="xuanyuan",
|
||||||
format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]),
|
format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]),
|
||||||
default_system=(
|
default_system=(
|
||||||
@ -1458,13 +1461,13 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="xverse",
|
name="xverse",
|
||||||
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: "]),
|
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: "]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="yayi",
|
name="yayi",
|
||||||
format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]),
|
format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]),
|
||||||
format_assistant=StringFormatter(slots=["{{content}}\n\n"]),
|
format_assistant=StringFormatter(slots=["{{content}}\n\n"]),
|
||||||
@ -1485,7 +1488,7 @@ _register_template(
|
|||||||
|
|
||||||
|
|
||||||
# copied from chatml template
|
# copied from chatml template
|
||||||
_register_template(
|
register_template(
|
||||||
name="yi",
|
name="yi",
|
||||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||||
@ -1494,7 +1497,7 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="yi_vl",
|
name="yi_vl",
|
||||||
format_user=StringFormatter(slots=["### Human: {{content}}\n### Assistant:"]),
|
format_user=StringFormatter(slots=["### Human: {{content}}\n### Assistant:"]),
|
||||||
format_assistant=StringFormatter(slots=["{{content}}\n"]),
|
format_assistant=StringFormatter(slots=["{{content}}\n"]),
|
||||||
@ -1511,7 +1514,7 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="yuan",
|
name="yuan",
|
||||||
format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),
|
format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),
|
||||||
format_assistant=StringFormatter(slots=["{{content}}<eod>\n"]),
|
format_assistant=StringFormatter(slots=["{{content}}<eod>\n"]),
|
||||||
@ -1519,7 +1522,7 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="zephyr",
|
name="zephyr",
|
||||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>\n"]),
|
format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>\n"]),
|
||||||
format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
|
format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
|
||||||
@ -1527,7 +1530,7 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
register_template(
|
||||||
name="ziya",
|
name="ziya",
|
||||||
format_user=StringFormatter(slots=["<human>:{{content}}\n<bot>:"]),
|
format_user=StringFormatter(slots=["<human>:{{content}}\n<bot>:"]),
|
||||||
format_assistant=StringFormatter(slots=["{{content}}\n"]),
|
format_assistant=StringFormatter(slots=["{{content}}\n"]),
|
||||||
|
@ -1549,7 +1549,7 @@ register_model_group(
|
|||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Pixtral-12B-Instruct": {
|
"Pixtral-12B": {
|
||||||
DownloadSource.DEFAULT: "mistral-community/pixtral-12b",
|
DownloadSource.DEFAULT: "mistral-community/pixtral-12b",
|
||||||
DownloadSource.MODELSCOPE: "AI-ModelScope/pixtral-12b",
|
DownloadSource.MODELSCOPE: "AI-ModelScope/pixtral-12b",
|
||||||
}
|
}
|
||||||
|
@ -16,7 +16,7 @@ from typing import Tuple
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llamafactory.data.processors.processor_utils import infer_seqlen
|
from llamafactory.data.processor.processor_utils import infer_seqlen
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
46
tests/data/test_converter.py
Normal file
46
tests/data/test_converter.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
from llamafactory.data import Role
|
||||||
|
from llamafactory.data.converter import get_dataset_converter
|
||||||
|
from llamafactory.data.parser import DatasetAttr
|
||||||
|
from llamafactory.hparams import DataArguments
|
||||||
|
|
||||||
|
|
||||||
|
def test_alpaca_converter():
|
||||||
|
dataset_attr = DatasetAttr("hf_hub", "llamafactory/tiny-supervised-dataset")
|
||||||
|
data_args = DataArguments()
|
||||||
|
example = {
|
||||||
|
"instruction": "Solve the math problem.",
|
||||||
|
"input": "3 + 4",
|
||||||
|
"output": "The answer is 7.",
|
||||||
|
}
|
||||||
|
dataset_converter = get_dataset_converter("alpaca", dataset_attr, data_args)
|
||||||
|
assert dataset_converter(example) == {
|
||||||
|
"_prompt": [{"role": Role.USER.value, "content": "Solve the math problem.\n3 + 4"}],
|
||||||
|
"_response": [{"role": Role.ASSISTANT.value, "content": "The answer is 7."}],
|
||||||
|
"_system": "",
|
||||||
|
"_tools": "",
|
||||||
|
"_images": None,
|
||||||
|
"_videos": None,
|
||||||
|
"_audios": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_sharegpt_converter():
|
||||||
|
dataset_attr = DatasetAttr("hf_hub", "llamafactory/tiny-supervised-dataset")
|
||||||
|
data_args = DataArguments()
|
||||||
|
example = {
|
||||||
|
"conversations": [
|
||||||
|
{"from": "system", "value": "You are a helpful assistant."},
|
||||||
|
{"from": "human", "value": "Solve the math problem.\n3 + 4"},
|
||||||
|
{"from": "gpt", "value": "The answer is 7."},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
dataset_converter = get_dataset_converter("sharegpt", dataset_attr, data_args)
|
||||||
|
assert dataset_converter(example) == {
|
||||||
|
"_prompt": [{"role": Role.USER.value, "content": "Solve the math problem.\n3 + 4"}],
|
||||||
|
"_response": [{"role": Role.ASSISTANT.value, "content": "The answer is 7."}],
|
||||||
|
"_system": "You are a helpful assistant.",
|
||||||
|
"_tools": "",
|
||||||
|
"_images": None,
|
||||||
|
"_videos": None,
|
||||||
|
"_audios": None,
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user