""" Model Factory for MapAnything """ import importlib.util import logging import warnings import numpy as np from omegaconf import DictConfig, OmegaConf # Core models that are always available from mapanything.models.mapanything import ( MapAnything, MapAnythingAblations, ModularDUSt3R, ) # Suppress DINOv2 warnings logging.getLogger("dinov2").setLevel(logging.WARNING) warnings.filterwarnings("ignore", message="xFormers is available", category=UserWarning) warnings.filterwarnings( "ignore", message="xFormers is not available", category=UserWarning ) def resolve_special_float(value): if value == "inf": return np.inf elif value == "-inf": return -np.inf else: raise ValueError(f"Unknown special float value: {value}") def init_model( model_str: str, model_config: DictConfig, torch_hub_force_reload: bool = False ): """ Initialize a model using OmegaConf configuration. Args: model_str (str): Name of the model class to create. model_config (DictConfig): OmegaConf model configuration. torch_hub_force_reload (bool): Whether to force reload relevant parts of the model from torch hub. """ if not OmegaConf.has_resolver("special_float"): OmegaConf.register_new_resolver("special_float", resolve_special_float) model_dict = OmegaConf.to_container(model_config, resolve=True) model = model_factory( model_str, torch_hub_force_reload=torch_hub_force_reload, **model_dict ) return model # Define model configurations with import paths MODEL_CONFIGS = { # Core models "mapanything": { "class": MapAnything, }, "mapanything_ablations": { "class": MapAnythingAblations, }, "modular_dust3r": { "class": ModularDUSt3R, }, # External models "anycalib": { "module": "mapanything.models.external.anycalib", "class_name": "AnyCalibWrapper", }, "dust3r": { "module": "mapanything.models.external.dust3r", "class_name": "DUSt3RBAWrapper", }, "mast3r": { "module": "mapanything.models.external.mast3r", "class_name": "MASt3RSGAWrapper", }, "moge": { "module": "mapanything.models.external.moge", "class_name": "MoGeWrapper", }, "must3r": { "module": "mapanything.models.external.must3r", "class_name": "MUSt3RWrapper", }, "pi3": { "module": "mapanything.models.external.pi3", "class_name": "Pi3Wrapper", }, "pow3r": { "module": "mapanything.models.external.pow3r", "class_name": "Pow3RWrapper", }, "pow3r_ba": { "module": "mapanything.models.external.pow3r", "class_name": "Pow3RBAWrapper", }, "vggt": { "module": "mapanything.models.external.vggt", "class_name": "VGGTWrapper", }, # Add other model classes here } def check_module_exists(module_path): """ Check if a module can be imported without actually importing it. Args: module_path (str): The path to the module to check. Returns: bool: True if the module can be imported, False otherwise. """ return importlib.util.find_spec(module_path) is not None def model_factory(model_str: str, **kwargs): """ Model factory for MapAnything. Args: model_str (str): Name of the model to create. **kwargs: Additional keyword arguments to pass to the model constructor. Returns: nn.Module: An instance of the specified model. """ if model_str not in MODEL_CONFIGS: raise ValueError( f"Unknown model: {model_str}. Valid options are: {', '.join(MODEL_CONFIGS.keys())}" ) model_config = MODEL_CONFIGS[model_str] # Handle core models directly if "class" in model_config: model_class = model_config["class"] # Handle external models with dynamic imports elif "module" in model_config: module_path = model_config["module"] class_name = model_config["class_name"] # Check if the module can be imported if not check_module_exists(module_path): raise ImportError( f"Model '{model_str}' requires module '{module_path}' which is not installed. " f"Please install the corresponding submodule or package." ) # Dynamically import the module and get the class try: module = importlib.import_module(module_path) model_class = getattr(module, class_name) except (ImportError, AttributeError) as e: raise ImportError( f"Failed to import {class_name} from {module_path}: {str(e)}" ) else: raise ValueError(f"Invalid model configuration for {model_str}") print(f"Initializing {model_class} with kwargs: {kwargs}") if model_str != "org_dust3r": return model_class(**kwargs) else: eval_str = kwargs.get("model_eval_str", None) return eval(eval_str) def get_available_models() -> list: """ Get a list of available models in MapAnything. Returns: list: A list of available model names. """ return list(MODEL_CONFIGS.keys()) __all__ = ["model_factory", "get_available_models"]