From 10146029bab1e4e8d4442a9fcb2fe6acd8a877f9 Mon Sep 17 00:00:00 2001 From: Yaowei Zheng Date: Tue, 7 Oct 2025 22:34:48 +0800 Subject: [PATCH] [v1] add v1 launcher (#9236) Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- .gitignore | 4 +- data/belle_multiturn/belle_multiturn.py | 82 --------- data/dataset_info.json | 18 -- data/hh_rlhf_en/hh_rlhf_en.py | 98 ---------- data/ultra_chat/ultra_chat.py | 74 -------- data/v1_sft_demo.yaml | 8 + src/llamafactory/cli.py | 133 +------------- src/llamafactory/extras/env.py | 34 ++-- src/llamafactory/launcher.py | 169 +++++++++++++++--- .../data_loader.py => config/__init__.py} | 0 src/llamafactory/v1/config/data_args.py | 33 ++++ src/llamafactory/v1/config/model_args.py | 27 +++ src/llamafactory/v1/config/parser.py | 63 +++++++ src/llamafactory/v1/config/sample_args.py | 24 +++ src/llamafactory/v1/config/training_args.py | 40 +++++ src/llamafactory/v1/core/base_trainer.py | 35 ++++ src/llamafactory/v1/core/chat_sampler.py | 20 +++ src/llamafactory/v1/core/data_engine.py | 75 ++++++++ src/llamafactory/v1/core/model_engine.py | 27 +++ src/llamafactory/v1/extras/types.py | 32 ++++ src/llamafactory/v1/launcher.py | 53 ++++-- .../v1/plugins/data_plugins/filter.py | 0 .../v1/plugins/data_plugins/template.py | 26 +++ .../v1/plugins/model_plugins/added_token.py | 0 src/llamafactory/v1/trainers/sft_trainer.py | 34 ++++ src/llamafactory/webui/common.py | 4 +- 26 files changed, 661 insertions(+), 452 deletions(-) delete mode 100644 data/belle_multiturn/belle_multiturn.py delete mode 100644 data/hh_rlhf_en/hh_rlhf_en.py delete mode 100644 data/ultra_chat/ultra_chat.py create mode 100644 data/v1_sft_demo.yaml rename src/llamafactory/v1/{core/data_loader.py => config/__init__.py} (100%) create mode 100644 src/llamafactory/v1/config/data_args.py create mode 100644 src/llamafactory/v1/config/model_args.py create mode 100644 src/llamafactory/v1/config/parser.py create mode 100644 src/llamafactory/v1/config/sample_args.py create mode 100644 src/llamafactory/v1/config/training_args.py create mode 100644 src/llamafactory/v1/core/data_engine.py create mode 100644 src/llamafactory/v1/extras/types.py create mode 100644 src/llamafactory/v1/plugins/data_plugins/filter.py create mode 100644 src/llamafactory/v1/plugins/data_plugins/template.py create mode 100644 src/llamafactory/v1/plugins/model_plugins/added_token.py diff --git a/.gitignore b/.gitignore index 0a3a47bd..f425bd59 100644 --- a/.gitignore +++ b/.gitignore @@ -169,8 +169,8 @@ uv.lock hf_cache/ ms_cache/ om_cache/ -cache/ -config/ +llamaboard_cache/ +llamaboard_config/ saves/ output/ wandb/ diff --git a/data/belle_multiturn/belle_multiturn.py b/data/belle_multiturn/belle_multiturn.py deleted file mode 100644 index 2c2ed4da..00000000 --- a/data/belle_multiturn/belle_multiturn.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright 2025 the LlamaFactory team. -# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import os - -import datasets - - -_HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co") - -_DESCRIPTION = "BELLE multiturn chat dataset." - -_CITATION = """\ -@article{belle2023exploring, - title={Exploring the Impact of Instruction Data Scaling on Large Language Models}, - author={Yunjie Ji, Yong Deng, Yan Gong, Yiping Peng, Qiang Niu, Lei Zhang, Baochang Ma, Xiangang Li}, - journal={arXiv preprint arXiv:2303.14742}, - year={2023} -} -""" - -_HOMEPAGE = f"{_HF_ENDPOINT}/datasets/BelleGroup/multiturn_chat_0.8M" -_LICENSE = "gpl-3.0" -_URL = f"{_HF_ENDPOINT}/datasets/BelleGroup/multiturn_chat_0.8M/resolve/main/multiturn_chat_0.8M.json" - - -class BelleMultiturn(datasets.GeneratorBasedBuilder): - VERSION = datasets.Version("0.0.0") - - def _info(self): - features = datasets.Features( - {"conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}]} - ) - return datasets.DatasetInfo( - description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION - ) - - def _split_generators(self, dl_manager: datasets.DownloadManager): - file_path = dl_manager.download(_URL) - return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": file_path})] - - def _generate_examples(self, filepath: str): - with open(filepath, encoding="utf-8") as f: - for key, row in enumerate(f): - data = json.loads(row) - conversations = [] - prompt = data["instruction"].strip() - response = data["output"].strip() - - assist_idx = prompt.rfind("Assistant:") - human_idx = prompt.rfind("Human:") - query = prompt[human_idx + 6 : assist_idx].strip() - prompt = prompt[:human_idx].strip() - conversations.insert(0, {"from": "gpt", "value": response}) - conversations.insert(0, {"from": "human", "value": query}) - - while prompt.rfind("Assistant:") != -1: - assist_idx = prompt.rfind("Assistant:") - human_idx = prompt.rfind("Human:") - if human_idx != -1: - old_query = prompt[human_idx + 6 : assist_idx].strip() - old_resp = prompt[assist_idx + 10 :].strip() - conversations.insert(0, {"from": "gpt", "value": old_resp}) - conversations.insert(0, {"from": "human", "value": old_query}) - else: - break - prompt = prompt[:human_idx].strip() - - yield key, {"conversations": conversations} diff --git a/data/dataset_info.json b/data/dataset_info.json index 855c35d9..3615952c 100644 --- a/data/dataset_info.json +++ b/data/dataset_info.json @@ -143,14 +143,6 @@ "hf_hub_url": "BelleGroup/school_math_0.25M", "ms_hub_url": "AI-ModelScope/school_math_0.25M" }, - "belle_multiturn": { - "script_url": "belle_multiturn", - "formatting": "sharegpt" - }, - "ultra_chat": { - "script_url": "ultra_chat", - "formatting": "sharegpt" - }, "open_platypus": { "hf_hub_url": "garage-bAInd/Open-Platypus", "ms_hub_url": "AI-ModelScope/Open-Platypus" @@ -583,16 +575,6 @@ "system": "system" } }, - "hh_rlhf_en": { - "script_url": "hh_rlhf_en", - "ranking": true, - "columns": { - "prompt": "instruction", - "chosen": "chosen", - "rejected": "rejected", - "history": "history" - } - }, "nectar_rm": { "hf_hub_url": "AstraMindAI/RLAIF-Nectar", "ms_hub_url": "AI-ModelScope/RLAIF-Nectar", diff --git a/data/hh_rlhf_en/hh_rlhf_en.py b/data/hh_rlhf_en/hh_rlhf_en.py deleted file mode 100644 index 287eac40..00000000 --- a/data/hh_rlhf_en/hh_rlhf_en.py +++ /dev/null @@ -1,98 +0,0 @@ -# Copyright 2025 the LlamaFactory team. -# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import os - -import datasets - - -_HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co") -_DESCRIPTION = "Human preference data about helpfulness and harmlessness." -_CITATION = "" -_HOMEPAGE = f"{_HF_ENDPOINT}/datasets/Anthropic/hh-rlhf" -_LICENSE = "mit" -_URL = f"{_HF_ENDPOINT}/datasets/Anthropic/hh-rlhf/resolve/main/" -_URLS = { - "train": [ - _URL + "harmless-base/train.jsonl.gz", - _URL + "helpful-base/train.jsonl.gz", - _URL + "helpful-online/train.jsonl.gz", - _URL + "helpful-rejection-sampled/train.jsonl.gz", - ], - "test": [ - _URL + "harmless-base/test.jsonl.gz", - _URL + "helpful-base/test.jsonl.gz", - _URL + "helpful-online/test.jsonl.gz", - _URL + "helpful-rejection-sampled/test.jsonl.gz", - ], -} - - -class HhRlhfEn(datasets.GeneratorBasedBuilder): - VERSION = datasets.Version("0.0.0") - - def _info(self) -> datasets.DatasetInfo: - features = datasets.Features( - { - "instruction": datasets.Value("string"), - "chosen": datasets.Value("string"), - "rejected": datasets.Value("string"), - "history": datasets.Sequence(datasets.Sequence(datasets.Value("string"))), - } - ) - return datasets.DatasetInfo( - description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION - ) - - def _split_generators(self, dl_manager: datasets.DownloadManager): - file_path = dl_manager.download_and_extract(_URLS) - return [ - datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepaths": file_path["train"]}), - datasets.SplitGenerator(name=datasets.Split.TEST, gen_kwargs={"filepaths": file_path["test"]}), - ] - - def _generate_examples(self, filepaths: list[str]): - key = 0 - for filepath in filepaths: - with open(filepath, encoding="utf-8") as f: - for row in f: - data = json.loads(row) - chosen = data["chosen"] - rejected = data["rejected"] - - assist_idx = rejected.rfind("\n\nAssistant: ") - r_reject = rejected[assist_idx + 13 :].strip() - assist_idx = chosen.rfind("\n\nAssistant: ") - r_accept = chosen[assist_idx + 13 :].strip() - - human_idx = chosen.rfind("\n\nHuman: ") - query = chosen[human_idx + 9 : assist_idx].strip() - prompt = chosen[:human_idx] - history = [] - - while prompt.rfind("\n\nAssistant: ") != -1: - assist_idx = prompt.rfind("\n\nAssistant: ") - human_idx = prompt.rfind("\n\nHuman: ") - if human_idx != -1: - old_query = prompt[human_idx + 9 : assist_idx].strip() - old_resp = prompt[assist_idx + 13 :].strip() - history.insert(0, (old_query, old_resp)) - else: - break - prompt = prompt[:human_idx] - - yield key, {"instruction": query, "chosen": r_accept, "rejected": r_reject, "history": history} - key += 1 diff --git a/data/ultra_chat/ultra_chat.py b/data/ultra_chat/ultra_chat.py deleted file mode 100644 index 2ce17204..00000000 --- a/data/ultra_chat/ultra_chat.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright 2025 the LlamaFactory team. -# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import os - -import datasets - - -_HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co") - -_DESCRIPTION = "UltraChat: Large-scale, Informative, and Diverse Multi-round Dialogue Data." - -_CITATION = """\ -@misc{UltraChat, - author = {Ding, Ning and Chen, Yulin and Xu, Bokai and Hu, Shengding and others}, - title = {UltraChat: A Large-scale Auto-generated Multi-round Dialogue Data}, - year = {2023}, - publisher = {GitHub}, - journal = {GitHub repository}, - howpublished = {\\url{https://github.com/thunlp/ultrachat}}, -} -""" - -_HOMEPAGE = f"{_HF_ENDPOINT}/datasets/stingning/ultrachat" -_LICENSE = "cc-by-nc-4.0" -_BASE_DATA_URL = f"{_HF_ENDPOINT}/datasets/stingning/ultrachat/resolve/main/train_{{idx}}.jsonl" - - -class UltraChat(datasets.GeneratorBasedBuilder): - VERSION = datasets.Version("0.0.0") - - def _info(self): - features = datasets.Features( - {"conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}]} - ) - return datasets.DatasetInfo( - description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION - ) - - def _split_generators(self, dl_manager: datasets.DownloadManager): - file_paths = [dl_manager.download(_BASE_DATA_URL.format(idx=idx)) for idx in range(10)] # multiple shards - return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepaths": file_paths})] - - def _generate_examples(self, filepaths: list[str]): - for filepath in filepaths: - with open(filepath, encoding="utf-8") as f: - for row in f: - try: - data = json.loads(row) - except Exception: - continue - key: int = data["id"] - content: list[str] = data["data"] - if len(content) % 2 == 1: - content.pop(-1) - if len(content) < 2: - continue - conversations = [ - {"from": "human" if i % 2 == 0 else "gpt", "value": content[i]} for i in range(len(content)) - ] - yield key, {"conversations": conversations} diff --git a/data/v1_sft_demo.yaml b/data/v1_sft_demo.yaml new file mode 100644 index 00000000..dcee9134 --- /dev/null +++ b/data/v1_sft_demo.yaml @@ -0,0 +1,8 @@ +identity: + file_name: identity.json + converter: alpaca +alpaca_en_demo: + file_name: alpaca_en_demo.json + dataset_dir: ~/data + converter: alpaca + num_samples: 500 diff --git a/src/llamafactory/cli.py b/src/llamafactory/cli.py index 9649cd69..d574bf1d 100644 --- a/src/llamafactory/cli.py +++ b/src/llamafactory/cli.py @@ -12,145 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import subprocess -import sys -from copy import deepcopy -from functools import partial - - -USAGE = ( - "-" * 70 - + "\n" - + "| Usage: |\n" - + "| llamafactory-cli api -h: launch an OpenAI-style API server |\n" - + "| llamafactory-cli chat -h: launch a chat interface in CLI |\n" - + "| llamafactory-cli export -h: merge LoRA adapters and export model |\n" - + "| llamafactory-cli train -h: train models |\n" - + "| llamafactory-cli webchat -h: launch a chat interface in Web UI |\n" - + "| llamafactory-cli webui: launch LlamaBoard |\n" - + "| llamafactory-cli env: show environment info |\n" - + "| llamafactory-cli version: show version info |\n" - + "| Hint: You can use `lmf` as a shortcut for `llamafactory-cli`. |\n" - + "-" * 70 -) - def main(): - from .extras import logging - from .extras.env import VERSION, print_env - from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_ray + from .extras.misc import is_env_enabled if is_env_enabled("USE_V1"): from .v1 import launcher else: from . import launcher - logger = logging.get_logger(__name__) - - WELCOME = ( - "-" * 58 - + "\n" - + f"| Welcome to LLaMA Factory, version {VERSION}" - + " " * (21 - len(VERSION)) - + "|\n|" - + " " * 56 - + "|\n" - + "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n" - + "-" * 58 - ) - - COMMAND_MAP = { - "api": launcher.run_api, - "chat": launcher.run_chat, - "env": print_env, - "eval": launcher.run_eval, - "export": launcher.export_model, - "train": launcher.run_exp, - "webchat": launcher.run_web_demo, - "webui": launcher.run_web_ui, - "version": partial(print, WELCOME), - "help": partial(print, USAGE), - } - - command = sys.argv.pop(1) if len(sys.argv) > 1 else "help" - if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())): - # launch distributed training - nnodes = os.getenv("NNODES", "1") - node_rank = os.getenv("NODE_RANK", "0") - nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count())) - master_addr = os.getenv("MASTER_ADDR", "127.0.0.1") - master_port = os.getenv("MASTER_PORT", str(find_available_port())) - logger.info_rank0(f"Initializing {nproc_per_node} distributed tasks at: {master_addr}:{master_port}") - if int(nnodes) > 1: - logger.info_rank0(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}") - - # elastic launch support - max_restarts = os.getenv("MAX_RESTARTS", "0") - rdzv_id = os.getenv("RDZV_ID") - min_nnodes = os.getenv("MIN_NNODES") - max_nnodes = os.getenv("MAX_NNODES") - - env = deepcopy(os.environ) - if is_env_enabled("OPTIM_TORCH", "1"): - # optimize DDP, see https://zhuanlan.zhihu.com/p/671834539 - env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" - env["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" - - if rdzv_id is not None: - # launch elastic job with fault tolerant support when possible - # see also https://docs.pytorch.org/docs/stable/elastic/train_script.html - rdzv_nnodes = nnodes - # elastic number of nodes if MIN_NNODES and MAX_NNODES are set - if min_nnodes is not None and max_nnodes is not None: - rdzv_nnodes = f"{min_nnodes}:{max_nnodes}" - - process = subprocess.run( - ( - "torchrun --nnodes {rdzv_nnodes} --nproc-per-node {nproc_per_node} " - "--rdzv-id {rdzv_id} --rdzv-backend c10d --rdzv-endpoint {master_addr}:{master_port} " - "--max-restarts {max_restarts} {file_name} {args}" - ) - .format( - rdzv_nnodes=rdzv_nnodes, - nproc_per_node=nproc_per_node, - rdzv_id=rdzv_id, - master_addr=master_addr, - master_port=master_port, - max_restarts=max_restarts, - file_name=launcher.__file__, - args=" ".join(sys.argv[1:]), - ) - .split(), - env=env, - check=True, - ) - else: - # NOTE: DO NOT USE shell=True to avoid security risk - process = subprocess.run( - ( - "torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} " - "--master_addr {master_addr} --master_port {master_port} {file_name} {args}" - ) - .format( - nnodes=nnodes, - node_rank=node_rank, - nproc_per_node=nproc_per_node, - master_addr=master_addr, - master_port=master_port, - file_name=launcher.__file__, - args=" ".join(sys.argv[1:]), - ) - .split(), - env=env, - check=True, - ) - - sys.exit(process.returncode) - elif command in COMMAND_MAP: - COMMAND_MAP[command]() - else: - print(f"Unknown command: {command}.\n{USAGE}") + launcher.launch() if __name__ == "__main__": diff --git a/src/llamafactory/extras/env.py b/src/llamafactory/extras/env.py index 82d89b69..d39b9945 100644 --- a/src/llamafactory/extras/env.py +++ b/src/llamafactory/extras/env.py @@ -16,6 +16,9 @@ # limitations under the License. +from collections import OrderedDict + + VERSION = "0.9.4.dev0" @@ -28,20 +31,20 @@ def print_env() -> None: import peft import torch import transformers - import trl from transformers.utils import is_torch_cuda_available, is_torch_npu_available - info = { - "`llamafactory` version": VERSION, - "Platform": platform.platform(), - "Python version": platform.python_version(), - "PyTorch version": torch.__version__, - "Transformers version": transformers.__version__, - "Datasets version": datasets.__version__, - "Accelerate version": accelerate.__version__, - "PEFT version": peft.__version__, - "TRL version": trl.__version__, - } + info = OrderedDict( + { + "`llamafactory` version": VERSION, + "Platform": platform.platform(), + "Python version": platform.python_version(), + "PyTorch version": torch.__version__, + "Transformers version": transformers.__version__, + "Datasets version": datasets.__version__, + "Accelerate version": accelerate.__version__, + "PEFT version": peft.__version__, + } + ) if is_torch_cuda_available(): info["PyTorch version"] += " (GPU)" @@ -54,6 +57,13 @@ def print_env() -> None: info["NPU type"] = torch.npu.get_device_name() info["CANN version"] = torch.version.cann + try: + import trl # type: ignore + + info["TRL version"] = trl.__version__ + except Exception: + pass + try: import deepspeed # type: ignore diff --git a/src/llamafactory/launcher.py b/src/llamafactory/launcher.py index 8b4ae586..ab955c03 100644 --- a/src/llamafactory/launcher.py +++ b/src/llamafactory/launcher.py @@ -12,46 +12,169 @@ # See the License for the specific language governing permissions and # limitations under the License. - -def run_api(): - from llamafactory.api.app import run_api as _run_api - - _run_api() +import os +import subprocess +import sys +from copy import deepcopy -def run_chat(): - from llamafactory.chat.chat_model import run_chat as _run_chat - - return _run_chat() +USAGE = ( + "-" * 70 + + "\n" + + "| Usage: |\n" + + "| llamafactory-cli api -h: launch an OpenAI-style API server |\n" + + "| llamafactory-cli chat -h: launch a chat interface in CLI |\n" + + "| llamafactory-cli export -h: merge LoRA adapters and export model |\n" + + "| llamafactory-cli train -h: train models |\n" + + "| llamafactory-cli webchat -h: launch a chat interface in Web UI |\n" + + "| llamafactory-cli webui: launch LlamaBoard |\n" + + "| llamafactory-cli env: show environment info |\n" + + "| llamafactory-cli version: show version info |\n" + + "| Hint: You can use `lmf` as a shortcut for `llamafactory-cli`. |\n" + + "-" * 70 +) -def run_eval(): - raise NotImplementedError("Evaluation will be deprecated in the future.") +def launch(): + from .extras import logging + from .extras.env import VERSION, print_env + from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_ray + logger = logging.get_logger(__name__) + WELCOME = ( + "-" * 58 + + "\n" + + f"| Welcome to LLaMA Factory, version {VERSION}" + + " " * (21 - len(VERSION)) + + "|\n|" + + " " * 56 + + "|\n" + + "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n" + + "-" * 58 + ) -def export_model(): - from llamafactory.train.tuner import export_model as _export_model + command = sys.argv.pop(1) if len(sys.argv) > 1 else "help" + if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())): + # launch distributed training + nnodes = os.getenv("NNODES", "1") + node_rank = os.getenv("NODE_RANK", "0") + nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count())) + master_addr = os.getenv("MASTER_ADDR", "127.0.0.1") + master_port = os.getenv("MASTER_PORT", str(find_available_port())) + logger.info_rank0(f"Initializing {nproc_per_node} distributed tasks at: {master_addr}:{master_port}") + if int(nnodes) > 1: + logger.info_rank0(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}") - return _export_model() + # elastic launch support + max_restarts = os.getenv("MAX_RESTARTS", "0") + rdzv_id = os.getenv("RDZV_ID") + min_nnodes = os.getenv("MIN_NNODES") + max_nnodes = os.getenv("MAX_NNODES") + env = deepcopy(os.environ) + if is_env_enabled("OPTIM_TORCH", "1"): + # optimize DDP, see https://zhuanlan.zhihu.com/p/671834539 + env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" + env["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" -def run_exp(): - from llamafactory.train.tuner import run_exp as _run_exp + if rdzv_id is not None: + # launch elastic job with fault tolerant support when possible + # see also https://docs.pytorch.org/docs/stable/elastic/train_script.html + rdzv_nnodes = nnodes + # elastic number of nodes if MIN_NNODES and MAX_NNODES are set + if min_nnodes is not None and max_nnodes is not None: + rdzv_nnodes = f"{min_nnodes}:{max_nnodes}" - return _run_exp() # use absolute import + process = subprocess.run( + ( + "torchrun --nnodes {rdzv_nnodes} --nproc-per-node {nproc_per_node} " + "--rdzv-id {rdzv_id} --rdzv-backend c10d --rdzv-endpoint {master_addr}:{master_port} " + "--max-restarts {max_restarts} {file_name} {args}" + ) + .format( + rdzv_nnodes=rdzv_nnodes, + nproc_per_node=nproc_per_node, + rdzv_id=rdzv_id, + master_addr=master_addr, + master_port=master_port, + max_restarts=max_restarts, + file_name=__file__, + args=" ".join(sys.argv[1:]), + ) + .split(), + env=env, + check=True, + ) + else: + # NOTE: DO NOT USE shell=True to avoid security risk + process = subprocess.run( + ( + "torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} " + "--master_addr {master_addr} --master_port {master_port} {file_name} {args}" + ) + .format( + nnodes=nnodes, + node_rank=node_rank, + nproc_per_node=nproc_per_node, + master_addr=master_addr, + master_port=master_port, + file_name=__file__, + args=" ".join(sys.argv[1:]), + ) + .split(), + env=env, + check=True, + ) + sys.exit(process.returncode) -def run_web_demo(): - from llamafactory.webui.interface import run_web_demo as _run_web_demo + elif command == "api": + from .api.app import run_api - return _run_web_demo() + run_api() + elif command == "chat": + from .chat.chat_model import run_chat -def run_web_ui(): - from llamafactory.webui.interface import run_web_ui as _run_web_ui + run_chat() - return _run_web_ui() + elif command == "eval": + raise NotImplementedError("Evaluation will be deprecated in the future.") + + elif command == "export": + from .train.tuner import export_model + + export_model() + + elif command == "train": + from .train.tuner import run_exp + + run_exp() + + elif command == "webchat": + from .webui.interface import run_web_demo + + run_web_demo() + + elif command == "webui": + from .webui.interface import run_web_ui + + run_web_ui() + + elif command == "env": + print_env() + + elif command == "version": + print(WELCOME) + + elif command == "help": + print(USAGE) + + else: + print(f"Unknown command: {command}.\n{USAGE}") if __name__ == "__main__": + from llamafactory.train.tuner import run_exp # use absolute import + run_exp() diff --git a/src/llamafactory/v1/core/data_loader.py b/src/llamafactory/v1/config/__init__.py similarity index 100% rename from src/llamafactory/v1/core/data_loader.py rename to src/llamafactory/v1/config/__init__.py diff --git a/src/llamafactory/v1/config/data_args.py b/src/llamafactory/v1/config/data_args.py new file mode 100644 index 00000000..28829642 --- /dev/null +++ b/src/llamafactory/v1/config/data_args.py @@ -0,0 +1,33 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass, field +from typing import Optional + + +@dataclass +class DataArguments: + dataset: Optional[str] = field( + default=None, + metadata={"help": "Path to the dataset."}, + ) + dataset_dir: str = field( + default="data", + metadata={"help": "Path to the folder containing the datasets."}, + ) + cutoff_len: int = field( + default=2048, + metadata={"help": "Cutoff length for the dataset."}, + ) diff --git a/src/llamafactory/v1/config/model_args.py b/src/llamafactory/v1/config/model_args.py new file mode 100644 index 00000000..b9a98660 --- /dev/null +++ b/src/llamafactory/v1/config/model_args.py @@ -0,0 +1,27 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass, field + + +@dataclass +class ModelArguments: + model: str = field( + metadata={"help": "Path to the model or model identifier from Hugging Face."}, + ) + trust_remote_code: bool = field( + default=False, + metadata={"help": "Trust remote code from Hugging Face."}, + ) diff --git a/src/llamafactory/v1/config/parser.py b/src/llamafactory/v1/config/parser.py new file mode 100644 index 00000000..eca1749f --- /dev/null +++ b/src/llamafactory/v1/config/parser.py @@ -0,0 +1,63 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +import sys +from pathlib import Path +from typing import Any, Optional, Union + +from omegaconf import OmegaConf +from transformers import HfArgumentParser + +from ...extras.misc import is_env_enabled +from .data_args import DataArguments +from .model_args import ModelArguments +from .sample_args import SampleArguments +from .training_args import TrainingArguments + + +def get_args( + args: Optional[Union[dict[str, Any], list[str]]] = None, +) -> tuple[DataArguments, ModelArguments, TrainingArguments, SampleArguments]: + """Parse arguments from command line or config file.""" + parser = HfArgumentParser([DataArguments, ModelArguments, TrainingArguments, SampleArguments]) + allow_extra_keys = is_env_enabled("ALLOW_EXTRA_KEYS") + + if args is None: + if len(sys.argv) > 1 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")): + override_config = OmegaConf.from_cli(sys.argv[2:]) + dict_config = OmegaConf.load(Path(sys.argv[1]).absolute()) + args = OmegaConf.to_container(OmegaConf.merge(dict_config, override_config)) + elif len(sys.argv) > 1 and sys.argv[1].endswith(".json"): + override_config = OmegaConf.from_cli(sys.argv[2:]) + dict_config = OmegaConf.create(json.load(Path(sys.argv[1]).absolute())) + args = OmegaConf.to_container(OmegaConf.merge(dict_config, override_config)) + else: # list of strings + args = sys.argv[1:] + + if isinstance(args, dict): + (*parsed_args,) = parser.parse_dict(args, allow_extra_keys=allow_extra_keys) + else: + (*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(args, return_remaining_strings=True) + if unknown_args and not allow_extra_keys: + print(parser.format_help()) + print(f"Got unknown args, potentially deprecated arguments: {unknown_args}") + raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}") + + return tuple(parsed_args) + + +if __name__ == "__main__": + print(get_args()) diff --git a/src/llamafactory/v1/config/sample_args.py b/src/llamafactory/v1/config/sample_args.py new file mode 100644 index 00000000..666efb01 --- /dev/null +++ b/src/llamafactory/v1/config/sample_args.py @@ -0,0 +1,24 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass, field + + +@dataclass +class SampleArguments: + max_new_tokens: int = field( + default=128, + metadata={"help": "Maximum number of new tokens to generate."}, + ) diff --git a/src/llamafactory/v1/config/training_args.py b/src/llamafactory/v1/config/training_args.py new file mode 100644 index 00000000..38d62ecf --- /dev/null +++ b/src/llamafactory/v1/config/training_args.py @@ -0,0 +1,40 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass, field + + +@dataclass +class TrainingArguments: + output_dir: str = field( + default="", + metadata={"help": "Path to the output directory."}, + ) + micro_batch_size: int = field( + default=1, + metadata={"help": "Micro batch size for training."}, + ) + global_batch_size: int = field( + default=1, + metadata={"help": "Global batch size for training."}, + ) + learning_rate: float = field( + default=1e-4, + metadata={"help": "Learning rate for training."}, + ) + bf16: bool = field( + default=False, + metadata={"help": "Use bf16 for training."}, + ) diff --git a/src/llamafactory/v1/core/base_trainer.py b/src/llamafactory/v1/core/base_trainer.py index e69de29b..9dca8a0a 100644 --- a/src/llamafactory/v1/core/base_trainer.py +++ b/src/llamafactory/v1/core/base_trainer.py @@ -0,0 +1,35 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..config.training_args import TrainingArguments +from ..extras.types import DataLoader, Model, Processor + + +class BaseTrainer: + def __init__( + self, + args: TrainingArguments, + model: Model, + processor: Processor, + data_loader: DataLoader, + ) -> None: + self.args = args + self.model = model + self.processor = processor + self.data_loader = data_loader + self.optimizer = None + self.lr_scheduler = None + + def fit(self) -> None: + pass diff --git a/src/llamafactory/v1/core/chat_sampler.py b/src/llamafactory/v1/core/chat_sampler.py index e69de29b..3f213914 100644 --- a/src/llamafactory/v1/core/chat_sampler.py +++ b/src/llamafactory/v1/core/chat_sampler.py @@ -0,0 +1,20 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..config.sample_args import SampleArguments + + +class ChatSampler: + def __init__(self, sample_args: SampleArguments) -> None: + self.args = sample_args diff --git a/src/llamafactory/v1/core/data_engine.py b/src/llamafactory/v1/core/data_engine.py new file mode 100644 index 00000000..3e58667c --- /dev/null +++ b/src/llamafactory/v1/core/data_engine.py @@ -0,0 +1,75 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from datasets import load_dataset +from huggingface_hub import hf_hub_download +from omegaconf import OmegaConf + +from ..config.data_args import DataArguments +from ..extras.types import DataLoader, Dataset, Processor + + +class DataCollator: + def __init__(self, processor: Processor) -> None: + self.processor = processor + + +class DatasetPathMixin: + args: DataArguments + + def _abspath(self, path: str) -> str: + return os.path.abspath(os.path.expanduser(os.path.join(self.args.dataset_dir, path))) + + def _exists(self, path: str) -> bool: + return os.path.exists(self._abspath(path)) + + def _isfile(self, path: str) -> bool: + return os.path.isfile(self._abspath(path)) + + +class DataEngine(DatasetPathMixin): + def __init__(self, data_args: DataArguments) -> None: + self.args = data_args + self.datasets: dict[str, Dataset] = {} + dataset_info = self.get_dataset_info() + self.load_dataset(dataset_info) + + def get_dataset_info(self) -> dict: + """Get dataset info from dataset path. + + Returns: + dict: Dataset info. + """ + if self.args.dataset.endswith(".yaml") and self._isfile(self.args.dataset): # local file + return OmegaConf.load(self._abspath(self.args.dataset)) + elif self.args.dataset.endswith(".yaml"): # hf hub uri + repo_id, filename = os.path.split(self.args.dataset) + filepath = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset") + return OmegaConf.load(filepath) + elif self._exists(self.args.dataset): # local file(s) + return {"default": {"file_name": self.args.dataset}} + else: # hf hub dataset + return {"default": {"hf_hub_url": self.args.dataset}} + + def load_dataset(self, dataset_info: dict) -> None: + for key, value in dataset_info.items(): + if "hf_hub_url" in value: + dataset_info[key] = load_dataset(value["hf_hub_url"]) + elif "file_name" in value: + dataset_info[key] = load_dataset(value["file_name"]) + + def get_data_loader(self, processor: Processor) -> DataLoader: + pass diff --git a/src/llamafactory/v1/core/model_engine.py b/src/llamafactory/v1/core/model_engine.py index e69de29b..24d2d4b7 100644 --- a/src/llamafactory/v1/core/model_engine.py +++ b/src/llamafactory/v1/core/model_engine.py @@ -0,0 +1,27 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..config.model_args import ModelArguments +from ..extras.types import Model, Processor + + +class ModelEngine: + def __init__(self, model_args: ModelArguments) -> None: + self.args = model_args + + def get_model(self) -> Model: + pass + + def get_processor(self) -> Processor: + pass diff --git a/src/llamafactory/v1/extras/types.py b/src/llamafactory/v1/extras/types.py new file mode 100644 index 00000000..a160e3a7 --- /dev/null +++ b/src/llamafactory/v1/extras/types.py @@ -0,0 +1,32 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Union + + +if TYPE_CHECKING: + from datasets import Dataset as HFDataset + from datasets import IterableDataset + from torch.utils.data import DataLoader as TorchDataLoader + from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin + + Dataset = Union[HFDataset, IterableDataset] + DataLoader = TorchDataLoader + Model = PreTrainedModel + Processor = Union[PreTrainedTokenizer, ProcessorMixin] +else: + Dataset = None + DataLoader = None + Model = None + Processor = None diff --git a/src/llamafactory/v1/launcher.py b/src/llamafactory/v1/launcher.py index 36f87403..6160835b 100644 --- a/src/llamafactory/v1/launcher.py +++ b/src/llamafactory/v1/launcher.py @@ -12,22 +12,55 @@ # See the License for the specific language governing permissions and # limitations under the License. +import sys -def run_train(): - raise NotImplementedError("Please use `llamafactory-cli sft` or `llamafactory-cli rm`.") +from ..extras.env import VERSION, print_env -def run_chat(): - from llamafactory.v1.core.chat_sampler import Sampler - - Sampler().cli() +USAGE = ( + "-" * 70 + + "\n" + + "| Usage: |\n" + + "| llamafactory-cli sft -h: train models |\n" + + "| llamafactory-cli version: show version info |\n" + + "| Hint: You can use `lmf` as a shortcut for `llamafactory-cli`. |\n" + + "-" * 70 +) -def run_sft(): - from llamafactory.v1.train.sft import SFTTrainer +WELCOME = ( + "-" * 58 + + "\n" + + f"| Welcome to LLaMA Factory, version {VERSION}" + + " " * (21 - len(VERSION)) + + "|\n|" + + " " * 56 + + "|\n" + + "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n" + + "-" * 58 +) - SFTTrainer().run() + +def launch(): + command = sys.argv.pop(1) if len(sys.argv) > 1 else "help" + + if command == "sft": + from .trainers.sft_trainer import run_sft + + run_sft() + + elif command == "env": + print_env() + + elif command == "version": + print(WELCOME) + + elif command == "help": + print(USAGE) + + else: + print(f"Unknown command: {command}.\n{USAGE}") if __name__ == "__main__": - run_train() + pass diff --git a/src/llamafactory/v1/plugins/data_plugins/filter.py b/src/llamafactory/v1/plugins/data_plugins/filter.py new file mode 100644 index 00000000..e69de29b diff --git a/src/llamafactory/v1/plugins/data_plugins/template.py b/src/llamafactory/v1/plugins/data_plugins/template.py new file mode 100644 index 00000000..cf41389f --- /dev/null +++ b/src/llamafactory/v1/plugins/data_plugins/template.py @@ -0,0 +1,26 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass + + +@dataclass +class Template: + user_template: str + assistant_template: str + system_template: str + + def render_message(self, message: "dict[str, str]") -> str: + return self.user_template.format(**message) diff --git a/src/llamafactory/v1/plugins/model_plugins/added_token.py b/src/llamafactory/v1/plugins/model_plugins/added_token.py new file mode 100644 index 00000000..e69de29b diff --git a/src/llamafactory/v1/trainers/sft_trainer.py b/src/llamafactory/v1/trainers/sft_trainer.py index e69de29b..5d254e4c 100644 --- a/src/llamafactory/v1/trainers/sft_trainer.py +++ b/src/llamafactory/v1/trainers/sft_trainer.py @@ -0,0 +1,34 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ..config.parser import get_args +from ..core.base_trainer import BaseTrainer +from ..core.data_engine import DataEngine +from ..core.model_engine import ModelEngine + + +class SFTTrainer(BaseTrainer): + pass + + +def run_sft(): + model_args, data_args, training_args, _ = get_args() + model_engine = ModelEngine(model_args) + data_engine = DataEngine(data_args) + model = model_engine.get_model() + processor = model_engine.get_processor() + data_loader = data_engine.get_data_loader(processor) + trainer = SFTTrainer(training_args, model, processor, data_loader) + trainer.fit() diff --git a/src/llamafactory/webui/common.py b/src/llamafactory/webui/common.py index b7004a91..a8e829f4 100644 --- a/src/llamafactory/webui/common.py +++ b/src/llamafactory/webui/common.py @@ -36,8 +36,8 @@ from ..extras.misc import use_modelscope, use_openmind logger = logging.get_logger(__name__) -DEFAULT_CACHE_DIR = "cache" -DEFAULT_CONFIG_DIR = "config" +DEFAULT_CACHE_DIR = "llamaboard_cache" +DEFAULT_CONFIG_DIR = "llamaboard_config" DEFAULT_DATA_DIR = "data" DEFAULT_SAVE_DIR = "saves" USER_CONFIG = "user_config.yaml"