From f2e139f5cd0fb07db822592fcce755c8ca9299c9 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Thu, 9 Nov 2023 16:41:32 +0800 Subject: [PATCH] fix #1452 Former-commit-id: 4d16214467715df458e24d03bb7d303d62b8bdcd --- src/llmtuner/tuner/dpo/trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/llmtuner/tuner/dpo/trainer.py b/src/llmtuner/tuner/dpo/trainer.py index 8a9f8dd6..647bcee2 100644 --- a/src/llmtuner/tuner/dpo/trainer.py +++ b/src/llmtuner/tuner/dpo/trainer.py @@ -1,5 +1,4 @@ import torch -import deepspeed # type: ignore from copy import deepcopy from collections import defaultdict from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union @@ -76,6 +75,8 @@ class CustomDPOTrainer(DPOTrainer): # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0) if config_kwargs["zero_optimization"]["stage"] != 3: config_kwargs["zero_optimization"]["stage"] = 0 + # lazy load + import deepspeed # type: ignore model, *_ = deepspeed.initialize(model=model, config=config_kwargs) model.eval() return model