import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import spaces
from duckduckgo_search import DDGS
import time
import torch
from datetime import datetime
import os
import subprocess
import numpy as np
from typing import List, Dict, Tuple, Any
# Install required dependencies for Kokoro with better error handling
try:
subprocess.run(['git', 'lfs', 'install'], check=True)
if not os.path.exists('Kokoro-82M'):
subprocess.run(['git', 'clone', 'https://huggingface.co/hexgrad/Kokoro-82M'], check=True)
# Try installing espeak with proper package manager commands
try:
# Update package list first
subprocess.run(['apt-get', 'update'], check=True)
# Try installing espeak first (more widely available)
subprocess.run(['apt-get', 'install', '-y', 'espeak'], check=True)
except subprocess.CalledProcessError:
print("Warning: Could not install espeak. Attempting espeak-ng...")
try:
subprocess.run(['apt-get', 'install', '-y', 'espeak-ng'], check=True)
except subprocess.CalledProcessError:
print("Warning: Could not install espeak or espeak-ng. TTS functionality may be limited.")
except Exception as e:
print(f"Warning: Initial setup error: {str(e)}")
print("Continuing with limited functionality...")
# --- Initialization (Do this ONCE) ---
model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
# Initialize DeepSeek model
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
offload_folder="offload",
low_cpu_mem_usage=True,
torch_dtype=torch.float16
)
# Initialize Kokoro TTS (with error handling)
VOICE_CHOICES = {
'πΊπΈ Female (Default)': 'af',
'πΊπΈ Bella': 'af_bella',
'πΊπΈ Sarah': 'af_sarah',
'πΊπΈ Nicole': 'af_nicole'
}
TTS_ENABLED = False
TTS_MODEL = None
VOICEPACK = None
try:
if os.path.exists('Kokoro-82M'):
import sys
sys.path.append('Kokoro-82M')
from models import build_model # type: ignore
from kokoro import generate # type: ignore
device = 'cuda' if torch.cuda.is_available() else 'cpu' # Correct device handling
TTS_MODEL = build_model('Kokoro-82M/kokoro-v0_19.pth', device)
# Load default voice
try:
VOICEPACK = torch.load('Kokoro-82M/voices/af.pt', map_location=device, weights_only=True)
except Exception as e:
print(f"Warning: Could not load default voice: {e}")
raise
TTS_ENABLED = True
else:
print("Warning: Kokoro-82M directory not found. TTS disabled.")
except Exception as e:
print(f"Warning: Could not initialize Kokoro TTS: {str(e)}")
TTS_ENABLED = False
def get_web_results(query: str, max_results: int = 5) -> List[Dict[str, str]]:
"""Get web search results using DuckDuckGo"""
try:
with DDGS() as ddgs:
results = list(ddgs.text(query, max_results=max_results))
return [{
"title": result.get("title", ""),
"snippet": result["body"],
"url": result["href"],
"date": result.get("published", "")
} for result in results]
except Exception as e:
print(f"Error in web search: {e}")
return []
def format_prompt(query: str, context: List[Dict[str, str]]) -> str:
"""Format the prompt with web context"""
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
context_lines = '\n'.join([f'- [{res["title"]}]: {res["snippet"]}' for res in context])
return f"""You are an intelligent search assistant. Answer the user's query using the provided web context.
Current Time: {current_time}
Important: For election-related queries, please distinguish clearly between different election years and types (presidential vs. non-presidential). Only use information from the provided web context.
Query: {query}
Web Context:
{context_lines}
Provide a detailed answer in markdown format. Include relevant information from sources and cite them using [1], [2], etc. If the query is about elections, clearly specify which year and type of election you're discussing.
Answer:"""
def format_sources(web_results: List[Dict[str, str]]) -> str:
"""Format sources with more details"""
if not web_results:
return "
No sources available
"
sources_html = ""
for i, res in enumerate(web_results, 1):
title = res["title"] or "Source"
date = f"
{res['date']}" if res['date'] else ""
sources_html += f"""
[{i}]
{title}
{date}
{res['snippet'][:150]}...
"""
sources_html += "
"
return sources_html
@spaces.GPU(duration=30)
def generate_answer(prompt: str) -> str:
"""Generate answer using the DeepSeek model"""
inputs = tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512,
return_attention_mask=True
).to(model.device)
outputs = model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=256,
temperature=0.7,
top_p=0.95,
pad_token_id=tokenizer.eos_token_id,
do_sample=True,
early_stopping=True
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
@spaces.GPU(duration=30)
def generate_speech_with_gpu(text: str, voice_name: str = 'af', tts_model = TTS_MODEL, voicepack = VOICEPACK) -> Tuple[int, np.ndarray] | None:
"""Generate speech from text using Kokoro TTS model."""
if not TTS_ENABLED or tts_model is None:
print("TTS is not enabled or model is not loaded.")
return None
try:
# Load voicepack if it hasn't been loaded or if a different voice is requested
if voice_name != 'af' or voicepack is None :
device = 'cuda' if torch.cuda.is_available() else 'cpu'
voicepack = torch.load(f'Kokoro-82M/voices/{voice_name}.pt', map_location=device, weights_only=True)
# Clean the text
clean_text = ' '.join([line for line in text.split('\n') if not line.startswith('#')])
clean_text = clean_text.replace('[', '').replace(']', '').replace('*', '')
# Split long text into chunks (improved logic)
max_chars = 1000
chunks = []
if len(clean_text) > max_chars:
sentences = clean_text.split('.')
current_chunk = ""
for sentence in sentences:
if len(current_chunk) + len(sentence) + 1 < max_chars: # +1 for the dot
current_chunk += sentence + "."
else:
chunks.append(current_chunk.strip())
current_chunk = sentence + "."
if current_chunk: # Add the last chunk
chunks.append(current_chunk.strip())
else:
chunks = [clean_text]
# Generate audio for each chunk
audio_chunks = []
for chunk in chunks:
if chunk.strip(): # Only process non-empty chunks
chunk_audio, _ = generate(tts_model, chunk, voicepack, lang='a')
if isinstance(chunk_audio, torch.Tensor):
chunk_audio = chunk_audio.cpu().numpy()
audio_chunks.append(chunk_audio)
# Concatenate chunks
if audio_chunks:
final_audio = np.concatenate(audio_chunks) if len(audio_chunks) > 1 else audio_chunks[0]
return (24000, final_audio)
else:
return None
except Exception as e:
print(f"Error generating speech: {str(e)}")
import traceback
traceback.print_exc()
return None
def process_query(query: str, history: List[List[str]], selected_voice: str = 'af') -> Dict[str, Any]:
"""Process user query with streaming effect"""
try:
if history is None:
history = []
# Get web results first
web_results = get_web_results(query)
sources_html = format_sources(web_results)
current_history = history + [[query, "*Searching...*"]]
yield {
answer_output: gr.Markdown("*Searching & Thinking...*"),
sources_output: gr.HTML(sources_html),
search_btn: gr.Button("Searching...", interactive=False),
chat_history_display: current_history,
audio_output: None
}
# Generate answer
prompt = format_prompt(query, web_results)
answer = generate_answer(prompt)
final_answer = answer.split("Answer:")[-1].strip()
# Update history *before* TTS (important for correct display)
updated_history = history + [[query, final_answer]]
# Generate speech from the answer (only if enabled)
if TTS_ENABLED:
yield { # Intermediate update before TTS
answer_output: gr.Markdown(final_answer),
sources_output: gr.HTML(sources_html),
search_btn: gr.Button("Generating audio...", interactive=False),
chat_history_display: updated_history,
audio_output: None
}
try:
audio = generate_speech_with_gpu(final_answer, selected_voice)
except Exception as e:
print(f"Error during TTS: {e}")
audio = None
else:
audio = None
yield {
answer_output: gr.Markdown(final_answer),
sources_output: gr.HTML(sources_html),
search_btn: gr.Button("Search", interactive=True),
chat_history_display: updated_history,
audio_output: audio if audio is not None else gr.Audio(value=None) # Ensure valid audio output
}
except Exception as e:
error_message = str(e)
if "GPU quota" in error_message:
error_message = "β οΈ GPU quota exceeded. Please try again later when the daily quota resets."
yield {
answer_output: gr.Markdown(f"Error: {error_message}"),
sources_output: gr.HTML(sources_html), #Still show sources on error
search_btn: gr.Button("Search", interactive=True),
chat_history_display: history + [[query, f"*Error: {error_message}*"]],
audio_output: None
}
# Update the CSS for better contrast and readability
css = """
.gradio-container {
max-width: 1200px !important;
background-color: #f7f7f8 !important;
}
#header {
text-align: center;
margin-bottom: 2rem;
padding: 2rem 0;
background: #1a1b1e;
border-radius: 12px;
color: white;
}
#header h1 {
color: white;
font-size: 2.5rem;
margin-bottom: 0.5rem;
}
#header h3 {
color: #a8a9ab;
}
.search-container {
background: #1a1b1e;
border-radius: 12px;
box-shadow: 0 4px 12px rgba(0,0,0,0.1);
padding: 1rem;
margin-bottom: 1rem;
}
.search-box {
padding: 1rem;
background: #2c2d30;
border-radius: 8px;
margin-bottom: 1rem;
}
/* Style the input textbox */
.search-box input[type="text"] {
background: #3a3b3e !important;
border: 1px solid #4a4b4e !important;
color: white !important;
border-radius: 8px !important;
}
.search-box input[type="text"]::placeholder {
color: #a8a9ab !important;
}
/* Style the search button */
.search-box button {
background: #2563eb !important;
border: none !important;
}
/* Results area styling */
.results-container {
background: #2c2d30;
border-radius: 8px;
padding: 1rem;
margin-top: 1rem;
}
.answer-box {
background: #3a3b3e;
border-radius: 8px;
padding: 1.5rem;
color: white;
margin-bottom: 1rem;
}
.answer-box p {
color: #e5e7eb;
line-height: 1.6;
}
.sources-container {
margin-top: 1rem;
background: #2c2d30;
border-radius: 8px;
padding: 1rem;
}
.source-item {
display: flex;
padding: 12px;
margin: 8px 0;
background: #3a3b3e;
border-radius: 8px;
transition: all 0.2s;
}
.source-item:hover {
background: #4a4b4e;
}
.source-number {
font-weight: bold;
margin-right: 12px;
color: #60a5fa;
}
.source-content {
flex: 1;
}
.source-title {
color: #60a5fa;
font-weight: 500;
text-decoration: none;
display: block;
margin-bottom: 4px;
}
.source-date {
color: #a8a9ab;
font-size: 0.9em;
margin-left: 8px;
}
.source-snippet {
color: #e5e7eb;
font-size: 0.9em;
line-height: 1.4;
}
.chat-history {
max-height: 400px;
overflow-y: auto;
padding: 1rem;
background: #2c2d30;
border-radius: 8px;
margin-top: 1rem;
}
.examples-container {
background: #2c2d30;
border-radius: 8px;
padding: 1rem;
margin-top: 1rem;
}
.examples-container button {
background: #3a3b3e !important;
border: 1px solid #4a4b4e !important;
color: #e5e7eb !important;
}
/* Markdown content styling */
.markdown-content {
color: #e5e7eb !important;
}
.markdown-content h1, .markdown-content h2, .markdown-content h3 {
color: white !important;
}
.markdown-content a {
color: #60a5fa !important;
}
/* Accordion styling */
.accordion {
background: #2c2d30 !important;
border-radius: 8px !important;
margin-top: 1rem !important;
}
.voice-selector {
margin-top: 1rem;
background: #2c2d30;
border-radius: 8px;
padding: 0.5rem;
}
.voice-selector select {
background: #3a3b3e !important;
color: white !important;
border: 1px solid #4a4b4e !important;
}
"""
# Update the Gradio interface layout
with gr.Blocks(title="AI Search Assistant", css=css, theme="dark") as demo:
chat_history = gr.State([])
with gr.Column(elem_id="header"):
gr.Markdown("# π AI Search Assistant")
gr.Markdown("### Powered by DeepSeek & Real-time Web Results with Voice")
with gr.Column(elem_classes="search-container"):
with gr.Row(elem_classes="search-box"):
search_input = gr.Textbox(
label="",
placeholder="Ask anything...",
scale=5,
container=False
)
search_btn = gr.Button("Search", variant="primary", scale=1)
voice_select = gr.Dropdown(
choices=list(VOICE_CHOICES.items()),
value='af',
label="Select Voice",
elem_classes="voice-selector"
)
with gr.Row(elem_classes="results-container"):
with gr.Column(scale=2):
with gr.Column(elem_classes="answer-box"):
answer_output = gr.Markdown(elem_classes="markdown-content")
with gr.Row():
audio_output = gr.Audio(label="Voice Response", elem_classes="audio-player")
with gr.Accordion("Chat History", open=False, elem_classes="accordion"):
chat_history_display = gr.Chatbot(elem_classes="chat-history")
with gr.Column(scale=1):
with gr.Column(elem_classes="sources-box"):
gr.Markdown("### Sources")
sources_output = gr.HTML()
with gr.Row(elem_classes="examples-container"):
gr.Examples(
examples=[
"musk explores blockchain for doge",
"nvidia to launch new gaming card",
"What are the best practices for sustainable living?",
"tesla mistaken for asteroid"
],
inputs=search_input,
label="Try these examples"
)
# Handle interactions
search_btn.click(
fn=process_query,
inputs=[search_input, chat_history, voice_select],
outputs=[answer_output, sources_output, search_btn, chat_history_display, audio_output]
)
# Also trigger search on Enter key
search_input.submit(
fn=process_query,
inputs=[search_input, chat_history, voice_select],
outputs=[answer_output, sources_output, search_btn, chat_history_display, audio_output]
)
if __name__ == "__main__":
demo.launch(share=True)