mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 12:42:51 +08:00
add tool test
Former-commit-id: 83dbfce8c30bc1183a5c86fb77a5676bf0b753af
This commit is contained in:
parent
4e3bfb799d
commit
1323b8f102
@ -4,7 +4,7 @@ from typing import Any, Dict, Generator, List, Literal, Optional, Tuple
|
|||||||
from threading import Thread
|
from threading import Thread
|
||||||
from transformers import GenerationConfig, TextIteratorStreamer
|
from transformers import GenerationConfig, TextIteratorStreamer
|
||||||
|
|
||||||
from ..data import get_template_and_fix_tokenizer
|
from ..data import get_template_and_fix_tokenizer, Role
|
||||||
from ..extras.misc import get_logits_processor
|
from ..extras.misc import get_logits_processor
|
||||||
from ..model import dispatch_model, load_model_and_tokenizer
|
from ..model import dispatch_model, load_model_and_tokenizer
|
||||||
from ..hparams import get_infer_args
|
from ..hparams import get_infer_args
|
||||||
@ -36,10 +36,19 @@ class ChatModel:
|
|||||||
query: str,
|
query: str,
|
||||||
history: Optional[List[Tuple[str, str]]] = None,
|
history: Optional[List[Tuple[str, str]]] = None,
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
|
tools: Optional[str] = None,
|
||||||
**input_kwargs
|
**input_kwargs
|
||||||
) -> Tuple[Dict[str, Any], int]:
|
) -> Tuple[Dict[str, Any], int]:
|
||||||
|
messages = []
|
||||||
|
if history is not None:
|
||||||
|
for old_prompt, old_response in history:
|
||||||
|
messages.append({"role": Role.USER, "content": old_prompt})
|
||||||
|
messages.append({"role": Role.ASSISTANT, "content": old_response})
|
||||||
|
|
||||||
|
messages.append({"role": Role.USER, "content": query})
|
||||||
|
messages.append({"role": Role.ASSISTANT, "content": ""})
|
||||||
prompt, _ = self.template.encode_oneturn(
|
prompt, _ = self.template.encode_oneturn(
|
||||||
tokenizer=self.tokenizer, query=query, resp="", history=history, system=system
|
tokenizer=self.tokenizer, messages=messages, system=system, tools=tools
|
||||||
)
|
)
|
||||||
prompt_length = len(prompt)
|
prompt_length = len(prompt)
|
||||||
input_ids = torch.tensor([prompt], device=self.model.device)
|
input_ids = torch.tensor([prompt], device=self.model.device)
|
||||||
@ -90,6 +99,7 @@ class ChatModel:
|
|||||||
query: str,
|
query: str,
|
||||||
history: Optional[List[Tuple[str, str]]] = None,
|
history: Optional[List[Tuple[str, str]]] = None,
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
|
tools: Optional[str] = None,
|
||||||
**input_kwargs
|
**input_kwargs
|
||||||
) -> List[Response]:
|
) -> List[Response]:
|
||||||
r"""
|
r"""
|
||||||
@ -97,7 +107,7 @@ class ChatModel:
|
|||||||
|
|
||||||
Returns: [(response_text, prompt_length, response_length)] * n (default n=1)
|
Returns: [(response_text, prompt_length, response_length)] * n (default n=1)
|
||||||
"""
|
"""
|
||||||
gen_kwargs, prompt_length = self._process_args(query, history, system, **input_kwargs)
|
gen_kwargs, prompt_length = self._process_args(query, history, system, tools, **input_kwargs)
|
||||||
generate_output = self.model.generate(**gen_kwargs)
|
generate_output = self.model.generate(**gen_kwargs)
|
||||||
response_ids = generate_output[:, prompt_length:]
|
response_ids = generate_output[:, prompt_length:]
|
||||||
response = self.tokenizer.batch_decode(
|
response = self.tokenizer.batch_decode(
|
||||||
@ -122,9 +132,10 @@ class ChatModel:
|
|||||||
query: str,
|
query: str,
|
||||||
history: Optional[List[Tuple[str, str]]] = None,
|
history: Optional[List[Tuple[str, str]]] = None,
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
|
tools: Optional[str] = None,
|
||||||
**input_kwargs
|
**input_kwargs
|
||||||
) -> Generator[str, None, None]:
|
) -> Generator[str, None, None]:
|
||||||
gen_kwargs, _ = self._process_args(query, history, system, **input_kwargs)
|
gen_kwargs, _ = self._process_args(query, history, system, tools, **input_kwargs)
|
||||||
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
||||||
gen_kwargs["streamer"] = streamer
|
gen_kwargs["streamer"] = streamer
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from .loader import get_dataset
|
from .loader import get_dataset
|
||||||
from .template import get_template_and_fix_tokenizer, templates
|
from .template import get_template_and_fix_tokenizer, templates
|
||||||
from .utils import split_dataset
|
from .utils import split_dataset, Role
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["get_dataset", "get_template_and_fix_tokenizer", "templates", "split_dataset"]
|
__all__ = ["get_dataset", "get_template_and_fix_tokenizer", "templates", "split_dataset", "Role"]
|
||||||
|
@ -97,7 +97,7 @@ def preprocess_packed_supervised_dataset(
|
|||||||
|
|
||||||
messages = examples["prompt"][i] + examples["response"][i]
|
messages = examples["prompt"][i] + examples["response"][i]
|
||||||
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
|
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
|
||||||
tokenizer, messages, examples["system"][i], examples["tool"][i], 1_000_000
|
tokenizer, messages, examples["system"][i], examples["tool"][i]
|
||||||
)):
|
)):
|
||||||
if data_args.train_on_prompt:
|
if data_args.train_on_prompt:
|
||||||
source_mask = source_ids
|
source_mask = source_ids
|
||||||
|
@ -33,13 +33,13 @@ class Template:
|
|||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
system: str,
|
system: str,
|
||||||
tool: str,
|
tools: str,
|
||||||
cutoff_len: int
|
cutoff_len: Optional[int] = 1_000_000
|
||||||
) -> Tuple[List[int], List[int]]:
|
) -> Tuple[List[int], List[int]]:
|
||||||
r"""
|
r"""
|
||||||
Returns a single pair of token ids representing prompt and response respectively.
|
Returns a single pair of token ids representing prompt and response respectively.
|
||||||
"""
|
"""
|
||||||
encoded_pairs = self._encode(tokenizer, messages, system, tool, cutoff_len)
|
encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len)
|
||||||
prompt_ids = []
|
prompt_ids = []
|
||||||
for query_ids, resp_ids in encoded_pairs[:-1]:
|
for query_ids, resp_ids in encoded_pairs[:-1]:
|
||||||
prompt_ids = prompt_ids + query_ids + resp_ids
|
prompt_ids = prompt_ids + query_ids + resp_ids
|
||||||
@ -52,13 +52,13 @@ class Template:
|
|||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
system: str,
|
system: str,
|
||||||
tool: str,
|
tools: str,
|
||||||
cutoff_len: int
|
cutoff_len: Optional[int] = 1_000_000
|
||||||
) -> List[Tuple[List[int], List[int]]]:
|
) -> List[Tuple[List[int], List[int]]]:
|
||||||
r"""
|
r"""
|
||||||
Returns multiple pairs of token ids representing prompts and responses respectively.
|
Returns multiple pairs of token ids representing prompts and responses respectively.
|
||||||
"""
|
"""
|
||||||
encoded_pairs = self._encode(tokenizer, messages, system, tool, cutoff_len)
|
encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len)
|
||||||
return encoded_pairs
|
return encoded_pairs
|
||||||
|
|
||||||
def _encode(
|
def _encode(
|
||||||
@ -66,7 +66,7 @@ class Template:
|
|||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
system: str,
|
system: str,
|
||||||
tool: str,
|
tools: str,
|
||||||
cutoff_len: int
|
cutoff_len: int
|
||||||
) -> List[Tuple[List[int], List[int]]]:
|
) -> List[Tuple[List[int], List[int]]]:
|
||||||
r"""
|
r"""
|
||||||
@ -78,8 +78,8 @@ class Template:
|
|||||||
encoded_messages = []
|
encoded_messages = []
|
||||||
for i, message in enumerate(messages):
|
for i, message in enumerate(messages):
|
||||||
elements = []
|
elements = []
|
||||||
if i == 0 and (system or tool):
|
if i == 0 and (system or tools):
|
||||||
tool_text = self.format_tool(content=tool)[0] if tool else ""
|
tool_text = self.format_tool(content=tools)[0] if tools else ""
|
||||||
elements += self.format_system(content=(system + tool_text))
|
elements += self.format_system(content=(system + tool_text))
|
||||||
elif i > 0 and i % 2 == 0:
|
elif i > 0 and i % 2 == 0:
|
||||||
elements += self.separator
|
elements += self.separator
|
||||||
@ -131,7 +131,7 @@ class Llama2Template(Template):
|
|||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
system: str,
|
system: str,
|
||||||
tool: str,
|
tools: str,
|
||||||
cutoff_len: int
|
cutoff_len: int
|
||||||
) -> List[Tuple[List[int], List[int]]]:
|
) -> List[Tuple[List[int], List[int]]]:
|
||||||
r"""
|
r"""
|
||||||
@ -144,8 +144,8 @@ class Llama2Template(Template):
|
|||||||
for i, message in enumerate(messages):
|
for i, message in enumerate(messages):
|
||||||
elements = []
|
elements = []
|
||||||
system_text = ""
|
system_text = ""
|
||||||
if i == 0 and (system or tool):
|
if i == 0 and (system or tools):
|
||||||
tool_text = self.format_tool(content=tool)[0] if tool else ""
|
tool_text = self.format_tool(content=tools)[0] if tools else ""
|
||||||
system_text = self.format_system(content=(system + tool_text))[0]
|
system_text = self.format_system(content=(system + tool_text))[0]
|
||||||
elif i > 0 and i % 2 == 0:
|
elif i > 0 and i % 2 == 0:
|
||||||
elements += self.separator
|
elements += self.separator
|
||||||
|
@ -65,17 +65,17 @@ class Evaluator:
|
|||||||
inputs, outputs, labels = [], [], []
|
inputs, outputs, labels = [], [], []
|
||||||
for i in trange(len(dataset[self.data_args.split]), desc="Formatting batches", position=1, leave=False):
|
for i in trange(len(dataset[self.data_args.split]), desc="Formatting batches", position=1, leave=False):
|
||||||
support_set = dataset["train"].shuffle().select(range(min(self.eval_args.n_shot, len(dataset["train"]))))
|
support_set = dataset["train"].shuffle().select(range(min(self.eval_args.n_shot, len(dataset["train"]))))
|
||||||
query, resp, history = self.eval_template.format_example(
|
messages = self.eval_template.format_example(
|
||||||
target_data=dataset[self.data_args.split][i],
|
target_data=dataset[self.data_args.split][i],
|
||||||
support_set=support_set,
|
support_set=support_set,
|
||||||
subject_name=categorys[subject]["name"],
|
subject_name=categorys[subject]["name"]
|
||||||
use_history=self.template.use_history
|
|
||||||
)
|
)
|
||||||
|
|
||||||
input_ids, _ = self.template.encode_oneturn(
|
input_ids, _ = self.template.encode_oneturn(
|
||||||
tokenizer=self.tokenizer, query=query, resp=resp, history=history
|
tokenizer=self.tokenizer, messages=messages
|
||||||
)
|
)
|
||||||
inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)})
|
inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)})
|
||||||
labels.append(resp)
|
labels.append(messages[-1]["content"])
|
||||||
|
|
||||||
for i in trange(0, len(inputs), self.eval_args.batch_size, desc="Predicting batches", position=1, leave=False):
|
for i in trange(0, len(inputs), self.eval_args.batch_size, desc="Predicting batches", position=1, leave=False):
|
||||||
batch_input = self.tokenizer.pad(
|
batch_input = self.tokenizer.pad(
|
||||||
|
@ -2,6 +2,7 @@ from dataclasses import dataclass
|
|||||||
from typing import TYPE_CHECKING, Dict, List, Tuple
|
from typing import TYPE_CHECKING, Dict, List, Tuple
|
||||||
|
|
||||||
from ..extras.constants import CHOICES
|
from ..extras.constants import CHOICES
|
||||||
|
from ..data import Role
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
@ -28,20 +29,23 @@ class EvalTemplate:
|
|||||||
support_set: "Dataset",
|
support_set: "Dataset",
|
||||||
subject_name: str,
|
subject_name: str,
|
||||||
use_history: bool
|
use_history: bool
|
||||||
) -> Tuple[str, str, List[Tuple[str, str]]]:
|
) -> List[Dict[str, str]]:
|
||||||
query, resp = self.parse_example(target_data)
|
messages = []
|
||||||
history = [self.parse_example(support_set[k]) for k in range(len(support_set))]
|
for k in range(len(support_set)):
|
||||||
|
prompt, response = self.parse_example(support_set[k])
|
||||||
|
messages.append({"role": Role.USER, "content": prompt})
|
||||||
|
messages.append({"role": Role.ASSISTANT, "content": response})
|
||||||
|
|
||||||
if len(history):
|
prompt, response = self.parse_example(target_data)
|
||||||
temp = history.pop(0)
|
messages.append({"role": Role.USER, "content": prompt})
|
||||||
history.insert(0, (self.system.format(subject=subject_name) + temp[0], temp[1]))
|
messages.append({"role": Role.ASSISTANT, "content": response})
|
||||||
else:
|
|
||||||
query = self.system.format(subject=subject_name) + query
|
messages[0]["content"] = self.system.format(subject=subject_name) + messages[0]["content"]
|
||||||
|
|
||||||
if not use_history:
|
if not use_history:
|
||||||
query = "\n\n".join(["".join(item) for item in history] + [query])
|
messages = [{"role": Role.USER, "content": "\n\n".join([message["content"] for message in messages[:-1]])}]
|
||||||
history = []
|
|
||||||
return query.strip(), resp, history
|
return messages
|
||||||
|
|
||||||
|
|
||||||
eval_templates: Dict[str, "EvalTemplate"] = {}
|
eval_templates: Dict[str, "EvalTemplate"] = {}
|
||||||
|
@ -105,6 +105,7 @@ class WebChatModel(ChatModel):
|
|||||||
query: str,
|
query: str,
|
||||||
history: List[Tuple[str, str]],
|
history: List[Tuple[str, str]],
|
||||||
system: str,
|
system: str,
|
||||||
|
tools: str,
|
||||||
max_new_tokens: int,
|
max_new_tokens: int,
|
||||||
top_p: float,
|
top_p: float,
|
||||||
temperature: float
|
temperature: float
|
||||||
@ -112,7 +113,7 @@ class WebChatModel(ChatModel):
|
|||||||
chatbot.append([query, ""])
|
chatbot.append([query, ""])
|
||||||
response = ""
|
response = ""
|
||||||
for new_text in self.stream_chat(
|
for new_text in self.stream_chat(
|
||||||
query, history, system, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
query, history, system, tools, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
||||||
):
|
):
|
||||||
response += new_text
|
response += new_text
|
||||||
new_history = history + [(query, response)]
|
new_history = history + [(query, response)]
|
||||||
|
@ -18,6 +18,7 @@ def create_chat_box(
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=4):
|
with gr.Column(scale=4):
|
||||||
system = gr.Textbox(show_label=False)
|
system = gr.Textbox(show_label=False)
|
||||||
|
tools = gr.Textbox(show_label=False, lines=2)
|
||||||
query = gr.Textbox(show_label=False, lines=8)
|
query = gr.Textbox(show_label=False, lines=8)
|
||||||
submit_btn = gr.Button(variant="primary")
|
submit_btn = gr.Button(variant="primary")
|
||||||
|
|
||||||
@ -30,7 +31,7 @@ def create_chat_box(
|
|||||||
|
|
||||||
submit_btn.click(
|
submit_btn.click(
|
||||||
engine.chatter.predict,
|
engine.chatter.predict,
|
||||||
[chatbot, query, history, system, max_new_tokens, top_p, temperature],
|
[chatbot, query, history, system, tools, max_new_tokens, top_p, temperature],
|
||||||
[chatbot, history],
|
[chatbot, history],
|
||||||
show_progress=True
|
show_progress=True
|
||||||
).then(
|
).then(
|
||||||
@ -41,6 +42,7 @@ def create_chat_box(
|
|||||||
|
|
||||||
return chat_box, chatbot, history, dict(
|
return chat_box, chatbot, history, dict(
|
||||||
system=system,
|
system=system,
|
||||||
|
tools=tools,
|
||||||
query=query,
|
query=query,
|
||||||
submit_btn=submit_btn,
|
submit_btn=submit_btn,
|
||||||
clear_btn=clear_btn,
|
clear_btn=clear_btn,
|
||||||
|
@ -521,6 +521,14 @@ LOCALES = {
|
|||||||
"placeholder": "系统提示词(非必填)"
|
"placeholder": "系统提示词(非必填)"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"tools": {
|
||||||
|
"en": {
|
||||||
|
"placeholder": "Tools (optional)"
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"placeholder": "工具列表(非必填)"
|
||||||
|
}
|
||||||
|
},
|
||||||
"query": {
|
"query": {
|
||||||
"en": {
|
"en": {
|
||||||
"placeholder": "Input..."
|
"placeholder": "Input..."
|
||||||
|
Loading…
x
Reference in New Issue
Block a user