Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| "LiftFeat: 3D Geometry-Aware Local Feature Matching" | |
| MegaDepth data handling was adapted from | |
| LoFTR official code: https://github.com/zju3dv/LoFTR/blob/master/src/datasets/megadepth.py | |
| """ | |
| import torch | |
| from kornia.utils import create_meshgrid | |
| import matplotlib.pyplot as plt | |
| import pdb | |
| import cv2 | |
| def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1): | |
| """ Warp kpts0 from I0 to I1 with depth, K and Rt | |
| Also check covisibility and depth consistency. | |
| Depth is consistent if relative error < 0.2 (hard-coded). | |
| Args: | |
| kpts0 (torch.Tensor): [N, L, 2] - <x, y>, | |
| depth0 (torch.Tensor): [N, H, W], | |
| depth1 (torch.Tensor): [N, H, W], | |
| T_0to1 (torch.Tensor): [N, 3, 4], | |
| K0 (torch.Tensor): [N, 3, 3], | |
| K1 (torch.Tensor): [N, 3, 3], | |
| Returns: | |
| calculable_mask (torch.Tensor): [N, L] | |
| warped_keypoints0 (torch.Tensor): [N, L, 2] <x0_hat, y1_hat> | |
| """ | |
| kpts0_long = kpts0.round().long().clip(0, 2000-1) | |
| depth0[:, 0, :] = 0 ; depth1[:, 0, :] = 0 | |
| depth0[:, :, 0] = 0 ; depth1[:, :, 0] = 0 | |
| # Sample depth, get calculable_mask on depth != 0 | |
| kpts0_depth = torch.stack( | |
| [depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]] for i in range(kpts0.shape[0])], dim=0 | |
| ) # (N, L) | |
| nonzero_mask = kpts0_depth > 0 | |
| # Draw cross marks on the image for each keypoint | |
| # for b in range(len(kpts0)): | |
| # fig, ax = plt.subplots(1,2) | |
| # depth_np = depth0.numpy()[b] | |
| # depth_np_plot = depth_np.copy() | |
| # for x, y in kpts0_long[b, nonzero_mask[b], :].numpy(): | |
| # cv2.drawMarker(depth_np_plot, (x, y), (255), cv2.MARKER_CROSS, markerSize=10, thickness=2) | |
| # ax[0].imshow(depth_np) | |
| # ax[1].imshow(depth_np_plot) | |
| # Unproject | |
| kpts0_h = torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) * kpts0_depth[..., None] # (N, L, 3) | |
| kpts0_cam = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L) | |
| # Rigid Transform | |
| w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L) | |
| w_kpts0_depth_computed = w_kpts0_cam[:, 2, :] | |
| # Project | |
| w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3) | |
| w_kpts0 = w_kpts0_h[:, :, :2] / (w_kpts0_h[:, :, [2]] + 1e-5) # (N, L, 2), +1e-4 to avoid zero depth | |
| # Covisible Check | |
| # h, w = depth1.shape[1:3] | |
| # covisible_mask = (w_kpts0[:, :, 0] > 0) * (w_kpts0[:, :, 0] < w-1) * \ | |
| # (w_kpts0[:, :, 1] > 0) * (w_kpts0[:, :, 1] < h-1) | |
| # w_kpts0_long = w_kpts0.long() | |
| # w_kpts0_long[~covisible_mask, :] = 0 | |
| # w_kpts0_depth = torch.stack( | |
| # [depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]] for i in range(w_kpts0_long.shape[0])], dim=0 | |
| # ) # (N, L) | |
| # consistent_mask = ((w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth).abs() < 0.2 | |
| valid_mask = nonzero_mask #* consistent_mask* covisible_mask | |
| return valid_mask, w_kpts0 | |
| def spvs_coarse(data, scale = 8): | |
| """ | |
| Supervise corresp with dense depth & camera poses | |
| """ | |
| # 1. misc | |
| device = data['image0'].device | |
| N, _, H0, W0 = data['image0'].shape | |
| _, _, H1, W1 = data['image1'].shape | |
| #scale = 8 | |
| scale0 = scale * data['scale0'][:, None] if 'scale0' in data else scale | |
| scale1 = scale * data['scale1'][:, None] if 'scale1' in data else scale | |
| h0, w0, h1, w1 = map(lambda x: x // scale, [H0, W0, H1, W1]) | |
| # 2. warp grids | |
| # create kpts in meshgrid and resize them to image resolution | |
| grid_pt1_c = create_meshgrid(h1, w1, False, device).reshape(1, h1*w1, 2).repeat(N, 1, 1) # [N, hw, 2] | |
| grid_pt1_i = scale1 * grid_pt1_c | |
| # warp kpts bi-directionally and check reproj error | |
| nonzero_m1, w_pt1_i = warp_kpts(grid_pt1_i, data['depth1'], data['depth0'], data['T_1to0'], data['K1'], data['K0']) | |
| nonzero_m2, w_pt1_og = warp_kpts( w_pt1_i, data['depth0'], data['depth1'], data['T_0to1'], data['K0'], data['K1']) | |
| dist = torch.linalg.norm( grid_pt1_i - w_pt1_og, dim=-1) | |
| mask_mutual = (dist < 1.5) & nonzero_m1 & nonzero_m2 | |
| #_, w_pt1_i = warp_kpts(grid_pt1_i, data['depth1'], data['depth0'], data['T_1to0'], data['K1'], data['K0']) | |
| batched_corrs = [ torch.cat([w_pt1_i[i, mask_mutual[i]] / data['scale0'][i], | |
| grid_pt1_i[i, mask_mutual[i]] / data['scale1'][i]],dim=-1) for i in range(len(mask_mutual))] | |
| #Remove repeated correspondences - this is important for network convergence | |
| corrs = [] | |
| for pts in batched_corrs: | |
| lut_mat12 = torch.ones((h1, w1, 4), device = device, dtype = torch.float32) * -1 | |
| lut_mat21 = torch.clone(lut_mat12) | |
| src_pts = pts[:, :2] / scale | |
| tgt_pts = pts[:, 2:] / scale | |
| try: | |
| lut_mat12[src_pts[:,1].long(), src_pts[:,0].long()] = torch.cat([src_pts, tgt_pts], dim=1) | |
| mask_valid12 = torch.all(lut_mat12 >= 0, dim=-1) | |
| points = lut_mat12[mask_valid12] | |
| #Target-src check | |
| src_pts, tgt_pts = points[:, :2], points[:, 2:] | |
| lut_mat21[tgt_pts[:,1].long(), tgt_pts[:,0].long()] = torch.cat([src_pts, tgt_pts], dim=1) | |
| mask_valid21 = torch.all(lut_mat21 >= 0, dim=-1) | |
| points = lut_mat21[mask_valid21] | |
| corrs.append(points) | |
| except: | |
| pdb.set_trace() | |
| print('..') | |
| #Plot for debug purposes | |
| # for i in range(len(corrs)): | |
| # plot_corrs(data['image0'][i], data['image1'][i], corrs[i][:, :2]*8, corrs[i][:, 2:]*8) | |
| return corrs | |
| def get_correspondences(pts2, data, idx): | |
| device = data['image0'].device | |
| N, _, H0, W0 = data['image0'].shape | |
| _, _, H1, W1 = data['image1'].shape | |
| pts2 = pts2[None, ...] | |
| scale0 = data['scale0'][idx, None][None, ...] if 'scale0' in data else 1 | |
| scale1 = data['scale1'][idx, None][None, ...] if 'scale1' in data else 1 | |
| pts2 = scale1 * pts2 * 8 | |
| # warp kpts bi-directionally and check reproj error | |
| nonzero_m1, pts1 = warp_kpts(pts2, data['depth1'][idx][None, ...], data['depth0'][idx][None, ...], data['T_1to0'][idx][None, ...], | |
| data['K1'][idx][None, ...], data['K0'][idx][None, ...]) | |
| corrs = torch.cat([pts1[0, :] / data['scale0'][idx], | |
| pts2[0, :] / data['scale1'][idx]],dim=-1) | |
| #plot_corrs(data['image0'][idx], data['image1'][idx], corrs[:, :2], corrs[:, 2:]) | |
| return corrs | |