File size: 6,227 Bytes
b95ce7f 0ab27bf c6fec7c 0ab27bf b95ce7f 0ab27bf 2107f88 ec27495 c23a873 b62773b 1c12941 b95ce7f 1c12941 b95ce7f c6fec7c 0ab27bf a0781f2 4073582 6905374 4073582 a0781f2 d4cc7c1 a0781f2 d4cc7c1 a0781f2 6795b2f d4cc7c1 a0781f2 cb29d9b c6fec7c 1c12941 ec27495 b95ce7f ec27495 c6fec7c 0ab27bf c6fec7c ec27495 b95ce7f 1c12941 ff66593 c23a873 c6fec7c 0ab27bf 2da39c5 0ab27bf ec27495 0ab27bf ec27495 0ab27bf ec27495 c6fec7c d4cc7c1 c6fec7c 0ab27bf c6fec7c 0ab27bf a0781f2 0ab27bf c6fec7c b95ce7f 754d7eb 27ed114 b62773b 1c12941 a0781f2 6795b2f d4cc7c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
# -*- coding: utf-8 -*-
"""
UseCase_with_Streamlit.py
This basic Streamlit app fetches Reddit posts from a few subreddits over the past 14 days,
computes sentiment scores using a PyTorch model, forecasts a 7-day sentiment trend using a pre-trained forecast model,
and displays the forecast plot.
Note: No extra logging or scheduling is included.
"""
import os, re, datetime, io
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.interpolate import make_interp_spline
import matplotlib.font_manager as fm
import joblib
import torch
import torch.nn as nn
from transformers import AutoTokenizer
import praw
import streamlit as st
# -------------------------------
# Inject custom CSS for Afacada font styling
# -------------------------------
st.markdown(
"""
<style>
body {
background-color: #fffff2;
}
@font-face {
font-family: 'Afacada';
src: url('AfacadFlux-VariableFont_slnt,wght[1].ttf') format('truetype');
font-weight: normal;
font-style: normal;
}
/* Title styling */
h1 {
font-family: 'Afacada', sans-serif;
color: #244B48;
}
/* Button styling */
.stButton>button {
font-family: 'Afacada', sans-serif;
font-size: 20px;
padding: 0.75rem 1.5rem;
background-color: #244B48;
color: white;
border: none;
border-radius: 4px;
}
.stButton>button:hover {
background-color: #1f3e38;
color: white;
}
.stButton>button:active, .stButton>button:focus {
background-color: #244B48;
color: white;
outline: none;
}
</style>
""",
unsafe_allow_html=True
)
# -------------------------------
# Load Models and Tokenizer
# -------------------------------
sentiment_model = joblib.load('sentiment_forecast_model.pkl')
MODEL = "cardiffnlp/xlm-twitter-politics-sentiment"
tokenizer = AutoTokenizer.from_pretrained(MODEL)
class ScorePredictor(nn.Module):
def __init__(self, vocab_size, embedding_dim=128, hidden_dim=256, output_dim=1):
super(ScorePredictor, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
self.fc = nn.Linear(hidden_dim, output_dim)
self.sigmoid = nn.Sigmoid()
def forward(self, input_ids, attention_mask):
embedded = self.embedding(input_ids)
lstm_out, _ = self.lstm(embedded)
final_hidden_state = lstm_out[:, -1, :]
output = self.fc(final_hidden_state)
return self.sigmoid(output)
score_model = ScorePredictor(tokenizer.vocab_size)
score_model.load_state_dict(torch.load("score_predictor.pth"))
score_model.eval()
# -------------------------------
# Set up Reddit API Client
# -------------------------------
reddit = praw.Reddit(
client_id=os.environ.get("REDDIT_CLIENT_ID"),
client_secret=os.environ.get("REDDIT_CLIENT_SECRET"),
user_agent='MyAPI/0.0.1',
check_for_async=False
)
# -------------------------------
# Helper Functions
# -------------------------------
def fetch_posts(subreddit_name, start_time, limit=100):
posts = []
subreddit = reddit.subreddit(subreddit_name)
for post in subreddit.new(limit=limit):
post_time = datetime.datetime.utcfromtimestamp(post.created_utc)
if post_time >= start_time:
posts.append({
"date": post_time.strftime('%Y-%m-%d %H:%M:%S'),
"post_text": post.title
})
return posts
def predict_score(text):
if not text:
return 0.0
encoded = tokenizer(text.split(), return_tensors='pt', padding=True, truncation=True)
with torch.no_grad():
score = score_model(encoded["input_ids"], encoded["attention_mask"])[0].item()
return score
# -------------------------------
# Streamlit Interface
# -------------------------------
st.title("7-Day Sentiment Forecast")
if st.button("Run Analysis"):
subreddits = ["ohio", "libertarian", "centrist"]
start_time = datetime.datetime.utcnow() - datetime.timedelta(days=14)
all_posts = []
for sub in subreddits:
all_posts.extend(fetch_posts(sub, start_time))
if not all_posts:
st.error("No posts fetched.")
else:
df = pd.DataFrame(all_posts)
df['date'] = pd.to_datetime(df['date'])
df['date_only'] = df['date'].dt.date
df = df.sort_values(by='date_only')
df['sentiment_score'] = df['post_text'].apply(predict_score)
daily_sentiment = df.groupby('date_only')['sentiment_score'].mean()
if len(daily_sentiment) < 14:
mean_val = daily_sentiment.mean()
pad = [mean_val] * (14 - len(daily_sentiment))
daily_sentiment = np.concatenate([daily_sentiment.values, pad])
daily_sentiment = pd.Series(daily_sentiment)
forecast = sentiment_model.predict(daily_sentiment.values.reshape(1, -1))[0]
font_path = "AfacadFlux-VariableFont_slnt,wght[1].ttf"
custom_font = fm.FontProperties(fname=font_path)
today = datetime.date.today()
days = [today + datetime.timedelta(days=i) for i in range(7)]
days_str = [d.strftime('%a %m/%d') for d in days]
xnew = np.linspace(0, 6, 300)
spline = make_interp_spline(np.arange(7), forecast, k=3)
smooth_forecast = spline(xnew)
fig, ax = plt.subplots(figsize=(8, 5))
ax.fill_between(xnew, smooth_forecast, color='#244B48', alpha=0.4)
ax.plot(xnew, smooth_forecast, color='#244B48', lw=3)
ax.scatter(np.arange(7), forecast, color='#244B48', s=50)
ax.set_title("7-Day Sentiment Forecast", fontproperties=custom_font, fontsize=20)
ax.set_xlabel("Day", fontproperties=custom_font, fontsize=14)
ax.set_ylabel("Sentiment", fontproperties=custom_font, fontsize=14)
ax.set_xticks(np.arange(7))
ax.set_xticklabels(days_str, fontproperties=custom_font, fontsize=12)
plt.tight_layout()
buf = io.BytesIO()
fig.savefig(buf, format='png')
buf.seek(0)
st.image(buf, caption="Forecast Plot")
|