Merge branch 'hiyouga:main' into minicpmv

Former-commit-id: ab87bd6b1398b379b1a7a95f01a6539743b9db2d
This commit is contained in:
Zhangchi Feng 2025-01-04 11:20:33 +08:00 committed by GitHub
commit a0188a430f
16 changed files with 155 additions and 54 deletions

View File

@ -4,7 +4,7 @@
[![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE)
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
[![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/)
[![Citation](https://img.shields.io/badge/citation-196-green)](https://scholar.google.com/scholar?cites=12620864006390196564)
[![Citation](https://img.shields.io/badge/citation-210-green)](https://scholar.google.com/scholar?cites=12620864006390196564)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
@ -195,6 +195,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
| [Granite 3.0-3.1](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 |
| [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
@ -753,6 +754,7 @@ If you have a project that should be incorporated, please contact via email or c
1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**: SDKs for fine-tuning LLMs on Windows PC for NVIDIA RTX.
1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**: An easy and lazy way for building multi-agent LLMs applications and supports model fine-tuning via LLaMA Factory.
1. **[RAG-Retrieval](https://github.com/NLPJCL/RAG-Retrieval)**: A full pipeline for RAG retrieval model fine-tuning, inference, and distillation. [[blog]](https://zhuanlan.zhihu.com/p/987727357)
1. **[360-LLaMA-Factory](https://github.com/Qihoo360/360-LLaMA-Factory)**: A modified library that supports long sequence SFT & DPO using ring attention.
</details>
@ -760,7 +762,7 @@ If you have a project that should be incorporated, please contact via email or c
This repository is licensed under the [Apache-2.0 License](LICENSE).
Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
## Citation

View File

@ -4,7 +4,7 @@
[![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE)
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
[![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/)
[![Citation](https://img.shields.io/badge/citation-196-green)](https://scholar.google.com/scholar?cites=12620864006390196564)
[![Citation](https://img.shields.io/badge/citation-210-green)](https://scholar.google.com/scholar?cites=12620864006390196564)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
@ -196,6 +196,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
| [Granite 3.0-3.1](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 |
| [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
@ -754,6 +755,7 @@ swanlab_run_name: test_run # 可选
1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**:在 Windows 主机上利用英伟达 RTX 设备进行大型语言模型微调的开发包。
1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**:一个低代码构建多 Agent 大模型应用的开发工具,支持基于 LLaMA Factory 的模型微调.
1. **[RAG-Retrieval](https://github.com/NLPJCL/RAG-Retrieval)**:一个全链路 RAG 检索模型微调、推理和蒸馏代码库。[[blog]](https://zhuanlan.zhihu.com/p/987727357)
1. **[360-LLaMA-Factory](https://github.com/Qihoo360/360-LLaMA-Factory)**:一个魔改后的代码库,通过 Ring Attention 支持长序列的 SFT 和 DPO 训练。
</details>
@ -761,7 +763,7 @@ swanlab_run_name: test_run # 可选
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
## 引用

View File

@ -17,16 +17,28 @@ ARG INSTALL_LIGER_KERNEL=false
ARG INSTALL_HQQ=false
ARG INSTALL_EETQ=false
ARG PIP_INDEX=https://pypi.org/simple
ARG HTTP_PROXY=
# Set the working directory
WORKDIR /app
# Set http proxy
RUN if [ -n "$HTTP_PROXY" ]; then \
echo "Configuring proxy..."; \
export http_proxy=$HTTP_PROXY; \
export https_proxy=$HTTP_PROXY; \
fi
# Install the requirements
COPY requirements.txt /app
RUN pip config set global.index-url "$PIP_INDEX" && \
pip config set global.extra-index-url "$PIP_INDEX" && \
python -m pip install --upgrade pip && \
python -m pip install -r requirements.txt
if [ -n "$HTTP_PROXY" ]; then \
python -m pip install --proxy=$HTTP_PROXY -r requirements.txt; \
else \
python -m pip install -r requirements.txt; \
fi
# Copy the rest of the application into the image
COPY . /app
@ -51,13 +63,30 @@ RUN EXTRA_PACKAGES="metrics"; \
if [ "$INSTALL_EETQ" == "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},eetq"; \
fi; \
pip install -e ".[$EXTRA_PACKAGES]"
if [ -n "$HTTP_PROXY" ]; then \
pip install --proxy=$HTTP_PROXY -e ".[$EXTRA_PACKAGES]"; \
else \
pip install -e ".[$EXTRA_PACKAGES]"; \
fi
# Rebuild flash attention
RUN pip uninstall -y transformer-engine flash-attn && \
if [ "$INSTALL_FLASHATTN" == "true" ]; then \
pip uninstall -y ninja && pip install ninja && \
pip uninstall -y ninja && \
if [ -n "$HTTP_PROXY" ]; then \
pip install --proxy=$HTTP_PROXY ninja && \
pip install --proxy=$HTTP_PROXY --no-cache-dir flash-attn --no-build-isolation; \
else \
pip install ninja && \
pip install --no-cache-dir flash-attn --no-build-isolation; \
fi; \
fi
# Unset http proxy
RUN if [ -n "$HTTP_PROXY" ]; then \
unset http_proxy; \
unset https_proxy; \
fi
# Set up volumes

View File

@ -12,16 +12,28 @@ ENV DEBIAN_FRONTEND=noninteractive
ARG INSTALL_DEEPSPEED=false
ARG PIP_INDEX=https://pypi.org/simple
ARG TORCH_INDEX=https://download.pytorch.org/whl/cpu
ARG HTTP_PROXY=
# Set the working directory
WORKDIR /app
# Set http proxy
RUN if [ -n "$HTTP_PROXY" ]; then \
echo "Configuring proxy..."; \
export http_proxy=$HTTP_PROXY; \
export https_proxy=$HTTP_PROXY; \
fi
# Install the requirements
COPY requirements.txt /app
RUN pip config set global.index-url "$PIP_INDEX" && \
pip config set global.extra-index-url "$TORCH_INDEX" && \
python -m pip install --upgrade pip && \
python -m pip install -r requirements.txt
if [ -n "$HTTP_PROXY" ]; then \
python -m pip install --proxy=$HTTP_PROXY -r requirements.txt; \
else \
python -m pip install -r requirements.txt; \
fi
# Copy the rest of the application into the image
COPY . /app
@ -31,7 +43,17 @@ RUN EXTRA_PACKAGES="torch-npu,metrics"; \
if [ "$INSTALL_DEEPSPEED" == "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
fi; \
pip install -e ".[$EXTRA_PACKAGES]"
if [ -n "$HTTP_PROXY" ]; then \
pip install --proxy=$HTTP_PROXY -e ".[$EXTRA_PACKAGES]"; \
else \
pip install -e ".[$EXTRA_PACKAGES]"; \
fi
# Unset http proxy
RUN if [ -n "$HTTP_PROXY" ]; then \
unset http_proxy; \
unset https_proxy; \
fi
# Set up volumes
VOLUME [ "/root/.cache/huggingface", "/root/.cache/modelscope", "/app/data", "/app/output" ]

View File

@ -13,16 +13,28 @@ ARG INSTALL_FLASHATTN=false
ARG INSTALL_LIGER_KERNEL=false
ARG INSTALL_HQQ=false
ARG PIP_INDEX=https://pypi.org/simple
ARG HTTP_PROXY=
# Set the working directory
WORKDIR /app
# Set http proxy
RUN if [ -n "$HTTP_PROXY" ]; then \
echo "Configuring proxy..."; \
export http_proxy=$HTTP_PROXY; \
export https_proxy=$HTTP_PROXY; \
fi
# Install the requirements
COPY requirements.txt /app
RUN pip config set global.index-url "$PIP_INDEX" && \
pip config set global.extra-index-url "$PIP_INDEX" && \
python -m pip install --upgrade pip && \
python -m pip install -r requirements.txt
if [ -n "$HTTP_PROXY" ]; then \
python -m pip install --proxy=$HTTP_PROXY -r requirements.txt; \
else \
python -m pip install -r requirements.txt; \
fi
# Copy the rest of the application into the image
COPY . /app
@ -44,13 +56,29 @@ RUN EXTRA_PACKAGES="metrics"; \
if [ "$INSTALL_HQQ" == "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},hqq"; \
fi; \
pip install -e ".[$EXTRA_PACKAGES]"
if [ -n "$HTTP_PROXY" ]; then \
pip install --proxy=$HTTP_PROXY -e ".[$EXTRA_PACKAGES]"; \
else \
pip install -e ".[$EXTRA_PACKAGES]"; \
fi
# Rebuild flash attention
RUN pip uninstall -y transformer-engine flash-attn && \
if [ "$INSTALL_FLASHATTN" == "true" ]; then \
pip uninstall -y ninja && pip install ninja && \
pip uninstall -y ninja && \
if [ -n "$HTTP_PROXY" ]; then \
pip install --proxy=$HTTP_PROXY ninja && \
pip install --proxy=$HTTP_PROXY --no-cache-dir flash-attn --no-build-isolation; \
else \
pip install ninja && \
pip install --no-cache-dir flash-attn --no-build-isolation; \
fi; \
fi
# Unset http proxy
RUN if [ -n "$HTTP_PROXY" ]; then \
unset http_proxy; \
unset https_proxy; \
fi
# Set up volumes

View File

@ -24,7 +24,7 @@ import fire
import torch
from safetensors.torch import save_file
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, PreTrainedModel
from transformers.modeling_utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
@ -35,7 +35,7 @@ from transformers.modeling_utils import (
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel
from transformers import PretrainedConfig
def change_name(name: str, old_index: int, new_index: int) -> str:
@ -61,17 +61,18 @@ def block_expansion(
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
tokenizer.save_pretrained(output_dir)
config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path) # load the original one
config = AutoConfig.from_pretrained(model_name_or_path) # load the original one
if save_safetensors:
setattr(config, "tie_word_embeddings", False) # safetensors does not allow shared weights
model: "PreTrainedModel" = AutoModelForCausalLM.from_pretrained(
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
config=config,
torch_dtype="auto",
trust_remote_code=True,
low_cpu_mem_usage=True,
)
assert isinstance(model, PreTrainedModel) # type hint
state_dict = model.state_dict()
if num_layers % num_expand != 0:
@ -85,7 +86,7 @@ def block_expansion(
if f".{i:d}." in key:
output_state_dict[change_name(key, i, layer_cnt)] = value
print(f"Add layer {layer_cnt} copied from layer {i}")
print(f"Add layer {layer_cnt} copied from layer {i}.")
layer_cnt += 1
if (i + 1) % split == 0:
for key, value in state_dict.items():
@ -95,7 +96,7 @@ def block_expansion(
else:
output_state_dict[change_name(key, i, layer_cnt)] = torch.clone(value)
print(f"Add layer {layer_cnt} expanded from layer {i}")
print(f"Add layer {layer_cnt} expanded from layer {i}.")
layer_cnt += 1
for key, value in state_dict.items():
@ -112,12 +113,13 @@ def block_expansion(
torch.save(shard, os.path.join(output_dir, shard_file))
if index is None:
print(f"Model weights saved in {os.path.join(output_dir, weights_name)}")
print(f"Model weights saved in {os.path.join(output_dir, weights_name)}.")
else:
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
json.dump(index, f, indent=2, sort_keys=True)
print(f"Model weights saved in {output_dir}")
print(f"Model weights saved in {output_dir}.")
print("- Fine-tune this model with:")
print(f"model_name_or_path: {output_dir}")

View File

@ -41,7 +41,7 @@ def calculate_lr(
dataset: str = "alpaca_en_demo",
dataset_dir: str = "data",
template: str = "default",
cutoff_len: int = 1024, # i.e. maximum input length during training
cutoff_len: int = 2048, # i.e. maximum input length during training
is_mistral_or_gemma: bool = False, # mistral and gemma models opt for a smaller learning rate,
packing: bool = False,
):
@ -59,6 +59,7 @@ def calculate_lr(
template=template,
cutoff_len=cutoff_len,
packing=packing,
preprocessing_num_workers=16,
output_dir="dummy_dir",
overwrite_cache=True,
do_train=True,
@ -79,7 +80,7 @@ def calculate_lr(
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
valid_tokens, total_tokens = 0, 0
for batch in tqdm(dataloader):
for batch in tqdm(dataloader, desc="Collecting valid tokens"):
valid_tokens += torch.sum(batch["labels"] != IGNORE_INDEX).item()
total_tokens += torch.numel(batch["labels"])

View File

@ -63,7 +63,7 @@ def calculate_ppl(
dataset: str = "alpaca_en_demo",
dataset_dir: str = "data",
template: str = "default",
cutoff_len: int = 1024,
cutoff_len: int = 2048,
max_samples: Optional[int] = None,
train_on_prompt: bool = False,
):
@ -82,6 +82,7 @@ def calculate_ppl(
cutoff_len=cutoff_len,
max_samples=max_samples,
train_on_prompt=train_on_prompt,
preprocessing_num_workers=16,
output_dir="dummy_dir",
overwrite_cache=True,
do_train=True,
@ -111,7 +112,7 @@ def calculate_ppl(
perplexities = []
batch: Dict[str, "torch.Tensor"]
with torch.no_grad():
for batch in tqdm(dataloader):
for batch in tqdm(dataloader, desc="Computing perplexities"):
batch = batch.to(model.device)
outputs = model(**batch)
shift_logits: "torch.Tensor" = outputs["logits"][..., :-1, :]

View File

@ -42,6 +42,7 @@ def length_cdf(
dataset_dir=dataset_dir,
template=template,
cutoff_len=1_000_000,
preprocessing_num_workers=16,
output_dir="dummy_dir",
overwrite_cache=True,
do_train=True,
@ -52,7 +53,7 @@ def length_cdf(
trainset = get_dataset(template, model_args, data_args, training_args, "sft", **tokenizer_module)["train_dataset"]
total_num = len(trainset)
length_dict = defaultdict(int)
for sample in tqdm(trainset["input_ids"]):
for sample in tqdm(trainset["input_ids"], desc="Collecting lengths"):
length_dict[len(sample) // interval * interval] += 1
length_tuples = list(length_dict.items())

View File

@ -64,6 +64,7 @@ def vllm_infer(
template=template,
cutoff_len=cutoff_len,
max_samples=max_samples,
preprocessing_num_workers=16,
vllm_config=vllm_config,
temperature=temperature,
top_p=top_p,

View File

@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple
from .processors.feedback import preprocess_feedback_dataset
from .processors.pairwise import preprocess_pairwise_dataset, print_pairwise_dataset_example
from .processors.pretrain import preprocess_pretrain_dataset
from .processors.pretrain import preprocess_pretrain_dataset, print_pretrain_dataset_example
from .processors.supervised import (
preprocess_packed_supervised_dataset,
preprocess_supervised_dataset,
@ -47,7 +47,7 @@ def get_preprocess_and_print_func(
tokenizer=tokenizer,
data_args=data_args,
)
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
print_function = partial(print_pretrain_dataset_example, tokenizer=tokenizer)
elif stage == "sft" and not do_generate:
if data_args.packing:
if data_args.neat_packing: # hack datasets to have int32 attention mask

View File

@ -52,3 +52,8 @@ def preprocess_pretrain_dataset(
result["input_ids"][i][0] = tokenizer.bos_token_id
return result
def print_pretrain_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))

View File

@ -100,3 +100,5 @@ def preprocess_unsupervised_dataset(
def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print("labels:\n{}".format(tokenizer.decode(example["labels"], skip_special_tokens=False)))

View File

@ -618,6 +618,28 @@ register_model_group(
)
register_model_group(
models={
"GPT-2-Small": {
DownloadSource.DEFAULT: "openai-community/gpt2",
DownloadSource.MODELSCOPE: "AI-ModelScope/gpt2",
},
"GPT-2-Medium": {
DownloadSource.DEFAULT: "openai-community/gpt2-medium",
DownloadSource.MODELSCOPE: "AI-ModelScope/gpt2-medium",
},
"GPT-2-Large": {
DownloadSource.DEFAULT: "openai-community/gpt2-large",
DownloadSource.MODELSCOPE: "AI-ModelScope/gpt2-large",
},
"GPT-2-XL": {
DownloadSource.DEFAULT: "openai-community/gpt2-xl",
DownloadSource.MODELSCOPE: "goodbai95/GPT2-xl",
},
},
)
register_model_group(
models={
"Granite-3.0-1B-A400M-Base": {

View File

@ -111,40 +111,27 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
inputs: Dict[str, Union["torch.Tensor", Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
**gen_kwargs,
) -> Tuple[Optional[float], Optional["torch.Tensor"], Optional["torch.Tensor"]]:
r"""
Removes the prompt part in the generated tokens.
Subclass and override to inject custom behavior.
"""
labels = inputs["labels"] if "labels" in inputs else None
if self.args.predict_with_generate:
assert self.processing_class.padding_side == "left", "This method only accepts left-padded tensor."
labels = labels.detach().clone() if labels is not None else None # backup labels
prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
if prompt_len > label_len:
inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"])
if label_len > prompt_len: # truncate the labels instead of padding the inputs (llama2 fp16 compatibility)
inputs["labels"] = inputs["labels"][:, :prompt_len]
if self.args.predict_with_generate: # do not pass labels to model when generate
labels = inputs.pop("labels", None)
else:
labels = inputs.get("labels")
loss, generated_tokens, _ = super().prediction_step( # ignore the returned labels (may be truncated)
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
loss, generated_tokens, _ = super().prediction_step(
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys, **gen_kwargs
)
if generated_tokens is not None and self.args.predict_with_generate:
generated_tokens[:, :prompt_len] = self.processing_class.pad_token_id
generated_tokens[:, : inputs["input_ids"].size(-1)] = self.processing_class.pad_token_id
generated_tokens = generated_tokens.contiguous()
return loss, generated_tokens, labels
def _pad_tensors_to_target_len(self, src_tensor: "torch.Tensor", tgt_tensor: "torch.Tensor") -> "torch.Tensor":
r"""
Pads the tensor to the same length as the target tensor.
"""
assert self.processing_class.pad_token_id is not None, "Pad token is required."
padded_tensor = self.processing_class.pad_token_id * torch.ones_like(tgt_tensor)
padded_tensor[:, -src_tensor.shape[-1] :] = src_tensor # adopt left-padding
return padded_tensor.contiguous() # in contiguous memory
def save_predictions(
self, dataset: "Dataset", predict_results: "PredictionOutput", skip_special_tokens: bool = True
) -> None:

View File

@ -117,8 +117,6 @@ def run_sft(
# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled
metrics.pop("eval_loss", None)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
@ -126,8 +124,6 @@ def run_sft(
if training_args.do_predict:
logger.warning_once("Batch generation can be very slow. Consider using `scripts/vllm_infer.py` instead.")
predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs)
if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
predict_results.metrics.pop("predict_loss", None)
trainer.log_metrics("predict", predict_results.metrics)
trainer.save_metrics("predict", predict_results.metrics)
trainer.save_predictions(dataset_module["eval_dataset"], predict_results, generating_args.skip_special_tokens)