mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-14 10:56:56 +08:00
support chatml safe encoding
Former-commit-id: ea52bb135bf9d07738091006ec7ada8df14cf15e
This commit is contained in:
@@ -30,10 +30,11 @@ class ChatModel:
|
||||
) -> Tuple[Dict[str, Any], int]:
|
||||
prefix = prefix or self.source_prefix
|
||||
|
||||
prompt = self.template.get_prompt(query, history, prefix, self.tokenizer.eos_token)
|
||||
inputs = self.tokenizer([prompt], return_tensors="pt")
|
||||
inputs = inputs.to(self.model.device)
|
||||
prompt_length = len(inputs["input_ids"][0])
|
||||
prompt, _ = self.template.get_prompt(
|
||||
tokenizer=self.tokenizer, query=query, resp="", history=history, prefix=prefix
|
||||
)
|
||||
input_ids = torch.tensor([prompt], device=self.model.device)
|
||||
prompt_length = len(input_ids[0])
|
||||
|
||||
do_sample = input_kwargs.pop("do_sample", None)
|
||||
temperature = input_kwargs.pop("temperature", None)
|
||||
@@ -45,7 +46,7 @@ class ChatModel:
|
||||
|
||||
gen_kwargs = self.generating_args.to_dict()
|
||||
gen_kwargs.update(dict(
|
||||
input_ids=inputs["input_ids"],
|
||||
input_ids=input_ids,
|
||||
do_sample=do_sample if do_sample is not None else gen_kwargs["do_sample"],
|
||||
temperature=temperature or gen_kwargs["temperature"],
|
||||
top_p=top_p or gen_kwargs["top_p"],
|
||||
|
||||
Reference in New Issue
Block a user