mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
omegaconf 2.2.2 compatibility
Summary: OmegaConf 2.2.2 doesn't like heterogenous tuples or Sequence or Set members. Workaround this. Reviewed By: shapovalov Differential Revision: D37278736 fbshipit-source-id: 123e6657947f5b27514910e4074c92086a457a2a
This commit is contained in:
parent
5c1ca757bb
commit
879495d38f
@ -5,7 +5,7 @@
|
|||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Sequence
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||||
@ -86,7 +86,7 @@ class SequenceDataLoaderMapProvider(DataLoaderMapProviderBase):
|
|||||||
num_workers: int = 0
|
num_workers: int = 0
|
||||||
dataset_len: int = 1000
|
dataset_len: int = 1000
|
||||||
dataset_len_val: int = 1
|
dataset_len_val: int = 1
|
||||||
images_per_seq_options: Sequence[int] = (2,)
|
images_per_seq_options: Tuple[int, ...] = (2,)
|
||||||
sample_consecutive_frames: bool = False
|
sample_consecutive_frames: bool = False
|
||||||
consecutive_frames_max_gap: int = 0
|
consecutive_frames_max_gap: int = 0
|
||||||
consecutive_frames_max_gap_seconds: float = 0.1
|
consecutive_frames_max_gap_seconds: float = 0.1
|
||||||
|
@ -122,9 +122,9 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
subsets: Optional[List[str]] = None
|
subsets: Optional[List[str]] = None
|
||||||
limit_to: int = 0
|
limit_to: int = 0
|
||||||
limit_sequences_to: int = 0
|
limit_sequences_to: int = 0
|
||||||
pick_sequence: Sequence[str] = ()
|
pick_sequence: Tuple[str, ...] = ()
|
||||||
exclude_sequence: Sequence[str] = ()
|
exclude_sequence: Tuple[str, ...] = ()
|
||||||
limit_category_to: Sequence[int] = ()
|
limit_category_to: Tuple[int, ...] = ()
|
||||||
dataset_root: str = ""
|
dataset_root: str = ""
|
||||||
load_images: bool = True
|
load_images: bool = True
|
||||||
load_depths: bool = True
|
load_depths: bool = True
|
||||||
|
@ -7,7 +7,7 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import Dict, List, Sequence, Tuple, Type
|
from typing import Dict, List, Tuple, Type
|
||||||
|
|
||||||
from omegaconf import DictConfig, open_dict
|
from omegaconf import DictConfig, open_dict
|
||||||
from pytorch3d.implicitron.tools.config import (
|
from pytorch3d.implicitron.tools.config import (
|
||||||
@ -98,7 +98,7 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
dataset_root: str = _CO3D_DATASET_ROOT
|
dataset_root: str = _CO3D_DATASET_ROOT
|
||||||
n_frames_per_sequence: int = -1
|
n_frames_per_sequence: int = -1
|
||||||
test_on_train: bool = False
|
test_on_train: bool = False
|
||||||
restrict_sequence_name: Sequence[str] = ()
|
restrict_sequence_name: Tuple[str, ...] = ()
|
||||||
test_restrict_sequence_id: int = -1
|
test_restrict_sequence_id: int = -1
|
||||||
assert_single_seq: bool = False
|
assert_single_seq: bool = False
|
||||||
only_test_set: bool = False
|
only_test_set: bool = False
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
# implicit_differentiable_renderer.py
|
# implicit_differentiable_renderer.py
|
||||||
# Copyright (c) 2020 Lior Yariv
|
# Copyright (c) 2020 Lior Yariv
|
||||||
import math
|
import math
|
||||||
from typing import Sequence
|
from typing import Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.implicitron.tools.config import registry
|
from pytorch3d.implicitron.tools.config import registry
|
||||||
@ -53,10 +53,10 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
feature_vector_size: int = 3
|
feature_vector_size: int = 3
|
||||||
d_in: int = 3
|
d_in: int = 3
|
||||||
d_out: int = 1
|
d_out: int = 1
|
||||||
dims: Sequence[int] = (512, 512, 512, 512, 512, 512, 512, 512)
|
dims: Tuple[int, ...] = (512, 512, 512, 512, 512, 512, 512, 512)
|
||||||
geometric_init: bool = True
|
geometric_init: bool = True
|
||||||
bias: float = 1.0
|
bias: float = 1.0
|
||||||
skip_in: Sequence[int] = ()
|
skip_in: Tuple[int, ...] = ()
|
||||||
weight_norm: bool = True
|
weight_norm: bool = True
|
||||||
n_harmonic_functions_xyz: int = 0
|
n_harmonic_functions_xyz: int = 0
|
||||||
pooled_feature_dim: int = 0
|
pooled_feature_dim: int = 0
|
||||||
|
@ -6,7 +6,7 @@
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, Optional, Sequence, Union
|
from typing import Dict, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -176,7 +176,7 @@ class ReductionFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
|
|||||||
the stack of source-view-specific features to a single feature.
|
the stack of source-view-specific features to a single feature.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
reduction_functions: Sequence[ReductionFunction] = (
|
reduction_functions: Tuple[ReductionFunction, ...] = (
|
||||||
ReductionFunction.AVG,
|
ReductionFunction.AVG,
|
||||||
ReductionFunction.STD,
|
ReductionFunction.STD,
|
||||||
)
|
)
|
||||||
@ -269,7 +269,7 @@ class AngleWeightedReductionFeatureAggregator(torch.nn.Module, FeatureAggregator
|
|||||||
used when calculating the angle-based aggregation weights.
|
used when calculating the angle-based aggregation weights.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
reduction_functions: Sequence[ReductionFunction] = (
|
reduction_functions: Tuple[ReductionFunction, ...] = (
|
||||||
ReductionFunction.AVG,
|
ReductionFunction.AVG,
|
||||||
ReductionFunction.STD,
|
ReductionFunction.STD,
|
||||||
)
|
)
|
||||||
|
@ -9,7 +9,7 @@ import textwrap
|
|||||||
import unittest
|
import unittest
|
||||||
from dataclasses import dataclass, field, is_dataclass
|
from dataclasses import dataclass, field, is_dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from omegaconf import DictConfig, ListConfig, OmegaConf, ValidationError
|
from omegaconf import DictConfig, ListConfig, OmegaConf, ValidationError
|
||||||
from pytorch3d.implicitron.tools.config import (
|
from pytorch3d.implicitron.tools.config import (
|
||||||
@ -484,16 +484,14 @@ class TestConfig(unittest.TestCase):
|
|||||||
none_field: Optional[int] = None
|
none_field: Optional[int] = None
|
||||||
float_field: float = 9.3
|
float_field: float = 9.3
|
||||||
bool_field: bool = True
|
bool_field: bool = True
|
||||||
tuple_field: tuple = (3, True, "j")
|
tuple_field: Tuple[int, ...] = (3,)
|
||||||
|
|
||||||
class SimpleClass:
|
class SimpleClass:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
tuple_member_: Tuple[int, int] = (3, 4),
|
tuple_member_: Tuple[int, int] = (3, 4),
|
||||||
set_member_: Set[int] = {2}, # noqa
|
|
||||||
):
|
):
|
||||||
self.tuple_member = tuple_member_
|
self.tuple_member = tuple_member_
|
||||||
self.set_member = set_member_
|
|
||||||
|
|
||||||
def get_tuple(self):
|
def get_tuple(self):
|
||||||
return self.tuple_member
|
return self.tuple_member
|
||||||
@ -524,11 +522,9 @@ class TestConfig(unittest.TestCase):
|
|||||||
self.assertEqual(simple.get_tuple(), [3, 4])
|
self.assertEqual(simple.get_tuple(), [3, 4])
|
||||||
self.assertTrue(isinstance(simple.get_tuple(), ListConfig))
|
self.assertTrue(isinstance(simple.get_tuple(), ListConfig))
|
||||||
# get_default_args converts sets to ListConfigs (which act like lists).
|
# get_default_args converts sets to ListConfigs (which act like lists).
|
||||||
self.assertEqual(simple.set_member, [2])
|
|
||||||
self.assertTrue(isinstance(simple.set_member, ListConfig))
|
|
||||||
self.assertEqual(c.a_tuple, [4.0, 3.0])
|
self.assertEqual(c.a_tuple, [4.0, 3.0])
|
||||||
self.assertTrue(isinstance(c.a_tuple, ListConfig))
|
self.assertTrue(isinstance(c.a_tuple, ListConfig))
|
||||||
self.assertEqual(mydata.tuple_field, (3, True, "j"))
|
self.assertEqual(mydata.tuple_field, (3,))
|
||||||
self.assertTrue(isinstance(mydata.tuple_field, ListConfig))
|
self.assertTrue(isinstance(mydata.tuple_field, ListConfig))
|
||||||
f(**c.f_args)
|
f(**c.f_args)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user