add tool test

Former-commit-id: 83dbfce8c30bc1183a5c86fb77a5676bf0b753af
This commit is contained in:
hiyouga 2024-01-18 10:26:26 +08:00
parent 4e3bfb799d
commit 1323b8f102
9 changed files with 63 additions and 37 deletions

View File

@ -4,7 +4,7 @@ from typing import Any, Dict, Generator, List, Literal, Optional, Tuple
from threading import Thread from threading import Thread
from transformers import GenerationConfig, TextIteratorStreamer from transformers import GenerationConfig, TextIteratorStreamer
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer, Role
from ..extras.misc import get_logits_processor from ..extras.misc import get_logits_processor
from ..model import dispatch_model, load_model_and_tokenizer from ..model import dispatch_model, load_model_and_tokenizer
from ..hparams import get_infer_args from ..hparams import get_infer_args
@ -36,10 +36,19 @@ class ChatModel:
query: str, query: str,
history: Optional[List[Tuple[str, str]]] = None, history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None,
**input_kwargs **input_kwargs
) -> Tuple[Dict[str, Any], int]: ) -> Tuple[Dict[str, Any], int]:
messages = []
if history is not None:
for old_prompt, old_response in history:
messages.append({"role": Role.USER, "content": old_prompt})
messages.append({"role": Role.ASSISTANT, "content": old_response})
messages.append({"role": Role.USER, "content": query})
messages.append({"role": Role.ASSISTANT, "content": ""})
prompt, _ = self.template.encode_oneturn( prompt, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, query=query, resp="", history=history, system=system tokenizer=self.tokenizer, messages=messages, system=system, tools=tools
) )
prompt_length = len(prompt) prompt_length = len(prompt)
input_ids = torch.tensor([prompt], device=self.model.device) input_ids = torch.tensor([prompt], device=self.model.device)
@ -90,6 +99,7 @@ class ChatModel:
query: str, query: str,
history: Optional[List[Tuple[str, str]]] = None, history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None,
**input_kwargs **input_kwargs
) -> List[Response]: ) -> List[Response]:
r""" r"""
@ -97,7 +107,7 @@ class ChatModel:
Returns: [(response_text, prompt_length, response_length)] * n (default n=1) Returns: [(response_text, prompt_length, response_length)] * n (default n=1)
""" """
gen_kwargs, prompt_length = self._process_args(query, history, system, **input_kwargs) gen_kwargs, prompt_length = self._process_args(query, history, system, tools, **input_kwargs)
generate_output = self.model.generate(**gen_kwargs) generate_output = self.model.generate(**gen_kwargs)
response_ids = generate_output[:, prompt_length:] response_ids = generate_output[:, prompt_length:]
response = self.tokenizer.batch_decode( response = self.tokenizer.batch_decode(
@ -122,9 +132,10 @@ class ChatModel:
query: str, query: str,
history: Optional[List[Tuple[str, str]]] = None, history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None,
**input_kwargs **input_kwargs
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
gen_kwargs, _ = self._process_args(query, history, system, **input_kwargs) gen_kwargs, _ = self._process_args(query, history, system, tools, **input_kwargs)
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer gen_kwargs["streamer"] = streamer

View File

@ -1,6 +1,6 @@
from .loader import get_dataset from .loader import get_dataset
from .template import get_template_and_fix_tokenizer, templates from .template import get_template_and_fix_tokenizer, templates
from .utils import split_dataset from .utils import split_dataset, Role
__all__ = ["get_dataset", "get_template_and_fix_tokenizer", "templates", "split_dataset"] __all__ = ["get_dataset", "get_template_and_fix_tokenizer", "templates", "split_dataset", "Role"]

View File

@ -97,7 +97,7 @@ def preprocess_packed_supervised_dataset(
messages = examples["prompt"][i] + examples["response"][i] messages = examples["prompt"][i] + examples["response"][i]
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn( for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
tokenizer, messages, examples["system"][i], examples["tool"][i], 1_000_000 tokenizer, messages, examples["system"][i], examples["tool"][i]
)): )):
if data_args.train_on_prompt: if data_args.train_on_prompt:
source_mask = source_ids source_mask = source_ids

View File

@ -33,13 +33,13 @@ class Template:
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
system: str, system: str,
tool: str, tools: str,
cutoff_len: int cutoff_len: Optional[int] = 1_000_000
) -> Tuple[List[int], List[int]]: ) -> Tuple[List[int], List[int]]:
r""" r"""
Returns a single pair of token ids representing prompt and response respectively. Returns a single pair of token ids representing prompt and response respectively.
""" """
encoded_pairs = self._encode(tokenizer, messages, system, tool, cutoff_len) encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len)
prompt_ids = [] prompt_ids = []
for query_ids, resp_ids in encoded_pairs[:-1]: for query_ids, resp_ids in encoded_pairs[:-1]:
prompt_ids = prompt_ids + query_ids + resp_ids prompt_ids = prompt_ids + query_ids + resp_ids
@ -52,13 +52,13 @@ class Template:
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
system: str, system: str,
tool: str, tools: str,
cutoff_len: int cutoff_len: Optional[int] = 1_000_000
) -> List[Tuple[List[int], List[int]]]: ) -> List[Tuple[List[int], List[int]]]:
r""" r"""
Returns multiple pairs of token ids representing prompts and responses respectively. Returns multiple pairs of token ids representing prompts and responses respectively.
""" """
encoded_pairs = self._encode(tokenizer, messages, system, tool, cutoff_len) encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len)
return encoded_pairs return encoded_pairs
def _encode( def _encode(
@ -66,7 +66,7 @@ class Template:
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
system: str, system: str,
tool: str, tools: str,
cutoff_len: int cutoff_len: int
) -> List[Tuple[List[int], List[int]]]: ) -> List[Tuple[List[int], List[int]]]:
r""" r"""
@ -78,8 +78,8 @@ class Template:
encoded_messages = [] encoded_messages = []
for i, message in enumerate(messages): for i, message in enumerate(messages):
elements = [] elements = []
if i == 0 and (system or tool): if i == 0 and (system or tools):
tool_text = self.format_tool(content=tool)[0] if tool else "" tool_text = self.format_tool(content=tools)[0] if tools else ""
elements += self.format_system(content=(system + tool_text)) elements += self.format_system(content=(system + tool_text))
elif i > 0 and i % 2 == 0: elif i > 0 and i % 2 == 0:
elements += self.separator elements += self.separator
@ -131,7 +131,7 @@ class Llama2Template(Template):
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
system: str, system: str,
tool: str, tools: str,
cutoff_len: int cutoff_len: int
) -> List[Tuple[List[int], List[int]]]: ) -> List[Tuple[List[int], List[int]]]:
r""" r"""
@ -144,8 +144,8 @@ class Llama2Template(Template):
for i, message in enumerate(messages): for i, message in enumerate(messages):
elements = [] elements = []
system_text = "" system_text = ""
if i == 0 and (system or tool): if i == 0 and (system or tools):
tool_text = self.format_tool(content=tool)[0] if tool else "" tool_text = self.format_tool(content=tools)[0] if tools else ""
system_text = self.format_system(content=(system + tool_text))[0] system_text = self.format_system(content=(system + tool_text))[0]
elif i > 0 and i % 2 == 0: elif i > 0 and i % 2 == 0:
elements += self.separator elements += self.separator

View File

@ -65,17 +65,17 @@ class Evaluator:
inputs, outputs, labels = [], [], [] inputs, outputs, labels = [], [], []
for i in trange(len(dataset[self.data_args.split]), desc="Formatting batches", position=1, leave=False): for i in trange(len(dataset[self.data_args.split]), desc="Formatting batches", position=1, leave=False):
support_set = dataset["train"].shuffle().select(range(min(self.eval_args.n_shot, len(dataset["train"])))) support_set = dataset["train"].shuffle().select(range(min(self.eval_args.n_shot, len(dataset["train"]))))
query, resp, history = self.eval_template.format_example( messages = self.eval_template.format_example(
target_data=dataset[self.data_args.split][i], target_data=dataset[self.data_args.split][i],
support_set=support_set, support_set=support_set,
subject_name=categorys[subject]["name"], subject_name=categorys[subject]["name"]
use_history=self.template.use_history
) )
input_ids, _ = self.template.encode_oneturn( input_ids, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, query=query, resp=resp, history=history tokenizer=self.tokenizer, messages=messages
) )
inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)}) inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)})
labels.append(resp) labels.append(messages[-1]["content"])
for i in trange(0, len(inputs), self.eval_args.batch_size, desc="Predicting batches", position=1, leave=False): for i in trange(0, len(inputs), self.eval_args.batch_size, desc="Predicting batches", position=1, leave=False):
batch_input = self.tokenizer.pad( batch_input = self.tokenizer.pad(

View File

@ -2,6 +2,7 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Tuple from typing import TYPE_CHECKING, Dict, List, Tuple
from ..extras.constants import CHOICES from ..extras.constants import CHOICES
from ..data import Role
if TYPE_CHECKING: if TYPE_CHECKING:
from datasets import Dataset from datasets import Dataset
@ -28,20 +29,23 @@ class EvalTemplate:
support_set: "Dataset", support_set: "Dataset",
subject_name: str, subject_name: str,
use_history: bool use_history: bool
) -> Tuple[str, str, List[Tuple[str, str]]]: ) -> List[Dict[str, str]]:
query, resp = self.parse_example(target_data) messages = []
history = [self.parse_example(support_set[k]) for k in range(len(support_set))] for k in range(len(support_set)):
prompt, response = self.parse_example(support_set[k])
messages.append({"role": Role.USER, "content": prompt})
messages.append({"role": Role.ASSISTANT, "content": response})
if len(history): prompt, response = self.parse_example(target_data)
temp = history.pop(0) messages.append({"role": Role.USER, "content": prompt})
history.insert(0, (self.system.format(subject=subject_name) + temp[0], temp[1])) messages.append({"role": Role.ASSISTANT, "content": response})
else:
query = self.system.format(subject=subject_name) + query messages[0]["content"] = self.system.format(subject=subject_name) + messages[0]["content"]
if not use_history: if not use_history:
query = "\n\n".join(["".join(item) for item in history] + [query]) messages = [{"role": Role.USER, "content": "\n\n".join([message["content"] for message in messages[:-1]])}]
history = []
return query.strip(), resp, history return messages
eval_templates: Dict[str, "EvalTemplate"] = {} eval_templates: Dict[str, "EvalTemplate"] = {}

View File

@ -105,6 +105,7 @@ class WebChatModel(ChatModel):
query: str, query: str,
history: List[Tuple[str, str]], history: List[Tuple[str, str]],
system: str, system: str,
tools: str,
max_new_tokens: int, max_new_tokens: int,
top_p: float, top_p: float,
temperature: float temperature: float
@ -112,7 +113,7 @@ class WebChatModel(ChatModel):
chatbot.append([query, ""]) chatbot.append([query, ""])
response = "" response = ""
for new_text in self.stream_chat( for new_text in self.stream_chat(
query, history, system, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature query, history, system, tools, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
): ):
response += new_text response += new_text
new_history = history + [(query, response)] new_history = history + [(query, response)]

View File

@ -18,6 +18,7 @@ def create_chat_box(
with gr.Row(): with gr.Row():
with gr.Column(scale=4): with gr.Column(scale=4):
system = gr.Textbox(show_label=False) system = gr.Textbox(show_label=False)
tools = gr.Textbox(show_label=False, lines=2)
query = gr.Textbox(show_label=False, lines=8) query = gr.Textbox(show_label=False, lines=8)
submit_btn = gr.Button(variant="primary") submit_btn = gr.Button(variant="primary")
@ -30,7 +31,7 @@ def create_chat_box(
submit_btn.click( submit_btn.click(
engine.chatter.predict, engine.chatter.predict,
[chatbot, query, history, system, max_new_tokens, top_p, temperature], [chatbot, query, history, system, tools, max_new_tokens, top_p, temperature],
[chatbot, history], [chatbot, history],
show_progress=True show_progress=True
).then( ).then(
@ -41,6 +42,7 @@ def create_chat_box(
return chat_box, chatbot, history, dict( return chat_box, chatbot, history, dict(
system=system, system=system,
tools=tools,
query=query, query=query,
submit_btn=submit_btn, submit_btn=submit_btn,
clear_btn=clear_btn, clear_btn=clear_btn,

View File

@ -521,6 +521,14 @@ LOCALES = {
"placeholder": "系统提示词(非必填)" "placeholder": "系统提示词(非必填)"
} }
}, },
"tools": {
"en": {
"placeholder": "Tools (optional)"
},
"zh": {
"placeholder": "工具列表(非必填)"
}
},
"query": { "query": {
"en": { "en": {
"placeholder": "Input..." "placeholder": "Input..."