mirror of
https://github.com/PrimitiveAnything/PrimitiveAnything.git
synced 2026-05-08 00:58:55 +08:00
init
This commit is contained in:
218
primitive_anything/michelangelo/models/modules/diffusion_transformer.py
Executable file
218
primitive_anything/michelangelo/models/modules/diffusion_transformer.py
Executable file
@@ -0,0 +1,218 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Optional
|
||||
|
||||
from .checkpoint import checkpoint
|
||||
from .transformer_blocks import (
|
||||
init_linear,
|
||||
MLP,
|
||||
MultiheadCrossAttention,
|
||||
MultiheadAttention,
|
||||
ResidualAttentionBlock
|
||||
)
|
||||
|
||||
|
||||
class AdaLayerNorm(nn.Module):
|
||||
def __init__(self,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
width: int):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.silu = nn.SiLU(inplace=True)
|
||||
self.linear = nn.Linear(width, width * 2, device=device, dtype=dtype)
|
||||
self.layernorm = nn.LayerNorm(width, elementwise_affine=False, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, timestep):
|
||||
emb = self.linear(timestep)
|
||||
scale, shift = torch.chunk(emb, 2, dim=2)
|
||||
x = self.layernorm(x) * (1 + scale) + shift
|
||||
return x
|
||||
|
||||
|
||||
class DitBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
n_ctx: int,
|
||||
width: int,
|
||||
heads: int,
|
||||
context_dim: int,
|
||||
qkv_bias: bool = False,
|
||||
init_scale: float = 1.0,
|
||||
use_checkpoint: bool = False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.use_checkpoint = use_checkpoint
|
||||
|
||||
self.attn = MultiheadAttention(
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
n_ctx=n_ctx,
|
||||
width=width,
|
||||
heads=heads,
|
||||
init_scale=init_scale,
|
||||
qkv_bias=qkv_bias
|
||||
)
|
||||
self.ln_1 = AdaLayerNorm(device, dtype, width)
|
||||
|
||||
if context_dim is not None:
|
||||
self.ln_2 = AdaLayerNorm(device, dtype, width)
|
||||
self.cross_attn = MultiheadCrossAttention(
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
width=width,
|
||||
heads=heads,
|
||||
data_width=context_dim,
|
||||
init_scale=init_scale,
|
||||
qkv_bias=qkv_bias
|
||||
)
|
||||
|
||||
self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale)
|
||||
self.ln_3 = AdaLayerNorm(device, dtype, width)
|
||||
|
||||
def forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None):
|
||||
return checkpoint(self._forward, (x, t, context), self.parameters(), self.use_checkpoint)
|
||||
|
||||
def _forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None):
|
||||
x = x + self.attn(self.ln_1(x, t))
|
||||
if context is not None:
|
||||
x = x + self.cross_attn(self.ln_2(x, t), context)
|
||||
x = x + self.mlp(self.ln_3(x, t))
|
||||
return x
|
||||
|
||||
|
||||
class DiT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
device: Optional[torch.device],
|
||||
dtype: Optional[torch.dtype],
|
||||
n_ctx: int,
|
||||
width: int,
|
||||
layers: int,
|
||||
heads: int,
|
||||
context_dim: int,
|
||||
init_scale: float = 0.25,
|
||||
qkv_bias: bool = False,
|
||||
use_checkpoint: bool = False
|
||||
):
|
||||
super().__init__()
|
||||
self.n_ctx = n_ctx
|
||||
self.width = width
|
||||
self.layers = layers
|
||||
|
||||
self.resblocks = nn.ModuleList(
|
||||
[
|
||||
DitBlock(
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
n_ctx=n_ctx,
|
||||
width=width,
|
||||
heads=heads,
|
||||
context_dim=context_dim,
|
||||
qkv_bias=qkv_bias,
|
||||
init_scale=init_scale,
|
||||
use_checkpoint=use_checkpoint
|
||||
)
|
||||
for _ in range(layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None):
|
||||
for block in self.resblocks:
|
||||
x = block(x, t, context)
|
||||
return x
|
||||
|
||||
|
||||
class UNetDiffusionTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
device: Optional[torch.device],
|
||||
dtype: Optional[torch.dtype],
|
||||
n_ctx: int,
|
||||
width: int,
|
||||
layers: int,
|
||||
heads: int,
|
||||
init_scale: float = 0.25,
|
||||
qkv_bias: bool = False,
|
||||
skip_ln: bool = False,
|
||||
use_checkpoint: bool = False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.n_ctx = n_ctx
|
||||
self.width = width
|
||||
self.layers = layers
|
||||
|
||||
self.encoder = nn.ModuleList()
|
||||
for _ in range(layers):
|
||||
resblock = ResidualAttentionBlock(
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
n_ctx=n_ctx,
|
||||
width=width,
|
||||
heads=heads,
|
||||
init_scale=init_scale,
|
||||
qkv_bias=qkv_bias,
|
||||
use_checkpoint=use_checkpoint
|
||||
)
|
||||
self.encoder.append(resblock)
|
||||
|
||||
self.middle_block = ResidualAttentionBlock(
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
n_ctx=n_ctx,
|
||||
width=width,
|
||||
heads=heads,
|
||||
init_scale=init_scale,
|
||||
qkv_bias=qkv_bias,
|
||||
use_checkpoint=use_checkpoint
|
||||
)
|
||||
|
||||
self.decoder = nn.ModuleList()
|
||||
for _ in range(layers):
|
||||
resblock = ResidualAttentionBlock(
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
n_ctx=n_ctx,
|
||||
width=width,
|
||||
heads=heads,
|
||||
init_scale=init_scale,
|
||||
qkv_bias=qkv_bias,
|
||||
use_checkpoint=use_checkpoint
|
||||
)
|
||||
linear = nn.Linear(width * 2, width, device=device, dtype=dtype)
|
||||
init_linear(linear, init_scale)
|
||||
|
||||
layer_norm = nn.LayerNorm(width, device=device, dtype=dtype) if skip_ln else None
|
||||
|
||||
self.decoder.append(nn.ModuleList([resblock, linear, layer_norm]))
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
|
||||
enc_outputs = []
|
||||
for block in self.encoder:
|
||||
x = block(x)
|
||||
enc_outputs.append(x)
|
||||
|
||||
x = self.middle_block(x)
|
||||
|
||||
for i, (resblock, linear, layer_norm) in enumerate(self.decoder):
|
||||
x = torch.cat([enc_outputs.pop(), x], dim=-1)
|
||||
x = linear(x)
|
||||
|
||||
if layer_norm is not None:
|
||||
x = layer_norm(x)
|
||||
|
||||
x = resblock(x)
|
||||
|
||||
return x
|
||||
Reference in New Issue
Block a user