mirror of
https://github.com/PrimitiveAnything/PrimitiveAnything.git
synced 2026-05-08 00:58:55 +08:00
init
This commit is contained in:
114
primitive_anything/michelangelo/models/tsal/clip_asl_module.py
Executable file
114
primitive_anything/michelangelo/models/tsal/clip_asl_module.py
Executable file
@@ -0,0 +1,114 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from einops import rearrange
|
||||
from transformers import CLIPModel
|
||||
|
||||
from .tsal_base import AlignedShapeAsLatentModule
|
||||
|
||||
|
||||
class CLIPAlignedShapeAsLatentModule(AlignedShapeAsLatentModule):
|
||||
|
||||
def __init__(self, *,
|
||||
shape_model,
|
||||
projection_dim=768):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.shape_model = shape_model
|
||||
self.shape_projection = nn.Parameter(torch.empty(self.shape_model.width, projection_dim))
|
||||
nn.init.normal_(self.shape_projection, std=projection_dim ** -0.5)
|
||||
|
||||
def set_shape_model_only(self):
|
||||
self.clip_model = None
|
||||
|
||||
def encode_shape_embed(self, surface, return_latents: bool = False):
|
||||
"""
|
||||
|
||||
Args:
|
||||
surface (torch.FloatTensor): [bs, n, 3 + c]
|
||||
return_latents (bool):
|
||||
|
||||
Returns:
|
||||
x (torch.FloatTensor): [bs, projection_dim]
|
||||
shape_latents (torch.FloatTensor): [bs, m, d]
|
||||
"""
|
||||
|
||||
pc = surface[..., 0:3]
|
||||
feats = surface[..., 3:]
|
||||
|
||||
shape_embed, shape_latents = self.shape_model.encode_latents(pc, feats)
|
||||
x = shape_embed @ self.shape_projection
|
||||
|
||||
if return_latents:
|
||||
return x, shape_latents
|
||||
else:
|
||||
return x
|
||||
|
||||
def encode_image_embed(self, image):
|
||||
"""
|
||||
|
||||
Args:
|
||||
image (torch.FloatTensor): [bs, 3, h, w]
|
||||
|
||||
Returns:
|
||||
x (torch.FloatTensor): [bs, projection_dim]
|
||||
"""
|
||||
|
||||
x = self.clip_model.get_image_features(image)
|
||||
|
||||
return x
|
||||
|
||||
def encode_text_embed(self, text):
|
||||
x = self.clip_model.get_text_features(text)
|
||||
return x
|
||||
|
||||
def forward(self, surface, image, text):
|
||||
"""
|
||||
|
||||
Args:
|
||||
surface (torch.FloatTensor):
|
||||
image (torch.FloatTensor): [bs, 3, 224, 224]
|
||||
text (torch.LongTensor): [bs, num_templates, 77]
|
||||
|
||||
Returns:
|
||||
embed_outputs (dict): the embedding outputs, and it contains:
|
||||
- image_embed (torch.FloatTensor):
|
||||
- text_embed (torch.FloatTensor):
|
||||
- shape_embed (torch.FloatTensor):
|
||||
- logit_scale (float):
|
||||
"""
|
||||
|
||||
# # text embedding
|
||||
# text_embed_all = []
|
||||
# for i in range(text.shape[0]):
|
||||
# text_for_one_sample = text[i]
|
||||
# text_embed = self.encode_text_embed(text_for_one_sample)
|
||||
# text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
|
||||
# text_embed = text_embed.mean(dim=0)
|
||||
# text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
|
||||
# text_embed_all.append(text_embed)
|
||||
# text_embed_all = torch.stack(text_embed_all)
|
||||
|
||||
b = text.shape[0]
|
||||
text_tokens = rearrange(text, "b t l -> (b t) l")
|
||||
text_embed = self.encode_text_embed(text_tokens)
|
||||
text_embed = rearrange(text_embed, "(b t) d -> b t d", b=b)
|
||||
text_embed = text_embed.mean(dim=1)
|
||||
text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
|
||||
|
||||
# image embedding
|
||||
image_embed = self.encode_image_embed(image)
|
||||
|
||||
# shape embedding
|
||||
shape_embed, shape_latents = self.encode_shape_embed(surface, return_latents=True)
|
||||
|
||||
embed_outputs = {
|
||||
"image_embed": image_embed,
|
||||
"text_embed": text_embed,
|
||||
"shape_embed": shape_embed,
|
||||
"logit_scale": self.clip_model.logit_scale.exp()
|
||||
}
|
||||
|
||||
return embed_outputs, shape_latents
|
||||
Reference in New Issue
Block a user