Harsh123007's picture
Create main.py
8a093ba verified
raw
history blame
1.25 kB
from fastapi import FastAPI
from pydantic import BaseModel
from fastapi.openapi.utils import get_openapi
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
app = FastAPI(
title="Harshal AI Backend",
version="1.0.0",
)
MODEL_NAME = "Qwen/Qwen2.5-0.5B"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float32,
device_map="cpu",
)
class ChatMessage(BaseModel):
messages: list
@app.get("/")
def home():
return {"message": "Harshal AI backend running with Qwen 0.5B!"}
@app.post("/chat")
def chat(body: ChatMessage):
user_msg = body.messages[-1]["content"]
prompt = f"User: {user_msg}\nAssistant:"
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(
**inputs,
max_new_tokens=120,
pad_token_id=tokenizer.eos_token_id,
temperature=0.4,
)
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
reply = text.split("Assistant:")[-1].strip()
return {"reply": reply}
@app.get("/openapi.json")
def openapi_json():
return get_openapi(
title="Harshal AI Backend",
version="1.0.0",
routes=app.routes
)