support DPO training (2305.18290)

Former-commit-id: 6d98de148e4af63a7028dfaeb6cf86eb56a4488f
This commit is contained in:
hiyouga
2023-08-11 03:02:53 +08:00
parent 72dfd74005
commit ca719a8697
33 changed files with 513 additions and 192 deletions

View File

@@ -18,7 +18,6 @@ class ChatModel:
self.model = self.model.eval() # change to eval mode
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
self.source_prefix = data_args.source_prefix
self.stop_ids = self.tokenizer.convert_tokens_to_ids(self.template.stop_words)
self.model.generate = MethodType(PreTrainedModel.generate, self.model) # disable custom method (for Qwen)
def process_args(
@@ -53,7 +52,7 @@ class ChatModel:
top_k=top_k or gen_kwargs["top_k"],
repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"],
logits_processor=get_logits_processor(),
stopping_criteria=get_stopping_criteria(self.stop_ids)
stopping_criteria=get_stopping_criteria(self.tokenizer.additional_special_tokens_ids)
))
if max_length: