mirror of
https://github.com/PrimitiveAnything/PrimitiveAnything.git
synced 2025-09-18 05:22:48 +08:00
51 lines
1.4 KiB
Python
Executable File
51 lines
1.4 KiB
Python
Executable File
import os
|
|
|
|
from omegaconf import OmegaConf
|
|
import torch
|
|
from torch import nn
|
|
|
|
from .utils.misc import instantiate_from_config
|
|
from ..utils import default, exists
|
|
|
|
|
|
def load_model():
|
|
model_config = OmegaConf.load(os.path.join(os.path.dirname(__file__), "shapevae-256.yaml"))
|
|
# print(model_config)
|
|
if hasattr(model_config, "model"):
|
|
model_config = model_config.model
|
|
ckpt_path = "./ckpt/shapevae-256.ckpt"
|
|
|
|
model = instantiate_from_config(model_config, ckpt_path=ckpt_path)
|
|
# model = model.cuda()
|
|
model = model.eval()
|
|
|
|
return model
|
|
|
|
|
|
class ShapeConditioner(nn.Module):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
dim_latent = None
|
|
):
|
|
super().__init__()
|
|
self.model = load_model()
|
|
|
|
self.dim_model_out = 768
|
|
dim_latent = default(dim_latent, self.dim_model_out)
|
|
self.dim_latent = dim_latent
|
|
|
|
def forward(
|
|
self,
|
|
shape = None,
|
|
shape_embed = None,
|
|
):
|
|
assert exists(shape) ^ exists(shape_embed)
|
|
|
|
if not exists(shape_embed):
|
|
point_feature = self.model.encode_latents(shape)
|
|
shape_latents = self.model.to_shape_latents(point_feature[:, 1:])
|
|
shape_head = point_feature[:, 0:1]
|
|
shape_embed = torch.cat([point_feature[:, 1:], shape_latents], dim=-1)
|
|
# shape_embed = torch.cat([point_feature[:, 1:], shape_latents], dim=-2) # cat tmp
|
|
return shape_head, shape_embed |