|
|
import gradio as gr |
|
|
import os |
|
|
import random |
|
|
import uuid |
|
|
import csv |
|
|
from datetime import datetime |
|
|
from pathlib import Path |
|
|
from PIL import Image |
|
|
from huggingface_hub import CommitScheduler, snapshot_download |
|
|
|
|
|
|
|
|
DATASET_REPO_ID = "Emilyxml/moveit" |
|
|
DATA_FOLDER = "data" |
|
|
LOG_FOLDER = Path("logs") |
|
|
LOG_FOLDER.mkdir(parents=True, exist_ok=True) |
|
|
TOKEN = os.environ.get("HF_TOKEN") |
|
|
|
|
|
|
|
|
if not os.path.exists(DATA_FOLDER) or not os.listdir(DATA_FOLDER): |
|
|
try: |
|
|
print("🚀 正在从 Dataset 下载数据...") |
|
|
snapshot_download( |
|
|
repo_id=DATASET_REPO_ID, |
|
|
repo_type="dataset", |
|
|
local_dir=DATA_FOLDER, |
|
|
token=TOKEN, |
|
|
allow_patterns=["*.jpg", "*.png", "*.jpeg", "*.webp", "*.txt"] |
|
|
) |
|
|
print("✅ 数据下载完成!") |
|
|
except Exception as e: |
|
|
print(f"⚠️ 下载失败: {e}") |
|
|
|
|
|
|
|
|
scheduler = CommitScheduler( |
|
|
repo_id=DATASET_REPO_ID, |
|
|
repo_type="dataset", |
|
|
folder_path=LOG_FOLDER, |
|
|
path_in_repo="logs", |
|
|
every=1, |
|
|
token=TOKEN |
|
|
) |
|
|
|
|
|
|
|
|
def load_data(): |
|
|
groups = {} |
|
|
if not os.path.exists(DATA_FOLDER): |
|
|
return {}, [] |
|
|
|
|
|
for filename in os.listdir(DATA_FOLDER): |
|
|
if filename.startswith('.'): continue |
|
|
file_path = os.path.join(DATA_FOLDER, filename) |
|
|
prefix = filename[:5] |
|
|
|
|
|
if prefix not in groups: |
|
|
groups[prefix] = {"origin": None, "candidates": [], "instruction": "暂无说明"} |
|
|
|
|
|
if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')): |
|
|
if "_origin" in filename.lower(): |
|
|
groups[prefix]["origin"] = file_path |
|
|
else: |
|
|
groups[prefix]["candidates"].append(file_path) |
|
|
elif filename.lower().endswith('.txt'): |
|
|
try: |
|
|
with open(file_path, "r", encoding="utf-8") as f: |
|
|
groups[prefix]["instruction"] = f.read() |
|
|
except: |
|
|
with open(file_path, "r", encoding="gbk") as f: |
|
|
groups[prefix]["instruction"] = f.read() |
|
|
|
|
|
valid_groups = {} |
|
|
for k, v in groups.items(): |
|
|
if v["origin"] is not None or len(v["candidates"]) > 0: |
|
|
valid_groups[k] = v |
|
|
|
|
|
group_ids = list(valid_groups.keys()) |
|
|
random.shuffle(group_ids) |
|
|
print(f"Loaded {len(group_ids)} groups.") |
|
|
return valid_groups, group_ids |
|
|
|
|
|
ALL_GROUPS, ALL_GROUP_IDS = load_data() |
|
|
|
|
|
|
|
|
def optimize_image(image_path, max_width=800): |
|
|
""" |
|
|
读取图片并调整大小,减少传输时间。 |
|
|
max_width: 限制最大宽度为 800px (足够人眼评估) |
|
|
""" |
|
|
if not image_path: |
|
|
return None |
|
|
try: |
|
|
img = Image.open(image_path) |
|
|
|
|
|
if img.width > max_width: |
|
|
ratio = max_width / img.width |
|
|
new_height = int(img.height * ratio) |
|
|
img = img.resize((max_width, new_height), Image.LANCZOS) |
|
|
return img |
|
|
except Exception as e: |
|
|
print(f"Error loading image {image_path}: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
def get_next_question(user_state): |
|
|
"""准备下一题的数据""" |
|
|
idx = user_state["index"] |
|
|
|
|
|
if idx >= len(ALL_GROUP_IDS): |
|
|
return ( |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(value="## 🎉 测试结束!感谢您的参与。", visible=True), |
|
|
user_state, |
|
|
[] |
|
|
) |
|
|
|
|
|
group_id = ALL_GROUP_IDS[idx] |
|
|
group_data = ALL_GROUPS[group_id] |
|
|
|
|
|
|
|
|
origin_img = optimize_image(group_data["origin"], max_width=600) |
|
|
|
|
|
|
|
|
candidates = group_data["candidates"].copy() |
|
|
random.shuffle(candidates) |
|
|
|
|
|
gallery_items = [] |
|
|
choices = [] |
|
|
candidates_info = [] |
|
|
|
|
|
for i, path in enumerate(candidates): |
|
|
label = f"Option {chr(65+i)}" |
|
|
|
|
|
|
|
|
optimized_img = optimize_image(path, max_width=600) |
|
|
|
|
|
gallery_items.append((optimized_img, label)) |
|
|
choices.append(label) |
|
|
candidates_info.append({"label": label, "path": path}) |
|
|
|
|
|
instruction = f"### 任务 ({idx + 1} / {len(ALL_GROUP_IDS)})\n\n{group_data['instruction']}" |
|
|
|
|
|
return ( |
|
|
gr.update(value=origin_img, visible=True if origin_img else False), |
|
|
gr.update(value=gallery_items, visible=True), |
|
|
gr.update(choices=choices, value=[], visible=True), |
|
|
gr.update(value=instruction, visible=True), |
|
|
gr.update(visible=True), |
|
|
gr.update(visible=True), |
|
|
gr.update(visible=False), |
|
|
user_state, |
|
|
candidates_info |
|
|
) |
|
|
|
|
|
def save_and_next(user_state, candidates_info, selected_options, is_none=False): |
|
|
current_idx = user_state["index"] |
|
|
group_id = ALL_GROUP_IDS[current_idx] |
|
|
|
|
|
if is_none: |
|
|
choice_str = "Rejected All" |
|
|
method_str = "None_Satisfied" |
|
|
else: |
|
|
if not selected_options: |
|
|
raise gr.Error("请至少勾选一个选项,或点击“都不满意”") |
|
|
|
|
|
choice_str = "; ".join(selected_options) |
|
|
selected_methods = [] |
|
|
for opt in selected_options: |
|
|
for info in candidates_info: |
|
|
if info["label"] == opt: |
|
|
path = info["path"] |
|
|
filename = os.path.basename(path) |
|
|
name = os.path.splitext(filename)[0] |
|
|
parts = name.split('_', 1) |
|
|
method = parts[1] if len(parts) > 1 else name |
|
|
selected_methods.append(method) |
|
|
break |
|
|
method_str = "; ".join(selected_methods) |
|
|
|
|
|
user_file = LOG_FOLDER / f"user_{user_state['user_id']}.csv" |
|
|
with scheduler.lock: |
|
|
exists = user_file.exists() |
|
|
with open(user_file, "a", newline="", encoding="utf-8") as f: |
|
|
writer = csv.writer(f) |
|
|
if not exists: |
|
|
writer.writerow(["user_id", "timestamp", "group_id", "choices", "methods"]) |
|
|
writer.writerow([ |
|
|
user_state["user_id"], |
|
|
datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
|
|
group_id, |
|
|
choice_str, |
|
|
method_str |
|
|
]) |
|
|
|
|
|
user_state["index"] += 1 |
|
|
return get_next_question(user_state) |
|
|
|
|
|
|
|
|
with gr.Blocks(title="User Study") as demo: |
|
|
|
|
|
state_user = gr.State(lambda: {"user_id": str(uuid.uuid4())[:8], "index": 0}) |
|
|
state_candidates_info = gr.State([]) |
|
|
|
|
|
with gr.Row(): |
|
|
md_instruction = gr.Markdown("Loading...") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
|
|
|
img_origin = gr.Image(label="Reference (参考原图)", interactive=False, height=400, format="jpeg") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
gallery_candidates = gr.Gallery( |
|
|
label="Candidates (候选结果)", |
|
|
columns=[2], |
|
|
height="auto", |
|
|
object_fit="contain", |
|
|
interactive=False, |
|
|
format="jpeg" |
|
|
) |
|
|
|
|
|
gr.Markdown("👇 **请在下方勾选您认为最好的结果(可多选):**") |
|
|
|
|
|
checkbox_options = gr.CheckboxGroup( |
|
|
choices=[], |
|
|
label="您的选择", |
|
|
info="对应上方图片的标签 (Option A, B...)" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
btn_submit = gr.Button("🚀 提交 (Submit)", variant="primary") |
|
|
btn_none = gr.Button("🚫 都不满意 (None)", variant="stop") |
|
|
|
|
|
md_end = gr.Markdown(visible=False) |
|
|
|
|
|
demo.load( |
|
|
fn=get_next_question, |
|
|
inputs=[state_user], |
|
|
outputs=[img_origin, gallery_candidates, checkbox_options, md_instruction, btn_submit, btn_none, md_end, state_user, state_candidates_info] |
|
|
) |
|
|
|
|
|
btn_submit.click( |
|
|
fn=lambda s, c, o: save_and_next(s, c, o, is_none=False), |
|
|
inputs=[state_user, state_candidates_info, checkbox_options], |
|
|
outputs=[img_origin, gallery_candidates, checkbox_options, md_instruction, btn_submit, btn_none, md_end, state_user, state_candidates_info] |
|
|
) |
|
|
|
|
|
btn_none.click( |
|
|
fn=lambda s, c, o: save_and_next(s, c, o, is_none=True), |
|
|
inputs=[state_user, state_candidates_info, checkbox_options], |
|
|
outputs=[img_origin, gallery_candidates, checkbox_options, md_instruction, btn_submit, btn_none, md_end, state_user, state_candidates_info] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |