support streaming data, fix #284 #274 #268

This commit is contained in:
hiyouga
2023-07-31 23:33:00 +08:00
parent 513e1f1ec9
commit 0411a4b3e1
28 changed files with 478 additions and 344 deletions

View File

@@ -1,13 +1,15 @@
import os
import json
import torch
from typing import Dict, List, Optional, Tuple, Union
from transformers.trainer import PredictionOutput
from transformers.modeling_utils import PreTrainedModel
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from llmtuner.extras.logging import get_logger
from llmtuner.tuner.core.trainer import PeftTrainer
if TYPE_CHECKING:
from transformers.trainer import PredictionOutput
from transformers.modeling_utils import PreTrainedModel
logger = get_logger(__name__)
@@ -23,7 +25,7 @@ class PairwisePeftTrainer(PeftTrainer):
def compute_loss(
self,
model: PreTrainedModel,
model: "PreTrainedModel",
inputs: Dict[str, torch.Tensor],
return_outputs: Optional[bool] = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
@@ -46,7 +48,7 @@ class PairwisePeftTrainer(PeftTrainer):
def save_predictions(
self,
predict_results: PredictionOutput
predict_results: "PredictionOutput"
) -> None:
r"""
Saves model predictions to `output_dir`.