This commit is contained in:
hyz317
2025-05-07 16:51:22 +08:00
commit 87c3ed5e40
54 changed files with 8014 additions and 0 deletions

View File

@@ -0,0 +1 @@
# -*- coding: utf-8 -*-

View File

@@ -0,0 +1,483 @@
# -*- coding: utf-8 -*-
from omegaconf import DictConfig
from typing import List, Tuple, Dict, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import lr_scheduler
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_only
from einops import rearrange
from diffusers.schedulers import (
DDPMScheduler,
DDIMScheduler,
KarrasVeScheduler,
DPMSolverMultistepScheduler
)
from ...utils import instantiate_from_config
# from ..tsal.tsal_base import ShapeAsLatentPLModule
from ..tsal.tsal_base import AlignedShapeAsLatentPLModule
from .inference_utils import ddim_sample
SchedulerType = Union[DDIMScheduler, KarrasVeScheduler, DPMSolverMultistepScheduler]
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
class ASLDiffuser(pl.LightningModule):
first_stage_model: Optional[AlignedShapeAsLatentPLModule]
# cond_stage_model: Optional[Union[nn.Module, pl.LightningModule]]
model: nn.Module
def __init__(self, *,
first_stage_config,
denoiser_cfg,
scheduler_cfg,
optimizer_cfg,
loss_cfg,
first_stage_key: str = "surface",
cond_stage_key: str = "image",
cond_stage_trainable: bool = True,
scale_by_std: bool = False,
z_scale_factor: float = 1.0,
ckpt_path: Optional[str] = None,
ignore_keys: Union[Tuple[str], List[str]] = ()):
super().__init__()
self.first_stage_key = first_stage_key
self.cond_stage_key = cond_stage_key
self.cond_stage_trainable = cond_stage_trainable
# 1. initialize first stage.
# Note: the condition model contained in the first stage model.
self.first_stage_config = first_stage_config
self.first_stage_model = None
# self.instantiate_first_stage(first_stage_config)
# 2. initialize conditional stage
# self.instantiate_cond_stage(cond_stage_config)
self.cond_stage_model = {
"image": self.encode_image,
"image_unconditional_embedding": self.empty_img_cond,
"text": self.encode_text,
"text_unconditional_embedding": self.empty_text_cond,
"surface": self.encode_surface,
"surface_unconditional_embedding": self.empty_surface_cond,
}
# 3. diffusion model
self.model = instantiate_from_config(
denoiser_cfg, device=None, dtype=None
)
self.optimizer_cfg = optimizer_cfg
# 4. scheduling strategy
self.scheduler_cfg = scheduler_cfg
self.noise_scheduler: DDPMScheduler = instantiate_from_config(scheduler_cfg.noise)
self.denoise_scheduler: SchedulerType = instantiate_from_config(scheduler_cfg.denoise)
# 5. loss configures
self.loss_cfg = loss_cfg
self.scale_by_std = scale_by_std
if scale_by_std:
self.register_buffer("z_scale_factor", torch.tensor(z_scale_factor))
else:
self.z_scale_factor = z_scale_factor
self.ckpt_path = ckpt_path
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def instantiate_first_stage(self, config):
model = instantiate_from_config(config)
self.first_stage_model = model.eval()
self.first_stage_model.train = disabled_train
for param in self.first_stage_model.parameters():
param.requires_grad = False
self.first_stage_model = self.first_stage_model.to(self.device)
# def instantiate_cond_stage(self, config):
# if not self.cond_stage_trainable:
# if config == "__is_first_stage__":
# print("Using first stage also as cond stage.")
# self.cond_stage_model = self.first_stage_model
# elif config == "__is_unconditional__":
# print(f"Training {self.__class__.__name__} as an unconditional model.")
# self.cond_stage_model = None
# # self.be_unconditional = True
# else:
# model = instantiate_from_config(config)
# self.cond_stage_model = model.eval()
# self.cond_stage_model.train = disabled_train
# for param in self.cond_stage_model.parameters():
# param.requires_grad = False
# else:
# assert config != "__is_first_stage__"
# assert config != "__is_unconditional__"
# model = instantiate_from_config(config)
# self.cond_stage_model = model
def init_from_ckpt(self, path, ignore_keys=()):
state_dict = torch.load(path, map_location="cpu")["state_dict"]
keys = list(state_dict.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del state_dict[k]
missing, unexpected = self.load_state_dict(state_dict, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if len(missing) > 0:
print(f"Missing Keys: {missing}")
print(f"Unexpected Keys: {unexpected}")
@property
def zero_rank(self):
if self._trainer:
zero_rank = self.trainer.local_rank == 0
else:
zero_rank = True
return zero_rank
def configure_optimizers(self) -> Tuple[List, List]:
lr = self.learning_rate
trainable_parameters = list(self.model.parameters())
# if the conditional encoder is trainable
# if self.cond_stage_trainable:
# conditioner_params = [p for p in self.cond_stage_model.parameters() if p.requires_grad]
# trainable_parameters += conditioner_params
# print(f"number of trainable conditional parameters: {len(conditioner_params)}.")
if self.optimizer_cfg is None:
optimizers = [torch.optim.AdamW(trainable_parameters, lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)]
schedulers = []
else:
optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=trainable_parameters)
scheduler_func = instantiate_from_config(
self.optimizer_cfg.scheduler,
max_decay_steps=self.trainer.max_steps,
lr_max=lr
)
scheduler = {
"scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule),
"interval": "step",
"frequency": 1
}
optimizers = [optimizer]
schedulers = [scheduler]
return optimizers, schedulers
@torch.no_grad()
def encode_text(self, text):
b = text.shape[0]
text_tokens = rearrange(text, "b t l -> (b t) l")
text_embed = self.first_stage_model.model.encode_text_embed(text_tokens)
text_embed = rearrange(text_embed, "(b t) d -> b t d", b=b)
text_embed = text_embed.mean(dim=1)
text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
return text_embed
@torch.no_grad()
def encode_image(self, img):
return self.first_stage_model.model.encode_image_embed(img)
@torch.no_grad()
def encode_surface(self, surface):
return self.first_stage_model.model.encode_shape_embed(surface, return_latents=False)
@torch.no_grad()
def empty_text_cond(self, cond):
return torch.zeros_like(cond, device=cond.device)
@torch.no_grad()
def empty_img_cond(self, cond):
return torch.zeros_like(cond, device=cond.device)
@torch.no_grad()
def empty_surface_cond(self, cond):
return torch.zeros_like(cond, device=cond.device)
@torch.no_grad()
def encode_first_stage(self, surface: torch.FloatTensor, sample_posterior=True):
z_q = self.first_stage_model.encode(surface, sample_posterior)
z_q = self.z_scale_factor * z_q
return z_q
@torch.no_grad()
def decode_first_stage(self, z_q: torch.FloatTensor, **kwargs):
z_q = 1. / self.z_scale_factor * z_q
latents = self.first_stage_model.decode(z_q, **kwargs)
return latents
@rank_zero_only
@torch.no_grad()
def on_train_batch_start(self, batch, batch_idx):
# only for very first batch
if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 \
and batch_idx == 0 and self.ckpt_path is None:
# set rescale weight to 1./std of encodings
print("### USING STD-RESCALING ###")
z_q = self.encode_first_stage(batch[self.first_stage_key])
z = z_q.detach()
del self.z_scale_factor
self.register_buffer("z_scale_factor", 1. / z.flatten().std())
print(f"setting self.z_scale_factor to {self.z_scale_factor}")
print("### USING STD-RESCALING ###")
def compute_loss(self, model_outputs, split):
"""
Args:
model_outputs (dict):
- x_0:
- noise:
- noise_prior:
- noise_pred:
- noise_pred_prior:
split (str):
Returns:
"""
pred = model_outputs["pred"]
if self.noise_scheduler.prediction_type == "epsilon":
target = model_outputs["noise"]
elif self.noise_scheduler.prediction_type == "sample":
target = model_outputs["x_0"]
else:
raise NotImplementedError(f"Prediction Type: {self.noise_scheduler.prediction_type} not yet supported.")
if self.loss_cfg.loss_type == "l1":
simple = F.l1_loss(pred, target, reduction="mean")
elif self.loss_cfg.loss_type in ["mse", "l2"]:
simple = F.mse_loss(pred, target, reduction="mean")
else:
raise NotImplementedError(f"Loss Type: {self.loss_cfg.loss_type} not yet supported.")
total_loss = simple
loss_dict = {
f"{split}/total_loss": total_loss.clone().detach(),
f"{split}/simple": simple.detach(),
}
return total_loss, loss_dict
def forward(self, batch):
"""
Args:
batch:
Returns:
"""
if self.first_stage_model is None:
self.instantiate_first_stage(self.first_stage_config)
latents = self.encode_first_stage(batch[self.first_stage_key])
# conditions = self.cond_stage_model.encode(batch[self.cond_stage_key])
conditions = self.cond_stage_model[self.cond_stage_key](batch[self.cond_stage_key]).unsqueeze(1)
mask = torch.rand((len(conditions), 1, 1), device=conditions.device, dtype=conditions.dtype) >= 0.1
conditions = conditions * mask.to(conditions)
# Sample noise that we"ll add to the latents
# [batch_size, n_token, latent_dim]
noise = torch.randn_like(latents)
bs = latents.shape[0]
# Sample a random timestep for each motion
timesteps = torch.randint(
0,
self.noise_scheduler.config.num_train_timesteps,
(bs,),
device=latents.device,
)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
noisy_z = self.noise_scheduler.add_noise(latents, noise, timesteps)
# diffusion model forward
noise_pred = self.model(noisy_z, timesteps, conditions)
diffusion_outputs = {
"x_0": noisy_z,
"noise": noise,
"pred": noise_pred
}
return diffusion_outputs
def training_step(self, batch: Dict[str, Union[torch.FloatTensor, List[str]]],
batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
"""
Args:
batch (dict): the batch sample, and it contains:
- surface (torch.FloatTensor):
- image (torch.FloatTensor): if provide, [bs, 3, h, w], item range [0, 1]
- depth (torch.FloatTensor): if provide, [bs, 1, h, w], item range [-1, 1]
- normal (torch.FloatTensor): if provide, [bs, 3, h, w], item range [-1, 1]
- text (list of str):
batch_idx (int):
optimizer_idx (int):
Returns:
loss (torch.FloatTensor):
"""
diffusion_outputs = self(batch)
loss, loss_dict = self.compute_loss(diffusion_outputs, "train")
self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
return loss
def validation_step(self, batch: Dict[str, torch.FloatTensor],
batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
"""
Args:
batch (dict): the batch sample, and it contains:
- surface_pc (torch.FloatTensor): [n_pts, 4]
- surface_feats (torch.FloatTensor): [n_pts, c]
- text (list of str):
batch_idx (int):
optimizer_idx (int):
Returns:
loss (torch.FloatTensor):
"""
diffusion_outputs = self(batch)
loss, loss_dict = self.compute_loss(diffusion_outputs, "val")
self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
return loss
@torch.no_grad()
def sample(self,
batch: Dict[str, Union[torch.FloatTensor, List[str]]],
sample_times: int = 1,
steps: Optional[int] = None,
guidance_scale: Optional[float] = None,
eta: float = 0.0,
return_intermediates: bool = False, **kwargs):
if self.first_stage_model is None:
self.instantiate_first_stage(self.first_stage_config)
if steps is None:
steps = self.scheduler_cfg.num_inference_steps
if guidance_scale is None:
guidance_scale = self.scheduler_cfg.guidance_scale
do_classifier_free_guidance = guidance_scale > 0
# conditional encode
xc = batch[self.cond_stage_key]
# cond = self.cond_stage_model[self.cond_stage_key](xc)
cond = self.cond_stage_model[self.cond_stage_key](xc).unsqueeze(1)
if do_classifier_free_guidance:
"""
Note: There are two kinds of uncond for text.
1: using "" as uncond text; (in SAL diffusion)
2: zeros_like(cond) as uncond text; (in MDM)
"""
# un_cond = self.cond_stage_model.unconditional_embedding(batch_size=len(xc))
un_cond = self.cond_stage_model[f"{self.cond_stage_key}_unconditional_embedding"](cond)
# un_cond = torch.zeros_like(cond, device=cond.device)
cond = torch.cat([un_cond, cond], dim=0)
outputs = []
latents = None
if not return_intermediates:
for _ in range(sample_times):
sample_loop = ddim_sample(
self.denoise_scheduler,
self.model,
shape=self.first_stage_model.latent_shape,
cond=cond,
steps=steps,
guidance_scale=guidance_scale,
do_classifier_free_guidance=do_classifier_free_guidance,
device=self.device,
eta=eta,
disable_prog=not self.zero_rank
)
for sample, t in sample_loop:
latents = sample
outputs.append(self.decode_first_stage(latents, **kwargs))
else:
sample_loop = ddim_sample(
self.denoise_scheduler,
self.model,
shape=self.first_stage_model.latent_shape,
cond=cond,
steps=steps,
guidance_scale=guidance_scale,
do_classifier_free_guidance=do_classifier_free_guidance,
device=self.device,
eta=eta,
disable_prog=not self.zero_rank
)
iter_size = steps // sample_times
i = 0
for sample, t in sample_loop:
latents = sample
if i % iter_size == 0 or i == steps - 1:
outputs.append(self.decode_first_stage(latents, **kwargs))
i += 1
return outputs

View File

@@ -0,0 +1,104 @@
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
from typing import Optional
from diffusers.models.embeddings import Timesteps
import math
from ..modules.transformer_blocks import MLP
from ..modules.diffusion_transformer import UNetDiffusionTransformer
class ConditionalASLUDTDenoiser(nn.Module):
def __init__(self, *,
device: Optional[torch.device],
dtype: Optional[torch.dtype],
input_channels: int,
output_channels: int,
n_ctx: int,
width: int,
layers: int,
heads: int,
context_dim: int,
context_ln: bool = True,
skip_ln: bool = False,
init_scale: float = 0.25,
flip_sin_to_cos: bool = False,
use_checkpoint: bool = False):
super().__init__()
self.use_checkpoint = use_checkpoint
init_scale = init_scale * math.sqrt(1.0 / width)
self.backbone = UNetDiffusionTransformer(
device=device,
dtype=dtype,
n_ctx=n_ctx,
width=width,
layers=layers,
heads=heads,
skip_ln=skip_ln,
init_scale=init_scale,
use_checkpoint=use_checkpoint
)
self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
self.input_proj = nn.Linear(input_channels, width, device=device, dtype=dtype)
self.output_proj = nn.Linear(width, output_channels, device=device, dtype=dtype)
# timestep embedding
self.time_embed = Timesteps(width, flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=0)
self.time_proj = MLP(
device=device, dtype=dtype, width=width, init_scale=init_scale
)
self.context_embed = nn.Sequential(
nn.LayerNorm(context_dim, device=device, dtype=dtype),
nn.Linear(context_dim, width, device=device, dtype=dtype),
)
if context_ln:
self.context_embed = nn.Sequential(
nn.LayerNorm(context_dim, device=device, dtype=dtype),
nn.Linear(context_dim, width, device=device, dtype=dtype),
)
else:
self.context_embed = nn.Linear(context_dim, width, device=device, dtype=dtype)
def forward(self,
model_input: torch.FloatTensor,
timestep: torch.LongTensor,
context: torch.FloatTensor):
r"""
Args:
model_input (torch.FloatTensor): [bs, n_data, c]
timestep (torch.LongTensor): [bs,]
context (torch.FloatTensor): [bs, context_tokens, c]
Returns:
sample (torch.FloatTensor): [bs, n_data, c]
"""
_, n_data, _ = model_input.shape
# 1. time
t_emb = self.time_proj(self.time_embed(timestep)).unsqueeze(dim=1)
# 2. conditions projector
context = self.context_embed(context)
# 3. denoiser
x = self.input_proj(model_input)
x = torch.cat([t_emb, context, x], dim=1)
x = self.backbone(x)
x = self.ln_post(x)
x = x[:, -n_data:]
sample = self.output_proj(x)
return sample

View File

@@ -0,0 +1,13 @@
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
class BaseDenoiser(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, t, context):
raise NotImplementedError

View File

@@ -0,0 +1,393 @@
# -*- coding: utf-8 -*-
from omegaconf import DictConfig
from typing import List, Tuple, Dict, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import lr_scheduler
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_only
from diffusers.schedulers import (
DDPMScheduler,
DDIMScheduler,
KarrasVeScheduler,
DPMSolverMultistepScheduler
)
from ...utils import instantiate_from_config
from ..tsal.tsal_base import AlignedShapeAsLatentPLModule
from .inference_utils import ddim_sample
SchedulerType = Union[DDIMScheduler, KarrasVeScheduler, DPMSolverMultistepScheduler]
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
class ClipASLDiffuser(pl.LightningModule):
first_stage_model: Optional[AlignedShapeAsLatentPLModule]
cond_stage_model: Optional[Union[nn.Module, pl.LightningModule]]
model: nn.Module
def __init__(self, *,
first_stage_config,
cond_stage_config,
denoiser_cfg,
scheduler_cfg,
optimizer_cfg,
loss_cfg,
first_stage_key: str = "surface",
cond_stage_key: str = "image",
scale_by_std: bool = False,
z_scale_factor: float = 1.0,
ckpt_path: Optional[str] = None,
ignore_keys: Union[Tuple[str], List[str]] = ()):
super().__init__()
self.first_stage_key = first_stage_key
self.cond_stage_key = cond_stage_key
# 1. lazy initialize first stage
self.instantiate_first_stage(first_stage_config)
# 2. initialize conditional stage
self.instantiate_cond_stage(cond_stage_config)
# 3. diffusion model
self.model = instantiate_from_config(
denoiser_cfg, device=None, dtype=None
)
self.optimizer_cfg = optimizer_cfg
# 4. scheduling strategy
self.scheduler_cfg = scheduler_cfg
self.noise_scheduler: DDPMScheduler = instantiate_from_config(scheduler_cfg.noise)
self.denoise_scheduler: SchedulerType = instantiate_from_config(scheduler_cfg.denoise)
# 5. loss configures
self.loss_cfg = loss_cfg
self.scale_by_std = scale_by_std
if scale_by_std:
self.register_buffer("z_scale_factor", torch.tensor(z_scale_factor))
else:
self.z_scale_factor = z_scale_factor
self.ckpt_path = ckpt_path
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def instantiate_non_trainable_model(self, config):
model = instantiate_from_config(config)
model = model.eval()
model.train = disabled_train
for param in model.parameters():
param.requires_grad = False
return model
def instantiate_first_stage(self, first_stage_config):
self.first_stage_model = self.instantiate_non_trainable_model(first_stage_config)
self.first_stage_model.set_shape_model_only()
def instantiate_cond_stage(self, cond_stage_config):
self.cond_stage_model = self.instantiate_non_trainable_model(cond_stage_config)
def init_from_ckpt(self, path, ignore_keys=()):
state_dict = torch.load(path, map_location="cpu")["state_dict"]
keys = list(state_dict.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del state_dict[k]
missing, unexpected = self.load_state_dict(state_dict, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if len(missing) > 0:
print(f"Missing Keys: {missing}")
print(f"Unexpected Keys: {unexpected}")
@property
def zero_rank(self):
if self._trainer:
zero_rank = self.trainer.local_rank == 0
else:
zero_rank = True
return zero_rank
def configure_optimizers(self) -> Tuple[List, List]:
lr = self.learning_rate
trainable_parameters = list(self.model.parameters())
if self.optimizer_cfg is None:
optimizers = [torch.optim.AdamW(trainable_parameters, lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)]
schedulers = []
else:
optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=trainable_parameters)
scheduler_func = instantiate_from_config(
self.optimizer_cfg.scheduler,
max_decay_steps=self.trainer.max_steps,
lr_max=lr
)
scheduler = {
"scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule),
"interval": "step",
"frequency": 1
}
optimizers = [optimizer]
schedulers = [scheduler]
return optimizers, schedulers
@torch.no_grad()
def encode_first_stage(self, surface: torch.FloatTensor, sample_posterior=True):
z_q = self.first_stage_model.encode(surface, sample_posterior)
z_q = self.z_scale_factor * z_q
return z_q
@torch.no_grad()
def decode_first_stage(self, z_q: torch.FloatTensor, **kwargs):
z_q = 1. / self.z_scale_factor * z_q
latents = self.first_stage_model.decode(z_q, **kwargs)
return latents
@rank_zero_only
@torch.no_grad()
def on_train_batch_start(self, batch, batch_idx):
# only for very first batch
if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 \
and batch_idx == 0 and self.ckpt_path is None:
# set rescale weight to 1./std of encodings
print("### USING STD-RESCALING ###")
z_q = self.encode_first_stage(batch[self.first_stage_key])
z = z_q.detach()
del self.z_scale_factor
self.register_buffer("z_scale_factor", 1. / z.flatten().std())
print(f"setting self.z_scale_factor to {self.z_scale_factor}")
print("### USING STD-RESCALING ###")
def compute_loss(self, model_outputs, split):
"""
Args:
model_outputs (dict):
- x_0:
- noise:
- noise_prior:
- noise_pred:
- noise_pred_prior:
split (str):
Returns:
"""
pred = model_outputs["pred"]
if self.noise_scheduler.prediction_type == "epsilon":
target = model_outputs["noise"]
elif self.noise_scheduler.prediction_type == "sample":
target = model_outputs["x_0"]
else:
raise NotImplementedError(f"Prediction Type: {self.noise_scheduler.prediction_type} not yet supported.")
if self.loss_cfg.loss_type == "l1":
simple = F.l1_loss(pred, target, reduction="mean")
elif self.loss_cfg.loss_type in ["mse", "l2"]:
simple = F.mse_loss(pred, target, reduction="mean")
else:
raise NotImplementedError(f"Loss Type: {self.loss_cfg.loss_type} not yet supported.")
total_loss = simple
loss_dict = {
f"{split}/total_loss": total_loss.clone().detach(),
f"{split}/simple": simple.detach(),
}
return total_loss, loss_dict
def forward(self, batch):
"""
Args:
batch:
Returns:
"""
latents = self.encode_first_stage(batch[self.first_stage_key])
conditions = self.cond_stage_model.encode(batch[self.cond_stage_key])
# Sample noise that we"ll add to the latents
# [batch_size, n_token, latent_dim]
noise = torch.randn_like(latents)
bs = latents.shape[0]
# Sample a random timestep for each motion
timesteps = torch.randint(
0,
self.noise_scheduler.config.num_train_timesteps,
(bs,),
device=latents.device,
)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
noisy_z = self.noise_scheduler.add_noise(latents, noise, timesteps)
# diffusion model forward
noise_pred = self.model(noisy_z, timesteps, conditions)
diffusion_outputs = {
"x_0": noisy_z,
"noise": noise,
"pred": noise_pred
}
return diffusion_outputs
def training_step(self, batch: Dict[str, Union[torch.FloatTensor, List[str]]],
batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
"""
Args:
batch (dict): the batch sample, and it contains:
- surface (torch.FloatTensor):
- image (torch.FloatTensor): if provide, [bs, 3, h, w], item range [0, 1]
- depth (torch.FloatTensor): if provide, [bs, 1, h, w], item range [-1, 1]
- normal (torch.FloatTensor): if provide, [bs, 3, h, w], item range [-1, 1]
- text (list of str):
batch_idx (int):
optimizer_idx (int):
Returns:
loss (torch.FloatTensor):
"""
diffusion_outputs = self(batch)
loss, loss_dict = self.compute_loss(diffusion_outputs, "train")
self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
return loss
def validation_step(self, batch: Dict[str, torch.FloatTensor],
batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
"""
Args:
batch (dict): the batch sample, and it contains:
- surface_pc (torch.FloatTensor): [n_pts, 4]
- surface_feats (torch.FloatTensor): [n_pts, c]
- text (list of str):
batch_idx (int):
optimizer_idx (int):
Returns:
loss (torch.FloatTensor):
"""
diffusion_outputs = self(batch)
loss, loss_dict = self.compute_loss(diffusion_outputs, "val")
self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
return loss
@torch.no_grad()
def sample(self,
batch: Dict[str, Union[torch.FloatTensor, List[str]]],
sample_times: int = 1,
steps: Optional[int] = None,
guidance_scale: Optional[float] = None,
eta: float = 0.0,
return_intermediates: bool = False, **kwargs):
if steps is None:
steps = self.scheduler_cfg.num_inference_steps
if guidance_scale is None:
guidance_scale = self.scheduler_cfg.guidance_scale
do_classifier_free_guidance = guidance_scale > 0
# conditional encode
xc = batch[self.cond_stage_key]
# print(self.first_stage_model.device, self.cond_stage_model.device, self.device)
cond = self.cond_stage_model(xc)
if do_classifier_free_guidance:
un_cond = self.cond_stage_model.unconditional_embedding(batch_size=len(xc))
cond = torch.cat([un_cond, cond], dim=0)
outputs = []
latents = None
if not return_intermediates:
for _ in range(sample_times):
sample_loop = ddim_sample(
self.denoise_scheduler,
self.model,
shape=self.first_stage_model.latent_shape,
cond=cond,
steps=steps,
guidance_scale=guidance_scale,
do_classifier_free_guidance=do_classifier_free_guidance,
device=self.device,
eta=eta,
disable_prog=not self.zero_rank
)
for sample, t in sample_loop:
latents = sample
outputs.append(self.decode_first_stage(latents, **kwargs))
else:
sample_loop = ddim_sample(
self.denoise_scheduler,
self.model,
shape=self.first_stage_model.latent_shape,
cond=cond,
steps=steps,
guidance_scale=guidance_scale,
do_classifier_free_guidance=do_classifier_free_guidance,
device=self.device,
eta=eta,
disable_prog=not self.zero_rank
)
iter_size = steps // sample_times
i = 0
for sample, t in sample_loop:
latents = sample
if i % iter_size == 0 or i == steps - 1:
outputs.append(self.decode_first_stage(latents, **kwargs))
i += 1
return outputs

View File

@@ -0,0 +1,80 @@
# -*- coding: utf-8 -*-
import torch
from tqdm import tqdm
from typing import Tuple, List, Union, Optional
from diffusers.schedulers import DDIMScheduler
__all__ = ["ddim_sample"]
def ddim_sample(ddim_scheduler: DDIMScheduler,
diffusion_model: torch.nn.Module,
shape: Union[List[int], Tuple[int]],
cond: torch.FloatTensor,
steps: int,
eta: float = 0.0,
guidance_scale: float = 3.0,
do_classifier_free_guidance: bool = True,
generator: Optional[torch.Generator] = None,
device: torch.device = "cuda:0",
disable_prog: bool = True):
assert steps > 0, f"{steps} must > 0."
# init latents
bsz = cond.shape[0]
if do_classifier_free_guidance:
bsz = bsz // 2
latents = torch.randn(
(bsz, *shape),
generator=generator,
device=cond.device,
dtype=cond.dtype,
)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * ddim_scheduler.init_noise_sigma
# set timesteps
ddim_scheduler.set_timesteps(steps)
timesteps = ddim_scheduler.timesteps.to(device)
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, and between [0, 1]
extra_step_kwargs = {
"eta": eta,
"generator": generator
}
# reverse
for i, t in enumerate(tqdm(timesteps, disable=disable_prog, desc="DDIM Sampling:", leave=False)):
# expand the latents if we are doing classifier free guidance
latent_model_input = (
torch.cat([latents] * 2)
if do_classifier_free_guidance
else latents
)
# latent_model_input = scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
timestep_tensor = torch.tensor([t], dtype=torch.long, device=device)
timestep_tensor = timestep_tensor.expand(latent_model_input.shape[0])
noise_pred = diffusion_model.forward(latent_model_input, timestep_tensor, cond)
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# text_embeddings_for_guidance = encoder_hidden_states.chunk(
# 2)[1] if do_classifier_free_guidance else encoder_hidden_states
# compute the previous noisy sample x_t -> x_t-1
latents = ddim_scheduler.step(
noise_pred, t, latents, **extra_step_kwargs
).prev_sample
yield latents, t
def karra_sample():
pass