support report custom args

This commit is contained in:
hiyouga
2024-12-19 14:57:09 +00:00
parent 84cd1188ac
commit 5111cac6f8
20 changed files with 164 additions and 124 deletions

View File

@@ -171,7 +171,10 @@ class HuggingfaceEngine(BaseEngine):
elif not isinstance(value, torch.Tensor):
value = torch.tensor(value)
gen_kwargs[key] = value.to(dtype=model.dtype, device=model.device)
if torch.is_floating_point(value):
value = value.to(model.dtype)
gen_kwargs[key] = value.to(model.device)
return gen_kwargs, prompt_length