update patcher

Former-commit-id: 3b040e8e0f
This commit is contained in:
hiyouga
2024-06-19 21:27:00 +08:00
parent 80e9f8e000
commit 030b4811c7
3 changed files with 10 additions and 7 deletions

View File

@@ -70,5 +70,5 @@ def test_upcast_lmhead_output():
tokenizer_module = load_tokenizer(model_args)
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
inputs = torch.randn((1, 16), dtype=torch.float16, device=get_current_device())
outputs: "torch.Tensor" = model.lm_head(inputs)
outputs: "torch.Tensor" = model.get_output_embeddings()(inputs)
assert outputs.dtype == torch.float32