mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 12:42:51 +08:00
support QDoRA
Former-commit-id: 19ef4826490b79e0c2aee20ad67430aa0e4724a7
This commit is contained in:
parent
096c31bfb6
commit
bbf272f96e
@ -8,6 +8,7 @@ import transformers
|
|||||||
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
||||||
from transformers.trainer_utils import get_last_checkpoint
|
from transformers.trainer_utils import get_last_checkpoint
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from ..extras.misc import check_dependencies
|
from ..extras.misc import check_dependencies
|
||||||
@ -129,7 +130,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||||||
|
|
||||||
if finetuning_args.use_dora:
|
if finetuning_args.use_dora:
|
||||||
if model_args.quantization_bit is not None:
|
if model_args.quantization_bit is not None:
|
||||||
raise ValueError("DoRA does not support quantization.")
|
require_version("peft>=0.9.1.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git")
|
||||||
|
|
||||||
if model_args.use_unsloth:
|
if model_args.use_unsloth:
|
||||||
raise ValueError("Unsloth does not support DoRA.")
|
raise ValueError("Unsloth does not support DoRA.")
|
||||||
@ -167,6 +168,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||||||
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
|
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
|
||||||
logger.warning("We recommend enable mixed precision training.")
|
logger.warning("We recommend enable mixed precision training.")
|
||||||
|
|
||||||
|
if training_args.do_train and finetuning_args.use_galore and not finetuning_args.pure_bf16:
|
||||||
|
logger.warning("Using GaLore with mixed precision training may significantly increases GPU memory usage.")
|
||||||
|
|
||||||
if (not training_args.do_train) and model_args.quantization_bit is not None:
|
if (not training_args.do_train) and model_args.quantization_bit is not None:
|
||||||
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
|
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user