Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import rasterio | |
| from rasterio.windows import Window | |
| from tqdm.auto import tqdm | |
| import io | |
| import zipfile | |
| # Assuming you have these functions defined elsewhere | |
| from your_module import preprocess, best_model, DEVICE | |
| def extract_tiles(map_file, model, tile_size=512, overlap=0, batch_size=4): | |
| tiles = [] | |
| with rasterio.open(map_file) as src: | |
| height = src.height | |
| width = src.width | |
| effective_tile_size = tile_size - overlap | |
| for y in tqdm(range(0, height, effective_tile_size)): | |
| for x in range(0, width, effective_tile_size): | |
| batch_images = [] | |
| batch_metas = [] | |
| for i in range(batch_size): | |
| curr_y = y + (i * effective_tile_size) | |
| if curr_y >= height: | |
| break | |
| window = Window(x, curr_y, tile_size, tile_size) | |
| out_image = src.read(window=window) | |
| if out_image.shape[0] == 1: | |
| out_image = np.repeat(out_image, 3, axis=0) | |
| elif out_image.shape[0] != 3: | |
| raise ValueError("The number of channels in the image is not supported") | |
| out_image = np.transpose(out_image, (1, 2, 0)) | |
| tile_image = Image.fromarray(out_image.astype(np.uint8)) | |
| out_meta = src.meta.copy() | |
| out_meta.update({ | |
| "driver": "GTiff", | |
| "height": tile_size, | |
| "width": tile_size, | |
| "transform": rasterio.windows.transform(window, src.transform) | |
| }) | |
| tile_image = np.array(tile_image) | |
| preprocessed_tile = preprocess(image=tile_image)['image'] | |
| batch_images.append(preprocessed_tile) | |
| batch_metas.append(out_meta) | |
| if not batch_images: | |
| break | |
| # Concatenate batch images | |
| batch_tensor = torch.cat([img.unsqueeze(0).to(DEVICE) for img in batch_images], dim=0) | |
| # Perform inference on the batch | |
| with torch.no_grad(): | |
| batch_masks = model(batch_tensor.to(DEVICE)) | |
| batch_masks = torch.sigmoid(batch_masks) | |
| batch_masks = (batch_masks > 0.6).float() | |
| # Process each mask in the batch | |
| for j, mask_tensor in enumerate(batch_masks): | |
| mask_resized = torch.nn.functional.interpolate(mask_tensor.unsqueeze(0), size=(tile_size, tile_size), mode='bilinear', align_corners=False).squeeze(0) | |
| mask_array = mask_resized.squeeze().cpu().numpy() | |
| if mask_array.any() == 1: | |
| tiles.append([mask_array, batch_metas[j]]) | |
| return tiles | |
| def main(): | |
| st.title("TIF File Processor") | |
| uploaded_file = st.file_uploader("Choose a TIF file", type="tif") | |
| if uploaded_file is not None: | |
| st.write("File uploaded successfully!") | |
| # Process button | |
| if st.button("Process File"): | |
| st.write("Processing...") | |
| # Save the uploaded file temporarily | |
| with open("temp.tif", "wb") as f: | |
| f.write(uploaded_file.getbuffer()) | |
| # Process the file | |
| best_model.float() | |
| tiles = extract_tiles("temp.tif", best_model, tile_size=512, overlap=15, batch_size=4) | |
| st.write("Processing complete!") | |
| # Prepare zip file for download | |
| zip_buffer = io.BytesIO() | |
| with zipfile.ZipFile(zip_buffer, "a", zipfile.ZIP_DEFLATED, False) as zip_file: | |
| for i, (mask_array, meta) in enumerate(tiles): | |
| # Save each tile as a separate TIF file | |
| with rasterio.open(f"tile_{i}.tif", 'w', **meta) as dst: | |
| dst.write(mask_array, 1) | |
| # Add the tile to the zip file | |
| zip_file.write(f"tile_{i}.tif") | |
| # Offer the zip file for download | |
| st.download_button( | |
| label="Download processed tiles", | |
| data=zip_buffer.getvalue(), | |
| file_name="processed_tiles.zip", | |
| mime="application/zip" | |
| ) | |
| if __name__ == "__main__": | |
| main() |