update tests

Former-commit-id: 93d3b8f43f
This commit is contained in:
hiyouga
2024-11-02 12:21:41 +08:00
parent 25093c2d82
commit 3f7c874594
23 changed files with 53 additions and 62 deletions

View File

@@ -80,18 +80,17 @@ def load_reference_model(
is_trainable: bool = False,
add_valuehead: bool = False,
) -> Union["PreTrainedModel", "LoraModel"]:
current_device = get_current_device()
if add_valuehead:
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(
model_path, torch_dtype=torch.float16, device_map=get_current_device()
model_path, torch_dtype=torch.float16, device_map=current_device
)
if not is_trainable:
model.v_head = model.v_head.to(torch.float16)
return model
model = AutoModelForCausalLM.from_pretrained(
model_path, torch_dtype=torch.float16, device_map=get_current_device()
)
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map=current_device)
if use_lora or use_pissa:
model = PeftModel.from_pretrained(
model, lora_path, subfolder="pissa_init" if use_pissa else None, is_trainable=is_trainable
@@ -110,7 +109,7 @@ def load_train_dataset(**kwargs) -> "Dataset":
return dataset_module["train_dataset"]
def patch_valuehead_model():
def patch_valuehead_model() -> None:
def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]) -> None:
state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")}
self.v_head.load_state_dict(state_dict, strict=False)