mirror of
https://github.com/PrimitiveAnything/PrimitiveAnything.git
synced 2026-05-08 00:58:55 +08:00
init
This commit is contained in:
80
primitive_anything/michelangelo/models/asl_diffusion/inference_utils.py
Executable file
80
primitive_anything/michelangelo/models/asl_diffusion/inference_utils.py
Executable file
@@ -0,0 +1,80 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from typing import Tuple, List, Union, Optional
|
||||
from diffusers.schedulers import DDIMScheduler
|
||||
|
||||
|
||||
__all__ = ["ddim_sample"]
|
||||
|
||||
|
||||
def ddim_sample(ddim_scheduler: DDIMScheduler,
|
||||
diffusion_model: torch.nn.Module,
|
||||
shape: Union[List[int], Tuple[int]],
|
||||
cond: torch.FloatTensor,
|
||||
steps: int,
|
||||
eta: float = 0.0,
|
||||
guidance_scale: float = 3.0,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
device: torch.device = "cuda:0",
|
||||
disable_prog: bool = True):
|
||||
|
||||
assert steps > 0, f"{steps} must > 0."
|
||||
|
||||
# init latents
|
||||
bsz = cond.shape[0]
|
||||
if do_classifier_free_guidance:
|
||||
bsz = bsz // 2
|
||||
|
||||
latents = torch.randn(
|
||||
(bsz, *shape),
|
||||
generator=generator,
|
||||
device=cond.device,
|
||||
dtype=cond.dtype,
|
||||
)
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * ddim_scheduler.init_noise_sigma
|
||||
# set timesteps
|
||||
ddim_scheduler.set_timesteps(steps)
|
||||
timesteps = ddim_scheduler.timesteps.to(device)
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, and between [0, 1]
|
||||
extra_step_kwargs = {
|
||||
"eta": eta,
|
||||
"generator": generator
|
||||
}
|
||||
|
||||
# reverse
|
||||
for i, t in enumerate(tqdm(timesteps, disable=disable_prog, desc="DDIM Sampling:", leave=False)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = (
|
||||
torch.cat([latents] * 2)
|
||||
if do_classifier_free_guidance
|
||||
else latents
|
||||
)
|
||||
# latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
||||
# predict the noise residual
|
||||
timestep_tensor = torch.tensor([t], dtype=torch.long, device=device)
|
||||
timestep_tensor = timestep_tensor.expand(latent_model_input.shape[0])
|
||||
noise_pred = diffusion_model.forward(latent_model_input, timestep_tensor, cond)
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
# text_embeddings_for_guidance = encoder_hidden_states.chunk(
|
||||
# 2)[1] if do_classifier_free_guidance else encoder_hidden_states
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = ddim_scheduler.step(
|
||||
noise_pred, t, latents, **extra_step_kwargs
|
||||
).prev_sample
|
||||
|
||||
yield latents, t
|
||||
|
||||
|
||||
def karra_sample():
|
||||
pass
|
||||
Reference in New Issue
Block a user