Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Implementation of the line matching methods. | |
| """ | |
| import numpy as np | |
| import cv2 | |
| import torch | |
| import torch.nn.functional as F | |
| from ..misc.geometry_utils import keypoints_to_grid | |
| class WunschLineMatcher(object): | |
| """ Class matching two sets of line segments | |
| with the Needleman-Wunsch algorithm. """ | |
| def __init__(self, cross_check=True, num_samples=10, min_dist_pts=8, | |
| top_k_candidates=10, grid_size=8, sampling="regular", | |
| line_score=False): | |
| self.cross_check = cross_check | |
| self.num_samples = num_samples | |
| self.min_dist_pts = min_dist_pts | |
| self.top_k_candidates = top_k_candidates | |
| self.grid_size = grid_size | |
| self.line_score = line_score # True to compute saliency on a line | |
| self.sampling_mode = sampling | |
| if sampling not in ["regular", "d2_net", "asl_feat"]: | |
| raise ValueError("Wrong sampling mode: " + sampling) | |
| def forward(self, line_seg1, line_seg2, desc1, desc2): | |
| """ | |
| Find the best matches between two sets of line segments | |
| and their corresponding descriptors. | |
| """ | |
| img_size1 = (desc1.shape[2] * self.grid_size, | |
| desc1.shape[3] * self.grid_size) | |
| img_size2 = (desc2.shape[2] * self.grid_size, | |
| desc2.shape[3] * self.grid_size) | |
| device = desc1.device | |
| # Default case when an image has no lines | |
| if len(line_seg1) == 0: | |
| return np.empty((0), dtype=int) | |
| if len(line_seg2) == 0: | |
| return -np.ones(len(line_seg1), dtype=int) | |
| # Sample points regularly along each line | |
| if self.sampling_mode == "regular": | |
| line_points1, valid_points1 = self.sample_line_points(line_seg1) | |
| line_points2, valid_points2 = self.sample_line_points(line_seg2) | |
| else: | |
| line_points1, valid_points1 = self.sample_salient_points( | |
| line_seg1, desc1, img_size1, self.sampling_mode) | |
| line_points2, valid_points2 = self.sample_salient_points( | |
| line_seg2, desc2, img_size2, self.sampling_mode) | |
| line_points1 = torch.tensor(line_points1.reshape(-1, 2), | |
| dtype=torch.float, device=device) | |
| line_points2 = torch.tensor(line_points2.reshape(-1, 2), | |
| dtype=torch.float, device=device) | |
| # Extract the descriptors for each point | |
| grid1 = keypoints_to_grid(line_points1, img_size1) | |
| grid2 = keypoints_to_grid(line_points2, img_size2) | |
| desc1 = F.normalize(F.grid_sample(desc1, grid1)[0, :, :, 0], dim=0) | |
| desc2 = F.normalize(F.grid_sample(desc2, grid2)[0, :, :, 0], dim=0) | |
| # Precompute the distance between line points for every pair of lines | |
| # Assign a score of -1 for unvalid points | |
| scores = desc1.t() @ desc2 | |
| scores[~valid_points1.flatten()] = -1 | |
| scores[:, ~valid_points2.flatten()] = -1 | |
| scores = scores.reshape(len(line_seg1), self.num_samples, | |
| len(line_seg2), self.num_samples) | |
| scores = scores.permute(0, 2, 1, 3) | |
| # scores.shape = (n_lines1, n_lines2, num_samples, num_samples) | |
| # Pre-filter the line candidates and find the best match for each line | |
| matches = self.filter_and_match_lines(scores) | |
| # [Optionally] filter matches with mutual nearest neighbor filtering | |
| if self.cross_check: | |
| matches2 = self.filter_and_match_lines( | |
| scores.permute(1, 0, 3, 2)) | |
| mutual = matches2[matches] == np.arange(len(line_seg1)) | |
| matches[~mutual] = -1 | |
| return matches | |
| def d2_net_saliency_score(self, desc): | |
| """ Compute the D2-Net saliency score | |
| on a 3D or 4D descriptor. """ | |
| is_3d = len(desc.shape) == 3 | |
| b_size = len(desc) | |
| feat = F.relu(desc) | |
| # Compute the soft local max | |
| exp = torch.exp(feat) | |
| if is_3d: | |
| sum_exp = 3 * F.avg_pool1d(exp, kernel_size=3, stride=1, | |
| padding=1) | |
| else: | |
| sum_exp = 9 * F.avg_pool2d(exp, kernel_size=3, stride=1, | |
| padding=1) | |
| soft_local_max = exp / sum_exp | |
| # Compute the depth-wise maximum | |
| depth_wise_max = torch.max(feat, dim=1)[0] | |
| depth_wise_max = feat / depth_wise_max.unsqueeze(1) | |
| # Total saliency score | |
| score = torch.max(soft_local_max * depth_wise_max, dim=1)[0] | |
| normalization = torch.sum(score.reshape(b_size, -1), dim=1) | |
| if is_3d: | |
| normalization = normalization.reshape(b_size, 1) | |
| else: | |
| normalization = normalization.reshape(b_size, 1, 1) | |
| score = score / normalization | |
| return score | |
| def asl_feat_saliency_score(self, desc): | |
| """ Compute the ASLFeat saliency score on a 3D or 4D descriptor. """ | |
| is_3d = len(desc.shape) == 3 | |
| b_size = len(desc) | |
| # Compute the soft local peakiness | |
| if is_3d: | |
| local_avg = F.avg_pool1d(desc, kernel_size=3, stride=1, padding=1) | |
| else: | |
| local_avg = F.avg_pool2d(desc, kernel_size=3, stride=1, padding=1) | |
| soft_local_score = F.softplus(desc - local_avg) | |
| # Compute the depth-wise peakiness | |
| depth_wise_mean = torch.mean(desc, dim=1).unsqueeze(1) | |
| depth_wise_score = F.softplus(desc - depth_wise_mean) | |
| # Total saliency score | |
| score = torch.max(soft_local_score * depth_wise_score, dim=1)[0] | |
| normalization = torch.sum(score.reshape(b_size, -1), dim=1) | |
| if is_3d: | |
| normalization = normalization.reshape(b_size, 1) | |
| else: | |
| normalization = normalization.reshape(b_size, 1, 1) | |
| score = score / normalization | |
| return score | |
| def sample_salient_points(self, line_seg, desc, img_size, | |
| saliency_type='d2_net'): | |
| """ | |
| Sample the most salient points along each line segments, with a | |
| minimal distance between each point. Pad the remaining points. | |
| Inputs: | |
| line_seg: an Nx2x2 torch.Tensor. | |
| desc: a NxDxHxW torch.Tensor. | |
| image_size: the original image size. | |
| saliency_type: 'd2_net' or 'asl_feat'. | |
| Outputs: | |
| line_points: an Nxnum_samplesx2 np.array. | |
| valid_points: a boolean Nxnum_samples np.array. | |
| """ | |
| device = desc.device | |
| if not self.line_score: | |
| # Compute the score map | |
| if saliency_type == "d2_net": | |
| score = self.d2_net_saliency_score(desc) | |
| else: | |
| score = self.asl_feat_saliency_score(desc) | |
| num_lines = len(line_seg) | |
| line_lengths = np.linalg.norm(line_seg[:, 0] - line_seg[:, 1], axis=1) | |
| # The number of samples depends on the length of the line | |
| num_samples_lst = np.clip(line_lengths // self.min_dist_pts, | |
| 2, self.num_samples) | |
| line_points = np.empty((num_lines, self.num_samples, 2), dtype=float) | |
| valid_points = np.empty((num_lines, self.num_samples), dtype=bool) | |
| # Sample the score on a fixed number of points of each line | |
| n_samples_per_region = 4 | |
| for n in np.arange(2, self.num_samples + 1): | |
| sample_rate = n * n_samples_per_region | |
| # Consider all lines where we can fit up to n points | |
| cur_mask = num_samples_lst == n | |
| cur_line_seg = line_seg[cur_mask] | |
| cur_num_lines = len(cur_line_seg) | |
| if cur_num_lines == 0: | |
| continue | |
| line_points_x = np.linspace(cur_line_seg[:, 0, 0], | |
| cur_line_seg[:, 1, 0], | |
| sample_rate, axis=-1) | |
| line_points_y = np.linspace(cur_line_seg[:, 0, 1], | |
| cur_line_seg[:, 1, 1], | |
| sample_rate, axis=-1) | |
| cur_line_points = np.stack([line_points_x, line_points_y], | |
| axis=-1).reshape(-1, 2) | |
| # cur_line_points is of shape (n_cur_lines * sample_rate, 2) | |
| cur_line_points = torch.tensor(cur_line_points, dtype=torch.float, | |
| device=device) | |
| grid_points = keypoints_to_grid(cur_line_points, img_size) | |
| if self.line_score: | |
| # The saliency score is high when the activation are locally | |
| # maximal along the line (and not in a square neigborhood) | |
| line_desc = F.grid_sample(desc, grid_points).squeeze() | |
| line_desc = line_desc.reshape(-1, cur_num_lines, sample_rate) | |
| line_desc = line_desc.permute(1, 0, 2) | |
| if saliency_type == "d2_net": | |
| scores = self.d2_net_saliency_score(line_desc) | |
| else: | |
| scores = self.asl_feat_saliency_score(line_desc) | |
| else: | |
| scores = F.grid_sample(score.unsqueeze(1), | |
| grid_points).squeeze() | |
| # Take the most salient point in n distinct regions | |
| scores = scores.reshape(-1, n, n_samples_per_region) | |
| best = torch.max(scores, dim=2, keepdim=True)[1].cpu().numpy() | |
| cur_line_points = cur_line_points.reshape(-1, n, | |
| n_samples_per_region, 2) | |
| cur_line_points = np.take_along_axis( | |
| cur_line_points, best[..., None], axis=2)[:, :, 0] | |
| # Pad | |
| cur_valid_points = np.ones((cur_num_lines, self.num_samples), | |
| dtype=bool) | |
| cur_valid_points[:, n:] = False | |
| cur_line_points = np.concatenate([ | |
| cur_line_points, | |
| np.zeros((cur_num_lines, self.num_samples - n, 2), dtype=float)], | |
| axis=1) | |
| line_points[cur_mask] = cur_line_points | |
| valid_points[cur_mask] = cur_valid_points | |
| return line_points, valid_points | |
| def sample_line_points(self, line_seg): | |
| """ | |
| Regularly sample points along each line segments, with a minimal | |
| distance between each point. Pad the remaining points. | |
| Inputs: | |
| line_seg: an Nx2x2 torch.Tensor. | |
| Outputs: | |
| line_points: an Nxnum_samplesx2 np.array. | |
| valid_points: a boolean Nxnum_samples np.array. | |
| """ | |
| num_lines = len(line_seg) | |
| line_lengths = np.linalg.norm(line_seg[:, 0] - line_seg[:, 1], axis=1) | |
| # Sample the points separated by at least min_dist_pts along each line | |
| # The number of samples depends on the length of the line | |
| num_samples_lst = np.clip(line_lengths // self.min_dist_pts, | |
| 2, self.num_samples) | |
| line_points = np.empty((num_lines, self.num_samples, 2), dtype=float) | |
| valid_points = np.empty((num_lines, self.num_samples), dtype=bool) | |
| for n in np.arange(2, self.num_samples + 1): | |
| # Consider all lines where we can fit up to n points | |
| cur_mask = num_samples_lst == n | |
| cur_line_seg = line_seg[cur_mask] | |
| line_points_x = np.linspace(cur_line_seg[:, 0, 0], | |
| cur_line_seg[:, 1, 0], | |
| n, axis=-1) | |
| line_points_y = np.linspace(cur_line_seg[:, 0, 1], | |
| cur_line_seg[:, 1, 1], | |
| n, axis=-1) | |
| cur_line_points = np.stack([line_points_x, line_points_y], axis=-1) | |
| # Pad | |
| cur_num_lines = len(cur_line_seg) | |
| cur_valid_points = np.ones((cur_num_lines, self.num_samples), | |
| dtype=bool) | |
| cur_valid_points[:, n:] = False | |
| cur_line_points = np.concatenate([ | |
| cur_line_points, | |
| np.zeros((cur_num_lines, self.num_samples - n, 2), dtype=float)], | |
| axis=1) | |
| line_points[cur_mask] = cur_line_points | |
| valid_points[cur_mask] = cur_valid_points | |
| return line_points, valid_points | |
| def filter_and_match_lines(self, scores): | |
| """ | |
| Use the scores to keep the top k best lines, compute the Needleman- | |
| Wunsch algorithm on each candidate pairs, and keep the highest score. | |
| Inputs: | |
| scores: a (N, M, n, n) torch.Tensor containing the pairwise scores | |
| of the elements to match. | |
| Outputs: | |
| matches: a (N) np.array containing the indices of the best match | |
| """ | |
| # Pre-filter the pairs and keep the top k best candidate lines | |
| line_scores1 = scores.max(3)[0] | |
| valid_scores1 = line_scores1 != -1 | |
| line_scores1 = ((line_scores1 * valid_scores1).sum(2) | |
| / valid_scores1.sum(2)) | |
| line_scores2 = scores.max(2)[0] | |
| valid_scores2 = line_scores2 != -1 | |
| line_scores2 = ((line_scores2 * valid_scores2).sum(2) | |
| / valid_scores2.sum(2)) | |
| line_scores = (line_scores1 + line_scores2) / 2 | |
| topk_lines = torch.argsort(line_scores, | |
| dim=1)[:, -self.top_k_candidates:] | |
| scores, topk_lines = scores.cpu().numpy(), topk_lines.cpu().numpy() | |
| # topk_lines.shape = (n_lines1, top_k_candidates) | |
| top_scores = np.take_along_axis(scores, topk_lines[:, :, None, None], | |
| axis=1) | |
| # Consider the reversed line segments as well | |
| top_scores = np.concatenate([top_scores, top_scores[..., ::-1]], | |
| axis=1) | |
| # Compute the line distance matrix with Needleman-Wunsch algo and | |
| # retrieve the closest line neighbor | |
| n_lines1, top2k, n, m = top_scores.shape | |
| top_scores = top_scores.reshape(n_lines1 * top2k, n, m) | |
| nw_scores = self.needleman_wunsch(top_scores) | |
| nw_scores = nw_scores.reshape(n_lines1, top2k) | |
| matches = np.mod(np.argmax(nw_scores, axis=1), top2k // 2) | |
| matches = topk_lines[np.arange(n_lines1), matches] | |
| return matches | |
| def needleman_wunsch(self, scores): | |
| """ | |
| Batched implementation of the Needleman-Wunsch algorithm. | |
| The cost of the InDel operation is set to 0 by subtracting the gap | |
| penalty to the scores. | |
| Inputs: | |
| scores: a (B, N, M) np.array containing the pairwise scores | |
| of the elements to match. | |
| """ | |
| b, n, m = scores.shape | |
| # Recalibrate the scores to get a gap score of 0 | |
| gap = 0.1 | |
| nw_scores = scores - gap | |
| # Run the dynamic programming algorithm | |
| nw_grid = np.zeros((b, n + 1, m + 1), dtype=float) | |
| for i in range(n): | |
| for j in range(m): | |
| nw_grid[:, i + 1, j + 1] = np.maximum( | |
| np.maximum(nw_grid[:, i + 1, j], nw_grid[:, i, j + 1]), | |
| nw_grid[:, i, j] + nw_scores[:, i, j]) | |
| return nw_grid[:, -1, -1] | |
| def get_pairwise_distance(self, line_seg1, line_seg2, desc1, desc2): | |
| """ | |
| Compute the OPPOSITE of the NW score for pairs of line segments | |
| and their corresponding descriptors. | |
| """ | |
| num_lines = len(line_seg1) | |
| assert num_lines == len(line_seg2), "The same number of lines is required in pairwise score." | |
| img_size1 = (desc1.shape[2] * self.grid_size, | |
| desc1.shape[3] * self.grid_size) | |
| img_size2 = (desc2.shape[2] * self.grid_size, | |
| desc2.shape[3] * self.grid_size) | |
| device = desc1.device | |
| # Sample points regularly along each line | |
| line_points1, valid_points1 = self.sample_line_points(line_seg1) | |
| line_points2, valid_points2 = self.sample_line_points(line_seg2) | |
| line_points1 = torch.tensor(line_points1.reshape(-1, 2), | |
| dtype=torch.float, device=device) | |
| line_points2 = torch.tensor(line_points2.reshape(-1, 2), | |
| dtype=torch.float, device=device) | |
| # Extract the descriptors for each point | |
| grid1 = keypoints_to_grid(line_points1, img_size1) | |
| grid2 = keypoints_to_grid(line_points2, img_size2) | |
| desc1 = F.normalize(F.grid_sample(desc1, grid1)[0, :, :, 0], dim=0) | |
| desc1 = desc1.reshape(-1, num_lines, self.num_samples) | |
| desc2 = F.normalize(F.grid_sample(desc2, grid2)[0, :, :, 0], dim=0) | |
| desc2 = desc2.reshape(-1, num_lines, self.num_samples) | |
| # Compute the distance between line points for every pair of lines | |
| # Assign a score of -1 for unvalid points | |
| scores = torch.einsum('dns,dnt->nst', desc1, desc2).cpu().numpy() | |
| scores = scores.reshape(num_lines * self.num_samples, | |
| self.num_samples) | |
| scores[~valid_points1.flatten()] = -1 | |
| scores = scores.reshape(num_lines, self.num_samples, self.num_samples) | |
| scores = scores.transpose(1, 0, 2).reshape(self.num_samples, -1) | |
| scores[:, ~valid_points2.flatten()] = -1 | |
| scores = scores.reshape(self.num_samples, num_lines, self.num_samples) | |
| scores = scores.transpose(1, 0, 2) | |
| # scores.shape = (num_lines, num_samples, num_samples) | |
| # Compute the NW score for each pair of lines | |
| pairwise_scores = np.array([self.needleman_wunsch(s) for s in scores]) | |
| return -pairwise_scores | |