mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-14 07:42:49 +08:00
[v1] add v1 launcher (#9236)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
95b7188090
commit
10146029ba
4
.gitignore
vendored
4
.gitignore
vendored
@ -169,8 +169,8 @@ uv.lock
|
|||||||
hf_cache/
|
hf_cache/
|
||||||
ms_cache/
|
ms_cache/
|
||||||
om_cache/
|
om_cache/
|
||||||
cache/
|
llamaboard_cache/
|
||||||
config/
|
llamaboard_config/
|
||||||
saves/
|
saves/
|
||||||
output/
|
output/
|
||||||
wandb/
|
wandb/
|
||||||
|
@ -1,82 +0,0 @@
|
|||||||
# Copyright 2025 the LlamaFactory team.
|
|
||||||
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
|
|
||||||
import datasets
|
|
||||||
|
|
||||||
|
|
||||||
_HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co")
|
|
||||||
|
|
||||||
_DESCRIPTION = "BELLE multiturn chat dataset."
|
|
||||||
|
|
||||||
_CITATION = """\
|
|
||||||
@article{belle2023exploring,
|
|
||||||
title={Exploring the Impact of Instruction Data Scaling on Large Language Models},
|
|
||||||
author={Yunjie Ji, Yong Deng, Yan Gong, Yiping Peng, Qiang Niu, Lei Zhang, Baochang Ma, Xiangang Li},
|
|
||||||
journal={arXiv preprint arXiv:2303.14742},
|
|
||||||
year={2023}
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
_HOMEPAGE = f"{_HF_ENDPOINT}/datasets/BelleGroup/multiturn_chat_0.8M"
|
|
||||||
_LICENSE = "gpl-3.0"
|
|
||||||
_URL = f"{_HF_ENDPOINT}/datasets/BelleGroup/multiturn_chat_0.8M/resolve/main/multiturn_chat_0.8M.json"
|
|
||||||
|
|
||||||
|
|
||||||
class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
|
||||||
VERSION = datasets.Version("0.0.0")
|
|
||||||
|
|
||||||
def _info(self):
|
|
||||||
features = datasets.Features(
|
|
||||||
{"conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}]}
|
|
||||||
)
|
|
||||||
return datasets.DatasetInfo(
|
|
||||||
description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
|
|
||||||
)
|
|
||||||
|
|
||||||
def _split_generators(self, dl_manager: datasets.DownloadManager):
|
|
||||||
file_path = dl_manager.download(_URL)
|
|
||||||
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": file_path})]
|
|
||||||
|
|
||||||
def _generate_examples(self, filepath: str):
|
|
||||||
with open(filepath, encoding="utf-8") as f:
|
|
||||||
for key, row in enumerate(f):
|
|
||||||
data = json.loads(row)
|
|
||||||
conversations = []
|
|
||||||
prompt = data["instruction"].strip()
|
|
||||||
response = data["output"].strip()
|
|
||||||
|
|
||||||
assist_idx = prompt.rfind("Assistant:")
|
|
||||||
human_idx = prompt.rfind("Human:")
|
|
||||||
query = prompt[human_idx + 6 : assist_idx].strip()
|
|
||||||
prompt = prompt[:human_idx].strip()
|
|
||||||
conversations.insert(0, {"from": "gpt", "value": response})
|
|
||||||
conversations.insert(0, {"from": "human", "value": query})
|
|
||||||
|
|
||||||
while prompt.rfind("Assistant:") != -1:
|
|
||||||
assist_idx = prompt.rfind("Assistant:")
|
|
||||||
human_idx = prompt.rfind("Human:")
|
|
||||||
if human_idx != -1:
|
|
||||||
old_query = prompt[human_idx + 6 : assist_idx].strip()
|
|
||||||
old_resp = prompt[assist_idx + 10 :].strip()
|
|
||||||
conversations.insert(0, {"from": "gpt", "value": old_resp})
|
|
||||||
conversations.insert(0, {"from": "human", "value": old_query})
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
prompt = prompt[:human_idx].strip()
|
|
||||||
|
|
||||||
yield key, {"conversations": conversations}
|
|
@ -143,14 +143,6 @@
|
|||||||
"hf_hub_url": "BelleGroup/school_math_0.25M",
|
"hf_hub_url": "BelleGroup/school_math_0.25M",
|
||||||
"ms_hub_url": "AI-ModelScope/school_math_0.25M"
|
"ms_hub_url": "AI-ModelScope/school_math_0.25M"
|
||||||
},
|
},
|
||||||
"belle_multiturn": {
|
|
||||||
"script_url": "belle_multiturn",
|
|
||||||
"formatting": "sharegpt"
|
|
||||||
},
|
|
||||||
"ultra_chat": {
|
|
||||||
"script_url": "ultra_chat",
|
|
||||||
"formatting": "sharegpt"
|
|
||||||
},
|
|
||||||
"open_platypus": {
|
"open_platypus": {
|
||||||
"hf_hub_url": "garage-bAInd/Open-Platypus",
|
"hf_hub_url": "garage-bAInd/Open-Platypus",
|
||||||
"ms_hub_url": "AI-ModelScope/Open-Platypus"
|
"ms_hub_url": "AI-ModelScope/Open-Platypus"
|
||||||
@ -583,16 +575,6 @@
|
|||||||
"system": "system"
|
"system": "system"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"hh_rlhf_en": {
|
|
||||||
"script_url": "hh_rlhf_en",
|
|
||||||
"ranking": true,
|
|
||||||
"columns": {
|
|
||||||
"prompt": "instruction",
|
|
||||||
"chosen": "chosen",
|
|
||||||
"rejected": "rejected",
|
|
||||||
"history": "history"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nectar_rm": {
|
"nectar_rm": {
|
||||||
"hf_hub_url": "AstraMindAI/RLAIF-Nectar",
|
"hf_hub_url": "AstraMindAI/RLAIF-Nectar",
|
||||||
"ms_hub_url": "AI-ModelScope/RLAIF-Nectar",
|
"ms_hub_url": "AI-ModelScope/RLAIF-Nectar",
|
||||||
|
@ -1,98 +0,0 @@
|
|||||||
# Copyright 2025 the LlamaFactory team.
|
|
||||||
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
|
|
||||||
import datasets
|
|
||||||
|
|
||||||
|
|
||||||
_HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co")
|
|
||||||
_DESCRIPTION = "Human preference data about helpfulness and harmlessness."
|
|
||||||
_CITATION = ""
|
|
||||||
_HOMEPAGE = f"{_HF_ENDPOINT}/datasets/Anthropic/hh-rlhf"
|
|
||||||
_LICENSE = "mit"
|
|
||||||
_URL = f"{_HF_ENDPOINT}/datasets/Anthropic/hh-rlhf/resolve/main/"
|
|
||||||
_URLS = {
|
|
||||||
"train": [
|
|
||||||
_URL + "harmless-base/train.jsonl.gz",
|
|
||||||
_URL + "helpful-base/train.jsonl.gz",
|
|
||||||
_URL + "helpful-online/train.jsonl.gz",
|
|
||||||
_URL + "helpful-rejection-sampled/train.jsonl.gz",
|
|
||||||
],
|
|
||||||
"test": [
|
|
||||||
_URL + "harmless-base/test.jsonl.gz",
|
|
||||||
_URL + "helpful-base/test.jsonl.gz",
|
|
||||||
_URL + "helpful-online/test.jsonl.gz",
|
|
||||||
_URL + "helpful-rejection-sampled/test.jsonl.gz",
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class HhRlhfEn(datasets.GeneratorBasedBuilder):
|
|
||||||
VERSION = datasets.Version("0.0.0")
|
|
||||||
|
|
||||||
def _info(self) -> datasets.DatasetInfo:
|
|
||||||
features = datasets.Features(
|
|
||||||
{
|
|
||||||
"instruction": datasets.Value("string"),
|
|
||||||
"chosen": datasets.Value("string"),
|
|
||||||
"rejected": datasets.Value("string"),
|
|
||||||
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string"))),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return datasets.DatasetInfo(
|
|
||||||
description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
|
|
||||||
)
|
|
||||||
|
|
||||||
def _split_generators(self, dl_manager: datasets.DownloadManager):
|
|
||||||
file_path = dl_manager.download_and_extract(_URLS)
|
|
||||||
return [
|
|
||||||
datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepaths": file_path["train"]}),
|
|
||||||
datasets.SplitGenerator(name=datasets.Split.TEST, gen_kwargs={"filepaths": file_path["test"]}),
|
|
||||||
]
|
|
||||||
|
|
||||||
def _generate_examples(self, filepaths: list[str]):
|
|
||||||
key = 0
|
|
||||||
for filepath in filepaths:
|
|
||||||
with open(filepath, encoding="utf-8") as f:
|
|
||||||
for row in f:
|
|
||||||
data = json.loads(row)
|
|
||||||
chosen = data["chosen"]
|
|
||||||
rejected = data["rejected"]
|
|
||||||
|
|
||||||
assist_idx = rejected.rfind("\n\nAssistant: ")
|
|
||||||
r_reject = rejected[assist_idx + 13 :].strip()
|
|
||||||
assist_idx = chosen.rfind("\n\nAssistant: ")
|
|
||||||
r_accept = chosen[assist_idx + 13 :].strip()
|
|
||||||
|
|
||||||
human_idx = chosen.rfind("\n\nHuman: ")
|
|
||||||
query = chosen[human_idx + 9 : assist_idx].strip()
|
|
||||||
prompt = chosen[:human_idx]
|
|
||||||
history = []
|
|
||||||
|
|
||||||
while prompt.rfind("\n\nAssistant: ") != -1:
|
|
||||||
assist_idx = prompt.rfind("\n\nAssistant: ")
|
|
||||||
human_idx = prompt.rfind("\n\nHuman: ")
|
|
||||||
if human_idx != -1:
|
|
||||||
old_query = prompt[human_idx + 9 : assist_idx].strip()
|
|
||||||
old_resp = prompt[assist_idx + 13 :].strip()
|
|
||||||
history.insert(0, (old_query, old_resp))
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
prompt = prompt[:human_idx]
|
|
||||||
|
|
||||||
yield key, {"instruction": query, "chosen": r_accept, "rejected": r_reject, "history": history}
|
|
||||||
key += 1
|
|
@ -1,74 +0,0 @@
|
|||||||
# Copyright 2025 the LlamaFactory team.
|
|
||||||
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
|
|
||||||
import datasets
|
|
||||||
|
|
||||||
|
|
||||||
_HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co")
|
|
||||||
|
|
||||||
_DESCRIPTION = "UltraChat: Large-scale, Informative, and Diverse Multi-round Dialogue Data."
|
|
||||||
|
|
||||||
_CITATION = """\
|
|
||||||
@misc{UltraChat,
|
|
||||||
author = {Ding, Ning and Chen, Yulin and Xu, Bokai and Hu, Shengding and others},
|
|
||||||
title = {UltraChat: A Large-scale Auto-generated Multi-round Dialogue Data},
|
|
||||||
year = {2023},
|
|
||||||
publisher = {GitHub},
|
|
||||||
journal = {GitHub repository},
|
|
||||||
howpublished = {\\url{https://github.com/thunlp/ultrachat}},
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
_HOMEPAGE = f"{_HF_ENDPOINT}/datasets/stingning/ultrachat"
|
|
||||||
_LICENSE = "cc-by-nc-4.0"
|
|
||||||
_BASE_DATA_URL = f"{_HF_ENDPOINT}/datasets/stingning/ultrachat/resolve/main/train_{{idx}}.jsonl"
|
|
||||||
|
|
||||||
|
|
||||||
class UltraChat(datasets.GeneratorBasedBuilder):
|
|
||||||
VERSION = datasets.Version("0.0.0")
|
|
||||||
|
|
||||||
def _info(self):
|
|
||||||
features = datasets.Features(
|
|
||||||
{"conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}]}
|
|
||||||
)
|
|
||||||
return datasets.DatasetInfo(
|
|
||||||
description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
|
|
||||||
)
|
|
||||||
|
|
||||||
def _split_generators(self, dl_manager: datasets.DownloadManager):
|
|
||||||
file_paths = [dl_manager.download(_BASE_DATA_URL.format(idx=idx)) for idx in range(10)] # multiple shards
|
|
||||||
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepaths": file_paths})]
|
|
||||||
|
|
||||||
def _generate_examples(self, filepaths: list[str]):
|
|
||||||
for filepath in filepaths:
|
|
||||||
with open(filepath, encoding="utf-8") as f:
|
|
||||||
for row in f:
|
|
||||||
try:
|
|
||||||
data = json.loads(row)
|
|
||||||
except Exception:
|
|
||||||
continue
|
|
||||||
key: int = data["id"]
|
|
||||||
content: list[str] = data["data"]
|
|
||||||
if len(content) % 2 == 1:
|
|
||||||
content.pop(-1)
|
|
||||||
if len(content) < 2:
|
|
||||||
continue
|
|
||||||
conversations = [
|
|
||||||
{"from": "human" if i % 2 == 0 else "gpt", "value": content[i]} for i in range(len(content))
|
|
||||||
]
|
|
||||||
yield key, {"conversations": conversations}
|
|
8
data/v1_sft_demo.yaml
Normal file
8
data/v1_sft_demo.yaml
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
identity:
|
||||||
|
file_name: identity.json
|
||||||
|
converter: alpaca
|
||||||
|
alpaca_en_demo:
|
||||||
|
file_name: alpaca_en_demo.json
|
||||||
|
dataset_dir: ~/data
|
||||||
|
converter: alpaca
|
||||||
|
num_samples: 500
|
@ -12,145 +12,16 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import os
|
|
||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
from copy import deepcopy
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
|
|
||||||
USAGE = (
|
|
||||||
"-" * 70
|
|
||||||
+ "\n"
|
|
||||||
+ "| Usage: |\n"
|
|
||||||
+ "| llamafactory-cli api -h: launch an OpenAI-style API server |\n"
|
|
||||||
+ "| llamafactory-cli chat -h: launch a chat interface in CLI |\n"
|
|
||||||
+ "| llamafactory-cli export -h: merge LoRA adapters and export model |\n"
|
|
||||||
+ "| llamafactory-cli train -h: train models |\n"
|
|
||||||
+ "| llamafactory-cli webchat -h: launch a chat interface in Web UI |\n"
|
|
||||||
+ "| llamafactory-cli webui: launch LlamaBoard |\n"
|
|
||||||
+ "| llamafactory-cli env: show environment info |\n"
|
|
||||||
+ "| llamafactory-cli version: show version info |\n"
|
|
||||||
+ "| Hint: You can use `lmf` as a shortcut for `llamafactory-cli`. |\n"
|
|
||||||
+ "-" * 70
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
from .extras import logging
|
from .extras.misc import is_env_enabled
|
||||||
from .extras.env import VERSION, print_env
|
|
||||||
from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_ray
|
|
||||||
|
|
||||||
if is_env_enabled("USE_V1"):
|
if is_env_enabled("USE_V1"):
|
||||||
from .v1 import launcher
|
from .v1 import launcher
|
||||||
else:
|
else:
|
||||||
from . import launcher
|
from . import launcher
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
launcher.launch()
|
||||||
|
|
||||||
WELCOME = (
|
|
||||||
"-" * 58
|
|
||||||
+ "\n"
|
|
||||||
+ f"| Welcome to LLaMA Factory, version {VERSION}"
|
|
||||||
+ " " * (21 - len(VERSION))
|
|
||||||
+ "|\n|"
|
|
||||||
+ " " * 56
|
|
||||||
+ "|\n"
|
|
||||||
+ "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n"
|
|
||||||
+ "-" * 58
|
|
||||||
)
|
|
||||||
|
|
||||||
COMMAND_MAP = {
|
|
||||||
"api": launcher.run_api,
|
|
||||||
"chat": launcher.run_chat,
|
|
||||||
"env": print_env,
|
|
||||||
"eval": launcher.run_eval,
|
|
||||||
"export": launcher.export_model,
|
|
||||||
"train": launcher.run_exp,
|
|
||||||
"webchat": launcher.run_web_demo,
|
|
||||||
"webui": launcher.run_web_ui,
|
|
||||||
"version": partial(print, WELCOME),
|
|
||||||
"help": partial(print, USAGE),
|
|
||||||
}
|
|
||||||
|
|
||||||
command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
|
|
||||||
if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())):
|
|
||||||
# launch distributed training
|
|
||||||
nnodes = os.getenv("NNODES", "1")
|
|
||||||
node_rank = os.getenv("NODE_RANK", "0")
|
|
||||||
nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count()))
|
|
||||||
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
|
|
||||||
master_port = os.getenv("MASTER_PORT", str(find_available_port()))
|
|
||||||
logger.info_rank0(f"Initializing {nproc_per_node} distributed tasks at: {master_addr}:{master_port}")
|
|
||||||
if int(nnodes) > 1:
|
|
||||||
logger.info_rank0(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}")
|
|
||||||
|
|
||||||
# elastic launch support
|
|
||||||
max_restarts = os.getenv("MAX_RESTARTS", "0")
|
|
||||||
rdzv_id = os.getenv("RDZV_ID")
|
|
||||||
min_nnodes = os.getenv("MIN_NNODES")
|
|
||||||
max_nnodes = os.getenv("MAX_NNODES")
|
|
||||||
|
|
||||||
env = deepcopy(os.environ)
|
|
||||||
if is_env_enabled("OPTIM_TORCH", "1"):
|
|
||||||
# optimize DDP, see https://zhuanlan.zhihu.com/p/671834539
|
|
||||||
env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
|
||||||
env["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
|
||||||
|
|
||||||
if rdzv_id is not None:
|
|
||||||
# launch elastic job with fault tolerant support when possible
|
|
||||||
# see also https://docs.pytorch.org/docs/stable/elastic/train_script.html
|
|
||||||
rdzv_nnodes = nnodes
|
|
||||||
# elastic number of nodes if MIN_NNODES and MAX_NNODES are set
|
|
||||||
if min_nnodes is not None and max_nnodes is not None:
|
|
||||||
rdzv_nnodes = f"{min_nnodes}:{max_nnodes}"
|
|
||||||
|
|
||||||
process = subprocess.run(
|
|
||||||
(
|
|
||||||
"torchrun --nnodes {rdzv_nnodes} --nproc-per-node {nproc_per_node} "
|
|
||||||
"--rdzv-id {rdzv_id} --rdzv-backend c10d --rdzv-endpoint {master_addr}:{master_port} "
|
|
||||||
"--max-restarts {max_restarts} {file_name} {args}"
|
|
||||||
)
|
|
||||||
.format(
|
|
||||||
rdzv_nnodes=rdzv_nnodes,
|
|
||||||
nproc_per_node=nproc_per_node,
|
|
||||||
rdzv_id=rdzv_id,
|
|
||||||
master_addr=master_addr,
|
|
||||||
master_port=master_port,
|
|
||||||
max_restarts=max_restarts,
|
|
||||||
file_name=launcher.__file__,
|
|
||||||
args=" ".join(sys.argv[1:]),
|
|
||||||
)
|
|
||||||
.split(),
|
|
||||||
env=env,
|
|
||||||
check=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# NOTE: DO NOT USE shell=True to avoid security risk
|
|
||||||
process = subprocess.run(
|
|
||||||
(
|
|
||||||
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
|
|
||||||
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
|
|
||||||
)
|
|
||||||
.format(
|
|
||||||
nnodes=nnodes,
|
|
||||||
node_rank=node_rank,
|
|
||||||
nproc_per_node=nproc_per_node,
|
|
||||||
master_addr=master_addr,
|
|
||||||
master_port=master_port,
|
|
||||||
file_name=launcher.__file__,
|
|
||||||
args=" ".join(sys.argv[1:]),
|
|
||||||
)
|
|
||||||
.split(),
|
|
||||||
env=env,
|
|
||||||
check=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
sys.exit(process.returncode)
|
|
||||||
elif command in COMMAND_MAP:
|
|
||||||
COMMAND_MAP[command]()
|
|
||||||
else:
|
|
||||||
print(f"Unknown command: {command}.\n{USAGE}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -16,6 +16,9 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
|
||||||
VERSION = "0.9.4.dev0"
|
VERSION = "0.9.4.dev0"
|
||||||
|
|
||||||
|
|
||||||
@ -28,20 +31,20 @@ def print_env() -> None:
|
|||||||
import peft
|
import peft
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
import trl
|
|
||||||
from transformers.utils import is_torch_cuda_available, is_torch_npu_available
|
from transformers.utils import is_torch_cuda_available, is_torch_npu_available
|
||||||
|
|
||||||
info = {
|
info = OrderedDict(
|
||||||
"`llamafactory` version": VERSION,
|
{
|
||||||
"Platform": platform.platform(),
|
"`llamafactory` version": VERSION,
|
||||||
"Python version": platform.python_version(),
|
"Platform": platform.platform(),
|
||||||
"PyTorch version": torch.__version__,
|
"Python version": platform.python_version(),
|
||||||
"Transformers version": transformers.__version__,
|
"PyTorch version": torch.__version__,
|
||||||
"Datasets version": datasets.__version__,
|
"Transformers version": transformers.__version__,
|
||||||
"Accelerate version": accelerate.__version__,
|
"Datasets version": datasets.__version__,
|
||||||
"PEFT version": peft.__version__,
|
"Accelerate version": accelerate.__version__,
|
||||||
"TRL version": trl.__version__,
|
"PEFT version": peft.__version__,
|
||||||
}
|
}
|
||||||
|
)
|
||||||
|
|
||||||
if is_torch_cuda_available():
|
if is_torch_cuda_available():
|
||||||
info["PyTorch version"] += " (GPU)"
|
info["PyTorch version"] += " (GPU)"
|
||||||
@ -54,6 +57,13 @@ def print_env() -> None:
|
|||||||
info["NPU type"] = torch.npu.get_device_name()
|
info["NPU type"] = torch.npu.get_device_name()
|
||||||
info["CANN version"] = torch.version.cann
|
info["CANN version"] = torch.version.cann
|
||||||
|
|
||||||
|
try:
|
||||||
|
import trl # type: ignore
|
||||||
|
|
||||||
|
info["TRL version"] = trl.__version__
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import deepspeed # type: ignore
|
import deepspeed # type: ignore
|
||||||
|
|
||||||
|
@ -12,46 +12,169 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
def run_api():
|
import subprocess
|
||||||
from llamafactory.api.app import run_api as _run_api
|
import sys
|
||||||
|
from copy import deepcopy
|
||||||
_run_api()
|
|
||||||
|
|
||||||
|
|
||||||
def run_chat():
|
USAGE = (
|
||||||
from llamafactory.chat.chat_model import run_chat as _run_chat
|
"-" * 70
|
||||||
|
+ "\n"
|
||||||
return _run_chat()
|
+ "| Usage: |\n"
|
||||||
|
+ "| llamafactory-cli api -h: launch an OpenAI-style API server |\n"
|
||||||
|
+ "| llamafactory-cli chat -h: launch a chat interface in CLI |\n"
|
||||||
|
+ "| llamafactory-cli export -h: merge LoRA adapters and export model |\n"
|
||||||
|
+ "| llamafactory-cli train -h: train models |\n"
|
||||||
|
+ "| llamafactory-cli webchat -h: launch a chat interface in Web UI |\n"
|
||||||
|
+ "| llamafactory-cli webui: launch LlamaBoard |\n"
|
||||||
|
+ "| llamafactory-cli env: show environment info |\n"
|
||||||
|
+ "| llamafactory-cli version: show version info |\n"
|
||||||
|
+ "| Hint: You can use `lmf` as a shortcut for `llamafactory-cli`. |\n"
|
||||||
|
+ "-" * 70
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def run_eval():
|
def launch():
|
||||||
raise NotImplementedError("Evaluation will be deprecated in the future.")
|
from .extras import logging
|
||||||
|
from .extras.env import VERSION, print_env
|
||||||
|
from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_ray
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
WELCOME = (
|
||||||
|
"-" * 58
|
||||||
|
+ "\n"
|
||||||
|
+ f"| Welcome to LLaMA Factory, version {VERSION}"
|
||||||
|
+ " " * (21 - len(VERSION))
|
||||||
|
+ "|\n|"
|
||||||
|
+ " " * 56
|
||||||
|
+ "|\n"
|
||||||
|
+ "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n"
|
||||||
|
+ "-" * 58
|
||||||
|
)
|
||||||
|
|
||||||
def export_model():
|
command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
|
||||||
from llamafactory.train.tuner import export_model as _export_model
|
if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())):
|
||||||
|
# launch distributed training
|
||||||
|
nnodes = os.getenv("NNODES", "1")
|
||||||
|
node_rank = os.getenv("NODE_RANK", "0")
|
||||||
|
nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count()))
|
||||||
|
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
|
||||||
|
master_port = os.getenv("MASTER_PORT", str(find_available_port()))
|
||||||
|
logger.info_rank0(f"Initializing {nproc_per_node} distributed tasks at: {master_addr}:{master_port}")
|
||||||
|
if int(nnodes) > 1:
|
||||||
|
logger.info_rank0(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}")
|
||||||
|
|
||||||
return _export_model()
|
# elastic launch support
|
||||||
|
max_restarts = os.getenv("MAX_RESTARTS", "0")
|
||||||
|
rdzv_id = os.getenv("RDZV_ID")
|
||||||
|
min_nnodes = os.getenv("MIN_NNODES")
|
||||||
|
max_nnodes = os.getenv("MAX_NNODES")
|
||||||
|
|
||||||
|
env = deepcopy(os.environ)
|
||||||
|
if is_env_enabled("OPTIM_TORCH", "1"):
|
||||||
|
# optimize DDP, see https://zhuanlan.zhihu.com/p/671834539
|
||||||
|
env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||||
|
env["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
||||||
|
|
||||||
def run_exp():
|
if rdzv_id is not None:
|
||||||
from llamafactory.train.tuner import run_exp as _run_exp
|
# launch elastic job with fault tolerant support when possible
|
||||||
|
# see also https://docs.pytorch.org/docs/stable/elastic/train_script.html
|
||||||
|
rdzv_nnodes = nnodes
|
||||||
|
# elastic number of nodes if MIN_NNODES and MAX_NNODES are set
|
||||||
|
if min_nnodes is not None and max_nnodes is not None:
|
||||||
|
rdzv_nnodes = f"{min_nnodes}:{max_nnodes}"
|
||||||
|
|
||||||
return _run_exp() # use absolute import
|
process = subprocess.run(
|
||||||
|
(
|
||||||
|
"torchrun --nnodes {rdzv_nnodes} --nproc-per-node {nproc_per_node} "
|
||||||
|
"--rdzv-id {rdzv_id} --rdzv-backend c10d --rdzv-endpoint {master_addr}:{master_port} "
|
||||||
|
"--max-restarts {max_restarts} {file_name} {args}"
|
||||||
|
)
|
||||||
|
.format(
|
||||||
|
rdzv_nnodes=rdzv_nnodes,
|
||||||
|
nproc_per_node=nproc_per_node,
|
||||||
|
rdzv_id=rdzv_id,
|
||||||
|
master_addr=master_addr,
|
||||||
|
master_port=master_port,
|
||||||
|
max_restarts=max_restarts,
|
||||||
|
file_name=__file__,
|
||||||
|
args=" ".join(sys.argv[1:]),
|
||||||
|
)
|
||||||
|
.split(),
|
||||||
|
env=env,
|
||||||
|
check=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# NOTE: DO NOT USE shell=True to avoid security risk
|
||||||
|
process = subprocess.run(
|
||||||
|
(
|
||||||
|
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
|
||||||
|
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
|
||||||
|
)
|
||||||
|
.format(
|
||||||
|
nnodes=nnodes,
|
||||||
|
node_rank=node_rank,
|
||||||
|
nproc_per_node=nproc_per_node,
|
||||||
|
master_addr=master_addr,
|
||||||
|
master_port=master_port,
|
||||||
|
file_name=__file__,
|
||||||
|
args=" ".join(sys.argv[1:]),
|
||||||
|
)
|
||||||
|
.split(),
|
||||||
|
env=env,
|
||||||
|
check=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
sys.exit(process.returncode)
|
||||||
|
|
||||||
def run_web_demo():
|
elif command == "api":
|
||||||
from llamafactory.webui.interface import run_web_demo as _run_web_demo
|
from .api.app import run_api
|
||||||
|
|
||||||
return _run_web_demo()
|
run_api()
|
||||||
|
|
||||||
|
elif command == "chat":
|
||||||
|
from .chat.chat_model import run_chat
|
||||||
|
|
||||||
def run_web_ui():
|
run_chat()
|
||||||
from llamafactory.webui.interface import run_web_ui as _run_web_ui
|
|
||||||
|
|
||||||
return _run_web_ui()
|
elif command == "eval":
|
||||||
|
raise NotImplementedError("Evaluation will be deprecated in the future.")
|
||||||
|
|
||||||
|
elif command == "export":
|
||||||
|
from .train.tuner import export_model
|
||||||
|
|
||||||
|
export_model()
|
||||||
|
|
||||||
|
elif command == "train":
|
||||||
|
from .train.tuner import run_exp
|
||||||
|
|
||||||
|
run_exp()
|
||||||
|
|
||||||
|
elif command == "webchat":
|
||||||
|
from .webui.interface import run_web_demo
|
||||||
|
|
||||||
|
run_web_demo()
|
||||||
|
|
||||||
|
elif command == "webui":
|
||||||
|
from .webui.interface import run_web_ui
|
||||||
|
|
||||||
|
run_web_ui()
|
||||||
|
|
||||||
|
elif command == "env":
|
||||||
|
print_env()
|
||||||
|
|
||||||
|
elif command == "version":
|
||||||
|
print(WELCOME)
|
||||||
|
|
||||||
|
elif command == "help":
|
||||||
|
print(USAGE)
|
||||||
|
|
||||||
|
else:
|
||||||
|
print(f"Unknown command: {command}.\n{USAGE}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
from llamafactory.train.tuner import run_exp # use absolute import
|
||||||
|
|
||||||
run_exp()
|
run_exp()
|
||||||
|
33
src/llamafactory/v1/config/data_args.py
Normal file
33
src/llamafactory/v1/config/data_args.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DataArguments:
|
||||||
|
dataset: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Path to the dataset."},
|
||||||
|
)
|
||||||
|
dataset_dir: str = field(
|
||||||
|
default="data",
|
||||||
|
metadata={"help": "Path to the folder containing the datasets."},
|
||||||
|
)
|
||||||
|
cutoff_len: int = field(
|
||||||
|
default=2048,
|
||||||
|
metadata={"help": "Cutoff length for the dataset."},
|
||||||
|
)
|
27
src/llamafactory/v1/config/model_args.py
Normal file
27
src/llamafactory/v1/config/model_args.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArguments:
|
||||||
|
model: str = field(
|
||||||
|
metadata={"help": "Path to the model or model identifier from Hugging Face."},
|
||||||
|
)
|
||||||
|
trust_remote_code: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Trust remote code from Hugging Face."},
|
||||||
|
)
|
63
src/llamafactory/v1/config/parser.py
Normal file
63
src/llamafactory/v1/config/parser.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from transformers import HfArgumentParser
|
||||||
|
|
||||||
|
from ...extras.misc import is_env_enabled
|
||||||
|
from .data_args import DataArguments
|
||||||
|
from .model_args import ModelArguments
|
||||||
|
from .sample_args import SampleArguments
|
||||||
|
from .training_args import TrainingArguments
|
||||||
|
|
||||||
|
|
||||||
|
def get_args(
|
||||||
|
args: Optional[Union[dict[str, Any], list[str]]] = None,
|
||||||
|
) -> tuple[DataArguments, ModelArguments, TrainingArguments, SampleArguments]:
|
||||||
|
"""Parse arguments from command line or config file."""
|
||||||
|
parser = HfArgumentParser([DataArguments, ModelArguments, TrainingArguments, SampleArguments])
|
||||||
|
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_KEYS")
|
||||||
|
|
||||||
|
if args is None:
|
||||||
|
if len(sys.argv) > 1 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")):
|
||||||
|
override_config = OmegaConf.from_cli(sys.argv[2:])
|
||||||
|
dict_config = OmegaConf.load(Path(sys.argv[1]).absolute())
|
||||||
|
args = OmegaConf.to_container(OmegaConf.merge(dict_config, override_config))
|
||||||
|
elif len(sys.argv) > 1 and sys.argv[1].endswith(".json"):
|
||||||
|
override_config = OmegaConf.from_cli(sys.argv[2:])
|
||||||
|
dict_config = OmegaConf.create(json.load(Path(sys.argv[1]).absolute()))
|
||||||
|
args = OmegaConf.to_container(OmegaConf.merge(dict_config, override_config))
|
||||||
|
else: # list of strings
|
||||||
|
args = sys.argv[1:]
|
||||||
|
|
||||||
|
if isinstance(args, dict):
|
||||||
|
(*parsed_args,) = parser.parse_dict(args, allow_extra_keys=allow_extra_keys)
|
||||||
|
else:
|
||||||
|
(*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(args, return_remaining_strings=True)
|
||||||
|
if unknown_args and not allow_extra_keys:
|
||||||
|
print(parser.format_help())
|
||||||
|
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
|
||||||
|
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
|
||||||
|
|
||||||
|
return tuple(parsed_args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print(get_args())
|
24
src/llamafactory/v1/config/sample_args.py
Normal file
24
src/llamafactory/v1/config/sample_args.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SampleArguments:
|
||||||
|
max_new_tokens: int = field(
|
||||||
|
default=128,
|
||||||
|
metadata={"help": "Maximum number of new tokens to generate."},
|
||||||
|
)
|
40
src/llamafactory/v1/config/training_args.py
Normal file
40
src/llamafactory/v1/config/training_args.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainingArguments:
|
||||||
|
output_dir: str = field(
|
||||||
|
default="",
|
||||||
|
metadata={"help": "Path to the output directory."},
|
||||||
|
)
|
||||||
|
micro_batch_size: int = field(
|
||||||
|
default=1,
|
||||||
|
metadata={"help": "Micro batch size for training."},
|
||||||
|
)
|
||||||
|
global_batch_size: int = field(
|
||||||
|
default=1,
|
||||||
|
metadata={"help": "Global batch size for training."},
|
||||||
|
)
|
||||||
|
learning_rate: float = field(
|
||||||
|
default=1e-4,
|
||||||
|
metadata={"help": "Learning rate for training."},
|
||||||
|
)
|
||||||
|
bf16: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Use bf16 for training."},
|
||||||
|
)
|
@ -0,0 +1,35 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from ..config.training_args import TrainingArguments
|
||||||
|
from ..extras.types import DataLoader, Model, Processor
|
||||||
|
|
||||||
|
|
||||||
|
class BaseTrainer:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
args: TrainingArguments,
|
||||||
|
model: Model,
|
||||||
|
processor: Processor,
|
||||||
|
data_loader: DataLoader,
|
||||||
|
) -> None:
|
||||||
|
self.args = args
|
||||||
|
self.model = model
|
||||||
|
self.processor = processor
|
||||||
|
self.data_loader = data_loader
|
||||||
|
self.optimizer = None
|
||||||
|
self.lr_scheduler = None
|
||||||
|
|
||||||
|
def fit(self) -> None:
|
||||||
|
pass
|
@ -0,0 +1,20 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from ..config.sample_args import SampleArguments
|
||||||
|
|
||||||
|
|
||||||
|
class ChatSampler:
|
||||||
|
def __init__(self, sample_args: SampleArguments) -> None:
|
||||||
|
self.args = sample_args
|
75
src/llamafactory/v1/core/data_engine.py
Normal file
75
src/llamafactory/v1/core/data_engine.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from datasets import load_dataset
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
from ..config.data_args import DataArguments
|
||||||
|
from ..extras.types import DataLoader, Dataset, Processor
|
||||||
|
|
||||||
|
|
||||||
|
class DataCollator:
|
||||||
|
def __init__(self, processor: Processor) -> None:
|
||||||
|
self.processor = processor
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetPathMixin:
|
||||||
|
args: DataArguments
|
||||||
|
|
||||||
|
def _abspath(self, path: str) -> str:
|
||||||
|
return os.path.abspath(os.path.expanduser(os.path.join(self.args.dataset_dir, path)))
|
||||||
|
|
||||||
|
def _exists(self, path: str) -> bool:
|
||||||
|
return os.path.exists(self._abspath(path))
|
||||||
|
|
||||||
|
def _isfile(self, path: str) -> bool:
|
||||||
|
return os.path.isfile(self._abspath(path))
|
||||||
|
|
||||||
|
|
||||||
|
class DataEngine(DatasetPathMixin):
|
||||||
|
def __init__(self, data_args: DataArguments) -> None:
|
||||||
|
self.args = data_args
|
||||||
|
self.datasets: dict[str, Dataset] = {}
|
||||||
|
dataset_info = self.get_dataset_info()
|
||||||
|
self.load_dataset(dataset_info)
|
||||||
|
|
||||||
|
def get_dataset_info(self) -> dict:
|
||||||
|
"""Get dataset info from dataset path.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Dataset info.
|
||||||
|
"""
|
||||||
|
if self.args.dataset.endswith(".yaml") and self._isfile(self.args.dataset): # local file
|
||||||
|
return OmegaConf.load(self._abspath(self.args.dataset))
|
||||||
|
elif self.args.dataset.endswith(".yaml"): # hf hub uri
|
||||||
|
repo_id, filename = os.path.split(self.args.dataset)
|
||||||
|
filepath = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")
|
||||||
|
return OmegaConf.load(filepath)
|
||||||
|
elif self._exists(self.args.dataset): # local file(s)
|
||||||
|
return {"default": {"file_name": self.args.dataset}}
|
||||||
|
else: # hf hub dataset
|
||||||
|
return {"default": {"hf_hub_url": self.args.dataset}}
|
||||||
|
|
||||||
|
def load_dataset(self, dataset_info: dict) -> None:
|
||||||
|
for key, value in dataset_info.items():
|
||||||
|
if "hf_hub_url" in value:
|
||||||
|
dataset_info[key] = load_dataset(value["hf_hub_url"])
|
||||||
|
elif "file_name" in value:
|
||||||
|
dataset_info[key] = load_dataset(value["file_name"])
|
||||||
|
|
||||||
|
def get_data_loader(self, processor: Processor) -> DataLoader:
|
||||||
|
pass
|
@ -0,0 +1,27 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from ..config.model_args import ModelArguments
|
||||||
|
from ..extras.types import Model, Processor
|
||||||
|
|
||||||
|
|
||||||
|
class ModelEngine:
|
||||||
|
def __init__(self, model_args: ModelArguments) -> None:
|
||||||
|
self.args = model_args
|
||||||
|
|
||||||
|
def get_model(self) -> Model:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_processor(self) -> Processor:
|
||||||
|
pass
|
32
src/llamafactory/v1/extras/types.py
Normal file
32
src/llamafactory/v1/extras/types.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Union
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from datasets import Dataset as HFDataset
|
||||||
|
from datasets import IterableDataset
|
||||||
|
from torch.utils.data import DataLoader as TorchDataLoader
|
||||||
|
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
||||||
|
|
||||||
|
Dataset = Union[HFDataset, IterableDataset]
|
||||||
|
DataLoader = TorchDataLoader
|
||||||
|
Model = PreTrainedModel
|
||||||
|
Processor = Union[PreTrainedTokenizer, ProcessorMixin]
|
||||||
|
else:
|
||||||
|
Dataset = None
|
||||||
|
DataLoader = None
|
||||||
|
Model = None
|
||||||
|
Processor = None
|
@ -12,22 +12,55 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
def run_train():
|
from ..extras.env import VERSION, print_env
|
||||||
raise NotImplementedError("Please use `llamafactory-cli sft` or `llamafactory-cli rm`.")
|
|
||||||
|
|
||||||
|
|
||||||
def run_chat():
|
USAGE = (
|
||||||
from llamafactory.v1.core.chat_sampler import Sampler
|
"-" * 70
|
||||||
|
+ "\n"
|
||||||
Sampler().cli()
|
+ "| Usage: |\n"
|
||||||
|
+ "| llamafactory-cli sft -h: train models |\n"
|
||||||
|
+ "| llamafactory-cli version: show version info |\n"
|
||||||
|
+ "| Hint: You can use `lmf` as a shortcut for `llamafactory-cli`. |\n"
|
||||||
|
+ "-" * 70
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def run_sft():
|
WELCOME = (
|
||||||
from llamafactory.v1.train.sft import SFTTrainer
|
"-" * 58
|
||||||
|
+ "\n"
|
||||||
|
+ f"| Welcome to LLaMA Factory, version {VERSION}"
|
||||||
|
+ " " * (21 - len(VERSION))
|
||||||
|
+ "|\n|"
|
||||||
|
+ " " * 56
|
||||||
|
+ "|\n"
|
||||||
|
+ "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n"
|
||||||
|
+ "-" * 58
|
||||||
|
)
|
||||||
|
|
||||||
SFTTrainer().run()
|
|
||||||
|
def launch():
|
||||||
|
command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
|
||||||
|
|
||||||
|
if command == "sft":
|
||||||
|
from .trainers.sft_trainer import run_sft
|
||||||
|
|
||||||
|
run_sft()
|
||||||
|
|
||||||
|
elif command == "env":
|
||||||
|
print_env()
|
||||||
|
|
||||||
|
elif command == "version":
|
||||||
|
print(WELCOME)
|
||||||
|
|
||||||
|
elif command == "help":
|
||||||
|
print(USAGE)
|
||||||
|
|
||||||
|
else:
|
||||||
|
print(f"Unknown command: {command}.\n{USAGE}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_train()
|
pass
|
||||||
|
0
src/llamafactory/v1/plugins/data_plugins/filter.py
Normal file
0
src/llamafactory/v1/plugins/data_plugins/filter.py
Normal file
26
src/llamafactory/v1/plugins/data_plugins/template.py
Normal file
26
src/llamafactory/v1/plugins/data_plugins/template.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Template:
|
||||||
|
user_template: str
|
||||||
|
assistant_template: str
|
||||||
|
system_template: str
|
||||||
|
|
||||||
|
def render_message(self, message: "dict[str, str]") -> str:
|
||||||
|
return self.user_template.format(**message)
|
@ -0,0 +1,34 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from ..config.parser import get_args
|
||||||
|
from ..core.base_trainer import BaseTrainer
|
||||||
|
from ..core.data_engine import DataEngine
|
||||||
|
from ..core.model_engine import ModelEngine
|
||||||
|
|
||||||
|
|
||||||
|
class SFTTrainer(BaseTrainer):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def run_sft():
|
||||||
|
model_args, data_args, training_args, _ = get_args()
|
||||||
|
model_engine = ModelEngine(model_args)
|
||||||
|
data_engine = DataEngine(data_args)
|
||||||
|
model = model_engine.get_model()
|
||||||
|
processor = model_engine.get_processor()
|
||||||
|
data_loader = data_engine.get_data_loader(processor)
|
||||||
|
trainer = SFTTrainer(training_args, model, processor, data_loader)
|
||||||
|
trainer.fit()
|
@ -36,8 +36,8 @@ from ..extras.misc import use_modelscope, use_openmind
|
|||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
DEFAULT_CACHE_DIR = "cache"
|
DEFAULT_CACHE_DIR = "llamaboard_cache"
|
||||||
DEFAULT_CONFIG_DIR = "config"
|
DEFAULT_CONFIG_DIR = "llamaboard_config"
|
||||||
DEFAULT_DATA_DIR = "data"
|
DEFAULT_DATA_DIR = "data"
|
||||||
DEFAULT_SAVE_DIR = "saves"
|
DEFAULT_SAVE_DIR = "saves"
|
||||||
USER_CONFIG = "user_config.yaml"
|
USER_CONFIG = "user_config.yaml"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user