Spaces:
Runtime error
Runtime error
| """ | |
| A two-view sparse feature matching pipeline. | |
| This model contains sub-models for each step: | |
| feature extraction, feature matching, outlier filtering, pose estimation. | |
| Each step is optional, and the features or matches can be provided as input. | |
| Default: SuperPoint with nearest neighbor matching. | |
| Convention for the matches: m0[i] is the index of the keypoint in image 1 | |
| that corresponds to the keypoint i in image 0. m0[i] = -1 if i is unmatched. | |
| """ | |
| import numpy as np | |
| import torch | |
| from gluestick import get_model | |
| from gluestick.models.base_model import BaseModel | |
| from line_matching.wireframe import SPWireframeDescriptor | |
| def keep_quadrant_kp_subset(keypoints, scores, descs, h, w): | |
| """Keep only keypoints in one of the four quadrant of the image.""" | |
| h2, w2 = h // 2, w // 2 | |
| w_x = np.random.choice([0, w2]) | |
| w_y = np.random.choice([0, h2]) | |
| valid_mask = ((keypoints[..., 0] >= w_x) | |
| & (keypoints[..., 0] < w_x + w2) | |
| & (keypoints[..., 1] >= w_y) | |
| & (keypoints[..., 1] < w_y + h2)) | |
| keypoints = keypoints[valid_mask][None] | |
| scores = scores[valid_mask][None] | |
| descs = descs.permute(0, 2, 1)[valid_mask].t()[None] | |
| return keypoints, scores, descs | |
| def keep_random_kp_subset(keypoints, scores, descs, num_selected): | |
| """Keep a random subset of keypoints.""" | |
| num_kp = keypoints.shape[1] | |
| selected_kp = torch.randperm(num_kp)[:num_selected] | |
| keypoints = keypoints[:, selected_kp] | |
| scores = scores[:, selected_kp] | |
| descs = descs[:, :, selected_kp] | |
| return keypoints, scores, descs | |
| def keep_best_kp_subset(keypoints, scores, descs, num_selected): | |
| """Keep the top num_selected best keypoints.""" | |
| sorted_indices = torch.sort(scores, dim=1)[1] | |
| selected_kp = sorted_indices[:, -num_selected:] | |
| keypoints = torch.gather(keypoints, 1, | |
| selected_kp[:, :, None].repeat(1, 1, 2)) | |
| scores = torch.gather(scores, 1, selected_kp) | |
| descs = torch.gather(descs, 2, | |
| selected_kp[:, None].repeat(1, descs.shape[1], 1)) | |
| return keypoints, scores, descs | |
| class TwoViewPipeline(BaseModel): | |
| default_conf = { | |
| 'extractor': { | |
| 'name': 'superpoint', | |
| 'trainable': False, | |
| }, | |
| 'use_lines': False, | |
| 'use_points': True, | |
| 'randomize_num_kp': False, | |
| 'detector': {'name': None}, | |
| 'descriptor': {'name': None}, | |
| 'matcher': {'name': 'nearest_neighbor_matcher'}, | |
| 'filter': {'name': None}, | |
| 'solver': {'name': None}, | |
| 'ground_truth': { | |
| 'from_pose_depth': False, | |
| 'from_homography': False, | |
| 'th_positive': 3, | |
| 'th_negative': 5, | |
| 'reward_positive': 1, | |
| 'reward_negative': -0.25, | |
| 'is_likelihood_soft': True, | |
| 'p_random_occluders': 0, | |
| 'n_line_sampled_pts': 50, | |
| 'line_perp_dist_th': 5, | |
| 'overlap_th': 0.2, | |
| 'min_visibility_th': 0.5 | |
| }, | |
| } | |
| required_data_keys = ['image0', 'image1'] | |
| strict_conf = False # need to pass new confs to children models | |
| components = [ | |
| 'extractor', 'detector', 'descriptor', 'matcher', 'filter', 'solver'] | |
| def _init(self, conf): | |
| if conf.extractor.name: | |
| self.extractor = SPWireframeDescriptor(conf.extractor) | |
| if conf.matcher.name: | |
| self.matcher = get_model(conf.matcher.name)(conf.matcher) | |
| else: | |
| self.required_data_keys += ['matches0'] | |
| if conf.filter.name: | |
| self.filter = get_model(conf.filter.name)(conf.filter) | |
| if conf.solver.name: | |
| self.solver = get_model(conf.solver.name)(conf.solver) | |
| def _forward(self, data): | |
| def process_siamese(data, i): | |
| data_i = {k[:-1]: v for k, v in data.items() if k[-1] == i} | |
| if self.conf.extractor.name: | |
| pred_i = self.extractor(data_i) | |
| else: | |
| pred_i = {} | |
| if self.conf.detector.name: | |
| pred_i = self.detector(data_i) | |
| else: | |
| for k in ['keypoints', 'keypoint_scores', 'descriptors', | |
| 'lines', 'line_scores', 'line_descriptors', | |
| 'valid_lines']: | |
| if k in data_i: | |
| pred_i[k] = data_i[k] | |
| if self.conf.descriptor.name: | |
| pred_i = { | |
| **pred_i, **self.descriptor({**data_i, **pred_i})} | |
| return pred_i | |
| pred0 = process_siamese(data, '0') | |
| pred1 = process_siamese(data, '1') | |
| pred = {**{k + '0': v for k, v in pred0.items()}, | |
| **{k + '1': v for k, v in pred1.items()}} | |
| if self.conf.matcher.name: | |
| pred = {**pred, **self.matcher({**data, **pred})} | |
| if self.conf.filter.name: | |
| pred = {**pred, **self.filter({**data, **pred})} | |
| if self.conf.solver.name: | |
| pred = {**pred, **self.solver({**data, **pred})} | |
| return pred | |
| def loss(self, pred, data): | |
| losses = {} | |
| total = 0 | |
| for k in self.components: | |
| if self.conf[k].name: | |
| try: | |
| losses_ = getattr(self, k).loss(pred, {**pred, **data}) | |
| except NotImplementedError: | |
| continue | |
| losses = {**losses, **losses_} | |
| total = losses_['total'] + total | |
| return {**losses, 'total': total} | |
| def metrics(self, pred, data): | |
| metrics = {} | |
| for k in self.components: | |
| if self.conf[k].name: | |
| try: | |
| metrics_ = getattr(self, k).metrics(pred, {**pred, **data}) | |
| except NotImplementedError: | |
| continue | |
| metrics = {**metrics, **metrics_} | |
| return metrics | |