insuranceai / app.py
prithvirajpawar's picture
addition of intro_msg
140b902
from fastapi import FastAPI, Request, Depends, HTTPException, Header, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Optional
from helpmate_ai import get_system_msg, retreive_results, rerank_with_cross_encoder, generate_response, intro_message
import google.generativeai as genai
import os
from dotenv import load_dotenv
import re
import speech_recognition as sr
from io import BytesIO
import wave
import google.generativeai as genai
# Load environment variables
load_dotenv()
gemini_api_key = os.getenv("GEMINI_API_KEY")
genai.configure(api_key=gemini_api_key)
# Define a secret API key (use environment variables in production)
API_KEY = os.getenv("API_KEY")
# Initialize FastAPI app
app = FastAPI()
# # Enable CORS
# app.add_middleware(
# CORSMiddleware,
# allow_origins=["*"],
# allow_credentials=True,
# allow_methods=["*"],
# allow_headers=["*"],
# )
# Pydantic models for request/response validation
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
message: str
class ChatResponse(BaseModel):
response: str
conversation: List[Message]
class Report(BaseModel):
response: str
message: str
timestamp: str
# Initialize conversation and model
conversation_bot = []
conversation = get_system_msg()
model = genai.GenerativeModel("gemini-1.5-flash", system_instruction=conversation)
# Initialize speech recognizer
recognizer = sr.Recognizer()
# Dependency to check the API key
async def verify_api_key(x_api_key: str = Header(...)):
if x_api_key != API_KEY:
raise HTTPException(status_code=403, detail="Unauthorized")
def get_gemini_completions(conversation: str) -> str:
response = model.generate_content(conversation)
return response.text
# @app.get("/secure-endpoint", dependencies=[Depends(verify_api_key)])
# async def secure_endpoint():
# return {"message": "Access granted!"}
# Initialize conversation endpoint
@app.get("/init", response_model=ChatResponse, dependencies=[Depends(verify_api_key)])
async def initialize_chat():
global conversation_bot
# conversation = "Hi"
# introduction = get_gemini_completions(conversation)
conversation_bot = [Message(role="bot", content=intro_message)]
return ChatResponse(
response=intro_message,
conversation=conversation_bot
)
# Chat endpoint
@app.post("/chat", response_model=ChatResponse, dependencies=[Depends(verify_api_key)])
async def chat(request: ChatRequest):
global conversation_bot
# Add user message to conversation
user_message = Message(role="user", content=request.message)
conversation_bot.append(user_message)
# Generate response
results_df = retreive_results(request.message)
top_docs = rerank_with_cross_encoder(request.message, results_df)
messages = generate_response(request.message, top_docs)
response_assistant = get_gemini_completions(messages)
# formatted_response = format_rag_response(response_assistant)
# Add bot response to conversation
bot_message = Message(role="bot", content=response_assistant)
conversation_bot.append(bot_message)
return ChatResponse(
response=response_assistant,
conversation=conversation_bot
)
# Voice processing endpoint
@app.post("/process-voice")
async def process_voice(audio_file: UploadFile = File(...), dependencies=[Depends(verify_api_key)]):
# async def process_voice(name: str):
try:
# Read the audio file
contents = await audio_file.read()
audio_data = BytesIO(contents)
# Convert audio to wav format for speech recognition
with sr.AudioFile(audio_data) as source:
audio = recognizer.record(source)
# Perform speech recognition
text = recognizer.recognize_google(audio)
# print(text)
# Process the text through the chat pipeline
results_df = retreive_results(text)
top_docs = rerank_with_cross_encoder(text, results_df)
messages = generate_response(text, top_docs)
response_assistant = get_gemini_completions(messages)
return {
"transcribed_text": text,
"response": response_assistant
}
except Exception as e:
return {"error": f"Error processing voice input: {str(e)}"}
@app.post("/report")
async def handle_feedback(
request: Report,
dependencies=[Depends(verify_api_key)]
):
# if x_api_key != VALID_API_KEY:
# raise HTTPException(status_code=403, detail="Invalid API key")
# Here you can store the feedback in your database
# For example:
# await db.store_feedback(message, is_positive)
return {"status": "success"}
# Reset conversation endpoint
@app.post("/reset", dependencies=[Depends(verify_api_key)])
async def reset_conversation():
global conversation_bot, conversation
conversation_bot = []
# conversation = "Hi"
# introduction = get_gemini_completions(conversation)
conversation_bot.append(Message(role="bot", content=intro_message))
return {"status": "success", "message": "Conversation reset"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)