Update chamfer_distance.py
This commit is contained in:
@@ -9,6 +9,30 @@ except Exception:
|
||||
except Exception:
|
||||
raise ImportError("Requires scipy.spatial.cKDTree or sklearn.neighbors. Install scipy or scikit-learn.")
|
||||
|
||||
def normalize_point_cloud_dimension(points):
|
||||
"""
|
||||
将点云数据按维度独立归一化到[-1, 1]范围。
|
||||
|
||||
参数:
|
||||
points (np.ndarray): 输入点云,形状为(n, 3)。
|
||||
|
||||
返回:
|
||||
np.ndarray: 归一化后的点云。
|
||||
tuple: 每个维度的最小值,用于反归一化。
|
||||
tuple: 每个维度的最大值,用于反归一化。
|
||||
"""
|
||||
# 计算每个维度(列)的最小值和最大值
|
||||
min_vals = np.min(points, axis=0)
|
||||
max_vals = np.max(points, axis=0)
|
||||
|
||||
# 计算每个维度的范围,避免除以零
|
||||
ranges = max_vals - min_vals
|
||||
ranges[ranges == 0] = 1e-8 # 如果某个维度值全相同,则范围设为一个小值
|
||||
|
||||
normalized_points = (points - min_vals) / ranges # 先归一化到[0,1]
|
||||
normalized_points = normalized_points * 2 - 1 # 再映射到[-1,1]
|
||||
|
||||
return normalized_points, min_vals, max_vals
|
||||
|
||||
def sample_points_from_mesh(vertices: np.ndarray, faces: np.ndarray, n_samples: int) -> np.ndarray:
|
||||
"""
|
||||
@@ -21,6 +45,7 @@ def sample_points_from_mesh(vertices: np.ndarray, faces: np.ndarray, n_samples:
|
||||
Returns: (n_samples, 3) sampled points (float32)
|
||||
"""
|
||||
v = np.asarray(vertices).reshape(-1, 3).astype(np.float64)
|
||||
v, _, _ = normalize_point_cloud_dimension(v)
|
||||
f = np.asarray(faces).reshape(-1, 3).astype(np.int64)
|
||||
|
||||
v0 = v[f[:, 0], :]
|
||||
@@ -116,15 +141,8 @@ def chamfer_distance_from_meshes(pred_vertices: np.ndarray,
|
||||
B_to_A_l2 = float(np.mean(d_gt_to_pred))
|
||||
cd_l2 = 0.5 * (A_to_B_l2 + B_to_A_l2)
|
||||
|
||||
A_to_B_l2_sq = float(np.mean(d_pred_to_gt ** 2))
|
||||
B_to_A_l2_sq = float(np.mean(d_gt_to_pred ** 2))
|
||||
cd_l2_sq = 0.5 * (A_to_B_l2_sq + B_to_A_l2_sq)
|
||||
|
||||
metrics = {
|
||||
'cd_l2_sq': cd_l2_sq,
|
||||
'cd_l2': cd_l2,
|
||||
'A_to_B_l2_sq': A_to_B_l2_sq,
|
||||
'B_to_A_l2_sq': B_to_A_l2_sq,
|
||||
'A_to_B_l2': A_to_B_l2,
|
||||
'B_to_A_l2': B_to_A_l2,
|
||||
'n_samples_per_mesh': n_samples,
|
||||
@@ -148,5 +166,5 @@ if __name__ == "__main__":
|
||||
gt_faces = pred_faces.copy()
|
||||
|
||||
metrics = chamfer_distance_from_meshes(pred_verts, pred_faces, gt_verts, gt_faces,
|
||||
n_samples=20000)
|
||||
n_samples=100000)
|
||||
print("Chamfer metrics:", metrics)
|
||||
|
||||
Reference in New Issue
Block a user