[v1] add renderer ut (#9722)

This commit is contained in:
Yaowei Zheng
2026-01-07 02:06:07 +08:00
committed by GitHub
parent ea0b4e2466
commit d22de0d4bf
13 changed files with 420 additions and 249 deletions

View File

@@ -15,6 +15,7 @@
import os
from typing import TYPE_CHECKING, Any, Optional, TypedDict
import torch
from transformers import (
AutoConfig,
AutoModelForCausalLM,
@@ -25,14 +26,11 @@ from transformers import (
AutoProcessor,
AutoTokenizer,
)
from packaging import version
from torch import nn
from trl import AutoModelForCausalLMWithValueHead
import warnings
from ..extras import logging
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
from ..extras.packages import _get_package_version
from ..extras.packages import is_torch_version_greater_than
from .adapter import init_adapter
from .model_utils.ktransformers import load_kt_pretrained_model
from .model_utils.liger_kernel import apply_liger_kernel
@@ -206,11 +204,10 @@ def load_model(
if vhead_params is not None:
model.load_state_dict(vhead_params, strict=False)
logger.info_rank0(f"Loaded valuehead from checkpoint: {vhead_path}")
# Conv3D is not recommended when using torch 2.9.x
torch_version = _get_package_version("torch")
if version.parse("2.9.0") <= torch_version < version.parse("2.10.0"):
if any(isinstance(m, nn.Conv3d) for m in model.modules()):
if is_torch_version_greater_than("2.9.0") and not is_torch_version_greater_than("2.10.0"):
if any(isinstance(m, torch.nn.Conv3d) for m in model.modules()):
raise ValueError(
"Unsupported torch version detected: torch 2.9.x with Conv3D. "
"This combination is known to cause severe performance regression. "