This commit is contained in:
hiyouga
2023-10-14 19:20:11 +08:00
parent b240b6792f
commit af18b0dce7
4 changed files with 46 additions and 27 deletions

View File

@@ -22,6 +22,9 @@ def preprocess_dataset(
column_names = list(next(iter(dataset)).keys())
template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.")
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
for i in range(len(examples["prompt"])):
query, response = examples["prompt"][i], examples["response"][i]
@@ -32,13 +35,12 @@ def preprocess_dataset(
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
# build grouped texts with format `X1 X2 X3 ...`
if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding):
kwargs = dict(allowed_special="all") # for tiktoken tokenizer (Qwen)
if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
kwargs = dict(allowed_special="all")
else:
kwargs = dict(add_special_tokens=True)
if hasattr(tokenizer, "add_bos_token") and hasattr(tokenizer, "add_eos_token"):
setattr(tokenizer, "add_bos_token", True) # for LLaMA tokenizer
if hasattr(tokenizer, "add_eos_token"): # for LLaMA tokenizer
setattr(tokenizer, "add_eos_token", True)
tokenized_examples = tokenizer(examples["prompt"], **kwargs)
@@ -74,7 +76,9 @@ def preprocess_dataset(
if len(target_ids) > max_target_len:
target_ids = target_ids[:max_target_len]
if turn_idx != 0 and template.efficient_eos:
if data_args.train_on_prompt:
source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
@@ -97,15 +101,17 @@ def preprocess_dataset(
return model_inputs
def preprocess_packed_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<bos> X Y <eos>`
# we do not mask the inputs in packed training.
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
input_ids, labels = [], []
for query, response, history, system in construct_example(examples):
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
tokenizer, query, response, history, system
)):
if turn_idx != 0 and template.efficient_eos:
if data_args.train_on_prompt:
source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
@@ -229,5 +235,9 @@ def preprocess_dataset(
**kwargs
)
print_function(next(iter(dataset)))
try:
print_function(next(iter(dataset)))
except StopIteration:
raise ValueError("Empty dataset!")
return dataset