This commit is contained in:
hiyouga
2024-03-09 02:01:26 +08:00
parent 516d0ddc66
commit e8dd38b7fd
7 changed files with 28 additions and 20 deletions

View File

@@ -14,7 +14,7 @@ from trl.core import PPODecorators, logprobs_from_logits
from ...extras.callbacks import FixValueHeadModelCallback, LogCallback
from ...extras.logging import get_logger
from ...extras.misc import AverageMeter, count_parameters, get_logits_processor
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
from .utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
@@ -49,6 +49,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.model_args = model_args
self.finetuning_args = finetuning_args
self.reward_model = reward_model
self.current_device = get_current_device() # patch for deepspeed training
self.generation_config = GenerationConfig(
pad_token_id=self.tokenizer.pad_token_id,