mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-01 11:12:50 +08:00
[breaking change] refactor data pipeline (#6901)
* refactor data * rename file Former-commit-id: 617c8ab467d32be5f7d5c94fa89c0e3d7d1963bc
This commit is contained in:
parent
d50e04b805
commit
1679930e00
@ -1,241 +0,0 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from ..extras import logging
|
||||
from .data_utils import Role
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset, IterableDataset
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
|
||||
from ..hparams import DataArguments
|
||||
from .parser import DatasetAttr
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _regularize_medias(
|
||||
inputs: Union[Any, Sequence[Any]],
|
||||
dataset_attr: "DatasetAttr",
|
||||
data_args: "DataArguments",
|
||||
) -> Optional[List[Any]]:
|
||||
r"""
|
||||
Optionally concatenates media path to media dir when loading from local disk.
|
||||
"""
|
||||
if not isinstance(inputs, list):
|
||||
inputs = [inputs]
|
||||
elif len(inputs) == 0:
|
||||
return None
|
||||
else:
|
||||
inputs = inputs[:]
|
||||
|
||||
if dataset_attr.load_from in ["script", "file"]:
|
||||
for i in range(len(inputs)):
|
||||
if isinstance(inputs[i], str) and os.path.isfile(os.path.join(data_args.media_dir, inputs[i])):
|
||||
inputs[i] = os.path.join(data_args.media_dir, inputs[i])
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
def convert_alpaca(
|
||||
example: Dict[str, Any],
|
||||
dataset_attr: "DatasetAttr",
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, Any]:
|
||||
r"""
|
||||
Converts alpaca format dataset to the standard format.
|
||||
"""
|
||||
prompt = []
|
||||
if dataset_attr.history and isinstance(example[dataset_attr.history], list):
|
||||
for old_prompt, old_response in example[dataset_attr.history]:
|
||||
prompt.append({"role": Role.USER.value, "content": old_prompt})
|
||||
prompt.append({"role": Role.ASSISTANT.value, "content": old_response})
|
||||
|
||||
query = []
|
||||
if dataset_attr.prompt and example[dataset_attr.prompt]:
|
||||
query.append(example[dataset_attr.prompt])
|
||||
|
||||
if dataset_attr.query and example[dataset_attr.query]:
|
||||
query.append(example[dataset_attr.query])
|
||||
|
||||
prompt.append({"role": Role.USER.value, "content": "\n".join(query)}) # "prompt\nquery"
|
||||
|
||||
if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
|
||||
response = [{"role": Role.ASSISTANT.value, "content": example[dataset_attr.response]}]
|
||||
if example[dataset_attr.kto_tag]:
|
||||
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||
else:
|
||||
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
|
||||
elif (
|
||||
dataset_attr.ranking
|
||||
and isinstance(example[dataset_attr.chosen], str)
|
||||
and isinstance(example[dataset_attr.rejected], str)
|
||||
): # pairwise example
|
||||
response = [
|
||||
{"role": Role.ASSISTANT.value, "content": example[dataset_attr.chosen]},
|
||||
{"role": Role.ASSISTANT.value, "content": example[dataset_attr.rejected]},
|
||||
]
|
||||
elif dataset_attr.response and isinstance(example[dataset_attr.response], str): # normal example
|
||||
response = [{"role": Role.ASSISTANT.value, "content": example[dataset_attr.response]}]
|
||||
else: # unsupervised
|
||||
response = []
|
||||
|
||||
regularize_medias = partial(_regularize_medias, dataset_attr=dataset_attr, data_args=data_args)
|
||||
output = {
|
||||
"_prompt": prompt,
|
||||
"_response": response,
|
||||
"_system": example[dataset_attr.system] if dataset_attr.system else "",
|
||||
"_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
|
||||
"_images": regularize_medias(example[dataset_attr.images]) if dataset_attr.images else None,
|
||||
"_videos": regularize_medias(example[dataset_attr.videos]) if dataset_attr.videos else None,
|
||||
"_audios": regularize_medias(example[dataset_attr.audios]) if dataset_attr.audios else None,
|
||||
}
|
||||
return output
|
||||
|
||||
|
||||
def convert_sharegpt(
|
||||
example: Dict[str, Any],
|
||||
dataset_attr: "DatasetAttr",
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, Any]:
|
||||
r"""
|
||||
Converts sharegpt format dataset to the standard format.
|
||||
"""
|
||||
tag_mapping = {
|
||||
dataset_attr.user_tag: Role.USER.value,
|
||||
dataset_attr.assistant_tag: Role.ASSISTANT.value,
|
||||
dataset_attr.observation_tag: Role.OBSERVATION.value,
|
||||
dataset_attr.function_tag: Role.FUNCTION.value,
|
||||
dataset_attr.system_tag: Role.SYSTEM.value,
|
||||
}
|
||||
odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag)
|
||||
even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
|
||||
accept_tags = (odd_tags, even_tags)
|
||||
messages = example[dataset_attr.messages]
|
||||
if (
|
||||
dataset_attr.system_tag
|
||||
and len(messages) != 0
|
||||
and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag
|
||||
):
|
||||
system = messages[0][dataset_attr.content_tag]
|
||||
messages = messages[1:]
|
||||
else:
|
||||
system = example[dataset_attr.system] if dataset_attr.system else ""
|
||||
|
||||
aligned_messages = []
|
||||
broken_data = False
|
||||
for turn_idx, message in enumerate(messages):
|
||||
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
|
||||
logger.warning_rank0(f"Invalid role tag in {messages}.")
|
||||
broken_data = True
|
||||
break
|
||||
|
||||
aligned_messages.append(
|
||||
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
|
||||
)
|
||||
|
||||
if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
|
||||
dataset_attr.ranking and len(aligned_messages) % 2 == 0
|
||||
):
|
||||
logger.warning_rank0(f"Invalid message count in {messages}.")
|
||||
broken_data = True
|
||||
|
||||
if broken_data:
|
||||
logger.warning_rank0("Skipping this abnormal example.")
|
||||
prompt, response = [], []
|
||||
elif dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
|
||||
prompt = aligned_messages[:-1]
|
||||
response = aligned_messages[-1:]
|
||||
if example[dataset_attr.kto_tag]:
|
||||
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||
else:
|
||||
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
|
||||
elif (
|
||||
dataset_attr.ranking
|
||||
and isinstance(example[dataset_attr.chosen], dict)
|
||||
and isinstance(example[dataset_attr.rejected], dict)
|
||||
): # pairwise example
|
||||
chosen = example[dataset_attr.chosen]
|
||||
rejected = example[dataset_attr.rejected]
|
||||
if (
|
||||
chosen[dataset_attr.role_tag] not in accept_tags[-1]
|
||||
or rejected[dataset_attr.role_tag] not in accept_tags[-1]
|
||||
):
|
||||
logger.warning_rank0(f"Invalid role tag in {[chosen, rejected]}.")
|
||||
broken_data = True
|
||||
|
||||
prompt = aligned_messages
|
||||
response = [
|
||||
{"role": tag_mapping[chosen[dataset_attr.role_tag]], "content": chosen[dataset_attr.content_tag]},
|
||||
{"role": tag_mapping[rejected[dataset_attr.role_tag]], "content": rejected[dataset_attr.content_tag]},
|
||||
]
|
||||
else: # normal example
|
||||
prompt = aligned_messages[:-1]
|
||||
response = aligned_messages[-1:]
|
||||
|
||||
regularize_medias = partial(_regularize_medias, dataset_attr=dataset_attr, data_args=data_args)
|
||||
output = {
|
||||
"_prompt": prompt,
|
||||
"_response": response,
|
||||
"_system": system,
|
||||
"_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
|
||||
"_images": regularize_medias(example[dataset_attr.images]) if dataset_attr.images else None,
|
||||
"_videos": regularize_medias(example[dataset_attr.videos]) if dataset_attr.videos else None,
|
||||
"_audios": regularize_medias(example[dataset_attr.audios]) if dataset_attr.audios else None,
|
||||
}
|
||||
return output
|
||||
|
||||
|
||||
def align_dataset(
|
||||
dataset: Union["Dataset", "IterableDataset"],
|
||||
dataset_attr: "DatasetAttr",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
r"""
|
||||
Aligned dataset:
|
||||
_prompt: [{"role": "user", "content": "..."}] * (2T - 1)
|
||||
_response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
|
||||
_system: "..."
|
||||
_tools: "...",
|
||||
_images: [],
|
||||
_videos: [],
|
||||
_audios: [],
|
||||
"""
|
||||
if dataset_attr.formatting == "alpaca":
|
||||
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args)
|
||||
else:
|
||||
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr, data_args=data_args)
|
||||
|
||||
column_names = list(next(iter(dataset)).keys())
|
||||
kwargs = {}
|
||||
if not data_args.streaming:
|
||||
kwargs = dict(
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
|
||||
desc="Converting format of dataset",
|
||||
)
|
||||
|
||||
return dataset.map(
|
||||
convert_func,
|
||||
batched=False,
|
||||
remove_columns=column_names,
|
||||
**kwargs,
|
||||
)
|
269
src/llamafactory/data/converter.py
Normal file
269
src/llamafactory/data/converter.py
Normal file
@ -0,0 +1,269 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Type, Union
|
||||
|
||||
from ..extras import logging
|
||||
from .data_utils import Role
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset, IterableDataset
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
|
||||
from ..hparams import DataArguments
|
||||
from .parser import DatasetAttr
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetConverter:
|
||||
dataset_attr: "DatasetAttr"
|
||||
data_args: "DataArguments"
|
||||
|
||||
def _find_medias(self, inputs: Union[Any, Sequence[Any]]) -> Optional[List[Any]]:
|
||||
r"""
|
||||
Optionally concatenates media path to media dir when loading from local disk.
|
||||
"""
|
||||
if not isinstance(inputs, list):
|
||||
inputs = [inputs]
|
||||
elif len(inputs) == 0:
|
||||
return None
|
||||
else:
|
||||
inputs = inputs[:]
|
||||
|
||||
if self.dataset_attr.load_from in ["script", "file"]:
|
||||
for i in range(len(inputs)):
|
||||
if isinstance(inputs[i], str) and os.path.isfile(os.path.join(self.data_args.media_dir, inputs[i])):
|
||||
inputs[i] = os.path.join(self.data_args.media_dir, inputs[i])
|
||||
|
||||
return inputs
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]:
|
||||
r"""
|
||||
Converts a single example in the dataset to the standard format.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@dataclass
|
||||
class AlpacaDatasetConverter(DatasetConverter):
|
||||
def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]:
|
||||
prompt = []
|
||||
if self.dataset_attr.history and isinstance(example[self.dataset_attr.history], list):
|
||||
for old_prompt, old_response in example[self.dataset_attr.history]:
|
||||
prompt.append({"role": Role.USER.value, "content": old_prompt})
|
||||
prompt.append({"role": Role.ASSISTANT.value, "content": old_response})
|
||||
|
||||
query = []
|
||||
if self.dataset_attr.prompt and example[self.dataset_attr.prompt]:
|
||||
query.append(example[self.dataset_attr.prompt])
|
||||
|
||||
if self.dataset_attr.query and example[self.dataset_attr.query]:
|
||||
query.append(example[self.dataset_attr.query])
|
||||
|
||||
prompt.append({"role": Role.USER.value, "content": "\n".join(query)}) # "prompt\nquery"
|
||||
|
||||
if self.dataset_attr.kto_tag and isinstance(example[self.dataset_attr.kto_tag], bool): # kto example
|
||||
response = [{"role": Role.ASSISTANT.value, "content": example[self.dataset_attr.response]}]
|
||||
if example[self.dataset_attr.kto_tag]:
|
||||
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||
else:
|
||||
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
|
||||
elif (
|
||||
self.dataset_attr.ranking
|
||||
and isinstance(example[self.dataset_attr.chosen], str)
|
||||
and isinstance(example[self.dataset_attr.rejected], str)
|
||||
): # pairwise example
|
||||
response = [
|
||||
{"role": Role.ASSISTANT.value, "content": example[self.dataset_attr.chosen]},
|
||||
{"role": Role.ASSISTANT.value, "content": example[self.dataset_attr.rejected]},
|
||||
]
|
||||
elif self.dataset_attr.response and isinstance(example[self.dataset_attr.response], str): # normal example
|
||||
response = [{"role": Role.ASSISTANT.value, "content": example[self.dataset_attr.response]}]
|
||||
else: # unsupervised
|
||||
response = []
|
||||
|
||||
output = {
|
||||
"_prompt": prompt,
|
||||
"_response": response,
|
||||
"_system": example[self.dataset_attr.system] if self.dataset_attr.system else "",
|
||||
"_tools": example[self.dataset_attr.tools] if self.dataset_attr.tools else "",
|
||||
"_images": self._find_medias(example[self.dataset_attr.images]) if self.dataset_attr.images else None,
|
||||
"_videos": self._find_medias(example[self.dataset_attr.videos]) if self.dataset_attr.videos else None,
|
||||
"_audios": self._find_medias(example[self.dataset_attr.audios]) if self.dataset_attr.audios else None,
|
||||
}
|
||||
return output
|
||||
|
||||
|
||||
@dataclass
|
||||
class SharegptDatasetConverter(DatasetConverter):
|
||||
def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]:
|
||||
tag_mapping = {
|
||||
self.dataset_attr.user_tag: Role.USER.value,
|
||||
self.dataset_attr.assistant_tag: Role.ASSISTANT.value,
|
||||
self.dataset_attr.observation_tag: Role.OBSERVATION.value,
|
||||
self.dataset_attr.function_tag: Role.FUNCTION.value,
|
||||
self.dataset_attr.system_tag: Role.SYSTEM.value,
|
||||
}
|
||||
odd_tags = (self.dataset_attr.user_tag, self.dataset_attr.observation_tag)
|
||||
even_tags = (self.dataset_attr.assistant_tag, self.dataset_attr.function_tag)
|
||||
accept_tags = (odd_tags, even_tags)
|
||||
messages = example[self.dataset_attr.messages]
|
||||
if (
|
||||
self.dataset_attr.system_tag
|
||||
and len(messages) != 0
|
||||
and messages[0][self.dataset_attr.role_tag] == self.dataset_attr.system_tag
|
||||
):
|
||||
system = messages[0][self.dataset_attr.content_tag]
|
||||
messages = messages[1:]
|
||||
else:
|
||||
system = example[self.dataset_attr.system] if self.dataset_attr.system else ""
|
||||
|
||||
aligned_messages = []
|
||||
broken_data = False
|
||||
for turn_idx, message in enumerate(messages):
|
||||
if message[self.dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
|
||||
logger.warning_rank0(f"Invalid role tag in {messages}.")
|
||||
broken_data = True
|
||||
break
|
||||
|
||||
aligned_messages.append(
|
||||
{
|
||||
"role": tag_mapping[message[self.dataset_attr.role_tag]],
|
||||
"content": message[self.dataset_attr.content_tag],
|
||||
}
|
||||
)
|
||||
|
||||
if (not self.dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
|
||||
self.dataset_attr.ranking and len(aligned_messages) % 2 == 0
|
||||
):
|
||||
logger.warning_rank0(f"Invalid message count in {messages}.")
|
||||
broken_data = True
|
||||
|
||||
if broken_data:
|
||||
logger.warning_rank0("Skipping this abnormal example.")
|
||||
prompt, response = [], []
|
||||
elif self.dataset_attr.kto_tag and isinstance(example[self.dataset_attr.kto_tag], bool): # kto example
|
||||
prompt = aligned_messages[:-1]
|
||||
response = aligned_messages[-1:]
|
||||
if example[self.dataset_attr.kto_tag]:
|
||||
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||
else:
|
||||
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
|
||||
elif (
|
||||
self.dataset_attr.ranking
|
||||
and isinstance(example[self.dataset_attr.chosen], dict)
|
||||
and isinstance(example[self.dataset_attr.rejected], dict)
|
||||
): # pairwise example
|
||||
chosen = example[self.dataset_attr.chosen]
|
||||
rejected = example[self.dataset_attr.rejected]
|
||||
if (
|
||||
chosen[self.dataset_attr.role_tag] not in accept_tags[-1]
|
||||
or rejected[self.dataset_attr.role_tag] not in accept_tags[-1]
|
||||
):
|
||||
logger.warning_rank0(f"Invalid role tag in {[chosen, rejected]}.")
|
||||
broken_data = True
|
||||
|
||||
prompt = aligned_messages
|
||||
response = [
|
||||
{
|
||||
"role": tag_mapping[chosen[self.dataset_attr.role_tag]],
|
||||
"content": chosen[self.dataset_attr.content_tag],
|
||||
},
|
||||
{
|
||||
"role": tag_mapping[rejected[self.dataset_attr.role_tag]],
|
||||
"content": rejected[self.dataset_attr.content_tag],
|
||||
},
|
||||
]
|
||||
else: # normal example
|
||||
prompt = aligned_messages[:-1]
|
||||
response = aligned_messages[-1:]
|
||||
|
||||
output = {
|
||||
"_prompt": prompt,
|
||||
"_response": response,
|
||||
"_system": system,
|
||||
"_tools": example[self.dataset_attr.tools] if self.dataset_attr.tools else "",
|
||||
"_images": self._find_medias(example[self.dataset_attr.images]) if self.dataset_attr.images else None,
|
||||
"_videos": self._find_medias(example[self.dataset_attr.videos]) if self.dataset_attr.videos else None,
|
||||
"_audios": self._find_medias(example[self.dataset_attr.audios]) if self.dataset_attr.audios else None,
|
||||
}
|
||||
return output
|
||||
|
||||
|
||||
DATASET_CONVERTERS = {
|
||||
"alpaca": AlpacaDatasetConverter,
|
||||
"sharegpt": SharegptDatasetConverter,
|
||||
}
|
||||
|
||||
|
||||
def register_dataset_converter(name: str, dataset_converter: Type["DatasetConverter"]) -> None:
|
||||
r"""
|
||||
Register a new dataset converter.
|
||||
"""
|
||||
if name in DATASET_CONVERTERS:
|
||||
raise ValueError(f"Dataset converter {name} already exists.")
|
||||
|
||||
DATASET_CONVERTERS[name] = dataset_converter
|
||||
|
||||
|
||||
def get_dataset_converter(name: str, dataset_attr: "DatasetAttr", data_args: "DataArguments") -> "DatasetConverter":
|
||||
r"""
|
||||
Gets a dataset converter.
|
||||
"""
|
||||
if name not in DATASET_CONVERTERS:
|
||||
raise ValueError(f"Dataset converter {name} not found.")
|
||||
|
||||
return DATASET_CONVERTERS[name](dataset_attr, data_args)
|
||||
|
||||
|
||||
def align_dataset(
|
||||
dataset: Union["Dataset", "IterableDataset"],
|
||||
dataset_attr: "DatasetAttr",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
r"""
|
||||
Aligned dataset:
|
||||
_prompt: [{"role": "user", "content": "..."}] * (2T - 1)
|
||||
_response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
|
||||
_system: "..."
|
||||
_tools: "...",
|
||||
_images: [],
|
||||
_videos: [],
|
||||
_audios: [],
|
||||
"""
|
||||
|
||||
column_names = list(next(iter(dataset)).keys())
|
||||
kwargs = {}
|
||||
if not data_args.streaming:
|
||||
kwargs = dict(
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
|
||||
desc="Converting format of dataset",
|
||||
)
|
||||
|
||||
dataset_converter = get_dataset_converter(dataset_attr.formatting, dataset_attr, data_args)
|
||||
return dataset.map(
|
||||
dataset_converter,
|
||||
batched=False,
|
||||
remove_columns=column_names,
|
||||
**kwargs,
|
||||
)
|
@ -22,10 +22,17 @@ from datasets import DatasetDict, load_dataset, load_from_disk
|
||||
from ..extras import logging
|
||||
from ..extras.constants import FILEEXT2TYPE
|
||||
from ..extras.misc import check_version, has_tokenized_data
|
||||
from .aligner import align_dataset
|
||||
from .converter import align_dataset
|
||||
from .data_utils import merge_dataset, split_dataset
|
||||
from .parser import get_dataset_list
|
||||
from .preprocess import get_preprocess_and_print_func
|
||||
from .processor import (
|
||||
FeedbackDatasetProcessor,
|
||||
PackedSupervisedDatasetProcessor,
|
||||
PairwiseDatasetProcessor,
|
||||
PretrainDatasetProcessor,
|
||||
SupervisedDatasetProcessor,
|
||||
UnsupervisedDatasetProcessor,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -35,6 +42,7 @@ if TYPE_CHECKING:
|
||||
from ..hparams import DataArguments, ModelArguments
|
||||
from .data_utils import DatasetModule
|
||||
from .parser import DatasetAttr
|
||||
from .processor import DatasetProcessor
|
||||
from .template import Template
|
||||
|
||||
|
||||
@ -158,7 +166,7 @@ def _get_merged_dataset(
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||
) -> Optional[Union["Dataset", "IterableDataset"]]:
|
||||
r"""
|
||||
Gets the merged datasets in the standard format.
|
||||
Returns the merged datasets in the standard format.
|
||||
"""
|
||||
if dataset_names is None:
|
||||
return None
|
||||
@ -173,6 +181,48 @@ def _get_merged_dataset(
|
||||
return merge_dataset(datasets, data_args, seed=training_args.seed)
|
||||
|
||||
|
||||
def _get_dataset_processor(
|
||||
data_args: "DataArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
do_generate: bool = False,
|
||||
) -> "DatasetProcessor":
|
||||
r"""
|
||||
Returns the corresponding dataset processor.
|
||||
"""
|
||||
if stage == "pt":
|
||||
dataset_processor_class = PretrainDatasetProcessor
|
||||
elif stage == "sft" and not do_generate:
|
||||
if data_args.packing:
|
||||
if data_args.neat_packing: # hack datasets to have int32 attention mask
|
||||
from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence
|
||||
|
||||
def __init__(self, data, **kwargs):
|
||||
return TypedSequence.__init__(
|
||||
self,
|
||||
data,
|
||||
type=kwargs.pop("type", None),
|
||||
try_type=kwargs.pop("try_type", None),
|
||||
optimized_int_type=kwargs.pop("optimized_int_type", None),
|
||||
)
|
||||
|
||||
OptimizedTypedSequence.__init__ = __init__
|
||||
dataset_processor_class = PackedSupervisedDatasetProcessor
|
||||
else:
|
||||
dataset_processor_class = SupervisedDatasetProcessor
|
||||
|
||||
elif stage == "rm":
|
||||
dataset_processor_class = PairwiseDatasetProcessor
|
||||
elif stage == "kto":
|
||||
dataset_processor_class = FeedbackDatasetProcessor
|
||||
else:
|
||||
dataset_processor_class = UnsupervisedDatasetProcessor
|
||||
|
||||
return dataset_processor_class(template=template, tokenizer=tokenizer, processor=processor, data_args=data_args)
|
||||
|
||||
|
||||
def _get_preprocessed_dataset(
|
||||
dataset: Optional[Union["Dataset", "IterableDataset"]],
|
||||
data_args: "DataArguments",
|
||||
@ -189,7 +239,7 @@ def _get_preprocessed_dataset(
|
||||
if dataset is None:
|
||||
return None
|
||||
|
||||
preprocess_func, print_function = get_preprocess_and_print_func(
|
||||
dataset_processor = _get_dataset_processor(
|
||||
data_args, stage, template, tokenizer, processor, do_generate=(training_args.predict_with_generate and is_eval)
|
||||
)
|
||||
column_names = list(next(iter(dataset)).keys())
|
||||
@ -202,7 +252,7 @@ def _get_preprocessed_dataset(
|
||||
)
|
||||
|
||||
dataset = dataset.map(
|
||||
preprocess_func,
|
||||
dataset_processor.preprocess_dataset,
|
||||
batched=True,
|
||||
batch_size=data_args.preprocessing_batch_size,
|
||||
remove_columns=column_names,
|
||||
@ -212,7 +262,7 @@ def _get_preprocessed_dataset(
|
||||
if training_args.should_log:
|
||||
try:
|
||||
print("eval example:" if is_eval else "training example:")
|
||||
print_function(next(iter(dataset)))
|
||||
dataset_processor.print_data_example(next(iter(dataset)))
|
||||
except StopIteration:
|
||||
if stage == "pt":
|
||||
raise RuntimeError("Cannot find sufficient samples, consider increasing dataset size.")
|
||||
|
@ -4,7 +4,7 @@ import re
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from io import BytesIO
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Type, TypedDict, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -1241,14 +1241,26 @@ PLUGINS = {
|
||||
}
|
||||
|
||||
|
||||
def register_mm_plugin(name: str, plugin_class: Type["BasePlugin"]) -> None:
|
||||
r"""
|
||||
Registers a multimodal plugin.
|
||||
"""
|
||||
if name in PLUGINS:
|
||||
raise ValueError(f"Multimodal plugin {name} already exists.")
|
||||
|
||||
PLUGINS[name] = plugin_class
|
||||
|
||||
|
||||
def get_mm_plugin(
|
||||
name: str,
|
||||
image_token: Optional[str] = None,
|
||||
video_token: Optional[str] = None,
|
||||
audio_token: Optional[str] = None,
|
||||
) -> "BasePlugin":
|
||||
plugin_class = PLUGINS.get(name, None)
|
||||
if plugin_class is None:
|
||||
r"""
|
||||
Gets plugin for multimodal inputs.
|
||||
"""
|
||||
if name not in PLUGINS:
|
||||
raise ValueError(f"Multimodal plugin `{name}` not found.")
|
||||
|
||||
return plugin_class(image_token, video_token, audio_token)
|
||||
return PLUGINS[name](image_token, video_token, audio_token)
|
||||
|
@ -45,7 +45,7 @@ class DatasetAttr:
|
||||
images: Optional[str] = None
|
||||
videos: Optional[str] = None
|
||||
audios: Optional[str] = None
|
||||
# rlhf columns
|
||||
# dpo columns
|
||||
chosen: Optional[str] = None
|
||||
rejected: Optional[str] = None
|
||||
kto_tag: Optional[str] = None
|
||||
@ -71,6 +71,26 @@ class DatasetAttr:
|
||||
def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None:
|
||||
setattr(self, key, obj.get(key, default))
|
||||
|
||||
def join(self, attr: Dict[str, Any]) -> None:
|
||||
self.set_attr("formatting", attr, default="alpaca")
|
||||
self.set_attr("ranking", attr, default=False)
|
||||
self.set_attr("subset", attr)
|
||||
self.set_attr("split", attr, default="train")
|
||||
self.set_attr("folder", attr)
|
||||
self.set_attr("num_samples", attr)
|
||||
|
||||
if "columns" in attr:
|
||||
column_names = ["prompt", "query", "response", "history", "messages", "system", "tools"]
|
||||
column_names += ["images", "videos", "audios", "chosen", "rejected", "kto_tag"]
|
||||
for column_name in column_names:
|
||||
self.set_attr(column_name, attr["columns"])
|
||||
|
||||
if "tags" in attr:
|
||||
tag_names = ["role_tag", "content_tag"]
|
||||
tag_names += ["user_tag", "assistant_tag", "observation_tag", "function_tag", "system_tag"]
|
||||
for tag in tag_names:
|
||||
self.set_attr(tag, attr["tags"])
|
||||
|
||||
|
||||
def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -> List["DatasetAttr"]:
|
||||
r"""
|
||||
@ -128,36 +148,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
|
||||
else:
|
||||
dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])
|
||||
|
||||
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
|
||||
dataset_attr.set_attr("ranking", dataset_info[name], default=False)
|
||||
dataset_attr.set_attr("subset", dataset_info[name])
|
||||
dataset_attr.set_attr("split", dataset_info[name], default="train")
|
||||
dataset_attr.set_attr("folder", dataset_info[name])
|
||||
dataset_attr.set_attr("num_samples", dataset_info[name])
|
||||
|
||||
if "columns" in dataset_info[name]:
|
||||
column_names = ["system", "tools", "images", "videos", "audios", "chosen", "rejected", "kto_tag"]
|
||||
if dataset_attr.formatting == "alpaca":
|
||||
column_names.extend(["prompt", "query", "response", "history"])
|
||||
else:
|
||||
column_names.extend(["messages"])
|
||||
|
||||
for column_name in column_names:
|
||||
dataset_attr.set_attr(column_name, dataset_info[name]["columns"])
|
||||
|
||||
if dataset_attr.formatting == "sharegpt" and "tags" in dataset_info[name]:
|
||||
tag_names = (
|
||||
"role_tag",
|
||||
"content_tag",
|
||||
"user_tag",
|
||||
"assistant_tag",
|
||||
"observation_tag",
|
||||
"function_tag",
|
||||
"system_tag",
|
||||
)
|
||||
for tag in tag_names:
|
||||
dataset_attr.set_attr(tag, dataset_info[name]["tags"])
|
||||
|
||||
dataset_attr.join(dataset_info[name])
|
||||
dataset_list.append(dataset_attr)
|
||||
|
||||
return dataset_list
|
||||
|
@ -1,111 +0,0 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple
|
||||
|
||||
from .processors.feedback import preprocess_feedback_dataset
|
||||
from .processors.pairwise import preprocess_pairwise_dataset, print_pairwise_dataset_example
|
||||
from .processors.pretrain import preprocess_pretrain_dataset, print_pretrain_dataset_example
|
||||
from .processors.supervised import (
|
||||
preprocess_packed_supervised_dataset,
|
||||
preprocess_supervised_dataset,
|
||||
print_supervised_dataset_example,
|
||||
)
|
||||
from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsupervised_dataset_example
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
|
||||
from ..hparams import DataArguments
|
||||
from .template import Template
|
||||
|
||||
|
||||
def get_preprocess_and_print_func(
|
||||
data_args: "DataArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
do_generate: bool = False,
|
||||
) -> Tuple[Callable, Callable]:
|
||||
if stage == "pt":
|
||||
preprocess_func = partial(
|
||||
preprocess_pretrain_dataset,
|
||||
tokenizer=tokenizer,
|
||||
data_args=data_args,
|
||||
)
|
||||
print_function = partial(print_pretrain_dataset_example, tokenizer=tokenizer)
|
||||
elif stage == "sft" and not do_generate:
|
||||
if data_args.packing:
|
||||
if data_args.neat_packing: # hack datasets to have int32 attention mask
|
||||
from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence
|
||||
|
||||
def __init__(self, data, **kwargs):
|
||||
return TypedSequence.__init__(
|
||||
self,
|
||||
data,
|
||||
type=kwargs.pop("type", None),
|
||||
try_type=kwargs.pop("try_type", None),
|
||||
optimized_int_type=kwargs.pop("optimized_int_type", None),
|
||||
)
|
||||
|
||||
OptimizedTypedSequence.__init__ = __init__
|
||||
preprocess_func = partial(
|
||||
preprocess_packed_supervised_dataset,
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
data_args=data_args,
|
||||
)
|
||||
else:
|
||||
preprocess_func = partial(
|
||||
preprocess_supervised_dataset,
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
data_args=data_args,
|
||||
)
|
||||
|
||||
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
|
||||
elif stage == "rm":
|
||||
preprocess_func = partial(
|
||||
preprocess_pairwise_dataset,
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
data_args=data_args,
|
||||
)
|
||||
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
|
||||
elif stage == "kto":
|
||||
preprocess_func = partial(
|
||||
preprocess_feedback_dataset,
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
data_args=data_args,
|
||||
)
|
||||
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
|
||||
else:
|
||||
preprocess_func = partial(
|
||||
preprocess_unsupervised_dataset,
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
data_args=data_args,
|
||||
)
|
||||
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
||||
|
||||
return preprocess_func, print_function
|
17
src/llamafactory/data/processor/__init__.py
Normal file
17
src/llamafactory/data/processor/__init__.py
Normal file
@ -0,0 +1,17 @@
|
||||
from .feedback import FeedbackDatasetProcessor
|
||||
from .pairwise import PairwiseDatasetProcessor
|
||||
from .pretrain import PretrainDatasetProcessor
|
||||
from .processor_utils import DatasetProcessor
|
||||
from .supervised import PackedSupervisedDatasetProcessor, SupervisedDatasetProcessor
|
||||
from .unsupervised import UnsupervisedDatasetProcessor
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DatasetProcessor",
|
||||
"FeedbackDatasetProcessor",
|
||||
"PairwiseDatasetProcessor",
|
||||
"PretrainDatasetProcessor",
|
||||
"PackedSupervisedDatasetProcessor",
|
||||
"SupervisedDatasetProcessor",
|
||||
"UnsupervisedDatasetProcessor",
|
||||
]
|
129
src/llamafactory/data/processor/feedback.py
Normal file
129
src/llamafactory/data/processor/feedback.py
Normal file
@ -0,0 +1,129 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from .processor_utils import DatasetProcessor, infer_seqlen
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..mm_plugin import AudioInput, ImageInput, VideoInput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class FeedbackDatasetProcessor(DatasetProcessor):
|
||||
def _encode_data_example(
|
||||
self,
|
||||
prompt: Sequence[Dict[str, str]],
|
||||
response: Sequence[Dict[str, str]],
|
||||
kl_response: Sequence[Dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
) -> Tuple[List[int], List[int], List[int], List[int], bool]:
|
||||
if response[0]["content"]: # desired example
|
||||
kto_tag = True
|
||||
messages = prompt + [response[0]]
|
||||
else: # undesired example
|
||||
kto_tag = False
|
||||
messages = prompt + [response[1]]
|
||||
|
||||
if kl_response[0]["content"]:
|
||||
kl_messages = prompt + [kl_response[0]]
|
||||
else:
|
||||
kl_messages = prompt + [kl_response[1]]
|
||||
|
||||
messages = self.template.mm_plugin.process_messages(messages, images, videos, audios, self.processor)
|
||||
kl_messages = self.template.mm_plugin.process_messages(kl_messages, images, videos, audios, self.processor)
|
||||
prompt_ids, response_ids = self.template.encode_oneturn(self.tokenizer, messages, system, tools)
|
||||
kl_prompt_ids, kl_response_ids = self.template.encode_oneturn(self.tokenizer, kl_messages, system, tools)
|
||||
|
||||
if self.template.efficient_eos:
|
||||
response_ids += [self.tokenizer.eos_token_id]
|
||||
kl_response_ids += [self.tokenizer.eos_token_id]
|
||||
|
||||
prompt_ids, _ = self.template.mm_plugin.process_token_ids(
|
||||
prompt_ids, None, images, videos, audios, self.tokenizer, self.processor
|
||||
)
|
||||
kl_prompt_ids, _ = self.template.mm_plugin.process_token_ids(
|
||||
kl_prompt_ids, None, images, videos, audios, self.tokenizer, self.processor
|
||||
)
|
||||
|
||||
source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), self.data_args.cutoff_len)
|
||||
prompt_ids = prompt_ids[:source_len]
|
||||
response_ids = response_ids[:target_len]
|
||||
kl_source_len, kl_target_len = infer_seqlen(
|
||||
len(kl_prompt_ids), len(kl_response_ids), self.data_args.cutoff_len
|
||||
)
|
||||
kl_prompt_ids = kl_prompt_ids[:kl_source_len]
|
||||
kl_response_ids = kl_response_ids[:kl_target_len]
|
||||
|
||||
input_ids = prompt_ids + response_ids
|
||||
labels = [IGNORE_INDEX] * source_len + response_ids
|
||||
kl_input_ids = kl_prompt_ids + kl_response_ids
|
||||
kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids
|
||||
return input_ids, labels, kl_input_ids, kl_labels, kto_tag
|
||||
|
||||
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
|
||||
kl_response = examples["_response"][::-1]
|
||||
model_inputs = defaultdict(list)
|
||||
for i in range(len(examples["_prompt"])):
|
||||
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
|
||||
logger.warning_rank0(
|
||||
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
||||
)
|
||||
continue
|
||||
|
||||
input_ids, labels, kl_input_ids, kl_labels, kto_tag = self._encode_data_example(
|
||||
prompt=examples["_prompt"][i],
|
||||
response=examples["_response"][i],
|
||||
kl_response=kl_response[i],
|
||||
system=examples["_system"][i],
|
||||
tools=examples["_tools"][i],
|
||||
images=examples["_images"][i] or [],
|
||||
videos=examples["_videos"][i] or [],
|
||||
audios=examples["_audios"][i] or [],
|
||||
)
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
model_inputs["kl_input_ids"].append(kl_input_ids)
|
||||
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
|
||||
model_inputs["kl_labels"].append(kl_labels)
|
||||
model_inputs["kto_tags"].append(kto_tag)
|
||||
model_inputs["images"].append(examples["_images"][i])
|
||||
model_inputs["videos"].append(examples["_videos"][i])
|
||||
model_inputs["audios"].append(examples["_audios"][i])
|
||||
|
||||
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
|
||||
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
|
||||
if desirable_num == 0 or undesirable_num == 0:
|
||||
logger.warning_rank0("Your dataset only has one preference type.")
|
||||
|
||||
return model_inputs
|
||||
|
||||
def print_data_example(self, example: Dict[str, List[int]]) -> None:
|
||||
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
print("label_ids:\n{}".format(example["labels"]))
|
||||
print(f"labels:\n{self.tokenizer.decode(valid_labels, skip_special_tokens=False)}")
|
118
src/llamafactory/data/processor/pairwise.py
Normal file
118
src/llamafactory/data/processor/pairwise.py
Normal file
@ -0,0 +1,118 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from .processor_utils import DatasetProcessor, infer_seqlen
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..mm_plugin import AudioInput, ImageInput, VideoInput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class PairwiseDatasetProcessor(DatasetProcessor):
|
||||
def _encode_data_example(
|
||||
self,
|
||||
prompt: Sequence[Dict[str, str]],
|
||||
response: Sequence[Dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
) -> Tuple[List[int], List[int], List[int], List[int]]:
|
||||
chosen_messages = self.template.mm_plugin.process_messages(
|
||||
prompt + [response[0]], images, videos, audios, self.processor
|
||||
)
|
||||
rejected_messages = self.template.mm_plugin.process_messages(
|
||||
prompt + [response[1]], images, videos, audios, self.processor
|
||||
)
|
||||
prompt_ids, chosen_ids = self.template.encode_oneturn(self.tokenizer, chosen_messages, system, tools)
|
||||
_, rejected_ids = self.template.encode_oneturn(self.tokenizer, rejected_messages, system, tools)
|
||||
|
||||
if self.template.efficient_eos:
|
||||
chosen_ids += [self.tokenizer.eos_token_id]
|
||||
rejected_ids += [self.tokenizer.eos_token_id]
|
||||
|
||||
prompt_ids, _ = self.template.mm_plugin.process_token_ids(
|
||||
prompt_ids, None, images, videos, audios, self.tokenizer, self.processor
|
||||
)
|
||||
# consider the response is more important
|
||||
source_len, target_len = infer_seqlen(
|
||||
len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), self.data_args.cutoff_len
|
||||
)
|
||||
prompt_ids = prompt_ids[:source_len]
|
||||
chosen_ids = chosen_ids[:target_len]
|
||||
rejected_ids = rejected_ids[:target_len]
|
||||
|
||||
chosen_input_ids = prompt_ids + chosen_ids
|
||||
chosen_labels = [IGNORE_INDEX] * source_len + chosen_ids
|
||||
rejected_input_ids = prompt_ids + rejected_ids
|
||||
rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids
|
||||
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels
|
||||
|
||||
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
|
||||
model_inputs = defaultdict(list)
|
||||
for i in range(len(examples["_prompt"])):
|
||||
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
|
||||
logger.warning_rank0(
|
||||
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
||||
)
|
||||
continue
|
||||
|
||||
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = self._encode_data_example(
|
||||
prompt=examples["_prompt"][i],
|
||||
response=examples["_response"][i],
|
||||
system=examples["_system"][i],
|
||||
tools=examples["_tools"][i],
|
||||
images=examples["_images"][i] or [],
|
||||
videos=examples["_videos"][i] or [],
|
||||
audios=examples["_audios"][i] or [],
|
||||
)
|
||||
model_inputs["chosen_input_ids"].append(chosen_input_ids)
|
||||
model_inputs["chosen_attention_mask"].append([1] * len(chosen_input_ids))
|
||||
model_inputs["chosen_labels"].append(chosen_labels)
|
||||
model_inputs["rejected_input_ids"].append(rejected_input_ids)
|
||||
model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids))
|
||||
model_inputs["rejected_labels"].append(rejected_labels)
|
||||
model_inputs["images"].append(examples["_images"][i])
|
||||
model_inputs["videos"].append(examples["_videos"][i])
|
||||
model_inputs["audios"].append(examples["_audios"][i])
|
||||
|
||||
return model_inputs
|
||||
|
||||
def print_data_example(self, example: Dict[str, List[int]]) -> None:
|
||||
valid_chosen_labels = list(filter(lambda x: x != IGNORE_INDEX, example["chosen_labels"]))
|
||||
valid_rejected_labels = list(filter(lambda x: x != IGNORE_INDEX, example["rejected_labels"]))
|
||||
print("chosen_input_ids:\n{}".format(example["chosen_input_ids"]))
|
||||
print(
|
||||
"chosen_inputs:\n{}".format(self.tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False))
|
||||
)
|
||||
print("chosen_label_ids:\n{}".format(example["chosen_labels"]))
|
||||
print(f"chosen_labels:\n{self.tokenizer.decode(valid_chosen_labels, skip_special_tokens=False)}")
|
||||
print("rejected_input_ids:\n{}".format(example["rejected_input_ids"]))
|
||||
print(
|
||||
"rejected_inputs:\n{}".format(
|
||||
self.tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False)
|
||||
)
|
||||
)
|
||||
print("rejected_label_ids:\n{}".format(example["rejected_labels"]))
|
||||
print(f"rejected_labels:\n{self.tokenizer.decode(valid_rejected_labels, skip_special_tokens=False)}")
|
57
src/llamafactory/data/processor/pretrain.py
Normal file
57
src/llamafactory/data/processor/pretrain.py
Normal file
@ -0,0 +1,57 @@
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's transformers library.
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from itertools import chain
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from .processor_utils import DatasetProcessor
|
||||
|
||||
|
||||
@dataclass
|
||||
class PretrainDatasetProcessor(DatasetProcessor):
|
||||
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
|
||||
eos_token = "<|end_of_text|>" if self.data_args.template == "llama3" else self.tokenizer.eos_token
|
||||
text_examples = [messages[0]["content"] + eos_token for messages in examples["_prompt"]]
|
||||
|
||||
if not self.data_args.packing:
|
||||
if getattr(self.tokenizer, "add_bos_token", False):
|
||||
text_examples = [self.tokenizer.bos_token + example for example in text_examples]
|
||||
|
||||
result = self.tokenizer(
|
||||
text_examples, add_special_tokens=False, truncation=True, max_length=self.data_args.cutoff_len
|
||||
)
|
||||
else:
|
||||
tokenized_examples = self.tokenizer(text_examples, add_special_tokens=False)
|
||||
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
|
||||
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
|
||||
block_size = self.data_args.cutoff_len
|
||||
total_length = (total_length // block_size) * block_size
|
||||
result = {
|
||||
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
||||
for k, t in concatenated_examples.items()
|
||||
}
|
||||
if getattr(self.tokenizer, "add_bos_token", False):
|
||||
for i in range(len(result["input_ids"])):
|
||||
result["input_ids"][i][0] = self.tokenizer.bos_token_id
|
||||
|
||||
return result
|
||||
|
||||
def print_data_example(self, example: Dict[str, List[int]]) -> None:
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
@ -13,7 +13,42 @@
|
||||
# limitations under the License.
|
||||
|
||||
import bisect
|
||||
from typing import List, Sequence, Tuple
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
|
||||
from ...hparams import DataArguments
|
||||
from ..template import Template
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetProcessor(ABC):
|
||||
r"""
|
||||
A class for data processors.
|
||||
"""
|
||||
|
||||
template: "Template"
|
||||
tokenizer: "PreTrainedTokenizer"
|
||||
processor: Optional["ProcessorMixin"]
|
||||
data_args: "DataArguments"
|
||||
|
||||
@abstractmethod
|
||||
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||
r"""
|
||||
Builds model inputs from the examples.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def print_data_example(self, example: Dict[str, List[int]]) -> None:
|
||||
r"""
|
||||
Print a data example to stdout.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
def search_for_fit(numbers: Sequence[int], capacity: int) -> int:
|
200
src/llamafactory/data/processor/supervised.py
Normal file
200
src/llamafactory/data/processor/supervised.py
Normal file
@ -0,0 +1,200 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from .processor_utils import DatasetProcessor, greedy_knapsack, infer_seqlen
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..mm_plugin import AudioInput, ImageInput, VideoInput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SupervisedDatasetProcessor(DatasetProcessor):
|
||||
def _encode_data_example(
|
||||
self,
|
||||
prompt: Sequence[Dict[str, str]],
|
||||
response: Sequence[Dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
messages = self.template.mm_plugin.process_messages(prompt + response, images, videos, audios, self.processor)
|
||||
input_ids, labels = self.template.mm_plugin.process_token_ids(
|
||||
[], [], images, videos, audios, self.tokenizer, self.processor
|
||||
)
|
||||
encoded_pairs = self.template.encode_multiturn(self.tokenizer, messages, system, tools)
|
||||
total_length = len(input_ids) + (1 if self.template.efficient_eos else 0)
|
||||
if self.data_args.mask_history:
|
||||
encoded_pairs = encoded_pairs[::-1] # high priority for last turns
|
||||
|
||||
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
|
||||
if total_length >= self.data_args.cutoff_len:
|
||||
break
|
||||
|
||||
source_len, target_len = infer_seqlen(
|
||||
len(source_ids), len(target_ids), self.data_args.cutoff_len - total_length
|
||||
)
|
||||
source_ids = source_ids[:source_len]
|
||||
target_ids = target_ids[:target_len]
|
||||
total_length += source_len + target_len
|
||||
|
||||
if self.data_args.train_on_prompt:
|
||||
source_label = source_ids
|
||||
elif self.template.efficient_eos:
|
||||
source_label = [self.tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1)
|
||||
else:
|
||||
source_label = [IGNORE_INDEX] * source_len
|
||||
|
||||
if self.data_args.mask_history and turn_idx != 0: # train on the last turn only
|
||||
target_label = [IGNORE_INDEX] * target_len
|
||||
else:
|
||||
target_label = target_ids
|
||||
|
||||
if self.data_args.mask_history: # reversed sequences
|
||||
input_ids = source_ids + target_ids + input_ids
|
||||
labels = source_label + target_label + labels
|
||||
else:
|
||||
input_ids += source_ids + target_ids
|
||||
labels += source_label + target_label
|
||||
|
||||
if self.template.efficient_eos:
|
||||
input_ids += [self.tokenizer.eos_token_id]
|
||||
labels += [self.tokenizer.eos_token_id]
|
||||
|
||||
return input_ids, labels
|
||||
|
||||
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
||||
model_inputs = defaultdict(list)
|
||||
for i in range(len(examples["_prompt"])):
|
||||
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
|
||||
logger.warning_rank0(
|
||||
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
||||
)
|
||||
continue
|
||||
|
||||
input_ids, labels = self._encode_data_example(
|
||||
prompt=examples["_prompt"][i],
|
||||
response=examples["_response"][i],
|
||||
system=examples["_system"][i],
|
||||
tools=examples["_tools"][i],
|
||||
images=examples["_images"][i] or [],
|
||||
videos=examples["_videos"][i] or [],
|
||||
audios=examples["_audios"][i] or [],
|
||||
)
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
model_inputs["images"].append(examples["_images"][i])
|
||||
model_inputs["videos"].append(examples["_videos"][i])
|
||||
model_inputs["audios"].append(examples["_audios"][i])
|
||||
|
||||
return model_inputs
|
||||
|
||||
def print_data_example(self, example: Dict[str, List[int]]) -> None:
|
||||
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
print("label_ids:\n{}".format(example["labels"]))
|
||||
print(f"labels:\n{self.tokenizer.decode(valid_labels, skip_special_tokens=False)}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
|
||||
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||
# TODO: use `position_ids` to achieve packing
|
||||
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
|
||||
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
|
||||
valid_num = 0
|
||||
batch_input_ids, batch_labels, batch_images, batch_videos, batch_audios = [], [], [], [], []
|
||||
lengths = []
|
||||
length2indexes = defaultdict(list)
|
||||
for i in range(len(examples["_prompt"])):
|
||||
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
|
||||
logger.warning_rank0(
|
||||
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
||||
)
|
||||
continue
|
||||
|
||||
input_ids, labels = self._encode_data_example(
|
||||
prompt=examples["_prompt"][i],
|
||||
response=examples["_response"][i],
|
||||
system=examples["_system"][i],
|
||||
tools=examples["_tools"][i],
|
||||
images=examples["_images"][i] or [],
|
||||
videos=examples["_videos"][i] or [],
|
||||
audios=examples["_audios"][i] or [],
|
||||
)
|
||||
length = len(input_ids)
|
||||
if length > self.data_args.cutoff_len:
|
||||
logger.warning_rank0(f"Dropped lengthy example with length {length} > {self.data_args.cutoff_len}.")
|
||||
else:
|
||||
lengths.append(length)
|
||||
length2indexes[length].append(valid_num)
|
||||
batch_input_ids.append(input_ids)
|
||||
batch_labels.append(labels)
|
||||
batch_images.append(examples["_images"][i] or [])
|
||||
batch_videos.append(examples["_videos"][i] or [])
|
||||
batch_audios.append(examples["_audios"][i] or [])
|
||||
valid_num += 1
|
||||
|
||||
model_inputs = defaultdict(list)
|
||||
knapsacks = greedy_knapsack(lengths, self.data_args.cutoff_len)
|
||||
for knapsack in knapsacks:
|
||||
packed_input_ids, packed_attention_masks, packed_labels = [], [], []
|
||||
packed_images, packed_videos, packed_audios = [], [], []
|
||||
for i, length in enumerate(knapsack):
|
||||
index = length2indexes[length].pop()
|
||||
packed_input_ids += batch_input_ids[index]
|
||||
packed_labels += batch_labels[index]
|
||||
packed_images += batch_images[index]
|
||||
packed_videos += batch_videos[index]
|
||||
packed_audios += batch_audios[index]
|
||||
if self.data_args.neat_packing:
|
||||
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
|
||||
else:
|
||||
packed_attention_masks += [1] * len(batch_input_ids[index])
|
||||
|
||||
if len(packed_input_ids) < self.data_args.cutoff_len + 1: # avoid flash_attn drops attn mask
|
||||
pad_length = self.data_args.cutoff_len - len(packed_input_ids) + 1
|
||||
packed_input_ids += [self.tokenizer.pad_token_id] * pad_length
|
||||
packed_labels += [IGNORE_INDEX] * pad_length
|
||||
if self.data_args.neat_packing:
|
||||
packed_attention_masks += [0] * pad_length
|
||||
else:
|
||||
packed_attention_masks += [1] * pad_length # more efficient flash_attn
|
||||
|
||||
if len(packed_input_ids) != self.data_args.cutoff_len + 1:
|
||||
raise ValueError("The length of packed example should be identical to the cutoff length.")
|
||||
|
||||
model_inputs["input_ids"].append(packed_input_ids)
|
||||
model_inputs["attention_mask"].append(packed_attention_masks)
|
||||
model_inputs["labels"].append(packed_labels)
|
||||
model_inputs["images"].append(packed_images or None)
|
||||
model_inputs["videos"].append(packed_videos or None)
|
||||
model_inputs["audios"].append(packed_audios or None)
|
||||
|
||||
return model_inputs
|
91
src/llamafactory/data/processor/unsupervised.py
Normal file
91
src/llamafactory/data/processor/unsupervised.py
Normal file
@ -0,0 +1,91 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras import logging
|
||||
from ..data_utils import Role
|
||||
from .processor_utils import DatasetProcessor, infer_seqlen
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..mm_plugin import AudioInput, ImageInput, VideoInput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class UnsupervisedDatasetProcessor(DatasetProcessor):
|
||||
def _encode_data_example(
|
||||
self,
|
||||
prompt: Sequence[Dict[str, str]],
|
||||
response: Sequence[Dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
if len(response) == 1:
|
||||
messages = prompt + response
|
||||
else:
|
||||
messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||
|
||||
messages = self.template.mm_plugin.process_messages(messages, images, videos, audios, self.processor)
|
||||
input_ids, labels = self.template.encode_oneturn(self.tokenizer, messages, system, tools)
|
||||
if self.template.efficient_eos:
|
||||
labels += [self.tokenizer.eos_token_id]
|
||||
|
||||
input_ids, _ = self.template.mm_plugin.process_token_ids(
|
||||
input_ids, None, images, videos, audios, self.tokenizer, self.processor
|
||||
)
|
||||
source_len, target_len = infer_seqlen(len(input_ids), len(labels), self.data_args.cutoff_len)
|
||||
input_ids = input_ids[:source_len]
|
||||
labels = labels[:target_len]
|
||||
return input_ids, labels
|
||||
|
||||
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||
# build inputs with format `<bos> X` and labels with format `Y <eos>`
|
||||
model_inputs = defaultdict(list)
|
||||
for i in range(len(examples["_prompt"])):
|
||||
if len(examples["_prompt"][i]) % 2 != 1:
|
||||
logger.warning_rank0(
|
||||
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
||||
)
|
||||
continue
|
||||
|
||||
input_ids, labels = self._encode_data_example(
|
||||
prompt=examples["_prompt"][i],
|
||||
response=examples["_response"][i],
|
||||
system=examples["_system"][i],
|
||||
tools=examples["_tools"][i],
|
||||
images=examples["_images"][i] or [],
|
||||
videos=examples["_videos"][i] or [],
|
||||
audios=examples["_audios"][i] or [],
|
||||
)
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
model_inputs["images"].append(examples["_images"][i])
|
||||
model_inputs["videos"].append(examples["_videos"][i])
|
||||
model_inputs["audios"].append(examples["_audios"][i])
|
||||
|
||||
return model_inputs
|
||||
|
||||
def print_data_example(self, example: Dict[str, List[int]]) -> None:
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
print("label_ids:\n{}".format(example["labels"]))
|
||||
print("labels:\n{}".format(self.tokenizer.decode(example["labels"], skip_special_tokens=False)))
|
@ -1,137 +0,0 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from .processor_utils import infer_seqlen
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
|
||||
from ...hparams import DataArguments
|
||||
from ..mm_plugin import AudioInput, ImageInput, VideoInput
|
||||
from ..template import Template
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _encode_feedback_example(
|
||||
prompt: Sequence[Dict[str, str]],
|
||||
response: Sequence[Dict[str, str]],
|
||||
kl_response: Sequence[Dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
cutoff_len: int,
|
||||
) -> Tuple[List[int], List[int], List[int], List[int], bool]:
|
||||
if response[0]["content"]: # desired example
|
||||
kto_tag = True
|
||||
messages = prompt + [response[0]]
|
||||
else: # undesired example
|
||||
kto_tag = False
|
||||
messages = prompt + [response[1]]
|
||||
|
||||
if kl_response[0]["content"]:
|
||||
kl_messages = prompt + [kl_response[0]]
|
||||
else:
|
||||
kl_messages = prompt + [kl_response[1]]
|
||||
|
||||
messages = template.mm_plugin.process_messages(messages, images, videos, audios, processor)
|
||||
kl_messages = template.mm_plugin.process_messages(kl_messages, images, videos, audios, processor)
|
||||
prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools)
|
||||
kl_prompt_ids, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools)
|
||||
|
||||
if template.efficient_eos:
|
||||
response_ids += [tokenizer.eos_token_id]
|
||||
kl_response_ids += [tokenizer.eos_token_id]
|
||||
|
||||
prompt_ids, _ = template.mm_plugin.process_token_ids(
|
||||
prompt_ids, None, images, videos, audios, tokenizer, processor
|
||||
)
|
||||
kl_prompt_ids, _ = template.mm_plugin.process_token_ids(
|
||||
kl_prompt_ids, None, images, videos, audios, tokenizer, processor
|
||||
)
|
||||
|
||||
source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), cutoff_len)
|
||||
prompt_ids = prompt_ids[:source_len]
|
||||
response_ids = response_ids[:target_len]
|
||||
kl_source_len, kl_target_len = infer_seqlen(len(kl_prompt_ids), len(kl_response_ids), cutoff_len)
|
||||
kl_prompt_ids = kl_prompt_ids[:kl_source_len]
|
||||
kl_response_ids = kl_response_ids[:kl_target_len]
|
||||
|
||||
input_ids = prompt_ids + response_ids
|
||||
labels = [IGNORE_INDEX] * source_len + response_ids
|
||||
kl_input_ids = kl_prompt_ids + kl_response_ids
|
||||
kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids
|
||||
return input_ids, labels, kl_input_ids, kl_labels, kto_tag
|
||||
|
||||
|
||||
def preprocess_feedback_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[Any]]:
|
||||
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
|
||||
kl_response = examples["_response"][::-1]
|
||||
model_inputs = defaultdict(list)
|
||||
for i in range(len(examples["_prompt"])):
|
||||
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
|
||||
logger.warning_rank0(
|
||||
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
||||
)
|
||||
continue
|
||||
|
||||
input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example(
|
||||
prompt=examples["_prompt"][i],
|
||||
response=examples["_response"][i],
|
||||
kl_response=kl_response[i],
|
||||
system=examples["_system"][i],
|
||||
tools=examples["_tools"][i],
|
||||
images=examples["_images"][i] or [],
|
||||
videos=examples["_videos"][i] or [],
|
||||
audios=examples["_audios"][i] or [],
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
cutoff_len=data_args.cutoff_len,
|
||||
)
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
model_inputs["kl_input_ids"].append(kl_input_ids)
|
||||
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
|
||||
model_inputs["kl_labels"].append(kl_labels)
|
||||
model_inputs["kto_tags"].append(kto_tag)
|
||||
model_inputs["images"].append(examples["_images"][i])
|
||||
model_inputs["videos"].append(examples["_videos"][i])
|
||||
model_inputs["audios"].append(examples["_audios"][i])
|
||||
|
||||
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
|
||||
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
|
||||
if desirable_num == 0 or undesirable_num == 0:
|
||||
logger.warning_rank0("Your dataset only has one preference type.")
|
||||
|
||||
return model_inputs
|
@ -1,124 +0,0 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from .processor_utils import infer_seqlen
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
|
||||
from ...hparams import DataArguments
|
||||
from ..mm_plugin import AudioInput, ImageInput, VideoInput
|
||||
from ..template import Template
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _encode_pairwise_example(
|
||||
prompt: Sequence[Dict[str, str]],
|
||||
response: Sequence[Dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
cutoff_len: int,
|
||||
) -> Tuple[List[int], List[int], List[int], List[int]]:
|
||||
chosen_messages = template.mm_plugin.process_messages(prompt + [response[0]], images, videos, audios, processor)
|
||||
rejected_messages = template.mm_plugin.process_messages(prompt + [response[1]], images, videos, audios, processor)
|
||||
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools)
|
||||
_, rejected_ids = template.encode_oneturn(tokenizer, rejected_messages, system, tools)
|
||||
|
||||
if template.efficient_eos:
|
||||
chosen_ids += [tokenizer.eos_token_id]
|
||||
rejected_ids += [tokenizer.eos_token_id]
|
||||
|
||||
prompt_ids, _ = template.mm_plugin.process_token_ids(
|
||||
prompt_ids, None, images, videos, audios, tokenizer, processor
|
||||
)
|
||||
# consider the response is more important
|
||||
source_len, target_len = infer_seqlen(len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), cutoff_len)
|
||||
prompt_ids = prompt_ids[:source_len]
|
||||
chosen_ids = chosen_ids[:target_len]
|
||||
rejected_ids = rejected_ids[:target_len]
|
||||
|
||||
chosen_input_ids = prompt_ids + chosen_ids
|
||||
chosen_labels = [IGNORE_INDEX] * source_len + chosen_ids
|
||||
rejected_input_ids = prompt_ids + rejected_ids
|
||||
rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids
|
||||
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels
|
||||
|
||||
|
||||
def preprocess_pairwise_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[Any]]:
|
||||
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
|
||||
model_inputs = defaultdict(list)
|
||||
for i in range(len(examples["_prompt"])):
|
||||
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
|
||||
logger.warning_rank0(
|
||||
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
||||
)
|
||||
continue
|
||||
|
||||
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example(
|
||||
prompt=examples["_prompt"][i],
|
||||
response=examples["_response"][i],
|
||||
system=examples["_system"][i],
|
||||
tools=examples["_tools"][i],
|
||||
images=examples["_images"][i] or [],
|
||||
videos=examples["_videos"][i] or [],
|
||||
audios=examples["_audios"][i] or [],
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
cutoff_len=data_args.cutoff_len,
|
||||
)
|
||||
model_inputs["chosen_input_ids"].append(chosen_input_ids)
|
||||
model_inputs["chosen_attention_mask"].append([1] * len(chosen_input_ids))
|
||||
model_inputs["chosen_labels"].append(chosen_labels)
|
||||
model_inputs["rejected_input_ids"].append(rejected_input_ids)
|
||||
model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids))
|
||||
model_inputs["rejected_labels"].append(rejected_labels)
|
||||
model_inputs["images"].append(examples["_images"][i])
|
||||
model_inputs["videos"].append(examples["_videos"][i])
|
||||
model_inputs["audios"].append(examples["_audios"][i])
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
||||
valid_chosen_labels = list(filter(lambda x: x != IGNORE_INDEX, example["chosen_labels"]))
|
||||
valid_rejected_labels = list(filter(lambda x: x != IGNORE_INDEX, example["rejected_labels"]))
|
||||
print("chosen_input_ids:\n{}".format(example["chosen_input_ids"]))
|
||||
print("chosen_inputs:\n{}".format(tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False)))
|
||||
print("chosen_label_ids:\n{}".format(example["chosen_labels"]))
|
||||
print(f"chosen_labels:\n{tokenizer.decode(valid_chosen_labels, skip_special_tokens=False)}")
|
||||
print("rejected_input_ids:\n{}".format(example["rejected_input_ids"]))
|
||||
print("rejected_inputs:\n{}".format(tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False)))
|
||||
print("rejected_label_ids:\n{}".format(example["rejected_labels"]))
|
||||
print(f"rejected_labels:\n{tokenizer.decode(valid_rejected_labels, skip_special_tokens=False)}")
|
@ -1,59 +0,0 @@
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's transformers library.
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Any, Dict, List
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from ...hparams import DataArguments
|
||||
|
||||
|
||||
def preprocess_pretrain_dataset(
|
||||
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
|
||||
) -> Dict[str, List[Any]]:
|
||||
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
|
||||
eos_token = "<|end_of_text|>" if data_args.template == "llama3" else tokenizer.eos_token
|
||||
text_examples = [messages[0]["content"] + eos_token for messages in examples["_prompt"]]
|
||||
|
||||
if not data_args.packing:
|
||||
if getattr(tokenizer, "add_bos_token", False):
|
||||
text_examples = [tokenizer.bos_token + example for example in text_examples]
|
||||
|
||||
result = tokenizer(text_examples, add_special_tokens=False, truncation=True, max_length=data_args.cutoff_len)
|
||||
else:
|
||||
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
|
||||
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
|
||||
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
|
||||
block_size = data_args.cutoff_len
|
||||
total_length = (total_length // block_size) * block_size
|
||||
result = {
|
||||
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
||||
for k, t in concatenated_examples.items()
|
||||
}
|
||||
if getattr(tokenizer, "add_bos_token", False):
|
||||
for i in range(len(result["input_ids"])):
|
||||
result["input_ids"][i][0] = tokenizer.bos_token_id
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def print_pretrain_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
@ -1,226 +0,0 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from .processor_utils import greedy_knapsack, infer_seqlen
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
|
||||
from ...hparams import DataArguments
|
||||
from ..mm_plugin import AudioInput, ImageInput, VideoInput
|
||||
from ..template import Template
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _encode_supervised_example(
|
||||
prompt: Sequence[Dict[str, str]],
|
||||
response: Sequence[Dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
cutoff_len: int,
|
||||
train_on_prompt: bool,
|
||||
mask_history: bool,
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
messages = template.mm_plugin.process_messages(prompt + response, images, videos, audios, processor)
|
||||
input_ids, labels = template.mm_plugin.process_token_ids([], [], images, videos, audios, tokenizer, processor)
|
||||
encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
|
||||
total_length = len(input_ids) + (1 if template.efficient_eos else 0)
|
||||
if mask_history:
|
||||
encoded_pairs = encoded_pairs[::-1] # high priority for last turns
|
||||
|
||||
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
|
||||
if total_length >= cutoff_len:
|
||||
break
|
||||
|
||||
source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), cutoff_len - total_length)
|
||||
source_ids = source_ids[:source_len]
|
||||
target_ids = target_ids[:target_len]
|
||||
total_length += source_len + target_len
|
||||
|
||||
if train_on_prompt:
|
||||
source_label = source_ids
|
||||
elif template.efficient_eos:
|
||||
source_label = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1)
|
||||
else:
|
||||
source_label = [IGNORE_INDEX] * source_len
|
||||
|
||||
if mask_history and turn_idx != 0: # train on the last turn only
|
||||
target_label = [IGNORE_INDEX] * target_len
|
||||
else:
|
||||
target_label = target_ids
|
||||
|
||||
if mask_history: # reversed sequences
|
||||
input_ids = source_ids + target_ids + input_ids
|
||||
labels = source_label + target_label + labels
|
||||
else:
|
||||
input_ids += source_ids + target_ids
|
||||
labels += source_label + target_label
|
||||
|
||||
if template.efficient_eos:
|
||||
input_ids += [tokenizer.eos_token_id]
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
return input_ids, labels
|
||||
|
||||
|
||||
def preprocess_supervised_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[Any]]:
|
||||
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
||||
model_inputs = defaultdict(list)
|
||||
for i in range(len(examples["_prompt"])):
|
||||
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
|
||||
logger.warning_rank0(
|
||||
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
||||
)
|
||||
continue
|
||||
|
||||
input_ids, labels = _encode_supervised_example(
|
||||
prompt=examples["_prompt"][i],
|
||||
response=examples["_response"][i],
|
||||
system=examples["_system"][i],
|
||||
tools=examples["_tools"][i],
|
||||
images=examples["_images"][i] or [],
|
||||
videos=examples["_videos"][i] or [],
|
||||
audios=examples["_audios"][i] or [],
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
cutoff_len=data_args.cutoff_len,
|
||||
train_on_prompt=data_args.train_on_prompt,
|
||||
mask_history=data_args.mask_history,
|
||||
)
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
model_inputs["images"].append(examples["_images"][i])
|
||||
model_inputs["videos"].append(examples["_videos"][i])
|
||||
model_inputs["audios"].append(examples["_audios"][i])
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
def preprocess_packed_supervised_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[Any]]:
|
||||
# TODO: use `position_ids` to achieve packing
|
||||
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
|
||||
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
|
||||
valid_num = 0
|
||||
batch_input_ids, batch_labels, batch_images, batch_videos, batch_audios = [], [], [], [], []
|
||||
lengths = []
|
||||
length2indexes = defaultdict(list)
|
||||
for i in range(len(examples["_prompt"])):
|
||||
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
|
||||
logger.warning_rank0(
|
||||
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
||||
)
|
||||
continue
|
||||
|
||||
input_ids, labels = _encode_supervised_example(
|
||||
prompt=examples["_prompt"][i],
|
||||
response=examples["_response"][i],
|
||||
system=examples["_system"][i],
|
||||
tools=examples["_tools"][i],
|
||||
images=examples["_images"][i] or [],
|
||||
videos=examples["_videos"][i] or [],
|
||||
audios=examples["_audios"][i] or [],
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
cutoff_len=data_args.cutoff_len - 1, # reserved for the padding token
|
||||
train_on_prompt=data_args.train_on_prompt,
|
||||
mask_history=data_args.mask_history,
|
||||
)
|
||||
length = len(input_ids)
|
||||
if length > data_args.cutoff_len:
|
||||
logger.warning_rank0(f"Dropped lengthy example with length {length} > {data_args.cutoff_len}.")
|
||||
else:
|
||||
lengths.append(length)
|
||||
length2indexes[length].append(valid_num)
|
||||
batch_input_ids.append(input_ids)
|
||||
batch_labels.append(labels)
|
||||
batch_images.append(examples["_images"][i] or [])
|
||||
batch_videos.append(examples["_videos"][i] or [])
|
||||
batch_audios.append(examples["_audios"][i] or [])
|
||||
valid_num += 1
|
||||
|
||||
model_inputs = defaultdict(list)
|
||||
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len - 1) # reserved for the padding token
|
||||
for knapsack in knapsacks:
|
||||
packed_input_ids, packed_attention_masks, packed_labels = [], [], []
|
||||
packed_images, packed_videos, packed_audios = [], [], []
|
||||
for i, length in enumerate(knapsack):
|
||||
index = length2indexes[length].pop()
|
||||
packed_input_ids += batch_input_ids[index]
|
||||
packed_labels += batch_labels[index]
|
||||
packed_images += batch_images[index]
|
||||
packed_videos += batch_videos[index]
|
||||
packed_audios += batch_audios[index]
|
||||
if data_args.neat_packing:
|
||||
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
|
||||
else:
|
||||
packed_attention_masks += [1] * len(batch_input_ids[index])
|
||||
|
||||
if len(packed_input_ids) < data_args.cutoff_len:
|
||||
pad_length = data_args.cutoff_len - len(packed_input_ids)
|
||||
packed_input_ids += [tokenizer.pad_token_id] * pad_length
|
||||
packed_labels += [IGNORE_INDEX] * pad_length
|
||||
if data_args.neat_packing:
|
||||
packed_attention_masks += [0] * pad_length
|
||||
else:
|
||||
packed_attention_masks += [1] * pad_length # more efficient flash_attn
|
||||
|
||||
if len(packed_input_ids) != data_args.cutoff_len:
|
||||
raise ValueError("The length of packed example should be identical to the cutoff length.")
|
||||
|
||||
model_inputs["input_ids"].append(packed_input_ids)
|
||||
model_inputs["attention_mask"].append(packed_attention_masks)
|
||||
model_inputs["labels"].append(packed_labels)
|
||||
model_inputs["images"].append(packed_images or None)
|
||||
model_inputs["videos"].append(packed_videos or None)
|
||||
model_inputs["audios"].append(packed_audios or None)
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
||||
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
print("label_ids:\n{}".format(example["labels"]))
|
||||
print(f"labels:\n{tokenizer.decode(valid_labels, skip_special_tokens=False)}")
|
@ -1,107 +0,0 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras import logging
|
||||
from ..data_utils import Role
|
||||
from .processor_utils import infer_seqlen
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
|
||||
from ...hparams import DataArguments
|
||||
from ..mm_plugin import AudioInput, ImageInput, VideoInput
|
||||
from ..template import Template
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _encode_unsupervised_example(
|
||||
prompt: Sequence[Dict[str, str]],
|
||||
response: Sequence[Dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
cutoff_len: int,
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
if len(response) == 1:
|
||||
messages = prompt + response
|
||||
else:
|
||||
messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||
|
||||
messages = template.mm_plugin.process_messages(messages, images, videos, audios, processor)
|
||||
input_ids, labels = template.encode_oneturn(tokenizer, messages, system, tools)
|
||||
if template.efficient_eos:
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
input_ids, _ = template.mm_plugin.process_token_ids(input_ids, None, images, videos, audios, tokenizer, processor)
|
||||
source_len, target_len = infer_seqlen(len(input_ids), len(labels), cutoff_len)
|
||||
input_ids = input_ids[:source_len]
|
||||
labels = labels[:target_len]
|
||||
return input_ids, labels
|
||||
|
||||
|
||||
def preprocess_unsupervised_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[Any]]:
|
||||
# build inputs with format `<bos> X` and labels with format `Y <eos>`
|
||||
model_inputs = defaultdict(list)
|
||||
for i in range(len(examples["_prompt"])):
|
||||
if len(examples["_prompt"][i]) % 2 != 1:
|
||||
logger.warning_rank0(
|
||||
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
||||
)
|
||||
continue
|
||||
|
||||
input_ids, labels = _encode_unsupervised_example(
|
||||
prompt=examples["_prompt"][i],
|
||||
response=examples["_response"][i],
|
||||
system=examples["_system"][i],
|
||||
tools=examples["_tools"][i],
|
||||
images=examples["_images"][i] or [],
|
||||
videos=examples["_videos"][i] or [],
|
||||
audios=examples["_audios"][i] or [],
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
cutoff_len=data_args.cutoff_len,
|
||||
)
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
model_inputs["images"].append(examples["_images"][i])
|
||||
model_inputs["videos"].append(examples["_videos"][i])
|
||||
model_inputs["audios"].append(examples["_audios"][i])
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
print("label_ids:\n{}".format(example["labels"]))
|
||||
print("labels:\n{}".format(tokenizer.decode(example["labels"], skip_special_tokens=False)))
|
@ -405,7 +405,7 @@ class Llama2Template(Template):
|
||||
TEMPLATES: Dict[str, "Template"] = {}
|
||||
|
||||
|
||||
def _register_template(
|
||||
def register_template(
|
||||
name: str,
|
||||
format_user: Optional["Formatter"] = None,
|
||||
format_assistant: Optional["Formatter"] = None,
|
||||
@ -421,7 +421,7 @@ def _register_template(
|
||||
replace_eos: bool = False,
|
||||
replace_jinja_template: bool = False,
|
||||
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
|
||||
template_class: Type[Template] = Template,
|
||||
template_class: Type["Template"] = Template,
|
||||
) -> None:
|
||||
r"""
|
||||
Registers a chat template.
|
||||
@ -436,7 +436,7 @@ def _register_template(
|
||||
|
||||
The corresponding code should be:
|
||||
```
|
||||
_register_template(
|
||||
register_template(
|
||||
name="custom",
|
||||
format_user=StringFormatter(slots=["<user>{{content}}\n<model>"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}</s>\n"]),
|
||||
@ -444,6 +444,9 @@ def _register_template(
|
||||
)
|
||||
```
|
||||
"""
|
||||
if name in TEMPLATES:
|
||||
raise ValueError(f"Template {name} already exists.")
|
||||
|
||||
default_slots = ["{{content}}"] if efficient_eos else ["{{content}}", {"eos_token"}]
|
||||
default_user_formatter = StringFormatter(slots=["{{content}}"])
|
||||
default_assistant_formatter = StringFormatter(slots=default_slots)
|
||||
@ -562,7 +565,7 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
|
||||
return template
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="alpaca",
|
||||
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n\n"]),
|
||||
@ -573,7 +576,7 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="aquila",
|
||||
format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}###"]),
|
||||
@ -586,7 +589,7 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="atom",
|
||||
format_user=StringFormatter(
|
||||
slots=[{"bos_token"}, "Human: {{content}}\n", {"eos_token"}, {"bos_token"}, "Assistant:"]
|
||||
@ -595,21 +598,21 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="baichuan",
|
||||
format_user=StringFormatter(slots=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]),
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="baichuan2",
|
||||
format_user=StringFormatter(slots=["<reserved_106>{{content}}<reserved_107>"]),
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="belle",
|
||||
format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n\n"]),
|
||||
@ -617,13 +620,13 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="bluelm",
|
||||
format_user=StringFormatter(slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="breeze",
|
||||
format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
@ -631,7 +634,7 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="chatglm2",
|
||||
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
|
||||
@ -639,7 +642,7 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="chatglm3",
|
||||
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
|
||||
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
||||
@ -655,7 +658,7 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="chatml",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
@ -668,7 +671,7 @@ _register_template(
|
||||
|
||||
|
||||
# copied from chatml template
|
||||
_register_template(
|
||||
register_template(
|
||||
name="chatml_de",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
@ -681,13 +684,13 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="codegeex2",
|
||||
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="codegeex4",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
|
||||
@ -704,7 +707,7 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="cohere",
|
||||
format_user=StringFormatter(
|
||||
slots=[
|
||||
@ -719,7 +722,7 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="cpm",
|
||||
format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
@ -727,7 +730,7 @@ _register_template(
|
||||
|
||||
|
||||
# copied from chatml template
|
||||
_register_template(
|
||||
register_template(
|
||||
name="cpm3",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
@ -738,7 +741,7 @@ _register_template(
|
||||
|
||||
|
||||
# copied from chatml template
|
||||
_register_template(
|
||||
register_template(
|
||||
name="dbrx",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
@ -763,7 +766,7 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="deepseek",
|
||||
format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
|
||||
format_system=StringFormatter(slots=["{{content}}\n\n"]),
|
||||
@ -771,14 +774,14 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="deepseek3",
|
||||
format_user=StringFormatter(slots=["<|User|>{{content}}<|Assistant|>"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="deepseekcoder",
|
||||
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
|
||||
format_assistant=StringFormatter(slots=["\n{{content}}\n<|EOT|>\n"]),
|
||||
@ -792,7 +795,7 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="default",
|
||||
format_user=StringFormatter(slots=["Human: {{content}}\nAssistant:"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n"]),
|
||||
@ -800,13 +803,13 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="empty",
|
||||
format_assistant=StringFormatter(slots=["{{content}}"]),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="exaone",
|
||||
format_user=StringFormatter(slots=["[|user|]{{content}}\n[|assistant|]"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n"]),
|
||||
@ -814,7 +817,7 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="falcon",
|
||||
format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}\n"]),
|
||||
@ -822,14 +825,14 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="fewshot",
|
||||
format_assistant=StringFormatter(slots=["{{content}}\n\n"]),
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="gemma",
|
||||
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
|
||||
@ -840,7 +843,7 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="glm4",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
|
||||
format_assistant=StringFormatter(slots=["\n{{content}}"]),
|
||||
@ -854,7 +857,7 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="granite3",
|
||||
format_user=StringFormatter(
|
||||
slots=[
|
||||
@ -866,7 +869,7 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="index",
|
||||
format_user=StringFormatter(slots=["reserved_0{{content}}reserved_1"]),
|
||||
format_system=StringFormatter(slots=["<unk>{{content}}"]),
|
||||
@ -874,7 +877,7 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="intern",
|
||||
format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<eoa>\n"]),
|
||||
@ -891,7 +894,7 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="intern2",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
@ -908,7 +911,7 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="llama2",
|
||||
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
|
||||
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
|
||||
@ -917,7 +920,7 @@ _register_template(
|
||||
|
||||
|
||||
# copied from llama2 template
|
||||
_register_template(
|
||||
register_template(
|
||||
name="llama2_zh",
|
||||
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
|
||||
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
|
||||
@ -926,7 +929,7 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="llama3",
|
||||
format_user=StringFormatter(
|
||||
slots=[
|
||||
@ -954,7 +957,7 @@ _register_template(
|
||||
|
||||
|
||||
# copied from llama3 template
|
||||
_register_template(
|
||||
register_template(
|
||||
name="mllama",
|
||||
format_user=StringFormatter(
|
||||
slots=[
|
||||
@ -983,7 +986,7 @@ _register_template(
|
||||
|
||||
|
||||
# copied from vicuna template
|
||||
_register_template(
|
||||
register_template(
|
||||
name="llava",
|
||||
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
|
||||
default_system=(
|
||||
@ -995,7 +998,7 @@ _register_template(
|
||||
|
||||
|
||||
# copied from vicuna template
|
||||
_register_template(
|
||||
register_template(
|
||||
name="llava_next",
|
||||
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
|
||||
default_system=(
|
||||
@ -1007,7 +1010,7 @@ _register_template(
|
||||
|
||||
|
||||
# copied from llama3 template
|
||||
_register_template(
|
||||
register_template(
|
||||
name="llava_next_llama3",
|
||||
format_user=StringFormatter(
|
||||
slots=[
|
||||
@ -1036,7 +1039,7 @@ _register_template(
|
||||
|
||||
|
||||
# copied from mistral template
|
||||
_register_template(
|
||||
register_template(
|
||||
name="llava_next_mistral",
|
||||
format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]),
|
||||
format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]),
|
||||
@ -1051,7 +1054,7 @@ _register_template(
|
||||
|
||||
|
||||
# copied from qwen template
|
||||
_register_template(
|
||||
register_template(
|
||||
name="llava_next_qwen",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
@ -1068,7 +1071,7 @@ _register_template(
|
||||
|
||||
|
||||
# copied from chatml template
|
||||
_register_template(
|
||||
register_template(
|
||||
name="llava_next_yi",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
@ -1079,7 +1082,7 @@ _register_template(
|
||||
|
||||
|
||||
# copied from vicuna template
|
||||
_register_template(
|
||||
register_template(
|
||||
name="llava_next_video",
|
||||
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
|
||||
default_system=(
|
||||
@ -1091,7 +1094,7 @@ _register_template(
|
||||
|
||||
|
||||
# copied from mistral template
|
||||
_register_template(
|
||||
register_template(
|
||||
name="llava_next_video_mistral",
|
||||
format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]),
|
||||
format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]),
|
||||
@ -1106,7 +1109,7 @@ _register_template(
|
||||
|
||||
|
||||
# copied from chatml template
|
||||
_register_template(
|
||||
register_template(
|
||||
name="llava_next_video_yi",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
@ -1117,7 +1120,7 @@ _register_template(
|
||||
|
||||
|
||||
# copied from chatml template
|
||||
_register_template(
|
||||
register_template(
|
||||
name="marco",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
@ -1133,7 +1136,7 @@ _register_template(
|
||||
|
||||
|
||||
# copied from chatml template
|
||||
_register_template(
|
||||
register_template(
|
||||
name="minicpm_v",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
@ -1144,7 +1147,7 @@ _register_template(
|
||||
|
||||
|
||||
# copied from minicpm_v template
|
||||
_register_template(
|
||||
register_template(
|
||||
name="minicpm_o",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
@ -1155,7 +1158,7 @@ _register_template(
|
||||
|
||||
|
||||
# mistral tokenizer v3 tekken
|
||||
_register_template(
|
||||
register_template(
|
||||
name="ministral",
|
||||
format_user=StringFormatter(slots=["[INST]{{content}}[/INST]"]),
|
||||
format_system=StringFormatter(slots=["{{content}}\n\n"]),
|
||||
@ -1168,7 +1171,7 @@ _register_template(
|
||||
|
||||
|
||||
# mistral tokenizer v3
|
||||
_register_template(
|
||||
register_template(
|
||||
name="mistral",
|
||||
format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]),
|
||||
format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]),
|
||||
@ -1182,7 +1185,7 @@ _register_template(
|
||||
|
||||
|
||||
# mistral tokenizer v7 tekken (copied from ministral)
|
||||
_register_template(
|
||||
register_template(
|
||||
name="mistral_small",
|
||||
format_user=StringFormatter(slots=["[INST]{{content}}[/INST]"]),
|
||||
format_system=StringFormatter(slots=["[SYSTEM_PROMPT]{{content}}[/SYSTEM_PROMPT]"]),
|
||||
@ -1193,21 +1196,21 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="olmo",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"eos_token"}]),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="openchat",
|
||||
format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="openchat-3.6",
|
||||
format_user=StringFormatter(
|
||||
slots=[
|
||||
@ -1223,7 +1226,7 @@ _register_template(
|
||||
|
||||
|
||||
# copied from chatml template
|
||||
_register_template(
|
||||
register_template(
|
||||
name="opencoder",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
@ -1234,7 +1237,7 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="orion",
|
||||
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
@ -1242,7 +1245,7 @@ _register_template(
|
||||
|
||||
|
||||
# copied from gemma template
|
||||
_register_template(
|
||||
register_template(
|
||||
name="paligemma",
|
||||
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
|
||||
@ -1254,7 +1257,7 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="phi",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|end|>\n"]),
|
||||
@ -1263,7 +1266,7 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="phi_small",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|end|>\n"]),
|
||||
@ -1273,7 +1276,7 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="phi4",
|
||||
format_user=StringFormatter(
|
||||
slots=["<|im_start|>user<|im_sep|>{{content}}<|im_end|><|im_start|>assistant<|im_sep|>"]
|
||||
@ -1285,7 +1288,7 @@ _register_template(
|
||||
|
||||
|
||||
# copied from ministral template
|
||||
_register_template(
|
||||
register_template(
|
||||
name="pixtral",
|
||||
format_user=StringFormatter(slots=["[INST]{{content}}[/INST]"]),
|
||||
format_system=StringFormatter(slots=["{{content}}\n\n"]),
|
||||
@ -1299,7 +1302,7 @@ _register_template(
|
||||
|
||||
|
||||
# copied from chatml template
|
||||
_register_template(
|
||||
register_template(
|
||||
name="qwen",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
@ -1315,7 +1318,7 @@ _register_template(
|
||||
|
||||
|
||||
# copied from chatml template
|
||||
_register_template(
|
||||
register_template(
|
||||
name="qwen2_audio",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
@ -1327,7 +1330,7 @@ _register_template(
|
||||
|
||||
|
||||
# copied from qwen template
|
||||
_register_template(
|
||||
register_template(
|
||||
name="qwen2_vl",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
@ -1343,7 +1346,7 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="sailor",
|
||||
format_user=StringFormatter(slots=["<|im_start|>question\n{{content}}<|im_end|>\n<|im_start|>answer\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
@ -1357,7 +1360,7 @@ _register_template(
|
||||
|
||||
|
||||
# copied from llama3 template
|
||||
_register_template(
|
||||
register_template(
|
||||
name="skywork_o1",
|
||||
format_user=StringFormatter(
|
||||
slots=[
|
||||
@ -1391,7 +1394,7 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="solar",
|
||||
format_user=StringFormatter(slots=["### User:\n{{content}}\n\n### Assistant:\n"]),
|
||||
format_system=StringFormatter(slots=["### System:\n{{content}}\n\n"]),
|
||||
@ -1399,7 +1402,7 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="starchat",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|end|>\n"]),
|
||||
@ -1408,14 +1411,14 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="telechat",
|
||||
format_user=StringFormatter(slots=["<_user>{{content}}<_bot>"]),
|
||||
format_system=StringFormatter(slots=["<_system>{{content}}<_end>"]),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="telechat2",
|
||||
format_user=StringFormatter(slots=["<_user>{{content}}<_bot>"]),
|
||||
format_system=StringFormatter(slots=["<_system>{{content}}"]),
|
||||
@ -1425,7 +1428,7 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="vicuna",
|
||||
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
|
||||
default_system=(
|
||||
@ -1436,7 +1439,7 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="video_llava",
|
||||
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
|
||||
default_system=(
|
||||
@ -1447,7 +1450,7 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="xuanyuan",
|
||||
format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]),
|
||||
default_system=(
|
||||
@ -1458,13 +1461,13 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="xverse",
|
||||
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: "]),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="yayi",
|
||||
format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}\n\n"]),
|
||||
@ -1485,7 +1488,7 @@ _register_template(
|
||||
|
||||
|
||||
# copied from chatml template
|
||||
_register_template(
|
||||
register_template(
|
||||
name="yi",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
@ -1494,7 +1497,7 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="yi_vl",
|
||||
format_user=StringFormatter(slots=["### Human: {{content}}\n### Assistant:"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}\n"]),
|
||||
@ -1511,7 +1514,7 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="yuan",
|
||||
format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<eod>\n"]),
|
||||
@ -1519,7 +1522,7 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="zephyr",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
|
||||
@ -1527,7 +1530,7 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
register_template(
|
||||
name="ziya",
|
||||
format_user=StringFormatter(slots=["<human>:{{content}}\n<bot>:"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}\n"]),
|
||||
|
@ -1549,7 +1549,7 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Pixtral-12B-Instruct": {
|
||||
"Pixtral-12B": {
|
||||
DownloadSource.DEFAULT: "mistral-community/pixtral-12b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/pixtral-12b",
|
||||
}
|
||||
|
@ -16,7 +16,7 @@ from typing import Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from llamafactory.data.processors.processor_utils import infer_seqlen
|
||||
from llamafactory.data.processor.processor_utils import infer_seqlen
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
46
tests/data/test_converter.py
Normal file
46
tests/data/test_converter.py
Normal file
@ -0,0 +1,46 @@
|
||||
from llamafactory.data import Role
|
||||
from llamafactory.data.converter import get_dataset_converter
|
||||
from llamafactory.data.parser import DatasetAttr
|
||||
from llamafactory.hparams import DataArguments
|
||||
|
||||
|
||||
def test_alpaca_converter():
|
||||
dataset_attr = DatasetAttr("hf_hub", "llamafactory/tiny-supervised-dataset")
|
||||
data_args = DataArguments()
|
||||
example = {
|
||||
"instruction": "Solve the math problem.",
|
||||
"input": "3 + 4",
|
||||
"output": "The answer is 7.",
|
||||
}
|
||||
dataset_converter = get_dataset_converter("alpaca", dataset_attr, data_args)
|
||||
assert dataset_converter(example) == {
|
||||
"_prompt": [{"role": Role.USER.value, "content": "Solve the math problem.\n3 + 4"}],
|
||||
"_response": [{"role": Role.ASSISTANT.value, "content": "The answer is 7."}],
|
||||
"_system": "",
|
||||
"_tools": "",
|
||||
"_images": None,
|
||||
"_videos": None,
|
||||
"_audios": None,
|
||||
}
|
||||
|
||||
|
||||
def test_sharegpt_converter():
|
||||
dataset_attr = DatasetAttr("hf_hub", "llamafactory/tiny-supervised-dataset")
|
||||
data_args = DataArguments()
|
||||
example = {
|
||||
"conversations": [
|
||||
{"from": "system", "value": "You are a helpful assistant."},
|
||||
{"from": "human", "value": "Solve the math problem.\n3 + 4"},
|
||||
{"from": "gpt", "value": "The answer is 7."},
|
||||
]
|
||||
}
|
||||
dataset_converter = get_dataset_converter("sharegpt", dataset_attr, data_args)
|
||||
assert dataset_converter(example) == {
|
||||
"_prompt": [{"role": Role.USER.value, "content": "Solve the math problem.\n3 + 4"}],
|
||||
"_response": [{"role": Role.ASSISTANT.value, "content": "The answer is 7."}],
|
||||
"_system": "You are a helpful assistant.",
|
||||
"_tools": "",
|
||||
"_images": None,
|
||||
"_videos": None,
|
||||
"_audios": None,
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user