mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-01-13 09:30:34 +08:00
[misc] Add a PyTorch version warning for Conv3D. (#9715)
This commit is contained in:
@@ -25,10 +25,14 @@ from transformers import (
|
|||||||
AutoProcessor,
|
AutoProcessor,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
)
|
)
|
||||||
|
from packaging import version
|
||||||
|
from torch import nn
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
import warnings
|
||||||
|
|
||||||
from ..extras import logging
|
from ..extras import logging
|
||||||
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
|
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
|
||||||
|
from ..extras.packages import _get_package_version
|
||||||
from .adapter import init_adapter
|
from .adapter import init_adapter
|
||||||
from .model_utils.ktransformers import load_kt_pretrained_model
|
from .model_utils.ktransformers import load_kt_pretrained_model
|
||||||
from .model_utils.liger_kernel import apply_liger_kernel
|
from .model_utils.liger_kernel import apply_liger_kernel
|
||||||
@@ -203,6 +207,17 @@ def load_model(
|
|||||||
model.load_state_dict(vhead_params, strict=False)
|
model.load_state_dict(vhead_params, strict=False)
|
||||||
logger.info_rank0(f"Loaded valuehead from checkpoint: {vhead_path}")
|
logger.info_rank0(f"Loaded valuehead from checkpoint: {vhead_path}")
|
||||||
|
|
||||||
|
# Conv3D is not recommended when using torch 2.9.x
|
||||||
|
torch_version = _get_package_version("torch")
|
||||||
|
if version.parse("2.9.0") <= torch_version < version.parse("2.10.0"):
|
||||||
|
if any(isinstance(m, nn.Conv3d) for m in model.modules()):
|
||||||
|
raise ValueError(
|
||||||
|
"Unsupported torch version detected: torch 2.9.x with Conv3D. "
|
||||||
|
"This combination is known to cause severe performance regression. "
|
||||||
|
"Please downgrade torch to <2.9 or remove Conv3D. "
|
||||||
|
"See https://github.com/pytorch/pytorch/issues/166122"
|
||||||
|
)
|
||||||
|
|
||||||
if not is_trainable:
|
if not is_trainable:
|
||||||
model.requires_grad_(False)
|
model.requires_grad_(False)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|||||||
Reference in New Issue
Block a user