fix generation bug #532

Former-commit-id: be21fc83f9
This commit is contained in:
hiyouga
2023-08-17 22:21:34 +08:00
parent a46f277477
commit 623a34b16f
5 changed files with 15 additions and 46 deletions

View File

@@ -3,7 +3,7 @@ from typing import Any, Dict, Generator, List, Optional, Tuple
from threading import Thread
from transformers import TextIteratorStreamer
from llmtuner.extras.misc import dispatch_model, get_logits_processor, get_stopping_criteria
from llmtuner.extras.misc import dispatch_model, get_logits_processor
from llmtuner.extras.template import get_template_and_fix_tokenizer
from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer
@@ -49,10 +49,9 @@ class ChatModel:
top_p=top_p or gen_kwargs["top_p"],
top_k=top_k or gen_kwargs["top_k"],
repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"],
eos_token_id=self.tokenizer.eos_token_id,
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
pad_token_id=self.tokenizer.pad_token_id,
logits_processor=get_logits_processor(),
stopping_criteria=get_stopping_criteria(self.tokenizer.additional_special_tokens_ids)
logits_processor=get_logits_processor()
))
if max_length: