update get template

Former-commit-id: dabad5570bf4a6b1044c963d8f27717030f373ef
This commit is contained in:
hiyouga 2024-09-04 22:36:20 +08:00
parent 1dfd1aaf82
commit d5ea05cfff
17 changed files with 57 additions and 56 deletions

View File

@ -25,7 +25,7 @@ from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
from llamafactory.data import get_dataset from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
from llamafactory.extras.constants import IGNORE_INDEX from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.hparams import get_train_args from llamafactory.hparams import get_train_args
from llamafactory.model import load_tokenizer from llamafactory.model import load_tokenizer
@ -66,7 +66,8 @@ def calculate_lr(
) )
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"] tokenizer = tokenizer_module["tokenizer"]
trainset = get_dataset(model_args, data_args, training_args, stage, **tokenizer_module)["train_dataset"] template = get_template_and_fix_tokenizer(tokenizer, data_args)
trainset = get_dataset(template, model_args, data_args, training_args, stage, **tokenizer_module)["train_dataset"]
if stage == "pt": if stage == "pt":
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
elif stage == "sft": elif stage == "sft":

View File

@ -23,7 +23,7 @@ from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
from llamafactory.data import get_dataset from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
from llamafactory.extras.constants import IGNORE_INDEX from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.hparams import get_train_args from llamafactory.hparams import get_train_args
from llamafactory.model import load_model, load_tokenizer from llamafactory.model import load_model, load_tokenizer
@ -88,7 +88,8 @@ def cal_ppl(
) )
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"] tokenizer = tokenizer_module["tokenizer"]
trainset = get_dataset(model_args, data_args, training_args, stage, **tokenizer_module)["train_dataset"] template = get_template_and_fix_tokenizer(tokenizer, data_args)
trainset = get_dataset(template, model_args, data_args, training_args, stage, **tokenizer_module)["train_dataset"]
model = load_model(tokenizer, model_args, finetuning_args, is_trainable=False) model = load_model(tokenizer, model_args, finetuning_args, is_trainable=False)
if stage == "pt": if stage == "pt":
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

View File

@ -18,7 +18,7 @@ from collections import defaultdict
import fire import fire
from tqdm import tqdm from tqdm import tqdm
from llamafactory.data import get_dataset from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
from llamafactory.hparams import get_train_args from llamafactory.hparams import get_train_args
from llamafactory.model import load_tokenizer from llamafactory.model import load_tokenizer
@ -48,7 +48,8 @@ def length_cdf(
) )
) )
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
trainset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)["train_dataset"] template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args)
trainset = get_dataset(template, model_args, data_args, training_args, "sft", **tokenizer_module)["train_dataset"]
total_num = len(trainset) total_num = len(trainset)
length_dict = defaultdict(int) length_dict = defaultdict(int)
for sample in tqdm(trainset["input_ids"]): for sample in tqdm(trainset["input_ids"]):

View File

@ -54,7 +54,7 @@ class HuggingfaceEngine(BaseEngine):
self.tokenizer = tokenizer_module["tokenizer"] self.tokenizer = tokenizer_module["tokenizer"]
self.processor = tokenizer_module["processor"] self.processor = tokenizer_module["processor"]
self.tokenizer.padding_side = "left" if self.can_generate else "right" self.tokenizer.padding_side = "left" if self.can_generate else "right"
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template, data_args.tool_format) self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
self.model = load_model( self.model = load_model(
self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate) self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
) # must after fixing tokenizer to resize vocab ) # must after fixing tokenizer to resize vocab

View File

@ -68,7 +68,7 @@ class VllmEngine(BaseEngine):
self.tokenizer = tokenizer_module["tokenizer"] self.tokenizer = tokenizer_module["tokenizer"]
self.processor = tokenizer_module["processor"] self.processor = tokenizer_module["processor"]
self.tokenizer.padding_side = "left" self.tokenizer.padding_side = "left"
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template, data_args.tool_format) self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
self.generating_args = generating_args.to_dict() self.generating_args = generating_args.to_dict()
engine_args = { engine_args = {

View File

@ -14,7 +14,7 @@
import os import os
import sys import sys
from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Tuple, Union from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union
import numpy as np import numpy as np
from datasets import DatasetDict, load_dataset, load_from_disk from datasets import DatasetDict, load_dataset, load_from_disk
@ -27,7 +27,6 @@ from .aligner import align_dataset
from .data_utils import merge_dataset, split_dataset from .data_utils import merge_dataset, split_dataset
from .parser import get_dataset_list from .parser import get_dataset_list
from .preprocess import get_preprocess_and_print_func from .preprocess import get_preprocess_and_print_func
from .template import get_template_and_fix_tokenizer
if TYPE_CHECKING: if TYPE_CHECKING:
@ -179,9 +178,6 @@ def _get_preprocessed_dataset(
load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0), load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
desc="Running tokenizer on dataset", desc="Running tokenizer on dataset",
) )
if data_args.dataset_map_batch_size:
# Set the batch size conditionally without considering the default variable of the batch size in the map function
kwargs.update(batch_size=data_args.dataset_map_batch_size)
dataset = dataset.map( dataset = dataset.map(
preprocess_func, preprocess_func,
@ -205,17 +201,14 @@ def _get_preprocessed_dataset(
def get_dataset( def get_dataset(
template: "Template",
model_args: "ModelArguments", model_args: "ModelArguments",
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"], stage: Literal["pt", "sft", "rm", "ppo", "kto"],
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"] = None, processor: Optional["ProcessorMixin"] = None,
) -> Tuple["DatasetModule", "Template"]: ) -> "DatasetModule":
template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format)
if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.")
# Load tokenized dataset # Load tokenized dataset
if data_args.tokenized_path is not None: if data_args.tokenized_path is not None:
if has_tokenized_data(data_args.tokenized_path): if has_tokenized_data(data_args.tokenized_path):
@ -233,7 +226,7 @@ def get_dataset(
if data_args.streaming: if data_args.streaming:
dataset_module = {k: v.to_iterable_dataset() for k, v in dataset_module.items()} dataset_module = {k: v.to_iterable_dataset() for k, v in dataset_module.items()}
return dataset_module, template return dataset_module
if data_args.streaming: if data_args.streaming:
raise ValueError("Turn off `streaming` when saving dataset to disk.") raise ValueError("Turn off `streaming` when saving dataset to disk.")
@ -280,7 +273,8 @@ def get_dataset(
dataset_module = {} dataset_module = {}
if "train" in dataset_dict: if "train" in dataset_dict:
dataset_module["train_dataset"] = dataset_dict["train"] dataset_module["train_dataset"] = dataset_dict["train"]
if "validation" in dataset_dict: if "validation" in dataset_dict:
dataset_module["eval_dataset"] = dataset_dict["validation"] dataset_module["eval_dataset"] = dataset_dict["validation"]
return dataset_module, template return dataset_module

View File

@ -27,6 +27,7 @@ from .mm_plugin import get_mm_plugin
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from ..hparams import DataArguments
from .formatter import SLOTS, Formatter from .formatter import SLOTS, Formatter
from .mm_plugin import BasePlugin from .mm_plugin import BasePlugin
@ -344,28 +345,27 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
return jinja_template return jinja_template
def get_template_and_fix_tokenizer( def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template":
tokenizer: "PreTrainedTokenizer", if data_args.template in ["llava", "paligemma", "qwen2_vl"]:
name: Optional[str] = None,
tool_format: Optional[str] = None,
) -> Template:
if name in ["llava", "paligemma", "qwen2_vl"]:
require_version( require_version(
"transformers>=4.45.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git" "transformers>=4.45.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git"
) )
if name is None: if data_args.template is None:
template = TEMPLATES["empty"] # placeholder template = TEMPLATES["empty"] # placeholder
else: else:
template = TEMPLATES.get(name, None) template = TEMPLATES.get(data_args.template, None)
if template is None: if template is None:
raise ValueError("Template {} does not exist.".format(name)) raise ValueError("Template {} does not exist.".format(data_args.template))
if tool_format is not None: if data_args.train_on_prompt and template.efficient_eos:
logger.info("Using tool format: {}.".format(tool_format)) raise ValueError("Current template does not support `train_on_prompt`.")
if data_args.tool_format is not None:
logger.info("Using tool format: {}.".format(data_args.tool_format))
eos_slots = [] if template.efficient_eos else [{"eos_token"}] eos_slots = [] if template.efficient_eos else [{"eos_token"}]
template.format_function = FunctionFormatter(slots=eos_slots, tool_format=tool_format) template.format_function = FunctionFormatter(slots=eos_slots, tool_format=data_args.tool_format)
template.format_tools = ToolFormatter(tool_format=tool_format) template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
stop_words = template.stop_words stop_words = template.stop_words
if template.replace_eos: if template.replace_eos:

View File

@ -59,7 +59,7 @@ class Evaluator:
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args) self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
self.tokenizer = load_tokenizer(self.model_args)["tokenizer"] self.tokenizer = load_tokenizer(self.model_args)["tokenizer"]
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2 self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template) self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args)
self.model = load_model(self.tokenizer, self.model_args, finetuning_args) self.model = load_model(self.tokenizer, self.model_args, finetuning_args)
self.eval_template = get_eval_template(self.eval_args.lang) self.eval_template = get_eval_template(self.eval_args.lang)
self.choice_inputs = [self.tokenizer.encode(ch, add_special_tokens=False)[-1] for ch in CHOICES] self.choice_inputs = [self.tokenizer.encode(ch, add_special_tokens=False)[-1] for ch in CHOICES]

View File

@ -113,10 +113,6 @@ class DataArguments:
default=None, default=None,
metadata={"help": "Path to save or load the tokenized datasets."}, metadata={"help": "Path to save or load the tokenized datasets."},
) )
dataset_map_batch_size: Optional[int] = field(
default=None,
metadata={"help": "Batch size for dataset mapping."},
)
def __post_init__(self): def __post_init__(self):
def split_arg(arg): def split_arg(arg):

View File

@ -17,7 +17,7 @@
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional
from ...data import PairwiseDataCollatorWithPadding, get_dataset from ...data import PairwiseDataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.ploting import plot_loss from ...extras.ploting import plot_loss
from ...hparams import ModelArguments from ...hparams import ModelArguments
@ -41,7 +41,8 @@ def run_dpo(
): ):
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"] tokenizer = tokenizer_module["tokenizer"]
dataset_module, template = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module) template = get_template_and_fix_tokenizer(tokenizer, data_args)
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="rm", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
data_collator = PairwiseDataCollatorWithPadding( data_collator = PairwiseDataCollatorWithPadding(

View File

@ -17,7 +17,7 @@
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional
from ...data import KTODataCollatorWithPadding, get_dataset from ...data import KTODataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.ploting import plot_loss from ...extras.ploting import plot_loss
from ...hparams import ModelArguments from ...hparams import ModelArguments
@ -41,7 +41,8 @@ def run_kto(
): ):
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"] tokenizer = tokenizer_module["tokenizer"]
dataset_module, template = get_dataset(model_args, data_args, training_args, stage="kto", **tokenizer_module) template = get_template_and_fix_tokenizer(tokenizer, data_args)
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="kto", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
data_collator = KTODataCollatorWithPadding( data_collator = KTODataCollatorWithPadding(

View File

@ -17,7 +17,7 @@
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional
from ...data import MultiModalDataCollatorForSeq2Seq, get_dataset from ...data import MultiModalDataCollatorForSeq2Seq, get_dataset, get_template_and_fix_tokenizer
from ...extras.ploting import plot_loss from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer from ...model import load_model, load_tokenizer
from ..callbacks import fix_valuehead_checkpoint from ..callbacks import fix_valuehead_checkpoint
@ -41,7 +41,8 @@ def run_ppo(
): ):
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"] tokenizer = tokenizer_module["tokenizer"]
dataset_module, template = get_dataset(model_args, data_args, training_args, stage="ppo", **tokenizer_module) template = get_template_and_fix_tokenizer(tokenizer, data_args)
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="ppo", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training

View File

@ -20,7 +20,7 @@ from typing import TYPE_CHECKING, List, Optional
from transformers import DataCollatorForLanguageModeling from transformers import DataCollatorForLanguageModeling
from ...data import get_dataset from ...data import get_dataset, get_template_and_fix_tokenizer
from ...extras.ploting import plot_loss from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer from ...model import load_model, load_tokenizer
from ..trainer_utils import create_modelcard_and_push from ..trainer_utils import create_modelcard_and_push
@ -42,7 +42,8 @@ def run_pt(
): ):
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"] tokenizer = tokenizer_module["tokenizer"]
dataset_module, _ = get_dataset(model_args, data_args, training_args, stage="pt", **tokenizer_module) template = get_template_and_fix_tokenizer(tokenizer, data_args)
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="pt", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

View File

@ -17,7 +17,7 @@
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional
from ...data import PairwiseDataCollatorWithPadding, get_dataset from ...data import PairwiseDataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer
from ...extras.ploting import plot_loss from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer from ...model import load_model, load_tokenizer
from ..callbacks import fix_valuehead_checkpoint from ..callbacks import fix_valuehead_checkpoint
@ -41,7 +41,8 @@ def run_rm(
): ):
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"] tokenizer = tokenizer_module["tokenizer"]
dataset_module, template = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module) template = get_template_and_fix_tokenizer(tokenizer, data_args)
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="rm", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
data_collator = PairwiseDataCollatorWithPadding(template=template, pad_to_multiple_of=8, **tokenizer_module) data_collator = PairwiseDataCollatorWithPadding(template=template, pad_to_multiple_of=8, **tokenizer_module)

View File

@ -17,7 +17,7 @@
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.misc import get_logits_processor from ...extras.misc import get_logits_processor
from ...extras.ploting import plot_loss from ...extras.ploting import plot_loss
@ -43,7 +43,8 @@ def run_sft(
): ):
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"] tokenizer = tokenizer_module["tokenizer"]
dataset_module, template = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module) template = get_template_and_fix_tokenizer(tokenizer, data_args)
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
if getattr(model, "is_quantized", False) and not training_args.do_train: if getattr(model, "is_quantized", False) and not training_args.do_train:
@ -62,7 +63,7 @@ def run_sft(
# Override the decoding parameters of Seq2SeqTrainer # Override the decoding parameters of Seq2SeqTrainer
training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
training_args.remove_unused_columns = False # important for multimodal and pairwise dataset training_args.remove_unused_columns = False # important for multimodal dataset
# Metric utils # Metric utils
metric_module = {} metric_module = {}

View File

@ -19,7 +19,7 @@ from peft import PeftModel
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from ..data import get_dataset from ..data import get_dataset, get_template_and_fix_tokenizer
from ..extras.misc import get_current_device from ..extras.misc import get_current_device
from ..hparams import get_infer_args, get_train_args from ..hparams import get_infer_args, get_train_args
from ..model import load_model, load_tokenizer from ..model import load_model, load_tokenizer
@ -105,7 +105,8 @@ def load_reference_model(
def load_train_dataset(**kwargs) -> "Dataset": def load_train_dataset(**kwargs) -> "Dataset":
model_args, data_args, training_args, _, _ = get_train_args(kwargs) model_args, data_args, training_args, _, _ = get_train_args(kwargs)
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
dataset_module, _ = get_dataset(model_args, data_args, training_args, stage=kwargs["stage"], **tokenizer_module) template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args)
dataset_module = get_dataset(template, model_args, data_args, training_args, kwargs["stage"], **tokenizer_module)
return dataset_module["train_dataset"] return dataset_module["train_dataset"]

View File

@ -19,6 +19,7 @@ import pytest
from transformers import AutoTokenizer from transformers import AutoTokenizer
from llamafactory.data import get_template_and_fix_tokenizer from llamafactory.data import get_template_and_fix_tokenizer
from llamafactory.hparams import DataArguments
if TYPE_CHECKING: if TYPE_CHECKING:
@ -51,7 +52,7 @@ def _check_single_template(
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, token=HF_TOKEN) tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, token=HF_TOKEN)
content_str = tokenizer.apply_chat_template(MESSAGES, tokenize=False) content_str = tokenizer.apply_chat_template(MESSAGES, tokenize=False)
content_ids = tokenizer.apply_chat_template(MESSAGES, tokenize=True) content_ids = tokenizer.apply_chat_template(MESSAGES, tokenize=True)
template = get_template_and_fix_tokenizer(tokenizer, name=template_name) template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template=template_name))
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES) prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
assert content_str == prompt_str + answer_str + extra_str assert content_str == prompt_str + answer_str + extra_str
assert content_ids == prompt_ids + answer_ids + tokenizer.encode(extra_str, add_special_tokens=False) assert content_ids == prompt_ids + answer_ids + tokenizer.encode(extra_str, add_special_tokens=False)
@ -78,7 +79,7 @@ def _check_template(model_id: str, template_name: str, prompt_str: str, answer_s
@pytest.mark.parametrize("use_fast", [True, False]) @pytest.mark.parametrize("use_fast", [True, False])
def test_encode_oneturn(use_fast: bool): def test_encode_oneturn(use_fast: bool):
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast) tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast)
template = get_template_and_fix_tokenizer(tokenizer, name="llama3") template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES) prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
prompt_str = ( prompt_str = (
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>" "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"
@ -93,7 +94,7 @@ def test_encode_oneturn(use_fast: bool):
@pytest.mark.parametrize("use_fast", [True, False]) @pytest.mark.parametrize("use_fast", [True, False])
def test_encode_multiturn(use_fast: bool): def test_encode_multiturn(use_fast: bool):
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast) tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast)
template = get_template_and_fix_tokenizer(tokenizer, name="llama3") template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES) encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES)
prompt_str_1 = ( prompt_str_1 = (
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>" "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"