diff --git a/hy3dshape/tools/evaluation/chamfer_distance.py b/hy3dshape/tools/evaluation/chamfer_distance.py index 658b108..1c0aab4 100644 --- a/hy3dshape/tools/evaluation/chamfer_distance.py +++ b/hy3dshape/tools/evaluation/chamfer_distance.py @@ -9,6 +9,30 @@ except Exception: except Exception: raise ImportError("Requires scipy.spatial.cKDTree or sklearn.neighbors. Install scipy or scikit-learn.") +def normalize_point_cloud_dimension(points): + """ + 将点云数据按维度独立归一化到[-1, 1]范围。 + + 参数: + points (np.ndarray): 输入点云,形状为(n, 3)。 + + 返回: + np.ndarray: 归一化后的点云。 + tuple: 每个维度的最小值,用于反归一化。 + tuple: 每个维度的最大值,用于反归一化。 + """ + # 计算每个维度(列)的最小值和最大值 + min_vals = np.min(points, axis=0) + max_vals = np.max(points, axis=0) + + # 计算每个维度的范围,避免除以零 + ranges = max_vals - min_vals + ranges[ranges == 0] = 1e-8 # 如果某个维度值全相同,则范围设为一个小值 + + normalized_points = (points - min_vals) / ranges # 先归一化到[0,1] + normalized_points = normalized_points * 2 - 1 # 再映射到[-1,1] + + return normalized_points, min_vals, max_vals def sample_points_from_mesh(vertices: np.ndarray, faces: np.ndarray, n_samples: int) -> np.ndarray: """ @@ -21,6 +45,7 @@ def sample_points_from_mesh(vertices: np.ndarray, faces: np.ndarray, n_samples: Returns: (n_samples, 3) sampled points (float32) """ v = np.asarray(vertices).reshape(-1, 3).astype(np.float64) + v, _, _ = normalize_point_cloud_dimension(v) f = np.asarray(faces).reshape(-1, 3).astype(np.int64) v0 = v[f[:, 0], :] @@ -116,15 +141,8 @@ def chamfer_distance_from_meshes(pred_vertices: np.ndarray, B_to_A_l2 = float(np.mean(d_gt_to_pred)) cd_l2 = 0.5 * (A_to_B_l2 + B_to_A_l2) - A_to_B_l2_sq = float(np.mean(d_pred_to_gt ** 2)) - B_to_A_l2_sq = float(np.mean(d_gt_to_pred ** 2)) - cd_l2_sq = 0.5 * (A_to_B_l2_sq + B_to_A_l2_sq) - metrics = { - 'cd_l2_sq': cd_l2_sq, 'cd_l2': cd_l2, - 'A_to_B_l2_sq': A_to_B_l2_sq, - 'B_to_A_l2_sq': B_to_A_l2_sq, 'A_to_B_l2': A_to_B_l2, 'B_to_A_l2': B_to_A_l2, 'n_samples_per_mesh': n_samples, @@ -148,5 +166,5 @@ if __name__ == "__main__": gt_faces = pred_faces.copy() metrics = chamfer_distance_from_meshes(pred_verts, pred_faces, gt_verts, gt_faces, - n_samples=20000) + n_samples=100000) print("Chamfer metrics:", metrics)