mirror of
https://github.com/PrimitiveAnything/PrimitiveAnything.git
synced 2025-09-18 05:22:48 +08:00
173 lines
5.8 KiB
Python
Executable File
173 lines
5.8 KiB
Python
Executable File
import argparse
|
|
import glob
|
|
import json
|
|
import yaml
|
|
from pathlib import Path
|
|
import os
|
|
import re
|
|
import numpy as np
|
|
|
|
from scipy.spatial.transform import Rotation
|
|
from tqdm import tqdm
|
|
import torch
|
|
import trimesh
|
|
|
|
from primitive_anything.primitive_dataset import create_dataset
|
|
from primitive_anything.utils import torch_to, count_parameters
|
|
from primitive_anything.utils.logger import create_logger, print_log
|
|
|
|
CODE_SHAPE = {
|
|
0: 'SM_GR_BS_CubeBevel_001.ply',
|
|
1: 'SM_GR_BS_SphereSharp_001.ply',
|
|
2: 'SM_GR_BS_CylinderSharp_001.ply',
|
|
}
|
|
|
|
shapename_map = {
|
|
'SM_GR_BS_CubeBevel_001.ply': 1101002001034001,
|
|
'SM_GR_BS_SphereSharp_001.ply': 1101002001034010,
|
|
'SM_GR_BS_CylinderSharp_001.ply': 1101002001034002,
|
|
}
|
|
|
|
bs_dir = 'data/basic_shapes_norm'
|
|
mesh_bs = {}
|
|
for bs_path in glob.glob(os.path.join(bs_dir, '*.ply')):
|
|
bs_name = os.path.basename(bs_path)
|
|
bs = trimesh.load(bs_path)
|
|
bs.visual.uv = np.clip(bs.visual.uv, 0, 1)
|
|
bs.visual = bs.visual.to_color()
|
|
mesh_bs[bs_name] = bs
|
|
|
|
|
|
def create_model(cfg_model):
|
|
kwargs = cfg_model
|
|
name = kwargs.pop('name')
|
|
model = get_model(name)(**kwargs)
|
|
print_log("Model '{}' init: nb_params={:,}, kwargs={}".format(name, count_parameters(model), kwargs))
|
|
return model
|
|
|
|
|
|
from primitive_anything.primitive_transformer import PrimitiveTransformerDiscrete
|
|
def get_model(name):
|
|
return {
|
|
'discrete': PrimitiveTransformerDiscrete,
|
|
}[name]
|
|
|
|
|
|
def euler_to_quat(euler):
|
|
return Rotation.from_euler('XYZ', euler, degrees=True).as_quat()
|
|
|
|
def rotvec_to_quat(rotvec):
|
|
return Rotation.from_rotvec(rotvec, degrees=True).as_quat()
|
|
|
|
def SRT_quat_to_matrix(scale, quat, translation):
|
|
rotation_matrix = Rotation.from_quat(quat).as_matrix()
|
|
transform_matrix = np.eye(4)
|
|
transform_matrix[:3, :3] = rotation_matrix * scale
|
|
transform_matrix[:3, 3] = translation
|
|
return transform_matrix
|
|
|
|
def write_json(primitives, shapename_map, out_path):
|
|
out_json = {}
|
|
out_json['operation'] = 0
|
|
out_json['type'] = 1
|
|
out_json['scene_id'] = None
|
|
|
|
new_group = []
|
|
model_scene = trimesh.Scene()
|
|
for scale, rotation, translation, type_code in zip(
|
|
primitives['scale'].squeeze().cpu().numpy(),
|
|
primitives['rotation'].squeeze().cpu().numpy(),
|
|
primitives['translation'].squeeze().cpu().numpy(),
|
|
primitives['type_code'].squeeze().cpu().numpy()
|
|
):
|
|
if type_code == -1:
|
|
break
|
|
bs_name = CODE_SHAPE[type_code]
|
|
new_block = {}
|
|
new_block['type_id'] = shapename_map[bs_name]
|
|
new_block['data'] = {}
|
|
new_block['data']['location'] = translation.tolist()
|
|
new_block['data']['rotation'] = euler_to_quat(rotation).tolist()
|
|
new_block['data']['scale'] = scale.tolist()
|
|
new_block['data']['color'] = ['808080']
|
|
new_group.append(new_block)
|
|
|
|
if new_block['type_id'] == 1101002001034001:
|
|
cur_color = "#2FA9FF"
|
|
elif new_block['type_id'] == 1101002001034002:
|
|
cur_color = "#FFC203"
|
|
elif new_block['type_id'] == 1101002001034010:
|
|
cur_color = "#FF8A9C"
|
|
|
|
def hex_to_rgb(hex_color):
|
|
hex_color = hex_color.lstrip('#')
|
|
return np.array([
|
|
int(hex_color[0:2], 16), # R
|
|
int(hex_color[2:4], 16), # G
|
|
int(hex_color[4:6], 16), # B
|
|
], dtype=np.uint8)[None]
|
|
|
|
trans = SRT_quat_to_matrix(scale, euler_to_quat(rotation), translation)
|
|
bs = mesh_bs[bs_name].copy().apply_transform(trans)
|
|
new_vertex_colors = np.repeat(hex_to_rgb(cur_color), bs.visual.vertex_colors.shape[0], axis=0)
|
|
bs.visual.vertex_colors[:, :3] = new_vertex_colors
|
|
vertices = bs.vertices.copy()
|
|
vertices[:, 1] = bs.vertices[:, 2]
|
|
vertices[:, 2] = -bs.vertices[:, 1]
|
|
bs.vertices = vertices
|
|
model_scene.add_geometry(bs)
|
|
|
|
out_json['group'] = new_group
|
|
|
|
with open(out_path, 'w') as json_file:
|
|
json.dump(out_json, json_file, indent=4)
|
|
|
|
glb_path = out_path.replace('.json', '.glb')
|
|
model_scene.export(glb_path)
|
|
|
|
return glb_path, out_json
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('-c', '--config', type=str, default='./configs/infer.yml', help='Config file path')
|
|
parser.add_argument('-ck', '--AR_ckpt', type=str, default='./ckpt/mesh-transformer.ckpt.60.pt')
|
|
parser.add_argument('-o', '--output', type=str, default='./results/infer')
|
|
parser.add_argument('--bs_dir', type=str, default='data/basic_shapes_norm')
|
|
parser.add_argument('--temperature', type=float, default=0.0)
|
|
args = parser.parse_args()
|
|
|
|
bs_names = []
|
|
for bs_path in glob.glob(os.path.join(args.bs_dir, '*.ply')):
|
|
bs_names.append(os.path.basename(bs_path))
|
|
|
|
with open(args.config, mode='r') as fp:
|
|
cfg = yaml.load(fp, Loader=yaml.FullLoader)
|
|
|
|
AR_checkpoint = torch.load(args.AR_ckpt)
|
|
|
|
os.makedirs(args.output, exist_ok=True)
|
|
json_result_folder = os.path.join(args.output, 'JsonResults')
|
|
os.makedirs(json_result_folder, exist_ok=True)
|
|
|
|
create_logger(Path(args.output))
|
|
|
|
dataset = create_dataset(cfg['dataset'])
|
|
|
|
transformer = create_model(cfg['model'])
|
|
transformer.load_state_dict(AR_checkpoint)
|
|
|
|
for item_i, item in tqdm(enumerate(dataset)):
|
|
pc = item.pop('pc')
|
|
|
|
item_filename = dataset.data_filename[item_i]
|
|
if torch.cuda.is_available():
|
|
pc = pc.cuda()
|
|
item = torch_to(item, torch.device('cuda'))
|
|
transformer = transformer.cuda()
|
|
|
|
recon_primitives, mask = transformer.generate(pc=pc.unsqueeze(0), temperature=args.temperature)
|
|
|
|
out_path = os.path.join(json_result_folder, os.path.basename(item_filename).replace('.ply', '.json'))
|
|
write_json(recon_primitives, shapename_map, out_path)
|