mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-14 10:56:56 +08:00
fix system prompt
Former-commit-id: 411e775aa939bdd154a3f1e92921ede90d989f18
This commit is contained in:
@@ -16,19 +16,19 @@ class ChatModel:
|
||||
self.model = dispatch_model(self.model)
|
||||
self.model = self.model.eval() # enable evaluation mode
|
||||
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
|
||||
self.source_prefix = data_args.source_prefix
|
||||
self.system_prompt = data_args.system_prompt
|
||||
|
||||
def process_args(
|
||||
self,
|
||||
query: str,
|
||||
history: Optional[List[Tuple[str, str]]] = None,
|
||||
prefix: Optional[str] = None,
|
||||
system: Optional[str] = None,
|
||||
**input_kwargs
|
||||
) -> Tuple[Dict[str, Any], int]:
|
||||
prefix = prefix or self.source_prefix
|
||||
system = system or self.system_prompt
|
||||
|
||||
prompt, _ = self.template.encode_oneturn(
|
||||
tokenizer=self.tokenizer, query=query, resp="", history=history, prefix=prefix
|
||||
tokenizer=self.tokenizer, query=query, resp="", history=history, system=system
|
||||
)
|
||||
input_ids = torch.tensor([prompt], device=self.model.device)
|
||||
prompt_length = len(input_ids[0])
|
||||
@@ -68,10 +68,10 @@ class ChatModel:
|
||||
self,
|
||||
query: str,
|
||||
history: Optional[List[Tuple[str, str]]] = None,
|
||||
prefix: Optional[str] = None,
|
||||
system: Optional[str] = None,
|
||||
**input_kwargs
|
||||
) -> Tuple[str, Tuple[int, int]]:
|
||||
gen_kwargs, prompt_length = self.process_args(query, history, prefix, **input_kwargs)
|
||||
gen_kwargs, prompt_length = self.process_args(query, history, system, **input_kwargs)
|
||||
generation_output = self.model.generate(**gen_kwargs)
|
||||
outputs = generation_output.tolist()[0][prompt_length:]
|
||||
response = self.tokenizer.decode(outputs, skip_special_tokens=True)
|
||||
@@ -83,10 +83,10 @@ class ChatModel:
|
||||
self,
|
||||
query: str,
|
||||
history: Optional[List[Tuple[str, str]]] = None,
|
||||
prefix: Optional[str] = None,
|
||||
system: Optional[str] = None,
|
||||
**input_kwargs
|
||||
) -> Generator[str, None, None]:
|
||||
gen_kwargs, _ = self.process_args(query, history, prefix, **input_kwargs)
|
||||
gen_kwargs, _ = self.process_args(query, history, system, **input_kwargs)
|
||||
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
||||
gen_kwargs["streamer"] = streamer
|
||||
|
||||
|
||||
Reference in New Issue
Block a user