Harsh123007 commited on
Commit
8a093ba
·
verified ·
1 Parent(s): de18933

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +51 -0
main.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ from fastapi.openapi.utils import get_openapi
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ import torch
6
+
7
+ app = FastAPI(
8
+ title="Harshal AI Backend",
9
+ version="1.0.0",
10
+ )
11
+
12
+ MODEL_NAME = "Qwen/Qwen2.5-0.5B"
13
+
14
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
15
+ model = AutoModelForCausalLM.from_pretrained(
16
+ MODEL_NAME,
17
+ torch_dtype=torch.float32,
18
+ device_map="cpu",
19
+ )
20
+
21
+ class ChatMessage(BaseModel):
22
+ messages: list
23
+
24
+ @app.get("/")
25
+ def home():
26
+ return {"message": "Harshal AI backend running with Qwen 0.5B!"}
27
+
28
+ @app.post("/chat")
29
+ def chat(body: ChatMessage):
30
+ user_msg = body.messages[-1]["content"]
31
+ prompt = f"User: {user_msg}\nAssistant:"
32
+
33
+ inputs = tokenizer(prompt, return_tensors="pt")
34
+ outputs = model.generate(
35
+ **inputs,
36
+ max_new_tokens=120,
37
+ pad_token_id=tokenizer.eos_token_id,
38
+ temperature=0.4,
39
+ )
40
+
41
+ text = tokenizer.decode(outputs[0], skip_special_tokens=True)
42
+ reply = text.split("Assistant:")[-1].strip()
43
+ return {"reply": reply}
44
+
45
+ @app.get("/openapi.json")
46
+ def openapi_json():
47
+ return get_openapi(
48
+ title="Harshal AI Backend",
49
+ version="1.0.0",
50
+ routes=app.routes
51
+ )