Spaces:
Sleeping
Sleeping
New changes
Browse files- .gitignore +1 -0
- app.py +14 -9
- requirements.txt +2 -1
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
.env
|
app.py
CHANGED
|
@@ -22,6 +22,11 @@ import traceback
|
|
| 22 |
import shutil
|
| 23 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 24 |
from tqdm import tqdm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
| 27 |
|
|
@@ -29,13 +34,11 @@ def build_index_and_dataset(domain, subsets, embedder_type="sentence-transformer
|
|
| 29 |
dataset_path = f"{domain}_dataset"
|
| 30 |
index_path = f"{domain}_index/faiss.index"
|
| 31 |
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
if os.path.exists(index_path):
|
| 36 |
-
os.remove(index_path)
|
| 37 |
|
| 38 |
-
print(f"π
|
| 39 |
|
| 40 |
all_docs = []
|
| 41 |
for subset in subsets:
|
|
@@ -107,6 +110,8 @@ gk_dataset = load_dataset("rungalileo/ragbench", "hotpotqa", split="test")
|
|
| 107 |
cs_dataset = load_dataset("rungalileo/ragbench", "emanual", split="test")
|
| 108 |
fin_dataset = load_dataset("rungalileo/ragbench", "finqa", split="test")
|
| 109 |
|
|
|
|
|
|
|
| 110 |
# Load BGE reranker
|
| 111 |
reranker = CrossEncoder("BAAI/bge-reranker-base", max_length=512)
|
| 112 |
|
|
@@ -173,7 +178,7 @@ def retrieve_top_c(query, domain, embedder, k=5):
|
|
| 173 |
|
| 174 |
|
| 175 |
client = Groq(
|
| 176 |
-
api_key= '
|
| 177 |
)
|
| 178 |
|
| 179 |
|
|
@@ -584,7 +589,7 @@ def evaluate_rag_pipeline(domain, q_indices):
|
|
| 584 |
result["AUC-ROC (Adherence)"] = round(roc_auc_score(gt_adherence, pred_adherence), 4)
|
| 585 |
else:
|
| 586 |
result["Adherence"] = compute_rmse(gt_adherence, pred_adherence)
|
| 587 |
-
result["AUC-ROC (Adherence)"] = "N/A - one class only"
|
| 588 |
|
| 589 |
return result
|
| 590 |
|
|
@@ -627,4 +632,4 @@ iface = gr.Interface(
|
|
| 627 |
)
|
| 628 |
|
| 629 |
# Launch app
|
| 630 |
-
iface.launch(server_name="0.0.0.0", server_port=7860, debug=True)
|
|
|
|
| 22 |
import shutil
|
| 23 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 24 |
from tqdm import tqdm
|
| 25 |
+
from dotenv import load_dotenv
|
| 26 |
+
import os
|
| 27 |
+
|
| 28 |
+
load_dotenv()
|
| 29 |
+
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
|
| 30 |
|
| 31 |
embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
| 32 |
|
|
|
|
| 34 |
dataset_path = f"{domain}_dataset"
|
| 35 |
index_path = f"{domain}_index/faiss.index"
|
| 36 |
|
| 37 |
+
if os.path.exists(dataset_path) and os.path.exists(index_path):
|
| 38 |
+
print(f"β
Using cached dataset and index for domain: {domain}")
|
| 39 |
+
return Dataset.load_from_disk(dataset_path), faiss.read_index(index_path)
|
|
|
|
|
|
|
| 40 |
|
| 41 |
+
print(f"π Building dataset and index for domain: {domain}")
|
| 42 |
|
| 43 |
all_docs = []
|
| 44 |
for subset in subsets:
|
|
|
|
| 110 |
cs_dataset = load_dataset("rungalileo/ragbench", "emanual", split="test")
|
| 111 |
fin_dataset = load_dataset("rungalileo/ragbench", "finqa", split="test")
|
| 112 |
|
| 113 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 114 |
+
|
| 115 |
# Load BGE reranker
|
| 116 |
reranker = CrossEncoder("BAAI/bge-reranker-base", max_length=512)
|
| 117 |
|
|
|
|
| 178 |
|
| 179 |
|
| 180 |
client = Groq(
|
| 181 |
+
api_key= 'GROQ_API_KEY',
|
| 182 |
)
|
| 183 |
|
| 184 |
|
|
|
|
| 589 |
result["AUC-ROC (Adherence)"] = round(roc_auc_score(gt_adherence, pred_adherence), 4)
|
| 590 |
else:
|
| 591 |
result["Adherence"] = compute_rmse(gt_adherence, pred_adherence)
|
| 592 |
+
#result["AUC-ROC (Adherence)"] = "N/A - one class only"
|
| 593 |
|
| 594 |
return result
|
| 595 |
|
|
|
|
| 632 |
)
|
| 633 |
|
| 634 |
# Launch app
|
| 635 |
+
iface.launch(server_name="0.0.0.0", server_port=7860, debug=True)
|
requirements.txt
CHANGED
|
@@ -7,4 +7,5 @@ datasets
|
|
| 7 |
scikit-learn
|
| 8 |
groq
|
| 9 |
langchain
|
| 10 |
-
tqdm
|
|
|
|
|
|
| 7 |
scikit-learn
|
| 8 |
groq
|
| 9 |
langchain
|
| 10 |
+
tqdm
|
| 11 |
+
python-dotenv
|