mirror of
https://github.com/PrimitiveAnything/PrimitiveAnything.git
synced 2025-09-18 05:22:48 +08:00
81 lines
2.8 KiB
Python
Executable File
81 lines
2.8 KiB
Python
Executable File
# -*- coding: utf-8 -*-
|
|
|
|
import torch
|
|
from tqdm import tqdm
|
|
from typing import Tuple, List, Union, Optional
|
|
from diffusers.schedulers import DDIMScheduler
|
|
|
|
|
|
__all__ = ["ddim_sample"]
|
|
|
|
|
|
def ddim_sample(ddim_scheduler: DDIMScheduler,
|
|
diffusion_model: torch.nn.Module,
|
|
shape: Union[List[int], Tuple[int]],
|
|
cond: torch.FloatTensor,
|
|
steps: int,
|
|
eta: float = 0.0,
|
|
guidance_scale: float = 3.0,
|
|
do_classifier_free_guidance: bool = True,
|
|
generator: Optional[torch.Generator] = None,
|
|
device: torch.device = "cuda:0",
|
|
disable_prog: bool = True):
|
|
|
|
assert steps > 0, f"{steps} must > 0."
|
|
|
|
# init latents
|
|
bsz = cond.shape[0]
|
|
if do_classifier_free_guidance:
|
|
bsz = bsz // 2
|
|
|
|
latents = torch.randn(
|
|
(bsz, *shape),
|
|
generator=generator,
|
|
device=cond.device,
|
|
dtype=cond.dtype,
|
|
)
|
|
# scale the initial noise by the standard deviation required by the scheduler
|
|
latents = latents * ddim_scheduler.init_noise_sigma
|
|
# set timesteps
|
|
ddim_scheduler.set_timesteps(steps)
|
|
timesteps = ddim_scheduler.timesteps.to(device)
|
|
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
|
# eta (η) is only used with the DDIMScheduler, and between [0, 1]
|
|
extra_step_kwargs = {
|
|
"eta": eta,
|
|
"generator": generator
|
|
}
|
|
|
|
# reverse
|
|
for i, t in enumerate(tqdm(timesteps, disable=disable_prog, desc="DDIM Sampling:", leave=False)):
|
|
# expand the latents if we are doing classifier free guidance
|
|
latent_model_input = (
|
|
torch.cat([latents] * 2)
|
|
if do_classifier_free_guidance
|
|
else latents
|
|
)
|
|
# latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
|
# predict the noise residual
|
|
timestep_tensor = torch.tensor([t], dtype=torch.long, device=device)
|
|
timestep_tensor = timestep_tensor.expand(latent_model_input.shape[0])
|
|
noise_pred = diffusion_model.forward(latent_model_input, timestep_tensor, cond)
|
|
|
|
# perform guidance
|
|
if do_classifier_free_guidance:
|
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
|
noise_pred = noise_pred_uncond + guidance_scale * (
|
|
noise_pred_text - noise_pred_uncond
|
|
)
|
|
# text_embeddings_for_guidance = encoder_hidden_states.chunk(
|
|
# 2)[1] if do_classifier_free_guidance else encoder_hidden_states
|
|
# compute the previous noisy sample x_t -> x_t-1
|
|
latents = ddim_scheduler.step(
|
|
noise_pred, t, latents, **extra_step_kwargs
|
|
).prev_sample
|
|
|
|
yield latents, t
|
|
|
|
|
|
def karra_sample():
|
|
pass
|