[trainer] new kto mismatch pair creation strategy (#7509)

This commit is contained in:
Hao 2025-04-01 15:21:53 +08:00 committed by GitHub
parent 2d421c57bf
commit 6d6e0f44fc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -83,8 +83,8 @@ class FeedbackDatasetProcessor(DatasetProcessor):
return input_ids, labels, kl_input_ids, kl_labels, kto_tag return input_ids, labels, kl_input_ids, kl_labels, kto_tag
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]: def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs # Creates mismatched pairs of prompts and completions for the KL dataset by adding a +1 offset to the order of completions.
kl_response = examples["_response"][::-1] kl_response = [examples["_response"][-1]] + examples["_response"][:-1]
model_inputs = defaultdict(list) model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])): for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2: if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2: