Rkruemmel's picture
Update app.py
bf435bf verified
import logging
import json
import os
import zipfile
from difflib import get_close_matches
import base64
from io import BytesIO
from PIL import Image
import random
import numpy as np
import gradio as gr
# Konfiguration des Loggers
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Globale Variablen
category_nodes = []
questions = []
initialized = False
# Dateipfade
ZIP_FILE = "DRLCogNet.zip"
MODEL_FILE = "model.DRLCogNet" # Datei ist direkt in der ZIP
def extract_model():
"""
Entpackt die ZIP-Datei in das aktuelle Verzeichnis, wenn die Modelldatei nicht existiert.
"""
if not os.path.exists(MODEL_FILE):
logging.info("Entpacke Modelldatei...")
with zipfile.ZipFile(ZIP_FILE, 'r') as zip_ref:
zip_ref.extractall(".") # Entpackt alle Dateien ins aktuelle Verzeichnis
logging.info("Modell erfolgreich entpackt.")
# Überprüfe, ob die Dateien entpackt wurden
extracted_files = os.listdir(".")
logging.info(f"Entpackte Dateien: {extracted_files}")
# Prüfe, ob die Modelldatei tatsächlich vorhanden ist
if MODEL_FILE in extracted_files:
logging.info(f"Die Datei {MODEL_FILE} wurde erfolgreich extrahiert.")
else:
logging.error(f"Die Datei {MODEL_FILE} wurde NICHT gefunden.")
else:
logging.info("Modell ist bereits entpackt.")
def load_model_with_questions_and_answers():
"""
Lädt das Modell mit Fragen und Antworten aus einer JSON-Datei.
Returns:
tuple: Die Liste der Kategorie-Knoten und die Liste der Fragen.
"""
global initialized
if initialized:
logging.info("Modell bereits initialisiert.")
return None, None
extract_model()
if not os.path.exists(MODEL_FILE):
logging.error(f"Modelldatei {MODEL_FILE} nicht gefunden. Überprüfe die ZIP-Datei.")
return None, None
try:
with open(MODEL_FILE, "r", encoding="utf-8") as file:
model_data = json.load(file)
nodes_dict = {node_data["label"]: Node(node_data["label"]) for node_data in model_data["nodes"]}
for node_data in model_data["nodes"]:
node = nodes_dict[node_data["label"]]
node.activation = node_data.get("activation", 0.0)
node.decay_rate = float(node_data.get("decay_rate", 0.0)) # Laden der decay_rate
for conn_state in node_data["connections"]:
target_node = nodes_dict.get(conn_state["target"])
if target_node:
node.add_connection(target_node, conn_state["weight"])
global questions
questions = model_data.get("questions", [])
logging.info(f"Modell geladen mit {len(nodes_dict)} Knoten und {len(questions)} Fragen")
initialized = True
return list(nodes_dict.values()), questions
except json.JSONDecodeError as e:
logging.error(f"Fehler beim Parsen der JSON-Datei: {e}")
return None, None
def find_similar_question(questions, query):
"""
Findet die ähnlichste Frage basierend auf einfachen Ähnlichkeitsmetriken.
Args:
questions (list): Liste aller Fragen.
query (str): Die Abfrage, nach der gesucht werden soll.
Returns:
dict: Die ähnlichste Frage.
"""
question_texts = [q['question'] for q in questions]
closest_matches = get_close_matches(query, question_texts, n=1, cutoff=0.3)
if closest_matches:
matched_question = next((q for q in questions if q['question'] == closest_matches[0]), None)
return matched_question
else:
return {"question": "Keine passende Frage gefunden", "category": "Unbekannt"}
def find_best_answer(category_nodes, questions, query):
"""
Findet die beste Antwort auf eine Abfrage.
Args:
category_nodes (list): Liste der Kategorie-Knoten.
questions (list): Liste der Fragen.
query (str): Die Abfrage.
Returns:
str: Die beste Antwort.
float: Die Aktivierung des Kategorie-Knotens.
"""
matched_question = find_similar_question(questions, query)
if matched_question:
logging.info(f"Gefundene Frage: {matched_question['question']} -> Kategorie: {matched_question['category']}")
answer = matched_question.get("answer", "Keine Antwort verfügbar")
logging.info(f"Antwort: {answer}")
activation = 0.8 # Beispielhafte Aktivierung
return answer, activation
else:
logging.warning("Keine passende Frage gefunden.")
return None, None
class Node:
"""
Ein Knoten im Netzwerk.
"""
def __init__(self, label):
self.label = label
self.connections = []
self.activation = 0.0
self.decay_rate = 0.0
def add_connection(self, target_node, weight=None):
"""
Fügt eine Verbindung zu einem Zielknoten hinzu.
Args:
target_node (Node): Der Zielknoten.
weight (float): Das Gewicht der Verbindung.
"""
self.connections.append(Connection(target_node, weight))
class Connection:
"""
Eine Verbindung zwischen zwei Knoten im Netzwerk.
"""
def __init__(self, target_node, weight=None):
self.target_node = target_node
self.weight = weight if weight is not None else random.uniform(0.1, 1.0)
def get_answer(query):
answer, activation = find_best_answer(category_nodes, questions, query)
if answer:
# Dekodieren der Base64-Bilddaten
try:
image_data = base64.b64decode(answer)
image = Image.open(BytesIO(image_data))
return image, activation, f"{activation:.2f}"
except Exception as e:
logging.error(f"Fehler beim Dekodieren des Bildes: {e}")
return None, None, "0.00"
else:
return None, None, "0.00"
# Lade das Modell und die Fragen
category_nodes, questions = load_model_with_questions_and_answers()
if category_nodes is None or questions is None:
logging.error("Fehler beim Laden des Modells. Stelle sicher, dass die ZIP-Datei vorhanden ist.")
else:
# Erstelle die Gradio-App
with gr.Blocks(title="DRL-CogNet Dog Finder: wrote a number behind the dog name") as demo:
gr.Markdown("# Frage an das Modell (wrote a numer behind the dog name: ein golden retriever 2")
with gr.Row():
with gr.Column():
question_input = gr.Textbox(label="Frage")
submit_button = gr.Button("Antwort abrufen")
with gr.Column():
image_output = gr.Image(label="Antwortbild")
activation_output = gr.Number(label="Aktivierung")
weight_output = gr.Text(label="Gewichtung")
submit_button.click(fn=get_answer, inputs=question_input, outputs=[image_output, activation_output, weight_output])
demo.launch()