diff --git a/src/llamafactory/train/ppo/trainer.py b/src/llamafactory/train/ppo/trainer.py index 2f9978a5..1c401938 100644 --- a/src/llamafactory/train/ppo/trainer.py +++ b/src/llamafactory/train/ppo/trainer.py @@ -407,7 +407,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): values = torch.transpose(values, 0, 1) rewards = values.gather(dim=-1, index=(batch["attention_mask"].sum(dim=-1, keepdim=True) - 1)) - return rewards.to(torch.float32).detach().cpu() # use fp32 type + return rewards.float().detach() # use fp32 type @PPODecorators.empty_device_cache() def batched_forward_pass( diff --git a/src/llamafactory/train/rm/trainer.py b/src/llamafactory/train/rm/trainer.py index f7160cfc..267e88e2 100644 --- a/src/llamafactory/train/rm/trainer.py +++ b/src/llamafactory/train/rm/trainer.py @@ -99,7 +99,7 @@ class PairwiseTrainer(Trainer): chosen_scores = chosen_rewards.gather(dim=-1, index=(chosen_masks.sum(dim=-1, keepdim=True) - 1)) rejected_scores = rejected_rewards.gather(dim=-1, index=(rejected_masks.sum(dim=-1, keepdim=True) - 1)) chosen_scores, rejected_scores = chosen_scores.squeeze(), rejected_scores.squeeze() - loss = -torch.nn.functional.logsigmoid(chosen_scores - rejected_scores).mean() + loss = -torch.nn.functional.logsigmoid(chosen_scores.float() - rejected_scores.float()).mean() if return_outputs: return loss, (loss, chosen_scores, rejected_scores) else: