kikikita commited on
Commit
eccd64e
·
1 Parent(s): 64ad372

Revert "feat: refactor API key management and update client usage in audio and image generation"

Browse files

This reverts commit 64ad372380e1c14b443720d14b377e55352f60f9.

src/agent/llm.py CHANGED
@@ -4,31 +4,17 @@ import logging
4
  from langchain_google_genai import ChatGoogleGenerativeAI
5
 
6
  from config import settings
 
7
 
8
  logger = logging.getLogger(__name__)
9
 
10
- _API_KEYS: list[str] = []
11
- _current_key_idx = 0
12
  MODEL_NAME = "gemini-2.5-flash-preview-05-20"
13
 
14
 
15
  def _get_api_key() -> str:
16
- """Return an API key using round-robin selection."""
17
- global _API_KEYS, _current_key_idx
18
-
19
- if not _API_KEYS:
20
- keys_str = settings.gemini_api_key.get_secret_value()
21
- if keys_str:
22
- _API_KEYS = [k.strip() for k in keys_str.split(",") if k.strip()]
23
- if not _API_KEYS:
24
- msg = "Google API keys are not configured or invalid"
25
- logger.error(msg)
26
- raise ValueError(msg)
27
-
28
- key = _API_KEYS[_current_key_idx]
29
- _current_key_idx = (_current_key_idx + 1) % len(_API_KEYS)
30
- logger.debug("Using Google API key index %s", _current_key_idx)
31
- return key
32
 
33
 
34
  def create_llm(
 
4
  from langchain_google_genai import ChatGoogleGenerativeAI
5
 
6
  from config import settings
7
+ from services.google import ApiKeyPool
8
 
9
  logger = logging.getLogger(__name__)
10
 
11
+ _pool = ApiKeyPool()
 
12
  MODEL_NAME = "gemini-2.5-flash-preview-05-20"
13
 
14
 
15
  def _get_api_key() -> str:
16
+ """Return an API key using round-robin selection in a thread-safe way."""
17
+ return _pool.get_key_sync()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
 
20
  def create_llm(
src/audio/audio_generator.py CHANGED
@@ -1,7 +1,5 @@
1
  import asyncio
2
- from google import genai
3
  from google.genai import types
4
- from config import settings
5
  import wave
6
  import queue
7
  import logging
@@ -10,33 +8,40 @@ import time
10
 
11
  logger = logging.getLogger(__name__)
12
 
13
- client = genai.Client(api_key=settings.gemini_api_key.get_secret_value(), http_options={'api_version': 'v1alpha'})
14
 
15
  async def generate_music(user_hash: str, music_tone: str, receive_audio):
16
  if user_hash in sessions:
17
  return
18
- async with (
19
- client.aio.live.music.connect(model='models/lyria-realtime-exp') as session,
20
- asyncio.TaskGroup() as tg,
21
- ):
22
- # Set up task to receive server messages.
23
- tg.create_task(receive_audio(session, user_hash))
 
24
 
25
- # Send initial prompts and config
26
- await session.set_weighted_prompts(
27
- prompts=[
28
- types.WeightedPrompt(text=music_tone, weight=1.0),
29
- ]
30
- )
31
- await session.set_music_generation_config(
32
- config=types.LiveMusicGenerationConfig(bpm=90, temperature=1.0)
33
- )
34
- await session.play()
35
- logger.info(f"Started music generation for user hash {user_hash}, music tone: {music_tone}")
36
- sessions[user_hash] = {
37
- 'session': session,
38
- 'queue': queue.Queue()
39
- }
 
 
 
 
 
 
40
 
41
  async def change_music_tone(user_hash: str, new_tone):
42
  logger.info(f"Changing music tone to {new_tone}")
@@ -44,8 +49,11 @@ async def change_music_tone(user_hash: str, new_tone):
44
  if not session:
45
  logger.error(f"No session found for user hash {user_hash}")
46
  return
47
- await session.set_weighted_prompts(
48
- prompts=[types.WeightedPrompt(text=new_tone, weight=1.0)]
 
 
 
49
  )
50
 
51
 
@@ -78,8 +86,8 @@ async def cleanup_music_session(user_hash: str):
78
  if user_hash in sessions:
79
  logger.info(f"Cleaning up music session for user hash {user_hash}")
80
  session = sessions[user_hash]['session']
81
- await session.stop()
82
- await session.close()
83
  del sessions[user_hash]
84
 
85
 
@@ -117,4 +125,4 @@ def update_audio(user_hash):
117
  wf.setframerate(SAMPLE_RATE)
118
  wf.writeframes(pcm_data)
119
  wav_bytes = wav_buffer.getvalue()
120
- yield wav_bytes
 
1
  import asyncio
 
2
  from google.genai import types
 
3
  import wave
4
  import queue
5
  import logging
 
8
 
9
  logger = logging.getLogger(__name__)
10
 
11
+ from services.google import GoogleClientFactory
12
 
13
  async def generate_music(user_hash: str, music_tone: str, receive_audio):
14
  if user_hash in sessions:
15
  return
16
+ async with GoogleClientFactory.audio() as client:
17
+ async with (
18
+ client.live.music.connect(model='models/lyria-realtime-exp') as session,
19
+ asyncio.TaskGroup() as tg,
20
+ ):
21
+ # Set up task to receive server messages.
22
+ tg.create_task(receive_audio(session, user_hash))
23
 
24
+ # Send initial prompts and config
25
+ await asyncio.wait_for(
26
+ session.set_weighted_prompts(
27
+ prompts=[types.WeightedPrompt(text=music_tone, weight=1.0)]
28
+ ),
29
+ 40,
30
+ )
31
+ await asyncio.wait_for(
32
+ session.set_music_generation_config(
33
+ config=types.LiveMusicGenerationConfig(bpm=90, temperature=1.0)
34
+ ),
35
+ 40,
36
+ )
37
+ await asyncio.wait_for(session.play(), 40)
38
+ logger.info(
39
+ f"Started music generation for user hash {user_hash}, music tone: {music_tone}"
40
+ )
41
+ sessions[user_hash] = {
42
+ 'session': session,
43
+ 'queue': queue.Queue()
44
+ }
45
 
46
  async def change_music_tone(user_hash: str, new_tone):
47
  logger.info(f"Changing music tone to {new_tone}")
 
49
  if not session:
50
  logger.error(f"No session found for user hash {user_hash}")
51
  return
52
+ await asyncio.wait_for(
53
+ session.set_weighted_prompts(
54
+ prompts=[types.WeightedPrompt(text=new_tone, weight=1.0)]
55
+ ),
56
+ 40,
57
  )
58
 
59
 
 
86
  if user_hash in sessions:
87
  logger.info(f"Cleaning up music session for user hash {user_hash}")
88
  session = sessions[user_hash]['session']
89
+ await asyncio.wait_for(session.stop(), 40)
90
+ await asyncio.wait_for(session.close(), 40)
91
  del sessions[user_hash]
92
 
93
 
 
125
  wf.setframerate(SAMPLE_RATE)
126
  wf.writeframes(pcm_data)
127
  wav_bytes = wav_buffer.getvalue()
128
+ yield wav_bytes
src/images/image_generator.py CHANGED
@@ -1,17 +1,15 @@
1
- from google import genai
2
  from google.genai import types
3
  import os
4
  from PIL import Image
5
  from io import BytesIO
6
  from datetime import datetime
7
- from config import settings
8
  import logging
9
  import asyncio
10
  import gradio as gr
11
 
12
- logger = logging.getLogger(__name__)
13
 
14
- client = genai.Client(api_key=settings.gemini_api_key.get_secret_value()).aio
15
 
16
  safety_settings = [
17
  types.SafetySetting(
@@ -50,14 +48,18 @@ async def generate_image(prompt: str) -> tuple[str, str] | None:
50
  logger.info(f"Generating image with prompt: {prompt}")
51
 
52
  try:
53
- response = await client.models.generate_content(
54
- model="gemini-2.0-flash-preview-image-generation",
55
- contents=prompt,
56
- config=types.GenerateContentConfig(
57
- response_modalities=["TEXT", "IMAGE"],
58
- safety_settings=safety_settings,
59
- ),
60
- )
 
 
 
 
61
 
62
  # Process the response parts
63
  image_saved = False
@@ -108,23 +110,23 @@ async def modify_image(image_path: str, modification_prompt: str) -> str | None:
108
  logger.error(f"Error: Image file not found at {image_path}")
109
  return None
110
 
111
- key = settings.gemini_api_key.get_secret_value()
112
-
113
- client = genai.Client(api_key=key).aio
114
-
115
  try:
116
- # Load the input image
117
- input_image = Image.open(image_path)
118
-
119
- # Make the API call with both text and image
120
- response = await client.models.generate_content(
121
- model="gemini-2.0-flash-preview-image-generation",
122
- contents=[modification_prompt, input_image],
123
- config=types.GenerateContentConfig(
124
- response_modalities=["TEXT", "IMAGE"],
125
- safety_settings=safety_settings,
126
- ),
127
- )
 
 
 
 
128
 
129
  # Process the response parts
130
  image_saved = False
 
 
1
  from google.genai import types
2
  import os
3
  from PIL import Image
4
  from io import BytesIO
5
  from datetime import datetime
 
6
  import logging
7
  import asyncio
8
  import gradio as gr
9
 
10
+ from services.google import GoogleClientFactory
11
 
12
+ logger = logging.getLogger(__name__)
13
 
14
  safety_settings = [
15
  types.SafetySetting(
 
48
  logger.info(f"Generating image with prompt: {prompt}")
49
 
50
  try:
51
+ async with GoogleClientFactory.image() as client:
52
+ response = await asyncio.wait_for(
53
+ client.models.generate_content(
54
+ model="gemini-2.0-flash-preview-image-generation",
55
+ contents=prompt,
56
+ config=types.GenerateContentConfig(
57
+ response_modalities=["TEXT", "IMAGE"],
58
+ safety_settings=safety_settings,
59
+ ),
60
+ ),
61
+ 40,
62
+ )
63
 
64
  # Process the response parts
65
  image_saved = False
 
110
  logger.error(f"Error: Image file not found at {image_path}")
111
  return None
112
 
 
 
 
 
113
  try:
114
+ async with GoogleClientFactory.image() as client:
115
+ # Load the input image
116
+ input_image = Image.open(image_path)
117
+
118
+ # Make the API call with both text and image
119
+ response = await asyncio.wait_for(
120
+ client.models.generate_content(
121
+ model="gemini-2.0-flash-preview-image-generation",
122
+ contents=[modification_prompt, input_image],
123
+ config=types.GenerateContentConfig(
124
+ response_modalities=["TEXT", "IMAGE"],
125
+ safety_settings=safety_settings,
126
+ ),
127
+ ),
128
+ 40,
129
+ )
130
 
131
  # Process the response parts
132
  image_saved = False
src/main.py CHANGED
@@ -366,4 +366,5 @@ with gr.Blocks(
366
  outputs=[audio_out],
367
  )
368
 
 
369
  demo.launch(ssr_mode=False)
 
366
  outputs=[audio_out],
367
  )
368
 
369
+ demo.queue()
370
  demo.launch(ssr_mode=False)
src/services/google.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import logging
3
+ from contextlib import asynccontextmanager
4
+ from google import genai
5
+ import threading
6
+
7
+ from config import settings
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class ApiKeyPool:
13
+ """Manage Google API keys with round-robin selection."""
14
+
15
+ def __init__(self) -> None:
16
+ self._keys: list[str] | None = None
17
+ self._index = 0
18
+ self._lock = asyncio.Lock()
19
+ self._sync_lock = threading.Lock()
20
+
21
+ def _load_keys(self) -> None:
22
+ keys_raw = (
23
+ getattr(settings, "gemini_api_keys", None) or settings.gemini_api_key
24
+ )
25
+ keys_str = keys_raw.get_secret_value()
26
+ keys = [k.strip() for k in keys_str.split(',') if k.strip()] if keys_str else []
27
+ if not keys:
28
+ msg = "Google API keys are not configured or invalid"
29
+ logger.error(msg)
30
+ raise ValueError(msg)
31
+ self._keys = keys
32
+
33
+ async def get_key(self) -> str:
34
+ async with self._lock:
35
+ if self._keys is None:
36
+ self._load_keys()
37
+ key = self._keys[self._index]
38
+ self._index = (self._index + 1) % len(self._keys)
39
+ logger.debug("Using Google API key index %s", self._index)
40
+ return key
41
+
42
+ def get_key_sync(self) -> str:
43
+ """Synchronous helper for environments without an event loop."""
44
+ with self._sync_lock:
45
+ if self._keys is None:
46
+ self._load_keys()
47
+ key = self._keys[self._index]
48
+ self._index = (self._index + 1) % len(self._keys)
49
+ logger.debug("Using Google API key index %s", self._index)
50
+ return key
51
+
52
+
53
+ class GoogleClientFactory:
54
+ """Factory for thread-safe creation of Google GenAI clients."""
55
+
56
+ _pool = ApiKeyPool()
57
+
58
+ @classmethod
59
+ @asynccontextmanager
60
+ async def image(cls):
61
+ key = await cls._pool.get_key()
62
+ client = genai.Client(api_key=key)
63
+ try:
64
+ yield client.aio
65
+ finally:
66
+ pass
67
+
68
+ @classmethod
69
+ @asynccontextmanager
70
+ async def audio(cls):
71
+ key = await cls._pool.get_key()
72
+ client = genai.Client(api_key=key, http_options={"api_version": "v1alpha"})
73
+ try:
74
+ yield client.aio
75
+ finally:
76
+ pass