diff --git a/requirements.txt b/requirements.txt index 0a273440f..0ce57f3ba 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,7 @@ datasets>=2.16.0,<=4.0.0 accelerate>=1.3.0,<=1.11.0 peft>=0.14.0,<=0.17.1 trl>=0.8.6,<=0.9.6 +torchdata # gui gradio>=4.38.0,<=5.45.0 matplotlib>=3.7.0 diff --git a/src/llamafactory/v1/core/data_loader.py b/src/llamafactory/v1/core/data_loader.py new file mode 100644 index 000000000..1d580589a --- /dev/null +++ b/src/llamafactory/v1/core/data_loader.py @@ -0,0 +1,277 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import copy +import sys +from collections.abc import Generator, Iterator +from dataclasses import dataclass +from typing import Optional + +from torchdata.stateful_dataloader import StatefulDataLoader +from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler + +from ..utils.batching_queue import BaseBatchingQueue +from ..utils.logging import get_logger +from ..utils.types import Processor, TorchDataset +from .trainer_utils.data_collator import DataCollator + + +logger = get_logger(__name__) + + +# base dataloader +class DistributedDataloader(StatefulDataLoader): + """Base Distributed DataLoader.""" + + dataset: "TorchDataset" + sampler: "StatefulDistributedSampler" + + def set_epoch(self, epoch: int) -> None: + if self.sampler is not None and hasattr(self.sampler, "set_epoch"): + self.sampler.set_epoch(epoch) + elif hasattr(self.dataset, "set_epoch"): + self.dataset.set_epoch(epoch) + + +@dataclass +class BaseDataLoader: + """Default DataLoader.""" + + processor: Processor + + def __init__(self, dataset: TorchDataset) -> None: + self.dataset = dataset + # guidlines: fetch until get fixed batchsize. + # save state_dict for buffer. + # resume with state + + # 1. Init stateful dataloader (tokenize) + # 2. Add to buffer (2 * max seq len per device) + # 3. Yield batch indexes (micro batch * grad acc) + # a ) non pack + non dynamic + # b ) non pack + dynamic + # c ) pack + non dynamic + # d ) pack + dynamic + + def init_dataloader(self) -> None: + ### init dataloader + pass + + def __iter__(self) -> Iterator: + pass + + def __next__(self) -> any: + pass + + +@dataclass +class DataLoader: + """Default DataLoader.""" + + processor: "Processor" + dataloader: "DistributedDataloader" + batching_queue: "BaseBatchingQueue" + collate_fn: "DataCollator" + num_micro_batch: int = 1 + length: int = 0 + drop_last: bool = True + + def __init__( + self, + dataloader: any, + collate_fn: "DataCollator", + num_micro_batch: int = 1, + length: int = 0, + drop_last: bool = True, + batching_queue: Optional["BaseBatchingQueue"] = None, + ) -> None: + self.batching_queue = batching_queue + self.num_micro_batch = num_micro_batch + self.step = 0 + self._collate_fn = collate_fn + self._dataloader = dataloader + self._drop_last = drop_last + self._data_iter: Iterator + self._resume = False + self._batch_data_iter: Generator + + if length > 0: + self._length = length + elif length == -1: + self._length = sys.maxsize + else: + self._length = len(self._dataloader) + + def __len__(self): + return self._length + + def __iter__(self) -> Iterator: + if not self._resume: + self.step = 0 + self._data_iter = iter(self._dataloader) + self._batch_data_iter = self.batch_data_generator() + self._resume = False + return self + + def __next__(self): + return next(self._batch_data_iter) # FIXME maybe we can move origin_batch_data_generator to here + + def origin_batch_data_generator(self): + """Standard pass-through generator if do not use batching queue.""" + while True: + if self._length > 0 and self.step >= self._length: + return + + try: + batch = [] + data = next(self._data_iter) + # split data into micro batches + for i in range(0, len(data), self.num_micro_batch): + micro_batch = data[i : i + self.num_micro_batch] + if self._collate_fn: + micro_batch = self._collate_fn(micro_batch) + batch.append(micro_batch) + yield batch + self.step += 1 + except StopIteration: + if self.step < self._length: + # Restart iterator to fill the requested length + self._data_iter = iter(self._dataloader) + try: + batch = [] + data = next(self._data_iter) + for i in range(0, len(data), self.num_micro_batch): + micro_batch = data[i : i + self.num_micro_batch] + if self._collate_fn: + micro_batch = self._collate_fn(micro_batch) + batch.append(micro_batch) + yield batch + self.step += 1 + except StopIteration: + return + else: + return + except Exception as e: + logger.error(f"DataLoader origin_batch_data_generator exception: {e}") + raise + + def batch_data_generator(self): + if self.batching_queue is None: + yield from self.origin_batch_data_generator() + return + + batch = [] + + while True: + if self._length and self.step >= self._length: + return + + if self.batching_queue.is_full_filled(): + micro_batch = self.batching_queue.get_micro_batch(self.step) + if self._collate_fn: + micro_batch = self._collate_fn(micro_batch) + batch.append(micro_batch) + if len(batch) == self.num_micro_batch: + yield batch + self.step += 1 + batch = [] + + try: + processing_item = next(self._data_iter) + except Exception as e: + if isinstance(e, StopIteration): + if self.step < self._length: + # call iter until reach length + self._data_iter = iter(self._dataloader) + processing_item = next(self._data_iter) + elif not self._drop_last and not self.batching_queue.empty(): + while not self.batching_queue.empty(): + micro_batch = self.batching_queue.get_micro_batch(self.step) + if self._collate_fn: + micro_batch = self._collate_fn(micro_batch) + batch.append(micro_batch) + if len(batch) == self.num_micro_batch: + yield batch + self.step += 1 + batch = [] + + while len(batch) < self.num_micro_batch: + padding_batch = copy.deepcopy(micro_batch) + padding_batch["is_padded"] = True + batch.append(padding_batch) + yield batch + self.step += 1 + return + else: + return + else: + logger.error(f"DataLoader iter data exception: {e}") + raise + + # put processing_item to buffer + if isinstance(processing_item, dict): + processing_item = [processing_item] + + for item in processing_item: + self.batching_queue.put_item(item) + + def state_dict(self): + # save state + state = self.__dict__.copy() + # remove internal fields + for k in list(state.keys()): + if k.startswith("_"): + del state[k] + + # save dataloader state + if hasattr(self._dataloader, "state_dict"): + state["dataloader_state"] = self._dataloader.state_dict() + elif hasattr(self._dataloader, "__getstate__"): + state["dataloader_state"] = self._dataloader.__getstate__() + + batching_strategy = getattr(self, "batching_strategy", None) + if batching_strategy and hasattr(batching_strategy, "state_dict"): + state["batching_strategy_state"] = batching_strategy.state_dict() + if "batching_strategy" in state: + del state["batching_strategy"] + + return copy.deepcopy(state) + + def load_state_dict(self, state: dict[str, any]): + if state["num_micro_batch"] != self.num_micro_batch: + logger.warning( + f"num_micro_batch changed: [ {state['num_micro_batch']} -> {self.num_micro_batch} ], will clear prefetch buffer" + ) + del state["num_micro_batch"] + self.__dict__.update(state) + self._resume = True + + if hasattr(self._dataloader, "load_state_dict"): + self._dataloader.load_state_dict(state["dataloader_state"]) + elif hasattr(self._dataloader, "__getstate__"): + self._dataloader.__setstate__(state["dataloader_state"]) + + if "batching_strategy_state" in state: + batching_strategy = getattr(self, "batching_strategy", None) + if batching_strategy: + batching_strategy.load_state_dict(state["batching_strategy_state"]) + del state["batching_strategy_state"] + + self._data_iter = iter(self._dataloader) + self._batch_data_iter = self.batch_data_generator() + + def set_epoch(self, epoch: int) -> None: + if hasattr(self._dataloader, "set_epoch"): + self._dataloader.set_epoch(epoch) diff --git a/src/llamafactory/v1/core/trainer_utils/data_collator.py b/src/llamafactory/v1/core/trainer_utils/data_collator.py index f91a07d77..a425f02f9 100644 --- a/src/llamafactory/v1/core/trainer_utils/data_collator.py +++ b/src/llamafactory/v1/core/trainer_utils/data_collator.py @@ -12,36 +12,108 @@ # See the License for the specific language governing permissions and # limitations under the License. - +from collections import defaultdict +from collections.abc import Sequence +from dataclasses import dataclass from typing import Any -from ...utils.types import Processor, Tensor, TorchDataset +import torch +import torch.nn.functional as F +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data._utils.collate import default_collate + +from ....extras.constants import IGNORE_INDEX +from ...plugins.data_plugins.template import Template +from ...utils.types import Processor, Tensor + + +def len2culen(seqlens: "torch.Tensor") -> "torch.Tensor": # FIXME move to utils + """Convert sequence lengths to cumulative sequence lengths.""" + return F.pad(torch.cumsum(seqlens, dim=0), (1, 0)).type(torch.int32) class DataCollator: """Default Data collator.""" - def __init__(self, processor: Processor) -> None: - self.processor = processor + processor: "Processor" # processor name -> map to encode_messages function + + def __post_init__(self): + # callback for text tokenizer + self.tokenizer = self.processor.tokenizer if hasattr(self.processor, "tokenizer") else self.processor def __call__(self, features: list[dict[str, Any]]) -> dict[str, Tensor]: """Collate features into a batch.""" - for feature in features: - pass + batch = defaultdict(list) + # batching features + for feature in features: + for key in feature.keys(): + batch[key].append(feature[key]) + + for key in batch.keys(): + # process padding features + if key in ["input_ids", "attention_mask", "position_ids"]: + padding_value = self.tokenizer.pad_token_id if key == "input_ids" else 0 + batch[key] = pad_sequence(batch[key], batch_first=True, padding_value=padding_value) + elif key in ["labels"]: + batch[key] = pad_sequence(batch[key], batch_first=True, padding_value=IGNORE_INDEX) + else: + batch[key] = default_collate(batch[key]) + + return batch # sft: messages # dpo: chosen_messages, rejected_messages -class DataLoader: - """Default DataLoader.""" +@dataclass +class DefaultCollator(DataCollator): + """Example for now.""" - def __init__(self, dataset: TorchDataset) -> None: - self.dataset = dataset - # 1. Init stateful dataloader (tokenize) - # 2. Add to buffer (2 * max seq len per device) - # 3. Yield batch indexes (micro batch * grad acc) - # a ) non pack + non dynamic - # b ) non pack + dynamic - # c ) pack + non dynamic - # d ) pack + dynamic + processor: "Processor" # processor name -> map to encode_messages function + template: "Template" + + def __call__(self, messages: list[list[dict[str, Any]]]) -> dict[str, Tensor]: + features = [] + + # Check if data is already tokenized (contains input_ids) + if messages and isinstance(messages[0], dict) and "input_ids" in messages[0]: + for feature in messages: + if not isinstance(feature, dict): + raise ValueError(f"Expected dict but got {type(feature)}") + tensor_feature = { + k: torch.tensor(v, dtype=torch.long) if not isinstance(v, torch.Tensor) else v + for k, v in feature.items() + } + features.append(tensor_feature) + else: + # raw messages need to be encoded + for message in messages: + encoded_message = self.template.encode_messages(self.tokenizer, message) + encoded_message = {k: torch.tensor(v, dtype=torch.long) for k, v in encoded_message.items()} + features.append(encoded_message) + + return super().__call__(features) + + +@dataclass +class PairwiseCollator(DataCollator): + pass + + +@dataclass +class DataCollatorWithPacking(DefaultCollator): + """Data collator with packing.""" + + processor: "Processor" + template: "Template" + + def __call__(self, features: Sequence[dict[str, "torch.Tensor"]]) -> dict[str, "torch.Tensor"]: + seqlens = torch.tensor([len(feature["input_ids"]) for feature in features], dtype=torch.long) + batch = {"cu_seqlens": len2culen(seqlens)} + for input_name in features[0].keys(): + if input_name in ("input_ids", "attention_mask", "labels"): + batch[input_name] = torch.cat([feature[input_name] for feature in features]) + else: + batch[input_name] = default_collate([feature[input_name] for feature in features]) + + return batch diff --git a/src/llamafactory/v1/plugins/data_plugins/template.py b/src/llamafactory/v1/plugins/data_plugins/template.py index cf41389fe..32ec6f378 100644 --- a/src/llamafactory/v1/plugins/data_plugins/template.py +++ b/src/llamafactory/v1/plugins/data_plugins/template.py @@ -14,6 +14,7 @@ from dataclasses import dataclass +from typing import Union @dataclass @@ -22,5 +23,112 @@ class Template: assistant_template: str system_template: str - def render_message(self, message: "dict[str, str]") -> str: + def render_message(self, message: dict[str, str]) -> str: return self.user_template.format(**message) + + +@dataclass +class QwenTemplate: + message_template: str = "<|im_start|>{role}\n{content}<|im_end|>\n" # FIXME if role: tool + thinking_template: str = "\n{content}\n\n\n" + + def _extract_content(self, content_data: Union[str, list[dict[str, str]]]) -> str: + if isinstance(content_data, str): + return content_data.strip() + + if isinstance(content_data, list): + parts = [] + for item in content_data: + if item.get("type") == "text": + parts.append(item.get("value", "")) + elif item.get("type") == "image_url": + pass + return "\n".join(parts).strip() + + return "" + + def render_message(self, message: dict[str, Union[str, list[dict[str, str]]]]) -> str: + role = message["role"] + content = self._extract_content(message.get("content", "")) + + if role == "assistant": + reasoning_content = message.get("reasoning_content", "") + if reasoning_content: + reasoning_content = self.thinking_template.format(content=str(reasoning_content).strip()) + return self.message_template.format(role="assistant", content=reasoning_content + content) + else: + return self.message_template.format(role=role, content=content) + + def encode_messages(self, tokenizer, messages: list[dict[str, str]], max_seq_len: int = 8192) -> any: + """Encode one message.""" + input_ids, attention_mask, labels = [], [], [] + for message in messages: + content_str = self.render_message(message) + content_ids = tokenizer.encode(content_str, add_special_tokens=False) + input_ids += content_ids + attention_mask += [1] * len(content_ids) + + if hasattr(message, "loss_weight"): + loss_weight = message["loss_weight"] + else: + loss_weight = 1 if message["role"] == "assistant" else 0 + if loss_weight == 1: + labels += content_ids + else: + labels += [-100] * len(content_ids) + model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels} + model_inputs.update({"position_ids": list(range(len(input_ids)))}) + model_inputs = {k: v[-max_seq_len:] for k, v in model_inputs.items()} + return model_inputs + + +if __name__ == "__main__": + + def to_qwen3_messages(template: QwenTemplate, messages: list[dict]): + out = [] + for m in messages: + role = m["role"] + content = template._extract_content(m.get("content", "")) + if role == "assistant": + reasoning = (m.get("reasoning_content") or "").strip() + if reasoning: + content = template.thinking_template.format(content=reasoning) + content + out.append({"role": role, "content": content}) + return out + + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained( + "Qwen/Qwen3-30B-A3B-Thinking-2507", + trust_remote_code=True, + ) + + test_messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": [{"type": "text", "text": "1+1等于几?"}, {"type": "text", "text": "2+2等于几?"}], + }, + { + "role": "assistant", + "reasoning_content": "这是一个简单的数学问题。1加1的结果是2。", + "content": [{"type": "text", "text": "1+1=2"}, {"type": "text", "text": "2+2=4"}], + }, + ] + + template = QwenTemplate() + rendered_custom = "".join([template.render_message(m) for m in test_messages]) + + qwen3_messages = to_qwen3_messages(template, test_messages) + rendered_hf = tok.apply_chat_template(qwen3_messages, tokenize=False, add_generation_prompt=False) + + print("==== custom ====") + print(rendered_custom) + print("==== hf ====") + print(rendered_hf) + + assert rendered_custom.strip() == rendered_hf.strip(), "Rendered text mismatch" + + ids_custom = tok.encode(rendered_custom, add_special_tokens=False) + ids_hf = tok.apply_chat_template(qwen3_messages, tokenize=True, add_generation_prompt=False) + assert ids_custom == ids_hf, f"Token ids mismatch: custom={len(ids_custom)} hf={len(ids_hf)}" diff --git a/src/llamafactory/v1/utils/batching_queue.py b/src/llamafactory/v1/utils/batching_queue.py new file mode 100644 index 000000000..ce71a6d29 --- /dev/null +++ b/src/llamafactory/v1/utils/batching_queue.py @@ -0,0 +1,220 @@ +# Copyright 2025 Bytedance Ltd. and the LlamaFactory team. +# +# This code is inspired by the Bytedance's VeOmni library. +# https://github.com/ByteDance-Seed/VeOmni/blob/v0.1.4/veomni/data/dynamic_batching.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod + + +class DynamicBatchSizeBuffer: + """A buffer to store samples for dynamic batch size.""" + + def __init__(self): + self._buffer: list[dict[str, any]] = [] + self._buffer_sample_lengths: list[int] = [] + self._deleted_indices: set[int] = set() + self._current_index: int = 0 + self._total_token_count: int = 0 + + def append(self, item: dict[str, any]) -> None: + """Append a sample to the buffer. + + Args: + item: A sample to append to the buffer. + The sample should be a dict with the following keys: + - input_ids: torch.Tensor of shape (seq_len, ) + - attention_mask: torch.Tensor of shape (seq_len, ) + """ + self._buffer.append(item) + sample_length = int(item["attention_mask"].sum().item()) + self._buffer_sample_lengths.append(sample_length) + self._total_token_count += sample_length + + def get_samples(self, max_tokens_per_iteration: int, force: bool = True) -> list[dict[str, any]]: + """Get samples from the buffer that fit within the token budget. + + Args: + max_tokens_per_iteration: Maximum number of tokens to retrieve. + force: If True, the first available sample will be returned even + if it exceeds the token budget. + + Returns: + A list of samples that fit within the token budget. + + Raises: + AssertionError: If no samples are found (should not happen in normal operation). + """ + cum_seq_len = 0 + samples = [] + + while self._current_index < len(self._buffer) and cum_seq_len < max_tokens_per_iteration: + if self._current_index in self._deleted_indices: + self._current_index += 1 + continue + + seq_len = self._buffer_sample_lengths[self._current_index] + remaining_tokens = max_tokens_per_iteration - cum_seq_len + + # Check if we can add this sample + can_add = (force and cum_seq_len == 0) or (seq_len <= remaining_tokens) + + if can_add: + cum_seq_len += seq_len + samples.append(self._buffer[self._current_index]) + self._deleted_indices.add(self._current_index) + + self._current_index += 1 + + assert len(samples) > 0, "No samples found in buffer" + return samples + + def __len__(self) -> int: + """Return the number of samples in the buffer.""" + return len(self._buffer) + + @property + def total_token_count(self) -> int: + """Return the total number of tokens in the buffer.""" + return self._total_token_count + + def flush(self) -> None: + tokens_to_remove = sum(self._buffer_sample_lengths[idx] for idx in self._deleted_indices) + self._total_token_count -= tokens_to_remove + + buffer_length = len(self._buffer) + self._buffer = [self._buffer[idx] for idx in range(buffer_length) if idx not in self._deleted_indices] + self._buffer_sample_lengths = [ + self._buffer_sample_lengths[idx] for idx in range(buffer_length) if idx not in self._deleted_indices + ] + + self._current_index = 0 + self._deleted_indices.clear() + + +class BaseBatchingQueue(ABC): + """Base class for batching queue.""" + + @abstractmethod + def is_full_filled(self) -> bool: + raise NotImplementedError("Subclasses must implement `is_full_filled`") + + @abstractmethod + def put_item(self, item: dict[str, any]) -> None: + raise NotImplementedError("Subclasses must implement `put_item`") + + @abstractmethod + def get_micro_batch(self, step: int) -> list[dict[str, any]]: + raise NotImplementedError("Subclasses must implement `get_micro_batch`") + + @abstractmethod + def empty(self) -> bool: + raise NotImplementedError("Subclasses must implement `empty`") + + +class IdentityPacker: + def __init__(self, token_micro_bsz, bsz_warmup_steps, bsz_warmup_init_mbtoken): + self.token_micro_bsz = token_micro_bsz + self.bsz_warmup_steps = bsz_warmup_steps + self.bsz_warmup_init_mbtoken = bsz_warmup_init_mbtoken + + def __call__(self, samples): + return samples + + def get_token_num_to_request(self, cur_step, warmup): + return ( + (self.token_micro_bsz - self.bsz_warmup_init_mbtoken) * cur_step // self.bsz_warmup_steps + + self.bsz_warmup_init_mbtoken + if warmup + else self.token_micro_bsz + ) + + +class TextBatchingQueue(BaseBatchingQueue): + """Batching text queue for text data.""" + + def __init__( + self, + token_micro_bsz, + buffer_size: int = 500, + bsz_warmup_steps: int = -1, + bsz_warmup_init_mbtoken: int = 200, + ) -> None: + super().__init__() + self._step = 0 + self.token_micro_bsz = token_micro_bsz + self.bsz_warmup_steps = bsz_warmup_steps + self.buffer_size = buffer_size # minimum samples in buffer + self.buffer = DynamicBatchSizeBuffer() + self.bsz_warmup_init_mbtoken = bsz_warmup_init_mbtoken # training warmup args + assert self.bsz_warmup_init_mbtoken >= 0 + + self.packer = IdentityPacker( + token_micro_bsz=token_micro_bsz, + bsz_warmup_steps=bsz_warmup_steps, + bsz_warmup_init_mbtoken=bsz_warmup_init_mbtoken, + ) + + def is_full_filled(self) -> bool: + return len(self.buffer) >= self.buffer_size and self.buffer.total_token_count >= self.token_micro_bsz + + def put_item(self, item: dict[str, any]): + if len(item["input_ids"]) == 1: + print("WARNING: EMPTY STRING.") + return + self.buffer.append(item) + + def get_token_num_to_request(self): + if self.packer is not None: + warmup = self._step <= self.bsz_warmup_steps and self.bsz_warmup_steps > 0 + return self.packer.get_token_num_to_request(self._step, warmup=warmup) + else: + return self.get_cur_token_micro_bsz() + + def get_cur_token_micro_bsz(self): + warmup = self._step <= self.bsz_warmup_steps and self.bsz_warmup_steps > 0 + if warmup: + return ( + self.token_micro_bsz - self.bsz_warmup_init_mbtoken + ) * self._step // self.bsz_warmup_steps + self.bsz_warmup_init_mbtoken + else: + return self.token_micro_bsz + + def get_micro_batch(self, step) -> any: + """Get a micro batch from the buffer according to the current step. + + Args: + step: the current step. + + Returns: + data: a list of samples. + """ + self._step = step + n_token_per_iter = self.get_token_num_to_request() + cur_token_micro_bsz = self.get_cur_token_micro_bsz() + assert cur_token_micro_bsz % n_token_per_iter == 0, ( + "The token num to get for each request should be divisible by token micro bsz." + ) + n_iter = int(cur_token_micro_bsz // n_token_per_iter) + data = [] + for _ in range(n_iter): + samples = self.buffer.get_samples(n_token_per_iter) + if self.packer: + samples = self.packer(samples) # maybe packed into one sample, but wrapped in list. + data.extend(samples) + self.buffer.flush() # remove the selected samples. + return data + + def empty(self) -> bool: + return len(self.buffer) == 0 diff --git a/tests_v1/core/test_data_loader.py b/tests_v1/core/test_data_loader.py new file mode 100644 index 000000000..cf25aa59e --- /dev/null +++ b/tests_v1/core/test_data_loader.py @@ -0,0 +1,173 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for DataLoader with different combinations of packing and dynamic batching. + +Tests the 4 scenarios: +a) non pack + non dynamic. +b) non pack + dynamic. +c) pack + non dynamic. +d) pack + dynamic. +""" + +import torch +from torch.utils.data import DataLoader as TorchDataLoader +from torch.utils.data import Dataset +from transformers import AutoTokenizer + +from llamafactory.v1.config.data_args import DataArguments +from llamafactory.v1.core.data_engine import DataEngine +from llamafactory.v1.core.data_loader import DataLoader +from llamafactory.v1.core.trainer_utils.data_collator import ( + DefaultCollator, +) +from llamafactory.v1.plugins.data_plugins.template import QwenTemplate +from llamafactory.v1.utils.batching_queue import TextBatchingQueue + + +class TensorDataset(Dataset): + """Wrapper dataset that converts DataEngine samples to tensor format.""" + + def __init__(self, data_engine: DataEngine, processor, template, max_samples: int = None): + self.data_engine = data_engine + self.processor = processor + self.template = template + self.max_samples = max_samples or len(data_engine) + self.tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor + + def __len__(self): + return min(self.max_samples, len(self.data_engine)) + + def __getitem__(self, idx): + # Get sample from DataEngine + sample = self.data_engine[idx] + + # Extract messages from sample + # DataEngine returns samples with format like {"messages": [...], ...} + # For llamafactory/v1-sft-demo, the format should have "messages" field + messages = None + if "messages" in sample: + messages = sample["messages"] + elif "conversations" in sample: + messages = sample["conversations"] + elif "conversation" in sample: + messages = sample["conversation"] + else: + # Try to find message-like fields (skip _dataset_name) + for key, value in sample.items(): + if key.startswith("_"): + continue + if isinstance(value, list) and len(value) > 0: + # Check if it looks like a message list + if isinstance(value[0], dict) and "role" in value[0]: + messages = value + break + + if messages is None: + raise ValueError(f"Could not find messages in sample: {list(sample.keys())}") + + # Encode messages using template + encoded = self.template.encode_messages(self.tokenizer, messages) + + # Convert to tensors + return { + "input_ids": torch.tensor(encoded["input_ids"], dtype=torch.long), + "attention_mask": torch.tensor(encoded["attention_mask"], dtype=torch.long), + "labels": torch.tensor(encoded["labels"], dtype=torch.long), + } + + +def create_real_dataset(max_samples: int = 20, batch_size: int = 4): + """Create a real dataset using DataEngine.""" + data_args = DataArguments(dataset="llamafactory/v1-sft-demo") + data_engine = DataEngine(data_args) + + # Create processor and template + processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen2.5") + template = QwenTemplate() + + # Create tensor dataset + raw_data_dataset = TensorDataset(data_engine, processor, template, max_samples=max_samples) + + # Create torch DataLoader + torch_dataloader = TorchDataLoader( + raw_data_dataset, + batch_size=batch_size, + shuffle=False, + collate_fn=lambda x: x, + ) + + return torch_dataloader, processor, template + + +class TestDataLoaderNonPackNonDynamic: + """Test case a) non pack + non dynamic.""" + + def test_basic_functionality(self): + """Test DataLoader without packing and without dynamic batching.""" + # Create real dataset + torch_dataloader, processor, template = create_real_dataset(max_samples=80, batch_size=8) + + # Create collator (non-packing) + collator = DefaultCollator(processor=processor, template=template) + + # Create DataLoader without batching_queue (non-dynamic) + data_loader = DataLoader( + dataloader=torch_dataloader, + collate_fn=collator, + num_micro_batch=1, + batching_queue=None, + ) + + # Iterate and check results + batches = list(iter(data_loader)) + assert len(batches) > 0 + + # Check first batch + one_batch = batches[0] + micro_batches = one_batch[0] + assert "input_ids" in micro_batches + assert "attention_mask" in micro_batches + assert "labels" in micro_batches + assert micro_batches["input_ids"].shape[0] == 1 # batch_size=1 + assert micro_batches["input_ids"].ndim == 2 # [batch_size, seq_len] + + +class TestDataLoaderNonPackDynamic: + """Test case b) non pack + dynamic.""" + + def test_basic_functionality(self): + """Test DataLoader without packing but with dynamic batching.""" + # Create real dataset + torch_dataloader, processor, template = create_real_dataset(max_samples=80, batch_size=8) + collator = DefaultCollator(processor=processor, template=template) + + # Create batching queue for dynamic batching + batching_queue = TextBatchingQueue( + token_micro_bsz=120, + buffer_size=8, + ) + + data_loader = DataLoader( + dataloader=torch_dataloader, + collate_fn=collator, + num_micro_batch=4, + batching_queue=batching_queue, + ) + + # Iterate and check + batches = list(iter(data_loader)) + micro_batch_tokens_first = [micro_batch["attention_mask"].sum() for micro_batch in batches[0]] + assert all(num_tokens <= 120 for num_tokens in micro_batch_tokens_first) + assert len(batches) > 0 diff --git a/tests_v1/utils/test_batching_queue.py b/tests_v1/utils/test_batching_queue.py new file mode 100644 index 000000000..0f9a68224 --- /dev/null +++ b/tests_v1/utils/test_batching_queue.py @@ -0,0 +1,112 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from llamafactory.v1.utils.batching_queue import DynamicBatchSizeBuffer, TextBatchingQueue + + +def create_sample(length: int): + """Helper to create a mock sample with a specific token length.""" + return {"input_ids": torch.ones(length), "attention_mask": torch.ones(length)} + + +class TestDynamicBatchSizeBuffer: + def test_append_and_token_count(self): + buffer = DynamicBatchSizeBuffer() + buffer.append(create_sample(10)) + buffer.append(create_sample(20)) + + assert len(buffer) == 2 + assert buffer.total_token_count == 30 + + def test_get_samples_within_budget(self): + buffer = DynamicBatchSizeBuffer() + buffer.append(create_sample(10)) + buffer.append(create_sample(10)) + buffer.append(create_sample(50)) # This one is large + + # Request 25 tokens. Should get the first two (20 tokens total) + samples = buffer.get_samples(max_tokens_per_iteration=25) + assert len(samples) == 2 + + def test_force_return_first_sample(self): + buffer = DynamicBatchSizeBuffer() + buffer.append(create_sample(100)) + + # Even though budget is 50, force=True (default) should return the 100-token sample + samples = buffer.get_samples(max_tokens_per_iteration=50, force=True) + assert len(samples) == 1 + assert len(samples[0]["input_ids"]) == 100 + + def test_flush_removes_used_samples(self): + buffer = DynamicBatchSizeBuffer() + buffer.append(create_sample(10)) + buffer.append(create_sample(20)) + + # Take the first sample + buffer.get_samples(max_tokens_per_iteration=15) + buffer.flush() + + assert len(buffer) == 1 + assert buffer.total_token_count == 20 + # The remaining sample should now be at the start + remaining = buffer.get_samples(max_tokens_per_iteration=50) + assert len(remaining[0]["input_ids"]) == 20 + + +class TestTextBatchingQueue: + def test_is_full_filled(self): + queue = TextBatchingQueue(token_micro_bsz=100, buffer_size=2) + + queue.put_item(create_sample(10)) + assert not queue.is_full_filled() # Only 1 sample, buffer_size=2 + + queue.put_item(create_sample(10)) + assert not queue.is_full_filled() # 2 samples, but only 20 tokens (min 100) + + queue.put_item(create_sample(90)) + assert queue.is_full_filled() # Meets both conditions + + def test_warmup_logic(self): + # token_micro_bsz=1000, starts at 200, reaches 1000 at step 10 + queue = TextBatchingQueue(token_micro_bsz=1000, bsz_warmup_steps=10, bsz_warmup_init_mbtoken=200) + + # Step 0: should be init value + assert queue.get_cur_token_micro_bsz() == 200 + + # Step 5: halfway through warmup (200 + (800 * 5/10)) = 600 + queue._step = 5 + assert queue.get_cur_token_micro_bsz() == 600 + + # Step 11: past warmup + queue._step = 11 + assert queue.get_cur_token_micro_bsz() == 1000 + + def test_get_micro_batch_integration(self): + queue = TextBatchingQueue(token_micro_bsz=50, buffer_size=1) + queue.put_item(create_sample(20)) + queue.put_item(create_sample(20)) + queue.put_item(create_sample(20)) + + # At step 0 (warmup not triggered as bsz_warmup_steps is -1 default), + # it should take samples up to 50 tokens. + batch = queue.get_micro_batch(step=0) + + assert len(batch) == 2 + assert queue.empty() is False + + batch_2 = queue.get_micro_batch(step=1) + assert len(batch_2) == 1 + assert queue.empty() is True