add bf16 lora option

Former-commit-id: 58e7d7ff0cf9bf30e53b3eb12576f38d31976413
This commit is contained in:
hiyouga 2024-01-19 16:29:03 +08:00
parent 9b390c4bea
commit 384f0e7678
2 changed files with 6 additions and 2 deletions

View File

@ -55,9 +55,13 @@ class LoraArguments:
Phi choices: [\"Wqkv\", \"out_proj\", \"fc1\", \"fc2\"], \ Phi choices: [\"Wqkv\", \"out_proj\", \"fc1\", \"fc2\"], \
Others choices: the same as LLaMA."} Others choices: the same as LLaMA."}
) )
lora_bf16_mode: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to train lora adapters in bf16 precision."}
)
create_new_adapter: Optional[bool] = field( create_new_adapter: Optional[bool] = field(
default=False, default=False,
metadata={"help": "Whether to create a new adapter with randomly initialized weight or not."} metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."}
) )

View File

@ -125,7 +125,7 @@ def init_adapter(
model = get_peft_model(model, lora_config) model = get_peft_model(model, lora_config)
for param in filter(lambda p: p.requires_grad, model.parameters()): for param in filter(lambda p: p.requires_grad, model.parameters()):
param.data = param.data.to(torch.float32) param.data = param.data.to(torch.bfloat16 if finetuning_args.lora_bf16_mode else torch.float32)
if model_args.adapter_name_or_path is not None: if model_args.adapter_name_or_path is not None:
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path))) logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))