mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 11:50:35 +08:00
fix #1232
This commit is contained in:
@@ -84,10 +84,12 @@ def batch_inference(
|
||||
prefix_char: str
|
||||
) -> List[str]:
|
||||
logits = chat_model.model(**batch_input).logits
|
||||
lengths = torch.sum(batch_input["attention_mask"], dim=-1)
|
||||
nextword_logits = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0)
|
||||
probs = torch.nn.functional.softmax(
|
||||
torch.stack(
|
||||
[
|
||||
logits[:, -1, chat_model.tokenizer.encode(prefix_char + choice, add_special_tokens=False)[-1]]
|
||||
nextword_logits[:, chat_model.tokenizer.encode(prefix_char + choice, add_special_tokens=False)[-1]]
|
||||
for choice in choices
|
||||
],
|
||||
dim=-1
|
||||
@@ -120,8 +122,8 @@ def evaluate(
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
template=template
|
||||
))
|
||||
chat_model.tokenizer.padding_side = "left" # avoid overflow issue in batched inference for llama2
|
||||
eval_template = eval_templates[lang]
|
||||
assert chat_model.tokenizer.padding_side == "left", "only left-padded tensor can be accepted."
|
||||
|
||||
category_corrects: Dict[str, np.ndarray] = {
|
||||
subj: np.array([], dtype="bool") for subj in ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
|
||||
|
||||
Reference in New Issue
Block a user