DeepSpeed ZeRO3 has inflight param error when calling model.eval()
This commit is contained in:
hiyouga
2024-06-13 02:25:50 +08:00
parent 2ed8270112
commit cf9f2d6c42
4 changed files with 12 additions and 17 deletions

View File

@@ -1,6 +1,7 @@
import math
import os
import sys
import warnings
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
@@ -136,6 +137,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
device_type = unwrapped_model.pretrained_model.device.type
self.amp_context = torch.autocast(device_type, dtype=model_args.compute_dtype)
warnings.simplefilter("ignore") # remove gc warnings on ref model
if finetuning_args.reward_model_type == "full":
if self.is_deepspeed_enabled: