support rank0 logger

This commit is contained in:
hiyouga
2024-11-02 18:31:04 +08:00
parent bd08b8c441
commit c38aa29336
42 changed files with 316 additions and 252 deletions

View File

@@ -24,7 +24,7 @@ import torch
from transformers import Trainer
from typing_extensions import override
from ...extras.logging import get_logger
from ...extras import logging
from ...extras.packages import is_transformers_version_equal_to_4_46
from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
@@ -37,7 +37,7 @@ if TYPE_CHECKING:
from ...hparams import FinetuningArguments
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
class PairwiseTrainer(Trainer):
@@ -118,7 +118,7 @@ class PairwiseTrainer(Trainer):
return
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
logger.info(f"Saving prediction results to {output_prediction_file}")
logger.info_rank0(f"Saving prediction results to {output_prediction_file}")
chosen_scores, rejected_scores = predict_results.predictions
with open(output_prediction_file, "w", encoding="utf-8") as writer: