# -*- coding: utf-8 -*- """ UseCase_with_Streamlit.py This basic Streamlit app fetches Reddit posts from a few subreddits over the past 14 days, computes sentiment scores using a PyTorch model, forecasts a 7-day sentiment trend using a pre-trained forecast model, and displays the forecast plot. Note: No extra logging or scheduling is included. """ import os, re, datetime, io import numpy as np import pandas as pd import matplotlib.pyplot as plt from scipy.interpolate import make_interp_spline import matplotlib.font_manager as fm import joblib import torch import torch.nn as nn from transformers import AutoTokenizer import praw import streamlit as st # ------------------------------- # Inject custom CSS for Afacada font styling # ------------------------------- st.markdown( """ """, unsafe_allow_html=True ) # ------------------------------- # Load Models and Tokenizer # ------------------------------- sentiment_model = joblib.load('sentiment_forecast_model.pkl') MODEL = "cardiffnlp/xlm-twitter-politics-sentiment" tokenizer = AutoTokenizer.from_pretrained(MODEL) class ScorePredictor(nn.Module): def __init__(self, vocab_size, embedding_dim=128, hidden_dim=256, output_dim=1): super(ScorePredictor, self).__init__() self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0) self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True) self.fc = nn.Linear(hidden_dim, output_dim) self.sigmoid = nn.Sigmoid() def forward(self, input_ids, attention_mask): embedded = self.embedding(input_ids) lstm_out, _ = self.lstm(embedded) final_hidden_state = lstm_out[:, -1, :] output = self.fc(final_hidden_state) return self.sigmoid(output) score_model = ScorePredictor(tokenizer.vocab_size) score_model.load_state_dict(torch.load("score_predictor.pth")) score_model.eval() # ------------------------------- # Set up Reddit API Client # ------------------------------- reddit = praw.Reddit( client_id=os.environ.get("REDDIT_CLIENT_ID"), client_secret=os.environ.get("REDDIT_CLIENT_SECRET"), user_agent='MyAPI/0.0.1', check_for_async=False ) # ------------------------------- # Helper Functions # ------------------------------- def fetch_posts(subreddit_name, start_time, limit=100): posts = [] subreddit = reddit.subreddit(subreddit_name) for post in subreddit.new(limit=limit): post_time = datetime.datetime.utcfromtimestamp(post.created_utc) if post_time >= start_time: posts.append({ "date": post_time.strftime('%Y-%m-%d %H:%M:%S'), "post_text": post.title }) return posts def predict_score(text): if not text: return 0.0 encoded = tokenizer(text.split(), return_tensors='pt', padding=True, truncation=True) with torch.no_grad(): score = score_model(encoded["input_ids"], encoded["attention_mask"])[0].item() return score # ------------------------------- # Streamlit Interface # ------------------------------- st.title("7-Day Sentiment Forecast") if st.button("Run Analysis"): subreddits = ["ohio", "libertarian", "centrist"] start_time = datetime.datetime.utcnow() - datetime.timedelta(days=14) all_posts = [] for sub in subreddits: all_posts.extend(fetch_posts(sub, start_time)) if not all_posts: st.error("No posts fetched.") else: df = pd.DataFrame(all_posts) df['date'] = pd.to_datetime(df['date']) df['date_only'] = df['date'].dt.date df = df.sort_values(by='date_only') df['sentiment_score'] = df['post_text'].apply(predict_score) daily_sentiment = df.groupby('date_only')['sentiment_score'].mean() if len(daily_sentiment) < 14: mean_val = daily_sentiment.mean() pad = [mean_val] * (14 - len(daily_sentiment)) daily_sentiment = np.concatenate([daily_sentiment.values, pad]) daily_sentiment = pd.Series(daily_sentiment) forecast = sentiment_model.predict(daily_sentiment.values.reshape(1, -1))[0] font_path = "AfacadFlux-VariableFont_slnt,wght[1].ttf" custom_font = fm.FontProperties(fname=font_path) today = datetime.date.today() days = [today + datetime.timedelta(days=i) for i in range(7)] days_str = [d.strftime('%a %m/%d') for d in days] xnew = np.linspace(0, 6, 300) spline = make_interp_spline(np.arange(7), forecast, k=3) smooth_forecast = spline(xnew) fig, ax = plt.subplots(figsize=(8, 5)) ax.fill_between(xnew, smooth_forecast, color='#244B48', alpha=0.4) ax.plot(xnew, smooth_forecast, color='#244B48', lw=3) ax.scatter(np.arange(7), forecast, color='#244B48', s=50) ax.set_title("7-Day Sentiment Forecast", fontproperties=custom_font, fontsize=20) ax.set_xlabel("Day", fontproperties=custom_font, fontsize=14) ax.set_ylabel("Sentiment", fontproperties=custom_font, fontsize=14) ax.set_xticks(np.arange(7)) ax.set_xticklabels(days_str, fontproperties=custom_font, fontsize=12) plt.tight_layout() buf = io.BytesIO() fig.savefig(buf, format='png') buf.seek(0) st.image(buf, caption="Forecast Plot")