mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-01-12 17:10:36 +08:00
[misc] Add a PyTorch version warning for Conv3D. (#9715)
This commit is contained in:
@@ -25,10 +25,14 @@ from transformers import (
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
)
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
import warnings
|
||||
|
||||
from ..extras import logging
|
||||
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
|
||||
from ..extras.packages import _get_package_version
|
||||
from .adapter import init_adapter
|
||||
from .model_utils.ktransformers import load_kt_pretrained_model
|
||||
from .model_utils.liger_kernel import apply_liger_kernel
|
||||
@@ -202,6 +206,17 @@ def load_model(
|
||||
if vhead_params is not None:
|
||||
model.load_state_dict(vhead_params, strict=False)
|
||||
logger.info_rank0(f"Loaded valuehead from checkpoint: {vhead_path}")
|
||||
|
||||
# Conv3D is not recommended when using torch 2.9.x
|
||||
torch_version = _get_package_version("torch")
|
||||
if version.parse("2.9.0") <= torch_version < version.parse("2.10.0"):
|
||||
if any(isinstance(m, nn.Conv3d) for m in model.modules()):
|
||||
raise ValueError(
|
||||
"Unsupported torch version detected: torch 2.9.x with Conv3D. "
|
||||
"This combination is known to cause severe performance regression. "
|
||||
"Please downgrade torch to <2.9 or remove Conv3D. "
|
||||
"See https://github.com/pytorch/pytorch/issues/166122"
|
||||
)
|
||||
|
||||
if not is_trainable:
|
||||
model.requires_grad_(False)
|
||||
|
||||
Reference in New Issue
Block a user