Spaces:
Build error
Build error
| # app.py | |
| # ----------------------------------------- | |
| # NanoChat (ONNX) を Streamlit チャット UI に移植 | |
| # 変更条件: | |
| # - デバイスは CPU(ONNX Runtime の CPUExecutionProvider) | |
| # - 量子化は行わない(そのままの ONNX を読み込む) | |
| # | |
| # 依存: | |
| # pip install streamlit transformers optimum[onnxruntime] | |
| # 実行: | |
| # streamlit run app.py | |
| # ----------------------------------------- | |
| import threading | |
| from typing import List, Dict | |
| import streamlit as st | |
| from transformers import AutoTokenizer, TextIteratorStreamer | |
| from optimum.onnxruntime import ORTModelForCausalLM | |
| MODEL_ID = "onnx-community/nanochat-d32-ONNX" | |
| # ----------------------------- | |
| # 初期化(モデルは一度だけロード) | |
| # ----------------------------- | |
| def load_model_and_tokenizer(): | |
| # CPU 明示(量子化なし) | |
| model = ORTModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| provider="CPUExecutionProvider", # ← CPU を使用 | |
| use_io_binding=False # 標準 I/O 経路(デバッグしやすい) | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True) | |
| return model, tokenizer | |
| # ----------------------------- | |
| # セッションステート初期化 | |
| # ----------------------------- | |
| def init_session(): | |
| if "messages" not in st.session_state: | |
| st.session_state.messages: List[Dict[str, str]] = [ | |
| {"role": "system", "content": "You are a helpful assistant."} | |
| ] | |
| if "temperature" not in st.session_state: | |
| st.session_state.temperature = 0.05 | |
| if "top_k" not in st.session_state: | |
| st.session_state.top_k = 5 | |
| if "loading_done" not in st.session_state: | |
| st.session_state.loading_done = False | |
| init_session() | |
| # ----------------------------- | |
| # レイアウト(サイドバー) | |
| # ----------------------------- | |
| st.set_page_config(page_title="NanoChat (ONNX, CPU)", page_icon="⚡", layout="centered") | |
| with st.sidebar: | |
| st.markdown("## ⚡ nanochat (ONNX, CPU)") | |
| st.caption("デバイス: **CPU**, 量子化: **なし**") | |
| st.divider() | |
| st.markdown("#### Generation Settings") | |
| st.session_state.temperature = st.slider( | |
| "Temperature", 0.0, 2.0, float(st.session_state.temperature), 0.01 | |
| ) | |
| st.session_state.top_k = st.slider( | |
| "Top-k", 1, 200, int(st.session_state.top_k), 1 | |
| ) | |
| if st.button("🆕 New Conversation", use_container_width=True): | |
| st.session_state.messages = [ | |
| {"role": "system", "content": "You are a helpful assistant."} | |
| ] | |
| st.experimental_rerun() | |
| # ----------------------------- | |
| # モデル読み込み(スピナー表示) | |
| # ----------------------------- | |
| with st.spinner("Loading model on CPU... (no quantization)"): | |
| model, tokenizer = load_model_and_tokenizer() | |
| st.session_state.loading_done = True | |
| # ----------------------------- | |
| # メイン: チャット履歴の描画 | |
| # ----------------------------- | |
| st.markdown("### nanochat web (Streamlit)") | |
| st.caption("Enter を押すと送信、Shift+Enter で改行") | |
| for msg in st.session_state.messages: | |
| if msg["role"] == "system": | |
| continue # system は非表示 | |
| with st.chat_message("assistant" if msg["role"] == "assistant" else "user"): | |
| st.write(msg["content"]) | |
| # ----------------------------- | |
| # ユーティリティ | |
| # ----------------------------- | |
| def handle_slash_command(text: str) -> bool: | |
| """ | |
| /temperature, /topk, /clear, /help をサポート | |
| """ | |
| parts = text.strip().split() | |
| cmd = parts[0].lower() | |
| def console_reply(content: str): | |
| st.session_state.messages.append({"role": "assistant", "content": content}) | |
| if cmd == "/temperature": | |
| if len(parts) == 1: | |
| console_reply(f"Current temperature: {st.session_state.temperature}") | |
| else: | |
| try: | |
| val = float(parts[1]) | |
| if val < 0 or val > 2: | |
| console_reply("Invalid temperature. Must be between 0.0 and 2.0") | |
| else: | |
| st.session_state.temperature = val | |
| console_reply(f"Temperature set to {val}") | |
| except ValueError: | |
| console_reply("Invalid value. Usage: /temperature <0.0-2.0>") | |
| return True | |
| if cmd == "/topk": | |
| if len(parts) == 1: | |
| console_reply(f"Current top-k: {st.session_state.top_k}") | |
| else: | |
| try: | |
| val = int(parts[1]) | |
| if val < 1 or val > 200: | |
| console_reply("Invalid top-k. Must be between 1 and 200") | |
| else: | |
| st.session_state.top_k = val | |
| console_reply(f"Top-k set to {val}") | |
| except ValueError: | |
| console_reply("Invalid value. Usage: /topk <1-200>") | |
| return True | |
| if cmd == "/clear": | |
| st.session_state.messages = [ | |
| {"role": "system", "content": "You are a helpful assistant."} | |
| ] | |
| return True | |
| if cmd == "/help": | |
| console_reply( | |
| "Available commands:\n" | |
| "/temperature - Show current temperature\n" | |
| "/temperature <value> - Set temperature (0.0-2.0)\n" | |
| "/topk - Show current top-k\n" | |
| "/topk <value> - Set top-k (1-200)\n" | |
| "/clear - Clear conversation\n" | |
| "/help - Show this help message" | |
| ) | |
| return True | |
| return False | |
| def build_chat_prompt(messages: List[Dict[str, str]]) -> Dict[str, "torch.Tensor"]: | |
| """ | |
| Chatテンプレートがある場合はそれを利用。 | |
| """ | |
| if hasattr(tokenizer, "apply_chat_template"): | |
| text = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| inputs = tokenizer(text, return_tensors="pt") | |
| else: | |
| # シンプルなフォールバック | |
| prompt = "" | |
| for m in messages: | |
| if m["role"] == "system": | |
| prompt += f"<|system|>\n{m['content']}\n" | |
| elif m["role"] == "user": | |
| prompt += f"<|user|>\n{m['content']}\n" | |
| elif m["role"] == "assistant": | |
| prompt += f"<|assistant|>\n{m['content']}\n" | |
| prompt += "<|assistant|>\n" | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| return inputs | |
| def generate_stream(messages: List[Dict[str, str]]): | |
| """ | |
| TextIteratorStreamer でトークンを逐次取得して表示用に yield。 | |
| """ | |
| streamer = TextIteratorStreamer( | |
| tokenizer, | |
| skip_special_tokens=True, | |
| skip_prompt=True | |
| ) | |
| gen_kwargs = dict( | |
| max_new_tokens=512, | |
| do_sample=(st.session_state.temperature > 0.0), | |
| temperature=float(st.session_state.temperature), | |
| top_k=int(st.session_state.top_k), | |
| repetition_penalty=1.2, | |
| streamer=streamer | |
| ) | |
| inputs = build_chat_prompt(messages) | |
| # 生成はスレッドで回す(streamer をメインスレッドで消費) | |
| def _thread_target(): | |
| model.generate(**inputs, **gen_kwargs) | |
| thread = threading.Thread(target=_thread_target) | |
| thread.start() | |
| for token in streamer: | |
| yield token | |
| thread.join() | |
| # ----------------------------- | |
| # 入力欄(送信 → 生成) | |
| # ----------------------------- | |
| user_input = st.chat_input("Ask anything", disabled=not st.session_state.loading_done) | |
| if user_input: | |
| # スラッシュコマンド判定 | |
| if user_input.strip().startswith("/"): | |
| if handle_slash_command(user_input.strip()): | |
| st.experimental_rerun() | |
| else: | |
| st.session_state.messages.append( | |
| {"role": "assistant", "content": f"Unknown command: {user_input.strip()}"} | |
| ) | |
| st.experimental_rerun() | |
| else: | |
| # ユーザーメッセージを履歴に追加して表示 | |
| st.session_state.messages.append({"role": "user", "content": user_input}) | |
| with st.chat_message("user"): | |
| st.write(user_input) | |
| # アシスタントのストリーミング出力 | |
| with st.chat_message("assistant"): | |
| placeholder = st.empty() | |
| full = "" | |
| try: | |
| for chunk in generate_stream(st.session_state.messages): | |
| full += chunk | |
| placeholder.write(full) | |
| except Exception as e: | |
| st.error(f"Error: {e}") | |
| full = "" | |
| finally: | |
| # 空でも履歴としては追加(エラー時は空文字) | |
| st.session_state.messages.append({"role": "assistant", "content": full}) | |
| st.experimental_rerun() | |