Spaces:
Sleeping
Sleeping
import logging | |
import json | |
import os | |
import zipfile | |
from difflib import get_close_matches | |
import base64 | |
from io import BytesIO | |
from PIL import Image | |
import random | |
import numpy as np | |
import gradio as gr | |
# Konfiguration des Loggers | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
# Globale Variablen | |
category_nodes = [] | |
questions = [] | |
initialized = False | |
# Dateipfade | |
ZIP_FILE = "DRLCogNet.zip" | |
MODEL_FILE = "model.DRLCogNet" # Datei ist direkt in der ZIP | |
def extract_model(): | |
""" | |
Entpackt die ZIP-Datei in das aktuelle Verzeichnis, wenn die Modelldatei nicht existiert. | |
""" | |
if not os.path.exists(MODEL_FILE): | |
logging.info("Entpacke Modelldatei...") | |
with zipfile.ZipFile(ZIP_FILE, 'r') as zip_ref: | |
zip_ref.extractall(".") # Entpackt alle Dateien ins aktuelle Verzeichnis | |
logging.info("Modell erfolgreich entpackt.") | |
# Überprüfe, ob die Dateien entpackt wurden | |
extracted_files = os.listdir(".") | |
logging.info(f"Entpackte Dateien: {extracted_files}") | |
# Prüfe, ob die Modelldatei tatsächlich vorhanden ist | |
if MODEL_FILE in extracted_files: | |
logging.info(f"Die Datei {MODEL_FILE} wurde erfolgreich extrahiert.") | |
else: | |
logging.error(f"Die Datei {MODEL_FILE} wurde NICHT gefunden.") | |
else: | |
logging.info("Modell ist bereits entpackt.") | |
def load_model_with_questions_and_answers(): | |
""" | |
Lädt das Modell mit Fragen und Antworten aus einer JSON-Datei. | |
Returns: | |
tuple: Die Liste der Kategorie-Knoten und die Liste der Fragen. | |
""" | |
global initialized | |
if initialized: | |
logging.info("Modell bereits initialisiert.") | |
return None, None | |
extract_model() | |
if not os.path.exists(MODEL_FILE): | |
logging.error(f"Modelldatei {MODEL_FILE} nicht gefunden. Überprüfe die ZIP-Datei.") | |
return None, None | |
try: | |
with open(MODEL_FILE, "r", encoding="utf-8") as file: | |
model_data = json.load(file) | |
nodes_dict = {node_data["label"]: Node(node_data["label"]) for node_data in model_data["nodes"]} | |
for node_data in model_data["nodes"]: | |
node = nodes_dict[node_data["label"]] | |
node.activation = node_data.get("activation", 0.0) | |
node.decay_rate = float(node_data.get("decay_rate", 0.0)) # Laden der decay_rate | |
for conn_state in node_data["connections"]: | |
target_node = nodes_dict.get(conn_state["target"]) | |
if target_node: | |
node.add_connection(target_node, conn_state["weight"]) | |
global questions | |
questions = model_data.get("questions", []) | |
logging.info(f"Modell geladen mit {len(nodes_dict)} Knoten und {len(questions)} Fragen") | |
initialized = True | |
return list(nodes_dict.values()), questions | |
except json.JSONDecodeError as e: | |
logging.error(f"Fehler beim Parsen der JSON-Datei: {e}") | |
return None, None | |
def find_similar_question(questions, query): | |
""" | |
Findet die ähnlichste Frage basierend auf einfachen Ähnlichkeitsmetriken. | |
Args: | |
questions (list): Liste aller Fragen. | |
query (str): Die Abfrage, nach der gesucht werden soll. | |
Returns: | |
dict: Die ähnlichste Frage. | |
""" | |
question_texts = [q['question'] for q in questions] | |
closest_matches = get_close_matches(query, question_texts, n=1, cutoff=0.3) | |
if closest_matches: | |
matched_question = next((q for q in questions if q['question'] == closest_matches[0]), None) | |
return matched_question | |
else: | |
return {"question": "Keine passende Frage gefunden", "category": "Unbekannt"} | |
def find_best_answer(category_nodes, questions, query): | |
""" | |
Findet die beste Antwort auf eine Abfrage. | |
Args: | |
category_nodes (list): Liste der Kategorie-Knoten. | |
questions (list): Liste der Fragen. | |
query (str): Die Abfrage. | |
Returns: | |
str: Die beste Antwort. | |
float: Die Aktivierung des Kategorie-Knotens. | |
""" | |
matched_question = find_similar_question(questions, query) | |
if matched_question: | |
logging.info(f"Gefundene Frage: {matched_question['question']} -> Kategorie: {matched_question['category']}") | |
answer = matched_question.get("answer", "Keine Antwort verfügbar") | |
logging.info(f"Antwort: {answer}") | |
activation = 0.8 # Beispielhafte Aktivierung | |
return answer, activation | |
else: | |
logging.warning("Keine passende Frage gefunden.") | |
return None, None | |
class Node: | |
""" | |
Ein Knoten im Netzwerk. | |
""" | |
def __init__(self, label): | |
self.label = label | |
self.connections = [] | |
self.activation = 0.0 | |
self.decay_rate = 0.0 | |
def add_connection(self, target_node, weight=None): | |
""" | |
Fügt eine Verbindung zu einem Zielknoten hinzu. | |
Args: | |
target_node (Node): Der Zielknoten. | |
weight (float): Das Gewicht der Verbindung. | |
""" | |
self.connections.append(Connection(target_node, weight)) | |
class Connection: | |
""" | |
Eine Verbindung zwischen zwei Knoten im Netzwerk. | |
""" | |
def __init__(self, target_node, weight=None): | |
self.target_node = target_node | |
self.weight = weight if weight is not None else random.uniform(0.1, 1.0) | |
def get_answer(query): | |
answer, activation = find_best_answer(category_nodes, questions, query) | |
if answer: | |
# Dekodieren der Base64-Bilddaten | |
try: | |
image_data = base64.b64decode(answer) | |
image = Image.open(BytesIO(image_data)) | |
return image, activation, f"{activation:.2f}" | |
except Exception as e: | |
logging.error(f"Fehler beim Dekodieren des Bildes: {e}") | |
return None, None, "0.00" | |
else: | |
return None, None, "0.00" | |
# Lade das Modell und die Fragen | |
category_nodes, questions = load_model_with_questions_and_answers() | |
if category_nodes is None or questions is None: | |
logging.error("Fehler beim Laden des Modells. Stelle sicher, dass die ZIP-Datei vorhanden ist.") | |
else: | |
# Erstelle die Gradio-App | |
with gr.Blocks(title="DRL-CogNet Dog Finder: wrote a number behind the dog name") as demo: | |
gr.Markdown("# Frage an das Modell (wrote a numer behind the dog name: ein golden retriever 2") | |
with gr.Row(): | |
with gr.Column(): | |
question_input = gr.Textbox(label="Frage") | |
submit_button = gr.Button("Antwort abrufen") | |
with gr.Column(): | |
image_output = gr.Image(label="Antwortbild") | |
activation_output = gr.Number(label="Aktivierung") | |
weight_output = gr.Text(label="Gewichtung") | |
submit_button.click(fn=get_answer, inputs=question_input, outputs=[image_output, activation_output, weight_output]) | |
demo.launch() | |