[v1] add accelerator (#9607)

This commit is contained in:
Yaowei Zheng
2025-12-12 19:22:06 +08:00
committed by GitHub
parent 4fd94141a4
commit 203069e11c
36 changed files with 941 additions and 443 deletions

View File

@@ -1,4 +1,4 @@
dpo_zh_demo:
hf_hub_url: HuggingFaceH4/orca_dpo_pairs
path: HuggingFaceH4/orca_dpo_pairs
split: train_prefs
converter: pair

View File

@@ -1,8 +1,9 @@
identity:
file_name: data/identity.json
path: data/identity.json
source: local
converter: alpaca
alpaca_en_demo:
file_name: alpaca_en_demo.json
dataset_dir: data
path: data/alpaca_en_demo.json
source: local
converter: alpaca
num_samples: 500
size: 500

View File

@@ -1,4 +1,7 @@
# Copyright 2025 the LlamaFactory team.
# Copyright 2025 Bytedance Ltd. and the LlamaFactory team.
#
# This code is inspired by the Bytedance's VeOmni library.
# https://github.com/ByteDance-Seed/VeOmni/blob/v0.1.4/veomni/utils/dist_utils.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,12 +15,68 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from contextlib import contextmanager
from enum import Enum, unique
from functools import lru_cache
from typing import TYPE_CHECKING, Optional
import numpy as np
import torch
import torch.distributed as dist
from ..utils.types import Tensor, TensorLike
def get_current_accelerator(check_available: bool = True):
if TYPE_CHECKING:
from torch.distributed import ProcessGroup
@unique
class DeviceType(str, Enum):
CPU = "cpu"
CUDA = "cuda"
META = "meta"
MPS = "mps"
NPU = "npu"
XPU = "xpu"
@unique
class ReduceOp(str, Enum):
SUM = "sum"
MEAN = "mean"
MAX = "max"
MIN = "min"
def is_distributed() -> bool:
"""Check if distributed environment is available."""
return os.getenv("RANK") is not None
def get_rank() -> int:
"""Get rank."""
return int(os.getenv("RANK", "0"))
def get_local_rank() -> int:
"""Get local rank."""
return int(os.getenv("LOCAL_RANK", "0"))
def get_world_size() -> int:
"""Get world size."""
return int(os.getenv("WORLD_SIZE", "1"))
def get_local_world_size() -> int:
"""Get local world size."""
return int(os.getenv("LOCAL_WORLD_SIZE", "1"))
@lru_cache
def get_current_accelerator(check_available: bool = True) -> torch.device:
"""Get current accelerator.
Note: this api requires torch>=2.7.0, 2.6 or lower will get an AttributeError or RuntimeError
@@ -27,26 +86,78 @@ def get_current_accelerator(check_available: bool = True):
accelerator = torch.accelerator.current_accelerator(check_available=check_available)
if accelerator is None:
return torch.device("cpu")
return torch.device(DeviceType.CPU.value)
return accelerator
@lru_cache
def is_torch_npu_available():
return get_current_accelerator().type == "npu"
@lru_cache
def is_torch_cuda_available():
return get_current_accelerator().type == "cuda"
return get_current_accelerator().type == DeviceType.CUDA
@lru_cache
def is_torch_xpu_available():
return get_current_accelerator().type == "xpu"
@lru_cache
def is_torch_mps_available():
return get_current_accelerator().type == "mps"
return get_current_accelerator().type == DeviceType.MPS
def is_torch_npu_available():
return get_current_accelerator().type == DeviceType.NPU
def is_torch_xpu_available():
return get_current_accelerator().type == DeviceType.XPU
def all_gather(tensor: Tensor, group: Optional["ProcessGroup"] = None) -> Tensor:
"""Gathers the tensor from all ranks and concats them along the first dim."""
world_size = get_world_size()
device = get_current_accelerator()
output_tensor = torch.empty(world_size * tensor.numel(), dtype=tensor.dtype, device=device)
dist.all_gather_into_tensor(output_tensor, tensor, group=group)
return output_tensor.view(-1, *tensor.size()[1:])
def all_reduce(data: TensorLike, op: ReduceOp = ReduceOp.MEAN, group: Optional["ProcessGroup"] = None) -> TensorLike:
"""Performs all reduce in the given process group."""
device = get_current_accelerator()
is_ndarray = isinstance(data, np.ndarray)
is_tensor = isinstance(data, torch.Tensor)
if is_ndarray:
data = torch.from_numpy(data)
elif not is_tensor:
data = torch.tensor(data, dtype=torch.float, device=device)
reduce_ops = {
ReduceOp.MEAN: dist.ReduceOp.SUM,
ReduceOp.SUM: dist.ReduceOp.SUM,
ReduceOp.MAX: dist.ReduceOp.MAX,
ReduceOp.MIN: dist.ReduceOp.MIN,
}
dist.all_reduce(data, op=reduce_ops[op], group=group)
if op == ReduceOp.MEAN: # ReduceOp.AVG is not supported by the NPU backend
data /= dist.get_world_size(group=group)
if is_tensor:
return data
elif is_ndarray:
return data.numpy()
elif data.numel() == 1:
return data.item()
else:
return data.tolist()
@contextmanager
def main_process_first(local_only: bool = True) -> None:
"""A context manager for torch distributed environment to do something on the main process firstly."""
if get_world_size() > 1:
is_main_process = get_local_rank() == 0 if local_only else get_rank() == 0
try:
if not is_main_process:
dist.barrier()
yield
finally:
if is_main_process:
dist.barrier()
else:
yield

View File

@@ -0,0 +1,123 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Optional
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from ..utils.types import TensorLike
from .helper import ReduceOp, all_reduce, get_current_accelerator, get_rank, get_world_size, is_distributed
@dataclass
class DistributedStrategy:
"""Distributed strategy."""
dp_size: Optional[int] = None
tp_size: int = 1
def __post_init__(self) -> None:
if not is_distributed():
self.dp_size = 1
elif self.dp_size is None:
self.dp_size = get_world_size() // self.tp_size
elif self.dp_size * self.tp_size != get_world_size():
raise ValueError(
f"dp_size * tp_size must equal to world_size, "
f"got {self.dp_size} * {self.tp_size} != {get_world_size()}."
)
@property
def mesh_shape(self) -> tuple[int, int]:
"""Mesh shape."""
return (self.dp_size, self.tp_size)
@property
def mesh_dim_names(self) -> tuple[str, str]:
"""Mesh dimension names."""
return ("dp", "tp")
class DistributedInterface:
"""Distributed interface."""
_instance: Optional["DistributedInterface"] = None
_initialized: bool = False
is_distributed = is_distributed()
"""Check if distributed environment is available."""
rank = get_rank()
"""Global rank."""
world_size = get_world_size()
"""Global world size."""
device_mesh: Optional[DeviceMesh] = None
"""Device mesh."""
current_accelerator = get_current_accelerator()
"""Current accelerator."""
def __new__(cls, *args: Any, **kwargs: Any) -> "DistributedInterface":
"""Singleton pattern."""
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, strategy: DistributedStrategy) -> None:
if self._initialized:
return
self.strategy = strategy
if self.is_distributed:
self.device_mesh = init_device_mesh(
device_type=self.current_accelerator.type,
mesh_shape=strategy.mesh_shape,
mesh_dim_names=strategy.mesh_dim_names,
)
else:
self.device_mesh = None
self._initialized = True
def __str__(self) -> str:
return (
f"DistributedInterface(strategy={self.strategy}), is_distributed={self.is_distributed}, "
f"rank={self.rank}, world_size={self.world_size}, "
f"device_mesh={self.device_mesh}, current_accelerator={self.current_accelerator}"
)
def dp_rank(self) -> int:
"""Data parallel rank."""
if self.device_mesh is None:
return 0
return self.device_mesh["dp"].get_rank()
def dp_size(self) -> int:
"""Data parallel size."""
if self.device_mesh is None:
return 1
return self.device_mesh["dp"].size()
def all_reduce_over_dp(self, data: TensorLike, op: ReduceOp = ReduceOp.MEAN) -> TensorLike:
"""All reduce tensor."""
if self.device_mesh is None:
return data
return all_reduce(data, op, self.device_mesh["dp"].get_group())
if __name__ == "__main__":
print(DistributedInterface(DistributedStrategy()))

View File

@@ -28,6 +28,21 @@ from .sample_args import SampleArguments
from .training_args import TrainingArguments
def validate_args(
data_args: DataArguments,
model_args: ModelArguments,
training_args: TrainingArguments,
sample_args: SampleArguments,
):
"""Validate arguments."""
if (
model_args.quant_config is not None
and training_args.dist_config is not None
and training_args.dist_config.name == "deepspeed"
):
raise ValueError("Quantization is not supported with deepspeed backend.")
def get_args(
args: Optional[Union[dict[str, Any], list[str]]] = None,
) -> tuple[DataArguments, ModelArguments, TrainingArguments, SampleArguments]:
@@ -56,6 +71,8 @@ def get_args(
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
validate_args(*parsed_args)
return tuple(parsed_args)

View File

@@ -0,0 +1,105 @@
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/training_args.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from enum import Enum, unique
from typing import Any, Optional, Union
class PluginConfig(dict):
"""Dictionary that allows attribute access."""
@property
def name(self) -> str:
"""Plugin name."""
if "name" not in self:
raise ValueError("Plugin configuration must have a 'name' field.")
return self["name"]
def __getattr__(self, key: str) -> Any:
try:
return self[key]
except KeyError:
raise AttributeError(f"Attribute {key} not found.")
def __setattr__(self, key: str, value: Any):
self[key] = value
PluginArgument = Optional[Union[PluginConfig, dict, str]]
@unique
class AutoClass(str, Enum):
"""Auto class for model config."""
CAUSALLM = "llm"
CLASSIFICATION = "cls"
OTHER = "other"
@unique
class SampleBackend(str, Enum):
HF = "hf"
VLLM = "vllm"
def _convert_str_dict(data: dict) -> dict:
"""Parse string representation inside the dictionary.
Args:
data: The string or dictionary to convert.
Returns:
The converted dictionary.
"""
for key, value in data.items():
if isinstance(value, dict):
data[key] = _convert_str_dict(value)
elif isinstance(value, str):
if value.lower() in ("true", "false"):
data[key] = value.lower() == "true"
elif value.isdigit():
data[key] = int(value)
elif value.replace(".", "", 1).isdigit():
data[key] = float(value)
return data
def get_plugin_config(config: PluginArgument) -> Optional[PluginConfig]:
"""Get the plugin configuration from the argument value.
Args:
config: The argument value to get the plugin configuration from.
Returns:
The plugin configuration.
"""
if config is None:
return None
if isinstance(config, str) and config.startswith("{"):
config = json.loads(config)
config = _convert_str_dict(config)
if "name" not in config:
raise ValueError("Plugin configuration must have a 'name' field.")
return PluginConfig(config)

View File

@@ -15,6 +15,8 @@
from dataclasses import dataclass, field
from .arg_utils import AutoClass, PluginConfig, get_plugin_config
@dataclass
class ModelArguments:
@@ -29,7 +31,24 @@ class ModelArguments:
default=True,
metadata={"help": "Use fast processor from Hugging Face."},
)
auto_model_class: str = field(
default="causallm",
auto_class: AutoClass = field(
default=AutoClass.CAUSALLM,
metadata={"help": "Model class from Hugging Face."},
)
peft_config: PluginConfig = field(
default=None,
metadata={"help": "PEFT configuration for the model."},
)
kernel_config: PluginConfig = field(
default=None,
metadata={"help": "Kernel configuration for the model."},
)
quant_config: PluginConfig = field(
default=None,
metadata={"help": "Quantization configuration for the model."},
)
def __post_init__(self) -> None:
self.peft_config = get_plugin_config(self.peft_config)
self.kernel_config = get_plugin_config(self.kernel_config)
self.quant_config = get_plugin_config(self.quant_config)

View File

@@ -14,12 +14,8 @@
from dataclasses import dataclass, field
from enum import Enum
class SampleBackend(Enum):
HF = "hf"
VLLM = "vllm"
from .arg_utils import SampleBackend
@dataclass

View File

@@ -15,6 +15,8 @@
from dataclasses import dataclass, field
from .arg_utils import PluginArgument, get_plugin_config
@dataclass
class TrainingArguments:
@@ -38,3 +40,10 @@ class TrainingArguments:
default=False,
metadata={"help": "Use bf16 for training."},
)
dist_config: PluginArgument = field(
default=None,
metadata={"help": "Distribution configuration for training."},
)
def __post_init__(self) -> None:
self.dist_config = get_plugin_config(self.dist_config)

View File

@@ -29,7 +29,7 @@ Train Phase:
"""
from ..config.training_args import TrainingArguments
from ..extras.types import TorchDataset
from ..utils.types import TorchDataset
from .model_worker import ModelWorker
from .trainer_utils.data_collator import DataCollator
@@ -49,13 +49,10 @@ class BaseTrainer:
self.optimizer = None
self.lr_scheduler = None
def init_device_mesh(self) -> None:
pass
def init_model_and_optimizer(self) -> None:
self.model_config = self.model_worker.get_model_config()
self.model_worker.init_model_config()
# with self.dist_plugin.get_model_init_context():
# self.model = self.model_worker.get_model(self.model_config)
# self.model = self.model_worker.init_model(self.model_config)
def create_dataloader(self) -> None:
pass

View File

@@ -34,7 +34,7 @@ from omegaconf import OmegaConf
from torch.utils.data import Dataset
from ..config.data_args import DataArguments
from ..extras.types import DatasetInfo, HFDataset, Sample
from ..utils.types import DatasetInfo, HFDataset, Sample
class DataEngine(Dataset):
@@ -64,9 +64,9 @@ class DataEngine(Dataset):
filepath = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")
self.dataset_infos = OmegaConf.load(filepath)
elif os.path.exists(self.args.dataset): # local file(s)
self.dataset_infos = {"default": {"file_name": self.args.dataset}}
self.dataset_infos = {"default": {"path": self.args.dataset, "source": "local"}}
else: # hf hub dataset, e.g. llamafactory/v1-sft-demo
self.dataset_infos = {"default": {"hf_hub_url": self.args.dataset}}
self.dataset_infos = {"default": {"path": self.args.dataset}}
def load_dataset(self) -> None:
"""Load datasets according to dataset info."""
@@ -74,14 +74,14 @@ class DataEngine(Dataset):
split = value.get("split", "train")
streaming = value.get("streaming", False)
self.streaming |= streaming
if "hf_hub_url" in value:
if value.get("source", "hf_hub") == "hf_hub":
from datasets import load_dataset
self.datasets[key] = load_dataset(value["hf_hub_url"], split=split, streaming=streaming)
self.datasets[key] = load_dataset(value["path"], split=split, streaming=streaming)
else: # data loader plugin
from ..plugins.data_plugins.loader import DataLoaderPlugin
self.datasets[key] = DataLoaderPlugin().auto_load_data(value)
self.datasets[key] = DataLoaderPlugin(value["source"]).load(value)
def build_data_index(self) -> None:
"""Build dataset index."""
@@ -112,9 +112,9 @@ class DataEngine(Dataset):
"""
converter = self.dataset_infos[dataset_name].get("converter")
if converter is not None:
from ..plugins.data_plugins.converter import get_converter
from ..plugins.data_plugins.converter import DataConverterPlugin
return {"_dataset_name": dataset_name, **get_converter(converter)(raw_sample)}
return {"_dataset_name": dataset_name, **DataConverterPlugin(converter)(raw_sample)}
else:
return {"_dataset_name": dataset_name, **raw_sample}
@@ -147,7 +147,7 @@ class DataEngine(Dataset):
else:
from ..plugins.data_plugins.loader import DataSelectorPlugin
selected_index = DataSelectorPlugin(data_index=self.data_index).select(index)
selected_index = DataSelectorPlugin().select(self.data_index, index)
if isinstance(selected_index, list):
return [
self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
@@ -187,7 +187,7 @@ class DataEngine(Dataset):
if __name__ == "__main__":
from ..config.parser import get_args
from ..config.arg_parser import get_args
data_args, *_ = get_args()
data_engine = DataEngine(data_args=data_args)

View File

@@ -24,10 +24,12 @@ Init Phase:
from typing import Optional
import torch
from transformers import AutoConfig, AutoProcessor
from ..config.model_args import ModelArguments
from ..extras.types import DistModel, HFConfig, HFModel, Processor
from ..accelerator.helper import DeviceType
from ..config.model_args import AutoClass, ModelArguments
from ..utils.types import HFConfig, HFModel, Processor
class ModelWorker:
@@ -38,16 +40,15 @@ class ModelWorker:
"""Tokenizer or multi-modal processor."""
self.model_config: Optional[HFConfig] = None
"""Model configuration."""
self.unwrapped_model: Optional[HFModel] = None
"""Unwrapped model."""
self.model: Optional[DistModel] = None
"""Distributed model."""
self.init_processor()
self.init_model_config()
self.init_model()
self.init_adapter()
self.model: Optional[HFModel] = None
"""HF model."""
self.is_adapter = False
"""Whether the model has adapter."""
def init_processor(self) -> None:
if self.processor is not None:
return
self.processor = AutoProcessor.from_pretrained(
self.args.model,
trust_remote_code=self.args.trust_remote_code,
@@ -55,38 +56,58 @@ class ModelWorker:
)
def init_model_config(self) -> None:
if self.model_config is not None:
return
self.model_config = AutoConfig.from_pretrained(
self.args.model,
trust_remote_code=self.args.trust_remote_code,
)
def init_model(self) -> None:
if self.args.auto_model_class == "causallm":
if self.model is not None:
return
self.init_model_config()
if self.args.auto_class == AutoClass.CAUSALLM:
from transformers import AutoModelForCausalLM, AutoModelForImageTextToText
if type(self.model_config) in AutoModelForImageTextToText._model_mapping.keys():
AutoClass = AutoModelForImageTextToText
ModelClass = AutoModelForImageTextToText
else:
AutoClass = AutoModelForCausalLM
elif self.args.auto_model_class == "classification":
ModelClass = AutoModelForCausalLM
elif self.args.auto_class == AutoClass.CLASSIFICATION:
from transformers import AutoModelForTokenClassification
AutoClass = AutoModelForTokenClassification
ModelClass = AutoModelForTokenClassification
else:
from transformers import AutoModel
AutoClass = AutoModel
ModelClass = AutoModel
self.unwrapped_model = AutoClass.from_pretrained(
self.args.model,
config=self.model_config,
dtype="auto",
device_map="cpu",
trust_remote_code=self.args.trust_remote_code,
)
default_device_type = torch.get_default_device().type
if default_device_type == DeviceType.META:
self.model = ModelClass.from_config(self.model_config)
else:
self.model = ModelClass.from_pretrained(
self.args.model,
config=self.model_config,
dtype="auto",
device_map=default_device_type,
trust_remote_code=self.args.trust_remote_code,
)
def init_adapter(self) -> None:
pass
if self.is_adapter:
return
if self.args.peft_config is not None:
from ..plugins.model_plugins.peft import PeftPlugin
self.model = PeftPlugin(self.args.peft_config.name)(self.model, self.args.peft_config)
self.is_adapter = True
def get_processor(self) -> Processor:
return self.processor
@@ -95,4 +116,4 @@ class ModelWorker:
return self.model_config
def get_model(self) -> HFModel:
return self.unwrapped_model
return self.model

View File

@@ -15,7 +15,7 @@
from typing import Any
from ...extras.types import Processor, Tensor, TorchDataset
from ...utils.types import Processor, Tensor, TorchDataset
class DataCollator:

View File

@@ -44,7 +44,7 @@ WELCOME = (
def launch():
command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
if command == "sft":
if command == "sft": # train command will fallback to sft command
from .trainers.sft_trainer import run_sft
run_sft()

View File

@@ -13,12 +13,13 @@
# limitations under the License.
from typing import Callable, TypedDict
from typing import Any, Literal, TypedDict
from typing_extensions import NotRequired, Required
from typing_extensions import NotRequired
from ....extras import logging
from ...extras.types import DPOSample, Sample, SFTSample
from ...utils import logging
from ...utils.plugin import BasePlugin
from ...utils.types import DPOSample, Sample, SFTSample
logger = logging.get_logger(__name__)
@@ -26,35 +27,48 @@ logger = logging.get_logger(__name__)
class AlpacaSample(TypedDict, total=False):
system: NotRequired[str]
instruction: NotRequired[str]
instruction: str
input: NotRequired[str]
output: NotRequired[str]
output: str
ShareGPTMessage = TypedDict(
"ShareGPTMessage",
{
"from": Required[str], # Role of the message sender (e.g., "human", "gpt", "system")
"value": Required[str], # Content of the message
},
SharegptMessage = TypedDict(
"SharegptMessage", {"from": Literal["human", "gpt", "system", "function_call", "observation"], "value": str}
)
class ShareGPTSample(TypedDict, total=False):
"""Type definition for raw ShareGPT sample."""
class SharegptSample(TypedDict, total=False):
conversations: list[SharegptMessage]
tools: NotRequired[str]
conversations: Required[list[ShareGPTMessage]]
class OpenaiMessage(TypedDict, total=False):
role: Literal["user", "assistant", "tool"]
content: str
class OpenaiSample(TypedDict, total=False):
messages: list[OpenaiMessage]
class PairSample(TypedDict, total=False):
prompt: NotRequired[str]
chosen: NotRequired[list[dict]]
rejected: NotRequired[list[dict]]
chosen: list[OpenaiMessage]
rejected: list[OpenaiMessage]
class DataConverterPlugin(BasePlugin):
"""Plugin for data converters."""
def __call__(self, raw_sample: dict[str, Any]) -> Sample:
return super().__call__(raw_sample)
@DataConverterPlugin("alpaca").register
def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
"""Convert Alpaca sample to SFT sample.
See raw example at: https://huggingface.co/datasets/llamafactory/alpaca_gpt4_en
Args:
raw_sample (AlpacaSample): Alpaca sample.
@@ -67,20 +81,6 @@ def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
{"role": "system", "content": [{"type": "text", "value": raw_sample["system"]}], "loss_weight": 0.0}
)
if "history" in raw_sample:
for idx, item in enumerate(raw_sample["history"]):
if len(item) != 2:
logger.warning_rank0(
f"Warning: History item at index {idx} has invalid length (expected 2, got {len(item)}). Skipping."
)
continue
old_prompt, old_response = item
messages.append({"role": "user", "content": [{"type": "text", "value": old_prompt}], "loss_weight": 0.0})
messages.append(
{"role": "assistant", "content": [{"type": "text", "value": old_response}], "loss_weight": 1.0}
)
if "instruction" in raw_sample or "input" in raw_sample:
messages.append(
{
@@ -100,149 +100,85 @@ def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
return {"messages": messages}
def sharegpt_converter(raw_sample: ShareGPTSample) -> SFTSample:
"""Converts a raw ShareGPT sample into a formatted SFT (Supervised Fine-Tuning) sample.
@DataConverterPlugin("sharegpt").register
def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
"""Convert ShareGPT sample to SFT sample.
Retains only SFT-relevant scenarios and removes parity checks.
See raw example at: https://huggingface.co/datasets/llamafactory/glaive_toolcall_en
Args:
raw_sample (ShareGPTSample): A raw sample in ShareGPT format.
raw_sample (SharegptSample): ShareGPT sample.
Returns:
dict: A dictionary containing the formatted 'messages' list for SFT training.
Returns an empty list if the input data is invalid.
SFTSample: SFT sample.
"""
tag_mapping = {
"system": "system",
"human": "user",
"gpt": "assistant",
"observation": "observation",
"function_call": "function",
"observation": "tool",
"function_call": "assistant",
}
messages = raw_sample.get("conversations", [])
aligned_messages = []
system_content = ""
messages = []
tools = raw_sample.get("tools", "")
# Extract system message if present (typically the first message)
if messages and messages[0]["from"] == "system":
system_content = messages[0]["value"]
messages = messages[1:]
for message in raw_sample.get("conversations", []):
tag = message["from"]
if tag not in tag_mapping:
logger.warning_rank0(f"Unsupported role tag {tag} in message: {message}")
elif tag == "function_call":
messages.append(
{
"role": "assistant",
"content": [{"type": "tool_calls", "value": message["value"]}],
"loss_weight": 1.0,
}
)
else:
messages.append(
{
"role": tag_mapping[tag],
"content": [{"type": "text", "value": message["value"]}],
"loss_weight": 1.0 if tag == "gpt" else 0.0,
}
)
if system_content:
aligned_messages.append(
{"role": "system", "content": [{"type": "text", "value": system_content}], "loss_weight": 0.0}
)
if tools:
if messages and messages[0]["role"] == "system":
messages[0]["content"].append({"type": "tools", "value": tools})
else:
messages.insert(0, {"role": "system", "content": [{"type": "tools", "value": tools}], "loss_weight": 0.0})
has_invalid_role = False
for message in messages:
sender = message["from"]
# validate sender is in supported tags
if sender not in tag_mapping:
logger.warning_rank0(f"Unsupported role tag '{sender}' in message: {message}")
has_invalid_role = True
break
aligned_messages.append(
{
"role": tag_mapping[sender],
"content": [{"type": "text", "value": message["value"]}],
"loss_weight": 0.0 if sender in ("human", "observation") else 1.0,
}
)
if has_invalid_role:
logger.warning_rank0("Skipping invalid example due to unsupported role tags.")
return {"messages": []}
return {"messages": aligned_messages}
return {"messages": messages}
@DataConverterPlugin("pair").register
def pair_converter(raw_sample: PairSample) -> DPOSample:
"""Convert Pair sample to standard DPO sample.
"""Convert Pair sample to DPO sample.
See raw example at: https://huggingface.co/datasets/HuggingFaceH4/orca_dpo_pairs
Args:
raw_sample (PairSample): pair sample with prompt, chosen, rejected fields.
see raw example at: https://huggingface.co/datasets/HuggingFaceH4/orca_dpo_pairs
raw_sample (PairSample): pair sample with chosen, rejected fields.
Returns:
DPOSample: DPO sample with chosen_messages and rejected_messages.
see the standard DPO sample at: https://huggingface.co/datasets/frozenleaves/v1-dpo-demo/raw/main/v1-dpo-demo.jsonl
"""
chosen_messages = []
assert "chosen" in raw_sample, "chosen field is required in pair sample."
assert "rejected" in raw_sample, "rejected field is required in pair sample."
assert isinstance(raw_sample["chosen"], list) and isinstance(raw_sample["rejected"], list), (
"chosen and rejected field should be a list[dict], or you may need to implement your custom converter."
)
if "chosen" in raw_sample:
value = raw_sample.get("chosen", "")
for item in value:
if item.get("role", "") == "system":
chosen_messages.append(
{
"role": "system",
"content": [{"type": "text", "value": item.get("content", "")}],
"loss_weight": 0.0,
}
)
if item.get("role", "") == "user":
chosen_messages.append(
{
"role": "user",
"content": [{"type": "text", "value": item.get("content", "")}],
"loss_weight": 0.0,
}
)
if item.get("role", "") == "assistant":
chosen_messages.append(
{
"role": "assistant",
"content": [{"type": "text", "value": item.get("content", "")}],
"loss_weight": 1.0,
}
)
def process_message(raw_messages: list[OpenaiMessage]):
messages = []
for message in raw_messages:
messages.append(
{
"role": message["role"],
"content": [{"type": "text", "value": message["content"]}],
"loss_weight": 1.0 if message["role"] == "assistant" else 0.0,
}
)
rejected_messages = []
if "rejected" in raw_sample:
value = raw_sample.get("rejected", "")
for item in value:
if item.get("role", "") == "system":
rejected_messages.append(
{
"role": "system",
"content": [{"type": "text", "value": item.get("content", "")}],
"loss_weight": 0.0,
}
)
if item.get("role", "") == "user":
rejected_messages.append(
{
"role": "user",
"content": [{"type": "text", "value": item.get("content", "")}],
"loss_weight": 0.0,
}
)
if item.get("role", "") == "assistant":
rejected_messages.append(
{
"role": "assistant",
"content": [{"type": "text", "value": item.get("content", "")}],
"loss_weight": 1.0,
}
)
return messages
chosen_messages = process_message(raw_sample.get("chosen", []))
rejected_messages = process_message(raw_sample.get("rejected", []))
return {"chosen_messages": chosen_messages, "rejected_messages": rejected_messages}
CONVERTERS = {
"alpaca": alpaca_converter,
"pair": pair_converter,
"sharegpt": sharegpt_converter,
}
def get_converter(converter_name: str) -> Callable[[dict], Sample]:
if converter_name not in CONVERTERS:
raise ValueError(f"Converter {converter_name} not found.")
return CONVERTERS[converter_name]

View File

@@ -14,57 +14,59 @@
import os
from dataclasses import dataclass
import random
from typing import Any, Literal, Optional, Union
from datasets import load_dataset
from ...extras.types import DatasetInfo, HFDataset
from ...utils.plugin import BasePlugin
from ...utils.types import DatasetInfo, HFDataset
@dataclass
class DataLoaderPlugin:
class DataLoaderPlugin(BasePlugin):
"""Plugin for loading dataset."""
def _get_builder_name(self, path: str) -> Literal["arrow", "csv", "json", "parquet", "text"]:
"""Get dataset builder name.
Args:
path (str): Dataset path.
Returns:
Literal["arrow", "csv", "json", "parquet", "text"]: Dataset builder name.
"""
return os.path.splitext(path)[-1][1:].replace("jsonl", "json").replace("txt", "text")
def auto_load_data(self, dataset_info: DatasetInfo) -> HFDataset:
dataset_dir = dataset_info.get("dataset_dir", ".")
def load(self, dataset_info: DatasetInfo) -> HFDataset:
path = dataset_info["path"]
split = dataset_info.get("split", "train")
streaming = dataset_info.get("streaming", False)
if "file_name" in dataset_info:
filepath = os.path.join(dataset_dir, dataset_info["file_name"])
return self.load_data_from_file(filepath, split, streaming)
else:
raise NotImplementedError()
def load_data_from_file(self, filepath: str, split: str, streaming: bool) -> HFDataset:
if os.path.isdir(filepath):
filetype = self._get_builder_name(os.listdir(filepath)[0])
dataset = load_dataset(filetype, data_dir=filepath, split=split)
elif os.path.isfile(filepath):
filetype = self._get_builder_name(filepath)
dataset = load_dataset(filetype, data_files=filepath, split=split)
else:
raise ValueError(f"Can not load dataset from {filepath}.")
if streaming:
dataset = dataset.to_iterable_dataset()
return dataset
return super().__call__(path, split, streaming)
@dataclass
class DataIndexPlugin:
def _get_builder_name(path: str) -> Literal["arrow", "csv", "json", "parquet", "text"]:
"""Get dataset builder name.
Args:
path (str): Dataset path.
Returns:
Literal["arrow", "csv", "json", "parquet", "text"]: Dataset builder name.
"""
filetype = os.path.splitext(path)[-1][1:]
if filetype in ["arrow", "csv", "json", "jsonl", "parquet", "txt"]:
return filetype.replace("jsonl", "json").replace("txt", "text")
else:
raise ValueError(f"Unknown dataset filetype: {filetype}.")
@DataLoaderPlugin("local").register
def load_data_from_file(filepath: str, split: str, streaming: bool) -> HFDataset:
if os.path.isdir(filepath):
filetype = _get_builder_name(os.listdir(filepath)[0])
dataset = load_dataset(filetype, data_dir=filepath, split=split)
elif os.path.isfile(filepath):
filetype = _get_builder_name(filepath)
dataset = load_dataset(filetype, data_files=filepath, split=split)
else:
raise ValueError(f"Can not load dataset from {filepath}.")
if streaming: # faster when data is streamed from local files
dataset = dataset.to_iterable_dataset()
return dataset
class DataIndexPlugin(BasePlugin):
"""Plugin for adjusting dataset index."""
def adjust_data_index(
@@ -81,39 +83,32 @@ class DataIndexPlugin:
list[tuple[str, int]]: Adjusted dataset index.
"""
if size is not None:
data_index = self.adjust_by_size(data_index, size)
data_index = random.choices(data_index, k=size)
if weight is not None:
data_index = self.adjust_by_weight(data_index, weight)
data_index = random.choices(data_index, k=int(len(data_index) * weight))
return data_index
def adjust_by_size(self, data_index: list[tuple[str, int]], size: int) -> list[tuple[str, int]]:
raise NotImplementedError()
def adjust_by_weight(self, data_index: list[tuple[str, int]], weight: float) -> list[tuple[str, int]]:
raise NotImplementedError()
@dataclass
class DataSelectorPlugin:
class DataSelectorPlugin(BasePlugin):
"""Plugin for selecting dataset samples."""
data_index: list[tuple[str, int]]
"""List of (dataset_name, sample_index)"""
def select(self, index: Union[slice, list[int], Any]) -> Union[tuple[str, int], list[tuple[str, int]]]:
def select(
self, data_index: list[tuple[str, int]], index: Union[slice, list[int], Any]
) -> Union[tuple[str, int], list[tuple[str, int]]]:
"""Select dataset samples.
Args:
data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
index (Union[slice, list[int], Any]): Index of dataset samples.
Returns:
Union[tuple[str, int], list[tuple[str, int]]]: Selected dataset samples.
"""
if isinstance(index, slice):
return [self.data_index[i] for i in range(*index.indices(len(self.data_index)))]
return [data_index[i] for i in range(*index.indices(len(data_index)))]
elif isinstance(index, list):
return [self.data_index[i] for i in index]
return [data_index[i] for i in index]
else:
raise ValueError(f"Invalid index type {type(index)}.")

View File

@@ -21,10 +21,3 @@ class KernelType(str, Enum):
FLASH_ATTENTION = "flash_attention"
ROPE = "rope"
MOE = "moe"
class DeviceType(str, Enum):
CPU = "cpu"
CUDA = "cuda"
NPU = "npu"
XPU = "xpu"

View File

@@ -18,10 +18,10 @@ import torch
import torch.nn.functional as F
import torch_npu
from .....accelerator.helper import is_torch_npu_available
from .....extras.packages import is_transformers_version_greater_than
from .....extras.types import HFModel
from ..constants import DeviceType, KernelType
from .....accelerator.helper import DeviceType, is_torch_npu_available
from .....utils.packages import is_transformers_version_greater_than
from .....utils.types import HFModel
from ..constants import KernelType
from ..registry import MetaMoEKernel

View File

@@ -17,9 +17,9 @@ import types
import torch
from .....accelerator.helper import is_torch_npu_available
from .....extras.types import HFModel
from ..constants import DeviceType, KernelType
from .....accelerator.helper import DeviceType, is_torch_npu_available
from .....utils.types import HFModel
from ..constants import KernelType
from ..registry import MetaSwiGluKernel

View File

@@ -13,11 +13,11 @@
# limitations under the License.
from abc import ABC, ABCMeta, abstractmethod
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, Union
from ....accelerator.helper import get_current_accelerator
from ....extras.types import HFModel
from .constants import DeviceType, KernelType
from ....accelerator.helper import DeviceType, get_current_accelerator
from ....utils.types import HFModel
from .constants import KernelType
class KernelRegistry:
@@ -27,11 +27,13 @@ class KernelRegistry:
def __new__(cls, *args: Any, **kwargs: Any) -> "KernelRegistry":
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self) -> None:
if self._initialized:
return
self._registry: dict[KernelType, dict[DeviceType, Callable[..., Any]]] = {}
self._initialized = True
@@ -218,7 +220,7 @@ def discover_kernels(model: HFModel = None) -> list[type[MetaKernel]]:
return discovered_kernels
# Iterate through registry and collect all kernels for current device
for kernel_type, devices in KERNEL_REGISTRY._registry.items():
for devices in KERNEL_REGISTRY._registry.values():
kernel_cls = devices.get(device_type)
if kernel_cls is not None:
discovered_kernels.append(kernel_cls)
@@ -226,7 +228,7 @@ def discover_kernels(model: HFModel = None) -> list[type[MetaKernel]]:
return discovered_kernels
def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> "HFModel":
def apply_kernel(model: HFModel, kernel: Union[type[MetaKernel], Any], /, **kwargs) -> "HFModel":
"""Call the MetaKernel's `apply` to perform the replacement.
Corresponding replacement logic is maintained inside each kernel; the only
@@ -238,16 +240,18 @@ def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> "HFMo
model = AutoModelForCausalLM.from_pretrained("qwen/qwen2.5-0.5B")
model = apply_kernel(model, NpuRMSNormKernel)
"""
if issubclass(kernel, MetaKernel) and kernel.device == get_current_accelerator().type:
return kernel.apply(model, **kwargs)
if not issubclass(kernel, MetaKernel):
raise ValueError(f"{kernel} must be a MetaKernel instance.")
raise ValueError(
f"{kernel} must be a MetaKernel instance, or the kernel don't match the device type. got {kernel.device} and {get_current_accelerator().type} instead."
)
if kernel.device != get_current_accelerator().type:
raise ValueError(f"{kernel} must be applied to {kernel.device} device, got {get_current_accelerator().type}.")
return kernel.apply(model, **kwargs)
def apply_available_kernels(model: HFModel, **kwargs) -> "HFModel":
"""Apply all available kernels to the model."""
for kernel in discover_kernels(model):
model = apply_kernel(model, kernel, **kwargs)
return model

View File

@@ -14,9 +14,9 @@
import re
import types
from .....accelerator.helper import is_torch_npu_available
from .....extras.types import HFModel
from ..constants import DeviceType, KernelType
from .....accelerator.helper import DeviceType, is_torch_npu_available
from .....utils.types import HFModel
from ..constants import KernelType
from ..registry import MetaRMSNormKernel

View File

@@ -16,9 +16,9 @@ import sys
import torch
from .....accelerator.helper import is_torch_npu_available
from .....extras.types import HFModel
from ..constants import DeviceType, KernelType
from .....accelerator.helper import DeviceType, is_torch_npu_available
from .....utils.types import HFModel
from ..constants import KernelType
from ..registry import MetaRoPEKernel

View File

@@ -0,0 +1,42 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Literal, TypedDict
from peft import LoraConfig, PeftModel, get_peft_model
from ...utils.plugin import BasePlugin
from ...utils.types import HFModel
class LoraConfigDict(TypedDict, total=False):
name: Literal["lora"]
"""Plugin name."""
r: int
"""Lora rank."""
lora_alpha: int
"""Lora alpha."""
target_modules: list[str]
"""Target modules."""
class PeftPlugin(BasePlugin):
pass
@PeftPlugin("lora").register
def get_lora_model(model: HFModel, config: LoraConfigDict) -> PeftModel:
peft_config = LoraConfig(**config)
model = get_peft_model(model, peft_config)
return model

View File

@@ -1,13 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -13,22 +13,21 @@
# limitations under the License.
from ..config.parser import get_args
from ..accelerator.interface import DistributedInterface, DistributedStrategy
from ..config.arg_parser import get_args
from ..core.base_trainer import BaseTrainer
from ..core.data_engine import DataEngine
from ..core.model_engine import ModelEngine
from ..core.model_worker import ModelWorker
class SFTTrainer(BaseTrainer):
pass
def run_sft():
model_args, data_args, training_args, _ = get_args()
model_engine = ModelEngine(model_args)
def run_sft(user_args):
model_args, data_args, training_args, _ = get_args(user_args)
DistributedInterface(DistributedStrategy())
data_engine = DataEngine(data_args)
model = model_engine.get_model()
processor = model_engine.get_processor()
data_loader = data_engine.get_data_loader(processor)
trainer = SFTTrainer(training_args, model, processor, data_loader)
model_worker = ModelWorker(model_args)
trainer = SFTTrainer(training_args, model_worker, data_engine)
trainer.fit()

View File

View File

@@ -0,0 +1,13 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,123 @@
# Copyright 2025 Optuna, HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/utils/logging.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys
import threading
from functools import lru_cache
from typing import Optional
_thread_lock = threading.RLock()
_default_handler: Optional["logging.Handler"] = None
_default_log_level: "logging._Level" = logging.INFO
class _Logger(logging.Logger):
r"""A logger that supports rank0 logging."""
def info_rank0(self, *args, **kwargs) -> None:
self.info(*args, **kwargs)
def warning_rank0(self, *args, **kwargs) -> None:
self.warning(*args, **kwargs)
def warning_rank0_once(self, *args, **kwargs) -> None:
self.warning(*args, **kwargs)
def _get_default_logging_level() -> "logging._Level":
r"""Return the default logging level."""
env_level_str = os.getenv("LLAMAFACTORY_VERBOSITY", None)
if env_level_str:
if env_level_str.upper() in logging._nameToLevel:
return logging._nameToLevel[env_level_str.upper()]
else:
raise ValueError(f"Unknown logging level: {env_level_str}.")
return _default_log_level
def _get_library_name() -> str:
return __name__.split(".")[0]
def _get_library_root_logger() -> "_Logger":
return logging.getLogger(_get_library_name())
def _configure_library_root_logger() -> None:
r"""Configure root logger using a stdout stream handler with an explicit format."""
global _default_handler
with _thread_lock:
if _default_handler: # already configured
return
formatter = logging.Formatter(
fmt="[%(levelname)s|%(asctime)s] %(name)s:%(lineno)s >> %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
_default_handler = logging.StreamHandler(sys.stdout)
_default_handler.setFormatter(formatter)
library_root_logger = _get_library_root_logger()
library_root_logger.addHandler(_default_handler)
library_root_logger.setLevel(_get_default_logging_level())
library_root_logger.propagate = False
def get_logger(name: Optional[str] = None) -> "_Logger":
r"""Return a logger with the specified name. It it not supposed to be accessed externally."""
if name is None:
name = _get_library_name()
_configure_library_root_logger()
return logging.getLogger(name)
def add_handler(handler: "logging.Handler") -> None:
r"""Add a handler to the root logger."""
_configure_library_root_logger()
_get_library_root_logger().addHandler(handler)
def remove_handler(handler: logging.Handler) -> None:
r"""Remove a handler to the root logger."""
_configure_library_root_logger()
_get_library_root_logger().removeHandler(handler)
def info_rank0(self: "logging.Logger", *args, **kwargs) -> None:
if int(os.getenv("LOCAL_RANK", "0")) == 0:
self.info(*args, **kwargs)
def warning_rank0(self: "logging.Logger", *args, **kwargs) -> None:
if int(os.getenv("LOCAL_RANK", "0")) == 0:
self.warning(*args, **kwargs)
@lru_cache(None)
def warning_rank0_once(self: "logging.Logger", *args, **kwargs) -> None:
if int(os.getenv("LOCAL_RANK", "0")) == 0:
self.warning(*args, **kwargs)
logging.Logger.info_rank0 = info_rank0
logging.Logger.warning_rank0 = warning_rank0
logging.Logger.warning_rank0_once = warning_rank0_once

View File

@@ -0,0 +1,86 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Optional
from . import logging
logger = logging.get_logger(__name__)
class BasePlugin:
"""Base class for plugins.
A plugin is a callable object that can be registered and called by name.
"""
_registry: dict[str, Callable] = {}
def __init__(self, name: Optional[str] = None):
"""Initialize the plugin with a name.
Args:
name (str): The name of the plugin.
"""
self.name = name
@property
def register(self) -> Callable:
"""Decorator to register a function as a plugin.
Example usage:
```python
@PrintPlugin("hello").register()
def print_hello():
print("Hello world!")
```
"""
if self.name is None:
raise ValueError("Plugin name is not specified.")
if self.name in self._registry:
logger.warning_rank0_once(f"Plugin {self.name} is already registered.")
def decorator(func: Callable) -> Callable:
self._registry[self.name] = func
return func
return decorator
def __call__(self, *args, **kwargs) -> Callable:
"""Call the registered function with the given arguments.
Example usage:
```python
PrintPlugin("hello")()
```
"""
if self.name not in self._registry:
raise ValueError(f"Plugin {self.name} is not registered.")
return self._registry[self.name](*args, **kwargs)
if __name__ == "__main__":
class PrintPlugin(BasePlugin):
pass
@PrintPlugin("hello").register
def print_hello():
print("Hello world!")
PrintPlugin("hello")()

View File

@@ -19,23 +19,27 @@ from typing_extensions import NotRequired
if TYPE_CHECKING:
import datasets
import numpy as np
import torch
import torch.utils.data
import transformers
from torch.distributed.fsdp import FullyShardedDataParallel
Tensor = torch.Tensor
TensorLike = Union[int, float, list[int], list[float], np.ndarray, Tensor]
TorchDataset = Union[torch.utils.data.Dataset, torch.utils.data.IterableDataset]
HFDataset = Union[datasets.Dataset, datasets.IterableDataset]
DataCollator = transformers.DataCollator
DataLoader = torch.utils.data.DataLoader
HFConfig = transformers.PretrainedConfig
HFModel = transformers.PreTrainedModel
DistModel = torch.nn.parallel.DistributedDataParallel
DistModel = Union[torch.nn.parallel.DistributedDataParallel, FullyShardedDataParallel]
Processor = Union[transformers.PreTrainedTokenizer, transformers.ProcessorMixin]
Optimizer = torch.optim.Optimizer
Scheduler = torch.optim.lr_scheduler.LRScheduler
else:
Tensor = None
TensorLike = None
TorchDataset = None
HFDataset = None
DataCollator = None
@@ -49,12 +53,10 @@ else:
class DatasetInfo(TypedDict, total=False):
hf_hub_url: NotRequired[str]
"""HF hub dataset uri."""
file_name: NotRequired[str]
path: str
"""Local file path."""
dataset_dir: NotRequired[str]
"""Dataset directory, default to args.dataset_dir."""
source: NotRequired[Literal["hf_hub", "ms_hub", "local"]]
"""Dataset source, default to "hf_hub"."""
split: NotRequired[str]
"""Dataset split, default to "train"."""
converter: NotRequired[str]
@@ -68,12 +70,12 @@ class DatasetInfo(TypedDict, total=False):
class Content(TypedDict):
type: Literal["text", "tools", "reasoning", "tool_calls", "image_url"]
type: Literal["text", "reasoning", "tools", "tool_calls", "image_url"]
value: str
class Message(TypedDict):
role: Literal["system", "user", "assistant"]
role: Literal["system", "user", "assistant", "tool"]
content: list[Content]
loss_weight: float

View File

@@ -12,25 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from torch.distributed.device_mesh import DeviceMesh
import os
from llamafactory.v1.accelerator.interface import DistributedInterface, DistributedStrategy
class DeviceMeshManager:
"""Device mesh manager."""
_instance: Optional["DeviceMeshManager"] = None
_initialized: bool = False
def __new__(cls) -> "DeviceMeshManager":
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self) -> None:
if self._initialized:
return
self.device_mesh: Optional[DeviceMesh] = None
self._initialized = True
def test_distributed_interface():
DistributedInterface(DistributedStrategy())
assert DistributedInterface.rank == int(os.getenv("RANK", "0"))
assert DistributedInterface.world_size == int(os.getenv("WORLD_SIZE", "1"))

View File

@@ -19,7 +19,7 @@ from datasets import load_dataset
from llamafactory.v1.config.data_args import DataArguments
from llamafactory.v1.core.data_engine import DataEngine
from llamafactory.v1.plugins.data_plugins.converter import get_converter
from llamafactory.v1.plugins.data_plugins.converter import DataConverterPlugin
@pytest.mark.parametrize("num_samples", [16])
@@ -49,99 +49,27 @@ def test_alpaca_converter(num_samples: int):
assert data_engine[index] == {"_dataset_name": "tiny_dataset", **expected_data}
def test_sharegpt_converter_invalid():
def test_sharegpt_converter():
example = {
"conversations": [
{
"from": "system",
"value": "Processes historical market data to generate trading signals "
"based on specified technical indicators.",
},
{
"from": "human",
"value": "I possess a detailed dataset, 'Historical_Market_Data.csv'. "
"Could you proceed with these function calls to assist me with the task?",
},
{
"from": "gpt",
"value": "```tool_call\n{'arguments': '{\"data_file\": \"Historical_Market_Data.csv\"]}', "
"'name': 'backtest_trading_signals'}```\n",
},
{
"from": "tool",
"value": '<tool id="D2">\n{"analysis": {"RSI_signals": [{"date": "2025-01-10", '
'"symbol": "AAPL", "signal": "Buy"}]}}}\n</tool>\n',
},
{"from": "system", "value": "System"},
{"from": "human", "value": "User"},
{"from": "gpt", "value": "Assistant"},
]
}
dataset_converter = get_converter("sharegpt")
assert dataset_converter(example) == {"messages": []}
def test_sharegpt_converter_valid():
example = {
"conversations": [
{
"from": "system",
"value": "Processes historical market data to generate trading signals based on "
"specified technical indicators.",
},
{
"from": "human",
"value": "I possess a detailed dataset, 'Historical_Market_Data.csv'. "
"Could you proceed with these function calls to assist me with the task?",
},
{
"from": "gpt",
"value": "```tool_call\n{'arguments': '{\"data_file\": \"Historical_Market_Data.csv\"]}', "
"'name': 'backtest_trading_signals'}```\n",
},
]
}
dataset_converter = get_converter("sharegpt")
expected_data = {
"messages": [
{
"content": [
{
"type": "text",
"value": "Processes historical market data to generate trading signals based on "
"specified technical indicators.",
}
],
"loss_weight": 0.0,
"role": "system",
},
{
"content": [
{
"type": "text",
"value": "I possess a detailed dataset, 'Historical_Market_Data.csv'. "
"Could you proceed with these function calls to assist me with the task?",
}
],
"loss_weight": 0.0,
"role": "user",
},
{
"content": [
{
"type": "text",
"value": "```tool_call\n{'arguments': '{\"data_file\": \"Historical_Market_Data.csv\"]}', "
"'name': 'backtest_trading_signals'}```\n",
}
],
"loss_weight": 1.0,
"role": "assistant",
},
{"content": [{"type": "text", "value": "System"}], "loss_weight": 0.0, "role": "system"},
{"content": [{"type": "text", "value": "User"}], "loss_weight": 0.0, "role": "user"},
{"content": [{"type": "text", "value": "Assistant"}], "loss_weight": 1.0, "role": "assistant"},
]
}
assert dataset_converter(example) == expected_data
assert DataConverterPlugin("sharegpt")(example) == expected_data
@pytest.mark.parametrize("num_samples", [16])
def test_pair_converter(num_samples: int):
data_args = DataArguments(dataset="frozenleaves/tiny-dpo/dataset_info.yaml")
data_args = DataArguments(dataset="llamafactory/tiny-preference-dataset/dataset_info.yaml")
data_engine = DataEngine(data_args)
original_data = load_dataset("HuggingFaceH4/orca_dpo_pairs", split="train_prefs")
indexes = random.choices(range(len(data_engine)), k=num_samples)
@@ -189,6 +117,5 @@ def test_pair_converter(num_samples: int):
if __name__ == "__main__":
test_alpaca_converter(1)
test_sharegpt_converter_invalid()
test_sharegpt_converter_valid()
test_sharegpt_converter()
test_pair_converter(1)

View File

@@ -17,10 +17,13 @@ from unittest.mock import MagicMock, patch
from transformers import AutoModelForCausalLM
from llamafactory.v1.accelerator.helper import get_current_accelerator
class TestKernelPlugin(unittest.TestCase):
@patch("torch.accelerator.current_accelerator")
def test_apply_kernel(self, mock_get_accelerator):
get_current_accelerator.cache_clear()
mock_device = MagicMock()
mock_device.type = "npu"
mock_get_accelerator.return_value = mock_device
@@ -47,6 +50,7 @@ class TestKernelPlugin(unittest.TestCase):
class Test_Use_V1_Kernels(unittest.TestCase):
@patch("torch.accelerator.current_accelerator")
def test_use_v1_kernels(self, mock_get_accelerator):
get_current_accelerator.cache_clear()
mock_device = MagicMock()
mock_device.type = "npu"
mock_get_accelerator.return_value = mock_device