Merge pull request #4417 from mMrBun/main

Add tool_format parameter to rewrite templates for different function call formats.

Former-commit-id: def6d280db
This commit is contained in:
hoshi-hiyouga
2024-06-24 23:17:55 +08:00
committed by GitHub
6 changed files with 14 additions and 6 deletions

View File

@@ -148,7 +148,7 @@ def get_dataset(
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"] = None,
) -> Union["Dataset", "IterableDataset"]:
template = get_template_and_fix_tokenizer(tokenizer, data_args.template)
template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format)
if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.")

View File

@@ -379,6 +379,7 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
def get_template_and_fix_tokenizer(
tokenizer: "PreTrainedTokenizer",
name: Optional[str] = None,
tool_format: Optional[str] = None,
) -> Template:
if name is None:
template = TEMPLATES["empty"] # placeholder
@@ -386,6 +387,9 @@ def get_template_and_fix_tokenizer(
template = TEMPLATES.get(name, None)
if template is None:
raise ValueError("Template {} does not exist.".format(name))
if tool_format:
template.format_tools = ToolFormatter(tool_format=tool_format)
stop_words = template.stop_words
if template.replace_eos: