fix lora target

Former-commit-id: d822e41e7ac7e310ee49e347fc45754284ce30b8
This commit is contained in:
hiyouga
2023-09-09 17:04:45 +08:00
parent 7143c551ab
commit f91c5f2638
7 changed files with 63 additions and 43 deletions

View File

@@ -1,7 +1,7 @@
import torch
from typing import Any, Dict, Generator, List, Optional, Tuple
from threading import Thread
from transformers import TextIteratorStreamer
from transformers import GenerationConfig, TextIteratorStreamer
from llmtuner.extras.misc import dispatch_model, get_logits_processor
from llmtuner.extras.template import get_template_and_fix_tokenizer
@@ -40,26 +40,30 @@ class ChatModel:
max_length = input_kwargs.pop("max_length", None)
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
gen_kwargs = self.generating_args.to_dict()
gen_kwargs.update(dict(
input_ids=input_ids,
do_sample=do_sample if do_sample is not None else gen_kwargs["do_sample"],
temperature=temperature or gen_kwargs["temperature"],
top_p=top_p or gen_kwargs["top_p"],
top_k=top_k or gen_kwargs["top_k"],
repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"],
generating_args = self.generating_args.to_dict()
generating_args.update(dict(
do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
temperature=temperature or generating_args["temperature"],
top_p=top_p or generating_args["top_p"],
top_k=top_k or generating_args["top_k"],
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
pad_token_id=self.tokenizer.pad_token_id,
logits_processor=get_logits_processor()
pad_token_id=self.tokenizer.pad_token_id
))
if max_length:
gen_kwargs.pop("max_new_tokens", None)
gen_kwargs["max_length"] = max_length
generating_args.pop("max_new_tokens", None)
generating_args["max_length"] = max_length
if max_new_tokens:
gen_kwargs.pop("max_length", None)
gen_kwargs["max_new_tokens"] = max_new_tokens
generating_args.pop("max_length", None)
generating_args["max_new_tokens"] = max_new_tokens
gen_kwargs = dict(
inputs=input_ids,
generation_config=GenerationConfig(**generating_args),
logits_processor=get_logits_processor()
)
return gen_kwargs, prompt_length