mirror of
https://github.com/PrimitiveAnything/PrimitiveAnything.git
synced 2025-12-28 11:00:33 +08:00
init
This commit is contained in:
3
primitive_anything/michelangelo/models/conditional_encoders/__init__.py
Executable file
3
primitive_anything/michelangelo/models/conditional_encoders/__init__.py
Executable file
@@ -0,0 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from .clip import CLIPEncoder
|
||||
89
primitive_anything/michelangelo/models/conditional_encoders/clip.py
Executable file
89
primitive_anything/michelangelo/models/conditional_encoders/clip.py
Executable file
@@ -0,0 +1,89 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from dataclasses import dataclass
|
||||
from torchvision.transforms import Normalize
|
||||
from transformers import CLIPModel, CLIPTokenizer
|
||||
from transformers.utils import ModelOutput
|
||||
from typing import Iterable, Optional, Union, List
|
||||
|
||||
|
||||
ImageType = Union[np.ndarray, torch.Tensor, Image.Image]
|
||||
|
||||
|
||||
@dataclass
|
||||
class CLIPEmbedOutput(ModelOutput):
|
||||
last_hidden_state: torch.FloatTensor = None
|
||||
pooler_output: torch.FloatTensor = None
|
||||
embeds: torch.FloatTensor = None
|
||||
|
||||
|
||||
class CLIPEncoder(torch.nn.Module):
|
||||
|
||||
def __init__(self, model_path="openai/clip-vit-base-patch32"):
|
||||
|
||||
super().__init__()
|
||||
|
||||
# Load the CLIP model and processor
|
||||
self.model: CLIPModel = CLIPModel.from_pretrained(model_path)
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(model_path)
|
||||
self.image_preprocess = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
|
||||
self.model.training = False
|
||||
for p in self.model.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_image(self, images: Iterable[Optional[ImageType]]):
|
||||
pixel_values = self.image_preprocess(images)
|
||||
|
||||
vision_outputs = self.model.vision_model(pixel_values=pixel_values)
|
||||
|
||||
pooler_output = vision_outputs[1] # pooled_output
|
||||
image_features = self.model.visual_projection(pooler_output)
|
||||
|
||||
visual_embeds = CLIPEmbedOutput(
|
||||
last_hidden_state=vision_outputs.last_hidden_state,
|
||||
pooler_output=pooler_output,
|
||||
embeds=image_features
|
||||
)
|
||||
|
||||
return visual_embeds
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_text(self, texts: List[str]):
|
||||
text_inputs = self.tokenizer(texts, padding=True, return_tensors="pt")
|
||||
|
||||
text_outputs = self.model.text_model(input_ids=text_inputs)
|
||||
|
||||
pooler_output = text_outputs[1] # pooled_output
|
||||
text_features = self.model.text_projection(pooler_output)
|
||||
|
||||
text_embeds = CLIPEmbedOutput(
|
||||
last_hidden_state=text_outputs.last_hidden_state,
|
||||
pooler_output=pooler_output,
|
||||
embeds=text_features
|
||||
)
|
||||
|
||||
return text_embeds
|
||||
|
||||
def forward(self,
|
||||
images: Iterable[Optional[ImageType]],
|
||||
texts: List[str]):
|
||||
|
||||
visual_embeds = self.encode_image(images)
|
||||
text_embeds = self.encode_text(texts)
|
||||
|
||||
return visual_embeds, text_embeds
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
562
primitive_anything/michelangelo/models/conditional_encoders/encoder_factory.py
Executable file
562
primitive_anything/michelangelo/models/conditional_encoders/encoder_factory.py
Executable file
@@ -0,0 +1,562 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision import transforms
|
||||
from transformers import CLIPModel, CLIPTokenizer
|
||||
from collections import OrderedDict
|
||||
|
||||
from ...data.transforms import RandomResize
|
||||
|
||||
|
||||
class AbstractEncoder(nn.Module):
|
||||
embedding_dim: int
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def encode(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ClassEmbedder(nn.Module):
|
||||
def __init__(self, embed_dim, n_classes=1000, key="class"):
|
||||
super().__init__()
|
||||
self.key = key
|
||||
self.embedding = nn.Embedding(n_classes, embed_dim)
|
||||
|
||||
def forward(self, batch, key=None):
|
||||
if key is None:
|
||||
key = self.key
|
||||
# this is for use in crossattn
|
||||
c = batch[key][:, None]
|
||||
c = self.embedding(c)
|
||||
return c
|
||||
|
||||
|
||||
class FrozenCLIPTextEmbedder(AbstractEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
version="openai/clip-vit-large-patch14",
|
||||
tokenizer_version=None,
|
||||
device="cuda",
|
||||
max_length=77,
|
||||
zero_embedding_radio: float = 0.1,
|
||||
):
|
||||
super().__init__()
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_version or version)
|
||||
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
self.zero_embedding_radio = zero_embedding_radio
|
||||
|
||||
self.clip_dict = OrderedDict()
|
||||
self.clip_name = os.path.split(version)[-1]
|
||||
|
||||
transformer = CLIPModel.from_pretrained(version).text_model
|
||||
|
||||
for param in transformer.parameters():
|
||||
param.requires_grad = False
|
||||
self.clip_dict[self.clip_name] = transformer
|
||||
|
||||
self._move_flag = False
|
||||
|
||||
@property
|
||||
def clip(self):
|
||||
return self.clip_dict[self.clip_name]
|
||||
|
||||
def move(self):
|
||||
if self._move_flag:
|
||||
return
|
||||
|
||||
self.clip_dict[self.clip_name] = self.clip_dict[self.clip_name].to(self.device)
|
||||
self._move_flag = True
|
||||
|
||||
def unconditional_embedding(self, batch_size):
|
||||
empty_text = [""] * batch_size
|
||||
empty_z = self.forward(empty_text)
|
||||
return empty_z
|
||||
|
||||
def forward(self, text):
|
||||
self.move()
|
||||
|
||||
batch_encoding = self.tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_length=True,
|
||||
return_overflowing_tokens=False,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
outputs = self.clip(input_ids=tokens)
|
||||
|
||||
z = outputs.last_hidden_state
|
||||
return z
|
||||
|
||||
def encode(self, text):
|
||||
batch_size = len(text)
|
||||
batch_mask = torch.rand((batch_size,))
|
||||
for i in range(batch_size):
|
||||
if batch_mask[i] < self.zero_embedding_radio:
|
||||
text[i] = ""
|
||||
|
||||
return self(text)
|
||||
|
||||
class FrozenAlignedCLIPTextEmbedder(AbstractEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
version="openai/clip-vit-large-patch14",
|
||||
tokenizer_version=None,
|
||||
device="cuda",
|
||||
max_length=77,
|
||||
zero_embedding_radio: float = 0.1,
|
||||
):
|
||||
super().__init__()
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_version or version)
|
||||
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
self.zero_embedding_radio = zero_embedding_radio
|
||||
|
||||
self.clip_dict = OrderedDict()
|
||||
self.clip_name = os.path.split(version)[-1]
|
||||
|
||||
transformer = CLIPModel.from_pretrained(version).text_model
|
||||
|
||||
for param in transformer.parameters():
|
||||
param.requires_grad = False
|
||||
self.clip_dict[self.clip_name] = transformer
|
||||
|
||||
self._move_flag = False
|
||||
|
||||
@property
|
||||
def clip(self):
|
||||
return self.clip_dict[self.clip_name]
|
||||
|
||||
def move(self):
|
||||
if self._move_flag:
|
||||
return
|
||||
|
||||
self.clip_dict[self.clip_name] = self.clip_dict[self.clip_name].to(self.device)
|
||||
self._move_flag = True
|
||||
|
||||
def unconditional_embedding(self, batch_size):
|
||||
empty_text = [""] * batch_size
|
||||
empty_z = self.forward(empty_text)
|
||||
return empty_z
|
||||
|
||||
def forward(self, text):
|
||||
self.move()
|
||||
|
||||
batch_encoding = self.tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_length=True,
|
||||
return_overflowing_tokens=False,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
outputs = self.clip(input_ids=tokens)
|
||||
|
||||
z = outputs.last_hidden_state
|
||||
return z
|
||||
|
||||
def encode(self, text):
|
||||
batch_size = len(text)
|
||||
batch_mask = torch.rand((batch_size,))
|
||||
for i in range(batch_size):
|
||||
if batch_mask[i] < self.zero_embedding_radio:
|
||||
text[i] = ""
|
||||
|
||||
return self(text)
|
||||
|
||||
|
||||
class FrozenCLIPImageEmbedder(AbstractEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
version="openai/clip-vit-large-patch14",
|
||||
device="cuda",
|
||||
zero_embedding_radio=0.1,
|
||||
normalize_embedding=True,
|
||||
num_projection_vector=0,
|
||||
linear_mapping_bias=True,
|
||||
reverse_visual_projection=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.device = device
|
||||
|
||||
self.clip_dict = OrderedDict()
|
||||
self.clip_name = os.path.split(version)[-1]
|
||||
|
||||
clip_model = CLIPModel.from_pretrained(version)
|
||||
clip_model.text_model = None
|
||||
clip_model.text_projection = None
|
||||
clip_model = clip_model.eval()
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
self.clip_dict[self.clip_name] = clip_model
|
||||
|
||||
self.transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(224, transforms.InterpolationMode.BICUBIC, antialias=True),
|
||||
transforms.CenterCrop(224), # crop a (224, 224) square
|
||||
transforms.Normalize(
|
||||
mean=[0.48145466, 0.4578275, 0.40821073],
|
||||
std=[0.26862954, 0.26130258, 0.27577711],
|
||||
),
|
||||
]
|
||||
)
|
||||
self.zero_embedding_radio = zero_embedding_radio
|
||||
|
||||
self.num_projection_vector = num_projection_vector
|
||||
self.reverse_visual_projection = reverse_visual_projection
|
||||
self.normalize_embedding = normalize_embedding
|
||||
|
||||
embedding_dim = (
|
||||
clip_model.visual_projection.in_features
|
||||
if reverse_visual_projection
|
||||
else clip_model.visual_projection.out_features
|
||||
)
|
||||
self.embedding_dim = embedding_dim
|
||||
if self.num_projection_vector > 0:
|
||||
self.projection = nn.Linear(
|
||||
embedding_dim,
|
||||
clip_model.visual_projection.out_features * num_projection_vector,
|
||||
bias=linear_mapping_bias,
|
||||
)
|
||||
nn.init.normal_(self.projection.weight, std=embedding_dim ** -0.5)
|
||||
|
||||
self._move_flag = False
|
||||
|
||||
@property
|
||||
def clip(self):
|
||||
return self.clip_dict[self.clip_name]
|
||||
|
||||
def unconditional_embedding(self, batch_size):
|
||||
zero = torch.zeros(
|
||||
batch_size,
|
||||
1,
|
||||
self.embedding_dim,
|
||||
device=self.device,
|
||||
dtype=self.clip.visual_projection.weight.dtype,
|
||||
)
|
||||
if self.num_projection_vector > 0:
|
||||
zero = self.projection(zero).view(batch_size, self.num_projection_vector, -1)
|
||||
return zero
|
||||
|
||||
def forward(self, image, value_range=(-1, 1), zero_embedding_radio=0):
|
||||
if value_range is not None:
|
||||
low, high = value_range
|
||||
image = (image - low) / (high - low)
|
||||
|
||||
image = image.to(self.device, dtype=self.clip.visual_projection.weight.dtype)
|
||||
|
||||
if self.reverse_visual_projection:
|
||||
z = self.clip.vision_model(self.transform(image))[1]
|
||||
else:
|
||||
z = self.clip.get_image_features(self.transform(image))
|
||||
|
||||
if self.normalize_embedding:
|
||||
z = z / z.norm(dim=-1, keepdim=True)
|
||||
if z.ndim == 2:
|
||||
z = z.unsqueeze(dim=-2)
|
||||
|
||||
if zero_embedding_radio > 0:
|
||||
mask = torch.rand((len(image), 1, 1), device=z.device, dtype=z.dtype) < zero_embedding_radio
|
||||
z = z * mask.to(z)
|
||||
|
||||
if self.num_projection_vector > 0:
|
||||
z = self.projection(z).view(len(image), self.num_projection_vector, -1)
|
||||
|
||||
return z
|
||||
|
||||
def move(self):
|
||||
if self._move_flag:
|
||||
return
|
||||
|
||||
self.clip_dict[self.clip_name] = self.clip_dict[self.clip_name].to(self.device)
|
||||
self._move_flag = True
|
||||
|
||||
def encode(self, image):
|
||||
self.move()
|
||||
return self(image, zero_embedding_radio=self.zero_embedding_radio)
|
||||
|
||||
|
||||
class FrozenCLIPImageGridEmbedder(AbstractEncoder):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
version="openai/clip-vit-large-patch14",
|
||||
device="cuda",
|
||||
zero_embedding_radio=0.1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.device = device
|
||||
|
||||
self.clip_dict = OrderedDict()
|
||||
self.clip_name = os.path.split(version)[-1]
|
||||
|
||||
clip_model: CLIPModel = CLIPModel.from_pretrained(version)
|
||||
clip_model.text_model = None
|
||||
clip_model.text_projection = None
|
||||
clip_model = clip_model.eval()
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
self.clip_dict[self.clip_name] = clip_model
|
||||
|
||||
self.transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(224, transforms.InterpolationMode.BILINEAR, antialias=True),
|
||||
transforms.CenterCrop(224), # crop a (224, 224) square
|
||||
transforms.Normalize(
|
||||
mean=[0.48145466, 0.4578275, 0.40821073],
|
||||
std=[0.26862954, 0.26130258, 0.27577711],
|
||||
),
|
||||
]
|
||||
)
|
||||
self.zero_embedding_radio = zero_embedding_radio
|
||||
self.embedding_dim = clip_model.vision_embed_dim
|
||||
|
||||
self._move_flag = False
|
||||
|
||||
@property
|
||||
def clip(self):
|
||||
return self.clip_dict[self.clip_name]
|
||||
|
||||
def move(self):
|
||||
if self._move_flag:
|
||||
return
|
||||
|
||||
self.clip_dict[self.clip_name] = self.clip_dict[self.clip_name].to(self.device)
|
||||
self._move_flag = True
|
||||
|
||||
def unconditional_embedding(self, batch_size):
|
||||
zero = torch.zeros(
|
||||
batch_size,
|
||||
self.clip.vision_model.embeddings.num_positions,
|
||||
self.embedding_dim,
|
||||
device=self.device,
|
||||
dtype=self.clip.visual_projection.weight.dtype,
|
||||
)
|
||||
return zero
|
||||
|
||||
def forward(self, image, value_range=(-1, 1), zero_embedding_radio=0):
|
||||
self.move()
|
||||
|
||||
if value_range is not None:
|
||||
low, high = value_range
|
||||
image = (image - low) / (high - low)
|
||||
|
||||
image = image.to(self.device, dtype=self.clip.visual_projection.weight.dtype)
|
||||
|
||||
z = self.clip.vision_model(self.transform(image)).last_hidden_state
|
||||
|
||||
if zero_embedding_radio > 0:
|
||||
mask = torch.rand((len(image), 1, 1), device=z.device, dtype=z.dtype) >= zero_embedding_radio
|
||||
z = z * mask.to(z)
|
||||
|
||||
return z
|
||||
|
||||
def encode(self, image):
|
||||
return self(image, zero_embedding_radio=self.zero_embedding_radio)
|
||||
|
||||
|
||||
class MoECLIPImageEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
versions,
|
||||
hidden_state_dim,
|
||||
num_projection_vector=8,
|
||||
zero_embedding_radio=0.1,
|
||||
device="cuda",
|
||||
precision="fp16",
|
||||
normalize=False,
|
||||
clip_max=0,
|
||||
transform_type="base",
|
||||
argument_p=0.2,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.device = torch.device(device)
|
||||
self.hidden_state_dim = hidden_state_dim
|
||||
self.zero_embedding_radio = zero_embedding_radio
|
||||
self.num_projection_vector = num_projection_vector
|
||||
self.dtype = dict(fp16=torch.float16, fp32=torch.float32, bf16=torch.bfloat16)[precision]
|
||||
self.normalize = normalize
|
||||
self.clip_max = clip_max
|
||||
|
||||
if transform_type == "base":
|
||||
self.transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(224, transforms.InterpolationMode.BICUBIC, antialias=True),
|
||||
transforms.CenterCrop(224), # crop a (224, 224) square
|
||||
transforms.Normalize(
|
||||
mean=[0.48145466, 0.4578275, 0.40821073],
|
||||
std=[0.26862954, 0.26130258, 0.27577711],
|
||||
),
|
||||
]
|
||||
)
|
||||
elif transform_type == "crop_blur_resize":
|
||||
self.transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(224, transforms.InterpolationMode.BICUBIC, antialias=True),
|
||||
transforms.CenterCrop(224), # crop a (224, 224) square
|
||||
transforms.RandomApply(
|
||||
transforms=[
|
||||
transforms.RandomResizedCrop(
|
||||
size=224,
|
||||
scale=(0.8, 1.0),
|
||||
ratio=(0.99, 1.01),
|
||||
interpolation=transforms.InterpolationMode.BICUBIC,
|
||||
),
|
||||
],
|
||||
p=argument_p,
|
||||
),
|
||||
transforms.RandomApply(
|
||||
transforms=[
|
||||
transforms.GaussianBlur(kernel_size=9, sigma=(0.1, 5)),
|
||||
],
|
||||
p=argument_p,
|
||||
),
|
||||
transforms.RandomApply(
|
||||
transforms=[
|
||||
RandomResize(size=224, resize_radio=(0.2, 1)),
|
||||
],
|
||||
p=argument_p,
|
||||
),
|
||||
transforms.Normalize(
|
||||
mean=[0.48145466, 0.4578275, 0.40821073],
|
||||
std=[0.26862954, 0.26130258, 0.27577711],
|
||||
),
|
||||
]
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"invalid {transform_type=}")
|
||||
|
||||
if isinstance(versions, str):
|
||||
versions = (versions,)
|
||||
|
||||
# 如果直接把clips定位为当前类的子module,1. 会在保存ckp时存无用的多个权重。 2. pl会调用to,导致layer_norm的权重也被转换成fp16
|
||||
clips = OrderedDict()
|
||||
|
||||
for v in versions:
|
||||
# 因为clips不是子module,直接指定device="cuda"会错误地导致clip模型权重都被放到cuda:0上。
|
||||
clips[v], _ = clip.load(name=v, device="cpu", jit=False, download_root=None)
|
||||
delattr(clips[v], "transformer")
|
||||
clips[v].eval()
|
||||
clips[v].requires_grad_(False)
|
||||
|
||||
self.clips_hidden_dim = sum(clips[v].ln_final.weight.size(0) for v in clips)
|
||||
|
||||
if self.num_projection_vector == 0:
|
||||
self.projection = nn.Identity()
|
||||
else:
|
||||
self.projection = nn.Linear(self.clips_hidden_dim, hidden_state_dim * self.num_projection_vector, bias=True)
|
||||
self.projection.to(dtype=self.dtype)
|
||||
nn.init.normal_(self.projection.weight, std=self.clips_hidden_dim ** -0.5)
|
||||
|
||||
self.clips = clips
|
||||
|
||||
self._move_flag = False
|
||||
|
||||
def move(self):
|
||||
if self._move_flag:
|
||||
return
|
||||
|
||||
def convert_weights(model: nn.Module):
|
||||
"""Convert applicable model parameters to fp16"""
|
||||
|
||||
def _convert_weights_to_fp16(l):
|
||||
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
||||
l.weight.data = l.weight.data.type(self.dtype)
|
||||
if l.bias is not None:
|
||||
l.bias.data = l.bias.data.type(self.dtype)
|
||||
|
||||
if isinstance(l, nn.MultiheadAttention):
|
||||
for attr in [
|
||||
*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]],
|
||||
"in_proj_bias",
|
||||
"bias_k",
|
||||
"bias_v",
|
||||
]:
|
||||
tensor = getattr(l, attr)
|
||||
if tensor is not None:
|
||||
tensor.data = tensor.data.type(self.dtype)
|
||||
|
||||
for name in ["text_projection", "proj"]:
|
||||
if hasattr(l, name):
|
||||
attr = getattr(l, name)
|
||||
if attr is not None:
|
||||
attr.data = attr.data.type(self.dtype)
|
||||
|
||||
model.apply(_convert_weights_to_fp16)
|
||||
|
||||
for k in self.clips:
|
||||
self.clips[k].to(self.device)
|
||||
convert_weights(self.clips[k]) # fp32 -> self.dtype
|
||||
self._move_flag = True
|
||||
|
||||
def unconditional_embedding(self, batch_size=None):
|
||||
zero = torch.zeros(
|
||||
batch_size,
|
||||
self.clips_hidden_dim,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
if self.num_projection_vector > 0:
|
||||
zero = self.projection(zero).view(batch_size, self.num_projection_vector, -1)
|
||||
return zero
|
||||
|
||||
def convert_embedding(self, z):
|
||||
if self.num_projection_vector > 0:
|
||||
z = self.projection(z.type(self.projection.weight.dtype)).view(len(z), self.num_projection_vector, -1)
|
||||
return z
|
||||
|
||||
def forward(self, image, value_range=(-1, 1), zero_embedding_radio=0):
|
||||
if value_range is not None:
|
||||
low, high = value_range
|
||||
image = (image - low) / (high - low)
|
||||
|
||||
image = self.transform(image)
|
||||
|
||||
with torch.no_grad():
|
||||
embs = []
|
||||
for v in self.clips:
|
||||
x = self.clips[v].encode_image(image)
|
||||
if self.normalize:
|
||||
x = x / x.norm(p=2, dim=-1, keepdim=True) * (x.size(-1) ** 0.5)
|
||||
# clip_max only works with normalization
|
||||
if self.clip_max > 0:
|
||||
x = x.clamp(-self.clip_max, self.clip_max)
|
||||
embs.append(x)
|
||||
|
||||
z = torch.cat(embs, dim=-1)
|
||||
if self.normalize:
|
||||
z /= z.size(-1) ** 0.5
|
||||
|
||||
if zero_embedding_radio > 0:
|
||||
mask = torch.rand((len(image), 1, 1), device=z.device, dtype=z.dtype) >= zero_embedding_radio
|
||||
z = z + mask.to(z)
|
||||
|
||||
if self.num_projection_vector > 0:
|
||||
z = self.projection(z).view(len(image), self.num_projection_vector, -1)
|
||||
return z
|
||||
|
||||
def encode(self, image):
|
||||
self.move()
|
||||
return self(image, zero_embedding_radio=self.zero_embedding_radio)
|
||||
Reference in New Issue
Block a user