| import matplotlib.pyplot as plt | |
| from generic_utils import generate_visualization | |
| def do_gradcam(transform, image, class_index=None): | |
| fig, axs = plt.subplots(1, 2) | |
| axs[0].imshow(image) | |
| axs[0].axis("off") | |
| transformed_image = transform(image) | |
| viz = generate_visualization( | |
| transformed_image, class_index=class_index, method="gradcam", LRP=False | |
| ) | |
| axs[1].imshow(viz) | |
| axs[1].axis("off") | |
| return fig | |