|
|
""" |
|
|
Demonstrates integrating Rerun visualization with Gradio and HF ZeroGPU. |
|
|
""" |
|
|
|
|
|
import uuid |
|
|
import gradio as gr |
|
|
import rerun as rr |
|
|
import rerun.blueprint as rrb |
|
|
from gradio_rerun import Rerun |
|
|
import spaces |
|
|
from transformers import DetrImageProcessor, DetrForObjectDetection |
|
|
import torch |
|
|
|
|
|
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") |
|
|
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_recording(recording_id: str) -> rr.RecordingStream: |
|
|
return rr.RecordingStream( |
|
|
application_id="rerun_example_gradio", recording_id=recording_id |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
def streaming_object_detection(recording_id: str, img): |
|
|
|
|
|
rec = get_recording(recording_id) |
|
|
stream = rec.binary_stream() |
|
|
|
|
|
if img is None: |
|
|
raise gr.Error("Must provide an image to blur.") |
|
|
|
|
|
blueprint = rrb.Blueprint( |
|
|
rrb.Horizontal( |
|
|
rrb.Spatial2DView(origin="image"), |
|
|
), |
|
|
collapse_panels=True, |
|
|
) |
|
|
|
|
|
rec.send_blueprint(blueprint) |
|
|
rec.set_time("iteration", sequence=0) |
|
|
rec.log("image", rr.Image(img)) |
|
|
yield stream.read() |
|
|
|
|
|
with torch.inference_mode(): |
|
|
inputs = processor(images=img, return_tensors="pt") |
|
|
outputs = model(**inputs) |
|
|
|
|
|
|
|
|
|
|
|
height, width = img.shape[:2] |
|
|
target_sizes = torch.tensor([[height, width]]) |
|
|
results = processor.post_process_object_detection( |
|
|
outputs, target_sizes=target_sizes, threshold=0.85 |
|
|
)[0] |
|
|
|
|
|
rec.log( |
|
|
"image/objects", |
|
|
rr.Boxes2D( |
|
|
array=results["boxes"], |
|
|
array_format=rr.Box2DFormat.XYXY, |
|
|
labels=[model.config.id2label[label.item()] for label in results["labels"]], |
|
|
colors=[ |
|
|
( |
|
|
label.item() * 50 % 255, |
|
|
(label.item() * 80 + 40) % 255, |
|
|
(label.item() * 120 + 100) % 255, |
|
|
) |
|
|
for label in results["labels"] |
|
|
], |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
stream.flush() |
|
|
yield stream.read() |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
with gr.Accordion("Your image", open=True): |
|
|
img = gr.Image(interactive=True, label="Image") |
|
|
detect_objects = gr.Button("Detect objects") |
|
|
|
|
|
with gr.Column(scale=4): |
|
|
viewer = Rerun( |
|
|
streaming=True, |
|
|
panel_states={ |
|
|
"time": "collapsed", |
|
|
"blueprint": "hidden", |
|
|
"selection": "hidden", |
|
|
}, |
|
|
height=700, |
|
|
) |
|
|
|
|
|
|
|
|
recording_id = gr.State(uuid.uuid4()) |
|
|
|
|
|
|
|
|
|
|
|
detect_objects.click( |
|
|
|
|
|
streaming_object_detection, |
|
|
inputs=[recording_id, img], |
|
|
outputs=[viewer], |
|
|
) |
|
|
if __name__ == "__main__": |
|
|
demo.launch(ssr_mode=False) |
|
|
|