mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-17 12:20:37 +08:00
@@ -80,18 +80,17 @@ def load_reference_model(
|
||||
is_trainable: bool = False,
|
||||
add_valuehead: bool = False,
|
||||
) -> Union["PreTrainedModel", "LoraModel"]:
|
||||
current_device = get_current_device()
|
||||
if add_valuehead:
|
||||
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
model_path, torch_dtype=torch.float16, device_map=get_current_device()
|
||||
model_path, torch_dtype=torch.float16, device_map=current_device
|
||||
)
|
||||
if not is_trainable:
|
||||
model.v_head = model.v_head.to(torch.float16)
|
||||
|
||||
return model
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, torch_dtype=torch.float16, device_map=get_current_device()
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map=current_device)
|
||||
if use_lora or use_pissa:
|
||||
model = PeftModel.from_pretrained(
|
||||
model, lora_path, subfolder="pissa_init" if use_pissa else None, is_trainable=is_trainable
|
||||
@@ -110,7 +109,7 @@ def load_train_dataset(**kwargs) -> "Dataset":
|
||||
return dataset_module["train_dataset"]
|
||||
|
||||
|
||||
def patch_valuehead_model():
|
||||
def patch_valuehead_model() -> None:
|
||||
def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]) -> None:
|
||||
state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")}
|
||||
self.v_head.load_state_dict(state_dict, strict=False)
|
||||
|
||||
Reference in New Issue
Block a user