moveit / app.py
Emilyxml's picture
Update app.py
cb8b5e8 verified
import gradio as gr
import os
import random
import uuid
import csv
from datetime import datetime
from pathlib import Path
from PIL import Image # 引入 PIL 用于处理图片
from huggingface_hub import CommitScheduler, snapshot_download
# --- 1. 配置区域 ---
DATASET_REPO_ID = "Emilyxml/moveit"
DATA_FOLDER = "data"
LOG_FOLDER = Path("logs")
LOG_FOLDER.mkdir(parents=True, exist_ok=True)
TOKEN = os.environ.get("HF_TOKEN")
# --- 2. 自动下载数据 ---
if not os.path.exists(DATA_FOLDER) or not os.listdir(DATA_FOLDER):
try:
print("🚀 正在从 Dataset 下载数据...")
snapshot_download(
repo_id=DATASET_REPO_ID,
repo_type="dataset",
local_dir=DATA_FOLDER,
token=TOKEN,
allow_patterns=["*.jpg", "*.png", "*.jpeg", "*.webp", "*.txt"]
)
print("✅ 数据下载完成!")
except Exception as e:
print(f"⚠️ 下载失败: {e}")
# --- 3. 启动同步调度器 ---
scheduler = CommitScheduler(
repo_id=DATASET_REPO_ID,
repo_type="dataset",
folder_path=LOG_FOLDER,
path_in_repo="logs",
every=1,
token=TOKEN
)
# --- 4. 数据加载 ---
def load_data():
groups = {}
if not os.path.exists(DATA_FOLDER):
return {}, []
for filename in os.listdir(DATA_FOLDER):
if filename.startswith('.'): continue
file_path = os.path.join(DATA_FOLDER, filename)
prefix = filename[:5]
if prefix not in groups:
groups[prefix] = {"origin": None, "candidates": [], "instruction": "暂无说明"}
if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')):
if "_origin" in filename.lower():
groups[prefix]["origin"] = file_path
else:
groups[prefix]["candidates"].append(file_path)
elif filename.lower().endswith('.txt'):
try:
with open(file_path, "r", encoding="utf-8") as f:
groups[prefix]["instruction"] = f.read()
except:
with open(file_path, "r", encoding="gbk") as f:
groups[prefix]["instruction"] = f.read()
valid_groups = {}
for k, v in groups.items():
if v["origin"] is not None or len(v["candidates"]) > 0:
valid_groups[k] = v
group_ids = list(valid_groups.keys())
random.shuffle(group_ids)
print(f"Loaded {len(group_ids)} groups.")
return valid_groups, group_ids
ALL_GROUPS, ALL_GROUP_IDS = load_data()
# --- NEW: 图片优化函数 (提速关键) ---
def optimize_image(image_path, max_width=800):
"""
读取图片并调整大小,减少传输时间。
max_width: 限制最大宽度为 800px (足够人眼评估)
"""
if not image_path:
return None
try:
img = Image.open(image_path)
# 如果图片太大,就缩小
if img.width > max_width:
ratio = max_width / img.width
new_height = int(img.height * ratio)
img = img.resize((max_width, new_height), Image.LANCZOS)
return img
except Exception as e:
print(f"Error loading image {image_path}: {e}")
return None
# --- 5. 核心逻辑 ---
def get_next_question(user_state):
"""准备下一题的数据"""
idx = user_state["index"]
if idx >= len(ALL_GROUP_IDS):
return (
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(value="## 🎉 测试结束!感谢您的参与。", visible=True),
user_state,
[]
)
group_id = ALL_GROUP_IDS[idx]
group_data = ALL_GROUPS[group_id]
# 1. 优化原图 (返回 PIL 对象而不是路径)
origin_img = optimize_image(group_data["origin"], max_width=600)
# 2. 优化候选图
candidates = group_data["candidates"].copy()
random.shuffle(candidates)
gallery_items = []
choices = []
candidates_info = []
for i, path in enumerate(candidates):
label = f"Option {chr(65+i)}"
# 优化每张候选图
optimized_img = optimize_image(path, max_width=600)
gallery_items.append((optimized_img, label))
choices.append(label)
candidates_info.append({"label": label, "path": path})
instruction = f"### 任务 ({idx + 1} / {len(ALL_GROUP_IDS)})\n\n{group_data['instruction']}"
return (
gr.update(value=origin_img, visible=True if origin_img else False),
gr.update(value=gallery_items, visible=True),
gr.update(choices=choices, value=[], visible=True),
gr.update(value=instruction, visible=True),
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=False),
user_state,
candidates_info
)
def save_and_next(user_state, candidates_info, selected_options, is_none=False):
current_idx = user_state["index"]
group_id = ALL_GROUP_IDS[current_idx]
if is_none:
choice_str = "Rejected All"
method_str = "None_Satisfied"
else:
if not selected_options:
raise gr.Error("请至少勾选一个选项,或点击“都不满意”")
choice_str = "; ".join(selected_options)
selected_methods = []
for opt in selected_options:
for info in candidates_info:
if info["label"] == opt:
path = info["path"]
filename = os.path.basename(path)
name = os.path.splitext(filename)[0]
parts = name.split('_', 1)
method = parts[1] if len(parts) > 1 else name
selected_methods.append(method)
break
method_str = "; ".join(selected_methods)
user_file = LOG_FOLDER / f"user_{user_state['user_id']}.csv"
with scheduler.lock:
exists = user_file.exists()
with open(user_file, "a", newline="", encoding="utf-8") as f:
writer = csv.writer(f)
if not exists:
writer.writerow(["user_id", "timestamp", "group_id", "choices", "methods"])
writer.writerow([
user_state["user_id"],
datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
group_id,
choice_str,
method_str
])
user_state["index"] += 1
return get_next_question(user_state)
# --- 6. 界面构建 ---
with gr.Blocks(title="User Study") as demo:
state_user = gr.State(lambda: {"user_id": str(uuid.uuid4())[:8], "index": 0})
state_candidates_info = gr.State([])
with gr.Row():
md_instruction = gr.Markdown("Loading...")
with gr.Row():
with gr.Column(scale=1):
# 将 format 设置为 jpeg 进一步减小体积
img_origin = gr.Image(label="Reference (参考原图)", interactive=False, height=400, format="jpeg")
with gr.Column(scale=2):
gallery_candidates = gr.Gallery(
label="Candidates (候选结果)",
columns=[2],
height="auto",
object_fit="contain",
interactive=False,
format="jpeg" # 强制输出 JPEG 格式
)
gr.Markdown("👇 **请在下方勾选您认为最好的结果(可多选):**")
checkbox_options = gr.CheckboxGroup(
choices=[],
label="您的选择",
info="对应上方图片的标签 (Option A, B...)"
)
with gr.Row():
btn_submit = gr.Button("🚀 提交 (Submit)", variant="primary")
btn_none = gr.Button("🚫 都不满意 (None)", variant="stop")
md_end = gr.Markdown(visible=False)
demo.load(
fn=get_next_question,
inputs=[state_user],
outputs=[img_origin, gallery_candidates, checkbox_options, md_instruction, btn_submit, btn_none, md_end, state_user, state_candidates_info]
)
btn_submit.click(
fn=lambda s, c, o: save_and_next(s, c, o, is_none=False),
inputs=[state_user, state_candidates_info, checkbox_options],
outputs=[img_origin, gallery_candidates, checkbox_options, md_instruction, btn_submit, btn_none, md_end, state_user, state_candidates_info]
)
btn_none.click(
fn=lambda s, c, o: save_and_next(s, c, o, is_none=True),
inputs=[state_user, state_candidates_info, checkbox_options],
outputs=[img_origin, gallery_candidates, checkbox_options, md_instruction, btn_submit, btn_none, md_end, state_user, state_candidates_info]
)
if __name__ == "__main__":
demo.launch()