# app.py # ----------------------------------------- # NanoChat (ONNX) を Streamlit チャット UI に移植 # 変更条件: # - デバイスは CPU(ONNX Runtime の CPUExecutionProvider) # - 量子化は行わない(そのままの ONNX を読み込む) # # 依存: # pip install streamlit transformers optimum[onnxruntime] # 実行: # streamlit run app.py # ----------------------------------------- import threading from typing import List, Dict import streamlit as st from transformers import AutoTokenizer, TextIteratorStreamer from optimum.onnxruntime import ORTModelForCausalLM MODEL_ID = "onnx-community/nanochat-d32-ONNX" # ----------------------------- # 初期化(モデルは一度だけロード) # ----------------------------- @st.cache_resource(show_spinner=True) def load_model_and_tokenizer(): # CPU 明示(量子化なし) model = ORTModelForCausalLM.from_pretrained( MODEL_ID, provider="CPUExecutionProvider", # ← CPU を使用 use_io_binding=False # 標準 I/O 経路(デバッグしやすい) ) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True) return model, tokenizer # ----------------------------- # セッションステート初期化 # ----------------------------- def init_session(): if "messages" not in st.session_state: st.session_state.messages: List[Dict[str, str]] = [ {"role": "system", "content": "You are a helpful assistant."} ] if "temperature" not in st.session_state: st.session_state.temperature = 0.05 if "top_k" not in st.session_state: st.session_state.top_k = 5 if "loading_done" not in st.session_state: st.session_state.loading_done = False init_session() # ----------------------------- # レイアウト(サイドバー) # ----------------------------- st.set_page_config(page_title="NanoChat (ONNX, CPU)", page_icon="⚡", layout="centered") with st.sidebar: st.markdown("## ⚡ nanochat (ONNX, CPU)") st.caption("デバイス: **CPU**, 量子化: **なし**") st.divider() st.markdown("#### Generation Settings") st.session_state.temperature = st.slider( "Temperature", 0.0, 2.0, float(st.session_state.temperature), 0.01 ) st.session_state.top_k = st.slider( "Top-k", 1, 200, int(st.session_state.top_k), 1 ) if st.button("🆕 New Conversation", use_container_width=True): st.session_state.messages = [ {"role": "system", "content": "You are a helpful assistant."} ] st.experimental_rerun() # ----------------------------- # モデル読み込み(スピナー表示) # ----------------------------- with st.spinner("Loading model on CPU... (no quantization)"): model, tokenizer = load_model_and_tokenizer() st.session_state.loading_done = True # ----------------------------- # メイン: チャット履歴の描画 # ----------------------------- st.markdown("### nanochat web (Streamlit)") st.caption("Enter を押すと送信、Shift+Enter で改行") for msg in st.session_state.messages: if msg["role"] == "system": continue # system は非表示 with st.chat_message("assistant" if msg["role"] == "assistant" else "user"): st.write(msg["content"]) # ----------------------------- # ユーティリティ # ----------------------------- def handle_slash_command(text: str) -> bool: """ /temperature, /topk, /clear, /help をサポート """ parts = text.strip().split() cmd = parts[0].lower() def console_reply(content: str): st.session_state.messages.append({"role": "assistant", "content": content}) if cmd == "/temperature": if len(parts) == 1: console_reply(f"Current temperature: {st.session_state.temperature}") else: try: val = float(parts[1]) if val < 0 or val > 2: console_reply("Invalid temperature. Must be between 0.0 and 2.0") else: st.session_state.temperature = val console_reply(f"Temperature set to {val}") except ValueError: console_reply("Invalid value. Usage: /temperature <0.0-2.0>") return True if cmd == "/topk": if len(parts) == 1: console_reply(f"Current top-k: {st.session_state.top_k}") else: try: val = int(parts[1]) if val < 1 or val > 200: console_reply("Invalid top-k. Must be between 1 and 200") else: st.session_state.top_k = val console_reply(f"Top-k set to {val}") except ValueError: console_reply("Invalid value. Usage: /topk <1-200>") return True if cmd == "/clear": st.session_state.messages = [ {"role": "system", "content": "You are a helpful assistant."} ] return True if cmd == "/help": console_reply( "Available commands:\n" "/temperature - Show current temperature\n" "/temperature - Set temperature (0.0-2.0)\n" "/topk - Show current top-k\n" "/topk - Set top-k (1-200)\n" "/clear - Clear conversation\n" "/help - Show this help message" ) return True return False def build_chat_prompt(messages: List[Dict[str, str]]) -> Dict[str, "torch.Tensor"]: """ Chatテンプレートがある場合はそれを利用。 """ if hasattr(tokenizer, "apply_chat_template"): text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = tokenizer(text, return_tensors="pt") else: # シンプルなフォールバック prompt = "" for m in messages: if m["role"] == "system": prompt += f"<|system|>\n{m['content']}\n" elif m["role"] == "user": prompt += f"<|user|>\n{m['content']}\n" elif m["role"] == "assistant": prompt += f"<|assistant|>\n{m['content']}\n" prompt += "<|assistant|>\n" inputs = tokenizer(prompt, return_tensors="pt") return inputs def generate_stream(messages: List[Dict[str, str]]): """ TextIteratorStreamer でトークンを逐次取得して表示用に yield。 """ streamer = TextIteratorStreamer( tokenizer, skip_special_tokens=True, skip_prompt=True ) gen_kwargs = dict( max_new_tokens=512, do_sample=(st.session_state.temperature > 0.0), temperature=float(st.session_state.temperature), top_k=int(st.session_state.top_k), repetition_penalty=1.2, streamer=streamer ) inputs = build_chat_prompt(messages) # 生成はスレッドで回す(streamer をメインスレッドで消費) def _thread_target(): model.generate(**inputs, **gen_kwargs) thread = threading.Thread(target=_thread_target) thread.start() for token in streamer: yield token thread.join() # ----------------------------- # 入力欄(送信 → 生成) # ----------------------------- user_input = st.chat_input("Ask anything", disabled=not st.session_state.loading_done) if user_input: # スラッシュコマンド判定 if user_input.strip().startswith("/"): if handle_slash_command(user_input.strip()): st.experimental_rerun() else: st.session_state.messages.append( {"role": "assistant", "content": f"Unknown command: {user_input.strip()}"} ) st.experimental_rerun() else: # ユーザーメッセージを履歴に追加して表示 st.session_state.messages.append({"role": "user", "content": user_input}) with st.chat_message("user"): st.write(user_input) # アシスタントのストリーミング出力 with st.chat_message("assistant"): placeholder = st.empty() full = "" try: for chunk in generate_stream(st.session_state.messages): full += chunk placeholder.write(full) except Exception as e: st.error(f"Error: {e}") full = "" finally: # 空でも履歴としては追加(エラー時は空文字) st.session_state.messages.append({"role": "assistant", "content": full}) st.experimental_rerun()