mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 20:52:59 +08:00
tiny fix
Former-commit-id: ace3f85a7273fbbc531adfe6ad73bf76a5fff52d
This commit is contained in:
parent
2bc685feee
commit
d04585df59
@ -47,10 +47,11 @@
|
||||
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B | query_key_value | - |
|
||||
| [Baichuan](https://github.com/baichuan-inc/Baichuan-13B) | 7B/13B | W_pack | baichuan |
|
||||
| [Baichuan2](https://github.com/baichuan-inc/Baichuan2) | 7B/13B | W_pack | baichuan2 |
|
||||
| [InternLM](https://github.com/InternLM/InternLM) | 7B | q_proj,v_proj | intern |
|
||||
| [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern |
|
||||
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml |
|
||||
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | xverse |
|
||||
| [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | 6B | query_key_value | chatglm2 |
|
||||
| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.5B | Wqkv | - |
|
||||
|
||||
> [!NOTE]
|
||||
> **Default module** is used for the `--lora_target` argument, you can use `--lora_target all` to specify all the available modules.
|
||||
@ -157,7 +158,7 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_web.py
|
||||
```
|
||||
|
||||
We strongly recommend using the all-in-one Web UI for newcomers since it can also generate training scripts **automatically**.
|
||||
We **strongly recommend** using the all-in-one Web UI for newcomers since it can also generate training scripts automatically, even without a GPU environment.
|
||||
|
||||
> [!WARNING]
|
||||
> Currently the web UI only supports training on **a single GPU**.
|
||||
@ -457,6 +458,7 @@ Please follow the model licenses to use the corresponding model weights:
|
||||
- [Qwen](https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/LICENSE)
|
||||
- [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf)
|
||||
- [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B/blob/main/MODEL_LICENSE)
|
||||
- [Phi-1.5](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx)
|
||||
|
||||
## Citation
|
||||
|
||||
|
@ -47,10 +47,11 @@
|
||||
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B | query_key_value | - |
|
||||
| [Baichuan](https://github.com/baichuan-inc/Baichuan-13B) | 7B/13B | W_pack | baichuan |
|
||||
| [Baichuan2](https://github.com/baichuan-inc/Baichuan2) | 7B/13B | W_pack | baichuan2 |
|
||||
| [InternLM](https://github.com/InternLM/InternLM) | 7B | q_proj,v_proj | intern |
|
||||
| [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern |
|
||||
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml |
|
||||
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | xverse |
|
||||
| [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | 6B | query_key_value | chatglm2 |
|
||||
| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.5B | Wqkv | - |
|
||||
|
||||
> [!NOTE]
|
||||
> **默认模块**应作为 `--lora_target` 参数的默认值,可使用 `--lora_target all` 参数指定全部模块。
|
||||
@ -157,7 +158,7 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_web.py
|
||||
```
|
||||
|
||||
我们极力推荐新手使用浏览器一体化界面,因为它还可以**自动**生成运行所需的命令行脚本。
|
||||
我们**极力推荐**新手使用浏览器一体化界面,因为它还可以不依赖 GPU 环境自动生成在 GPU 上运行的命令行脚本。
|
||||
|
||||
> [!WARNING]
|
||||
> 目前网页 UI 仅支持**单卡训练**。
|
||||
@ -456,6 +457,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
- [Qwen](https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/LICENSE)
|
||||
- [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf)
|
||||
- [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B/blob/main/MODEL_LICENSE)
|
||||
- [Phi-1.5](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx)
|
||||
|
||||
## 引用
|
||||
|
||||
|
@ -13,14 +13,14 @@ from transformers import (
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizerBase
|
||||
)
|
||||
from transformers.utils import check_min_version, is_torch_npu_available
|
||||
from transformers.utils import check_min_version
|
||||
from transformers.utils.versions import require_version
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
try:
|
||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||
except ImportError:
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
except ImportError:
|
||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||
|
||||
from llmtuner.extras.logging import reset_logging, get_logger
|
||||
from llmtuner.extras.misc import count_parameters
|
||||
@ -85,7 +85,7 @@ def load_model_and_tokenizer(
|
||||
config = AutoConfig.from_pretrained(model_to_load, **config_kwargs)
|
||||
|
||||
# Fix config (for Qwen)
|
||||
if is_trainable and hasattr(config, "fp16") and hasattr(config, "bf16"):
|
||||
if hasattr(config, "fp16") and hasattr(config, "bf16"):
|
||||
setattr(config, "fp16", model_args.compute_dtype == torch.float16)
|
||||
setattr(config, "bf16", model_args.compute_dtype == torch.bfloat16)
|
||||
|
||||
@ -215,11 +215,7 @@ def load_model_and_tokenizer(
|
||||
# Prepare model for inference
|
||||
if not is_trainable:
|
||||
model.requires_grad_(False) # fix all model params
|
||||
if is_torch_npu_available():
|
||||
infer_dtype = torch.float16
|
||||
else:
|
||||
infer_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 # detect cuda capability
|
||||
model = model.to(infer_dtype) if model_args.quantization_bit is None else model
|
||||
model = model.to(model_args.compute_dtype) if model_args.quantization_bit is None else model
|
||||
|
||||
trainable_params, all_param = count_parameters(model)
|
||||
logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
|
||||
|
@ -8,6 +8,14 @@ from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
||||
from transformers.utils.versions import require_version
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
|
||||
try:
|
||||
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available
|
||||
is_bf16_available = is_torch_bf16_gpu_available()
|
||||
is_npu_available = is_torch_npu_available()
|
||||
except ImportError:
|
||||
is_bf16_available = torch.cuda.is_bf16_supported()
|
||||
is_npu_available = False
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.hparams import (
|
||||
ModelArguments,
|
||||
@ -197,7 +205,7 @@ def get_train_args(
|
||||
|
||||
# postprocess model_args
|
||||
if training_args.bf16:
|
||||
if not torch.cuda.is_bf16_supported():
|
||||
if not is_bf16_available:
|
||||
raise ValueError("Current device does not support bf16 training.")
|
||||
model_args.compute_dtype = torch.bfloat16
|
||||
elif training_args.fp16:
|
||||
@ -243,4 +251,12 @@ def get_infer_args(
|
||||
if model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1:
|
||||
raise ValueError("Quantized model only accepts a single checkpoint. Merge them first.")
|
||||
|
||||
# auto-detect cuda capability
|
||||
if is_npu_available:
|
||||
model_args.compute_dtype = torch.float16
|
||||
elif is_bf16_available:
|
||||
model_args.compute_dtype = torch.bfloat16
|
||||
else:
|
||||
model_args.compute_dtype = torch.float16
|
||||
|
||||
return model_args, data_args, finetuning_args, generating_args
|
||||
|
Loading…
x
Reference in New Issue
Block a user