support BLOOM models

This commit is contained in:
hiyouga
2023-05-31 16:54:06 +08:00
parent a72492e649
commit 740a5daf56
16 changed files with 134 additions and 90 deletions

View File

@@ -1,7 +1,7 @@
import torch
from typing import Dict, Sequence, Union
from .data_collator import DataCollatorForLLaMA
from .data_collator import DynamicDataCollatorWithPadding
from .peft_trainer import PeftTrainer
@@ -10,7 +10,7 @@ from .other import get_logger
logger = get_logger(__name__)
class PairwiseDataCollatorForLLaMA(DataCollatorForLLaMA):
class PairwiseDataCollatorWithPadding(DynamicDataCollatorWithPadding):
r"""
Data collator for pairwise data.
"""
@@ -26,7 +26,7 @@ class PairwiseDataCollatorForLLaMA(DataCollatorForLLaMA):
return super().__call__(features)
class PairwiseTrainerForLLaMA(PeftTrainer):
class PairwisePeftTrainer(PeftTrainer):
r"""
Inherits PeftTrainer to compute pairwise loss.
"""