[misc] Add a PyTorch version warning for Conv3D. (#9715)

This commit is contained in:
Xunpeng Xiao
2026-01-05 13:26:29 +08:00
committed by GitHub
parent f60a6e3d01
commit 68119e5522

View File

@@ -25,10 +25,14 @@ from transformers import (
AutoProcessor,
AutoTokenizer,
)
from packaging import version
from torch import nn
from trl import AutoModelForCausalLMWithValueHead
import warnings
from ..extras import logging
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
from ..extras.packages import _get_package_version
from .adapter import init_adapter
from .model_utils.ktransformers import load_kt_pretrained_model
from .model_utils.liger_kernel import apply_liger_kernel
@@ -202,6 +206,17 @@ def load_model(
if vhead_params is not None:
model.load_state_dict(vhead_params, strict=False)
logger.info_rank0(f"Loaded valuehead from checkpoint: {vhead_path}")
# Conv3D is not recommended when using torch 2.9.x
torch_version = _get_package_version("torch")
if version.parse("2.9.0") <= torch_version < version.parse("2.10.0"):
if any(isinstance(m, nn.Conv3d) for m in model.modules()):
raise ValueError(
"Unsupported torch version detected: torch 2.9.x with Conv3D. "
"This combination is known to cause severe performance regression. "
"Please downgrade torch to <2.9 or remove Conv3D. "
"See https://github.com/pytorch/pytorch/issues/166122"
)
if not is_trainable:
model.requires_grad_(False)