[breaking change] refactor data pipeline (#6901)

* refactor data

* rename file

Former-commit-id: 7a1a4ce6451cb782573d0bd9dd27a5e443e3a18b
This commit is contained in:
hoshi-hiyouga
2025-02-13 00:39:20 +08:00
committed by GitHub
parent 19ba216571
commit efef4eaefc
27 changed files with 1145 additions and 1132 deletions

View File

@@ -0,0 +1,17 @@
from .feedback import FeedbackDatasetProcessor
from .pairwise import PairwiseDatasetProcessor
from .pretrain import PretrainDatasetProcessor
from .processor_utils import DatasetProcessor
from .supervised import PackedSupervisedDatasetProcessor, SupervisedDatasetProcessor
from .unsupervised import UnsupervisedDatasetProcessor
__all__ = [
"DatasetProcessor",
"FeedbackDatasetProcessor",
"PairwiseDatasetProcessor",
"PretrainDatasetProcessor",
"PackedSupervisedDatasetProcessor",
"SupervisedDatasetProcessor",
"UnsupervisedDatasetProcessor",
]

View File

@@ -0,0 +1,129 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
from .processor_utils import DatasetProcessor, infer_seqlen
if TYPE_CHECKING:
from ..mm_plugin import AudioInput, ImageInput, VideoInput
logger = logging.get_logger(__name__)
class FeedbackDatasetProcessor(DatasetProcessor):
def _encode_data_example(
self,
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
kl_response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
) -> Tuple[List[int], List[int], List[int], List[int], bool]:
if response[0]["content"]: # desired example
kto_tag = True
messages = prompt + [response[0]]
else: # undesired example
kto_tag = False
messages = prompt + [response[1]]
if kl_response[0]["content"]:
kl_messages = prompt + [kl_response[0]]
else:
kl_messages = prompt + [kl_response[1]]
messages = self.template.mm_plugin.process_messages(messages, images, videos, audios, self.processor)
kl_messages = self.template.mm_plugin.process_messages(kl_messages, images, videos, audios, self.processor)
prompt_ids, response_ids = self.template.encode_oneturn(self.tokenizer, messages, system, tools)
kl_prompt_ids, kl_response_ids = self.template.encode_oneturn(self.tokenizer, kl_messages, system, tools)
if self.template.efficient_eos:
response_ids += [self.tokenizer.eos_token_id]
kl_response_ids += [self.tokenizer.eos_token_id]
prompt_ids, _ = self.template.mm_plugin.process_token_ids(
prompt_ids, None, images, videos, audios, self.tokenizer, self.processor
)
kl_prompt_ids, _ = self.template.mm_plugin.process_token_ids(
kl_prompt_ids, None, images, videos, audios, self.tokenizer, self.processor
)
source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), self.data_args.cutoff_len)
prompt_ids = prompt_ids[:source_len]
response_ids = response_ids[:target_len]
kl_source_len, kl_target_len = infer_seqlen(
len(kl_prompt_ids), len(kl_response_ids), self.data_args.cutoff_len
)
kl_prompt_ids = kl_prompt_ids[:kl_source_len]
kl_response_ids = kl_response_ids[:kl_target_len]
input_ids = prompt_ids + response_ids
labels = [IGNORE_INDEX] * source_len + response_ids
kl_input_ids = kl_prompt_ids + kl_response_ids
kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids
return input_ids, labels, kl_input_ids, kl_labels, kto_tag
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
kl_response = examples["_response"][::-1]
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
input_ids, labels, kl_input_ids, kl_labels, kto_tag = self._encode_data_example(
prompt=examples["_prompt"][i],
response=examples["_response"][i],
kl_response=kl_response[i],
system=examples["_system"][i],
tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
audios=examples["_audios"][i] or [],
)
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
model_inputs["kl_input_ids"].append(kl_input_ids)
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
model_inputs["kl_labels"].append(kl_labels)
model_inputs["kto_tags"].append(kto_tag)
model_inputs["images"].append(examples["_images"][i])
model_inputs["videos"].append(examples["_videos"][i])
model_inputs["audios"].append(examples["_audios"][i])
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
if desirable_num == 0 or undesirable_num == 0:
logger.warning_rank0("Your dataset only has one preference type.")
return model_inputs
def print_data_example(self, example: Dict[str, List[int]]) -> None:
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print(f"labels:\n{self.tokenizer.decode(valid_labels, skip_special_tokens=False)}")

View File

@@ -0,0 +1,118 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
from .processor_utils import DatasetProcessor, infer_seqlen
if TYPE_CHECKING:
from ..mm_plugin import AudioInput, ImageInput, VideoInput
logger = logging.get_logger(__name__)
class PairwiseDatasetProcessor(DatasetProcessor):
def _encode_data_example(
self,
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
) -> Tuple[List[int], List[int], List[int], List[int]]:
chosen_messages = self.template.mm_plugin.process_messages(
prompt + [response[0]], images, videos, audios, self.processor
)
rejected_messages = self.template.mm_plugin.process_messages(
prompt + [response[1]], images, videos, audios, self.processor
)
prompt_ids, chosen_ids = self.template.encode_oneturn(self.tokenizer, chosen_messages, system, tools)
_, rejected_ids = self.template.encode_oneturn(self.tokenizer, rejected_messages, system, tools)
if self.template.efficient_eos:
chosen_ids += [self.tokenizer.eos_token_id]
rejected_ids += [self.tokenizer.eos_token_id]
prompt_ids, _ = self.template.mm_plugin.process_token_ids(
prompt_ids, None, images, videos, audios, self.tokenizer, self.processor
)
# consider the response is more important
source_len, target_len = infer_seqlen(
len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), self.data_args.cutoff_len
)
prompt_ids = prompt_ids[:source_len]
chosen_ids = chosen_ids[:target_len]
rejected_ids = rejected_ids[:target_len]
chosen_input_ids = prompt_ids + chosen_ids
chosen_labels = [IGNORE_INDEX] * source_len + chosen_ids
rejected_input_ids = prompt_ids + rejected_ids
rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = self._encode_data_example(
prompt=examples["_prompt"][i],
response=examples["_response"][i],
system=examples["_system"][i],
tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
audios=examples["_audios"][i] or [],
)
model_inputs["chosen_input_ids"].append(chosen_input_ids)
model_inputs["chosen_attention_mask"].append([1] * len(chosen_input_ids))
model_inputs["chosen_labels"].append(chosen_labels)
model_inputs["rejected_input_ids"].append(rejected_input_ids)
model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids))
model_inputs["rejected_labels"].append(rejected_labels)
model_inputs["images"].append(examples["_images"][i])
model_inputs["videos"].append(examples["_videos"][i])
model_inputs["audios"].append(examples["_audios"][i])
return model_inputs
def print_data_example(self, example: Dict[str, List[int]]) -> None:
valid_chosen_labels = list(filter(lambda x: x != IGNORE_INDEX, example["chosen_labels"]))
valid_rejected_labels = list(filter(lambda x: x != IGNORE_INDEX, example["rejected_labels"]))
print("chosen_input_ids:\n{}".format(example["chosen_input_ids"]))
print(
"chosen_inputs:\n{}".format(self.tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False))
)
print("chosen_label_ids:\n{}".format(example["chosen_labels"]))
print(f"chosen_labels:\n{self.tokenizer.decode(valid_chosen_labels, skip_special_tokens=False)}")
print("rejected_input_ids:\n{}".format(example["rejected_input_ids"]))
print(
"rejected_inputs:\n{}".format(
self.tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False)
)
)
print("rejected_label_ids:\n{}".format(example["rejected_labels"]))
print(f"rejected_labels:\n{self.tokenizer.decode(valid_rejected_labels, skip_special_tokens=False)}")

View File

@@ -0,0 +1,57 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from itertools import chain
from typing import Any, Dict, List
from .processor_utils import DatasetProcessor
@dataclass
class PretrainDatasetProcessor(DatasetProcessor):
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
eos_token = "<|end_of_text|>" if self.data_args.template == "llama3" else self.tokenizer.eos_token
text_examples = [messages[0]["content"] + eos_token for messages in examples["_prompt"]]
if not self.data_args.packing:
if getattr(self.tokenizer, "add_bos_token", False):
text_examples = [self.tokenizer.bos_token + example for example in text_examples]
result = self.tokenizer(
text_examples, add_special_tokens=False, truncation=True, max_length=self.data_args.cutoff_len
)
else:
tokenized_examples = self.tokenizer(text_examples, add_special_tokens=False)
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
block_size = self.data_args.cutoff_len
total_length = (total_length // block_size) * block_size
result = {
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
if getattr(self.tokenizer, "add_bos_token", False):
for i in range(len(result["input_ids"])):
result["input_ids"][i][0] = self.tokenizer.bos_token_id
return result
def print_data_example(self, example: Dict[str, List[int]]) -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))

View File

@@ -0,0 +1,100 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import bisect
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments
from ..template import Template
@dataclass
class DatasetProcessor(ABC):
r"""
A class for data processors.
"""
template: "Template"
tokenizer: "PreTrainedTokenizer"
processor: Optional["ProcessorMixin"]
data_args: "DataArguments"
@abstractmethod
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
r"""
Builds model inputs from the examples.
"""
...
@abstractmethod
def print_data_example(self, example: Dict[str, List[int]]) -> None:
r"""
Print a data example to stdout.
"""
...
def search_for_fit(numbers: Sequence[int], capacity: int) -> int:
r"""
Finds the index of largest number that fits into the knapsack with the given capacity.
"""
index = bisect.bisect(numbers, capacity)
return -1 if index == 0 else (index - 1)
def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
r"""
An efficient greedy algorithm with binary search for the knapsack problem.
"""
numbers.sort() # sort numbers in ascending order for binary search
knapsacks = []
while numbers:
current_knapsack = []
remaining_capacity = capacity
while True:
index = search_for_fit(numbers, remaining_capacity)
if index == -1:
break # no more numbers fit in this knapsack
remaining_capacity -= numbers[index] # update the remaining capacity
current_knapsack.append(numbers.pop(index)) # add the number to knapsack
knapsacks.append(current_knapsack)
return knapsacks
def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]:
r"""
Computes the real sequence length after truncation by the cutoff_len.
"""
if target_len * 2 < cutoff_len: # truncate source
max_target_len = cutoff_len
elif source_len * 2 < cutoff_len: # truncate target
max_target_len = cutoff_len - source_len
else: # truncate both
max_target_len = int(cutoff_len * (target_len / (source_len + target_len)))
new_target_len = min(max_target_len, target_len)
max_source_len = max(cutoff_len - new_target_len, 0)
new_source_len = min(max_source_len, source_len)
return new_source_len, new_target_len

View File

@@ -0,0 +1,200 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
from .processor_utils import DatasetProcessor, greedy_knapsack, infer_seqlen
if TYPE_CHECKING:
from ..mm_plugin import AudioInput, ImageInput, VideoInput
logger = logging.get_logger(__name__)
@dataclass
class SupervisedDatasetProcessor(DatasetProcessor):
def _encode_data_example(
self,
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
) -> Tuple[List[int], List[int]]:
messages = self.template.mm_plugin.process_messages(prompt + response, images, videos, audios, self.processor)
input_ids, labels = self.template.mm_plugin.process_token_ids(
[], [], images, videos, audios, self.tokenizer, self.processor
)
encoded_pairs = self.template.encode_multiturn(self.tokenizer, messages, system, tools)
total_length = len(input_ids) + (1 if self.template.efficient_eos else 0)
if self.data_args.mask_history:
encoded_pairs = encoded_pairs[::-1] # high priority for last turns
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
if total_length >= self.data_args.cutoff_len:
break
source_len, target_len = infer_seqlen(
len(source_ids), len(target_ids), self.data_args.cutoff_len - total_length
)
source_ids = source_ids[:source_len]
target_ids = target_ids[:target_len]
total_length += source_len + target_len
if self.data_args.train_on_prompt:
source_label = source_ids
elif self.template.efficient_eos:
source_label = [self.tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1)
else:
source_label = [IGNORE_INDEX] * source_len
if self.data_args.mask_history and turn_idx != 0: # train on the last turn only
target_label = [IGNORE_INDEX] * target_len
else:
target_label = target_ids
if self.data_args.mask_history: # reversed sequences
input_ids = source_ids + target_ids + input_ids
labels = source_label + target_label + labels
else:
input_ids += source_ids + target_ids
labels += source_label + target_label
if self.template.efficient_eos:
input_ids += [self.tokenizer.eos_token_id]
labels += [self.tokenizer.eos_token_id]
return input_ids, labels
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
input_ids, labels = self._encode_data_example(
prompt=examples["_prompt"][i],
response=examples["_response"][i],
system=examples["_system"][i],
tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
audios=examples["_audios"][i] or [],
)
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
model_inputs["images"].append(examples["_images"][i])
model_inputs["videos"].append(examples["_videos"][i])
model_inputs["audios"].append(examples["_audios"][i])
return model_inputs
def print_data_example(self, example: Dict[str, List[int]]) -> None:
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print(f"labels:\n{self.tokenizer.decode(valid_labels, skip_special_tokens=False)}")
@dataclass
class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
# TODO: use `position_ids` to achieve packing
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
valid_num = 0
batch_input_ids, batch_labels, batch_images, batch_videos, batch_audios = [], [], [], [], []
lengths = []
length2indexes = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
input_ids, labels = self._encode_data_example(
prompt=examples["_prompt"][i],
response=examples["_response"][i],
system=examples["_system"][i],
tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
audios=examples["_audios"][i] or [],
)
length = len(input_ids)
if length > self.data_args.cutoff_len:
logger.warning_rank0(f"Dropped lengthy example with length {length} > {self.data_args.cutoff_len}.")
else:
lengths.append(length)
length2indexes[length].append(valid_num)
batch_input_ids.append(input_ids)
batch_labels.append(labels)
batch_images.append(examples["_images"][i] or [])
batch_videos.append(examples["_videos"][i] or [])
batch_audios.append(examples["_audios"][i] or [])
valid_num += 1
model_inputs = defaultdict(list)
knapsacks = greedy_knapsack(lengths, self.data_args.cutoff_len)
for knapsack in knapsacks:
packed_input_ids, packed_attention_masks, packed_labels = [], [], []
packed_images, packed_videos, packed_audios = [], [], []
for i, length in enumerate(knapsack):
index = length2indexes[length].pop()
packed_input_ids += batch_input_ids[index]
packed_labels += batch_labels[index]
packed_images += batch_images[index]
packed_videos += batch_videos[index]
packed_audios += batch_audios[index]
if self.data_args.neat_packing:
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
else:
packed_attention_masks += [1] * len(batch_input_ids[index])
if len(packed_input_ids) < self.data_args.cutoff_len + 1: # avoid flash_attn drops attn mask
pad_length = self.data_args.cutoff_len - len(packed_input_ids) + 1
packed_input_ids += [self.tokenizer.pad_token_id] * pad_length
packed_labels += [IGNORE_INDEX] * pad_length
if self.data_args.neat_packing:
packed_attention_masks += [0] * pad_length
else:
packed_attention_masks += [1] * pad_length # more efficient flash_attn
if len(packed_input_ids) != self.data_args.cutoff_len + 1:
raise ValueError("The length of packed example should be identical to the cutoff length.")
model_inputs["input_ids"].append(packed_input_ids)
model_inputs["attention_mask"].append(packed_attention_masks)
model_inputs["labels"].append(packed_labels)
model_inputs["images"].append(packed_images or None)
model_inputs["videos"].append(packed_videos or None)
model_inputs["audios"].append(packed_audios or None)
return model_inputs

View File

@@ -0,0 +1,91 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras import logging
from ..data_utils import Role
from .processor_utils import DatasetProcessor, infer_seqlen
if TYPE_CHECKING:
from ..mm_plugin import AudioInput, ImageInput, VideoInput
logger = logging.get_logger(__name__)
class UnsupervisedDatasetProcessor(DatasetProcessor):
def _encode_data_example(
self,
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
) -> Tuple[List[int], List[int]]:
if len(response) == 1:
messages = prompt + response
else:
messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}]
messages = self.template.mm_plugin.process_messages(messages, images, videos, audios, self.processor)
input_ids, labels = self.template.encode_oneturn(self.tokenizer, messages, system, tools)
if self.template.efficient_eos:
labels += [self.tokenizer.eos_token_id]
input_ids, _ = self.template.mm_plugin.process_token_ids(
input_ids, None, images, videos, audios, self.tokenizer, self.processor
)
source_len, target_len = infer_seqlen(len(input_ids), len(labels), self.data_args.cutoff_len)
input_ids = input_ids[:source_len]
labels = labels[:target_len]
return input_ids, labels
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1:
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
input_ids, labels = self._encode_data_example(
prompt=examples["_prompt"][i],
response=examples["_response"][i],
system=examples["_system"][i],
tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
audios=examples["_audios"][i] or [],
)
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
model_inputs["images"].append(examples["_images"][i])
model_inputs["videos"].append(examples["_videos"][i])
model_inputs["audios"].append(examples["_audios"][i])
return model_inputs
def print_data_example(self, example: Dict[str, List[int]]) -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print("labels:\n{}".format(self.tokenizer.decode(example["labels"], skip_special_tokens=False)))