mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-18 12:50:38 +08:00
@@ -1,13 +1,15 @@
|
||||
import os
|
||||
import json
|
||||
import torch
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from transformers.trainer import PredictionOutput
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.trainer import PredictionOutput
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -23,7 +25,7 @@ class PairwisePeftTrainer(PeftTrainer):
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
model: PreTrainedModel,
|
||||
model: "PreTrainedModel",
|
||||
inputs: Dict[str, torch.Tensor],
|
||||
return_outputs: Optional[bool] = False
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
@@ -46,7 +48,7 @@ class PairwisePeftTrainer(PeftTrainer):
|
||||
|
||||
def save_predictions(
|
||||
self,
|
||||
predict_results: PredictionOutput
|
||||
predict_results: "PredictionOutput"
|
||||
) -> None:
|
||||
r"""
|
||||
Saves model predictions to `output_dir`.
|
||||
|
||||
Reference in New Issue
Block a user