Nickeik commited on
Commit
8f757dd
·
verified ·
1 Parent(s): 803a4d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -29
app.py CHANGED
@@ -31,7 +31,7 @@ def predict(task, model, text):
31
  if model in ["ChatGPT", "GPT4"]:
32
  # OpenAI API request
33
  response = openai.ChatCompletion.create(
34
- model=MODELS[model],
35
  messages=[{"role": "user", "content": text}]
36
  )
37
  return response['choices'][0]['message']['content']
@@ -43,48 +43,51 @@ def predict(task, model, text):
43
  print(f"Error in prediction: {e}")
44
  return {"error": str(e)}
45
 
46
-
47
  # Function to benchmark Hugging Face models and OpenAI models
48
  def benchmark(task, model, file):
49
- data = pd.read_csv(file.name)
50
- texts = data['query'].tolist()
51
- true_labels = data['answer'].tolist()
52
-
53
- if model in ["ChatGPT", "GPT-4"]:
54
  predictions = []
55
- for text in texts:
56
- response = openai.ChatCompletion.create(
57
- model="gpt-4" if model == "GPT-4" else "gpt-3.5-turbo",
58
- messages=[{"role": "user", "content": text}]
59
- )
60
- predictions.append(response.choices[0].message['content'].strip())
61
- else:
62
- selected_pipeline = load_pipeline(task, model)
63
- predictions = [selected_pipeline(text)[0]['label'] for text in texts]
64
-
65
- accuracy = accuracy_score(true_labels, predictions)
66
- precision, recall, f1, _ = precision_recall_fscore_support(true_labels, predictions, average='macro')
67
-
68
- return {
69
- "Accuracy": accuracy,
70
- "Precision": precision,
71
- "Recall": recall,
72
- "F1 Score": f1
73
- }
 
 
 
 
74
 
75
  # Define the Gradio interface
76
  with gr.Blocks() as demo:
77
  with gr.Row():
78
  task_input = gr.Dropdown(TASKS, label="Task")
79
- model_input = gr.Dropdown(list(MODELS.keys()) + ["ChatGPT", "GPT-4"], label="Model")
80
-
81
  with gr.Tab("Predict"):
82
  with gr.Row():
83
  text_input = gr.Textbox(lines=2, placeholder="Enter text here...", label="Text")
84
  predict_button = gr.Button("Predict")
85
  predict_output = gr.JSON(label="Prediction Output")
86
  predict_button.click(predict, inputs=[task_input, model_input, text_input], outputs=predict_output)
87
-
88
  with gr.Tab("Benchmark"):
89
  with gr.Row():
90
  file_input = gr.File(label="Upload CSV for Benchmarking")
 
31
  if model in ["ChatGPT", "GPT4"]:
32
  # OpenAI API request
33
  response = openai.ChatCompletion.create(
34
+ model="gpt-4" if model == "GPT4" else "gpt-3.5-turbo",
35
  messages=[{"role": "user", "content": text}]
36
  )
37
  return response['choices'][0]['message']['content']
 
43
  print(f"Error in prediction: {e}")
44
  return {"error": str(e)}
45
 
 
46
  # Function to benchmark Hugging Face models and OpenAI models
47
  def benchmark(task, model, file):
48
+ try:
49
+ data = pd.read_csv(file.name)
50
+ texts = data['query'].tolist()
51
+ true_labels = data['answer'].tolist()
52
+
53
  predictions = []
54
+ if model in ["ChatGPT", "GPT4"]:
55
+ for text in texts:
56
+ response = openai.ChatCompletion.create(
57
+ model="gpt-4" if model == "GPT4" else "gpt-3.5-turbo",
58
+ messages=[{"role": "user", "content": text}]
59
+ )
60
+ predictions.append(response['choices'][0]['message']['content'].strip())
61
+ else:
62
+ selected_pipeline = load_pipeline(task, model)
63
+ predictions = [selected_pipeline(text)[0]['label'] for text in texts]
64
+
65
+ accuracy = accuracy_score(true_labels, predictions)
66
+ precision, recall, f1, _ = precision_recall_fscore_support(true_labels, predictions, average='macro')
67
+
68
+ return {
69
+ "Accuracy": accuracy,
70
+ "Precision": precision,
71
+ "Recall": recall,
72
+ "F1 Score": f1
73
+ }
74
+ except Exception as e:
75
+ print(f"Error in benchmarking: {e}")
76
+ return {"error": str(e)}
77
 
78
  # Define the Gradio interface
79
  with gr.Blocks() as demo:
80
  with gr.Row():
81
  task_input = gr.Dropdown(TASKS, label="Task")
82
+ model_input = gr.Dropdown(list(MODELS.keys()) + ["ChatGPT", "GPT4"], label="Model")
83
+
84
  with gr.Tab("Predict"):
85
  with gr.Row():
86
  text_input = gr.Textbox(lines=2, placeholder="Enter text here...", label="Text")
87
  predict_button = gr.Button("Predict")
88
  predict_output = gr.JSON(label="Prediction Output")
89
  predict_button.click(predict, inputs=[task_input, model_input, text_input], outputs=predict_output)
90
+
91
  with gr.Tab("Benchmark"):
92
  with gr.Row():
93
  file_input = gr.File(label="Upload CSV for Benchmarking")