diff --git a/.gitignore b/.gitignore index f497809b7..b9013087f 100644 --- a/.gitignore +++ b/.gitignore @@ -176,6 +176,7 @@ llamaboard_cache/ llamaboard_config/ saves/ output/ +outputs/ wandb/ swanlog/ generated_predictions.jsonl diff --git a/src/llamafactory/v1/accelerator/interface.py b/src/llamafactory/v1/accelerator/interface.py index b464198b2..6a4c51962 100644 --- a/src/llamafactory/v1/accelerator/interface.py +++ b/src/llamafactory/v1/accelerator/interface.py @@ -174,7 +174,7 @@ class DistributedInterface: """Get device mesh for specified dimension.""" if dim is None: raise ValueError("dim must be specified.") - elif self.model_device_mesh is None: + elif not self._is_distributed: return None elif dim in self.strategy.data_mesh_dim_names: return self.data_device_mesh[dim.value] @@ -183,14 +183,14 @@ class DistributedInterface: def get_group(self, dim: Dim | None = None) -> Optional[ProcessGroup]: """Get process group for specified dimension.""" - if self.model_device_mesh is None or dim is None: + if not self._is_distributed or dim is None: return None else: return self.get_device_mesh(dim).get_group() def get_rank(self, dim: Dim | None = None) -> int: """Get parallel rank for specified dimension.""" - if self.model_device_mesh is None: + if not self._is_distributed: return 0 elif dim is None: return self._rank @@ -199,7 +199,7 @@ class DistributedInterface: def get_world_size(self, dim: Dim | None = None) -> int: """Get parallel size for specified dimension.""" - if self.model_device_mesh is None: + if not self._is_distributed: return 1 elif dim is None: return self._world_size @@ -216,7 +216,7 @@ class DistributedInterface: def all_gather(self, data: TensorLike, dim: Dim | None = Dim.DP) -> TensorLike: """Gather tensor across specified parallel group.""" - if self.model_device_mesh is not None: + if self._is_distributed: return helper.operate_tensorlike(helper.all_gather, data, group=self.get_group(dim)) else: return data @@ -225,29 +225,32 @@ class DistributedInterface: self, data: TensorLike, op: helper.ReduceOp = helper.ReduceOp.MEAN, dim: Dim | None = Dim.DP ) -> TensorLike: """Reduce tensor across specified parallel group.""" - if self.model_device_mesh is not None: + if self._is_distributed: return helper.operate_tensorlike(helper.all_reduce, data, op=op, group=self.get_group(dim)) else: return data def broadcast(self, data: TensorLike, src: int = 0, dim: Dim | None = Dim.DP) -> TensorLike: """Broadcast tensor across specified parallel group.""" - if self.model_device_mesh is not None: + if self._is_distributed: return helper.operate_tensorlike(helper.broadcast, data, src=src, group=self.get_group(dim)) else: return data def sync(self) -> None: """Synchronize all processes.""" - helper.synchronize() + if self._is_distributed: + helper.synchronize() def barrier(self) -> None: """Barrier all processes.""" - barrier() + if self._is_distributed: + barrier() def destroy(self) -> None: """Destroy all processes.""" - destroy_process_group() + if self._is_distributed: + destroy_process_group() if __name__ == "__main__": diff --git a/src/llamafactory/v1/config/arg_parser.py b/src/llamafactory/v1/config/arg_parser.py index aee30efaf..38eddd54a 100644 --- a/src/llamafactory/v1/config/arg_parser.py +++ b/src/llamafactory/v1/config/arg_parser.py @@ -30,9 +30,9 @@ from .training_args import TrainingArguments InputArgument = dict[str, Any] | list[str] | None -def get_args(args: InputArgument = None) -> tuple[DataArguments, ModelArguments, TrainingArguments, SampleArguments]: +def get_args(args: InputArgument = None) -> tuple[ModelArguments, DataArguments, TrainingArguments, SampleArguments]: """Parse arguments from command line or config file.""" - parser = HfArgumentParser([DataArguments, ModelArguments, TrainingArguments, SampleArguments]) + parser = HfArgumentParser([ModelArguments, DataArguments, TrainingArguments, SampleArguments]) allow_extra_keys = is_env_enabled("ALLOW_EXTRA_KEYS") if args is None: diff --git a/src/llamafactory/v1/config/data_args.py b/src/llamafactory/v1/config/data_args.py index f9b5d06a9..8693df429 100644 --- a/src/llamafactory/v1/config/data_args.py +++ b/src/llamafactory/v1/config/data_args.py @@ -18,7 +18,11 @@ from dataclasses import dataclass, field @dataclass class DataArguments: - dataset: str | None = field( + train_dataset: str | None = field( default=None, - metadata={"help": "Path to the dataset."}, + metadata={"help": "Path to the training dataset."}, + ) + eval_dataset: str | None = field( + default=None, + metadata={"help": "Path to the evaluation dataset."}, ) diff --git a/src/llamafactory/v1/config/training_args.py b/src/llamafactory/v1/config/training_args.py index e937170f7..67d09653c 100644 --- a/src/llamafactory/v1/config/training_args.py +++ b/src/llamafactory/v1/config/training_args.py @@ -33,13 +33,21 @@ class TrainingArguments: default=None, metadata={"help": "Global batch size for training, default to DP size * micro batch size."}, ) + cutoff_len: int = field( + default=2048, + metadata={"help": "Maximum sequence length for training."}, + ) learning_rate: float = field( default=1e-4, metadata={"help": "Learning rate for training."}, ) - cutoff_len: int = field( - default=2048, - metadata={"help": "Maximum sequence length for training."}, + num_train_epochs: int = field( + default=3, + metadata={"help": "Number of training epochs."}, + ) + max_grad_norm: float = field( + default=1.0, + metadata={"help": "Maximum gradient norm for training."}, ) bf16: bool = field( default=False, @@ -53,10 +61,24 @@ class TrainingArguments: default=16, metadata={"help": "Number of workers for batching."}, ) + enable_activation_checkpointing: bool = field( + default=True, + metadata={"help": "Enable activation checkpointing for training."}, + ) dist_config: PluginConfig | None = field( default=None, metadata={"help": "Distribution configuration for training."}, ) + optim_config: PluginConfig | None = field( + default=None, + metadata={"help": "Optimizer configuration for training."}, + ) + lr_scheduler_config: PluginConfig | None = field( + default=None, + metadata={"help": "Learning rate scheduler configuration for training."}, + ) def __post_init__(self) -> None: self.dist_config = get_plugin_config(self.dist_config) + self.optim_config = get_plugin_config(self.optim_config) + self.lr_scheduler_config = get_plugin_config(self.lr_scheduler_config) diff --git a/src/llamafactory/v1/core/base_sampler.py b/src/llamafactory/v1/core/base_sampler.py index efa90bec1..1080b3f13 100644 --- a/src/llamafactory/v1/core/base_sampler.py +++ b/src/llamafactory/v1/core/base_sampler.py @@ -12,115 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio -import os -from abc import ABC, abstractmethod from collections.abc import AsyncGenerator -from threading import Thread -import torch -from transformers import AsyncTextIteratorStreamer - -from ..accelerator.interface import DistributedInterface from ..config import ModelArguments, SampleArguments, SampleBackend -from ..utils.helper import get_tokenizer from ..utils.types import HFModel, Message, Sample, TorchDataset +from .utils.inference_engine import HuggingFaceEngine from .utils.rendering import Renderer -class BaseEngine(ABC): - @abstractmethod - def __init__( - self, - args: SampleArguments, - model_args: ModelArguments, - model: HFModel, - renderer: Renderer, - ) -> None: - """Initialize the engine. - - Args: - args: Sample arguments. - model_args: Model arguments. - model: Model. - renderer: Renderer. - """ - ... - - @abstractmethod - async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]: - """Generate tokens asynchronously. - - Args: - messages: List of messages. - tools: Tools string. - - Yields: - Generated tokens. - """ - ... - - @abstractmethod - async def batch_infer(self, dataset: TorchDataset) -> list[Sample]: - """Batch infer samples. - - Args: - dataset: Torch dataset. - - Returns: - List of samples. - """ - ... - - -class HuggingFaceEngine(BaseEngine): - def __init__( - self, - args: SampleArguments, - model_args: ModelArguments, - model: HFModel, - renderer: Renderer, - ) -> None: - self.args = args - self.model_args = model_args - self.model = model - self.renderer = renderer - self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1"))) - - @torch.inference_mode() - async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]: - async with self.semaphore: - model_inputs = self.renderer.render_messages(messages, tools, is_generate=True) - streamer = AsyncTextIteratorStreamer( - tokenizer=get_tokenizer(self.renderer.processor), - skip_prompt=True, - skip_special_tokens=True, # TODO: configurable - ) - device = DistributedInterface().current_device - kwargs = { - "input_ids": torch.tensor([model_inputs["input_ids"]]).to(device), - "attention_mask": torch.tensor([model_inputs["attention_mask"]]).to(device), - "max_new_tokens": self.args.max_new_tokens, - "streamer": streamer, - } - thread = Thread(target=self.model.generate, kwargs=kwargs, daemon=True) - thread.start() - - async for token in streamer: - yield token - - async def batch_infer(self, dataset: TorchDataset) -> list[Sample]: - """Batch infer samples. - - Args: - dataset: Torch dataset. - - Returns: - List of samples. - """ - raise NotImplementedError("Batch infer is not implemented.") - - class BaseSampler: """Base sampler. diff --git a/src/llamafactory/v1/core/base_trainer.py b/src/llamafactory/v1/core/base_trainer.py index df9af26e4..e474b37a9 100644 --- a/src/llamafactory/v1/core/base_trainer.py +++ b/src/llamafactory/v1/core/base_trainer.py @@ -16,42 +16,166 @@ Init Phase: -1. Init dataloader. +1. Init batch generator. 2. Init optimizer (deepspeed). 3. Shard model. 4. Init optimizer (fsdp). -5. Init scheduler. +5. Init lr scheduler. Train Phase: 1. Train Loop """ -from ..config.training_args import TrainingArguments -from ..utils.types import HFModel, TorchDataset +from abc import abstractmethod + +import torch +import torch.nn.functional as F + +from ..accelerator.helper import ReduceOp +from ..accelerator.interface import Dim, DistributedInterface +from ..config import TrainingArguments +from ..utils import logging +from ..utils.helper import compute_valid_tokens +from ..utils.types import BatchInput, HFModel, ModelOutput, Tensor, TorchDataset +from .utils.batching import BatchGenerator from .utils.rendering import Renderer +logger = logging.get_logger(__name__) + + class BaseTrainer: def __init__( self, args: TrainingArguments, model: HFModel, renderer: Renderer, - dataset: TorchDataset, + train_dataset: TorchDataset, ) -> None: self.args = args self.model = model self.renderer = renderer - self.dataset = dataset - self.optimizer = None - self.lr_scheduler = None + self.train_dataset = train_dataset - def _create_dataloader(self) -> None: + # info + self.global_step = 0 + + # cached variables + self.device = DistributedInterface().current_device + self.dp_size = DistributedInterface().get_world_size(Dim.DP) + self.model_input_names = self.renderer.processor.model_input_names + + self._create_batch_generator() + self.num_training_steps = self.args.num_train_epochs * len(self.train_batch_generator) + + if self.args.enable_activation_checkpointing: + self.model.gradient_checkpointing_enable({"use_reentrant": False}) + + if self.args.dist_config is not None: + shard_need_optimizer = self.args.dist_config.name == "deepspeed" + else: + shard_need_optimizer = False + + if shard_need_optimizer: + self._init_optimizer() + self._shard_model() + else: + self._shard_model() + self._init_optimizer() + + self._init_lr_scheduler() + + def _create_batch_generator(self) -> None: + self.train_batch_generator = BatchGenerator( + dataset=self.train_dataset, + renderer=self.renderer, + micro_batch_size=self.args.micro_batch_size, + global_batch_size=self.args.global_batch_size, + cutoff_len=self.args.cutoff_len, + batching_workers=self.args.batching_workers, + batching_strategy=self.args.batching_strategy, + ) + + def _shard_model(self) -> None: pass - def _init_model_and_optimizer(self) -> None: - pass + def _init_optimizer(self) -> None: + """Init optimizer.""" + if self.args.optim_config is None: + _trainable_params = [p for p in self.model.parameters() if p.requires_grad] + self.optimizer = torch.optim.AdamW(_trainable_params, lr=self.args.learning_rate) + else: + from ..plugins.trainer_plugins.optimizer import OptimizerPlugin + + self.optimizer = OptimizerPlugin(self.args.optim_config.name)(self.model, self.args.optim_config) + + def _init_lr_scheduler(self) -> None: + """Init lr scheduler.""" + if self.args.lr_scheduler_config is None: + self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda x: 1.0) + else: + from ..plugins.trainer_plugins.lr_scheduler import LRSchedulerPlugin + + self.lr_scheduler = LRSchedulerPlugin(self.args.lr_scheduler_config.name)( + self.optimizer, self.num_training_steps, self.args.lr_scheduler_config + ) + + def compute_log_probs(self, model: HFModel, batch: BatchInput) -> Tensor: + """Compute log probs. + + log_probs: Tensor of shape (batch_size, seq_len - 1) + """ + batch_size, _ = batch["labels"].shape + model_inputs = { + k: v.to(self.device, non_blocking=True) for k, v in batch.items() if k in self.model_input_names + } + labels = batch["labels"].to(self.device, non_blocking=True) + outputs: ModelOutput = model(**model_inputs) + logits = outputs.logits.float() + shift_labels = labels[..., 1:].contiguous().view(-1) + shift_logits = logits[..., :-1, :].contiguous().view(shift_labels.size(0), -1) + return -F.cross_entropy(shift_logits, shift_labels, reduction="none").view(batch_size, -1) + + @abstractmethod + def compute_loss(self, batch: BatchInput) -> Tensor: + """Compute the scalar loss.""" + ... def fit(self) -> None: - pass + """Train the model.""" + self.model.train() + for epoch in range(self.args.num_train_epochs): + self.train_batch_generator.set_epoch(epoch) + for micro_batches in self.train_batch_generator: + self.global_step += 1 + step_loss = 0 + step_valid_tokens = compute_valid_tokens(micro_batches) + step_valid_tokens = DistributedInterface().all_reduce(step_valid_tokens, op=ReduceOp.SUM) + for micro_batch in micro_batches: + loss = self.compute_loss(micro_batch) + mini_step_valid_tokens = compute_valid_tokens([micro_batch]) + # fsdp uses mean reduction so we need to scale the loss by dp_size + loss = loss * mini_step_valid_tokens * self.dp_size / (step_valid_tokens + 1e-6) + + loss.backward() + step_loss += loss.item() + + grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm).item() + if not torch.isfinite(grad_norm): + logger.warning_rank0(f"Gradient norm is not finite: {grad_norm}") + else: + self.optimizer.step() + + self.lr_scheduler.step() + self.optimizer.zero_grad() + + step_loss, grad_norm = DistributedInterface().all_reduce([step_loss, grad_norm]) + DistributedInterface().sync() + print(f"Epoch {epoch}, Step {self.global_step}, Loss: {step_loss:.4f}, Grad Norm: {grad_norm:.4f}") + + def save_model(self) -> None: + """Save the model.""" + self.model.save_pretrained(self.args.output_dir) + self.renderer.processor.save_pretrained(self.args.output_dir) + logger.info_rank0(f"Model saved to {self.args.output_dir}") diff --git a/src/llamafactory/v1/core/data_engine.py b/src/llamafactory/v1/core/data_engine.py index 60a2667d0..b2cbf536e 100644 --- a/src/llamafactory/v1/core/data_engine.py +++ b/src/llamafactory/v1/core/data_engine.py @@ -15,7 +15,7 @@ """The definition of data engine. How to use: -data_engine = DataEngine(data_args) +data_engine = DataEngine(data_args.train_dataset) data_engine[i]: Get the sample via index. Init workflow: @@ -41,7 +41,6 @@ from huggingface_hub import hf_hub_download from omegaconf import OmegaConf from torch.utils.data import Dataset -from ..config.data_args import DataArguments from ..utils.types import DatasetInfo, HFDataset, Sample @@ -52,9 +51,9 @@ class DataEngine(Dataset): data_args: Data arguments. """ - def __init__(self, data_args: DataArguments) -> None: - self.args = data_args - """Data arguments.""" + def __init__(self, dataset_path: str) -> None: + self.path = dataset_path + """Dataset path.""" self.datasets: dict[str, HFDataset] = {} """Dict of (dataset_name, dataset)""" self.dataset_infos: dict[str, DatasetInfo] = {} @@ -69,16 +68,16 @@ class DataEngine(Dataset): def _get_dataset_info(self) -> None: """Get dataset info from data arguments.""" - if self.args.dataset.endswith(".yaml") and os.path.isfile(self.args.dataset): # local file - self.dataset_infos = OmegaConf.load(self.args.dataset) - elif self.args.dataset.endswith(".yaml"): # hf hub uri, e.g. llamafactory/v1-sft-demo/dataset_info.yaml - repo_id, filename = os.path.split(self.args.dataset) + if self.path.endswith(".yaml") and os.path.isfile(self.path): # local file + self.dataset_infos = OmegaConf.load(self.path) + elif self.path.endswith(".yaml"): # hf hub uri, e.g. llamafactory/v1-sft-demo/dataset_info.yaml + repo_id, filename = os.path.split(self.path) filepath = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset") self.dataset_infos = OmegaConf.load(filepath) - elif os.path.exists(self.args.dataset): # local file(s) - self.dataset_infos = {"default": {"path": self.args.dataset, "source": "local"}} + elif os.path.exists(self.path): # local file(s) + self.dataset_infos = {"default": {"path": self.path, "source": "local"}} else: # hf hub dataset, e.g. llamafactory/v1-sft-demo - self.dataset_infos = {"default": {"path": self.args.dataset}} + self.dataset_infos = {"default": {"path": self.path}} def _load_dataset(self) -> None: """Load datasets according to dataset info.""" @@ -187,11 +186,11 @@ class DataEngine(Dataset): if __name__ == "__main__": """ - python -m llamafactory.v1.core.data_engine --dataset data/v1_sft_demo.yaml - python -m llamafactory.v1.core.data_engine --dataset data/v1_dpo_demo.yaml + python -m llamafactory.v1.core.data_engine --train_dataset data/v1_sft_demo.yaml + python -m llamafactory.v1.core.data_engine --train_dataset data/v1_dpo_demo.yaml """ from ..config.arg_parser import get_args - data_args, *_ = get_args() - data_engine = DataEngine(data_args=data_args) + _, data_args, *_ = get_args() + data_engine = DataEngine(data_args.train_dataset) print(data_engine[0]) diff --git a/src/llamafactory/v1/core/model_engine.py b/src/llamafactory/v1/core/model_engine.py index 849d37fa9..0a5a00204 100644 --- a/src/llamafactory/v1/core/model_engine.py +++ b/src/llamafactory/v1/core/model_engine.py @@ -153,7 +153,7 @@ if __name__ == "__main__": """ from ..config.arg_parser import get_args - _, model_args, *_ = get_args() + model_args, *_ = get_args() model_engine = ModelEngine(model_args=model_args) print(model_engine.processor) print(model_engine.model_config) diff --git a/src/llamafactory/v1/core/utils/batching.py b/src/llamafactory/v1/core/utils/batching.py index 65511288d..100bae4e5 100644 --- a/src/llamafactory/v1/core/utils/batching.py +++ b/src/llamafactory/v1/core/utils/batching.py @@ -216,7 +216,7 @@ if __name__ == "__main__": """ python -m llamafactory.v1.core.utils.batching \ --model llamafactory/tiny-random-qwen2.5 \ - --dataset data/v1_sft_demo.yaml \ + --train_dataset data/v1_sft_demo.yaml \ --micro_batch_size 2 \ --global_batch_size 4 \ --batching_workers 0 @@ -225,8 +225,8 @@ if __name__ == "__main__": from ..data_engine import DataEngine from ..model_engine import ModelEngine - data_args, model_args, training_args, _ = get_args() - data_engine = DataEngine(data_args=data_args) + model_args, data_args, training_args, _ = get_args() + data_engine = DataEngine(data_args.train_dataset) model_engine = ModelEngine(model_args=model_args) batch_generator = BatchGenerator( data_engine, diff --git a/src/llamafactory/v1/core/utils/data_collator.py b/src/llamafactory/v1/core/utils/data_collator.py deleted file mode 100644 index a425f02f9..000000000 --- a/src/llamafactory/v1/core/utils/data_collator.py +++ /dev/null @@ -1,119 +0,0 @@ -# Copyright 2025 the LlamaFactory team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from collections import defaultdict -from collections.abc import Sequence -from dataclasses import dataclass -from typing import Any - -import torch -import torch.nn.functional as F -from torch.nn.utils.rnn import pad_sequence -from torch.utils.data._utils.collate import default_collate - -from ....extras.constants import IGNORE_INDEX -from ...plugins.data_plugins.template import Template -from ...utils.types import Processor, Tensor - - -def len2culen(seqlens: "torch.Tensor") -> "torch.Tensor": # FIXME move to utils - """Convert sequence lengths to cumulative sequence lengths.""" - return F.pad(torch.cumsum(seqlens, dim=0), (1, 0)).type(torch.int32) - - -class DataCollator: - """Default Data collator.""" - - processor: "Processor" # processor name -> map to encode_messages function - - def __post_init__(self): - # callback for text tokenizer - self.tokenizer = self.processor.tokenizer if hasattr(self.processor, "tokenizer") else self.processor - - def __call__(self, features: list[dict[str, Any]]) -> dict[str, Tensor]: - """Collate features into a batch.""" - batch = defaultdict(list) - - # batching features - for feature in features: - for key in feature.keys(): - batch[key].append(feature[key]) - - for key in batch.keys(): - # process padding features - if key in ["input_ids", "attention_mask", "position_ids"]: - padding_value = self.tokenizer.pad_token_id if key == "input_ids" else 0 - batch[key] = pad_sequence(batch[key], batch_first=True, padding_value=padding_value) - elif key in ["labels"]: - batch[key] = pad_sequence(batch[key], batch_first=True, padding_value=IGNORE_INDEX) - else: - batch[key] = default_collate(batch[key]) - - return batch - # sft: messages - # dpo: chosen_messages, rejected_messages - - -@dataclass -class DefaultCollator(DataCollator): - """Example for now.""" - - processor: "Processor" # processor name -> map to encode_messages function - template: "Template" - - def __call__(self, messages: list[list[dict[str, Any]]]) -> dict[str, Tensor]: - features = [] - - # Check if data is already tokenized (contains input_ids) - if messages and isinstance(messages[0], dict) and "input_ids" in messages[0]: - for feature in messages: - if not isinstance(feature, dict): - raise ValueError(f"Expected dict but got {type(feature)}") - tensor_feature = { - k: torch.tensor(v, dtype=torch.long) if not isinstance(v, torch.Tensor) else v - for k, v in feature.items() - } - features.append(tensor_feature) - else: - # raw messages need to be encoded - for message in messages: - encoded_message = self.template.encode_messages(self.tokenizer, message) - encoded_message = {k: torch.tensor(v, dtype=torch.long) for k, v in encoded_message.items()} - features.append(encoded_message) - - return super().__call__(features) - - -@dataclass -class PairwiseCollator(DataCollator): - pass - - -@dataclass -class DataCollatorWithPacking(DefaultCollator): - """Data collator with packing.""" - - processor: "Processor" - template: "Template" - - def __call__(self, features: Sequence[dict[str, "torch.Tensor"]]) -> dict[str, "torch.Tensor"]: - seqlens = torch.tensor([len(feature["input_ids"]) for feature in features], dtype=torch.long) - batch = {"cu_seqlens": len2culen(seqlens)} - for input_name in features[0].keys(): - if input_name in ("input_ids", "attention_mask", "labels"): - batch[input_name] = torch.cat([feature[input_name] for feature in features]) - else: - batch[input_name] = default_collate([feature[input_name] for feature in features]) - - return batch diff --git a/src/llamafactory/v1/core/utils/inference_engine.py b/src/llamafactory/v1/core/utils/inference_engine.py new file mode 100644 index 000000000..b090dae81 --- /dev/null +++ b/src/llamafactory/v1/core/utils/inference_engine.py @@ -0,0 +1,121 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import os +from abc import ABC, abstractmethod +from collections.abc import AsyncGenerator +from threading import Thread + +import torch +from transformers import AsyncTextIteratorStreamer + +from ...accelerator.interface import DistributedInterface +from ...config import ModelArguments, SampleArguments +from ...utils.helper import get_tokenizer +from ...utils.types import HFModel, Message, Sample, TorchDataset +from .rendering import Renderer + + +class BaseEngine(ABC): + @abstractmethod + def __init__( + self, + args: SampleArguments, + model_args: ModelArguments, + model: HFModel, + renderer: Renderer, + ) -> None: + """Initialize the engine. + + Args: + args: Sample arguments. + model_args: Model arguments. + model: Model. + renderer: Renderer. + """ + ... + + @abstractmethod + async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]: + """Generate tokens asynchronously. + + Args: + messages: List of messages. + tools: Tools string. + + Yields: + Generated tokens. + """ + ... + + @abstractmethod + async def batch_infer(self, dataset: TorchDataset) -> list[Sample]: + """Batch infer samples. + + Args: + dataset: Torch dataset. + + Returns: + List of samples. + """ + ... + + +class HuggingFaceEngine(BaseEngine): + def __init__( + self, + args: SampleArguments, + model_args: ModelArguments, + model: HFModel, + renderer: Renderer, + ) -> None: + self.args = args + self.model_args = model_args + self.model = model + self.renderer = renderer + self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1"))) + + @torch.inference_mode() + async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]: + async with self.semaphore: + model_inputs = self.renderer.render_messages(messages, tools, is_generate=True) + streamer = AsyncTextIteratorStreamer( + tokenizer=get_tokenizer(self.renderer.processor), + skip_prompt=True, + skip_special_tokens=True, # TODO: configurable + ) + device = DistributedInterface().current_device + kwargs = { + "input_ids": torch.tensor([model_inputs["input_ids"]]).to(device), + "attention_mask": torch.tensor([model_inputs["attention_mask"]]).to(device), + "max_new_tokens": self.args.max_new_tokens, + "streamer": streamer, + } + thread = Thread(target=self.model.generate, kwargs=kwargs, daemon=True) + thread.start() + + async for token in streamer: + yield token + + async def batch_infer(self, dataset: TorchDataset) -> list[Sample]: + """Batch infer samples. + + Args: + dataset: Torch dataset. + + Returns: + List of samples. + """ + raise NotImplementedError("Batch infer is not implemented.") diff --git a/src/llamafactory/v1/core/utils/lr_scheduler.py b/src/llamafactory/v1/core/utils/lr_scheduler.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/llamafactory/v1/core/utils/rendering.py b/src/llamafactory/v1/core/utils/rendering.py index b4c0b02d6..cbe8383c4 100644 --- a/src/llamafactory/v1/core/utils/rendering.py +++ b/src/llamafactory/v1/core/utils/rendering.py @@ -142,8 +142,8 @@ class Renderer: elif "chosen_messages" in sample and "rejected_messages" in sample: chosen_input = self.render_messages(sample["chosen_messages"], sample.get("tools")) rejected_input = self.render_messages(sample["rejected_messages"], sample.get("tools")) - chosen_input["token_type_ids"] = [0] * len(chosen_input["input_ids"]) - rejected_input["token_type_ids"] = [1] * len(rejected_input["input_ids"]) + chosen_input["token_type_ids"] = [1] * len(chosen_input["input_ids"]) + rejected_input["token_type_ids"] = [2] * len(rejected_input["input_ids"]) model_input = ModelInput( input_ids=chosen_input["input_ids"] + rejected_input["input_ids"], attention_mask=chosen_input["attention_mask"] + rejected_input["attention_mask"], diff --git a/src/llamafactory/v1/plugins/trainer_plugins/batching.py b/src/llamafactory/v1/plugins/trainer_plugins/batching.py index f61de78d1..aef22eac2 100644 --- a/src/llamafactory/v1/plugins/trainer_plugins/batching.py +++ b/src/llamafactory/v1/plugins/trainer_plugins/batching.py @@ -18,8 +18,11 @@ from ...utils.types import BatchInfo, BatchInput, DataLoader class BatchingPlugin(BasePlugin): - def compute_length(self, dataloader: DataLoader) -> int: - """Compute the length of the batch generator.""" + def compute_length(self, data_provider: DataLoader) -> int: + """Compute the length of the batch generator. + + The approximate length is used to calculate the lr schedule. + """ raise NotImplementedError() def fill_buffer(self, buffer: StatefulBuffer, batch_info: BatchInfo) -> None: diff --git a/src/llamafactory/v1/plugins/trainer_plugins/lr_scheduler.py b/src/llamafactory/v1/plugins/trainer_plugins/lr_scheduler.py new file mode 100644 index 000000000..02c9e8b03 --- /dev/null +++ b/src/llamafactory/v1/plugins/trainer_plugins/lr_scheduler.py @@ -0,0 +1,19 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils.plugin import BasePlugin + + +class LRSchedulerPlugin(BasePlugin): + pass diff --git a/src/llamafactory/v1/plugins/trainer_plugins/optimizer.py b/src/llamafactory/v1/plugins/trainer_plugins/optimizer.py new file mode 100644 index 000000000..d040b0e29 --- /dev/null +++ b/src/llamafactory/v1/plugins/trainer_plugins/optimizer.py @@ -0,0 +1,19 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils.plugin import BasePlugin + + +class OptimizerPlugin(BasePlugin): + pass diff --git a/src/llamafactory/v1/samplers/cli_sampler.py b/src/llamafactory/v1/samplers/cli_sampler.py index d102ac84b..40647165f 100644 --- a/src/llamafactory/v1/samplers/cli_sampler.py +++ b/src/llamafactory/v1/samplers/cli_sampler.py @@ -73,14 +73,14 @@ class SyncSampler(BaseSampler): def run_chat(args: InputArgument = None): - data_args, model_args, _, sample_args = get_args(args) + model_args, data_args, _, sample_args = get_args(args) if sample_args.sample_backend != SampleBackend.HF: model_args.init_plugin = {"name": "init_on_meta"} model_engine = ModelEngine(model_args) sampler = SyncSampler(sample_args, model_args, model_engine.model, model_engine.renderer) - if data_args.dataset is not None: - dataset = DataEngine(data_args) + if data_args.train_dataset is not None: + dataset = DataEngine(data_args.train_dataset) sampler.batch_infer(dataset) else: if os.name != "nt": diff --git a/src/llamafactory/v1/trainers/sft_trainer.py b/src/llamafactory/v1/trainers/sft_trainer.py index 3a91c4f4c..898b54f98 100644 --- a/src/llamafactory/v1/trainers/sft_trainer.py +++ b/src/llamafactory/v1/trainers/sft_trainer.py @@ -18,21 +18,35 @@ from ..config import InputArgument, get_args from ..core.base_trainer import BaseTrainer from ..core.data_engine import DataEngine from ..core.model_engine import ModelEngine +from ..utils.types import BatchInput, Tensor class SFTTrainer(BaseTrainer): - pass + def compute_loss(self, batch: BatchInput) -> Tensor: + shift_loss_weights = batch["loss_weights"].to(self.device, non_blocking=True)[..., 1:] + log_probs = self.compute_log_probs(self.model, batch) + loss = (-log_probs * shift_loss_weights).sum() / (shift_loss_weights.sum() + 1e-6) + return loss def run_sft(args: InputArgument = None): model_args, data_args, training_args, _ = get_args(args) DistributedInterface(training_args.dist_config) - data_engine = DataEngine(data_args) + train_dataset = DataEngine(data_args.train_dataset) model_engine = ModelEngine(model_args) trainer = SFTTrainer( args=training_args, model=model_engine.model, renderer=model_engine.renderer, - dataset=data_engine, + train_dataset=train_dataset, ) trainer.fit() + trainer.save_model() + DistributedInterface().destroy() + + +if __name__ == "__main__": + """ + python -m llamafactory.v1.trainers.sft_trainer --model Qwen/Qwen3-0.6B --train_dataset data/v1_sft_demo.yaml + """ + run_sft() diff --git a/src/llamafactory/v1/utils/helper.py b/src/llamafactory/v1/utils/helper.py index 8b94e71a7..3f7b75505 100644 --- a/src/llamafactory/v1/utils/helper.py +++ b/src/llamafactory/v1/utils/helper.py @@ -16,6 +16,7 @@ import torch from transformers import PreTrainedTokenizer +from ..accelerator.interface import DistributedInterface from .constants import IGNORE_INDEX from .types import BatchInput, ModelInput, Processor, Tensor @@ -73,3 +74,20 @@ def pad_and_truncate(samples: list[ModelInput], max_seqlen: int) -> list[BatchIn padded_samples.append(padded_sample) return padded_samples + + +def compute_valid_tokens(batches: list[BatchInput]) -> int: + """Compute valid tokens in batches. + + Args: + batches: Batches. + + Returns: + Number of valid tokens. + """ + device = DistributedInterface().current_device + return sum( + (batch["labels"].to(device, non_blocking=True) != IGNORE_INDEX).sum().item() + for batch in batches + if "labels" in batch + ) diff --git a/src/llamafactory/v1/utils/types.py b/src/llamafactory/v1/utils/types.py index d2a1e52f3..2f3906968 100644 --- a/src/llamafactory/v1/utils/types.py +++ b/src/llamafactory/v1/utils/types.py @@ -13,7 +13,7 @@ # limitations under the License. from collections.abc import Iterator -from typing import TYPE_CHECKING, Any, Literal, NotRequired, TypedDict, Union +from typing import TYPE_CHECKING, Any, Literal, NamedTuple, NotRequired, TypedDict, Union if TYPE_CHECKING: @@ -146,7 +146,7 @@ class ModelInput(TypedDict, total=False): position_ids: NotRequired[list[int] | list[list[int]]] """Position ids for the model (optional).""" token_type_ids: NotRequired[list[int]] - """Token type ids used in DPO, 0 represents the chosen messages, 1 represents the rejected messages.""" + """Token type ids used in DPO, 1 represents the chosen messages, 2 represents the rejected messages.""" class BatchInput(TypedDict, total=False): @@ -161,7 +161,7 @@ class BatchInput(TypedDict, total=False): position_ids: NotRequired[Tensor] """Position ids for the model (optional).""" token_type_ids: NotRequired[Tensor] - """Token type ids used in DPO, 0 represents the chosen messages, 1 represents the rejected messages.""" + """Token type ids used in DPO, 1 represents the chosen messages, 2 represents the rejected messages.""" class BatchInfo(TypedDict): @@ -173,3 +173,8 @@ class BatchInfo(TypedDict): """Cutoff length.""" data_iter: Iterator[list[ModelInput]] """Data iterator.""" + + +class ModelOutput(NamedTuple): + logits: Tensor + """Logits for the model.""" diff --git a/tests/version.txt b/tests/version.txt index e082f373b..a1f9032b0 100644 --- a/tests/version.txt +++ b/tests/version.txt @@ -1,2 +1,2 @@ # change if test fails or cache is outdated -0.9.5.104 +0.9.5.105 diff --git a/tests_v1/config/test_args_parser.py b/tests_v1/config/test_args_parser.py index 8772f92d3..db235ab54 100644 --- a/tests_v1/config/test_args_parser.py +++ b/tests_v1/config/test_args_parser.py @@ -34,7 +34,7 @@ def test_get_args_from_yaml(tmp_path: Path): quant_config: null ### data - dataset: llamafactory/v1-sft-demo + train_dataset: llamafactory/v1-sft-demo ### training output_dir: outputs/test_run @@ -56,8 +56,8 @@ def test_get_args_from_yaml(tmp_path: Path): test_argv = ["test_args_parser.py", str(config_file)] with patch.object(sys, "argv", test_argv): - data_args, model_args, training_args, sample_args = get_args() - assert data_args.dataset == "llamafactory/v1-sft-demo" + model_args, data_args, training_args, sample_args = get_args() + assert data_args.train_dataset == "llamafactory/v1-sft-demo" assert model_args.model == "llamafactory/tiny-random-qwen3" assert model_args.kernel_config.name == "auto" assert model_args.kernel_config.get("include_kernels") == "auto" diff --git a/tests_v1/core/test_data_engine.py b/tests_v1/core/test_data_engine.py index d1b2223ad..373069a66 100644 --- a/tests_v1/core/test_data_engine.py +++ b/tests_v1/core/test_data_engine.py @@ -23,8 +23,8 @@ from llamafactory.v1.core.data_engine import DataEngine @pytest.mark.parametrize("num_samples", [16]) def test_map_dataset(num_samples: int): - data_args = DataArguments(dataset="llamafactory/v1-sft-demo") - data_engine = DataEngine(data_args) + data_args = DataArguments(train_dataset="llamafactory/v1-sft-demo") + data_engine = DataEngine(data_args.train_dataset) original_data = load_dataset("llamafactory/v1-sft-demo", split="train") indexes = random.choices(range(len(data_engine)), k=num_samples) for index in indexes: diff --git a/tests_v1/core/utils/test_batching.py b/tests_v1/core/utils/test_batching.py index ba74d21d7..87e8a89cb 100644 --- a/tests_v1/core/utils/test_batching.py +++ b/tests_v1/core/utils/test_batching.py @@ -19,8 +19,8 @@ from llamafactory.v1.core.utils.batching import BatchGenerator def test_normal_batching(): - data_args = DataArguments(dataset="llamafactory/v1-sft-demo") - data_engine = DataEngine(data_args=data_args) + data_args = DataArguments(train_dataset="llamafactory/v1-sft-demo") + data_engine = DataEngine(data_args.train_dataset) model_args = ModelArguments(model="llamafactory/tiny-random-qwen3") model_engine = ModelEngine(model_args=model_args) training_args = TrainingArguments( diff --git a/tests_v1/core/utils/test_rendering.py b/tests_v1/core/utils/test_rendering.py index 3963ccb8a..f3e5f83c6 100644 --- a/tests_v1/core/utils/test_rendering.py +++ b/tests_v1/core/utils/test_rendering.py @@ -111,8 +111,8 @@ def test_chatml_parse(): def test_chatml_rendering_remote(num_samples: int): tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3") renderer = Renderer(template="chatml", processor=tokenizer) - data_args = DataArguments(dataset="llamafactory/v1-sft-demo") - data_engine = DataEngine(data_args) + data_args = DataArguments(train_dataset="llamafactory/v1-sft-demo") + data_engine = DataEngine(data_args.train_dataset) for index in range(num_samples): v1_inputs = renderer.render_messages(data_engine[index]["messages"], is_generate=True) prefix = tokenizer.encode("<|im_start|>user\n", add_special_tokens=False) @@ -167,8 +167,8 @@ def test_qwen3_nothink_parse(): def test_qwen3_nothink_rendering_remote(num_samples: int): tokenizer: Processor = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B-Instruct-2507") renderer = Renderer(template="qwen3_nothink", processor=tokenizer) - data_args = DataArguments(dataset="llamafactory/reason-tool-use-demo-1500") - data_engine = DataEngine(data_args) + data_args = DataArguments(train_dataset="llamafactory/reason-tool-use-demo-1500") + data_engine = DataEngine(data_args.train_dataset) for index in range(num_samples): v1_inputs = renderer.render_messages(data_engine[index]["messages"], tools=data_engine[index]["tools"]) prefix_text = ( @@ -213,7 +213,7 @@ def test_process_dpo_samples(): model_inputs = renderer.process_samples(samples) assert len(model_inputs) == 1 assert model_inputs[0]["input_ids"] == hf_inputs * 2 - assert model_inputs[0]["token_type_ids"] == [0] * len(hf_inputs) + [1] * len(hf_inputs) + assert model_inputs[0]["token_type_ids"] == [1] * len(hf_inputs) + [2] * len(hf_inputs) assert model_inputs[0]["extra_info"] == "test" assert model_inputs[0]["_dataset_name"] == "default" diff --git a/tests_v1/plugins/data_plugins/test_converter.py b/tests_v1/plugins/data_plugins/test_converter.py index 0f47d8e55..1722b4a67 100644 --- a/tests_v1/plugins/data_plugins/test_converter.py +++ b/tests_v1/plugins/data_plugins/test_converter.py @@ -24,8 +24,8 @@ from llamafactory.v1.plugins.data_plugins.converter import DataConverterPlugin @pytest.mark.parametrize("num_samples", [16]) def test_alpaca_converter(num_samples: int): - data_args = DataArguments(dataset="llamafactory/v1-dataset-info/tiny-supervised-dataset.yaml") - data_engine = DataEngine(data_args) + data_args = DataArguments(train_dataset="llamafactory/v1-dataset-info/tiny-supervised-dataset.yaml") + data_engine = DataEngine(data_args.train_dataset) original_data = load_dataset("llamafactory/tiny-supervised-dataset", split="train") indexes = random.choices(range(len(data_engine)), k=num_samples) for index in indexes: @@ -73,8 +73,8 @@ def test_sharegpt_converter(): @pytest.mark.parametrize("num_samples", [16]) def test_pair_converter(num_samples: int): - data_args = DataArguments(dataset="llamafactory/v1-dataset-info/orca-dpo-pairs.yaml") - data_engine = DataEngine(data_args) + data_args = DataArguments(train_dataset="llamafactory/v1-dataset-info/orca-dpo-pairs.yaml") + data_engine = DataEngine(data_args.train_dataset) original_data = load_dataset("HuggingFaceH4/orca_dpo_pairs", split="train_prefs") indexes = random.choices(range(len(data_engine)), k=num_samples) for index in indexes: diff --git a/tests_v1/plugins/model_plugins/test_init_plugin.py b/tests_v1/plugins/model_plugins/test_init_plugin.py index 7c4e154e8..947f18bd9 100644 --- a/tests_v1/plugins/model_plugins/test_init_plugin.py +++ b/tests_v1/plugins/model_plugins/test_init_plugin.py @@ -19,7 +19,7 @@ from llamafactory.v1.core.model_engine import ModelEngine def test_init_on_meta(): - _, model_args, *_ = get_args( + model_args, *_ = get_args( dict( model="llamafactory/tiny-random-qwen3", init_config={"name": "init_on_meta"}, @@ -30,7 +30,7 @@ def test_init_on_meta(): def test_init_on_rank0(): - _, model_args, *_ = get_args( + model_args, *_ = get_args( dict( model="llamafactory/tiny-random-qwen3", init_config={"name": "init_on_rank0"}, @@ -44,7 +44,7 @@ def test_init_on_rank0(): def test_init_on_default(): - _, model_args, *_ = get_args( + model_args, *_ = get_args( dict( model="llamafactory/tiny-random-qwen3", init_config={"name": "init_on_default"}, diff --git a/tests_v1/plugins/model_plugins/test_kernel_plugin.py b/tests_v1/plugins/model_plugins/test_kernel_plugin.py index 90524b7ba..10e8e413a 100644 --- a/tests_v1/plugins/model_plugins/test_kernel_plugin.py +++ b/tests_v1/plugins/model_plugins/test_kernel_plugin.py @@ -43,7 +43,8 @@ def test_apply_kernel(mock_get_accelerator: MagicMock): reload_kernels() from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_default_kernels - model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen3") + # NOTE: use a special model to avoid contamination by other tests + model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5") original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward original_swiglu_forward = model.model.layers[0].mlp.forward model = apply_default_kernels(model=model, include_kernels="npu_fused_rmsnorm") @@ -62,7 +63,8 @@ def test_apply_all_kernels(mock_get_accelerator: MagicMock): reload_kernels() from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_default_kernels - model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen3") + # NOTE: use a special model to avoid contamination by other tests + model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5") original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward original_swiglu_forward = model.model.layers[0].mlp.forward