diff --git a/src/utils/common.py b/src/utils/common.py index 917bd867..648f226f 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -505,7 +505,7 @@ def preprocess_data( input_ids, labels = [], [] for i in range(len(dialog) // 2): - source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=True) + source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=(i == 0)) target_ids = tokenizer.encode(text=dialog[2*i+1], add_special_tokens=False) if len(source_ids) > data_args.max_source_length: