mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 20:52:59 +08:00
fix generation
Former-commit-id: d9e62711a3349d7c6fd3512fb25c709bdfbb311a
This commit is contained in:
parent
edc15c62fa
commit
048f99354f
@ -49,6 +49,8 @@ class ChatModel:
|
|||||||
top_p=top_p or gen_kwargs["top_p"],
|
top_p=top_p or gen_kwargs["top_p"],
|
||||||
top_k=top_k or gen_kwargs["top_k"],
|
top_k=top_k or gen_kwargs["top_k"],
|
||||||
repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"],
|
repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"],
|
||||||
|
eos_token_id=self.tokenizer.eos_token_id,
|
||||||
|
pad_token_id=self.tokenizer.pad_token_id,
|
||||||
logits_processor=get_logits_processor(),
|
logits_processor=get_logits_processor(),
|
||||||
stopping_criteria=get_stopping_criteria(self.tokenizer.additional_special_tokens_ids)
|
stopping_criteria=get_stopping_criteria(self.tokenizer.additional_special_tokens_ids)
|
||||||
))
|
))
|
||||||
|
@ -74,6 +74,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||||||
|
|
||||||
# Keyword arguments for `model.generate`
|
# Keyword arguments for `model.generate`
|
||||||
gen_kwargs = self.generating_args.to_dict()
|
gen_kwargs = self.generating_args.to_dict()
|
||||||
|
gen_kwargs["eos_token_id"] = self.tokenizer.eos_token_id
|
||||||
|
gen_kwargs["pad_token_id"] = self.tokenizer.pad_token_id
|
||||||
gen_kwargs["logits_processor"] = get_logits_processor()
|
gen_kwargs["logits_processor"] = get_logits_processor()
|
||||||
gen_kwargs["stopping_criteria"] = get_stopping_criteria(self.tokenizer.additional_special_tokens_ids)
|
gen_kwargs["stopping_criteria"] = get_stopping_criteria(self.tokenizer.additional_special_tokens_ids)
|
||||||
|
|
||||||
|
@ -52,6 +52,8 @@ def run_sft(
|
|||||||
|
|
||||||
# Keyword arguments for `model.generate`
|
# Keyword arguments for `model.generate`
|
||||||
gen_kwargs = generating_args.to_dict()
|
gen_kwargs = generating_args.to_dict()
|
||||||
|
gen_kwargs["eos_token_id"] = tokenizer.eos_token_id
|
||||||
|
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
|
||||||
gen_kwargs["logits_processor"] = get_logits_processor()
|
gen_kwargs["logits_processor"] = get_logits_processor()
|
||||||
gen_kwargs["stopping_criteria"] = get_stopping_criteria(tokenizer.additional_special_tokens_ids)
|
gen_kwargs["stopping_criteria"] = get_stopping_criteria(tokenizer.additional_special_tokens_ids)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user