diff --git a/README.md b/README.md index 4535fd88..1eec5b96 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ [![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE) [![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main) [![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/) -[![Citation](https://img.shields.io/badge/citation-196-green)](https://scholar.google.com/scholar?cites=12620864006390196564) +[![Citation](https://img.shields.io/badge/citation-210-green)](https://scholar.google.com/scholar?cites=12620864006390196564) [![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls) [![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK) [![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai) @@ -195,6 +195,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ | [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon | | [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma | | [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 | +| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - | | [Granite 3.0-3.1](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 | | [Index](https://huggingface.co/IndexTeam) | 1.9B | index | | [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 | @@ -753,6 +754,7 @@ If you have a project that should be incorporated, please contact via email or c 1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**: SDKs for fine-tuning LLMs on Windows PC for NVIDIA RTX. 1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**: An easy and lazy way for building multi-agent LLMs applications and supports model fine-tuning via LLaMA Factory. 1. **[RAG-Retrieval](https://github.com/NLPJCL/RAG-Retrieval)**: A full pipeline for RAG retrieval model fine-tuning, inference, and distillation. [[blog]](https://zhuanlan.zhihu.com/p/987727357) +1. **[360-LLaMA-Factory](https://github.com/Qihoo360/360-LLaMA-Factory)**: A modified library that supports long sequence SFT & DPO using ring attention. @@ -760,7 +762,7 @@ If you have a project that should be incorporated, please contact via email or c This repository is licensed under the [Apache-2.0 License](LICENSE). -Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) +Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) ## Citation diff --git a/README_zh.md b/README_zh.md index 998779ea..352a483d 100644 --- a/README_zh.md +++ b/README_zh.md @@ -4,7 +4,7 @@ [![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE) [![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main) [![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/) -[![Citation](https://img.shields.io/badge/citation-196-green)](https://scholar.google.com/scholar?cites=12620864006390196564) +[![Citation](https://img.shields.io/badge/citation-210-green)](https://scholar.google.com/scholar?cites=12620864006390196564) [![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls) [![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK) [![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai) @@ -196,6 +196,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272 | [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon | | [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma | | [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 | +| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - | | [Granite 3.0-3.1](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 | | [Index](https://huggingface.co/IndexTeam) | 1.9B | index | | [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 | @@ -754,6 +755,7 @@ swanlab_run_name: test_run # 可选 1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**:在 Windows 主机上利用英伟达 RTX 设备进行大型语言模型微调的开发包。 1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**:一个低代码构建多 Agent 大模型应用的开发工具,支持基于 LLaMA Factory 的模型微调. 1. **[RAG-Retrieval](https://github.com/NLPJCL/RAG-Retrieval)**:一个全链路 RAG 检索模型微调、推理和蒸馏代码库。[[blog]](https://zhuanlan.zhihu.com/p/987727357) +1. **[360-LLaMA-Factory](https://github.com/Qihoo360/360-LLaMA-Factory)**:一个魔改后的代码库,通过 Ring Attention 支持长序列的 SFT 和 DPO 训练。 @@ -761,7 +763,7 @@ swanlab_run_name: test_run # 可选 本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。 -使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) +使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) ## 引用 diff --git a/docker/docker-cuda/Dockerfile b/docker/docker-cuda/Dockerfile index 34503290..b1914ed5 100644 --- a/docker/docker-cuda/Dockerfile +++ b/docker/docker-cuda/Dockerfile @@ -17,16 +17,28 @@ ARG INSTALL_LIGER_KERNEL=false ARG INSTALL_HQQ=false ARG INSTALL_EETQ=false ARG PIP_INDEX=https://pypi.org/simple +ARG HTTP_PROXY= # Set the working directory WORKDIR /app +# Set http proxy +RUN if [ -n "$HTTP_PROXY" ]; then \ + echo "Configuring proxy..."; \ + export http_proxy=$HTTP_PROXY; \ + export https_proxy=$HTTP_PROXY; \ + fi + # Install the requirements COPY requirements.txt /app RUN pip config set global.index-url "$PIP_INDEX" && \ pip config set global.extra-index-url "$PIP_INDEX" && \ python -m pip install --upgrade pip && \ - python -m pip install -r requirements.txt + if [ -n "$HTTP_PROXY" ]; then \ + python -m pip install --proxy=$HTTP_PROXY -r requirements.txt; \ + else \ + python -m pip install -r requirements.txt; \ + fi # Copy the rest of the application into the image COPY . /app @@ -51,13 +63,30 @@ RUN EXTRA_PACKAGES="metrics"; \ if [ "$INSTALL_EETQ" == "true" ]; then \ EXTRA_PACKAGES="${EXTRA_PACKAGES},eetq"; \ fi; \ - pip install -e ".[$EXTRA_PACKAGES]" + if [ -n "$HTTP_PROXY" ]; then \ + pip install --proxy=$HTTP_PROXY -e ".[$EXTRA_PACKAGES]"; \ + else \ + pip install -e ".[$EXTRA_PACKAGES]"; \ + fi # Rebuild flash attention RUN pip uninstall -y transformer-engine flash-attn && \ if [ "$INSTALL_FLASHATTN" == "true" ]; then \ - pip uninstall -y ninja && pip install ninja && \ - pip install --no-cache-dir flash-attn --no-build-isolation; \ + pip uninstall -y ninja && \ + if [ -n "$HTTP_PROXY" ]; then \ + pip install --proxy=$HTTP_PROXY ninja && \ + pip install --proxy=$HTTP_PROXY --no-cache-dir flash-attn --no-build-isolation; \ + else \ + pip install ninja && \ + pip install --no-cache-dir flash-attn --no-build-isolation; \ + fi; \ + fi + + +# Unset http proxy +RUN if [ -n "$HTTP_PROXY" ]; then \ + unset http_proxy; \ + unset https_proxy; \ fi # Set up volumes diff --git a/docker/docker-npu/Dockerfile b/docker/docker-npu/Dockerfile index dc35de47..15d4eee4 100644 --- a/docker/docker-npu/Dockerfile +++ b/docker/docker-npu/Dockerfile @@ -12,16 +12,28 @@ ENV DEBIAN_FRONTEND=noninteractive ARG INSTALL_DEEPSPEED=false ARG PIP_INDEX=https://pypi.org/simple ARG TORCH_INDEX=https://download.pytorch.org/whl/cpu +ARG HTTP_PROXY= # Set the working directory WORKDIR /app +# Set http proxy +RUN if [ -n "$HTTP_PROXY" ]; then \ + echo "Configuring proxy..."; \ + export http_proxy=$HTTP_PROXY; \ + export https_proxy=$HTTP_PROXY; \ + fi + # Install the requirements COPY requirements.txt /app RUN pip config set global.index-url "$PIP_INDEX" && \ pip config set global.extra-index-url "$TORCH_INDEX" && \ python -m pip install --upgrade pip && \ - python -m pip install -r requirements.txt + if [ -n "$HTTP_PROXY" ]; then \ + python -m pip install --proxy=$HTTP_PROXY -r requirements.txt; \ + else \ + python -m pip install -r requirements.txt; \ + fi # Copy the rest of the application into the image COPY . /app @@ -31,7 +43,17 @@ RUN EXTRA_PACKAGES="torch-npu,metrics"; \ if [ "$INSTALL_DEEPSPEED" == "true" ]; then \ EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \ fi; \ - pip install -e ".[$EXTRA_PACKAGES]" + if [ -n "$HTTP_PROXY" ]; then \ + pip install --proxy=$HTTP_PROXY -e ".[$EXTRA_PACKAGES]"; \ + else \ + pip install -e ".[$EXTRA_PACKAGES]"; \ + fi + +# Unset http proxy +RUN if [ -n "$HTTP_PROXY" ]; then \ + unset http_proxy; \ + unset https_proxy; \ + fi # Set up volumes VOLUME [ "/root/.cache/huggingface", "/root/.cache/modelscope", "/app/data", "/app/output" ] diff --git a/docker/docker-rocm/Dockerfile b/docker/docker-rocm/Dockerfile index 62bd78f5..86e96a37 100644 --- a/docker/docker-rocm/Dockerfile +++ b/docker/docker-rocm/Dockerfile @@ -13,16 +13,28 @@ ARG INSTALL_FLASHATTN=false ARG INSTALL_LIGER_KERNEL=false ARG INSTALL_HQQ=false ARG PIP_INDEX=https://pypi.org/simple +ARG HTTP_PROXY= # Set the working directory WORKDIR /app +# Set http proxy +RUN if [ -n "$HTTP_PROXY" ]; then \ + echo "Configuring proxy..."; \ + export http_proxy=$HTTP_PROXY; \ + export https_proxy=$HTTP_PROXY; \ + fi + # Install the requirements COPY requirements.txt /app RUN pip config set global.index-url "$PIP_INDEX" && \ pip config set global.extra-index-url "$PIP_INDEX" && \ python -m pip install --upgrade pip && \ - python -m pip install -r requirements.txt + if [ -n "$HTTP_PROXY" ]; then \ + python -m pip install --proxy=$HTTP_PROXY -r requirements.txt; \ + else \ + python -m pip install -r requirements.txt; \ + fi # Copy the rest of the application into the image COPY . /app @@ -44,13 +56,29 @@ RUN EXTRA_PACKAGES="metrics"; \ if [ "$INSTALL_HQQ" == "true" ]; then \ EXTRA_PACKAGES="${EXTRA_PACKAGES},hqq"; \ fi; \ - pip install -e ".[$EXTRA_PACKAGES]" + if [ -n "$HTTP_PROXY" ]; then \ + pip install --proxy=$HTTP_PROXY -e ".[$EXTRA_PACKAGES]"; \ + else \ + pip install -e ".[$EXTRA_PACKAGES]"; \ + fi # Rebuild flash attention RUN pip uninstall -y transformer-engine flash-attn && \ if [ "$INSTALL_FLASHATTN" == "true" ]; then \ - pip uninstall -y ninja && pip install ninja && \ - pip install --no-cache-dir flash-attn --no-build-isolation; \ + pip uninstall -y ninja && \ + if [ -n "$HTTP_PROXY" ]; then \ + pip install --proxy=$HTTP_PROXY ninja && \ + pip install --proxy=$HTTP_PROXY --no-cache-dir flash-attn --no-build-isolation; \ + else \ + pip install ninja && \ + pip install --no-cache-dir flash-attn --no-build-isolation; \ + fi; \ + fi + +# Unset http proxy +RUN if [ -n "$HTTP_PROXY" ]; then \ + unset http_proxy; \ + unset https_proxy; \ fi # Set up volumes diff --git a/scripts/llama_pro.py b/scripts/llama_pro.py index b086583d..447890f4 100644 --- a/scripts/llama_pro.py +++ b/scripts/llama_pro.py @@ -24,7 +24,7 @@ import fire import torch from safetensors.torch import save_file from tqdm import tqdm -from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, PreTrainedModel from transformers.modeling_utils import ( SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, @@ -35,7 +35,7 @@ from transformers.modeling_utils import ( if TYPE_CHECKING: - from transformers import PretrainedConfig, PreTrainedModel + from transformers import PretrainedConfig def change_name(name: str, old_index: int, new_index: int) -> str: @@ -61,17 +61,18 @@ def block_expansion( tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) tokenizer.save_pretrained(output_dir) - config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path) # load the original one + config = AutoConfig.from_pretrained(model_name_or_path) # load the original one if save_safetensors: setattr(config, "tie_word_embeddings", False) # safetensors does not allow shared weights - model: "PreTrainedModel" = AutoModelForCausalLM.from_pretrained( + model = AutoModelForCausalLM.from_pretrained( model_name_or_path, config=config, torch_dtype="auto", trust_remote_code=True, low_cpu_mem_usage=True, ) + assert isinstance(model, PreTrainedModel) # type hint state_dict = model.state_dict() if num_layers % num_expand != 0: @@ -85,7 +86,7 @@ def block_expansion( if f".{i:d}." in key: output_state_dict[change_name(key, i, layer_cnt)] = value - print(f"Add layer {layer_cnt} copied from layer {i}") + print(f"Add layer {layer_cnt} copied from layer {i}.") layer_cnt += 1 if (i + 1) % split == 0: for key, value in state_dict.items(): @@ -95,7 +96,7 @@ def block_expansion( else: output_state_dict[change_name(key, i, layer_cnt)] = torch.clone(value) - print(f"Add layer {layer_cnt} expanded from layer {i}") + print(f"Add layer {layer_cnt} expanded from layer {i}.") layer_cnt += 1 for key, value in state_dict.items(): @@ -112,12 +113,13 @@ def block_expansion( torch.save(shard, os.path.join(output_dir, shard_file)) if index is None: - print(f"Model weights saved in {os.path.join(output_dir, weights_name)}") + print(f"Model weights saved in {os.path.join(output_dir, weights_name)}.") else: index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f: json.dump(index, f, indent=2, sort_keys=True) - print(f"Model weights saved in {output_dir}") + + print(f"Model weights saved in {output_dir}.") print("- Fine-tune this model with:") print(f"model_name_or_path: {output_dir}") diff --git a/scripts/stat_utils/cal_lr.py b/scripts/stat_utils/cal_lr.py index a76d5827..21206a28 100644 --- a/scripts/stat_utils/cal_lr.py +++ b/scripts/stat_utils/cal_lr.py @@ -41,7 +41,7 @@ def calculate_lr( dataset: str = "alpaca_en_demo", dataset_dir: str = "data", template: str = "default", - cutoff_len: int = 1024, # i.e. maximum input length during training + cutoff_len: int = 2048, # i.e. maximum input length during training is_mistral_or_gemma: bool = False, # mistral and gemma models opt for a smaller learning rate, packing: bool = False, ): @@ -59,6 +59,7 @@ def calculate_lr( template=template, cutoff_len=cutoff_len, packing=packing, + preprocessing_num_workers=16, output_dir="dummy_dir", overwrite_cache=True, do_train=True, @@ -79,7 +80,7 @@ def calculate_lr( dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True) valid_tokens, total_tokens = 0, 0 - for batch in tqdm(dataloader): + for batch in tqdm(dataloader, desc="Collecting valid tokens"): valid_tokens += torch.sum(batch["labels"] != IGNORE_INDEX).item() total_tokens += torch.numel(batch["labels"]) diff --git a/scripts/stat_utils/cal_ppl.py b/scripts/stat_utils/cal_ppl.py index 03b25d9b..32d50e64 100644 --- a/scripts/stat_utils/cal_ppl.py +++ b/scripts/stat_utils/cal_ppl.py @@ -63,7 +63,7 @@ def calculate_ppl( dataset: str = "alpaca_en_demo", dataset_dir: str = "data", template: str = "default", - cutoff_len: int = 1024, + cutoff_len: int = 2048, max_samples: Optional[int] = None, train_on_prompt: bool = False, ): @@ -82,6 +82,7 @@ def calculate_ppl( cutoff_len=cutoff_len, max_samples=max_samples, train_on_prompt=train_on_prompt, + preprocessing_num_workers=16, output_dir="dummy_dir", overwrite_cache=True, do_train=True, @@ -111,7 +112,7 @@ def calculate_ppl( perplexities = [] batch: Dict[str, "torch.Tensor"] with torch.no_grad(): - for batch in tqdm(dataloader): + for batch in tqdm(dataloader, desc="Computing perplexities"): batch = batch.to(model.device) outputs = model(**batch) shift_logits: "torch.Tensor" = outputs["logits"][..., :-1, :] diff --git a/scripts/stat_utils/length_cdf.py b/scripts/stat_utils/length_cdf.py index 4b2b5349..5cf25347 100644 --- a/scripts/stat_utils/length_cdf.py +++ b/scripts/stat_utils/length_cdf.py @@ -42,6 +42,7 @@ def length_cdf( dataset_dir=dataset_dir, template=template, cutoff_len=1_000_000, + preprocessing_num_workers=16, output_dir="dummy_dir", overwrite_cache=True, do_train=True, @@ -52,7 +53,7 @@ def length_cdf( trainset = get_dataset(template, model_args, data_args, training_args, "sft", **tokenizer_module)["train_dataset"] total_num = len(trainset) length_dict = defaultdict(int) - for sample in tqdm(trainset["input_ids"]): + for sample in tqdm(trainset["input_ids"], desc="Collecting lengths"): length_dict[len(sample) // interval * interval] += 1 length_tuples = list(length_dict.items()) diff --git a/scripts/vllm_infer.py b/scripts/vllm_infer.py index 063f457c..c9a8cfb6 100644 --- a/scripts/vllm_infer.py +++ b/scripts/vllm_infer.py @@ -64,6 +64,7 @@ def vllm_infer( template=template, cutoff_len=cutoff_len, max_samples=max_samples, + preprocessing_num_workers=16, vllm_config=vllm_config, temperature=temperature, top_p=top_p, diff --git a/src/llamafactory/data/preprocess.py b/src/llamafactory/data/preprocess.py index 9f015b38..c5a10ec9 100644 --- a/src/llamafactory/data/preprocess.py +++ b/src/llamafactory/data/preprocess.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple from .processors.feedback import preprocess_feedback_dataset from .processors.pairwise import preprocess_pairwise_dataset, print_pairwise_dataset_example -from .processors.pretrain import preprocess_pretrain_dataset +from .processors.pretrain import preprocess_pretrain_dataset, print_pretrain_dataset_example from .processors.supervised import ( preprocess_packed_supervised_dataset, preprocess_supervised_dataset, @@ -47,7 +47,7 @@ def get_preprocess_and_print_func( tokenizer=tokenizer, data_args=data_args, ) - print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) + print_function = partial(print_pretrain_dataset_example, tokenizer=tokenizer) elif stage == "sft" and not do_generate: if data_args.packing: if data_args.neat_packing: # hack datasets to have int32 attention mask diff --git a/src/llamafactory/data/processors/pretrain.py b/src/llamafactory/data/processors/pretrain.py index 6d6b98d6..2cee40e5 100644 --- a/src/llamafactory/data/processors/pretrain.py +++ b/src/llamafactory/data/processors/pretrain.py @@ -52,3 +52,8 @@ def preprocess_pretrain_dataset( result["input_ids"][i][0] = tokenizer.bos_token_id return result + + +def print_pretrain_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None: + print("input_ids:\n{}".format(example["input_ids"])) + print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) diff --git a/src/llamafactory/data/processors/unsupervised.py b/src/llamafactory/data/processors/unsupervised.py index bc5ad34c..e21ebd42 100644 --- a/src/llamafactory/data/processors/unsupervised.py +++ b/src/llamafactory/data/processors/unsupervised.py @@ -100,3 +100,5 @@ def preprocess_unsupervised_dataset( def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None: print("input_ids:\n{}".format(example["input_ids"])) print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) + print("label_ids:\n{}".format(example["labels"])) + print("labels:\n{}".format(tokenizer.decode(example["labels"], skip_special_tokens=False))) diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index 82b5bc36..52e18dd0 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -618,6 +618,28 @@ register_model_group( ) +register_model_group( + models={ + "GPT-2-Small": { + DownloadSource.DEFAULT: "openai-community/gpt2", + DownloadSource.MODELSCOPE: "AI-ModelScope/gpt2", + }, + "GPT-2-Medium": { + DownloadSource.DEFAULT: "openai-community/gpt2-medium", + DownloadSource.MODELSCOPE: "AI-ModelScope/gpt2-medium", + }, + "GPT-2-Large": { + DownloadSource.DEFAULT: "openai-community/gpt2-large", + DownloadSource.MODELSCOPE: "AI-ModelScope/gpt2-large", + }, + "GPT-2-XL": { + DownloadSource.DEFAULT: "openai-community/gpt2-xl", + DownloadSource.MODELSCOPE: "goodbai95/GPT2-xl", + }, + }, +) + + register_model_group( models={ "Granite-3.0-1B-A400M-Base": { diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index 45998262..28ec25eb 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -111,40 +111,27 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): inputs: Dict[str, Union["torch.Tensor", Any]], prediction_loss_only: bool, ignore_keys: Optional[List[str]] = None, + **gen_kwargs, ) -> Tuple[Optional[float], Optional["torch.Tensor"], Optional["torch.Tensor"]]: r""" Removes the prompt part in the generated tokens. Subclass and override to inject custom behavior. """ - labels = inputs["labels"] if "labels" in inputs else None - if self.args.predict_with_generate: - assert self.processing_class.padding_side == "left", "This method only accepts left-padded tensor." - labels = labels.detach().clone() if labels is not None else None # backup labels - prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1) - if prompt_len > label_len: - inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"]) - if label_len > prompt_len: # truncate the labels instead of padding the inputs (llama2 fp16 compatibility) - inputs["labels"] = inputs["labels"][:, :prompt_len] + if self.args.predict_with_generate: # do not pass labels to model when generate + labels = inputs.pop("labels", None) + else: + labels = inputs.get("labels") - loss, generated_tokens, _ = super().prediction_step( # ignore the returned labels (may be truncated) - model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys + loss, generated_tokens, _ = super().prediction_step( + model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys, **gen_kwargs ) if generated_tokens is not None and self.args.predict_with_generate: - generated_tokens[:, :prompt_len] = self.processing_class.pad_token_id + generated_tokens[:, : inputs["input_ids"].size(-1)] = self.processing_class.pad_token_id generated_tokens = generated_tokens.contiguous() return loss, generated_tokens, labels - def _pad_tensors_to_target_len(self, src_tensor: "torch.Tensor", tgt_tensor: "torch.Tensor") -> "torch.Tensor": - r""" - Pads the tensor to the same length as the target tensor. - """ - assert self.processing_class.pad_token_id is not None, "Pad token is required." - padded_tensor = self.processing_class.pad_token_id * torch.ones_like(tgt_tensor) - padded_tensor[:, -src_tensor.shape[-1] :] = src_tensor # adopt left-padding - return padded_tensor.contiguous() # in contiguous memory - def save_predictions( self, dataset: "Dataset", predict_results: "PredictionOutput", skip_special_tokens: bool = True ) -> None: diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index 5f4a09cc..1ccfa9ef 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -117,8 +117,6 @@ def run_sft( # Evaluation if training_args.do_eval: metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs) - if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled - metrics.pop("eval_loss", None) trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics) @@ -126,8 +124,6 @@ def run_sft( if training_args.do_predict: logger.warning_once("Batch generation can be very slow. Consider using `scripts/vllm_infer.py` instead.") predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs) - if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled - predict_results.metrics.pop("predict_loss", None) trainer.log_metrics("predict", predict_results.metrics) trainer.save_metrics("predict", predict_results.metrics) trainer.save_predictions(dataset_module["eval_dataset"], predict_results, generating_args.skip_special_tokens)