sergio-sanz-rodriguez commited on
Commit
f1dd2a9
·
1 Parent(s): aeaa073

updated app, with the new algorithm using two ViTs

Browse files
Files changed (3) hide show
  1. app.py +110 -39
  2. class_names.txt +2 -2
  3. food_descriptions.json +2 -2
app.py CHANGED
@@ -13,10 +13,13 @@ from torchvision.transforms import v2
13
  # Specify class names
14
  food_vision_class_names_path = "class_names.txt"
15
  with open(food_vision_class_names_path, "r") as f:
16
- class_names = f.read().splitlines()
 
 
17
 
18
  # Specify number of classes
19
- num_classes = len(class_names) - 1 # 101, "unknown" to be discarded
 
20
 
21
  # Load the food description file
22
  food_descriptions_json = "food_descriptions.json"
@@ -32,18 +35,26 @@ effnetb0_model = create_effnetb0(
32
  compile=True
33
  )
34
 
35
- # Load the ViT-Base/16 transformer with input image of 384x384 pixels
36
- vitbase_model = create_vitbase_model(
 
 
 
 
 
 
 
 
37
  model_weights_dir=".",
38
  model_weights_name="vitbase16_2_2024-12-31.pth",
39
  img_size=384,
40
- num_classes=num_classes,
41
  compile=True
42
  )
43
 
44
  # Specify manual transforms for model_2
45
  transforms = v2.Compose([
46
- v2.Resize(384), #v2.Resize((384, 384)),
47
  v2.CenterCrop((384, 384)),
48
  v2.ToImage(),
49
  v2.ToDtype(torch.float32, scale=True),
@@ -51,66 +62,126 @@ transforms = v2.Compose([
51
  std=[0.229, 0.224, 0.225])
52
  ])
53
 
 
54
  # Put models into evaluation mode and turn on inference mode
55
  effnetb0_model.eval()
56
- vitbase_model.eval()
 
57
 
58
  # Set thresdholds
59
  BINARY_CLASSIF_THR = 0.9989122152328491
60
  MULTICLASS_CLASSIF_THR = 0.5
61
  ENTROPY_THR = 2.6
62
 
63
- # Predict function
64
- def predict(image) -> Tuple[Dict, str, str]:
 
 
 
 
 
 
 
65
 
66
  """Transforms and performs a prediction on image and returns prediction and time taken.
67
  """
68
  try:
69
  # Start the timer
70
  start_time = timer()
71
-
72
  # Transform the target image and add a batch dimension
73
  image = transforms(image).unsqueeze(0)
74
 
75
  # Make prediction...
76
  with torch.inference_mode():
77
-
78
  # If the picture is food
79
  if effnetb0_model(image)[:,1].cpu() >= BINARY_CLASSIF_THR:
80
 
81
- # Pass the transformed image through the model and turn the prediction logits into prediction probabilities
82
- pred_probs = torch.softmax(vitbase_model(image), dim=1) # 101 classes
83
 
84
- # Calculate entropy
85
- entropy = -torch.sum(pred_probs * torch.log(pred_probs), dim=1).item()
 
86
 
87
- # Create a prediction label and prediction probability dictionary for each prediction class
88
- pred_classes_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(num_classes)}
89
- pred_classes_and_probs["unknown"] = 0.0
90
 
91
- # Get the top predicted class
92
- top_class = max(pred_classes_and_probs, key=pred_classes_and_probs.get)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
- # If the image is likely to be an unknown category
95
- if pred_probs[0][class_names.index(top_class)] <= MULTICLASS_CLASSIF_THR and entropy > ENTROPY_THR:
96
 
97
- # Create prediction label and prediction probability for class unknown and rescale the rest of predictions
98
- pred_classes_and_probs["unknown"] = pred_probs.max() * 1.25
99
- prob_sum = sum(pred_classes_and_probs.values())
100
- pred_classes_and_probs = {key: value / prob_sum for key, value in pred_classes_and_probs.items()}
 
 
 
 
 
 
 
 
101
 
102
  # Get the top predicted class
103
- top_class = "unknown"
 
 
 
104
 
 
 
 
 
 
 
 
 
105
  # Otherwise
106
  else:
107
 
108
  # Set all probabilites to zero except class unknown
109
- pred_classes_and_probs = {class_names[i]: 0.0 for i in range(num_classes)}
110
  pred_classes_and_probs["unknown"] = 1.0
111
 
112
  # Get the top predicted class
113
  top_class = "unknown"
 
114
 
115
  # Get the description of the top predicted class
116
  top_class_description = food_descriptions.get(top_class, "Description not available.")
@@ -133,22 +204,21 @@ description = f"""
133
  A cutting-edge Vision Transformer (ViT) model to classify 101 delicious food types. Discover the power of AI in culinary recognition.
134
 
135
  ### Supported Food Types
136
- {', '.join(class_names[:-1])}.
137
  """
138
 
139
  # Configure the upload image area
140
  upload_input = gr.Image(type="pil", label="Upload Image", sources=['upload'], show_label=True, mirror_webcam=False)
141
 
142
  # Configure the dropdown option
143
- #model_dropdown = gr.Dropdown(
144
- # choices=["Vision Transformer - 384x384 pixels (higher accuracy, slower predictions)",
145
- # "Vision Transformer - 224x224 pixels (lower accuracy, faster predictions)"],
146
- # value="Vision Transformer - 384x384 pixels (higher accuracy, slower predictions)",
147
- # label="Select Model:"
148
- #)
149
 
150
  # Configure the sample image area
151
- food_vision_examples = [["examples/" + example] for example in os.listdir("examples")]
152
 
153
  # Author
154
  article = "Created by Sergio Sanz."
@@ -159,15 +229,16 @@ article = "Created by Sergio Sanz."
159
 
160
  # Create the Gradio demo
161
  demo = gr.Interface(fn=predict, # mapping function from input to outputs
162
- inputs=upload_input, # inputs #[upload_input, model_dropdown]
163
  outputs=[gr.Label(num_top_classes=3, label="Prediction"),
164
  gr.Textbox(label="Prediction time:"),
165
  gr.Textbox(label="Food Description:")], # outputs
166
- examples=food_vision_examples, # Create examples list from "examples/" directory
167
- cache_examples=True, # Cache the examples
168
  title=title, # Title of the app
169
  description=description, # Brief description of the app
170
  article=article, # Created by...
 
171
  theme="ocean") # Theme
172
 
173
  # Launch the demo!
 
13
  # Specify class names
14
  food_vision_class_names_path = "class_names.txt"
15
  with open(food_vision_class_names_path, "r") as f:
16
+ class_names_102 = f.read().splitlines()
17
+ class_names_101 = class_names_102.copy()
18
+ class_names_101.remove("unknown")
19
 
20
  # Specify number of classes
21
+ num_classes_102 = len(class_names_102) # 101 + unknown
22
+ num_classes_101 = len(class_names_101) # 101
23
 
24
  # Load the food description file
25
  food_descriptions_json = "food_descriptions.json"
 
35
  compile=True
36
  )
37
 
38
+ # Load the ViT-Base/16 transformer with input image of 384x384 pixels and 101 + unknown classes
39
+ vitbase_model_102 = create_vitbase_model(
40
+ model_weights_dir=".",
41
+ model_weights_name="vitbase16_102_2025-01-07.pth",
42
+ img_size=384,
43
+ num_classes=num_classes_102,
44
+ compile=True
45
+ )
46
+
47
+ vitbase_model_101 = create_vitbase_model(
48
  model_weights_dir=".",
49
  model_weights_name="vitbase16_2_2024-12-31.pth",
50
  img_size=384,
51
+ num_classes=num_classes_101,
52
  compile=True
53
  )
54
 
55
  # Specify manual transforms for model_2
56
  transforms = v2.Compose([
57
+ v2.Resize((384)), #v2.Resize((384, 384)),
58
  v2.CenterCrop((384, 384)),
59
  v2.ToImage(),
60
  v2.ToDtype(torch.float32, scale=True),
 
62
  std=[0.229, 0.224, 0.225])
63
  ])
64
 
65
+
66
  # Put models into evaluation mode and turn on inference mode
67
  effnetb0_model.eval()
68
+ vitbase_model_102.eval()
69
+ vitbase_model_101.eval()
70
 
71
  # Set thresdholds
72
  BINARY_CLASSIF_THR = 0.9989122152328491
73
  MULTICLASS_CLASSIF_THR = 0.5
74
  ENTROPY_THR = 2.6
75
 
76
+ # Set model names
77
+ lite_model = "⚡ Lite (faster, less accurate)"
78
+ pro_model = "💎 Pro (slower, more accurate)"
79
+
80
+ # Set allow flagging
81
+ allow_flagging = "never" # "manual"
82
+
83
+ # Predict method
84
+ def predict(image, model=pro_model) -> Tuple[Dict, str, str]:
85
 
86
  """Transforms and performs a prediction on image and returns prediction and time taken.
87
  """
88
  try:
89
  # Start the timer
90
  start_time = timer()
91
+
92
  # Transform the target image and add a batch dimension
93
  image = transforms(image).unsqueeze(0)
94
 
95
  # Make prediction...
96
  with torch.inference_mode():
97
+
98
  # If the picture is food
99
  if effnetb0_model(image)[:,1].cpu() >= BINARY_CLASSIF_THR:
100
 
101
+ # If Pro
102
+ if model == pro_model:
103
 
104
+ # Pass the transformed image through the model and turn the prediction logits into prediction probabilities
105
+ pred_probs_102 = torch.softmax(vitbase_model_102(image), dim=1)
106
+ pred_probs_101 = torch.softmax(vitbase_model_101(image), dim=1)
107
 
108
+ # Calculate entropy
109
+ entropy = -torch.sum(pred_probs_101 * torch.log(pred_probs_101), dim=1).item()
 
110
 
111
+ # Create a prediction label and prediction probability dictionary for each prediction class
112
+ pred_classes_and_probs_102 = {class_names_102[i]: float(pred_probs_102[0][i]) for i in range(num_classes_102)}
113
+ pred_classes_and_probs_101 = {class_names_101[i]: float(pred_probs_101[0][i]) for i in range(num_classes_101)}
114
+ pred_classes_and_probs_101["unknown"] = 0.0
115
+
116
+ # Get the top predicted class
117
+ top_class_102 = max(pred_classes_and_probs_102, key=pred_classes_and_probs_102.get)
118
+ sec_class_102 = sorted(pred_classes_and_probs_102.items(), key=lambda x: x[1], reverse=True)[1][0]
119
+ top_class_101 = max(pred_classes_and_probs_101, key=pred_classes_and_probs_101.get)
120
+
121
+ # If the image is likely to be an unknown category
122
+ if pred_probs_101[0][class_names_101.index(top_class_101)] <= MULTICLASS_CLASSIF_THR and entropy > ENTROPY_THR:
123
+
124
+ # Create prediction label and prediction probability for class unknown and rescale the rest of predictions
125
+ pred_classes_and_probs_101["unknown"] = pred_probs_101.max() * 1.25
126
+ prob_sum = sum(pred_classes_and_probs_101.values())
127
+ pred_classes_and_probs = {key: value / prob_sum for key, value in pred_classes_and_probs_101.items()}
128
+
129
+ # Get the top predicted class
130
+ top_class = "unknown"
131
+
132
+ elif ((top_class_101 == sec_class_102) and (top_class_102 == "unknown")) or (top_class_101 == top_class_102):
133
+
134
+ # Get the probability vector
135
+ pred_classes_and_probs = pred_classes_and_probs_101
136
+
137
+ # Get the top predicted class
138
+ top_class = top_class_101
139
+
140
+ else:
141
+
142
+ # Get the probability vector
143
+ pred_classes_and_probs = pred_classes_and_probs_102
144
 
145
+ # Get the top predicted class
146
+ top_class = top_class_102
147
 
148
+ # Otherwise
149
+ else:
150
+
151
+ # Pass the transformed image through the model and turn the prediction logits into prediction probabilities
152
+ pred_probs = torch.softmax(vitbase_model_101(image), dim=1) # 101 classes
153
+
154
+ # Calculate entropy
155
+ entropy = -torch.sum(pred_probs * torch.log(pred_probs), dim=1).item()
156
+
157
+ # Create a prediction label and prediction probability dictionary for each prediction class
158
+ pred_classes_and_probs = {class_names_101[i]: float(pred_probs[0][i]) for i in range(num_classes_101)}
159
+ pred_classes_and_probs["unknown"] = 0.0
160
 
161
  # Get the top predicted class
162
+ top_class = max(pred_classes_and_probs, key=pred_classes_and_probs.get)
163
+
164
+ # If the image is likely to be an unknown category
165
+ if pred_probs[0][class_names_101.index(top_class)] <= MULTICLASS_CLASSIF_THR and entropy > ENTROPY_THR:
166
 
167
+ # Create prediction label and prediction probability for class unknown and rescale the rest of predictions
168
+ pred_classes_and_probs["unknown"] = pred_probs.max() * 1.25
169
+ prob_sum = sum(pred_classes_and_probs.values())
170
+ pred_classes_and_probs = {key: value / prob_sum for key, value in pred_classes_and_probs.items()}
171
+
172
+ # Get the top predicted class
173
+ top_class = "unknown"
174
+
175
  # Otherwise
176
  else:
177
 
178
  # Set all probabilites to zero except class unknown
179
+ pred_classes_and_probs = {class_names_101[i]: 0.0 for i in range(num_classes_101)}
180
  pred_classes_and_probs["unknown"] = 1.0
181
 
182
  # Get the top predicted class
183
  top_class = "unknown"
184
+
185
 
186
  # Get the description of the top predicted class
187
  top_class_description = food_descriptions.get(top_class, "Description not available.")
 
204
  A cutting-edge Vision Transformer (ViT) model to classify 101 delicious food types. Discover the power of AI in culinary recognition.
205
 
206
  ### Supported Food Types
207
+ {', '.join(class_names_102[:-1])}.
208
  """
209
 
210
  # Configure the upload image area
211
  upload_input = gr.Image(type="pil", label="Upload Image", sources=['upload'], show_label=True, mirror_webcam=False)
212
 
213
  # Configure the dropdown option
214
+ model_dropdown = gr.Dropdown(
215
+ choices=[lite_model, pro_model],
216
+ value=pro_model,
217
+ label="Select ViT Model:"
218
+ )
 
219
 
220
  # Configure the sample image area
221
+ # food_vision_examples = [["examples/" + example] for example in os.listdir("examples")]
222
 
223
  # Author
224
  article = "Created by Sergio Sanz."
 
229
 
230
  # Create the Gradio demo
231
  demo = gr.Interface(fn=predict, # mapping function from input to outputs
232
+ inputs=[upload_input, model_dropdown], # inputs
233
  outputs=[gr.Label(num_top_classes=3, label="Prediction"),
234
  gr.Textbox(label="Prediction time:"),
235
  gr.Textbox(label="Food Description:")], # outputs
236
+ #examples=food_vision_examples, # Create examples list from "examples/" directory
237
+ #cache_examples=True, # Cache the examples
238
  title=title, # Title of the app
239
  description=description, # Brief description of the app
240
  article=article, # Created by...
241
+ allow_flagging=allow_flagging, # Only For debugging
242
  theme="ocean") # Theme
243
 
244
  # Launch the demo!
class_names.txt CHANGED
@@ -98,5 +98,5 @@ tacos
98
  takoyaki
99
  tiramisu
100
  tuna tartare
101
- waffles
102
- unknown
 
98
  takoyaki
99
  tiramisu
100
  tuna tartare
101
+ unknown
102
+ waffles
food_descriptions.json CHANGED
@@ -99,6 +99,6 @@
99
  "takoyaki": "Japanese snack made from batter, octopus, and tempura bits, served with takoyaki sauce.",
100
  "tiramisu": "Italian dessert with coffee-soaked ladyfingers, mascarpone cheese, and cocoa powder.",
101
  "tuna tartare": "Finely diced raw tuna, often mixed with soy sauce and served as an appetizer.",
102
- "waffles": "Batter-based dish cooked in a grid pattern, served with syrup, fruit, or whipped cream.",
103
- "unknown": "No sufficient confidence to classify the image."
104
  }
 
99
  "takoyaki": "Japanese snack made from batter, octopus, and tempura bits, served with takoyaki sauce.",
100
  "tiramisu": "Italian dessert with coffee-soaked ladyfingers, mascarpone cheese, and cocoa powder.",
101
  "tuna tartare": "Finely diced raw tuna, often mixed with soy sauce and served as an appetizer.",
102
+ "unknown": "No sufficient confidence to classify the image.",
103
+ "waffles": "Batter-based dish cooked in a grid pattern, served with syrup, fruit, or whipped cream."
104
  }