omegaconf 2.2.2 compatibility

Summary: OmegaConf 2.2.2 doesn't like heterogenous tuples or Sequence or Set members. Workaround this.

Reviewed By: shapovalov

Differential Revision: D37278736

fbshipit-source-id: 123e6657947f5b27514910e4074c92086a457a2a
This commit is contained in:
Jeremy Reizenstein 2022-06-24 04:18:01 -07:00 committed by Facebook GitHub Bot
parent 5c1ca757bb
commit 879495d38f
6 changed files with 16 additions and 20 deletions

View File

@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Optional, Sequence
from typing import Optional, Tuple
import torch
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
@ -86,7 +86,7 @@ class SequenceDataLoaderMapProvider(DataLoaderMapProviderBase):
num_workers: int = 0
dataset_len: int = 1000
dataset_len_val: int = 1
images_per_seq_options: Sequence[int] = (2,)
images_per_seq_options: Tuple[int, ...] = (2,)
sample_consecutive_frames: bool = False
consecutive_frames_max_gap: int = 0
consecutive_frames_max_gap_seconds: float = 0.1

View File

@ -122,9 +122,9 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
subsets: Optional[List[str]] = None
limit_to: int = 0
limit_sequences_to: int = 0
pick_sequence: Sequence[str] = ()
exclude_sequence: Sequence[str] = ()
limit_category_to: Sequence[int] = ()
pick_sequence: Tuple[str, ...] = ()
exclude_sequence: Tuple[str, ...] = ()
limit_category_to: Tuple[int, ...] = ()
dataset_root: str = ""
load_images: bool = True
load_depths: bool = True

View File

@ -7,7 +7,7 @@
import json
import os
from typing import Dict, List, Sequence, Tuple, Type
from typing import Dict, List, Tuple, Type
from omegaconf import DictConfig, open_dict
from pytorch3d.implicitron.tools.config import (
@ -98,7 +98,7 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
dataset_root: str = _CO3D_DATASET_ROOT
n_frames_per_sequence: int = -1
test_on_train: bool = False
restrict_sequence_name: Sequence[str] = ()
restrict_sequence_name: Tuple[str, ...] = ()
test_restrict_sequence_id: int = -1
assert_single_seq: bool = False
only_test_set: bool = False

View File

@ -3,7 +3,7 @@
# implicit_differentiable_renderer.py
# Copyright (c) 2020 Lior Yariv
import math
from typing import Sequence
from typing import Tuple
import torch
from pytorch3d.implicitron.tools.config import registry
@ -53,10 +53,10 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
feature_vector_size: int = 3
d_in: int = 3
d_out: int = 1
dims: Sequence[int] = (512, 512, 512, 512, 512, 512, 512, 512)
dims: Tuple[int, ...] = (512, 512, 512, 512, 512, 512, 512, 512)
geometric_init: bool = True
bias: float = 1.0
skip_in: Sequence[int] = ()
skip_in: Tuple[int, ...] = ()
weight_norm: bool = True
n_harmonic_functions_xyz: int = 0
pooled_feature_dim: int = 0

View File

@ -6,7 +6,7 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Dict, Optional, Sequence, Union
from typing import Dict, Optional, Sequence, Tuple, Union
import torch
import torch.nn.functional as F
@ -176,7 +176,7 @@ class ReductionFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
the stack of source-view-specific features to a single feature.
"""
reduction_functions: Sequence[ReductionFunction] = (
reduction_functions: Tuple[ReductionFunction, ...] = (
ReductionFunction.AVG,
ReductionFunction.STD,
)
@ -269,7 +269,7 @@ class AngleWeightedReductionFeatureAggregator(torch.nn.Module, FeatureAggregator
used when calculating the angle-based aggregation weights.
"""
reduction_functions: Sequence[ReductionFunction] = (
reduction_functions: Tuple[ReductionFunction, ...] = (
ReductionFunction.AVG,
ReductionFunction.STD,
)

View File

@ -9,7 +9,7 @@ import textwrap
import unittest
from dataclasses import dataclass, field, is_dataclass
from enum import Enum
from typing import Any, Dict, List, Optional, Set, Tuple
from typing import Any, Dict, List, Optional, Tuple
from omegaconf import DictConfig, ListConfig, OmegaConf, ValidationError
from pytorch3d.implicitron.tools.config import (
@ -484,16 +484,14 @@ class TestConfig(unittest.TestCase):
none_field: Optional[int] = None
float_field: float = 9.3
bool_field: bool = True
tuple_field: tuple = (3, True, "j")
tuple_field: Tuple[int, ...] = (3,)
class SimpleClass:
def __init__(
self,
tuple_member_: Tuple[int, int] = (3, 4),
set_member_: Set[int] = {2}, # noqa
):
self.tuple_member = tuple_member_
self.set_member = set_member_
def get_tuple(self):
return self.tuple_member
@ -524,11 +522,9 @@ class TestConfig(unittest.TestCase):
self.assertEqual(simple.get_tuple(), [3, 4])
self.assertTrue(isinstance(simple.get_tuple(), ListConfig))
# get_default_args converts sets to ListConfigs (which act like lists).
self.assertEqual(simple.set_member, [2])
self.assertTrue(isinstance(simple.set_member, ListConfig))
self.assertEqual(c.a_tuple, [4.0, 3.0])
self.assertTrue(isinstance(c.a_tuple, ListConfig))
self.assertEqual(mydata.tuple_field, (3, True, "j"))
self.assertEqual(mydata.tuple_field, (3,))
self.assertTrue(isinstance(mydata.tuple_field, ListConfig))
f(**c.f_args)