import os import re import numpy as np import faiss from datasets import load_dataset from sentence_transformers import SentenceTransformer from langchain_text_splitters import RecursiveCharacterTextSplitter from openai import OpenAI # =============================== # Load Embedder # =============================== def load_embedder(): return SentenceTransformer("all-MiniLM-L6-v2") embed_model = load_embedder() # =============================== # Clean Text # =============================== def clean_text(text): text = re.sub(r'\s+', ' ', text).strip() text = re.sub(r'http\S+|www\.\S+', '', text) text = re.sub(r'<.*?>', '', text) text = re.sub(r'\.{2,}', '.', text) text = re.sub(r'[^\w\s\.\,\-\/\(\)\%]', '', text) return text if len(text) >= 50 else "" # =============================== # Load & Prepare Dataset # =============================== def load_data(): dataset = load_dataset("qiaojin/PubMedQA", "pqa_labeled") texts = [] for x in dataset["train"]: context_text = " ".join(x["context"]["contexts"]) if "contexts" in x["context"] else "" text = f"Question: {x['question']} Context: {context_text} Answer: {x['final_decision']}" cleaned = clean_text(text) if cleaned: texts.append(cleaned) return texts texts = load_data() # =============================== # Chunking # =============================== def chunk_texts(texts): splitter = RecursiveCharacterTextSplitter( chunk_size=500, chunk_overlap=100, separators=["\n\n", "\n", ". ", " ", ""] ) all_chunks = [] sources = [] for i, text in enumerate(texts): chunks = splitter.split_text(text) for chunk in chunks: if len(chunk) > 50: all_chunks.append(chunk) sources.append(i) return all_chunks, sources all_chunks, chunk_sources = chunk_texts(texts) # =============================== # Embeddings + FAISS # =============================== def build_faiss(chunks): embeddings = embed_model.encode(chunks, show_progress_bar=False) dim = embeddings.shape[1] index = faiss.IndexFlatL2(dim) index.add(np.array(embeddings).astype("float32")) return index index = build_faiss(all_chunks) # =============================== # Retriever # =============================== def retrieve(query, k=5): q_embed = embed_model.encode([query]) distances, indices = index.search(q_embed.astype("float32"), k) results = [] for i, idx in enumerate(indices[0]): results.append({ "text": all_chunks[idx], "distance": float(distances[0][i]), "relevance": float(1 / (1 + distances[0][i])), "rank": i + 1, "source_doc": chunk_sources[idx] }) return results # =============================== # LLM Client # =============================== HF_TOKEN = os.getenv("HF_TOKEN") client = OpenAI( base_url="https://router.huggingface.co/v1", api_key=HF_TOKEN ) # =============================== # Answer Generator # =============================== def generate_answer(context, question): prompt = f""" You are a medical research assistant. Answer the question using ONLY the research context provided. INSTRUCTIONS: 1. Use ONLY the context 2. Do NOT hallucinate 3. If insufficient info, say: "I don't have sufficient research information to answer this question. Please consult a healthcare professional." Research Context: {context[:2500]} Question: {question} Answer in ONE clear paragraph: """ try: completion = client.chat.completions.create( model="meta-llama/Llama-3.1-8B-Instruct", messages=[{"role": "user", "content": prompt}], max_tokens=300, temperature=0.3 ) answer = completion.choices[0].message.content.strip() if "don't have sufficient" not in answer.lower(): answer += "\n\n⚠️ Research-based information only. Consult healthcare professionals." return answer except Exception as e: return f"❌ LLM Error: {e}"