mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-15 03:10:35 +08:00
[v1] add accelerator (#9607)
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
dpo_zh_demo:
|
dpo_zh_demo:
|
||||||
hf_hub_url: HuggingFaceH4/orca_dpo_pairs
|
path: HuggingFaceH4/orca_dpo_pairs
|
||||||
split: train_prefs
|
split: train_prefs
|
||||||
converter: pair
|
converter: pair
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
identity:
|
identity:
|
||||||
file_name: data/identity.json
|
path: data/identity.json
|
||||||
|
source: local
|
||||||
converter: alpaca
|
converter: alpaca
|
||||||
alpaca_en_demo:
|
alpaca_en_demo:
|
||||||
file_name: alpaca_en_demo.json
|
path: data/alpaca_en_demo.json
|
||||||
dataset_dir: data
|
source: local
|
||||||
converter: alpaca
|
converter: alpaca
|
||||||
num_samples: 500
|
size: 500
|
||||||
|
|||||||
@@ -1,4 +1,7 @@
|
|||||||
# Copyright 2025 the LlamaFactory team.
|
# Copyright 2025 Bytedance Ltd. and the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# This code is inspired by the Bytedance's VeOmni library.
|
||||||
|
# https://github.com/ByteDance-Seed/VeOmni/blob/v0.1.4/veomni/utils/dist_utils.py
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -12,12 +15,68 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from enum import Enum, unique
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from ..utils.types import Tensor, TensorLike
|
||||||
|
|
||||||
|
|
||||||
def get_current_accelerator(check_available: bool = True):
|
if TYPE_CHECKING:
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
|
||||||
|
@unique
|
||||||
|
class DeviceType(str, Enum):
|
||||||
|
CPU = "cpu"
|
||||||
|
CUDA = "cuda"
|
||||||
|
META = "meta"
|
||||||
|
MPS = "mps"
|
||||||
|
NPU = "npu"
|
||||||
|
XPU = "xpu"
|
||||||
|
|
||||||
|
|
||||||
|
@unique
|
||||||
|
class ReduceOp(str, Enum):
|
||||||
|
SUM = "sum"
|
||||||
|
MEAN = "mean"
|
||||||
|
MAX = "max"
|
||||||
|
MIN = "min"
|
||||||
|
|
||||||
|
|
||||||
|
def is_distributed() -> bool:
|
||||||
|
"""Check if distributed environment is available."""
|
||||||
|
return os.getenv("RANK") is not None
|
||||||
|
|
||||||
|
|
||||||
|
def get_rank() -> int:
|
||||||
|
"""Get rank."""
|
||||||
|
return int(os.getenv("RANK", "0"))
|
||||||
|
|
||||||
|
|
||||||
|
def get_local_rank() -> int:
|
||||||
|
"""Get local rank."""
|
||||||
|
return int(os.getenv("LOCAL_RANK", "0"))
|
||||||
|
|
||||||
|
|
||||||
|
def get_world_size() -> int:
|
||||||
|
"""Get world size."""
|
||||||
|
return int(os.getenv("WORLD_SIZE", "1"))
|
||||||
|
|
||||||
|
|
||||||
|
def get_local_world_size() -> int:
|
||||||
|
"""Get local world size."""
|
||||||
|
return int(os.getenv("LOCAL_WORLD_SIZE", "1"))
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def get_current_accelerator(check_available: bool = True) -> torch.device:
|
||||||
"""Get current accelerator.
|
"""Get current accelerator.
|
||||||
|
|
||||||
Note: this api requires torch>=2.7.0, 2.6 or lower will get an AttributeError or RuntimeError
|
Note: this api requires torch>=2.7.0, 2.6 or lower will get an AttributeError or RuntimeError
|
||||||
@@ -27,26 +86,78 @@ def get_current_accelerator(check_available: bool = True):
|
|||||||
|
|
||||||
accelerator = torch.accelerator.current_accelerator(check_available=check_available)
|
accelerator = torch.accelerator.current_accelerator(check_available=check_available)
|
||||||
if accelerator is None:
|
if accelerator is None:
|
||||||
return torch.device("cpu")
|
return torch.device(DeviceType.CPU.value)
|
||||||
|
|
||||||
return accelerator
|
return accelerator
|
||||||
|
|
||||||
|
|
||||||
@lru_cache
|
|
||||||
def is_torch_npu_available():
|
|
||||||
return get_current_accelerator().type == "npu"
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache
|
|
||||||
def is_torch_cuda_available():
|
def is_torch_cuda_available():
|
||||||
return get_current_accelerator().type == "cuda"
|
return get_current_accelerator().type == DeviceType.CUDA
|
||||||
|
|
||||||
|
|
||||||
@lru_cache
|
|
||||||
def is_torch_xpu_available():
|
|
||||||
return get_current_accelerator().type == "xpu"
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache
|
|
||||||
def is_torch_mps_available():
|
def is_torch_mps_available():
|
||||||
return get_current_accelerator().type == "mps"
|
return get_current_accelerator().type == DeviceType.MPS
|
||||||
|
|
||||||
|
|
||||||
|
def is_torch_npu_available():
|
||||||
|
return get_current_accelerator().type == DeviceType.NPU
|
||||||
|
|
||||||
|
|
||||||
|
def is_torch_xpu_available():
|
||||||
|
return get_current_accelerator().type == DeviceType.XPU
|
||||||
|
|
||||||
|
|
||||||
|
def all_gather(tensor: Tensor, group: Optional["ProcessGroup"] = None) -> Tensor:
|
||||||
|
"""Gathers the tensor from all ranks and concats them along the first dim."""
|
||||||
|
world_size = get_world_size()
|
||||||
|
device = get_current_accelerator()
|
||||||
|
output_tensor = torch.empty(world_size * tensor.numel(), dtype=tensor.dtype, device=device)
|
||||||
|
dist.all_gather_into_tensor(output_tensor, tensor, group=group)
|
||||||
|
return output_tensor.view(-1, *tensor.size()[1:])
|
||||||
|
|
||||||
|
|
||||||
|
def all_reduce(data: TensorLike, op: ReduceOp = ReduceOp.MEAN, group: Optional["ProcessGroup"] = None) -> TensorLike:
|
||||||
|
"""Performs all reduce in the given process group."""
|
||||||
|
device = get_current_accelerator()
|
||||||
|
is_ndarray = isinstance(data, np.ndarray)
|
||||||
|
is_tensor = isinstance(data, torch.Tensor)
|
||||||
|
|
||||||
|
if is_ndarray:
|
||||||
|
data = torch.from_numpy(data)
|
||||||
|
elif not is_tensor:
|
||||||
|
data = torch.tensor(data, dtype=torch.float, device=device)
|
||||||
|
|
||||||
|
reduce_ops = {
|
||||||
|
ReduceOp.MEAN: dist.ReduceOp.SUM,
|
||||||
|
ReduceOp.SUM: dist.ReduceOp.SUM,
|
||||||
|
ReduceOp.MAX: dist.ReduceOp.MAX,
|
||||||
|
ReduceOp.MIN: dist.ReduceOp.MIN,
|
||||||
|
}
|
||||||
|
dist.all_reduce(data, op=reduce_ops[op], group=group)
|
||||||
|
if op == ReduceOp.MEAN: # ReduceOp.AVG is not supported by the NPU backend
|
||||||
|
data /= dist.get_world_size(group=group)
|
||||||
|
|
||||||
|
if is_tensor:
|
||||||
|
return data
|
||||||
|
elif is_ndarray:
|
||||||
|
return data.numpy()
|
||||||
|
elif data.numel() == 1:
|
||||||
|
return data.item()
|
||||||
|
else:
|
||||||
|
return data.tolist()
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def main_process_first(local_only: bool = True) -> None:
|
||||||
|
"""A context manager for torch distributed environment to do something on the main process firstly."""
|
||||||
|
if get_world_size() > 1:
|
||||||
|
is_main_process = get_local_rank() == 0 if local_only else get_rank() == 0
|
||||||
|
try:
|
||||||
|
if not is_main_process:
|
||||||
|
dist.barrier()
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
if is_main_process:
|
||||||
|
dist.barrier()
|
||||||
|
else:
|
||||||
|
yield
|
||||||
|
|||||||
123
src/llamafactory/v1/accelerator/interface.py
Normal file
123
src/llamafactory/v1/accelerator/interface.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
|
||||||
|
|
||||||
|
from ..utils.types import TensorLike
|
||||||
|
from .helper import ReduceOp, all_reduce, get_current_accelerator, get_rank, get_world_size, is_distributed
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DistributedStrategy:
|
||||||
|
"""Distributed strategy."""
|
||||||
|
|
||||||
|
dp_size: Optional[int] = None
|
||||||
|
tp_size: int = 1
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
if not is_distributed():
|
||||||
|
self.dp_size = 1
|
||||||
|
elif self.dp_size is None:
|
||||||
|
self.dp_size = get_world_size() // self.tp_size
|
||||||
|
elif self.dp_size * self.tp_size != get_world_size():
|
||||||
|
raise ValueError(
|
||||||
|
f"dp_size * tp_size must equal to world_size, "
|
||||||
|
f"got {self.dp_size} * {self.tp_size} != {get_world_size()}."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mesh_shape(self) -> tuple[int, int]:
|
||||||
|
"""Mesh shape."""
|
||||||
|
return (self.dp_size, self.tp_size)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mesh_dim_names(self) -> tuple[str, str]:
|
||||||
|
"""Mesh dimension names."""
|
||||||
|
return ("dp", "tp")
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedInterface:
|
||||||
|
"""Distributed interface."""
|
||||||
|
|
||||||
|
_instance: Optional["DistributedInterface"] = None
|
||||||
|
_initialized: bool = False
|
||||||
|
|
||||||
|
is_distributed = is_distributed()
|
||||||
|
"""Check if distributed environment is available."""
|
||||||
|
rank = get_rank()
|
||||||
|
"""Global rank."""
|
||||||
|
world_size = get_world_size()
|
||||||
|
"""Global world size."""
|
||||||
|
device_mesh: Optional[DeviceMesh] = None
|
||||||
|
"""Device mesh."""
|
||||||
|
current_accelerator = get_current_accelerator()
|
||||||
|
"""Current accelerator."""
|
||||||
|
|
||||||
|
def __new__(cls, *args: Any, **kwargs: Any) -> "DistributedInterface":
|
||||||
|
"""Singleton pattern."""
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self, strategy: DistributedStrategy) -> None:
|
||||||
|
if self._initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.strategy = strategy
|
||||||
|
if self.is_distributed:
|
||||||
|
self.device_mesh = init_device_mesh(
|
||||||
|
device_type=self.current_accelerator.type,
|
||||||
|
mesh_shape=strategy.mesh_shape,
|
||||||
|
mesh_dim_names=strategy.mesh_dim_names,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.device_mesh = None
|
||||||
|
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return (
|
||||||
|
f"DistributedInterface(strategy={self.strategy}), is_distributed={self.is_distributed}, "
|
||||||
|
f"rank={self.rank}, world_size={self.world_size}, "
|
||||||
|
f"device_mesh={self.device_mesh}, current_accelerator={self.current_accelerator}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def dp_rank(self) -> int:
|
||||||
|
"""Data parallel rank."""
|
||||||
|
if self.device_mesh is None:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
return self.device_mesh["dp"].get_rank()
|
||||||
|
|
||||||
|
def dp_size(self) -> int:
|
||||||
|
"""Data parallel size."""
|
||||||
|
if self.device_mesh is None:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
return self.device_mesh["dp"].size()
|
||||||
|
|
||||||
|
def all_reduce_over_dp(self, data: TensorLike, op: ReduceOp = ReduceOp.MEAN) -> TensorLike:
|
||||||
|
"""All reduce tensor."""
|
||||||
|
if self.device_mesh is None:
|
||||||
|
return data
|
||||||
|
|
||||||
|
return all_reduce(data, op, self.device_mesh["dp"].get_group())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print(DistributedInterface(DistributedStrategy()))
|
||||||
@@ -28,6 +28,21 @@ from .sample_args import SampleArguments
|
|||||||
from .training_args import TrainingArguments
|
from .training_args import TrainingArguments
|
||||||
|
|
||||||
|
|
||||||
|
def validate_args(
|
||||||
|
data_args: DataArguments,
|
||||||
|
model_args: ModelArguments,
|
||||||
|
training_args: TrainingArguments,
|
||||||
|
sample_args: SampleArguments,
|
||||||
|
):
|
||||||
|
"""Validate arguments."""
|
||||||
|
if (
|
||||||
|
model_args.quant_config is not None
|
||||||
|
and training_args.dist_config is not None
|
||||||
|
and training_args.dist_config.name == "deepspeed"
|
||||||
|
):
|
||||||
|
raise ValueError("Quantization is not supported with deepspeed backend.")
|
||||||
|
|
||||||
|
|
||||||
def get_args(
|
def get_args(
|
||||||
args: Optional[Union[dict[str, Any], list[str]]] = None,
|
args: Optional[Union[dict[str, Any], list[str]]] = None,
|
||||||
) -> tuple[DataArguments, ModelArguments, TrainingArguments, SampleArguments]:
|
) -> tuple[DataArguments, ModelArguments, TrainingArguments, SampleArguments]:
|
||||||
@@ -56,6 +71,8 @@ def get_args(
|
|||||||
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
|
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
|
||||||
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
|
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
|
||||||
|
|
||||||
|
validate_args(*parsed_args)
|
||||||
|
|
||||||
return tuple(parsed_args)
|
return tuple(parsed_args)
|
||||||
|
|
||||||
|
|
||||||
105
src/llamafactory/v1/config/arg_utils.py
Normal file
105
src/llamafactory/v1/config/arg_utils.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# This code is inspired by the HuggingFace's transformers library.
|
||||||
|
# https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/training_args.py
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import json
|
||||||
|
from enum import Enum, unique
|
||||||
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
|
|
||||||
|
class PluginConfig(dict):
|
||||||
|
"""Dictionary that allows attribute access."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Plugin name."""
|
||||||
|
if "name" not in self:
|
||||||
|
raise ValueError("Plugin configuration must have a 'name' field.")
|
||||||
|
|
||||||
|
return self["name"]
|
||||||
|
|
||||||
|
def __getattr__(self, key: str) -> Any:
|
||||||
|
try:
|
||||||
|
return self[key]
|
||||||
|
except KeyError:
|
||||||
|
raise AttributeError(f"Attribute {key} not found.")
|
||||||
|
|
||||||
|
def __setattr__(self, key: str, value: Any):
|
||||||
|
self[key] = value
|
||||||
|
|
||||||
|
|
||||||
|
PluginArgument = Optional[Union[PluginConfig, dict, str]]
|
||||||
|
|
||||||
|
|
||||||
|
@unique
|
||||||
|
class AutoClass(str, Enum):
|
||||||
|
"""Auto class for model config."""
|
||||||
|
|
||||||
|
CAUSALLM = "llm"
|
||||||
|
CLASSIFICATION = "cls"
|
||||||
|
OTHER = "other"
|
||||||
|
|
||||||
|
|
||||||
|
@unique
|
||||||
|
class SampleBackend(str, Enum):
|
||||||
|
HF = "hf"
|
||||||
|
VLLM = "vllm"
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_str_dict(data: dict) -> dict:
|
||||||
|
"""Parse string representation inside the dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: The string or dictionary to convert.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The converted dictionary.
|
||||||
|
"""
|
||||||
|
for key, value in data.items():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
data[key] = _convert_str_dict(value)
|
||||||
|
elif isinstance(value, str):
|
||||||
|
if value.lower() in ("true", "false"):
|
||||||
|
data[key] = value.lower() == "true"
|
||||||
|
elif value.isdigit():
|
||||||
|
data[key] = int(value)
|
||||||
|
elif value.replace(".", "", 1).isdigit():
|
||||||
|
data[key] = float(value)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def get_plugin_config(config: PluginArgument) -> Optional[PluginConfig]:
|
||||||
|
"""Get the plugin configuration from the argument value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: The argument value to get the plugin configuration from.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The plugin configuration.
|
||||||
|
"""
|
||||||
|
if config is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if isinstance(config, str) and config.startswith("{"):
|
||||||
|
config = json.loads(config)
|
||||||
|
|
||||||
|
config = _convert_str_dict(config)
|
||||||
|
if "name" not in config:
|
||||||
|
raise ValueError("Plugin configuration must have a 'name' field.")
|
||||||
|
|
||||||
|
return PluginConfig(config)
|
||||||
@@ -15,6 +15,8 @@
|
|||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from .arg_utils import AutoClass, PluginConfig, get_plugin_config
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelArguments:
|
class ModelArguments:
|
||||||
@@ -29,7 +31,24 @@ class ModelArguments:
|
|||||||
default=True,
|
default=True,
|
||||||
metadata={"help": "Use fast processor from Hugging Face."},
|
metadata={"help": "Use fast processor from Hugging Face."},
|
||||||
)
|
)
|
||||||
auto_model_class: str = field(
|
auto_class: AutoClass = field(
|
||||||
default="causallm",
|
default=AutoClass.CAUSALLM,
|
||||||
metadata={"help": "Model class from Hugging Face."},
|
metadata={"help": "Model class from Hugging Face."},
|
||||||
)
|
)
|
||||||
|
peft_config: PluginConfig = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "PEFT configuration for the model."},
|
||||||
|
)
|
||||||
|
kernel_config: PluginConfig = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Kernel configuration for the model."},
|
||||||
|
)
|
||||||
|
quant_config: PluginConfig = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Quantization configuration for the model."},
|
||||||
|
)
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
self.peft_config = get_plugin_config(self.peft_config)
|
||||||
|
self.kernel_config = get_plugin_config(self.kernel_config)
|
||||||
|
self.quant_config = get_plugin_config(self.quant_config)
|
||||||
|
|||||||
@@ -14,12 +14,8 @@
|
|||||||
|
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
|
from .arg_utils import SampleBackend
|
||||||
class SampleBackend(Enum):
|
|
||||||
HF = "hf"
|
|
||||||
VLLM = "vllm"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -15,6 +15,8 @@
|
|||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from .arg_utils import PluginArgument, get_plugin_config
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainingArguments:
|
class TrainingArguments:
|
||||||
@@ -38,3 +40,10 @@ class TrainingArguments:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Use bf16 for training."},
|
metadata={"help": "Use bf16 for training."},
|
||||||
)
|
)
|
||||||
|
dist_config: PluginArgument = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Distribution configuration for training."},
|
||||||
|
)
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
self.dist_config = get_plugin_config(self.dist_config)
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ Train Phase:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from ..config.training_args import TrainingArguments
|
from ..config.training_args import TrainingArguments
|
||||||
from ..extras.types import TorchDataset
|
from ..utils.types import TorchDataset
|
||||||
from .model_worker import ModelWorker
|
from .model_worker import ModelWorker
|
||||||
from .trainer_utils.data_collator import DataCollator
|
from .trainer_utils.data_collator import DataCollator
|
||||||
|
|
||||||
@@ -49,13 +49,10 @@ class BaseTrainer:
|
|||||||
self.optimizer = None
|
self.optimizer = None
|
||||||
self.lr_scheduler = None
|
self.lr_scheduler = None
|
||||||
|
|
||||||
def init_device_mesh(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def init_model_and_optimizer(self) -> None:
|
def init_model_and_optimizer(self) -> None:
|
||||||
self.model_config = self.model_worker.get_model_config()
|
self.model_worker.init_model_config()
|
||||||
# with self.dist_plugin.get_model_init_context():
|
# with self.dist_plugin.get_model_init_context():
|
||||||
# self.model = self.model_worker.get_model(self.model_config)
|
# self.model = self.model_worker.init_model(self.model_config)
|
||||||
|
|
||||||
def create_dataloader(self) -> None:
|
def create_dataloader(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ from omegaconf import OmegaConf
|
|||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from ..config.data_args import DataArguments
|
from ..config.data_args import DataArguments
|
||||||
from ..extras.types import DatasetInfo, HFDataset, Sample
|
from ..utils.types import DatasetInfo, HFDataset, Sample
|
||||||
|
|
||||||
|
|
||||||
class DataEngine(Dataset):
|
class DataEngine(Dataset):
|
||||||
@@ -64,9 +64,9 @@ class DataEngine(Dataset):
|
|||||||
filepath = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")
|
filepath = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")
|
||||||
self.dataset_infos = OmegaConf.load(filepath)
|
self.dataset_infos = OmegaConf.load(filepath)
|
||||||
elif os.path.exists(self.args.dataset): # local file(s)
|
elif os.path.exists(self.args.dataset): # local file(s)
|
||||||
self.dataset_infos = {"default": {"file_name": self.args.dataset}}
|
self.dataset_infos = {"default": {"path": self.args.dataset, "source": "local"}}
|
||||||
else: # hf hub dataset, e.g. llamafactory/v1-sft-demo
|
else: # hf hub dataset, e.g. llamafactory/v1-sft-demo
|
||||||
self.dataset_infos = {"default": {"hf_hub_url": self.args.dataset}}
|
self.dataset_infos = {"default": {"path": self.args.dataset}}
|
||||||
|
|
||||||
def load_dataset(self) -> None:
|
def load_dataset(self) -> None:
|
||||||
"""Load datasets according to dataset info."""
|
"""Load datasets according to dataset info."""
|
||||||
@@ -74,14 +74,14 @@ class DataEngine(Dataset):
|
|||||||
split = value.get("split", "train")
|
split = value.get("split", "train")
|
||||||
streaming = value.get("streaming", False)
|
streaming = value.get("streaming", False)
|
||||||
self.streaming |= streaming
|
self.streaming |= streaming
|
||||||
if "hf_hub_url" in value:
|
if value.get("source", "hf_hub") == "hf_hub":
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|
||||||
self.datasets[key] = load_dataset(value["hf_hub_url"], split=split, streaming=streaming)
|
self.datasets[key] = load_dataset(value["path"], split=split, streaming=streaming)
|
||||||
else: # data loader plugin
|
else: # data loader plugin
|
||||||
from ..plugins.data_plugins.loader import DataLoaderPlugin
|
from ..plugins.data_plugins.loader import DataLoaderPlugin
|
||||||
|
|
||||||
self.datasets[key] = DataLoaderPlugin().auto_load_data(value)
|
self.datasets[key] = DataLoaderPlugin(value["source"]).load(value)
|
||||||
|
|
||||||
def build_data_index(self) -> None:
|
def build_data_index(self) -> None:
|
||||||
"""Build dataset index."""
|
"""Build dataset index."""
|
||||||
@@ -112,9 +112,9 @@ class DataEngine(Dataset):
|
|||||||
"""
|
"""
|
||||||
converter = self.dataset_infos[dataset_name].get("converter")
|
converter = self.dataset_infos[dataset_name].get("converter")
|
||||||
if converter is not None:
|
if converter is not None:
|
||||||
from ..plugins.data_plugins.converter import get_converter
|
from ..plugins.data_plugins.converter import DataConverterPlugin
|
||||||
|
|
||||||
return {"_dataset_name": dataset_name, **get_converter(converter)(raw_sample)}
|
return {"_dataset_name": dataset_name, **DataConverterPlugin(converter)(raw_sample)}
|
||||||
else:
|
else:
|
||||||
return {"_dataset_name": dataset_name, **raw_sample}
|
return {"_dataset_name": dataset_name, **raw_sample}
|
||||||
|
|
||||||
@@ -147,7 +147,7 @@ class DataEngine(Dataset):
|
|||||||
else:
|
else:
|
||||||
from ..plugins.data_plugins.loader import DataSelectorPlugin
|
from ..plugins.data_plugins.loader import DataSelectorPlugin
|
||||||
|
|
||||||
selected_index = DataSelectorPlugin(data_index=self.data_index).select(index)
|
selected_index = DataSelectorPlugin().select(self.data_index, index)
|
||||||
if isinstance(selected_index, list):
|
if isinstance(selected_index, list):
|
||||||
return [
|
return [
|
||||||
self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
|
self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
|
||||||
@@ -187,7 +187,7 @@ class DataEngine(Dataset):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from ..config.parser import get_args
|
from ..config.arg_parser import get_args
|
||||||
|
|
||||||
data_args, *_ = get_args()
|
data_args, *_ = get_args()
|
||||||
data_engine = DataEngine(data_args=data_args)
|
data_engine = DataEngine(data_args=data_args)
|
||||||
|
|||||||
@@ -24,10 +24,12 @@ Init Phase:
|
|||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
from transformers import AutoConfig, AutoProcessor
|
from transformers import AutoConfig, AutoProcessor
|
||||||
|
|
||||||
from ..config.model_args import ModelArguments
|
from ..accelerator.helper import DeviceType
|
||||||
from ..extras.types import DistModel, HFConfig, HFModel, Processor
|
from ..config.model_args import AutoClass, ModelArguments
|
||||||
|
from ..utils.types import HFConfig, HFModel, Processor
|
||||||
|
|
||||||
|
|
||||||
class ModelWorker:
|
class ModelWorker:
|
||||||
@@ -38,16 +40,15 @@ class ModelWorker:
|
|||||||
"""Tokenizer or multi-modal processor."""
|
"""Tokenizer or multi-modal processor."""
|
||||||
self.model_config: Optional[HFConfig] = None
|
self.model_config: Optional[HFConfig] = None
|
||||||
"""Model configuration."""
|
"""Model configuration."""
|
||||||
self.unwrapped_model: Optional[HFModel] = None
|
self.model: Optional[HFModel] = None
|
||||||
"""Unwrapped model."""
|
"""HF model."""
|
||||||
self.model: Optional[DistModel] = None
|
self.is_adapter = False
|
||||||
"""Distributed model."""
|
"""Whether the model has adapter."""
|
||||||
self.init_processor()
|
|
||||||
self.init_model_config()
|
|
||||||
self.init_model()
|
|
||||||
self.init_adapter()
|
|
||||||
|
|
||||||
def init_processor(self) -> None:
|
def init_processor(self) -> None:
|
||||||
|
if self.processor is not None:
|
||||||
|
return
|
||||||
|
|
||||||
self.processor = AutoProcessor.from_pretrained(
|
self.processor = AutoProcessor.from_pretrained(
|
||||||
self.args.model,
|
self.args.model,
|
||||||
trust_remote_code=self.args.trust_remote_code,
|
trust_remote_code=self.args.trust_remote_code,
|
||||||
@@ -55,38 +56,58 @@ class ModelWorker:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def init_model_config(self) -> None:
|
def init_model_config(self) -> None:
|
||||||
|
if self.model_config is not None:
|
||||||
|
return
|
||||||
|
|
||||||
self.model_config = AutoConfig.from_pretrained(
|
self.model_config = AutoConfig.from_pretrained(
|
||||||
self.args.model,
|
self.args.model,
|
||||||
trust_remote_code=self.args.trust_remote_code,
|
trust_remote_code=self.args.trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
def init_model(self) -> None:
|
def init_model(self) -> None:
|
||||||
if self.args.auto_model_class == "causallm":
|
if self.model is not None:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.init_model_config()
|
||||||
|
|
||||||
|
if self.args.auto_class == AutoClass.CAUSALLM:
|
||||||
from transformers import AutoModelForCausalLM, AutoModelForImageTextToText
|
from transformers import AutoModelForCausalLM, AutoModelForImageTextToText
|
||||||
|
|
||||||
if type(self.model_config) in AutoModelForImageTextToText._model_mapping.keys():
|
if type(self.model_config) in AutoModelForImageTextToText._model_mapping.keys():
|
||||||
AutoClass = AutoModelForImageTextToText
|
ModelClass = AutoModelForImageTextToText
|
||||||
else:
|
else:
|
||||||
AutoClass = AutoModelForCausalLM
|
ModelClass = AutoModelForCausalLM
|
||||||
elif self.args.auto_model_class == "classification":
|
elif self.args.auto_class == AutoClass.CLASSIFICATION:
|
||||||
from transformers import AutoModelForTokenClassification
|
from transformers import AutoModelForTokenClassification
|
||||||
|
|
||||||
AutoClass = AutoModelForTokenClassification
|
ModelClass = AutoModelForTokenClassification
|
||||||
else:
|
else:
|
||||||
from transformers import AutoModel
|
from transformers import AutoModel
|
||||||
|
|
||||||
AutoClass = AutoModel
|
ModelClass = AutoModel
|
||||||
|
|
||||||
self.unwrapped_model = AutoClass.from_pretrained(
|
default_device_type = torch.get_default_device().type
|
||||||
self.args.model,
|
if default_device_type == DeviceType.META:
|
||||||
config=self.model_config,
|
self.model = ModelClass.from_config(self.model_config)
|
||||||
dtype="auto",
|
else:
|
||||||
device_map="cpu",
|
self.model = ModelClass.from_pretrained(
|
||||||
trust_remote_code=self.args.trust_remote_code,
|
self.args.model,
|
||||||
)
|
config=self.model_config,
|
||||||
|
dtype="auto",
|
||||||
|
device_map=default_device_type,
|
||||||
|
trust_remote_code=self.args.trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
def init_adapter(self) -> None:
|
def init_adapter(self) -> None:
|
||||||
pass
|
if self.is_adapter:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.args.peft_config is not None:
|
||||||
|
from ..plugins.model_plugins.peft import PeftPlugin
|
||||||
|
|
||||||
|
self.model = PeftPlugin(self.args.peft_config.name)(self.model, self.args.peft_config)
|
||||||
|
|
||||||
|
self.is_adapter = True
|
||||||
|
|
||||||
def get_processor(self) -> Processor:
|
def get_processor(self) -> Processor:
|
||||||
return self.processor
|
return self.processor
|
||||||
@@ -95,4 +116,4 @@ class ModelWorker:
|
|||||||
return self.model_config
|
return self.model_config
|
||||||
|
|
||||||
def get_model(self) -> HFModel:
|
def get_model(self) -> HFModel:
|
||||||
return self.unwrapped_model
|
return self.model
|
||||||
|
|||||||
@@ -15,7 +15,7 @@
|
|||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from ...extras.types import Processor, Tensor, TorchDataset
|
from ...utils.types import Processor, Tensor, TorchDataset
|
||||||
|
|
||||||
|
|
||||||
class DataCollator:
|
class DataCollator:
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ WELCOME = (
|
|||||||
def launch():
|
def launch():
|
||||||
command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
|
command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
|
||||||
|
|
||||||
if command == "sft":
|
if command == "sft": # train command will fallback to sft command
|
||||||
from .trainers.sft_trainer import run_sft
|
from .trainers.sft_trainer import run_sft
|
||||||
|
|
||||||
run_sft()
|
run_sft()
|
||||||
|
|||||||
@@ -13,12 +13,13 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
from typing import Callable, TypedDict
|
from typing import Any, Literal, TypedDict
|
||||||
|
|
||||||
from typing_extensions import NotRequired, Required
|
from typing_extensions import NotRequired
|
||||||
|
|
||||||
from ....extras import logging
|
from ...utils import logging
|
||||||
from ...extras.types import DPOSample, Sample, SFTSample
|
from ...utils.plugin import BasePlugin
|
||||||
|
from ...utils.types import DPOSample, Sample, SFTSample
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -26,35 +27,48 @@ logger = logging.get_logger(__name__)
|
|||||||
|
|
||||||
class AlpacaSample(TypedDict, total=False):
|
class AlpacaSample(TypedDict, total=False):
|
||||||
system: NotRequired[str]
|
system: NotRequired[str]
|
||||||
instruction: NotRequired[str]
|
instruction: str
|
||||||
input: NotRequired[str]
|
input: NotRequired[str]
|
||||||
output: NotRequired[str]
|
output: str
|
||||||
|
|
||||||
|
|
||||||
ShareGPTMessage = TypedDict(
|
SharegptMessage = TypedDict(
|
||||||
"ShareGPTMessage",
|
"SharegptMessage", {"from": Literal["human", "gpt", "system", "function_call", "observation"], "value": str}
|
||||||
{
|
|
||||||
"from": Required[str], # Role of the message sender (e.g., "human", "gpt", "system")
|
|
||||||
"value": Required[str], # Content of the message
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ShareGPTSample(TypedDict, total=False):
|
class SharegptSample(TypedDict, total=False):
|
||||||
"""Type definition for raw ShareGPT sample."""
|
conversations: list[SharegptMessage]
|
||||||
|
tools: NotRequired[str]
|
||||||
|
|
||||||
conversations: Required[list[ShareGPTMessage]]
|
|
||||||
|
class OpenaiMessage(TypedDict, total=False):
|
||||||
|
role: Literal["user", "assistant", "tool"]
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
class OpenaiSample(TypedDict, total=False):
|
||||||
|
messages: list[OpenaiMessage]
|
||||||
|
|
||||||
|
|
||||||
class PairSample(TypedDict, total=False):
|
class PairSample(TypedDict, total=False):
|
||||||
prompt: NotRequired[str]
|
chosen: list[OpenaiMessage]
|
||||||
chosen: NotRequired[list[dict]]
|
rejected: list[OpenaiMessage]
|
||||||
rejected: NotRequired[list[dict]]
|
|
||||||
|
|
||||||
|
|
||||||
|
class DataConverterPlugin(BasePlugin):
|
||||||
|
"""Plugin for data converters."""
|
||||||
|
|
||||||
|
def __call__(self, raw_sample: dict[str, Any]) -> Sample:
|
||||||
|
return super().__call__(raw_sample)
|
||||||
|
|
||||||
|
|
||||||
|
@DataConverterPlugin("alpaca").register
|
||||||
def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
|
def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
|
||||||
"""Convert Alpaca sample to SFT sample.
|
"""Convert Alpaca sample to SFT sample.
|
||||||
|
|
||||||
|
See raw example at: https://huggingface.co/datasets/llamafactory/alpaca_gpt4_en
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
raw_sample (AlpacaSample): Alpaca sample.
|
raw_sample (AlpacaSample): Alpaca sample.
|
||||||
|
|
||||||
@@ -67,20 +81,6 @@ def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
|
|||||||
{"role": "system", "content": [{"type": "text", "value": raw_sample["system"]}], "loss_weight": 0.0}
|
{"role": "system", "content": [{"type": "text", "value": raw_sample["system"]}], "loss_weight": 0.0}
|
||||||
)
|
)
|
||||||
|
|
||||||
if "history" in raw_sample:
|
|
||||||
for idx, item in enumerate(raw_sample["history"]):
|
|
||||||
if len(item) != 2:
|
|
||||||
logger.warning_rank0(
|
|
||||||
f"Warning: History item at index {idx} has invalid length (expected 2, got {len(item)}). Skipping."
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
old_prompt, old_response = item
|
|
||||||
messages.append({"role": "user", "content": [{"type": "text", "value": old_prompt}], "loss_weight": 0.0})
|
|
||||||
messages.append(
|
|
||||||
{"role": "assistant", "content": [{"type": "text", "value": old_response}], "loss_weight": 1.0}
|
|
||||||
)
|
|
||||||
|
|
||||||
if "instruction" in raw_sample or "input" in raw_sample:
|
if "instruction" in raw_sample or "input" in raw_sample:
|
||||||
messages.append(
|
messages.append(
|
||||||
{
|
{
|
||||||
@@ -100,149 +100,85 @@ def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
|
|||||||
return {"messages": messages}
|
return {"messages": messages}
|
||||||
|
|
||||||
|
|
||||||
def sharegpt_converter(raw_sample: ShareGPTSample) -> SFTSample:
|
@DataConverterPlugin("sharegpt").register
|
||||||
"""Converts a raw ShareGPT sample into a formatted SFT (Supervised Fine-Tuning) sample.
|
def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
|
||||||
|
"""Convert ShareGPT sample to SFT sample.
|
||||||
|
|
||||||
Retains only SFT-relevant scenarios and removes parity checks.
|
See raw example at: https://huggingface.co/datasets/llamafactory/glaive_toolcall_en
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
raw_sample (ShareGPTSample): A raw sample in ShareGPT format.
|
raw_sample (SharegptSample): ShareGPT sample.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: A dictionary containing the formatted 'messages' list for SFT training.
|
SFTSample: SFT sample.
|
||||||
Returns an empty list if the input data is invalid.
|
|
||||||
"""
|
"""
|
||||||
tag_mapping = {
|
tag_mapping = {
|
||||||
|
"system": "system",
|
||||||
"human": "user",
|
"human": "user",
|
||||||
"gpt": "assistant",
|
"gpt": "assistant",
|
||||||
"observation": "observation",
|
"observation": "tool",
|
||||||
"function_call": "function",
|
"function_call": "assistant",
|
||||||
}
|
}
|
||||||
messages = raw_sample.get("conversations", [])
|
messages = []
|
||||||
aligned_messages = []
|
tools = raw_sample.get("tools", "")
|
||||||
system_content = ""
|
|
||||||
|
|
||||||
# Extract system message if present (typically the first message)
|
for message in raw_sample.get("conversations", []):
|
||||||
if messages and messages[0]["from"] == "system":
|
tag = message["from"]
|
||||||
system_content = messages[0]["value"]
|
if tag not in tag_mapping:
|
||||||
messages = messages[1:]
|
logger.warning_rank0(f"Unsupported role tag {tag} in message: {message}")
|
||||||
|
elif tag == "function_call":
|
||||||
|
messages.append(
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [{"type": "tool_calls", "value": message["value"]}],
|
||||||
|
"loss_weight": 1.0,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
messages.append(
|
||||||
|
{
|
||||||
|
"role": tag_mapping[tag],
|
||||||
|
"content": [{"type": "text", "value": message["value"]}],
|
||||||
|
"loss_weight": 1.0 if tag == "gpt" else 0.0,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
if system_content:
|
if tools:
|
||||||
aligned_messages.append(
|
if messages and messages[0]["role"] == "system":
|
||||||
{"role": "system", "content": [{"type": "text", "value": system_content}], "loss_weight": 0.0}
|
messages[0]["content"].append({"type": "tools", "value": tools})
|
||||||
)
|
else:
|
||||||
|
messages.insert(0, {"role": "system", "content": [{"type": "tools", "value": tools}], "loss_weight": 0.0})
|
||||||
|
|
||||||
has_invalid_role = False
|
return {"messages": messages}
|
||||||
for message in messages:
|
|
||||||
sender = message["from"]
|
|
||||||
# validate sender is in supported tags
|
|
||||||
if sender not in tag_mapping:
|
|
||||||
logger.warning_rank0(f"Unsupported role tag '{sender}' in message: {message}")
|
|
||||||
has_invalid_role = True
|
|
||||||
break
|
|
||||||
|
|
||||||
aligned_messages.append(
|
|
||||||
{
|
|
||||||
"role": tag_mapping[sender],
|
|
||||||
"content": [{"type": "text", "value": message["value"]}],
|
|
||||||
"loss_weight": 0.0 if sender in ("human", "observation") else 1.0,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
if has_invalid_role:
|
|
||||||
logger.warning_rank0("Skipping invalid example due to unsupported role tags.")
|
|
||||||
return {"messages": []}
|
|
||||||
|
|
||||||
return {"messages": aligned_messages}
|
|
||||||
|
|
||||||
|
|
||||||
|
@DataConverterPlugin("pair").register
|
||||||
def pair_converter(raw_sample: PairSample) -> DPOSample:
|
def pair_converter(raw_sample: PairSample) -> DPOSample:
|
||||||
"""Convert Pair sample to standard DPO sample.
|
"""Convert Pair sample to DPO sample.
|
||||||
|
|
||||||
|
See raw example at: https://huggingface.co/datasets/HuggingFaceH4/orca_dpo_pairs
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
raw_sample (PairSample): pair sample with prompt, chosen, rejected fields.
|
raw_sample (PairSample): pair sample with chosen, rejected fields.
|
||||||
see raw example at: https://huggingface.co/datasets/HuggingFaceH4/orca_dpo_pairs
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
DPOSample: DPO sample with chosen_messages and rejected_messages.
|
DPOSample: DPO sample with chosen_messages and rejected_messages.
|
||||||
see the standard DPO sample at: https://huggingface.co/datasets/frozenleaves/v1-dpo-demo/raw/main/v1-dpo-demo.jsonl
|
|
||||||
"""
|
"""
|
||||||
chosen_messages = []
|
|
||||||
assert "chosen" in raw_sample, "chosen field is required in pair sample."
|
|
||||||
assert "rejected" in raw_sample, "rejected field is required in pair sample."
|
|
||||||
assert isinstance(raw_sample["chosen"], list) and isinstance(raw_sample["rejected"], list), (
|
|
||||||
"chosen and rejected field should be a list[dict], or you may need to implement your custom converter."
|
|
||||||
)
|
|
||||||
|
|
||||||
if "chosen" in raw_sample:
|
def process_message(raw_messages: list[OpenaiMessage]):
|
||||||
value = raw_sample.get("chosen", "")
|
messages = []
|
||||||
for item in value:
|
for message in raw_messages:
|
||||||
if item.get("role", "") == "system":
|
messages.append(
|
||||||
chosen_messages.append(
|
{
|
||||||
{
|
"role": message["role"],
|
||||||
"role": "system",
|
"content": [{"type": "text", "value": message["content"]}],
|
||||||
"content": [{"type": "text", "value": item.get("content", "")}],
|
"loss_weight": 1.0 if message["role"] == "assistant" else 0.0,
|
||||||
"loss_weight": 0.0,
|
}
|
||||||
}
|
)
|
||||||
)
|
|
||||||
if item.get("role", "") == "user":
|
|
||||||
chosen_messages.append(
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [{"type": "text", "value": item.get("content", "")}],
|
|
||||||
"loss_weight": 0.0,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if item.get("role", "") == "assistant":
|
|
||||||
chosen_messages.append(
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": [{"type": "text", "value": item.get("content", "")}],
|
|
||||||
"loss_weight": 1.0,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
rejected_messages = []
|
return messages
|
||||||
if "rejected" in raw_sample:
|
|
||||||
value = raw_sample.get("rejected", "")
|
chosen_messages = process_message(raw_sample.get("chosen", []))
|
||||||
for item in value:
|
rejected_messages = process_message(raw_sample.get("rejected", []))
|
||||||
if item.get("role", "") == "system":
|
|
||||||
rejected_messages.append(
|
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"content": [{"type": "text", "value": item.get("content", "")}],
|
|
||||||
"loss_weight": 0.0,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if item.get("role", "") == "user":
|
|
||||||
rejected_messages.append(
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [{"type": "text", "value": item.get("content", "")}],
|
|
||||||
"loss_weight": 0.0,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if item.get("role", "") == "assistant":
|
|
||||||
rejected_messages.append(
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": [{"type": "text", "value": item.get("content", "")}],
|
|
||||||
"loss_weight": 1.0,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"chosen_messages": chosen_messages, "rejected_messages": rejected_messages}
|
return {"chosen_messages": chosen_messages, "rejected_messages": rejected_messages}
|
||||||
|
|
||||||
|
|
||||||
CONVERTERS = {
|
|
||||||
"alpaca": alpaca_converter,
|
|
||||||
"pair": pair_converter,
|
|
||||||
"sharegpt": sharegpt_converter,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_converter(converter_name: str) -> Callable[[dict], Sample]:
|
|
||||||
if converter_name not in CONVERTERS:
|
|
||||||
raise ValueError(f"Converter {converter_name} not found.")
|
|
||||||
|
|
||||||
return CONVERTERS[converter_name]
|
|
||||||
|
|||||||
@@ -14,57 +14,59 @@
|
|||||||
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
import random
|
||||||
from typing import Any, Literal, Optional, Union
|
from typing import Any, Literal, Optional, Union
|
||||||
|
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|
||||||
from ...extras.types import DatasetInfo, HFDataset
|
from ...utils.plugin import BasePlugin
|
||||||
|
from ...utils.types import DatasetInfo, HFDataset
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
class DataLoaderPlugin(BasePlugin):
|
||||||
class DataLoaderPlugin:
|
|
||||||
"""Plugin for loading dataset."""
|
"""Plugin for loading dataset."""
|
||||||
|
|
||||||
def _get_builder_name(self, path: str) -> Literal["arrow", "csv", "json", "parquet", "text"]:
|
def load(self, dataset_info: DatasetInfo) -> HFDataset:
|
||||||
"""Get dataset builder name.
|
path = dataset_info["path"]
|
||||||
|
|
||||||
Args:
|
|
||||||
path (str): Dataset path.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Literal["arrow", "csv", "json", "parquet", "text"]: Dataset builder name.
|
|
||||||
"""
|
|
||||||
return os.path.splitext(path)[-1][1:].replace("jsonl", "json").replace("txt", "text")
|
|
||||||
|
|
||||||
def auto_load_data(self, dataset_info: DatasetInfo) -> HFDataset:
|
|
||||||
dataset_dir = dataset_info.get("dataset_dir", ".")
|
|
||||||
split = dataset_info.get("split", "train")
|
split = dataset_info.get("split", "train")
|
||||||
streaming = dataset_info.get("streaming", False)
|
streaming = dataset_info.get("streaming", False)
|
||||||
if "file_name" in dataset_info:
|
return super().__call__(path, split, streaming)
|
||||||
filepath = os.path.join(dataset_dir, dataset_info["file_name"])
|
|
||||||
return self.load_data_from_file(filepath, split, streaming)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def load_data_from_file(self, filepath: str, split: str, streaming: bool) -> HFDataset:
|
|
||||||
if os.path.isdir(filepath):
|
|
||||||
filetype = self._get_builder_name(os.listdir(filepath)[0])
|
|
||||||
dataset = load_dataset(filetype, data_dir=filepath, split=split)
|
|
||||||
elif os.path.isfile(filepath):
|
|
||||||
filetype = self._get_builder_name(filepath)
|
|
||||||
dataset = load_dataset(filetype, data_files=filepath, split=split)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Can not load dataset from {filepath}.")
|
|
||||||
|
|
||||||
if streaming:
|
|
||||||
dataset = dataset.to_iterable_dataset()
|
|
||||||
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
def _get_builder_name(path: str) -> Literal["arrow", "csv", "json", "parquet", "text"]:
|
||||||
class DataIndexPlugin:
|
"""Get dataset builder name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): Dataset path.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Literal["arrow", "csv", "json", "parquet", "text"]: Dataset builder name.
|
||||||
|
"""
|
||||||
|
filetype = os.path.splitext(path)[-1][1:]
|
||||||
|
if filetype in ["arrow", "csv", "json", "jsonl", "parquet", "txt"]:
|
||||||
|
return filetype.replace("jsonl", "json").replace("txt", "text")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown dataset filetype: {filetype}.")
|
||||||
|
|
||||||
|
|
||||||
|
@DataLoaderPlugin("local").register
|
||||||
|
def load_data_from_file(filepath: str, split: str, streaming: bool) -> HFDataset:
|
||||||
|
if os.path.isdir(filepath):
|
||||||
|
filetype = _get_builder_name(os.listdir(filepath)[0])
|
||||||
|
dataset = load_dataset(filetype, data_dir=filepath, split=split)
|
||||||
|
elif os.path.isfile(filepath):
|
||||||
|
filetype = _get_builder_name(filepath)
|
||||||
|
dataset = load_dataset(filetype, data_files=filepath, split=split)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Can not load dataset from {filepath}.")
|
||||||
|
|
||||||
|
if streaming: # faster when data is streamed from local files
|
||||||
|
dataset = dataset.to_iterable_dataset()
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
class DataIndexPlugin(BasePlugin):
|
||||||
"""Plugin for adjusting dataset index."""
|
"""Plugin for adjusting dataset index."""
|
||||||
|
|
||||||
def adjust_data_index(
|
def adjust_data_index(
|
||||||
@@ -81,39 +83,32 @@ class DataIndexPlugin:
|
|||||||
list[tuple[str, int]]: Adjusted dataset index.
|
list[tuple[str, int]]: Adjusted dataset index.
|
||||||
"""
|
"""
|
||||||
if size is not None:
|
if size is not None:
|
||||||
data_index = self.adjust_by_size(data_index, size)
|
data_index = random.choices(data_index, k=size)
|
||||||
|
|
||||||
if weight is not None:
|
if weight is not None:
|
||||||
data_index = self.adjust_by_weight(data_index, weight)
|
data_index = random.choices(data_index, k=int(len(data_index) * weight))
|
||||||
|
|
||||||
return data_index
|
return data_index
|
||||||
|
|
||||||
def adjust_by_size(self, data_index: list[tuple[str, int]], size: int) -> list[tuple[str, int]]:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def adjust_by_weight(self, data_index: list[tuple[str, int]], weight: float) -> list[tuple[str, int]]:
|
class DataSelectorPlugin(BasePlugin):
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class DataSelectorPlugin:
|
|
||||||
"""Plugin for selecting dataset samples."""
|
"""Plugin for selecting dataset samples."""
|
||||||
|
|
||||||
data_index: list[tuple[str, int]]
|
def select(
|
||||||
"""List of (dataset_name, sample_index)"""
|
self, data_index: list[tuple[str, int]], index: Union[slice, list[int], Any]
|
||||||
|
) -> Union[tuple[str, int], list[tuple[str, int]]]:
|
||||||
def select(self, index: Union[slice, list[int], Any]) -> Union[tuple[str, int], list[tuple[str, int]]]:
|
|
||||||
"""Select dataset samples.
|
"""Select dataset samples.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
|
||||||
index (Union[slice, list[int], Any]): Index of dataset samples.
|
index (Union[slice, list[int], Any]): Index of dataset samples.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Union[tuple[str, int], list[tuple[str, int]]]: Selected dataset samples.
|
Union[tuple[str, int], list[tuple[str, int]]]: Selected dataset samples.
|
||||||
"""
|
"""
|
||||||
if isinstance(index, slice):
|
if isinstance(index, slice):
|
||||||
return [self.data_index[i] for i in range(*index.indices(len(self.data_index)))]
|
return [data_index[i] for i in range(*index.indices(len(data_index)))]
|
||||||
elif isinstance(index, list):
|
elif isinstance(index, list):
|
||||||
return [self.data_index[i] for i in index]
|
return [data_index[i] for i in index]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid index type {type(index)}.")
|
raise ValueError(f"Invalid index type {type(index)}.")
|
||||||
|
|||||||
@@ -21,10 +21,3 @@ class KernelType(str, Enum):
|
|||||||
FLASH_ATTENTION = "flash_attention"
|
FLASH_ATTENTION = "flash_attention"
|
||||||
ROPE = "rope"
|
ROPE = "rope"
|
||||||
MOE = "moe"
|
MOE = "moe"
|
||||||
|
|
||||||
|
|
||||||
class DeviceType(str, Enum):
|
|
||||||
CPU = "cpu"
|
|
||||||
CUDA = "cuda"
|
|
||||||
NPU = "npu"
|
|
||||||
XPU = "xpu"
|
|
||||||
|
|||||||
@@ -18,10 +18,10 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch_npu
|
import torch_npu
|
||||||
|
|
||||||
from .....accelerator.helper import is_torch_npu_available
|
from .....accelerator.helper import DeviceType, is_torch_npu_available
|
||||||
from .....extras.packages import is_transformers_version_greater_than
|
from .....utils.packages import is_transformers_version_greater_than
|
||||||
from .....extras.types import HFModel
|
from .....utils.types import HFModel
|
||||||
from ..constants import DeviceType, KernelType
|
from ..constants import KernelType
|
||||||
from ..registry import MetaMoEKernel
|
from ..registry import MetaMoEKernel
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -17,9 +17,9 @@ import types
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .....accelerator.helper import is_torch_npu_available
|
from .....accelerator.helper import DeviceType, is_torch_npu_available
|
||||||
from .....extras.types import HFModel
|
from .....utils.types import HFModel
|
||||||
from ..constants import DeviceType, KernelType
|
from ..constants import KernelType
|
||||||
from ..registry import MetaSwiGluKernel
|
from ..registry import MetaSwiGluKernel
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -13,11 +13,11 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from abc import ABC, ABCMeta, abstractmethod
|
from abc import ABC, ABCMeta, abstractmethod
|
||||||
from typing import Any, Callable, Optional
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
from ....accelerator.helper import get_current_accelerator
|
from ....accelerator.helper import DeviceType, get_current_accelerator
|
||||||
from ....extras.types import HFModel
|
from ....utils.types import HFModel
|
||||||
from .constants import DeviceType, KernelType
|
from .constants import KernelType
|
||||||
|
|
||||||
|
|
||||||
class KernelRegistry:
|
class KernelRegistry:
|
||||||
@@ -27,11 +27,13 @@ class KernelRegistry:
|
|||||||
def __new__(cls, *args: Any, **kwargs: Any) -> "KernelRegistry":
|
def __new__(cls, *args: Any, **kwargs: Any) -> "KernelRegistry":
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
cls._instance = super().__new__(cls)
|
cls._instance = super().__new__(cls)
|
||||||
|
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
if self._initialized:
|
if self._initialized:
|
||||||
return
|
return
|
||||||
|
|
||||||
self._registry: dict[KernelType, dict[DeviceType, Callable[..., Any]]] = {}
|
self._registry: dict[KernelType, dict[DeviceType, Callable[..., Any]]] = {}
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
|
|
||||||
@@ -218,7 +220,7 @@ def discover_kernels(model: HFModel = None) -> list[type[MetaKernel]]:
|
|||||||
return discovered_kernels
|
return discovered_kernels
|
||||||
|
|
||||||
# Iterate through registry and collect all kernels for current device
|
# Iterate through registry and collect all kernels for current device
|
||||||
for kernel_type, devices in KERNEL_REGISTRY._registry.items():
|
for devices in KERNEL_REGISTRY._registry.values():
|
||||||
kernel_cls = devices.get(device_type)
|
kernel_cls = devices.get(device_type)
|
||||||
if kernel_cls is not None:
|
if kernel_cls is not None:
|
||||||
discovered_kernels.append(kernel_cls)
|
discovered_kernels.append(kernel_cls)
|
||||||
@@ -226,7 +228,7 @@ def discover_kernels(model: HFModel = None) -> list[type[MetaKernel]]:
|
|||||||
return discovered_kernels
|
return discovered_kernels
|
||||||
|
|
||||||
|
|
||||||
def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> "HFModel":
|
def apply_kernel(model: HFModel, kernel: Union[type[MetaKernel], Any], /, **kwargs) -> "HFModel":
|
||||||
"""Call the MetaKernel's `apply` to perform the replacement.
|
"""Call the MetaKernel's `apply` to perform the replacement.
|
||||||
|
|
||||||
Corresponding replacement logic is maintained inside each kernel; the only
|
Corresponding replacement logic is maintained inside each kernel; the only
|
||||||
@@ -238,16 +240,18 @@ def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> "HFMo
|
|||||||
model = AutoModelForCausalLM.from_pretrained("qwen/qwen2.5-0.5B")
|
model = AutoModelForCausalLM.from_pretrained("qwen/qwen2.5-0.5B")
|
||||||
model = apply_kernel(model, NpuRMSNormKernel)
|
model = apply_kernel(model, NpuRMSNormKernel)
|
||||||
"""
|
"""
|
||||||
if issubclass(kernel, MetaKernel) and kernel.device == get_current_accelerator().type:
|
if not issubclass(kernel, MetaKernel):
|
||||||
return kernel.apply(model, **kwargs)
|
raise ValueError(f"{kernel} must be a MetaKernel instance.")
|
||||||
|
|
||||||
raise ValueError(
|
if kernel.device != get_current_accelerator().type:
|
||||||
f"{kernel} must be a MetaKernel instance, or the kernel don't match the device type. got {kernel.device} and {get_current_accelerator().type} instead."
|
raise ValueError(f"{kernel} must be applied to {kernel.device} device, got {get_current_accelerator().type}.")
|
||||||
)
|
|
||||||
|
return kernel.apply(model, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def apply_available_kernels(model: HFModel, **kwargs) -> "HFModel":
|
def apply_available_kernels(model: HFModel, **kwargs) -> "HFModel":
|
||||||
"""Apply all available kernels to the model."""
|
"""Apply all available kernels to the model."""
|
||||||
for kernel in discover_kernels(model):
|
for kernel in discover_kernels(model):
|
||||||
model = apply_kernel(model, kernel, **kwargs)
|
model = apply_kernel(model, kernel, **kwargs)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|||||||
@@ -14,9 +14,9 @@
|
|||||||
import re
|
import re
|
||||||
import types
|
import types
|
||||||
|
|
||||||
from .....accelerator.helper import is_torch_npu_available
|
from .....accelerator.helper import DeviceType, is_torch_npu_available
|
||||||
from .....extras.types import HFModel
|
from .....utils.types import HFModel
|
||||||
from ..constants import DeviceType, KernelType
|
from ..constants import KernelType
|
||||||
from ..registry import MetaRMSNormKernel
|
from ..registry import MetaRMSNormKernel
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -16,9 +16,9 @@ import sys
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .....accelerator.helper import is_torch_npu_available
|
from .....accelerator.helper import DeviceType, is_torch_npu_available
|
||||||
from .....extras.types import HFModel
|
from .....utils.types import HFModel
|
||||||
from ..constants import DeviceType, KernelType
|
from ..constants import KernelType
|
||||||
from ..registry import MetaRoPEKernel
|
from ..registry import MetaRoPEKernel
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,42 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Literal, TypedDict
|
||||||
|
|
||||||
|
from peft import LoraConfig, PeftModel, get_peft_model
|
||||||
|
|
||||||
|
from ...utils.plugin import BasePlugin
|
||||||
|
from ...utils.types import HFModel
|
||||||
|
|
||||||
|
|
||||||
|
class LoraConfigDict(TypedDict, total=False):
|
||||||
|
name: Literal["lora"]
|
||||||
|
"""Plugin name."""
|
||||||
|
r: int
|
||||||
|
"""Lora rank."""
|
||||||
|
lora_alpha: int
|
||||||
|
"""Lora alpha."""
|
||||||
|
target_modules: list[str]
|
||||||
|
"""Target modules."""
|
||||||
|
|
||||||
|
|
||||||
|
class PeftPlugin(BasePlugin):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@PeftPlugin("lora").register
|
||||||
|
def get_lora_model(model: HFModel, config: LoraConfigDict) -> PeftModel:
|
||||||
|
peft_config = LoraConfig(**config)
|
||||||
|
model = get_peft_model(model, peft_config)
|
||||||
|
return model
|
||||||
|
|||||||
@@ -1,13 +0,0 @@
|
|||||||
# Copyright 2025 the LlamaFactory team.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|||||||
@@ -13,22 +13,21 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
from ..config.parser import get_args
|
from ..accelerator.interface import DistributedInterface, DistributedStrategy
|
||||||
|
from ..config.arg_parser import get_args
|
||||||
from ..core.base_trainer import BaseTrainer
|
from ..core.base_trainer import BaseTrainer
|
||||||
from ..core.data_engine import DataEngine
|
from ..core.data_engine import DataEngine
|
||||||
from ..core.model_engine import ModelEngine
|
from ..core.model_worker import ModelWorker
|
||||||
|
|
||||||
|
|
||||||
class SFTTrainer(BaseTrainer):
|
class SFTTrainer(BaseTrainer):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def run_sft():
|
def run_sft(user_args):
|
||||||
model_args, data_args, training_args, _ = get_args()
|
model_args, data_args, training_args, _ = get_args(user_args)
|
||||||
model_engine = ModelEngine(model_args)
|
DistributedInterface(DistributedStrategy())
|
||||||
data_engine = DataEngine(data_args)
|
data_engine = DataEngine(data_args)
|
||||||
model = model_engine.get_model()
|
model_worker = ModelWorker(model_args)
|
||||||
processor = model_engine.get_processor()
|
trainer = SFTTrainer(training_args, model_worker, data_engine)
|
||||||
data_loader = data_engine.get_data_loader(processor)
|
|
||||||
trainer = SFTTrainer(training_args, model, processor, data_loader)
|
|
||||||
trainer.fit()
|
trainer.fit()
|
||||||
|
|||||||
0
src/llamafactory/v1/utils/__init__.py
Normal file
0
src/llamafactory/v1/utils/__init__.py
Normal file
13
src/llamafactory/v1/utils/constants.py
Normal file
13
src/llamafactory/v1/utils/constants.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
123
src/llamafactory/v1/utils/logging.py
Normal file
123
src/llamafactory/v1/utils/logging.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
# Copyright 2025 Optuna, HuggingFace Inc. and the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# This code is inspired by the HuggingFace's transformers library.
|
||||||
|
# https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/utils/logging.py
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
_thread_lock = threading.RLock()
|
||||||
|
_default_handler: Optional["logging.Handler"] = None
|
||||||
|
_default_log_level: "logging._Level" = logging.INFO
|
||||||
|
|
||||||
|
|
||||||
|
class _Logger(logging.Logger):
|
||||||
|
r"""A logger that supports rank0 logging."""
|
||||||
|
|
||||||
|
def info_rank0(self, *args, **kwargs) -> None:
|
||||||
|
self.info(*args, **kwargs)
|
||||||
|
|
||||||
|
def warning_rank0(self, *args, **kwargs) -> None:
|
||||||
|
self.warning(*args, **kwargs)
|
||||||
|
|
||||||
|
def warning_rank0_once(self, *args, **kwargs) -> None:
|
||||||
|
self.warning(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_default_logging_level() -> "logging._Level":
|
||||||
|
r"""Return the default logging level."""
|
||||||
|
env_level_str = os.getenv("LLAMAFACTORY_VERBOSITY", None)
|
||||||
|
if env_level_str:
|
||||||
|
if env_level_str.upper() in logging._nameToLevel:
|
||||||
|
return logging._nameToLevel[env_level_str.upper()]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown logging level: {env_level_str}.")
|
||||||
|
|
||||||
|
return _default_log_level
|
||||||
|
|
||||||
|
|
||||||
|
def _get_library_name() -> str:
|
||||||
|
return __name__.split(".")[0]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_library_root_logger() -> "_Logger":
|
||||||
|
return logging.getLogger(_get_library_name())
|
||||||
|
|
||||||
|
|
||||||
|
def _configure_library_root_logger() -> None:
|
||||||
|
r"""Configure root logger using a stdout stream handler with an explicit format."""
|
||||||
|
global _default_handler
|
||||||
|
|
||||||
|
with _thread_lock:
|
||||||
|
if _default_handler: # already configured
|
||||||
|
return
|
||||||
|
|
||||||
|
formatter = logging.Formatter(
|
||||||
|
fmt="[%(levelname)s|%(asctime)s] %(name)s:%(lineno)s >> %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
)
|
||||||
|
_default_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
_default_handler.setFormatter(formatter)
|
||||||
|
library_root_logger = _get_library_root_logger()
|
||||||
|
library_root_logger.addHandler(_default_handler)
|
||||||
|
library_root_logger.setLevel(_get_default_logging_level())
|
||||||
|
library_root_logger.propagate = False
|
||||||
|
|
||||||
|
|
||||||
|
def get_logger(name: Optional[str] = None) -> "_Logger":
|
||||||
|
r"""Return a logger with the specified name. It it not supposed to be accessed externally."""
|
||||||
|
if name is None:
|
||||||
|
name = _get_library_name()
|
||||||
|
|
||||||
|
_configure_library_root_logger()
|
||||||
|
return logging.getLogger(name)
|
||||||
|
|
||||||
|
|
||||||
|
def add_handler(handler: "logging.Handler") -> None:
|
||||||
|
r"""Add a handler to the root logger."""
|
||||||
|
_configure_library_root_logger()
|
||||||
|
_get_library_root_logger().addHandler(handler)
|
||||||
|
|
||||||
|
|
||||||
|
def remove_handler(handler: logging.Handler) -> None:
|
||||||
|
r"""Remove a handler to the root logger."""
|
||||||
|
_configure_library_root_logger()
|
||||||
|
_get_library_root_logger().removeHandler(handler)
|
||||||
|
|
||||||
|
|
||||||
|
def info_rank0(self: "logging.Logger", *args, **kwargs) -> None:
|
||||||
|
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||||
|
self.info(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def warning_rank0(self: "logging.Logger", *args, **kwargs) -> None:
|
||||||
|
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||||
|
self.warning(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(None)
|
||||||
|
def warning_rank0_once(self: "logging.Logger", *args, **kwargs) -> None:
|
||||||
|
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||||
|
self.warning(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
logging.Logger.info_rank0 = info_rank0
|
||||||
|
logging.Logger.warning_rank0 = warning_rank0
|
||||||
|
logging.Logger.warning_rank0_once = warning_rank0_once
|
||||||
86
src/llamafactory/v1/utils/plugin.py
Normal file
86
src/llamafactory/v1/utils/plugin.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
from . import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BasePlugin:
|
||||||
|
"""Base class for plugins.
|
||||||
|
|
||||||
|
A plugin is a callable object that can be registered and called by name.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_registry: dict[str, Callable] = {}
|
||||||
|
|
||||||
|
def __init__(self, name: Optional[str] = None):
|
||||||
|
"""Initialize the plugin with a name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): The name of the plugin.
|
||||||
|
"""
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def register(self) -> Callable:
|
||||||
|
"""Decorator to register a function as a plugin.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
```python
|
||||||
|
@PrintPlugin("hello").register()
|
||||||
|
def print_hello():
|
||||||
|
print("Hello world!")
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if self.name is None:
|
||||||
|
raise ValueError("Plugin name is not specified.")
|
||||||
|
|
||||||
|
if self.name in self._registry:
|
||||||
|
logger.warning_rank0_once(f"Plugin {self.name} is already registered.")
|
||||||
|
|
||||||
|
def decorator(func: Callable) -> Callable:
|
||||||
|
self._registry[self.name] = func
|
||||||
|
return func
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs) -> Callable:
|
||||||
|
"""Call the registered function with the given arguments.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
```python
|
||||||
|
PrintPlugin("hello")()
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if self.name not in self._registry:
|
||||||
|
raise ValueError(f"Plugin {self.name} is not registered.")
|
||||||
|
|
||||||
|
return self._registry[self.name](*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
class PrintPlugin(BasePlugin):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@PrintPlugin("hello").register
|
||||||
|
def print_hello():
|
||||||
|
print("Hello world!")
|
||||||
|
|
||||||
|
PrintPlugin("hello")()
|
||||||
@@ -19,23 +19,27 @@ from typing_extensions import NotRequired
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import datasets
|
import datasets
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
import transformers
|
import transformers
|
||||||
|
from torch.distributed.fsdp import FullyShardedDataParallel
|
||||||
|
|
||||||
Tensor = torch.Tensor
|
Tensor = torch.Tensor
|
||||||
|
TensorLike = Union[int, float, list[int], list[float], np.ndarray, Tensor]
|
||||||
TorchDataset = Union[torch.utils.data.Dataset, torch.utils.data.IterableDataset]
|
TorchDataset = Union[torch.utils.data.Dataset, torch.utils.data.IterableDataset]
|
||||||
HFDataset = Union[datasets.Dataset, datasets.IterableDataset]
|
HFDataset = Union[datasets.Dataset, datasets.IterableDataset]
|
||||||
DataCollator = transformers.DataCollator
|
DataCollator = transformers.DataCollator
|
||||||
DataLoader = torch.utils.data.DataLoader
|
DataLoader = torch.utils.data.DataLoader
|
||||||
HFConfig = transformers.PretrainedConfig
|
HFConfig = transformers.PretrainedConfig
|
||||||
HFModel = transformers.PreTrainedModel
|
HFModel = transformers.PreTrainedModel
|
||||||
DistModel = torch.nn.parallel.DistributedDataParallel
|
DistModel = Union[torch.nn.parallel.DistributedDataParallel, FullyShardedDataParallel]
|
||||||
Processor = Union[transformers.PreTrainedTokenizer, transformers.ProcessorMixin]
|
Processor = Union[transformers.PreTrainedTokenizer, transformers.ProcessorMixin]
|
||||||
Optimizer = torch.optim.Optimizer
|
Optimizer = torch.optim.Optimizer
|
||||||
Scheduler = torch.optim.lr_scheduler.LRScheduler
|
Scheduler = torch.optim.lr_scheduler.LRScheduler
|
||||||
else:
|
else:
|
||||||
Tensor = None
|
Tensor = None
|
||||||
|
TensorLike = None
|
||||||
TorchDataset = None
|
TorchDataset = None
|
||||||
HFDataset = None
|
HFDataset = None
|
||||||
DataCollator = None
|
DataCollator = None
|
||||||
@@ -49,12 +53,10 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
class DatasetInfo(TypedDict, total=False):
|
class DatasetInfo(TypedDict, total=False):
|
||||||
hf_hub_url: NotRequired[str]
|
path: str
|
||||||
"""HF hub dataset uri."""
|
|
||||||
file_name: NotRequired[str]
|
|
||||||
"""Local file path."""
|
"""Local file path."""
|
||||||
dataset_dir: NotRequired[str]
|
source: NotRequired[Literal["hf_hub", "ms_hub", "local"]]
|
||||||
"""Dataset directory, default to args.dataset_dir."""
|
"""Dataset source, default to "hf_hub"."""
|
||||||
split: NotRequired[str]
|
split: NotRequired[str]
|
||||||
"""Dataset split, default to "train"."""
|
"""Dataset split, default to "train"."""
|
||||||
converter: NotRequired[str]
|
converter: NotRequired[str]
|
||||||
@@ -68,12 +70,12 @@ class DatasetInfo(TypedDict, total=False):
|
|||||||
|
|
||||||
|
|
||||||
class Content(TypedDict):
|
class Content(TypedDict):
|
||||||
type: Literal["text", "tools", "reasoning", "tool_calls", "image_url"]
|
type: Literal["text", "reasoning", "tools", "tool_calls", "image_url"]
|
||||||
value: str
|
value: str
|
||||||
|
|
||||||
|
|
||||||
class Message(TypedDict):
|
class Message(TypedDict):
|
||||||
role: Literal["system", "user", "assistant"]
|
role: Literal["system", "user", "assistant", "tool"]
|
||||||
content: list[Content]
|
content: list[Content]
|
||||||
loss_weight: float
|
loss_weight: float
|
||||||
|
|
||||||
@@ -12,25 +12,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from torch.distributed.device_mesh import DeviceMesh
|
import os
|
||||||
|
|
||||||
|
from llamafactory.v1.accelerator.interface import DistributedInterface, DistributedStrategy
|
||||||
|
|
||||||
|
|
||||||
class DeviceMeshManager:
|
def test_distributed_interface():
|
||||||
"""Device mesh manager."""
|
DistributedInterface(DistributedStrategy())
|
||||||
|
assert DistributedInterface.rank == int(os.getenv("RANK", "0"))
|
||||||
_instance: Optional["DeviceMeshManager"] = None
|
assert DistributedInterface.world_size == int(os.getenv("WORLD_SIZE", "1"))
|
||||||
_initialized: bool = False
|
|
||||||
|
|
||||||
def __new__(cls) -> "DeviceMeshManager":
|
|
||||||
if cls._instance is None:
|
|
||||||
cls._instance = super().__new__(cls)
|
|
||||||
return cls._instance
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
if self._initialized:
|
|
||||||
return
|
|
||||||
|
|
||||||
self.device_mesh: Optional[DeviceMesh] = None
|
|
||||||
self._initialized = True
|
|
||||||
@@ -19,7 +19,7 @@ from datasets import load_dataset
|
|||||||
|
|
||||||
from llamafactory.v1.config.data_args import DataArguments
|
from llamafactory.v1.config.data_args import DataArguments
|
||||||
from llamafactory.v1.core.data_engine import DataEngine
|
from llamafactory.v1.core.data_engine import DataEngine
|
||||||
from llamafactory.v1.plugins.data_plugins.converter import get_converter
|
from llamafactory.v1.plugins.data_plugins.converter import DataConverterPlugin
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_samples", [16])
|
@pytest.mark.parametrize("num_samples", [16])
|
||||||
@@ -49,99 +49,27 @@ def test_alpaca_converter(num_samples: int):
|
|||||||
assert data_engine[index] == {"_dataset_name": "tiny_dataset", **expected_data}
|
assert data_engine[index] == {"_dataset_name": "tiny_dataset", **expected_data}
|
||||||
|
|
||||||
|
|
||||||
def test_sharegpt_converter_invalid():
|
def test_sharegpt_converter():
|
||||||
example = {
|
example = {
|
||||||
"conversations": [
|
"conversations": [
|
||||||
{
|
{"from": "system", "value": "System"},
|
||||||
"from": "system",
|
{"from": "human", "value": "User"},
|
||||||
"value": "Processes historical market data to generate trading signals "
|
{"from": "gpt", "value": "Assistant"},
|
||||||
"based on specified technical indicators.",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"from": "human",
|
|
||||||
"value": "I possess a detailed dataset, 'Historical_Market_Data.csv'. "
|
|
||||||
"Could you proceed with these function calls to assist me with the task?",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"from": "gpt",
|
|
||||||
"value": "```tool_call\n{'arguments': '{\"data_file\": \"Historical_Market_Data.csv\"]}', "
|
|
||||||
"'name': 'backtest_trading_signals'}```\n",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"from": "tool",
|
|
||||||
"value": '<tool id="D2">\n{"analysis": {"RSI_signals": [{"date": "2025-01-10", '
|
|
||||||
'"symbol": "AAPL", "signal": "Buy"}]}}}\n</tool>\n',
|
|
||||||
},
|
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
dataset_converter = get_converter("sharegpt")
|
|
||||||
assert dataset_converter(example) == {"messages": []}
|
|
||||||
|
|
||||||
|
|
||||||
def test_sharegpt_converter_valid():
|
|
||||||
example = {
|
|
||||||
"conversations": [
|
|
||||||
{
|
|
||||||
"from": "system",
|
|
||||||
"value": "Processes historical market data to generate trading signals based on "
|
|
||||||
"specified technical indicators.",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"from": "human",
|
|
||||||
"value": "I possess a detailed dataset, 'Historical_Market_Data.csv'. "
|
|
||||||
"Could you proceed with these function calls to assist me with the task?",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"from": "gpt",
|
|
||||||
"value": "```tool_call\n{'arguments': '{\"data_file\": \"Historical_Market_Data.csv\"]}', "
|
|
||||||
"'name': 'backtest_trading_signals'}```\n",
|
|
||||||
},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
dataset_converter = get_converter("sharegpt")
|
|
||||||
expected_data = {
|
expected_data = {
|
||||||
"messages": [
|
"messages": [
|
||||||
{
|
{"content": [{"type": "text", "value": "System"}], "loss_weight": 0.0, "role": "system"},
|
||||||
"content": [
|
{"content": [{"type": "text", "value": "User"}], "loss_weight": 0.0, "role": "user"},
|
||||||
{
|
{"content": [{"type": "text", "value": "Assistant"}], "loss_weight": 1.0, "role": "assistant"},
|
||||||
"type": "text",
|
|
||||||
"value": "Processes historical market data to generate trading signals based on "
|
|
||||||
"specified technical indicators.",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"loss_weight": 0.0,
|
|
||||||
"role": "system",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"value": "I possess a detailed dataset, 'Historical_Market_Data.csv'. "
|
|
||||||
"Could you proceed with these function calls to assist me with the task?",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"loss_weight": 0.0,
|
|
||||||
"role": "user",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"value": "```tool_call\n{'arguments': '{\"data_file\": \"Historical_Market_Data.csv\"]}', "
|
|
||||||
"'name': 'backtest_trading_signals'}```\n",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"loss_weight": 1.0,
|
|
||||||
"role": "assistant",
|
|
||||||
},
|
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
assert dataset_converter(example) == expected_data
|
assert DataConverterPlugin("sharegpt")(example) == expected_data
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_samples", [16])
|
@pytest.mark.parametrize("num_samples", [16])
|
||||||
def test_pair_converter(num_samples: int):
|
def test_pair_converter(num_samples: int):
|
||||||
data_args = DataArguments(dataset="frozenleaves/tiny-dpo/dataset_info.yaml")
|
data_args = DataArguments(dataset="llamafactory/tiny-preference-dataset/dataset_info.yaml")
|
||||||
data_engine = DataEngine(data_args)
|
data_engine = DataEngine(data_args)
|
||||||
original_data = load_dataset("HuggingFaceH4/orca_dpo_pairs", split="train_prefs")
|
original_data = load_dataset("HuggingFaceH4/orca_dpo_pairs", split="train_prefs")
|
||||||
indexes = random.choices(range(len(data_engine)), k=num_samples)
|
indexes = random.choices(range(len(data_engine)), k=num_samples)
|
||||||
@@ -189,6 +117,5 @@ def test_pair_converter(num_samples: int):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_alpaca_converter(1)
|
test_alpaca_converter(1)
|
||||||
test_sharegpt_converter_invalid()
|
test_sharegpt_converter()
|
||||||
test_sharegpt_converter_valid()
|
|
||||||
test_pair_converter(1)
|
test_pair_converter(1)
|
||||||
|
|||||||
@@ -17,10 +17,13 @@ from unittest.mock import MagicMock, patch
|
|||||||
|
|
||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
|
from llamafactory.v1.accelerator.helper import get_current_accelerator
|
||||||
|
|
||||||
|
|
||||||
class TestKernelPlugin(unittest.TestCase):
|
class TestKernelPlugin(unittest.TestCase):
|
||||||
@patch("torch.accelerator.current_accelerator")
|
@patch("torch.accelerator.current_accelerator")
|
||||||
def test_apply_kernel(self, mock_get_accelerator):
|
def test_apply_kernel(self, mock_get_accelerator):
|
||||||
|
get_current_accelerator.cache_clear()
|
||||||
mock_device = MagicMock()
|
mock_device = MagicMock()
|
||||||
mock_device.type = "npu"
|
mock_device.type = "npu"
|
||||||
mock_get_accelerator.return_value = mock_device
|
mock_get_accelerator.return_value = mock_device
|
||||||
@@ -47,6 +50,7 @@ class TestKernelPlugin(unittest.TestCase):
|
|||||||
class Test_Use_V1_Kernels(unittest.TestCase):
|
class Test_Use_V1_Kernels(unittest.TestCase):
|
||||||
@patch("torch.accelerator.current_accelerator")
|
@patch("torch.accelerator.current_accelerator")
|
||||||
def test_use_v1_kernels(self, mock_get_accelerator):
|
def test_use_v1_kernels(self, mock_get_accelerator):
|
||||||
|
get_current_accelerator.cache_clear()
|
||||||
mock_device = MagicMock()
|
mock_device = MagicMock()
|
||||||
mock_device.type = "npu"
|
mock_device.type = "npu"
|
||||||
mock_get_accelerator.return_value = mock_device
|
mock_get_accelerator.return_value = mock_device
|
||||||
|
|||||||
Reference in New Issue
Block a user