mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-01-12 09:00:35 +08:00
[v1] add sft (#9752)
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -176,6 +176,7 @@ llamaboard_cache/
|
||||
llamaboard_config/
|
||||
saves/
|
||||
output/
|
||||
outputs/
|
||||
wandb/
|
||||
swanlog/
|
||||
generated_predictions.jsonl
|
||||
|
||||
@@ -174,7 +174,7 @@ class DistributedInterface:
|
||||
"""Get device mesh for specified dimension."""
|
||||
if dim is None:
|
||||
raise ValueError("dim must be specified.")
|
||||
elif self.model_device_mesh is None:
|
||||
elif not self._is_distributed:
|
||||
return None
|
||||
elif dim in self.strategy.data_mesh_dim_names:
|
||||
return self.data_device_mesh[dim.value]
|
||||
@@ -183,14 +183,14 @@ class DistributedInterface:
|
||||
|
||||
def get_group(self, dim: Dim | None = None) -> Optional[ProcessGroup]:
|
||||
"""Get process group for specified dimension."""
|
||||
if self.model_device_mesh is None or dim is None:
|
||||
if not self._is_distributed or dim is None:
|
||||
return None
|
||||
else:
|
||||
return self.get_device_mesh(dim).get_group()
|
||||
|
||||
def get_rank(self, dim: Dim | None = None) -> int:
|
||||
"""Get parallel rank for specified dimension."""
|
||||
if self.model_device_mesh is None:
|
||||
if not self._is_distributed:
|
||||
return 0
|
||||
elif dim is None:
|
||||
return self._rank
|
||||
@@ -199,7 +199,7 @@ class DistributedInterface:
|
||||
|
||||
def get_world_size(self, dim: Dim | None = None) -> int:
|
||||
"""Get parallel size for specified dimension."""
|
||||
if self.model_device_mesh is None:
|
||||
if not self._is_distributed:
|
||||
return 1
|
||||
elif dim is None:
|
||||
return self._world_size
|
||||
@@ -216,7 +216,7 @@ class DistributedInterface:
|
||||
|
||||
def all_gather(self, data: TensorLike, dim: Dim | None = Dim.DP) -> TensorLike:
|
||||
"""Gather tensor across specified parallel group."""
|
||||
if self.model_device_mesh is not None:
|
||||
if self._is_distributed:
|
||||
return helper.operate_tensorlike(helper.all_gather, data, group=self.get_group(dim))
|
||||
else:
|
||||
return data
|
||||
@@ -225,29 +225,32 @@ class DistributedInterface:
|
||||
self, data: TensorLike, op: helper.ReduceOp = helper.ReduceOp.MEAN, dim: Dim | None = Dim.DP
|
||||
) -> TensorLike:
|
||||
"""Reduce tensor across specified parallel group."""
|
||||
if self.model_device_mesh is not None:
|
||||
if self._is_distributed:
|
||||
return helper.operate_tensorlike(helper.all_reduce, data, op=op, group=self.get_group(dim))
|
||||
else:
|
||||
return data
|
||||
|
||||
def broadcast(self, data: TensorLike, src: int = 0, dim: Dim | None = Dim.DP) -> TensorLike:
|
||||
"""Broadcast tensor across specified parallel group."""
|
||||
if self.model_device_mesh is not None:
|
||||
if self._is_distributed:
|
||||
return helper.operate_tensorlike(helper.broadcast, data, src=src, group=self.get_group(dim))
|
||||
else:
|
||||
return data
|
||||
|
||||
def sync(self) -> None:
|
||||
"""Synchronize all processes."""
|
||||
helper.synchronize()
|
||||
if self._is_distributed:
|
||||
helper.synchronize()
|
||||
|
||||
def barrier(self) -> None:
|
||||
"""Barrier all processes."""
|
||||
barrier()
|
||||
if self._is_distributed:
|
||||
barrier()
|
||||
|
||||
def destroy(self) -> None:
|
||||
"""Destroy all processes."""
|
||||
destroy_process_group()
|
||||
if self._is_distributed:
|
||||
destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -30,9 +30,9 @@ from .training_args import TrainingArguments
|
||||
InputArgument = dict[str, Any] | list[str] | None
|
||||
|
||||
|
||||
def get_args(args: InputArgument = None) -> tuple[DataArguments, ModelArguments, TrainingArguments, SampleArguments]:
|
||||
def get_args(args: InputArgument = None) -> tuple[ModelArguments, DataArguments, TrainingArguments, SampleArguments]:
|
||||
"""Parse arguments from command line or config file."""
|
||||
parser = HfArgumentParser([DataArguments, ModelArguments, TrainingArguments, SampleArguments])
|
||||
parser = HfArgumentParser([ModelArguments, DataArguments, TrainingArguments, SampleArguments])
|
||||
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_KEYS")
|
||||
|
||||
if args is None:
|
||||
|
||||
@@ -18,7 +18,11 @@ from dataclasses import dataclass, field
|
||||
|
||||
@dataclass
|
||||
class DataArguments:
|
||||
dataset: str | None = field(
|
||||
train_dataset: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the dataset."},
|
||||
metadata={"help": "Path to the training dataset."},
|
||||
)
|
||||
eval_dataset: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the evaluation dataset."},
|
||||
)
|
||||
|
||||
@@ -33,13 +33,21 @@ class TrainingArguments:
|
||||
default=None,
|
||||
metadata={"help": "Global batch size for training, default to DP size * micro batch size."},
|
||||
)
|
||||
cutoff_len: int = field(
|
||||
default=2048,
|
||||
metadata={"help": "Maximum sequence length for training."},
|
||||
)
|
||||
learning_rate: float = field(
|
||||
default=1e-4,
|
||||
metadata={"help": "Learning rate for training."},
|
||||
)
|
||||
cutoff_len: int = field(
|
||||
default=2048,
|
||||
metadata={"help": "Maximum sequence length for training."},
|
||||
num_train_epochs: int = field(
|
||||
default=3,
|
||||
metadata={"help": "Number of training epochs."},
|
||||
)
|
||||
max_grad_norm: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Maximum gradient norm for training."},
|
||||
)
|
||||
bf16: bool = field(
|
||||
default=False,
|
||||
@@ -53,10 +61,24 @@ class TrainingArguments:
|
||||
default=16,
|
||||
metadata={"help": "Number of workers for batching."},
|
||||
)
|
||||
enable_activation_checkpointing: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Enable activation checkpointing for training."},
|
||||
)
|
||||
dist_config: PluginConfig | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Distribution configuration for training."},
|
||||
)
|
||||
optim_config: PluginConfig | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Optimizer configuration for training."},
|
||||
)
|
||||
lr_scheduler_config: PluginConfig | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Learning rate scheduler configuration for training."},
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.dist_config = get_plugin_config(self.dist_config)
|
||||
self.optim_config = get_plugin_config(self.optim_config)
|
||||
self.lr_scheduler_config = get_plugin_config(self.lr_scheduler_config)
|
||||
|
||||
@@ -12,115 +12,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncGenerator
|
||||
from threading import Thread
|
||||
|
||||
import torch
|
||||
from transformers import AsyncTextIteratorStreamer
|
||||
|
||||
from ..accelerator.interface import DistributedInterface
|
||||
from ..config import ModelArguments, SampleArguments, SampleBackend
|
||||
from ..utils.helper import get_tokenizer
|
||||
from ..utils.types import HFModel, Message, Sample, TorchDataset
|
||||
from .utils.inference_engine import HuggingFaceEngine
|
||||
from .utils.rendering import Renderer
|
||||
|
||||
|
||||
class BaseEngine(ABC):
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
args: SampleArguments,
|
||||
model_args: ModelArguments,
|
||||
model: HFModel,
|
||||
renderer: Renderer,
|
||||
) -> None:
|
||||
"""Initialize the engine.
|
||||
|
||||
Args:
|
||||
args: Sample arguments.
|
||||
model_args: Model arguments.
|
||||
model: Model.
|
||||
renderer: Renderer.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
|
||||
"""Generate tokens asynchronously.
|
||||
|
||||
Args:
|
||||
messages: List of messages.
|
||||
tools: Tools string.
|
||||
|
||||
Yields:
|
||||
Generated tokens.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
|
||||
"""Batch infer samples.
|
||||
|
||||
Args:
|
||||
dataset: Torch dataset.
|
||||
|
||||
Returns:
|
||||
List of samples.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class HuggingFaceEngine(BaseEngine):
|
||||
def __init__(
|
||||
self,
|
||||
args: SampleArguments,
|
||||
model_args: ModelArguments,
|
||||
model: HFModel,
|
||||
renderer: Renderer,
|
||||
) -> None:
|
||||
self.args = args
|
||||
self.model_args = model_args
|
||||
self.model = model
|
||||
self.renderer = renderer
|
||||
self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
|
||||
|
||||
@torch.inference_mode()
|
||||
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
|
||||
async with self.semaphore:
|
||||
model_inputs = self.renderer.render_messages(messages, tools, is_generate=True)
|
||||
streamer = AsyncTextIteratorStreamer(
|
||||
tokenizer=get_tokenizer(self.renderer.processor),
|
||||
skip_prompt=True,
|
||||
skip_special_tokens=True, # TODO: configurable
|
||||
)
|
||||
device = DistributedInterface().current_device
|
||||
kwargs = {
|
||||
"input_ids": torch.tensor([model_inputs["input_ids"]]).to(device),
|
||||
"attention_mask": torch.tensor([model_inputs["attention_mask"]]).to(device),
|
||||
"max_new_tokens": self.args.max_new_tokens,
|
||||
"streamer": streamer,
|
||||
}
|
||||
thread = Thread(target=self.model.generate, kwargs=kwargs, daemon=True)
|
||||
thread.start()
|
||||
|
||||
async for token in streamer:
|
||||
yield token
|
||||
|
||||
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
|
||||
"""Batch infer samples.
|
||||
|
||||
Args:
|
||||
dataset: Torch dataset.
|
||||
|
||||
Returns:
|
||||
List of samples.
|
||||
"""
|
||||
raise NotImplementedError("Batch infer is not implemented.")
|
||||
|
||||
|
||||
class BaseSampler:
|
||||
"""Base sampler.
|
||||
|
||||
|
||||
@@ -16,42 +16,166 @@
|
||||
|
||||
Init Phase:
|
||||
|
||||
1. Init dataloader.
|
||||
1. Init batch generator.
|
||||
2. Init optimizer (deepspeed).
|
||||
3. Shard model.
|
||||
4. Init optimizer (fsdp).
|
||||
5. Init scheduler.
|
||||
5. Init lr scheduler.
|
||||
|
||||
Train Phase:
|
||||
1. Train Loop
|
||||
|
||||
"""
|
||||
|
||||
from ..config.training_args import TrainingArguments
|
||||
from ..utils.types import HFModel, TorchDataset
|
||||
from abc import abstractmethod
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..accelerator.helper import ReduceOp
|
||||
from ..accelerator.interface import Dim, DistributedInterface
|
||||
from ..config import TrainingArguments
|
||||
from ..utils import logging
|
||||
from ..utils.helper import compute_valid_tokens
|
||||
from ..utils.types import BatchInput, HFModel, ModelOutput, Tensor, TorchDataset
|
||||
from .utils.batching import BatchGenerator
|
||||
from .utils.rendering import Renderer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class BaseTrainer:
|
||||
def __init__(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
model: HFModel,
|
||||
renderer: Renderer,
|
||||
dataset: TorchDataset,
|
||||
train_dataset: TorchDataset,
|
||||
) -> None:
|
||||
self.args = args
|
||||
self.model = model
|
||||
self.renderer = renderer
|
||||
self.dataset = dataset
|
||||
self.optimizer = None
|
||||
self.lr_scheduler = None
|
||||
self.train_dataset = train_dataset
|
||||
|
||||
def _create_dataloader(self) -> None:
|
||||
# info
|
||||
self.global_step = 0
|
||||
|
||||
# cached variables
|
||||
self.device = DistributedInterface().current_device
|
||||
self.dp_size = DistributedInterface().get_world_size(Dim.DP)
|
||||
self.model_input_names = self.renderer.processor.model_input_names
|
||||
|
||||
self._create_batch_generator()
|
||||
self.num_training_steps = self.args.num_train_epochs * len(self.train_batch_generator)
|
||||
|
||||
if self.args.enable_activation_checkpointing:
|
||||
self.model.gradient_checkpointing_enable({"use_reentrant": False})
|
||||
|
||||
if self.args.dist_config is not None:
|
||||
shard_need_optimizer = self.args.dist_config.name == "deepspeed"
|
||||
else:
|
||||
shard_need_optimizer = False
|
||||
|
||||
if shard_need_optimizer:
|
||||
self._init_optimizer()
|
||||
self._shard_model()
|
||||
else:
|
||||
self._shard_model()
|
||||
self._init_optimizer()
|
||||
|
||||
self._init_lr_scheduler()
|
||||
|
||||
def _create_batch_generator(self) -> None:
|
||||
self.train_batch_generator = BatchGenerator(
|
||||
dataset=self.train_dataset,
|
||||
renderer=self.renderer,
|
||||
micro_batch_size=self.args.micro_batch_size,
|
||||
global_batch_size=self.args.global_batch_size,
|
||||
cutoff_len=self.args.cutoff_len,
|
||||
batching_workers=self.args.batching_workers,
|
||||
batching_strategy=self.args.batching_strategy,
|
||||
)
|
||||
|
||||
def _shard_model(self) -> None:
|
||||
pass
|
||||
|
||||
def _init_model_and_optimizer(self) -> None:
|
||||
pass
|
||||
def _init_optimizer(self) -> None:
|
||||
"""Init optimizer."""
|
||||
if self.args.optim_config is None:
|
||||
_trainable_params = [p for p in self.model.parameters() if p.requires_grad]
|
||||
self.optimizer = torch.optim.AdamW(_trainable_params, lr=self.args.learning_rate)
|
||||
else:
|
||||
from ..plugins.trainer_plugins.optimizer import OptimizerPlugin
|
||||
|
||||
self.optimizer = OptimizerPlugin(self.args.optim_config.name)(self.model, self.args.optim_config)
|
||||
|
||||
def _init_lr_scheduler(self) -> None:
|
||||
"""Init lr scheduler."""
|
||||
if self.args.lr_scheduler_config is None:
|
||||
self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda x: 1.0)
|
||||
else:
|
||||
from ..plugins.trainer_plugins.lr_scheduler import LRSchedulerPlugin
|
||||
|
||||
self.lr_scheduler = LRSchedulerPlugin(self.args.lr_scheduler_config.name)(
|
||||
self.optimizer, self.num_training_steps, self.args.lr_scheduler_config
|
||||
)
|
||||
|
||||
def compute_log_probs(self, model: HFModel, batch: BatchInput) -> Tensor:
|
||||
"""Compute log probs.
|
||||
|
||||
log_probs: Tensor of shape (batch_size, seq_len - 1)
|
||||
"""
|
||||
batch_size, _ = batch["labels"].shape
|
||||
model_inputs = {
|
||||
k: v.to(self.device, non_blocking=True) for k, v in batch.items() if k in self.model_input_names
|
||||
}
|
||||
labels = batch["labels"].to(self.device, non_blocking=True)
|
||||
outputs: ModelOutput = model(**model_inputs)
|
||||
logits = outputs.logits.float()
|
||||
shift_labels = labels[..., 1:].contiguous().view(-1)
|
||||
shift_logits = logits[..., :-1, :].contiguous().view(shift_labels.size(0), -1)
|
||||
return -F.cross_entropy(shift_logits, shift_labels, reduction="none").view(batch_size, -1)
|
||||
|
||||
@abstractmethod
|
||||
def compute_loss(self, batch: BatchInput) -> Tensor:
|
||||
"""Compute the scalar loss."""
|
||||
...
|
||||
|
||||
def fit(self) -> None:
|
||||
pass
|
||||
"""Train the model."""
|
||||
self.model.train()
|
||||
for epoch in range(self.args.num_train_epochs):
|
||||
self.train_batch_generator.set_epoch(epoch)
|
||||
for micro_batches in self.train_batch_generator:
|
||||
self.global_step += 1
|
||||
step_loss = 0
|
||||
step_valid_tokens = compute_valid_tokens(micro_batches)
|
||||
step_valid_tokens = DistributedInterface().all_reduce(step_valid_tokens, op=ReduceOp.SUM)
|
||||
for micro_batch in micro_batches:
|
||||
loss = self.compute_loss(micro_batch)
|
||||
mini_step_valid_tokens = compute_valid_tokens([micro_batch])
|
||||
# fsdp uses mean reduction so we need to scale the loss by dp_size
|
||||
loss = loss * mini_step_valid_tokens * self.dp_size / (step_valid_tokens + 1e-6)
|
||||
|
||||
loss.backward()
|
||||
step_loss += loss.item()
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm).item()
|
||||
if not torch.isfinite(grad_norm):
|
||||
logger.warning_rank0(f"Gradient norm is not finite: {grad_norm}")
|
||||
else:
|
||||
self.optimizer.step()
|
||||
|
||||
self.lr_scheduler.step()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
step_loss, grad_norm = DistributedInterface().all_reduce([step_loss, grad_norm])
|
||||
DistributedInterface().sync()
|
||||
print(f"Epoch {epoch}, Step {self.global_step}, Loss: {step_loss:.4f}, Grad Norm: {grad_norm:.4f}")
|
||||
|
||||
def save_model(self) -> None:
|
||||
"""Save the model."""
|
||||
self.model.save_pretrained(self.args.output_dir)
|
||||
self.renderer.processor.save_pretrained(self.args.output_dir)
|
||||
logger.info_rank0(f"Model saved to {self.args.output_dir}")
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
"""The definition of data engine.
|
||||
|
||||
How to use:
|
||||
data_engine = DataEngine(data_args)
|
||||
data_engine = DataEngine(data_args.train_dataset)
|
||||
data_engine[i]: Get the sample via index.
|
||||
|
||||
Init workflow:
|
||||
@@ -41,7 +41,6 @@ from huggingface_hub import hf_hub_download
|
||||
from omegaconf import OmegaConf
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from ..config.data_args import DataArguments
|
||||
from ..utils.types import DatasetInfo, HFDataset, Sample
|
||||
|
||||
|
||||
@@ -52,9 +51,9 @@ class DataEngine(Dataset):
|
||||
data_args: Data arguments.
|
||||
"""
|
||||
|
||||
def __init__(self, data_args: DataArguments) -> None:
|
||||
self.args = data_args
|
||||
"""Data arguments."""
|
||||
def __init__(self, dataset_path: str) -> None:
|
||||
self.path = dataset_path
|
||||
"""Dataset path."""
|
||||
self.datasets: dict[str, HFDataset] = {}
|
||||
"""Dict of (dataset_name, dataset)"""
|
||||
self.dataset_infos: dict[str, DatasetInfo] = {}
|
||||
@@ -69,16 +68,16 @@ class DataEngine(Dataset):
|
||||
|
||||
def _get_dataset_info(self) -> None:
|
||||
"""Get dataset info from data arguments."""
|
||||
if self.args.dataset.endswith(".yaml") and os.path.isfile(self.args.dataset): # local file
|
||||
self.dataset_infos = OmegaConf.load(self.args.dataset)
|
||||
elif self.args.dataset.endswith(".yaml"): # hf hub uri, e.g. llamafactory/v1-sft-demo/dataset_info.yaml
|
||||
repo_id, filename = os.path.split(self.args.dataset)
|
||||
if self.path.endswith(".yaml") and os.path.isfile(self.path): # local file
|
||||
self.dataset_infos = OmegaConf.load(self.path)
|
||||
elif self.path.endswith(".yaml"): # hf hub uri, e.g. llamafactory/v1-sft-demo/dataset_info.yaml
|
||||
repo_id, filename = os.path.split(self.path)
|
||||
filepath = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")
|
||||
self.dataset_infos = OmegaConf.load(filepath)
|
||||
elif os.path.exists(self.args.dataset): # local file(s)
|
||||
self.dataset_infos = {"default": {"path": self.args.dataset, "source": "local"}}
|
||||
elif os.path.exists(self.path): # local file(s)
|
||||
self.dataset_infos = {"default": {"path": self.path, "source": "local"}}
|
||||
else: # hf hub dataset, e.g. llamafactory/v1-sft-demo
|
||||
self.dataset_infos = {"default": {"path": self.args.dataset}}
|
||||
self.dataset_infos = {"default": {"path": self.path}}
|
||||
|
||||
def _load_dataset(self) -> None:
|
||||
"""Load datasets according to dataset info."""
|
||||
@@ -187,11 +186,11 @@ class DataEngine(Dataset):
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python -m llamafactory.v1.core.data_engine --dataset data/v1_sft_demo.yaml
|
||||
python -m llamafactory.v1.core.data_engine --dataset data/v1_dpo_demo.yaml
|
||||
python -m llamafactory.v1.core.data_engine --train_dataset data/v1_sft_demo.yaml
|
||||
python -m llamafactory.v1.core.data_engine --train_dataset data/v1_dpo_demo.yaml
|
||||
"""
|
||||
from ..config.arg_parser import get_args
|
||||
|
||||
data_args, *_ = get_args()
|
||||
data_engine = DataEngine(data_args=data_args)
|
||||
_, data_args, *_ = get_args()
|
||||
data_engine = DataEngine(data_args.train_dataset)
|
||||
print(data_engine[0])
|
||||
|
||||
@@ -153,7 +153,7 @@ if __name__ == "__main__":
|
||||
"""
|
||||
from ..config.arg_parser import get_args
|
||||
|
||||
_, model_args, *_ = get_args()
|
||||
model_args, *_ = get_args()
|
||||
model_engine = ModelEngine(model_args=model_args)
|
||||
print(model_engine.processor)
|
||||
print(model_engine.model_config)
|
||||
|
||||
@@ -216,7 +216,7 @@ if __name__ == "__main__":
|
||||
"""
|
||||
python -m llamafactory.v1.core.utils.batching \
|
||||
--model llamafactory/tiny-random-qwen2.5 \
|
||||
--dataset data/v1_sft_demo.yaml \
|
||||
--train_dataset data/v1_sft_demo.yaml \
|
||||
--micro_batch_size 2 \
|
||||
--global_batch_size 4 \
|
||||
--batching_workers 0
|
||||
@@ -225,8 +225,8 @@ if __name__ == "__main__":
|
||||
from ..data_engine import DataEngine
|
||||
from ..model_engine import ModelEngine
|
||||
|
||||
data_args, model_args, training_args, _ = get_args()
|
||||
data_engine = DataEngine(data_args=data_args)
|
||||
model_args, data_args, training_args, _ = get_args()
|
||||
data_engine = DataEngine(data_args.train_dataset)
|
||||
model_engine = ModelEngine(model_args=model_args)
|
||||
batch_generator = BatchGenerator(
|
||||
data_engine,
|
||||
|
||||
@@ -1,119 +0,0 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from torch.utils.data._utils.collate import default_collate
|
||||
|
||||
from ....extras.constants import IGNORE_INDEX
|
||||
from ...plugins.data_plugins.template import Template
|
||||
from ...utils.types import Processor, Tensor
|
||||
|
||||
|
||||
def len2culen(seqlens: "torch.Tensor") -> "torch.Tensor": # FIXME move to utils
|
||||
"""Convert sequence lengths to cumulative sequence lengths."""
|
||||
return F.pad(torch.cumsum(seqlens, dim=0), (1, 0)).type(torch.int32)
|
||||
|
||||
|
||||
class DataCollator:
|
||||
"""Default Data collator."""
|
||||
|
||||
processor: "Processor" # processor name -> map to encode_messages function
|
||||
|
||||
def __post_init__(self):
|
||||
# callback for text tokenizer
|
||||
self.tokenizer = self.processor.tokenizer if hasattr(self.processor, "tokenizer") else self.processor
|
||||
|
||||
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Tensor]:
|
||||
"""Collate features into a batch."""
|
||||
batch = defaultdict(list)
|
||||
|
||||
# batching features
|
||||
for feature in features:
|
||||
for key in feature.keys():
|
||||
batch[key].append(feature[key])
|
||||
|
||||
for key in batch.keys():
|
||||
# process padding features
|
||||
if key in ["input_ids", "attention_mask", "position_ids"]:
|
||||
padding_value = self.tokenizer.pad_token_id if key == "input_ids" else 0
|
||||
batch[key] = pad_sequence(batch[key], batch_first=True, padding_value=padding_value)
|
||||
elif key in ["labels"]:
|
||||
batch[key] = pad_sequence(batch[key], batch_first=True, padding_value=IGNORE_INDEX)
|
||||
else:
|
||||
batch[key] = default_collate(batch[key])
|
||||
|
||||
return batch
|
||||
# sft: messages
|
||||
# dpo: chosen_messages, rejected_messages
|
||||
|
||||
|
||||
@dataclass
|
||||
class DefaultCollator(DataCollator):
|
||||
"""Example for now."""
|
||||
|
||||
processor: "Processor" # processor name -> map to encode_messages function
|
||||
template: "Template"
|
||||
|
||||
def __call__(self, messages: list[list[dict[str, Any]]]) -> dict[str, Tensor]:
|
||||
features = []
|
||||
|
||||
# Check if data is already tokenized (contains input_ids)
|
||||
if messages and isinstance(messages[0], dict) and "input_ids" in messages[0]:
|
||||
for feature in messages:
|
||||
if not isinstance(feature, dict):
|
||||
raise ValueError(f"Expected dict but got {type(feature)}")
|
||||
tensor_feature = {
|
||||
k: torch.tensor(v, dtype=torch.long) if not isinstance(v, torch.Tensor) else v
|
||||
for k, v in feature.items()
|
||||
}
|
||||
features.append(tensor_feature)
|
||||
else:
|
||||
# raw messages need to be encoded
|
||||
for message in messages:
|
||||
encoded_message = self.template.encode_messages(self.tokenizer, message)
|
||||
encoded_message = {k: torch.tensor(v, dtype=torch.long) for k, v in encoded_message.items()}
|
||||
features.append(encoded_message)
|
||||
|
||||
return super().__call__(features)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PairwiseCollator(DataCollator):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorWithPacking(DefaultCollator):
|
||||
"""Data collator with packing."""
|
||||
|
||||
processor: "Processor"
|
||||
template: "Template"
|
||||
|
||||
def __call__(self, features: Sequence[dict[str, "torch.Tensor"]]) -> dict[str, "torch.Tensor"]:
|
||||
seqlens = torch.tensor([len(feature["input_ids"]) for feature in features], dtype=torch.long)
|
||||
batch = {"cu_seqlens": len2culen(seqlens)}
|
||||
for input_name in features[0].keys():
|
||||
if input_name in ("input_ids", "attention_mask", "labels"):
|
||||
batch[input_name] = torch.cat([feature[input_name] for feature in features])
|
||||
else:
|
||||
batch[input_name] = default_collate([feature[input_name] for feature in features])
|
||||
|
||||
return batch
|
||||
121
src/llamafactory/v1/core/utils/inference_engine.py
Normal file
121
src/llamafactory/v1/core/utils/inference_engine.py
Normal file
@@ -0,0 +1,121 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncGenerator
|
||||
from threading import Thread
|
||||
|
||||
import torch
|
||||
from transformers import AsyncTextIteratorStreamer
|
||||
|
||||
from ...accelerator.interface import DistributedInterface
|
||||
from ...config import ModelArguments, SampleArguments
|
||||
from ...utils.helper import get_tokenizer
|
||||
from ...utils.types import HFModel, Message, Sample, TorchDataset
|
||||
from .rendering import Renderer
|
||||
|
||||
|
||||
class BaseEngine(ABC):
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
args: SampleArguments,
|
||||
model_args: ModelArguments,
|
||||
model: HFModel,
|
||||
renderer: Renderer,
|
||||
) -> None:
|
||||
"""Initialize the engine.
|
||||
|
||||
Args:
|
||||
args: Sample arguments.
|
||||
model_args: Model arguments.
|
||||
model: Model.
|
||||
renderer: Renderer.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
|
||||
"""Generate tokens asynchronously.
|
||||
|
||||
Args:
|
||||
messages: List of messages.
|
||||
tools: Tools string.
|
||||
|
||||
Yields:
|
||||
Generated tokens.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
|
||||
"""Batch infer samples.
|
||||
|
||||
Args:
|
||||
dataset: Torch dataset.
|
||||
|
||||
Returns:
|
||||
List of samples.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class HuggingFaceEngine(BaseEngine):
|
||||
def __init__(
|
||||
self,
|
||||
args: SampleArguments,
|
||||
model_args: ModelArguments,
|
||||
model: HFModel,
|
||||
renderer: Renderer,
|
||||
) -> None:
|
||||
self.args = args
|
||||
self.model_args = model_args
|
||||
self.model = model
|
||||
self.renderer = renderer
|
||||
self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
|
||||
|
||||
@torch.inference_mode()
|
||||
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
|
||||
async with self.semaphore:
|
||||
model_inputs = self.renderer.render_messages(messages, tools, is_generate=True)
|
||||
streamer = AsyncTextIteratorStreamer(
|
||||
tokenizer=get_tokenizer(self.renderer.processor),
|
||||
skip_prompt=True,
|
||||
skip_special_tokens=True, # TODO: configurable
|
||||
)
|
||||
device = DistributedInterface().current_device
|
||||
kwargs = {
|
||||
"input_ids": torch.tensor([model_inputs["input_ids"]]).to(device),
|
||||
"attention_mask": torch.tensor([model_inputs["attention_mask"]]).to(device),
|
||||
"max_new_tokens": self.args.max_new_tokens,
|
||||
"streamer": streamer,
|
||||
}
|
||||
thread = Thread(target=self.model.generate, kwargs=kwargs, daemon=True)
|
||||
thread.start()
|
||||
|
||||
async for token in streamer:
|
||||
yield token
|
||||
|
||||
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
|
||||
"""Batch infer samples.
|
||||
|
||||
Args:
|
||||
dataset: Torch dataset.
|
||||
|
||||
Returns:
|
||||
List of samples.
|
||||
"""
|
||||
raise NotImplementedError("Batch infer is not implemented.")
|
||||
@@ -142,8 +142,8 @@ class Renderer:
|
||||
elif "chosen_messages" in sample and "rejected_messages" in sample:
|
||||
chosen_input = self.render_messages(sample["chosen_messages"], sample.get("tools"))
|
||||
rejected_input = self.render_messages(sample["rejected_messages"], sample.get("tools"))
|
||||
chosen_input["token_type_ids"] = [0] * len(chosen_input["input_ids"])
|
||||
rejected_input["token_type_ids"] = [1] * len(rejected_input["input_ids"])
|
||||
chosen_input["token_type_ids"] = [1] * len(chosen_input["input_ids"])
|
||||
rejected_input["token_type_ids"] = [2] * len(rejected_input["input_ids"])
|
||||
model_input = ModelInput(
|
||||
input_ids=chosen_input["input_ids"] + rejected_input["input_ids"],
|
||||
attention_mask=chosen_input["attention_mask"] + rejected_input["attention_mask"],
|
||||
|
||||
@@ -18,8 +18,11 @@ from ...utils.types import BatchInfo, BatchInput, DataLoader
|
||||
|
||||
|
||||
class BatchingPlugin(BasePlugin):
|
||||
def compute_length(self, dataloader: DataLoader) -> int:
|
||||
"""Compute the length of the batch generator."""
|
||||
def compute_length(self, data_provider: DataLoader) -> int:
|
||||
"""Compute the length of the batch generator.
|
||||
|
||||
The approximate length is used to calculate the lr schedule.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def fill_buffer(self, buffer: StatefulBuffer, batch_info: BatchInfo) -> None:
|
||||
|
||||
19
src/llamafactory/v1/plugins/trainer_plugins/lr_scheduler.py
Normal file
19
src/llamafactory/v1/plugins/trainer_plugins/lr_scheduler.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...utils.plugin import BasePlugin
|
||||
|
||||
|
||||
class LRSchedulerPlugin(BasePlugin):
|
||||
pass
|
||||
19
src/llamafactory/v1/plugins/trainer_plugins/optimizer.py
Normal file
19
src/llamafactory/v1/plugins/trainer_plugins/optimizer.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...utils.plugin import BasePlugin
|
||||
|
||||
|
||||
class OptimizerPlugin(BasePlugin):
|
||||
pass
|
||||
@@ -73,14 +73,14 @@ class SyncSampler(BaseSampler):
|
||||
|
||||
|
||||
def run_chat(args: InputArgument = None):
|
||||
data_args, model_args, _, sample_args = get_args(args)
|
||||
model_args, data_args, _, sample_args = get_args(args)
|
||||
if sample_args.sample_backend != SampleBackend.HF:
|
||||
model_args.init_plugin = {"name": "init_on_meta"}
|
||||
|
||||
model_engine = ModelEngine(model_args)
|
||||
sampler = SyncSampler(sample_args, model_args, model_engine.model, model_engine.renderer)
|
||||
if data_args.dataset is not None:
|
||||
dataset = DataEngine(data_args)
|
||||
if data_args.train_dataset is not None:
|
||||
dataset = DataEngine(data_args.train_dataset)
|
||||
sampler.batch_infer(dataset)
|
||||
else:
|
||||
if os.name != "nt":
|
||||
|
||||
@@ -18,21 +18,35 @@ from ..config import InputArgument, get_args
|
||||
from ..core.base_trainer import BaseTrainer
|
||||
from ..core.data_engine import DataEngine
|
||||
from ..core.model_engine import ModelEngine
|
||||
from ..utils.types import BatchInput, Tensor
|
||||
|
||||
|
||||
class SFTTrainer(BaseTrainer):
|
||||
pass
|
||||
def compute_loss(self, batch: BatchInput) -> Tensor:
|
||||
shift_loss_weights = batch["loss_weights"].to(self.device, non_blocking=True)[..., 1:]
|
||||
log_probs = self.compute_log_probs(self.model, batch)
|
||||
loss = (-log_probs * shift_loss_weights).sum() / (shift_loss_weights.sum() + 1e-6)
|
||||
return loss
|
||||
|
||||
|
||||
def run_sft(args: InputArgument = None):
|
||||
model_args, data_args, training_args, _ = get_args(args)
|
||||
DistributedInterface(training_args.dist_config)
|
||||
data_engine = DataEngine(data_args)
|
||||
train_dataset = DataEngine(data_args.train_dataset)
|
||||
model_engine = ModelEngine(model_args)
|
||||
trainer = SFTTrainer(
|
||||
args=training_args,
|
||||
model=model_engine.model,
|
||||
renderer=model_engine.renderer,
|
||||
dataset=data_engine,
|
||||
train_dataset=train_dataset,
|
||||
)
|
||||
trainer.fit()
|
||||
trainer.save_model()
|
||||
DistributedInterface().destroy()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python -m llamafactory.v1.trainers.sft_trainer --model Qwen/Qwen3-0.6B --train_dataset data/v1_sft_demo.yaml
|
||||
"""
|
||||
run_sft()
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
import torch
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from ..accelerator.interface import DistributedInterface
|
||||
from .constants import IGNORE_INDEX
|
||||
from .types import BatchInput, ModelInput, Processor, Tensor
|
||||
|
||||
@@ -73,3 +74,20 @@ def pad_and_truncate(samples: list[ModelInput], max_seqlen: int) -> list[BatchIn
|
||||
padded_samples.append(padded_sample)
|
||||
|
||||
return padded_samples
|
||||
|
||||
|
||||
def compute_valid_tokens(batches: list[BatchInput]) -> int:
|
||||
"""Compute valid tokens in batches.
|
||||
|
||||
Args:
|
||||
batches: Batches.
|
||||
|
||||
Returns:
|
||||
Number of valid tokens.
|
||||
"""
|
||||
device = DistributedInterface().current_device
|
||||
return sum(
|
||||
(batch["labels"].to(device, non_blocking=True) != IGNORE_INDEX).sum().item()
|
||||
for batch in batches
|
||||
if "labels" in batch
|
||||
)
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Iterator
|
||||
from typing import TYPE_CHECKING, Any, Literal, NotRequired, TypedDict, Union
|
||||
from typing import TYPE_CHECKING, Any, Literal, NamedTuple, NotRequired, TypedDict, Union
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -146,7 +146,7 @@ class ModelInput(TypedDict, total=False):
|
||||
position_ids: NotRequired[list[int] | list[list[int]]]
|
||||
"""Position ids for the model (optional)."""
|
||||
token_type_ids: NotRequired[list[int]]
|
||||
"""Token type ids used in DPO, 0 represents the chosen messages, 1 represents the rejected messages."""
|
||||
"""Token type ids used in DPO, 1 represents the chosen messages, 2 represents the rejected messages."""
|
||||
|
||||
|
||||
class BatchInput(TypedDict, total=False):
|
||||
@@ -161,7 +161,7 @@ class BatchInput(TypedDict, total=False):
|
||||
position_ids: NotRequired[Tensor]
|
||||
"""Position ids for the model (optional)."""
|
||||
token_type_ids: NotRequired[Tensor]
|
||||
"""Token type ids used in DPO, 0 represents the chosen messages, 1 represents the rejected messages."""
|
||||
"""Token type ids used in DPO, 1 represents the chosen messages, 2 represents the rejected messages."""
|
||||
|
||||
|
||||
class BatchInfo(TypedDict):
|
||||
@@ -173,3 +173,8 @@ class BatchInfo(TypedDict):
|
||||
"""Cutoff length."""
|
||||
data_iter: Iterator[list[ModelInput]]
|
||||
"""Data iterator."""
|
||||
|
||||
|
||||
class ModelOutput(NamedTuple):
|
||||
logits: Tensor
|
||||
"""Logits for the model."""
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
# change if test fails or cache is outdated
|
||||
0.9.5.104
|
||||
0.9.5.105
|
||||
|
||||
@@ -34,7 +34,7 @@ def test_get_args_from_yaml(tmp_path: Path):
|
||||
quant_config: null
|
||||
|
||||
### data
|
||||
dataset: llamafactory/v1-sft-demo
|
||||
train_dataset: llamafactory/v1-sft-demo
|
||||
|
||||
### training
|
||||
output_dir: outputs/test_run
|
||||
@@ -56,8 +56,8 @@ def test_get_args_from_yaml(tmp_path: Path):
|
||||
test_argv = ["test_args_parser.py", str(config_file)]
|
||||
|
||||
with patch.object(sys, "argv", test_argv):
|
||||
data_args, model_args, training_args, sample_args = get_args()
|
||||
assert data_args.dataset == "llamafactory/v1-sft-demo"
|
||||
model_args, data_args, training_args, sample_args = get_args()
|
||||
assert data_args.train_dataset == "llamafactory/v1-sft-demo"
|
||||
assert model_args.model == "llamafactory/tiny-random-qwen3"
|
||||
assert model_args.kernel_config.name == "auto"
|
||||
assert model_args.kernel_config.get("include_kernels") == "auto"
|
||||
|
||||
@@ -23,8 +23,8 @@ from llamafactory.v1.core.data_engine import DataEngine
|
||||
|
||||
@pytest.mark.parametrize("num_samples", [16])
|
||||
def test_map_dataset(num_samples: int):
|
||||
data_args = DataArguments(dataset="llamafactory/v1-sft-demo")
|
||||
data_engine = DataEngine(data_args)
|
||||
data_args = DataArguments(train_dataset="llamafactory/v1-sft-demo")
|
||||
data_engine = DataEngine(data_args.train_dataset)
|
||||
original_data = load_dataset("llamafactory/v1-sft-demo", split="train")
|
||||
indexes = random.choices(range(len(data_engine)), k=num_samples)
|
||||
for index in indexes:
|
||||
|
||||
@@ -19,8 +19,8 @@ from llamafactory.v1.core.utils.batching import BatchGenerator
|
||||
|
||||
|
||||
def test_normal_batching():
|
||||
data_args = DataArguments(dataset="llamafactory/v1-sft-demo")
|
||||
data_engine = DataEngine(data_args=data_args)
|
||||
data_args = DataArguments(train_dataset="llamafactory/v1-sft-demo")
|
||||
data_engine = DataEngine(data_args.train_dataset)
|
||||
model_args = ModelArguments(model="llamafactory/tiny-random-qwen3")
|
||||
model_engine = ModelEngine(model_args=model_args)
|
||||
training_args = TrainingArguments(
|
||||
|
||||
@@ -111,8 +111,8 @@ def test_chatml_parse():
|
||||
def test_chatml_rendering_remote(num_samples: int):
|
||||
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
|
||||
renderer = Renderer(template="chatml", processor=tokenizer)
|
||||
data_args = DataArguments(dataset="llamafactory/v1-sft-demo")
|
||||
data_engine = DataEngine(data_args)
|
||||
data_args = DataArguments(train_dataset="llamafactory/v1-sft-demo")
|
||||
data_engine = DataEngine(data_args.train_dataset)
|
||||
for index in range(num_samples):
|
||||
v1_inputs = renderer.render_messages(data_engine[index]["messages"], is_generate=True)
|
||||
prefix = tokenizer.encode("<|im_start|>user\n", add_special_tokens=False)
|
||||
@@ -167,8 +167,8 @@ def test_qwen3_nothink_parse():
|
||||
def test_qwen3_nothink_rendering_remote(num_samples: int):
|
||||
tokenizer: Processor = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B-Instruct-2507")
|
||||
renderer = Renderer(template="qwen3_nothink", processor=tokenizer)
|
||||
data_args = DataArguments(dataset="llamafactory/reason-tool-use-demo-1500")
|
||||
data_engine = DataEngine(data_args)
|
||||
data_args = DataArguments(train_dataset="llamafactory/reason-tool-use-demo-1500")
|
||||
data_engine = DataEngine(data_args.train_dataset)
|
||||
for index in range(num_samples):
|
||||
v1_inputs = renderer.render_messages(data_engine[index]["messages"], tools=data_engine[index]["tools"])
|
||||
prefix_text = (
|
||||
@@ -213,7 +213,7 @@ def test_process_dpo_samples():
|
||||
model_inputs = renderer.process_samples(samples)
|
||||
assert len(model_inputs) == 1
|
||||
assert model_inputs[0]["input_ids"] == hf_inputs * 2
|
||||
assert model_inputs[0]["token_type_ids"] == [0] * len(hf_inputs) + [1] * len(hf_inputs)
|
||||
assert model_inputs[0]["token_type_ids"] == [1] * len(hf_inputs) + [2] * len(hf_inputs)
|
||||
assert model_inputs[0]["extra_info"] == "test"
|
||||
assert model_inputs[0]["_dataset_name"] == "default"
|
||||
|
||||
|
||||
@@ -24,8 +24,8 @@ from llamafactory.v1.plugins.data_plugins.converter import DataConverterPlugin
|
||||
|
||||
@pytest.mark.parametrize("num_samples", [16])
|
||||
def test_alpaca_converter(num_samples: int):
|
||||
data_args = DataArguments(dataset="llamafactory/v1-dataset-info/tiny-supervised-dataset.yaml")
|
||||
data_engine = DataEngine(data_args)
|
||||
data_args = DataArguments(train_dataset="llamafactory/v1-dataset-info/tiny-supervised-dataset.yaml")
|
||||
data_engine = DataEngine(data_args.train_dataset)
|
||||
original_data = load_dataset("llamafactory/tiny-supervised-dataset", split="train")
|
||||
indexes = random.choices(range(len(data_engine)), k=num_samples)
|
||||
for index in indexes:
|
||||
@@ -73,8 +73,8 @@ def test_sharegpt_converter():
|
||||
|
||||
@pytest.mark.parametrize("num_samples", [16])
|
||||
def test_pair_converter(num_samples: int):
|
||||
data_args = DataArguments(dataset="llamafactory/v1-dataset-info/orca-dpo-pairs.yaml")
|
||||
data_engine = DataEngine(data_args)
|
||||
data_args = DataArguments(train_dataset="llamafactory/v1-dataset-info/orca-dpo-pairs.yaml")
|
||||
data_engine = DataEngine(data_args.train_dataset)
|
||||
original_data = load_dataset("HuggingFaceH4/orca_dpo_pairs", split="train_prefs")
|
||||
indexes = random.choices(range(len(data_engine)), k=num_samples)
|
||||
for index in indexes:
|
||||
|
||||
@@ -19,7 +19,7 @@ from llamafactory.v1.core.model_engine import ModelEngine
|
||||
|
||||
|
||||
def test_init_on_meta():
|
||||
_, model_args, *_ = get_args(
|
||||
model_args, *_ = get_args(
|
||||
dict(
|
||||
model="llamafactory/tiny-random-qwen3",
|
||||
init_config={"name": "init_on_meta"},
|
||||
@@ -30,7 +30,7 @@ def test_init_on_meta():
|
||||
|
||||
|
||||
def test_init_on_rank0():
|
||||
_, model_args, *_ = get_args(
|
||||
model_args, *_ = get_args(
|
||||
dict(
|
||||
model="llamafactory/tiny-random-qwen3",
|
||||
init_config={"name": "init_on_rank0"},
|
||||
@@ -44,7 +44,7 @@ def test_init_on_rank0():
|
||||
|
||||
|
||||
def test_init_on_default():
|
||||
_, model_args, *_ = get_args(
|
||||
model_args, *_ = get_args(
|
||||
dict(
|
||||
model="llamafactory/tiny-random-qwen3",
|
||||
init_config={"name": "init_on_default"},
|
||||
|
||||
@@ -43,7 +43,8 @@ def test_apply_kernel(mock_get_accelerator: MagicMock):
|
||||
reload_kernels()
|
||||
from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_default_kernels
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen3")
|
||||
# NOTE: use a special model to avoid contamination by other tests
|
||||
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
|
||||
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
|
||||
original_swiglu_forward = model.model.layers[0].mlp.forward
|
||||
model = apply_default_kernels(model=model, include_kernels="npu_fused_rmsnorm")
|
||||
@@ -62,7 +63,8 @@ def test_apply_all_kernels(mock_get_accelerator: MagicMock):
|
||||
reload_kernels()
|
||||
from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_default_kernels
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen3")
|
||||
# NOTE: use a special model to avoid contamination by other tests
|
||||
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
|
||||
|
||||
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
|
||||
original_swiglu_forward = model.model.layers[0].mlp.forward
|
||||
|
||||
Reference in New Issue
Block a user