allow non-packing pretraining

This commit is contained in:
hiyouga
2024-03-09 22:21:46 +08:00
parent 412c52e325
commit bdb496644c
22 changed files with 64 additions and 67 deletions

View File

@@ -292,7 +292,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
queries: torch.Tensor,
responses: torch.Tensor,
model_inputs: dict,
return_logits: Optional[bool] = False,
return_logits: bool = False,
response_masks: Optional[torch.Tensor] = None,
):
r"""