mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-15 03:10:35 +08:00
upcast logits
This commit is contained in:
@@ -407,7 +407,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
values = torch.transpose(values, 0, 1)
|
||||
|
||||
rewards = values.gather(dim=-1, index=(batch["attention_mask"].sum(dim=-1, keepdim=True) - 1))
|
||||
return rewards.to(torch.float32).detach().cpu() # use fp32 type
|
||||
return rewards.float().detach() # use fp32 type
|
||||
|
||||
@PPODecorators.empty_device_cache()
|
||||
def batched_forward_pass(
|
||||
|
||||
Reference in New Issue
Block a user