mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-15 03:10:35 +08:00
Merge pull request #4417 from mMrBun/main
Add tool_format parameter to rewrite templates for different function call formats.
Former-commit-id: def6d280db
This commit is contained in:
@@ -148,7 +148,7 @@ def get_dataset(
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"] = None,
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
template = get_template_and_fix_tokenizer(tokenizer, data_args.template)
|
||||
template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format)
|
||||
if data_args.train_on_prompt and template.efficient_eos:
|
||||
raise ValueError("Current template does not support `train_on_prompt`.")
|
||||
|
||||
|
||||
@@ -379,6 +379,7 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
|
||||
def get_template_and_fix_tokenizer(
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
name: Optional[str] = None,
|
||||
tool_format: Optional[str] = None,
|
||||
) -> Template:
|
||||
if name is None:
|
||||
template = TEMPLATES["empty"] # placeholder
|
||||
@@ -386,6 +387,9 @@ def get_template_and_fix_tokenizer(
|
||||
template = TEMPLATES.get(name, None)
|
||||
if template is None:
|
||||
raise ValueError("Template {} does not exist.".format(name))
|
||||
|
||||
if tool_format:
|
||||
template.format_tools = ToolFormatter(tool_format=tool_format)
|
||||
|
||||
stop_words = template.stop_words
|
||||
if template.replace_eos:
|
||||
|
||||
Reference in New Issue
Block a user