|
import os |
|
from transformers import pipeline |
|
import torch |
|
import nltk |
|
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction |
|
import fitz |
|
from sentence_transformers import SentenceTransformer |
|
import faiss |
|
import numpy as np |
|
import pickle |
|
import re |
|
import logging |
|
from fastapi import FastAPI, HTTPException |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from pydantic import BaseModel |
|
import uvicorn |
|
import asyncio |
|
from config import ( |
|
ALL_FILES, |
|
MATH_FILES, |
|
SCIENCE_FILES, |
|
DATA_DIR, |
|
DOCUMENTS_PATH, |
|
FAISS_INDEX_PATH, |
|
HUGGINGFACE_TOKEN, |
|
MODEL_ID |
|
) |
|
|
|
app = FastAPI(title="Swahili Content Generation API") |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
class PromptRequest(BaseModel): |
|
prompt: str |
|
|
|
class ContentRequest(BaseModel): |
|
grade: int |
|
subject: str |
|
topic: str |
|
style: str = "normal" |
|
|
|
TOPIC_KEYWORDS = { |
|
|
|
'mazingira g3.pdf': ['mazingira'], |
|
'nishati g3.pdf': ['nishati'], |
|
'maada g3.pdf': ['maada'], |
|
'mawasiliano g3.pdf': ['mawasiliano'], |
|
'usafi g3.pdf': ['usafi'], |
|
'vipimo g3.pdf': ['vipimo-s'], |
|
'mlo g3.pdf': ['mlo'], |
|
'mfumo g3.pdf': ['mfumo'], |
|
'maambukizi g3.pdf': ['maambukizi'], |
|
'huduma g3.pdf': ['huduma'], |
|
'vifaa g3.pdf': ['vifaa'], |
|
|
|
|
|
'kinga ya mwili g4.txt': ['kinga'], |
|
'magonjwa g4.txt': ['magonjwa'], |
|
'majaribio ya kisayansi g4.txt': ['majaribio'], |
|
'maji g4.txt': ['maji'], |
|
'ukimwi g4.txt': ['ukimwi'], |
|
'huduma g4.txt': ['huduma-g4'], |
|
'mazingira g4.txt': ['mazingira-g4'], |
|
'matumizi ya nishati g4.txt': ['matumizi-nishati-g4'], |
|
'nishati g4.txt': ['nishati-g4'], |
|
'mfumo g4.txt': ['mfumo-g4'], |
|
'mawasiliano g4.txt': ['mawasiliano-g4'], |
|
|
|
|
|
'namba g3.txt': ['namba'], |
|
'mpangilio g3.txt': ['mpangilio'], |
|
'matendo katika namba g3.txt': ['matendo'], |
|
'kutambua sehemu g3.txt': ['sehemu'], |
|
'kutambua maumbo g3.txt': ['maumbo'], |
|
'vipimo g3.txt': ['vipimo'], |
|
'fedha g3.txt': ['fedha'], |
|
'takwimu kwa picha g3.txt': ['takwimu'], |
|
|
|
|
|
'kugawanya namba g4.txt': ['kugawanya'], |
|
'kujumlisha namba g4.txt': ['kujumlisha'], |
|
'kuzidisha namba g4.txt': ['kuzidisha'], |
|
'namba nzima g4.txt': ['namba-g4'], |
|
'namba za kirumi g4.txt': ['kirumi'], |
|
'wakati g4.txt': ['wakati'], |
|
'mpangilio g4.txt': ['mpangilio-g4'], |
|
'vipimo g4.txt': ['vipimo-g4'], |
|
'takwimu g4.txt': ['takwimu-g4'], |
|
'kutoa namba g4.txt': ['kutoa'], |
|
'fedha g4.txt': ['fedha-g4'], |
|
'sehemu g4.txt': ['sehemu-g4'], |
|
'maumbo g4.txt': ['maumbo-g4'] |
|
} |
|
|
|
def preprocess_pdf_text(text): |
|
words_to_remove = ['FOR', 'ONLINE', 'USE', 'ONLY', 'DO', 'NOT', 'DUPLICATE', 'SAYANSI', 'STD', 'PM'] |
|
pattern = r'\b(?:' + '|'.join(map(re.escape, words_to_remove)) + r')\b' |
|
text = re.sub(pattern, '', text, flags=re.IGNORECASE) |
|
|
|
text = ' '.join(text.split()) |
|
text = re.sub(r'[^\w\s\.\,\?\!\'\"àèìòùÀÈÌÒÙáéíóúÁÉÍÓÚâêîôûÂÊÎÔÛãẽĩõũÃẼĨÕŨ]', ' ', text) |
|
text = ' '.join(text.split()) |
|
return text |
|
|
|
def extract_text_from_file(file_path): |
|
if file_path.lower().endswith('.pdf'): |
|
return extract_text_from_pdf(file_path) |
|
elif file_path.lower().endswith('.txt'): |
|
try: |
|
with open(file_path, 'r', encoding='utf-8') as file: |
|
text = file.read() |
|
return text.strip() |
|
except Exception as e: |
|
logging.error(f"Error reading text file {file_path}: {str(e)}") |
|
return "" |
|
else: |
|
logging.error(f"Unsupported file type for {file_path}") |
|
return "" |
|
|
|
def extract_text_from_pdf(pdf_path): |
|
text = "" |
|
with fitz.open(pdf_path) as doc: |
|
for page_num, page in enumerate(doc): |
|
try: |
|
blocks = page.get_text("blocks") |
|
page_text = "\n".join(block[4] for block in blocks) |
|
cleaned_text = preprocess_pdf_text(page_text) |
|
text += cleaned_text + "\n" |
|
|
|
except Exception as e: |
|
logging.error(f"Error processing page {page_num + 1}: {str(e)}") |
|
continue |
|
|
|
return text.strip() |
|
|
|
def split_text_into_chunks(text, source_file, chunk_size=500, overlap=50): |
|
|
|
text = text.strip().replace('\n', ' ').replace(' ', ' ') |
|
|
|
|
|
filename = os.path.basename(source_file) |
|
keywords = TOPIC_KEYWORDS.get(filename, []) |
|
|
|
|
|
sentences = nltk.sent_tokenize(text) |
|
chunks = [] |
|
current_chunk = [] |
|
current_size = 0 |
|
|
|
for sentence in sentences: |
|
sentence_words = len(sentence.split()) |
|
|
|
if current_size + sentence_words > chunk_size: |
|
if current_chunk: |
|
|
|
chunk_text = ' '.join(current_chunk) |
|
|
|
chunk_info = { |
|
'text': chunk_text, |
|
'source': filename, |
|
'keywords': keywords |
|
} |
|
|
|
chunks.append(chunk_info) |
|
|
|
|
|
overlap_size = 0 |
|
overlap_chunk = [] |
|
for s in reversed(current_chunk): |
|
if overlap_size + len(s.split()) <= overlap: |
|
overlap_chunk.insert(0, s) |
|
overlap_size += len(s.split()) |
|
else: |
|
break |
|
|
|
current_chunk = overlap_chunk |
|
current_size = overlap_size |
|
|
|
current_chunk.append(sentence) |
|
current_size += sentence_words |
|
|
|
if current_chunk: |
|
chunk_text = ' '.join(current_chunk) |
|
chunks.append({ |
|
'text': chunk_text, |
|
'source': filename, |
|
'keywords': keywords |
|
}) |
|
|
|
return chunks |
|
|
|
def create_faiss_index(texts, embedding_model): |
|
doc_embeddings = embedding_model.encode(texts) |
|
index = faiss.IndexFlatL2(doc_embeddings.shape[1]) |
|
index.add(np.array(doc_embeddings)) |
|
return index |
|
|
|
def retrieve_documents(query, index, embedding_model, documents, top_k=5): |
|
query_lower = query.lower() |
|
target_topic = None |
|
|
|
|
|
for filename, keywords in TOPIC_KEYWORDS.items(): |
|
if keywords[0] == query_lower: |
|
target_topic = filename |
|
break |
|
|
|
|
|
query_embedding = embedding_model.encode([query]) |
|
distances, indices = index.search(query_embedding, top_k * 3) |
|
|
|
|
|
topic_docs = [] |
|
|
|
for idx in indices[0]: |
|
doc = documents[idx] |
|
if doc['source'] == target_topic: |
|
|
|
if not any(existing.get('text', '') == doc['text'] for existing in topic_docs): |
|
topic_docs.append(doc) |
|
|
|
if len(topic_docs) >= top_k: |
|
break |
|
|
|
final_content = "\n\n".join(doc['text'] for doc in topic_docs[:top_k]) |
|
logger.info(f"Retrieved content from: {target_topic}") |
|
return final_content |
|
|
|
def calculate_bleu(reference, candidate): |
|
""" |
|
Calculate BLEU score between reference and candidate texts. |
|
""" |
|
if isinstance(reference, list): |
|
reference = " ".join(reference) |
|
if isinstance(candidate, list): |
|
candidate = " ".join(candidate) |
|
|
|
reference_tokens = [reference.split()] |
|
candidate_tokens = candidate.split() |
|
smoothing = SmoothingFunction().method1 |
|
return sentence_bleu(reference_tokens, candidate_tokens, smoothing_function=smoothing) |
|
|
|
def get_topic_files(grade: int, subject: str, topic: str): |
|
|
|
topic_lower = topic.lower() |
|
|
|
|
|
file_list = MATH_FILES if subject.lower() == "math" else SCIENCE_FILES |
|
|
|
|
|
matching_files = [] |
|
for file in file_list: |
|
if f"g{grade}" in file.lower(): |
|
filename = os.path.basename(file) |
|
if filename in TOPIC_KEYWORDS: |
|
keywords = TOPIC_KEYWORDS[filename] |
|
if topic_lower == keywords[0]: |
|
matching_files.append(file) |
|
|
|
return matching_files |
|
|
|
def generate_response_with_rag(prompt, index, embedding_model, documents, settings): |
|
|
|
retrieved_context = retrieve_documents(prompt, index, embedding_model, documents) |
|
|
|
|
|
logger.info("Context sent to model:") |
|
logger.info("-" * 50) |
|
logger.info(retrieved_context) |
|
logger.info("-" * 50) |
|
|
|
style_instructions = { |
|
"simple": "Provide clear and easy-to-understand answers using common words and short sentences. Explain concepts as if talking to a young student.", |
|
"creative": "Give creative and engaging answers, using real-life examples and illustrations to make the content interesting and memorable.", |
|
"normal": "" |
|
} |
|
|
|
instruction = style_instructions.get(settings.get("style", "normal"), "") |
|
|
|
|
|
system_prompt = f""" |
|
Explain the topic of "{settings['topic']}" in detail following this structure: |
|
1. Summary: Briefly explain what the student will learn in this topic (5-6 sentences). |
|
2. Introduction to the topic: Provide background information about the topic before breaking it down into subtopics. |
|
3. Subtopics: Explain each subtopic in detail, providing real-life examples where necessary. For each subtopic, Describe images that could help explain the topic in detail using text instead of actual images. |
|
Use this format: [Picture: Image description]. Dont provide more than 3 [Picture: Image description]. |
|
4. Activities: After each subtopic, provide small exercises or activities that the student can do to enhance understanding (Activities). |
|
5. Practice questions: Provide 6-8 questions related to the topic to reinforce the student's understanding. |
|
|
|
**Respond to all questions and instructions in Swahili.** |
|
|
|
{instruction} |
|
|
|
Context: |
|
{retrieved_context} |
|
""" |
|
|
|
|
|
messages = [{"role": "system", "content": system_prompt}] |
|
outputs = app.state.pipe(messages, max_new_tokens=2000) |
|
|
|
try: |
|
|
|
if not outputs or len(outputs) == 0: |
|
logger.error("No output generated") |
|
return { |
|
"content": "Failed to generate response", |
|
"context": retrieved_context |
|
} |
|
|
|
|
|
generated_messages = outputs[0]['generated_text'] |
|
if isinstance(generated_messages, list): |
|
|
|
for message in generated_messages: |
|
if message.get('role') == 'assistant': |
|
response_content = message.get('content', '') |
|
break |
|
else: |
|
logger.error("No assistant response found in messages") |
|
return { |
|
"content": "Failed to generate response", |
|
"context": retrieved_context |
|
} |
|
else: |
|
response_content = generated_messages |
|
|
|
if not response_content: |
|
logger.error("Empty response content") |
|
return { |
|
"content": "Failed to generate response", |
|
"context": retrieved_context |
|
} |
|
|
|
|
|
response_content = response_content.strip() |
|
|
|
|
|
paragraphs = [p.strip() for p in response_content.split('\n\n') if p.strip()] |
|
|
|
|
|
formatted_paragraphs = [] |
|
for paragraph in paragraphs: |
|
|
|
|
|
if len(paragraph) > 100 and '\n' not in paragraph: |
|
sentences = [s.strip() for s in nltk.sent_tokenize(paragraph)] |
|
formatted_paragraphs.append('\n'.join(sentences)) |
|
else: |
|
formatted_paragraphs.append(paragraph) |
|
|
|
|
|
response_content = '\n\n'.join(formatted_paragraphs) |
|
response_content = response_content.replace('\n', '<br>') |
|
|
|
return { |
|
"content": response_content, |
|
"context": retrieved_context |
|
} |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing response: {e}") |
|
logger.error(f"Raw output: {outputs}") |
|
return { |
|
"content": "Error processing response", |
|
"context": retrieved_context |
|
} |
|
|
|
async def load_or_create_index(): |
|
embedding_model = SentenceTransformer("all-MiniLM-L6-v2") |
|
os.makedirs(DATA_DIR, exist_ok=True) |
|
os.makedirs(os.path.dirname(FAISS_INDEX_PATH), exist_ok=True) |
|
|
|
try: |
|
with open(DOCUMENTS_PATH, 'rb') as f: |
|
documents = pickle.load(f) |
|
index = faiss.read_index(FAISS_INDEX_PATH) |
|
print("FAISS index and documents loaded successfully.") |
|
return index, documents, embedding_model |
|
except FileNotFoundError: |
|
print("Index and documents not found. Proceeding to create them.") |
|
documents = [] |
|
|
|
|
|
files_found = False |
|
for file_path in ALL_FILES: |
|
if not os.path.exists(file_path): |
|
logger.warning(f"File not found: {file_path}") |
|
continue |
|
|
|
filename = os.path.basename(file_path) |
|
logging.info(f"Processing {filename}") |
|
text = extract_text_from_file(file_path) |
|
|
|
if text: |
|
files_found = True |
|
chunks = split_text_into_chunks(text, filename) |
|
documents.extend(chunks) |
|
await asyncio.sleep(0) |
|
|
|
if not files_found: |
|
raise Exception(f"No valid text or PDF files found in the specified paths") |
|
|
|
texts = [doc['text'] for doc in documents] |
|
index = create_faiss_index(texts, embedding_model) |
|
|
|
os.makedirs(os.path.dirname(DOCUMENTS_PATH), exist_ok=True) |
|
|
|
|
|
with open(DOCUMENTS_PATH, 'wb') as f: |
|
pickle.dump(documents, f) |
|
faiss.write_index(index, FAISS_INDEX_PATH) |
|
print("FAISS index and documents saved successfully.") |
|
|
|
return index, documents, embedding_model |
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
) |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@app.on_event("startup") |
|
async def startup_event(): |
|
"""Initialize the application on startup.""" |
|
logger = logging.getLogger(__name__) |
|
logger.info("Starting application initialization...") |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
logger.info(f"Using device: {device}") |
|
|
|
if device == "cpu": |
|
logger.warning("GPU not detected. Model will run slower on CPU.") |
|
|
|
|
|
nltk_data_dir = os.environ.get('NLTK_DATA', os.path.join(os.path.expanduser('~'), 'nltk_data')) |
|
os.makedirs(nltk_data_dir, exist_ok=True) |
|
|
|
|
|
logger.info("Downloading NLTK data...") |
|
try: |
|
|
|
import nltk.data |
|
try: |
|
nltk.data.find('tokenizers/punkt', paths=[nltk_data_dir]) |
|
logger.info("NLTK punkt already downloaded") |
|
except LookupError: |
|
await asyncio.to_thread(nltk.download, 'punkt', download_dir=nltk_data_dir, quiet=True) |
|
|
|
try: |
|
nltk.data.find('tokenizers/punkt_tab', paths=[nltk_data_dir]) |
|
logger.info("NLTK punkt_tab already downloaded") |
|
except LookupError: |
|
await asyncio.to_thread(nltk.download, 'punkt_tab', download_dir=nltk_data_dir, quiet=True) |
|
except Exception as e: |
|
logger.error(f"Error handling NLTK data: {str(e)}") |
|
raise Exception(f"Failed to initialize application: {str(e)}") |
|
|
|
|
|
try: |
|
app.state.pipe = pipeline( |
|
"text-generation", |
|
model=MODEL_ID, |
|
trust_remote_code=True, |
|
token=HUGGINGFACE_TOKEN, |
|
device_map="auto", |
|
torch_dtype=torch.float16 if device == "cuda" else torch.float32 |
|
) |
|
|
|
faiss_index, documents, embedding_model = await load_or_create_index() |
|
|
|
|
|
app.state.faiss_index = faiss_index |
|
app.state.documents = documents |
|
app.state.embedding_model = embedding_model |
|
|
|
logger.info("Application initialization completed successfully") |
|
except Exception as e: |
|
logger.error(f"Error initializing application: {str(e)}") |
|
raise Exception(f"Failed to initialize application: {str(e)}") |
|
|
|
@app.post("/generate") |
|
async def generate_content(request: ContentRequest): |
|
try: |
|
logger.info(f"Generating content for grade {request.grade}, subject {request.subject}, topic {request.topic}") |
|
|
|
|
|
if request.grade not in [3, 4]: |
|
raise HTTPException(status_code=400, detail="Invalid grade level. Must be 3 or 4") |
|
|
|
if request.subject.lower() not in ["math", "science"]: |
|
raise HTTPException(status_code=400, detail="Invalid subject. Must be 'math' or 'science'") |
|
|
|
if request.style not in ["normal", "simple", "creative"]: |
|
raise HTTPException(status_code=400, detail="Invalid style. Must be 'normal', 'simple', or 'creative'") |
|
|
|
|
|
topic_files = get_topic_files(request.grade, request.subject, request.topic) |
|
if not topic_files: |
|
raise HTTPException(status_code=404, detail="Topic not found for specified grade and subject") |
|
|
|
|
|
settings = { |
|
"style": request.style, |
|
"topic": request.topic, |
|
"grade": request.grade, |
|
"subject": request.subject |
|
} |
|
|
|
response = generate_response_with_rag( |
|
request.topic, |
|
app.state.faiss_index, |
|
app.state.embedding_model, |
|
app.state.documents, |
|
settings |
|
) |
|
|
|
logger.info("Content generated successfully") |
|
return {"response": response['content']} |
|
|
|
except Exception as e: |
|
logger.error(f"Error generating response: {str(e)}") |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@app.get("/health") |
|
async def health_check(): |
|
try: |
|
|
|
if not hasattr(app.state, "pipe"): |
|
return {"status": "starting", "message": "Model is still loading"} |
|
return {"status": "healthy"} |
|
except Exception as e: |
|
logger.error(f"Health check failed: {str(e)}") |
|
raise HTTPException(status_code=500, detail="Internal server error") |
|
|
|
if __name__ == "__main__": |
|
try: |
|
logger.info("Starting FastAPI server...") |
|
uvicorn.run(app, host="0.0.0.0", port=8080, log_level="info") |
|
except Exception as e: |
|
logger.error(f"Application failed to start: {str(e)}") |
|
raise |
|
|