kshitijthakkar commited on
Commit
97b162d
·
1 Parent(s): 0b8bed8

feat: Integrate prompt template generation into synthetic data workflow

Browse files

- Modified synthetic data screen to generate both dataset AND prompt template
- Added new state variable for storing prompt template data
- Enhanced UI with tabs to show dataset and prompt template previews
- Support for agent_type='both': generates templates for BOTH tool and code agents
- Updated on_generate_synthetic_data to call generate_prompt_template MCP tool
- Modified on_push_to_hub to include prompt template in dataset card
- Prompt template automatically included when pushing to HuggingFace Hub
- Complete end-to-end workflow: Generate → Review → Push (with template)

Files changed (1) hide show
  1. app.py +126 -26
app.py CHANGED
@@ -2112,8 +2112,9 @@ with gr.Blocks(title="TraceMind-AI", theme=theme) as app:
2112
 
2113
  gr.Markdown("---")
2114
 
2115
- # Store generated dataset in component state
2116
  generated_dataset_state = gr.State(None)
 
2117
 
2118
  # Step 1: Generate Dataset
2119
  with gr.Group():
@@ -2161,14 +2162,29 @@ with gr.Blocks(title="TraceMind-AI", theme=theme) as app:
2161
 
2162
  # Step 2: Review Dataset
2163
  with gr.Group():
2164
- gr.Markdown("### 🔍 Step 2: Review Generated Dataset")
2165
 
2166
- dataset_preview = gr.JSON(
2167
- label="Generated Dataset (Preview)",
2168
- visible=False
2169
- )
 
2170
 
2171
- dataset_stats = gr.Markdown("", visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2172
 
2173
  # Step 3: Push to Hub
2174
  with gr.Group():
@@ -3010,15 +3026,17 @@ No historical data available for **{model}**.
3010
 
3011
  # Synthetic Data Generator Callbacks
3012
  def on_generate_synthetic_data(domain, tools, num_tasks, difficulty, agent_type):
3013
- """Generate synthetic dataset using MCP server"""
3014
  try:
3015
  from gradio_client import Client
 
3016
 
3017
  # Connect to MCP server
3018
  client = Client("https://mcp-1st-birthday-tracemind-mcp-server.hf.space/")
3019
 
3020
- # Call the synthetic data generation endpoint
3021
- result = client.predict(
 
3022
  domain=domain,
3023
  tools=tools,
3024
  num_tasks=int(num_tasks),
@@ -3027,15 +3045,82 @@ No historical data available for **{model}**.
3027
  api_name="/run_generate_synthetic"
3028
  )
3029
 
3030
- # Parse the result
3031
- import json
3032
- if isinstance(result, str):
3033
  try:
3034
- dataset = json.loads(result)
3035
  except:
3036
- dataset = {"raw_result": result}
3037
  else:
3038
- dataset = result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3039
 
3040
  # Generate stats
3041
  task_count = len(dataset.get('tasks', [])) if isinstance(dataset.get('tasks'), list) else 0
@@ -3046,27 +3131,29 @@ No historical data available for **{model}**.
3046
  suggested_repo_name = f"{default_username}/smoltrace-{domain_clean}-tasks"
3047
 
3048
  stats_md = f"""
3049
- ### ✅ Dataset Generated Successfully!
3050
 
3051
  - **Total Tasks**: {task_count}
3052
  - **Domain**: {dataset.get('domain', domain)}
3053
  - **Difficulty**: {dataset.get('difficulty', difficulty)}
3054
  - **Agent Type**: {dataset.get('agent_type', agent_type)}
3055
  - **Tools Available**: {len(tools.split(','))}
 
3056
 
3057
- Review the dataset below and push to HuggingFace Hub when ready.
3058
 
3059
  **Suggested repo name**: `{suggested_repo_name}`
3060
 
3061
- 💡 **Tip**: Using environment HF token? Keep the default username.
3062
- Want to push to your own profile? Update repo name to `your-username/smoltrace-{domain_clean}-tasks` and provide your HF token.
3063
  """
3064
 
3065
  return {
3066
  generated_dataset_state: dataset,
 
3067
  dataset_preview: gr.update(value=dataset, visible=True),
3068
  dataset_stats: gr.update(value=stats_md, visible=True),
3069
- generation_status: "✅ Dataset generated successfully! Review below.",
 
3070
  push_btn: gr.update(visible=True),
3071
  repo_name_input: gr.update(value=suggested_repo_name)
3072
  }
@@ -3079,15 +3166,17 @@ No historical data available for **{model}**.
3079
 
3080
  return {
3081
  generated_dataset_state: None,
 
3082
  dataset_preview: gr.update(visible=False),
3083
  dataset_stats: gr.update(visible=False),
 
3084
  generation_status: error_msg,
3085
  push_btn: gr.update(visible=False),
3086
  repo_name_input: gr.update(value="")
3087
  }
3088
 
3089
- def on_push_to_hub(dataset, repo_name, hf_token, private):
3090
- """Push dataset to HuggingFace Hub"""
3091
  try:
3092
  from gradio_client import Client
3093
  import os
@@ -3100,6 +3189,16 @@ No historical data available for **{model}**.
3100
  if not repo_name:
3101
  return "❌ Please provide a repository name."
3102
 
 
 
 
 
 
 
 
 
 
 
3103
  # Determine which HF token to use (user-provided or environment)
3104
  if hf_token and hf_token.strip():
3105
  # User provided a token
@@ -3152,12 +3251,13 @@ No historical data available for **{model}**.
3152
  print(f"[INFO] Private: {private}")
3153
  print(f"[INFO] Passing HF token to MCP server (source: {token_source})")
3154
 
3155
- # Call the push dataset endpoint with the token
3156
  result = client.predict(
3157
  dataset_json=dataset_json,
3158
  repo_name=repo_name,
3159
  hf_token=token_to_use, # Token from user input OR environment
3160
  private=private,
 
3161
  api_name="/run_push_dataset"
3162
  )
3163
 
@@ -3394,12 +3494,12 @@ Result: {result}
3394
  generate_btn.click(
3395
  fn=on_generate_synthetic_data,
3396
  inputs=[domain_input, tools_input, num_tasks_input, difficulty_input, agent_type_input],
3397
- outputs=[generated_dataset_state, dataset_preview, dataset_stats, generation_status, push_btn, repo_name_input]
3398
  )
3399
 
3400
  push_btn.click(
3401
  fn=on_push_to_hub,
3402
- inputs=[generated_dataset_state, repo_name_input, hf_token_input, private_checkbox],
3403
  outputs=[push_status]
3404
  )
3405
 
 
2112
 
2113
  gr.Markdown("---")
2114
 
2115
+ # Store generated dataset and prompt template in component state
2116
  generated_dataset_state = gr.State(None)
2117
+ generated_prompt_template_state = gr.State(None)
2118
 
2119
  # Step 1: Generate Dataset
2120
  with gr.Group():
 
2162
 
2163
  # Step 2: Review Dataset
2164
  with gr.Group():
2165
+ gr.Markdown("### 🔍 Step 2: Review Generated Dataset & Prompt Template")
2166
 
2167
+ with gr.Tab("📊 Dataset Preview"):
2168
+ dataset_preview = gr.JSON(
2169
+ label="Generated Dataset",
2170
+ visible=False
2171
+ )
2172
 
2173
+ dataset_stats = gr.Markdown("", visible=False)
2174
+
2175
+ with gr.Tab("📝 Prompt Template"):
2176
+ gr.Markdown("""
2177
+ **AI-Generated Prompt Template**
2178
+
2179
+ This customized prompt template is based on smolagents templates and adapted for your domain and tools.
2180
+ It will be automatically included in your dataset card when you push to HuggingFace Hub.
2181
+ """)
2182
+
2183
+ prompt_template_preview = gr.Code(
2184
+ label="Customized Prompt Template (YAML)",
2185
+ language="yaml",
2186
+ visible=False
2187
+ )
2188
 
2189
  # Step 3: Push to Hub
2190
  with gr.Group():
 
3026
 
3027
  # Synthetic Data Generator Callbacks
3028
  def on_generate_synthetic_data(domain, tools, num_tasks, difficulty, agent_type):
3029
+ """Generate synthetic dataset AND prompt template using MCP server"""
3030
  try:
3031
  from gradio_client import Client
3032
+ import json
3033
 
3034
  # Connect to MCP server
3035
  client = Client("https://mcp-1st-birthday-tracemind-mcp-server.hf.space/")
3036
 
3037
+ # ===== STEP 1: Generate Dataset =====
3038
+ print(f"[INFO] Generating synthetic dataset for domain: {domain}")
3039
+ dataset_result = client.predict(
3040
  domain=domain,
3041
  tools=tools,
3042
  num_tasks=int(num_tasks),
 
3045
  api_name="/run_generate_synthetic"
3046
  )
3047
 
3048
+ # Parse the dataset result
3049
+ if isinstance(dataset_result, str):
 
3050
  try:
3051
+ dataset = json.loads(dataset_result)
3052
  except:
3053
+ dataset = {"raw_result": dataset_result}
3054
  else:
3055
+ dataset = dataset_result
3056
+
3057
+ # ===== STEP 2: Generate Prompt Template(s) =====
3058
+ # When agent_type="both", generate templates for both tool and code agents
3059
+ agent_types_to_generate = ["tool", "code"] if agent_type == "both" else [agent_type]
3060
+ print(f"[INFO] Generating prompt template(s) for: {agent_types_to_generate}")
3061
+
3062
+ prompt_templates = {}
3063
+ try:
3064
+ for current_agent_type in agent_types_to_generate:
3065
+ print(f"[INFO] Generating {current_agent_type} agent template for domain: {domain}")
3066
+
3067
+ template_result = client.predict(
3068
+ domain=domain,
3069
+ tools=tools,
3070
+ agent_type=current_agent_type,
3071
+ api_name="/run_generate_prompt_template"
3072
+ )
3073
+
3074
+ # Parse the template result
3075
+ if isinstance(template_result, dict):
3076
+ prompt_template_data = template_result
3077
+ elif isinstance(template_result, str):
3078
+ try:
3079
+ prompt_template_data = json.loads(template_result)
3080
+ except:
3081
+ prompt_template_data = {"error": "Failed to parse template response"}
3082
+ else:
3083
+ prompt_template_data = {"error": "Unexpected template response format"}
3084
+
3085
+ # Extract the YAML template
3086
+ if "prompt_template" in prompt_template_data:
3087
+ prompt_templates[current_agent_type] = prompt_template_data["prompt_template"]
3088
+ print(f"[INFO] {current_agent_type} agent template generated successfully")
3089
+ elif "error" in prompt_template_data:
3090
+ prompt_templates[current_agent_type] = f"# Error generating template:\n# {prompt_template_data['error']}"
3091
+ print(f"[WARNING] {current_agent_type} template generation error: {prompt_template_data['error']}")
3092
+ else:
3093
+ prompt_templates[current_agent_type] = "# Template format not recognized"
3094
+ print(f"[WARNING] Unexpected template format for {current_agent_type}")
3095
+
3096
+ # Combine templates for display
3097
+ if agent_type == "both":
3098
+ prompt_template = f"""# ========================================
3099
+ # TOOL AGENT TEMPLATE (ToolCallingAgent)
3100
+ # ========================================
3101
+
3102
+ {prompt_templates.get('tool', '# Failed to generate tool agent template')}
3103
+
3104
+ # ========================================
3105
+ # CODE AGENT TEMPLATE (CodeAgent)
3106
+ # ========================================
3107
+
3108
+ {prompt_templates.get('code', '# Failed to generate code agent template')}
3109
+ """
3110
+ else:
3111
+ prompt_template = prompt_templates.get(agent_type, "# Template not generated")
3112
+
3113
+ # Store all templates in data for push_to_hub
3114
+ prompt_template_data = {
3115
+ "agent_type": agent_type,
3116
+ "templates": prompt_templates,
3117
+ "combined": prompt_template
3118
+ }
3119
+
3120
+ except Exception as template_error:
3121
+ print(f"[WARNING] Failed to generate prompt template: {template_error}")
3122
+ prompt_template = f"# Failed to generate template: {str(template_error)}"
3123
+ prompt_template_data = {"error": str(template_error)}
3124
 
3125
  # Generate stats
3126
  task_count = len(dataset.get('tasks', [])) if isinstance(dataset.get('tasks'), list) else 0
 
3131
  suggested_repo_name = f"{default_username}/smoltrace-{domain_clean}-tasks"
3132
 
3133
  stats_md = f"""
3134
+ ### ✅ Dataset & Prompt Template Generated Successfully!
3135
 
3136
  - **Total Tasks**: {task_count}
3137
  - **Domain**: {dataset.get('domain', domain)}
3138
  - **Difficulty**: {dataset.get('difficulty', difficulty)}
3139
  - **Agent Type**: {dataset.get('agent_type', agent_type)}
3140
  - **Tools Available**: {len(tools.split(','))}
3141
+ - **Prompt Template**: ✅ AI-customized for your domain
3142
 
3143
+ Review both the dataset and prompt template in the tabs above, then push to HuggingFace Hub when ready.
3144
 
3145
  **Suggested repo name**: `{suggested_repo_name}`
3146
 
3147
+ 💡 **Tip**: The prompt template will be automatically included in your dataset card!
 
3148
  """
3149
 
3150
  return {
3151
  generated_dataset_state: dataset,
3152
+ generated_prompt_template_state: prompt_template_data,
3153
  dataset_preview: gr.update(value=dataset, visible=True),
3154
  dataset_stats: gr.update(value=stats_md, visible=True),
3155
+ prompt_template_preview: gr.update(value=prompt_template, visible=True),
3156
+ generation_status: "✅ Dataset & prompt template generated! Review in tabs above.",
3157
  push_btn: gr.update(visible=True),
3158
  repo_name_input: gr.update(value=suggested_repo_name)
3159
  }
 
3166
 
3167
  return {
3168
  generated_dataset_state: None,
3169
+ generated_prompt_template_state: None,
3170
  dataset_preview: gr.update(visible=False),
3171
  dataset_stats: gr.update(visible=False),
3172
+ prompt_template_preview: gr.update(visible=False),
3173
  generation_status: error_msg,
3174
  push_btn: gr.update(visible=False),
3175
  repo_name_input: gr.update(value="")
3176
  }
3177
 
3178
+ def on_push_to_hub(dataset, prompt_template_data, repo_name, hf_token, private):
3179
+ """Push dataset AND prompt template to HuggingFace Hub"""
3180
  try:
3181
  from gradio_client import Client
3182
  import os
 
3189
  if not repo_name:
3190
  return "❌ Please provide a repository name."
3191
 
3192
+ # Extract prompt template for pushing
3193
+ prompt_template_to_push = None
3194
+ if prompt_template_data and isinstance(prompt_template_data, dict):
3195
+ if "combined" in prompt_template_data:
3196
+ prompt_template_to_push = prompt_template_data["combined"]
3197
+ elif "prompt_template" in prompt_template_data:
3198
+ prompt_template_to_push = prompt_template_data["prompt_template"]
3199
+
3200
+ print(f"[INFO] Prompt template will {'be included' if prompt_template_to_push else 'NOT be included'} in dataset card")
3201
+
3202
  # Determine which HF token to use (user-provided or environment)
3203
  if hf_token and hf_token.strip():
3204
  # User provided a token
 
3251
  print(f"[INFO] Private: {private}")
3252
  print(f"[INFO] Passing HF token to MCP server (source: {token_source})")
3253
 
3254
+ # Call the push dataset endpoint with the token and prompt template
3255
  result = client.predict(
3256
  dataset_json=dataset_json,
3257
  repo_name=repo_name,
3258
  hf_token=token_to_use, # Token from user input OR environment
3259
  private=private,
3260
+ prompt_template=prompt_template_to_push if prompt_template_to_push else "", # Include template if available
3261
  api_name="/run_push_dataset"
3262
  )
3263
 
 
3494
  generate_btn.click(
3495
  fn=on_generate_synthetic_data,
3496
  inputs=[domain_input, tools_input, num_tasks_input, difficulty_input, agent_type_input],
3497
+ outputs=[generated_dataset_state, generated_prompt_template_state, dataset_preview, dataset_stats, prompt_template_preview, generation_status, push_btn, repo_name_input]
3498
  )
3499
 
3500
  push_btn.click(
3501
  fn=on_push_to_hub,
3502
+ inputs=[generated_dataset_state, generated_prompt_template_state, repo_name_input, hf_token_input, private_checkbox],
3503
  outputs=[push_status]
3504
  )
3505