Cherie Ho
commited on
Commit
·
283b3f6
1
Parent(s):
b684d11
adjust threshold for viz
Browse files- mapper/utils/viz_2d.py +3 -2
mapper/utils/viz_2d.py
CHANGED
|
@@ -83,8 +83,9 @@ def one_hot_argmax_to_rgb(y, num_class):
|
|
| 83 |
class_colors = class_colors.values()
|
| 84 |
class_colors = [torch.tensor(x).float() for x in class_colors]
|
| 85 |
|
| 86 |
-
|
| 87 |
-
argmaxed
|
|
|
|
| 88 |
# print(argmaxed.shape)
|
| 89 |
|
| 90 |
seg_rgb = torch.ones(
|
|
|
|
| 83 |
class_colors = class_colors.values()
|
| 84 |
class_colors = [torch.tensor(x).float() for x in class_colors]
|
| 85 |
|
| 86 |
+
threshold = 0.25
|
| 87 |
+
argmaxed = torch.argmax((y > threshold).float(), dim=1) # Take argmax
|
| 88 |
+
argmaxed[torch.all(y <= threshold, dim=1)] = num_class
|
| 89 |
# print(argmaxed.shape)
|
| 90 |
|
| 91 |
seg_rgb = torch.ones(
|