diff --git a/.github/instructions-v0.md b/.github/instructions-v0.md new file mode 100644 index 000000000..e69de29bb diff --git a/.github/instructions-v1.md b/.github/instructions-v1.md new file mode 100644 index 000000000..e69de29bb diff --git a/scripts/convert_ckpt/tiny_qwen3.py b/scripts/convert_ckpt/tiny_qwen3.py new file mode 100644 index 000000000..902c0ec09 --- /dev/null +++ b/scripts/convert_ckpt/tiny_qwen3.py @@ -0,0 +1,32 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from transformers import AutoTokenizer, Qwen3Config, Qwen3ForCausalLM + + +if __name__ == "__main__": + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B-Instruct-2507") + config = Qwen3Config( + hidden_size=1408, + image_size=336, + intermediate_size=5632, + num_attention_heads=16, + num_hidden_layers=4, + vision_output_dim=4096, + ) + model = Qwen3ForCausalLM.from_config(config) + model.save_pretrained("tiny-qwen3") + tokenizer.save_pretrained("tiny-qwen3") + model.push_to_hub("llamafactory/tiny-random-qwen3") + tokenizer.push_to_hub("llamafactory/tiny-random-qwen3") diff --git a/src/llamafactory/v1/core/utils/batching.py b/src/llamafactory/v1/core/utils/batching.py index 7f4c724c6..65511288d 100644 --- a/src/llamafactory/v1/core/utils/batching.py +++ b/src/llamafactory/v1/core/utils/batching.py @@ -34,30 +34,29 @@ from ...accelerator.interface import DistributedInterface from ...config import BatchingStrategy from ...utils import logging from ...utils.helper import pad_and_truncate -from ...utils.types import BatchInput, ModelInput, TorchDataset +from ...utils.objects import StatefulBuffer +from ...utils.types import BatchInfo, BatchInput, ModelInput, TorchDataset from .rendering import Renderer logger = logging.get_logger(__name__) -def default_collate_fn( - buffer: list[ModelInput], buffer_tokens: int, micro_batch_size: int, num_micro_batch: int, cutoff_len: int -) -> tuple[list[ModelInput], int, list[BatchInput]]: +def default_collate_fn(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None: + micro_batch_size = batch_info["micro_batch_size"] + num_micro_batch = batch_info["num_micro_batch"] + cutoff_len = batch_info["cutoff_len"] batch_size = micro_batch_size * num_micro_batch if len(buffer) < batch_size: - return buffer, buffer_tokens, None - - samples = buffer[:batch_size] - buffer = buffer[batch_size:] - buffer_tokens -= sum(len(sample["input_ids"]) for sample in samples) + return None + samples = buffer.get(batch_size) batch = [] for i in range(num_micro_batch): micro_batch = samples[i * micro_batch_size : (i + 1) * micro_batch_size] batch.append(default_collate(pad_and_truncate(micro_batch, cutoff_len))) - return buffer, buffer_tokens, batch + return batch class BatchGenerator(Iterator): @@ -105,9 +104,14 @@ class BatchGenerator(Iterator): self._is_resuming: bool = False self._data_iter = iter(self._data_provider) - self._buffer: list[ModelInput] = [] - self._buffer_tokens: int = 0 - self._max_buffer_tokens: int = self.micro_batch_size * self.num_micro_batch * self.cutoff_len + self._buffer = StatefulBuffer() + + self._batch_info: BatchInfo = { + "micro_batch_size": self.micro_batch_size, + "num_micro_batch": self.num_micro_batch, + "cutoff_len": self.cutoff_len, + "data_iter": self._data_iter, + } logger.info_rank0( f"Init unified data loader with global batch size {self.global_batch_size}, " @@ -145,7 +149,7 @@ class BatchGenerator(Iterator): else: from ...plugins.trainer_plugins.batching import BatchingPlugin - self._length = BatchingPlugin(self.batching_strategy).compute_length() + self._length = BatchingPlugin(self.batching_strategy).compute_length(self._data_provider) raise NotImplementedError("Batching strategy other than NORMAL is not supported yet.") def __len__(self) -> int: @@ -161,38 +165,34 @@ class BatchGenerator(Iterator): return self def __next__(self): - batch = self._next_batch() + self._fill_buffer() + batch = self._generate_batch() if batch is None: raise StopIteration return batch - def _next_batch(self) -> list[BatchInput] | None: - while self._buffer_tokens < self._max_buffer_tokens: - try: - samples: list[ModelInput] = next(self._data_iter) - except StopIteration: - break - - num_tokens = sum(len(sample["input_ids"]) for sample in samples) - self._buffer.extend(samples) - self._buffer_tokens += num_tokens - - return self._build_batch() - - def _build_batch(self) -> list[BatchInput] | None: + def _fill_buffer(self) -> None: if self.batching_strategy == BatchingStrategy.NORMAL: - self._buffer, self._buffer_tokens, batch = default_collate_fn( - self._buffer, self._buffer_tokens, self.micro_batch_size, self.num_micro_batch, self.cutoff_len - ) - return batch + while len(self._buffer) < self.micro_batch_size * self.num_micro_batch: + try: + samples: list[ModelInput] = next(self._data_iter) + except StopIteration: + break + + self._buffer.put(samples) else: from ...plugins.trainer_plugins.batching import BatchingPlugin - self._buffer, self._buffer_tokens, batch = BatchingPlugin(self.batching_strategy)( - self._buffer, self._buffer_tokens, self.micro_batch_size, self.num_micro_batch, self.cutoff_len - ) - return batch + BatchingPlugin(self.batching_strategy).fill_buffer(self._buffer, self._batch_info) + + def _generate_batch(self) -> list[BatchInput] | None: + if self.batching_strategy == BatchingStrategy.NORMAL: + return default_collate_fn(self._buffer, self._batch_info) + else: + from ...plugins.trainer_plugins.batching import BatchingPlugin + + return BatchingPlugin(self.batching_strategy).generate_batch(self._buffer, self._batch_info) def state_dict(self) -> dict[str, Any]: return { diff --git a/src/llamafactory/v1/plugins/model_plugins/rendering.py b/src/llamafactory/v1/plugins/model_plugins/rendering.py index 1ada523e9..8ca8b43fc 100644 --- a/src/llamafactory/v1/plugins/model_plugins/rendering.py +++ b/src/llamafactory/v1/plugins/model_plugins/rendering.py @@ -22,7 +22,19 @@ from ...utils.types import Message, ModelInput, Processor, ToolCall class RenderingPlugin(BasePlugin): - pass + def render_messages( + self, + processor: Processor, + messages: list[Message], + tools: str | None = None, + is_generate: bool = False, + ) -> ModelInput: + """Render messages in the template format.""" + return self["render_messages"](processor, messages, tools, is_generate) + + def parse_messages(self, generated_text: str) -> Message: + """Parse messages in the template format.""" + return self["parse_messages"](generated_text) def _update_model_input( diff --git a/src/llamafactory/v1/plugins/trainer_plugins/batching.py b/src/llamafactory/v1/plugins/trainer_plugins/batching.py index fa09dcff4..f61de78d1 100644 --- a/src/llamafactory/v1/plugins/trainer_plugins/batching.py +++ b/src/llamafactory/v1/plugins/trainer_plugins/batching.py @@ -12,8 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +from ...utils.objects import StatefulBuffer from ...utils.plugin import BasePlugin +from ...utils.types import BatchInfo, BatchInput, DataLoader class BatchingPlugin(BasePlugin): - pass + def compute_length(self, dataloader: DataLoader) -> int: + """Compute the length of the batch generator.""" + raise NotImplementedError() + + def fill_buffer(self, buffer: StatefulBuffer, batch_info: BatchInfo) -> None: + """Fill the buffer with data.""" + raise NotImplementedError() + + def generate_batch(self, buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None: + """Generate a batch from the buffer.""" + raise NotImplementedError() diff --git a/src/llamafactory/v1/utils/objects.py b/src/llamafactory/v1/utils/objects.py new file mode 100644 index 000000000..338f52365 --- /dev/null +++ b/src/llamafactory/v1/utils/objects.py @@ -0,0 +1,67 @@ +# Copyright 2025 Optuna, HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/utils/logging.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .types import ModelInput + + +class StatefulBuffer: + """A buffer that stores model inputs.""" + + def __init__(self, max_buffer_size: int = 1_000_000_000) -> None: + self._buffer: list[ModelInput] = [] + self._buffer_size: int = 0 + self._max_buffer_size: int = max_buffer_size + + def __len__(self) -> int: + return len(self._buffer) + + @property + def size(self) -> int: + return self._buffer_size + + def put(self, samples: list[ModelInput]) -> None: + """Add samples to the buffer.""" + num_tokens = sum(len(sample["input_ids"]) for sample in samples) + if self._buffer_size + num_tokens > self._max_buffer_size: + raise ValueError(f"Buffer size exceeds max buffer size {self._max_buffer_size}.") + + self._buffer.extend(samples) + self._buffer_size += num_tokens + + def get(self, value: int) -> list[ModelInput]: + """Get samples from the buffer and remove them.""" + samples = self._buffer[:value] + self._buffer_size -= sum(len(sample["input_ids"]) for sample in samples) + del self._buffer[:value] + return samples + + def clear(self) -> None: + """Clear the buffer.""" + self._buffer = [] + self._buffer_size = 0 + + def state_dict(self) -> dict: + """Returns the state of the buffer.""" + return { + "buffer": self._buffer, + "buffer_size": self._buffer_size, + } + + def load_state_dict(self, state_dict: dict) -> None: + """Loads the state into the buffer.""" + self._buffer = state_dict["buffer"] + self._buffer_size = state_dict["buffer_size"] diff --git a/src/llamafactory/v1/utils/plugin.py b/src/llamafactory/v1/utils/plugin.py index 136896492..4c06c88d3 100644 --- a/src/llamafactory/v1/utils/plugin.py +++ b/src/llamafactory/v1/utils/plugin.py @@ -15,6 +15,7 @@ from collections import defaultdict from collections.abc import Callable +from typing import Any from . import logging @@ -26,33 +27,37 @@ class BasePlugin: """Base class for plugins. A plugin is a callable object that can be registered and called by name. + + Example usage: + ```python + class PrintPlugin(BasePlugin): + def again(self): # optional + self["again"]() + + + @PrintPlugin("hello").register() + def print_hello(): + print("Hello world!") + + + @PrintPlugin("hello").register("again") + def print_hello_again(): + print("Hello world! Again.") + + + PrintPlugin("hello")() + PrintPlugin("hello").again() + ``` """ _registry: dict[str, dict[str, Callable]] = defaultdict(dict) - def __init__(self, name: str | None = None): - """Initialize the plugin with a name. - - Args: - name (str): The name of the plugin. - """ + def __init__(self, name: str | None = None) -> None: + """Initialize the plugin with a name.""" self.name = name - def register(self, method_name: str = "__call__"): - """Decorator to register a function as a plugin. - - Example usage: - ```python - @PrintPlugin("hello").register() - def print_hello(): - print("Hello world!") - - - @PrintPlugin("hello").register("again") - def print_hello_again(): - print("Hello world! Again.") - ``` - """ + def register(self, method_name: str = "__call__") -> Callable: + """Decorator to register a function as a plugin.""" if self.name is None: raise ValueError("Plugin name should be specified.") @@ -65,27 +70,16 @@ class BasePlugin: return decorator - def __call__(self, *args, **kwargs): - """Call the registered function with the given arguments. + def __call__(self, *args, **kwargs) -> Any: + """Call the registered function with the given arguments.""" + return self["__call__"](*args, **kwargs) - Example usage: - ```python - PrintPlugin("hello")() - ``` - """ - if "__call__" not in self._registry[self.name]: - raise ValueError(f"Method __call__ of plugin {self.name} is not registered.") + def __getattr__(self, method_name: str) -> Callable: + """Get the registered function with the given name.""" + return self[method_name] - return self._registry[self.name]["__call__"](*args, **kwargs) - - def __getattr__(self, method_name: str): - """Get the registered function with the given name. - - Example usage: - ```python - PrintPlugin("hello").again() - ``` - """ + def __getitem__(self, method_name: str) -> Callable: + """Get the registered function with the given name.""" if method_name not in self._registry[self.name]: raise ValueError(f"Method {method_name} of plugin {self.name} is not registered.") @@ -98,7 +92,8 @@ if __name__ == "__main__": """ class PrintPlugin(BasePlugin): - pass + def again(self): # optional + self["again"]() @PrintPlugin("hello").register() def print_hello(): diff --git a/src/llamafactory/v1/utils/types.py b/src/llamafactory/v1/utils/types.py index fdd2509b3..d2a1e52f3 100644 --- a/src/llamafactory/v1/utils/types.py +++ b/src/llamafactory/v1/utils/types.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Iterator from typing import TYPE_CHECKING, Any, Literal, NotRequired, TypedDict, Union @@ -161,3 +162,14 @@ class BatchInput(TypedDict, total=False): """Position ids for the model (optional).""" token_type_ids: NotRequired[Tensor] """Token type ids used in DPO, 0 represents the chosen messages, 1 represents the rejected messages.""" + + +class BatchInfo(TypedDict): + micro_batch_size: int + """Micro batch size.""" + num_micro_batch: int + """Number of micro batches.""" + cutoff_len: int + """Cutoff length.""" + data_iter: Iterator[list[ModelInput]] + """Data iterator.""" diff --git a/tests_v1/accelerator/test_interface.py b/tests_v1/accelerator/test_interface.py index 38ec83b39..d3838f8b1 100644 --- a/tests_v1/accelerator/test_interface.py +++ b/tests_v1/accelerator/test_interface.py @@ -58,3 +58,10 @@ def test_multi_device(): master_port = find_available_port() world_size = 2 mp.spawn(_all_reduce_tests, args=(world_size, master_port), nprocs=world_size) + + +if __name__ == "__main__": + """ + python tests_v1/accelerator/test_interface.py + """ + test_all_device() diff --git a/tests_v1/config/test_args_parser.py b/tests_v1/config/test_args_parser.py index 37ba7bc36..8772f92d3 100644 --- a/tests_v1/config/test_args_parser.py +++ b/tests_v1/config/test_args_parser.py @@ -12,41 +12,41 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pathlib import sys +from pathlib import Path from unittest.mock import patch from llamafactory.v1.config.arg_parser import get_args -def test_get_args_from_yaml(tmp_path: pathlib.Path): +def test_get_args_from_yaml(tmp_path: Path): config_yaml = """ ### model - model: "llamafactory/tiny-random-qwen2.5" + model: llamafactory/tiny-random-qwen3 trust_remote_code: true - model_class: "llm" + model_class: llm kernel_config: - name: "auto" - include_kernels: "auto" # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null + name: auto + include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null peft_config: - name: "lora" - lora_rank: 0.8 + name: lora + lora_rank: 0.8 quant_config: null ### data - dataset: "llamafactory/tiny-supervised-dataset" - cutoff_len: 2048 + dataset: llamafactory/v1-sft-demo ### training - output_dir: "outputs/test_run" + output_dir: outputs/test_run micro_batch_size: 1 global_batch_size: 1 + cutoff_len: 2048 learning_rate: 1.0e-4 bf16: false dist_config: null ### sample - sample_backend: "hf" + sample_backend: hf max_new_tokens: 128 """ @@ -57,14 +57,26 @@ def test_get_args_from_yaml(tmp_path: pathlib.Path): with patch.object(sys, "argv", test_argv): data_args, model_args, training_args, sample_args = get_args() + assert data_args.dataset == "llamafactory/v1-sft-demo" + assert model_args.model == "llamafactory/tiny-random-qwen3" + assert model_args.kernel_config.name == "auto" + assert model_args.kernel_config.get("include_kernels") == "auto" + assert model_args.peft_config.name == "lora" + assert model_args.peft_config.get("lora_rank") == 0.8 assert training_args.output_dir == "outputs/test_run" assert training_args.micro_batch_size == 1 assert training_args.global_batch_size == 1 assert training_args.learning_rate == 1.0e-4 assert training_args.bf16 is False assert training_args.dist_config is None - assert model_args.model == "llamafactory/tiny-random-qwen2.5" - assert model_args.kernel_config.name == "auto" - assert model_args.kernel_config.get("include_kernels") == "auto" - assert model_args.peft_config.name == "lora" - assert model_args.peft_config.get("lora_rank") == 0.8 + assert sample_args.sample_backend == "hf" + + +if __name__ == "__main__": + """ + python -m tests_v1.config.test_args_parser + """ + import tempfile + + with tempfile.TemporaryDirectory() as tmp_dir: + test_get_args_from_yaml(tmp_path=Path(tmp_dir)) diff --git a/tests_v1/core/test_data_engine.py b/tests_v1/core/test_data_engine.py index 65cb42668..d1b2223ad 100644 --- a/tests_v1/core/test_data_engine.py +++ b/tests_v1/core/test_data_engine.py @@ -33,4 +33,7 @@ def test_map_dataset(num_samples: int): if __name__ == "__main__": + """ + python -m tests_v1.core.test_data_engine + """ test_map_dataset(1) diff --git a/tests_v1/core/test_model_loader.py b/tests_v1/core/test_model_loader.py index bf426c93b..6228a3699 100644 --- a/tests_v1/core/test_model_loader.py +++ b/tests_v1/core/test_model_loader.py @@ -44,5 +44,8 @@ def test_tiny_qwen_with_kernel_plugin(): if __name__ == "__main__": + """ + python -m tests_v1.core.test_model_loader + """ test_tiny_qwen() test_tiny_qwen_with_kernel_plugin() diff --git a/tests_v1/core/utils/test_batching.py b/tests_v1/core/utils/test_batching.py index 9d3461337..ba74d21d7 100644 --- a/tests_v1/core/utils/test_batching.py +++ b/tests_v1/core/utils/test_batching.py @@ -46,4 +46,7 @@ def test_normal_batching(): if __name__ == "__main__": + """ + python -m tests_v1.core.utils.test_batching + """ test_normal_batching() diff --git a/tests_v1/core/utils/test_rendering.py b/tests_v1/core/utils/test_rendering.py index 40dd12532..3963ccb8a 100644 --- a/tests_v1/core/utils/test_rendering.py +++ b/tests_v1/core/utils/test_rendering.py @@ -219,6 +219,9 @@ def test_process_dpo_samples(): if __name__ == "__main__": + """ + python -m tests_v1.core.utils.test_rendering + """ test_chatml_rendering() test_chatml_parse() test_chatml_rendering_remote(16) diff --git a/tests_v1/plugins/data_plugins/test_converter.py b/tests_v1/plugins/data_plugins/test_converter.py index d4106d9c1..0f47d8e55 100644 --- a/tests_v1/plugins/data_plugins/test_converter.py +++ b/tests_v1/plugins/data_plugins/test_converter.py @@ -120,6 +120,9 @@ def test_pair_converter(num_samples: int): if __name__ == "__main__": + """ + python -m tests_v1.plugins.data_plugins.test_converter + """ test_alpaca_converter(1) test_sharegpt_converter() test_pair_converter(1) diff --git a/tests_v1/plugins/model_plugins/test_init_plugin.py b/tests_v1/plugins/model_plugins/test_init_plugin.py index 1b1ec104c..7c4e154e8 100644 --- a/tests_v1/plugins/model_plugins/test_init_plugin.py +++ b/tests_v1/plugins/model_plugins/test_init_plugin.py @@ -52,3 +52,12 @@ def test_init_on_default(): ) model_engine = ModelEngine(model_args=model_args) assert model_engine.model.device == DistributedInterface().current_device + + +if __name__ == "__main__": + """ + python tests_v1/plugins/model_plugins/test_init_plugin.py + """ + test_init_on_meta() + test_init_on_rank0() + test_init_on_default() diff --git a/tests_v1/sampler/test_cli_sampler.py b/tests_v1/sampler/test_cli_sampler.py index 69f75a88d..9f858e1f9 100644 --- a/tests_v1/sampler/test_cli_sampler.py +++ b/tests_v1/sampler/test_cli_sampler.py @@ -38,4 +38,7 @@ def test_sync_sampler(): if __name__ == "__main__": + """ + python tests_v1/sampler/test_cli_sampler.py + """ test_sync_sampler()