Registed model and added config
Browse files- register_model.py +12 -0
- sagvit_config.py +28 -0
register_model.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoConfig, AutoModel
|
| 2 |
+
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
|
| 3 |
+
from transformers.models.auto.modeling_auto import MODEL_MAPPING
|
| 4 |
+
|
| 5 |
+
from sagvit_config import SAGViTConfig
|
| 6 |
+
from sag_vit_model import SAGViTClassifier
|
| 7 |
+
|
| 8 |
+
# Register the configuration
|
| 9 |
+
CONFIG_MAPPING.register("sagvit", SAGViTConfig)
|
| 10 |
+
|
| 11 |
+
# Register the model
|
| 12 |
+
MODEL_MAPPING.register(SAGViTConfig, SAGViTClassifier)
|
sagvit_config.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import PretrainedConfig
|
| 2 |
+
|
| 3 |
+
class SAGViTConfig(PretrainedConfig):
|
| 4 |
+
model_type = "sagvit"
|
| 5 |
+
|
| 6 |
+
def __init__(self,
|
| 7 |
+
d_model=64,
|
| 8 |
+
dim_feedforward=64,
|
| 9 |
+
gcn_hidden=128,
|
| 10 |
+
gcn_out=64,
|
| 11 |
+
hidden_mlp_features=64,
|
| 12 |
+
in_channels=2560,
|
| 13 |
+
nhead=4,
|
| 14 |
+
num_classes=10,
|
| 15 |
+
num_layers=2,
|
| 16 |
+
patch_size=(4, 4),
|
| 17 |
+
**kwargs):
|
| 18 |
+
super().__init__(**kwargs)
|
| 19 |
+
self.d_model = d_model
|
| 20 |
+
self.dim_feedforward = dim_feedforward
|
| 21 |
+
self.gcn_hidden = gcn_hidden
|
| 22 |
+
self.gcn_out = gcn_out
|
| 23 |
+
self.hidden_mlp_features = hidden_mlp_features
|
| 24 |
+
self.in_channels = in_channels
|
| 25 |
+
self.nhead = nhead
|
| 26 |
+
self.num_classes = num_classes
|
| 27 |
+
self.num_layers = num_layers
|
| 28 |
+
self.patch_size = patch_size
|