mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 11:50:35 +08:00
add logits processor
This commit is contained in:
@@ -46,7 +46,8 @@ from .other import (
|
||||
)
|
||||
|
||||
check_min_version("4.29.1")
|
||||
require_version("datasets>=2.10.0", "To fix: pip install datasets>=2.10.0")
|
||||
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
|
||||
require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0")
|
||||
require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0")
|
||||
require_version("trl>=0.4.1", "To fix: pip install trl>=0.4.1")
|
||||
|
||||
@@ -84,8 +85,7 @@ def init_adapter(
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None:
|
||||
if len(model_args.checkpoint_dir) > 1:
|
||||
logger.warning("Only LoRA tuning accepts multiple checkpoints.")
|
||||
assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
|
||||
load_trainable_params(model, model_args.checkpoint_dir[0]) # load model checkpoints for non-peft methods
|
||||
|
||||
if finetuning_args.finetuning_type == "lora":
|
||||
@@ -154,8 +154,7 @@ def load_pretrained(
|
||||
config_kwargs = {}
|
||||
if model_args.quantization_bit is not None:
|
||||
assert model_args.quantization_bit == 8, "We only accept 8-bit quantization."
|
||||
|
||||
require_version("bitsandbytes>=0.37.0", "bitsandbytes library is required to use this feature.")
|
||||
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.1")
|
||||
from bitsandbytes.cuda_setup.main import get_compute_capability, get_cuda_lib_handle, is_cublasLt_compatible
|
||||
cuda = get_cuda_lib_handle()
|
||||
cc = get_compute_capability(cuda)
|
||||
@@ -179,7 +178,6 @@ def load_pretrained(
|
||||
|
||||
if not is_trainable:
|
||||
model.requires_grad_(False) # fix all model params
|
||||
model = model.half() # cast all params to float16 for inference
|
||||
|
||||
if stage == "rm" or stage == "ppo": # add value head
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||
|
||||
Reference in New Issue
Block a user