mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-03 20:32:51 +08:00
Type safety fixes
Summary: Pyre expects Mapping for ** operator. Reviewed By: bottler Differential Revision: D35288632 fbshipit-source-id: 34d6f26ad912b3a5046f440922bb6ed2fd86f533
This commit is contained in:
parent
24260130ce
commit
a999fc22ee
@ -16,11 +16,13 @@ from dataclasses import dataclass, field, fields
|
|||||||
from itertools import islice
|
from itertools import islice
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (
|
from typing import (
|
||||||
|
Any,
|
||||||
ClassVar,
|
ClassVar,
|
||||||
Dict,
|
Dict,
|
||||||
Iterable,
|
Iterable,
|
||||||
Iterator,
|
Iterator,
|
||||||
List,
|
List,
|
||||||
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
Tuple,
|
Tuple,
|
||||||
@ -42,7 +44,7 @@ from . import types
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FrameData:
|
class FrameData(Mapping[str, Any]):
|
||||||
"""
|
"""
|
||||||
A type of the elements returned by indexing the dataset object.
|
A type of the elements returned by indexing the dataset object.
|
||||||
It can represent both individual frames and batches of thereof;
|
It can represent both individual frames and batches of thereof;
|
||||||
@ -137,13 +139,16 @@ class FrameData:
|
|||||||
return self.to(device=torch.device("cuda"))
|
return self.to(device=torch.device("cuda"))
|
||||||
|
|
||||||
# the following functions make sure **frame_data can be passed to functions
|
# the following functions make sure **frame_data can be passed to functions
|
||||||
def keys(self):
|
def __iter__(self):
|
||||||
for f in fields(self):
|
for f in fields(self):
|
||||||
yield f.name
|
yield f.name
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
return getattr(self, key)
|
return getattr(self, key)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(fields(self))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def collate(cls, batch):
|
def collate(cls, batch):
|
||||||
"""
|
"""
|
||||||
|
@ -10,6 +10,7 @@ from typing import NamedTuple, Optional, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from pytorch3d.structures import Pointclouds
|
||||||
|
|
||||||
from .rasterize_points import rasterize_points
|
from .rasterize_points import rasterize_points
|
||||||
|
|
||||||
@ -75,7 +76,7 @@ class PointsRasterizer(nn.Module):
|
|||||||
self.cameras = cameras
|
self.cameras = cameras
|
||||||
self.raster_settings = raster_settings
|
self.raster_settings = raster_settings
|
||||||
|
|
||||||
def transform(self, point_clouds, **kwargs) -> torch.Tensor:
|
def transform(self, point_clouds, **kwargs) -> Pointclouds:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
point_clouds: a set of point clouds
|
point_clouds: a set of point clouds
|
||||||
|
Loading…
x
Reference in New Issue
Block a user