Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from PIL import Image | |
| from collections import OrderedDict | |
| import torch | |
| from models.model import GLPDepth | |
| from PIL import Image | |
| from torchvision import transforms | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import os | |
| # load model | |
| DEVICE='cpu' | |
| def load_mde_model(path): | |
| model = GLPDepth(max_depth=700.0, is_train=False).to(DEVICE) | |
| model_weight = torch.load(path, map_location=torch.device('cpu')) | |
| model_weight = model_weight['model_state_dict'] | |
| if 'module' in next(iter(model_weight.items()))[0]: | |
| model_weight = OrderedDict((k[7:], v) for k, v in model_weight.items()) | |
| model.load_state_dict(model_weight) | |
| model.eval() | |
| return model | |
| model = load_mde_model('best_model.ckpt') | |
| preprocess = transforms.Compose([ | |
| transforms.Resize((512, 512)), | |
| transforms.ToTensor() | |
| ]) | |
| def predict(input_image): | |
| pil_image = Image.fromarray(input_image.astype('uint8'), 'RGB') | |
| # transform image to torch and do preprocessing | |
| torch_img = preprocess(pil_image).to(DEVICE).unsqueeze(0) | |
| # model predict | |
| with torch.no_grad(): | |
| output_patch = model(torch_img) | |
| # transform torch to image | |
| predicted_image = output_patch['pred_d'].squeeze().cpu().detach().numpy() | |
| # return correct image | |
| fig, ax = plt.subplots() | |
| im = ax.imshow(predicted_image, cmap='jet', vmin=0, vmax=np.max(predicted_image)) | |
| plt.colorbar(im, ax=ax) | |
| fig.canvas.draw() | |
| data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) | |
| data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
| return data #, str(predicted_image.tolist()) | |
| iface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(shape=(512,512)), | |
| outputs=[ | |
| gr.Image(shape=(512,512)), | |
| # gr.outputs.Textbox(label='Raw output') | |
| ], | |
| examples=[ | |
| [f"demo_imgs/{name}"] for name in os.listdir('demo_imgs') | |
| ], | |
| title="DTM Estimation", | |
| description="This demo predict a DTM using GLP Depth model. It will scale input image to 512x512 and at the end it will apply a colormap to better visualize the output." | |
| ) | |
| iface.launch() |