Spaces:
Sleeping
Sleeping
from flask import Flask, request, jsonify, render_template_string | |
import os | |
import requests | |
import json | |
import logging | |
from typing import Dict, Any, List | |
import time | |
app = Flask(__name__) | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
# Configuration | |
OLLAMA_API_URL = os.getenv('OLLAMA_API_URL', 'https://huggingface.co/spaces/tommytracx/ollama-api') | |
DEFAULT_MODEL = os.getenv('DEFAULT_MODEL', 'llama2,llama2:13b,llama2:70b,codellama,neural-chat,gemma-3-270m').split(',') | |
MAX_TOKENS = int(os.getenv('MAX_TOKENS', '2048')) | |
TEMPERATURE = float(os.getenv('TEMPERATURE', '0.7')) | |
class OllamaClient: | |
def __init__(self, api_url: str): | |
self.api_url = api_url.rstrip('/') | |
self.available_models = DEFAULT_MODEL # Initialize with default models | |
self.refresh_models() | |
def refresh_models(self) -> None: | |
"""Refresh the list of available models from the API, falling back to defaults on failure.""" | |
try: | |
response = requests.get(f"{self.api_url}/api/models", timeout=10) | |
response.raise_for_status() | |
data = response.json() | |
if data.get('status') == 'success' and isinstance(data.get('models'), list): | |
self.available_models = data['models'] | |
logging.info(f"Successfully fetched models: {self.available_models}") | |
else: | |
logging.warning(f"Invalid response format from API: {data}") | |
self.available_models = DEFAULT_MODEL | |
except Exception as e: | |
logging.error(f"Error refreshing models: {e}") | |
self.available_models = DEFAULT_MODEL | |
def list_models(self) -> List[str]: | |
"""Return the list of available models.""" | |
return self.available_models | |
def generate(self, model_name: str, prompt: str, **kwargs) -> Dict[str, Any]: | |
"""Generate text using a model.""" | |
if model_name not in self.available_models: | |
return {"status": "error", "message": f"Model {model_name} not available"} | |
try: | |
payload = { | |
"model": model_name, | |
"prompt": prompt, | |
"stream": False, | |
**kwargs | |
} | |
response = requests.post(f"{self.api_url}/api/generate", json=payload, timeout=120) | |
response.raise_for_status() | |
data = response.json() | |
if data.get('status') == 'success': | |
return { | |
"status": "success", | |
"response": data.get('response', ''), | |
"model": model_name, | |
"usage": data.get('usage', {}) | |
} | |
return {"status": "error", "message": data.get('message', 'Unknown error')} | |
except Exception as e: | |
logging.error(f"Error generating response: {e}") | |
return {"status": "error", "message": str(e)} | |
def health_check(self) -> Dict[str, Any]: | |
"""Check the health of the Ollama API.""" | |
try: | |
response = requests.get(f"{self.api_url}/health", timeout=10) | |
response.raise_for_status() | |
return response.json() | |
except Exception as e: | |
logging.error(f"Health check failed: {e}") | |
return {"status": "unhealthy", "error": str(e)} | |
# Initialize Ollama client | |
ollama_client = OllamaClient(OLLAMA_API_URL) | |
# HTML template for the chat interface | |
HTML_TEMPLATE = ''' | |
<!DOCTYPE html> | |
<html lang="en"> | |
<head> | |
<meta charset="UTF-8"> | |
<meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
<title>OpenWebUI - Ollama Chat</title> | |
<style> | |
/* [Previous CSS unchanged] */ | |
</style> | |
</head> | |
<body> | |
<div class="container"> | |
<div class="header"> | |
<h1>🤖 OpenWebUI</h1> | |
<p>Chat with your local Ollama models through Hugging Face Spaces</p> | |
</div> | |
<div class="controls"> | |
<div class="control-group"> | |
<label for="model-select">Model:</label> | |
<select id="model-select"> | |
<option value="">Select a model...</option> | |
</select> | |
</div> | |
<div class="control-group"> | |
<label for="temperature">Temperature:</label> | |
<input type="range" id="temperature" min="0" max="2" step="0.1" value="0.7"> | |
<span id="temp-value">0.7</span> | |
</div> | |
<div class="control-group"> | |
<label for="max-tokens">Max Tokens:</label> | |
<input type="number" id="max-tokens" min="1" max="4096" value="2048"> | |
</div> | |
</div> | |
<div class="chat-container" id="chat-container"> | |
<div class="message assistant"> | |
<div class="message-avatar">AI</div> | |
<div class="message-content"> | |
Hello! I'm your AI assistant powered by Ollama. How can I help you today? | |
</div> | |
</div> | |
</div> | |
<div class="typing-indicator" id="typing-indicator"> | |
AI is thinking... | |
</div> | |
<div class="input-container"> | |
<form class="input-form" id="chat-form"> | |
<textarea | |
class="input-field" | |
id="message-input" | |
placeholder="Type your message here..." | |
rows="1" | |
></textarea> | |
<button type="submit" class="send-button" id="send-button"> | |
Send | |
</button> | |
</form> | |
</div> | |
<div class="status" id="status"></div> | |
</div> | |
<script> | |
let conversationHistory = []; | |
document.addEventListener('DOMContentLoaded', function() { | |
loadModels(); | |
setupEventListeners(); | |
autoResizeTextarea(); | |
}); | |
async function loadModels() { | |
const modelSelect = document.getElementById('model-select'); | |
modelSelect.innerHTML = '<option value="">Loading models...</option>'; | |
try { | |
const response = await fetch('/api/models'); | |
const data = await response.json(); | |
modelSelect.innerHTML = '<option value="">Select a model...</option>'; | |
if (data.status === 'success' && data.models.length > 0) { | |
data.models.forEach(model => { | |
const option = document.createElement('option'); | |
option.value = model; | |
option.textContent = model; | |
if (model === '{{ default_model[0] }}') { | |
option.selected = true; | |
} | |
modelSelect.appendChild(option); | |
}); | |
showStatus('Models loaded successfully', 'success'); | |
} else { | |
modelSelect.innerHTML = '<option value="">No models available</option>'; | |
showStatus('No models available from API', 'error'); | |
} | |
} catch (error) { | |
console.error('Error loading models:', error); | |
modelSelect.innerHTML = '<option value="">No models available</option>'; | |
showStatus('Failed to load models: ' + error.message, 'error'); | |
} | |
} | |
function setupEventListeners() { | |
document.getElementById('chat-form').addEventListener('submit', handleSubmit); | |
document.getElementById('temperature').addEventListener('input', function() { | |
document.getElementById('temp-value').textContent = this.value; | |
}); | |
document.getElementById('message-input').addEventListener('input', autoResizeTextarea); | |
} | |
function autoResizeTextarea() { | |
const textarea = document.getElementById('message-input'); | |
textarea.style.height = 'auto'; | |
textarea.style.height = Math.min(textarea.scrollHeight, 120) + 'px'; | |
} | |
async function handleSubmit(e) { | |
e.preventDefault(); | |
const messageInput = document.getElementById('message-input'); | |
const message = messageInput.value.trim(); | |
if (!message) return; | |
const model = document.getElementById('model-select').value; | |
const temperature = parseFloat(document.getElementById('temperature').value); | |
const maxTokens = parseInt(document.getElementById('max-tokens').value); | |
if (!model) { | |
showStatus('Please select a model', 'error'); | |
return; | |
} | |
addMessage(message, 'user'); | |
messageInput.value = ''; | |
autoResizeTextarea(); | |
showTypingIndicator(true); | |
try { | |
const response = await fetch('/api/chat', { | |
method: 'POST', | |
headers: { 'Content-Type': 'application/json' }, | |
body: JSON.stringify({ model, message, temperature, max_tokens: maxTokens }) | |
}); | |
const data = await response.json(); | |
showTypingIndicator(false); | |
if (data.status === 'success') { | |
addMessage(data.response, 'assistant'); | |
showStatus(`Response generated using ${model}`, 'success'); | |
} else { | |
addMessage('Sorry, I encountered an error while processing your request.', 'assistant'); | |
showStatus(`Error: ${data.message}`, 'error'); | |
} | |
} catch (error) { | |
showTypingIndicator(false); | |
addMessage('Sorry, I encountered a network error.', 'assistant'); | |
showStatus('Network error: ' + error.message, 'error'); | |
} | |
} | |
function addMessage(content, sender) { | |
const chatContainer = document.getElementById('chat-container'); | |
const messageDiv = document.createElement('div'); | |
messageDiv.className = `message ${sender}`; | |
const avatar = document.createElement('div'); | |
avatar.className = 'message-avatar'; | |
avatar.textContent = sender === 'user' ? 'U' : 'AI'; | |
const messageContent = document.createElement('div'); | |
messageContent.className = 'message-content'; | |
messageContent.textContent = content; | |
messageDiv.appendChild(avatar); | |
messageDiv.appendChild(messageContent); | |
chatContainer.appendChild(messageDiv); | |
chatContainer.scrollTop = chatContainer.scrollHeight; | |
conversationHistory.push({ role: sender, content: content }); | |
} | |
function showTypingIndicator(show) { | |
const indicator = document.getElementById('typing-indicator'); | |
indicator.style.display = show ? 'block' : 'none'; | |
if (show) { | |
const chatContainer = document.getElementById('chat-container'); | |
chatContainer.scrollTop = chatContainer.scrollHeight; | |
} | |
} | |
function showStatus(message, type = '') { | |
const statusDiv = document.getElementById('status'); | |
statusDiv.textContent = message; | |
statusDiv.className = `status ${type}`; | |
setTimeout(() => { | |
statusDiv.textContent = ''; | |
statusDiv.className = 'status'; | |
}, 5000); | |
} | |
</script> | |
</body> | |
</html> | |
''' | |
def home(): | |
"""Main chat interface.""" | |
return render_template_string(HTML_TEMPLATE, ollama_api_url=OLLAMA_API_URL, default_model=DEFAULT_MODEL) | |
def chat(): | |
"""Chat API endpoint.""" | |
try: | |
data = request.get_json() | |
if not data or 'message' not in data or 'model' not in data: | |
return jsonify({"status": "error", "message": "Message and model are required"}), 400 | |
message = data['message'] | |
model = data['model'] | |
temperature = data.get('temperature', TEMPERATURE) | |
max_tokens = data.get('max_tokens', MAX_TOKENS) | |
result = ollama_client.generate(model, message, temperature=temperature, max_tokens=max_tokens) | |
return jsonify(result), 200 if result["status"] == "success" else 500 | |
except Exception as e: | |
logging.error(f"Chat endpoint error: {e}") | |
return jsonify({"status": "error", "message": str(e)}), 500 | |
def get_models(): | |
"""Get available models.""" | |
try: | |
models = ollama_client.list_models() | |
return jsonify({ | |
"status": "success", | |
"models": models, | |
"count": len(models) | |
}) | |
except Exception as e: | |
logging.error(f"Models endpoint error: {e}") | |
return jsonify({"status": "error", "message": str(e)}), 500 | |
def health_check(): | |
"""Health check endpoint.""" | |
try: | |
ollama_health = ollama_client.health_check() | |
return jsonify({ | |
"status": "healthy", | |
"ollama_api": ollama_health, | |
"timestamp": time.time() | |
}) | |
except Exception as e: | |
logging.error(f"Health check endpoint error: {e}") | |
return jsonify({ | |
"status": "unhealthy", | |
"error": str(e), | |
"timestamp": time.time() | |
}), 500 | |
if __name__ == '__main__': | |
app.run(host='0.0.0.0', port=7860, debug=False) |