alter rewards data type

This commit is contained in:
hiyouga
2023-06-02 14:19:51 +08:00
parent e6126244c1
commit 50d9a20f81
12 changed files with 40 additions and 50 deletions

View File

@@ -2,7 +2,7 @@ import torch
from typing import Dict, Optional, Sequence, Union
from transformers import DataCollatorWithPadding
from transformers import DataCollatorWithPadding, BatchEncoding
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer
@@ -34,7 +34,7 @@ class DynamicDataCollatorWithPadding(DataCollatorWithPadding):
attention_mask = attention_mask.bool()
return attention_mask
def __call__(self, features: Sequence[Dict[str, Union[torch.Tensor, Sequence[int]]]]) -> Dict[str, torch.Tensor]:
def __call__(self, features: Sequence[Dict[str, Union[torch.Tensor, Sequence[int]]]]) -> BatchEncoding:
r"""
Pads batched data to the longest sequence in the batch.
@@ -64,4 +64,4 @@ class DynamicDataCollatorWithPadding(DataCollatorWithPadding):
batch["input_ids"] = input_ids
batch["attention_mask"] = self.get_attention_masks(input_ids, device=input_ids.device)
return batch
return BatchEncoding(batch)