playing around with gradio
Browse files
app.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
| 1 |
import functools
|
| 2 |
-
import matplotlib
|
| 3 |
-
matplotlib.use('Agg')
|
| 4 |
-
import matplotlib.pyplot as plt
|
| 5 |
|
| 6 |
import gradio as gr
|
|
|
|
|
|
|
| 7 |
|
| 8 |
|
|
|
|
| 9 |
BENCHMARK_DATA = {
|
| 10 |
"Greedy Search": {
|
| 11 |
"DistilGPT2": {
|
|
@@ -29,9 +29,9 @@ BENCHMARK_DATA = {
|
|
| 29 |
"A100": [],
|
| 30 |
},
|
| 31 |
"T5 Small": {
|
| 32 |
-
"T4": [
|
| 33 |
-
"3090": [],
|
| 34 |
-
"A100": [],
|
| 35 |
},
|
| 36 |
"T5 Base": {
|
| 37 |
"T4": [],
|
|
@@ -137,10 +137,19 @@ BENCHMARK_DATA = {
|
|
| 137 |
|
| 138 |
|
| 139 |
def get_plot(model_name, generate_type):
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
|
| 145 |
demo = gr.Blocks()
|
| 146 |
|
|
|
|
| 1 |
import functools
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
import gradio as gr
|
| 4 |
+
import seaborn as sns
|
| 5 |
+
import pandas as pd
|
| 6 |
|
| 7 |
|
| 8 |
+
# benchmark order: pytorch, tf eager, tf xla; units = ms
|
| 9 |
BENCHMARK_DATA = {
|
| 10 |
"Greedy Search": {
|
| 11 |
"DistilGPT2": {
|
|
|
|
| 29 |
"A100": [],
|
| 30 |
},
|
| 31 |
"T5 Small": {
|
| 32 |
+
"T4": [99.88, 1527.73, 18.78],
|
| 33 |
+
"3090": [55.09, 665.70, 9.25],
|
| 34 |
+
"A100": [124.91, 1642.07, 13.72],
|
| 35 |
},
|
| 36 |
"T5 Base": {
|
| 37 |
"T4": [],
|
|
|
|
| 137 |
|
| 138 |
|
| 139 |
def get_plot(model_name, generate_type):
|
| 140 |
+
df = pd.DataFrame(BENCHMARK_DATA[generate_type][model_name])
|
| 141 |
+
df["framework"] = ["PyTorch", "TF (Eager Execition)", "TF (XLA)"]
|
| 142 |
+
df = pd.melt(df, id_vars=["framework"], value_vars=["T4", "3090", "A100"])
|
| 143 |
+
|
| 144 |
+
g = sns.catplot(
|
| 145 |
+
data=df, kind="bar",
|
| 146 |
+
x="variable", y="value", hue="framework",
|
| 147 |
+
ci="sd", palette="dark", alpha=.6, height=6
|
| 148 |
+
)
|
| 149 |
+
g.despine(left=True)
|
| 150 |
+
# g.set_axis_labels("", "Body mass (g)")
|
| 151 |
+
# g.legend.set_title("")
|
| 152 |
+
return g.gcf()
|
| 153 |
|
| 154 |
demo = gr.Blocks()
|
| 155 |
|