Files changed (1) hide show
  1. app.py +171 -149
app.py CHANGED
@@ -1,150 +1,172 @@
1
- """
2
- FastAPI Application for Multimodal RAG System
3
- US Army Medical Research Papers Q&A
4
- """
5
-
6
- import os
7
- import logging
8
- from typing import List, Dict, Optional, Union
9
- from contextlib import asynccontextmanager
10
-
11
- from fastapi import FastAPI, HTTPException
12
- from fastapi.middleware.cors import CORSMiddleware
13
- from fastapi.responses import FileResponse
14
- from fastapi.staticfiles import StaticFiles
15
- from pydantic import BaseModel, Field
16
-
17
- # Import from query_index (standalone)
18
- from query_index import MultimodalRAGSystem
19
-
20
- # Configure logging
21
- logging.basicConfig(
22
- level=logging.INFO,
23
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
24
- )
25
- logger = logging.getLogger(__name__)
26
-
27
- # Global variables
28
- rag_system: Optional[MultimodalRAGSystem] = None
29
-
30
- # Lifecycle management
31
- @asynccontextmanager
32
- async def lifespan(app: FastAPI):
33
- """Initialize and cleanup RAG system"""
34
- global rag_system
35
-
36
- logger.info("Starting RAG system initialization...")
37
- try:
38
- rag_system = MultimodalRAGSystem()
39
- logger.info("RAG system initialized successfully!")
40
- except Exception as e:
41
- logger.error(f"Error during initialization: {str(e)}")
42
- rag_system = None
43
-
44
- yield
45
-
46
- logger.info("Shutting down RAG system...")
47
- rag_system = None
48
-
49
- # Create FastAPI app
50
- app = FastAPI(
51
- title="Multimodal RAG API",
52
- description="Q&A system for US Army medical research papers (Text + Images)",
53
- version="2.0.0",
54
- lifespan=lifespan
55
- )
56
-
57
- # CORS middleware
58
- app.add_middleware(
59
- CORSMiddleware,
60
- allow_origins=["*"],
61
- allow_credentials=True,
62
- allow_methods=["*"],
63
- allow_headers=["*"],
64
- )
65
-
66
- # Mount static files
67
- app.mount("/static", StaticFiles(directory="static"), name="static")
68
-
69
- # Mount extracted images
70
- # This allows the frontend to load images via /extracted_images/filename.jpg
71
- if os.path.exists("extracted_images"):
72
- app.mount("/extracted_images", StaticFiles(directory="extracted_images"), name="images")
73
-
74
- # Mount PDF documents
75
- if os.path.exists("WHEC_Documents"):
76
- app.mount("/documents", StaticFiles(directory="WHEC_Documents"), name="documents")
77
-
78
- # Pydantic models
79
- class QueryRequest(BaseModel):
80
- question: str = Field(..., min_length=1, max_length=1000, description="Question to ask")
81
-
82
- class ImageSource(BaseModel):
83
- path: Optional[str]
84
- filename: Optional[str]
85
- score: Optional[float]
86
- page: Optional[Union[str, int]] # could be int or str depending on metadata
87
- file: Optional[str]
88
- link: Optional[str] = None
89
-
90
- class TextSource(BaseModel):
91
- text: str
92
- score: float
93
- page: Optional[Union[str, int]]
94
- file: Optional[str]
95
- link: Optional[str] = None
96
-
97
- class QueryResponse(BaseModel):
98
- answer: str
99
- images: List[ImageSource]
100
- texts: List[TextSource]
101
- question: str
102
-
103
- class HealthResponse(BaseModel):
104
- status: str
105
- rag_initialized: bool
106
-
107
- # API Endpoints
108
-
109
- @app.get("/", tags=["Root"])
110
- async def root():
111
- """Serve the frontend application"""
112
- return FileResponse('static/index.html')
113
-
114
- @app.get("/health", response_model=HealthResponse, tags=["Health"])
115
- async def health_check():
116
- """Health check endpoint"""
117
- return HealthResponse(
118
- status="healthy",
119
- rag_initialized=rag_system is not None
120
- )
121
-
122
- @app.post("/query", response_model=QueryResponse, tags=["Query"])
123
- async def query_rag(request: QueryRequest):
124
- """
125
- Query the RAG system
126
- """
127
- if not rag_system:
128
- raise HTTPException(
129
- status_code=503,
130
- detail="RAG system not initialized. Check logs for errors."
131
- )
132
-
133
- try:
134
- # Get answer
135
- result = rag_system.ask(request.question)
136
-
137
- return QueryResponse(
138
- answer=result['answer'],
139
- images=result['images'],
140
- texts=result['texts'],
141
- question=request.question
142
- )
143
-
144
- except Exception as e:
145
- logger.error(f"Error processing query: {str(e)}")
146
- raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}")
147
-
148
- if __name__ == "__main__":
149
- import uvicorn
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ """
2
+ FastAPI Application for Multimodal RAG System
3
+ US Army Medical Research Papers Q&A
4
+ """
5
+
6
+ import os
7
+ import logging
8
+ from typing import List, Dict, Optional, Union
9
+ from contextlib import asynccontextmanager
10
+
11
+ from fastapi import FastAPI, HTTPException
12
+ from fastapi.middleware.cors import CORSMiddleware
13
+ from fastapi.responses import FileResponse
14
+ from fastapi.staticfiles import StaticFiles
15
+ from pydantic import BaseModel, Field
16
+
17
+ # Import from query_index (standalone)
18
+ from query_index import MultimodalRAGSystem
19
+
20
+ # Configure logging
21
+ logging.basicConfig(
22
+ level=logging.INFO,
23
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
24
+ )
25
+ logger = logging.getLogger(__name__)
26
+
27
+ # Global variables
28
+ rag_system: Optional[MultimodalRAGSystem] = None
29
+
30
+ # Store last question-answer pair for simple follow-up
31
+ last_qa_context: Optional[str] = None
32
+
33
+ # Lifecycle management
34
+ @asynccontextmanager
35
+ async def lifespan(app: FastAPI):
36
+ """Initialize and cleanup RAG system"""
37
+ global rag_system
38
+
39
+ logger.info("Starting RAG system initialization...")
40
+ try:
41
+ rag_system = MultimodalRAGSystem()
42
+ logger.info("RAG system initialized successfully!")
43
+ except Exception as e:
44
+ logger.error(f"Error during initialization: {str(e)}")
45
+ rag_system = None
46
+
47
+ yield
48
+
49
+ logger.info("Shutting down RAG system...")
50
+ rag_system = None
51
+
52
+ # Create FastAPI app
53
+ app = FastAPI(
54
+ title="Multimodal RAG API",
55
+ description="Q&A system for US Army medical research papers (Text + Images)",
56
+ version="2.0.0",
57
+ lifespan=lifespan
58
+ )
59
+
60
+ # CORS middleware
61
+ app.add_middleware(
62
+ CORSMiddleware,
63
+ allow_origins=["*"],
64
+ allow_credentials=True,
65
+ allow_methods=["*"],
66
+ allow_headers=["*"],
67
+ )
68
+
69
+ # Mount static files
70
+ app.mount("/static", StaticFiles(directory="static"), name="static")
71
+
72
+ # Mount extracted images
73
+ # This allows the frontend to load images via /extracted_images/filename.jpg
74
+ if os.path.exists("extracted_images"):
75
+ app.mount("/extracted_images", StaticFiles(directory="extracted_images"), name="images")
76
+
77
+ # Mount PDF documents
78
+ if os.path.exists("WHEC_Documents"):
79
+ app.mount("/documents", StaticFiles(directory="WHEC_Documents"), name="documents")
80
+
81
+ # Pydantic models
82
+ class QueryRequest(BaseModel):
83
+ question: str = Field(..., min_length=1, max_length=1000, description="Question to ask")
84
+
85
+ class ImageSource(BaseModel):
86
+ path: Optional[str]
87
+ filename: Optional[str]
88
+ score: Optional[float]
89
+ page: Optional[Union[str, int]] # could be int or str depending on metadata
90
+ file: Optional[str]
91
+ link: Optional[str] = None
92
+
93
+ class TextSource(BaseModel):
94
+ text: str
95
+ score: float
96
+ page: Optional[Union[str, int]]
97
+ file: Optional[str]
98
+ link: Optional[str] = None
99
+
100
+ class QueryResponse(BaseModel):
101
+ answer: str
102
+ images: List[ImageSource]
103
+ texts: List[TextSource]
104
+ question: str
105
+
106
+ class HealthResponse(BaseModel):
107
+ status: str
108
+ rag_initialized: bool
109
+
110
+ # API Endpoints
111
+
112
+ @app.get("/", tags=["Root"])
113
+ async def root():
114
+ """Serve the frontend application"""
115
+ return FileResponse('static/index.html')
116
+
117
+ @app.get("/health", response_model=HealthResponse, tags=["Health"])
118
+ async def health_check():
119
+ """Health check endpoint"""
120
+ return HealthResponse(
121
+ status="healthy",
122
+ rag_initialized=rag_system is not None
123
+ )
124
+
125
+ @app.post("/query", response_model=QueryResponse, tags=["Query"])
126
+ async def query_rag(request: QueryRequest):
127
+ """
128
+ Query the RAG system
129
+ """
130
+ global last_qa_context
131
+
132
+ if not rag_system:
133
+ raise HTTPException(
134
+ status_code=503,
135
+ detail="RAG system not initialized. Check logs for errors."
136
+ )
137
+
138
+ try:
139
+ # Build prompt using previous Q/A if available
140
+ if last_qa_context:
141
+ prompt = (
142
+ f"Previous question and answer:\n"
143
+ f"{last_qa_context}\n\n"
144
+ f"Follow up question:\n"
145
+ f"{request.question}"
146
+ )
147
+ else:
148
+ prompt = request.question
149
+
150
+ # Query RAG system
151
+ result = rag_system.ask(prompt)
152
+
153
+ # Save current Q/A as context for next turn
154
+ last_qa_context = (
155
+ f"Question: {request.question}\n"
156
+ f"Answer: {result['answer']}"
157
+ )
158
+
159
+ return QueryResponse(
160
+ answer=result['answer'],
161
+ images=result['images'],
162
+ texts=result['texts'],
163
+ question=request.question
164
+ )
165
+
166
+ except Exception as e:
167
+ logger.error(f"Error processing query: {str(e)}")
168
+ raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}")
169
+
170
+ if __name__ == "__main__":
171
+ import uvicorn
172
  uvicorn.run(app, host="0.0.0.0", port=7860)