mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-04-24 06:39:08 +08:00
[v1] add batch generator (#9744)
This commit is contained in:
@@ -29,7 +29,6 @@ Train Phase:
|
||||
|
||||
from ..config.training_args import TrainingArguments
|
||||
from ..utils.types import HFModel, TorchDataset
|
||||
from .utils.data_collator import DataCollator
|
||||
from .utils.rendering import Renderer
|
||||
|
||||
|
||||
@@ -45,7 +44,6 @@ class BaseTrainer:
|
||||
self.model = model
|
||||
self.renderer = renderer
|
||||
self.dataset = dataset
|
||||
self.data_collator = DataCollator()
|
||||
self.optimizer = None
|
||||
self.lr_scheduler = None
|
||||
|
||||
|
||||
@@ -82,14 +82,17 @@ class DataEngine(Dataset):
|
||||
|
||||
def _load_dataset(self) -> None:
|
||||
"""Load datasets according to dataset info."""
|
||||
is_streaming = [dataset_info.get("streaming", False) for dataset_info in self.dataset_infos.values()]
|
||||
self.streaming = any(is_streaming)
|
||||
if all(is_streaming) != any(is_streaming):
|
||||
raise ValueError("All datasets must be streaming or non-streaming.")
|
||||
|
||||
for dataset_name, dataset_info in self.dataset_infos.items():
|
||||
split = dataset_info.get("split", "train")
|
||||
streaming = dataset_info.get("streaming", False)
|
||||
self.streaming |= streaming
|
||||
if dataset_info.get("source", "hf_hub") == "hf_hub":
|
||||
from datasets import load_dataset
|
||||
|
||||
self.datasets[dataset_name] = load_dataset(dataset_info["path"], split=split, streaming=streaming)
|
||||
self.datasets[dataset_name] = load_dataset(dataset_info["path"], split=split, streaming=self.streaming)
|
||||
else: # data loader plugin
|
||||
from ..plugins.data_plugins.loader import DataLoaderPlugin
|
||||
|
||||
@@ -98,8 +101,7 @@ class DataEngine(Dataset):
|
||||
def _build_data_index(self) -> None:
|
||||
"""Build dataset index."""
|
||||
for dataset_name, dataset in self.datasets.items():
|
||||
streaming = self.dataset_infos[dataset_name].get("streaming", False)
|
||||
if streaming:
|
||||
if self.streaming:
|
||||
data_index = [(dataset_name, -1) for _ in range(1000)]
|
||||
else:
|
||||
data_index = [(dataset_name, sample_index) for sample_index in range(len(dataset))]
|
||||
@@ -185,8 +187,8 @@ class DataEngine(Dataset):
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python -m llamafactory.v1.core.data_engine --model none --dataset data/v1_sft_demo.yaml
|
||||
python -m llamafactory.v1.core.data_engine --model none --dataset data/v1_dpo_demo.yaml
|
||||
python -m llamafactory.v1.core.data_engine --dataset data/v1_sft_demo.yaml
|
||||
python -m llamafactory.v1.core.data_engine --dataset data/v1_dpo_demo.yaml
|
||||
"""
|
||||
from ..config.arg_parser import get_args
|
||||
|
||||
|
||||
244
src/llamafactory/v1/core/utils/batching.py
Normal file
244
src/llamafactory/v1/core/utils/batching.py
Normal file
@@ -0,0 +1,244 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Batching utils supports stateful dataloader.
|
||||
|
||||
1. Init stateful dataloader (tokenize)
|
||||
2. Add to buffer
|
||||
3. Yield batch indexes (micro batch * grad acc)
|
||||
a) non pack + non dynamic
|
||||
b) non pack + dynamic
|
||||
c) pack + non dynamic
|
||||
d) pack + dynamic
|
||||
"""
|
||||
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
from torch.utils.data import default_collate
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
|
||||
|
||||
from ...accelerator.interface import DistributedInterface
|
||||
from ...config import BatchingStrategy
|
||||
from ...utils import logging
|
||||
from ...utils.helper import pad_and_truncate
|
||||
from ...utils.types import BatchInput, ModelInput, TorchDataset
|
||||
from .rendering import Renderer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def default_collate_fn(
|
||||
buffer: list[ModelInput], buffer_tokens: int, micro_batch_size: int, num_micro_batch: int, cutoff_len: int
|
||||
) -> tuple[list[ModelInput], int, list[BatchInput]]:
|
||||
batch_size = micro_batch_size * num_micro_batch
|
||||
if len(buffer) < batch_size:
|
||||
return buffer, buffer_tokens, None
|
||||
|
||||
samples = buffer[:batch_size]
|
||||
buffer = buffer[batch_size:]
|
||||
buffer_tokens -= sum(len(sample["input_ids"]) for sample in samples)
|
||||
|
||||
batch = []
|
||||
for i in range(num_micro_batch):
|
||||
micro_batch = samples[i * micro_batch_size : (i + 1) * micro_batch_size]
|
||||
batch.append(default_collate(pad_and_truncate(micro_batch, cutoff_len)))
|
||||
|
||||
return buffer, buffer_tokens, batch
|
||||
|
||||
|
||||
class BatchGenerator(Iterator):
|
||||
def __init__(
|
||||
self,
|
||||
dataset: TorchDataset,
|
||||
renderer: Renderer,
|
||||
micro_batch_size: int = 1,
|
||||
global_batch_size: int | None = None,
|
||||
cutoff_len: int = 2048,
|
||||
batching_workers: int = 0,
|
||||
batching_strategy: BatchingStrategy = BatchingStrategy.NORMAL,
|
||||
pin_memory: bool = True,
|
||||
drop_last: bool = True,
|
||||
) -> None:
|
||||
self.dataset = dataset
|
||||
self.renderer = renderer
|
||||
|
||||
self.micro_batch_size = micro_batch_size
|
||||
self.global_batch_size = global_batch_size
|
||||
self.cutoff_len = cutoff_len
|
||||
self.batching_workers = batching_workers
|
||||
self.batching_strategy = batching_strategy
|
||||
self.pin_memory = pin_memory
|
||||
self.drop_last = drop_last
|
||||
# TODO: support length and infinity
|
||||
|
||||
dp_size = DistributedInterface().get_world_size("dp")
|
||||
|
||||
if self.global_batch_size is None:
|
||||
self.global_batch_size = dp_size * micro_batch_size
|
||||
self.num_micro_batch = 1
|
||||
elif self.global_batch_size % (dp_size * micro_batch_size) == 0:
|
||||
self.num_micro_batch = global_batch_size // dp_size // micro_batch_size
|
||||
else:
|
||||
raise ValueError(
|
||||
"Global batch size must be divisible by DP size and micro batch size. "
|
||||
f"Got {global_batch_size} % ({dp_size} * {micro_batch_size}) != 0."
|
||||
)
|
||||
|
||||
if not self.drop_last:
|
||||
raise ValueError("Drop last must be True.")
|
||||
|
||||
self._init_data_provider()
|
||||
|
||||
self._is_resuming: bool = False
|
||||
self._data_iter = iter(self._data_provider)
|
||||
self._buffer: list[ModelInput] = []
|
||||
self._buffer_tokens: int = 0
|
||||
self._max_buffer_tokens: int = self.micro_batch_size * self.num_micro_batch * self.cutoff_len
|
||||
|
||||
logger.info_rank0(
|
||||
f"Init unified data loader with global batch size {self.global_batch_size}, "
|
||||
f"micro batch size {self.micro_batch_size}, "
|
||||
f"num micro batch {self.num_micro_batch}, "
|
||||
f"cutoff len {self.cutoff_len}, "
|
||||
f"batching workers {self.batching_workers}, "
|
||||
f"batching strategy {self.batching_strategy}."
|
||||
)
|
||||
|
||||
def _init_data_provider(self) -> None:
|
||||
if len(self.dataset) != -1:
|
||||
sampler = StatefulDistributedSampler(
|
||||
self.dataset,
|
||||
num_replicas=DistributedInterface().get_world_size("dp"),
|
||||
rank=DistributedInterface().get_rank("dp"),
|
||||
shuffle=True,
|
||||
seed=0,
|
||||
drop_last=self.drop_last,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Iterable dataset is not supported yet.")
|
||||
|
||||
self._data_provider = StatefulDataLoader(
|
||||
self.dataset,
|
||||
batch_size=self.micro_batch_size * self.num_micro_batch,
|
||||
sampler=sampler,
|
||||
num_workers=self.batching_workers,
|
||||
collate_fn=self.renderer.process_samples,
|
||||
pin_memory=self.pin_memory,
|
||||
drop_last=self.drop_last,
|
||||
)
|
||||
if self.batching_strategy == BatchingStrategy.NORMAL:
|
||||
self._length = len(self._data_provider)
|
||||
else:
|
||||
from ...plugins.trainer_plugins.batching import BatchingPlugin
|
||||
|
||||
self._length = BatchingPlugin(self.batching_strategy).compute_length()
|
||||
raise NotImplementedError("Batching strategy other than NORMAL is not supported yet.")
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._length
|
||||
|
||||
def __iter__(self):
|
||||
if not self._is_resuming:
|
||||
self._buffer.clear()
|
||||
self._buffer_tokens = 0
|
||||
|
||||
self._data_iter = iter(self._data_provider)
|
||||
self._is_resuming = False
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
batch = self._next_batch()
|
||||
if batch is None:
|
||||
raise StopIteration
|
||||
|
||||
return batch
|
||||
|
||||
def _next_batch(self) -> list[BatchInput] | None:
|
||||
while self._buffer_tokens < self._max_buffer_tokens:
|
||||
try:
|
||||
samples: list[ModelInput] = next(self._data_iter)
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
num_tokens = sum(len(sample["input_ids"]) for sample in samples)
|
||||
self._buffer.extend(samples)
|
||||
self._buffer_tokens += num_tokens
|
||||
|
||||
return self._build_batch()
|
||||
|
||||
def _build_batch(self) -> list[BatchInput] | None:
|
||||
if self.batching_strategy == BatchingStrategy.NORMAL:
|
||||
self._buffer, self._buffer_tokens, batch = default_collate_fn(
|
||||
self._buffer, self._buffer_tokens, self.micro_batch_size, self.num_micro_batch, self.cutoff_len
|
||||
)
|
||||
return batch
|
||||
else:
|
||||
from ...plugins.trainer_plugins.batching import BatchingPlugin
|
||||
|
||||
self._buffer, self._buffer_tokens, batch = BatchingPlugin(self.batching_strategy)(
|
||||
self._buffer, self._buffer_tokens, self.micro_batch_size, self.num_micro_batch, self.cutoff_len
|
||||
)
|
||||
return batch
|
||||
|
||||
def state_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"buffer": self._buffer,
|
||||
"buffer_tokens": self._buffer_tokens,
|
||||
"data_provider": self._data_provider.state_dict(),
|
||||
}
|
||||
|
||||
def load_state_dict(self, state: dict[str, Any]) -> None:
|
||||
self._buffer = state["buffer"]
|
||||
self._buffer_tokens = state["buffer_tokens"]
|
||||
self._data_provider.load_state_dict(state["data_provider"])
|
||||
self._is_resuming = True
|
||||
|
||||
def set_epoch(self, epoch: int) -> None:
|
||||
if hasattr(self._data_provider.sampler, "set_epoch"):
|
||||
self._data_provider.sampler.set_epoch(epoch)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python -m llamafactory.v1.core.utils.batching \
|
||||
--model llamafactory/tiny-random-qwen2.5 \
|
||||
--dataset data/v1_sft_demo.yaml \
|
||||
--micro_batch_size 2 \
|
||||
--global_batch_size 4 \
|
||||
--batching_workers 0
|
||||
"""
|
||||
from ...config.arg_parser import get_args
|
||||
from ..data_engine import DataEngine
|
||||
from ..model_engine import ModelEngine
|
||||
|
||||
data_args, model_args, training_args, _ = get_args()
|
||||
data_engine = DataEngine(data_args=data_args)
|
||||
model_engine = ModelEngine(model_args=model_args)
|
||||
batch_generator = BatchGenerator(
|
||||
data_engine,
|
||||
model_engine.renderer,
|
||||
micro_batch_size=training_args.micro_batch_size,
|
||||
global_batch_size=training_args.global_batch_size,
|
||||
cutoff_len=training_args.cutoff_len,
|
||||
batching_workers=training_args.batching_workers,
|
||||
batching_strategy=training_args.batching_strategy,
|
||||
)
|
||||
for batch in batch_generator:
|
||||
print(batch)
|
||||
print(len(batch))
|
||||
print(batch[0]["input_ids"].shape)
|
||||
break
|
||||
@@ -1,277 +0,0 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import copy
|
||||
import sys
|
||||
from collections.abc import Generator, Iterator
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
|
||||
|
||||
from ...utils.batching_queue import BaseBatchingQueue
|
||||
from ...utils.logging import get_logger
|
||||
from ...utils.types import Processor, TorchDataset
|
||||
from .data_collator import DataCollator
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# base dataloader
|
||||
class DistributedDataloader(StatefulDataLoader):
|
||||
"""Base Distributed DataLoader."""
|
||||
|
||||
dataset: "TorchDataset"
|
||||
sampler: "StatefulDistributedSampler"
|
||||
|
||||
def set_epoch(self, epoch: int) -> None:
|
||||
if self.sampler is not None and hasattr(self.sampler, "set_epoch"):
|
||||
self.sampler.set_epoch(epoch)
|
||||
elif hasattr(self.dataset, "set_epoch"):
|
||||
self.dataset.set_epoch(epoch)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseDataLoader:
|
||||
"""Default DataLoader."""
|
||||
|
||||
processor: Processor
|
||||
|
||||
def __init__(self, dataset: TorchDataset) -> None:
|
||||
self.dataset = dataset
|
||||
# guidlines: fetch until get fixed batchsize.
|
||||
# save state_dict for buffer.
|
||||
# resume with state
|
||||
|
||||
# 1. Init stateful dataloader (tokenize)
|
||||
# 2. Add to buffer (2 * max seq len per device)
|
||||
# 3. Yield batch indexes (micro batch * grad acc)
|
||||
# a ) non pack + non dynamic
|
||||
# b ) non pack + dynamic
|
||||
# c ) pack + non dynamic
|
||||
# d ) pack + dynamic
|
||||
|
||||
def init_dataloader(self) -> None:
|
||||
### init dataloader
|
||||
pass
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
pass
|
||||
|
||||
def __next__(self) -> any:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataLoader:
|
||||
"""Default DataLoader."""
|
||||
|
||||
processor: "Processor"
|
||||
dataloader: "DistributedDataloader"
|
||||
batching_queue: "BaseBatchingQueue"
|
||||
collate_fn: "DataCollator"
|
||||
num_micro_batch: int = 1
|
||||
length: int = 0
|
||||
drop_last: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataloader: any,
|
||||
collate_fn: "DataCollator",
|
||||
num_micro_batch: int = 1,
|
||||
length: int = 0,
|
||||
drop_last: bool = True,
|
||||
batching_queue: Optional["BaseBatchingQueue"] = None,
|
||||
) -> None:
|
||||
self.batching_queue = batching_queue
|
||||
self.num_micro_batch = num_micro_batch
|
||||
self.step = 0
|
||||
self._collate_fn = collate_fn
|
||||
self._dataloader = dataloader
|
||||
self._drop_last = drop_last
|
||||
self._data_iter: Iterator
|
||||
self._resume = False
|
||||
self._batch_data_iter: Generator
|
||||
|
||||
if length > 0:
|
||||
self._length = length
|
||||
elif length == -1:
|
||||
self._length = sys.maxsize
|
||||
else:
|
||||
self._length = len(self._dataloader)
|
||||
|
||||
def __len__(self):
|
||||
return self._length
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
if not self._resume:
|
||||
self.step = 0
|
||||
self._data_iter = iter(self._dataloader)
|
||||
self._batch_data_iter = self.batch_data_generator()
|
||||
self._resume = False
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
return next(self._batch_data_iter) # FIXME maybe we can move origin_batch_data_generator to here
|
||||
|
||||
def origin_batch_data_generator(self):
|
||||
"""Standard pass-through generator if do not use batching queue."""
|
||||
while True:
|
||||
if self._length > 0 and self.step >= self._length:
|
||||
return
|
||||
|
||||
try:
|
||||
batch = []
|
||||
data = next(self._data_iter)
|
||||
# split data into micro batches
|
||||
for i in range(0, len(data), self.num_micro_batch):
|
||||
micro_batch = data[i : i + self.num_micro_batch]
|
||||
if self._collate_fn:
|
||||
micro_batch = self._collate_fn(micro_batch)
|
||||
batch.append(micro_batch)
|
||||
yield batch
|
||||
self.step += 1
|
||||
except StopIteration:
|
||||
if self.step < self._length:
|
||||
# Restart iterator to fill the requested length
|
||||
self._data_iter = iter(self._dataloader)
|
||||
try:
|
||||
batch = []
|
||||
data = next(self._data_iter)
|
||||
for i in range(0, len(data), self.num_micro_batch):
|
||||
micro_batch = data[i : i + self.num_micro_batch]
|
||||
if self._collate_fn:
|
||||
micro_batch = self._collate_fn(micro_batch)
|
||||
batch.append(micro_batch)
|
||||
yield batch
|
||||
self.step += 1
|
||||
except StopIteration:
|
||||
return
|
||||
else:
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"DataLoader origin_batch_data_generator exception: {e}")
|
||||
raise
|
||||
|
||||
def batch_data_generator(self):
|
||||
if self.batching_queue is None:
|
||||
yield from self.origin_batch_data_generator()
|
||||
return
|
||||
|
||||
batch = []
|
||||
|
||||
while True:
|
||||
if self._length and self.step >= self._length:
|
||||
return
|
||||
|
||||
if self.batching_queue.is_full_filled():
|
||||
micro_batch = self.batching_queue.get_micro_batch(self.step)
|
||||
if self._collate_fn:
|
||||
micro_batch = self._collate_fn(micro_batch)
|
||||
batch.append(micro_batch)
|
||||
if len(batch) == self.num_micro_batch:
|
||||
yield batch
|
||||
self.step += 1
|
||||
batch = []
|
||||
|
||||
try:
|
||||
processing_item = next(self._data_iter)
|
||||
except Exception as e:
|
||||
if isinstance(e, StopIteration):
|
||||
if self.step < self._length:
|
||||
# call iter until reach length
|
||||
self._data_iter = iter(self._dataloader)
|
||||
processing_item = next(self._data_iter)
|
||||
elif not self._drop_last and not self.batching_queue.empty():
|
||||
while not self.batching_queue.empty():
|
||||
micro_batch = self.batching_queue.get_micro_batch(self.step)
|
||||
if self._collate_fn:
|
||||
micro_batch = self._collate_fn(micro_batch)
|
||||
batch.append(micro_batch)
|
||||
if len(batch) == self.num_micro_batch:
|
||||
yield batch
|
||||
self.step += 1
|
||||
batch = []
|
||||
|
||||
while len(batch) < self.num_micro_batch:
|
||||
padding_batch = copy.deepcopy(micro_batch)
|
||||
padding_batch["is_padded"] = True
|
||||
batch.append(padding_batch)
|
||||
yield batch
|
||||
self.step += 1
|
||||
return
|
||||
else:
|
||||
return
|
||||
else:
|
||||
logger.error(f"DataLoader iter data exception: {e}")
|
||||
raise
|
||||
|
||||
# put processing_item to buffer
|
||||
if isinstance(processing_item, dict):
|
||||
processing_item = [processing_item]
|
||||
|
||||
for item in processing_item:
|
||||
self.batching_queue.put_item(item)
|
||||
|
||||
def state_dict(self):
|
||||
# save state
|
||||
state = self.__dict__.copy()
|
||||
# remove internal fields
|
||||
for k in list(state.keys()):
|
||||
if k.startswith("_"):
|
||||
del state[k]
|
||||
|
||||
# save dataloader state
|
||||
if hasattr(self._dataloader, "state_dict"):
|
||||
state["dataloader_state"] = self._dataloader.state_dict()
|
||||
elif hasattr(self._dataloader, "__getstate__"):
|
||||
state["dataloader_state"] = self._dataloader.__getstate__()
|
||||
|
||||
batching_strategy = getattr(self, "batching_strategy", None)
|
||||
if batching_strategy and hasattr(batching_strategy, "state_dict"):
|
||||
state["batching_strategy_state"] = batching_strategy.state_dict()
|
||||
if "batching_strategy" in state:
|
||||
del state["batching_strategy"]
|
||||
|
||||
return copy.deepcopy(state)
|
||||
|
||||
def load_state_dict(self, state: dict[str, any]):
|
||||
if state["num_micro_batch"] != self.num_micro_batch:
|
||||
logger.warning(
|
||||
f"num_micro_batch changed: [ {state['num_micro_batch']} -> {self.num_micro_batch} ], will clear prefetch buffer"
|
||||
)
|
||||
del state["num_micro_batch"]
|
||||
self.__dict__.update(state)
|
||||
self._resume = True
|
||||
|
||||
if hasattr(self._dataloader, "load_state_dict"):
|
||||
self._dataloader.load_state_dict(state["dataloader_state"])
|
||||
elif hasattr(self._dataloader, "__getstate__"):
|
||||
self._dataloader.__setstate__(state["dataloader_state"])
|
||||
|
||||
if "batching_strategy_state" in state:
|
||||
batching_strategy = getattr(self, "batching_strategy", None)
|
||||
if batching_strategy:
|
||||
batching_strategy.load_state_dict(state["batching_strategy_state"])
|
||||
del state["batching_strategy_state"]
|
||||
|
||||
self._data_iter = iter(self._dataloader)
|
||||
self._batch_data_iter = self.batch_data_generator()
|
||||
|
||||
def set_epoch(self, epoch: int) -> None:
|
||||
if hasattr(self._dataloader, "set_epoch"):
|
||||
self._dataloader.set_epoch(epoch)
|
||||
@@ -12,10 +12,20 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Rendering utils.
|
||||
|
||||
How to use:
|
||||
renderer = Renderer(template, processor)
|
||||
renderer.render_messages(messages: list[Message], tools: str | None) -> ModelInputs
|
||||
renderer.parse_message(text: str) -> Message
|
||||
renderer.process_samples(samples: list[Sample]) -> list[ModelInput]
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...utils.constants import IGNORE_INDEX
|
||||
from ...utils.helper import get_tokenizer
|
||||
from ...utils.types import Message, ModelInput, Processor
|
||||
from ...utils.types import Message, ModelInput, Processor, Sample
|
||||
|
||||
|
||||
def render_chatml_messages(
|
||||
@@ -64,7 +74,7 @@ def render_chatml_messages(
|
||||
|
||||
|
||||
def parse_chatml_message(generated_text: str) -> Message:
|
||||
"""Parse a message in ChatML format. Supports interleaved reasoning and tool calls.
|
||||
"""Parse a message in ChatML format.
|
||||
|
||||
Args:
|
||||
generated_text (str): The generated text in ChatML format.
|
||||
@@ -83,6 +93,16 @@ class Renderer:
|
||||
def render_messages(
|
||||
self, messages: list[Message], tools: str | None = None, is_generate: bool = False
|
||||
) -> ModelInput:
|
||||
"""Apply template to messages and convert them to model input.
|
||||
|
||||
Args:
|
||||
messages (list[Message]): The messages to render.
|
||||
tools (str | None, optional): The tools to use. Defaults to None.
|
||||
is_generate (bool, optional): Whether to render for generation. Defaults to False.
|
||||
|
||||
Returns:
|
||||
ModelInput: The rendered model input.
|
||||
"""
|
||||
if self.template == "chatml":
|
||||
return render_chatml_messages(self.processor, messages, tools, is_generate)
|
||||
else:
|
||||
@@ -91,9 +111,59 @@ class Renderer:
|
||||
return RenderingPlugin(self.template).render_messages(self.processor, messages, tools, is_generate)
|
||||
|
||||
def parse_message(self, generated_text: str) -> Message:
|
||||
"""Parse a message in the template format.
|
||||
|
||||
Args:
|
||||
generated_text (str): The generated text in the template format.
|
||||
|
||||
Returns:
|
||||
Message: The parsed message.
|
||||
"""
|
||||
if self.template == "chatml":
|
||||
return parse_chatml_message(generated_text)
|
||||
else:
|
||||
from ...plugins.model_plugins.rendering import RenderingPlugin
|
||||
|
||||
return RenderingPlugin(self.template).parse_message(generated_text)
|
||||
|
||||
def process_samples(self, samples: list[Sample]) -> list[ModelInput]:
|
||||
"""Process samples to model input.
|
||||
|
||||
Args:
|
||||
samples (list[Sample]): The samples to process.
|
||||
|
||||
Returns:
|
||||
list[ModelInput]: The processed model inputs.
|
||||
"""
|
||||
model_inputs = []
|
||||
for sample in samples:
|
||||
if "messages" in sample:
|
||||
model_input = self.render_messages(sample["messages"], sample.get("tools"))
|
||||
elif "chosen_messages" in sample and "rejected_messages" in sample:
|
||||
chosen_input = self.render_messages(sample["chosen_messages"], sample.get("tools"))
|
||||
rejected_input = self.render_messages(sample["rejected_messages"], sample.get("tools"))
|
||||
chosen_input["token_type_ids"] = [0] * len(chosen_input["input_ids"])
|
||||
rejected_input["token_type_ids"] = [1] * len(rejected_input["input_ids"])
|
||||
model_input = ModelInput(
|
||||
input_ids=chosen_input["input_ids"] + rejected_input["input_ids"],
|
||||
attention_mask=chosen_input["attention_mask"] + rejected_input["attention_mask"],
|
||||
labels=chosen_input["labels"] + rejected_input["labels"],
|
||||
loss_weights=chosen_input["loss_weights"] + rejected_input["loss_weights"],
|
||||
token_type_ids=chosen_input["token_type_ids"] + rejected_input["token_type_ids"],
|
||||
)
|
||||
if "position_ids" in chosen_input:
|
||||
model_input["position_ids"] = np.concatenate(
|
||||
[chosen_input["position_ids"], rejected_input["position_ids"]], axis=-1
|
||||
)
|
||||
else:
|
||||
raise ValueError("No valid messages or chosen_messages/rejected_messages found in sample.")
|
||||
|
||||
if "extra_info" in sample:
|
||||
model_input["extra_info"] = sample["extra_info"]
|
||||
|
||||
if "_dataset_name" in sample:
|
||||
model_input["_dataset_name"] = sample["_dataset_name"]
|
||||
|
||||
model_inputs.append(model_input)
|
||||
|
||||
return model_inputs
|
||||
|
||||
Reference in New Issue
Block a user