mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-01-12 17:10:36 +08:00
[v1] add sft (#9752)
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -176,6 +176,7 @@ llamaboard_cache/
|
|||||||
llamaboard_config/
|
llamaboard_config/
|
||||||
saves/
|
saves/
|
||||||
output/
|
output/
|
||||||
|
outputs/
|
||||||
wandb/
|
wandb/
|
||||||
swanlog/
|
swanlog/
|
||||||
generated_predictions.jsonl
|
generated_predictions.jsonl
|
||||||
|
|||||||
@@ -174,7 +174,7 @@ class DistributedInterface:
|
|||||||
"""Get device mesh for specified dimension."""
|
"""Get device mesh for specified dimension."""
|
||||||
if dim is None:
|
if dim is None:
|
||||||
raise ValueError("dim must be specified.")
|
raise ValueError("dim must be specified.")
|
||||||
elif self.model_device_mesh is None:
|
elif not self._is_distributed:
|
||||||
return None
|
return None
|
||||||
elif dim in self.strategy.data_mesh_dim_names:
|
elif dim in self.strategy.data_mesh_dim_names:
|
||||||
return self.data_device_mesh[dim.value]
|
return self.data_device_mesh[dim.value]
|
||||||
@@ -183,14 +183,14 @@ class DistributedInterface:
|
|||||||
|
|
||||||
def get_group(self, dim: Dim | None = None) -> Optional[ProcessGroup]:
|
def get_group(self, dim: Dim | None = None) -> Optional[ProcessGroup]:
|
||||||
"""Get process group for specified dimension."""
|
"""Get process group for specified dimension."""
|
||||||
if self.model_device_mesh is None or dim is None:
|
if not self._is_distributed or dim is None:
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
return self.get_device_mesh(dim).get_group()
|
return self.get_device_mesh(dim).get_group()
|
||||||
|
|
||||||
def get_rank(self, dim: Dim | None = None) -> int:
|
def get_rank(self, dim: Dim | None = None) -> int:
|
||||||
"""Get parallel rank for specified dimension."""
|
"""Get parallel rank for specified dimension."""
|
||||||
if self.model_device_mesh is None:
|
if not self._is_distributed:
|
||||||
return 0
|
return 0
|
||||||
elif dim is None:
|
elif dim is None:
|
||||||
return self._rank
|
return self._rank
|
||||||
@@ -199,7 +199,7 @@ class DistributedInterface:
|
|||||||
|
|
||||||
def get_world_size(self, dim: Dim | None = None) -> int:
|
def get_world_size(self, dim: Dim | None = None) -> int:
|
||||||
"""Get parallel size for specified dimension."""
|
"""Get parallel size for specified dimension."""
|
||||||
if self.model_device_mesh is None:
|
if not self._is_distributed:
|
||||||
return 1
|
return 1
|
||||||
elif dim is None:
|
elif dim is None:
|
||||||
return self._world_size
|
return self._world_size
|
||||||
@@ -216,7 +216,7 @@ class DistributedInterface:
|
|||||||
|
|
||||||
def all_gather(self, data: TensorLike, dim: Dim | None = Dim.DP) -> TensorLike:
|
def all_gather(self, data: TensorLike, dim: Dim | None = Dim.DP) -> TensorLike:
|
||||||
"""Gather tensor across specified parallel group."""
|
"""Gather tensor across specified parallel group."""
|
||||||
if self.model_device_mesh is not None:
|
if self._is_distributed:
|
||||||
return helper.operate_tensorlike(helper.all_gather, data, group=self.get_group(dim))
|
return helper.operate_tensorlike(helper.all_gather, data, group=self.get_group(dim))
|
||||||
else:
|
else:
|
||||||
return data
|
return data
|
||||||
@@ -225,29 +225,32 @@ class DistributedInterface:
|
|||||||
self, data: TensorLike, op: helper.ReduceOp = helper.ReduceOp.MEAN, dim: Dim | None = Dim.DP
|
self, data: TensorLike, op: helper.ReduceOp = helper.ReduceOp.MEAN, dim: Dim | None = Dim.DP
|
||||||
) -> TensorLike:
|
) -> TensorLike:
|
||||||
"""Reduce tensor across specified parallel group."""
|
"""Reduce tensor across specified parallel group."""
|
||||||
if self.model_device_mesh is not None:
|
if self._is_distributed:
|
||||||
return helper.operate_tensorlike(helper.all_reduce, data, op=op, group=self.get_group(dim))
|
return helper.operate_tensorlike(helper.all_reduce, data, op=op, group=self.get_group(dim))
|
||||||
else:
|
else:
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def broadcast(self, data: TensorLike, src: int = 0, dim: Dim | None = Dim.DP) -> TensorLike:
|
def broadcast(self, data: TensorLike, src: int = 0, dim: Dim | None = Dim.DP) -> TensorLike:
|
||||||
"""Broadcast tensor across specified parallel group."""
|
"""Broadcast tensor across specified parallel group."""
|
||||||
if self.model_device_mesh is not None:
|
if self._is_distributed:
|
||||||
return helper.operate_tensorlike(helper.broadcast, data, src=src, group=self.get_group(dim))
|
return helper.operate_tensorlike(helper.broadcast, data, src=src, group=self.get_group(dim))
|
||||||
else:
|
else:
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def sync(self) -> None:
|
def sync(self) -> None:
|
||||||
"""Synchronize all processes."""
|
"""Synchronize all processes."""
|
||||||
helper.synchronize()
|
if self._is_distributed:
|
||||||
|
helper.synchronize()
|
||||||
|
|
||||||
def barrier(self) -> None:
|
def barrier(self) -> None:
|
||||||
"""Barrier all processes."""
|
"""Barrier all processes."""
|
||||||
barrier()
|
if self._is_distributed:
|
||||||
|
barrier()
|
||||||
|
|
||||||
def destroy(self) -> None:
|
def destroy(self) -> None:
|
||||||
"""Destroy all processes."""
|
"""Destroy all processes."""
|
||||||
destroy_process_group()
|
if self._is_distributed:
|
||||||
|
destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -30,9 +30,9 @@ from .training_args import TrainingArguments
|
|||||||
InputArgument = dict[str, Any] | list[str] | None
|
InputArgument = dict[str, Any] | list[str] | None
|
||||||
|
|
||||||
|
|
||||||
def get_args(args: InputArgument = None) -> tuple[DataArguments, ModelArguments, TrainingArguments, SampleArguments]:
|
def get_args(args: InputArgument = None) -> tuple[ModelArguments, DataArguments, TrainingArguments, SampleArguments]:
|
||||||
"""Parse arguments from command line or config file."""
|
"""Parse arguments from command line or config file."""
|
||||||
parser = HfArgumentParser([DataArguments, ModelArguments, TrainingArguments, SampleArguments])
|
parser = HfArgumentParser([ModelArguments, DataArguments, TrainingArguments, SampleArguments])
|
||||||
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_KEYS")
|
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_KEYS")
|
||||||
|
|
||||||
if args is None:
|
if args is None:
|
||||||
|
|||||||
@@ -18,7 +18,11 @@ from dataclasses import dataclass, field
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataArguments:
|
class DataArguments:
|
||||||
dataset: str | None = field(
|
train_dataset: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to the dataset."},
|
metadata={"help": "Path to the training dataset."},
|
||||||
|
)
|
||||||
|
eval_dataset: str | None = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Path to the evaluation dataset."},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -33,13 +33,21 @@ class TrainingArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Global batch size for training, default to DP size * micro batch size."},
|
metadata={"help": "Global batch size for training, default to DP size * micro batch size."},
|
||||||
)
|
)
|
||||||
|
cutoff_len: int = field(
|
||||||
|
default=2048,
|
||||||
|
metadata={"help": "Maximum sequence length for training."},
|
||||||
|
)
|
||||||
learning_rate: float = field(
|
learning_rate: float = field(
|
||||||
default=1e-4,
|
default=1e-4,
|
||||||
metadata={"help": "Learning rate for training."},
|
metadata={"help": "Learning rate for training."},
|
||||||
)
|
)
|
||||||
cutoff_len: int = field(
|
num_train_epochs: int = field(
|
||||||
default=2048,
|
default=3,
|
||||||
metadata={"help": "Maximum sequence length for training."},
|
metadata={"help": "Number of training epochs."},
|
||||||
|
)
|
||||||
|
max_grad_norm: float = field(
|
||||||
|
default=1.0,
|
||||||
|
metadata={"help": "Maximum gradient norm for training."},
|
||||||
)
|
)
|
||||||
bf16: bool = field(
|
bf16: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
@@ -53,10 +61,24 @@ class TrainingArguments:
|
|||||||
default=16,
|
default=16,
|
||||||
metadata={"help": "Number of workers for batching."},
|
metadata={"help": "Number of workers for batching."},
|
||||||
)
|
)
|
||||||
|
enable_activation_checkpointing: bool = field(
|
||||||
|
default=True,
|
||||||
|
metadata={"help": "Enable activation checkpointing for training."},
|
||||||
|
)
|
||||||
dist_config: PluginConfig | None = field(
|
dist_config: PluginConfig | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Distribution configuration for training."},
|
metadata={"help": "Distribution configuration for training."},
|
||||||
)
|
)
|
||||||
|
optim_config: PluginConfig | None = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Optimizer configuration for training."},
|
||||||
|
)
|
||||||
|
lr_scheduler_config: PluginConfig | None = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Learning rate scheduler configuration for training."},
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
self.dist_config = get_plugin_config(self.dist_config)
|
self.dist_config = get_plugin_config(self.dist_config)
|
||||||
|
self.optim_config = get_plugin_config(self.optim_config)
|
||||||
|
self.lr_scheduler_config = get_plugin_config(self.lr_scheduler_config)
|
||||||
|
|||||||
@@ -12,115 +12,14 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import os
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from threading import Thread
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from transformers import AsyncTextIteratorStreamer
|
|
||||||
|
|
||||||
from ..accelerator.interface import DistributedInterface
|
|
||||||
from ..config import ModelArguments, SampleArguments, SampleBackend
|
from ..config import ModelArguments, SampleArguments, SampleBackend
|
||||||
from ..utils.helper import get_tokenizer
|
|
||||||
from ..utils.types import HFModel, Message, Sample, TorchDataset
|
from ..utils.types import HFModel, Message, Sample, TorchDataset
|
||||||
|
from .utils.inference_engine import HuggingFaceEngine
|
||||||
from .utils.rendering import Renderer
|
from .utils.rendering import Renderer
|
||||||
|
|
||||||
|
|
||||||
class BaseEngine(ABC):
|
|
||||||
@abstractmethod
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
args: SampleArguments,
|
|
||||||
model_args: ModelArguments,
|
|
||||||
model: HFModel,
|
|
||||||
renderer: Renderer,
|
|
||||||
) -> None:
|
|
||||||
"""Initialize the engine.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
args: Sample arguments.
|
|
||||||
model_args: Model arguments.
|
|
||||||
model: Model.
|
|
||||||
renderer: Renderer.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
|
|
||||||
"""Generate tokens asynchronously.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: List of messages.
|
|
||||||
tools: Tools string.
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
Generated tokens.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
|
|
||||||
"""Batch infer samples.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset: Torch dataset.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of samples.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class HuggingFaceEngine(BaseEngine):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
args: SampleArguments,
|
|
||||||
model_args: ModelArguments,
|
|
||||||
model: HFModel,
|
|
||||||
renderer: Renderer,
|
|
||||||
) -> None:
|
|
||||||
self.args = args
|
|
||||||
self.model_args = model_args
|
|
||||||
self.model = model
|
|
||||||
self.renderer = renderer
|
|
||||||
self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
|
|
||||||
async with self.semaphore:
|
|
||||||
model_inputs = self.renderer.render_messages(messages, tools, is_generate=True)
|
|
||||||
streamer = AsyncTextIteratorStreamer(
|
|
||||||
tokenizer=get_tokenizer(self.renderer.processor),
|
|
||||||
skip_prompt=True,
|
|
||||||
skip_special_tokens=True, # TODO: configurable
|
|
||||||
)
|
|
||||||
device = DistributedInterface().current_device
|
|
||||||
kwargs = {
|
|
||||||
"input_ids": torch.tensor([model_inputs["input_ids"]]).to(device),
|
|
||||||
"attention_mask": torch.tensor([model_inputs["attention_mask"]]).to(device),
|
|
||||||
"max_new_tokens": self.args.max_new_tokens,
|
|
||||||
"streamer": streamer,
|
|
||||||
}
|
|
||||||
thread = Thread(target=self.model.generate, kwargs=kwargs, daemon=True)
|
|
||||||
thread.start()
|
|
||||||
|
|
||||||
async for token in streamer:
|
|
||||||
yield token
|
|
||||||
|
|
||||||
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
|
|
||||||
"""Batch infer samples.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset: Torch dataset.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of samples.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError("Batch infer is not implemented.")
|
|
||||||
|
|
||||||
|
|
||||||
class BaseSampler:
|
class BaseSampler:
|
||||||
"""Base sampler.
|
"""Base sampler.
|
||||||
|
|
||||||
|
|||||||
@@ -16,42 +16,166 @@
|
|||||||
|
|
||||||
Init Phase:
|
Init Phase:
|
||||||
|
|
||||||
1. Init dataloader.
|
1. Init batch generator.
|
||||||
2. Init optimizer (deepspeed).
|
2. Init optimizer (deepspeed).
|
||||||
3. Shard model.
|
3. Shard model.
|
||||||
4. Init optimizer (fsdp).
|
4. Init optimizer (fsdp).
|
||||||
5. Init scheduler.
|
5. Init lr scheduler.
|
||||||
|
|
||||||
Train Phase:
|
Train Phase:
|
||||||
1. Train Loop
|
1. Train Loop
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from ..config.training_args import TrainingArguments
|
from abc import abstractmethod
|
||||||
from ..utils.types import HFModel, TorchDataset
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from ..accelerator.helper import ReduceOp
|
||||||
|
from ..accelerator.interface import Dim, DistributedInterface
|
||||||
|
from ..config import TrainingArguments
|
||||||
|
from ..utils import logging
|
||||||
|
from ..utils.helper import compute_valid_tokens
|
||||||
|
from ..utils.types import BatchInput, HFModel, ModelOutput, Tensor, TorchDataset
|
||||||
|
from .utils.batching import BatchGenerator
|
||||||
from .utils.rendering import Renderer
|
from .utils.rendering import Renderer
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BaseTrainer:
|
class BaseTrainer:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
args: TrainingArguments,
|
args: TrainingArguments,
|
||||||
model: HFModel,
|
model: HFModel,
|
||||||
renderer: Renderer,
|
renderer: Renderer,
|
||||||
dataset: TorchDataset,
|
train_dataset: TorchDataset,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.args = args
|
self.args = args
|
||||||
self.model = model
|
self.model = model
|
||||||
self.renderer = renderer
|
self.renderer = renderer
|
||||||
self.dataset = dataset
|
self.train_dataset = train_dataset
|
||||||
self.optimizer = None
|
|
||||||
self.lr_scheduler = None
|
|
||||||
|
|
||||||
def _create_dataloader(self) -> None:
|
# info
|
||||||
|
self.global_step = 0
|
||||||
|
|
||||||
|
# cached variables
|
||||||
|
self.device = DistributedInterface().current_device
|
||||||
|
self.dp_size = DistributedInterface().get_world_size(Dim.DP)
|
||||||
|
self.model_input_names = self.renderer.processor.model_input_names
|
||||||
|
|
||||||
|
self._create_batch_generator()
|
||||||
|
self.num_training_steps = self.args.num_train_epochs * len(self.train_batch_generator)
|
||||||
|
|
||||||
|
if self.args.enable_activation_checkpointing:
|
||||||
|
self.model.gradient_checkpointing_enable({"use_reentrant": False})
|
||||||
|
|
||||||
|
if self.args.dist_config is not None:
|
||||||
|
shard_need_optimizer = self.args.dist_config.name == "deepspeed"
|
||||||
|
else:
|
||||||
|
shard_need_optimizer = False
|
||||||
|
|
||||||
|
if shard_need_optimizer:
|
||||||
|
self._init_optimizer()
|
||||||
|
self._shard_model()
|
||||||
|
else:
|
||||||
|
self._shard_model()
|
||||||
|
self._init_optimizer()
|
||||||
|
|
||||||
|
self._init_lr_scheduler()
|
||||||
|
|
||||||
|
def _create_batch_generator(self) -> None:
|
||||||
|
self.train_batch_generator = BatchGenerator(
|
||||||
|
dataset=self.train_dataset,
|
||||||
|
renderer=self.renderer,
|
||||||
|
micro_batch_size=self.args.micro_batch_size,
|
||||||
|
global_batch_size=self.args.global_batch_size,
|
||||||
|
cutoff_len=self.args.cutoff_len,
|
||||||
|
batching_workers=self.args.batching_workers,
|
||||||
|
batching_strategy=self.args.batching_strategy,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _shard_model(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _init_model_and_optimizer(self) -> None:
|
def _init_optimizer(self) -> None:
|
||||||
pass
|
"""Init optimizer."""
|
||||||
|
if self.args.optim_config is None:
|
||||||
|
_trainable_params = [p for p in self.model.parameters() if p.requires_grad]
|
||||||
|
self.optimizer = torch.optim.AdamW(_trainable_params, lr=self.args.learning_rate)
|
||||||
|
else:
|
||||||
|
from ..plugins.trainer_plugins.optimizer import OptimizerPlugin
|
||||||
|
|
||||||
|
self.optimizer = OptimizerPlugin(self.args.optim_config.name)(self.model, self.args.optim_config)
|
||||||
|
|
||||||
|
def _init_lr_scheduler(self) -> None:
|
||||||
|
"""Init lr scheduler."""
|
||||||
|
if self.args.lr_scheduler_config is None:
|
||||||
|
self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda x: 1.0)
|
||||||
|
else:
|
||||||
|
from ..plugins.trainer_plugins.lr_scheduler import LRSchedulerPlugin
|
||||||
|
|
||||||
|
self.lr_scheduler = LRSchedulerPlugin(self.args.lr_scheduler_config.name)(
|
||||||
|
self.optimizer, self.num_training_steps, self.args.lr_scheduler_config
|
||||||
|
)
|
||||||
|
|
||||||
|
def compute_log_probs(self, model: HFModel, batch: BatchInput) -> Tensor:
|
||||||
|
"""Compute log probs.
|
||||||
|
|
||||||
|
log_probs: Tensor of shape (batch_size, seq_len - 1)
|
||||||
|
"""
|
||||||
|
batch_size, _ = batch["labels"].shape
|
||||||
|
model_inputs = {
|
||||||
|
k: v.to(self.device, non_blocking=True) for k, v in batch.items() if k in self.model_input_names
|
||||||
|
}
|
||||||
|
labels = batch["labels"].to(self.device, non_blocking=True)
|
||||||
|
outputs: ModelOutput = model(**model_inputs)
|
||||||
|
logits = outputs.logits.float()
|
||||||
|
shift_labels = labels[..., 1:].contiguous().view(-1)
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous().view(shift_labels.size(0), -1)
|
||||||
|
return -F.cross_entropy(shift_logits, shift_labels, reduction="none").view(batch_size, -1)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def compute_loss(self, batch: BatchInput) -> Tensor:
|
||||||
|
"""Compute the scalar loss."""
|
||||||
|
...
|
||||||
|
|
||||||
def fit(self) -> None:
|
def fit(self) -> None:
|
||||||
pass
|
"""Train the model."""
|
||||||
|
self.model.train()
|
||||||
|
for epoch in range(self.args.num_train_epochs):
|
||||||
|
self.train_batch_generator.set_epoch(epoch)
|
||||||
|
for micro_batches in self.train_batch_generator:
|
||||||
|
self.global_step += 1
|
||||||
|
step_loss = 0
|
||||||
|
step_valid_tokens = compute_valid_tokens(micro_batches)
|
||||||
|
step_valid_tokens = DistributedInterface().all_reduce(step_valid_tokens, op=ReduceOp.SUM)
|
||||||
|
for micro_batch in micro_batches:
|
||||||
|
loss = self.compute_loss(micro_batch)
|
||||||
|
mini_step_valid_tokens = compute_valid_tokens([micro_batch])
|
||||||
|
# fsdp uses mean reduction so we need to scale the loss by dp_size
|
||||||
|
loss = loss * mini_step_valid_tokens * self.dp_size / (step_valid_tokens + 1e-6)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
step_loss += loss.item()
|
||||||
|
|
||||||
|
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm).item()
|
||||||
|
if not torch.isfinite(grad_norm):
|
||||||
|
logger.warning_rank0(f"Gradient norm is not finite: {grad_norm}")
|
||||||
|
else:
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
|
self.lr_scheduler.step()
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
|
step_loss, grad_norm = DistributedInterface().all_reduce([step_loss, grad_norm])
|
||||||
|
DistributedInterface().sync()
|
||||||
|
print(f"Epoch {epoch}, Step {self.global_step}, Loss: {step_loss:.4f}, Grad Norm: {grad_norm:.4f}")
|
||||||
|
|
||||||
|
def save_model(self) -> None:
|
||||||
|
"""Save the model."""
|
||||||
|
self.model.save_pretrained(self.args.output_dir)
|
||||||
|
self.renderer.processor.save_pretrained(self.args.output_dir)
|
||||||
|
logger.info_rank0(f"Model saved to {self.args.output_dir}")
|
||||||
|
|||||||
@@ -15,7 +15,7 @@
|
|||||||
"""The definition of data engine.
|
"""The definition of data engine.
|
||||||
|
|
||||||
How to use:
|
How to use:
|
||||||
data_engine = DataEngine(data_args)
|
data_engine = DataEngine(data_args.train_dataset)
|
||||||
data_engine[i]: Get the sample via index.
|
data_engine[i]: Get the sample via index.
|
||||||
|
|
||||||
Init workflow:
|
Init workflow:
|
||||||
@@ -41,7 +41,6 @@ from huggingface_hub import hf_hub_download
|
|||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from ..config.data_args import DataArguments
|
|
||||||
from ..utils.types import DatasetInfo, HFDataset, Sample
|
from ..utils.types import DatasetInfo, HFDataset, Sample
|
||||||
|
|
||||||
|
|
||||||
@@ -52,9 +51,9 @@ class DataEngine(Dataset):
|
|||||||
data_args: Data arguments.
|
data_args: Data arguments.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, data_args: DataArguments) -> None:
|
def __init__(self, dataset_path: str) -> None:
|
||||||
self.args = data_args
|
self.path = dataset_path
|
||||||
"""Data arguments."""
|
"""Dataset path."""
|
||||||
self.datasets: dict[str, HFDataset] = {}
|
self.datasets: dict[str, HFDataset] = {}
|
||||||
"""Dict of (dataset_name, dataset)"""
|
"""Dict of (dataset_name, dataset)"""
|
||||||
self.dataset_infos: dict[str, DatasetInfo] = {}
|
self.dataset_infos: dict[str, DatasetInfo] = {}
|
||||||
@@ -69,16 +68,16 @@ class DataEngine(Dataset):
|
|||||||
|
|
||||||
def _get_dataset_info(self) -> None:
|
def _get_dataset_info(self) -> None:
|
||||||
"""Get dataset info from data arguments."""
|
"""Get dataset info from data arguments."""
|
||||||
if self.args.dataset.endswith(".yaml") and os.path.isfile(self.args.dataset): # local file
|
if self.path.endswith(".yaml") and os.path.isfile(self.path): # local file
|
||||||
self.dataset_infos = OmegaConf.load(self.args.dataset)
|
self.dataset_infos = OmegaConf.load(self.path)
|
||||||
elif self.args.dataset.endswith(".yaml"): # hf hub uri, e.g. llamafactory/v1-sft-demo/dataset_info.yaml
|
elif self.path.endswith(".yaml"): # hf hub uri, e.g. llamafactory/v1-sft-demo/dataset_info.yaml
|
||||||
repo_id, filename = os.path.split(self.args.dataset)
|
repo_id, filename = os.path.split(self.path)
|
||||||
filepath = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")
|
filepath = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")
|
||||||
self.dataset_infos = OmegaConf.load(filepath)
|
self.dataset_infos = OmegaConf.load(filepath)
|
||||||
elif os.path.exists(self.args.dataset): # local file(s)
|
elif os.path.exists(self.path): # local file(s)
|
||||||
self.dataset_infos = {"default": {"path": self.args.dataset, "source": "local"}}
|
self.dataset_infos = {"default": {"path": self.path, "source": "local"}}
|
||||||
else: # hf hub dataset, e.g. llamafactory/v1-sft-demo
|
else: # hf hub dataset, e.g. llamafactory/v1-sft-demo
|
||||||
self.dataset_infos = {"default": {"path": self.args.dataset}}
|
self.dataset_infos = {"default": {"path": self.path}}
|
||||||
|
|
||||||
def _load_dataset(self) -> None:
|
def _load_dataset(self) -> None:
|
||||||
"""Load datasets according to dataset info."""
|
"""Load datasets according to dataset info."""
|
||||||
@@ -187,11 +186,11 @@ class DataEngine(Dataset):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
"""
|
"""
|
||||||
python -m llamafactory.v1.core.data_engine --dataset data/v1_sft_demo.yaml
|
python -m llamafactory.v1.core.data_engine --train_dataset data/v1_sft_demo.yaml
|
||||||
python -m llamafactory.v1.core.data_engine --dataset data/v1_dpo_demo.yaml
|
python -m llamafactory.v1.core.data_engine --train_dataset data/v1_dpo_demo.yaml
|
||||||
"""
|
"""
|
||||||
from ..config.arg_parser import get_args
|
from ..config.arg_parser import get_args
|
||||||
|
|
||||||
data_args, *_ = get_args()
|
_, data_args, *_ = get_args()
|
||||||
data_engine = DataEngine(data_args=data_args)
|
data_engine = DataEngine(data_args.train_dataset)
|
||||||
print(data_engine[0])
|
print(data_engine[0])
|
||||||
|
|||||||
@@ -153,7 +153,7 @@ if __name__ == "__main__":
|
|||||||
"""
|
"""
|
||||||
from ..config.arg_parser import get_args
|
from ..config.arg_parser import get_args
|
||||||
|
|
||||||
_, model_args, *_ = get_args()
|
model_args, *_ = get_args()
|
||||||
model_engine = ModelEngine(model_args=model_args)
|
model_engine = ModelEngine(model_args=model_args)
|
||||||
print(model_engine.processor)
|
print(model_engine.processor)
|
||||||
print(model_engine.model_config)
|
print(model_engine.model_config)
|
||||||
|
|||||||
@@ -216,7 +216,7 @@ if __name__ == "__main__":
|
|||||||
"""
|
"""
|
||||||
python -m llamafactory.v1.core.utils.batching \
|
python -m llamafactory.v1.core.utils.batching \
|
||||||
--model llamafactory/tiny-random-qwen2.5 \
|
--model llamafactory/tiny-random-qwen2.5 \
|
||||||
--dataset data/v1_sft_demo.yaml \
|
--train_dataset data/v1_sft_demo.yaml \
|
||||||
--micro_batch_size 2 \
|
--micro_batch_size 2 \
|
||||||
--global_batch_size 4 \
|
--global_batch_size 4 \
|
||||||
--batching_workers 0
|
--batching_workers 0
|
||||||
@@ -225,8 +225,8 @@ if __name__ == "__main__":
|
|||||||
from ..data_engine import DataEngine
|
from ..data_engine import DataEngine
|
||||||
from ..model_engine import ModelEngine
|
from ..model_engine import ModelEngine
|
||||||
|
|
||||||
data_args, model_args, training_args, _ = get_args()
|
model_args, data_args, training_args, _ = get_args()
|
||||||
data_engine = DataEngine(data_args=data_args)
|
data_engine = DataEngine(data_args.train_dataset)
|
||||||
model_engine = ModelEngine(model_args=model_args)
|
model_engine = ModelEngine(model_args=model_args)
|
||||||
batch_generator = BatchGenerator(
|
batch_generator = BatchGenerator(
|
||||||
data_engine,
|
data_engine,
|
||||||
|
|||||||
@@ -1,119 +0,0 @@
|
|||||||
# Copyright 2025 the LlamaFactory team.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from collections import defaultdict
|
|
||||||
from collections.abc import Sequence
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
|
||||||
from torch.utils.data._utils.collate import default_collate
|
|
||||||
|
|
||||||
from ....extras.constants import IGNORE_INDEX
|
|
||||||
from ...plugins.data_plugins.template import Template
|
|
||||||
from ...utils.types import Processor, Tensor
|
|
||||||
|
|
||||||
|
|
||||||
def len2culen(seqlens: "torch.Tensor") -> "torch.Tensor": # FIXME move to utils
|
|
||||||
"""Convert sequence lengths to cumulative sequence lengths."""
|
|
||||||
return F.pad(torch.cumsum(seqlens, dim=0), (1, 0)).type(torch.int32)
|
|
||||||
|
|
||||||
|
|
||||||
class DataCollator:
|
|
||||||
"""Default Data collator."""
|
|
||||||
|
|
||||||
processor: "Processor" # processor name -> map to encode_messages function
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
# callback for text tokenizer
|
|
||||||
self.tokenizer = self.processor.tokenizer if hasattr(self.processor, "tokenizer") else self.processor
|
|
||||||
|
|
||||||
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Tensor]:
|
|
||||||
"""Collate features into a batch."""
|
|
||||||
batch = defaultdict(list)
|
|
||||||
|
|
||||||
# batching features
|
|
||||||
for feature in features:
|
|
||||||
for key in feature.keys():
|
|
||||||
batch[key].append(feature[key])
|
|
||||||
|
|
||||||
for key in batch.keys():
|
|
||||||
# process padding features
|
|
||||||
if key in ["input_ids", "attention_mask", "position_ids"]:
|
|
||||||
padding_value = self.tokenizer.pad_token_id if key == "input_ids" else 0
|
|
||||||
batch[key] = pad_sequence(batch[key], batch_first=True, padding_value=padding_value)
|
|
||||||
elif key in ["labels"]:
|
|
||||||
batch[key] = pad_sequence(batch[key], batch_first=True, padding_value=IGNORE_INDEX)
|
|
||||||
else:
|
|
||||||
batch[key] = default_collate(batch[key])
|
|
||||||
|
|
||||||
return batch
|
|
||||||
# sft: messages
|
|
||||||
# dpo: chosen_messages, rejected_messages
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class DefaultCollator(DataCollator):
|
|
||||||
"""Example for now."""
|
|
||||||
|
|
||||||
processor: "Processor" # processor name -> map to encode_messages function
|
|
||||||
template: "Template"
|
|
||||||
|
|
||||||
def __call__(self, messages: list[list[dict[str, Any]]]) -> dict[str, Tensor]:
|
|
||||||
features = []
|
|
||||||
|
|
||||||
# Check if data is already tokenized (contains input_ids)
|
|
||||||
if messages and isinstance(messages[0], dict) and "input_ids" in messages[0]:
|
|
||||||
for feature in messages:
|
|
||||||
if not isinstance(feature, dict):
|
|
||||||
raise ValueError(f"Expected dict but got {type(feature)}")
|
|
||||||
tensor_feature = {
|
|
||||||
k: torch.tensor(v, dtype=torch.long) if not isinstance(v, torch.Tensor) else v
|
|
||||||
for k, v in feature.items()
|
|
||||||
}
|
|
||||||
features.append(tensor_feature)
|
|
||||||
else:
|
|
||||||
# raw messages need to be encoded
|
|
||||||
for message in messages:
|
|
||||||
encoded_message = self.template.encode_messages(self.tokenizer, message)
|
|
||||||
encoded_message = {k: torch.tensor(v, dtype=torch.long) for k, v in encoded_message.items()}
|
|
||||||
features.append(encoded_message)
|
|
||||||
|
|
||||||
return super().__call__(features)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PairwiseCollator(DataCollator):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class DataCollatorWithPacking(DefaultCollator):
|
|
||||||
"""Data collator with packing."""
|
|
||||||
|
|
||||||
processor: "Processor"
|
|
||||||
template: "Template"
|
|
||||||
|
|
||||||
def __call__(self, features: Sequence[dict[str, "torch.Tensor"]]) -> dict[str, "torch.Tensor"]:
|
|
||||||
seqlens = torch.tensor([len(feature["input_ids"]) for feature in features], dtype=torch.long)
|
|
||||||
batch = {"cu_seqlens": len2culen(seqlens)}
|
|
||||||
for input_name in features[0].keys():
|
|
||||||
if input_name in ("input_ids", "attention_mask", "labels"):
|
|
||||||
batch[input_name] = torch.cat([feature[input_name] for feature in features])
|
|
||||||
else:
|
|
||||||
batch[input_name] = default_collate([feature[input_name] for feature in features])
|
|
||||||
|
|
||||||
return batch
|
|
||||||
121
src/llamafactory/v1/core/utils/inference_engine.py
Normal file
121
src/llamafactory/v1/core/utils/inference_engine.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import AsyncTextIteratorStreamer
|
||||||
|
|
||||||
|
from ...accelerator.interface import DistributedInterface
|
||||||
|
from ...config import ModelArguments, SampleArguments
|
||||||
|
from ...utils.helper import get_tokenizer
|
||||||
|
from ...utils.types import HFModel, Message, Sample, TorchDataset
|
||||||
|
from .rendering import Renderer
|
||||||
|
|
||||||
|
|
||||||
|
class BaseEngine(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
args: SampleArguments,
|
||||||
|
model_args: ModelArguments,
|
||||||
|
model: HFModel,
|
||||||
|
renderer: Renderer,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the engine.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: Sample arguments.
|
||||||
|
model_args: Model arguments.
|
||||||
|
model: Model.
|
||||||
|
renderer: Renderer.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
|
||||||
|
"""Generate tokens asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of messages.
|
||||||
|
tools: Tools string.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Generated tokens.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
|
||||||
|
"""Batch infer samples.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: Torch dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of samples.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingFaceEngine(BaseEngine):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
args: SampleArguments,
|
||||||
|
model_args: ModelArguments,
|
||||||
|
model: HFModel,
|
||||||
|
renderer: Renderer,
|
||||||
|
) -> None:
|
||||||
|
self.args = args
|
||||||
|
self.model_args = model_args
|
||||||
|
self.model = model
|
||||||
|
self.renderer = renderer
|
||||||
|
self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
|
||||||
|
async with self.semaphore:
|
||||||
|
model_inputs = self.renderer.render_messages(messages, tools, is_generate=True)
|
||||||
|
streamer = AsyncTextIteratorStreamer(
|
||||||
|
tokenizer=get_tokenizer(self.renderer.processor),
|
||||||
|
skip_prompt=True,
|
||||||
|
skip_special_tokens=True, # TODO: configurable
|
||||||
|
)
|
||||||
|
device = DistributedInterface().current_device
|
||||||
|
kwargs = {
|
||||||
|
"input_ids": torch.tensor([model_inputs["input_ids"]]).to(device),
|
||||||
|
"attention_mask": torch.tensor([model_inputs["attention_mask"]]).to(device),
|
||||||
|
"max_new_tokens": self.args.max_new_tokens,
|
||||||
|
"streamer": streamer,
|
||||||
|
}
|
||||||
|
thread = Thread(target=self.model.generate, kwargs=kwargs, daemon=True)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
async for token in streamer:
|
||||||
|
yield token
|
||||||
|
|
||||||
|
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
|
||||||
|
"""Batch infer samples.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: Torch dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of samples.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Batch infer is not implemented.")
|
||||||
@@ -142,8 +142,8 @@ class Renderer:
|
|||||||
elif "chosen_messages" in sample and "rejected_messages" in sample:
|
elif "chosen_messages" in sample and "rejected_messages" in sample:
|
||||||
chosen_input = self.render_messages(sample["chosen_messages"], sample.get("tools"))
|
chosen_input = self.render_messages(sample["chosen_messages"], sample.get("tools"))
|
||||||
rejected_input = self.render_messages(sample["rejected_messages"], sample.get("tools"))
|
rejected_input = self.render_messages(sample["rejected_messages"], sample.get("tools"))
|
||||||
chosen_input["token_type_ids"] = [0] * len(chosen_input["input_ids"])
|
chosen_input["token_type_ids"] = [1] * len(chosen_input["input_ids"])
|
||||||
rejected_input["token_type_ids"] = [1] * len(rejected_input["input_ids"])
|
rejected_input["token_type_ids"] = [2] * len(rejected_input["input_ids"])
|
||||||
model_input = ModelInput(
|
model_input = ModelInput(
|
||||||
input_ids=chosen_input["input_ids"] + rejected_input["input_ids"],
|
input_ids=chosen_input["input_ids"] + rejected_input["input_ids"],
|
||||||
attention_mask=chosen_input["attention_mask"] + rejected_input["attention_mask"],
|
attention_mask=chosen_input["attention_mask"] + rejected_input["attention_mask"],
|
||||||
|
|||||||
@@ -18,8 +18,11 @@ from ...utils.types import BatchInfo, BatchInput, DataLoader
|
|||||||
|
|
||||||
|
|
||||||
class BatchingPlugin(BasePlugin):
|
class BatchingPlugin(BasePlugin):
|
||||||
def compute_length(self, dataloader: DataLoader) -> int:
|
def compute_length(self, data_provider: DataLoader) -> int:
|
||||||
"""Compute the length of the batch generator."""
|
"""Compute the length of the batch generator.
|
||||||
|
|
||||||
|
The approximate length is used to calculate the lr schedule.
|
||||||
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def fill_buffer(self, buffer: StatefulBuffer, batch_info: BatchInfo) -> None:
|
def fill_buffer(self, buffer: StatefulBuffer, batch_info: BatchInfo) -> None:
|
||||||
|
|||||||
19
src/llamafactory/v1/plugins/trainer_plugins/lr_scheduler.py
Normal file
19
src/llamafactory/v1/plugins/trainer_plugins/lr_scheduler.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from ...utils.plugin import BasePlugin
|
||||||
|
|
||||||
|
|
||||||
|
class LRSchedulerPlugin(BasePlugin):
|
||||||
|
pass
|
||||||
19
src/llamafactory/v1/plugins/trainer_plugins/optimizer.py
Normal file
19
src/llamafactory/v1/plugins/trainer_plugins/optimizer.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from ...utils.plugin import BasePlugin
|
||||||
|
|
||||||
|
|
||||||
|
class OptimizerPlugin(BasePlugin):
|
||||||
|
pass
|
||||||
@@ -73,14 +73,14 @@ class SyncSampler(BaseSampler):
|
|||||||
|
|
||||||
|
|
||||||
def run_chat(args: InputArgument = None):
|
def run_chat(args: InputArgument = None):
|
||||||
data_args, model_args, _, sample_args = get_args(args)
|
model_args, data_args, _, sample_args = get_args(args)
|
||||||
if sample_args.sample_backend != SampleBackend.HF:
|
if sample_args.sample_backend != SampleBackend.HF:
|
||||||
model_args.init_plugin = {"name": "init_on_meta"}
|
model_args.init_plugin = {"name": "init_on_meta"}
|
||||||
|
|
||||||
model_engine = ModelEngine(model_args)
|
model_engine = ModelEngine(model_args)
|
||||||
sampler = SyncSampler(sample_args, model_args, model_engine.model, model_engine.renderer)
|
sampler = SyncSampler(sample_args, model_args, model_engine.model, model_engine.renderer)
|
||||||
if data_args.dataset is not None:
|
if data_args.train_dataset is not None:
|
||||||
dataset = DataEngine(data_args)
|
dataset = DataEngine(data_args.train_dataset)
|
||||||
sampler.batch_infer(dataset)
|
sampler.batch_infer(dataset)
|
||||||
else:
|
else:
|
||||||
if os.name != "nt":
|
if os.name != "nt":
|
||||||
|
|||||||
@@ -18,21 +18,35 @@ from ..config import InputArgument, get_args
|
|||||||
from ..core.base_trainer import BaseTrainer
|
from ..core.base_trainer import BaseTrainer
|
||||||
from ..core.data_engine import DataEngine
|
from ..core.data_engine import DataEngine
|
||||||
from ..core.model_engine import ModelEngine
|
from ..core.model_engine import ModelEngine
|
||||||
|
from ..utils.types import BatchInput, Tensor
|
||||||
|
|
||||||
|
|
||||||
class SFTTrainer(BaseTrainer):
|
class SFTTrainer(BaseTrainer):
|
||||||
pass
|
def compute_loss(self, batch: BatchInput) -> Tensor:
|
||||||
|
shift_loss_weights = batch["loss_weights"].to(self.device, non_blocking=True)[..., 1:]
|
||||||
|
log_probs = self.compute_log_probs(self.model, batch)
|
||||||
|
loss = (-log_probs * shift_loss_weights).sum() / (shift_loss_weights.sum() + 1e-6)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
def run_sft(args: InputArgument = None):
|
def run_sft(args: InputArgument = None):
|
||||||
model_args, data_args, training_args, _ = get_args(args)
|
model_args, data_args, training_args, _ = get_args(args)
|
||||||
DistributedInterface(training_args.dist_config)
|
DistributedInterface(training_args.dist_config)
|
||||||
data_engine = DataEngine(data_args)
|
train_dataset = DataEngine(data_args.train_dataset)
|
||||||
model_engine = ModelEngine(model_args)
|
model_engine = ModelEngine(model_args)
|
||||||
trainer = SFTTrainer(
|
trainer = SFTTrainer(
|
||||||
args=training_args,
|
args=training_args,
|
||||||
model=model_engine.model,
|
model=model_engine.model,
|
||||||
renderer=model_engine.renderer,
|
renderer=model_engine.renderer,
|
||||||
dataset=data_engine,
|
train_dataset=train_dataset,
|
||||||
)
|
)
|
||||||
trainer.fit()
|
trainer.fit()
|
||||||
|
trainer.save_model()
|
||||||
|
DistributedInterface().destroy()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
"""
|
||||||
|
python -m llamafactory.v1.trainers.sft_trainer --model Qwen/Qwen3-0.6B --train_dataset data/v1_sft_demo.yaml
|
||||||
|
"""
|
||||||
|
run_sft()
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
import torch
|
import torch
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
|
from ..accelerator.interface import DistributedInterface
|
||||||
from .constants import IGNORE_INDEX
|
from .constants import IGNORE_INDEX
|
||||||
from .types import BatchInput, ModelInput, Processor, Tensor
|
from .types import BatchInput, ModelInput, Processor, Tensor
|
||||||
|
|
||||||
@@ -73,3 +74,20 @@ def pad_and_truncate(samples: list[ModelInput], max_seqlen: int) -> list[BatchIn
|
|||||||
padded_samples.append(padded_sample)
|
padded_samples.append(padded_sample)
|
||||||
|
|
||||||
return padded_samples
|
return padded_samples
|
||||||
|
|
||||||
|
|
||||||
|
def compute_valid_tokens(batches: list[BatchInput]) -> int:
|
||||||
|
"""Compute valid tokens in batches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batches: Batches.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of valid tokens.
|
||||||
|
"""
|
||||||
|
device = DistributedInterface().current_device
|
||||||
|
return sum(
|
||||||
|
(batch["labels"].to(device, non_blocking=True) != IGNORE_INDEX).sum().item()
|
||||||
|
for batch in batches
|
||||||
|
if "labels" in batch
|
||||||
|
)
|
||||||
|
|||||||
@@ -13,7 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
from typing import TYPE_CHECKING, Any, Literal, NotRequired, TypedDict, Union
|
from typing import TYPE_CHECKING, Any, Literal, NamedTuple, NotRequired, TypedDict, Union
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -146,7 +146,7 @@ class ModelInput(TypedDict, total=False):
|
|||||||
position_ids: NotRequired[list[int] | list[list[int]]]
|
position_ids: NotRequired[list[int] | list[list[int]]]
|
||||||
"""Position ids for the model (optional)."""
|
"""Position ids for the model (optional)."""
|
||||||
token_type_ids: NotRequired[list[int]]
|
token_type_ids: NotRequired[list[int]]
|
||||||
"""Token type ids used in DPO, 0 represents the chosen messages, 1 represents the rejected messages."""
|
"""Token type ids used in DPO, 1 represents the chosen messages, 2 represents the rejected messages."""
|
||||||
|
|
||||||
|
|
||||||
class BatchInput(TypedDict, total=False):
|
class BatchInput(TypedDict, total=False):
|
||||||
@@ -161,7 +161,7 @@ class BatchInput(TypedDict, total=False):
|
|||||||
position_ids: NotRequired[Tensor]
|
position_ids: NotRequired[Tensor]
|
||||||
"""Position ids for the model (optional)."""
|
"""Position ids for the model (optional)."""
|
||||||
token_type_ids: NotRequired[Tensor]
|
token_type_ids: NotRequired[Tensor]
|
||||||
"""Token type ids used in DPO, 0 represents the chosen messages, 1 represents the rejected messages."""
|
"""Token type ids used in DPO, 1 represents the chosen messages, 2 represents the rejected messages."""
|
||||||
|
|
||||||
|
|
||||||
class BatchInfo(TypedDict):
|
class BatchInfo(TypedDict):
|
||||||
@@ -173,3 +173,8 @@ class BatchInfo(TypedDict):
|
|||||||
"""Cutoff length."""
|
"""Cutoff length."""
|
||||||
data_iter: Iterator[list[ModelInput]]
|
data_iter: Iterator[list[ModelInput]]
|
||||||
"""Data iterator."""
|
"""Data iterator."""
|
||||||
|
|
||||||
|
|
||||||
|
class ModelOutput(NamedTuple):
|
||||||
|
logits: Tensor
|
||||||
|
"""Logits for the model."""
|
||||||
|
|||||||
@@ -1,2 +1,2 @@
|
|||||||
# change if test fails or cache is outdated
|
# change if test fails or cache is outdated
|
||||||
0.9.5.104
|
0.9.5.105
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ def test_get_args_from_yaml(tmp_path: Path):
|
|||||||
quant_config: null
|
quant_config: null
|
||||||
|
|
||||||
### data
|
### data
|
||||||
dataset: llamafactory/v1-sft-demo
|
train_dataset: llamafactory/v1-sft-demo
|
||||||
|
|
||||||
### training
|
### training
|
||||||
output_dir: outputs/test_run
|
output_dir: outputs/test_run
|
||||||
@@ -56,8 +56,8 @@ def test_get_args_from_yaml(tmp_path: Path):
|
|||||||
test_argv = ["test_args_parser.py", str(config_file)]
|
test_argv = ["test_args_parser.py", str(config_file)]
|
||||||
|
|
||||||
with patch.object(sys, "argv", test_argv):
|
with patch.object(sys, "argv", test_argv):
|
||||||
data_args, model_args, training_args, sample_args = get_args()
|
model_args, data_args, training_args, sample_args = get_args()
|
||||||
assert data_args.dataset == "llamafactory/v1-sft-demo"
|
assert data_args.train_dataset == "llamafactory/v1-sft-demo"
|
||||||
assert model_args.model == "llamafactory/tiny-random-qwen3"
|
assert model_args.model == "llamafactory/tiny-random-qwen3"
|
||||||
assert model_args.kernel_config.name == "auto"
|
assert model_args.kernel_config.name == "auto"
|
||||||
assert model_args.kernel_config.get("include_kernels") == "auto"
|
assert model_args.kernel_config.get("include_kernels") == "auto"
|
||||||
|
|||||||
@@ -23,8 +23,8 @@ from llamafactory.v1.core.data_engine import DataEngine
|
|||||||
|
|
||||||
@pytest.mark.parametrize("num_samples", [16])
|
@pytest.mark.parametrize("num_samples", [16])
|
||||||
def test_map_dataset(num_samples: int):
|
def test_map_dataset(num_samples: int):
|
||||||
data_args = DataArguments(dataset="llamafactory/v1-sft-demo")
|
data_args = DataArguments(train_dataset="llamafactory/v1-sft-demo")
|
||||||
data_engine = DataEngine(data_args)
|
data_engine = DataEngine(data_args.train_dataset)
|
||||||
original_data = load_dataset("llamafactory/v1-sft-demo", split="train")
|
original_data = load_dataset("llamafactory/v1-sft-demo", split="train")
|
||||||
indexes = random.choices(range(len(data_engine)), k=num_samples)
|
indexes = random.choices(range(len(data_engine)), k=num_samples)
|
||||||
for index in indexes:
|
for index in indexes:
|
||||||
|
|||||||
@@ -19,8 +19,8 @@ from llamafactory.v1.core.utils.batching import BatchGenerator
|
|||||||
|
|
||||||
|
|
||||||
def test_normal_batching():
|
def test_normal_batching():
|
||||||
data_args = DataArguments(dataset="llamafactory/v1-sft-demo")
|
data_args = DataArguments(train_dataset="llamafactory/v1-sft-demo")
|
||||||
data_engine = DataEngine(data_args=data_args)
|
data_engine = DataEngine(data_args.train_dataset)
|
||||||
model_args = ModelArguments(model="llamafactory/tiny-random-qwen3")
|
model_args = ModelArguments(model="llamafactory/tiny-random-qwen3")
|
||||||
model_engine = ModelEngine(model_args=model_args)
|
model_engine = ModelEngine(model_args=model_args)
|
||||||
training_args = TrainingArguments(
|
training_args = TrainingArguments(
|
||||||
|
|||||||
@@ -111,8 +111,8 @@ def test_chatml_parse():
|
|||||||
def test_chatml_rendering_remote(num_samples: int):
|
def test_chatml_rendering_remote(num_samples: int):
|
||||||
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
|
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
|
||||||
renderer = Renderer(template="chatml", processor=tokenizer)
|
renderer = Renderer(template="chatml", processor=tokenizer)
|
||||||
data_args = DataArguments(dataset="llamafactory/v1-sft-demo")
|
data_args = DataArguments(train_dataset="llamafactory/v1-sft-demo")
|
||||||
data_engine = DataEngine(data_args)
|
data_engine = DataEngine(data_args.train_dataset)
|
||||||
for index in range(num_samples):
|
for index in range(num_samples):
|
||||||
v1_inputs = renderer.render_messages(data_engine[index]["messages"], is_generate=True)
|
v1_inputs = renderer.render_messages(data_engine[index]["messages"], is_generate=True)
|
||||||
prefix = tokenizer.encode("<|im_start|>user\n", add_special_tokens=False)
|
prefix = tokenizer.encode("<|im_start|>user\n", add_special_tokens=False)
|
||||||
@@ -167,8 +167,8 @@ def test_qwen3_nothink_parse():
|
|||||||
def test_qwen3_nothink_rendering_remote(num_samples: int):
|
def test_qwen3_nothink_rendering_remote(num_samples: int):
|
||||||
tokenizer: Processor = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B-Instruct-2507")
|
tokenizer: Processor = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B-Instruct-2507")
|
||||||
renderer = Renderer(template="qwen3_nothink", processor=tokenizer)
|
renderer = Renderer(template="qwen3_nothink", processor=tokenizer)
|
||||||
data_args = DataArguments(dataset="llamafactory/reason-tool-use-demo-1500")
|
data_args = DataArguments(train_dataset="llamafactory/reason-tool-use-demo-1500")
|
||||||
data_engine = DataEngine(data_args)
|
data_engine = DataEngine(data_args.train_dataset)
|
||||||
for index in range(num_samples):
|
for index in range(num_samples):
|
||||||
v1_inputs = renderer.render_messages(data_engine[index]["messages"], tools=data_engine[index]["tools"])
|
v1_inputs = renderer.render_messages(data_engine[index]["messages"], tools=data_engine[index]["tools"])
|
||||||
prefix_text = (
|
prefix_text = (
|
||||||
@@ -213,7 +213,7 @@ def test_process_dpo_samples():
|
|||||||
model_inputs = renderer.process_samples(samples)
|
model_inputs = renderer.process_samples(samples)
|
||||||
assert len(model_inputs) == 1
|
assert len(model_inputs) == 1
|
||||||
assert model_inputs[0]["input_ids"] == hf_inputs * 2
|
assert model_inputs[0]["input_ids"] == hf_inputs * 2
|
||||||
assert model_inputs[0]["token_type_ids"] == [0] * len(hf_inputs) + [1] * len(hf_inputs)
|
assert model_inputs[0]["token_type_ids"] == [1] * len(hf_inputs) + [2] * len(hf_inputs)
|
||||||
assert model_inputs[0]["extra_info"] == "test"
|
assert model_inputs[0]["extra_info"] == "test"
|
||||||
assert model_inputs[0]["_dataset_name"] == "default"
|
assert model_inputs[0]["_dataset_name"] == "default"
|
||||||
|
|
||||||
|
|||||||
@@ -24,8 +24,8 @@ from llamafactory.v1.plugins.data_plugins.converter import DataConverterPlugin
|
|||||||
|
|
||||||
@pytest.mark.parametrize("num_samples", [16])
|
@pytest.mark.parametrize("num_samples", [16])
|
||||||
def test_alpaca_converter(num_samples: int):
|
def test_alpaca_converter(num_samples: int):
|
||||||
data_args = DataArguments(dataset="llamafactory/v1-dataset-info/tiny-supervised-dataset.yaml")
|
data_args = DataArguments(train_dataset="llamafactory/v1-dataset-info/tiny-supervised-dataset.yaml")
|
||||||
data_engine = DataEngine(data_args)
|
data_engine = DataEngine(data_args.train_dataset)
|
||||||
original_data = load_dataset("llamafactory/tiny-supervised-dataset", split="train")
|
original_data = load_dataset("llamafactory/tiny-supervised-dataset", split="train")
|
||||||
indexes = random.choices(range(len(data_engine)), k=num_samples)
|
indexes = random.choices(range(len(data_engine)), k=num_samples)
|
||||||
for index in indexes:
|
for index in indexes:
|
||||||
@@ -73,8 +73,8 @@ def test_sharegpt_converter():
|
|||||||
|
|
||||||
@pytest.mark.parametrize("num_samples", [16])
|
@pytest.mark.parametrize("num_samples", [16])
|
||||||
def test_pair_converter(num_samples: int):
|
def test_pair_converter(num_samples: int):
|
||||||
data_args = DataArguments(dataset="llamafactory/v1-dataset-info/orca-dpo-pairs.yaml")
|
data_args = DataArguments(train_dataset="llamafactory/v1-dataset-info/orca-dpo-pairs.yaml")
|
||||||
data_engine = DataEngine(data_args)
|
data_engine = DataEngine(data_args.train_dataset)
|
||||||
original_data = load_dataset("HuggingFaceH4/orca_dpo_pairs", split="train_prefs")
|
original_data = load_dataset("HuggingFaceH4/orca_dpo_pairs", split="train_prefs")
|
||||||
indexes = random.choices(range(len(data_engine)), k=num_samples)
|
indexes = random.choices(range(len(data_engine)), k=num_samples)
|
||||||
for index in indexes:
|
for index in indexes:
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from llamafactory.v1.core.model_engine import ModelEngine
|
|||||||
|
|
||||||
|
|
||||||
def test_init_on_meta():
|
def test_init_on_meta():
|
||||||
_, model_args, *_ = get_args(
|
model_args, *_ = get_args(
|
||||||
dict(
|
dict(
|
||||||
model="llamafactory/tiny-random-qwen3",
|
model="llamafactory/tiny-random-qwen3",
|
||||||
init_config={"name": "init_on_meta"},
|
init_config={"name": "init_on_meta"},
|
||||||
@@ -30,7 +30,7 @@ def test_init_on_meta():
|
|||||||
|
|
||||||
|
|
||||||
def test_init_on_rank0():
|
def test_init_on_rank0():
|
||||||
_, model_args, *_ = get_args(
|
model_args, *_ = get_args(
|
||||||
dict(
|
dict(
|
||||||
model="llamafactory/tiny-random-qwen3",
|
model="llamafactory/tiny-random-qwen3",
|
||||||
init_config={"name": "init_on_rank0"},
|
init_config={"name": "init_on_rank0"},
|
||||||
@@ -44,7 +44,7 @@ def test_init_on_rank0():
|
|||||||
|
|
||||||
|
|
||||||
def test_init_on_default():
|
def test_init_on_default():
|
||||||
_, model_args, *_ = get_args(
|
model_args, *_ = get_args(
|
||||||
dict(
|
dict(
|
||||||
model="llamafactory/tiny-random-qwen3",
|
model="llamafactory/tiny-random-qwen3",
|
||||||
init_config={"name": "init_on_default"},
|
init_config={"name": "init_on_default"},
|
||||||
|
|||||||
@@ -43,7 +43,8 @@ def test_apply_kernel(mock_get_accelerator: MagicMock):
|
|||||||
reload_kernels()
|
reload_kernels()
|
||||||
from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_default_kernels
|
from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_default_kernels
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen3")
|
# NOTE: use a special model to avoid contamination by other tests
|
||||||
|
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
|
||||||
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
|
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
|
||||||
original_swiglu_forward = model.model.layers[0].mlp.forward
|
original_swiglu_forward = model.model.layers[0].mlp.forward
|
||||||
model = apply_default_kernels(model=model, include_kernels="npu_fused_rmsnorm")
|
model = apply_default_kernels(model=model, include_kernels="npu_fused_rmsnorm")
|
||||||
@@ -62,7 +63,8 @@ def test_apply_all_kernels(mock_get_accelerator: MagicMock):
|
|||||||
reload_kernels()
|
reload_kernels()
|
||||||
from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_default_kernels
|
from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_default_kernels
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen3")
|
# NOTE: use a special model to avoid contamination by other tests
|
||||||
|
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
|
||||||
|
|
||||||
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
|
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
|
||||||
original_swiglu_forward = model.model.layers[0].mlp.forward
|
original_swiglu_forward = model.model.layers[0].mlp.forward
|
||||||
|
|||||||
Reference in New Issue
Block a user