File size: 4,171 Bytes
bf7efe3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import streamlit as st
import trafilatura
import numpy as np
import pandas as pd
from tensorflow.lite.python.interpreter import Interpreter
import requests

# File paths
MODEL_PATH = "./model.tflite"
VOCAB_PATH = "./vocab.txt"
LABELS_PATH = "./taxonomy_v2.csv"

@st.cache_resource
def load_vocab():
    with open(VOCAB_PATH, 'r') as f:
        vocab = [line.strip() for line in f]
    return vocab

@st.cache_resource
def load_labels():
    # Load labels from the CSV file
    taxonomy = pd.read_csv(LABELS_PATH)
    taxonomy["ID"] = taxonomy["ID"].astype(int)
    labels_dict = taxonomy.set_index("ID")["Topic"].to_dict()
    return labels_dict

@st.cache_resource
def load_model():
    try:
        # Use TensorFlow Lite Interpreter
        interpreter = Interpreter(model_path=MODEL_PATH)
        interpreter.allocate_tensors()
        input_details = interpreter.get_input_details()
        output_details = interpreter.get_output_details()
        return interpreter, input_details, output_details
    except Exception as e:
        st.error(f"Failed to load the model: {e}")
        raise

def preprocess_text(text, vocab, max_length=128):
    # Tokenize the text using the provided vocabulary
    words = text.split()[:max_length]  # Split and truncate
    token_ids = [vocab.index(word) if word in vocab else vocab.index("[UNK]") for word in words]
    token_ids = np.array(token_ids + [0] * (max_length - len(token_ids)), dtype=np.int32)  # Pad to max length
    attention_mask = np.array([1 if i < len(words) else 0 for i in range(max_length)], dtype=np.int32)
    token_type_ids = np.zeros_like(attention_mask, dtype=np.int32)
    return token_ids[np.newaxis, :], attention_mask[np.newaxis, :], token_type_ids[np.newaxis, :]

def classify_text(interpreter, input_details, output_details, input_word_ids, input_mask, input_type_ids):
    interpreter.set_tensor(input_details[0]["index"], input_word_ids)
    interpreter.set_tensor(input_details[1]["index"], input_mask)
    interpreter.set_tensor(input_details[2]["index"], input_type_ids)
    interpreter.invoke()
    output = interpreter.get_tensor(output_details[0]["index"])
    return output[0]

def fetch_url_content(url):
    headers = {
        "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/116.0.0.0 Safari/537.36",
        "Accept-Language": "en-US,en;q=0.9",
        "Accept-Encoding": "gzip, deflate, br",
    }
    try:
        response = requests.get(url, headers=headers, cookies={}, timeout=10)
        if response.status_code == 200:
            return response.text
        else:
            st.error(f"Failed to fetch content. Status code: {response.status_code}")
            return None
    except Exception as e:
        st.error(f"Error fetching content: {e}")
        return None

# Streamlit app
st.title("Topic Classification from URL")

url = st.text_input("Enter a URL:", "")
if url:
    st.write("Extracting content from the URL...")
    raw_content = fetch_url_content(url)
    if raw_content:
        content = trafilatura.extract(raw_content)
        if content:
            st.write("Content extracted successfully!")
            st.write(content[:500])  # Display a snippet of the content

            # Load resources
            vocab = load_vocab()
            labels_dict = load_labels()
            interpreter, input_details, output_details = load_model()

            # Preprocess content and classify
            input_word_ids, input_mask, input_type_ids = preprocess_text(content, vocab)
            predictions = classify_text(interpreter, input_details, output_details, input_word_ids, input_mask, input_type_ids)

            # Display classification
            st.write("Topic Classification:")
            sorted_indices = np.argsort(predictions)[::-1][:5]  # Top 5 topics
            for idx in sorted_indices:
                topic = labels_dict.get(idx, "Unknown Topic")
                st.write(f"ID: {idx} - Topic: {topic} - Score: {predictions[idx]:.4f}")
        else:
            st.error("Unable to extract content from the fetched HTML.")
    else:
        st.error("Failed to fetch the URL.")