HayatoHongo's picture
Create app.py
6968eda verified
# app.py
# -----------------------------------------
# NanoChat (ONNX) を Streamlit チャット UI に移植
# 変更条件:
# - デバイスは CPU(ONNX Runtime の CPUExecutionProvider)
# - 量子化は行わない(そのままの ONNX を読み込む)
#
# 依存:
# pip install streamlit transformers optimum[onnxruntime]
# 実行:
# streamlit run app.py
# -----------------------------------------
import threading
from typing import List, Dict
import streamlit as st
from transformers import AutoTokenizer, TextIteratorStreamer
from optimum.onnxruntime import ORTModelForCausalLM
MODEL_ID = "onnx-community/nanochat-d32-ONNX"
# -----------------------------
# 初期化(モデルは一度だけロード)
# -----------------------------
@st.cache_resource(show_spinner=True)
def load_model_and_tokenizer():
# CPU 明示(量子化なし)
model = ORTModelForCausalLM.from_pretrained(
MODEL_ID,
provider="CPUExecutionProvider", # ← CPU を使用
use_io_binding=False # 標準 I/O 経路(デバッグしやすい)
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
return model, tokenizer
# -----------------------------
# セッションステート初期化
# -----------------------------
def init_session():
if "messages" not in st.session_state:
st.session_state.messages: List[Dict[str, str]] = [
{"role": "system", "content": "You are a helpful assistant."}
]
if "temperature" not in st.session_state:
st.session_state.temperature = 0.05
if "top_k" not in st.session_state:
st.session_state.top_k = 5
if "loading_done" not in st.session_state:
st.session_state.loading_done = False
init_session()
# -----------------------------
# レイアウト(サイドバー)
# -----------------------------
st.set_page_config(page_title="NanoChat (ONNX, CPU)", page_icon="⚡", layout="centered")
with st.sidebar:
st.markdown("## ⚡ nanochat (ONNX, CPU)")
st.caption("デバイス: **CPU**, 量子化: **なし**")
st.divider()
st.markdown("#### Generation Settings")
st.session_state.temperature = st.slider(
"Temperature", 0.0, 2.0, float(st.session_state.temperature), 0.01
)
st.session_state.top_k = st.slider(
"Top-k", 1, 200, int(st.session_state.top_k), 1
)
if st.button("🆕 New Conversation", use_container_width=True):
st.session_state.messages = [
{"role": "system", "content": "You are a helpful assistant."}
]
st.experimental_rerun()
# -----------------------------
# モデル読み込み(スピナー表示)
# -----------------------------
with st.spinner("Loading model on CPU... (no quantization)"):
model, tokenizer = load_model_and_tokenizer()
st.session_state.loading_done = True
# -----------------------------
# メイン: チャット履歴の描画
# -----------------------------
st.markdown("### nanochat web (Streamlit)")
st.caption("Enter を押すと送信、Shift+Enter で改行")
for msg in st.session_state.messages:
if msg["role"] == "system":
continue # system は非表示
with st.chat_message("assistant" if msg["role"] == "assistant" else "user"):
st.write(msg["content"])
# -----------------------------
# ユーティリティ
# -----------------------------
def handle_slash_command(text: str) -> bool:
"""
/temperature, /topk, /clear, /help をサポート
"""
parts = text.strip().split()
cmd = parts[0].lower()
def console_reply(content: str):
st.session_state.messages.append({"role": "assistant", "content": content})
if cmd == "/temperature":
if len(parts) == 1:
console_reply(f"Current temperature: {st.session_state.temperature}")
else:
try:
val = float(parts[1])
if val < 0 or val > 2:
console_reply("Invalid temperature. Must be between 0.0 and 2.0")
else:
st.session_state.temperature = val
console_reply(f"Temperature set to {val}")
except ValueError:
console_reply("Invalid value. Usage: /temperature <0.0-2.0>")
return True
if cmd == "/topk":
if len(parts) == 1:
console_reply(f"Current top-k: {st.session_state.top_k}")
else:
try:
val = int(parts[1])
if val < 1 or val > 200:
console_reply("Invalid top-k. Must be between 1 and 200")
else:
st.session_state.top_k = val
console_reply(f"Top-k set to {val}")
except ValueError:
console_reply("Invalid value. Usage: /topk <1-200>")
return True
if cmd == "/clear":
st.session_state.messages = [
{"role": "system", "content": "You are a helpful assistant."}
]
return True
if cmd == "/help":
console_reply(
"Available commands:\n"
"/temperature - Show current temperature\n"
"/temperature <value> - Set temperature (0.0-2.0)\n"
"/topk - Show current top-k\n"
"/topk <value> - Set top-k (1-200)\n"
"/clear - Clear conversation\n"
"/help - Show this help message"
)
return True
return False
def build_chat_prompt(messages: List[Dict[str, str]]) -> Dict[str, "torch.Tensor"]:
"""
Chatテンプレートがある場合はそれを利用。
"""
if hasattr(tokenizer, "apply_chat_template"):
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
inputs = tokenizer(text, return_tensors="pt")
else:
# シンプルなフォールバック
prompt = ""
for m in messages:
if m["role"] == "system":
prompt += f"<|system|>\n{m['content']}\n"
elif m["role"] == "user":
prompt += f"<|user|>\n{m['content']}\n"
elif m["role"] == "assistant":
prompt += f"<|assistant|>\n{m['content']}\n"
prompt += "<|assistant|>\n"
inputs = tokenizer(prompt, return_tensors="pt")
return inputs
def generate_stream(messages: List[Dict[str, str]]):
"""
TextIteratorStreamer でトークンを逐次取得して表示用に yield。
"""
streamer = TextIteratorStreamer(
tokenizer,
skip_special_tokens=True,
skip_prompt=True
)
gen_kwargs = dict(
max_new_tokens=512,
do_sample=(st.session_state.temperature > 0.0),
temperature=float(st.session_state.temperature),
top_k=int(st.session_state.top_k),
repetition_penalty=1.2,
streamer=streamer
)
inputs = build_chat_prompt(messages)
# 生成はスレッドで回す(streamer をメインスレッドで消費)
def _thread_target():
model.generate(**inputs, **gen_kwargs)
thread = threading.Thread(target=_thread_target)
thread.start()
for token in streamer:
yield token
thread.join()
# -----------------------------
# 入力欄(送信 → 生成)
# -----------------------------
user_input = st.chat_input("Ask anything", disabled=not st.session_state.loading_done)
if user_input:
# スラッシュコマンド判定
if user_input.strip().startswith("/"):
if handle_slash_command(user_input.strip()):
st.experimental_rerun()
else:
st.session_state.messages.append(
{"role": "assistant", "content": f"Unknown command: {user_input.strip()}"}
)
st.experimental_rerun()
else:
# ユーザーメッセージを履歴に追加して表示
st.session_state.messages.append({"role": "user", "content": user_input})
with st.chat_message("user"):
st.write(user_input)
# アシスタントのストリーミング出力
with st.chat_message("assistant"):
placeholder = st.empty()
full = ""
try:
for chunk in generate_stream(st.session_state.messages):
full += chunk
placeholder.write(full)
except Exception as e:
st.error(f"Error: {e}")
full = ""
finally:
# 空でも履歴としては追加(エラー時は空文字)
st.session_state.messages.append({"role": "assistant", "content": full})
st.experimental_rerun()