mirror of
https://github.com/PrimitiveAnything/PrimitiveAnything.git
synced 2025-09-18 05:22:48 +08:00
115 lines
3.4 KiB
Python
Executable File
115 lines
3.4 KiB
Python
Executable File
# -*- coding: utf-8 -*-
|
|
|
|
import torch
|
|
from torch import nn
|
|
from einops import rearrange
|
|
from transformers import CLIPModel
|
|
|
|
from .tsal_base import AlignedShapeAsLatentModule
|
|
|
|
|
|
class CLIPAlignedShapeAsLatentModule(AlignedShapeAsLatentModule):
|
|
|
|
def __init__(self, *,
|
|
shape_model,
|
|
projection_dim=768):
|
|
|
|
super().__init__()
|
|
|
|
self.shape_model = shape_model
|
|
self.shape_projection = nn.Parameter(torch.empty(self.shape_model.width, projection_dim))
|
|
nn.init.normal_(self.shape_projection, std=projection_dim ** -0.5)
|
|
|
|
def set_shape_model_only(self):
|
|
self.clip_model = None
|
|
|
|
def encode_shape_embed(self, surface, return_latents: bool = False):
|
|
"""
|
|
|
|
Args:
|
|
surface (torch.FloatTensor): [bs, n, 3 + c]
|
|
return_latents (bool):
|
|
|
|
Returns:
|
|
x (torch.FloatTensor): [bs, projection_dim]
|
|
shape_latents (torch.FloatTensor): [bs, m, d]
|
|
"""
|
|
|
|
pc = surface[..., 0:3]
|
|
feats = surface[..., 3:]
|
|
|
|
shape_embed, shape_latents = self.shape_model.encode_latents(pc, feats)
|
|
x = shape_embed @ self.shape_projection
|
|
|
|
if return_latents:
|
|
return x, shape_latents
|
|
else:
|
|
return x
|
|
|
|
def encode_image_embed(self, image):
|
|
"""
|
|
|
|
Args:
|
|
image (torch.FloatTensor): [bs, 3, h, w]
|
|
|
|
Returns:
|
|
x (torch.FloatTensor): [bs, projection_dim]
|
|
"""
|
|
|
|
x = self.clip_model.get_image_features(image)
|
|
|
|
return x
|
|
|
|
def encode_text_embed(self, text):
|
|
x = self.clip_model.get_text_features(text)
|
|
return x
|
|
|
|
def forward(self, surface, image, text):
|
|
"""
|
|
|
|
Args:
|
|
surface (torch.FloatTensor):
|
|
image (torch.FloatTensor): [bs, 3, 224, 224]
|
|
text (torch.LongTensor): [bs, num_templates, 77]
|
|
|
|
Returns:
|
|
embed_outputs (dict): the embedding outputs, and it contains:
|
|
- image_embed (torch.FloatTensor):
|
|
- text_embed (torch.FloatTensor):
|
|
- shape_embed (torch.FloatTensor):
|
|
- logit_scale (float):
|
|
"""
|
|
|
|
# # text embedding
|
|
# text_embed_all = []
|
|
# for i in range(text.shape[0]):
|
|
# text_for_one_sample = text[i]
|
|
# text_embed = self.encode_text_embed(text_for_one_sample)
|
|
# text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
|
|
# text_embed = text_embed.mean(dim=0)
|
|
# text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
|
|
# text_embed_all.append(text_embed)
|
|
# text_embed_all = torch.stack(text_embed_all)
|
|
|
|
b = text.shape[0]
|
|
text_tokens = rearrange(text, "b t l -> (b t) l")
|
|
text_embed = self.encode_text_embed(text_tokens)
|
|
text_embed = rearrange(text_embed, "(b t) d -> b t d", b=b)
|
|
text_embed = text_embed.mean(dim=1)
|
|
text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
|
|
|
|
# image embedding
|
|
image_embed = self.encode_image_embed(image)
|
|
|
|
# shape embedding
|
|
shape_embed, shape_latents = self.encode_shape_embed(surface, return_latents=True)
|
|
|
|
embed_outputs = {
|
|
"image_embed": image_embed,
|
|
"text_embed": text_embed,
|
|
"shape_embed": shape_embed,
|
|
"logit_scale": self.clip_model.logit_scale.exp()
|
|
}
|
|
|
|
return embed_outputs, shape_latents
|