add test cases

Former-commit-id: b27269bd2b
This commit is contained in:
hiyouga
2024-06-15 04:05:54 +08:00
parent d4ce280fbc
commit a3f4925c2c
9 changed files with 184 additions and 34 deletions

View File

@@ -135,8 +135,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
self.is_chatglm_model = getattr(unwrapped_model.config, "model_type", None) == "chatglm"
device_type = unwrapped_model.pretrained_model.device.type
self.amp_context = torch.autocast(device_type, dtype=model_args.compute_dtype)
self.amp_context = torch.autocast(self.current_device.type, dtype=self.model_args.compute_dtype)
warnings.simplefilter("ignore") # remove gc warnings on ref model
if finetuning_args.reward_model_type == "full":