mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-14 10:56:56 +08:00
support DPO training (2305.18290)
Former-commit-id: 6d98de148e4af63a7028dfaeb6cf86eb56a4488f
This commit is contained in:
@@ -18,7 +18,6 @@ class ChatModel:
|
||||
self.model = self.model.eval() # change to eval mode
|
||||
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
|
||||
self.source_prefix = data_args.source_prefix
|
||||
self.stop_ids = self.tokenizer.convert_tokens_to_ids(self.template.stop_words)
|
||||
self.model.generate = MethodType(PreTrainedModel.generate, self.model) # disable custom method (for Qwen)
|
||||
|
||||
def process_args(
|
||||
@@ -53,7 +52,7 @@ class ChatModel:
|
||||
top_k=top_k or gen_kwargs["top_k"],
|
||||
repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"],
|
||||
logits_processor=get_logits_processor(),
|
||||
stopping_criteria=get_stopping_criteria(self.stop_ids)
|
||||
stopping_criteria=get_stopping_criteria(self.tokenizer.additional_special_tokens_ids)
|
||||
))
|
||||
|
||||
if max_length:
|
||||
|
||||
Reference in New Issue
Block a user