mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-14 19:06:26 +08:00
[v1] add accelerator (#9607)
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
dpo_zh_demo:
|
||||
hf_hub_url: HuggingFaceH4/orca_dpo_pairs
|
||||
path: HuggingFaceH4/orca_dpo_pairs
|
||||
split: train_prefs
|
||||
converter: pair
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
identity:
|
||||
file_name: data/identity.json
|
||||
path: data/identity.json
|
||||
source: local
|
||||
converter: alpaca
|
||||
alpaca_en_demo:
|
||||
file_name: alpaca_en_demo.json
|
||||
dataset_dir: data
|
||||
path: data/alpaca_en_demo.json
|
||||
source: local
|
||||
converter: alpaca
|
||||
num_samples: 500
|
||||
size: 500
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
# Copyright 2025 Bytedance Ltd. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the Bytedance's VeOmni library.
|
||||
# https://github.com/ByteDance-Seed/VeOmni/blob/v0.1.4/veomni/utils/dist_utils.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -12,12 +15,68 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum, unique
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from ..utils.types import Tensor, TensorLike
|
||||
|
||||
|
||||
def get_current_accelerator(check_available: bool = True):
|
||||
if TYPE_CHECKING:
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
|
||||
@unique
|
||||
class DeviceType(str, Enum):
|
||||
CPU = "cpu"
|
||||
CUDA = "cuda"
|
||||
META = "meta"
|
||||
MPS = "mps"
|
||||
NPU = "npu"
|
||||
XPU = "xpu"
|
||||
|
||||
|
||||
@unique
|
||||
class ReduceOp(str, Enum):
|
||||
SUM = "sum"
|
||||
MEAN = "mean"
|
||||
MAX = "max"
|
||||
MIN = "min"
|
||||
|
||||
|
||||
def is_distributed() -> bool:
|
||||
"""Check if distributed environment is available."""
|
||||
return os.getenv("RANK") is not None
|
||||
|
||||
|
||||
def get_rank() -> int:
|
||||
"""Get rank."""
|
||||
return int(os.getenv("RANK", "0"))
|
||||
|
||||
|
||||
def get_local_rank() -> int:
|
||||
"""Get local rank."""
|
||||
return int(os.getenv("LOCAL_RANK", "0"))
|
||||
|
||||
|
||||
def get_world_size() -> int:
|
||||
"""Get world size."""
|
||||
return int(os.getenv("WORLD_SIZE", "1"))
|
||||
|
||||
|
||||
def get_local_world_size() -> int:
|
||||
"""Get local world size."""
|
||||
return int(os.getenv("LOCAL_WORLD_SIZE", "1"))
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_current_accelerator(check_available: bool = True) -> torch.device:
|
||||
"""Get current accelerator.
|
||||
|
||||
Note: this api requires torch>=2.7.0, 2.6 or lower will get an AttributeError or RuntimeError
|
||||
@@ -27,26 +86,78 @@ def get_current_accelerator(check_available: bool = True):
|
||||
|
||||
accelerator = torch.accelerator.current_accelerator(check_available=check_available)
|
||||
if accelerator is None:
|
||||
return torch.device("cpu")
|
||||
return torch.device(DeviceType.CPU.value)
|
||||
|
||||
return accelerator
|
||||
|
||||
|
||||
@lru_cache
|
||||
def is_torch_npu_available():
|
||||
return get_current_accelerator().type == "npu"
|
||||
|
||||
|
||||
@lru_cache
|
||||
def is_torch_cuda_available():
|
||||
return get_current_accelerator().type == "cuda"
|
||||
return get_current_accelerator().type == DeviceType.CUDA
|
||||
|
||||
|
||||
@lru_cache
|
||||
def is_torch_xpu_available():
|
||||
return get_current_accelerator().type == "xpu"
|
||||
|
||||
|
||||
@lru_cache
|
||||
def is_torch_mps_available():
|
||||
return get_current_accelerator().type == "mps"
|
||||
return get_current_accelerator().type == DeviceType.MPS
|
||||
|
||||
|
||||
def is_torch_npu_available():
|
||||
return get_current_accelerator().type == DeviceType.NPU
|
||||
|
||||
|
||||
def is_torch_xpu_available():
|
||||
return get_current_accelerator().type == DeviceType.XPU
|
||||
|
||||
|
||||
def all_gather(tensor: Tensor, group: Optional["ProcessGroup"] = None) -> Tensor:
|
||||
"""Gathers the tensor from all ranks and concats them along the first dim."""
|
||||
world_size = get_world_size()
|
||||
device = get_current_accelerator()
|
||||
output_tensor = torch.empty(world_size * tensor.numel(), dtype=tensor.dtype, device=device)
|
||||
dist.all_gather_into_tensor(output_tensor, tensor, group=group)
|
||||
return output_tensor.view(-1, *tensor.size()[1:])
|
||||
|
||||
|
||||
def all_reduce(data: TensorLike, op: ReduceOp = ReduceOp.MEAN, group: Optional["ProcessGroup"] = None) -> TensorLike:
|
||||
"""Performs all reduce in the given process group."""
|
||||
device = get_current_accelerator()
|
||||
is_ndarray = isinstance(data, np.ndarray)
|
||||
is_tensor = isinstance(data, torch.Tensor)
|
||||
|
||||
if is_ndarray:
|
||||
data = torch.from_numpy(data)
|
||||
elif not is_tensor:
|
||||
data = torch.tensor(data, dtype=torch.float, device=device)
|
||||
|
||||
reduce_ops = {
|
||||
ReduceOp.MEAN: dist.ReduceOp.SUM,
|
||||
ReduceOp.SUM: dist.ReduceOp.SUM,
|
||||
ReduceOp.MAX: dist.ReduceOp.MAX,
|
||||
ReduceOp.MIN: dist.ReduceOp.MIN,
|
||||
}
|
||||
dist.all_reduce(data, op=reduce_ops[op], group=group)
|
||||
if op == ReduceOp.MEAN: # ReduceOp.AVG is not supported by the NPU backend
|
||||
data /= dist.get_world_size(group=group)
|
||||
|
||||
if is_tensor:
|
||||
return data
|
||||
elif is_ndarray:
|
||||
return data.numpy()
|
||||
elif data.numel() == 1:
|
||||
return data.item()
|
||||
else:
|
||||
return data.tolist()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def main_process_first(local_only: bool = True) -> None:
|
||||
"""A context manager for torch distributed environment to do something on the main process firstly."""
|
||||
if get_world_size() > 1:
|
||||
is_main_process = get_local_rank() == 0 if local_only else get_rank() == 0
|
||||
try:
|
||||
if not is_main_process:
|
||||
dist.barrier()
|
||||
yield
|
||||
finally:
|
||||
if is_main_process:
|
||||
dist.barrier()
|
||||
else:
|
||||
yield
|
||||
|
||||
123
src/llamafactory/v1/accelerator/interface.py
Normal file
123
src/llamafactory/v1/accelerator/interface.py
Normal file
@@ -0,0 +1,123 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
|
||||
|
||||
from ..utils.types import TensorLike
|
||||
from .helper import ReduceOp, all_reduce, get_current_accelerator, get_rank, get_world_size, is_distributed
|
||||
|
||||
|
||||
@dataclass
|
||||
class DistributedStrategy:
|
||||
"""Distributed strategy."""
|
||||
|
||||
dp_size: Optional[int] = None
|
||||
tp_size: int = 1
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not is_distributed():
|
||||
self.dp_size = 1
|
||||
elif self.dp_size is None:
|
||||
self.dp_size = get_world_size() // self.tp_size
|
||||
elif self.dp_size * self.tp_size != get_world_size():
|
||||
raise ValueError(
|
||||
f"dp_size * tp_size must equal to world_size, "
|
||||
f"got {self.dp_size} * {self.tp_size} != {get_world_size()}."
|
||||
)
|
||||
|
||||
@property
|
||||
def mesh_shape(self) -> tuple[int, int]:
|
||||
"""Mesh shape."""
|
||||
return (self.dp_size, self.tp_size)
|
||||
|
||||
@property
|
||||
def mesh_dim_names(self) -> tuple[str, str]:
|
||||
"""Mesh dimension names."""
|
||||
return ("dp", "tp")
|
||||
|
||||
|
||||
class DistributedInterface:
|
||||
"""Distributed interface."""
|
||||
|
||||
_instance: Optional["DistributedInterface"] = None
|
||||
_initialized: bool = False
|
||||
|
||||
is_distributed = is_distributed()
|
||||
"""Check if distributed environment is available."""
|
||||
rank = get_rank()
|
||||
"""Global rank."""
|
||||
world_size = get_world_size()
|
||||
"""Global world size."""
|
||||
device_mesh: Optional[DeviceMesh] = None
|
||||
"""Device mesh."""
|
||||
current_accelerator = get_current_accelerator()
|
||||
"""Current accelerator."""
|
||||
|
||||
def __new__(cls, *args: Any, **kwargs: Any) -> "DistributedInterface":
|
||||
"""Singleton pattern."""
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, strategy: DistributedStrategy) -> None:
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self.strategy = strategy
|
||||
if self.is_distributed:
|
||||
self.device_mesh = init_device_mesh(
|
||||
device_type=self.current_accelerator.type,
|
||||
mesh_shape=strategy.mesh_shape,
|
||||
mesh_dim_names=strategy.mesh_dim_names,
|
||||
)
|
||||
else:
|
||||
self.device_mesh = None
|
||||
|
||||
self._initialized = True
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f"DistributedInterface(strategy={self.strategy}), is_distributed={self.is_distributed}, "
|
||||
f"rank={self.rank}, world_size={self.world_size}, "
|
||||
f"device_mesh={self.device_mesh}, current_accelerator={self.current_accelerator}"
|
||||
)
|
||||
|
||||
def dp_rank(self) -> int:
|
||||
"""Data parallel rank."""
|
||||
if self.device_mesh is None:
|
||||
return 0
|
||||
|
||||
return self.device_mesh["dp"].get_rank()
|
||||
|
||||
def dp_size(self) -> int:
|
||||
"""Data parallel size."""
|
||||
if self.device_mesh is None:
|
||||
return 1
|
||||
|
||||
return self.device_mesh["dp"].size()
|
||||
|
||||
def all_reduce_over_dp(self, data: TensorLike, op: ReduceOp = ReduceOp.MEAN) -> TensorLike:
|
||||
"""All reduce tensor."""
|
||||
if self.device_mesh is None:
|
||||
return data
|
||||
|
||||
return all_reduce(data, op, self.device_mesh["dp"].get_group())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(DistributedInterface(DistributedStrategy()))
|
||||
@@ -28,6 +28,21 @@ from .sample_args import SampleArguments
|
||||
from .training_args import TrainingArguments
|
||||
|
||||
|
||||
def validate_args(
|
||||
data_args: DataArguments,
|
||||
model_args: ModelArguments,
|
||||
training_args: TrainingArguments,
|
||||
sample_args: SampleArguments,
|
||||
):
|
||||
"""Validate arguments."""
|
||||
if (
|
||||
model_args.quant_config is not None
|
||||
and training_args.dist_config is not None
|
||||
and training_args.dist_config.name == "deepspeed"
|
||||
):
|
||||
raise ValueError("Quantization is not supported with deepspeed backend.")
|
||||
|
||||
|
||||
def get_args(
|
||||
args: Optional[Union[dict[str, Any], list[str]]] = None,
|
||||
) -> tuple[DataArguments, ModelArguments, TrainingArguments, SampleArguments]:
|
||||
@@ -56,6 +71,8 @@ def get_args(
|
||||
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
|
||||
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
|
||||
|
||||
validate_args(*parsed_args)
|
||||
|
||||
return tuple(parsed_args)
|
||||
|
||||
|
||||
105
src/llamafactory/v1/config/arg_utils.py
Normal file
105
src/llamafactory/v1/config/arg_utils.py
Normal file
@@ -0,0 +1,105 @@
|
||||
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's transformers library.
|
||||
# https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/training_args.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import json
|
||||
from enum import Enum, unique
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
|
||||
class PluginConfig(dict):
|
||||
"""Dictionary that allows attribute access."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Plugin name."""
|
||||
if "name" not in self:
|
||||
raise ValueError("Plugin configuration must have a 'name' field.")
|
||||
|
||||
return self["name"]
|
||||
|
||||
def __getattr__(self, key: str) -> Any:
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
raise AttributeError(f"Attribute {key} not found.")
|
||||
|
||||
def __setattr__(self, key: str, value: Any):
|
||||
self[key] = value
|
||||
|
||||
|
||||
PluginArgument = Optional[Union[PluginConfig, dict, str]]
|
||||
|
||||
|
||||
@unique
|
||||
class AutoClass(str, Enum):
|
||||
"""Auto class for model config."""
|
||||
|
||||
CAUSALLM = "llm"
|
||||
CLASSIFICATION = "cls"
|
||||
OTHER = "other"
|
||||
|
||||
|
||||
@unique
|
||||
class SampleBackend(str, Enum):
|
||||
HF = "hf"
|
||||
VLLM = "vllm"
|
||||
|
||||
|
||||
def _convert_str_dict(data: dict) -> dict:
|
||||
"""Parse string representation inside the dictionary.
|
||||
|
||||
Args:
|
||||
data: The string or dictionary to convert.
|
||||
|
||||
Returns:
|
||||
The converted dictionary.
|
||||
"""
|
||||
for key, value in data.items():
|
||||
if isinstance(value, dict):
|
||||
data[key] = _convert_str_dict(value)
|
||||
elif isinstance(value, str):
|
||||
if value.lower() in ("true", "false"):
|
||||
data[key] = value.lower() == "true"
|
||||
elif value.isdigit():
|
||||
data[key] = int(value)
|
||||
elif value.replace(".", "", 1).isdigit():
|
||||
data[key] = float(value)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def get_plugin_config(config: PluginArgument) -> Optional[PluginConfig]:
|
||||
"""Get the plugin configuration from the argument value.
|
||||
|
||||
Args:
|
||||
config: The argument value to get the plugin configuration from.
|
||||
|
||||
Returns:
|
||||
The plugin configuration.
|
||||
"""
|
||||
if config is None:
|
||||
return None
|
||||
|
||||
if isinstance(config, str) and config.startswith("{"):
|
||||
config = json.loads(config)
|
||||
|
||||
config = _convert_str_dict(config)
|
||||
if "name" not in config:
|
||||
raise ValueError("Plugin configuration must have a 'name' field.")
|
||||
|
||||
return PluginConfig(config)
|
||||
@@ -15,6 +15,8 @@
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from .arg_utils import AutoClass, PluginConfig, get_plugin_config
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
@@ -29,7 +31,24 @@ class ModelArguments:
|
||||
default=True,
|
||||
metadata={"help": "Use fast processor from Hugging Face."},
|
||||
)
|
||||
auto_model_class: str = field(
|
||||
default="causallm",
|
||||
auto_class: AutoClass = field(
|
||||
default=AutoClass.CAUSALLM,
|
||||
metadata={"help": "Model class from Hugging Face."},
|
||||
)
|
||||
peft_config: PluginConfig = field(
|
||||
default=None,
|
||||
metadata={"help": "PEFT configuration for the model."},
|
||||
)
|
||||
kernel_config: PluginConfig = field(
|
||||
default=None,
|
||||
metadata={"help": "Kernel configuration for the model."},
|
||||
)
|
||||
quant_config: PluginConfig = field(
|
||||
default=None,
|
||||
metadata={"help": "Quantization configuration for the model."},
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.peft_config = get_plugin_config(self.peft_config)
|
||||
self.kernel_config = get_plugin_config(self.kernel_config)
|
||||
self.quant_config = get_plugin_config(self.quant_config)
|
||||
|
||||
@@ -14,12 +14,8 @@
|
||||
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class SampleBackend(Enum):
|
||||
HF = "hf"
|
||||
VLLM = "vllm"
|
||||
from .arg_utils import SampleBackend
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -15,6 +15,8 @@
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from .arg_utils import PluginArgument, get_plugin_config
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments:
|
||||
@@ -38,3 +40,10 @@ class TrainingArguments:
|
||||
default=False,
|
||||
metadata={"help": "Use bf16 for training."},
|
||||
)
|
||||
dist_config: PluginArgument = field(
|
||||
default=None,
|
||||
metadata={"help": "Distribution configuration for training."},
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.dist_config = get_plugin_config(self.dist_config)
|
||||
|
||||
@@ -29,7 +29,7 @@ Train Phase:
|
||||
"""
|
||||
|
||||
from ..config.training_args import TrainingArguments
|
||||
from ..extras.types import TorchDataset
|
||||
from ..utils.types import TorchDataset
|
||||
from .model_worker import ModelWorker
|
||||
from .trainer_utils.data_collator import DataCollator
|
||||
|
||||
@@ -49,13 +49,10 @@ class BaseTrainer:
|
||||
self.optimizer = None
|
||||
self.lr_scheduler = None
|
||||
|
||||
def init_device_mesh(self) -> None:
|
||||
pass
|
||||
|
||||
def init_model_and_optimizer(self) -> None:
|
||||
self.model_config = self.model_worker.get_model_config()
|
||||
self.model_worker.init_model_config()
|
||||
# with self.dist_plugin.get_model_init_context():
|
||||
# self.model = self.model_worker.get_model(self.model_config)
|
||||
# self.model = self.model_worker.init_model(self.model_config)
|
||||
|
||||
def create_dataloader(self) -> None:
|
||||
pass
|
||||
|
||||
@@ -34,7 +34,7 @@ from omegaconf import OmegaConf
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from ..config.data_args import DataArguments
|
||||
from ..extras.types import DatasetInfo, HFDataset, Sample
|
||||
from ..utils.types import DatasetInfo, HFDataset, Sample
|
||||
|
||||
|
||||
class DataEngine(Dataset):
|
||||
@@ -64,9 +64,9 @@ class DataEngine(Dataset):
|
||||
filepath = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")
|
||||
self.dataset_infos = OmegaConf.load(filepath)
|
||||
elif os.path.exists(self.args.dataset): # local file(s)
|
||||
self.dataset_infos = {"default": {"file_name": self.args.dataset}}
|
||||
self.dataset_infos = {"default": {"path": self.args.dataset, "source": "local"}}
|
||||
else: # hf hub dataset, e.g. llamafactory/v1-sft-demo
|
||||
self.dataset_infos = {"default": {"hf_hub_url": self.args.dataset}}
|
||||
self.dataset_infos = {"default": {"path": self.args.dataset}}
|
||||
|
||||
def load_dataset(self) -> None:
|
||||
"""Load datasets according to dataset info."""
|
||||
@@ -74,14 +74,14 @@ class DataEngine(Dataset):
|
||||
split = value.get("split", "train")
|
||||
streaming = value.get("streaming", False)
|
||||
self.streaming |= streaming
|
||||
if "hf_hub_url" in value:
|
||||
if value.get("source", "hf_hub") == "hf_hub":
|
||||
from datasets import load_dataset
|
||||
|
||||
self.datasets[key] = load_dataset(value["hf_hub_url"], split=split, streaming=streaming)
|
||||
self.datasets[key] = load_dataset(value["path"], split=split, streaming=streaming)
|
||||
else: # data loader plugin
|
||||
from ..plugins.data_plugins.loader import DataLoaderPlugin
|
||||
|
||||
self.datasets[key] = DataLoaderPlugin().auto_load_data(value)
|
||||
self.datasets[key] = DataLoaderPlugin(value["source"]).load(value)
|
||||
|
||||
def build_data_index(self) -> None:
|
||||
"""Build dataset index."""
|
||||
@@ -112,9 +112,9 @@ class DataEngine(Dataset):
|
||||
"""
|
||||
converter = self.dataset_infos[dataset_name].get("converter")
|
||||
if converter is not None:
|
||||
from ..plugins.data_plugins.converter import get_converter
|
||||
from ..plugins.data_plugins.converter import DataConverterPlugin
|
||||
|
||||
return {"_dataset_name": dataset_name, **get_converter(converter)(raw_sample)}
|
||||
return {"_dataset_name": dataset_name, **DataConverterPlugin(converter)(raw_sample)}
|
||||
else:
|
||||
return {"_dataset_name": dataset_name, **raw_sample}
|
||||
|
||||
@@ -147,7 +147,7 @@ class DataEngine(Dataset):
|
||||
else:
|
||||
from ..plugins.data_plugins.loader import DataSelectorPlugin
|
||||
|
||||
selected_index = DataSelectorPlugin(data_index=self.data_index).select(index)
|
||||
selected_index = DataSelectorPlugin().select(self.data_index, index)
|
||||
if isinstance(selected_index, list):
|
||||
return [
|
||||
self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
|
||||
@@ -187,7 +187,7 @@ class DataEngine(Dataset):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from ..config.parser import get_args
|
||||
from ..config.arg_parser import get_args
|
||||
|
||||
data_args, *_ = get_args()
|
||||
data_engine = DataEngine(data_args=data_args)
|
||||
|
||||
@@ -24,10 +24,12 @@ Init Phase:
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoProcessor
|
||||
|
||||
from ..config.model_args import ModelArguments
|
||||
from ..extras.types import DistModel, HFConfig, HFModel, Processor
|
||||
from ..accelerator.helper import DeviceType
|
||||
from ..config.model_args import AutoClass, ModelArguments
|
||||
from ..utils.types import HFConfig, HFModel, Processor
|
||||
|
||||
|
||||
class ModelWorker:
|
||||
@@ -38,16 +40,15 @@ class ModelWorker:
|
||||
"""Tokenizer or multi-modal processor."""
|
||||
self.model_config: Optional[HFConfig] = None
|
||||
"""Model configuration."""
|
||||
self.unwrapped_model: Optional[HFModel] = None
|
||||
"""Unwrapped model."""
|
||||
self.model: Optional[DistModel] = None
|
||||
"""Distributed model."""
|
||||
self.init_processor()
|
||||
self.init_model_config()
|
||||
self.init_model()
|
||||
self.init_adapter()
|
||||
self.model: Optional[HFModel] = None
|
||||
"""HF model."""
|
||||
self.is_adapter = False
|
||||
"""Whether the model has adapter."""
|
||||
|
||||
def init_processor(self) -> None:
|
||||
if self.processor is not None:
|
||||
return
|
||||
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
self.args.model,
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
@@ -55,38 +56,58 @@ class ModelWorker:
|
||||
)
|
||||
|
||||
def init_model_config(self) -> None:
|
||||
if self.model_config is not None:
|
||||
return
|
||||
|
||||
self.model_config = AutoConfig.from_pretrained(
|
||||
self.args.model,
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
)
|
||||
|
||||
def init_model(self) -> None:
|
||||
if self.args.auto_model_class == "causallm":
|
||||
if self.model is not None:
|
||||
return
|
||||
|
||||
self.init_model_config()
|
||||
|
||||
if self.args.auto_class == AutoClass.CAUSALLM:
|
||||
from transformers import AutoModelForCausalLM, AutoModelForImageTextToText
|
||||
|
||||
if type(self.model_config) in AutoModelForImageTextToText._model_mapping.keys():
|
||||
AutoClass = AutoModelForImageTextToText
|
||||
ModelClass = AutoModelForImageTextToText
|
||||
else:
|
||||
AutoClass = AutoModelForCausalLM
|
||||
elif self.args.auto_model_class == "classification":
|
||||
ModelClass = AutoModelForCausalLM
|
||||
elif self.args.auto_class == AutoClass.CLASSIFICATION:
|
||||
from transformers import AutoModelForTokenClassification
|
||||
|
||||
AutoClass = AutoModelForTokenClassification
|
||||
ModelClass = AutoModelForTokenClassification
|
||||
else:
|
||||
from transformers import AutoModel
|
||||
|
||||
AutoClass = AutoModel
|
||||
ModelClass = AutoModel
|
||||
|
||||
self.unwrapped_model = AutoClass.from_pretrained(
|
||||
self.args.model,
|
||||
config=self.model_config,
|
||||
dtype="auto",
|
||||
device_map="cpu",
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
)
|
||||
default_device_type = torch.get_default_device().type
|
||||
if default_device_type == DeviceType.META:
|
||||
self.model = ModelClass.from_config(self.model_config)
|
||||
else:
|
||||
self.model = ModelClass.from_pretrained(
|
||||
self.args.model,
|
||||
config=self.model_config,
|
||||
dtype="auto",
|
||||
device_map=default_device_type,
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
)
|
||||
|
||||
def init_adapter(self) -> None:
|
||||
pass
|
||||
if self.is_adapter:
|
||||
return
|
||||
|
||||
if self.args.peft_config is not None:
|
||||
from ..plugins.model_plugins.peft import PeftPlugin
|
||||
|
||||
self.model = PeftPlugin(self.args.peft_config.name)(self.model, self.args.peft_config)
|
||||
|
||||
self.is_adapter = True
|
||||
|
||||
def get_processor(self) -> Processor:
|
||||
return self.processor
|
||||
@@ -95,4 +116,4 @@ class ModelWorker:
|
||||
return self.model_config
|
||||
|
||||
def get_model(self) -> HFModel:
|
||||
return self.unwrapped_model
|
||||
return self.model
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ...extras.types import Processor, Tensor, TorchDataset
|
||||
from ...utils.types import Processor, Tensor, TorchDataset
|
||||
|
||||
|
||||
class DataCollator:
|
||||
|
||||
@@ -44,7 +44,7 @@ WELCOME = (
|
||||
def launch():
|
||||
command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
|
||||
|
||||
if command == "sft":
|
||||
if command == "sft": # train command will fallback to sft command
|
||||
from .trainers.sft_trainer import run_sft
|
||||
|
||||
run_sft()
|
||||
|
||||
@@ -13,12 +13,13 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Callable, TypedDict
|
||||
from typing import Any, Literal, TypedDict
|
||||
|
||||
from typing_extensions import NotRequired, Required
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from ....extras import logging
|
||||
from ...extras.types import DPOSample, Sample, SFTSample
|
||||
from ...utils import logging
|
||||
from ...utils.plugin import BasePlugin
|
||||
from ...utils.types import DPOSample, Sample, SFTSample
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -26,35 +27,48 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
class AlpacaSample(TypedDict, total=False):
|
||||
system: NotRequired[str]
|
||||
instruction: NotRequired[str]
|
||||
instruction: str
|
||||
input: NotRequired[str]
|
||||
output: NotRequired[str]
|
||||
output: str
|
||||
|
||||
|
||||
ShareGPTMessage = TypedDict(
|
||||
"ShareGPTMessage",
|
||||
{
|
||||
"from": Required[str], # Role of the message sender (e.g., "human", "gpt", "system")
|
||||
"value": Required[str], # Content of the message
|
||||
},
|
||||
SharegptMessage = TypedDict(
|
||||
"SharegptMessage", {"from": Literal["human", "gpt", "system", "function_call", "observation"], "value": str}
|
||||
)
|
||||
|
||||
|
||||
class ShareGPTSample(TypedDict, total=False):
|
||||
"""Type definition for raw ShareGPT sample."""
|
||||
class SharegptSample(TypedDict, total=False):
|
||||
conversations: list[SharegptMessage]
|
||||
tools: NotRequired[str]
|
||||
|
||||
conversations: Required[list[ShareGPTMessage]]
|
||||
|
||||
class OpenaiMessage(TypedDict, total=False):
|
||||
role: Literal["user", "assistant", "tool"]
|
||||
content: str
|
||||
|
||||
|
||||
class OpenaiSample(TypedDict, total=False):
|
||||
messages: list[OpenaiMessage]
|
||||
|
||||
|
||||
class PairSample(TypedDict, total=False):
|
||||
prompt: NotRequired[str]
|
||||
chosen: NotRequired[list[dict]]
|
||||
rejected: NotRequired[list[dict]]
|
||||
chosen: list[OpenaiMessage]
|
||||
rejected: list[OpenaiMessage]
|
||||
|
||||
|
||||
class DataConverterPlugin(BasePlugin):
|
||||
"""Plugin for data converters."""
|
||||
|
||||
def __call__(self, raw_sample: dict[str, Any]) -> Sample:
|
||||
return super().__call__(raw_sample)
|
||||
|
||||
|
||||
@DataConverterPlugin("alpaca").register
|
||||
def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
|
||||
"""Convert Alpaca sample to SFT sample.
|
||||
|
||||
See raw example at: https://huggingface.co/datasets/llamafactory/alpaca_gpt4_en
|
||||
|
||||
Args:
|
||||
raw_sample (AlpacaSample): Alpaca sample.
|
||||
|
||||
@@ -67,20 +81,6 @@ def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
|
||||
{"role": "system", "content": [{"type": "text", "value": raw_sample["system"]}], "loss_weight": 0.0}
|
||||
)
|
||||
|
||||
if "history" in raw_sample:
|
||||
for idx, item in enumerate(raw_sample["history"]):
|
||||
if len(item) != 2:
|
||||
logger.warning_rank0(
|
||||
f"Warning: History item at index {idx} has invalid length (expected 2, got {len(item)}). Skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
old_prompt, old_response = item
|
||||
messages.append({"role": "user", "content": [{"type": "text", "value": old_prompt}], "loss_weight": 0.0})
|
||||
messages.append(
|
||||
{"role": "assistant", "content": [{"type": "text", "value": old_response}], "loss_weight": 1.0}
|
||||
)
|
||||
|
||||
if "instruction" in raw_sample or "input" in raw_sample:
|
||||
messages.append(
|
||||
{
|
||||
@@ -100,149 +100,85 @@ def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
|
||||
return {"messages": messages}
|
||||
|
||||
|
||||
def sharegpt_converter(raw_sample: ShareGPTSample) -> SFTSample:
|
||||
"""Converts a raw ShareGPT sample into a formatted SFT (Supervised Fine-Tuning) sample.
|
||||
@DataConverterPlugin("sharegpt").register
|
||||
def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
|
||||
"""Convert ShareGPT sample to SFT sample.
|
||||
|
||||
Retains only SFT-relevant scenarios and removes parity checks.
|
||||
See raw example at: https://huggingface.co/datasets/llamafactory/glaive_toolcall_en
|
||||
|
||||
Args:
|
||||
raw_sample (ShareGPTSample): A raw sample in ShareGPT format.
|
||||
raw_sample (SharegptSample): ShareGPT sample.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the formatted 'messages' list for SFT training.
|
||||
Returns an empty list if the input data is invalid.
|
||||
SFTSample: SFT sample.
|
||||
"""
|
||||
tag_mapping = {
|
||||
"system": "system",
|
||||
"human": "user",
|
||||
"gpt": "assistant",
|
||||
"observation": "observation",
|
||||
"function_call": "function",
|
||||
"observation": "tool",
|
||||
"function_call": "assistant",
|
||||
}
|
||||
messages = raw_sample.get("conversations", [])
|
||||
aligned_messages = []
|
||||
system_content = ""
|
||||
messages = []
|
||||
tools = raw_sample.get("tools", "")
|
||||
|
||||
# Extract system message if present (typically the first message)
|
||||
if messages and messages[0]["from"] == "system":
|
||||
system_content = messages[0]["value"]
|
||||
messages = messages[1:]
|
||||
for message in raw_sample.get("conversations", []):
|
||||
tag = message["from"]
|
||||
if tag not in tag_mapping:
|
||||
logger.warning_rank0(f"Unsupported role tag {tag} in message: {message}")
|
||||
elif tag == "function_call":
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "tool_calls", "value": message["value"]}],
|
||||
"loss_weight": 1.0,
|
||||
}
|
||||
)
|
||||
else:
|
||||
messages.append(
|
||||
{
|
||||
"role": tag_mapping[tag],
|
||||
"content": [{"type": "text", "value": message["value"]}],
|
||||
"loss_weight": 1.0 if tag == "gpt" else 0.0,
|
||||
}
|
||||
)
|
||||
|
||||
if system_content:
|
||||
aligned_messages.append(
|
||||
{"role": "system", "content": [{"type": "text", "value": system_content}], "loss_weight": 0.0}
|
||||
)
|
||||
if tools:
|
||||
if messages and messages[0]["role"] == "system":
|
||||
messages[0]["content"].append({"type": "tools", "value": tools})
|
||||
else:
|
||||
messages.insert(0, {"role": "system", "content": [{"type": "tools", "value": tools}], "loss_weight": 0.0})
|
||||
|
||||
has_invalid_role = False
|
||||
for message in messages:
|
||||
sender = message["from"]
|
||||
# validate sender is in supported tags
|
||||
if sender not in tag_mapping:
|
||||
logger.warning_rank0(f"Unsupported role tag '{sender}' in message: {message}")
|
||||
has_invalid_role = True
|
||||
break
|
||||
|
||||
aligned_messages.append(
|
||||
{
|
||||
"role": tag_mapping[sender],
|
||||
"content": [{"type": "text", "value": message["value"]}],
|
||||
"loss_weight": 0.0 if sender in ("human", "observation") else 1.0,
|
||||
}
|
||||
)
|
||||
|
||||
if has_invalid_role:
|
||||
logger.warning_rank0("Skipping invalid example due to unsupported role tags.")
|
||||
return {"messages": []}
|
||||
|
||||
return {"messages": aligned_messages}
|
||||
return {"messages": messages}
|
||||
|
||||
|
||||
@DataConverterPlugin("pair").register
|
||||
def pair_converter(raw_sample: PairSample) -> DPOSample:
|
||||
"""Convert Pair sample to standard DPO sample.
|
||||
"""Convert Pair sample to DPO sample.
|
||||
|
||||
See raw example at: https://huggingface.co/datasets/HuggingFaceH4/orca_dpo_pairs
|
||||
|
||||
Args:
|
||||
raw_sample (PairSample): pair sample with prompt, chosen, rejected fields.
|
||||
see raw example at: https://huggingface.co/datasets/HuggingFaceH4/orca_dpo_pairs
|
||||
raw_sample (PairSample): pair sample with chosen, rejected fields.
|
||||
|
||||
Returns:
|
||||
DPOSample: DPO sample with chosen_messages and rejected_messages.
|
||||
see the standard DPO sample at: https://huggingface.co/datasets/frozenleaves/v1-dpo-demo/raw/main/v1-dpo-demo.jsonl
|
||||
"""
|
||||
chosen_messages = []
|
||||
assert "chosen" in raw_sample, "chosen field is required in pair sample."
|
||||
assert "rejected" in raw_sample, "rejected field is required in pair sample."
|
||||
assert isinstance(raw_sample["chosen"], list) and isinstance(raw_sample["rejected"], list), (
|
||||
"chosen and rejected field should be a list[dict], or you may need to implement your custom converter."
|
||||
)
|
||||
|
||||
if "chosen" in raw_sample:
|
||||
value = raw_sample.get("chosen", "")
|
||||
for item in value:
|
||||
if item.get("role", "") == "system":
|
||||
chosen_messages.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "value": item.get("content", "")}],
|
||||
"loss_weight": 0.0,
|
||||
}
|
||||
)
|
||||
if item.get("role", "") == "user":
|
||||
chosen_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "value": item.get("content", "")}],
|
||||
"loss_weight": 0.0,
|
||||
}
|
||||
)
|
||||
if item.get("role", "") == "assistant":
|
||||
chosen_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "value": item.get("content", "")}],
|
||||
"loss_weight": 1.0,
|
||||
}
|
||||
)
|
||||
def process_message(raw_messages: list[OpenaiMessage]):
|
||||
messages = []
|
||||
for message in raw_messages:
|
||||
messages.append(
|
||||
{
|
||||
"role": message["role"],
|
||||
"content": [{"type": "text", "value": message["content"]}],
|
||||
"loss_weight": 1.0 if message["role"] == "assistant" else 0.0,
|
||||
}
|
||||
)
|
||||
|
||||
rejected_messages = []
|
||||
if "rejected" in raw_sample:
|
||||
value = raw_sample.get("rejected", "")
|
||||
for item in value:
|
||||
if item.get("role", "") == "system":
|
||||
rejected_messages.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "value": item.get("content", "")}],
|
||||
"loss_weight": 0.0,
|
||||
}
|
||||
)
|
||||
if item.get("role", "") == "user":
|
||||
rejected_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "value": item.get("content", "")}],
|
||||
"loss_weight": 0.0,
|
||||
}
|
||||
)
|
||||
if item.get("role", "") == "assistant":
|
||||
rejected_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "value": item.get("content", "")}],
|
||||
"loss_weight": 1.0,
|
||||
}
|
||||
)
|
||||
return messages
|
||||
|
||||
chosen_messages = process_message(raw_sample.get("chosen", []))
|
||||
rejected_messages = process_message(raw_sample.get("rejected", []))
|
||||
|
||||
return {"chosen_messages": chosen_messages, "rejected_messages": rejected_messages}
|
||||
|
||||
|
||||
CONVERTERS = {
|
||||
"alpaca": alpaca_converter,
|
||||
"pair": pair_converter,
|
||||
"sharegpt": sharegpt_converter,
|
||||
}
|
||||
|
||||
|
||||
def get_converter(converter_name: str) -> Callable[[dict], Sample]:
|
||||
if converter_name not in CONVERTERS:
|
||||
raise ValueError(f"Converter {converter_name} not found.")
|
||||
|
||||
return CONVERTERS[converter_name]
|
||||
|
||||
@@ -14,57 +14,59 @@
|
||||
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
import random
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from ...extras.types import DatasetInfo, HFDataset
|
||||
from ...utils.plugin import BasePlugin
|
||||
from ...utils.types import DatasetInfo, HFDataset
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataLoaderPlugin:
|
||||
class DataLoaderPlugin(BasePlugin):
|
||||
"""Plugin for loading dataset."""
|
||||
|
||||
def _get_builder_name(self, path: str) -> Literal["arrow", "csv", "json", "parquet", "text"]:
|
||||
"""Get dataset builder name.
|
||||
|
||||
Args:
|
||||
path (str): Dataset path.
|
||||
|
||||
Returns:
|
||||
Literal["arrow", "csv", "json", "parquet", "text"]: Dataset builder name.
|
||||
"""
|
||||
return os.path.splitext(path)[-1][1:].replace("jsonl", "json").replace("txt", "text")
|
||||
|
||||
def auto_load_data(self, dataset_info: DatasetInfo) -> HFDataset:
|
||||
dataset_dir = dataset_info.get("dataset_dir", ".")
|
||||
def load(self, dataset_info: DatasetInfo) -> HFDataset:
|
||||
path = dataset_info["path"]
|
||||
split = dataset_info.get("split", "train")
|
||||
streaming = dataset_info.get("streaming", False)
|
||||
if "file_name" in dataset_info:
|
||||
filepath = os.path.join(dataset_dir, dataset_info["file_name"])
|
||||
return self.load_data_from_file(filepath, split, streaming)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def load_data_from_file(self, filepath: str, split: str, streaming: bool) -> HFDataset:
|
||||
if os.path.isdir(filepath):
|
||||
filetype = self._get_builder_name(os.listdir(filepath)[0])
|
||||
dataset = load_dataset(filetype, data_dir=filepath, split=split)
|
||||
elif os.path.isfile(filepath):
|
||||
filetype = self._get_builder_name(filepath)
|
||||
dataset = load_dataset(filetype, data_files=filepath, split=split)
|
||||
else:
|
||||
raise ValueError(f"Can not load dataset from {filepath}.")
|
||||
|
||||
if streaming:
|
||||
dataset = dataset.to_iterable_dataset()
|
||||
|
||||
return dataset
|
||||
return super().__call__(path, split, streaming)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataIndexPlugin:
|
||||
def _get_builder_name(path: str) -> Literal["arrow", "csv", "json", "parquet", "text"]:
|
||||
"""Get dataset builder name.
|
||||
|
||||
Args:
|
||||
path (str): Dataset path.
|
||||
|
||||
Returns:
|
||||
Literal["arrow", "csv", "json", "parquet", "text"]: Dataset builder name.
|
||||
"""
|
||||
filetype = os.path.splitext(path)[-1][1:]
|
||||
if filetype in ["arrow", "csv", "json", "jsonl", "parquet", "txt"]:
|
||||
return filetype.replace("jsonl", "json").replace("txt", "text")
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset filetype: {filetype}.")
|
||||
|
||||
|
||||
@DataLoaderPlugin("local").register
|
||||
def load_data_from_file(filepath: str, split: str, streaming: bool) -> HFDataset:
|
||||
if os.path.isdir(filepath):
|
||||
filetype = _get_builder_name(os.listdir(filepath)[0])
|
||||
dataset = load_dataset(filetype, data_dir=filepath, split=split)
|
||||
elif os.path.isfile(filepath):
|
||||
filetype = _get_builder_name(filepath)
|
||||
dataset = load_dataset(filetype, data_files=filepath, split=split)
|
||||
else:
|
||||
raise ValueError(f"Can not load dataset from {filepath}.")
|
||||
|
||||
if streaming: # faster when data is streamed from local files
|
||||
dataset = dataset.to_iterable_dataset()
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
class DataIndexPlugin(BasePlugin):
|
||||
"""Plugin for adjusting dataset index."""
|
||||
|
||||
def adjust_data_index(
|
||||
@@ -81,39 +83,32 @@ class DataIndexPlugin:
|
||||
list[tuple[str, int]]: Adjusted dataset index.
|
||||
"""
|
||||
if size is not None:
|
||||
data_index = self.adjust_by_size(data_index, size)
|
||||
data_index = random.choices(data_index, k=size)
|
||||
|
||||
if weight is not None:
|
||||
data_index = self.adjust_by_weight(data_index, weight)
|
||||
data_index = random.choices(data_index, k=int(len(data_index) * weight))
|
||||
|
||||
return data_index
|
||||
|
||||
def adjust_by_size(self, data_index: list[tuple[str, int]], size: int) -> list[tuple[str, int]]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def adjust_by_weight(self, data_index: list[tuple[str, int]], weight: float) -> list[tuple[str, int]]:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataSelectorPlugin:
|
||||
class DataSelectorPlugin(BasePlugin):
|
||||
"""Plugin for selecting dataset samples."""
|
||||
|
||||
data_index: list[tuple[str, int]]
|
||||
"""List of (dataset_name, sample_index)"""
|
||||
|
||||
def select(self, index: Union[slice, list[int], Any]) -> Union[tuple[str, int], list[tuple[str, int]]]:
|
||||
def select(
|
||||
self, data_index: list[tuple[str, int]], index: Union[slice, list[int], Any]
|
||||
) -> Union[tuple[str, int], list[tuple[str, int]]]:
|
||||
"""Select dataset samples.
|
||||
|
||||
Args:
|
||||
data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
|
||||
index (Union[slice, list[int], Any]): Index of dataset samples.
|
||||
|
||||
Returns:
|
||||
Union[tuple[str, int], list[tuple[str, int]]]: Selected dataset samples.
|
||||
"""
|
||||
if isinstance(index, slice):
|
||||
return [self.data_index[i] for i in range(*index.indices(len(self.data_index)))]
|
||||
return [data_index[i] for i in range(*index.indices(len(data_index)))]
|
||||
elif isinstance(index, list):
|
||||
return [self.data_index[i] for i in index]
|
||||
return [data_index[i] for i in index]
|
||||
else:
|
||||
raise ValueError(f"Invalid index type {type(index)}.")
|
||||
|
||||
@@ -21,10 +21,3 @@ class KernelType(str, Enum):
|
||||
FLASH_ATTENTION = "flash_attention"
|
||||
ROPE = "rope"
|
||||
MOE = "moe"
|
||||
|
||||
|
||||
class DeviceType(str, Enum):
|
||||
CPU = "cpu"
|
||||
CUDA = "cuda"
|
||||
NPU = "npu"
|
||||
XPU = "xpu"
|
||||
|
||||
@@ -18,10 +18,10 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
import torch_npu
|
||||
|
||||
from .....accelerator.helper import is_torch_npu_available
|
||||
from .....extras.packages import is_transformers_version_greater_than
|
||||
from .....extras.types import HFModel
|
||||
from ..constants import DeviceType, KernelType
|
||||
from .....accelerator.helper import DeviceType, is_torch_npu_available
|
||||
from .....utils.packages import is_transformers_version_greater_than
|
||||
from .....utils.types import HFModel
|
||||
from ..constants import KernelType
|
||||
from ..registry import MetaMoEKernel
|
||||
|
||||
|
||||
|
||||
@@ -17,9 +17,9 @@ import types
|
||||
|
||||
import torch
|
||||
|
||||
from .....accelerator.helper import is_torch_npu_available
|
||||
from .....extras.types import HFModel
|
||||
from ..constants import DeviceType, KernelType
|
||||
from .....accelerator.helper import DeviceType, is_torch_npu_available
|
||||
from .....utils.types import HFModel
|
||||
from ..constants import KernelType
|
||||
from ..registry import MetaSwiGluKernel
|
||||
|
||||
|
||||
|
||||
@@ -13,11 +13,11 @@
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC, ABCMeta, abstractmethod
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
from ....accelerator.helper import get_current_accelerator
|
||||
from ....extras.types import HFModel
|
||||
from .constants import DeviceType, KernelType
|
||||
from ....accelerator.helper import DeviceType, get_current_accelerator
|
||||
from ....utils.types import HFModel
|
||||
from .constants import KernelType
|
||||
|
||||
|
||||
class KernelRegistry:
|
||||
@@ -27,11 +27,13 @@ class KernelRegistry:
|
||||
def __new__(cls, *args: Any, **kwargs: Any) -> "KernelRegistry":
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
|
||||
return cls._instance
|
||||
|
||||
def __init__(self) -> None:
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self._registry: dict[KernelType, dict[DeviceType, Callable[..., Any]]] = {}
|
||||
self._initialized = True
|
||||
|
||||
@@ -218,7 +220,7 @@ def discover_kernels(model: HFModel = None) -> list[type[MetaKernel]]:
|
||||
return discovered_kernels
|
||||
|
||||
# Iterate through registry and collect all kernels for current device
|
||||
for kernel_type, devices in KERNEL_REGISTRY._registry.items():
|
||||
for devices in KERNEL_REGISTRY._registry.values():
|
||||
kernel_cls = devices.get(device_type)
|
||||
if kernel_cls is not None:
|
||||
discovered_kernels.append(kernel_cls)
|
||||
@@ -226,7 +228,7 @@ def discover_kernels(model: HFModel = None) -> list[type[MetaKernel]]:
|
||||
return discovered_kernels
|
||||
|
||||
|
||||
def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> "HFModel":
|
||||
def apply_kernel(model: HFModel, kernel: Union[type[MetaKernel], Any], /, **kwargs) -> "HFModel":
|
||||
"""Call the MetaKernel's `apply` to perform the replacement.
|
||||
|
||||
Corresponding replacement logic is maintained inside each kernel; the only
|
||||
@@ -238,16 +240,18 @@ def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> "HFMo
|
||||
model = AutoModelForCausalLM.from_pretrained("qwen/qwen2.5-0.5B")
|
||||
model = apply_kernel(model, NpuRMSNormKernel)
|
||||
"""
|
||||
if issubclass(kernel, MetaKernel) and kernel.device == get_current_accelerator().type:
|
||||
return kernel.apply(model, **kwargs)
|
||||
if not issubclass(kernel, MetaKernel):
|
||||
raise ValueError(f"{kernel} must be a MetaKernel instance.")
|
||||
|
||||
raise ValueError(
|
||||
f"{kernel} must be a MetaKernel instance, or the kernel don't match the device type. got {kernel.device} and {get_current_accelerator().type} instead."
|
||||
)
|
||||
if kernel.device != get_current_accelerator().type:
|
||||
raise ValueError(f"{kernel} must be applied to {kernel.device} device, got {get_current_accelerator().type}.")
|
||||
|
||||
return kernel.apply(model, **kwargs)
|
||||
|
||||
|
||||
def apply_available_kernels(model: HFModel, **kwargs) -> "HFModel":
|
||||
"""Apply all available kernels to the model."""
|
||||
for kernel in discover_kernels(model):
|
||||
model = apply_kernel(model, kernel, **kwargs)
|
||||
|
||||
return model
|
||||
|
||||
@@ -14,9 +14,9 @@
|
||||
import re
|
||||
import types
|
||||
|
||||
from .....accelerator.helper import is_torch_npu_available
|
||||
from .....extras.types import HFModel
|
||||
from ..constants import DeviceType, KernelType
|
||||
from .....accelerator.helper import DeviceType, is_torch_npu_available
|
||||
from .....utils.types import HFModel
|
||||
from ..constants import KernelType
|
||||
from ..registry import MetaRMSNormKernel
|
||||
|
||||
|
||||
|
||||
@@ -16,9 +16,9 @@ import sys
|
||||
|
||||
import torch
|
||||
|
||||
from .....accelerator.helper import is_torch_npu_available
|
||||
from .....extras.types import HFModel
|
||||
from ..constants import DeviceType, KernelType
|
||||
from .....accelerator.helper import DeviceType, is_torch_npu_available
|
||||
from .....utils.types import HFModel
|
||||
from ..constants import KernelType
|
||||
from ..registry import MetaRoPEKernel
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Literal, TypedDict
|
||||
|
||||
from peft import LoraConfig, PeftModel, get_peft_model
|
||||
|
||||
from ...utils.plugin import BasePlugin
|
||||
from ...utils.types import HFModel
|
||||
|
||||
|
||||
class LoraConfigDict(TypedDict, total=False):
|
||||
name: Literal["lora"]
|
||||
"""Plugin name."""
|
||||
r: int
|
||||
"""Lora rank."""
|
||||
lora_alpha: int
|
||||
"""Lora alpha."""
|
||||
target_modules: list[str]
|
||||
"""Target modules."""
|
||||
|
||||
|
||||
class PeftPlugin(BasePlugin):
|
||||
pass
|
||||
|
||||
|
||||
@PeftPlugin("lora").register
|
||||
def get_lora_model(model: HFModel, config: LoraConfigDict) -> PeftModel:
|
||||
peft_config = LoraConfig(**config)
|
||||
model = get_peft_model(model, peft_config)
|
||||
return model
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
@@ -13,22 +13,21 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from ..config.parser import get_args
|
||||
from ..accelerator.interface import DistributedInterface, DistributedStrategy
|
||||
from ..config.arg_parser import get_args
|
||||
from ..core.base_trainer import BaseTrainer
|
||||
from ..core.data_engine import DataEngine
|
||||
from ..core.model_engine import ModelEngine
|
||||
from ..core.model_worker import ModelWorker
|
||||
|
||||
|
||||
class SFTTrainer(BaseTrainer):
|
||||
pass
|
||||
|
||||
|
||||
def run_sft():
|
||||
model_args, data_args, training_args, _ = get_args()
|
||||
model_engine = ModelEngine(model_args)
|
||||
def run_sft(user_args):
|
||||
model_args, data_args, training_args, _ = get_args(user_args)
|
||||
DistributedInterface(DistributedStrategy())
|
||||
data_engine = DataEngine(data_args)
|
||||
model = model_engine.get_model()
|
||||
processor = model_engine.get_processor()
|
||||
data_loader = data_engine.get_data_loader(processor)
|
||||
trainer = SFTTrainer(training_args, model, processor, data_loader)
|
||||
model_worker = ModelWorker(model_args)
|
||||
trainer = SFTTrainer(training_args, model_worker, data_engine)
|
||||
trainer.fit()
|
||||
|
||||
0
src/llamafactory/v1/utils/__init__.py
Normal file
0
src/llamafactory/v1/utils/__init__.py
Normal file
13
src/llamafactory/v1/utils/constants.py
Normal file
13
src/llamafactory/v1/utils/constants.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
123
src/llamafactory/v1/utils/logging.py
Normal file
123
src/llamafactory/v1/utils/logging.py
Normal file
@@ -0,0 +1,123 @@
|
||||
# Copyright 2025 Optuna, HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's transformers library.
|
||||
# https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/utils/logging.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
|
||||
|
||||
_thread_lock = threading.RLock()
|
||||
_default_handler: Optional["logging.Handler"] = None
|
||||
_default_log_level: "logging._Level" = logging.INFO
|
||||
|
||||
|
||||
class _Logger(logging.Logger):
|
||||
r"""A logger that supports rank0 logging."""
|
||||
|
||||
def info_rank0(self, *args, **kwargs) -> None:
|
||||
self.info(*args, **kwargs)
|
||||
|
||||
def warning_rank0(self, *args, **kwargs) -> None:
|
||||
self.warning(*args, **kwargs)
|
||||
|
||||
def warning_rank0_once(self, *args, **kwargs) -> None:
|
||||
self.warning(*args, **kwargs)
|
||||
|
||||
|
||||
def _get_default_logging_level() -> "logging._Level":
|
||||
r"""Return the default logging level."""
|
||||
env_level_str = os.getenv("LLAMAFACTORY_VERBOSITY", None)
|
||||
if env_level_str:
|
||||
if env_level_str.upper() in logging._nameToLevel:
|
||||
return logging._nameToLevel[env_level_str.upper()]
|
||||
else:
|
||||
raise ValueError(f"Unknown logging level: {env_level_str}.")
|
||||
|
||||
return _default_log_level
|
||||
|
||||
|
||||
def _get_library_name() -> str:
|
||||
return __name__.split(".")[0]
|
||||
|
||||
|
||||
def _get_library_root_logger() -> "_Logger":
|
||||
return logging.getLogger(_get_library_name())
|
||||
|
||||
|
||||
def _configure_library_root_logger() -> None:
|
||||
r"""Configure root logger using a stdout stream handler with an explicit format."""
|
||||
global _default_handler
|
||||
|
||||
with _thread_lock:
|
||||
if _default_handler: # already configured
|
||||
return
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="[%(levelname)s|%(asctime)s] %(name)s:%(lineno)s >> %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
_default_handler = logging.StreamHandler(sys.stdout)
|
||||
_default_handler.setFormatter(formatter)
|
||||
library_root_logger = _get_library_root_logger()
|
||||
library_root_logger.addHandler(_default_handler)
|
||||
library_root_logger.setLevel(_get_default_logging_level())
|
||||
library_root_logger.propagate = False
|
||||
|
||||
|
||||
def get_logger(name: Optional[str] = None) -> "_Logger":
|
||||
r"""Return a logger with the specified name. It it not supposed to be accessed externally."""
|
||||
if name is None:
|
||||
name = _get_library_name()
|
||||
|
||||
_configure_library_root_logger()
|
||||
return logging.getLogger(name)
|
||||
|
||||
|
||||
def add_handler(handler: "logging.Handler") -> None:
|
||||
r"""Add a handler to the root logger."""
|
||||
_configure_library_root_logger()
|
||||
_get_library_root_logger().addHandler(handler)
|
||||
|
||||
|
||||
def remove_handler(handler: logging.Handler) -> None:
|
||||
r"""Remove a handler to the root logger."""
|
||||
_configure_library_root_logger()
|
||||
_get_library_root_logger().removeHandler(handler)
|
||||
|
||||
|
||||
def info_rank0(self: "logging.Logger", *args, **kwargs) -> None:
|
||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||
self.info(*args, **kwargs)
|
||||
|
||||
|
||||
def warning_rank0(self: "logging.Logger", *args, **kwargs) -> None:
|
||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||
self.warning(*args, **kwargs)
|
||||
|
||||
|
||||
@lru_cache(None)
|
||||
def warning_rank0_once(self: "logging.Logger", *args, **kwargs) -> None:
|
||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||
self.warning(*args, **kwargs)
|
||||
|
||||
|
||||
logging.Logger.info_rank0 = info_rank0
|
||||
logging.Logger.warning_rank0 = warning_rank0
|
||||
logging.Logger.warning_rank0_once = warning_rank0_once
|
||||
86
src/llamafactory/v1/utils/plugin.py
Normal file
86
src/llamafactory/v1/utils/plugin.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
from . import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class BasePlugin:
|
||||
"""Base class for plugins.
|
||||
|
||||
A plugin is a callable object that can be registered and called by name.
|
||||
"""
|
||||
|
||||
_registry: dict[str, Callable] = {}
|
||||
|
||||
def __init__(self, name: Optional[str] = None):
|
||||
"""Initialize the plugin with a name.
|
||||
|
||||
Args:
|
||||
name (str): The name of the plugin.
|
||||
"""
|
||||
self.name = name
|
||||
|
||||
@property
|
||||
def register(self) -> Callable:
|
||||
"""Decorator to register a function as a plugin.
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
@PrintPlugin("hello").register()
|
||||
def print_hello():
|
||||
print("Hello world!")
|
||||
```
|
||||
"""
|
||||
if self.name is None:
|
||||
raise ValueError("Plugin name is not specified.")
|
||||
|
||||
if self.name in self._registry:
|
||||
logger.warning_rank0_once(f"Plugin {self.name} is already registered.")
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
self._registry[self.name] = func
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Callable:
|
||||
"""Call the registered function with the given arguments.
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
PrintPlugin("hello")()
|
||||
```
|
||||
"""
|
||||
if self.name not in self._registry:
|
||||
raise ValueError(f"Plugin {self.name} is not registered.")
|
||||
|
||||
return self._registry[self.name](*args, **kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
class PrintPlugin(BasePlugin):
|
||||
pass
|
||||
|
||||
@PrintPlugin("hello").register
|
||||
def print_hello():
|
||||
print("Hello world!")
|
||||
|
||||
PrintPlugin("hello")()
|
||||
@@ -19,23 +19,27 @@ from typing_extensions import NotRequired
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import datasets
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import transformers
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel
|
||||
|
||||
Tensor = torch.Tensor
|
||||
TensorLike = Union[int, float, list[int], list[float], np.ndarray, Tensor]
|
||||
TorchDataset = Union[torch.utils.data.Dataset, torch.utils.data.IterableDataset]
|
||||
HFDataset = Union[datasets.Dataset, datasets.IterableDataset]
|
||||
DataCollator = transformers.DataCollator
|
||||
DataLoader = torch.utils.data.DataLoader
|
||||
HFConfig = transformers.PretrainedConfig
|
||||
HFModel = transformers.PreTrainedModel
|
||||
DistModel = torch.nn.parallel.DistributedDataParallel
|
||||
DistModel = Union[torch.nn.parallel.DistributedDataParallel, FullyShardedDataParallel]
|
||||
Processor = Union[transformers.PreTrainedTokenizer, transformers.ProcessorMixin]
|
||||
Optimizer = torch.optim.Optimizer
|
||||
Scheduler = torch.optim.lr_scheduler.LRScheduler
|
||||
else:
|
||||
Tensor = None
|
||||
TensorLike = None
|
||||
TorchDataset = None
|
||||
HFDataset = None
|
||||
DataCollator = None
|
||||
@@ -49,12 +53,10 @@ else:
|
||||
|
||||
|
||||
class DatasetInfo(TypedDict, total=False):
|
||||
hf_hub_url: NotRequired[str]
|
||||
"""HF hub dataset uri."""
|
||||
file_name: NotRequired[str]
|
||||
path: str
|
||||
"""Local file path."""
|
||||
dataset_dir: NotRequired[str]
|
||||
"""Dataset directory, default to args.dataset_dir."""
|
||||
source: NotRequired[Literal["hf_hub", "ms_hub", "local"]]
|
||||
"""Dataset source, default to "hf_hub"."""
|
||||
split: NotRequired[str]
|
||||
"""Dataset split, default to "train"."""
|
||||
converter: NotRequired[str]
|
||||
@@ -68,12 +70,12 @@ class DatasetInfo(TypedDict, total=False):
|
||||
|
||||
|
||||
class Content(TypedDict):
|
||||
type: Literal["text", "tools", "reasoning", "tool_calls", "image_url"]
|
||||
type: Literal["text", "reasoning", "tools", "tool_calls", "image_url"]
|
||||
value: str
|
||||
|
||||
|
||||
class Message(TypedDict):
|
||||
role: Literal["system", "user", "assistant"]
|
||||
role: Literal["system", "user", "assistant", "tool"]
|
||||
content: list[Content]
|
||||
loss_weight: float
|
||||
|
||||
@@ -12,25 +12,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
import os
|
||||
|
||||
from llamafactory.v1.accelerator.interface import DistributedInterface, DistributedStrategy
|
||||
|
||||
|
||||
class DeviceMeshManager:
|
||||
"""Device mesh manager."""
|
||||
|
||||
_instance: Optional["DeviceMeshManager"] = None
|
||||
_initialized: bool = False
|
||||
|
||||
def __new__(cls) -> "DeviceMeshManager":
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self) -> None:
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self.device_mesh: Optional[DeviceMesh] = None
|
||||
self._initialized = True
|
||||
def test_distributed_interface():
|
||||
DistributedInterface(DistributedStrategy())
|
||||
assert DistributedInterface.rank == int(os.getenv("RANK", "0"))
|
||||
assert DistributedInterface.world_size == int(os.getenv("WORLD_SIZE", "1"))
|
||||
@@ -19,7 +19,7 @@ from datasets import load_dataset
|
||||
|
||||
from llamafactory.v1.config.data_args import DataArguments
|
||||
from llamafactory.v1.core.data_engine import DataEngine
|
||||
from llamafactory.v1.plugins.data_plugins.converter import get_converter
|
||||
from llamafactory.v1.plugins.data_plugins.converter import DataConverterPlugin
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_samples", [16])
|
||||
@@ -49,99 +49,27 @@ def test_alpaca_converter(num_samples: int):
|
||||
assert data_engine[index] == {"_dataset_name": "tiny_dataset", **expected_data}
|
||||
|
||||
|
||||
def test_sharegpt_converter_invalid():
|
||||
def test_sharegpt_converter():
|
||||
example = {
|
||||
"conversations": [
|
||||
{
|
||||
"from": "system",
|
||||
"value": "Processes historical market data to generate trading signals "
|
||||
"based on specified technical indicators.",
|
||||
},
|
||||
{
|
||||
"from": "human",
|
||||
"value": "I possess a detailed dataset, 'Historical_Market_Data.csv'. "
|
||||
"Could you proceed with these function calls to assist me with the task?",
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
"value": "```tool_call\n{'arguments': '{\"data_file\": \"Historical_Market_Data.csv\"]}', "
|
||||
"'name': 'backtest_trading_signals'}```\n",
|
||||
},
|
||||
{
|
||||
"from": "tool",
|
||||
"value": '<tool id="D2">\n{"analysis": {"RSI_signals": [{"date": "2025-01-10", '
|
||||
'"symbol": "AAPL", "signal": "Buy"}]}}}\n</tool>\n',
|
||||
},
|
||||
{"from": "system", "value": "System"},
|
||||
{"from": "human", "value": "User"},
|
||||
{"from": "gpt", "value": "Assistant"},
|
||||
]
|
||||
}
|
||||
dataset_converter = get_converter("sharegpt")
|
||||
assert dataset_converter(example) == {"messages": []}
|
||||
|
||||
|
||||
def test_sharegpt_converter_valid():
|
||||
example = {
|
||||
"conversations": [
|
||||
{
|
||||
"from": "system",
|
||||
"value": "Processes historical market data to generate trading signals based on "
|
||||
"specified technical indicators.",
|
||||
},
|
||||
{
|
||||
"from": "human",
|
||||
"value": "I possess a detailed dataset, 'Historical_Market_Data.csv'. "
|
||||
"Could you proceed with these function calls to assist me with the task?",
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
"value": "```tool_call\n{'arguments': '{\"data_file\": \"Historical_Market_Data.csv\"]}', "
|
||||
"'name': 'backtest_trading_signals'}```\n",
|
||||
},
|
||||
]
|
||||
}
|
||||
dataset_converter = get_converter("sharegpt")
|
||||
expected_data = {
|
||||
"messages": [
|
||||
{
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"value": "Processes historical market data to generate trading signals based on "
|
||||
"specified technical indicators.",
|
||||
}
|
||||
],
|
||||
"loss_weight": 0.0,
|
||||
"role": "system",
|
||||
},
|
||||
{
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"value": "I possess a detailed dataset, 'Historical_Market_Data.csv'. "
|
||||
"Could you proceed with these function calls to assist me with the task?",
|
||||
}
|
||||
],
|
||||
"loss_weight": 0.0,
|
||||
"role": "user",
|
||||
},
|
||||
{
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"value": "```tool_call\n{'arguments': '{\"data_file\": \"Historical_Market_Data.csv\"]}', "
|
||||
"'name': 'backtest_trading_signals'}```\n",
|
||||
}
|
||||
],
|
||||
"loss_weight": 1.0,
|
||||
"role": "assistant",
|
||||
},
|
||||
{"content": [{"type": "text", "value": "System"}], "loss_weight": 0.0, "role": "system"},
|
||||
{"content": [{"type": "text", "value": "User"}], "loss_weight": 0.0, "role": "user"},
|
||||
{"content": [{"type": "text", "value": "Assistant"}], "loss_weight": 1.0, "role": "assistant"},
|
||||
]
|
||||
}
|
||||
assert dataset_converter(example) == expected_data
|
||||
assert DataConverterPlugin("sharegpt")(example) == expected_data
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_samples", [16])
|
||||
def test_pair_converter(num_samples: int):
|
||||
data_args = DataArguments(dataset="frozenleaves/tiny-dpo/dataset_info.yaml")
|
||||
data_args = DataArguments(dataset="llamafactory/tiny-preference-dataset/dataset_info.yaml")
|
||||
data_engine = DataEngine(data_args)
|
||||
original_data = load_dataset("HuggingFaceH4/orca_dpo_pairs", split="train_prefs")
|
||||
indexes = random.choices(range(len(data_engine)), k=num_samples)
|
||||
@@ -189,6 +117,5 @@ def test_pair_converter(num_samples: int):
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_alpaca_converter(1)
|
||||
test_sharegpt_converter_invalid()
|
||||
test_sharegpt_converter_valid()
|
||||
test_sharegpt_converter()
|
||||
test_pair_converter(1)
|
||||
|
||||
@@ -17,10 +17,13 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from llamafactory.v1.accelerator.helper import get_current_accelerator
|
||||
|
||||
|
||||
class TestKernelPlugin(unittest.TestCase):
|
||||
@patch("torch.accelerator.current_accelerator")
|
||||
def test_apply_kernel(self, mock_get_accelerator):
|
||||
get_current_accelerator.cache_clear()
|
||||
mock_device = MagicMock()
|
||||
mock_device.type = "npu"
|
||||
mock_get_accelerator.return_value = mock_device
|
||||
@@ -47,6 +50,7 @@ class TestKernelPlugin(unittest.TestCase):
|
||||
class Test_Use_V1_Kernels(unittest.TestCase):
|
||||
@patch("torch.accelerator.current_accelerator")
|
||||
def test_use_v1_kernels(self, mock_get_accelerator):
|
||||
get_current_accelerator.cache_clear()
|
||||
mock_device = MagicMock()
|
||||
mock_device.type = "npu"
|
||||
mock_get_accelerator.return_value = mock_device
|
||||
|
||||
Reference in New Issue
Block a user