mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-17 20:30:36 +08:00
support BLOOM models
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
from typing import Dict, Sequence, Union
|
||||
|
||||
from .data_collator import DataCollatorForLLaMA
|
||||
from .data_collator import DynamicDataCollatorWithPadding
|
||||
|
||||
from .peft_trainer import PeftTrainer
|
||||
|
||||
@@ -10,7 +10,7 @@ from .other import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PairwiseDataCollatorForLLaMA(DataCollatorForLLaMA):
|
||||
class PairwiseDataCollatorWithPadding(DynamicDataCollatorWithPadding):
|
||||
r"""
|
||||
Data collator for pairwise data.
|
||||
"""
|
||||
@@ -26,7 +26,7 @@ class PairwiseDataCollatorForLLaMA(DataCollatorForLLaMA):
|
||||
return super().__call__(features)
|
||||
|
||||
|
||||
class PairwiseTrainerForLLaMA(PeftTrainer):
|
||||
class PairwisePeftTrainer(PeftTrainer):
|
||||
r"""
|
||||
Inherits PeftTrainer to compute pairwise loss.
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user