mirror of
https://github.com/PrimitiveAnything/PrimitiveAnything.git
synced 2026-03-01 18:05:58 +08:00
init
This commit is contained in:
1
primitive_anything/michelangelo/models/asl_diffusion/__init__.py
Executable file
1
primitive_anything/michelangelo/models/asl_diffusion/__init__.py
Executable file
@@ -0,0 +1 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
483
primitive_anything/michelangelo/models/asl_diffusion/asl_diffuser_pl_module.py
Executable file
483
primitive_anything/michelangelo/models/asl_diffusion/asl_diffuser_pl_module.py
Executable file
@@ -0,0 +1,483 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from omegaconf import DictConfig
|
||||
from typing import List, Tuple, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.optim import lr_scheduler
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from diffusers.schedulers import (
|
||||
DDPMScheduler,
|
||||
DDIMScheduler,
|
||||
KarrasVeScheduler,
|
||||
DPMSolverMultistepScheduler
|
||||
)
|
||||
|
||||
from ...utils import instantiate_from_config
|
||||
# from ..tsal.tsal_base import ShapeAsLatentPLModule
|
||||
from ..tsal.tsal_base import AlignedShapeAsLatentPLModule
|
||||
from .inference_utils import ddim_sample
|
||||
|
||||
SchedulerType = Union[DDIMScheduler, KarrasVeScheduler, DPMSolverMultistepScheduler]
|
||||
|
||||
|
||||
def disabled_train(self, mode=True):
|
||||
"""Overwrite model.train with this function to make sure train/eval mode
|
||||
does not change anymore."""
|
||||
return self
|
||||
|
||||
|
||||
class ASLDiffuser(pl.LightningModule):
|
||||
first_stage_model: Optional[AlignedShapeAsLatentPLModule]
|
||||
# cond_stage_model: Optional[Union[nn.Module, pl.LightningModule]]
|
||||
model: nn.Module
|
||||
|
||||
def __init__(self, *,
|
||||
first_stage_config,
|
||||
denoiser_cfg,
|
||||
scheduler_cfg,
|
||||
optimizer_cfg,
|
||||
loss_cfg,
|
||||
first_stage_key: str = "surface",
|
||||
cond_stage_key: str = "image",
|
||||
cond_stage_trainable: bool = True,
|
||||
scale_by_std: bool = False,
|
||||
z_scale_factor: float = 1.0,
|
||||
ckpt_path: Optional[str] = None,
|
||||
ignore_keys: Union[Tuple[str], List[str]] = ()):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.first_stage_key = first_stage_key
|
||||
self.cond_stage_key = cond_stage_key
|
||||
self.cond_stage_trainable = cond_stage_trainable
|
||||
|
||||
# 1. initialize first stage.
|
||||
# Note: the condition model contained in the first stage model.
|
||||
self.first_stage_config = first_stage_config
|
||||
self.first_stage_model = None
|
||||
# self.instantiate_first_stage(first_stage_config)
|
||||
|
||||
# 2. initialize conditional stage
|
||||
# self.instantiate_cond_stage(cond_stage_config)
|
||||
self.cond_stage_model = {
|
||||
"image": self.encode_image,
|
||||
"image_unconditional_embedding": self.empty_img_cond,
|
||||
"text": self.encode_text,
|
||||
"text_unconditional_embedding": self.empty_text_cond,
|
||||
"surface": self.encode_surface,
|
||||
"surface_unconditional_embedding": self.empty_surface_cond,
|
||||
}
|
||||
|
||||
# 3. diffusion model
|
||||
self.model = instantiate_from_config(
|
||||
denoiser_cfg, device=None, dtype=None
|
||||
)
|
||||
|
||||
self.optimizer_cfg = optimizer_cfg
|
||||
|
||||
# 4. scheduling strategy
|
||||
self.scheduler_cfg = scheduler_cfg
|
||||
|
||||
self.noise_scheduler: DDPMScheduler = instantiate_from_config(scheduler_cfg.noise)
|
||||
self.denoise_scheduler: SchedulerType = instantiate_from_config(scheduler_cfg.denoise)
|
||||
|
||||
# 5. loss configures
|
||||
self.loss_cfg = loss_cfg
|
||||
|
||||
self.scale_by_std = scale_by_std
|
||||
if scale_by_std:
|
||||
self.register_buffer("z_scale_factor", torch.tensor(z_scale_factor))
|
||||
else:
|
||||
self.z_scale_factor = z_scale_factor
|
||||
|
||||
self.ckpt_path = ckpt_path
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
|
||||
def instantiate_first_stage(self, config):
|
||||
model = instantiate_from_config(config)
|
||||
self.first_stage_model = model.eval()
|
||||
self.first_stage_model.train = disabled_train
|
||||
for param in self.first_stage_model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
self.first_stage_model = self.first_stage_model.to(self.device)
|
||||
|
||||
# def instantiate_cond_stage(self, config):
|
||||
# if not self.cond_stage_trainable:
|
||||
# if config == "__is_first_stage__":
|
||||
# print("Using first stage also as cond stage.")
|
||||
# self.cond_stage_model = self.first_stage_model
|
||||
# elif config == "__is_unconditional__":
|
||||
# print(f"Training {self.__class__.__name__} as an unconditional model.")
|
||||
# self.cond_stage_model = None
|
||||
# # self.be_unconditional = True
|
||||
# else:
|
||||
# model = instantiate_from_config(config)
|
||||
# self.cond_stage_model = model.eval()
|
||||
# self.cond_stage_model.train = disabled_train
|
||||
# for param in self.cond_stage_model.parameters():
|
||||
# param.requires_grad = False
|
||||
# else:
|
||||
# assert config != "__is_first_stage__"
|
||||
# assert config != "__is_unconditional__"
|
||||
# model = instantiate_from_config(config)
|
||||
# self.cond_stage_model = model
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=()):
|
||||
state_dict = torch.load(path, map_location="cpu")["state_dict"]
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
del state_dict[k]
|
||||
|
||||
missing, unexpected = self.load_state_dict(state_dict, strict=False)
|
||||
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
||||
if len(missing) > 0:
|
||||
print(f"Missing Keys: {missing}")
|
||||
print(f"Unexpected Keys: {unexpected}")
|
||||
|
||||
@property
|
||||
def zero_rank(self):
|
||||
if self._trainer:
|
||||
zero_rank = self.trainer.local_rank == 0
|
||||
else:
|
||||
zero_rank = True
|
||||
|
||||
return zero_rank
|
||||
|
||||
def configure_optimizers(self) -> Tuple[List, List]:
|
||||
|
||||
lr = self.learning_rate
|
||||
|
||||
trainable_parameters = list(self.model.parameters())
|
||||
# if the conditional encoder is trainable
|
||||
|
||||
# if self.cond_stage_trainable:
|
||||
# conditioner_params = [p for p in self.cond_stage_model.parameters() if p.requires_grad]
|
||||
# trainable_parameters += conditioner_params
|
||||
# print(f"number of trainable conditional parameters: {len(conditioner_params)}.")
|
||||
|
||||
if self.optimizer_cfg is None:
|
||||
optimizers = [torch.optim.AdamW(trainable_parameters, lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)]
|
||||
schedulers = []
|
||||
else:
|
||||
optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=trainable_parameters)
|
||||
scheduler_func = instantiate_from_config(
|
||||
self.optimizer_cfg.scheduler,
|
||||
max_decay_steps=self.trainer.max_steps,
|
||||
lr_max=lr
|
||||
)
|
||||
scheduler = {
|
||||
"scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule),
|
||||
"interval": "step",
|
||||
"frequency": 1
|
||||
}
|
||||
optimizers = [optimizer]
|
||||
schedulers = [scheduler]
|
||||
|
||||
return optimizers, schedulers
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_text(self, text):
|
||||
|
||||
b = text.shape[0]
|
||||
text_tokens = rearrange(text, "b t l -> (b t) l")
|
||||
text_embed = self.first_stage_model.model.encode_text_embed(text_tokens)
|
||||
text_embed = rearrange(text_embed, "(b t) d -> b t d", b=b)
|
||||
text_embed = text_embed.mean(dim=1)
|
||||
text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
|
||||
|
||||
return text_embed
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_image(self, img):
|
||||
|
||||
return self.first_stage_model.model.encode_image_embed(img)
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_surface(self, surface):
|
||||
|
||||
return self.first_stage_model.model.encode_shape_embed(surface, return_latents=False)
|
||||
|
||||
@torch.no_grad()
|
||||
def empty_text_cond(self, cond):
|
||||
|
||||
return torch.zeros_like(cond, device=cond.device)
|
||||
|
||||
@torch.no_grad()
|
||||
def empty_img_cond(self, cond):
|
||||
|
||||
return torch.zeros_like(cond, device=cond.device)
|
||||
|
||||
@torch.no_grad()
|
||||
def empty_surface_cond(self, cond):
|
||||
|
||||
return torch.zeros_like(cond, device=cond.device)
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_first_stage(self, surface: torch.FloatTensor, sample_posterior=True):
|
||||
|
||||
z_q = self.first_stage_model.encode(surface, sample_posterior)
|
||||
z_q = self.z_scale_factor * z_q
|
||||
|
||||
return z_q
|
||||
|
||||
@torch.no_grad()
|
||||
def decode_first_stage(self, z_q: torch.FloatTensor, **kwargs):
|
||||
|
||||
z_q = 1. / self.z_scale_factor * z_q
|
||||
latents = self.first_stage_model.decode(z_q, **kwargs)
|
||||
return latents
|
||||
|
||||
@rank_zero_only
|
||||
@torch.no_grad()
|
||||
def on_train_batch_start(self, batch, batch_idx):
|
||||
# only for very first batch
|
||||
if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 \
|
||||
and batch_idx == 0 and self.ckpt_path is None:
|
||||
# set rescale weight to 1./std of encodings
|
||||
print("### USING STD-RESCALING ###")
|
||||
|
||||
z_q = self.encode_first_stage(batch[self.first_stage_key])
|
||||
z = z_q.detach()
|
||||
|
||||
del self.z_scale_factor
|
||||
self.register_buffer("z_scale_factor", 1. / z.flatten().std())
|
||||
print(f"setting self.z_scale_factor to {self.z_scale_factor}")
|
||||
|
||||
print("### USING STD-RESCALING ###")
|
||||
|
||||
def compute_loss(self, model_outputs, split):
|
||||
"""
|
||||
|
||||
Args:
|
||||
model_outputs (dict):
|
||||
- x_0:
|
||||
- noise:
|
||||
- noise_prior:
|
||||
- noise_pred:
|
||||
- noise_pred_prior:
|
||||
|
||||
split (str):
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
|
||||
pred = model_outputs["pred"]
|
||||
|
||||
if self.noise_scheduler.prediction_type == "epsilon":
|
||||
target = model_outputs["noise"]
|
||||
elif self.noise_scheduler.prediction_type == "sample":
|
||||
target = model_outputs["x_0"]
|
||||
else:
|
||||
raise NotImplementedError(f"Prediction Type: {self.noise_scheduler.prediction_type} not yet supported.")
|
||||
|
||||
if self.loss_cfg.loss_type == "l1":
|
||||
simple = F.l1_loss(pred, target, reduction="mean")
|
||||
elif self.loss_cfg.loss_type in ["mse", "l2"]:
|
||||
simple = F.mse_loss(pred, target, reduction="mean")
|
||||
else:
|
||||
raise NotImplementedError(f"Loss Type: {self.loss_cfg.loss_type} not yet supported.")
|
||||
|
||||
total_loss = simple
|
||||
|
||||
loss_dict = {
|
||||
f"{split}/total_loss": total_loss.clone().detach(),
|
||||
f"{split}/simple": simple.detach(),
|
||||
}
|
||||
|
||||
return total_loss, loss_dict
|
||||
|
||||
def forward(self, batch):
|
||||
"""
|
||||
|
||||
Args:
|
||||
batch:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
|
||||
if self.first_stage_model is None:
|
||||
self.instantiate_first_stage(self.first_stage_config)
|
||||
|
||||
latents = self.encode_first_stage(batch[self.first_stage_key])
|
||||
|
||||
# conditions = self.cond_stage_model.encode(batch[self.cond_stage_key])
|
||||
|
||||
conditions = self.cond_stage_model[self.cond_stage_key](batch[self.cond_stage_key]).unsqueeze(1)
|
||||
|
||||
mask = torch.rand((len(conditions), 1, 1), device=conditions.device, dtype=conditions.dtype) >= 0.1
|
||||
conditions = conditions * mask.to(conditions)
|
||||
|
||||
# Sample noise that we"ll add to the latents
|
||||
# [batch_size, n_token, latent_dim]
|
||||
noise = torch.randn_like(latents)
|
||||
bs = latents.shape[0]
|
||||
# Sample a random timestep for each motion
|
||||
timesteps = torch.randint(
|
||||
0,
|
||||
self.noise_scheduler.config.num_train_timesteps,
|
||||
(bs,),
|
||||
device=latents.device,
|
||||
)
|
||||
timesteps = timesteps.long()
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
noisy_z = self.noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# diffusion model forward
|
||||
noise_pred = self.model(noisy_z, timesteps, conditions)
|
||||
|
||||
diffusion_outputs = {
|
||||
"x_0": noisy_z,
|
||||
"noise": noise,
|
||||
"pred": noise_pred
|
||||
}
|
||||
|
||||
return diffusion_outputs
|
||||
|
||||
def training_step(self, batch: Dict[str, Union[torch.FloatTensor, List[str]]],
|
||||
batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
|
||||
"""
|
||||
|
||||
Args:
|
||||
batch (dict): the batch sample, and it contains:
|
||||
- surface (torch.FloatTensor):
|
||||
- image (torch.FloatTensor): if provide, [bs, 3, h, w], item range [0, 1]
|
||||
- depth (torch.FloatTensor): if provide, [bs, 1, h, w], item range [-1, 1]
|
||||
- normal (torch.FloatTensor): if provide, [bs, 3, h, w], item range [-1, 1]
|
||||
- text (list of str):
|
||||
|
||||
batch_idx (int):
|
||||
|
||||
optimizer_idx (int):
|
||||
|
||||
Returns:
|
||||
loss (torch.FloatTensor):
|
||||
|
||||
"""
|
||||
|
||||
diffusion_outputs = self(batch)
|
||||
|
||||
loss, loss_dict = self.compute_loss(diffusion_outputs, "train")
|
||||
self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
|
||||
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch: Dict[str, torch.FloatTensor],
|
||||
batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
|
||||
"""
|
||||
|
||||
Args:
|
||||
batch (dict): the batch sample, and it contains:
|
||||
- surface_pc (torch.FloatTensor): [n_pts, 4]
|
||||
- surface_feats (torch.FloatTensor): [n_pts, c]
|
||||
- text (list of str):
|
||||
|
||||
batch_idx (int):
|
||||
|
||||
optimizer_idx (int):
|
||||
|
||||
Returns:
|
||||
loss (torch.FloatTensor):
|
||||
|
||||
"""
|
||||
|
||||
diffusion_outputs = self(batch)
|
||||
|
||||
loss, loss_dict = self.compute_loss(diffusion_outputs, "val")
|
||||
self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
|
||||
|
||||
return loss
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self,
|
||||
batch: Dict[str, Union[torch.FloatTensor, List[str]]],
|
||||
sample_times: int = 1,
|
||||
steps: Optional[int] = None,
|
||||
guidance_scale: Optional[float] = None,
|
||||
eta: float = 0.0,
|
||||
return_intermediates: bool = False, **kwargs):
|
||||
|
||||
if self.first_stage_model is None:
|
||||
self.instantiate_first_stage(self.first_stage_config)
|
||||
|
||||
if steps is None:
|
||||
steps = self.scheduler_cfg.num_inference_steps
|
||||
|
||||
if guidance_scale is None:
|
||||
guidance_scale = self.scheduler_cfg.guidance_scale
|
||||
do_classifier_free_guidance = guidance_scale > 0
|
||||
|
||||
# conditional encode
|
||||
xc = batch[self.cond_stage_key]
|
||||
# cond = self.cond_stage_model[self.cond_stage_key](xc)
|
||||
cond = self.cond_stage_model[self.cond_stage_key](xc).unsqueeze(1)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
"""
|
||||
Note: There are two kinds of uncond for text.
|
||||
1: using "" as uncond text; (in SAL diffusion)
|
||||
2: zeros_like(cond) as uncond text; (in MDM)
|
||||
"""
|
||||
# un_cond = self.cond_stage_model.unconditional_embedding(batch_size=len(xc))
|
||||
un_cond = self.cond_stage_model[f"{self.cond_stage_key}_unconditional_embedding"](cond)
|
||||
# un_cond = torch.zeros_like(cond, device=cond.device)
|
||||
cond = torch.cat([un_cond, cond], dim=0)
|
||||
|
||||
outputs = []
|
||||
latents = None
|
||||
|
||||
if not return_intermediates:
|
||||
for _ in range(sample_times):
|
||||
sample_loop = ddim_sample(
|
||||
self.denoise_scheduler,
|
||||
self.model,
|
||||
shape=self.first_stage_model.latent_shape,
|
||||
cond=cond,
|
||||
steps=steps,
|
||||
guidance_scale=guidance_scale,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
device=self.device,
|
||||
eta=eta,
|
||||
disable_prog=not self.zero_rank
|
||||
)
|
||||
for sample, t in sample_loop:
|
||||
latents = sample
|
||||
outputs.append(self.decode_first_stage(latents, **kwargs))
|
||||
else:
|
||||
|
||||
sample_loop = ddim_sample(
|
||||
self.denoise_scheduler,
|
||||
self.model,
|
||||
shape=self.first_stage_model.latent_shape,
|
||||
cond=cond,
|
||||
steps=steps,
|
||||
guidance_scale=guidance_scale,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
device=self.device,
|
||||
eta=eta,
|
||||
disable_prog=not self.zero_rank
|
||||
)
|
||||
|
||||
iter_size = steps // sample_times
|
||||
i = 0
|
||||
for sample, t in sample_loop:
|
||||
latents = sample
|
||||
if i % iter_size == 0 or i == steps - 1:
|
||||
outputs.append(self.decode_first_stage(latents, **kwargs))
|
||||
i += 1
|
||||
|
||||
return outputs
|
||||
104
primitive_anything/michelangelo/models/asl_diffusion/asl_udt.py
Executable file
104
primitive_anything/michelangelo/models/asl_diffusion/asl_udt.py
Executable file
@@ -0,0 +1,104 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Optional
|
||||
from diffusers.models.embeddings import Timesteps
|
||||
import math
|
||||
|
||||
from ..modules.transformer_blocks import MLP
|
||||
from ..modules.diffusion_transformer import UNetDiffusionTransformer
|
||||
|
||||
|
||||
class ConditionalASLUDTDenoiser(nn.Module):
|
||||
|
||||
def __init__(self, *,
|
||||
device: Optional[torch.device],
|
||||
dtype: Optional[torch.dtype],
|
||||
input_channels: int,
|
||||
output_channels: int,
|
||||
n_ctx: int,
|
||||
width: int,
|
||||
layers: int,
|
||||
heads: int,
|
||||
context_dim: int,
|
||||
context_ln: bool = True,
|
||||
skip_ln: bool = False,
|
||||
init_scale: float = 0.25,
|
||||
flip_sin_to_cos: bool = False,
|
||||
use_checkpoint: bool = False):
|
||||
super().__init__()
|
||||
|
||||
self.use_checkpoint = use_checkpoint
|
||||
|
||||
init_scale = init_scale * math.sqrt(1.0 / width)
|
||||
|
||||
self.backbone = UNetDiffusionTransformer(
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
n_ctx=n_ctx,
|
||||
width=width,
|
||||
layers=layers,
|
||||
heads=heads,
|
||||
skip_ln=skip_ln,
|
||||
init_scale=init_scale,
|
||||
use_checkpoint=use_checkpoint
|
||||
)
|
||||
self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
|
||||
self.input_proj = nn.Linear(input_channels, width, device=device, dtype=dtype)
|
||||
self.output_proj = nn.Linear(width, output_channels, device=device, dtype=dtype)
|
||||
|
||||
# timestep embedding
|
||||
self.time_embed = Timesteps(width, flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=0)
|
||||
self.time_proj = MLP(
|
||||
device=device, dtype=dtype, width=width, init_scale=init_scale
|
||||
)
|
||||
|
||||
self.context_embed = nn.Sequential(
|
||||
nn.LayerNorm(context_dim, device=device, dtype=dtype),
|
||||
nn.Linear(context_dim, width, device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
if context_ln:
|
||||
self.context_embed = nn.Sequential(
|
||||
nn.LayerNorm(context_dim, device=device, dtype=dtype),
|
||||
nn.Linear(context_dim, width, device=device, dtype=dtype),
|
||||
)
|
||||
else:
|
||||
self.context_embed = nn.Linear(context_dim, width, device=device, dtype=dtype)
|
||||
|
||||
def forward(self,
|
||||
model_input: torch.FloatTensor,
|
||||
timestep: torch.LongTensor,
|
||||
context: torch.FloatTensor):
|
||||
|
||||
r"""
|
||||
Args:
|
||||
model_input (torch.FloatTensor): [bs, n_data, c]
|
||||
timestep (torch.LongTensor): [bs,]
|
||||
context (torch.FloatTensor): [bs, context_tokens, c]
|
||||
|
||||
Returns:
|
||||
sample (torch.FloatTensor): [bs, n_data, c]
|
||||
|
||||
"""
|
||||
|
||||
_, n_data, _ = model_input.shape
|
||||
|
||||
# 1. time
|
||||
t_emb = self.time_proj(self.time_embed(timestep)).unsqueeze(dim=1)
|
||||
|
||||
# 2. conditions projector
|
||||
context = self.context_embed(context)
|
||||
|
||||
# 3. denoiser
|
||||
x = self.input_proj(model_input)
|
||||
x = torch.cat([t_emb, context, x], dim=1)
|
||||
x = self.backbone(x)
|
||||
x = self.ln_post(x)
|
||||
x = x[:, -n_data:]
|
||||
sample = self.output_proj(x)
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
13
primitive_anything/michelangelo/models/asl_diffusion/base.py
Executable file
13
primitive_anything/michelangelo/models/asl_diffusion/base.py
Executable file
@@ -0,0 +1,13 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class BaseDenoiser(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, t, context):
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,393 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from omegaconf import DictConfig
|
||||
from typing import List, Tuple, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.optim import lr_scheduler
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
|
||||
from diffusers.schedulers import (
|
||||
DDPMScheduler,
|
||||
DDIMScheduler,
|
||||
KarrasVeScheduler,
|
||||
DPMSolverMultistepScheduler
|
||||
)
|
||||
|
||||
from ...utils import instantiate_from_config
|
||||
from ..tsal.tsal_base import AlignedShapeAsLatentPLModule
|
||||
from .inference_utils import ddim_sample
|
||||
|
||||
SchedulerType = Union[DDIMScheduler, KarrasVeScheduler, DPMSolverMultistepScheduler]
|
||||
|
||||
|
||||
def disabled_train(self, mode=True):
|
||||
"""Overwrite model.train with this function to make sure train/eval mode
|
||||
does not change anymore."""
|
||||
return self
|
||||
|
||||
|
||||
class ClipASLDiffuser(pl.LightningModule):
|
||||
first_stage_model: Optional[AlignedShapeAsLatentPLModule]
|
||||
cond_stage_model: Optional[Union[nn.Module, pl.LightningModule]]
|
||||
model: nn.Module
|
||||
|
||||
def __init__(self, *,
|
||||
first_stage_config,
|
||||
cond_stage_config,
|
||||
denoiser_cfg,
|
||||
scheduler_cfg,
|
||||
optimizer_cfg,
|
||||
loss_cfg,
|
||||
first_stage_key: str = "surface",
|
||||
cond_stage_key: str = "image",
|
||||
scale_by_std: bool = False,
|
||||
z_scale_factor: float = 1.0,
|
||||
ckpt_path: Optional[str] = None,
|
||||
ignore_keys: Union[Tuple[str], List[str]] = ()):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.first_stage_key = first_stage_key
|
||||
self.cond_stage_key = cond_stage_key
|
||||
|
||||
# 1. lazy initialize first stage
|
||||
self.instantiate_first_stage(first_stage_config)
|
||||
|
||||
# 2. initialize conditional stage
|
||||
self.instantiate_cond_stage(cond_stage_config)
|
||||
|
||||
# 3. diffusion model
|
||||
self.model = instantiate_from_config(
|
||||
denoiser_cfg, device=None, dtype=None
|
||||
)
|
||||
|
||||
self.optimizer_cfg = optimizer_cfg
|
||||
|
||||
# 4. scheduling strategy
|
||||
self.scheduler_cfg = scheduler_cfg
|
||||
|
||||
self.noise_scheduler: DDPMScheduler = instantiate_from_config(scheduler_cfg.noise)
|
||||
self.denoise_scheduler: SchedulerType = instantiate_from_config(scheduler_cfg.denoise)
|
||||
|
||||
# 5. loss configures
|
||||
self.loss_cfg = loss_cfg
|
||||
|
||||
self.scale_by_std = scale_by_std
|
||||
if scale_by_std:
|
||||
self.register_buffer("z_scale_factor", torch.tensor(z_scale_factor))
|
||||
else:
|
||||
self.z_scale_factor = z_scale_factor
|
||||
|
||||
self.ckpt_path = ckpt_path
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
|
||||
def instantiate_non_trainable_model(self, config):
|
||||
model = instantiate_from_config(config)
|
||||
model = model.eval()
|
||||
model.train = disabled_train
|
||||
for param in model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
return model
|
||||
|
||||
def instantiate_first_stage(self, first_stage_config):
|
||||
self.first_stage_model = self.instantiate_non_trainable_model(first_stage_config)
|
||||
self.first_stage_model.set_shape_model_only()
|
||||
|
||||
def instantiate_cond_stage(self, cond_stage_config):
|
||||
self.cond_stage_model = self.instantiate_non_trainable_model(cond_stage_config)
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=()):
|
||||
state_dict = torch.load(path, map_location="cpu")["state_dict"]
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
del state_dict[k]
|
||||
|
||||
missing, unexpected = self.load_state_dict(state_dict, strict=False)
|
||||
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
||||
if len(missing) > 0:
|
||||
print(f"Missing Keys: {missing}")
|
||||
print(f"Unexpected Keys: {unexpected}")
|
||||
|
||||
@property
|
||||
def zero_rank(self):
|
||||
if self._trainer:
|
||||
zero_rank = self.trainer.local_rank == 0
|
||||
else:
|
||||
zero_rank = True
|
||||
|
||||
return zero_rank
|
||||
|
||||
def configure_optimizers(self) -> Tuple[List, List]:
|
||||
|
||||
lr = self.learning_rate
|
||||
|
||||
trainable_parameters = list(self.model.parameters())
|
||||
if self.optimizer_cfg is None:
|
||||
optimizers = [torch.optim.AdamW(trainable_parameters, lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)]
|
||||
schedulers = []
|
||||
else:
|
||||
optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=trainable_parameters)
|
||||
scheduler_func = instantiate_from_config(
|
||||
self.optimizer_cfg.scheduler,
|
||||
max_decay_steps=self.trainer.max_steps,
|
||||
lr_max=lr
|
||||
)
|
||||
scheduler = {
|
||||
"scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule),
|
||||
"interval": "step",
|
||||
"frequency": 1
|
||||
}
|
||||
optimizers = [optimizer]
|
||||
schedulers = [scheduler]
|
||||
|
||||
return optimizers, schedulers
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_first_stage(self, surface: torch.FloatTensor, sample_posterior=True):
|
||||
|
||||
z_q = self.first_stage_model.encode(surface, sample_posterior)
|
||||
z_q = self.z_scale_factor * z_q
|
||||
|
||||
return z_q
|
||||
|
||||
@torch.no_grad()
|
||||
def decode_first_stage(self, z_q: torch.FloatTensor, **kwargs):
|
||||
|
||||
z_q = 1. / self.z_scale_factor * z_q
|
||||
latents = self.first_stage_model.decode(z_q, **kwargs)
|
||||
return latents
|
||||
|
||||
@rank_zero_only
|
||||
@torch.no_grad()
|
||||
def on_train_batch_start(self, batch, batch_idx):
|
||||
# only for very first batch
|
||||
if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 \
|
||||
and batch_idx == 0 and self.ckpt_path is None:
|
||||
# set rescale weight to 1./std of encodings
|
||||
print("### USING STD-RESCALING ###")
|
||||
|
||||
z_q = self.encode_first_stage(batch[self.first_stage_key])
|
||||
z = z_q.detach()
|
||||
|
||||
del self.z_scale_factor
|
||||
self.register_buffer("z_scale_factor", 1. / z.flatten().std())
|
||||
print(f"setting self.z_scale_factor to {self.z_scale_factor}")
|
||||
|
||||
print("### USING STD-RESCALING ###")
|
||||
|
||||
def compute_loss(self, model_outputs, split):
|
||||
"""
|
||||
|
||||
Args:
|
||||
model_outputs (dict):
|
||||
- x_0:
|
||||
- noise:
|
||||
- noise_prior:
|
||||
- noise_pred:
|
||||
- noise_pred_prior:
|
||||
|
||||
split (str):
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
|
||||
pred = model_outputs["pred"]
|
||||
|
||||
if self.noise_scheduler.prediction_type == "epsilon":
|
||||
target = model_outputs["noise"]
|
||||
elif self.noise_scheduler.prediction_type == "sample":
|
||||
target = model_outputs["x_0"]
|
||||
else:
|
||||
raise NotImplementedError(f"Prediction Type: {self.noise_scheduler.prediction_type} not yet supported.")
|
||||
|
||||
if self.loss_cfg.loss_type == "l1":
|
||||
simple = F.l1_loss(pred, target, reduction="mean")
|
||||
elif self.loss_cfg.loss_type in ["mse", "l2"]:
|
||||
simple = F.mse_loss(pred, target, reduction="mean")
|
||||
else:
|
||||
raise NotImplementedError(f"Loss Type: {self.loss_cfg.loss_type} not yet supported.")
|
||||
|
||||
total_loss = simple
|
||||
|
||||
loss_dict = {
|
||||
f"{split}/total_loss": total_loss.clone().detach(),
|
||||
f"{split}/simple": simple.detach(),
|
||||
}
|
||||
|
||||
return total_loss, loss_dict
|
||||
|
||||
def forward(self, batch):
|
||||
"""
|
||||
|
||||
Args:
|
||||
batch:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
|
||||
latents = self.encode_first_stage(batch[self.first_stage_key])
|
||||
conditions = self.cond_stage_model.encode(batch[self.cond_stage_key])
|
||||
|
||||
# Sample noise that we"ll add to the latents
|
||||
# [batch_size, n_token, latent_dim]
|
||||
noise = torch.randn_like(latents)
|
||||
bs = latents.shape[0]
|
||||
# Sample a random timestep for each motion
|
||||
timesteps = torch.randint(
|
||||
0,
|
||||
self.noise_scheduler.config.num_train_timesteps,
|
||||
(bs,),
|
||||
device=latents.device,
|
||||
)
|
||||
timesteps = timesteps.long()
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
noisy_z = self.noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# diffusion model forward
|
||||
noise_pred = self.model(noisy_z, timesteps, conditions)
|
||||
|
||||
diffusion_outputs = {
|
||||
"x_0": noisy_z,
|
||||
"noise": noise,
|
||||
"pred": noise_pred
|
||||
}
|
||||
|
||||
return diffusion_outputs
|
||||
|
||||
def training_step(self, batch: Dict[str, Union[torch.FloatTensor, List[str]]],
|
||||
batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
|
||||
"""
|
||||
|
||||
Args:
|
||||
batch (dict): the batch sample, and it contains:
|
||||
- surface (torch.FloatTensor):
|
||||
- image (torch.FloatTensor): if provide, [bs, 3, h, w], item range [0, 1]
|
||||
- depth (torch.FloatTensor): if provide, [bs, 1, h, w], item range [-1, 1]
|
||||
- normal (torch.FloatTensor): if provide, [bs, 3, h, w], item range [-1, 1]
|
||||
- text (list of str):
|
||||
|
||||
batch_idx (int):
|
||||
|
||||
optimizer_idx (int):
|
||||
|
||||
Returns:
|
||||
loss (torch.FloatTensor):
|
||||
|
||||
"""
|
||||
|
||||
diffusion_outputs = self(batch)
|
||||
|
||||
loss, loss_dict = self.compute_loss(diffusion_outputs, "train")
|
||||
self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
|
||||
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch: Dict[str, torch.FloatTensor],
|
||||
batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
|
||||
"""
|
||||
|
||||
Args:
|
||||
batch (dict): the batch sample, and it contains:
|
||||
- surface_pc (torch.FloatTensor): [n_pts, 4]
|
||||
- surface_feats (torch.FloatTensor): [n_pts, c]
|
||||
- text (list of str):
|
||||
|
||||
batch_idx (int):
|
||||
|
||||
optimizer_idx (int):
|
||||
|
||||
Returns:
|
||||
loss (torch.FloatTensor):
|
||||
|
||||
"""
|
||||
|
||||
diffusion_outputs = self(batch)
|
||||
|
||||
loss, loss_dict = self.compute_loss(diffusion_outputs, "val")
|
||||
self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
|
||||
|
||||
return loss
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self,
|
||||
batch: Dict[str, Union[torch.FloatTensor, List[str]]],
|
||||
sample_times: int = 1,
|
||||
steps: Optional[int] = None,
|
||||
guidance_scale: Optional[float] = None,
|
||||
eta: float = 0.0,
|
||||
return_intermediates: bool = False, **kwargs):
|
||||
|
||||
if steps is None:
|
||||
steps = self.scheduler_cfg.num_inference_steps
|
||||
|
||||
if guidance_scale is None:
|
||||
guidance_scale = self.scheduler_cfg.guidance_scale
|
||||
do_classifier_free_guidance = guidance_scale > 0
|
||||
|
||||
# conditional encode
|
||||
xc = batch[self.cond_stage_key]
|
||||
|
||||
# print(self.first_stage_model.device, self.cond_stage_model.device, self.device)
|
||||
|
||||
cond = self.cond_stage_model(xc)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
un_cond = self.cond_stage_model.unconditional_embedding(batch_size=len(xc))
|
||||
cond = torch.cat([un_cond, cond], dim=0)
|
||||
|
||||
outputs = []
|
||||
latents = None
|
||||
|
||||
if not return_intermediates:
|
||||
for _ in range(sample_times):
|
||||
sample_loop = ddim_sample(
|
||||
self.denoise_scheduler,
|
||||
self.model,
|
||||
shape=self.first_stage_model.latent_shape,
|
||||
cond=cond,
|
||||
steps=steps,
|
||||
guidance_scale=guidance_scale,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
device=self.device,
|
||||
eta=eta,
|
||||
disable_prog=not self.zero_rank
|
||||
)
|
||||
for sample, t in sample_loop:
|
||||
latents = sample
|
||||
outputs.append(self.decode_first_stage(latents, **kwargs))
|
||||
else:
|
||||
|
||||
sample_loop = ddim_sample(
|
||||
self.denoise_scheduler,
|
||||
self.model,
|
||||
shape=self.first_stage_model.latent_shape,
|
||||
cond=cond,
|
||||
steps=steps,
|
||||
guidance_scale=guidance_scale,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
device=self.device,
|
||||
eta=eta,
|
||||
disable_prog=not self.zero_rank
|
||||
)
|
||||
|
||||
iter_size = steps // sample_times
|
||||
i = 0
|
||||
for sample, t in sample_loop:
|
||||
latents = sample
|
||||
if i % iter_size == 0 or i == steps - 1:
|
||||
outputs.append(self.decode_first_stage(latents, **kwargs))
|
||||
i += 1
|
||||
|
||||
return outputs
|
||||
80
primitive_anything/michelangelo/models/asl_diffusion/inference_utils.py
Executable file
80
primitive_anything/michelangelo/models/asl_diffusion/inference_utils.py
Executable file
@@ -0,0 +1,80 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from typing import Tuple, List, Union, Optional
|
||||
from diffusers.schedulers import DDIMScheduler
|
||||
|
||||
|
||||
__all__ = ["ddim_sample"]
|
||||
|
||||
|
||||
def ddim_sample(ddim_scheduler: DDIMScheduler,
|
||||
diffusion_model: torch.nn.Module,
|
||||
shape: Union[List[int], Tuple[int]],
|
||||
cond: torch.FloatTensor,
|
||||
steps: int,
|
||||
eta: float = 0.0,
|
||||
guidance_scale: float = 3.0,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
device: torch.device = "cuda:0",
|
||||
disable_prog: bool = True):
|
||||
|
||||
assert steps > 0, f"{steps} must > 0."
|
||||
|
||||
# init latents
|
||||
bsz = cond.shape[0]
|
||||
if do_classifier_free_guidance:
|
||||
bsz = bsz // 2
|
||||
|
||||
latents = torch.randn(
|
||||
(bsz, *shape),
|
||||
generator=generator,
|
||||
device=cond.device,
|
||||
dtype=cond.dtype,
|
||||
)
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * ddim_scheduler.init_noise_sigma
|
||||
# set timesteps
|
||||
ddim_scheduler.set_timesteps(steps)
|
||||
timesteps = ddim_scheduler.timesteps.to(device)
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, and between [0, 1]
|
||||
extra_step_kwargs = {
|
||||
"eta": eta,
|
||||
"generator": generator
|
||||
}
|
||||
|
||||
# reverse
|
||||
for i, t in enumerate(tqdm(timesteps, disable=disable_prog, desc="DDIM Sampling:", leave=False)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = (
|
||||
torch.cat([latents] * 2)
|
||||
if do_classifier_free_guidance
|
||||
else latents
|
||||
)
|
||||
# latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
||||
# predict the noise residual
|
||||
timestep_tensor = torch.tensor([t], dtype=torch.long, device=device)
|
||||
timestep_tensor = timestep_tensor.expand(latent_model_input.shape[0])
|
||||
noise_pred = diffusion_model.forward(latent_model_input, timestep_tensor, cond)
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
# text_embeddings_for_guidance = encoder_hidden_states.chunk(
|
||||
# 2)[1] if do_classifier_free_guidance else encoder_hidden_states
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = ddim_scheduler.step(
|
||||
noise_pred, t, latents, **extra_step_kwargs
|
||||
).prev_sample
|
||||
|
||||
yield latents, t
|
||||
|
||||
|
||||
def karra_sample():
|
||||
pass
|
||||
Reference in New Issue
Block a user