Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from transformers import CLIPImageProcessor | |
| try: | |
| import open_clip | |
| import torchvision | |
| from open_clip.transformer import _expand_token | |
| except ImportError: | |
| print("OpenCLIP not installed") | |
| open_clip = None | |
| HIDDEN_SIZE_DICT = { | |
| "ViT-H-14-378-quickgelu": 1280, | |
| } | |
| class OpenCLIPVisionTower(nn.Module): | |
| def __init__(self, vision_tower, args, delay_load=False): | |
| super().__init__() | |
| self.is_loaded = False | |
| self.model_name = vision_tower.replace("open_clip_hub:", "") | |
| self.pretrained = args.vision_tower_pretrained | |
| self.select_layer = args.mm_vision_select_layer | |
| self.select_feature = getattr(args, "mm_vision_select_feature", "patch") | |
| if not delay_load: | |
| print(f"Loading vision tower: {vision_tower}") | |
| self.load_model() | |
| elif getattr(args, "unfreeze_mm_vision_tower", False): | |
| # TODO: better detector is needed. | |
| print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.") | |
| self.load_model() | |
| elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts: | |
| print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.") | |
| self.load_model() | |
| def load_model(self, device_map="auto"): | |
| print(f"Loading OpenCLIP model: {self.model_name}") | |
| print(f"Pretrained: {self.pretrained}") | |
| vision_tower, _, image_processor = open_clip.create_model_and_transforms(model_name=self.model_name, pretrained=self.pretrained, precision="fp32", device="cuda") | |
| resize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Resize)][0] | |
| normalize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Normalize)][0] | |
| self.resize_transform_size = resize_transform.size # 224 or 384 | |
| self.patch_size = vision_tower.visual.conv1.kernel_size[0] # 14 or 16 | |
| self.image_processor = CLIPImageProcessor.from_pretrained( | |
| "openai/clip-vit-large-patch14", | |
| crop_size=resize_transform.size, | |
| size={"shortest_edge": resize_transform.size}, | |
| image_mean=list(normalize_transform.mean), | |
| image_std=list(normalize_transform.std), | |
| ) | |
| print(f"Loaded image processor: {self.image_processor}") | |
| self.vision_tower = vision_tower.visual | |
| self.vision_tower.requires_grad_(False) | |
| self.is_loaded = True | |
| def feature_select(self, image_forward_outs): | |
| image_features = image_forward_outs[self.select_layer] | |
| if self.select_feature == "patch": | |
| image_features = image_features[:, 1:] | |
| elif self.select_feature == "cls_patch": | |
| image_features = image_features | |
| elif self.select_feature == "conv_flatten": | |
| image_features = image_features.flatten(2).transpose(1, 2) | |
| else: | |
| raise ValueError(f"Unexpected select feature: {self.select_feature}") | |
| return image_features | |
| def forward_visual(self, x, output_hidden_states=False): | |
| if hasattr(self.vision_tower, "trunk") and hasattr(self.vision_tower.trunk, "_intermediate_layers"): | |
| return self.vision_tower.trunk._intermediate_layers(x, abs(self.select_layer)) | |
| else: | |
| def forward_openclip(self, x: torch.Tensor): | |
| features = [] | |
| x = self.conv1(x) # shape = [*, width, grid, grid] | |
| x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] | |
| x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] | |
| # class embeddings and positional embeddings | |
| x = torch.cat( | |
| [_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], | |
| dim=1, | |
| ) | |
| # shape = [*, grid ** 2 + 1, width] | |
| x = x + self.positional_embedding.to(x.dtype) | |
| x = self.patch_dropout(x) | |
| x = self.ln_pre(x) | |
| x = x.permute(1, 0, 2) # NLD -> LND | |
| for r in self.transformer.resblocks: | |
| x = r(x, attn_mask=None) | |
| features.append(x) | |
| return features | |
| return forward_openclip(self.vision_tower, x) | |
| def forward(self, images): | |
| if type(images) is list: | |
| image_features = [] | |
| for image in images: | |
| image_forward_out = self.forward_visual(image.to(self.dtype).unsqueeze(0), output_hidden_states=True) | |
| image_feature = self.feature_select(image_forward_out).to(image.dtype) | |
| image_features.append(image_feature) | |
| else: | |
| image_forward_outs = self.forward_visual(images.to(self.dtype), output_hidden_states=True) | |
| image_features = self.feature_select(image_forward_outs).to(images.dtype) | |
| return image_features | |
| def dummy_feature(self): | |
| return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) | |
| def dtype(self): | |
| if hasattr(self.vision_tower, "conv1"): | |
| return self.vision_tower.conv1.weight.dtype | |
| if hasattr(self.vision_tower, "trunk"): | |
| return self.vision_tower.trunk.patch_embed.proj.weight.dtype | |
| raise NotImplementedError | |
| def device(self): | |
| if hasattr(self.vision_tower, "conv1"): | |
| return self.vision_tower.conv1.weight.device | |
| if hasattr(self.vision_tower, "trunk"): | |
| return self.vision_tower.trunk.patch_embed.proj.weight.device | |
| raise NotImplementedError | |
| def config(self): | |
| return None | |
| def hidden_size(self): | |
| if self.model_name in HIDDEN_SIZE_DICT: | |
| return HIDDEN_SIZE_DICT[self.model_name] | |
| else: | |
| raise NotImplementedError | |
| def num_patches(self): | |
| image_size = self.resize_transform_size if isinstance(self.resize_transform_size, int) else self.resize_transform_size[0] | |
| _num_patches = (image_size // self.patch_size) ** 2 | |
| if "cls_patch" in self.select_feature: | |
| _num_patches += 1 | |
| return _num_patches | |
| def image_size(self): | |
| return self.resize_transform_size | |
| def num_patches_per_side(self): | |
| return self.resize_transform_size // self.patch_size | |