mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 03:40:34 +08:00
fix baichuan templates
This commit is contained in:
@@ -63,7 +63,9 @@ def preprocess_dataset(
|
||||
for query, response, history, system in construct_example(examples):
|
||||
input_ids, labels = [], []
|
||||
|
||||
for source_ids, target_ids in template.encode_multiturn(tokenizer, query, response, history, system):
|
||||
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
|
||||
tokenizer, query, response, history, system
|
||||
)):
|
||||
if len(source_ids) > data_args.max_source_length:
|
||||
source_ids = source_ids[:data_args.max_source_length]
|
||||
if len(target_ids) > data_args.max_target_length:
|
||||
@@ -72,8 +74,17 @@ def preprocess_dataset(
|
||||
if len(input_ids) + len(source_ids) + len(target_ids) > max_length:
|
||||
break
|
||||
|
||||
if turn_idx != 0 and template.efficient_eos: # used in baichuan, qwen and gpt2 models
|
||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||
else:
|
||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||
|
||||
input_ids += source_ids + target_ids
|
||||
labels += [IGNORE_INDEX] * len(source_ids) + target_ids
|
||||
labels += source_mask + target_ids
|
||||
|
||||
if template.efficient_eos:
|
||||
input_ids += [tokenizer.eos_token_id]
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
|
||||
Reference in New Issue
Block a user