mirror of
https://github.com/PrimitiveAnything/PrimitiveAnything.git
synced 2026-05-08 00:58:55 +08:00
init
This commit is contained in:
89
primitive_anything/michelangelo/models/conditional_encoders/clip.py
Executable file
89
primitive_anything/michelangelo/models/conditional_encoders/clip.py
Executable file
@@ -0,0 +1,89 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from dataclasses import dataclass
|
||||
from torchvision.transforms import Normalize
|
||||
from transformers import CLIPModel, CLIPTokenizer
|
||||
from transformers.utils import ModelOutput
|
||||
from typing import Iterable, Optional, Union, List
|
||||
|
||||
|
||||
ImageType = Union[np.ndarray, torch.Tensor, Image.Image]
|
||||
|
||||
|
||||
@dataclass
|
||||
class CLIPEmbedOutput(ModelOutput):
|
||||
last_hidden_state: torch.FloatTensor = None
|
||||
pooler_output: torch.FloatTensor = None
|
||||
embeds: torch.FloatTensor = None
|
||||
|
||||
|
||||
class CLIPEncoder(torch.nn.Module):
|
||||
|
||||
def __init__(self, model_path="openai/clip-vit-base-patch32"):
|
||||
|
||||
super().__init__()
|
||||
|
||||
# Load the CLIP model and processor
|
||||
self.model: CLIPModel = CLIPModel.from_pretrained(model_path)
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(model_path)
|
||||
self.image_preprocess = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
|
||||
self.model.training = False
|
||||
for p in self.model.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_image(self, images: Iterable[Optional[ImageType]]):
|
||||
pixel_values = self.image_preprocess(images)
|
||||
|
||||
vision_outputs = self.model.vision_model(pixel_values=pixel_values)
|
||||
|
||||
pooler_output = vision_outputs[1] # pooled_output
|
||||
image_features = self.model.visual_projection(pooler_output)
|
||||
|
||||
visual_embeds = CLIPEmbedOutput(
|
||||
last_hidden_state=vision_outputs.last_hidden_state,
|
||||
pooler_output=pooler_output,
|
||||
embeds=image_features
|
||||
)
|
||||
|
||||
return visual_embeds
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_text(self, texts: List[str]):
|
||||
text_inputs = self.tokenizer(texts, padding=True, return_tensors="pt")
|
||||
|
||||
text_outputs = self.model.text_model(input_ids=text_inputs)
|
||||
|
||||
pooler_output = text_outputs[1] # pooled_output
|
||||
text_features = self.model.text_projection(pooler_output)
|
||||
|
||||
text_embeds = CLIPEmbedOutput(
|
||||
last_hidden_state=text_outputs.last_hidden_state,
|
||||
pooler_output=pooler_output,
|
||||
embeds=text_features
|
||||
)
|
||||
|
||||
return text_embeds
|
||||
|
||||
def forward(self,
|
||||
images: Iterable[Optional[ImageType]],
|
||||
texts: List[str]):
|
||||
|
||||
visual_embeds = self.encode_image(images)
|
||||
text_embeds = self.encode_text(texts)
|
||||
|
||||
return visual_embeds, text_embeds
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user