fix sft encode

Former-commit-id: b2f7cb446591e3722b5be8d250ddfe0caa226384
This commit is contained in:
hiyouga 2023-07-11 19:50:33 +08:00
parent 685ba1e02e
commit 925a026853

View File

@ -505,7 +505,7 @@ def preprocess_data(
input_ids, labels = [], [] input_ids, labels = [], []
for i in range(len(dialog) // 2): for i in range(len(dialog) // 2):
source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=True) source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=(i == 0))
target_ids = tokenizer.encode(text=dialog[2*i+1], add_special_tokens=False) target_ids = tokenizer.encode(text=dialog[2*i+1], add_special_tokens=False)
if len(source_ids) > data_args.max_source_length: if len(source_ids) > data_args.max_source_length: