【三维几何学习】网格上低分辨率的分割结果到高分辨率的投影与可视化
创始人
2024-05-31 18:33:24

网格上低分辨率的分割结果到高分辨率的投影与可视化

  • 引言
  • 一、到高分辨率的投影
    • 1.1 准确率
    • 1.2 主要代码
    • 1.3 投影核心代码
  • 二、可视化代码

引言

三角网格的结构特性决定了其仅用少量三角形即可表示一个完整的3D模型。增加其分辨率可以展示更多模型的形状细节。对于网格分割来说,并不需要很多模型细节,只需要知晓其数据元素所属部分(类别)即可。
在这里插入图片描述

  • 上图分别为低分辨率分割结果、高分辨率投影结果以及Ground truth

在简化网格上进行预测,然后投影到高分辨率网格上一个可行的方案。例如:

MeshWalker1使用的的边界平滑
A Spectral Segmentation Method for Large Meshes2的feature-aware的网格简化

一、到高分辨率的投影

1.1 准确率

以面标签版本的COSEG外星人数据集为例,可参考三角网格(Triangular Mesh)分割数据集
在这里插入图片描述
简化网格上的准确率:96.94 到高分辨率网格投影:95.53
时间上也会快很多,毕竟计算高分辨率网格的输入特征较为费时

1.2 主要代码

部分代码来自3:MeshCNN
TriTransNet是对简化三角网格进行分割的网络,可替换为其它神经网络

import potpourri3d as pp3d
import numpy as np
import os
import pickle
from scipy.spatial import cKDTree
import time
import torch
from config.config import Config
from network.TriTransNet import TriTransNet
from postprocessing.mesh_project import get_faces_BorderPointsdef is_mesh_file(filename):return any(filename.endswith(extension) for extension in ['.obj', 'off'])def fix_vertices(vs):z = vs[:, 2].copy()vs[:, 2] = vs[:, 1]vs[:, 1] = zmax_range = 0for i in range(3):min_value = np.min(vs[:, i])max_value = np.max(vs[:, i])max_range = max(max_range, max_value - min_value)vs[:, i] -= min_valuescale_by = max_rangevs /= scale_byreturn vsdef get_seg_files(paths, seg_dir, seg_ext='.eseg'):segs = []for path in paths:segfile = os.path.join(seg_dir, os.path.splitext(os.path.basename(path))[0] + seg_ext)assert (os.path.isfile(segfile))segs.append(segfile)return segsdef make_dataset(path):meshes = []assert os.path.isdir(path), '%s is not a valid directory' % pathfor root, _, fnames in sorted(os.walk(path)):for fname in fnames:if is_mesh_file(fname):path = os.path.join(root, fname)meshes.append(path)return meshesif __name__ == '__main__':# 简化网格sim_root = '../../../datasets/face_label/coseg_aliens'sim_paths = make_dataset(os.path.join(sim_root, 'test'))# sim_labels = get_seg_files(sim_paths, seg_dir=os.path.join(sim_root, 'seg'))# 原始网格org_root = '../../../datasets/aliens'  # '../../datasets/vases'org_paths = make_dataset(os.path.join(org_root, 'test'))   # shapes  or segorg_labels = get_seg_files(org_paths, seg_dir=os.path.join(org_root, 'seg'), seg_ext='.seg')# 网络读取cfg = Config()cfg.class_n = 4cfg.mode = 'seg'net = TriTransNet(cfg)state_dict = torch.load('../../../results/aliens_1500/model/latest_xyz_net.pth')  # latest_xyz_net 95.53432if hasattr(state_dict, '_metadata'):del state_dict._metadatanet.load_state_dict(state_dict)net.eval()# 准确率统计all_acc = 0sim_acc = 0are_acc = 0for i in range(len(sim_paths)):# 获取网格数据sim_name = sim_paths[i]filename, _ = os.path.splitext(sim_name)prefix = os.path.basename(filename)cache = os.path.join('../../../results/aliens_1500/cache/', prefix + '.pkl')with open(cache, 'rb') as f:   # 不再计算 读取缓存meta = pickle.load(f)# 获取网格数据sim_mesh = meta['mesh']sim_label = meta['label']vs = fix_vertices(sim_mesh.vs)# 获取预测标签with torch.no_grad():face_features = np.concatenate([sim_mesh.face_features, sim_mesh.xyz], axis = 0)  # sim_mesh.hks[0:3]face_features = torch.from_numpy(face_features).float().unsqueeze(0)out = net(face_features, [sim_mesh])label = out.data.max(1)[1]sim_correct = label.eq(torch.from_numpy(sim_label).long()).sum().float() / sim_mesh.faces_numsim_acc += sim_correct# 面积# idex = label.eq(torch.from_numpy(sim_label).long()).numpy().reshape(-1)# face_area = sim_mesh.face_features[6, :]# sum_area = face_area.sum()# are_acc += face_area[idex].sum() / sum_area# 时间t = time.time()# 投影准备label = label.numpy().reshape(-1)BorderPoints_xyz, BorderPoints_label = get_faces_BorderPoints(vs, sim_mesh.faces, label, border_k=0.01, border_num=10)# 0.01 10 95.53432# 0.5 1  退化成最简单的最近邻  94.02kdt = cKDTree(BorderPoints_xyz)# 读取高分辨率网格org_vs, org_faces = pp3d.read_mesh(org_paths[i])org_vs = fix_vertices(org_vs)org_label = np.loadtxt(open(org_labels[i], 'r'), dtype='float64') -1# 原始网格中心点mean_vs = org_vs[org_faces]mean_vs = mean_vs.sum(axis=1) / 3.0dist, indices = kdt.query(mean_vs, workers=-1)# 准确率计算org_prolabels = BorderPoints_label[indices].reshape(-1)pro_cnt = np.equal(org_prolabels, org_label).sum()pro_acc = pro_cnt / len(org_label)all_acc += pro_accprint(filename, ':', pro_acc, ' time:', time.time()-t)print(all_acc / len(sim_paths))print(sim_acc / len(sim_paths))# print(are_acc / len(sim_paths))

1.3 投影核心代码

def get_faces_BorderPoints(vs, faces, labels, border_k=0.1, border_num=1):"""border_k:   远离边的系数border_num: 每条边的边缘点数首先 默认简化是不会过分破坏分割边界 简化后的网格和原网格基本对齐1.简化后的面更大 以一个面为例 均匀采样其边界部分形成边缘点 边缘点的标签赋值为面的标签2.赋值原网格面标签为 距离其重心最近的简化网格边缘点标签"""BorderPoints_xyz = -np.ones((len(faces) * 3 * border_num, 3), np.float64)BorderPoints_label = -np.ones((len(faces) * 3 * border_num, 1), np.int32)cnt = 0for face_id in range(len(faces)):face = faces[face_id]label = labels[face_id]for i in range(3):if border_num > 1:p1, p2, p = vs[face[i]], vs[face[(i + 1) % 3]], vs[face[(i + 2) % 3]]for j in range(border_num):center_p = p1 + (p2 - p1) / (border_num + 1) * (j + 1)border_p = center_p + (p - center_p) * border_kBorderPoints_xyz[cnt] = border_pBorderPoints_label[cnt] = labelcnt = cnt + 1else:p1, p2, p = vs[face[i]], vs[face[(i + 1) % 3]], vs[face[(i + 2) % 3]]center_p = (p1 + p2) / 2border_p = center_p + (p - center_p) * border_kBorderPoints_xyz[cnt] = border_pBorderPoints_label[cnt] = labelcnt = cnt + 1return BorderPoints_xyz, BorderPoints_label

二、可视化代码

减小可视化网格边的边长,查看模型细节:
在这里插入图片描述

import potpourri3d as pp3d
import numpy as np
import os
import pickle
from scipy.spatial import cKDTree
import time
import pylab as pl
import torch
from config.config import Config
from network.TriTransNet import TriTransNet
from postprocessing.mesh_project import get_faces_BorderPoints
import mpl_toolkits.mplot3d as a3
import matplotlib.colors as colors
from scipy import linalgdef rot_vs_axis_z(vs, radian, scale):bias = np.mean(vs)vs = vs - biasvs *= scalerot_matrix = linalg.expm(np.cross(np.eye(3), [0, 0, 1] / linalg.norm([0, 0, 1]) * radian))vs = np.dot(rot_matrix, vs.T)vs = vs.T + biasreturn vsdef init_ax(ax):# hide axis, thank to# https://stackoverflow.com/questions/29041326/3d-plot-with-matplotlib-hide-axes-but-keep-axis-labels/ax.w_xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))ax.w_yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))ax.w_zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))# Get rid of the spinesax.w_xaxis.line.set_color((1.0, 1.0, 1.0, 0.0))ax.w_yaxis.line.set_color((1.0, 1.0, 1.0, 0.0))ax.w_zaxis.line.set_color((1.0, 1.0, 1.0, 0.0))# Get rid of the ticksax.set_xticks([])ax.set_yticks([])ax.set_zticks([])return axdef is_mesh_file(filename):return any(filename.endswith(extension) for extension in ['.obj', 'off'])def fix_vertices(vs):z = vs[:, 2].copy()vs[:, 2] = vs[:, 1]vs[:, 1] = zmax_range = 0for i in range(3):min_value = np.min(vs[:, i])max_value = np.max(vs[:, i])max_range = max(max_range, max_value - min_value)vs[:, i] -= min_valuescale_by = max_rangevs /= scale_byreturn vsdef get_seg_files(paths, seg_dir, seg_ext='.eseg'):segs = []for path in paths:segfile = os.path.join(seg_dir, os.path.splitext(os.path.basename(path))[0] + seg_ext)assert (os.path.isfile(segfile))segs.append(segfile)return segsdef make_dataset(path):meshes = []assert os.path.isdir(path), '%s is not a valid directory' % pathfor root, _, fnames in sorted(os.walk(path)):for fname in fnames:if is_mesh_file(fname):path = os.path.join(root, fname)meshes.append(path)return meshesif __name__ == '__main__':# 简化网格sim_root = '../../../datasets/face_label/coseg_aliens'sim_paths = make_dataset(os.path.join(sim_root, 'test'))# sim_labels = get_seg_files(sim_paths, seg_dir=os.path.join(sim_root, 'seg'))# 原始网格org_root = '../../../datasets/aliens'  # '../../datasets/vases'org_paths = make_dataset(os.path.join(org_root, 'test'))   # shapes  or segorg_labels = get_seg_files(org_paths, seg_dir=os.path.join(org_root, 'seg'), seg_ext='.seg')# 网络读取cfg = Config()cfg.class_n = 4cfg.mode = 'seg'net = TriTransNet(cfg)state_dict = torch.load('../../../results/aliens_1500/model/latest_xyz_net.pth')  # latest_xyz_net 95.53432if hasattr(state_dict, '_metadata'):del state_dict._metadatanet.load_state_dict(state_dict)net.eval()# 准确率统计all_acc = 0sim_acc = 0are_acc = 0for i in range(len(sim_paths)):# 获取网格数据sim_name = sim_paths[i]filename, _ = os.path.splitext(sim_name)prefix = os.path.basename(filename)# 选择某一个网格可视化#if prefix != '132':#    continueif i != 3:continuecache = os.path.join('../../../results/aliens_1500/cache/', prefix + '.pkl')with open(cache, 'rb') as f:   # 不再计算 读取缓存meta = pickle.load(f)# 获取网格数据sim_mesh = meta['mesh']sim_label = meta['label']vs = fix_vertices(sim_mesh.vs)# 获取预测标签with torch.no_grad():face_features = np.concatenate([sim_mesh.face_features, sim_mesh.xyz], axis = 0)  # sim_mesh.hks[0:3]face_features = torch.from_numpy(face_features).float().unsqueeze(0)out = net(face_features, [sim_mesh])label = out.data.max(1)[1]sim_correct = label.eq(torch.from_numpy(sim_label).long()).sum().float() / sim_mesh.faces_numsim_acc += sim_correct# 时间t = time.time()# 投影准备label = label.numpy().reshape(-1)BorderPoints_xyz, BorderPoints_label = get_faces_BorderPoints(vs, sim_mesh.faces, label, border_k=0.01, border_num=10)# 0.01 10 95.53432# 0.5 1  退化成最简单的最近邻  94.02kdt = cKDTree(BorderPoints_xyz)# 读取高分辨率网格org_vs, org_faces = pp3d.read_mesh(org_paths[i])org_vs = fix_vertices(org_vs)org_label = np.loadtxt(open(org_labels[i], 'r'), dtype='float64') -1# 原始网格中心点mean_vs = org_vs[org_faces]mean_vs = mean_vs.sum(axis=1) / 3.0dist, indices = kdt.query(mean_vs, workers=-1)# 准确率计算org_prolabels = BorderPoints_label[indices].reshape(-1)pro_cnt = np.equal(org_prolabels, org_label).sum()pro_acc = pro_cnt / len(org_label)all_acc += pro_accprint(filename, ':', pro_acc, ' time:', time.time()-t)# 可视化f = pl.figure()ax = f.add_subplot(1, 1, 1, projection='3d')ax = init_ax(ax)r2h = lambda x: colors.rgb2hex(tuple(map(lambda y: y / 255., x)))f_colors = [r2h((0, 0, 255)), r2h((0, 255, 255)), r2h((255, 0, 255)), r2h((0, 255, 0))]vis_bias = 0.3 ## 简化网格faces_color = []for l in label:faces_color.append(f_colors[l - 1])vs = rot_vs_axis_z(vs, 0.95, 1)tri = a3.art3d.Poly3DCollection(vs[sim_mesh.faces],facecolors=faces_color,edgecolors=r2h((0, 0, 0)),linewidths=0.1,  # 0.1# linestyles='dashdot',alpha=1)ax.add_collection3d(tri)# 高分辨率网格org_vs = rot_vs_axis_z(org_vs, 0.95, 1)faces_color = []for l in org_prolabels.astype(int):faces_color.append(f_colors[l - 1])org_vs[:, 0] += vs[:, 0].max()/2 + vis_biastri1 = a3.art3d.Poly3DCollection(org_vs[org_faces],facecolors=faces_color,edgecolors=r2h((0, 0, 0)),linewidths=0.1,# linestyles='dashdot',alpha=1)ax.add_collection3d(tri1)max_x = org_vs[:, 0].max()# 高分辨率网格Ground truthfaces_color = []for l in org_label.astype(int):faces_color.append(f_colors[l - 1])org_vs[:, 0] += vs[:, 0].max() / 2 + vis_biastri2 = a3.art3d.Poly3DCollection(org_vs[org_faces],facecolors=faces_color,edgecolors=r2h((0, 0, 0)),linewidths=0.1,# linestyles='dashdot',alpha=1)ax.add_collection3d(tri2)max_x = org_vs[:, 0].max()ax.auto_scale_xyz([0, max_x], [0, max_x], [0, max_x])ax.view_init(0, -90)pl.tight_layout()pl.savefig('corr.png', dpi=1000)  # ipl.show()breakprint(all_acc / len(sim_paths))print(sim_acc / len(sim_paths))

  1. MeshWalker: Deep Mesh Understanding by Random Walks ↩︎

  2. A Spectral Segmentation Method for Large Meshes ↩︎

  3. MeshCNN ↩︎

相关内容

热门资讯

苗族的传统节日 贵州苗族节日有... 【岜沙苗族芦笙节】岜沙,苗语叫“分送”,距从江县城7.5公里,是世界上最崇拜树木并以树为神的枪手部落...
北京的名胜古迹 北京最著名的景... 北京从元代开始,逐渐走上帝国首都的道路,先是成为大辽朝五大首都之一的南京城,随着金灭辽,金代从海陵王...
应用未安装解决办法 平板应用未... ---IT小技术,每天Get一个小技能!一、前言描述苹果IPad2居然不能安装怎么办?与此IPad不...
长白山自助游攻略 吉林长白山游... 昨天介绍了西坡的景点详细请看链接:一个人的旅行,据说能看到长白山天池全凭运气,您的运气如何?今日介绍...