Spaces:
Sleeping
Sleeping
import os | |
import json | |
import logging | |
import requests | |
import xmltodict | |
import time | |
import streamlit as st | |
from openai import OpenAI | |
from typing import List, Dict | |
from io import StringIO | |
# Configure logging for progress tracking and debugging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Initialize OpenAI client with the DeepSeek model | |
client = OpenAI( | |
base_url="https://api.aimlapi.com/v1", | |
api_key="api-key", # Replace with your AIML API key | |
) | |
# Define constants for PubMed API | |
BASE_URL = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/" | |
SEARCH_URL = f"{BASE_URL}esearch.fcgi" | |
FETCH_URL = f"{BASE_URL}efetch.fcgi" | |
class KnowledgeBaseLoader: | |
""" | |
Loads schizophrenia research documents from a JSON file. | |
""" | |
def __init__(self, filepath: str): | |
self.filepath = filepath | |
def load_data(self) -> List[Dict]: | |
"""Loads and returns data from the JSON file.""" | |
try: | |
with open(self.filepath, "r", encoding="utf-8") as f: | |
data = json.load(f) | |
logger.info(f"Successfully loaded {len(data)} records from '{self.filepath}'.") | |
return data | |
except Exception as e: | |
logger.error(f"Error loading knowledge base: {e}") | |
return [] | |
class SchizophreniaAgent: | |
""" | |
An agent to answer questions related to schizophrenia using a domain-specific knowledge base. | |
""" | |
def __init__(self, knowledge_base: List[Dict]): | |
self.knowledge_base = knowledge_base | |
def process_query(self, query: str) -> str: | |
""" | |
Process the incoming query by searching for matching documents in the knowledge base. | |
Args: | |
query: A string containing the user's query. | |
Returns: | |
A response string summarizing how many documents matched and some sample content. | |
""" | |
if not self.knowledge_base: | |
logger.warning("Knowledge base is empty. Cannot process query.") | |
return "No knowledge base available." | |
# Simple matching: count documents where query text is found in abstract | |
matching_docs = [] | |
for doc in self.knowledge_base: | |
# Ensure abstract is a string (if it's a list, join it into a single string) | |
abstract = doc.get("abstract", []) | |
# Check if abstract is a list and join items that are strings | |
if isinstance(abstract, list): | |
abstract = " ".join([str(item) for item in abstract if isinstance(item, str)]).strip() | |
if query.lower() in abstract.lower(): | |
matching_docs.append(doc) | |
logger.info(f"Query '{query}' matched {len(matching_docs)} documents.") | |
# For a more robust agent, integrate with an LLM or retrieval system here. | |
if len(matching_docs) > 0: | |
response = ( | |
f"Found {len(matching_docs)} documents matching your query. " | |
f"Examples: " + | |
", ".join(f"'{doc.get('title', 'No Title')}'" for doc in matching_docs[:3]) + | |
"." | |
) | |
else: | |
response = "No relevant documents found for your query." | |
# Now ask the AIML model (DeepSeek) to generate more user-friendly information | |
aiml_response = self.query_deepseek(query) | |
return response + "\n\nAI-Suggested Guidance:\n" + aiml_response | |
def query_deepseek(self, query: str) -> str: | |
"""Query DeepSeek for additional AI-driven responses.""" | |
response = client.chat.completions.create( | |
model="deepseek/deepseek-r1", | |
messages=[ | |
{"role": "system", "content": "You are an AI assistant who knows everything about schizophrenia."}, | |
{"role": "user", "content": query} | |
], | |
) | |
return response.choices[0].message.content | |
def fetch_pubmed_papers(query: str, max_results: int = 10): | |
""" | |
Fetch PubMed papers related to the query (e.g., "schizophrenia"). | |
Args: | |
query (str): The search term to look for in PubMed. | |
max_results (int): The maximum number of results to fetch (default is 10). | |
Returns: | |
List of dictionaries containing paper details like title, abstract, etc. | |
""" | |
# Step 1: Search PubMed for articles related to the query | |
search_params = { | |
'db': 'pubmed', | |
'term': query, | |
'retmax': max_results, | |
'retmode': 'xml' | |
} | |
search_response = requests.get(SEARCH_URL, params=search_params) | |
if search_response.status_code != 200: | |
print("Error: Unable to fetch search results from PubMed.") | |
return [] | |
search_data = xmltodict.parse(search_response.text) | |
# Step 2: Extract PubMed IDs (PMIDs) from the search results | |
try: | |
pmids = search_data['eSearchResult']['IdList']['Id'] | |
except KeyError: | |
print("Error: No PubMed IDs found in search results.") | |
return [] | |
# Step 3: Fetch the details of the papers using the PMIDs | |
papers = [] | |
for pmid in pmids: | |
fetch_params = { | |
'db': 'pubmed', | |
'id': pmid, | |
'retmode': 'xml', | |
'rettype': 'abstract' | |
} | |
fetch_response = requests.get(FETCH_URL, params=fetch_params) | |
if fetch_response.status_code != 200: | |
print(f"Error: Unable to fetch details for PMID {pmid}") | |
continue | |
fetch_data = xmltodict.parse(fetch_response.text) | |
# Extract relevant details for each paper | |
try: | |
paper = fetch_data['PubmedArticleSet']['PubmedArticle'] | |
title = paper['MedlineCitation']['Article']['ArticleTitle'] | |
abstract = paper['MedlineCitation']['Article'].get('Abstract', {}).get('AbstractText', 'No abstract available.') | |
journal = paper['MedlineCitation']['Article']['Journal']['Title'] | |
year = paper['MedlineCitation']['Article']['Journal']['JournalIssue']['PubDate']['Year'] | |
# Store paper details in a dictionary | |
papers.append({ | |
'pmid': pmid, | |
'title': title, | |
'abstract': abstract, | |
'journal': journal, | |
'year': year | |
}) | |
except KeyError: | |
print(f"Error parsing paper details for PMID {pmid}") | |
continue | |
# Add a delay between requests to avoid hitting rate limits | |
time.sleep(1) | |
return papers | |
# Streamlit User Interface | |
def main(): | |
# Set configuration: path to the parsed knowledge base file | |
data_file = os.getenv("SCHIZ_DATA_FILE", "parsed_data.json") | |
# Initialize and load the knowledge base | |
loader = KnowledgeBaseLoader(data_file) | |
kb_data = loader.load_data() | |
# Initialize the schizophrenia agent with the loaded data | |
agent = SchizophreniaAgent(knowledge_base=kb_data) | |
# Streamlit UI setup | |
st.set_page_config(page_title="Schizophrenia Assistant", page_icon="🧠", layout="wide") | |
st.title("Schizophrenia Episode Management Assistant") | |
st.markdown( | |
""" | |
This tool helps you manage schizophrenia episodes. You can search PubMed for research papers or provide details about a patient's episode, and the assistant will provide recommendations and guidance. | |
""" | |
) | |
# **Part 1: Fetch and Download PubMed Papers** | |
st.header("Fetch and Download PubMed Papers") | |
query = st.text_input("Enter search query (e.g., schizophrenia):", value="schizophrenia") | |
if st.button("Fetch PubMed Papers"): | |
with st.spinner("Fetching papers..."): | |
papers = fetch_pubmed_papers(query, max_results=10) | |
if papers: | |
# Save papers to JSON and provide download link | |
json_data = json.dumps(papers, ensure_ascii=False, indent=4) | |
st.download_button("Download JSON", data=json_data, file_name="pubmed_papers.json", mime="application/json") | |
st.success(f"Successfully fetched {len(papers)} papers related to '{query}'") | |
else: | |
st.error("No papers found. Please try another query.") | |
# **Part 2: Upload and Use JSON File** | |
st.header("Upload and Use JSON File for Schizophrenia Assistant") | |
uploaded_file = st.file_uploader("Upload PubMed JSON file", type=["json"]) | |
if uploaded_file is not None: | |
file_data = json.load(uploaded_file) | |
st.write("File uploaded successfully. You can now query the assistant.") | |
agent = SchizophreniaAgent(knowledge_base=file_data) | |
# User Input for Query | |
user_input = st.text_area("Enter the patient's condition or episode details:", height=200) | |
if st.button("Get Response"): | |
if user_input.strip(): | |
with st.spinner("Processing your request..."): | |
answer = agent.process_query(user_input.strip()) | |
st.subheader("Response") | |
st.write(answer) | |
else: | |
st.error("Please enter a valid query to get a response.") | |
# Run the Streamlit app | |
if __name__ == "__main__": | |
main() | |