mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 04:32:50 +08:00
add bf16 lora option
Former-commit-id: b6ec112bebcb379caa32617a135df4d5d3cf865b
This commit is contained in:
parent
a9fe47a848
commit
e199967391
@ -55,9 +55,13 @@ class LoraArguments:
|
||||
Phi choices: [\"Wqkv\", \"out_proj\", \"fc1\", \"fc2\"], \
|
||||
Others choices: the same as LLaMA."}
|
||||
)
|
||||
lora_bf16_mode: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to train lora adapters in bf16 precision."}
|
||||
)
|
||||
create_new_adapter: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to create a new adapter with randomly initialized weight or not."}
|
||||
metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."}
|
||||
)
|
||||
|
||||
|
||||
|
@ -125,7 +125,7 @@ def init_adapter(
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
||||
for param in filter(lambda p: p.requires_grad, model.parameters()):
|
||||
param.data = param.data.to(torch.float32)
|
||||
param.data = param.data.to(torch.bfloat16 if finetuning_args.lora_bf16_mode else torch.float32)
|
||||
|
||||
if model_args.adapter_name_or_path is not None:
|
||||
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
|
||||
|
Loading…
x
Reference in New Issue
Block a user