saptarshineilsinha commited on
Commit
c9ca5b4
·
1 Parent(s): 954ea89

Application

Browse files
Files changed (3) hide show
  1. app.py +81 -0
  2. exp.py +28 -0
  3. requirements.txt +102 -0
app.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ import gradio as gr
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ import spaces
8
+ from yolox.exp import get_exp
9
+ from yolox.utils import fuse_model, postprocess, vis
10
+ from yolox.data.data_augment import preproc
11
+ from PIL import Image
12
+ from pathlib import Path
13
+
14
+ MODEL_PATH = "models/yolox-tiny.pth" # Path to trained model
15
+ EXP_FILE = "exp.py" # Path to your experiment file
16
+ CONF_THRESHOLD = 0.4 # Confidence threshold
17
+ NMS_THRESHOLD = 0.65 # Non-max suppression threshold
18
+
19
+ # Load experiment
20
+ @spaces.GPU
21
+ def process_frame(frame):
22
+ exp = get_exp(EXP_FILE, None)
23
+ model = exp.get_model()
24
+ model.eval()
25
+
26
+ ckpt = torch.load(Path(MODEL_PATH), map_location='cpu')
27
+ model.load_state_dict(ckpt["model"])
28
+ model = fuse_model(model)
29
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+ model.to(device)
31
+
32
+ # pil_image = PIL.Image.open('Image.jpg').convert('RGB')
33
+ open_cv_image = np.array(frame)
34
+ # Convert RGB to BGR
35
+ img = open_cv_image[:, :, ::-1].copy()
36
+
37
+ img_input, ratio = preproc(img, exp.test_size)
38
+ img_input = torch.from_numpy(img_input).unsqueeze(0).float().to(device)
39
+
40
+ with torch.no_grad():
41
+ outputs = model(img_input)
42
+ outputs = postprocess(outputs, exp.num_classes, CONF_THRESHOLD, NMS_THRESHOLD)
43
+
44
+ if outputs[0] is not None:
45
+ dets = outputs[0].cpu().numpy()
46
+ bboxes = dets[:, :4].astype(int)
47
+ scores = dets[:, 4] # Achte darauf, dass der Index für Scores korrekt ist
48
+ cls_ids = dets[:, 5].astype(int)
49
+
50
+ result_img = vis(img, bboxes, scores, cls_ids, class_names=exp.class_names, conf=CONF_THRESHOLD)
51
+ else:
52
+ result_img = img # No detections, return original frame
53
+
54
+ return result_img
55
+
56
+ def get_default_image_paths(folder_path):
57
+ image_extensions = ('.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff')
58
+ image_paths = [[os.path.join(folder_path, file)] for file in os.listdir(folder_path)
59
+ if file.lower().endswith(image_extensions)]
60
+ return image_paths
61
+
62
+ default_images = get_default_image_paths(Path("examples/"))
63
+
64
+ def process_input(file_input):
65
+ processed_img = process_frame(file_input)
66
+ processed_img = processed_img[:, :, ::-1].copy()
67
+ return Image.fromarray(processed_img) # Return the processed image directly
68
+
69
+ # Create Gradio Interface with title and description
70
+ iface = gr.Interface(
71
+ fn=process_input,
72
+ inputs=[
73
+ gr.Image(label="Upload Image", type="pil"), # File input as PIL Image
74
+ ],
75
+ outputs=gr.Image(type="pil", label="Output (Image)"), # Show output as an image
76
+ examples=default_images,
77
+ title="Strawberry Disease Detection",
78
+ description="This application detects diseases in strawberries using a trained YOLOX model. Upload an image, video, or use your webcam for analysis."
79
+ )
80
+
81
+ iface.launch(share=True)
exp.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from yolox.exp import Exp as MyExp # type: ignore
3
+ #from yolox.data import COCODataset # type: ignore
4
+
5
+ class Exp(MyExp):
6
+ def __init__(self):
7
+ super(Exp, self).__init__()
8
+ self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
9
+ self.num_classes = 7
10
+ self.class_names=["Angular Leafspot", "Leaf Spot", "Anthracnose Fruit Rot", "Blossom Blight", "Gray Mold", "Powdery Mildew Fruit", "Powdery Mildew Leaf"]
11
+
12
+ # small
13
+ # self.depth = 0.33
14
+ # self.width = 0.50
15
+
16
+ # tiny
17
+ self.depth = 0.33
18
+ self.width = 0.375
19
+ self.input_size = (416, 416)
20
+ self.mosaic_scale = (0.5, 1.5)
21
+ self.random_size = (10, 20)
22
+ self.test_size = (416, 416)
23
+ self.enable_mixup = False
24
+
25
+ self.data_dir = "coco_dataset"
26
+ self.train_ann = "instances_train.json"
27
+ self.val_ann = "instances_val.json"
28
+ self.test_ann = "instances_test.json"
requirements.txt ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.3.1
2
+ aiofiles==23.2.1
3
+ annotated-types==0.7.0
4
+ anyio==4.5.2
5
+ cachetools==5.5.2
6
+ certifi==2025.7.14
7
+ charset-normalizer==3.4.2
8
+ click==8.1.8
9
+ contourpy==1.1.1
10
+ cycler==0.12.1
11
+ exceptiongroup==1.3.0
12
+ fastapi==0.116.1
13
+ ffmpy==0.5.0
14
+ filelock==3.16.1
15
+ flatbuffers==25.2.10
16
+ fonttools==4.57.0
17
+ fsspec==2025.3.0
18
+ google-auth==2.40.3
19
+ google-auth-oauthlib==1.0.0
20
+ gradio==4.44.1
21
+ gradio_client==1.3.0
22
+ grpcio==1.70.0
23
+ h11==0.16.0
24
+ hf-xet==1.1.5
25
+ httpcore==1.0.9
26
+ httpx==0.28.1
27
+ huggingface-hub==0.33.4
28
+ idna==3.10
29
+ imageio==2.35.1
30
+ importlib_metadata==8.5.0
31
+ importlib_resources==6.4.5
32
+ Jinja2==3.1.6
33
+ kiwisolver==1.4.7
34
+ lazy_loader==0.4
35
+ loguru==0.7.3
36
+ Markdown==3.7
37
+ markdown-it-py==3.0.0
38
+ MarkupSafe==2.1.5
39
+ matplotlib==3.7.5
40
+ mdurl==0.1.2
41
+ networkx==3.1
42
+ ninja==1.11.1.4
43
+ numpy==1.24.4
44
+ nvidia-pyindex==1.0.9
45
+ oauthlib==3.3.1
46
+ onnx==1.8.1
47
+ onnx-simplifier==0.3.5
48
+ onnxoptimizer==0.3.13
49
+ onnxruntime==1.8.0
50
+ opencv-python==4.5.5.64
51
+ orjson==3.10.15
52
+ packaging==25.0
53
+ pandas==2.0.3
54
+ pillow==10.4.0
55
+ protobuf==5.29.5
56
+ psutil==5.9.8
57
+ pyasn1==0.6.1
58
+ pyasn1_modules==0.4.2
59
+ pycocotools==2.0.7
60
+ pydantic==2.10.6
61
+ pydantic_core==2.27.2
62
+ pydub==0.25.1
63
+ Pygments==2.19.2
64
+ pyparsing==3.1.4
65
+ python-dateutil==2.9.0.post0
66
+ python-multipart==0.0.20
67
+ pytz==2025.2
68
+ PyWavelets==1.4.1
69
+ PyYAML==6.0.2
70
+ requests==2.32.4
71
+ requests-oauthlib==2.0.0
72
+ rich==14.0.0
73
+ rsa==4.9.1
74
+ ruff==0.12.3
75
+ scikit-image==0.21.0
76
+ scipy==1.10.1
77
+ semantic-version==2.10.0
78
+ shellingham==1.5.4
79
+ shiboken2==5.15.2.1
80
+ six==1.17.0
81
+ sniffio==1.3.1
82
+ spaces==0.37.1
83
+ starlette==0.44.0
84
+ tabulate==0.9.0
85
+ tensorboard==2.14.0
86
+ tensorboard-data-server==0.7.2
87
+ thop==0.1.1.post2209072238
88
+ tifffile==2023.7.10
89
+ tomlkit==0.12.0
90
+ torch==1.13.1+cu116
91
+ torchaudio==0.13.1+cu116
92
+ torchvision==0.14.1+cu116
93
+ tqdm==4.67.1
94
+ typer==0.16.0
95
+ typing_extensions==4.13.2
96
+ tzdata==2025.2
97
+ urllib3==2.2.3
98
+ uvicorn==0.33.0
99
+ websockets==12.0
100
+ Werkzeug==3.0.6
101
+ yolox==0.3.0
102
+ zipp==3.20.2