Former-commit-id: 0e86527d7fae9c9fe0df89d6fbd89035c9d83fe3
This commit is contained in:
hiyouga 2023-11-09 16:41:32 +08:00
parent 164559d01d
commit 4dbb52750f

View File

@ -1,5 +1,4 @@
import torch import torch
import deepspeed # type: ignore
from copy import deepcopy from copy import deepcopy
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
@ -76,6 +75,8 @@ class CustomDPOTrainer(DPOTrainer):
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0) # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
if config_kwargs["zero_optimization"]["stage"] != 3: if config_kwargs["zero_optimization"]["stage"] != 3:
config_kwargs["zero_optimization"]["stage"] = 0 config_kwargs["zero_optimization"]["stage"] = 0
# lazy load
import deepspeed # type: ignore
model, *_ = deepspeed.initialize(model=model, config=config_kwargs) model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
model.eval() model.eval()
return model return model