[v1] add sft (#9752)

This commit is contained in:
Yaowei Zheng
2026-01-12 03:15:01 +08:00
committed by GitHub
parent 4d3621e3d3
commit 958b9c3468
29 changed files with 439 additions and 305 deletions

1
.gitignore vendored
View File

@@ -176,6 +176,7 @@ llamaboard_cache/
llamaboard_config/
saves/
output/
outputs/
wandb/
swanlog/
generated_predictions.jsonl

View File

@@ -174,7 +174,7 @@ class DistributedInterface:
"""Get device mesh for specified dimension."""
if dim is None:
raise ValueError("dim must be specified.")
elif self.model_device_mesh is None:
elif not self._is_distributed:
return None
elif dim in self.strategy.data_mesh_dim_names:
return self.data_device_mesh[dim.value]
@@ -183,14 +183,14 @@ class DistributedInterface:
def get_group(self, dim: Dim | None = None) -> Optional[ProcessGroup]:
"""Get process group for specified dimension."""
if self.model_device_mesh is None or dim is None:
if not self._is_distributed or dim is None:
return None
else:
return self.get_device_mesh(dim).get_group()
def get_rank(self, dim: Dim | None = None) -> int:
"""Get parallel rank for specified dimension."""
if self.model_device_mesh is None:
if not self._is_distributed:
return 0
elif dim is None:
return self._rank
@@ -199,7 +199,7 @@ class DistributedInterface:
def get_world_size(self, dim: Dim | None = None) -> int:
"""Get parallel size for specified dimension."""
if self.model_device_mesh is None:
if not self._is_distributed:
return 1
elif dim is None:
return self._world_size
@@ -216,7 +216,7 @@ class DistributedInterface:
def all_gather(self, data: TensorLike, dim: Dim | None = Dim.DP) -> TensorLike:
"""Gather tensor across specified parallel group."""
if self.model_device_mesh is not None:
if self._is_distributed:
return helper.operate_tensorlike(helper.all_gather, data, group=self.get_group(dim))
else:
return data
@@ -225,29 +225,32 @@ class DistributedInterface:
self, data: TensorLike, op: helper.ReduceOp = helper.ReduceOp.MEAN, dim: Dim | None = Dim.DP
) -> TensorLike:
"""Reduce tensor across specified parallel group."""
if self.model_device_mesh is not None:
if self._is_distributed:
return helper.operate_tensorlike(helper.all_reduce, data, op=op, group=self.get_group(dim))
else:
return data
def broadcast(self, data: TensorLike, src: int = 0, dim: Dim | None = Dim.DP) -> TensorLike:
"""Broadcast tensor across specified parallel group."""
if self.model_device_mesh is not None:
if self._is_distributed:
return helper.operate_tensorlike(helper.broadcast, data, src=src, group=self.get_group(dim))
else:
return data
def sync(self) -> None:
"""Synchronize all processes."""
helper.synchronize()
if self._is_distributed:
helper.synchronize()
def barrier(self) -> None:
"""Barrier all processes."""
barrier()
if self._is_distributed:
barrier()
def destroy(self) -> None:
"""Destroy all processes."""
destroy_process_group()
if self._is_distributed:
destroy_process_group()
if __name__ == "__main__":

View File

@@ -30,9 +30,9 @@ from .training_args import TrainingArguments
InputArgument = dict[str, Any] | list[str] | None
def get_args(args: InputArgument = None) -> tuple[DataArguments, ModelArguments, TrainingArguments, SampleArguments]:
def get_args(args: InputArgument = None) -> tuple[ModelArguments, DataArguments, TrainingArguments, SampleArguments]:
"""Parse arguments from command line or config file."""
parser = HfArgumentParser([DataArguments, ModelArguments, TrainingArguments, SampleArguments])
parser = HfArgumentParser([ModelArguments, DataArguments, TrainingArguments, SampleArguments])
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_KEYS")
if args is None:

View File

@@ -18,7 +18,11 @@ from dataclasses import dataclass, field
@dataclass
class DataArguments:
dataset: str | None = field(
train_dataset: str | None = field(
default=None,
metadata={"help": "Path to the dataset."},
metadata={"help": "Path to the training dataset."},
)
eval_dataset: str | None = field(
default=None,
metadata={"help": "Path to the evaluation dataset."},
)

View File

@@ -33,13 +33,21 @@ class TrainingArguments:
default=None,
metadata={"help": "Global batch size for training, default to DP size * micro batch size."},
)
cutoff_len: int = field(
default=2048,
metadata={"help": "Maximum sequence length for training."},
)
learning_rate: float = field(
default=1e-4,
metadata={"help": "Learning rate for training."},
)
cutoff_len: int = field(
default=2048,
metadata={"help": "Maximum sequence length for training."},
num_train_epochs: int = field(
default=3,
metadata={"help": "Number of training epochs."},
)
max_grad_norm: float = field(
default=1.0,
metadata={"help": "Maximum gradient norm for training."},
)
bf16: bool = field(
default=False,
@@ -53,10 +61,24 @@ class TrainingArguments:
default=16,
metadata={"help": "Number of workers for batching."},
)
enable_activation_checkpointing: bool = field(
default=True,
metadata={"help": "Enable activation checkpointing for training."},
)
dist_config: PluginConfig | None = field(
default=None,
metadata={"help": "Distribution configuration for training."},
)
optim_config: PluginConfig | None = field(
default=None,
metadata={"help": "Optimizer configuration for training."},
)
lr_scheduler_config: PluginConfig | None = field(
default=None,
metadata={"help": "Learning rate scheduler configuration for training."},
)
def __post_init__(self) -> None:
self.dist_config = get_plugin_config(self.dist_config)
self.optim_config = get_plugin_config(self.optim_config)
self.lr_scheduler_config = get_plugin_config(self.lr_scheduler_config)

View File

@@ -12,115 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import os
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
from threading import Thread
import torch
from transformers import AsyncTextIteratorStreamer
from ..accelerator.interface import DistributedInterface
from ..config import ModelArguments, SampleArguments, SampleBackend
from ..utils.helper import get_tokenizer
from ..utils.types import HFModel, Message, Sample, TorchDataset
from .utils.inference_engine import HuggingFaceEngine
from .utils.rendering import Renderer
class BaseEngine(ABC):
@abstractmethod
def __init__(
self,
args: SampleArguments,
model_args: ModelArguments,
model: HFModel,
renderer: Renderer,
) -> None:
"""Initialize the engine.
Args:
args: Sample arguments.
model_args: Model arguments.
model: Model.
renderer: Renderer.
"""
...
@abstractmethod
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
"""Generate tokens asynchronously.
Args:
messages: List of messages.
tools: Tools string.
Yields:
Generated tokens.
"""
...
@abstractmethod
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
"""Batch infer samples.
Args:
dataset: Torch dataset.
Returns:
List of samples.
"""
...
class HuggingFaceEngine(BaseEngine):
def __init__(
self,
args: SampleArguments,
model_args: ModelArguments,
model: HFModel,
renderer: Renderer,
) -> None:
self.args = args
self.model_args = model_args
self.model = model
self.renderer = renderer
self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
@torch.inference_mode()
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
async with self.semaphore:
model_inputs = self.renderer.render_messages(messages, tools, is_generate=True)
streamer = AsyncTextIteratorStreamer(
tokenizer=get_tokenizer(self.renderer.processor),
skip_prompt=True,
skip_special_tokens=True, # TODO: configurable
)
device = DistributedInterface().current_device
kwargs = {
"input_ids": torch.tensor([model_inputs["input_ids"]]).to(device),
"attention_mask": torch.tensor([model_inputs["attention_mask"]]).to(device),
"max_new_tokens": self.args.max_new_tokens,
"streamer": streamer,
}
thread = Thread(target=self.model.generate, kwargs=kwargs, daemon=True)
thread.start()
async for token in streamer:
yield token
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
"""Batch infer samples.
Args:
dataset: Torch dataset.
Returns:
List of samples.
"""
raise NotImplementedError("Batch infer is not implemented.")
class BaseSampler:
"""Base sampler.

View File

@@ -16,42 +16,166 @@
Init Phase:
1. Init dataloader.
1. Init batch generator.
2. Init optimizer (deepspeed).
3. Shard model.
4. Init optimizer (fsdp).
5. Init scheduler.
5. Init lr scheduler.
Train Phase:
1. Train Loop
"""
from ..config.training_args import TrainingArguments
from ..utils.types import HFModel, TorchDataset
from abc import abstractmethod
import torch
import torch.nn.functional as F
from ..accelerator.helper import ReduceOp
from ..accelerator.interface import Dim, DistributedInterface
from ..config import TrainingArguments
from ..utils import logging
from ..utils.helper import compute_valid_tokens
from ..utils.types import BatchInput, HFModel, ModelOutput, Tensor, TorchDataset
from .utils.batching import BatchGenerator
from .utils.rendering import Renderer
logger = logging.get_logger(__name__)
class BaseTrainer:
def __init__(
self,
args: TrainingArguments,
model: HFModel,
renderer: Renderer,
dataset: TorchDataset,
train_dataset: TorchDataset,
) -> None:
self.args = args
self.model = model
self.renderer = renderer
self.dataset = dataset
self.optimizer = None
self.lr_scheduler = None
self.train_dataset = train_dataset
def _create_dataloader(self) -> None:
# info
self.global_step = 0
# cached variables
self.device = DistributedInterface().current_device
self.dp_size = DistributedInterface().get_world_size(Dim.DP)
self.model_input_names = self.renderer.processor.model_input_names
self._create_batch_generator()
self.num_training_steps = self.args.num_train_epochs * len(self.train_batch_generator)
if self.args.enable_activation_checkpointing:
self.model.gradient_checkpointing_enable({"use_reentrant": False})
if self.args.dist_config is not None:
shard_need_optimizer = self.args.dist_config.name == "deepspeed"
else:
shard_need_optimizer = False
if shard_need_optimizer:
self._init_optimizer()
self._shard_model()
else:
self._shard_model()
self._init_optimizer()
self._init_lr_scheduler()
def _create_batch_generator(self) -> None:
self.train_batch_generator = BatchGenerator(
dataset=self.train_dataset,
renderer=self.renderer,
micro_batch_size=self.args.micro_batch_size,
global_batch_size=self.args.global_batch_size,
cutoff_len=self.args.cutoff_len,
batching_workers=self.args.batching_workers,
batching_strategy=self.args.batching_strategy,
)
def _shard_model(self) -> None:
pass
def _init_model_and_optimizer(self) -> None:
pass
def _init_optimizer(self) -> None:
"""Init optimizer."""
if self.args.optim_config is None:
_trainable_params = [p for p in self.model.parameters() if p.requires_grad]
self.optimizer = torch.optim.AdamW(_trainable_params, lr=self.args.learning_rate)
else:
from ..plugins.trainer_plugins.optimizer import OptimizerPlugin
self.optimizer = OptimizerPlugin(self.args.optim_config.name)(self.model, self.args.optim_config)
def _init_lr_scheduler(self) -> None:
"""Init lr scheduler."""
if self.args.lr_scheduler_config is None:
self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda x: 1.0)
else:
from ..plugins.trainer_plugins.lr_scheduler import LRSchedulerPlugin
self.lr_scheduler = LRSchedulerPlugin(self.args.lr_scheduler_config.name)(
self.optimizer, self.num_training_steps, self.args.lr_scheduler_config
)
def compute_log_probs(self, model: HFModel, batch: BatchInput) -> Tensor:
"""Compute log probs.
log_probs: Tensor of shape (batch_size, seq_len - 1)
"""
batch_size, _ = batch["labels"].shape
model_inputs = {
k: v.to(self.device, non_blocking=True) for k, v in batch.items() if k in self.model_input_names
}
labels = batch["labels"].to(self.device, non_blocking=True)
outputs: ModelOutput = model(**model_inputs)
logits = outputs.logits.float()
shift_labels = labels[..., 1:].contiguous().view(-1)
shift_logits = logits[..., :-1, :].contiguous().view(shift_labels.size(0), -1)
return -F.cross_entropy(shift_logits, shift_labels, reduction="none").view(batch_size, -1)
@abstractmethod
def compute_loss(self, batch: BatchInput) -> Tensor:
"""Compute the scalar loss."""
...
def fit(self) -> None:
pass
"""Train the model."""
self.model.train()
for epoch in range(self.args.num_train_epochs):
self.train_batch_generator.set_epoch(epoch)
for micro_batches in self.train_batch_generator:
self.global_step += 1
step_loss = 0
step_valid_tokens = compute_valid_tokens(micro_batches)
step_valid_tokens = DistributedInterface().all_reduce(step_valid_tokens, op=ReduceOp.SUM)
for micro_batch in micro_batches:
loss = self.compute_loss(micro_batch)
mini_step_valid_tokens = compute_valid_tokens([micro_batch])
# fsdp uses mean reduction so we need to scale the loss by dp_size
loss = loss * mini_step_valid_tokens * self.dp_size / (step_valid_tokens + 1e-6)
loss.backward()
step_loss += loss.item()
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm).item()
if not torch.isfinite(grad_norm):
logger.warning_rank0(f"Gradient norm is not finite: {grad_norm}")
else:
self.optimizer.step()
self.lr_scheduler.step()
self.optimizer.zero_grad()
step_loss, grad_norm = DistributedInterface().all_reduce([step_loss, grad_norm])
DistributedInterface().sync()
print(f"Epoch {epoch}, Step {self.global_step}, Loss: {step_loss:.4f}, Grad Norm: {grad_norm:.4f}")
def save_model(self) -> None:
"""Save the model."""
self.model.save_pretrained(self.args.output_dir)
self.renderer.processor.save_pretrained(self.args.output_dir)
logger.info_rank0(f"Model saved to {self.args.output_dir}")

View File

@@ -15,7 +15,7 @@
"""The definition of data engine.
How to use:
data_engine = DataEngine(data_args)
data_engine = DataEngine(data_args.train_dataset)
data_engine[i]: Get the sample via index.
Init workflow:
@@ -41,7 +41,6 @@ from huggingface_hub import hf_hub_download
from omegaconf import OmegaConf
from torch.utils.data import Dataset
from ..config.data_args import DataArguments
from ..utils.types import DatasetInfo, HFDataset, Sample
@@ -52,9 +51,9 @@ class DataEngine(Dataset):
data_args: Data arguments.
"""
def __init__(self, data_args: DataArguments) -> None:
self.args = data_args
"""Data arguments."""
def __init__(self, dataset_path: str) -> None:
self.path = dataset_path
"""Dataset path."""
self.datasets: dict[str, HFDataset] = {}
"""Dict of (dataset_name, dataset)"""
self.dataset_infos: dict[str, DatasetInfo] = {}
@@ -69,16 +68,16 @@ class DataEngine(Dataset):
def _get_dataset_info(self) -> None:
"""Get dataset info from data arguments."""
if self.args.dataset.endswith(".yaml") and os.path.isfile(self.args.dataset): # local file
self.dataset_infos = OmegaConf.load(self.args.dataset)
elif self.args.dataset.endswith(".yaml"): # hf hub uri, e.g. llamafactory/v1-sft-demo/dataset_info.yaml
repo_id, filename = os.path.split(self.args.dataset)
if self.path.endswith(".yaml") and os.path.isfile(self.path): # local file
self.dataset_infos = OmegaConf.load(self.path)
elif self.path.endswith(".yaml"): # hf hub uri, e.g. llamafactory/v1-sft-demo/dataset_info.yaml
repo_id, filename = os.path.split(self.path)
filepath = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")
self.dataset_infos = OmegaConf.load(filepath)
elif os.path.exists(self.args.dataset): # local file(s)
self.dataset_infos = {"default": {"path": self.args.dataset, "source": "local"}}
elif os.path.exists(self.path): # local file(s)
self.dataset_infos = {"default": {"path": self.path, "source": "local"}}
else: # hf hub dataset, e.g. llamafactory/v1-sft-demo
self.dataset_infos = {"default": {"path": self.args.dataset}}
self.dataset_infos = {"default": {"path": self.path}}
def _load_dataset(self) -> None:
"""Load datasets according to dataset info."""
@@ -187,11 +186,11 @@ class DataEngine(Dataset):
if __name__ == "__main__":
"""
python -m llamafactory.v1.core.data_engine --dataset data/v1_sft_demo.yaml
python -m llamafactory.v1.core.data_engine --dataset data/v1_dpo_demo.yaml
python -m llamafactory.v1.core.data_engine --train_dataset data/v1_sft_demo.yaml
python -m llamafactory.v1.core.data_engine --train_dataset data/v1_dpo_demo.yaml
"""
from ..config.arg_parser import get_args
data_args, *_ = get_args()
data_engine = DataEngine(data_args=data_args)
_, data_args, *_ = get_args()
data_engine = DataEngine(data_args.train_dataset)
print(data_engine[0])

View File

@@ -153,7 +153,7 @@ if __name__ == "__main__":
"""
from ..config.arg_parser import get_args
_, model_args, *_ = get_args()
model_args, *_ = get_args()
model_engine = ModelEngine(model_args=model_args)
print(model_engine.processor)
print(model_engine.model_config)

View File

@@ -216,7 +216,7 @@ if __name__ == "__main__":
"""
python -m llamafactory.v1.core.utils.batching \
--model llamafactory/tiny-random-qwen2.5 \
--dataset data/v1_sft_demo.yaml \
--train_dataset data/v1_sft_demo.yaml \
--micro_batch_size 2 \
--global_batch_size 4 \
--batching_workers 0
@@ -225,8 +225,8 @@ if __name__ == "__main__":
from ..data_engine import DataEngine
from ..model_engine import ModelEngine
data_args, model_args, training_args, _ = get_args()
data_engine = DataEngine(data_args=data_args)
model_args, data_args, training_args, _ = get_args()
data_engine = DataEngine(data_args.train_dataset)
model_engine = ModelEngine(model_args=model_args)
batch_generator = BatchGenerator(
data_engine,

View File

@@ -1,119 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any
import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data._utils.collate import default_collate
from ....extras.constants import IGNORE_INDEX
from ...plugins.data_plugins.template import Template
from ...utils.types import Processor, Tensor
def len2culen(seqlens: "torch.Tensor") -> "torch.Tensor": # FIXME move to utils
"""Convert sequence lengths to cumulative sequence lengths."""
return F.pad(torch.cumsum(seqlens, dim=0), (1, 0)).type(torch.int32)
class DataCollator:
"""Default Data collator."""
processor: "Processor" # processor name -> map to encode_messages function
def __post_init__(self):
# callback for text tokenizer
self.tokenizer = self.processor.tokenizer if hasattr(self.processor, "tokenizer") else self.processor
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Tensor]:
"""Collate features into a batch."""
batch = defaultdict(list)
# batching features
for feature in features:
for key in feature.keys():
batch[key].append(feature[key])
for key in batch.keys():
# process padding features
if key in ["input_ids", "attention_mask", "position_ids"]:
padding_value = self.tokenizer.pad_token_id if key == "input_ids" else 0
batch[key] = pad_sequence(batch[key], batch_first=True, padding_value=padding_value)
elif key in ["labels"]:
batch[key] = pad_sequence(batch[key], batch_first=True, padding_value=IGNORE_INDEX)
else:
batch[key] = default_collate(batch[key])
return batch
# sft: messages
# dpo: chosen_messages, rejected_messages
@dataclass
class DefaultCollator(DataCollator):
"""Example for now."""
processor: "Processor" # processor name -> map to encode_messages function
template: "Template"
def __call__(self, messages: list[list[dict[str, Any]]]) -> dict[str, Tensor]:
features = []
# Check if data is already tokenized (contains input_ids)
if messages and isinstance(messages[0], dict) and "input_ids" in messages[0]:
for feature in messages:
if not isinstance(feature, dict):
raise ValueError(f"Expected dict but got {type(feature)}")
tensor_feature = {
k: torch.tensor(v, dtype=torch.long) if not isinstance(v, torch.Tensor) else v
for k, v in feature.items()
}
features.append(tensor_feature)
else:
# raw messages need to be encoded
for message in messages:
encoded_message = self.template.encode_messages(self.tokenizer, message)
encoded_message = {k: torch.tensor(v, dtype=torch.long) for k, v in encoded_message.items()}
features.append(encoded_message)
return super().__call__(features)
@dataclass
class PairwiseCollator(DataCollator):
pass
@dataclass
class DataCollatorWithPacking(DefaultCollator):
"""Data collator with packing."""
processor: "Processor"
template: "Template"
def __call__(self, features: Sequence[dict[str, "torch.Tensor"]]) -> dict[str, "torch.Tensor"]:
seqlens = torch.tensor([len(feature["input_ids"]) for feature in features], dtype=torch.long)
batch = {"cu_seqlens": len2culen(seqlens)}
for input_name in features[0].keys():
if input_name in ("input_ids", "attention_mask", "labels"):
batch[input_name] = torch.cat([feature[input_name] for feature in features])
else:
batch[input_name] = default_collate([feature[input_name] for feature in features])
return batch

View File

@@ -0,0 +1,121 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import os
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
from threading import Thread
import torch
from transformers import AsyncTextIteratorStreamer
from ...accelerator.interface import DistributedInterface
from ...config import ModelArguments, SampleArguments
from ...utils.helper import get_tokenizer
from ...utils.types import HFModel, Message, Sample, TorchDataset
from .rendering import Renderer
class BaseEngine(ABC):
@abstractmethod
def __init__(
self,
args: SampleArguments,
model_args: ModelArguments,
model: HFModel,
renderer: Renderer,
) -> None:
"""Initialize the engine.
Args:
args: Sample arguments.
model_args: Model arguments.
model: Model.
renderer: Renderer.
"""
...
@abstractmethod
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
"""Generate tokens asynchronously.
Args:
messages: List of messages.
tools: Tools string.
Yields:
Generated tokens.
"""
...
@abstractmethod
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
"""Batch infer samples.
Args:
dataset: Torch dataset.
Returns:
List of samples.
"""
...
class HuggingFaceEngine(BaseEngine):
def __init__(
self,
args: SampleArguments,
model_args: ModelArguments,
model: HFModel,
renderer: Renderer,
) -> None:
self.args = args
self.model_args = model_args
self.model = model
self.renderer = renderer
self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
@torch.inference_mode()
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
async with self.semaphore:
model_inputs = self.renderer.render_messages(messages, tools, is_generate=True)
streamer = AsyncTextIteratorStreamer(
tokenizer=get_tokenizer(self.renderer.processor),
skip_prompt=True,
skip_special_tokens=True, # TODO: configurable
)
device = DistributedInterface().current_device
kwargs = {
"input_ids": torch.tensor([model_inputs["input_ids"]]).to(device),
"attention_mask": torch.tensor([model_inputs["attention_mask"]]).to(device),
"max_new_tokens": self.args.max_new_tokens,
"streamer": streamer,
}
thread = Thread(target=self.model.generate, kwargs=kwargs, daemon=True)
thread.start()
async for token in streamer:
yield token
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
"""Batch infer samples.
Args:
dataset: Torch dataset.
Returns:
List of samples.
"""
raise NotImplementedError("Batch infer is not implemented.")

View File

@@ -142,8 +142,8 @@ class Renderer:
elif "chosen_messages" in sample and "rejected_messages" in sample:
chosen_input = self.render_messages(sample["chosen_messages"], sample.get("tools"))
rejected_input = self.render_messages(sample["rejected_messages"], sample.get("tools"))
chosen_input["token_type_ids"] = [0] * len(chosen_input["input_ids"])
rejected_input["token_type_ids"] = [1] * len(rejected_input["input_ids"])
chosen_input["token_type_ids"] = [1] * len(chosen_input["input_ids"])
rejected_input["token_type_ids"] = [2] * len(rejected_input["input_ids"])
model_input = ModelInput(
input_ids=chosen_input["input_ids"] + rejected_input["input_ids"],
attention_mask=chosen_input["attention_mask"] + rejected_input["attention_mask"],

View File

@@ -18,8 +18,11 @@ from ...utils.types import BatchInfo, BatchInput, DataLoader
class BatchingPlugin(BasePlugin):
def compute_length(self, dataloader: DataLoader) -> int:
"""Compute the length of the batch generator."""
def compute_length(self, data_provider: DataLoader) -> int:
"""Compute the length of the batch generator.
The approximate length is used to calculate the lr schedule.
"""
raise NotImplementedError()
def fill_buffer(self, buffer: StatefulBuffer, batch_info: BatchInfo) -> None:

View File

@@ -0,0 +1,19 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils.plugin import BasePlugin
class LRSchedulerPlugin(BasePlugin):
pass

View File

@@ -0,0 +1,19 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils.plugin import BasePlugin
class OptimizerPlugin(BasePlugin):
pass

View File

@@ -73,14 +73,14 @@ class SyncSampler(BaseSampler):
def run_chat(args: InputArgument = None):
data_args, model_args, _, sample_args = get_args(args)
model_args, data_args, _, sample_args = get_args(args)
if sample_args.sample_backend != SampleBackend.HF:
model_args.init_plugin = {"name": "init_on_meta"}
model_engine = ModelEngine(model_args)
sampler = SyncSampler(sample_args, model_args, model_engine.model, model_engine.renderer)
if data_args.dataset is not None:
dataset = DataEngine(data_args)
if data_args.train_dataset is not None:
dataset = DataEngine(data_args.train_dataset)
sampler.batch_infer(dataset)
else:
if os.name != "nt":

View File

@@ -18,21 +18,35 @@ from ..config import InputArgument, get_args
from ..core.base_trainer import BaseTrainer
from ..core.data_engine import DataEngine
from ..core.model_engine import ModelEngine
from ..utils.types import BatchInput, Tensor
class SFTTrainer(BaseTrainer):
pass
def compute_loss(self, batch: BatchInput) -> Tensor:
shift_loss_weights = batch["loss_weights"].to(self.device, non_blocking=True)[..., 1:]
log_probs = self.compute_log_probs(self.model, batch)
loss = (-log_probs * shift_loss_weights).sum() / (shift_loss_weights.sum() + 1e-6)
return loss
def run_sft(args: InputArgument = None):
model_args, data_args, training_args, _ = get_args(args)
DistributedInterface(training_args.dist_config)
data_engine = DataEngine(data_args)
train_dataset = DataEngine(data_args.train_dataset)
model_engine = ModelEngine(model_args)
trainer = SFTTrainer(
args=training_args,
model=model_engine.model,
renderer=model_engine.renderer,
dataset=data_engine,
train_dataset=train_dataset,
)
trainer.fit()
trainer.save_model()
DistributedInterface().destroy()
if __name__ == "__main__":
"""
python -m llamafactory.v1.trainers.sft_trainer --model Qwen/Qwen3-0.6B --train_dataset data/v1_sft_demo.yaml
"""
run_sft()

View File

@@ -16,6 +16,7 @@
import torch
from transformers import PreTrainedTokenizer
from ..accelerator.interface import DistributedInterface
from .constants import IGNORE_INDEX
from .types import BatchInput, ModelInput, Processor, Tensor
@@ -73,3 +74,20 @@ def pad_and_truncate(samples: list[ModelInput], max_seqlen: int) -> list[BatchIn
padded_samples.append(padded_sample)
return padded_samples
def compute_valid_tokens(batches: list[BatchInput]) -> int:
"""Compute valid tokens in batches.
Args:
batches: Batches.
Returns:
Number of valid tokens.
"""
device = DistributedInterface().current_device
return sum(
(batch["labels"].to(device, non_blocking=True) != IGNORE_INDEX).sum().item()
for batch in batches
if "labels" in batch
)

View File

@@ -13,7 +13,7 @@
# limitations under the License.
from collections.abc import Iterator
from typing import TYPE_CHECKING, Any, Literal, NotRequired, TypedDict, Union
from typing import TYPE_CHECKING, Any, Literal, NamedTuple, NotRequired, TypedDict, Union
if TYPE_CHECKING:
@@ -146,7 +146,7 @@ class ModelInput(TypedDict, total=False):
position_ids: NotRequired[list[int] | list[list[int]]]
"""Position ids for the model (optional)."""
token_type_ids: NotRequired[list[int]]
"""Token type ids used in DPO, 0 represents the chosen messages, 1 represents the rejected messages."""
"""Token type ids used in DPO, 1 represents the chosen messages, 2 represents the rejected messages."""
class BatchInput(TypedDict, total=False):
@@ -161,7 +161,7 @@ class BatchInput(TypedDict, total=False):
position_ids: NotRequired[Tensor]
"""Position ids for the model (optional)."""
token_type_ids: NotRequired[Tensor]
"""Token type ids used in DPO, 0 represents the chosen messages, 1 represents the rejected messages."""
"""Token type ids used in DPO, 1 represents the chosen messages, 2 represents the rejected messages."""
class BatchInfo(TypedDict):
@@ -173,3 +173,8 @@ class BatchInfo(TypedDict):
"""Cutoff length."""
data_iter: Iterator[list[ModelInput]]
"""Data iterator."""
class ModelOutput(NamedTuple):
logits: Tensor
"""Logits for the model."""

View File

@@ -1,2 +1,2 @@
# change if test fails or cache is outdated
0.9.5.104
0.9.5.105

View File

@@ -34,7 +34,7 @@ def test_get_args_from_yaml(tmp_path: Path):
quant_config: null
### data
dataset: llamafactory/v1-sft-demo
train_dataset: llamafactory/v1-sft-demo
### training
output_dir: outputs/test_run
@@ -56,8 +56,8 @@ def test_get_args_from_yaml(tmp_path: Path):
test_argv = ["test_args_parser.py", str(config_file)]
with patch.object(sys, "argv", test_argv):
data_args, model_args, training_args, sample_args = get_args()
assert data_args.dataset == "llamafactory/v1-sft-demo"
model_args, data_args, training_args, sample_args = get_args()
assert data_args.train_dataset == "llamafactory/v1-sft-demo"
assert model_args.model == "llamafactory/tiny-random-qwen3"
assert model_args.kernel_config.name == "auto"
assert model_args.kernel_config.get("include_kernels") == "auto"

View File

@@ -23,8 +23,8 @@ from llamafactory.v1.core.data_engine import DataEngine
@pytest.mark.parametrize("num_samples", [16])
def test_map_dataset(num_samples: int):
data_args = DataArguments(dataset="llamafactory/v1-sft-demo")
data_engine = DataEngine(data_args)
data_args = DataArguments(train_dataset="llamafactory/v1-sft-demo")
data_engine = DataEngine(data_args.train_dataset)
original_data = load_dataset("llamafactory/v1-sft-demo", split="train")
indexes = random.choices(range(len(data_engine)), k=num_samples)
for index in indexes:

View File

@@ -19,8 +19,8 @@ from llamafactory.v1.core.utils.batching import BatchGenerator
def test_normal_batching():
data_args = DataArguments(dataset="llamafactory/v1-sft-demo")
data_engine = DataEngine(data_args=data_args)
data_args = DataArguments(train_dataset="llamafactory/v1-sft-demo")
data_engine = DataEngine(data_args.train_dataset)
model_args = ModelArguments(model="llamafactory/tiny-random-qwen3")
model_engine = ModelEngine(model_args=model_args)
training_args = TrainingArguments(

View File

@@ -111,8 +111,8 @@ def test_chatml_parse():
def test_chatml_rendering_remote(num_samples: int):
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
renderer = Renderer(template="chatml", processor=tokenizer)
data_args = DataArguments(dataset="llamafactory/v1-sft-demo")
data_engine = DataEngine(data_args)
data_args = DataArguments(train_dataset="llamafactory/v1-sft-demo")
data_engine = DataEngine(data_args.train_dataset)
for index in range(num_samples):
v1_inputs = renderer.render_messages(data_engine[index]["messages"], is_generate=True)
prefix = tokenizer.encode("<|im_start|>user\n", add_special_tokens=False)
@@ -167,8 +167,8 @@ def test_qwen3_nothink_parse():
def test_qwen3_nothink_rendering_remote(num_samples: int):
tokenizer: Processor = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B-Instruct-2507")
renderer = Renderer(template="qwen3_nothink", processor=tokenizer)
data_args = DataArguments(dataset="llamafactory/reason-tool-use-demo-1500")
data_engine = DataEngine(data_args)
data_args = DataArguments(train_dataset="llamafactory/reason-tool-use-demo-1500")
data_engine = DataEngine(data_args.train_dataset)
for index in range(num_samples):
v1_inputs = renderer.render_messages(data_engine[index]["messages"], tools=data_engine[index]["tools"])
prefix_text = (
@@ -213,7 +213,7 @@ def test_process_dpo_samples():
model_inputs = renderer.process_samples(samples)
assert len(model_inputs) == 1
assert model_inputs[0]["input_ids"] == hf_inputs * 2
assert model_inputs[0]["token_type_ids"] == [0] * len(hf_inputs) + [1] * len(hf_inputs)
assert model_inputs[0]["token_type_ids"] == [1] * len(hf_inputs) + [2] * len(hf_inputs)
assert model_inputs[0]["extra_info"] == "test"
assert model_inputs[0]["_dataset_name"] == "default"

View File

@@ -24,8 +24,8 @@ from llamafactory.v1.plugins.data_plugins.converter import DataConverterPlugin
@pytest.mark.parametrize("num_samples", [16])
def test_alpaca_converter(num_samples: int):
data_args = DataArguments(dataset="llamafactory/v1-dataset-info/tiny-supervised-dataset.yaml")
data_engine = DataEngine(data_args)
data_args = DataArguments(train_dataset="llamafactory/v1-dataset-info/tiny-supervised-dataset.yaml")
data_engine = DataEngine(data_args.train_dataset)
original_data = load_dataset("llamafactory/tiny-supervised-dataset", split="train")
indexes = random.choices(range(len(data_engine)), k=num_samples)
for index in indexes:
@@ -73,8 +73,8 @@ def test_sharegpt_converter():
@pytest.mark.parametrize("num_samples", [16])
def test_pair_converter(num_samples: int):
data_args = DataArguments(dataset="llamafactory/v1-dataset-info/orca-dpo-pairs.yaml")
data_engine = DataEngine(data_args)
data_args = DataArguments(train_dataset="llamafactory/v1-dataset-info/orca-dpo-pairs.yaml")
data_engine = DataEngine(data_args.train_dataset)
original_data = load_dataset("HuggingFaceH4/orca_dpo_pairs", split="train_prefs")
indexes = random.choices(range(len(data_engine)), k=num_samples)
for index in indexes:

View File

@@ -19,7 +19,7 @@ from llamafactory.v1.core.model_engine import ModelEngine
def test_init_on_meta():
_, model_args, *_ = get_args(
model_args, *_ = get_args(
dict(
model="llamafactory/tiny-random-qwen3",
init_config={"name": "init_on_meta"},
@@ -30,7 +30,7 @@ def test_init_on_meta():
def test_init_on_rank0():
_, model_args, *_ = get_args(
model_args, *_ = get_args(
dict(
model="llamafactory/tiny-random-qwen3",
init_config={"name": "init_on_rank0"},
@@ -44,7 +44,7 @@ def test_init_on_rank0():
def test_init_on_default():
_, model_args, *_ = get_args(
model_args, *_ = get_args(
dict(
model="llamafactory/tiny-random-qwen3",
init_config={"name": "init_on_default"},

View File

@@ -43,7 +43,8 @@ def test_apply_kernel(mock_get_accelerator: MagicMock):
reload_kernels()
from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_default_kernels
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen3")
# NOTE: use a special model to avoid contamination by other tests
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
original_swiglu_forward = model.model.layers[0].mlp.forward
model = apply_default_kernels(model=model, include_kernels="npu_fused_rmsnorm")
@@ -62,7 +63,8 @@ def test_apply_all_kernels(mock_get_accelerator: MagicMock):
reload_kernels()
from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_default_kernels
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen3")
# NOTE: use a special model to avoid contamination by other tests
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
original_swiglu_forward = model.model.layers[0].mlp.forward