From 46203856fc65f2701cec6a0ccbd258b6f2215deb Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Thu, 13 Feb 2025 00:39:20 +0800 Subject: [PATCH] [breaking change] refactor data pipeline (#6901) * refactor data * rename file Former-commit-id: 7a1a4ce6451cb782573d0bd9dd27a5e443e3a18b --- src/llamafactory/data/aligner.py | 241 ---------------- src/llamafactory/data/converter.py | 269 ++++++++++++++++++ src/llamafactory/data/loader.py | 62 +++- src/llamafactory/data/mm_plugin.py | 20 +- src/llamafactory/data/parser.py | 53 ++-- src/llamafactory/data/preprocess.py | 111 -------- src/llamafactory/data/processor/__init__.py | 17 ++ src/llamafactory/data/processor/feedback.py | 129 +++++++++ src/llamafactory/data/processor/pairwise.py | 118 ++++++++ src/llamafactory/data/processor/pretrain.py | 57 ++++ .../processor_utils.py | 37 ++- src/llamafactory/data/processor/supervised.py | 200 +++++++++++++ .../data/processor/unsupervised.py | 91 ++++++ src/llamafactory/data/processors/__init__.py | 0 src/llamafactory/data/processors/feedback.py | 137 --------- src/llamafactory/data/processors/pairwise.py | 124 -------- src/llamafactory/data/processors/pretrain.py | 59 ---- .../data/processors/supervised.py | 226 --------------- .../data/processors/unsupervised.py | 107 ------- src/llamafactory/data/template.py | 169 +++++------ src/llamafactory/extras/constants.py | 2 +- .../test_feedback.py | 0 .../test_pairwise.py | 0 .../test_processor_utils.py | 2 +- .../test_supervised.py | 0 .../test_unsupervised.py | 0 tests/data/test_converter.py | 46 +++ 27 files changed, 1145 insertions(+), 1132 deletions(-) delete mode 100644 src/llamafactory/data/aligner.py create mode 100644 src/llamafactory/data/converter.py delete mode 100644 src/llamafactory/data/preprocess.py create mode 100644 src/llamafactory/data/processor/__init__.py create mode 100644 src/llamafactory/data/processor/feedback.py create mode 100644 src/llamafactory/data/processor/pairwise.py create mode 100644 src/llamafactory/data/processor/pretrain.py rename src/llamafactory/data/{processors => processor}/processor_utils.py (71%) create mode 100644 src/llamafactory/data/processor/supervised.py create mode 100644 src/llamafactory/data/processor/unsupervised.py delete mode 100644 src/llamafactory/data/processors/__init__.py delete mode 100644 src/llamafactory/data/processors/feedback.py delete mode 100644 src/llamafactory/data/processors/pairwise.py delete mode 100644 src/llamafactory/data/processors/pretrain.py delete mode 100644 src/llamafactory/data/processors/supervised.py delete mode 100644 src/llamafactory/data/processors/unsupervised.py rename tests/data/{processors => processor}/test_feedback.py (100%) rename tests/data/{processors => processor}/test_pairwise.py (100%) rename tests/data/{processors => processor}/test_processor_utils.py (94%) rename tests/data/{processors => processor}/test_supervised.py (100%) rename tests/data/{processors => processor}/test_unsupervised.py (100%) create mode 100644 tests/data/test_converter.py diff --git a/src/llamafactory/data/aligner.py b/src/llamafactory/data/aligner.py deleted file mode 100644 index b71964b0..00000000 --- a/src/llamafactory/data/aligner.py +++ /dev/null @@ -1,241 +0,0 @@ -# Copyright 2025 the LlamaFactory team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from functools import partial -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union - -from ..extras import logging -from .data_utils import Role - - -if TYPE_CHECKING: - from datasets import Dataset, IterableDataset - from transformers import Seq2SeqTrainingArguments - - from ..hparams import DataArguments - from .parser import DatasetAttr - - -logger = logging.get_logger(__name__) - - -def _regularize_medias( - inputs: Union[Any, Sequence[Any]], - dataset_attr: "DatasetAttr", - data_args: "DataArguments", -) -> Optional[List[Any]]: - r""" - Optionally concatenates media path to media dir when loading from local disk. - """ - if not isinstance(inputs, list): - inputs = [inputs] - elif len(inputs) == 0: - return None - else: - inputs = inputs[:] - - if dataset_attr.load_from in ["script", "file"]: - for i in range(len(inputs)): - if isinstance(inputs[i], str) and os.path.isfile(os.path.join(data_args.media_dir, inputs[i])): - inputs[i] = os.path.join(data_args.media_dir, inputs[i]) - - return inputs - - -def convert_alpaca( - example: Dict[str, Any], - dataset_attr: "DatasetAttr", - data_args: "DataArguments", -) -> Dict[str, Any]: - r""" - Converts alpaca format dataset to the standard format. - """ - prompt = [] - if dataset_attr.history and isinstance(example[dataset_attr.history], list): - for old_prompt, old_response in example[dataset_attr.history]: - prompt.append({"role": Role.USER.value, "content": old_prompt}) - prompt.append({"role": Role.ASSISTANT.value, "content": old_response}) - - query = [] - if dataset_attr.prompt and example[dataset_attr.prompt]: - query.append(example[dataset_attr.prompt]) - - if dataset_attr.query and example[dataset_attr.query]: - query.append(example[dataset_attr.query]) - - prompt.append({"role": Role.USER.value, "content": "\n".join(query)}) # "prompt\nquery" - - if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example - response = [{"role": Role.ASSISTANT.value, "content": example[dataset_attr.response]}] - if example[dataset_attr.kto_tag]: - response = response + [{"role": Role.ASSISTANT.value, "content": ""}] - else: - response = [{"role": Role.ASSISTANT.value, "content": ""}] + response - elif ( - dataset_attr.ranking - and isinstance(example[dataset_attr.chosen], str) - and isinstance(example[dataset_attr.rejected], str) - ): # pairwise example - response = [ - {"role": Role.ASSISTANT.value, "content": example[dataset_attr.chosen]}, - {"role": Role.ASSISTANT.value, "content": example[dataset_attr.rejected]}, - ] - elif dataset_attr.response and isinstance(example[dataset_attr.response], str): # normal example - response = [{"role": Role.ASSISTANT.value, "content": example[dataset_attr.response]}] - else: # unsupervised - response = [] - - regularize_medias = partial(_regularize_medias, dataset_attr=dataset_attr, data_args=data_args) - output = { - "_prompt": prompt, - "_response": response, - "_system": example[dataset_attr.system] if dataset_attr.system else "", - "_tools": example[dataset_attr.tools] if dataset_attr.tools else "", - "_images": regularize_medias(example[dataset_attr.images]) if dataset_attr.images else None, - "_videos": regularize_medias(example[dataset_attr.videos]) if dataset_attr.videos else None, - "_audios": regularize_medias(example[dataset_attr.audios]) if dataset_attr.audios else None, - } - return output - - -def convert_sharegpt( - example: Dict[str, Any], - dataset_attr: "DatasetAttr", - data_args: "DataArguments", -) -> Dict[str, Any]: - r""" - Converts sharegpt format dataset to the standard format. - """ - tag_mapping = { - dataset_attr.user_tag: Role.USER.value, - dataset_attr.assistant_tag: Role.ASSISTANT.value, - dataset_attr.observation_tag: Role.OBSERVATION.value, - dataset_attr.function_tag: Role.FUNCTION.value, - dataset_attr.system_tag: Role.SYSTEM.value, - } - odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag) - even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag) - accept_tags = (odd_tags, even_tags) - messages = example[dataset_attr.messages] - if ( - dataset_attr.system_tag - and len(messages) != 0 - and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag - ): - system = messages[0][dataset_attr.content_tag] - messages = messages[1:] - else: - system = example[dataset_attr.system] if dataset_attr.system else "" - - aligned_messages = [] - broken_data = False - for turn_idx, message in enumerate(messages): - if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]: - logger.warning_rank0(f"Invalid role tag in {messages}.") - broken_data = True - break - - aligned_messages.append( - {"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]} - ) - - if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or ( - dataset_attr.ranking and len(aligned_messages) % 2 == 0 - ): - logger.warning_rank0(f"Invalid message count in {messages}.") - broken_data = True - - if broken_data: - logger.warning_rank0("Skipping this abnormal example.") - prompt, response = [], [] - elif dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example - prompt = aligned_messages[:-1] - response = aligned_messages[-1:] - if example[dataset_attr.kto_tag]: - response = response + [{"role": Role.ASSISTANT.value, "content": ""}] - else: - response = [{"role": Role.ASSISTANT.value, "content": ""}] + response - elif ( - dataset_attr.ranking - and isinstance(example[dataset_attr.chosen], dict) - and isinstance(example[dataset_attr.rejected], dict) - ): # pairwise example - chosen = example[dataset_attr.chosen] - rejected = example[dataset_attr.rejected] - if ( - chosen[dataset_attr.role_tag] not in accept_tags[-1] - or rejected[dataset_attr.role_tag] not in accept_tags[-1] - ): - logger.warning_rank0(f"Invalid role tag in {[chosen, rejected]}.") - broken_data = True - - prompt = aligned_messages - response = [ - {"role": tag_mapping[chosen[dataset_attr.role_tag]], "content": chosen[dataset_attr.content_tag]}, - {"role": tag_mapping[rejected[dataset_attr.role_tag]], "content": rejected[dataset_attr.content_tag]}, - ] - else: # normal example - prompt = aligned_messages[:-1] - response = aligned_messages[-1:] - - regularize_medias = partial(_regularize_medias, dataset_attr=dataset_attr, data_args=data_args) - output = { - "_prompt": prompt, - "_response": response, - "_system": system, - "_tools": example[dataset_attr.tools] if dataset_attr.tools else "", - "_images": regularize_medias(example[dataset_attr.images]) if dataset_attr.images else None, - "_videos": regularize_medias(example[dataset_attr.videos]) if dataset_attr.videos else None, - "_audios": regularize_medias(example[dataset_attr.audios]) if dataset_attr.audios else None, - } - return output - - -def align_dataset( - dataset: Union["Dataset", "IterableDataset"], - dataset_attr: "DatasetAttr", - data_args: "DataArguments", - training_args: "Seq2SeqTrainingArguments", -) -> Union["Dataset", "IterableDataset"]: - r""" - Aligned dataset: - _prompt: [{"role": "user", "content": "..."}] * (2T - 1) - _response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset) - _system: "..." - _tools: "...", - _images: [], - _videos: [], - _audios: [], - """ - if dataset_attr.formatting == "alpaca": - convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args) - else: - convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr, data_args=data_args) - - column_names = list(next(iter(dataset)).keys()) - kwargs = {} - if not data_args.streaming: - kwargs = dict( - num_proc=data_args.preprocessing_num_workers, - load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0), - desc="Converting format of dataset", - ) - - return dataset.map( - convert_func, - batched=False, - remove_columns=column_names, - **kwargs, - ) diff --git a/src/llamafactory/data/converter.py b/src/llamafactory/data/converter.py new file mode 100644 index 00000000..2cab9b08 --- /dev/null +++ b/src/llamafactory/data/converter.py @@ -0,0 +1,269 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from abc import abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Type, Union + +from ..extras import logging +from .data_utils import Role + + +if TYPE_CHECKING: + from datasets import Dataset, IterableDataset + from transformers import Seq2SeqTrainingArguments + + from ..hparams import DataArguments + from .parser import DatasetAttr + +logger = logging.get_logger(__name__) + + +@dataclass +class DatasetConverter: + dataset_attr: "DatasetAttr" + data_args: "DataArguments" + + def _find_medias(self, inputs: Union[Any, Sequence[Any]]) -> Optional[List[Any]]: + r""" + Optionally concatenates media path to media dir when loading from local disk. + """ + if not isinstance(inputs, list): + inputs = [inputs] + elif len(inputs) == 0: + return None + else: + inputs = inputs[:] + + if self.dataset_attr.load_from in ["script", "file"]: + for i in range(len(inputs)): + if isinstance(inputs[i], str) and os.path.isfile(os.path.join(self.data_args.media_dir, inputs[i])): + inputs[i] = os.path.join(self.data_args.media_dir, inputs[i]) + + return inputs + + @abstractmethod + def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]: + r""" + Converts a single example in the dataset to the standard format. + """ + ... + + +@dataclass +class AlpacaDatasetConverter(DatasetConverter): + def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]: + prompt = [] + if self.dataset_attr.history and isinstance(example[self.dataset_attr.history], list): + for old_prompt, old_response in example[self.dataset_attr.history]: + prompt.append({"role": Role.USER.value, "content": old_prompt}) + prompt.append({"role": Role.ASSISTANT.value, "content": old_response}) + + query = [] + if self.dataset_attr.prompt and example[self.dataset_attr.prompt]: + query.append(example[self.dataset_attr.prompt]) + + if self.dataset_attr.query and example[self.dataset_attr.query]: + query.append(example[self.dataset_attr.query]) + + prompt.append({"role": Role.USER.value, "content": "\n".join(query)}) # "prompt\nquery" + + if self.dataset_attr.kto_tag and isinstance(example[self.dataset_attr.kto_tag], bool): # kto example + response = [{"role": Role.ASSISTANT.value, "content": example[self.dataset_attr.response]}] + if example[self.dataset_attr.kto_tag]: + response = response + [{"role": Role.ASSISTANT.value, "content": ""}] + else: + response = [{"role": Role.ASSISTANT.value, "content": ""}] + response + elif ( + self.dataset_attr.ranking + and isinstance(example[self.dataset_attr.chosen], str) + and isinstance(example[self.dataset_attr.rejected], str) + ): # pairwise example + response = [ + {"role": Role.ASSISTANT.value, "content": example[self.dataset_attr.chosen]}, + {"role": Role.ASSISTANT.value, "content": example[self.dataset_attr.rejected]}, + ] + elif self.dataset_attr.response and isinstance(example[self.dataset_attr.response], str): # normal example + response = [{"role": Role.ASSISTANT.value, "content": example[self.dataset_attr.response]}] + else: # unsupervised + response = [] + + output = { + "_prompt": prompt, + "_response": response, + "_system": example[self.dataset_attr.system] if self.dataset_attr.system else "", + "_tools": example[self.dataset_attr.tools] if self.dataset_attr.tools else "", + "_images": self._find_medias(example[self.dataset_attr.images]) if self.dataset_attr.images else None, + "_videos": self._find_medias(example[self.dataset_attr.videos]) if self.dataset_attr.videos else None, + "_audios": self._find_medias(example[self.dataset_attr.audios]) if self.dataset_attr.audios else None, + } + return output + + +@dataclass +class SharegptDatasetConverter(DatasetConverter): + def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]: + tag_mapping = { + self.dataset_attr.user_tag: Role.USER.value, + self.dataset_attr.assistant_tag: Role.ASSISTANT.value, + self.dataset_attr.observation_tag: Role.OBSERVATION.value, + self.dataset_attr.function_tag: Role.FUNCTION.value, + self.dataset_attr.system_tag: Role.SYSTEM.value, + } + odd_tags = (self.dataset_attr.user_tag, self.dataset_attr.observation_tag) + even_tags = (self.dataset_attr.assistant_tag, self.dataset_attr.function_tag) + accept_tags = (odd_tags, even_tags) + messages = example[self.dataset_attr.messages] + if ( + self.dataset_attr.system_tag + and len(messages) != 0 + and messages[0][self.dataset_attr.role_tag] == self.dataset_attr.system_tag + ): + system = messages[0][self.dataset_attr.content_tag] + messages = messages[1:] + else: + system = example[self.dataset_attr.system] if self.dataset_attr.system else "" + + aligned_messages = [] + broken_data = False + for turn_idx, message in enumerate(messages): + if message[self.dataset_attr.role_tag] not in accept_tags[turn_idx % 2]: + logger.warning_rank0(f"Invalid role tag in {messages}.") + broken_data = True + break + + aligned_messages.append( + { + "role": tag_mapping[message[self.dataset_attr.role_tag]], + "content": message[self.dataset_attr.content_tag], + } + ) + + if (not self.dataset_attr.ranking and len(aligned_messages) % 2 != 0) or ( + self.dataset_attr.ranking and len(aligned_messages) % 2 == 0 + ): + logger.warning_rank0(f"Invalid message count in {messages}.") + broken_data = True + + if broken_data: + logger.warning_rank0("Skipping this abnormal example.") + prompt, response = [], [] + elif self.dataset_attr.kto_tag and isinstance(example[self.dataset_attr.kto_tag], bool): # kto example + prompt = aligned_messages[:-1] + response = aligned_messages[-1:] + if example[self.dataset_attr.kto_tag]: + response = response + [{"role": Role.ASSISTANT.value, "content": ""}] + else: + response = [{"role": Role.ASSISTANT.value, "content": ""}] + response + elif ( + self.dataset_attr.ranking + and isinstance(example[self.dataset_attr.chosen], dict) + and isinstance(example[self.dataset_attr.rejected], dict) + ): # pairwise example + chosen = example[self.dataset_attr.chosen] + rejected = example[self.dataset_attr.rejected] + if ( + chosen[self.dataset_attr.role_tag] not in accept_tags[-1] + or rejected[self.dataset_attr.role_tag] not in accept_tags[-1] + ): + logger.warning_rank0(f"Invalid role tag in {[chosen, rejected]}.") + broken_data = True + + prompt = aligned_messages + response = [ + { + "role": tag_mapping[chosen[self.dataset_attr.role_tag]], + "content": chosen[self.dataset_attr.content_tag], + }, + { + "role": tag_mapping[rejected[self.dataset_attr.role_tag]], + "content": rejected[self.dataset_attr.content_tag], + }, + ] + else: # normal example + prompt = aligned_messages[:-1] + response = aligned_messages[-1:] + + output = { + "_prompt": prompt, + "_response": response, + "_system": system, + "_tools": example[self.dataset_attr.tools] if self.dataset_attr.tools else "", + "_images": self._find_medias(example[self.dataset_attr.images]) if self.dataset_attr.images else None, + "_videos": self._find_medias(example[self.dataset_attr.videos]) if self.dataset_attr.videos else None, + "_audios": self._find_medias(example[self.dataset_attr.audios]) if self.dataset_attr.audios else None, + } + return output + + +DATASET_CONVERTERS = { + "alpaca": AlpacaDatasetConverter, + "sharegpt": SharegptDatasetConverter, +} + + +def register_dataset_converter(name: str, dataset_converter: Type["DatasetConverter"]) -> None: + r""" + Register a new dataset converter. + """ + if name in DATASET_CONVERTERS: + raise ValueError(f"Dataset converter {name} already exists.") + + DATASET_CONVERTERS[name] = dataset_converter + + +def get_dataset_converter(name: str, dataset_attr: "DatasetAttr", data_args: "DataArguments") -> "DatasetConverter": + r""" + Gets a dataset converter. + """ + if name not in DATASET_CONVERTERS: + raise ValueError(f"Dataset converter {name} not found.") + + return DATASET_CONVERTERS[name](dataset_attr, data_args) + + +def align_dataset( + dataset: Union["Dataset", "IterableDataset"], + dataset_attr: "DatasetAttr", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", +) -> Union["Dataset", "IterableDataset"]: + r""" + Aligned dataset: + _prompt: [{"role": "user", "content": "..."}] * (2T - 1) + _response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset) + _system: "..." + _tools: "...", + _images: [], + _videos: [], + _audios: [], + """ + + column_names = list(next(iter(dataset)).keys()) + kwargs = {} + if not data_args.streaming: + kwargs = dict( + num_proc=data_args.preprocessing_num_workers, + load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0), + desc="Converting format of dataset", + ) + + dataset_converter = get_dataset_converter(dataset_attr.formatting, dataset_attr, data_args) + return dataset.map( + dataset_converter, + batched=False, + remove_columns=column_names, + **kwargs, + ) diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index 7e972c88..079fe4d0 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -22,10 +22,17 @@ from datasets import DatasetDict, load_dataset, load_from_disk from ..extras import logging from ..extras.constants import FILEEXT2TYPE from ..extras.misc import check_version, has_tokenized_data -from .aligner import align_dataset +from .converter import align_dataset from .data_utils import merge_dataset, split_dataset from .parser import get_dataset_list -from .preprocess import get_preprocess_and_print_func +from .processor import ( + FeedbackDatasetProcessor, + PackedSupervisedDatasetProcessor, + PairwiseDatasetProcessor, + PretrainDatasetProcessor, + SupervisedDatasetProcessor, + UnsupervisedDatasetProcessor, +) if TYPE_CHECKING: @@ -35,6 +42,7 @@ if TYPE_CHECKING: from ..hparams import DataArguments, ModelArguments from .data_utils import DatasetModule from .parser import DatasetAttr + from .processor import DatasetProcessor from .template import Template @@ -158,7 +166,7 @@ def _get_merged_dataset( stage: Literal["pt", "sft", "rm", "ppo", "kto"], ) -> Optional[Union["Dataset", "IterableDataset"]]: r""" - Gets the merged datasets in the standard format. + Returns the merged datasets in the standard format. """ if dataset_names is None: return None @@ -173,6 +181,48 @@ def _get_merged_dataset( return merge_dataset(datasets, data_args, seed=training_args.seed) +def _get_dataset_processor( + data_args: "DataArguments", + stage: Literal["pt", "sft", "rm", "ppo", "kto"], + template: "Template", + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"], + do_generate: bool = False, +) -> "DatasetProcessor": + r""" + Returns the corresponding dataset processor. + """ + if stage == "pt": + dataset_processor_class = PretrainDatasetProcessor + elif stage == "sft" and not do_generate: + if data_args.packing: + if data_args.neat_packing: # hack datasets to have int32 attention mask + from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence + + def __init__(self, data, **kwargs): + return TypedSequence.__init__( + self, + data, + type=kwargs.pop("type", None), + try_type=kwargs.pop("try_type", None), + optimized_int_type=kwargs.pop("optimized_int_type", None), + ) + + OptimizedTypedSequence.__init__ = __init__ + dataset_processor_class = PackedSupervisedDatasetProcessor + else: + dataset_processor_class = SupervisedDatasetProcessor + + elif stage == "rm": + dataset_processor_class = PairwiseDatasetProcessor + elif stage == "kto": + dataset_processor_class = FeedbackDatasetProcessor + else: + dataset_processor_class = UnsupervisedDatasetProcessor + + return dataset_processor_class(template=template, tokenizer=tokenizer, processor=processor, data_args=data_args) + + def _get_preprocessed_dataset( dataset: Optional[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", @@ -189,7 +239,7 @@ def _get_preprocessed_dataset( if dataset is None: return None - preprocess_func, print_function = get_preprocess_and_print_func( + dataset_processor = _get_dataset_processor( data_args, stage, template, tokenizer, processor, do_generate=(training_args.predict_with_generate and is_eval) ) column_names = list(next(iter(dataset)).keys()) @@ -202,7 +252,7 @@ def _get_preprocessed_dataset( ) dataset = dataset.map( - preprocess_func, + dataset_processor.preprocess_dataset, batched=True, batch_size=data_args.preprocessing_batch_size, remove_columns=column_names, @@ -212,7 +262,7 @@ def _get_preprocessed_dataset( if training_args.should_log: try: print("eval example:" if is_eval else "training example:") - print_function(next(iter(dataset))) + dataset_processor.print_data_example(next(iter(dataset))) except StopIteration: if stage == "pt": raise RuntimeError("Cannot find sufficient samples, consider increasing dataset size.") diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 2430a74d..8d69b5a2 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -4,7 +4,7 @@ import re from copy import deepcopy from dataclasses import dataclass from io import BytesIO -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Type, TypedDict, Union import numpy as np import torch @@ -1241,14 +1241,26 @@ PLUGINS = { } +def register_mm_plugin(name: str, plugin_class: Type["BasePlugin"]) -> None: + r""" + Registers a multimodal plugin. + """ + if name in PLUGINS: + raise ValueError(f"Multimodal plugin {name} already exists.") + + PLUGINS[name] = plugin_class + + def get_mm_plugin( name: str, image_token: Optional[str] = None, video_token: Optional[str] = None, audio_token: Optional[str] = None, ) -> "BasePlugin": - plugin_class = PLUGINS.get(name, None) - if plugin_class is None: + r""" + Gets plugin for multimodal inputs. + """ + if name not in PLUGINS: raise ValueError(f"Multimodal plugin `{name}` not found.") - return plugin_class(image_token, video_token, audio_token) + return PLUGINS[name](image_token, video_token, audio_token) diff --git a/src/llamafactory/data/parser.py b/src/llamafactory/data/parser.py index c2cb8dd7..ac6bc932 100644 --- a/src/llamafactory/data/parser.py +++ b/src/llamafactory/data/parser.py @@ -45,7 +45,7 @@ class DatasetAttr: images: Optional[str] = None videos: Optional[str] = None audios: Optional[str] = None - # rlhf columns + # dpo columns chosen: Optional[str] = None rejected: Optional[str] = None kto_tag: Optional[str] = None @@ -71,6 +71,26 @@ class DatasetAttr: def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None: setattr(self, key, obj.get(key, default)) + def join(self, attr: Dict[str, Any]) -> None: + self.set_attr("formatting", attr, default="alpaca") + self.set_attr("ranking", attr, default=False) + self.set_attr("subset", attr) + self.set_attr("split", attr, default="train") + self.set_attr("folder", attr) + self.set_attr("num_samples", attr) + + if "columns" in attr: + column_names = ["prompt", "query", "response", "history", "messages", "system", "tools"] + column_names += ["images", "videos", "audios", "chosen", "rejected", "kto_tag"] + for column_name in column_names: + self.set_attr(column_name, attr["columns"]) + + if "tags" in attr: + tag_names = ["role_tag", "content_tag"] + tag_names += ["user_tag", "assistant_tag", "observation_tag", "function_tag", "system_tag"] + for tag in tag_names: + self.set_attr(tag, attr["tags"]) + def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -> List["DatasetAttr"]: r""" @@ -128,36 +148,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) - else: dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"]) - dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca") - dataset_attr.set_attr("ranking", dataset_info[name], default=False) - dataset_attr.set_attr("subset", dataset_info[name]) - dataset_attr.set_attr("split", dataset_info[name], default="train") - dataset_attr.set_attr("folder", dataset_info[name]) - dataset_attr.set_attr("num_samples", dataset_info[name]) - - if "columns" in dataset_info[name]: - column_names = ["system", "tools", "images", "videos", "audios", "chosen", "rejected", "kto_tag"] - if dataset_attr.formatting == "alpaca": - column_names.extend(["prompt", "query", "response", "history"]) - else: - column_names.extend(["messages"]) - - for column_name in column_names: - dataset_attr.set_attr(column_name, dataset_info[name]["columns"]) - - if dataset_attr.formatting == "sharegpt" and "tags" in dataset_info[name]: - tag_names = ( - "role_tag", - "content_tag", - "user_tag", - "assistant_tag", - "observation_tag", - "function_tag", - "system_tag", - ) - for tag in tag_names: - dataset_attr.set_attr(tag, dataset_info[name]["tags"]) - + dataset_attr.join(dataset_info[name]) dataset_list.append(dataset_attr) return dataset_list diff --git a/src/llamafactory/data/preprocess.py b/src/llamafactory/data/preprocess.py deleted file mode 100644 index 27363791..00000000 --- a/src/llamafactory/data/preprocess.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright 2025 the LlamaFactory team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from functools import partial -from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple - -from .processors.feedback import preprocess_feedback_dataset -from .processors.pairwise import preprocess_pairwise_dataset, print_pairwise_dataset_example -from .processors.pretrain import preprocess_pretrain_dataset, print_pretrain_dataset_example -from .processors.supervised import ( - preprocess_packed_supervised_dataset, - preprocess_supervised_dataset, - print_supervised_dataset_example, -) -from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsupervised_dataset_example - - -if TYPE_CHECKING: - from transformers import PreTrainedTokenizer, ProcessorMixin - - from ..hparams import DataArguments - from .template import Template - - -def get_preprocess_and_print_func( - data_args: "DataArguments", - stage: Literal["pt", "sft", "rm", "ppo", "kto"], - template: "Template", - tokenizer: "PreTrainedTokenizer", - processor: Optional["ProcessorMixin"], - do_generate: bool = False, -) -> Tuple[Callable, Callable]: - if stage == "pt": - preprocess_func = partial( - preprocess_pretrain_dataset, - tokenizer=tokenizer, - data_args=data_args, - ) - print_function = partial(print_pretrain_dataset_example, tokenizer=tokenizer) - elif stage == "sft" and not do_generate: - if data_args.packing: - if data_args.neat_packing: # hack datasets to have int32 attention mask - from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence - - def __init__(self, data, **kwargs): - return TypedSequence.__init__( - self, - data, - type=kwargs.pop("type", None), - try_type=kwargs.pop("try_type", None), - optimized_int_type=kwargs.pop("optimized_int_type", None), - ) - - OptimizedTypedSequence.__init__ = __init__ - preprocess_func = partial( - preprocess_packed_supervised_dataset, - template=template, - tokenizer=tokenizer, - processor=processor, - data_args=data_args, - ) - else: - preprocess_func = partial( - preprocess_supervised_dataset, - template=template, - tokenizer=tokenizer, - processor=processor, - data_args=data_args, - ) - - print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer) - elif stage == "rm": - preprocess_func = partial( - preprocess_pairwise_dataset, - template=template, - tokenizer=tokenizer, - processor=processor, - data_args=data_args, - ) - print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer) - elif stage == "kto": - preprocess_func = partial( - preprocess_feedback_dataset, - template=template, - tokenizer=tokenizer, - processor=processor, - data_args=data_args, - ) - print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer) - else: - preprocess_func = partial( - preprocess_unsupervised_dataset, - template=template, - tokenizer=tokenizer, - processor=processor, - data_args=data_args, - ) - print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) - - return preprocess_func, print_function diff --git a/src/llamafactory/data/processor/__init__.py b/src/llamafactory/data/processor/__init__.py new file mode 100644 index 00000000..fa82a88e --- /dev/null +++ b/src/llamafactory/data/processor/__init__.py @@ -0,0 +1,17 @@ +from .feedback import FeedbackDatasetProcessor +from .pairwise import PairwiseDatasetProcessor +from .pretrain import PretrainDatasetProcessor +from .processor_utils import DatasetProcessor +from .supervised import PackedSupervisedDatasetProcessor, SupervisedDatasetProcessor +from .unsupervised import UnsupervisedDatasetProcessor + + +__all__ = [ + "DatasetProcessor", + "FeedbackDatasetProcessor", + "PairwiseDatasetProcessor", + "PretrainDatasetProcessor", + "PackedSupervisedDatasetProcessor", + "SupervisedDatasetProcessor", + "UnsupervisedDatasetProcessor", +] diff --git a/src/llamafactory/data/processor/feedback.py b/src/llamafactory/data/processor/feedback.py new file mode 100644 index 00000000..fb3c4803 --- /dev/null +++ b/src/llamafactory/data/processor/feedback.py @@ -0,0 +1,129 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple + +from ...extras import logging +from ...extras.constants import IGNORE_INDEX +from .processor_utils import DatasetProcessor, infer_seqlen + + +if TYPE_CHECKING: + from ..mm_plugin import AudioInput, ImageInput, VideoInput + + +logger = logging.get_logger(__name__) + + +class FeedbackDatasetProcessor(DatasetProcessor): + def _encode_data_example( + self, + prompt: Sequence[Dict[str, str]], + response: Sequence[Dict[str, str]], + kl_response: Sequence[Dict[str, str]], + system: Optional[str], + tools: Optional[str], + images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], + audios: Sequence["AudioInput"], + ) -> Tuple[List[int], List[int], List[int], List[int], bool]: + if response[0]["content"]: # desired example + kto_tag = True + messages = prompt + [response[0]] + else: # undesired example + kto_tag = False + messages = prompt + [response[1]] + + if kl_response[0]["content"]: + kl_messages = prompt + [kl_response[0]] + else: + kl_messages = prompt + [kl_response[1]] + + messages = self.template.mm_plugin.process_messages(messages, images, videos, audios, self.processor) + kl_messages = self.template.mm_plugin.process_messages(kl_messages, images, videos, audios, self.processor) + prompt_ids, response_ids = self.template.encode_oneturn(self.tokenizer, messages, system, tools) + kl_prompt_ids, kl_response_ids = self.template.encode_oneturn(self.tokenizer, kl_messages, system, tools) + + if self.template.efficient_eos: + response_ids += [self.tokenizer.eos_token_id] + kl_response_ids += [self.tokenizer.eos_token_id] + + prompt_ids, _ = self.template.mm_plugin.process_token_ids( + prompt_ids, None, images, videos, audios, self.tokenizer, self.processor + ) + kl_prompt_ids, _ = self.template.mm_plugin.process_token_ids( + kl_prompt_ids, None, images, videos, audios, self.tokenizer, self.processor + ) + + source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), self.data_args.cutoff_len) + prompt_ids = prompt_ids[:source_len] + response_ids = response_ids[:target_len] + kl_source_len, kl_target_len = infer_seqlen( + len(kl_prompt_ids), len(kl_response_ids), self.data_args.cutoff_len + ) + kl_prompt_ids = kl_prompt_ids[:kl_source_len] + kl_response_ids = kl_response_ids[:kl_target_len] + + input_ids = prompt_ids + response_ids + labels = [IGNORE_INDEX] * source_len + response_ids + kl_input_ids = kl_prompt_ids + kl_response_ids + kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids + return input_ids, labels, kl_input_ids, kl_labels, kto_tag + + def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + # create unrelated input-output pairs for estimating the KL term by flipping the matched pairs + kl_response = examples["_response"][::-1] + model_inputs = defaultdict(list) + for i in range(len(examples["_prompt"])): + if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2: + logger.warning_rank0( + "Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]) + ) + continue + + input_ids, labels, kl_input_ids, kl_labels, kto_tag = self._encode_data_example( + prompt=examples["_prompt"][i], + response=examples["_response"][i], + kl_response=kl_response[i], + system=examples["_system"][i], + tools=examples["_tools"][i], + images=examples["_images"][i] or [], + videos=examples["_videos"][i] or [], + audios=examples["_audios"][i] or [], + ) + model_inputs["input_ids"].append(input_ids) + model_inputs["attention_mask"].append([1] * len(input_ids)) + model_inputs["labels"].append(labels) + model_inputs["kl_input_ids"].append(kl_input_ids) + model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids)) + model_inputs["kl_labels"].append(kl_labels) + model_inputs["kto_tags"].append(kto_tag) + model_inputs["images"].append(examples["_images"][i]) + model_inputs["videos"].append(examples["_videos"][i]) + model_inputs["audios"].append(examples["_audios"][i]) + + desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag]) + undesirable_num = len(model_inputs["kto_tags"]) - desirable_num + if desirable_num == 0 or undesirable_num == 0: + logger.warning_rank0("Your dataset only has one preference type.") + + return model_inputs + + def print_data_example(self, example: Dict[str, List[int]]) -> None: + valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"])) + print("input_ids:\n{}".format(example["input_ids"])) + print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False))) + print("label_ids:\n{}".format(example["labels"])) + print(f"labels:\n{self.tokenizer.decode(valid_labels, skip_special_tokens=False)}") diff --git a/src/llamafactory/data/processor/pairwise.py b/src/llamafactory/data/processor/pairwise.py new file mode 100644 index 00000000..f30ebbf8 --- /dev/null +++ b/src/llamafactory/data/processor/pairwise.py @@ -0,0 +1,118 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple + +from ...extras import logging +from ...extras.constants import IGNORE_INDEX +from .processor_utils import DatasetProcessor, infer_seqlen + + +if TYPE_CHECKING: + from ..mm_plugin import AudioInput, ImageInput, VideoInput + + +logger = logging.get_logger(__name__) + + +class PairwiseDatasetProcessor(DatasetProcessor): + def _encode_data_example( + self, + prompt: Sequence[Dict[str, str]], + response: Sequence[Dict[str, str]], + system: Optional[str], + tools: Optional[str], + images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], + audios: Sequence["AudioInput"], + ) -> Tuple[List[int], List[int], List[int], List[int]]: + chosen_messages = self.template.mm_plugin.process_messages( + prompt + [response[0]], images, videos, audios, self.processor + ) + rejected_messages = self.template.mm_plugin.process_messages( + prompt + [response[1]], images, videos, audios, self.processor + ) + prompt_ids, chosen_ids = self.template.encode_oneturn(self.tokenizer, chosen_messages, system, tools) + _, rejected_ids = self.template.encode_oneturn(self.tokenizer, rejected_messages, system, tools) + + if self.template.efficient_eos: + chosen_ids += [self.tokenizer.eos_token_id] + rejected_ids += [self.tokenizer.eos_token_id] + + prompt_ids, _ = self.template.mm_plugin.process_token_ids( + prompt_ids, None, images, videos, audios, self.tokenizer, self.processor + ) + # consider the response is more important + source_len, target_len = infer_seqlen( + len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), self.data_args.cutoff_len + ) + prompt_ids = prompt_ids[:source_len] + chosen_ids = chosen_ids[:target_len] + rejected_ids = rejected_ids[:target_len] + + chosen_input_ids = prompt_ids + chosen_ids + chosen_labels = [IGNORE_INDEX] * source_len + chosen_ids + rejected_input_ids = prompt_ids + rejected_ids + rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids + return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels + + def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + # build input pairs with format ` X`, `Y1 ` and `Y2 ` + model_inputs = defaultdict(list) + for i in range(len(examples["_prompt"])): + if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2: + logger.warning_rank0( + "Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]) + ) + continue + + chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = self._encode_data_example( + prompt=examples["_prompt"][i], + response=examples["_response"][i], + system=examples["_system"][i], + tools=examples["_tools"][i], + images=examples["_images"][i] or [], + videos=examples["_videos"][i] or [], + audios=examples["_audios"][i] or [], + ) + model_inputs["chosen_input_ids"].append(chosen_input_ids) + model_inputs["chosen_attention_mask"].append([1] * len(chosen_input_ids)) + model_inputs["chosen_labels"].append(chosen_labels) + model_inputs["rejected_input_ids"].append(rejected_input_ids) + model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids)) + model_inputs["rejected_labels"].append(rejected_labels) + model_inputs["images"].append(examples["_images"][i]) + model_inputs["videos"].append(examples["_videos"][i]) + model_inputs["audios"].append(examples["_audios"][i]) + + return model_inputs + + def print_data_example(self, example: Dict[str, List[int]]) -> None: + valid_chosen_labels = list(filter(lambda x: x != IGNORE_INDEX, example["chosen_labels"])) + valid_rejected_labels = list(filter(lambda x: x != IGNORE_INDEX, example["rejected_labels"])) + print("chosen_input_ids:\n{}".format(example["chosen_input_ids"])) + print( + "chosen_inputs:\n{}".format(self.tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False)) + ) + print("chosen_label_ids:\n{}".format(example["chosen_labels"])) + print(f"chosen_labels:\n{self.tokenizer.decode(valid_chosen_labels, skip_special_tokens=False)}") + print("rejected_input_ids:\n{}".format(example["rejected_input_ids"])) + print( + "rejected_inputs:\n{}".format( + self.tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False) + ) + ) + print("rejected_label_ids:\n{}".format(example["rejected_labels"])) + print(f"rejected_labels:\n{self.tokenizer.decode(valid_rejected_labels, skip_special_tokens=False)}") diff --git a/src/llamafactory/data/processor/pretrain.py b/src/llamafactory/data/processor/pretrain.py new file mode 100644 index 00000000..87e35ad1 --- /dev/null +++ b/src/llamafactory/data/processor/pretrain.py @@ -0,0 +1,57 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from itertools import chain +from typing import Any, Dict, List + +from .processor_utils import DatasetProcessor + + +@dataclass +class PretrainDatasetProcessor(DatasetProcessor): + def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + # build grouped texts with format `X1 X2 X3 ...` if packing is enabled + eos_token = "<|end_of_text|>" if self.data_args.template == "llama3" else self.tokenizer.eos_token + text_examples = [messages[0]["content"] + eos_token for messages in examples["_prompt"]] + + if not self.data_args.packing: + if getattr(self.tokenizer, "add_bos_token", False): + text_examples = [self.tokenizer.bos_token + example for example in text_examples] + + result = self.tokenizer( + text_examples, add_special_tokens=False, truncation=True, max_length=self.data_args.cutoff_len + ) + else: + tokenized_examples = self.tokenizer(text_examples, add_special_tokens=False) + concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()} + total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]]) + block_size = self.data_args.cutoff_len + total_length = (total_length // block_size) * block_size + result = { + k: [t[i : i + block_size] for i in range(0, total_length, block_size)] + for k, t in concatenated_examples.items() + } + if getattr(self.tokenizer, "add_bos_token", False): + for i in range(len(result["input_ids"])): + result["input_ids"][i][0] = self.tokenizer.bos_token_id + + return result + + def print_data_example(self, example: Dict[str, List[int]]) -> None: + print("input_ids:\n{}".format(example["input_ids"])) + print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False))) diff --git a/src/llamafactory/data/processors/processor_utils.py b/src/llamafactory/data/processor/processor_utils.py similarity index 71% rename from src/llamafactory/data/processors/processor_utils.py rename to src/llamafactory/data/processor/processor_utils.py index 95198623..9e5cb086 100644 --- a/src/llamafactory/data/processors/processor_utils.py +++ b/src/llamafactory/data/processor/processor_utils.py @@ -13,7 +13,42 @@ # limitations under the License. import bisect -from typing import List, Sequence, Tuple +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple + + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer, ProcessorMixin + + from ...hparams import DataArguments + from ..template import Template + + +@dataclass +class DatasetProcessor(ABC): + r""" + A class for data processors. + """ + + template: "Template" + tokenizer: "PreTrainedTokenizer" + processor: Optional["ProcessorMixin"] + data_args: "DataArguments" + + @abstractmethod + def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + r""" + Builds model inputs from the examples. + """ + ... + + @abstractmethod + def print_data_example(self, example: Dict[str, List[int]]) -> None: + r""" + Print a data example to stdout. + """ + ... def search_for_fit(numbers: Sequence[int], capacity: int) -> int: diff --git a/src/llamafactory/data/processor/supervised.py b/src/llamafactory/data/processor/supervised.py new file mode 100644 index 00000000..e83de97b --- /dev/null +++ b/src/llamafactory/data/processor/supervised.py @@ -0,0 +1,200 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple + +from ...extras import logging +from ...extras.constants import IGNORE_INDEX +from .processor_utils import DatasetProcessor, greedy_knapsack, infer_seqlen + + +if TYPE_CHECKING: + from ..mm_plugin import AudioInput, ImageInput, VideoInput + + +logger = logging.get_logger(__name__) + + +@dataclass +class SupervisedDatasetProcessor(DatasetProcessor): + def _encode_data_example( + self, + prompt: Sequence[Dict[str, str]], + response: Sequence[Dict[str, str]], + system: Optional[str], + tools: Optional[str], + images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], + audios: Sequence["AudioInput"], + ) -> Tuple[List[int], List[int]]: + messages = self.template.mm_plugin.process_messages(prompt + response, images, videos, audios, self.processor) + input_ids, labels = self.template.mm_plugin.process_token_ids( + [], [], images, videos, audios, self.tokenizer, self.processor + ) + encoded_pairs = self.template.encode_multiturn(self.tokenizer, messages, system, tools) + total_length = len(input_ids) + (1 if self.template.efficient_eos else 0) + if self.data_args.mask_history: + encoded_pairs = encoded_pairs[::-1] # high priority for last turns + + for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs): + if total_length >= self.data_args.cutoff_len: + break + + source_len, target_len = infer_seqlen( + len(source_ids), len(target_ids), self.data_args.cutoff_len - total_length + ) + source_ids = source_ids[:source_len] + target_ids = target_ids[:target_len] + total_length += source_len + target_len + + if self.data_args.train_on_prompt: + source_label = source_ids + elif self.template.efficient_eos: + source_label = [self.tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1) + else: + source_label = [IGNORE_INDEX] * source_len + + if self.data_args.mask_history and turn_idx != 0: # train on the last turn only + target_label = [IGNORE_INDEX] * target_len + else: + target_label = target_ids + + if self.data_args.mask_history: # reversed sequences + input_ids = source_ids + target_ids + input_ids + labels = source_label + target_label + labels + else: + input_ids += source_ids + target_ids + labels += source_label + target_label + + if self.template.efficient_eos: + input_ids += [self.tokenizer.eos_token_id] + labels += [self.tokenizer.eos_token_id] + + return input_ids, labels + + def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + # build inputs with format ` X Y ` and labels with format ` ... Y ` + # for multiturn examples, we only mask the prompt part in each prompt-response pair. + model_inputs = defaultdict(list) + for i in range(len(examples["_prompt"])): + if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1: + logger.warning_rank0( + "Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]) + ) + continue + + input_ids, labels = self._encode_data_example( + prompt=examples["_prompt"][i], + response=examples["_response"][i], + system=examples["_system"][i], + tools=examples["_tools"][i], + images=examples["_images"][i] or [], + videos=examples["_videos"][i] or [], + audios=examples["_audios"][i] or [], + ) + model_inputs["input_ids"].append(input_ids) + model_inputs["attention_mask"].append([1] * len(input_ids)) + model_inputs["labels"].append(labels) + model_inputs["images"].append(examples["_images"][i]) + model_inputs["videos"].append(examples["_videos"][i]) + model_inputs["audios"].append(examples["_audios"][i]) + + return model_inputs + + def print_data_example(self, example: Dict[str, List[int]]) -> None: + valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"])) + print("input_ids:\n{}".format(example["input_ids"])) + print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False))) + print("label_ids:\n{}".format(example["labels"])) + print(f"labels:\n{self.tokenizer.decode(valid_labels, skip_special_tokens=False)}") + + +@dataclass +class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor): + def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + # TODO: use `position_ids` to achieve packing + # build inputs with format ` X1 Y1 X2 Y2 ` + # and labels with format ` ... Y1 ... Y2 ` + valid_num = 0 + batch_input_ids, batch_labels, batch_images, batch_videos, batch_audios = [], [], [], [], [] + lengths = [] + length2indexes = defaultdict(list) + for i in range(len(examples["_prompt"])): + if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1: + logger.warning_rank0( + "Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]) + ) + continue + + input_ids, labels = self._encode_data_example( + prompt=examples["_prompt"][i], + response=examples["_response"][i], + system=examples["_system"][i], + tools=examples["_tools"][i], + images=examples["_images"][i] or [], + videos=examples["_videos"][i] or [], + audios=examples["_audios"][i] or [], + ) + length = len(input_ids) + if length > self.data_args.cutoff_len: + logger.warning_rank0(f"Dropped lengthy example with length {length} > {self.data_args.cutoff_len}.") + else: + lengths.append(length) + length2indexes[length].append(valid_num) + batch_input_ids.append(input_ids) + batch_labels.append(labels) + batch_images.append(examples["_images"][i] or []) + batch_videos.append(examples["_videos"][i] or []) + batch_audios.append(examples["_audios"][i] or []) + valid_num += 1 + + model_inputs = defaultdict(list) + knapsacks = greedy_knapsack(lengths, self.data_args.cutoff_len) + for knapsack in knapsacks: + packed_input_ids, packed_attention_masks, packed_labels = [], [], [] + packed_images, packed_videos, packed_audios = [], [], [] + for i, length in enumerate(knapsack): + index = length2indexes[length].pop() + packed_input_ids += batch_input_ids[index] + packed_labels += batch_labels[index] + packed_images += batch_images[index] + packed_videos += batch_videos[index] + packed_audios += batch_audios[index] + if self.data_args.neat_packing: + packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1 + else: + packed_attention_masks += [1] * len(batch_input_ids[index]) + + if len(packed_input_ids) < self.data_args.cutoff_len + 1: # avoid flash_attn drops attn mask + pad_length = self.data_args.cutoff_len - len(packed_input_ids) + 1 + packed_input_ids += [self.tokenizer.pad_token_id] * pad_length + packed_labels += [IGNORE_INDEX] * pad_length + if self.data_args.neat_packing: + packed_attention_masks += [0] * pad_length + else: + packed_attention_masks += [1] * pad_length # more efficient flash_attn + + if len(packed_input_ids) != self.data_args.cutoff_len + 1: + raise ValueError("The length of packed example should be identical to the cutoff length.") + + model_inputs["input_ids"].append(packed_input_ids) + model_inputs["attention_mask"].append(packed_attention_masks) + model_inputs["labels"].append(packed_labels) + model_inputs["images"].append(packed_images or None) + model_inputs["videos"].append(packed_videos or None) + model_inputs["audios"].append(packed_audios or None) + + return model_inputs diff --git a/src/llamafactory/data/processor/unsupervised.py b/src/llamafactory/data/processor/unsupervised.py new file mode 100644 index 00000000..38a0b442 --- /dev/null +++ b/src/llamafactory/data/processor/unsupervised.py @@ -0,0 +1,91 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple + +from ...extras import logging +from ..data_utils import Role +from .processor_utils import DatasetProcessor, infer_seqlen + + +if TYPE_CHECKING: + from ..mm_plugin import AudioInput, ImageInput, VideoInput + + +logger = logging.get_logger(__name__) + + +class UnsupervisedDatasetProcessor(DatasetProcessor): + def _encode_data_example( + self, + prompt: Sequence[Dict[str, str]], + response: Sequence[Dict[str, str]], + system: Optional[str], + tools: Optional[str], + images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], + audios: Sequence["AudioInput"], + ) -> Tuple[List[int], List[int]]: + if len(response) == 1: + messages = prompt + response + else: + messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}] + + messages = self.template.mm_plugin.process_messages(messages, images, videos, audios, self.processor) + input_ids, labels = self.template.encode_oneturn(self.tokenizer, messages, system, tools) + if self.template.efficient_eos: + labels += [self.tokenizer.eos_token_id] + + input_ids, _ = self.template.mm_plugin.process_token_ids( + input_ids, None, images, videos, audios, self.tokenizer, self.processor + ) + source_len, target_len = infer_seqlen(len(input_ids), len(labels), self.data_args.cutoff_len) + input_ids = input_ids[:source_len] + labels = labels[:target_len] + return input_ids, labels + + def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + # build inputs with format ` X` and labels with format `Y ` + model_inputs = defaultdict(list) + for i in range(len(examples["_prompt"])): + if len(examples["_prompt"][i]) % 2 != 1: + logger.warning_rank0( + "Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]) + ) + continue + + input_ids, labels = self._encode_data_example( + prompt=examples["_prompt"][i], + response=examples["_response"][i], + system=examples["_system"][i], + tools=examples["_tools"][i], + images=examples["_images"][i] or [], + videos=examples["_videos"][i] or [], + audios=examples["_audios"][i] or [], + ) + model_inputs["input_ids"].append(input_ids) + model_inputs["attention_mask"].append([1] * len(input_ids)) + model_inputs["labels"].append(labels) + model_inputs["images"].append(examples["_images"][i]) + model_inputs["videos"].append(examples["_videos"][i]) + model_inputs["audios"].append(examples["_audios"][i]) + + return model_inputs + + def print_data_example(self, example: Dict[str, List[int]]) -> None: + print("input_ids:\n{}".format(example["input_ids"])) + print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False))) + print("label_ids:\n{}".format(example["labels"])) + print("labels:\n{}".format(self.tokenizer.decode(example["labels"], skip_special_tokens=False))) diff --git a/src/llamafactory/data/processors/__init__.py b/src/llamafactory/data/processors/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/llamafactory/data/processors/feedback.py b/src/llamafactory/data/processors/feedback.py deleted file mode 100644 index 2451fefb..00000000 --- a/src/llamafactory/data/processors/feedback.py +++ /dev/null @@ -1,137 +0,0 @@ -# Copyright 2025 the LlamaFactory team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from collections import defaultdict -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple - -from ...extras import logging -from ...extras.constants import IGNORE_INDEX -from .processor_utils import infer_seqlen - - -if TYPE_CHECKING: - from transformers import PreTrainedTokenizer, ProcessorMixin - - from ...hparams import DataArguments - from ..mm_plugin import AudioInput, ImageInput, VideoInput - from ..template import Template - - -logger = logging.get_logger(__name__) - - -def _encode_feedback_example( - prompt: Sequence[Dict[str, str]], - response: Sequence[Dict[str, str]], - kl_response: Sequence[Dict[str, str]], - system: Optional[str], - tools: Optional[str], - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - audios: Sequence["AudioInput"], - template: "Template", - tokenizer: "PreTrainedTokenizer", - processor: Optional["ProcessorMixin"], - cutoff_len: int, -) -> Tuple[List[int], List[int], List[int], List[int], bool]: - if response[0]["content"]: # desired example - kto_tag = True - messages = prompt + [response[0]] - else: # undesired example - kto_tag = False - messages = prompt + [response[1]] - - if kl_response[0]["content"]: - kl_messages = prompt + [kl_response[0]] - else: - kl_messages = prompt + [kl_response[1]] - - messages = template.mm_plugin.process_messages(messages, images, videos, audios, processor) - kl_messages = template.mm_plugin.process_messages(kl_messages, images, videos, audios, processor) - prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools) - kl_prompt_ids, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools) - - if template.efficient_eos: - response_ids += [tokenizer.eos_token_id] - kl_response_ids += [tokenizer.eos_token_id] - - prompt_ids, _ = template.mm_plugin.process_token_ids( - prompt_ids, None, images, videos, audios, tokenizer, processor - ) - kl_prompt_ids, _ = template.mm_plugin.process_token_ids( - kl_prompt_ids, None, images, videos, audios, tokenizer, processor - ) - - source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), cutoff_len) - prompt_ids = prompt_ids[:source_len] - response_ids = response_ids[:target_len] - kl_source_len, kl_target_len = infer_seqlen(len(kl_prompt_ids), len(kl_response_ids), cutoff_len) - kl_prompt_ids = kl_prompt_ids[:kl_source_len] - kl_response_ids = kl_response_ids[:kl_target_len] - - input_ids = prompt_ids + response_ids - labels = [IGNORE_INDEX] * source_len + response_ids - kl_input_ids = kl_prompt_ids + kl_response_ids - kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids - return input_ids, labels, kl_input_ids, kl_labels, kto_tag - - -def preprocess_feedback_dataset( - examples: Dict[str, List[Any]], - template: "Template", - tokenizer: "PreTrainedTokenizer", - processor: Optional["ProcessorMixin"], - data_args: "DataArguments", -) -> Dict[str, List[Any]]: - # create unrelated input-output pairs for estimating the KL term by flipping the matched pairs - kl_response = examples["_response"][::-1] - model_inputs = defaultdict(list) - for i in range(len(examples["_prompt"])): - if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2: - logger.warning_rank0( - "Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]) - ) - continue - - input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example( - prompt=examples["_prompt"][i], - response=examples["_response"][i], - kl_response=kl_response[i], - system=examples["_system"][i], - tools=examples["_tools"][i], - images=examples["_images"][i] or [], - videos=examples["_videos"][i] or [], - audios=examples["_audios"][i] or [], - template=template, - tokenizer=tokenizer, - processor=processor, - cutoff_len=data_args.cutoff_len, - ) - model_inputs["input_ids"].append(input_ids) - model_inputs["attention_mask"].append([1] * len(input_ids)) - model_inputs["labels"].append(labels) - model_inputs["kl_input_ids"].append(kl_input_ids) - model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids)) - model_inputs["kl_labels"].append(kl_labels) - model_inputs["kto_tags"].append(kto_tag) - model_inputs["images"].append(examples["_images"][i]) - model_inputs["videos"].append(examples["_videos"][i]) - model_inputs["audios"].append(examples["_audios"][i]) - - desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag]) - undesirable_num = len(model_inputs["kto_tags"]) - desirable_num - if desirable_num == 0 or undesirable_num == 0: - logger.warning_rank0("Your dataset only has one preference type.") - - return model_inputs diff --git a/src/llamafactory/data/processors/pairwise.py b/src/llamafactory/data/processors/pairwise.py deleted file mode 100644 index 2445de69..00000000 --- a/src/llamafactory/data/processors/pairwise.py +++ /dev/null @@ -1,124 +0,0 @@ -# Copyright 2025 the LlamaFactory team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from collections import defaultdict -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple - -from ...extras import logging -from ...extras.constants import IGNORE_INDEX -from .processor_utils import infer_seqlen - - -if TYPE_CHECKING: - from transformers import PreTrainedTokenizer, ProcessorMixin - - from ...hparams import DataArguments - from ..mm_plugin import AudioInput, ImageInput, VideoInput - from ..template import Template - - -logger = logging.get_logger(__name__) - - -def _encode_pairwise_example( - prompt: Sequence[Dict[str, str]], - response: Sequence[Dict[str, str]], - system: Optional[str], - tools: Optional[str], - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - audios: Sequence["AudioInput"], - template: "Template", - tokenizer: "PreTrainedTokenizer", - processor: Optional["ProcessorMixin"], - cutoff_len: int, -) -> Tuple[List[int], List[int], List[int], List[int]]: - chosen_messages = template.mm_plugin.process_messages(prompt + [response[0]], images, videos, audios, processor) - rejected_messages = template.mm_plugin.process_messages(prompt + [response[1]], images, videos, audios, processor) - prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools) - _, rejected_ids = template.encode_oneturn(tokenizer, rejected_messages, system, tools) - - if template.efficient_eos: - chosen_ids += [tokenizer.eos_token_id] - rejected_ids += [tokenizer.eos_token_id] - - prompt_ids, _ = template.mm_plugin.process_token_ids( - prompt_ids, None, images, videos, audios, tokenizer, processor - ) - # consider the response is more important - source_len, target_len = infer_seqlen(len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), cutoff_len) - prompt_ids = prompt_ids[:source_len] - chosen_ids = chosen_ids[:target_len] - rejected_ids = rejected_ids[:target_len] - - chosen_input_ids = prompt_ids + chosen_ids - chosen_labels = [IGNORE_INDEX] * source_len + chosen_ids - rejected_input_ids = prompt_ids + rejected_ids - rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids - return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels - - -def preprocess_pairwise_dataset( - examples: Dict[str, List[Any]], - template: "Template", - tokenizer: "PreTrainedTokenizer", - processor: Optional["ProcessorMixin"], - data_args: "DataArguments", -) -> Dict[str, List[Any]]: - # build input pairs with format ` X`, `Y1 ` and `Y2 ` - model_inputs = defaultdict(list) - for i in range(len(examples["_prompt"])): - if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2: - logger.warning_rank0( - "Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]) - ) - continue - - chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example( - prompt=examples["_prompt"][i], - response=examples["_response"][i], - system=examples["_system"][i], - tools=examples["_tools"][i], - images=examples["_images"][i] or [], - videos=examples["_videos"][i] or [], - audios=examples["_audios"][i] or [], - template=template, - tokenizer=tokenizer, - processor=processor, - cutoff_len=data_args.cutoff_len, - ) - model_inputs["chosen_input_ids"].append(chosen_input_ids) - model_inputs["chosen_attention_mask"].append([1] * len(chosen_input_ids)) - model_inputs["chosen_labels"].append(chosen_labels) - model_inputs["rejected_input_ids"].append(rejected_input_ids) - model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids)) - model_inputs["rejected_labels"].append(rejected_labels) - model_inputs["images"].append(examples["_images"][i]) - model_inputs["videos"].append(examples["_videos"][i]) - model_inputs["audios"].append(examples["_audios"][i]) - - return model_inputs - - -def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None: - valid_chosen_labels = list(filter(lambda x: x != IGNORE_INDEX, example["chosen_labels"])) - valid_rejected_labels = list(filter(lambda x: x != IGNORE_INDEX, example["rejected_labels"])) - print("chosen_input_ids:\n{}".format(example["chosen_input_ids"])) - print("chosen_inputs:\n{}".format(tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False))) - print("chosen_label_ids:\n{}".format(example["chosen_labels"])) - print(f"chosen_labels:\n{tokenizer.decode(valid_chosen_labels, skip_special_tokens=False)}") - print("rejected_input_ids:\n{}".format(example["rejected_input_ids"])) - print("rejected_inputs:\n{}".format(tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False))) - print("rejected_label_ids:\n{}".format(example["rejected_labels"])) - print(f"rejected_labels:\n{tokenizer.decode(valid_rejected_labels, skip_special_tokens=False)}") diff --git a/src/llamafactory/data/processors/pretrain.py b/src/llamafactory/data/processors/pretrain.py deleted file mode 100644 index c8fa3b26..00000000 --- a/src/llamafactory/data/processors/pretrain.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. -# -# This code is inspired by the HuggingFace's transformers library. -# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from itertools import chain -from typing import TYPE_CHECKING, Any, Dict, List - - -if TYPE_CHECKING: - from transformers import PreTrainedTokenizer - - from ...hparams import DataArguments - - -def preprocess_pretrain_dataset( - examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments" -) -> Dict[str, List[Any]]: - # build grouped texts with format `X1 X2 X3 ...` if packing is enabled - eos_token = "<|end_of_text|>" if data_args.template == "llama3" else tokenizer.eos_token - text_examples = [messages[0]["content"] + eos_token for messages in examples["_prompt"]] - - if not data_args.packing: - if getattr(tokenizer, "add_bos_token", False): - text_examples = [tokenizer.bos_token + example for example in text_examples] - - result = tokenizer(text_examples, add_special_tokens=False, truncation=True, max_length=data_args.cutoff_len) - else: - tokenized_examples = tokenizer(text_examples, add_special_tokens=False) - concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()} - total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]]) - block_size = data_args.cutoff_len - total_length = (total_length // block_size) * block_size - result = { - k: [t[i : i + block_size] for i in range(0, total_length, block_size)] - for k, t in concatenated_examples.items() - } - if getattr(tokenizer, "add_bos_token", False): - for i in range(len(result["input_ids"])): - result["input_ids"][i][0] = tokenizer.bos_token_id - - return result - - -def print_pretrain_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None: - print("input_ids:\n{}".format(example["input_ids"])) - print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) diff --git a/src/llamafactory/data/processors/supervised.py b/src/llamafactory/data/processors/supervised.py deleted file mode 100644 index 771b3cba..00000000 --- a/src/llamafactory/data/processors/supervised.py +++ /dev/null @@ -1,226 +0,0 @@ -# Copyright 2025 the LlamaFactory team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from collections import defaultdict -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple - -from ...extras import logging -from ...extras.constants import IGNORE_INDEX -from .processor_utils import greedy_knapsack, infer_seqlen - - -if TYPE_CHECKING: - from transformers import PreTrainedTokenizer, ProcessorMixin - - from ...hparams import DataArguments - from ..mm_plugin import AudioInput, ImageInput, VideoInput - from ..template import Template - - -logger = logging.get_logger(__name__) - - -def _encode_supervised_example( - prompt: Sequence[Dict[str, str]], - response: Sequence[Dict[str, str]], - system: Optional[str], - tools: Optional[str], - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - audios: Sequence["AudioInput"], - template: "Template", - tokenizer: "PreTrainedTokenizer", - processor: Optional["ProcessorMixin"], - cutoff_len: int, - train_on_prompt: bool, - mask_history: bool, -) -> Tuple[List[int], List[int]]: - messages = template.mm_plugin.process_messages(prompt + response, images, videos, audios, processor) - input_ids, labels = template.mm_plugin.process_token_ids([], [], images, videos, audios, tokenizer, processor) - encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools) - total_length = len(input_ids) + (1 if template.efficient_eos else 0) - if mask_history: - encoded_pairs = encoded_pairs[::-1] # high priority for last turns - - for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs): - if total_length >= cutoff_len: - break - - source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), cutoff_len - total_length) - source_ids = source_ids[:source_len] - target_ids = target_ids[:target_len] - total_length += source_len + target_len - - if train_on_prompt: - source_label = source_ids - elif template.efficient_eos: - source_label = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1) - else: - source_label = [IGNORE_INDEX] * source_len - - if mask_history and turn_idx != 0: # train on the last turn only - target_label = [IGNORE_INDEX] * target_len - else: - target_label = target_ids - - if mask_history: # reversed sequences - input_ids = source_ids + target_ids + input_ids - labels = source_label + target_label + labels - else: - input_ids += source_ids + target_ids - labels += source_label + target_label - - if template.efficient_eos: - input_ids += [tokenizer.eos_token_id] - labels += [tokenizer.eos_token_id] - - return input_ids, labels - - -def preprocess_supervised_dataset( - examples: Dict[str, List[Any]], - template: "Template", - tokenizer: "PreTrainedTokenizer", - processor: Optional["ProcessorMixin"], - data_args: "DataArguments", -) -> Dict[str, List[Any]]: - # build inputs with format ` X Y ` and labels with format ` ... Y ` - # for multiturn examples, we only mask the prompt part in each prompt-response pair. - model_inputs = defaultdict(list) - for i in range(len(examples["_prompt"])): - if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1: - logger.warning_rank0( - "Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]) - ) - continue - - input_ids, labels = _encode_supervised_example( - prompt=examples["_prompt"][i], - response=examples["_response"][i], - system=examples["_system"][i], - tools=examples["_tools"][i], - images=examples["_images"][i] or [], - videos=examples["_videos"][i] or [], - audios=examples["_audios"][i] or [], - template=template, - tokenizer=tokenizer, - processor=processor, - cutoff_len=data_args.cutoff_len, - train_on_prompt=data_args.train_on_prompt, - mask_history=data_args.mask_history, - ) - model_inputs["input_ids"].append(input_ids) - model_inputs["attention_mask"].append([1] * len(input_ids)) - model_inputs["labels"].append(labels) - model_inputs["images"].append(examples["_images"][i]) - model_inputs["videos"].append(examples["_videos"][i]) - model_inputs["audios"].append(examples["_audios"][i]) - - return model_inputs - - -def preprocess_packed_supervised_dataset( - examples: Dict[str, List[Any]], - template: "Template", - tokenizer: "PreTrainedTokenizer", - processor: Optional["ProcessorMixin"], - data_args: "DataArguments", -) -> Dict[str, List[Any]]: - # TODO: use `position_ids` to achieve packing - # build inputs with format ` X1 Y1 X2 Y2 ` - # and labels with format ` ... Y1 ... Y2 ` - valid_num = 0 - batch_input_ids, batch_labels, batch_images, batch_videos, batch_audios = [], [], [], [], [] - lengths = [] - length2indexes = defaultdict(list) - for i in range(len(examples["_prompt"])): - if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1: - logger.warning_rank0( - "Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]) - ) - continue - - input_ids, labels = _encode_supervised_example( - prompt=examples["_prompt"][i], - response=examples["_response"][i], - system=examples["_system"][i], - tools=examples["_tools"][i], - images=examples["_images"][i] or [], - videos=examples["_videos"][i] or [], - audios=examples["_audios"][i] or [], - template=template, - tokenizer=tokenizer, - processor=processor, - cutoff_len=data_args.cutoff_len - 1, # reserved for the padding token - train_on_prompt=data_args.train_on_prompt, - mask_history=data_args.mask_history, - ) - length = len(input_ids) - if length > data_args.cutoff_len: - logger.warning_rank0(f"Dropped lengthy example with length {length} > {data_args.cutoff_len}.") - else: - lengths.append(length) - length2indexes[length].append(valid_num) - batch_input_ids.append(input_ids) - batch_labels.append(labels) - batch_images.append(examples["_images"][i] or []) - batch_videos.append(examples["_videos"][i] or []) - batch_audios.append(examples["_audios"][i] or []) - valid_num += 1 - - model_inputs = defaultdict(list) - knapsacks = greedy_knapsack(lengths, data_args.cutoff_len - 1) # reserved for the padding token - for knapsack in knapsacks: - packed_input_ids, packed_attention_masks, packed_labels = [], [], [] - packed_images, packed_videos, packed_audios = [], [], [] - for i, length in enumerate(knapsack): - index = length2indexes[length].pop() - packed_input_ids += batch_input_ids[index] - packed_labels += batch_labels[index] - packed_images += batch_images[index] - packed_videos += batch_videos[index] - packed_audios += batch_audios[index] - if data_args.neat_packing: - packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1 - else: - packed_attention_masks += [1] * len(batch_input_ids[index]) - - if len(packed_input_ids) < data_args.cutoff_len: - pad_length = data_args.cutoff_len - len(packed_input_ids) - packed_input_ids += [tokenizer.pad_token_id] * pad_length - packed_labels += [IGNORE_INDEX] * pad_length - if data_args.neat_packing: - packed_attention_masks += [0] * pad_length - else: - packed_attention_masks += [1] * pad_length # more efficient flash_attn - - if len(packed_input_ids) != data_args.cutoff_len: - raise ValueError("The length of packed example should be identical to the cutoff length.") - - model_inputs["input_ids"].append(packed_input_ids) - model_inputs["attention_mask"].append(packed_attention_masks) - model_inputs["labels"].append(packed_labels) - model_inputs["images"].append(packed_images or None) - model_inputs["videos"].append(packed_videos or None) - model_inputs["audios"].append(packed_audios or None) - - return model_inputs - - -def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None: - valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"])) - print("input_ids:\n{}".format(example["input_ids"])) - print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) - print("label_ids:\n{}".format(example["labels"])) - print(f"labels:\n{tokenizer.decode(valid_labels, skip_special_tokens=False)}") diff --git a/src/llamafactory/data/processors/unsupervised.py b/src/llamafactory/data/processors/unsupervised.py deleted file mode 100644 index 4e62d5c9..00000000 --- a/src/llamafactory/data/processors/unsupervised.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright 2025 the LlamaFactory team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from collections import defaultdict -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple - -from ...extras import logging -from ..data_utils import Role -from .processor_utils import infer_seqlen - - -if TYPE_CHECKING: - from transformers import PreTrainedTokenizer, ProcessorMixin - - from ...hparams import DataArguments - from ..mm_plugin import AudioInput, ImageInput, VideoInput - from ..template import Template - - -logger = logging.get_logger(__name__) - - -def _encode_unsupervised_example( - prompt: Sequence[Dict[str, str]], - response: Sequence[Dict[str, str]], - system: Optional[str], - tools: Optional[str], - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - audios: Sequence["AudioInput"], - template: "Template", - tokenizer: "PreTrainedTokenizer", - processor: Optional["ProcessorMixin"], - cutoff_len: int, -) -> Tuple[List[int], List[int]]: - if len(response) == 1: - messages = prompt + response - else: - messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}] - - messages = template.mm_plugin.process_messages(messages, images, videos, audios, processor) - input_ids, labels = template.encode_oneturn(tokenizer, messages, system, tools) - if template.efficient_eos: - labels += [tokenizer.eos_token_id] - - input_ids, _ = template.mm_plugin.process_token_ids(input_ids, None, images, videos, audios, tokenizer, processor) - source_len, target_len = infer_seqlen(len(input_ids), len(labels), cutoff_len) - input_ids = input_ids[:source_len] - labels = labels[:target_len] - return input_ids, labels - - -def preprocess_unsupervised_dataset( - examples: Dict[str, List[Any]], - template: "Template", - tokenizer: "PreTrainedTokenizer", - processor: Optional["ProcessorMixin"], - data_args: "DataArguments", -) -> Dict[str, List[Any]]: - # build inputs with format ` X` and labels with format `Y ` - model_inputs = defaultdict(list) - for i in range(len(examples["_prompt"])): - if len(examples["_prompt"][i]) % 2 != 1: - logger.warning_rank0( - "Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]) - ) - continue - - input_ids, labels = _encode_unsupervised_example( - prompt=examples["_prompt"][i], - response=examples["_response"][i], - system=examples["_system"][i], - tools=examples["_tools"][i], - images=examples["_images"][i] or [], - videos=examples["_videos"][i] or [], - audios=examples["_audios"][i] or [], - template=template, - tokenizer=tokenizer, - processor=processor, - cutoff_len=data_args.cutoff_len, - ) - model_inputs["input_ids"].append(input_ids) - model_inputs["attention_mask"].append([1] * len(input_ids)) - model_inputs["labels"].append(labels) - model_inputs["images"].append(examples["_images"][i]) - model_inputs["videos"].append(examples["_videos"][i]) - model_inputs["audios"].append(examples["_audios"][i]) - - return model_inputs - - -def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None: - print("input_ids:\n{}".format(example["input_ids"])) - print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) - print("label_ids:\n{}".format(example["labels"])) - print("labels:\n{}".format(tokenizer.decode(example["labels"], skip_special_tokens=False))) diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index ba043f6c..ebe9a3ce 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -405,7 +405,7 @@ class Llama2Template(Template): TEMPLATES: Dict[str, "Template"] = {} -def _register_template( +def register_template( name: str, format_user: Optional["Formatter"] = None, format_assistant: Optional["Formatter"] = None, @@ -421,7 +421,7 @@ def _register_template( replace_eos: bool = False, replace_jinja_template: bool = False, mm_plugin: "BasePlugin" = get_mm_plugin(name="base"), - template_class: Type[Template] = Template, + template_class: Type["Template"] = Template, ) -> None: r""" Registers a chat template. @@ -436,7 +436,7 @@ def _register_template( The corresponding code should be: ``` - _register_template( + register_template( name="custom", format_user=StringFormatter(slots=["{{content}}\n"]), format_assistant=StringFormatter(slots=["{{content}}\n"]), @@ -444,6 +444,9 @@ def _register_template( ) ``` """ + if name in TEMPLATES: + raise ValueError(f"Template {name} already exists.") + default_slots = ["{{content}}"] if efficient_eos else ["{{content}}", {"eos_token"}] default_user_formatter = StringFormatter(slots=["{{content}}"]) default_assistant_formatter = StringFormatter(slots=default_slots) @@ -562,7 +565,7 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: return template -_register_template( +register_template( name="alpaca", format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]), format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n\n"]), @@ -573,7 +576,7 @@ _register_template( ) -_register_template( +register_template( name="aquila", format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]), format_assistant=StringFormatter(slots=["{{content}}###"]), @@ -586,7 +589,7 @@ _register_template( ) -_register_template( +register_template( name="atom", format_user=StringFormatter( slots=[{"bos_token"}, "Human: {{content}}\n", {"eos_token"}, {"bos_token"}, "Assistant:"] @@ -595,21 +598,21 @@ _register_template( ) -_register_template( +register_template( name="baichuan", format_user=StringFormatter(slots=[{"token": ""}, "{{content}}", {"token": ""}]), efficient_eos=True, ) -_register_template( +register_template( name="baichuan2", format_user=StringFormatter(slots=["{{content}}"]), efficient_eos=True, ) -_register_template( +register_template( name="belle", format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]), format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n\n"]), @@ -617,13 +620,13 @@ _register_template( ) -_register_template( +register_template( name="bluelm", format_user=StringFormatter(slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]), ) -_register_template( +register_template( name="breeze", format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]), @@ -631,7 +634,7 @@ _register_template( ) -_register_template( +register_template( name="chatglm2", format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]), format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]), @@ -639,7 +642,7 @@ _register_template( ) -_register_template( +register_template( name="chatglm3", format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]), format_assistant=StringFormatter(slots=["\n", "{{content}}"]), @@ -655,7 +658,7 @@ _register_template( ) -_register_template( +register_template( name="chatml", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), @@ -668,7 +671,7 @@ _register_template( # copied from chatml template -_register_template( +register_template( name="chatml_de", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), @@ -681,13 +684,13 @@ _register_template( ) -_register_template( +register_template( name="codegeex2", format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]), ) -_register_template( +register_template( name="codegeex4", format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]), format_system=StringFormatter(slots=["<|system|>\n{{content}}"]), @@ -704,7 +707,7 @@ _register_template( ) -_register_template( +register_template( name="cohere", format_user=StringFormatter( slots=[ @@ -719,7 +722,7 @@ _register_template( ) -_register_template( +register_template( name="cpm", format_user=StringFormatter(slots=["<用户>{{content}}"]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]), @@ -727,7 +730,7 @@ _register_template( # copied from chatml template -_register_template( +register_template( name="cpm3", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), @@ -738,7 +741,7 @@ _register_template( # copied from chatml template -_register_template( +register_template( name="dbrx", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), @@ -763,7 +766,7 @@ _register_template( ) -_register_template( +register_template( name="deepseek", format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]), format_system=StringFormatter(slots=["{{content}}\n\n"]), @@ -771,14 +774,14 @@ _register_template( ) -_register_template( +register_template( name="deepseek3", format_user=StringFormatter(slots=["<|User|>{{content}}<|Assistant|>"]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]), ) -_register_template( +register_template( name="deepseekcoder", format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]), format_assistant=StringFormatter(slots=["\n{{content}}\n<|EOT|>\n"]), @@ -792,7 +795,7 @@ _register_template( ) -_register_template( +register_template( name="default", format_user=StringFormatter(slots=["Human: {{content}}\nAssistant:"]), format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n"]), @@ -800,13 +803,13 @@ _register_template( ) -_register_template( +register_template( name="empty", format_assistant=StringFormatter(slots=["{{content}}"]), ) -_register_template( +register_template( name="exaone", format_user=StringFormatter(slots=["[|user|]{{content}}\n[|assistant|]"]), format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n"]), @@ -814,7 +817,7 @@ _register_template( ) -_register_template( +register_template( name="falcon", format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]), format_assistant=StringFormatter(slots=["{{content}}\n"]), @@ -822,14 +825,14 @@ _register_template( ) -_register_template( +register_template( name="fewshot", format_assistant=StringFormatter(slots=["{{content}}\n\n"]), efficient_eos=True, ) -_register_template( +register_template( name="gemma", format_user=StringFormatter(slots=["user\n{{content}}\nmodel\n"]), format_assistant=StringFormatter(slots=["{{content}}\n"]), @@ -840,7 +843,7 @@ _register_template( ) -_register_template( +register_template( name="glm4", format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]), format_assistant=StringFormatter(slots=["\n{{content}}"]), @@ -854,7 +857,7 @@ _register_template( ) -_register_template( +register_template( name="granite3", format_user=StringFormatter( slots=[ @@ -866,7 +869,7 @@ _register_template( ) -_register_template( +register_template( name="index", format_user=StringFormatter(slots=["reserved_0{{content}}reserved_1"]), format_system=StringFormatter(slots=["{{content}}"]), @@ -874,7 +877,7 @@ _register_template( ) -_register_template( +register_template( name="intern", format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]), format_assistant=StringFormatter(slots=["{{content}}\n"]), @@ -891,7 +894,7 @@ _register_template( ) -_register_template( +register_template( name="intern2", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), @@ -908,7 +911,7 @@ _register_template( ) -_register_template( +register_template( name="llama2", format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]), format_system=StringFormatter(slots=["<>\n{{content}}\n<>\n\n"]), @@ -917,7 +920,7 @@ _register_template( # copied from llama2 template -_register_template( +register_template( name="llama2_zh", format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]), format_system=StringFormatter(slots=["<>\n{{content}}\n<>\n\n"]), @@ -926,7 +929,7 @@ _register_template( ) -_register_template( +register_template( name="llama3", format_user=StringFormatter( slots=[ @@ -954,7 +957,7 @@ _register_template( # copied from llama3 template -_register_template( +register_template( name="mllama", format_user=StringFormatter( slots=[ @@ -983,7 +986,7 @@ _register_template( # copied from vicuna template -_register_template( +register_template( name="llava", format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]), default_system=( @@ -995,7 +998,7 @@ _register_template( # copied from vicuna template -_register_template( +register_template( name="llava_next", format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]), default_system=( @@ -1007,7 +1010,7 @@ _register_template( # copied from llama3 template -_register_template( +register_template( name="llava_next_llama3", format_user=StringFormatter( slots=[ @@ -1036,7 +1039,7 @@ _register_template( # copied from mistral template -_register_template( +register_template( name="llava_next_mistral", format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]), format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]), @@ -1051,7 +1054,7 @@ _register_template( # copied from qwen template -_register_template( +register_template( name="llava_next_qwen", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), @@ -1068,7 +1071,7 @@ _register_template( # copied from chatml template -_register_template( +register_template( name="llava_next_yi", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), @@ -1079,7 +1082,7 @@ _register_template( # copied from vicuna template -_register_template( +register_template( name="llava_next_video", format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]), default_system=( @@ -1091,7 +1094,7 @@ _register_template( # copied from mistral template -_register_template( +register_template( name="llava_next_video_mistral", format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]), format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]), @@ -1106,7 +1109,7 @@ _register_template( # copied from chatml template -_register_template( +register_template( name="llava_next_video_yi", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), @@ -1117,7 +1120,7 @@ _register_template( # copied from chatml template -_register_template( +register_template( name="marco", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), @@ -1133,7 +1136,7 @@ _register_template( # copied from chatml template -_register_template( +register_template( name="minicpm_v", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), @@ -1144,7 +1147,7 @@ _register_template( # copied from minicpm_v template -_register_template( +register_template( name="minicpm_o", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), @@ -1155,7 +1158,7 @@ _register_template( # mistral tokenizer v3 tekken -_register_template( +register_template( name="ministral", format_user=StringFormatter(slots=["[INST]{{content}}[/INST]"]), format_system=StringFormatter(slots=["{{content}}\n\n"]), @@ -1168,7 +1171,7 @@ _register_template( # mistral tokenizer v3 -_register_template( +register_template( name="mistral", format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]), format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]), @@ -1182,7 +1185,7 @@ _register_template( # mistral tokenizer v7 tekken (copied from ministral) -_register_template( +register_template( name="mistral_small", format_user=StringFormatter(slots=["[INST]{{content}}[/INST]"]), format_system=StringFormatter(slots=["[SYSTEM_PROMPT]{{content}}[/SYSTEM_PROMPT]"]), @@ -1193,21 +1196,21 @@ _register_template( ) -_register_template( +register_template( name="olmo", format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]), format_prefix=EmptyFormatter(slots=[{"eos_token"}]), ) -_register_template( +register_template( name="openchat", format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]), ) -_register_template( +register_template( name="openchat-3.6", format_user=StringFormatter( slots=[ @@ -1223,7 +1226,7 @@ _register_template( # copied from chatml template -_register_template( +register_template( name="opencoder", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), @@ -1234,7 +1237,7 @@ _register_template( ) -_register_template( +register_template( name="orion", format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]), @@ -1242,7 +1245,7 @@ _register_template( # copied from gemma template -_register_template( +register_template( name="paligemma", format_user=StringFormatter(slots=["user\n{{content}}\nmodel\n"]), format_assistant=StringFormatter(slots=["{{content}}\n"]), @@ -1254,7 +1257,7 @@ _register_template( ) -_register_template( +register_template( name="phi", format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]), format_assistant=StringFormatter(slots=["{{content}}<|end|>\n"]), @@ -1263,7 +1266,7 @@ _register_template( ) -_register_template( +register_template( name="phi_small", format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]), format_assistant=StringFormatter(slots=["{{content}}<|end|>\n"]), @@ -1273,7 +1276,7 @@ _register_template( ) -_register_template( +register_template( name="phi4", format_user=StringFormatter( slots=["<|im_start|>user<|im_sep|>{{content}}<|im_end|><|im_start|>assistant<|im_sep|>"] @@ -1285,7 +1288,7 @@ _register_template( # copied from ministral template -_register_template( +register_template( name="pixtral", format_user=StringFormatter(slots=["[INST]{{content}}[/INST]"]), format_system=StringFormatter(slots=["{{content}}\n\n"]), @@ -1299,7 +1302,7 @@ _register_template( # copied from chatml template -_register_template( +register_template( name="qwen", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), @@ -1315,7 +1318,7 @@ _register_template( # copied from chatml template -_register_template( +register_template( name="qwen2_audio", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), @@ -1327,7 +1330,7 @@ _register_template( # copied from qwen template -_register_template( +register_template( name="qwen2_vl", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), @@ -1343,7 +1346,7 @@ _register_template( ) -_register_template( +register_template( name="sailor", format_user=StringFormatter(slots=["<|im_start|>question\n{{content}}<|im_end|>\n<|im_start|>answer\n"]), format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), @@ -1357,7 +1360,7 @@ _register_template( # copied from llama3 template -_register_template( +register_template( name="skywork_o1", format_user=StringFormatter( slots=[ @@ -1391,7 +1394,7 @@ _register_template( ) -_register_template( +register_template( name="solar", format_user=StringFormatter(slots=["### User:\n{{content}}\n\n### Assistant:\n"]), format_system=StringFormatter(slots=["### System:\n{{content}}\n\n"]), @@ -1399,7 +1402,7 @@ _register_template( ) -_register_template( +register_template( name="starchat", format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>"]), format_assistant=StringFormatter(slots=["{{content}}<|end|>\n"]), @@ -1408,14 +1411,14 @@ _register_template( ) -_register_template( +register_template( name="telechat", format_user=StringFormatter(slots=["<_user>{{content}}<_bot>"]), format_system=StringFormatter(slots=["<_system>{{content}}<_end>"]), ) -_register_template( +register_template( name="telechat2", format_user=StringFormatter(slots=["<_user>{{content}}<_bot>"]), format_system=StringFormatter(slots=["<_system>{{content}}"]), @@ -1425,7 +1428,7 @@ _register_template( ) -_register_template( +register_template( name="vicuna", format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]), default_system=( @@ -1436,7 +1439,7 @@ _register_template( ) -_register_template( +register_template( name="video_llava", format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]), default_system=( @@ -1447,7 +1450,7 @@ _register_template( ) -_register_template( +register_template( name="xuanyuan", format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]), default_system=( @@ -1458,13 +1461,13 @@ _register_template( ) -_register_template( +register_template( name="xverse", format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: "]), ) -_register_template( +register_template( name="yayi", format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]), format_assistant=StringFormatter(slots=["{{content}}\n\n"]), @@ -1485,7 +1488,7 @@ _register_template( # copied from chatml template -_register_template( +register_template( name="yi", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), @@ -1494,7 +1497,7 @@ _register_template( ) -_register_template( +register_template( name="yi_vl", format_user=StringFormatter(slots=["### Human: {{content}}\n### Assistant:"]), format_assistant=StringFormatter(slots=["{{content}}\n"]), @@ -1511,7 +1514,7 @@ _register_template( ) -_register_template( +register_template( name="yuan", format_user=StringFormatter(slots=["{{content}}", {"token": ""}]), format_assistant=StringFormatter(slots=["{{content}}\n"]), @@ -1519,7 +1522,7 @@ _register_template( ) -_register_template( +register_template( name="zephyr", format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>\n"]), format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]), @@ -1527,7 +1530,7 @@ _register_template( ) -_register_template( +register_template( name="ziya", format_user=StringFormatter(slots=[":{{content}}\n:"]), format_assistant=StringFormatter(slots=["{{content}}\n"]), diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index 5470f791..520d3958 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -1549,7 +1549,7 @@ register_model_group( register_model_group( models={ - "Pixtral-12B-Instruct": { + "Pixtral-12B": { DownloadSource.DEFAULT: "mistral-community/pixtral-12b", DownloadSource.MODELSCOPE: "AI-ModelScope/pixtral-12b", } diff --git a/tests/data/processors/test_feedback.py b/tests/data/processor/test_feedback.py similarity index 100% rename from tests/data/processors/test_feedback.py rename to tests/data/processor/test_feedback.py diff --git a/tests/data/processors/test_pairwise.py b/tests/data/processor/test_pairwise.py similarity index 100% rename from tests/data/processors/test_pairwise.py rename to tests/data/processor/test_pairwise.py diff --git a/tests/data/processors/test_processor_utils.py b/tests/data/processor/test_processor_utils.py similarity index 94% rename from tests/data/processors/test_processor_utils.py rename to tests/data/processor/test_processor_utils.py index 9cf31220..64d2ab91 100644 --- a/tests/data/processors/test_processor_utils.py +++ b/tests/data/processor/test_processor_utils.py @@ -16,7 +16,7 @@ from typing import Tuple import pytest -from llamafactory.data.processors.processor_utils import infer_seqlen +from llamafactory.data.processor.processor_utils import infer_seqlen @pytest.mark.parametrize( diff --git a/tests/data/processors/test_supervised.py b/tests/data/processor/test_supervised.py similarity index 100% rename from tests/data/processors/test_supervised.py rename to tests/data/processor/test_supervised.py diff --git a/tests/data/processors/test_unsupervised.py b/tests/data/processor/test_unsupervised.py similarity index 100% rename from tests/data/processors/test_unsupervised.py rename to tests/data/processor/test_unsupervised.py diff --git a/tests/data/test_converter.py b/tests/data/test_converter.py new file mode 100644 index 00000000..0308d3ee --- /dev/null +++ b/tests/data/test_converter.py @@ -0,0 +1,46 @@ +from llamafactory.data import Role +from llamafactory.data.converter import get_dataset_converter +from llamafactory.data.parser import DatasetAttr +from llamafactory.hparams import DataArguments + + +def test_alpaca_converter(): + dataset_attr = DatasetAttr("hf_hub", "llamafactory/tiny-supervised-dataset") + data_args = DataArguments() + example = { + "instruction": "Solve the math problem.", + "input": "3 + 4", + "output": "The answer is 7.", + } + dataset_converter = get_dataset_converter("alpaca", dataset_attr, data_args) + assert dataset_converter(example) == { + "_prompt": [{"role": Role.USER.value, "content": "Solve the math problem.\n3 + 4"}], + "_response": [{"role": Role.ASSISTANT.value, "content": "The answer is 7."}], + "_system": "", + "_tools": "", + "_images": None, + "_videos": None, + "_audios": None, + } + + +def test_sharegpt_converter(): + dataset_attr = DatasetAttr("hf_hub", "llamafactory/tiny-supervised-dataset") + data_args = DataArguments() + example = { + "conversations": [ + {"from": "system", "value": "You are a helpful assistant."}, + {"from": "human", "value": "Solve the math problem.\n3 + 4"}, + {"from": "gpt", "value": "The answer is 7."}, + ] + } + dataset_converter = get_dataset_converter("sharegpt", dataset_attr, data_args) + assert dataset_converter(example) == { + "_prompt": [{"role": Role.USER.value, "content": "Solve the math problem.\n3 + 4"}], + "_response": [{"role": Role.ASSISTANT.value, "content": "The answer is 7."}], + "_system": "You are a helpful assistant.", + "_tools": "", + "_images": None, + "_videos": None, + "_audios": None, + }