Former-commit-id: ace3f85a7273fbbc531adfe6ad73bf76a5fff52d
This commit is contained in:
hiyouga 2023-09-21 15:25:29 +08:00
parent 2bc685feee
commit d04585df59
4 changed files with 30 additions and 14 deletions

View File

@ -47,10 +47,11 @@
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B | query_key_value | - |
| [Baichuan](https://github.com/baichuan-inc/Baichuan-13B) | 7B/13B | W_pack | baichuan |
| [Baichuan2](https://github.com/baichuan-inc/Baichuan2) | 7B/13B | W_pack | baichuan2 |
| [InternLM](https://github.com/InternLM/InternLM) | 7B | q_proj,v_proj | intern |
| [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern |
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml |
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | xverse |
| [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | 6B | query_key_value | chatglm2 |
| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.5B | Wqkv | - |
> [!NOTE]
> **Default module** is used for the `--lora_target` argument, you can use `--lora_target all` to specify all the available modules.
@ -157,7 +158,7 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
CUDA_VISIBLE_DEVICES=0 python src/train_web.py
```
We strongly recommend using the all-in-one Web UI for newcomers since it can also generate training scripts **automatically**.
We **strongly recommend** using the all-in-one Web UI for newcomers since it can also generate training scripts automatically, even without a GPU environment.
> [!WARNING]
> Currently the web UI only supports training on **a single GPU**.
@ -457,6 +458,7 @@ Please follow the model licenses to use the corresponding model weights:
- [Qwen](https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/LICENSE)
- [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf)
- [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B/blob/main/MODEL_LICENSE)
- [Phi-1.5](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx)
## Citation

View File

@ -47,10 +47,11 @@
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B | query_key_value | - |
| [Baichuan](https://github.com/baichuan-inc/Baichuan-13B) | 7B/13B | W_pack | baichuan |
| [Baichuan2](https://github.com/baichuan-inc/Baichuan2) | 7B/13B | W_pack | baichuan2 |
| [InternLM](https://github.com/InternLM/InternLM) | 7B | q_proj,v_proj | intern |
| [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern |
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml |
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | xverse |
| [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | 6B | query_key_value | chatglm2 |
| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.5B | Wqkv | - |
> [!NOTE]
> **默认模块**应作为 `--lora_target` 参数的默认值,可使用 `--lora_target all` 参数指定全部模块。
@ -157,7 +158,7 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
CUDA_VISIBLE_DEVICES=0 python src/train_web.py
```
我们极力推荐新手使用浏览器一体化界面,因为它还可以**自动**生成运行所需的命令行脚本。
我们**极力推荐**新手使用浏览器一体化界面,因为它还可以不依赖 GPU 环境自动生成在 GPU 上运行的命令行脚本。
> [!WARNING]
> 目前网页 UI 仅支持**单卡训练**。
@ -456,6 +457,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
- [Qwen](https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/LICENSE)
- [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf)
- [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B/blob/main/MODEL_LICENSE)
- [Phi-1.5](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx)
## 引用

View File

@ -13,14 +13,14 @@ from transformers import (
PreTrainedModel,
PreTrainedTokenizerBase
)
from transformers.utils import check_min_version, is_torch_npu_available
from transformers.utils import check_min_version
from transformers.utils.versions import require_version
from trl import AutoModelForCausalLMWithValueHead
try:
from transformers.deepspeed import is_deepspeed_zero3_enabled
except ImportError:
from transformers.integrations import is_deepspeed_zero3_enabled
except ImportError:
from transformers.deepspeed import is_deepspeed_zero3_enabled
from llmtuner.extras.logging import reset_logging, get_logger
from llmtuner.extras.misc import count_parameters
@ -85,7 +85,7 @@ def load_model_and_tokenizer(
config = AutoConfig.from_pretrained(model_to_load, **config_kwargs)
# Fix config (for Qwen)
if is_trainable and hasattr(config, "fp16") and hasattr(config, "bf16"):
if hasattr(config, "fp16") and hasattr(config, "bf16"):
setattr(config, "fp16", model_args.compute_dtype == torch.float16)
setattr(config, "bf16", model_args.compute_dtype == torch.bfloat16)
@ -215,11 +215,7 @@ def load_model_and_tokenizer(
# Prepare model for inference
if not is_trainable:
model.requires_grad_(False) # fix all model params
if is_torch_npu_available():
infer_dtype = torch.float16
else:
infer_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 # detect cuda capability
model = model.to(infer_dtype) if model_args.quantization_bit is None else model
model = model.to(model_args.compute_dtype) if model_args.quantization_bit is None else model
trainable_params, all_param = count_parameters(model)
logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(

View File

@ -8,6 +8,14 @@ from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.utils.versions import require_version
from transformers.trainer_utils import get_last_checkpoint
try:
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available
is_bf16_available = is_torch_bf16_gpu_available()
is_npu_available = is_torch_npu_available()
except ImportError:
is_bf16_available = torch.cuda.is_bf16_supported()
is_npu_available = False
from llmtuner.extras.logging import get_logger
from llmtuner.hparams import (
ModelArguments,
@ -197,7 +205,7 @@ def get_train_args(
# postprocess model_args
if training_args.bf16:
if not torch.cuda.is_bf16_supported():
if not is_bf16_available:
raise ValueError("Current device does not support bf16 training.")
model_args.compute_dtype = torch.bfloat16
elif training_args.fp16:
@ -243,4 +251,12 @@ def get_infer_args(
if model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1:
raise ValueError("Quantized model only accepts a single checkpoint. Merge them first.")
# auto-detect cuda capability
if is_npu_available:
model_args.compute_dtype = torch.float16
elif is_bf16_available:
model_args.compute_dtype = torch.bfloat16
else:
model_args.compute_dtype = torch.float16
return model_args, data_args, finetuning_args, generating_args