shravvvv commited on
Commit
31a17dc
·
1 Parent(s): 039647a

Registed model and added config

Browse files
Files changed (2) hide show
  1. register_model.py +12 -0
  2. sagvit_config.py +28 -0
register_model.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoConfig, AutoModel
2
+ from transformers.models.auto.configuration_auto import CONFIG_MAPPING
3
+ from transformers.models.auto.modeling_auto import MODEL_MAPPING
4
+
5
+ from sagvit_config import SAGViTConfig
6
+ from sag_vit_model import SAGViTClassifier
7
+
8
+ # Register the configuration
9
+ CONFIG_MAPPING.register("sagvit", SAGViTConfig)
10
+
11
+ # Register the model
12
+ MODEL_MAPPING.register(SAGViTConfig, SAGViTClassifier)
sagvit_config.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class SAGViTConfig(PretrainedConfig):
4
+ model_type = "sagvit"
5
+
6
+ def __init__(self,
7
+ d_model=64,
8
+ dim_feedforward=64,
9
+ gcn_hidden=128,
10
+ gcn_out=64,
11
+ hidden_mlp_features=64,
12
+ in_channels=2560,
13
+ nhead=4,
14
+ num_classes=10,
15
+ num_layers=2,
16
+ patch_size=(4, 4),
17
+ **kwargs):
18
+ super().__init__(**kwargs)
19
+ self.d_model = d_model
20
+ self.dim_feedforward = dim_feedforward
21
+ self.gcn_hidden = gcn_hidden
22
+ self.gcn_out = gcn_out
23
+ self.hidden_mlp_features = hidden_mlp_features
24
+ self.in_channels = in_channels
25
+ self.nhead = nhead
26
+ self.num_classes = num_classes
27
+ self.num_layers = num_layers
28
+ self.patch_size = patch_size