AlainDeLong commited on
Commit
fa7aa9f
·
1 Parent(s): cc4fd89

create app

Browse files
.gitignore ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ .idea/
169
+
170
+ # Ruff stuff:
171
+ .ruff_cache/
172
+
173
+ # PyPI configuration file
174
+ .pypirc
175
+
176
+ # Streamlit Secrets
177
+ .streamlit/
178
+
179
+ # Youtube Links
180
+ src/link.txt
181
+
182
+ # Test files
183
+ src/test_plotly_script.py
requirements.txt CHANGED
@@ -1,3 +1,8 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
 
 
1
+ altair==5.5.0
2
+ pandas==2.2.2
3
+ streamlit==1.45.0
4
+ torch==2.5.1
5
+ transformers==4.46.2
6
+ regex==2024.11.6
7
+ plotly==6.0.1
8
+ google-api-python-client==2.169.0
src/fine_tuned_model/config.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "RobertaForSequenceClassification"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "bos_token_id": 0,
7
+ "classifier_dropout": null,
8
+ "eos_token_id": 2,
9
+ "hidden_act": "gelu",
10
+ "hidden_dropout_prob": 0.1,
11
+ "hidden_size": 768,
12
+ "id2label": {
13
+ "0": "negative",
14
+ "1": "neutral",
15
+ "2": "positive"
16
+ },
17
+ "initializer_range": 0.02,
18
+ "intermediate_size": 3072,
19
+ "label2id": {
20
+ "negative": 0,
21
+ "neutral": 1,
22
+ "positive": 2
23
+ },
24
+ "layer_norm_eps": 1e-05,
25
+ "max_position_embeddings": 514,
26
+ "model_type": "roberta",
27
+ "num_attention_heads": 12,
28
+ "num_hidden_layers": 12,
29
+ "pad_token_id": 1,
30
+ "position_embedding_type": "absolute",
31
+ "problem_type": "single_label_classification",
32
+ "torch_dtype": "float32",
33
+ "transformers_version": "4.51.3",
34
+ "type_vocab_size": 1,
35
+ "use_cache": true,
36
+ "vocab_size": 50265
37
+ }
src/fine_tuned_model/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
src/fine_tuned_model/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ff26327999e09e76218bd59e2f78b1445a2720ea58fb27c15f47ae3f1e6cd42e
3
+ size 498615900
src/fine_tuned_model/special_tokens_map.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<s>",
3
+ "cls_token": "<s>",
4
+ "eos_token": "</s>",
5
+ "mask_token": {
6
+ "content": "<mask>",
7
+ "lstrip": true,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "pad_token": "<pad>",
13
+ "sep_token": "</s>",
14
+ "unk_token": "<unk>"
15
+ }
src/fine_tuned_model/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
src/fine_tuned_model/tokenizer_config.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "<s>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "1": {
13
+ "content": "<pad>",
14
+ "lstrip": false,
15
+ "normalized": true,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "2": {
21
+ "content": "</s>",
22
+ "lstrip": false,
23
+ "normalized": true,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ },
28
+ "3": {
29
+ "content": "<unk>",
30
+ "lstrip": false,
31
+ "normalized": true,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": true
35
+ },
36
+ "50264": {
37
+ "content": "<mask>",
38
+ "lstrip": true,
39
+ "normalized": false,
40
+ "rstrip": false,
41
+ "single_word": false,
42
+ "special": true
43
+ }
44
+ },
45
+ "bos_token": "<s>",
46
+ "clean_up_tokenization_spaces": false,
47
+ "cls_token": "<s>",
48
+ "eos_token": "</s>",
49
+ "errors": "replace",
50
+ "extra_special_tokens": {},
51
+ "mask_token": "<mask>",
52
+ "model_max_length": 512,
53
+ "pad_token": "<pad>",
54
+ "sep_token": "</s>",
55
+ "tokenizer_class": "RobertaTokenizer",
56
+ "trim_offsets": true,
57
+ "unk_token": "<unk>"
58
+ }
src/fine_tuned_model/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
src/predict.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/predict.py
2
+
3
+ import os # To help build file paths correctly
4
+ import torch # PyTorch library, for tensors and model operations
5
+ from transformers import (
6
+ AutoModelForSequenceClassification,
7
+ AutoTokenizer,
8
+ ) # Hugging Face stuff for models
9
+
10
+
11
+ # --- Configuration ---
12
+ # This is where our fine-tuned model and tokenizer files are stored
13
+ # Assuming 'fine_tuned_model' directory is inside 'src/' and next to this predict.py file
14
+ _SCRIPT_DIR = os.path.dirname(
15
+ os.path.abspath(__file__)
16
+ ) # Gets the directory where this script is
17
+ MODEL_PATH = os.path.join(
18
+ _SCRIPT_DIR, "fine_tuned_model"
19
+ ) # User confirmed this variable name and directory
20
+
21
+ print(f"DEBUG (predict.py): Model path set to: {MODEL_PATH}") # For checking the path
22
+
23
+ # --- Device Setup ---
24
+ # Check if a GPU is available, otherwise use CPU
25
+ # Using GPU makes predictions much faster!
26
+ if torch.cuda.is_available():
27
+ device = torch.device("cuda")
28
+ # Trying to get the name of the GPU, just for information
29
+ try:
30
+ gpu_name = torch.cuda.get_device_name(0)
31
+ print(f"INFO (predict.py): GPU is available ({gpu_name}), using CUDA.")
32
+ except Exception as e:
33
+ print(
34
+ f"INFO (predict.py): GPU is available, using CUDA. (Could not get GPU name: {e})"
35
+ )
36
+ else:
37
+ device = torch.device("cpu")
38
+ print(
39
+ "INFO (predict.py): GPU not available, using CPU. Predictions might be slower."
40
+ )
41
+
42
+ # --- Load Model and Tokenizer ---
43
+ # We load these once when the script (or module) is first loaded.
44
+ # This is much better than loading them every time we want to predict.
45
+ model = None
46
+ tokenizer = None
47
+ id2label_mapping = {0: "negative", 1: "neutral", 2: "positive"} # Default mapping
48
+
49
+ try:
50
+ print(f"INFO (predict.py): Loading model from {MODEL_PATH}...")
51
+ # Load the pre-trained model for sequence classification
52
+ # This should be the PyTorch RoBERTa model we fine-tuned
53
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
54
+ model.to(device) # Move the model to the GPU (or CPU if no GPU)
55
+ model.eval() # Set the model to evaluation mode (important for layers like Dropout)
56
+ print("INFO (predict.py): Model loaded successfully and set to evaluation mode.")
57
+
58
+ print(f"INFO (predict.py): Loading tokenizer from {MODEL_PATH}...")
59
+ # Load the tokenizer that matches the model
60
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
61
+ print("INFO (predict.py): Tokenizer loaded successfully.")
62
+
63
+ # Get the label mapping from the model's configuration
64
+ # This was saved during fine-tuning
65
+ if hasattr(model.config, "id2label") and model.config.id2label:
66
+ id2label_mapping = model.config.id2label
67
+ # Convert string keys from config.json to int if necessary
68
+ id2label_mapping = {int(k): v for k, v in id2label_mapping.items()}
69
+ print(
70
+ f"INFO (predict.py): Loaded id2label mapping from model config: {id2label_mapping}"
71
+ )
72
+ else:
73
+ print(
74
+ "WARN (predict.py): id2label not found in model config, using default mapping."
75
+ )
76
+
77
+ except FileNotFoundError:
78
+ print(f"--- CRITICAL ERROR (predict.py) ---")
79
+ print(f"Model or Tokenizer files NOT FOUND at the specified path: {MODEL_PATH}")
80
+ print(
81
+ f"Please ensure the '{os.path.basename(MODEL_PATH)}' directory exists at '{_SCRIPT_DIR}' and contains all necessary model files (pytorch_model.bin/model.safetensors, config.json, tokenizer files, etc.)."
82
+ )
83
+ # Keep model and tokenizer as None, so predict_sentiments can handle it
84
+ except Exception as e:
85
+ print(f"--- ERROR (predict.py) ---")
86
+ print(f"An unexpected error occurred loading model or tokenizer: {e}")
87
+ # Keep model and tokenizer as None
88
+
89
+
90
+ # --- Preprocessing Function ---
91
+ # Same function we used for training data to make sure inputs are consistent
92
+ def preprocess_tweet(text):
93
+ """Replaces @user mentions and http links with placeholders."""
94
+ preprocessed_text = []
95
+ if text is None:
96
+ return "" # Handle None input
97
+ # Split text into parts by space
98
+ for t in text.split(" "):
99
+ if len(t) > 0: # Avoid processing empty parts from multiple spaces
100
+ t = "@user" if t.startswith("@") else t # Replace mentions
101
+ t = "http" if t.startswith("http") else t # Replace links
102
+ preprocessed_text.append(t)
103
+ return " ".join(preprocessed_text) # Put the parts back together
104
+
105
+
106
+ # --- Prediction Function (UPDATED to return probabilities) ---
107
+ def predict_sentiments(comment_list: list):
108
+ """
109
+ Predicts sentiments for a list of comment strings.
110
+ Returns a list of dictionaries, each containing the predicted label
111
+ and the probabilities (scores) for each class.
112
+ e.g., [{'label': 'positive', 'scores': {'negative': 0.1, 'neutral': 0.2, 'positive': 0.7}}, ...]
113
+ """
114
+ # Check if model and tokenizer are ready
115
+ if model is None or tokenizer is None:
116
+ print(
117
+ "ERROR (predict.py - predict_sentiments): Model or Tokenizer not loaded. Cannot predict."
118
+ )
119
+ # Return an error structure
120
+ return [{"label": "Error: Model not loaded", "scores": {}}] * len(comment_list)
121
+
122
+ if not comment_list: # Handle empty input list
123
+ return []
124
+
125
+ """
126
+ # Preprocess comments first
127
+ processed_comments = [preprocess_tweet(comment) for comment in comment_list]
128
+
129
+ # Tokenize the batch
130
+ print(
131
+ f"DEBUG (predict.py): Tokenizing {len(processed_comments)} comments for prediction..."
132
+ )
133
+ inputs = tokenizer(
134
+ processed_comments,
135
+ padding=True,
136
+ truncation=True,
137
+ return_tensors="pt", # PyTorch tensors
138
+ max_length=(
139
+ tokenizer.model_max_length
140
+ if hasattr(tokenizer, "model_max_length") and tokenizer.model_max_length
141
+ else 512
142
+ ),
143
+ )
144
+
145
+ # Move inputs to the correct device
146
+ inputs = {k: v.to(device) for k, v in inputs.items()}
147
+
148
+ results_list = [] # To store the dictionaries
149
+ try:
150
+ # Perform inference without calculating gradients
151
+ with torch.no_grad():
152
+ outputs = model(**inputs)
153
+ logits = outputs.logits
154
+
155
+ # Apply Softmax to convert logits to probabilities
156
+ # dim=-1 applies softmax across the last dimension (the classes)
157
+ probabilities = torch.softmax(logits, dim=-1)
158
+
159
+ # Get the predicted class IDs (index of the highest probability)
160
+ predicted_class_ids = torch.argmax(probabilities, dim=-1)
161
+
162
+ # Move results to CPU and convert to Python lists/numpy for easier handling
163
+ probs_list = (
164
+ probabilities.cpu().numpy().tolist()
165
+ ) # List of lists of probabilities
166
+ ids_list = predicted_class_ids.cpu().numpy().tolist() # List of predicted IDs
167
+
168
+ print(
169
+ f"DEBUG (predict.py): Probabilities and IDs calculated. Batch size: {len(ids_list)}"
170
+ )
171
+
172
+ # Format the output: list of dictionaries
173
+ for i in range(len(ids_list)):
174
+ pred_id = ids_list[i]
175
+ # Map predicted ID to label string using the mapping from model config
176
+ pred_label = id2label_mapping.get(pred_id, "Unknown")
177
+
178
+ # Create the dictionary of scores {label_name: probability}
179
+ pred_scores = {
180
+ label_name: probs_list[i][label_id]
181
+ for label_id, label_name in id2label_mapping.items()
182
+ # Ensure index is within bounds, just in case
183
+ if 0 <= label_id < probabilities.shape[-1]
184
+ }
185
+
186
+ # Append the result for this comment
187
+ results_list.append({"label": pred_label, "scores": pred_scores})
188
+
189
+ except Exception as e:
190
+ print(f"--- ERROR (predict.py - predict_sentiments) ---")
191
+ print(f"Error during sentiment prediction inference or formatting: {e}")
192
+ import traceback
193
+
194
+ traceback.print_exc() # Print full traceback for debugging
195
+ # Return error structure for each comment
196
+ results_list = [
197
+ {"label": "Error: Prediction failed", "scores": {}} for _ in comment_list
198
+ ]
199
+
200
+ return results_list # Return the list of dictionaries
201
+ """
202
+
203
+ inference_batch_size = 64 # You can adjust this number based on performance/memory
204
+ print(
205
+ f"INFO (predict.py): Predicting sentiments for {len(comment_list)} comments in batches of {inference_batch_size}..."
206
+ )
207
+
208
+ all_results_list = [] # We'll collect results for all batches here
209
+
210
+ # --- Loop through the comment list in batches ---
211
+ try:
212
+ total_comments = len(comment_list)
213
+ # This loop goes from 0 to total_comments, jumping by inference_batch_size each time
214
+ for i in range(0, total_comments, inference_batch_size):
215
+ # Get the current slice of comments for this batch
216
+ batch_comments = comment_list[i : i + inference_batch_size]
217
+
218
+ # Just printing progress for long lists
219
+ current_batch_num = i // inference_batch_size + 1
220
+ total_batches = (
221
+ total_comments + inference_batch_size - 1
222
+ ) // inference_batch_size
223
+ print(
224
+ f"DEBUG (predict.py): Processing batch {current_batch_num}/{total_batches}..."
225
+ )
226
+
227
+ # --- Process ONLY the current batch ---
228
+ # 1. Preprocess this specific batch
229
+ processed_batch = [preprocess_tweet(comment) for comment in batch_comments]
230
+
231
+ # 2. Tokenize this batch
232
+ # Tokenizer handles padding within this smaller batch
233
+ inputs = tokenizer(
234
+ processed_batch,
235
+ padding=True,
236
+ truncation=True,
237
+ return_tensors="pt",
238
+ max_length=(
239
+ tokenizer.model_max_length
240
+ if hasattr(tokenizer, "model_max_length")
241
+ and tokenizer.model_max_length
242
+ else 512
243
+ ),
244
+ )
245
+
246
+ # 3. Move this batch's inputs to the device (GPU/CPU)
247
+ inputs = {k: v.to(device) for k, v in inputs.items()}
248
+
249
+ # 4. Make prediction for this batch - no need for gradients
250
+ with torch.no_grad():
251
+ outputs = model(**inputs)
252
+ logits = outputs.logits # Raw scores from the model for this batch
253
+
254
+ # 5. Calculate probabilities and get predicted class IDs for this batch
255
+ probabilities_batch = torch.softmax(logits, dim=-1)
256
+ predicted_class_ids_batch = torch.argmax(probabilities_batch, dim=-1)
257
+
258
+ # 6. Move results back to CPU, convert to lists for easier looping
259
+ probs_list_batch = probabilities_batch.cpu().numpy().tolist()
260
+ ids_list_batch = predicted_class_ids_batch.cpu().numpy().tolist()
261
+
262
+ # 7. Format results for each comment in THIS batch
263
+ batch_results = []
264
+ for j in range(len(ids_list_batch)):
265
+ pred_id = ids_list_batch[j]
266
+ pred_label = id2label_mapping.get(
267
+ pred_id, "Unknown"
268
+ ) # Map ID to label name
269
+ # Create the scores dictionary for this comment
270
+ pred_scores = {
271
+ label_name: probs_list_batch[j][label_id]
272
+ for label_id, label_name in id2label_mapping.items()
273
+ if 0
274
+ <= label_id
275
+ < probabilities_batch.shape[-1] # Safety check for index
276
+ }
277
+ # Add the result for this comment
278
+ batch_results.append({"label": pred_label, "scores": pred_scores})
279
+
280
+ # Add the results from this completed batch to our main list
281
+ all_results_list.extend(batch_results)
282
+ # --- Finished processing current batch ---
283
+
284
+ print(
285
+ f"INFO (predict.py): Finished processing all {len(all_results_list)} comments."
286
+ )
287
+
288
+ except Exception as e:
289
+ # Catch errors that might happen during the loop
290
+ print(f"--- ERROR (predict.py - predict_sentiments loop) ---")
291
+ print(
292
+ f"An error occurred during batch prediction (around comment index {i}): {e}"
293
+ )
294
+ import traceback
295
+
296
+ traceback.print_exc() # Print full error details to console
297
+ # Try to return results for processed batches + error messages for the rest
298
+ num_processed = len(all_results_list)
299
+ num_remaining = len(comment_list) - num_processed
300
+ # Add error indicators for comments that couldn't be processed
301
+ all_results_list.extend(
302
+ [{"label": "Error: Batch failed", "scores": {}}] * num_remaining
303
+ )
304
+
305
+ # Return the list containing results for all comments
306
+ return all_results_list
307
+
308
+
309
+ # --- Main block for testing this script directly (UPDATED to show scores) ---
310
+ if __name__ == "__main__":
311
+ print("\n--- Testing predict.py Script Directly ---")
312
+ if model and tokenizer:
313
+ sample_comments_for_testing = [
314
+ "This is an amazing movie, I loved it!",
315
+ "I'm not sure how I feel about this, it was okay.",
316
+ "Worst experience ever, would not recommend.",
317
+ "The food was alright, but the service was slow.",
318
+ "What a fantastic day! #blessed",
319
+ "I hate waiting in long lines.",
320
+ "@user Check out http this is cool.",
321
+ "Just a normal sentence, nothing special here.",
322
+ "",
323
+ "This new update is absolutely terrible and full of bugs.",
324
+ ]
325
+
326
+ print("\nInput Comments for Direct Test:")
327
+ for i, c in enumerate(sample_comments_for_testing):
328
+ print(f"{i+1}. '{c}'")
329
+
330
+ # Get predictions (now a list of dictionaries)
331
+ prediction_results = predict_sentiments(sample_comments_for_testing)
332
+
333
+ print("\nPredicted Sentiments and Scores (Direct Test):")
334
+ # Loop through the results list
335
+ for i, (comment, result) in enumerate(
336
+ zip(sample_comments_for_testing, prediction_results)
337
+ ):
338
+ print(f"{i+1}. Comment: '{comment}'")
339
+ # Format scores nicely for printing
340
+ scores_dict = result.get("scores", {})
341
+ formatted_scores = ", ".join(
342
+ [f"{name}: {score:.3f}" for name, score in scores_dict.items()]
343
+ )
344
+ print(f" -> Predicted Label: {result.get('label', 'N/A')}")
345
+ # Also print the raw scores dictionary
346
+ print(f" -> Scores: {{{formatted_scores}}}")
347
+ print("--- Direct Test Finished ---")
348
+ else:
349
+ print("ERROR (predict.py - main test): Model and/or tokenizer not loaded.")
350
+ print(
351
+ f"Please check the MODEL_PATH ('{MODEL_PATH}') and ensure model files are present."
352
+ )
353
+
354
+
355
+ # # --- Prediction Function ---
356
+ # def predict_sentiments(comment_list: list):
357
+ # """
358
+ # Predicts sentiments for a list of comment strings.
359
+ # Returns a list of sentiment labels (e.g., "positive", "neutral", "negative").
360
+ # """
361
+ # # Check if model and tokenizer were loaded properly
362
+ # if model is None or tokenizer is None:
363
+ # print(
364
+ # "ERROR (predict.py - predict_sentiments): Model or Tokenizer not loaded. Cannot make predictions."
365
+ # )
366
+ # # Return an error message for each comment if model isn't ready
367
+ # return ["Error: Model not loaded"] * len(comment_list)
368
+
369
+ # if not comment_list: # If the input list is empty
370
+ # return []
371
+
372
+ # # First, preprocess all comments like we did for training data
373
+ # processed_comments = [preprocess_tweet(comment) for comment in comment_list]
374
+
375
+ # # Tokenize the processed comments
376
+ # # This turns text into numbers (input IDs, attention mask) for the model
377
+ # # padding=True: make all sequences in the batch the same length
378
+ # # truncation=True: cut off sequences longer than the model can handle
379
+ # # return_tensors="pt": return PyTorch tensors
380
+ # # max_length: ensure we don't exceed model's limit (e.g., 512 for RoBERTa)
381
+ # print(f"DEBUG (predict.py): Tokenizing {len(processed_comments)} comments...")
382
+ # inputs = tokenizer(
383
+ # processed_comments,
384
+ # padding=True,
385
+ # truncation=True,
386
+ # return_tensors="pt",
387
+ # max_length=(
388
+ # tokenizer.model_max_length
389
+ # if hasattr(tokenizer, "model_max_length") and tokenizer.model_max_length
390
+ # else 512
391
+ # ),
392
+ # )
393
+
394
+ # # Move the tokenized inputs to the same device as the model (GPU or CPU)
395
+ # inputs = {k: v.to(device) for k, v in inputs.items()}
396
+
397
+ # sentiment_labels_as_strings = []
398
+ # try:
399
+ # # Make predictions
400
+ # # torch.no_grad() is important for inference:
401
+ # # it tells PyTorch not to calculate gradients, saving memory and speeding things up.
402
+ # with torch.no_grad():
403
+ # outputs = model(**inputs) # Get model outputs
404
+ # logits = outputs.logits # These are the raw scores from the final layer
405
+
406
+ # # Get the predicted class ID by finding the index with the highest score (logit)
407
+ # # logits shape is (batch_size, num_labels)
408
+ # predicted_class_ids = torch.argmax(
409
+ # logits, dim=-1
410
+ # ) # dim=-1 means find max along the last dimension
411
+
412
+ # # Convert the predicted class IDs (numbers) to actual sentiment labels (strings)
413
+ # # using the id2label_mapping we got from the model's config
414
+ # # .item() gets the Python number from a 0-dim PyTorch tensor
415
+ # sentiment_labels_as_strings = [
416
+ # id2label_mapping.get(class_id.item(), "Unknown")
417
+ # for class_id in predicted_class_ids
418
+ # ]
419
+ # print(
420
+ # f"DEBUG (predict.py): Predictions made. Example: {sentiment_labels_as_strings[:3] if sentiment_labels_as_strings else 'N/A'}"
421
+ # )
422
+
423
+ # except Exception as e:
424
+ # print(f"--- ERROR (predict.py - predict_sentiments) ---")
425
+ # print(f"Error during sentiment prediction inference: {e}")
426
+ # # Return an error message for each comment if prediction fails
427
+ # sentiment_labels_as_strings = ["Error: Prediction failed"] * len(comment_list)
428
+
429
+ # return sentiment_labels_as_strings
430
+
431
+
432
+ # # --- Main block for testing this script directly ---
433
+ # # This part only runs if you execute 'python src/predict.py' from the terminal
434
+ # # It won't run when app.py imports this file.
435
+ # if __name__ == "__main__":
436
+ # print("\n--- Testing predict.py Script Directly ---")
437
+ # # Check if model was loaded, otherwise can't test
438
+ # if model and tokenizer:
439
+ # sample_comments_for_testing = [
440
+ # "This is an amazing movie, I loved it!", # Expected: positive
441
+ # "I'm not sure how I feel about this, it was okay.", # Expected: neutral
442
+ # "Worst experience ever, would not recommend.", # Expected: negative
443
+ # "The food was alright, but the service was slow.", # Expected: neutral or negative
444
+ # "What a fantastic day! #blessed", # Expected: positive
445
+ # "I hate waiting in long lines.", # Expected: negative
446
+ # "@user Check out http this is cool.", # Test preprocessing, Expected: positive or neutral
447
+ # "Just a normal sentence, nothing special here.", # Expected: neutral
448
+ # "", # Empty string test
449
+ # "This new update is absolutely terrible and full of bugs.", # Expected: negative
450
+ # ]
451
+
452
+ # print("\nInput Comments for Direct Test:")
453
+ # for i, c in enumerate(sample_comments_for_testing):
454
+ # print(f"{i + 1}. '{c}'")
455
+
456
+ # # Get predictions using our main function
457
+ # predicted_sentiments = predict_sentiments(sample_comments_for_testing)
458
+
459
+ # print("\nPredicted Sentiments (Direct Test):")
460
+ # for i, (comment, sentiment) in enumerate(
461
+ # zip(sample_comments_for_testing, predicted_sentiments)
462
+ # ):
463
+ # print(
464
+ # f"{i + 1}. Comment: '{comment}'\n -> Predicted Sentiment: {sentiment}"
465
+ # )
466
+ # print("--- Direct Test Finished ---")
467
+ # else:
468
+ # print(
469
+ # "ERROR (predict.py - main test): Model and/or tokenizer not loaded. Cannot run direct test."
470
+ # )
471
+ # print(
472
+ # f"Please check the MODEL_PATH ('{MODEL_PATH}') and ensure model files are present."
473
+ # )
src/streamlit_app.py CHANGED
@@ -1,40 +1,709 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
- import streamlit as st
5
-
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/streamlit_app.py
2
+
3
+ import streamlit as st
4
+ import pandas as pd
5
+ import re # For robust YouTube video ID extraction
6
+
7
+ # Try to import Plotly, if not available, we'll use Streamlit's basic charts
8
+ try:
9
+ import plotly.express as px
10
+
11
+ PLOTLY_AVAILABLE = True
12
+ except ImportError:
13
+ PLOTLY_AVAILABLE = False
14
+ st.sidebar.warning(
15
+ "Plotly not installed. Charts will be basic. Consider 'pip install plotly'."
16
+ ) # Optional warning
17
+
18
+ # Import our custom modules from the src directory
19
+ try:
20
+ from predict import (
21
+ predict_sentiments,
22
+ ) # This function should return list of strings: "positive", "negative", "neutral"
23
+ from youtube import (
24
+ get_video_comments,
25
+ ) # This function should return a list of comment strings
26
+ except ImportError as e:
27
+ st.error(
28
+ f"Failed to import necessary modules (predict.py, youtube.py). Ensure they are in the 'src' directory. Error: {e}"
29
+ )
30
+ # Stop the app if core modules are missing
31
+ st.stop()
32
+
33
+
34
+ def extract_video_id(url_or_id: str) -> str | None:
35
+ """
36
+ Tries to get the YouTube video ID from different common URL types.
37
+ Also handles if the input is just the ID itself.
38
+ A bit of regex to find the ID part in common URLs.
39
+ """
40
+ if not url_or_id:
41
+ return None
42
+
43
+ # Patterns for various YouTube URL formats
44
+ # Order matters: more specific patterns should come first if overlap exists
45
+ patterns = [
46
+ r"watch\?v=([a-zA-Z0-9_-]{11})", # Standard watch URL
47
+ r"youtu\.be/([a-zA-Z0-9_-]{11})", # Shortened URL
48
+ r"embed/([a-zA-Z0-9_-]{11})", # Embed URL
49
+ r"shorts/([a-zA-Z0-9_-]{11})", # Shorts URL
50
+ ]
51
+
52
+ for pattern in patterns:
53
+ match = re.search(pattern, url_or_id)
54
+ if match:
55
+ return match.group(1) # The first capturing group is the ID
56
+
57
+ # If no pattern matches, check if the input itself is a valid 11-char ID
58
+ # Basic check: 11 chars, no spaces, not starting with http (already handled by regex above implicitly)
59
+ if len(url_or_id) == 11 and not (
60
+ "/" in url_or_id or "?" in url_or_id or "=" in url_or_id or "." in url_or_id
61
+ ):
62
+ return url_or_id # Assume it's a direct ID
63
+
64
+ return None # Return None if no ID found
65
+
66
+
67
+ def analyze_youtube_video(video_url_or_id: str):
68
+ """
69
+ Main function for the YouTube analysis part.
70
+ It gets comments, then predicts their sentiments.
71
+ Then it summarizes the results.
72
+ """
73
+ video_id = extract_video_id(video_url_or_id)
74
+ if not video_id:
75
+ # Give a more helpful error message to the user
76
+ st.error(
77
+ "Oops! That doesn't look like a valid YouTube URL or Video ID. Please check and try again. Example: Z9kGRMglw-I or youtu.be/3?v=Z9kGRMglw-I"
78
+ )
79
+ return None # Stop if no valid ID
80
+
81
+ summary_data = {} # Initialize
82
+ # comments_with_sentiments = [] # Initialize
83
+
84
+ try:
85
+ with st.spinner(f"Fetching comments & title for video ID: {video_id}..."):
86
+ video_data = get_video_comments(video_id)
87
+ comments_text_list = video_data.get("comments", [])
88
+ video_title = video_data.get("title", "Video Title Not Found")
89
+ print(
90
+ f"DEBUG (streamlit_app.py): Received title from youtube.py: '{video_title}'"
91
+ )
92
+
93
+ # Check if we actually got any comments
94
+ if not comments_text_list:
95
+ st.warning(
96
+ "Hmm, no comments found for this video. Are comments enabled? Or is it a very new video?"
97
+ )
98
+ # Provide a default empty summary structure
99
+ summary_data = {
100
+ "num_comments_fetched": 0,
101
+ "num_comments_analyzed": 0,
102
+ "positive": 0,
103
+ "neutral": 0,
104
+ "negative": 0,
105
+ "positive_percentage": 0,
106
+ "neutral_percentage": 0,
107
+ "negative_percentage": 0,
108
+ "num_valid_predictions": 0,
109
+ }
110
+ return {"summary": summary_data, "comments_data": []}
111
+
112
+ st.info(
113
+ f"Great! Found {len(comments_text_list)} comments. Now thinking about their feelings (sentiments)..."
114
+ )
115
+ # Another spinner for the prediction part, as this can be slow on CPU
116
+ with st.spinner("Analyzing sentiments with the model... Please wait."):
117
+ # This calls predict_sentiments from predict.py
118
+ # Expected to return: ["positive", "negative", "neutral", ...]
119
+ prediction_results = predict_sentiments(comments_text_list)
120
+
121
+ positive_count = 0
122
+ negative_count = 0
123
+ neutral_count = 0
124
+ error_count = 0
125
+
126
+ for result in prediction_results:
127
+ label = result.get("label")
128
+ if label == "positive":
129
+ positive_count += 1
130
+ elif label == "negative":
131
+ negative_count += 1
132
+ elif label == "neutral":
133
+ neutral_count += 1
134
+ else:
135
+ error_count += 1
136
+
137
+ num_valid_predictions = positive_count + negative_count + neutral_count
138
+ total_comments_processed = len(prediction_results)
139
+ if error_count > 0:
140
+ st.warning(
141
+ f"Could not predict sentiment properly for {error_count} comments."
142
+ )
143
+
144
+ summary_data = {
145
+ "video_title": video_title,
146
+ "num_comments_fetched": len(comments_text_list),
147
+ "num_comments_analyzed": total_comments_processed,
148
+ "num_valid_predictions": num_valid_predictions,
149
+ "positive": positive_count,
150
+ "negative": negative_count,
151
+ "neutral": neutral_count,
152
+ "positive_percentage": (
153
+ (positive_count / num_valid_predictions) * 100
154
+ if num_valid_predictions > 0
155
+ else 0
156
+ ),
157
+ "neutral_percentage": (
158
+ (neutral_count / num_valid_predictions) * 100
159
+ if num_valid_predictions > 0
160
+ else 0
161
+ ),
162
+ "negative_percentage": (
163
+ (negative_count / num_valid_predictions) * 100
164
+ if num_valid_predictions > 0
165
+ else 0
166
+ ),
167
+ }
168
+
169
+ comments_data_for_df = []
170
+ for i in range(len(comments_text_list)):
171
+ comment_text = comments_text_list[i]
172
+ result = prediction_results[i]
173
+ label = result.get("label", "Error")
174
+ scores = result.get("scores", {})
175
+ confidence = max(scores.values()) if scores else 0.0
176
+
177
+ comments_data_for_df.append(
178
+ {
179
+ "Comment Text": comment_text,
180
+ "Predicted Sentiment": label,
181
+ "Confidence": confidence,
182
+ # "All Scores": scores
183
+ }
184
+ )
185
+
186
+ return {"summary": summary_data, "comments_data": comments_data_for_df}
187
+
188
+ except Exception as e:
189
+ # Show a general error if anything unexpected happens
190
+ st.error(f"Uh oh! An error popped up during analysis: {str(e)}")
191
+ # Also print to console for more detailed debugging when running locally
192
+ print(f"Full error in analyze_youtube_video: {e}")
193
+ import traceback
194
+
195
+ traceback.print_exc() # Print full traceback to console
196
+ return None # Return None on error
197
+
198
+
199
+ # --- Streamlit App UI ---
200
+
201
+ # Page configuration: Set to centered layout (default) instead of "wide"
202
+ st.set_page_config(page_title="Social Sentiment Analysis", layout="centered")
203
+
204
+ st.title("📊 SOCIAL SENTIMENT ANALYSIS")
205
+ # A little description for the user
206
+ st.write(
207
+ """
208
+ Welcome to the **Social Sentiment Analyzer!** 👋
209
+
210
+ This application uses a fine-tuned RoBERTa model to predict the sentiment (Positive, Neutral, or Negative) expressed in text.
211
+
212
+ Use the tabs below to choose your input method:
213
+ * **Analyze Text Input:** Paste or type any English text directly.
214
+ * **YouTube Analysis:** Enter a YouTube video URL or ID to analyze its comments.
215
+ * **Twitter/X Analysis:** Support for analyzing Twitter/X posts is coming soon!
216
+
217
+ Select a tab to begin!
218
+ """
219
+ )
220
+
221
+ # Tabs for different platforms, makes it easy to add Twitter later
222
+ tab_text_input, tab_youtube, tab_twitter = st.tabs(
223
+ ["Analyze Text Input", "YouTube Analysis", "Twitter/X Analysis (Coming Soon!)"]
224
+ )
225
+
226
+ with tab_text_input:
227
+ # Header for this tab
228
+ st.header("Analyze Sentiment of Your Text")
229
+ st.write(
230
+ "Enter a sentence or a short paragraph below to see its predicted sentiment distribution."
231
+ )
232
+
233
+ # Use text_area for potentially longer input
234
+ # Giving it a unique key helps maintain state if needed
235
+ user_text = st.text_area(
236
+ "Enter text here:",
237
+ key="text_input_area_key",
238
+ height=100,
239
+ placeholder="Type or paste your text...",
240
+ )
241
+
242
+ # Button to trigger the analysis
243
+ if st.button("Analyze Text", key="text_input_analyze_btn"):
244
+ # Check if the user actually entered something (not just whitespace)
245
+ if user_text and not user_text.isspace():
246
+ # Show a spinner while processing
247
+ with st.spinner("Analyzing your text..."):
248
+ try:
249
+ # Call the prediction function from predict.py
250
+ # Pass the input text as a list with one element
251
+ prediction_results = predict_sentiments([user_text])
252
+
253
+ # Check if prediction was successful and returned expected format
254
+ if (
255
+ prediction_results
256
+ and isinstance(prediction_results, list)
257
+ and len(prediction_results) > 0
258
+ ):
259
+ # Get the result dictionary for the single input text
260
+ result = prediction_results[0]
261
+ predicted_label = result.get("label")
262
+ scores = result.get(
263
+ "scores"
264
+ ) # This should be a dict like {'negative': 0.1, ...}
265
+
266
+ # Make sure we got a valid label and scores dictionary
267
+ if (
268
+ predicted_label
269
+ and scores
270
+ and isinstance(scores, dict)
271
+ and predicted_label != "Error"
272
+ ):
273
+
274
+ # Display the top predicted sentiment
275
+ st.subheader("Predicted Sentiment:")
276
+ # Using Streamlit's built-in status elements for color
277
+ if predicted_label == "positive":
278
+ st.success(
279
+ f"The model thinks the sentiment is: **{predicted_label.capitalize()}** 👍"
280
+ )
281
+ elif predicted_label == "negative":
282
+ st.error(
283
+ f"The model thinks the sentiment is: **{predicted_label.capitalize()}** 👎"
284
+ )
285
+ else: # Neutral or potentially "Unknown" if mapping failed
286
+ st.info(
287
+ f"The model thinks the sentiment is: **{predicted_label.capitalize()}** 😐"
288
+ )
289
+
290
+ st.write("---") # Adding a small separator
291
+ st.subheader(
292
+ "Detailed Probabilities:"
293
+ ) # Subheader for this section
294
+ if scores and isinstance(scores, dict):
295
+ # Using columns here helps align the probabilities nicely
296
+ prob_col_neg, prob_col_neu, prob_col_pos = st.columns(3)
297
+
298
+ # Helper to get score safely
299
+ def get_score(sentiment_name):
300
+ return scores.get(
301
+ sentiment_name.lower(), 0.0
302
+ ) # Use lowercase to be safe
303
+
304
+ value_font_size = "22px"
305
+ value_font_weight = "bold"
306
+
307
+ with prob_col_neg:
308
+ neg_prob = get_score("negative")
309
+ # Display label "Negative"
310
+ st.markdown("**Negative 👎:**")
311
+ # Display the probability, larger font, red color
312
+ st.markdown(
313
+ f"<p style='font-size: {value_font_size}; font-weight: {value_font_weight}; color:red;'>{neg_prob:.1%}</p>",
314
+ unsafe_allow_html=True,
315
+ )
316
+
317
+ with prob_col_neu:
318
+ neu_prob = get_score("neutral")
319
+ # Display label "Neutral"
320
+ st.markdown("**Neutral 😐:**")
321
+ # Display the probability, larger font, grey color
322
+ st.markdown(
323
+ f"<p style='font-size: {value_font_size}; font-weight: {value_font_weight}; color:grey;'>{neu_prob:.1%}</p>",
324
+ unsafe_allow_html=True,
325
+ )
326
+
327
+ with prob_col_pos:
328
+ pos_prob = get_score("positive")
329
+ # Display label "Positive"
330
+ st.markdown("**Positive 👍:**")
331
+ # Display the probability, larger font, green color
332
+ st.markdown(
333
+ f"<p style='font-size: {value_font_size}; font-weight: {value_font_weight}; color:green;'>{pos_prob:.1%}</p>",
334
+ unsafe_allow_html=True,
335
+ )
336
+
337
+ else:
338
+ # If scores dict is missing or invalid
339
+ st.write("Could not retrieve probability scores.")
340
+ st.write("---") # Another separator before the chart
341
+
342
+ # --- Display Pie Chart of Probabilities ---
343
+ st.subheader("Sentiment Probabilities:")
344
+ if PLOTLY_AVAILABLE:
345
+ # Convert the scores dictionary to a DataFrame suitable for Plotly
346
+ # Ensure keys match class_names for consistency if possible
347
+ # Assuming scores keys are 'negative', 'neutral', 'positive'
348
+ score_items = list(scores.items())
349
+ if score_items: # Check if scores dict is not empty
350
+ df_scores = pd.DataFrame(
351
+ score_items,
352
+ columns=["Sentiment", "Probability"],
353
+ )
354
+ # Convert Probability to numeric just in case
355
+ df_scores["Probability"] = pd.to_numeric(
356
+ df_scores["Probability"]
357
+ )
358
+
359
+ # Define colors (ensure keys match Sentiment names case)
360
+ color_map = {
361
+ "positive": "green",
362
+ "neutral": "grey",
363
+ "negative": "red",
364
+ }
365
+ # Make keys lowercase for robust mapping
366
+ df_scores["Sentiment"] = df_scores[
367
+ "Sentiment"
368
+ ].str.capitalize()
369
+ df_scores["Sentiment_Lower"] = df_scores[
370
+ "Sentiment"
371
+ ].str.lower()
372
+ color_map_lower = {
373
+ k.lower(): v for k, v in color_map.items()
374
+ }
375
+
376
+ # Debug print for the dataframe fed to plotly
377
+ # st.write("DEBUG: DataFrame for text input pie chart:")
378
+ # st.dataframe(df_scores)
379
+
380
+ try:
381
+ # Create the pie chart
382
+ fig_pie_text = px.pie(
383
+ df_scores,
384
+ values="Probability", # Use the probability column
385
+ names="Sentiment", # Labels for the slices
386
+ title="Probability Distribution per Class",
387
+ color="Sentiment_Lower", # Use lowercase for mapping
388
+ color_discrete_map=color_map_lower,
389
+ ) # Map colors
390
+
391
+ # Update how text is shown on slices
392
+ fig_pie_text.update_traces(
393
+ textposition="inside",
394
+ textinfo="percent+label",
395
+ hovertemplate="Sentiment: %{label}<br>Probability: %{percent}",
396
+ )
397
+ # Maybe add hover info too
398
+ fig_pie_text.update_layout(
399
+ uniformtext_minsize=16,
400
+ uniformtext_mode="hide",
401
+ ) # Improve text fitting
402
+
403
+ st.plotly_chart(
404
+ fig_pie_text, use_container_width=True
405
+ )
406
+
407
+ except Exception as plot_e:
408
+ st.error(
409
+ f"Sorry, couldn't create the probability pie chart: {str(plot_e)}"
410
+ )
411
+ print(
412
+ f"Full error during text input Plotly chart generation: {plot_e}"
413
+ )
414
+ import traceback
415
+
416
+ traceback.print_exc()
417
+ st.write(
418
+ "Raw scores:", scores
419
+ ) # Show raw scores as fallback
420
+
421
+ else: # If scores dictionary was empty
422
+ st.warning(
423
+ "Received empty scores, cannot plot chart."
424
+ )
425
+
426
+ elif not PLOTLY_AVAILABLE:
427
+ st.warning(
428
+ "Plotly not installed, cannot display pie chart. Showing raw scores instead."
429
+ )
430
+ st.json(
431
+ scores
432
+ ) # Display raw scores as JSON if no Plotly
433
+ else:
434
+ # This case should be covered by the check above, but for safety
435
+ st.write("No valid score data available to plot.")
436
+ # --- End Pie Chart ---
437
+
438
+ else:
439
+ # This handles cases where predict_sentiments returned an error label
440
+ st.error(
441
+ f"Sentiment analysis failed for the input text. Result: {result}"
442
+ )
443
+
444
+ else:
445
+ # This handles cases where predict_sentiments returned None or empty list
446
+ st.error(
447
+ "Received no valid result from the prediction function."
448
+ )
449
+
450
+ except Exception as analysis_e:
451
+ # Catch-all for other errors during analysis for this tab
452
+ st.error(
453
+ f"An error occurred during text analysis: {str(analysis_e)}"
454
+ )
455
+ print(f"Full error during text input analysis: {analysis_e}")
456
+ import traceback
457
+
458
+ traceback.print_exc()
459
+
460
+ else:
461
+ # If user clicks button without entering text
462
+ st.warning("Please enter some text in the text area first!")
463
+
464
+ with tab_youtube:
465
+ st.header("YouTube Comment Sentiment Analyzer")
466
+ # Input field for URL or ID
467
+ video_url_input = st.text_input(
468
+ "Enter YouTube Video URL or Video ID:",
469
+ key="youtube_url_input_key", # Giving it a unique key
470
+ placeholder="e.g., Z9kGRMglw-I or full URL",
471
+ )
472
+
473
+ # Button to trigger analysis
474
+ if st.button("Analyze YouTube Comments", key="youtube_analyze_button_key"):
475
+ if video_url_input: # Check if user actually entered something
476
+ # analyze_youtube_video handles spinners internally now
477
+ analysis_results = analyze_youtube_video(video_url_input)
478
+
479
+ if (
480
+ analysis_results and analysis_results["summary"]
481
+ ): # Check if we got valid results
482
+ summary = analysis_results["summary"]
483
+ comments_data = analysis_results["comments_data"]
484
+ video_title_display = summary.get(
485
+ "video_title", "Video Title Not Available"
486
+ )
487
+
488
+ st.markdown("---")
489
+ # Displaying the video title using markdown for potential formatting later
490
+ st.markdown(f"### Analyzing Video: **{video_title_display}**")
491
+ st.markdown("---")
492
+
493
+ st.subheader("📊 Sentiment Summary")
494
+
495
+ # Define desired font sizes (you can adjust these)
496
+ # label_font_size = (
497
+ # "24px" # Font size for the label text like "Comments Fetched"
498
+ # )
499
+ label_font_size = "24px"
500
+ value_font_size = "28px" # Font size for the actual count like "137"
501
+ value_font_weight = "bold" # Make the count bold
502
+
503
+ # Define colors for the sentiment counts
504
+ positive_color = "green"
505
+ neutral_color = "grey"
506
+ negative_color = "red"
507
+
508
+ # Using 5 columns
509
+ col_fetched, col_analyzed, col_pos, col_neu, col_neg = st.columns(5)
510
+
511
+ # Metric 1: Comments Fetched
512
+ with col_fetched:
513
+ # Label for fetched comments
514
+ st.markdown(
515
+ f"<p style='font-size: {label_font_size}; margin-bottom: 0px;'>Comments Fetched</p>",
516
+ unsafe_allow_html=True,
517
+ )
518
+ # The number of fetched comments
519
+ st.markdown(
520
+ f"<p style='font-size: {value_font_size}; font-weight: {value_font_weight}; margin-top: 0px;'>{summary.get('num_comments_fetched', 0)}</p>",
521
+ unsafe_allow_html=True,
522
+ )
523
+
524
+ # Metric 2: Comments Analyzed
525
+ with col_analyzed:
526
+ # Label for analyzed comments
527
+ st.markdown(
528
+ f"<p style='font-size: {label_font_size}; margin-bottom: 0px;'>Comments Analyzed</p>",
529
+ unsafe_allow_html=True,
530
+ )
531
+ # The number of analyzed comments
532
+ st.markdown(
533
+ f"<p style='font-size: {value_font_size}; font-weight: {value_font_weight}; margin-top: 0px;'>{summary.get('num_comments_analyzed', 0)}</p>",
534
+ unsafe_allow_html=True,
535
+ )
536
+
537
+ # Metric 3: Positive
538
+ with col_pos:
539
+ # Label for positive comments, with emoji
540
+ st.markdown(
541
+ f"<p style='font-size: {label_font_size}; margin-bottom: 0px;'>Positive 👍</p>",
542
+ unsafe_allow_html=True,
543
+ )
544
+ # The count of positive comments, green and bold
545
+ st.markdown(
546
+ f"<p style='font-size: {value_font_size}; font-weight: {value_font_weight}; color:{positive_color}; margin-top: 0px;'>{summary.get('positive', 0)}</p>",
547
+ unsafe_allow_html=True,
548
+ )
549
+
550
+ # Metric 4: Neutral
551
+ with col_neu:
552
+ # Label for neutral comments
553
+ st.markdown(
554
+ f"<p style='font-size: {label_font_size}; margin-bottom: 0px;'>Neutral 😐</p>",
555
+ unsafe_allow_html=True,
556
+ )
557
+ # The count of neutral comments, grey and bold
558
+ st.markdown(
559
+ f"<p style='font-size: {value_font_size}; font-weight: {value_font_weight}; color:{neutral_color}; margin-top: 0px;'>{summary.get('neutral', 0)}</p>",
560
+ unsafe_allow_html=True,
561
+ )
562
+
563
+ # Metric 5: Negative
564
+ with col_neg:
565
+ # Label for negative comments
566
+ st.markdown(
567
+ f"<p style='font-size: {label_font_size}; margin-bottom: 0px;'>Negative 👎</p>",
568
+ unsafe_allow_html=True,
569
+ )
570
+ # The count of negative comments, red and bold
571
+ st.markdown(
572
+ f"<p style='font-size: {value_font_size}; font-weight: {value_font_weight}; color:{negative_color}; margin-top: 0px;'>{summary.get('negative', 0)}</p>",
573
+ unsafe_allow_html=True,
574
+ )
575
+
576
+ # Add a visual separator before charts
577
+ st.markdown("---")
578
+
579
+ # Data for charts - make sure it has counts > 0
580
+ if summary.get("num_valid_predictions", 0) > 0:
581
+ # Prepare DataFrame for Plotly charts
582
+ sentiment_data_for_plot = [
583
+ {"Sentiment": "Positive", "Count": summary.get("positive", 0)},
584
+ {"Sentiment": "Neutral", "Count": summary.get("neutral", 0)},
585
+ {"Sentiment": "Negative", "Count": summary.get("negative", 0)},
586
+ ]
587
+ sentiment_counts_df = pd.DataFrame(sentiment_data_for_plot)
588
+ # Filter out rows where Count is 0 for cleaner charts
589
+ sentiment_counts_df_for_plot = sentiment_counts_df[
590
+ sentiment_counts_df["Count"] > 0
591
+ ].copy()
592
+
593
+ # Define the color map for charts
594
+ # Keys should match the 'Sentiment' column values
595
+ color_map = {
596
+ "Positive": "green",
597
+ "Neutral": "grey",
598
+ "Negative": "red",
599
+ }
600
+
601
+ if not sentiment_counts_df_for_plot.empty:
602
+ st.subheader("📈 Sentiment Distribution Charts")
603
+ # Try to use Plotly for richer charts
604
+ if PLOTLY_AVAILABLE:
605
+ try:
606
+ # Pie Chart (Corrected data input for Plotly)
607
+ # Plotly pie chart expects a DataFrame where one column is values, another is names
608
+ fig_pie = px.pie(
609
+ sentiment_counts_df_for_plot, # Use the filtered DataFrame
610
+ values="Count", # Column for pie slice values
611
+ names="Sentiment", # Column for pie slice names
612
+ title="Pie Chart: Comment Sentiments",
613
+ color="Sentiment", # Color slices based on the 'Sentiment' category
614
+ color_discrete_map=color_map,
615
+ ) # Apply custom colors
616
+
617
+ fig_pie.update_traces(
618
+ textposition="inside",
619
+ textinfo="percent+label",
620
+ hovertemplate="Sentiment: %{label}<br>Count: %{value}<br>Percentage: %{percent}",
621
+ )
622
+
623
+ fig_pie.update_layout(
624
+ uniformtext_minsize=16, uniformtext_mode="hide"
625
+ )
626
+
627
+ st.plotly_chart(fig_pie, use_container_width=True)
628
+
629
+ # Bar Chart (Using Plotly for consistent coloring)
630
+ fig_bar = px.bar(
631
+ sentiment_counts_df_for_plot, # Use the filtered DataFrame
632
+ x="Sentiment", # Categories on X-axis
633
+ y="Count", # Values on Y-axis
634
+ title="Bar Chart: Comment Sentiments",
635
+ color="Sentiment", # Color bars based on 'Sentiment'
636
+ color_discrete_map=color_map, # Apply custom colors
637
+ labels={
638
+ "Count": "Number of Comments",
639
+ "Sentiment": "Sentiment Category",
640
+ },
641
+ ) # Custom labels
642
+ st.plotly_chart(fig_bar, use_container_width=True)
643
+
644
+ except Exception as plot_e:
645
+ # Fallback if Plotly fails for some reason other than import
646
+ st.error(
647
+ f"Sorry, couldn't create Plotly charts: {plot_e}"
648
+ )
649
+ st.write(
650
+ "Displaying basic bar chart instead (default colors):"
651
+ )
652
+ st.bar_chart(
653
+ sentiment_counts_df.set_index("Sentiment")
654
+ ) # Fallback with original (unfiltered for bar)
655
+ else:
656
+ # Fallback to Streamlit's basic bar chart if Plotly is not installed
657
+ st.write(
658
+ "Displaying basic bar chart (Plotly not installed):"
659
+ )
660
+ st.bar_chart(
661
+ sentiment_counts_df.set_index("Sentiment")
662
+ ) # Basic bar chart
663
+ else:
664
+ # This message shows if all sentiment counts are zero
665
+ st.write(
666
+ "No sentiment data (Positive, Neutral, Negative all zero) to display in charts."
667
+ )
668
+ else:
669
+ # This message shows if no comments were analyzed successfully
670
+ st.write(
671
+ "Not enough valid sentiment data to display distribution charts."
672
+ )
673
+
674
+ # Display comments and their sentiments
675
+ if comments_data:
676
+ st.subheader(
677
+ f"🔍 Analyzed Comments (showing first {len(comments_data)} results)"
678
+ )
679
+ comments_display_df = pd.DataFrame(comments_data)
680
+
681
+ if "Confidence" in comments_display_df.columns:
682
+ try:
683
+ # Format as percentage with 1 decimal place
684
+ comments_display_df["Confidence"] = comments_display_df[
685
+ "Confidence"
686
+ ].map("{:.1%}".format)
687
+ except (TypeError, ValueError):
688
+ st.warning(
689
+ "Could not format confidence scores."
690
+ ) # Handle potential errors if confidence is not numeric
691
+
692
+ st.dataframe(
693
+ comments_display_df, use_container_width=True, height=400
694
+ )
695
+ else:
696
+ st.write("No comments were analyzed to display.")
697
+ # else: # analyze_youtube_video already handles its own errors by showing st.error
698
+ # st.info("Could not complete analysis. Please check the URL or try again.")
699
+ else:
700
+ # If user clicks button without entering URL
701
+ st.warning("Please enter a YouTube URL or Video ID first!")
702
+
703
+ with tab_twitter:
704
+ st.header("Twitter/X Post Analysis")
705
+ st.info("This feature is currently under construction. Please check back later!")
706
+ # Placeholder for future Twitter input
707
+ # twitter_url_input = st.text_input("Enter Twitter/X Post URL:", key="twitter_url_input_key")
708
+ # if st.button("Analyze Tweets", key="twitter_analyze_button_key"):
709
+ # st.write("Imagine amazing Twitter analysis happening here... Tweet tweet!")
src/youtube.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import googleapiclient.discovery
3
+ import googleapiclient.errors
4
+
5
+ # from dotenv import load_dotenv
6
+ import streamlit as st
7
+
8
+ # load_dotenv()
9
+ # api_key = os.getenv("API_KEY")
10
+ api_key = st.secrets["API_KEY"]
11
+
12
+
13
+ def get_comments(youtube, **kwargs):
14
+ comments = []
15
+ results = youtube.commentThreads().list(**kwargs).execute()
16
+
17
+ while results:
18
+ for item in results["items"]:
19
+ comment = item["snippet"]["topLevelComment"]["snippet"]["textDisplay"]
20
+ comments.append(comment)
21
+
22
+ # check if there are more comments
23
+ if "nextPageToken" in results:
24
+ kwargs["pageToken"] = results["nextPageToken"]
25
+ results = youtube.commentThreads().list(**kwargs).execute()
26
+ else:
27
+ break
28
+
29
+ return comments
30
+
31
+
32
+ def main(video_id, api_key):
33
+ # Disable OAuthlib's HTTPs verification when running locally.
34
+ os.environ["OAUTHLIB_INSECURE_TRANSPORT"] = "1"
35
+
36
+ youtube = googleapiclient.discovery.build("youtube", "v3", developerKey=api_key)
37
+
38
+ video_title = "N/A" # Provide a default title
39
+
40
+ try:
41
+ # Get video details using the videos().list endpoint
42
+ print(f"DEBUG (youtube.py): Fetching video details for ID: {video_id}")
43
+ video_response = (
44
+ youtube.videos()
45
+ .list(
46
+ part="snippet", # 'snippet' contains title, description, channel etc.
47
+ id=video_id, # The ID of the video we want info for
48
+ )
49
+ .execute()
50
+ )
51
+
52
+ # Extract the title from the response
53
+ # It's usually nested like this, good to check if 'items' exists
54
+ if video_response.get("items"):
55
+ video_title = video_response["items"][0]["snippet"]["title"]
56
+ print(f"DEBUG (youtube.py): Found title: '{video_title}'") # Just a check
57
+ else:
58
+ print(f"WARN (youtube.py): No video items found for ID: {video_id}")
59
+ video_title = "Video Not Found or Private" # More informative default
60
+
61
+ except Exception as e:
62
+ print(
63
+ f"ERROR (youtube.py): Failed to fetch video title for ID {video_id}. Error: {e}"
64
+ )
65
+ video_title = "Error Fetching Title" # Error specific default
66
+ # Depending on requirements, maybe we still want to proceed to get comments?
67
+
68
+ comments = get_comments(
69
+ youtube, part="snippet", videoId=video_id, textFormat="plainText"
70
+ )
71
+ # return comments
72
+ # Return a dictionary containing both title and comments
73
+ return {"title": video_title, "comments": comments}
74
+
75
+
76
+ def get_video_comments(video_id):
77
+ return main(video_id, api_key)