三角网格的结构特性决定了其仅用少量三角形即可表示一个完整的3D模型。增加其分辨率可以展示更多模型的形状细节。对于网格分割来说,并不需要很多模型细节,只需要知晓其数据元素所属部分(类别)即可。

在简化网格上进行预测,然后投影到高分辨率网格上一个可行的方案。例如:
MeshWalker1使用的的边界平滑
A Spectral Segmentation Method for Large Meshes2的feature-aware的网格简化
以面标签版本的COSEG外星人数据集为例,可参考三角网格(Triangular Mesh)分割数据集

简化网格上的准确率:96.94 到高分辨率网格投影:95.53
时间上也会快很多,毕竟计算高分辨率网格的输入特征较为费时
部分代码来自3:MeshCNN
TriTransNet是对简化三角网格进行分割的网络,可替换为其它神经网络
import potpourri3d as pp3d
import numpy as np
import os
import pickle
from scipy.spatial import cKDTree
import time
import torch
from config.config import Config
from network.TriTransNet import TriTransNet
from postprocessing.mesh_project import get_faces_BorderPointsdef is_mesh_file(filename):return any(filename.endswith(extension) for extension in ['.obj', 'off'])def fix_vertices(vs):z = vs[:, 2].copy()vs[:, 2] = vs[:, 1]vs[:, 1] = zmax_range = 0for i in range(3):min_value = np.min(vs[:, i])max_value = np.max(vs[:, i])max_range = max(max_range, max_value - min_value)vs[:, i] -= min_valuescale_by = max_rangevs /= scale_byreturn vsdef get_seg_files(paths, seg_dir, seg_ext='.eseg'):segs = []for path in paths:segfile = os.path.join(seg_dir, os.path.splitext(os.path.basename(path))[0] + seg_ext)assert (os.path.isfile(segfile))segs.append(segfile)return segsdef make_dataset(path):meshes = []assert os.path.isdir(path), '%s is not a valid directory' % pathfor root, _, fnames in sorted(os.walk(path)):for fname in fnames:if is_mesh_file(fname):path = os.path.join(root, fname)meshes.append(path)return meshesif __name__ == '__main__':# 简化网格sim_root = '../../../datasets/face_label/coseg_aliens'sim_paths = make_dataset(os.path.join(sim_root, 'test'))# sim_labels = get_seg_files(sim_paths, seg_dir=os.path.join(sim_root, 'seg'))# 原始网格org_root = '../../../datasets/aliens' # '../../datasets/vases'org_paths = make_dataset(os.path.join(org_root, 'test')) # shapes or segorg_labels = get_seg_files(org_paths, seg_dir=os.path.join(org_root, 'seg'), seg_ext='.seg')# 网络读取cfg = Config()cfg.class_n = 4cfg.mode = 'seg'net = TriTransNet(cfg)state_dict = torch.load('../../../results/aliens_1500/model/latest_xyz_net.pth') # latest_xyz_net 95.53432if hasattr(state_dict, '_metadata'):del state_dict._metadatanet.load_state_dict(state_dict)net.eval()# 准确率统计all_acc = 0sim_acc = 0are_acc = 0for i in range(len(sim_paths)):# 获取网格数据sim_name = sim_paths[i]filename, _ = os.path.splitext(sim_name)prefix = os.path.basename(filename)cache = os.path.join('../../../results/aliens_1500/cache/', prefix + '.pkl')with open(cache, 'rb') as f: # 不再计算 读取缓存meta = pickle.load(f)# 获取网格数据sim_mesh = meta['mesh']sim_label = meta['label']vs = fix_vertices(sim_mesh.vs)# 获取预测标签with torch.no_grad():face_features = np.concatenate([sim_mesh.face_features, sim_mesh.xyz], axis = 0) # sim_mesh.hks[0:3]face_features = torch.from_numpy(face_features).float().unsqueeze(0)out = net(face_features, [sim_mesh])label = out.data.max(1)[1]sim_correct = label.eq(torch.from_numpy(sim_label).long()).sum().float() / sim_mesh.faces_numsim_acc += sim_correct# 面积# idex = label.eq(torch.from_numpy(sim_label).long()).numpy().reshape(-1)# face_area = sim_mesh.face_features[6, :]# sum_area = face_area.sum()# are_acc += face_area[idex].sum() / sum_area# 时间t = time.time()# 投影准备label = label.numpy().reshape(-1)BorderPoints_xyz, BorderPoints_label = get_faces_BorderPoints(vs, sim_mesh.faces, label, border_k=0.01, border_num=10)# 0.01 10 95.53432# 0.5 1 退化成最简单的最近邻 94.02kdt = cKDTree(BorderPoints_xyz)# 读取高分辨率网格org_vs, org_faces = pp3d.read_mesh(org_paths[i])org_vs = fix_vertices(org_vs)org_label = np.loadtxt(open(org_labels[i], 'r'), dtype='float64') -1# 原始网格中心点mean_vs = org_vs[org_faces]mean_vs = mean_vs.sum(axis=1) / 3.0dist, indices = kdt.query(mean_vs, workers=-1)# 准确率计算org_prolabels = BorderPoints_label[indices].reshape(-1)pro_cnt = np.equal(org_prolabels, org_label).sum()pro_acc = pro_cnt / len(org_label)all_acc += pro_accprint(filename, ':', pro_acc, ' time:', time.time()-t)print(all_acc / len(sim_paths))print(sim_acc / len(sim_paths))# print(are_acc / len(sim_paths))
def get_faces_BorderPoints(vs, faces, labels, border_k=0.1, border_num=1):"""border_k: 远离边的系数border_num: 每条边的边缘点数首先 默认简化是不会过分破坏分割边界 简化后的网格和原网格基本对齐1.简化后的面更大 以一个面为例 均匀采样其边界部分形成边缘点 边缘点的标签赋值为面的标签2.赋值原网格面标签为 距离其重心最近的简化网格边缘点标签"""BorderPoints_xyz = -np.ones((len(faces) * 3 * border_num, 3), np.float64)BorderPoints_label = -np.ones((len(faces) * 3 * border_num, 1), np.int32)cnt = 0for face_id in range(len(faces)):face = faces[face_id]label = labels[face_id]for i in range(3):if border_num > 1:p1, p2, p = vs[face[i]], vs[face[(i + 1) % 3]], vs[face[(i + 2) % 3]]for j in range(border_num):center_p = p1 + (p2 - p1) / (border_num + 1) * (j + 1)border_p = center_p + (p - center_p) * border_kBorderPoints_xyz[cnt] = border_pBorderPoints_label[cnt] = labelcnt = cnt + 1else:p1, p2, p = vs[face[i]], vs[face[(i + 1) % 3]], vs[face[(i + 2) % 3]]center_p = (p1 + p2) / 2border_p = center_p + (p - center_p) * border_kBorderPoints_xyz[cnt] = border_pBorderPoints_label[cnt] = labelcnt = cnt + 1return BorderPoints_xyz, BorderPoints_label
减小可视化网格边的边长,查看模型细节:

import potpourri3d as pp3d
import numpy as np
import os
import pickle
from scipy.spatial import cKDTree
import time
import pylab as pl
import torch
from config.config import Config
from network.TriTransNet import TriTransNet
from postprocessing.mesh_project import get_faces_BorderPoints
import mpl_toolkits.mplot3d as a3
import matplotlib.colors as colors
from scipy import linalgdef rot_vs_axis_z(vs, radian, scale):bias = np.mean(vs)vs = vs - biasvs *= scalerot_matrix = linalg.expm(np.cross(np.eye(3), [0, 0, 1] / linalg.norm([0, 0, 1]) * radian))vs = np.dot(rot_matrix, vs.T)vs = vs.T + biasreturn vsdef init_ax(ax):# hide axis, thank to# https://stackoverflow.com/questions/29041326/3d-plot-with-matplotlib-hide-axes-but-keep-axis-labels/ax.w_xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))ax.w_yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))ax.w_zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))# Get rid of the spinesax.w_xaxis.line.set_color((1.0, 1.0, 1.0, 0.0))ax.w_yaxis.line.set_color((1.0, 1.0, 1.0, 0.0))ax.w_zaxis.line.set_color((1.0, 1.0, 1.0, 0.0))# Get rid of the ticksax.set_xticks([])ax.set_yticks([])ax.set_zticks([])return axdef is_mesh_file(filename):return any(filename.endswith(extension) for extension in ['.obj', 'off'])def fix_vertices(vs):z = vs[:, 2].copy()vs[:, 2] = vs[:, 1]vs[:, 1] = zmax_range = 0for i in range(3):min_value = np.min(vs[:, i])max_value = np.max(vs[:, i])max_range = max(max_range, max_value - min_value)vs[:, i] -= min_valuescale_by = max_rangevs /= scale_byreturn vsdef get_seg_files(paths, seg_dir, seg_ext='.eseg'):segs = []for path in paths:segfile = os.path.join(seg_dir, os.path.splitext(os.path.basename(path))[0] + seg_ext)assert (os.path.isfile(segfile))segs.append(segfile)return segsdef make_dataset(path):meshes = []assert os.path.isdir(path), '%s is not a valid directory' % pathfor root, _, fnames in sorted(os.walk(path)):for fname in fnames:if is_mesh_file(fname):path = os.path.join(root, fname)meshes.append(path)return meshesif __name__ == '__main__':# 简化网格sim_root = '../../../datasets/face_label/coseg_aliens'sim_paths = make_dataset(os.path.join(sim_root, 'test'))# sim_labels = get_seg_files(sim_paths, seg_dir=os.path.join(sim_root, 'seg'))# 原始网格org_root = '../../../datasets/aliens' # '../../datasets/vases'org_paths = make_dataset(os.path.join(org_root, 'test')) # shapes or segorg_labels = get_seg_files(org_paths, seg_dir=os.path.join(org_root, 'seg'), seg_ext='.seg')# 网络读取cfg = Config()cfg.class_n = 4cfg.mode = 'seg'net = TriTransNet(cfg)state_dict = torch.load('../../../results/aliens_1500/model/latest_xyz_net.pth') # latest_xyz_net 95.53432if hasattr(state_dict, '_metadata'):del state_dict._metadatanet.load_state_dict(state_dict)net.eval()# 准确率统计all_acc = 0sim_acc = 0are_acc = 0for i in range(len(sim_paths)):# 获取网格数据sim_name = sim_paths[i]filename, _ = os.path.splitext(sim_name)prefix = os.path.basename(filename)# 选择某一个网格可视化#if prefix != '132':# continueif i != 3:continuecache = os.path.join('../../../results/aliens_1500/cache/', prefix + '.pkl')with open(cache, 'rb') as f: # 不再计算 读取缓存meta = pickle.load(f)# 获取网格数据sim_mesh = meta['mesh']sim_label = meta['label']vs = fix_vertices(sim_mesh.vs)# 获取预测标签with torch.no_grad():face_features = np.concatenate([sim_mesh.face_features, sim_mesh.xyz], axis = 0) # sim_mesh.hks[0:3]face_features = torch.from_numpy(face_features).float().unsqueeze(0)out = net(face_features, [sim_mesh])label = out.data.max(1)[1]sim_correct = label.eq(torch.from_numpy(sim_label).long()).sum().float() / sim_mesh.faces_numsim_acc += sim_correct# 时间t = time.time()# 投影准备label = label.numpy().reshape(-1)BorderPoints_xyz, BorderPoints_label = get_faces_BorderPoints(vs, sim_mesh.faces, label, border_k=0.01, border_num=10)# 0.01 10 95.53432# 0.5 1 退化成最简单的最近邻 94.02kdt = cKDTree(BorderPoints_xyz)# 读取高分辨率网格org_vs, org_faces = pp3d.read_mesh(org_paths[i])org_vs = fix_vertices(org_vs)org_label = np.loadtxt(open(org_labels[i], 'r'), dtype='float64') -1# 原始网格中心点mean_vs = org_vs[org_faces]mean_vs = mean_vs.sum(axis=1) / 3.0dist, indices = kdt.query(mean_vs, workers=-1)# 准确率计算org_prolabels = BorderPoints_label[indices].reshape(-1)pro_cnt = np.equal(org_prolabels, org_label).sum()pro_acc = pro_cnt / len(org_label)all_acc += pro_accprint(filename, ':', pro_acc, ' time:', time.time()-t)# 可视化f = pl.figure()ax = f.add_subplot(1, 1, 1, projection='3d')ax = init_ax(ax)r2h = lambda x: colors.rgb2hex(tuple(map(lambda y: y / 255., x)))f_colors = [r2h((0, 0, 255)), r2h((0, 255, 255)), r2h((255, 0, 255)), r2h((0, 255, 0))]vis_bias = 0.3 ## 简化网格faces_color = []for l in label:faces_color.append(f_colors[l - 1])vs = rot_vs_axis_z(vs, 0.95, 1)tri = a3.art3d.Poly3DCollection(vs[sim_mesh.faces],facecolors=faces_color,edgecolors=r2h((0, 0, 0)),linewidths=0.1, # 0.1# linestyles='dashdot',alpha=1)ax.add_collection3d(tri)# 高分辨率网格org_vs = rot_vs_axis_z(org_vs, 0.95, 1)faces_color = []for l in org_prolabels.astype(int):faces_color.append(f_colors[l - 1])org_vs[:, 0] += vs[:, 0].max()/2 + vis_biastri1 = a3.art3d.Poly3DCollection(org_vs[org_faces],facecolors=faces_color,edgecolors=r2h((0, 0, 0)),linewidths=0.1,# linestyles='dashdot',alpha=1)ax.add_collection3d(tri1)max_x = org_vs[:, 0].max()# 高分辨率网格Ground truthfaces_color = []for l in org_label.astype(int):faces_color.append(f_colors[l - 1])org_vs[:, 0] += vs[:, 0].max() / 2 + vis_biastri2 = a3.art3d.Poly3DCollection(org_vs[org_faces],facecolors=faces_color,edgecolors=r2h((0, 0, 0)),linewidths=0.1,# linestyles='dashdot',alpha=1)ax.add_collection3d(tri2)max_x = org_vs[:, 0].max()ax.auto_scale_xyz([0, max_x], [0, max_x], [0, max_x])ax.view_init(0, -90)pl.tight_layout()pl.savefig('corr.png', dpi=1000) # ipl.show()breakprint(all_acc / len(sim_paths))print(sim_acc / len(sim_paths))
MeshWalker: Deep Mesh Understanding by Random Walks ↩︎
A Spectral Segmentation Method for Large Meshes ↩︎
MeshCNN ↩︎