mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-16 00:28:10 +08:00
Merge branch 'hiyouga:main' into main
Former-commit-id: 014acaa7845b7ac2876596d216b1be369a8e9311
This commit is contained in:
commit
f94b54b776
25
README.md
25
README.md
@ -5,7 +5,7 @@
|
|||||||
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
||||||
[](https://pypi.org/project/llmtuner/)
|
[](https://pypi.org/project/llmtuner/)
|
||||||
[](https://pypi.org/project/llmtuner/)
|
[](https://pypi.org/project/llmtuner/)
|
||||||
[](#projects-using-llama-factory)
|
[](#projects-using-llama-factory)
|
||||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||||
[](https://discord.gg/rKfvV9r9FK)
|
[](https://discord.gg/rKfvV9r9FK)
|
||||||
[](https://twitter.com/llamafactory_ai)
|
[](https://twitter.com/llamafactory_ai)
|
||||||
@ -310,13 +310,19 @@ huggingface-cli login
|
|||||||
|
|
||||||
### Installation
|
### Installation
|
||||||
|
|
||||||
|
> [!IMPORTANT]
|
||||||
|
> Installation is mandatory.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/hiyouga/LLaMA-Factory.git
|
git clone https://github.com/hiyouga/LLaMA-Factory.git
|
||||||
cd LLaMA-Factory
|
cd LLaMA-Factory
|
||||||
pip install -e .[metrics]
|
pip install -e .[torch,metrics]
|
||||||
```
|
```
|
||||||
|
|
||||||
Extra dependencies available: metrics, deepspeed, bitsandbytes, vllm, galore, badam, gptq, awq, aqlm, qwen, modelscope, quality
|
Extra dependencies available: torch, metrics, deepspeed, bitsandbytes, vllm, galore, badam, gptq, awq, aqlm, qwen, modelscope, quality
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> Use `pip install --no-deps -e .` to resolve package conflicts.
|
||||||
|
|
||||||
<details><summary>For Windows users</summary>
|
<details><summary>For Windows users</summary>
|
||||||
|
|
||||||
@ -363,12 +369,18 @@ See [examples/README.md](examples/README.md) for advanced usage (including distr
|
|||||||
CUDA_VISIBLE_DEVICES=0 GRADIO_SHARE=1 llamafactory-cli webui
|
CUDA_VISIBLE_DEVICES=0 GRADIO_SHARE=1 llamafactory-cli webui
|
||||||
```
|
```
|
||||||
|
|
||||||
<details><summary>For Alibaba Cloud users</summary>
|
<details><summary>For Alibaba Cloud PAI or AutoDL users</summary>
|
||||||
|
|
||||||
If you encountered display problems in LLaMA Board on Alibaba Cloud, try using the following command to set environment variables before starting LLaMA Board:
|
If you encountered display problems in LLaMA Board on Alibaba Cloud PAI, try using the following command to set environment variables before starting LLaMA Board:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
export GRADIO_ROOT_PATH=/${JUPYTER_NAME}/proxy/7860/
|
export GRADIO_SERVER_PORT=7860 GRADIO_ROOT_PATH=/${JUPYTER_NAME}/proxy/7860/
|
||||||
|
```
|
||||||
|
|
||||||
|
If you are using AutoDL, please install a specific version of Gradio:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install gradio==4.10.0
|
||||||
```
|
```
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
@ -467,6 +479,7 @@ If you have a project that should be incorporated, please contact via email or c
|
|||||||
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B.
|
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B.
|
||||||
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods.
|
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods.
|
||||||
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**: A large language model specialized in generate metadata for stable diffusion. [[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
|
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**: A large language model specialized in generate metadata for stable diffusion. [[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
|
||||||
|
1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**: A multimodal large language model specialized in Chinese medical domain, based on LLaVA-1.5-7B.
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
25
README_zh.md
25
README_zh.md
@ -5,7 +5,7 @@
|
|||||||
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
||||||
[](https://pypi.org/project/llmtuner/)
|
[](https://pypi.org/project/llmtuner/)
|
||||||
[](https://pypi.org/project/llmtuner/)
|
[](https://pypi.org/project/llmtuner/)
|
||||||
[](#使用了-llama-factory-的项目)
|
[](#使用了-llama-factory-的项目)
|
||||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||||
[](https://discord.gg/rKfvV9r9FK)
|
[](https://discord.gg/rKfvV9r9FK)
|
||||||
[](https://twitter.com/llamafactory_ai)
|
[](https://twitter.com/llamafactory_ai)
|
||||||
@ -310,13 +310,19 @@ huggingface-cli login
|
|||||||
|
|
||||||
### 安装 LLaMA Factory
|
### 安装 LLaMA Factory
|
||||||
|
|
||||||
|
> [!IMPORTANT]
|
||||||
|
> 此步骤为必需。
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/hiyouga/LLaMA-Factory.git
|
git clone https://github.com/hiyouga/LLaMA-Factory.git
|
||||||
cd LLaMA-Factory
|
cd LLaMA-Factory
|
||||||
pip install -e .[metrics]
|
pip install -e .[torch,metrics]
|
||||||
```
|
```
|
||||||
|
|
||||||
可选的额外依赖项:metrics、deepspeed、bitsandbytes、vllm、galore、badam、gptq、awq、aqlm、qwen、modelscope、quality
|
可选的额外依赖项:torch、metrics、deepspeed、bitsandbytes、vllm、galore、badam、gptq、awq、aqlm、qwen、modelscope、quality
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。
|
||||||
|
|
||||||
<details><summary>Windows 用户指南</summary>
|
<details><summary>Windows 用户指南</summary>
|
||||||
|
|
||||||
@ -363,12 +369,18 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_s
|
|||||||
CUDA_VISIBLE_DEVICES=0 GRADIO_SHARE=1 llamafactory-cli webui
|
CUDA_VISIBLE_DEVICES=0 GRADIO_SHARE=1 llamafactory-cli webui
|
||||||
```
|
```
|
||||||
|
|
||||||
<details><summary>阿里云用户指南</summary>
|
<details><summary>阿里云 PAI 和 AutoDL 用户指南</summary>
|
||||||
|
|
||||||
如果您在阿里云上使用 LLaMA Board 时遇到显示问题,请尝试在启动前使用以下命令设置环境变量:
|
如果您在阿里云 PAI 上使用 LLaMA Board 时遇到显示问题,请尝试在启动前使用以下命令设置环境变量:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
export GRADIO_ROOT_PATH=/${JUPYTER_NAME}/proxy/7860/
|
export GRADIO_SERVER_PORT=7860 GRADIO_ROOT_PATH=/${JUPYTER_NAME}/proxy/7860/
|
||||||
|
```
|
||||||
|
|
||||||
|
如果您正在使用 AutoDL,请安装下述 Gradio 版本:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install gradio==4.10.0
|
||||||
```
|
```
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
@ -467,6 +479,7 @@ export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
|
|||||||
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT,基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。
|
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT,基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。
|
||||||
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**:MBTI性格大模型项目,根据数据集与训练方式让任意 LLM 拥有 16 个不同的性格类型。
|
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**:MBTI性格大模型项目,根据数据集与训练方式让任意 LLM 拥有 16 个不同的性格类型。
|
||||||
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**:一个用于生成 Stable Diffusion 提示词的大型语言模型。[[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
|
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**:一个用于生成 Stable Diffusion 提示词的大型语言模型。[[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
|
||||||
|
1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**:中文多模态医学大模型,基于 LLaVA-1.5-7B 在中文多模态医疗数据上微调而得。
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
torch>=1.13.1
|
|
||||||
transformers>=4.37.2
|
transformers>=4.37.2
|
||||||
datasets>=2.14.3
|
datasets>=2.14.3
|
||||||
accelerate>=0.27.2
|
accelerate>=0.27.2
|
||||||
|
1
setup.py
1
setup.py
@ -20,6 +20,7 @@ def get_requires():
|
|||||||
|
|
||||||
|
|
||||||
extra_require = {
|
extra_require = {
|
||||||
|
"torch": ["torch>=1.13.1"],
|
||||||
"metrics": ["nltk", "jieba", "rouge-chinese"],
|
"metrics": ["nltk", "jieba", "rouge-chinese"],
|
||||||
"deepspeed": ["deepspeed>=0.10.0,<=0.14.0"],
|
"deepspeed": ["deepspeed>=0.10.0,<=0.14.0"],
|
||||||
"bitsandbytes": ["bitsandbytes>=0.39.0"],
|
"bitsandbytes": ["bitsandbytes>=0.39.0"],
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import Annotated, Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from ..chat import ChatModel
|
from ..chat import ChatModel
|
||||||
from ..extras.misc import torch_gc
|
from ..extras.misc import torch_gc
|
||||||
|
@ -11,7 +11,7 @@ from .aligner import align_dataset
|
|||||||
from .parser import get_dataset_list
|
from .parser import get_dataset_list
|
||||||
from .preprocess import get_preprocess_and_print_func
|
from .preprocess import get_preprocess_and_print_func
|
||||||
from .template import get_template_and_fix_tokenizer
|
from .template import get_template_and_fix_tokenizer
|
||||||
from .utils import checksum, merge_dataset
|
from .utils import merge_dataset
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -61,8 +61,6 @@ def load_single_dataset(
|
|||||||
|
|
||||||
if data_path is None:
|
if data_path is None:
|
||||||
raise ValueError("File extension must be txt, csv, json or jsonl.")
|
raise ValueError("File extension must be txt, csv, json or jsonl.")
|
||||||
|
|
||||||
checksum(data_files, dataset_attr.file_sha1)
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -21,7 +21,6 @@ class DatasetAttr:
|
|||||||
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
|
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
|
||||||
dataset_name: str
|
dataset_name: str
|
||||||
""" extra configs """
|
""" extra configs """
|
||||||
file_sha1: Optional[str] = None
|
|
||||||
subset: Optional[str] = None
|
subset: Optional[str] = None
|
||||||
folder: Optional[str] = None
|
folder: Optional[str] = None
|
||||||
ranking: bool = False
|
ranking: bool = False
|
||||||
@ -99,7 +98,6 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
|
|||||||
else:
|
else:
|
||||||
dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])
|
dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])
|
||||||
|
|
||||||
dataset_attr.set_attr("file_sha1", dataset_info[name])
|
|
||||||
dataset_attr.set_attr("subset", dataset_info[name])
|
dataset_attr.set_attr("subset", dataset_info[name])
|
||||||
dataset_attr.set_attr("folder", dataset_info[name])
|
dataset_attr.set_attr("folder", dataset_info[name])
|
||||||
dataset_attr.set_attr("ranking", dataset_info[name], default=False)
|
dataset_attr.set_attr("ranking", dataset_info[name], default=False)
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import hashlib
|
|
||||||
from enum import Enum, unique
|
from enum import Enum, unique
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
|
||||||
|
|
||||||
from datasets import concatenate_datasets, interleave_datasets
|
from datasets import concatenate_datasets, interleave_datasets
|
||||||
|
|
||||||
@ -26,21 +25,6 @@ class Role(str, Enum):
|
|||||||
OBSERVATION = "observation"
|
OBSERVATION = "observation"
|
||||||
|
|
||||||
|
|
||||||
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
|
|
||||||
if file_sha1 is None:
|
|
||||||
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
|
|
||||||
return
|
|
||||||
|
|
||||||
if len(data_files) != 1:
|
|
||||||
logger.warning("Checksum failed: too many files.")
|
|
||||||
return
|
|
||||||
|
|
||||||
with open(data_files[0], "rb") as f:
|
|
||||||
sha1 = hashlib.sha1(f.read()).hexdigest()
|
|
||||||
if sha1 != file_sha1:
|
|
||||||
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
|
|
||||||
|
|
||||||
|
|
||||||
def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label_len: int) -> Tuple[int, int]:
|
def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label_len: int) -> Tuple[int, int]:
|
||||||
max_target_len = int(max_len * (target_len / (source_len + target_len)))
|
max_target_len = int(max_len * (target_len / (source_len + target_len)))
|
||||||
max_target_len = max(max_target_len, reserved_label_len)
|
max_target_len = max(max_target_len, reserved_label_len)
|
||||||
|
@ -139,12 +139,14 @@ class LogCallback(TrainerCallback):
|
|||||||
r"""
|
r"""
|
||||||
Event called after an evaluation phase.
|
Event called after an evaluation phase.
|
||||||
"""
|
"""
|
||||||
|
if not self.do_train:
|
||||||
self._close_thread_pool()
|
self._close_thread_pool()
|
||||||
|
|
||||||
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Event called after a successful prediction.
|
Event called after a successful prediction.
|
||||||
"""
|
"""
|
||||||
|
if not self.do_train:
|
||||||
self._close_thread_pool()
|
self._close_thread_pool()
|
||||||
|
|
||||||
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
|
@ -715,11 +715,11 @@ register_model_group(
|
|||||||
models={
|
models={
|
||||||
"Phi3-3.8B-4k-Chat": {
|
"Phi3-3.8B-4k-Chat": {
|
||||||
DownloadSource.DEFAULT: "microsoft/Phi-3-mini-4k-instruct",
|
DownloadSource.DEFAULT: "microsoft/Phi-3-mini-4k-instruct",
|
||||||
DownloadSource.DEFAULT: "LLM-Research/Phi-3-mini-4k-instruct",
|
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-4k-instruct",
|
||||||
},
|
},
|
||||||
"Phi3-3.8B-128k-Chat": {
|
"Phi3-3.8B-128k-Chat": {
|
||||||
DownloadSource.DEFAULT: "microsoft/Phi-3-mini-128k-instruct",
|
DownloadSource.DEFAULT: "microsoft/Phi-3-mini-128k-instruct",
|
||||||
DownloadSource.DEFAULT: "LLM-Research/Phi-3-mini-128k-instruct",
|
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-128k-instruct",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
module="qkv_proj",
|
module="qkv_proj",
|
||||||
|
@ -46,6 +46,9 @@ def init_adapter(
|
|||||||
if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
|
if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
|
||||||
model = model.float()
|
model = model.float()
|
||||||
|
|
||||||
|
if model_args.visual_inputs and hasattr(model, "vision_tower"): # freeze vision model
|
||||||
|
model.vision_tower.requires_grad_(False)
|
||||||
|
|
||||||
if finetuning_args.finetuning_type == "freeze" and is_trainable:
|
if finetuning_args.finetuning_type == "freeze" and is_trainable:
|
||||||
logger.info("Fine-tuning method: Freeze")
|
logger.info("Fine-tuning method: Freeze")
|
||||||
num_layers = (
|
num_layers = (
|
||||||
|
@ -106,7 +106,7 @@ def load_model(
|
|||||||
"""
|
"""
|
||||||
init_kwargs = _get_init_kwargs(model_args)
|
init_kwargs = _get_init_kwargs(model_args)
|
||||||
config = load_config(model_args)
|
config = load_config(model_args)
|
||||||
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable, add_valuehead)
|
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
|
||||||
|
|
||||||
model = None
|
model = None
|
||||||
lazy_load = False
|
lazy_load = False
|
||||||
|
@ -15,8 +15,8 @@ from .utils.longlora import configure_longlora
|
|||||||
from .utils.moe import add_z3_leaf_module, configure_moe
|
from .utils.moe import add_z3_leaf_module, configure_moe
|
||||||
from .utils.quantization import configure_quantization
|
from .utils.quantization import configure_quantization
|
||||||
from .utils.rope import configure_rope
|
from .utils.rope import configure_rope
|
||||||
from .utils.valuehead import configure_valuehead, prepare_valuehead_model
|
from .utils.valuehead import prepare_valuehead_model
|
||||||
from .utils.visual import autocast_projector_dtype
|
from .utils.visual import autocast_projector_dtype, configure_hidden_size
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -40,7 +40,6 @@ def patch_config(
|
|||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
init_kwargs: Dict[str, Any],
|
init_kwargs: Dict[str, Any],
|
||||||
is_trainable: bool,
|
is_trainable: bool,
|
||||||
add_valuehead: bool,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
|
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
|
||||||
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||||
@ -50,9 +49,7 @@ def patch_config(
|
|||||||
configure_longlora(config, model_args, is_trainable)
|
configure_longlora(config, model_args, is_trainable)
|
||||||
configure_quantization(config, tokenizer, model_args, init_kwargs)
|
configure_quantization(config, tokenizer, model_args, init_kwargs)
|
||||||
configure_moe(config, model_args, is_trainable)
|
configure_moe(config, model_args, is_trainable)
|
||||||
|
configure_hidden_size(config)
|
||||||
if add_valuehead:
|
|
||||||
configure_valuehead(config)
|
|
||||||
|
|
||||||
if model_args.use_cache and not is_trainable:
|
if model_args.use_cache and not is_trainable:
|
||||||
setattr(config, "use_cache", True)
|
setattr(config, "use_cache", True)
|
||||||
|
@ -8,7 +8,7 @@ from ...extras.logging import get_logger
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PretrainedConfig, PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
|
|
||||||
from ...hparams import ModelArguments
|
from ...hparams import ModelArguments
|
||||||
|
|
||||||
@ -16,11 +16,6 @@ if TYPE_CHECKING:
|
|||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def configure_valuehead(config: "PretrainedConfig") -> None:
|
|
||||||
if getattr(config, "model_type", None) == "llava":
|
|
||||||
setattr(config, "hidden_size", getattr(config.vision_config, "intermediate_size", None))
|
|
||||||
|
|
||||||
|
|
||||||
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
|
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
|
||||||
r"""
|
r"""
|
||||||
Loads value head parameters from Hugging Face Hub or local disk.
|
Loads value head parameters from Hugging Face Hub or local disk.
|
||||||
|
@ -6,7 +6,7 @@ from ...extras.logging import get_logger
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PreTrainedModel
|
from transformers import PretrainedConfig, PreTrainedModel
|
||||||
|
|
||||||
from ...hparams import ModelArguments
|
from ...hparams import ModelArguments
|
||||||
|
|
||||||
@ -14,6 +14,11 @@ if TYPE_CHECKING:
|
|||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def configure_hidden_size(config: "PretrainedConfig") -> None:
|
||||||
|
if getattr(config, "model_type", None) == "llava":
|
||||||
|
setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))
|
||||||
|
|
||||||
|
|
||||||
def autocast_projector_dtype(
|
def autocast_projector_dtype(
|
||||||
model: "PreTrainedModel", model_args: "ModelArguments", mm_projector_name: str = "multi_modal_projector"
|
model: "PreTrainedModel", model_args: "ModelArguments", mm_projector_name: str = "multi_modal_projector"
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -22,7 +27,7 @@ def autocast_projector_dtype(
|
|||||||
) -> "torch.Tensor":
|
) -> "torch.Tensor":
|
||||||
return output.to(model_args.compute_dtype)
|
return output.to(model_args.compute_dtype)
|
||||||
|
|
||||||
if hasattr(model, mm_projector_name):
|
if hasattr(model, mm_projector_name) and getattr(model.config, "quantization_method", None):
|
||||||
logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype))
|
logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype))
|
||||||
mm_projector: "torch.nn.Module" = getattr(model, mm_projector_name)
|
mm_projector: "torch.nn.Module" = getattr(model, mm_projector_name)
|
||||||
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)
|
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)
|
||||||
|
@ -52,7 +52,9 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
|
|||||||
if model_args.adapter_name_or_path is not None and model_args.export_quantization_bit is not None:
|
if model_args.adapter_name_or_path is not None and model_args.export_quantization_bit is not None:
|
||||||
raise ValueError("Please merge adapters before quantizing the model.")
|
raise ValueError("Please merge adapters before quantizing the model.")
|
||||||
|
|
||||||
tokenizer = load_tokenizer(model_args)["tokenizer"]
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
|
tokenizer = tokenizer_module["tokenizer"]
|
||||||
|
processor = tokenizer_module["processor"]
|
||||||
get_template_and_fix_tokenizer(tokenizer, data_args.template)
|
get_template_and_fix_tokenizer(tokenizer, data_args.template)
|
||||||
model = load_model(tokenizer, model_args, finetuning_args) # must after fixing tokenizer to resize vocab
|
model = load_model(tokenizer, model_args, finetuning_args) # must after fixing tokenizer to resize vocab
|
||||||
|
|
||||||
@ -66,6 +68,8 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
|
|||||||
output_dtype = getattr(model.config, "torch_dtype", torch.float16)
|
output_dtype = getattr(model.config, "torch_dtype", torch.float16)
|
||||||
setattr(model.config, "torch_dtype", output_dtype)
|
setattr(model.config, "torch_dtype", output_dtype)
|
||||||
model = model.to(output_dtype)
|
model = model.to(output_dtype)
|
||||||
|
else:
|
||||||
|
setattr(model.config, "torch_dtype", torch.float16)
|
||||||
|
|
||||||
model.save_pretrained(
|
model.save_pretrained(
|
||||||
save_directory=model_args.export_dir,
|
save_directory=model_args.export_dir,
|
||||||
@ -86,5 +90,12 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
|
|||||||
tokenizer.save_pretrained(model_args.export_dir)
|
tokenizer.save_pretrained(model_args.export_dir)
|
||||||
if model_args.export_hub_model_id is not None:
|
if model_args.export_hub_model_id is not None:
|
||||||
tokenizer.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token)
|
tokenizer.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token)
|
||||||
|
|
||||||
|
if model_args.visual_inputs and processor is not None:
|
||||||
|
getattr(processor, "image_processor").save_pretrained(model_args.export_dir)
|
||||||
|
if model_args.export_hub_model_id is not None:
|
||||||
|
getattr(processor, "image_processor").push_to_hub(
|
||||||
|
model_args.export_hub_model_id, token=model_args.hf_hub_token
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("Cannot save tokenizer, please copy the files manually.")
|
logger.warning("Cannot save tokenizer, please copy the files manually.")
|
||||||
|
@ -71,14 +71,12 @@ def create_web_demo() -> gr.Blocks:
|
|||||||
|
|
||||||
|
|
||||||
def run_web_ui() -> None:
|
def run_web_ui() -> None:
|
||||||
server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
|
|
||||||
server_port = int(os.environ.get("GRADIO_SERVER_PORT", "7860"))
|
|
||||||
gradio_share = bool(int(os.environ.get("GRADIO_SHARE", "0")))
|
gradio_share = bool(int(os.environ.get("GRADIO_SHARE", "0")))
|
||||||
create_ui().queue().launch(share=gradio_share, server_name=server_name, server_port=server_port)
|
server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
|
||||||
|
create_ui().queue().launch(share=gradio_share, server_name=server_name)
|
||||||
|
|
||||||
|
|
||||||
def run_web_demo() -> None:
|
def run_web_demo() -> None:
|
||||||
server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
|
|
||||||
server_port = int(os.environ.get("GRADIO_SERVER_PORT", "7860"))
|
|
||||||
gradio_share = bool(int(os.environ.get("GRADIO_SHARE", "0")))
|
gradio_share = bool(int(os.environ.get("GRADIO_SHARE", "0")))
|
||||||
create_web_demo().queue().launch(share=gradio_share, server_name=server_name, server_port=server_port)
|
server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
|
||||||
|
create_web_demo().queue().launch(share=gradio_share, server_name=server_name)
|
||||||
|
@ -4,10 +4,9 @@ from llmtuner.webui.interface import create_ui
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
|
|
||||||
server_port = int(os.environ.get("GRADIO_SERVER_PORT", "7860"))
|
|
||||||
gradio_share = bool(int(os.environ.get("GRADIO_SHARE", "0")))
|
gradio_share = bool(int(os.environ.get("GRADIO_SHARE", "0")))
|
||||||
create_ui().queue().launch(share=gradio_share, server_name=server_name, server_port=server_port)
|
server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
|
||||||
|
create_ui().queue().launch(share=gradio_share, server_name=server_name)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user