mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-15 03:10:35 +08:00
tiny fix
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by CarperAI's trlx library.
|
||||
# This code is inspired by the CarperAI's trlx library.
|
||||
# https://github.com/CarperAI/trlx/blob/v0.7.0/examples/summarize_rlhf/reward_model/reward_model.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -89,8 +89,8 @@ class PairwiseTrainer(Trainer):
|
||||
|
||||
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
|
||||
super()._save(output_dir, state_dict)
|
||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||
if self.processor is not None:
|
||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||
getattr(self.processor, "image_processor").save_pretrained(output_dir)
|
||||
|
||||
def compute_loss(
|
||||
|
||||
Reference in New Issue
Block a user