Spaces:
Running
Running
feat: Add Hugging Face model support
Browse files- .env.example +2 -0
- app/api/chat/route.ts +12 -4
- components/chat.tsx +12 -2
- components/message.tsx +1 -1
- components/model-picker.tsx +22 -13
- lib/hf-client.ts +6 -0
- lib/models.ts +22 -3
.env.example
CHANGED
|
@@ -9,3 +9,5 @@ OPENAI_BASE_URL="" # optional β leave blank to use api.openai.com
|
|
| 9 |
# When pointing to Azure: OPENAI_BASE_URL=https://<resource>.openai.azure.com/openai/deployments/<deployment-name>
|
| 10 |
# Optional extra headers as JSON, eg: {"api-key":"abc","organization":"org_xyz"}
|
| 11 |
OPENAI_EXTRA_HEADERS=""
|
|
|
|
|
|
|
|
|
| 9 |
# When pointing to Azure: OPENAI_BASE_URL=https://<resource>.openai.azure.com/openai/deployments/<deployment-name>
|
| 10 |
# Optional extra headers as JSON, eg: {"api-key":"abc","organization":"org_xyz"}
|
| 11 |
OPENAI_EXTRA_HEADERS=""
|
| 12 |
+
|
| 13 |
+
HF_TOKEN=
|
app/api/chat/route.ts
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import { openai } from "@/lib/openai-client";
|
| 2 |
-
import
|
|
|
|
| 3 |
import { saveChat } from "@/lib/chat-store";
|
| 4 |
import { nanoid } from "nanoid";
|
| 5 |
import { db } from "@/lib/db";
|
|
@@ -105,13 +106,20 @@ export async function POST(req: Request) {
|
|
| 105 |
|
| 106 |
const { tools, cleanup } = await initializeMCPClients(mcpServers, req.signal);
|
| 107 |
|
| 108 |
-
const
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
{
|
| 110 |
model: selectedModel,
|
| 111 |
stream: true,
|
| 112 |
messages,
|
| 113 |
-
|
| 114 |
-
|
|
|
|
|
|
|
| 115 |
},
|
| 116 |
{ signal: req.signal }
|
| 117 |
);
|
|
|
|
| 1 |
import { openai } from "@/lib/openai-client";
|
| 2 |
+
import { hf } from "@/lib/hf-client";
|
| 3 |
+
import { getModels, type ModelID } from "@/lib/models";
|
| 4 |
import { saveChat } from "@/lib/chat-store";
|
| 5 |
import { nanoid } from "nanoid";
|
| 6 |
import { db } from "@/lib/db";
|
|
|
|
| 106 |
|
| 107 |
const { tools, cleanup } = await initializeMCPClients(mcpServers, req.signal);
|
| 108 |
|
| 109 |
+
const hfModels = await getModels();
|
| 110 |
+
const client = hfModels.includes(selectedModel) ? hf : openai;
|
| 111 |
+
|
| 112 |
+
const openAITools = mcpToolsToOpenAITools(tools);
|
| 113 |
+
|
| 114 |
+
const completion = await client.chat.completions.create(
|
| 115 |
{
|
| 116 |
model: selectedModel,
|
| 117 |
stream: true,
|
| 118 |
messages,
|
| 119 |
+
...(openAITools.length > 0 && {
|
| 120 |
+
tools: openAITools,
|
| 121 |
+
tool_choice: "auto",
|
| 122 |
+
}),
|
| 123 |
},
|
| 124 |
{ signal: req.signal }
|
| 125 |
);
|
components/chat.tsx
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
"use client";
|
| 2 |
|
| 3 |
-
import {
|
| 4 |
import { Message, useChat } from "@ai-sdk/react";
|
| 5 |
import { useState, useEffect, useMemo, useCallback } from "react";
|
| 6 |
import { Textarea } from "./textarea";
|
|
@@ -30,7 +30,7 @@ export default function Chat() {
|
|
| 30 |
const chatId = params?.id as string | undefined;
|
| 31 |
const queryClient = useQueryClient();
|
| 32 |
|
| 33 |
-
const [selectedModel, setSelectedModel] = useLocalStorage<ModelID>("selectedModel",
|
| 34 |
const [userId, setUserId] = useState<string>('');
|
| 35 |
const [generatedChatId, setGeneratedChatId] = useState<string>('');
|
| 36 |
|
|
@@ -41,6 +41,16 @@ export default function Chat() {
|
|
| 41 |
useEffect(() => {
|
| 42 |
setUserId(getUserId());
|
| 43 |
}, []);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
// Generate a chat ID if needed
|
| 46 |
useEffect(() => {
|
|
|
|
| 1 |
"use client";
|
| 2 |
|
| 3 |
+
import { getDefaultModel, type ModelID } from "@/lib/models";
|
| 4 |
import { Message, useChat } from "@ai-sdk/react";
|
| 5 |
import { useState, useEffect, useMemo, useCallback } from "react";
|
| 6 |
import { Textarea } from "./textarea";
|
|
|
|
| 30 |
const chatId = params?.id as string | undefined;
|
| 31 |
const queryClient = useQueryClient();
|
| 32 |
|
| 33 |
+
const [selectedModel, setSelectedModel] = useLocalStorage<ModelID>("selectedModel", "");
|
| 34 |
const [userId, setUserId] = useState<string>('');
|
| 35 |
const [generatedChatId, setGeneratedChatId] = useState<string>('');
|
| 36 |
|
|
|
|
| 41 |
useEffect(() => {
|
| 42 |
setUserId(getUserId());
|
| 43 |
}, []);
|
| 44 |
+
|
| 45 |
+
useEffect(() => {
|
| 46 |
+
const fetchDefaultModel = async () => {
|
| 47 |
+
const defaultModel = await getDefaultModel();
|
| 48 |
+
if (!selectedModel) {
|
| 49 |
+
setSelectedModel(defaultModel);
|
| 50 |
+
}
|
| 51 |
+
};
|
| 52 |
+
fetchDefaultModel();
|
| 53 |
+
}, [selectedModel, setSelectedModel]);
|
| 54 |
|
| 55 |
// Generate a chat ID if needed
|
| 56 |
useEffect(() => {
|
components/message.tsx
CHANGED
|
@@ -163,7 +163,7 @@ const PurePreviewMessage = ({
|
|
| 163 |
>
|
| 164 |
<div
|
| 165 |
className={cn("flex flex-col gap-3 w-full", {
|
| 166 |
-
"bg-secondary text-secondary-foreground px-4 py-
|
| 167 |
message.role === "user",
|
| 168 |
})}
|
| 169 |
>
|
|
|
|
| 163 |
>
|
| 164 |
<div
|
| 165 |
className={cn("flex flex-col gap-3 w-full", {
|
| 166 |
+
"bg-secondary text-secondary-foreground px-4 py-1.5 rounded-2xl":
|
| 167 |
message.role === "user",
|
| 168 |
})}
|
| 169 |
>
|
components/model-picker.tsx
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
"use client";
|
| 2 |
-
import {
|
| 3 |
import {
|
| 4 |
Select,
|
| 5 |
SelectContent,
|
|
@@ -10,7 +10,7 @@ import {
|
|
| 10 |
} from "./ui/select";
|
| 11 |
import { cn } from "@/lib/utils";
|
| 12 |
import { Bot } from "lucide-react";
|
| 13 |
-
import { useEffect } from "react";
|
| 14 |
|
| 15 |
interface ModelPickerProps {
|
| 16 |
selectedModel: ModelID;
|
|
@@ -18,19 +18,28 @@ interface ModelPickerProps {
|
|
| 18 |
}
|
| 19 |
|
| 20 |
export const ModelPicker = ({ selectedModel, setSelectedModel }: ModelPickerProps) => {
|
| 21 |
-
|
| 22 |
-
const validModelId =
|
| 23 |
-
|
| 24 |
-
// If the selected model is invalid, update it to the default
|
| 25 |
useEffect(() => {
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
// Handle model change
|
| 32 |
const handleModelChange = (modelId: string) => {
|
| 33 |
-
if (
|
| 34 |
setSelectedModel(modelId as ModelID);
|
| 35 |
}
|
| 36 |
};
|
|
@@ -43,7 +52,7 @@ export const ModelPicker = ({ selectedModel, setSelectedModel }: ModelPickerProp
|
|
| 43 |
defaultValue={validModelId}
|
| 44 |
>
|
| 45 |
<SelectTrigger
|
| 46 |
-
className="max-w-[200px] sm:max-w-fit sm:w-
|
| 47 |
>
|
| 48 |
<SelectValue
|
| 49 |
placeholder="Select model"
|
|
@@ -60,7 +69,7 @@ export const ModelPicker = ({ selectedModel, setSelectedModel }: ModelPickerProp
|
|
| 60 |
className="bg-background/95 dark:bg-muted/95 backdrop-blur-sm border-border/80 rounded-lg overflow-hidden p-0 w-[280px]"
|
| 61 |
>
|
| 62 |
<SelectGroup className="space-y-1 p-1">
|
| 63 |
-
{
|
| 64 |
<SelectItem
|
| 65 |
key={id}
|
| 66 |
value={id}
|
|
|
|
| 1 |
"use client";
|
| 2 |
+
import { getModels, getDefaultModel, ModelID } from "@/lib/models";
|
| 3 |
import {
|
| 4 |
Select,
|
| 5 |
SelectContent,
|
|
|
|
| 10 |
} from "./ui/select";
|
| 11 |
import { cn } from "@/lib/utils";
|
| 12 |
import { Bot } from "lucide-react";
|
| 13 |
+
import { useEffect, useState } from "react";
|
| 14 |
|
| 15 |
interface ModelPickerProps {
|
| 16 |
selectedModel: ModelID;
|
|
|
|
| 18 |
}
|
| 19 |
|
| 20 |
export const ModelPicker = ({ selectedModel, setSelectedModel }: ModelPickerProps) => {
|
| 21 |
+
const [models, setModels] = useState<ModelID[]>([]);
|
| 22 |
+
const [validModelId, setValidModelId] = useState<ModelID>("");
|
| 23 |
+
|
|
|
|
| 24 |
useEffect(() => {
|
| 25 |
+
const fetchModels = async () => {
|
| 26 |
+
const availableModels = await getModels();
|
| 27 |
+
setModels(availableModels);
|
| 28 |
+
const defaultModel = await getDefaultModel();
|
| 29 |
+
const currentModel = selectedModel || defaultModel;
|
| 30 |
+
const isValid = availableModels.includes(currentModel);
|
| 31 |
+
const newValidModelId = isValid ? currentModel : defaultModel;
|
| 32 |
+
setValidModelId(newValidModelId);
|
| 33 |
+
if (selectedModel !== newValidModelId) {
|
| 34 |
+
setSelectedModel(newValidModelId);
|
| 35 |
+
}
|
| 36 |
+
};
|
| 37 |
+
fetchModels();
|
| 38 |
+
}, [selectedModel, setSelectedModel]);
|
| 39 |
|
| 40 |
// Handle model change
|
| 41 |
const handleModelChange = (modelId: string) => {
|
| 42 |
+
if (models.includes(modelId as ModelID)) {
|
| 43 |
setSelectedModel(modelId as ModelID);
|
| 44 |
}
|
| 45 |
};
|
|
|
|
| 52 |
defaultValue={validModelId}
|
| 53 |
>
|
| 54 |
<SelectTrigger
|
| 55 |
+
className="max-w-[200px] sm:max-w-fit sm:w-80 px-2 sm:px-3 h-8 sm:h-9 rounded-full group border-primary/20 bg-primary/5 hover:bg-primary/10 dark:bg-primary/10 dark:hover:bg-primary/20 transition-all duration-200 ring-offset-background focus:ring-2 focus:ring-primary/30 focus:ring-offset-2"
|
| 56 |
>
|
| 57 |
<SelectValue
|
| 58 |
placeholder="Select model"
|
|
|
|
| 69 |
className="bg-background/95 dark:bg-muted/95 backdrop-blur-sm border-border/80 rounded-lg overflow-hidden p-0 w-[280px]"
|
| 70 |
>
|
| 71 |
<SelectGroup className="space-y-1 p-1">
|
| 72 |
+
{models.map((id) => (
|
| 73 |
<SelectItem
|
| 74 |
key={id}
|
| 75 |
value={id}
|
lib/hf-client.ts
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import OpenAI from "openai";
|
| 2 |
+
|
| 3 |
+
export const hf = new OpenAI({
|
| 4 |
+
apiKey: process.env.HF_TOKEN,
|
| 5 |
+
baseURL: "https://router.huggingface.co/v1",
|
| 6 |
+
});
|
lib/models.ts
CHANGED
|
@@ -2,8 +2,27 @@
|
|
| 2 |
* List here only the model IDs your endpoint exposes.
|
| 3 |
* Add/remove freely β nothing else in the codebase cares.
|
| 4 |
*/
|
| 5 |
-
export
|
| 6 |
|
| 7 |
-
|
| 8 |
|
| 9 |
-
export
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
* List here only the model IDs your endpoint exposes.
|
| 3 |
* Add/remove freely β nothing else in the codebase cares.
|
| 4 |
*/
|
| 5 |
+
export type ModelID = string;
|
| 6 |
|
| 7 |
+
let modelsCache: string[] | null = null;
|
| 8 |
|
| 9 |
+
export async function getModels(): Promise<string[]> {
|
| 10 |
+
if (modelsCache) {
|
| 11 |
+
return modelsCache;
|
| 12 |
+
}
|
| 13 |
+
try {
|
| 14 |
+
const response = await fetch("https://router.huggingface.co/v1/models");
|
| 15 |
+
const data = await response.json();
|
| 16 |
+
const modelIds = data.data.slice(0, 5).map((model: any) => model.id);
|
| 17 |
+
modelsCache = modelIds;
|
| 18 |
+
return modelIds;
|
| 19 |
+
} catch (e) {
|
| 20 |
+
console.error(e);
|
| 21 |
+
return [];
|
| 22 |
+
}
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
export async function getDefaultModel(): Promise<ModelID> {
|
| 26 |
+
const models = await getModels();
|
| 27 |
+
return models[0] ?? "";
|
| 28 |
+
}
|