Spaces:
Runtime error
Runtime error
| import torch | |
| import os | |
| import streamlit as st | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| import numpy as np | |
| import plotly.express as px | |
| import pickle | |
| import random | |
| from utils import check_password | |
| from PIL import Image | |
| from transformers import YolosFeatureExtractor, YolosForObjectDetection | |
| from torchvision.transforms import ToTensor, ToPILImage | |
| from annotated_text import annotated_text | |
| from st_pages import add_indentation | |
| #add_indentation() | |
| st.set_page_config(layout="wide") | |
| def load_model(feature_extractor_url, model_url): | |
| feature_extractor_ = YolosFeatureExtractor.from_pretrained(feature_extractor_url) | |
| model_ = YolosForObjectDetection.from_pretrained(model_url) | |
| return feature_extractor_, model_ | |
| def rgb_to_hex(rgb): | |
| """Converts an RGB tuple to an HTML-style Hex string.""" | |
| hex_color = "#{:02x}{:02x}{:02x}".format(int(rgb[0] * 255), int(rgb[1] * 255), int(rgb[2] * 255)) | |
| return hex_color | |
| ## CODE TO CLEAN IMAGES | |
| def fix_channels(t): | |
| if len(t.shape) == 2: | |
| return ToPILImage()(torch.stack([t for i in (0, 0, 0)])) | |
| if t.shape[0] == 4: | |
| return ToPILImage()(t[:3]) | |
| if t.shape[0] == 1: | |
| return ToPILImage()(torch.stack([t[0] for i in (0, 0, 0)])) | |
| return ToPILImage()(t) | |
| ## CODE FOR PLOTS WITH BOUNDING BOXES | |
| COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125], | |
| [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]] | |
| def idx_to_text(i): | |
| if i in list(dict_cats_final.keys()): | |
| return dict_cats_final[i.item()] | |
| else: | |
| return False | |
| # for output bounding box post-processing | |
| def box_cxcywh_to_xyxy(x): | |
| x_c, y_c, w, h = x.unbind(1) | |
| b = [(x_c - 0.5 * w), (y_c - 0.5 * h), | |
| (x_c + 0.5 * w), (y_c + 0.5 * h)] | |
| return torch.stack(b, dim=1) | |
| def rescale_bboxes(out_bbox, size): | |
| img_w, img_h = size | |
| b = box_cxcywh_to_xyxy(out_bbox) | |
| b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32) | |
| return b | |
| def plot_results(pil_img, prob, boxes): | |
| fig = plt.figure(figsize=(16,10)) | |
| plt.imshow(pil_img) | |
| ax = plt.gca() | |
| colors = COLORS * 100 | |
| colors_used = [] | |
| for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors): | |
| cl = p.argmax() | |
| p_max = p.max().detach().numpy() | |
| if idx_to_text(cl) is False: | |
| pass | |
| else: | |
| colors_used.append(rgb_to_hex(c)) | |
| ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, | |
| fill=False, color=c, linewidth=3)) | |
| ax.text(xmin, ymin, f"{idx_to_text(cl)}", fontsize=10, | |
| bbox=dict(facecolor=c, alpha=0.8)) | |
| plt.axis('off') | |
| plt.savefig("results_od.png", | |
| bbox_inches ="tight") | |
| plt.show() | |
| st.image("results_od.png") | |
| return colors_used | |
| def return_probas(outputs, threshold): | |
| probas = outputs.logits.softmax(-1)[0, :, :-1] | |
| probas = probas[:][:,list(dict_cats_final.keys())] | |
| keep = probas.max(-1).values > threshold | |
| return probas, keep | |
| def visualize_probas(probas, threshold, colors): | |
| label_df = pd.DataFrame({"label":probas.max(-1).indices.detach().numpy(), | |
| "proba":probas.max(-1).values.detach().numpy()}) | |
| cats_dict = dict(zip(np.arange(0,len(cats)),cats)) | |
| label_df["label"] = label_df["label"].map(cats_dict) | |
| top_label_df = label_df.loc[label_df["proba"]>threshold].round(2) | |
| top_label_df["colors"] = colors | |
| top_label_df.sort_values(by=["proba"], ascending=False, inplace=True) | |
| #st.dataframe(top_label_df.drop(columns=["colors"])) | |
| mode_func = lambda x: x.mode().iloc[0] | |
| top_label_df_agg = top_label_df.groupby("label").agg({"proba":"mean", "colors":mode_func}) | |
| top_label_df_agg = top_label_df_agg.reset_index().sort_values(by=["proba"], ascending=False) | |
| top_label_df_agg.columns = ["Item","Score","Colors"] | |
| color_map = dict(zip(top_label_df_agg["Item"].to_list(), | |
| top_label_df_agg["Colors"].to_list())) | |
| fig = px.bar(top_label_df_agg, y='Item', x='Score', | |
| color="Item", title="Probability scores") | |
| st.plotly_chart(fig, use_container_width=True) | |
| cats = ['shirt, blouse', 'top, t-shirt, sweatshirt', 'sweater', 'cardigan', 'jacket', 'vest', 'pants', 'shorts', 'skirt', 'coat', 'dress', 'jumpsuit', | |
| 'cape', 'glasses', 'hat', 'headband, head covering, hair accessory', 'tie', 'glove', 'watch', 'belt', 'leg warmer', 'tights, stockings', 'sock', 'shoe', 'bag, wallet', 'scarf', 'umbrella', 'hood', 'collar', | |
| 'lapel', 'epaulette', 'sleeve', 'pocket', 'neckline', 'buckle', 'zipper', 'applique', 'bead', 'bow', 'flower', 'fringe', 'ribbon', 'rivet', 'ruffle', 'sequin', 'tassel'] | |
| ###################################################################################################################################### | |
| if check_password(): | |
| st.markdown("# Object Detection 📹") | |
| st.markdown("### What is Object Detection ?") | |
| #st.markdown("""Object detection involves **identifying** and **locating objects** within an image or video frame through bounding boxes. """) | |
| st.info("""Object Detection is a computer vision task in which the goal is to **detect** and **locate objects** of interest in an image or video. | |
| The task involves identifying the position and boundaries of objects (or **bounding boxes**) in an image, and classifying the objects into different categories.""") | |
| st.markdown("Here is an example of Object Detection for Traffic Analysis.") | |
| #image_od = Image.open('images/od_2.png') | |
| #st.image(image_od, width=600) | |
| st.video(data='https://www.youtube.com/watch?v=PVCGDoTZHaI') | |
| st.markdown(" ") | |
| st.markdown("""Common applications of Object Detection include: | |
| - **Autonomous Vehicles** :car: : Object detection is crucial for self-driving cars to track pedestrians, cyclists, other vehicles, and obstacles on the road. | |
| - **Retail** 🏬 : Implementing smart shelves and checkout systems that use object detection to track inventory and monitor stock levels. | |
| - **Healthcare** 👨⚕️: Detecting and tracking anomalies in medical images, such as tumors or abnormalities, for diagnostic purposes or prevention. | |
| - **Manufacturing** 🏭: Quality control on production lines by detecting defects or irregularities in manufactured products. Ensuring workplace safety by monitoring the movement of workers and equipment. | |
| """) | |
| ############################# USE CASE ############################# | |
| st.markdown(" ") | |
| st.divider() | |
| st.markdown("# Fashion Object Detection 👗") | |
| # st.info("""This use case showcases the application of **Object detection** to detect clothing items/features on images. <br> | |
| # The images used were gathered from Dior's""") | |
| st.info("""**Object detection models** can very valuable for fashion retailers wishing to improve customer experience. They can provide, for example, **product recognition**, **visual search** | |
| and even **virtual try-ons**.""") | |
| st.markdown("In this use case, we are going to show an object detection model that as able to identify and locate different articles of clothings on fashion show images.") | |
| st.markdown(" ") | |
| st.markdown(" ") | |
| # images_dior = [os.path.join("data/dior_show/images",url) for url in os.listdir("data/dior_show/images") if url != "results"] | |
| # columns_img = st.columns(4) | |
| # for img, col in zip(images_dior,columns_img): | |
| # with col: | |
| # st.image(img) | |
| _, col, _ = st.columns([0.1,0.8,0.1]) | |
| with col: | |
| st.image("images/fashion_od2.png") | |
| st.markdown(" ") | |
| st.markdown(" ") | |
| st.markdown("### About the model 📚") | |
| st.markdown("""The object detection model was trained to **detect specific clothing items** on images. <br> | |
| Below is a list of the <b>46</b> different types of clothing items the model can identify and locate.""", unsafe_allow_html=True) | |
| colors = ["#8ef", "#faa", "#afa", "#fea", "#8ef","#afa"]*7 + ["#8ef", "#faa", "#afa", "#fea"] | |
| cats_annotated = [(g,"","#afa") for g in cats] | |
| annotated_text([cats_annotated]) | |
| # st.markdown("""**Here are the 'objects' the model is able to detect**: <br> | |
| # 'shirt, blouse', 'top, t-shirt, sweatshirt', 'sweater', 'cardigan', 'jacket', 'vest', 'pants', 'shorts', 'skirt', | |
| # 'coat', 'dress', 'jumpsuit', 'cape', 'glasses', 'hat', 'headband, head covering, hair accessory', 'tie', 'glove', 'watch', | |
| # 'belt', 'leg warmer', 'tights, stockings', 'sock', 'shoe', 'bag, wallet', 'scarf', 'umbrella', 'hood', 'collar', 'lapel', | |
| # 'epaulette', 'sleeve', 'pocket', 'neckline', 'buckle', 'zipper', 'applique', 'bead', 'bow', 'flower', 'fringe', 'ribbon', 'rivet', | |
| # 'ruffle', 'sequin', 'tassel'""", unsafe_allow_html=True) | |
| st.markdown("Credits for the model: https://huggingface.co/valentinafeve/yolos-fashionpedia") | |
| st.markdown("") | |
| st.markdown("") | |
| ############## SELECT AN IMAGE ############### | |
| st.markdown("### Select an image 🖼️") | |
| st.markdown("""The images provided were taken from **Dior's 2020 Fall Women Fashion Show**""") | |
| image_ = None | |
| fashion_images_path = r"data/dior_show/images" | |
| list_images = os.listdir(fashion_images_path) | |
| image_name = st.selectbox("Choose an image", list_images) | |
| image_ = os.path.join(fashion_images_path, image_name) | |
| st.image(image_, width=300) | |
| # image_ = None | |
| # select_image_box = st.radio( | |
| # "**Select the image you wish to run the model on**", | |
| # ["Choose an existing image", "Load your own image"], | |
| # index=None,)# #label_visibility="collapsed") | |
| # if select_image_box == "Choose an existing image": | |
| # fashion_images_path = r"data/dior_show/images" | |
| # list_images = os.listdir(fashion_images_path) | |
| # image_ = st.selectbox("", list_images, label_visibility="collapsed") | |
| # if image_ is not None: | |
| # image_ = os.path.join(fashion_images_path,image_) | |
| # st.markdown("You've selected the following image:") | |
| # st.image(image_, width=300) | |
| # elif select_image_box == "Load your own image": | |
| # image_ = st.file_uploader("Load an image here", | |
| # key="OD_dior", type=['jpg','jpeg','png'], label_visibility="collapsed") | |
| # st.warning("""**Note**: The model tends to perform better with images of people/clothing items facing forward. | |
| # Choose this type of image if you want optimal results.""") | |
| # st.warning("""**Note:** The model was trained to detect clothing items on a single person. | |
| # If your image contains more than one person, the model won't detect the items of the other persons.""") | |
| # if image_ is not None: | |
| # st.image(Image.open(image_), width=300) | |
| st.markdown(" ") | |
| st.markdown(" ") | |
| ########## SELECT AN ELEMENT TO DETECT ################## | |
| dict_cats = dict(zip(np.arange(len(cats)), cats)) | |
| # st.markdown("#### Choose the elements you want to detect 👉") | |
| # # Select one or more elements to detect | |
| # container = st.container() | |
| # selected_options = None | |
| # all = st.checkbox("Select all") | |
| # if all: | |
| # selected_options = container.multiselect("**Select one or more items**", cats, cats) | |
| # else: | |
| # selected_options = container.multiselect("**Select one or more items**", cats) | |
| #cats = selected_options | |
| selected_options = cats | |
| dict_cats_final = {key:value for (key,value) in dict_cats.items() if value in selected_options} | |
| # st.markdown(" ") | |
| # st.markdown(" ") | |
| ############## SELECT A THRESHOLD ############### | |
| st.markdown("### Define a threshold for predictions 🔎") | |
| st.markdown("""In this section, you can select a threshold for the model's final predictions. <br> | |
| Objects that are given a lower score than the chosen threshold will be ignored in the final results""", unsafe_allow_html=True) | |
| st.info("""**Note**: Object detection models detect objects using bounding boxes as well as assign objects to specific classes. | |
| Each object is given a class based on a probability score computed by the model. A high probability signals that the model is confident in its prediction. | |
| On the contrary, a lower probability score signals a level of uncertainty.""") | |
| st.markdown(" ") | |
| #st.markdown("The images below are examples of probability scores given by object detection models for each element detected.") | |
| _, col, _ = st.columns([0.2,0.6,0.2]) | |
| with col: | |
| st.image("images/probability_od.png", | |
| caption="Examples of object detection with bounding boses and probability scores") | |
| st.markdown(" ") | |
| st.markdown("**Select a threshold** ") | |
| # st.warning("""**Note**: The threshold helps you decide how confident you want your model to be with its predictions. | |
| # Elements that are identified with a lower probability than the given threshold will be ignored in the final results.""") | |
| threshold = st.slider('**Select a threshold**', min_value=0.5, step=0.05, max_value=1.0, value=0.75, label_visibility="collapsed") | |
| # if threshold < 0.6: | |
| # st.error("""**Warning**: Selecting a low threshold (below 0.6) could lead the model to make errors and detect too many objects.""") | |
| st.write("You've selected a threshold at", threshold) | |
| st.markdown(" ") | |
| pickle_file_path = r"data/dior_show/results" | |
| ############# RUN MODEL ################ | |
| run_model = st.button("**Run the model**", type="primary") | |
| if run_model: | |
| if image_ != None and selected_options != None and threshold!= None: | |
| with st.spinner('Wait for it...'): | |
| ## SELECT IMAGE | |
| #st.write(image_) | |
| image = Image.open(image_) | |
| image = fix_channels(ToTensor()(image)) | |
| ## LOAD OBJECT DETECTION MODEL | |
| FEATURE_EXTRACTOR_PATH = "hustvl/yolos-small" | |
| MODEL_PATH = "valentinafeve/yolos-fashionpedia" | |
| # feature_extractor, model = load_model(FEATURE_EXTRACTOR_PATH, MODEL_PATH) | |
| # # RUN MODEL ON IMAGE | |
| # inputs = feature_extractor(images=image, return_tensors="pt") | |
| # outputs = model(**inputs) | |
| # Save results | |
| # pickle_file_path = r"data/dior_show/results" | |
| # image_name = image_.split('\\')[1][:5] | |
| # with open(os.path.join(pickle_file_path, f"{image_name}_results.pkl"), 'wb') as file: | |
| # pickle.dump(outputs, file) | |
| image_name = image_name[:5] | |
| path_load_pickle = os.path.join(pickle_file_path, f"{image_name}_results.pkl") | |
| with open(path_load_pickle, 'rb') as pickle_file: | |
| outputs = pickle.load(pickle_file) | |
| probas, keep = return_probas(outputs, threshold) | |
| st.markdown("#### See the results ☑️") | |
| # PLOT BOUNDING BOX AND BARS/PROBA | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.markdown(" ") | |
| st.markdown("##### 1. Bounding box results") | |
| bboxes_scaled = rescale_bboxes(outputs.pred_boxes[0, keep].cpu(), image.size) | |
| colors_used = plot_results(image, probas[keep], bboxes_scaled) | |
| with col2: | |
| #st.markdown("**Probability scores**") | |
| if not any(keep.tolist()): | |
| st.error("""No objects were detected on the image. | |
| Decrease your threshold or choose differents items to detect.""") | |
| else: | |
| st.markdown(" ") | |
| st.markdown("##### 2. Probability score of each object") | |
| st.info("""**Note**: Some items might have been detected more than once on the image. | |
| For these items, we've computed the average probability score across all detections.""") | |
| visualize_probas(probas, threshold, colors_used) | |
| else: | |
| st.error("You must select an **image**, **elements to detect** and a **threshold** to run the model !") | |