mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 19:52:50 +08:00
[trainer] new kto mismatch pair creation strategy (#7509)
This commit is contained in:
parent
2d421c57bf
commit
6d6e0f44fc
@ -83,8 +83,8 @@ class FeedbackDatasetProcessor(DatasetProcessor):
|
|||||||
return input_ids, labels, kl_input_ids, kl_labels, kto_tag
|
return input_ids, labels, kl_input_ids, kl_labels, kto_tag
|
||||||
|
|
||||||
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
|
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
|
||||||
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
|
# Creates mismatched pairs of prompts and completions for the KL dataset by adding a +1 offset to the order of completions.
|
||||||
kl_response = examples["_response"][::-1]
|
kl_response = [examples["_response"][-1]] + examples["_response"][:-1]
|
||||||
model_inputs = defaultdict(list)
|
model_inputs = defaultdict(list)
|
||||||
for i in range(len(examples["_prompt"])):
|
for i in range(len(examples["_prompt"])):
|
||||||
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
|
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user