mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 12:42:51 +08:00
parent
164559d01d
commit
4dbb52750f
@ -1,5 +1,4 @@
|
|||||||
import torch
|
import torch
|
||||||
import deepspeed # type: ignore
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
|
||||||
@ -76,6 +75,8 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
|
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
|
||||||
if config_kwargs["zero_optimization"]["stage"] != 3:
|
if config_kwargs["zero_optimization"]["stage"] != 3:
|
||||||
config_kwargs["zero_optimization"]["stage"] = 0
|
config_kwargs["zero_optimization"]["stage"] = 0
|
||||||
|
# lazy load
|
||||||
|
import deepspeed # type: ignore
|
||||||
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
|
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
|
||||||
model.eval()
|
model.eval()
|
||||||
return model
|
return model
|
||||||
|
Loading…
x
Reference in New Issue
Block a user