mirror of
https://github.com/PrimitiveAnything/PrimitiveAnything.git
synced 2026-05-08 00:58:55 +08:00
init
This commit is contained in:
87
primitive_anything/primitive_dataset.py
Normal file
87
primitive_anything/primitive_dataset.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from scipy.linalg import polar
|
||||
from scipy.spatial.transform import Rotation
|
||||
import open3d as o3d
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from .utils import exists
|
||||
from .utils.logger import print_log
|
||||
|
||||
|
||||
def create_dataset(cfg_dataset):
|
||||
kwargs = cfg_dataset
|
||||
name = kwargs.pop('name')
|
||||
dataset = get_dataset(name)(**kwargs)
|
||||
print_log(f"Dataset '{name}' init: kwargs={kwargs}, len={len(dataset)}")
|
||||
return dataset
|
||||
|
||||
def get_dataset(name):
|
||||
return {
|
||||
'base': PrimitiveDataset,
|
||||
}[name]
|
||||
|
||||
|
||||
SHAPE_CODE = {
|
||||
'CubeBevel': 0,
|
||||
'SphereSharp': 1,
|
||||
'CylinderSharp': 2,
|
||||
}
|
||||
|
||||
|
||||
class PrimitiveDataset(Dataset):
|
||||
def __init__(self,
|
||||
pc_dir,
|
||||
bs_dir,
|
||||
max_length=144,
|
||||
range_scale=[0, 1],
|
||||
range_rotation=[-180, 180],
|
||||
range_translation=[-1, 1],
|
||||
rotation_type='euler',
|
||||
pc_format='pc',
|
||||
):
|
||||
self.data_filename = os.listdir(pc_dir)
|
||||
|
||||
self.pc_dir = pc_dir
|
||||
self.max_length = max_length
|
||||
self.range_scale = range_scale
|
||||
self.range_rotation = range_rotation
|
||||
self.range_translation = range_translation
|
||||
self.rotation_type = rotation_type
|
||||
self.pc_format = pc_format
|
||||
|
||||
with open(os.path.join(bs_dir, 'basic_shapes.json'), 'r', encoding='utf-8') as f:
|
||||
basic_shapes = json.load(f)
|
||||
|
||||
self.typeid_map = {
|
||||
1101002001034001: 'CubeBevel',
|
||||
1101002001034010: 'SphereSharp',
|
||||
1101002001034002: 'CylinderSharp',
|
||||
}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_filename)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
pc_file = os.path.join(self.pc_dir, self.data_filename[idx])
|
||||
pc = o3d.io.read_point_cloud(pc_file)
|
||||
|
||||
model_data = {}
|
||||
|
||||
points = torch.from_numpy(np.asarray(pc.points)).float()
|
||||
colors = torch.from_numpy(np.asarray(pc.colors)).float()
|
||||
normals = torch.from_numpy(np.asarray(pc.normals)).float()
|
||||
if self.pc_format == 'pc':
|
||||
model_data['pc'] = torch.concatenate([points, colors], dim=-1).T
|
||||
elif self.pc_format == 'pn':
|
||||
model_data['pc'] = torch.concatenate([points, normals], dim=-1)
|
||||
elif self.pc_format == 'pcn':
|
||||
model_data['pc'] = torch.concatenate([points, colors, normals], dim=-1)
|
||||
else:
|
||||
raise ValueError(f'invalid pc_format: {self.pc_format}')
|
||||
|
||||
return model_data
|
||||
Reference in New Issue
Block a user