mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	Compare commits
	
		
			2 Commits
		
	
	
		
			d5bb4e6394
			...
			10146029ba
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 
						 | 
					10146029ba | ||
| 
						 | 
					95b7188090 | 
							
								
								
									
										4
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@ -169,8 +169,8 @@ uv.lock
 | 
			
		||||
hf_cache/
 | 
			
		||||
ms_cache/
 | 
			
		||||
om_cache/
 | 
			
		||||
cache/
 | 
			
		||||
config/
 | 
			
		||||
llamaboard_cache/
 | 
			
		||||
llamaboard_config/
 | 
			
		||||
saves/
 | 
			
		||||
output/
 | 
			
		||||
wandb/
 | 
			
		||||
 | 
			
		||||
@ -1,82 +0,0 @@
 | 
			
		||||
# Copyright 2025 the LlamaFactory team.
 | 
			
		||||
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import json
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
import datasets
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co")
 | 
			
		||||
 | 
			
		||||
_DESCRIPTION = "BELLE multiturn chat dataset."
 | 
			
		||||
 | 
			
		||||
_CITATION = """\
 | 
			
		||||
@article{belle2023exploring,
 | 
			
		||||
  title={Exploring the Impact of Instruction Data Scaling on Large Language Models},
 | 
			
		||||
  author={Yunjie Ji, Yong Deng, Yan Gong, Yiping Peng, Qiang Niu, Lei Zhang, Baochang Ma, Xiangang Li},
 | 
			
		||||
  journal={arXiv preprint arXiv:2303.14742},
 | 
			
		||||
  year={2023}
 | 
			
		||||
}
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
_HOMEPAGE = f"{_HF_ENDPOINT}/datasets/BelleGroup/multiturn_chat_0.8M"
 | 
			
		||||
_LICENSE = "gpl-3.0"
 | 
			
		||||
_URL = f"{_HF_ENDPOINT}/datasets/BelleGroup/multiturn_chat_0.8M/resolve/main/multiturn_chat_0.8M.json"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BelleMultiturn(datasets.GeneratorBasedBuilder):
 | 
			
		||||
    VERSION = datasets.Version("0.0.0")
 | 
			
		||||
 | 
			
		||||
    def _info(self):
 | 
			
		||||
        features = datasets.Features(
 | 
			
		||||
            {"conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}]}
 | 
			
		||||
        )
 | 
			
		||||
        return datasets.DatasetInfo(
 | 
			
		||||
            description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _split_generators(self, dl_manager: datasets.DownloadManager):
 | 
			
		||||
        file_path = dl_manager.download(_URL)
 | 
			
		||||
        return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": file_path})]
 | 
			
		||||
 | 
			
		||||
    def _generate_examples(self, filepath: str):
 | 
			
		||||
        with open(filepath, encoding="utf-8") as f:
 | 
			
		||||
            for key, row in enumerate(f):
 | 
			
		||||
                data = json.loads(row)
 | 
			
		||||
                conversations = []
 | 
			
		||||
                prompt = data["instruction"].strip()
 | 
			
		||||
                response = data["output"].strip()
 | 
			
		||||
 | 
			
		||||
                assist_idx = prompt.rfind("Assistant:")
 | 
			
		||||
                human_idx = prompt.rfind("Human:")
 | 
			
		||||
                query = prompt[human_idx + 6 : assist_idx].strip()
 | 
			
		||||
                prompt = prompt[:human_idx].strip()
 | 
			
		||||
                conversations.insert(0, {"from": "gpt", "value": response})
 | 
			
		||||
                conversations.insert(0, {"from": "human", "value": query})
 | 
			
		||||
 | 
			
		||||
                while prompt.rfind("Assistant:") != -1:
 | 
			
		||||
                    assist_idx = prompt.rfind("Assistant:")
 | 
			
		||||
                    human_idx = prompt.rfind("Human:")
 | 
			
		||||
                    if human_idx != -1:
 | 
			
		||||
                        old_query = prompt[human_idx + 6 : assist_idx].strip()
 | 
			
		||||
                        old_resp = prompt[assist_idx + 10 :].strip()
 | 
			
		||||
                        conversations.insert(0, {"from": "gpt", "value": old_resp})
 | 
			
		||||
                        conversations.insert(0, {"from": "human", "value": old_query})
 | 
			
		||||
                    else:
 | 
			
		||||
                        break
 | 
			
		||||
                    prompt = prompt[:human_idx].strip()
 | 
			
		||||
 | 
			
		||||
                yield key, {"conversations": conversations}
 | 
			
		||||
@ -143,14 +143,6 @@
 | 
			
		||||
    "hf_hub_url": "BelleGroup/school_math_0.25M",
 | 
			
		||||
    "ms_hub_url": "AI-ModelScope/school_math_0.25M"
 | 
			
		||||
  },
 | 
			
		||||
  "belle_multiturn": {
 | 
			
		||||
    "script_url": "belle_multiturn",
 | 
			
		||||
    "formatting": "sharegpt"
 | 
			
		||||
  },
 | 
			
		||||
  "ultra_chat": {
 | 
			
		||||
    "script_url": "ultra_chat",
 | 
			
		||||
    "formatting": "sharegpt"
 | 
			
		||||
  },
 | 
			
		||||
  "open_platypus": {
 | 
			
		||||
    "hf_hub_url": "garage-bAInd/Open-Platypus",
 | 
			
		||||
    "ms_hub_url": "AI-ModelScope/Open-Platypus"
 | 
			
		||||
@ -583,16 +575,6 @@
 | 
			
		||||
      "system": "system"
 | 
			
		||||
    }
 | 
			
		||||
  },
 | 
			
		||||
  "hh_rlhf_en": {
 | 
			
		||||
    "script_url": "hh_rlhf_en",
 | 
			
		||||
    "ranking": true,
 | 
			
		||||
    "columns": {
 | 
			
		||||
      "prompt": "instruction",
 | 
			
		||||
      "chosen": "chosen",
 | 
			
		||||
      "rejected": "rejected",
 | 
			
		||||
      "history": "history"
 | 
			
		||||
    }
 | 
			
		||||
  },
 | 
			
		||||
  "nectar_rm": {
 | 
			
		||||
    "hf_hub_url": "AstraMindAI/RLAIF-Nectar",
 | 
			
		||||
    "ms_hub_url": "AI-ModelScope/RLAIF-Nectar",
 | 
			
		||||
 | 
			
		||||
@ -1,98 +0,0 @@
 | 
			
		||||
# Copyright 2025 the LlamaFactory team.
 | 
			
		||||
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import json
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
import datasets
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co")
 | 
			
		||||
_DESCRIPTION = "Human preference data about helpfulness and harmlessness."
 | 
			
		||||
_CITATION = ""
 | 
			
		||||
_HOMEPAGE = f"{_HF_ENDPOINT}/datasets/Anthropic/hh-rlhf"
 | 
			
		||||
_LICENSE = "mit"
 | 
			
		||||
_URL = f"{_HF_ENDPOINT}/datasets/Anthropic/hh-rlhf/resolve/main/"
 | 
			
		||||
_URLS = {
 | 
			
		||||
    "train": [
 | 
			
		||||
        _URL + "harmless-base/train.jsonl.gz",
 | 
			
		||||
        _URL + "helpful-base/train.jsonl.gz",
 | 
			
		||||
        _URL + "helpful-online/train.jsonl.gz",
 | 
			
		||||
        _URL + "helpful-rejection-sampled/train.jsonl.gz",
 | 
			
		||||
    ],
 | 
			
		||||
    "test": [
 | 
			
		||||
        _URL + "harmless-base/test.jsonl.gz",
 | 
			
		||||
        _URL + "helpful-base/test.jsonl.gz",
 | 
			
		||||
        _URL + "helpful-online/test.jsonl.gz",
 | 
			
		||||
        _URL + "helpful-rejection-sampled/test.jsonl.gz",
 | 
			
		||||
    ],
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HhRlhfEn(datasets.GeneratorBasedBuilder):
 | 
			
		||||
    VERSION = datasets.Version("0.0.0")
 | 
			
		||||
 | 
			
		||||
    def _info(self) -> datasets.DatasetInfo:
 | 
			
		||||
        features = datasets.Features(
 | 
			
		||||
            {
 | 
			
		||||
                "instruction": datasets.Value("string"),
 | 
			
		||||
                "chosen": datasets.Value("string"),
 | 
			
		||||
                "rejected": datasets.Value("string"),
 | 
			
		||||
                "history": datasets.Sequence(datasets.Sequence(datasets.Value("string"))),
 | 
			
		||||
            }
 | 
			
		||||
        )
 | 
			
		||||
        return datasets.DatasetInfo(
 | 
			
		||||
            description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _split_generators(self, dl_manager: datasets.DownloadManager):
 | 
			
		||||
        file_path = dl_manager.download_and_extract(_URLS)
 | 
			
		||||
        return [
 | 
			
		||||
            datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepaths": file_path["train"]}),
 | 
			
		||||
            datasets.SplitGenerator(name=datasets.Split.TEST, gen_kwargs={"filepaths": file_path["test"]}),
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
    def _generate_examples(self, filepaths: list[str]):
 | 
			
		||||
        key = 0
 | 
			
		||||
        for filepath in filepaths:
 | 
			
		||||
            with open(filepath, encoding="utf-8") as f:
 | 
			
		||||
                for row in f:
 | 
			
		||||
                    data = json.loads(row)
 | 
			
		||||
                    chosen = data["chosen"]
 | 
			
		||||
                    rejected = data["rejected"]
 | 
			
		||||
 | 
			
		||||
                    assist_idx = rejected.rfind("\n\nAssistant: ")
 | 
			
		||||
                    r_reject = rejected[assist_idx + 13 :].strip()
 | 
			
		||||
                    assist_idx = chosen.rfind("\n\nAssistant: ")
 | 
			
		||||
                    r_accept = chosen[assist_idx + 13 :].strip()
 | 
			
		||||
 | 
			
		||||
                    human_idx = chosen.rfind("\n\nHuman: ")
 | 
			
		||||
                    query = chosen[human_idx + 9 : assist_idx].strip()
 | 
			
		||||
                    prompt = chosen[:human_idx]
 | 
			
		||||
                    history = []
 | 
			
		||||
 | 
			
		||||
                    while prompt.rfind("\n\nAssistant: ") != -1:
 | 
			
		||||
                        assist_idx = prompt.rfind("\n\nAssistant: ")
 | 
			
		||||
                        human_idx = prompt.rfind("\n\nHuman: ")
 | 
			
		||||
                        if human_idx != -1:
 | 
			
		||||
                            old_query = prompt[human_idx + 9 : assist_idx].strip()
 | 
			
		||||
                            old_resp = prompt[assist_idx + 13 :].strip()
 | 
			
		||||
                            history.insert(0, (old_query, old_resp))
 | 
			
		||||
                        else:
 | 
			
		||||
                            break
 | 
			
		||||
                        prompt = prompt[:human_idx]
 | 
			
		||||
 | 
			
		||||
                    yield key, {"instruction": query, "chosen": r_accept, "rejected": r_reject, "history": history}
 | 
			
		||||
                    key += 1
 | 
			
		||||
@ -1,74 +0,0 @@
 | 
			
		||||
# Copyright 2025 the LlamaFactory team.
 | 
			
		||||
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import json
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
import datasets
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co")
 | 
			
		||||
 | 
			
		||||
_DESCRIPTION = "UltraChat: Large-scale, Informative, and Diverse Multi-round Dialogue Data."
 | 
			
		||||
 | 
			
		||||
_CITATION = """\
 | 
			
		||||
@misc{UltraChat,
 | 
			
		||||
  author = {Ding, Ning and Chen, Yulin and Xu, Bokai and Hu, Shengding and others},
 | 
			
		||||
  title = {UltraChat: A Large-scale Auto-generated Multi-round Dialogue Data},
 | 
			
		||||
  year = {2023},
 | 
			
		||||
  publisher = {GitHub},
 | 
			
		||||
  journal = {GitHub repository},
 | 
			
		||||
  howpublished = {\\url{https://github.com/thunlp/ultrachat}},
 | 
			
		||||
}
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
_HOMEPAGE = f"{_HF_ENDPOINT}/datasets/stingning/ultrachat"
 | 
			
		||||
_LICENSE = "cc-by-nc-4.0"
 | 
			
		||||
_BASE_DATA_URL = f"{_HF_ENDPOINT}/datasets/stingning/ultrachat/resolve/main/train_{{idx}}.jsonl"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class UltraChat(datasets.GeneratorBasedBuilder):
 | 
			
		||||
    VERSION = datasets.Version("0.0.0")
 | 
			
		||||
 | 
			
		||||
    def _info(self):
 | 
			
		||||
        features = datasets.Features(
 | 
			
		||||
            {"conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}]}
 | 
			
		||||
        )
 | 
			
		||||
        return datasets.DatasetInfo(
 | 
			
		||||
            description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _split_generators(self, dl_manager: datasets.DownloadManager):
 | 
			
		||||
        file_paths = [dl_manager.download(_BASE_DATA_URL.format(idx=idx)) for idx in range(10)]  # multiple shards
 | 
			
		||||
        return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepaths": file_paths})]
 | 
			
		||||
 | 
			
		||||
    def _generate_examples(self, filepaths: list[str]):
 | 
			
		||||
        for filepath in filepaths:
 | 
			
		||||
            with open(filepath, encoding="utf-8") as f:
 | 
			
		||||
                for row in f:
 | 
			
		||||
                    try:
 | 
			
		||||
                        data = json.loads(row)
 | 
			
		||||
                    except Exception:
 | 
			
		||||
                        continue
 | 
			
		||||
                    key: int = data["id"]
 | 
			
		||||
                    content: list[str] = data["data"]
 | 
			
		||||
                    if len(content) % 2 == 1:
 | 
			
		||||
                        content.pop(-1)
 | 
			
		||||
                    if len(content) < 2:
 | 
			
		||||
                        continue
 | 
			
		||||
                    conversations = [
 | 
			
		||||
                        {"from": "human" if i % 2 == 0 else "gpt", "value": content[i]} for i in range(len(content))
 | 
			
		||||
                    ]
 | 
			
		||||
                    yield key, {"conversations": conversations}
 | 
			
		||||
							
								
								
									
										8
									
								
								data/v1_sft_demo.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								data/v1_sft_demo.yaml
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,8 @@
 | 
			
		||||
identity:
 | 
			
		||||
  file_name: identity.json
 | 
			
		||||
  converter: alpaca
 | 
			
		||||
alpaca_en_demo:
 | 
			
		||||
  file_name: alpaca_en_demo.json
 | 
			
		||||
  dataset_dir: ~/data
 | 
			
		||||
  converter: alpaca
 | 
			
		||||
  num_samples: 500
 | 
			
		||||
@ -26,7 +26,7 @@ from ..extras import logging
 | 
			
		||||
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
 | 
			
		||||
from ..extras.misc import is_env_enabled
 | 
			
		||||
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
 | 
			
		||||
from .common import dictify, jsonify
 | 
			
		||||
from .common import check_lfi_path, check_ssrf_url, dictify, jsonify
 | 
			
		||||
from .protocol import (
 | 
			
		||||
    ChatCompletionMessage,
 | 
			
		||||
    ChatCompletionResponse,
 | 
			
		||||
@ -121,8 +121,10 @@ def _process_request(
 | 
			
		||||
                    if re.match(r"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$", image_url):  # base64 image
 | 
			
		||||
                        image_stream = io.BytesIO(base64.b64decode(image_url.split(",", maxsplit=1)[1]))
 | 
			
		||||
                    elif os.path.isfile(image_url):  # local file
 | 
			
		||||
                        check_lfi_path(image_url)
 | 
			
		||||
                        image_stream = open(image_url, "rb")
 | 
			
		||||
                    else:  # web uri
 | 
			
		||||
                        check_ssrf_url(image_url)
 | 
			
		||||
                        image_stream = requests.get(image_url, stream=True).raw
 | 
			
		||||
 | 
			
		||||
                    images.append(Image.open(image_stream).convert("RGB"))
 | 
			
		||||
@ -132,8 +134,10 @@ def _process_request(
 | 
			
		||||
                    if re.match(r"^data:video\/(mp4|mkv|avi|mov);base64,(.+)$", video_url):  # base64 video
 | 
			
		||||
                        video_stream = io.BytesIO(base64.b64decode(video_url.split(",", maxsplit=1)[1]))
 | 
			
		||||
                    elif os.path.isfile(video_url):  # local file
 | 
			
		||||
                        check_lfi_path(video_url)
 | 
			
		||||
                        video_stream = video_url
 | 
			
		||||
                    else:  # web uri
 | 
			
		||||
                        check_ssrf_url(video_url)
 | 
			
		||||
                        video_stream = requests.get(video_url, stream=True).raw
 | 
			
		||||
 | 
			
		||||
                    videos.append(video_stream)
 | 
			
		||||
@ -143,8 +147,10 @@ def _process_request(
 | 
			
		||||
                    if re.match(r"^data:audio\/(mpeg|mp3|wav|ogg);base64,(.+)$", audio_url):  # base64 audio
 | 
			
		||||
                        audio_stream = io.BytesIO(base64.b64decode(audio_url.split(",", maxsplit=1)[1]))
 | 
			
		||||
                    elif os.path.isfile(audio_url):  # local file
 | 
			
		||||
                        check_lfi_path(audio_url)
 | 
			
		||||
                        audio_stream = audio_url
 | 
			
		||||
                    else:  # web uri
 | 
			
		||||
                        check_ssrf_url(audio_url)
 | 
			
		||||
                        audio_stream = requests.get(audio_url, stream=True).raw
 | 
			
		||||
 | 
			
		||||
                    audios.append(audio_stream)
 | 
			
		||||
 | 
			
		||||
@ -12,14 +12,29 @@
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import ipaddress
 | 
			
		||||
import json
 | 
			
		||||
import os
 | 
			
		||||
import socket
 | 
			
		||||
from typing import TYPE_CHECKING, Any
 | 
			
		||||
from urllib.parse import urlparse
 | 
			
		||||
 | 
			
		||||
from ..extras.misc import is_env_enabled
 | 
			
		||||
from ..extras.packages import is_fastapi_available
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if is_fastapi_available():
 | 
			
		||||
    from fastapi import HTTPException, status
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from pydantic import BaseModel
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
SAFE_MEDIA_PATH = os.environ.get("SAFE_MEDIA_PATH", os.path.join(os.path.dirname(__file__), "safe_media"))
 | 
			
		||||
ALLOW_LOCAL_FILES = is_env_enabled("ALLOW_LOCAL_FILES", "1")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def dictify(data: "BaseModel") -> dict[str, Any]:
 | 
			
		||||
    try:  # pydantic v2
 | 
			
		||||
        return data.model_dump(exclude_unset=True)
 | 
			
		||||
@ -32,3 +47,50 @@ def jsonify(data: "BaseModel") -> str:
 | 
			
		||||
        return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
 | 
			
		||||
    except AttributeError:  # pydantic v1
 | 
			
		||||
        return data.json(exclude_unset=True, ensure_ascii=False)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def check_lfi_path(path: str) -> None:
 | 
			
		||||
    """Checks if a given path is vulnerable to LFI. Raises HTTPException if unsafe."""
 | 
			
		||||
    if not ALLOW_LOCAL_FILES:
 | 
			
		||||
        raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Local file access is disabled.")
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        os.makedirs(SAFE_MEDIA_PATH, exist_ok=True)
 | 
			
		||||
        real_path = os.path.realpath(path)
 | 
			
		||||
        safe_path = os.path.realpath(SAFE_MEDIA_PATH)
 | 
			
		||||
 | 
			
		||||
        if not real_path.startswith(safe_path):
 | 
			
		||||
            raise HTTPException(
 | 
			
		||||
                status_code=status.HTTP_403_FORBIDDEN, detail="File access is restricted to the safe media directory."
 | 
			
		||||
            )
 | 
			
		||||
    except Exception:
 | 
			
		||||
        raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid or inaccessible file path.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def check_ssrf_url(url: str) -> None:
 | 
			
		||||
    """Checks if a given URL is vulnerable to SSRF. Raises HTTPException if unsafe."""
 | 
			
		||||
    try:
 | 
			
		||||
        parsed_url = urlparse(url)
 | 
			
		||||
        if parsed_url.scheme not in ["http", "https"]:
 | 
			
		||||
            raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only HTTP/HTTPS URLs are allowed.")
 | 
			
		||||
 | 
			
		||||
        hostname = parsed_url.hostname
 | 
			
		||||
        if not hostname:
 | 
			
		||||
            raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid URL hostname.")
 | 
			
		||||
 | 
			
		||||
        ip_info = socket.getaddrinfo(hostname, parsed_url.port)
 | 
			
		||||
        ip_address_str = ip_info[0][4][0]
 | 
			
		||||
        ip = ipaddress.ip_address(ip_address_str)
 | 
			
		||||
 | 
			
		||||
        if not ip.is_global:
 | 
			
		||||
            raise HTTPException(
 | 
			
		||||
                status_code=status.HTTP_403_FORBIDDEN,
 | 
			
		||||
                detail="Access to private or reserved IP addresses is not allowed.",
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    except socket.gaierror:
 | 
			
		||||
        raise HTTPException(
 | 
			
		||||
            status_code=status.HTTP_400_BAD_REQUEST, detail=f"Could not resolve hostname: {parsed_url.hostname}"
 | 
			
		||||
        )
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid URL: {e}")
 | 
			
		||||
 | 
			
		||||
@ -12,145 +12,16 @@
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
import subprocess
 | 
			
		||||
import sys
 | 
			
		||||
from copy import deepcopy
 | 
			
		||||
from functools import partial
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
USAGE = (
 | 
			
		||||
    "-" * 70
 | 
			
		||||
    + "\n"
 | 
			
		||||
    + "| Usage:                                                             |\n"
 | 
			
		||||
    + "|   llamafactory-cli api -h: launch an OpenAI-style API server       |\n"
 | 
			
		||||
    + "|   llamafactory-cli chat -h: launch a chat interface in CLI         |\n"
 | 
			
		||||
    + "|   llamafactory-cli export -h: merge LoRA adapters and export model |\n"
 | 
			
		||||
    + "|   llamafactory-cli train -h: train models                          |\n"
 | 
			
		||||
    + "|   llamafactory-cli webchat -h: launch a chat interface in Web UI   |\n"
 | 
			
		||||
    + "|   llamafactory-cli webui: launch LlamaBoard                        |\n"
 | 
			
		||||
    + "|   llamafactory-cli env: show environment info                      |\n"
 | 
			
		||||
    + "|   llamafactory-cli version: show version info                      |\n"
 | 
			
		||||
    + "| Hint: You can use `lmf` as a shortcut for `llamafactory-cli`.      |\n"
 | 
			
		||||
    + "-" * 70
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
    from .extras import logging
 | 
			
		||||
    from .extras.env import VERSION, print_env
 | 
			
		||||
    from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_ray
 | 
			
		||||
    from .extras.misc import is_env_enabled
 | 
			
		||||
 | 
			
		||||
    if is_env_enabled("USE_V1"):
 | 
			
		||||
        from .v1 import launcher
 | 
			
		||||
    else:
 | 
			
		||||
        from . import launcher
 | 
			
		||||
 | 
			
		||||
    logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
    WELCOME = (
 | 
			
		||||
        "-" * 58
 | 
			
		||||
        + "\n"
 | 
			
		||||
        + f"| Welcome to LLaMA Factory, version {VERSION}"
 | 
			
		||||
        + " " * (21 - len(VERSION))
 | 
			
		||||
        + "|\n|"
 | 
			
		||||
        + " " * 56
 | 
			
		||||
        + "|\n"
 | 
			
		||||
        + "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n"
 | 
			
		||||
        + "-" * 58
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    COMMAND_MAP = {
 | 
			
		||||
        "api": launcher.run_api,
 | 
			
		||||
        "chat": launcher.run_chat,
 | 
			
		||||
        "env": print_env,
 | 
			
		||||
        "eval": launcher.run_eval,
 | 
			
		||||
        "export": launcher.export_model,
 | 
			
		||||
        "train": launcher.run_exp,
 | 
			
		||||
        "webchat": launcher.run_web_demo,
 | 
			
		||||
        "webui": launcher.run_web_ui,
 | 
			
		||||
        "version": partial(print, WELCOME),
 | 
			
		||||
        "help": partial(print, USAGE),
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
 | 
			
		||||
    if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())):
 | 
			
		||||
        # launch distributed training
 | 
			
		||||
        nnodes = os.getenv("NNODES", "1")
 | 
			
		||||
        node_rank = os.getenv("NODE_RANK", "0")
 | 
			
		||||
        nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count()))
 | 
			
		||||
        master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
 | 
			
		||||
        master_port = os.getenv("MASTER_PORT", str(find_available_port()))
 | 
			
		||||
        logger.info_rank0(f"Initializing {nproc_per_node} distributed tasks at: {master_addr}:{master_port}")
 | 
			
		||||
        if int(nnodes) > 1:
 | 
			
		||||
            logger.info_rank0(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}")
 | 
			
		||||
 | 
			
		||||
        # elastic launch support
 | 
			
		||||
        max_restarts = os.getenv("MAX_RESTARTS", "0")
 | 
			
		||||
        rdzv_id = os.getenv("RDZV_ID")
 | 
			
		||||
        min_nnodes = os.getenv("MIN_NNODES")
 | 
			
		||||
        max_nnodes = os.getenv("MAX_NNODES")
 | 
			
		||||
 | 
			
		||||
        env = deepcopy(os.environ)
 | 
			
		||||
        if is_env_enabled("OPTIM_TORCH", "1"):
 | 
			
		||||
            # optimize DDP, see https://zhuanlan.zhihu.com/p/671834539
 | 
			
		||||
            env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
 | 
			
		||||
            env["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
 | 
			
		||||
 | 
			
		||||
        if rdzv_id is not None:
 | 
			
		||||
            # launch elastic job with fault tolerant support when possible
 | 
			
		||||
            # see also https://docs.pytorch.org/docs/stable/elastic/train_script.html
 | 
			
		||||
            rdzv_nnodes = nnodes
 | 
			
		||||
            # elastic number of nodes if MIN_NNODES and MAX_NNODES are set
 | 
			
		||||
            if min_nnodes is not None and max_nnodes is not None:
 | 
			
		||||
                rdzv_nnodes = f"{min_nnodes}:{max_nnodes}"
 | 
			
		||||
 | 
			
		||||
            process = subprocess.run(
 | 
			
		||||
                (
 | 
			
		||||
                    "torchrun --nnodes {rdzv_nnodes} --nproc-per-node {nproc_per_node} "
 | 
			
		||||
                    "--rdzv-id {rdzv_id} --rdzv-backend c10d --rdzv-endpoint {master_addr}:{master_port} "
 | 
			
		||||
                    "--max-restarts {max_restarts} {file_name} {args}"
 | 
			
		||||
                )
 | 
			
		||||
                .format(
 | 
			
		||||
                    rdzv_nnodes=rdzv_nnodes,
 | 
			
		||||
                    nproc_per_node=nproc_per_node,
 | 
			
		||||
                    rdzv_id=rdzv_id,
 | 
			
		||||
                    master_addr=master_addr,
 | 
			
		||||
                    master_port=master_port,
 | 
			
		||||
                    max_restarts=max_restarts,
 | 
			
		||||
                    file_name=launcher.__file__,
 | 
			
		||||
                    args=" ".join(sys.argv[1:]),
 | 
			
		||||
                )
 | 
			
		||||
                .split(),
 | 
			
		||||
                env=env,
 | 
			
		||||
                check=True,
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            # NOTE: DO NOT USE shell=True to avoid security risk
 | 
			
		||||
            process = subprocess.run(
 | 
			
		||||
                (
 | 
			
		||||
                    "torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
 | 
			
		||||
                    "--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
 | 
			
		||||
                )
 | 
			
		||||
                .format(
 | 
			
		||||
                    nnodes=nnodes,
 | 
			
		||||
                    node_rank=node_rank,
 | 
			
		||||
                    nproc_per_node=nproc_per_node,
 | 
			
		||||
                    master_addr=master_addr,
 | 
			
		||||
                    master_port=master_port,
 | 
			
		||||
                    file_name=launcher.__file__,
 | 
			
		||||
                    args=" ".join(sys.argv[1:]),
 | 
			
		||||
                )
 | 
			
		||||
                .split(),
 | 
			
		||||
                env=env,
 | 
			
		||||
                check=True,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        sys.exit(process.returncode)
 | 
			
		||||
    elif command in COMMAND_MAP:
 | 
			
		||||
        COMMAND_MAP[command]()
 | 
			
		||||
    else:
 | 
			
		||||
        print(f"Unknown command: {command}.\n{USAGE}")
 | 
			
		||||
    launcher.launch()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
 | 
			
		||||
@ -16,6 +16,9 @@
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from collections import OrderedDict
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
VERSION = "0.9.4.dev0"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -28,20 +31,20 @@ def print_env() -> None:
 | 
			
		||||
    import peft
 | 
			
		||||
    import torch
 | 
			
		||||
    import transformers
 | 
			
		||||
    import trl
 | 
			
		||||
    from transformers.utils import is_torch_cuda_available, is_torch_npu_available
 | 
			
		||||
 | 
			
		||||
    info = {
 | 
			
		||||
        "`llamafactory` version": VERSION,
 | 
			
		||||
        "Platform": platform.platform(),
 | 
			
		||||
        "Python version": platform.python_version(),
 | 
			
		||||
        "PyTorch version": torch.__version__,
 | 
			
		||||
        "Transformers version": transformers.__version__,
 | 
			
		||||
        "Datasets version": datasets.__version__,
 | 
			
		||||
        "Accelerate version": accelerate.__version__,
 | 
			
		||||
        "PEFT version": peft.__version__,
 | 
			
		||||
        "TRL version": trl.__version__,
 | 
			
		||||
    }
 | 
			
		||||
    info = OrderedDict(
 | 
			
		||||
        {
 | 
			
		||||
            "`llamafactory` version": VERSION,
 | 
			
		||||
            "Platform": platform.platform(),
 | 
			
		||||
            "Python version": platform.python_version(),
 | 
			
		||||
            "PyTorch version": torch.__version__,
 | 
			
		||||
            "Transformers version": transformers.__version__,
 | 
			
		||||
            "Datasets version": datasets.__version__,
 | 
			
		||||
            "Accelerate version": accelerate.__version__,
 | 
			
		||||
            "PEFT version": peft.__version__,
 | 
			
		||||
        }
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    if is_torch_cuda_available():
 | 
			
		||||
        info["PyTorch version"] += " (GPU)"
 | 
			
		||||
@ -54,6 +57,13 @@ def print_env() -> None:
 | 
			
		||||
        info["NPU type"] = torch.npu.get_device_name()
 | 
			
		||||
        info["CANN version"] = torch.version.cann
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        import trl  # type: ignore
 | 
			
		||||
 | 
			
		||||
        info["TRL version"] = trl.__version__
 | 
			
		||||
    except Exception:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        import deepspeed  # type: ignore
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -12,46 +12,169 @@
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_api():
 | 
			
		||||
    from llamafactory.api.app import run_api as _run_api
 | 
			
		||||
 | 
			
		||||
    _run_api()
 | 
			
		||||
import os
 | 
			
		||||
import subprocess
 | 
			
		||||
import sys
 | 
			
		||||
from copy import deepcopy
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_chat():
 | 
			
		||||
    from llamafactory.chat.chat_model import run_chat as _run_chat
 | 
			
		||||
 | 
			
		||||
    return _run_chat()
 | 
			
		||||
USAGE = (
 | 
			
		||||
    "-" * 70
 | 
			
		||||
    + "\n"
 | 
			
		||||
    + "| Usage:                                                             |\n"
 | 
			
		||||
    + "|   llamafactory-cli api -h: launch an OpenAI-style API server       |\n"
 | 
			
		||||
    + "|   llamafactory-cli chat -h: launch a chat interface in CLI         |\n"
 | 
			
		||||
    + "|   llamafactory-cli export -h: merge LoRA adapters and export model |\n"
 | 
			
		||||
    + "|   llamafactory-cli train -h: train models                          |\n"
 | 
			
		||||
    + "|   llamafactory-cli webchat -h: launch a chat interface in Web UI   |\n"
 | 
			
		||||
    + "|   llamafactory-cli webui: launch LlamaBoard                        |\n"
 | 
			
		||||
    + "|   llamafactory-cli env: show environment info                      |\n"
 | 
			
		||||
    + "|   llamafactory-cli version: show version info                      |\n"
 | 
			
		||||
    + "| Hint: You can use `lmf` as a shortcut for `llamafactory-cli`.      |\n"
 | 
			
		||||
    + "-" * 70
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_eval():
 | 
			
		||||
    raise NotImplementedError("Evaluation will be deprecated in the future.")
 | 
			
		||||
def launch():
 | 
			
		||||
    from .extras import logging
 | 
			
		||||
    from .extras.env import VERSION, print_env
 | 
			
		||||
    from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_ray
 | 
			
		||||
 | 
			
		||||
    logger = logging.get_logger(__name__)
 | 
			
		||||
    WELCOME = (
 | 
			
		||||
        "-" * 58
 | 
			
		||||
        + "\n"
 | 
			
		||||
        + f"| Welcome to LLaMA Factory, version {VERSION}"
 | 
			
		||||
        + " " * (21 - len(VERSION))
 | 
			
		||||
        + "|\n|"
 | 
			
		||||
        + " " * 56
 | 
			
		||||
        + "|\n"
 | 
			
		||||
        + "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n"
 | 
			
		||||
        + "-" * 58
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
def export_model():
 | 
			
		||||
    from llamafactory.train.tuner import export_model as _export_model
 | 
			
		||||
    command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
 | 
			
		||||
    if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())):
 | 
			
		||||
        # launch distributed training
 | 
			
		||||
        nnodes = os.getenv("NNODES", "1")
 | 
			
		||||
        node_rank = os.getenv("NODE_RANK", "0")
 | 
			
		||||
        nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count()))
 | 
			
		||||
        master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
 | 
			
		||||
        master_port = os.getenv("MASTER_PORT", str(find_available_port()))
 | 
			
		||||
        logger.info_rank0(f"Initializing {nproc_per_node} distributed tasks at: {master_addr}:{master_port}")
 | 
			
		||||
        if int(nnodes) > 1:
 | 
			
		||||
            logger.info_rank0(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}")
 | 
			
		||||
 | 
			
		||||
    return _export_model()
 | 
			
		||||
        # elastic launch support
 | 
			
		||||
        max_restarts = os.getenv("MAX_RESTARTS", "0")
 | 
			
		||||
        rdzv_id = os.getenv("RDZV_ID")
 | 
			
		||||
        min_nnodes = os.getenv("MIN_NNODES")
 | 
			
		||||
        max_nnodes = os.getenv("MAX_NNODES")
 | 
			
		||||
 | 
			
		||||
        env = deepcopy(os.environ)
 | 
			
		||||
        if is_env_enabled("OPTIM_TORCH", "1"):
 | 
			
		||||
            # optimize DDP, see https://zhuanlan.zhihu.com/p/671834539
 | 
			
		||||
            env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
 | 
			
		||||
            env["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
 | 
			
		||||
 | 
			
		||||
def run_exp():
 | 
			
		||||
    from llamafactory.train.tuner import run_exp as _run_exp
 | 
			
		||||
        if rdzv_id is not None:
 | 
			
		||||
            # launch elastic job with fault tolerant support when possible
 | 
			
		||||
            # see also https://docs.pytorch.org/docs/stable/elastic/train_script.html
 | 
			
		||||
            rdzv_nnodes = nnodes
 | 
			
		||||
            # elastic number of nodes if MIN_NNODES and MAX_NNODES are set
 | 
			
		||||
            if min_nnodes is not None and max_nnodes is not None:
 | 
			
		||||
                rdzv_nnodes = f"{min_nnodes}:{max_nnodes}"
 | 
			
		||||
 | 
			
		||||
    return _run_exp()  # use absolute import
 | 
			
		||||
            process = subprocess.run(
 | 
			
		||||
                (
 | 
			
		||||
                    "torchrun --nnodes {rdzv_nnodes} --nproc-per-node {nproc_per_node} "
 | 
			
		||||
                    "--rdzv-id {rdzv_id} --rdzv-backend c10d --rdzv-endpoint {master_addr}:{master_port} "
 | 
			
		||||
                    "--max-restarts {max_restarts} {file_name} {args}"
 | 
			
		||||
                )
 | 
			
		||||
                .format(
 | 
			
		||||
                    rdzv_nnodes=rdzv_nnodes,
 | 
			
		||||
                    nproc_per_node=nproc_per_node,
 | 
			
		||||
                    rdzv_id=rdzv_id,
 | 
			
		||||
                    master_addr=master_addr,
 | 
			
		||||
                    master_port=master_port,
 | 
			
		||||
                    max_restarts=max_restarts,
 | 
			
		||||
                    file_name=__file__,
 | 
			
		||||
                    args=" ".join(sys.argv[1:]),
 | 
			
		||||
                )
 | 
			
		||||
                .split(),
 | 
			
		||||
                env=env,
 | 
			
		||||
                check=True,
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            # NOTE: DO NOT USE shell=True to avoid security risk
 | 
			
		||||
            process = subprocess.run(
 | 
			
		||||
                (
 | 
			
		||||
                    "torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
 | 
			
		||||
                    "--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
 | 
			
		||||
                )
 | 
			
		||||
                .format(
 | 
			
		||||
                    nnodes=nnodes,
 | 
			
		||||
                    node_rank=node_rank,
 | 
			
		||||
                    nproc_per_node=nproc_per_node,
 | 
			
		||||
                    master_addr=master_addr,
 | 
			
		||||
                    master_port=master_port,
 | 
			
		||||
                    file_name=__file__,
 | 
			
		||||
                    args=" ".join(sys.argv[1:]),
 | 
			
		||||
                )
 | 
			
		||||
                .split(),
 | 
			
		||||
                env=env,
 | 
			
		||||
                check=True,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        sys.exit(process.returncode)
 | 
			
		||||
 | 
			
		||||
def run_web_demo():
 | 
			
		||||
    from llamafactory.webui.interface import run_web_demo as _run_web_demo
 | 
			
		||||
    elif command == "api":
 | 
			
		||||
        from .api.app import run_api
 | 
			
		||||
 | 
			
		||||
    return _run_web_demo()
 | 
			
		||||
        run_api()
 | 
			
		||||
 | 
			
		||||
    elif command == "chat":
 | 
			
		||||
        from .chat.chat_model import run_chat
 | 
			
		||||
 | 
			
		||||
def run_web_ui():
 | 
			
		||||
    from llamafactory.webui.interface import run_web_ui as _run_web_ui
 | 
			
		||||
        run_chat()
 | 
			
		||||
 | 
			
		||||
    return _run_web_ui()
 | 
			
		||||
    elif command == "eval":
 | 
			
		||||
        raise NotImplementedError("Evaluation will be deprecated in the future.")
 | 
			
		||||
 | 
			
		||||
    elif command == "export":
 | 
			
		||||
        from .train.tuner import export_model
 | 
			
		||||
 | 
			
		||||
        export_model()
 | 
			
		||||
 | 
			
		||||
    elif command == "train":
 | 
			
		||||
        from .train.tuner import run_exp
 | 
			
		||||
 | 
			
		||||
        run_exp()
 | 
			
		||||
 | 
			
		||||
    elif command == "webchat":
 | 
			
		||||
        from .webui.interface import run_web_demo
 | 
			
		||||
 | 
			
		||||
        run_web_demo()
 | 
			
		||||
 | 
			
		||||
    elif command == "webui":
 | 
			
		||||
        from .webui.interface import run_web_ui
 | 
			
		||||
 | 
			
		||||
        run_web_ui()
 | 
			
		||||
 | 
			
		||||
    elif command == "env":
 | 
			
		||||
        print_env()
 | 
			
		||||
 | 
			
		||||
    elif command == "version":
 | 
			
		||||
        print(WELCOME)
 | 
			
		||||
 | 
			
		||||
    elif command == "help":
 | 
			
		||||
        print(USAGE)
 | 
			
		||||
 | 
			
		||||
    else:
 | 
			
		||||
        print(f"Unknown command: {command}.\n{USAGE}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    from llamafactory.train.tuner import run_exp  # use absolute import
 | 
			
		||||
 | 
			
		||||
    run_exp()
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										33
									
								
								src/llamafactory/v1/config/data_args.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								src/llamafactory/v1/config/data_args.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,33 @@
 | 
			
		||||
# Copyright 2025 the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from dataclasses import dataclass, field
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class DataArguments:
 | 
			
		||||
    dataset: Optional[str] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "Path to the dataset."},
 | 
			
		||||
    )
 | 
			
		||||
    dataset_dir: str = field(
 | 
			
		||||
        default="data",
 | 
			
		||||
        metadata={"help": "Path to the folder containing the datasets."},
 | 
			
		||||
    )
 | 
			
		||||
    cutoff_len: int = field(
 | 
			
		||||
        default=2048,
 | 
			
		||||
        metadata={"help": "Cutoff length for the dataset."},
 | 
			
		||||
    )
 | 
			
		||||
							
								
								
									
										27
									
								
								src/llamafactory/v1/config/model_args.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								src/llamafactory/v1/config/model_args.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,27 @@
 | 
			
		||||
# Copyright 2025 the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from dataclasses import dataclass, field
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class ModelArguments:
 | 
			
		||||
    model: str = field(
 | 
			
		||||
        metadata={"help": "Path to the model or model identifier from Hugging Face."},
 | 
			
		||||
    )
 | 
			
		||||
    trust_remote_code: bool = field(
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "Trust remote code from Hugging Face."},
 | 
			
		||||
    )
 | 
			
		||||
							
								
								
									
										63
									
								
								src/llamafactory/v1/config/parser.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										63
									
								
								src/llamafactory/v1/config/parser.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,63 @@
 | 
			
		||||
# Copyright 2025 the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import json
 | 
			
		||||
import sys
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from typing import Any, Optional, Union
 | 
			
		||||
 | 
			
		||||
from omegaconf import OmegaConf
 | 
			
		||||
from transformers import HfArgumentParser
 | 
			
		||||
 | 
			
		||||
from ...extras.misc import is_env_enabled
 | 
			
		||||
from .data_args import DataArguments
 | 
			
		||||
from .model_args import ModelArguments
 | 
			
		||||
from .sample_args import SampleArguments
 | 
			
		||||
from .training_args import TrainingArguments
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_args(
 | 
			
		||||
    args: Optional[Union[dict[str, Any], list[str]]] = None,
 | 
			
		||||
) -> tuple[DataArguments, ModelArguments, TrainingArguments, SampleArguments]:
 | 
			
		||||
    """Parse arguments from command line or config file."""
 | 
			
		||||
    parser = HfArgumentParser([DataArguments, ModelArguments, TrainingArguments, SampleArguments])
 | 
			
		||||
    allow_extra_keys = is_env_enabled("ALLOW_EXTRA_KEYS")
 | 
			
		||||
 | 
			
		||||
    if args is None:
 | 
			
		||||
        if len(sys.argv) > 1 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")):
 | 
			
		||||
            override_config = OmegaConf.from_cli(sys.argv[2:])
 | 
			
		||||
            dict_config = OmegaConf.load(Path(sys.argv[1]).absolute())
 | 
			
		||||
            args = OmegaConf.to_container(OmegaConf.merge(dict_config, override_config))
 | 
			
		||||
        elif len(sys.argv) > 1 and sys.argv[1].endswith(".json"):
 | 
			
		||||
            override_config = OmegaConf.from_cli(sys.argv[2:])
 | 
			
		||||
            dict_config = OmegaConf.create(json.load(Path(sys.argv[1]).absolute()))
 | 
			
		||||
            args = OmegaConf.to_container(OmegaConf.merge(dict_config, override_config))
 | 
			
		||||
        else:  # list of strings
 | 
			
		||||
            args = sys.argv[1:]
 | 
			
		||||
 | 
			
		||||
    if isinstance(args, dict):
 | 
			
		||||
        (*parsed_args,) = parser.parse_dict(args, allow_extra_keys=allow_extra_keys)
 | 
			
		||||
    else:
 | 
			
		||||
        (*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(args, return_remaining_strings=True)
 | 
			
		||||
        if unknown_args and not allow_extra_keys:
 | 
			
		||||
            print(parser.format_help())
 | 
			
		||||
            print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
 | 
			
		||||
            raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
 | 
			
		||||
 | 
			
		||||
    return tuple(parsed_args)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    print(get_args())
 | 
			
		||||
							
								
								
									
										24
									
								
								src/llamafactory/v1/config/sample_args.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								src/llamafactory/v1/config/sample_args.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,24 @@
 | 
			
		||||
# Copyright 2025 the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from dataclasses import dataclass, field
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class SampleArguments:
 | 
			
		||||
    max_new_tokens: int = field(
 | 
			
		||||
        default=128,
 | 
			
		||||
        metadata={"help": "Maximum number of new tokens to generate."},
 | 
			
		||||
    )
 | 
			
		||||
							
								
								
									
										40
									
								
								src/llamafactory/v1/config/training_args.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										40
									
								
								src/llamafactory/v1/config/training_args.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,40 @@
 | 
			
		||||
# Copyright 2025 the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from dataclasses import dataclass, field
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class TrainingArguments:
 | 
			
		||||
    output_dir: str = field(
 | 
			
		||||
        default="",
 | 
			
		||||
        metadata={"help": "Path to the output directory."},
 | 
			
		||||
    )
 | 
			
		||||
    micro_batch_size: int = field(
 | 
			
		||||
        default=1,
 | 
			
		||||
        metadata={"help": "Micro batch size for training."},
 | 
			
		||||
    )
 | 
			
		||||
    global_batch_size: int = field(
 | 
			
		||||
        default=1,
 | 
			
		||||
        metadata={"help": "Global batch size for training."},
 | 
			
		||||
    )
 | 
			
		||||
    learning_rate: float = field(
 | 
			
		||||
        default=1e-4,
 | 
			
		||||
        metadata={"help": "Learning rate for training."},
 | 
			
		||||
    )
 | 
			
		||||
    bf16: bool = field(
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "Use bf16 for training."},
 | 
			
		||||
    )
 | 
			
		||||
@ -0,0 +1,35 @@
 | 
			
		||||
# Copyright 2025 the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from ..config.training_args import TrainingArguments
 | 
			
		||||
from ..extras.types import DataLoader, Model, Processor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BaseTrainer:
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        args: TrainingArguments,
 | 
			
		||||
        model: Model,
 | 
			
		||||
        processor: Processor,
 | 
			
		||||
        data_loader: DataLoader,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        self.args = args
 | 
			
		||||
        self.model = model
 | 
			
		||||
        self.processor = processor
 | 
			
		||||
        self.data_loader = data_loader
 | 
			
		||||
        self.optimizer = None
 | 
			
		||||
        self.lr_scheduler = None
 | 
			
		||||
 | 
			
		||||
    def fit(self) -> None:
 | 
			
		||||
        pass
 | 
			
		||||
@ -0,0 +1,20 @@
 | 
			
		||||
# Copyright 2025 the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from ..config.sample_args import SampleArguments
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ChatSampler:
 | 
			
		||||
    def __init__(self, sample_args: SampleArguments) -> None:
 | 
			
		||||
        self.args = sample_args
 | 
			
		||||
							
								
								
									
										75
									
								
								src/llamafactory/v1/core/data_engine.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										75
									
								
								src/llamafactory/v1/core/data_engine.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,75 @@
 | 
			
		||||
# Copyright 2025 the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
from datasets import load_dataset
 | 
			
		||||
from huggingface_hub import hf_hub_download
 | 
			
		||||
from omegaconf import OmegaConf
 | 
			
		||||
 | 
			
		||||
from ..config.data_args import DataArguments
 | 
			
		||||
from ..extras.types import DataLoader, Dataset, Processor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DataCollator:
 | 
			
		||||
    def __init__(self, processor: Processor) -> None:
 | 
			
		||||
        self.processor = processor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DatasetPathMixin:
 | 
			
		||||
    args: DataArguments
 | 
			
		||||
 | 
			
		||||
    def _abspath(self, path: str) -> str:
 | 
			
		||||
        return os.path.abspath(os.path.expanduser(os.path.join(self.args.dataset_dir, path)))
 | 
			
		||||
 | 
			
		||||
    def _exists(self, path: str) -> bool:
 | 
			
		||||
        return os.path.exists(self._abspath(path))
 | 
			
		||||
 | 
			
		||||
    def _isfile(self, path: str) -> bool:
 | 
			
		||||
        return os.path.isfile(self._abspath(path))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DataEngine(DatasetPathMixin):
 | 
			
		||||
    def __init__(self, data_args: DataArguments) -> None:
 | 
			
		||||
        self.args = data_args
 | 
			
		||||
        self.datasets: dict[str, Dataset] = {}
 | 
			
		||||
        dataset_info = self.get_dataset_info()
 | 
			
		||||
        self.load_dataset(dataset_info)
 | 
			
		||||
 | 
			
		||||
    def get_dataset_info(self) -> dict:
 | 
			
		||||
        """Get dataset info from dataset path.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            dict: Dataset info.
 | 
			
		||||
        """
 | 
			
		||||
        if self.args.dataset.endswith(".yaml") and self._isfile(self.args.dataset):  # local file
 | 
			
		||||
            return OmegaConf.load(self._abspath(self.args.dataset))
 | 
			
		||||
        elif self.args.dataset.endswith(".yaml"):  # hf hub uri
 | 
			
		||||
            repo_id, filename = os.path.split(self.args.dataset)
 | 
			
		||||
            filepath = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")
 | 
			
		||||
            return OmegaConf.load(filepath)
 | 
			
		||||
        elif self._exists(self.args.dataset):  # local file(s)
 | 
			
		||||
            return {"default": {"file_name": self.args.dataset}}
 | 
			
		||||
        else:  # hf hub dataset
 | 
			
		||||
            return {"default": {"hf_hub_url": self.args.dataset}}
 | 
			
		||||
 | 
			
		||||
    def load_dataset(self, dataset_info: dict) -> None:
 | 
			
		||||
        for key, value in dataset_info.items():
 | 
			
		||||
            if "hf_hub_url" in value:
 | 
			
		||||
                dataset_info[key] = load_dataset(value["hf_hub_url"])
 | 
			
		||||
            elif "file_name" in value:
 | 
			
		||||
                dataset_info[key] = load_dataset(value["file_name"])
 | 
			
		||||
 | 
			
		||||
    def get_data_loader(self, processor: Processor) -> DataLoader:
 | 
			
		||||
        pass
 | 
			
		||||
@ -0,0 +1,27 @@
 | 
			
		||||
# Copyright 2025 the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from ..config.model_args import ModelArguments
 | 
			
		||||
from ..extras.types import Model, Processor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ModelEngine:
 | 
			
		||||
    def __init__(self, model_args: ModelArguments) -> None:
 | 
			
		||||
        self.args = model_args
 | 
			
		||||
 | 
			
		||||
    def get_model(self) -> Model:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def get_processor(self) -> Processor:
 | 
			
		||||
        pass
 | 
			
		||||
							
								
								
									
										32
									
								
								src/llamafactory/v1/extras/types.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								src/llamafactory/v1/extras/types.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,32 @@
 | 
			
		||||
# Copyright 2025 the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from typing import TYPE_CHECKING, Union
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from datasets import Dataset as HFDataset
 | 
			
		||||
    from datasets import IterableDataset
 | 
			
		||||
    from torch.utils.data import DataLoader as TorchDataLoader
 | 
			
		||||
    from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
 | 
			
		||||
 | 
			
		||||
    Dataset = Union[HFDataset, IterableDataset]
 | 
			
		||||
    DataLoader = TorchDataLoader
 | 
			
		||||
    Model = PreTrainedModel
 | 
			
		||||
    Processor = Union[PreTrainedTokenizer, ProcessorMixin]
 | 
			
		||||
else:
 | 
			
		||||
    Dataset = None
 | 
			
		||||
    DataLoader = None
 | 
			
		||||
    Model = None
 | 
			
		||||
    Processor = None
 | 
			
		||||
@ -12,22 +12,55 @@
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import sys
 | 
			
		||||
 | 
			
		||||
def run_train():
 | 
			
		||||
    raise NotImplementedError("Please use `llamafactory-cli sft` or `llamafactory-cli rm`.")
 | 
			
		||||
from ..extras.env import VERSION, print_env
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_chat():
 | 
			
		||||
    from llamafactory.v1.core.chat_sampler import Sampler
 | 
			
		||||
 | 
			
		||||
    Sampler().cli()
 | 
			
		||||
USAGE = (
 | 
			
		||||
    "-" * 70
 | 
			
		||||
    + "\n"
 | 
			
		||||
    + "| Usage:                                                             |\n"
 | 
			
		||||
    + "|   llamafactory-cli sft -h: train models                            |\n"
 | 
			
		||||
    + "|   llamafactory-cli version: show version info                      |\n"
 | 
			
		||||
    + "| Hint: You can use `lmf` as a shortcut for `llamafactory-cli`.      |\n"
 | 
			
		||||
    + "-" * 70
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_sft():
 | 
			
		||||
    from llamafactory.v1.train.sft import SFTTrainer
 | 
			
		||||
WELCOME = (
 | 
			
		||||
    "-" * 58
 | 
			
		||||
    + "\n"
 | 
			
		||||
    + f"| Welcome to LLaMA Factory, version {VERSION}"
 | 
			
		||||
    + " " * (21 - len(VERSION))
 | 
			
		||||
    + "|\n|"
 | 
			
		||||
    + " " * 56
 | 
			
		||||
    + "|\n"
 | 
			
		||||
    + "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n"
 | 
			
		||||
    + "-" * 58
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
    SFTTrainer().run()
 | 
			
		||||
 | 
			
		||||
def launch():
 | 
			
		||||
    command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
 | 
			
		||||
 | 
			
		||||
    if command == "sft":
 | 
			
		||||
        from .trainers.sft_trainer import run_sft
 | 
			
		||||
 | 
			
		||||
        run_sft()
 | 
			
		||||
 | 
			
		||||
    elif command == "env":
 | 
			
		||||
        print_env()
 | 
			
		||||
 | 
			
		||||
    elif command == "version":
 | 
			
		||||
        print(WELCOME)
 | 
			
		||||
 | 
			
		||||
    elif command == "help":
 | 
			
		||||
        print(USAGE)
 | 
			
		||||
 | 
			
		||||
    else:
 | 
			
		||||
        print(f"Unknown command: {command}.\n{USAGE}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    run_train()
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										0
									
								
								src/llamafactory/v1/plugins/data_plugins/filter.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								src/llamafactory/v1/plugins/data_plugins/filter.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										26
									
								
								src/llamafactory/v1/plugins/data_plugins/template.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								src/llamafactory/v1/plugins/data_plugins/template.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,26 @@
 | 
			
		||||
# Copyright 2025 the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class Template:
 | 
			
		||||
    user_template: str
 | 
			
		||||
    assistant_template: str
 | 
			
		||||
    system_template: str
 | 
			
		||||
 | 
			
		||||
    def render_message(self, message: "dict[str, str]") -> str:
 | 
			
		||||
        return self.user_template.format(**message)
 | 
			
		||||
@ -0,0 +1,34 @@
 | 
			
		||||
# Copyright 2025 the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from ..config.parser import get_args
 | 
			
		||||
from ..core.base_trainer import BaseTrainer
 | 
			
		||||
from ..core.data_engine import DataEngine
 | 
			
		||||
from ..core.model_engine import ModelEngine
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SFTTrainer(BaseTrainer):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_sft():
 | 
			
		||||
    model_args, data_args, training_args, _ = get_args()
 | 
			
		||||
    model_engine = ModelEngine(model_args)
 | 
			
		||||
    data_engine = DataEngine(data_args)
 | 
			
		||||
    model = model_engine.get_model()
 | 
			
		||||
    processor = model_engine.get_processor()
 | 
			
		||||
    data_loader = data_engine.get_data_loader(processor)
 | 
			
		||||
    trainer = SFTTrainer(training_args, model, processor, data_loader)
 | 
			
		||||
    trainer.fit()
 | 
			
		||||
@ -36,8 +36,8 @@ from ..extras.misc import use_modelscope, use_openmind
 | 
			
		||||
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
DEFAULT_CACHE_DIR = "cache"
 | 
			
		||||
DEFAULT_CONFIG_DIR = "config"
 | 
			
		||||
DEFAULT_CACHE_DIR = "llamaboard_cache"
 | 
			
		||||
DEFAULT_CONFIG_DIR = "llamaboard_config"
 | 
			
		||||
DEFAULT_DATA_DIR = "data"
 | 
			
		||||
DEFAULT_SAVE_DIR = "saves"
 | 
			
		||||
USER_CONFIG = "user_config.yaml"
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user