innafomina commited on
Commit
a8bf136
·
1 Parent(s): f8e8c9b

added more functionality

Browse files
Files changed (3) hide show
  1. app.py +45 -32
  2. chess.py +44 -0
  3. tools.py +53 -3
app.py CHANGED
@@ -3,12 +3,12 @@ import gradio as gr
3
  import requests
4
  import inspect
5
  import pandas as pd
6
- from smolagents import CodeAgent, OpenAIServerModel, DuckDuckGoSearchTool, WikipediaSearchTool, HfApiModel, GoogleSearchTool
7
  from dotenv import find_dotenv, load_dotenv
8
- from tools import WikipediaSearch, ExcelReader, download_files, get_images, FileReader, AudioTransciber, YouTubeTranscipt, YouTubeVideoUnderstanding
9
  from pathlib import Path
10
  from PIL import Image
11
-
12
  # (Keep Constants as is)
13
  # --- Constants ---
14
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
@@ -16,18 +16,24 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
16
  # --- Basic Agent Definition ---
17
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
18
  class BasicAgent:
19
- def __init__(self):
20
  load_dotenv(find_dotenv())
21
  os.environ["SERPER_API_KEY"] = os.getenv('SERPER_API_KEY')
22
- model = OpenAIServerModel(model_id="gpt-4o",
23
- api_key=os.getenv("OPEN_AI_KEY"))
 
 
24
  #model=HfApiModel(api_key=os.getenv('HUGGING_FACE_API_KEY'))
25
  # Instantiate the agent
26
  self.agent = CodeAgent(
27
  tools=[
28
  GoogleSearchTool(provider="serper"),
 
29
  #DuckDuckGoSearchTool(),
30
- WikipediaSearch(), ExcelReader(), FileReader(), AudioTransciber(), YouTubeTranscipt(),
 
 
 
31
  YouTubeVideoUnderstanding()],
32
  model=model,
33
  add_base_tools=True,
@@ -43,7 +49,8 @@ class BasicAgent:
43
  digits in plain text unless specified otherwise. Never include currency symbols in the response.
44
  If you are asked for a comma separated list, apply the above rules depending of whether the element
45
  to be put in the list is a number or a string. For question that contain phrases like `what is the number` or
46
- `what is the highest number` return just the number, e.g., 2.
 
47
  """
48
  self.agent.prompt_templates["system_prompt"] = self.agent.prompt_templates["system_prompt"] + SYSTEM_PROMPT
49
  def __call__(self, question: str) -> str:
@@ -91,7 +98,11 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
91
  return f"An unexpected error occurred fetching questions: {e}", None
92
  # 1. Instantiate Agent ( modify this part to create your agent)
93
  try:
94
- agent = BasicAgent()
 
 
 
 
95
  except Exception as e:
96
  print(f"Error instantiating agent: {e}")
97
  return f"Error initializing agent: {e}", None
@@ -105,35 +116,37 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
105
  results_log = []
106
  answers_payload = []
107
  print(f"Running agent on {len(questions_data)} questions...")
108
- #questions_data = [i for i in questions_data if i['task_id'] == 'f918266a-b3e0-4914-865d-4faa564f1aef']
109
  images = []
110
  #added limit for testing
111
- for item in questions_data:
112
- task_id = item.get("task_id")
113
- question_text = item.get("question") + 'You can use wikipedia.'
114
- file_name = item.get('file_name')
115
- if file_name:
116
- file_path = download_files(task_id, file_name)
117
- file_format = file_name.split('.')[-1]
118
- question_text = question_text + f"This question has an associated file at path: {file_path}. The file is in the {file_format} format"
119
- if not task_id or question_text is None:
120
- print(f"Skipping item with missing task_id or question: {item}")
121
- continue
122
- try:
123
- print(images)
124
- submitted_answer = agent(question_text)
125
- print(submitted_answer)
126
- answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
127
- results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
128
- except Exception as e:
129
- print(f"Error running agent on task {task_id}: {e}")
130
- results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": f"AGENT ERROR: {e}"})
 
 
131
 
132
  if not answers_payload:
133
  print("Agent did not produce any answers to submit.")
134
  return "Agent did not produce any answers to submit.", pd.DataFrame(results_log)
135
 
136
- # 4. Prepare Submission
137
  submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload}
138
  status_update = f"Agent finished. Submitting {len(answers_payload)} answers for user '{username}'..."
139
  print(status_update)
@@ -235,7 +248,7 @@ if __name__ == "__main__":
235
  print("-"*(60 + len(" App Starting ")) + "\n")
236
 
237
  print("Launching Gradio Interface for Basic Agent Evaluation...")
238
- demo.launch(debug=True, share=False)
239
 
240
  if __name__ == "__main__":
241
  run_and_submit_all()
 
3
  import requests
4
  import inspect
5
  import pandas as pd
6
+ from smolagents import CodeAgent, OpenAIServerModel, DuckDuckGoSearchTool, WikipediaSearchTool, HfApiModel, GoogleSearchTool, LiteLLMModel
7
  from dotenv import find_dotenv, load_dotenv
8
+ from tools import WikipediaSearch, ExcelReader, ChessSolver, download_files, get_images, FileReader, AudioTransciber, YouTubeTranscipt, YouTubeVideoUnderstanding, VegetableFruitClassification
9
  from pathlib import Path
10
  from PIL import Image
11
+ import time
12
  # (Keep Constants as is)
13
  # --- Constants ---
14
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
 
16
  # --- Basic Agent Definition ---
17
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
18
  class BasicAgent:
19
+ def __init__(self, model):
20
  load_dotenv(find_dotenv())
21
  os.environ["SERPER_API_KEY"] = os.getenv('SERPER_API_KEY')
22
+ # model = OpenAIServerModel(model_id="gpt-4o",
23
+ # api_key=os.getenv("OPEN_AI_KEY"))
24
+ # model = LiteLLMModel(model_id="gemini/gemini-2.0-flash",
25
+ # api_key=os.getenv("GEMINI_API_KEY"))
26
  #model=HfApiModel(api_key=os.getenv('HUGGING_FACE_API_KEY'))
27
  # Instantiate the agent
28
  self.agent = CodeAgent(
29
  tools=[
30
  GoogleSearchTool(provider="serper"),
31
+ ChessSolver(),
32
  #DuckDuckGoSearchTool(),
33
+ VegetableFruitClassification(),
34
+ WikipediaSearch(),
35
+ ExcelReader(), FileReader(),
36
+ AudioTransciber(), YouTubeTranscipt(),
37
  YouTubeVideoUnderstanding()],
38
  model=model,
39
  add_base_tools=True,
 
49
  digits in plain text unless specified otherwise. Never include currency symbols in the response.
50
  If you are asked for a comma separated list, apply the above rules depending of whether the element
51
  to be put in the list is a number or a string. For question that contain phrases like `what is the number` or
52
+ `what is the highest number` return just the number, e.g., 2. For questions around currency,
53
+ include just the number, not the currency sign.
54
  """
55
  self.agent.prompt_templates["system_prompt"] = self.agent.prompt_templates["system_prompt"] + SYSTEM_PROMPT
56
  def __call__(self, question: str) -> str:
 
98
  return f"An unexpected error occurred fetching questions: {e}", None
99
  # 1. Instantiate Agent ( modify this part to create your agent)
100
  try:
101
+ model = LiteLLMModel(model_id= "gemini/gemini-2.0-flash",
102
+ api_key=os.getenv("GEMINI_API_KEY"))
103
+ # model = OpenAIServerModel(model_id="gpt-4o",
104
+ # api_key=os.getenv("OPEN_AI_KEY"))
105
+ agent = BasicAgent(model=model)
106
  except Exception as e:
107
  print(f"Error instantiating agent: {e}")
108
  return f"Error initializing agent: {e}", None
 
116
  results_log = []
117
  answers_payload = []
118
  print(f"Running agent on {len(questions_data)} questions...")
119
+ #questions_data = [i for i in questions_data if i['task_id'] == '5a0c1adf-205e-4841-a666-7c3ef95def9d']
120
  images = []
121
  #added limit for testing
122
+ for req_num, item in enumerate(questions_data):
123
+ if req_num % 2 == 0:
124
+ time.sleep(30)
125
+ else:
126
+ task_id = item.get("task_id")
127
+ question_text = item.get("question") + ' You can use wikipedia. Do not change original names or omit name parts.'
128
+ file_name = item.get('file_name')
129
+ if file_name:
130
+ file_path = download_files(task_id, file_name)
131
+ file_format = file_name.split('.')[-1]
132
+ question_text = question_text + f"This question has an associated file at path: {file_path}. The file is in the {file_format} format"
133
+ if not task_id or question_text is None:
134
+ print(f"Skipping item with missing task_id or question: {item}")
135
+ continue
136
+ try:
137
+ submitted_answer = agent(question_text)
138
+ print(submitted_answer)
139
+ answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
140
+ results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
141
+ except Exception as e:
142
+ print(f"Error running agent on task {task_id}: {e}")
143
+ results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": f"AGENT ERROR: {e}"})
144
 
145
  if not answers_payload:
146
  print("Agent did not produce any answers to submit.")
147
  return "Agent did not produce any answers to submit.", pd.DataFrame(results_log)
148
 
149
+ #4. Prepare Submission
150
  submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload}
151
  status_update = f"Agent finished. Submitting {len(answers_payload)} answers for user '{username}'..."
152
  print(status_update)
 
248
  print("-"*(60 + len(" App Starting ")) + "\n")
249
 
250
  print("Launching Gradio Interface for Basic Agent Evaluation...")
251
+ demo.launch(debug=True, share=True)
252
 
253
  if __name__ == "__main__":
254
  run_and_submit_all()
chess.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import json
3
+ import base64
4
+
5
+ def get_base64(file_path: str) -> str:
6
+ with open(file_path, "rb") as f:
7
+ base64_data = base64.b64encode(f.read()).decode('utf-8')
8
+ return base64_data
9
+
10
+ def fen_notation(image_path:str, current_player:str)->str:
11
+ chessvisionai_url = "http://app.chessvision.ai/predict"
12
+ base64_image = get_base64(image_path)
13
+ base64_image_encoded = f"data:image/png;base64,{base64_image}"
14
+ current_player = 'black'
15
+ if current_player not in ["black", "white"]:
16
+ raise ValueError("current_player must be 'black' or 'white'")
17
+ payload = {
18
+ "board_orientation": "predict",
19
+ "cropped": False,
20
+ "current_player": current_player,
21
+ "image": base64_image_encoded,
22
+ "predict_turn": False
23
+ }
24
+ response = requests.post(chessvisionai_url
25
+ , json=payload)
26
+ if response.status_code == 200:
27
+ fen_notation = response.json()['result'].replace('_', ' ')
28
+ return fen_notation
29
+ else:
30
+ raise Exception('Error retrieving fen' + response.status_code + response.text)
31
+
32
+ def chess_analysis(fen_notation:str)->str:
33
+ chess_api = "https://chess-api.com/v1"
34
+ url = chess_api
35
+ payload = {
36
+ "fen": fen_notation
37
+ }
38
+ chess_response = requests.post(url, json=payload)
39
+ if chess_response.status_code == 200:
40
+ best_move = chess_response.json().get('san')
41
+ return best_move
42
+ else:
43
+ raise Exception(f'Error occurred {chess_response.status_code} {chess_response.text}')
44
+
tools.py CHANGED
@@ -15,7 +15,7 @@ from openai import OpenAI
15
  from llama_index.readers.youtube_transcript import YoutubeTranscriptReader
16
  from google import genai
17
  from google.genai import types
18
-
19
  class WikipediaSearch(Tool):
20
  name = "wikipedia_search"
21
  description = "Fetches wikipedia pages."
@@ -143,7 +143,7 @@ class YouTubeVideoUnderstanding(Tool):
143
  load_dotenv(find_dotenv())
144
  client = genai.Client(api_key=os.getenv("GEMINI_API_KEY"))
145
  response = client.models.generate_content(
146
- model='models/gemini-2.5-flash-preview-04-17',
147
  contents=types.Content(
148
  parts=[
149
  types.Part(
@@ -153,4 +153,54 @@ class YouTubeVideoUnderstanding(Tool):
153
  ]
154
  )
155
  )
156
- return response.text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  from llama_index.readers.youtube_transcript import YoutubeTranscriptReader
16
  from google import genai
17
  from google.genai import types
18
+ import chess
19
  class WikipediaSearch(Tool):
20
  name = "wikipedia_search"
21
  description = "Fetches wikipedia pages."
 
143
  load_dotenv(find_dotenv())
144
  client = genai.Client(api_key=os.getenv("GEMINI_API_KEY"))
145
  response = client.models.generate_content(
146
+ model='models/gemini-2.0-flash',
147
  contents=types.Content(
148
  parts=[
149
  types.Part(
 
153
  ]
154
  )
155
  )
156
+ return response.text
157
+
158
+ class VegetableFruitClassification(Tool):
159
+ name = 'vegetable_fruit_classificaiton'
160
+ description = "a tool that can help classify fruits and vegetables"
161
+ inputs = {
162
+ "prompt": {
163
+ "type": "string",
164
+ "description": "user prompt about fruits or vegetables"
165
+
166
+ }
167
+ }
168
+ output_type = "string"
169
+
170
+ def forward(self, prompt:str)->str:
171
+ load_dotenv(find_dotenv())
172
+ client = genai.Client(api_key=os.getenv("GEMINI_API_KEY"))
173
+ additional_context = """
174
+ The botanical distinction between fruits and vegetables is anatomical of the plant in question.
175
+ For example, a tomato has seeds, which would result in reproduction. Rhubarb is the stalk of a plant, and has no means of proliferation after consumption.
176
+ A tomato is a botanical fruit and rhubarb is botanically a vegetable. """
177
+ extended_prompt = prompt + additional_context
178
+ response = client.models.generate_content(
179
+ model='models/gemini-2.5-flash-preview-05-20',
180
+ contents=types.Content(
181
+ parts=[
182
+ types.Part(text=extended_prompt)
183
+ ]
184
+ )
185
+ )
186
+ return response.text
187
+
188
+ class ChessSolver(Tool):
189
+ name = "chess_analysis_tool"
190
+ description = "analyzes the chess board to determine the best next move."
191
+ inputs = {
192
+ "image_path": {
193
+ "type": "string",
194
+ "description": "path to the image showing a chess board."
195
+ },
196
+ "current_player":{
197
+ "type": "string",
198
+ "description": "player whose turn it is. Acceptable inputs are 'black' or 'white'"
199
+ },
200
+ }
201
+ output_type = "string"
202
+
203
+ def forward(self, image_path:str, current_player:str)->str:
204
+ fen = chess.fen_notation(image_path, current_player)
205
+ best_move = chess.chess_analysis(fen)
206
+ return best_move