mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-29 18:20:35 +08:00
[v1] add v1 launcher (#9236)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
33
src/llamafactory/v1/config/data_args.py
Normal file
33
src/llamafactory/v1/config/data_args.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments:
|
||||
dataset: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the dataset."},
|
||||
)
|
||||
dataset_dir: str = field(
|
||||
default="data",
|
||||
metadata={"help": "Path to the folder containing the datasets."},
|
||||
)
|
||||
cutoff_len: int = field(
|
||||
default=2048,
|
||||
metadata={"help": "Cutoff length for the dataset."},
|
||||
)
|
||||
27
src/llamafactory/v1/config/model_args.py
Normal file
27
src/llamafactory/v1/config/model_args.py
Normal file
@@ -0,0 +1,27 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
model: str = field(
|
||||
metadata={"help": "Path to the model or model identifier from Hugging Face."},
|
||||
)
|
||||
trust_remote_code: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Trust remote code from Hugging Face."},
|
||||
)
|
||||
63
src/llamafactory/v1/config/parser.py
Normal file
63
src/llamafactory/v1/config/parser.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
from ...extras.misc import is_env_enabled
|
||||
from .data_args import DataArguments
|
||||
from .model_args import ModelArguments
|
||||
from .sample_args import SampleArguments
|
||||
from .training_args import TrainingArguments
|
||||
|
||||
|
||||
def get_args(
|
||||
args: Optional[Union[dict[str, Any], list[str]]] = None,
|
||||
) -> tuple[DataArguments, ModelArguments, TrainingArguments, SampleArguments]:
|
||||
"""Parse arguments from command line or config file."""
|
||||
parser = HfArgumentParser([DataArguments, ModelArguments, TrainingArguments, SampleArguments])
|
||||
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_KEYS")
|
||||
|
||||
if args is None:
|
||||
if len(sys.argv) > 1 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")):
|
||||
override_config = OmegaConf.from_cli(sys.argv[2:])
|
||||
dict_config = OmegaConf.load(Path(sys.argv[1]).absolute())
|
||||
args = OmegaConf.to_container(OmegaConf.merge(dict_config, override_config))
|
||||
elif len(sys.argv) > 1 and sys.argv[1].endswith(".json"):
|
||||
override_config = OmegaConf.from_cli(sys.argv[2:])
|
||||
dict_config = OmegaConf.create(json.load(Path(sys.argv[1]).absolute()))
|
||||
args = OmegaConf.to_container(OmegaConf.merge(dict_config, override_config))
|
||||
else: # list of strings
|
||||
args = sys.argv[1:]
|
||||
|
||||
if isinstance(args, dict):
|
||||
(*parsed_args,) = parser.parse_dict(args, allow_extra_keys=allow_extra_keys)
|
||||
else:
|
||||
(*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(args, return_remaining_strings=True)
|
||||
if unknown_args and not allow_extra_keys:
|
||||
print(parser.format_help())
|
||||
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
|
||||
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
|
||||
|
||||
return tuple(parsed_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(get_args())
|
||||
24
src/llamafactory/v1/config/sample_args.py
Normal file
24
src/llamafactory/v1/config/sample_args.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class SampleArguments:
|
||||
max_new_tokens: int = field(
|
||||
default=128,
|
||||
metadata={"help": "Maximum number of new tokens to generate."},
|
||||
)
|
||||
40
src/llamafactory/v1/config/training_args.py
Normal file
40
src/llamafactory/v1/config/training_args.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments:
|
||||
output_dir: str = field(
|
||||
default="",
|
||||
metadata={"help": "Path to the output directory."},
|
||||
)
|
||||
micro_batch_size: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Micro batch size for training."},
|
||||
)
|
||||
global_batch_size: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Global batch size for training."},
|
||||
)
|
||||
learning_rate: float = field(
|
||||
default=1e-4,
|
||||
metadata={"help": "Learning rate for training."},
|
||||
)
|
||||
bf16: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use bf16 for training."},
|
||||
)
|
||||
@@ -0,0 +1,35 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..config.training_args import TrainingArguments
|
||||
from ..extras.types import DataLoader, Model, Processor
|
||||
|
||||
|
||||
class BaseTrainer:
|
||||
def __init__(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
model: Model,
|
||||
processor: Processor,
|
||||
data_loader: DataLoader,
|
||||
) -> None:
|
||||
self.args = args
|
||||
self.model = model
|
||||
self.processor = processor
|
||||
self.data_loader = data_loader
|
||||
self.optimizer = None
|
||||
self.lr_scheduler = None
|
||||
|
||||
def fit(self) -> None:
|
||||
pass
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..config.sample_args import SampleArguments
|
||||
|
||||
|
||||
class ChatSampler:
|
||||
def __init__(self, sample_args: SampleArguments) -> None:
|
||||
self.args = sample_args
|
||||
|
||||
75
src/llamafactory/v1/core/data_engine.py
Normal file
75
src/llamafactory/v1/core/data_engine.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import hf_hub_download
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from ..config.data_args import DataArguments
|
||||
from ..extras.types import DataLoader, Dataset, Processor
|
||||
|
||||
|
||||
class DataCollator:
|
||||
def __init__(self, processor: Processor) -> None:
|
||||
self.processor = processor
|
||||
|
||||
|
||||
class DatasetPathMixin:
|
||||
args: DataArguments
|
||||
|
||||
def _abspath(self, path: str) -> str:
|
||||
return os.path.abspath(os.path.expanduser(os.path.join(self.args.dataset_dir, path)))
|
||||
|
||||
def _exists(self, path: str) -> bool:
|
||||
return os.path.exists(self._abspath(path))
|
||||
|
||||
def _isfile(self, path: str) -> bool:
|
||||
return os.path.isfile(self._abspath(path))
|
||||
|
||||
|
||||
class DataEngine(DatasetPathMixin):
|
||||
def __init__(self, data_args: DataArguments) -> None:
|
||||
self.args = data_args
|
||||
self.datasets: dict[str, Dataset] = {}
|
||||
dataset_info = self.get_dataset_info()
|
||||
self.load_dataset(dataset_info)
|
||||
|
||||
def get_dataset_info(self) -> dict:
|
||||
"""Get dataset info from dataset path.
|
||||
|
||||
Returns:
|
||||
dict: Dataset info.
|
||||
"""
|
||||
if self.args.dataset.endswith(".yaml") and self._isfile(self.args.dataset): # local file
|
||||
return OmegaConf.load(self._abspath(self.args.dataset))
|
||||
elif self.args.dataset.endswith(".yaml"): # hf hub uri
|
||||
repo_id, filename = os.path.split(self.args.dataset)
|
||||
filepath = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")
|
||||
return OmegaConf.load(filepath)
|
||||
elif self._exists(self.args.dataset): # local file(s)
|
||||
return {"default": {"file_name": self.args.dataset}}
|
||||
else: # hf hub dataset
|
||||
return {"default": {"hf_hub_url": self.args.dataset}}
|
||||
|
||||
def load_dataset(self, dataset_info: dict) -> None:
|
||||
for key, value in dataset_info.items():
|
||||
if "hf_hub_url" in value:
|
||||
dataset_info[key] = load_dataset(value["hf_hub_url"])
|
||||
elif "file_name" in value:
|
||||
dataset_info[key] = load_dataset(value["file_name"])
|
||||
|
||||
def get_data_loader(self, processor: Processor) -> DataLoader:
|
||||
pass
|
||||
@@ -0,0 +1,27 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..config.model_args import ModelArguments
|
||||
from ..extras.types import Model, Processor
|
||||
|
||||
|
||||
class ModelEngine:
|
||||
def __init__(self, model_args: ModelArguments) -> None:
|
||||
self.args = model_args
|
||||
|
||||
def get_model(self) -> Model:
|
||||
pass
|
||||
|
||||
def get_processor(self) -> Processor:
|
||||
pass
|
||||
|
||||
32
src/llamafactory/v1/extras/types.py
Normal file
32
src/llamafactory/v1/extras/types.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset as HFDataset
|
||||
from datasets import IterableDataset
|
||||
from torch.utils.data import DataLoader as TorchDataLoader
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
||||
|
||||
Dataset = Union[HFDataset, IterableDataset]
|
||||
DataLoader = TorchDataLoader
|
||||
Model = PreTrainedModel
|
||||
Processor = Union[PreTrainedTokenizer, ProcessorMixin]
|
||||
else:
|
||||
Dataset = None
|
||||
DataLoader = None
|
||||
Model = None
|
||||
Processor = None
|
||||
@@ -12,22 +12,55 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
|
||||
def run_train():
|
||||
raise NotImplementedError("Please use `llamafactory-cli sft` or `llamafactory-cli rm`.")
|
||||
from ..extras.env import VERSION, print_env
|
||||
|
||||
|
||||
def run_chat():
|
||||
from llamafactory.v1.core.chat_sampler import Sampler
|
||||
|
||||
Sampler().cli()
|
||||
USAGE = (
|
||||
"-" * 70
|
||||
+ "\n"
|
||||
+ "| Usage: |\n"
|
||||
+ "| llamafactory-cli sft -h: train models |\n"
|
||||
+ "| llamafactory-cli version: show version info |\n"
|
||||
+ "| Hint: You can use `lmf` as a shortcut for `llamafactory-cli`. |\n"
|
||||
+ "-" * 70
|
||||
)
|
||||
|
||||
|
||||
def run_sft():
|
||||
from llamafactory.v1.train.sft import SFTTrainer
|
||||
WELCOME = (
|
||||
"-" * 58
|
||||
+ "\n"
|
||||
+ f"| Welcome to LLaMA Factory, version {VERSION}"
|
||||
+ " " * (21 - len(VERSION))
|
||||
+ "|\n|"
|
||||
+ " " * 56
|
||||
+ "|\n"
|
||||
+ "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n"
|
||||
+ "-" * 58
|
||||
)
|
||||
|
||||
SFTTrainer().run()
|
||||
|
||||
def launch():
|
||||
command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
|
||||
|
||||
if command == "sft":
|
||||
from .trainers.sft_trainer import run_sft
|
||||
|
||||
run_sft()
|
||||
|
||||
elif command == "env":
|
||||
print_env()
|
||||
|
||||
elif command == "version":
|
||||
print(WELCOME)
|
||||
|
||||
elif command == "help":
|
||||
print(USAGE)
|
||||
|
||||
else:
|
||||
print(f"Unknown command: {command}.\n{USAGE}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_train()
|
||||
pass
|
||||
|
||||
0
src/llamafactory/v1/plugins/data_plugins/filter.py
Normal file
0
src/llamafactory/v1/plugins/data_plugins/filter.py
Normal file
26
src/llamafactory/v1/plugins/data_plugins/template.py
Normal file
26
src/llamafactory/v1/plugins/data_plugins/template.py
Normal file
@@ -0,0 +1,26 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Template:
|
||||
user_template: str
|
||||
assistant_template: str
|
||||
system_template: str
|
||||
|
||||
def render_message(self, message: "dict[str, str]") -> str:
|
||||
return self.user_template.format(**message)
|
||||
@@ -0,0 +1,34 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from ..config.parser import get_args
|
||||
from ..core.base_trainer import BaseTrainer
|
||||
from ..core.data_engine import DataEngine
|
||||
from ..core.model_engine import ModelEngine
|
||||
|
||||
|
||||
class SFTTrainer(BaseTrainer):
|
||||
pass
|
||||
|
||||
|
||||
def run_sft():
|
||||
model_args, data_args, training_args, _ = get_args()
|
||||
model_engine = ModelEngine(model_args)
|
||||
data_engine = DataEngine(data_args)
|
||||
model = model_engine.get_model()
|
||||
processor = model_engine.get_processor()
|
||||
data_loader = data_engine.get_data_loader(processor)
|
||||
trainer = SFTTrainer(training_args, model, processor, data_loader)
|
||||
trainer.fit()
|
||||
|
||||
Reference in New Issue
Block a user