mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-03-02 17:55:59 +08:00
[v1] add renderer ut (#9722)
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypedDict
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
@@ -25,14 +26,11 @@ from transformers import (
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
)
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
import warnings
|
||||
|
||||
from ..extras import logging
|
||||
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
|
||||
from ..extras.packages import _get_package_version
|
||||
from ..extras.packages import is_torch_version_greater_than
|
||||
from .adapter import init_adapter
|
||||
from .model_utils.ktransformers import load_kt_pretrained_model
|
||||
from .model_utils.liger_kernel import apply_liger_kernel
|
||||
@@ -206,11 +204,10 @@ def load_model(
|
||||
if vhead_params is not None:
|
||||
model.load_state_dict(vhead_params, strict=False)
|
||||
logger.info_rank0(f"Loaded valuehead from checkpoint: {vhead_path}")
|
||||
|
||||
|
||||
# Conv3D is not recommended when using torch 2.9.x
|
||||
torch_version = _get_package_version("torch")
|
||||
if version.parse("2.9.0") <= torch_version < version.parse("2.10.0"):
|
||||
if any(isinstance(m, nn.Conv3d) for m in model.modules()):
|
||||
if is_torch_version_greater_than("2.9.0") and not is_torch_version_greater_than("2.10.0"):
|
||||
if any(isinstance(m, torch.nn.Conv3d) for m in model.modules()):
|
||||
raise ValueError(
|
||||
"Unsupported torch version detected: torch 2.9.x with Conv3D. "
|
||||
"This combination is known to cause severe performance regression. "
|
||||
|
||||
Reference in New Issue
Block a user