mirror of
https://github.com/PrimitiveAnything/PrimitiveAnything.git
synced 2026-05-08 00:58:55 +08:00
init
This commit is contained in:
57
primitive_anything/utils/typing.py
Executable file
57
primitive_anything/utils/typing.py
Executable file
@@ -0,0 +1,57 @@
|
||||
from environs import Env
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
from beartype import beartype
|
||||
from beartype.door import is_bearable
|
||||
|
||||
from jaxtyping import (
|
||||
Float,
|
||||
Int,
|
||||
Bool,
|
||||
jaxtyped
|
||||
)
|
||||
|
||||
# environment
|
||||
|
||||
env = Env()
|
||||
env.read_env()
|
||||
|
||||
# function
|
||||
|
||||
def always(value):
|
||||
def inner(*args, **kwargs):
|
||||
return value
|
||||
return inner
|
||||
|
||||
def identity(t):
|
||||
return t
|
||||
|
||||
# jaxtyping is a misnomer, works for pytorch
|
||||
|
||||
class TorchTyping:
|
||||
def __init__(self, abstract_dtype):
|
||||
self.abstract_dtype = abstract_dtype
|
||||
|
||||
def __getitem__(self, shapes: str):
|
||||
return self.abstract_dtype[Tensor, shapes]
|
||||
|
||||
Float = TorchTyping(Float)
|
||||
Int = TorchTyping(Int)
|
||||
Bool = TorchTyping(Bool)
|
||||
|
||||
# use env variable TYPECHECK to control whether to use beartype + jaxtyping
|
||||
|
||||
should_typecheck = env.bool('TYPECHECK', False)
|
||||
|
||||
typecheck = jaxtyped(typechecker = beartype) if should_typecheck else identity
|
||||
|
||||
beartype_isinstance = is_bearable if should_typecheck else always(True)
|
||||
|
||||
__all__ = [
|
||||
Float,
|
||||
Int,
|
||||
Bool,
|
||||
typecheck,
|
||||
beartype_isinstance
|
||||
]
|
||||
Reference in New Issue
Block a user