mirror of
https://github.com/PrimitiveAnything/PrimitiveAnything.git
synced 2025-09-18 05:22:48 +08:00
563 lines
18 KiB
Python
Executable File
563 lines
18 KiB
Python
Executable File
# -*- coding: utf-8 -*-
|
||
import os
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
from torchvision import transforms
|
||
from transformers import CLIPModel, CLIPTokenizer
|
||
from collections import OrderedDict
|
||
|
||
from ...data.transforms import RandomResize
|
||
|
||
|
||
class AbstractEncoder(nn.Module):
|
||
embedding_dim: int
|
||
|
||
def __init__(self):
|
||
super().__init__()
|
||
|
||
def encode(self, *args, **kwargs):
|
||
raise NotImplementedError
|
||
|
||
|
||
class ClassEmbedder(nn.Module):
|
||
def __init__(self, embed_dim, n_classes=1000, key="class"):
|
||
super().__init__()
|
||
self.key = key
|
||
self.embedding = nn.Embedding(n_classes, embed_dim)
|
||
|
||
def forward(self, batch, key=None):
|
||
if key is None:
|
||
key = self.key
|
||
# this is for use in crossattn
|
||
c = batch[key][:, None]
|
||
c = self.embedding(c)
|
||
return c
|
||
|
||
|
||
class FrozenCLIPTextEmbedder(AbstractEncoder):
|
||
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
|
||
|
||
def __init__(
|
||
self,
|
||
version="openai/clip-vit-large-patch14",
|
||
tokenizer_version=None,
|
||
device="cuda",
|
||
max_length=77,
|
||
zero_embedding_radio: float = 0.1,
|
||
):
|
||
super().__init__()
|
||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_version or version)
|
||
|
||
self.device = device
|
||
self.max_length = max_length
|
||
self.zero_embedding_radio = zero_embedding_radio
|
||
|
||
self.clip_dict = OrderedDict()
|
||
self.clip_name = os.path.split(version)[-1]
|
||
|
||
transformer = CLIPModel.from_pretrained(version).text_model
|
||
|
||
for param in transformer.parameters():
|
||
param.requires_grad = False
|
||
self.clip_dict[self.clip_name] = transformer
|
||
|
||
self._move_flag = False
|
||
|
||
@property
|
||
def clip(self):
|
||
return self.clip_dict[self.clip_name]
|
||
|
||
def move(self):
|
||
if self._move_flag:
|
||
return
|
||
|
||
self.clip_dict[self.clip_name] = self.clip_dict[self.clip_name].to(self.device)
|
||
self._move_flag = True
|
||
|
||
def unconditional_embedding(self, batch_size):
|
||
empty_text = [""] * batch_size
|
||
empty_z = self.forward(empty_text)
|
||
return empty_z
|
||
|
||
def forward(self, text):
|
||
self.move()
|
||
|
||
batch_encoding = self.tokenizer(
|
||
text,
|
||
truncation=True,
|
||
max_length=self.max_length,
|
||
return_length=True,
|
||
return_overflowing_tokens=False,
|
||
padding="max_length",
|
||
return_tensors="pt",
|
||
)
|
||
|
||
tokens = batch_encoding["input_ids"].to(self.device)
|
||
outputs = self.clip(input_ids=tokens)
|
||
|
||
z = outputs.last_hidden_state
|
||
return z
|
||
|
||
def encode(self, text):
|
||
batch_size = len(text)
|
||
batch_mask = torch.rand((batch_size,))
|
||
for i in range(batch_size):
|
||
if batch_mask[i] < self.zero_embedding_radio:
|
||
text[i] = ""
|
||
|
||
return self(text)
|
||
|
||
class FrozenAlignedCLIPTextEmbedder(AbstractEncoder):
|
||
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
|
||
|
||
def __init__(
|
||
self,
|
||
version="openai/clip-vit-large-patch14",
|
||
tokenizer_version=None,
|
||
device="cuda",
|
||
max_length=77,
|
||
zero_embedding_radio: float = 0.1,
|
||
):
|
||
super().__init__()
|
||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_version or version)
|
||
|
||
self.device = device
|
||
self.max_length = max_length
|
||
self.zero_embedding_radio = zero_embedding_radio
|
||
|
||
self.clip_dict = OrderedDict()
|
||
self.clip_name = os.path.split(version)[-1]
|
||
|
||
transformer = CLIPModel.from_pretrained(version).text_model
|
||
|
||
for param in transformer.parameters():
|
||
param.requires_grad = False
|
||
self.clip_dict[self.clip_name] = transformer
|
||
|
||
self._move_flag = False
|
||
|
||
@property
|
||
def clip(self):
|
||
return self.clip_dict[self.clip_name]
|
||
|
||
def move(self):
|
||
if self._move_flag:
|
||
return
|
||
|
||
self.clip_dict[self.clip_name] = self.clip_dict[self.clip_name].to(self.device)
|
||
self._move_flag = True
|
||
|
||
def unconditional_embedding(self, batch_size):
|
||
empty_text = [""] * batch_size
|
||
empty_z = self.forward(empty_text)
|
||
return empty_z
|
||
|
||
def forward(self, text):
|
||
self.move()
|
||
|
||
batch_encoding = self.tokenizer(
|
||
text,
|
||
truncation=True,
|
||
max_length=self.max_length,
|
||
return_length=True,
|
||
return_overflowing_tokens=False,
|
||
padding="max_length",
|
||
return_tensors="pt",
|
||
)
|
||
|
||
tokens = batch_encoding["input_ids"].to(self.device)
|
||
outputs = self.clip(input_ids=tokens)
|
||
|
||
z = outputs.last_hidden_state
|
||
return z
|
||
|
||
def encode(self, text):
|
||
batch_size = len(text)
|
||
batch_mask = torch.rand((batch_size,))
|
||
for i in range(batch_size):
|
||
if batch_mask[i] < self.zero_embedding_radio:
|
||
text[i] = ""
|
||
|
||
return self(text)
|
||
|
||
|
||
class FrozenCLIPImageEmbedder(AbstractEncoder):
|
||
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
|
||
|
||
def __init__(
|
||
self,
|
||
version="openai/clip-vit-large-patch14",
|
||
device="cuda",
|
||
zero_embedding_radio=0.1,
|
||
normalize_embedding=True,
|
||
num_projection_vector=0,
|
||
linear_mapping_bias=True,
|
||
reverse_visual_projection=False,
|
||
):
|
||
super().__init__()
|
||
|
||
self.device = device
|
||
|
||
self.clip_dict = OrderedDict()
|
||
self.clip_name = os.path.split(version)[-1]
|
||
|
||
clip_model = CLIPModel.from_pretrained(version)
|
||
clip_model.text_model = None
|
||
clip_model.text_projection = None
|
||
clip_model = clip_model.eval()
|
||
for param in self.parameters():
|
||
param.requires_grad = False
|
||
self.clip_dict[self.clip_name] = clip_model
|
||
|
||
self.transform = transforms.Compose(
|
||
[
|
||
transforms.Resize(224, transforms.InterpolationMode.BICUBIC, antialias=True),
|
||
transforms.CenterCrop(224), # crop a (224, 224) square
|
||
transforms.Normalize(
|
||
mean=[0.48145466, 0.4578275, 0.40821073],
|
||
std=[0.26862954, 0.26130258, 0.27577711],
|
||
),
|
||
]
|
||
)
|
||
self.zero_embedding_radio = zero_embedding_radio
|
||
|
||
self.num_projection_vector = num_projection_vector
|
||
self.reverse_visual_projection = reverse_visual_projection
|
||
self.normalize_embedding = normalize_embedding
|
||
|
||
embedding_dim = (
|
||
clip_model.visual_projection.in_features
|
||
if reverse_visual_projection
|
||
else clip_model.visual_projection.out_features
|
||
)
|
||
self.embedding_dim = embedding_dim
|
||
if self.num_projection_vector > 0:
|
||
self.projection = nn.Linear(
|
||
embedding_dim,
|
||
clip_model.visual_projection.out_features * num_projection_vector,
|
||
bias=linear_mapping_bias,
|
||
)
|
||
nn.init.normal_(self.projection.weight, std=embedding_dim ** -0.5)
|
||
|
||
self._move_flag = False
|
||
|
||
@property
|
||
def clip(self):
|
||
return self.clip_dict[self.clip_name]
|
||
|
||
def unconditional_embedding(self, batch_size):
|
||
zero = torch.zeros(
|
||
batch_size,
|
||
1,
|
||
self.embedding_dim,
|
||
device=self.device,
|
||
dtype=self.clip.visual_projection.weight.dtype,
|
||
)
|
||
if self.num_projection_vector > 0:
|
||
zero = self.projection(zero).view(batch_size, self.num_projection_vector, -1)
|
||
return zero
|
||
|
||
def forward(self, image, value_range=(-1, 1), zero_embedding_radio=0):
|
||
if value_range is not None:
|
||
low, high = value_range
|
||
image = (image - low) / (high - low)
|
||
|
||
image = image.to(self.device, dtype=self.clip.visual_projection.weight.dtype)
|
||
|
||
if self.reverse_visual_projection:
|
||
z = self.clip.vision_model(self.transform(image))[1]
|
||
else:
|
||
z = self.clip.get_image_features(self.transform(image))
|
||
|
||
if self.normalize_embedding:
|
||
z = z / z.norm(dim=-1, keepdim=True)
|
||
if z.ndim == 2:
|
||
z = z.unsqueeze(dim=-2)
|
||
|
||
if zero_embedding_radio > 0:
|
||
mask = torch.rand((len(image), 1, 1), device=z.device, dtype=z.dtype) < zero_embedding_radio
|
||
z = z * mask.to(z)
|
||
|
||
if self.num_projection_vector > 0:
|
||
z = self.projection(z).view(len(image), self.num_projection_vector, -1)
|
||
|
||
return z
|
||
|
||
def move(self):
|
||
if self._move_flag:
|
||
return
|
||
|
||
self.clip_dict[self.clip_name] = self.clip_dict[self.clip_name].to(self.device)
|
||
self._move_flag = True
|
||
|
||
def encode(self, image):
|
||
self.move()
|
||
return self(image, zero_embedding_radio=self.zero_embedding_radio)
|
||
|
||
|
||
class FrozenCLIPImageGridEmbedder(AbstractEncoder):
|
||
|
||
def __init__(
|
||
self,
|
||
version="openai/clip-vit-large-patch14",
|
||
device="cuda",
|
||
zero_embedding_radio=0.1,
|
||
):
|
||
super().__init__()
|
||
|
||
self.device = device
|
||
|
||
self.clip_dict = OrderedDict()
|
||
self.clip_name = os.path.split(version)[-1]
|
||
|
||
clip_model: CLIPModel = CLIPModel.from_pretrained(version)
|
||
clip_model.text_model = None
|
||
clip_model.text_projection = None
|
||
clip_model = clip_model.eval()
|
||
for param in self.parameters():
|
||
param.requires_grad = False
|
||
self.clip_dict[self.clip_name] = clip_model
|
||
|
||
self.transform = transforms.Compose(
|
||
[
|
||
transforms.Resize(224, transforms.InterpolationMode.BILINEAR, antialias=True),
|
||
transforms.CenterCrop(224), # crop a (224, 224) square
|
||
transforms.Normalize(
|
||
mean=[0.48145466, 0.4578275, 0.40821073],
|
||
std=[0.26862954, 0.26130258, 0.27577711],
|
||
),
|
||
]
|
||
)
|
||
self.zero_embedding_radio = zero_embedding_radio
|
||
self.embedding_dim = clip_model.vision_embed_dim
|
||
|
||
self._move_flag = False
|
||
|
||
@property
|
||
def clip(self):
|
||
return self.clip_dict[self.clip_name]
|
||
|
||
def move(self):
|
||
if self._move_flag:
|
||
return
|
||
|
||
self.clip_dict[self.clip_name] = self.clip_dict[self.clip_name].to(self.device)
|
||
self._move_flag = True
|
||
|
||
def unconditional_embedding(self, batch_size):
|
||
zero = torch.zeros(
|
||
batch_size,
|
||
self.clip.vision_model.embeddings.num_positions,
|
||
self.embedding_dim,
|
||
device=self.device,
|
||
dtype=self.clip.visual_projection.weight.dtype,
|
||
)
|
||
return zero
|
||
|
||
def forward(self, image, value_range=(-1, 1), zero_embedding_radio=0):
|
||
self.move()
|
||
|
||
if value_range is not None:
|
||
low, high = value_range
|
||
image = (image - low) / (high - low)
|
||
|
||
image = image.to(self.device, dtype=self.clip.visual_projection.weight.dtype)
|
||
|
||
z = self.clip.vision_model(self.transform(image)).last_hidden_state
|
||
|
||
if zero_embedding_radio > 0:
|
||
mask = torch.rand((len(image), 1, 1), device=z.device, dtype=z.dtype) >= zero_embedding_radio
|
||
z = z * mask.to(z)
|
||
|
||
return z
|
||
|
||
def encode(self, image):
|
||
return self(image, zero_embedding_radio=self.zero_embedding_radio)
|
||
|
||
|
||
class MoECLIPImageEncoder(nn.Module):
|
||
def __init__(
|
||
self,
|
||
versions,
|
||
hidden_state_dim,
|
||
num_projection_vector=8,
|
||
zero_embedding_radio=0.1,
|
||
device="cuda",
|
||
precision="fp16",
|
||
normalize=False,
|
||
clip_max=0,
|
||
transform_type="base",
|
||
argument_p=0.2,
|
||
):
|
||
super().__init__()
|
||
|
||
self.device = torch.device(device)
|
||
self.hidden_state_dim = hidden_state_dim
|
||
self.zero_embedding_radio = zero_embedding_radio
|
||
self.num_projection_vector = num_projection_vector
|
||
self.dtype = dict(fp16=torch.float16, fp32=torch.float32, bf16=torch.bfloat16)[precision]
|
||
self.normalize = normalize
|
||
self.clip_max = clip_max
|
||
|
||
if transform_type == "base":
|
||
self.transform = transforms.Compose(
|
||
[
|
||
transforms.Resize(224, transforms.InterpolationMode.BICUBIC, antialias=True),
|
||
transforms.CenterCrop(224), # crop a (224, 224) square
|
||
transforms.Normalize(
|
||
mean=[0.48145466, 0.4578275, 0.40821073],
|
||
std=[0.26862954, 0.26130258, 0.27577711],
|
||
),
|
||
]
|
||
)
|
||
elif transform_type == "crop_blur_resize":
|
||
self.transform = transforms.Compose(
|
||
[
|
||
transforms.Resize(224, transforms.InterpolationMode.BICUBIC, antialias=True),
|
||
transforms.CenterCrop(224), # crop a (224, 224) square
|
||
transforms.RandomApply(
|
||
transforms=[
|
||
transforms.RandomResizedCrop(
|
||
size=224,
|
||
scale=(0.8, 1.0),
|
||
ratio=(0.99, 1.01),
|
||
interpolation=transforms.InterpolationMode.BICUBIC,
|
||
),
|
||
],
|
||
p=argument_p,
|
||
),
|
||
transforms.RandomApply(
|
||
transforms=[
|
||
transforms.GaussianBlur(kernel_size=9, sigma=(0.1, 5)),
|
||
],
|
||
p=argument_p,
|
||
),
|
||
transforms.RandomApply(
|
||
transforms=[
|
||
RandomResize(size=224, resize_radio=(0.2, 1)),
|
||
],
|
||
p=argument_p,
|
||
),
|
||
transforms.Normalize(
|
||
mean=[0.48145466, 0.4578275, 0.40821073],
|
||
std=[0.26862954, 0.26130258, 0.27577711],
|
||
),
|
||
]
|
||
)
|
||
else:
|
||
raise ValueError(f"invalid {transform_type=}")
|
||
|
||
if isinstance(versions, str):
|
||
versions = (versions,)
|
||
|
||
# 如果直接把clips定位为当前类的子module,1. 会在保存ckp时存无用的多个权重。 2. pl会调用to,导致layer_norm的权重也被转换成fp16
|
||
clips = OrderedDict()
|
||
|
||
for v in versions:
|
||
# 因为clips不是子module,直接指定device="cuda"会错误地导致clip模型权重都被放到cuda:0上。
|
||
clips[v], _ = clip.load(name=v, device="cpu", jit=False, download_root=None)
|
||
delattr(clips[v], "transformer")
|
||
clips[v].eval()
|
||
clips[v].requires_grad_(False)
|
||
|
||
self.clips_hidden_dim = sum(clips[v].ln_final.weight.size(0) for v in clips)
|
||
|
||
if self.num_projection_vector == 0:
|
||
self.projection = nn.Identity()
|
||
else:
|
||
self.projection = nn.Linear(self.clips_hidden_dim, hidden_state_dim * self.num_projection_vector, bias=True)
|
||
self.projection.to(dtype=self.dtype)
|
||
nn.init.normal_(self.projection.weight, std=self.clips_hidden_dim ** -0.5)
|
||
|
||
self.clips = clips
|
||
|
||
self._move_flag = False
|
||
|
||
def move(self):
|
||
if self._move_flag:
|
||
return
|
||
|
||
def convert_weights(model: nn.Module):
|
||
"""Convert applicable model parameters to fp16"""
|
||
|
||
def _convert_weights_to_fp16(l):
|
||
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
||
l.weight.data = l.weight.data.type(self.dtype)
|
||
if l.bias is not None:
|
||
l.bias.data = l.bias.data.type(self.dtype)
|
||
|
||
if isinstance(l, nn.MultiheadAttention):
|
||
for attr in [
|
||
*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]],
|
||
"in_proj_bias",
|
||
"bias_k",
|
||
"bias_v",
|
||
]:
|
||
tensor = getattr(l, attr)
|
||
if tensor is not None:
|
||
tensor.data = tensor.data.type(self.dtype)
|
||
|
||
for name in ["text_projection", "proj"]:
|
||
if hasattr(l, name):
|
||
attr = getattr(l, name)
|
||
if attr is not None:
|
||
attr.data = attr.data.type(self.dtype)
|
||
|
||
model.apply(_convert_weights_to_fp16)
|
||
|
||
for k in self.clips:
|
||
self.clips[k].to(self.device)
|
||
convert_weights(self.clips[k]) # fp32 -> self.dtype
|
||
self._move_flag = True
|
||
|
||
def unconditional_embedding(self, batch_size=None):
|
||
zero = torch.zeros(
|
||
batch_size,
|
||
self.clips_hidden_dim,
|
||
device=self.device,
|
||
dtype=self.dtype,
|
||
)
|
||
if self.num_projection_vector > 0:
|
||
zero = self.projection(zero).view(batch_size, self.num_projection_vector, -1)
|
||
return zero
|
||
|
||
def convert_embedding(self, z):
|
||
if self.num_projection_vector > 0:
|
||
z = self.projection(z.type(self.projection.weight.dtype)).view(len(z), self.num_projection_vector, -1)
|
||
return z
|
||
|
||
def forward(self, image, value_range=(-1, 1), zero_embedding_radio=0):
|
||
if value_range is not None:
|
||
low, high = value_range
|
||
image = (image - low) / (high - low)
|
||
|
||
image = self.transform(image)
|
||
|
||
with torch.no_grad():
|
||
embs = []
|
||
for v in self.clips:
|
||
x = self.clips[v].encode_image(image)
|
||
if self.normalize:
|
||
x = x / x.norm(p=2, dim=-1, keepdim=True) * (x.size(-1) ** 0.5)
|
||
# clip_max only works with normalization
|
||
if self.clip_max > 0:
|
||
x = x.clamp(-self.clip_max, self.clip_max)
|
||
embs.append(x)
|
||
|
||
z = torch.cat(embs, dim=-1)
|
||
if self.normalize:
|
||
z /= z.size(-1) ** 0.5
|
||
|
||
if zero_embedding_radio > 0:
|
||
mask = torch.rand((len(image), 1, 1), device=z.device, dtype=z.dtype) >= zero_embedding_radio
|
||
z = z + mask.to(z)
|
||
|
||
if self.num_projection_vector > 0:
|
||
z = self.projection(z).view(len(image), self.num_projection_vector, -1)
|
||
return z
|
||
|
||
def encode(self, image):
|
||
self.move()
|
||
return self(image, zero_embedding_radio=self.zero_embedding_radio)
|