mirror of
https://github.com/PrimitiveAnything/PrimitiveAnything.git
synced 2025-09-18 05:22:48 +08:00
58 lines
1.0 KiB
Python
Executable File
58 lines
1.0 KiB
Python
Executable File
from environs import Env
|
|
|
|
from torch import Tensor
|
|
|
|
from beartype import beartype
|
|
from beartype.door import is_bearable
|
|
|
|
from jaxtyping import (
|
|
Float,
|
|
Int,
|
|
Bool,
|
|
jaxtyped
|
|
)
|
|
|
|
# environment
|
|
|
|
env = Env()
|
|
env.read_env()
|
|
|
|
# function
|
|
|
|
def always(value):
|
|
def inner(*args, **kwargs):
|
|
return value
|
|
return inner
|
|
|
|
def identity(t):
|
|
return t
|
|
|
|
# jaxtyping is a misnomer, works for pytorch
|
|
|
|
class TorchTyping:
|
|
def __init__(self, abstract_dtype):
|
|
self.abstract_dtype = abstract_dtype
|
|
|
|
def __getitem__(self, shapes: str):
|
|
return self.abstract_dtype[Tensor, shapes]
|
|
|
|
Float = TorchTyping(Float)
|
|
Int = TorchTyping(Int)
|
|
Bool = TorchTyping(Bool)
|
|
|
|
# use env variable TYPECHECK to control whether to use beartype + jaxtyping
|
|
|
|
should_typecheck = env.bool('TYPECHECK', False)
|
|
|
|
typecheck = jaxtyped(typechecker = beartype) if should_typecheck else identity
|
|
|
|
beartype_isinstance = is_bearable if should_typecheck else always(True)
|
|
|
|
__all__ = [
|
|
Float,
|
|
Int,
|
|
Bool,
|
|
typecheck,
|
|
beartype_isinstance
|
|
]
|