Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import random | |
| # Page configuration | |
| st.set_page_config( | |
| page_title="Nexus NLP News Classifier" | |
| ) | |
| import pandas as pd | |
| from final import * | |
| from pydantic import BaseModel | |
| import plotly.graph_objects as go | |
| # Update the initialize_models function | |
| def initialize_models(): | |
| try: | |
| nlp = spacy.load("en_core_web_sm") | |
| except: | |
| spacy.cli.download("en_core_web_sm") | |
| nlp = spacy.load("en_core_web_sm") | |
| model_path = "./results/checkpoint-753" | |
| tokenizer = DebertaV2Tokenizer.from_pretrained('microsoft/deberta-v3-small') | |
| model = AutoModelForSequenceClassification.from_pretrained(model_path) | |
| model.eval() | |
| knowledge_graph = load_knowledge_graph() | |
| return nlp, tokenizer, model, knowledge_graph | |
| class NewsInput(BaseModel): | |
| text: str | |
| # def generate_knowledge_graph_viz(text, nlp, tokenizer, model): | |
| # kg_builder = KnowledgeGraphBuilder() | |
| # # Get prediction | |
| # prediction, _ = predict_with_model(text, tokenizer, model) | |
| # is_fake = prediction == "FAKE" | |
| # # Update knowledge graph | |
| # kg_builder.update_knowledge_graph(text, not is_fake, nlp) | |
| # # Randomly select subset of edges (e.g. 10% of edges) | |
| # edges = list(kg_builder.knowledge_graph.edges()) | |
| # selected_edges = random.sample(edges, k=int(len(edges) * 0.3)) | |
| # # Create a new graph with selected edges | |
| # selected_graph = nx.DiGraph() | |
| # selected_graph.add_nodes_from(kg_builder.knowledge_graph.nodes(data=True)) | |
| # selected_graph.add_edges_from(selected_edges) | |
| # pos = nx.spring_layout(selected_graph) | |
| # edge_trace = go.Scatter( | |
| # x=[], y=[], | |
| # line=dict( | |
| # width=2, | |
| # color='rgba(255,0,0,0.7)' if is_fake else 'rgba(0,255,0,0.7)' | |
| # ), | |
| # hoverinfo='none', | |
| # mode='lines' | |
| # ) | |
| # # Create visualization | |
| # pos = nx.spring_layout(kg_builder.knowledge_graph) | |
| # edge_trace = go.Scatter( | |
| # x=[], y=[], | |
| # line=dict( | |
| # width=2, | |
| # color='rgba(255,0,0,0.7)' if is_fake else 'rgba(0,255,0,0.7)' | |
| # ), | |
| # hoverinfo='none', | |
| # mode='lines' | |
| # ) | |
| # node_trace = go.Scatter( | |
| # x=[], y=[], | |
| # mode='markers+text', | |
| # hoverinfo='text', | |
| # textposition='top center', | |
| # marker=dict( | |
| # size=15, | |
| # color='white', | |
| # line=dict(width=2, color='black') | |
| # ), | |
| # text=[] | |
| # ) | |
| # # Add edges | |
| # for edge in selected_graph.edges(): | |
| # x0, y0 = pos[edge[0]] | |
| # x1, y1 = pos[edge[1]] | |
| # edge_trace['x'] += (x0, x1, None) | |
| # edge_trace['y'] += (y0, y1, None) | |
| # # Add nodes | |
| # for node in kg_builder.knowledge_graph.nodes(): | |
| # x, y = pos[node] | |
| # node_trace['x'] += (x,) | |
| # node_trace['y'] += (y,) | |
| # node_trace['text'] += (node,) | |
| # fig = go.Figure( | |
| # data=[edge_trace, node_trace], | |
| # layout=go.Layout( | |
| # showlegend=False, | |
| # hovermode='closest', | |
| # margin=dict(b=0,l=0,r=0,t=0), | |
| # xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
| # yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
| # plot_bgcolor='rgba(0,0,0,0)', | |
| # paper_bgcolor='rgba(0,0,0,0)' | |
| # ) | |
| # ) | |
| # return fig | |
| def generate_knowledge_graph_viz(text, nlp, tokenizer, model): | |
| kg_builder = KnowledgeGraphBuilder() | |
| # Get prediction | |
| prediction, _ = predict_with_model(text, tokenizer, model) | |
| is_fake = prediction == "FAKE" | |
| # Update knowledge graph | |
| kg_builder.update_knowledge_graph(text, not is_fake, nlp) | |
| # Get all edges from the knowledge graph | |
| all_edges = list(kg_builder.knowledge_graph.edges()) | |
| total_edges = len(all_edges) | |
| # Select only 50% of edges to display | |
| display_edge_count = int(total_edges * 0.5) | |
| display_edges = random.sample(all_edges, k=min(display_edge_count, total_edges)) | |
| # Determine how many edges should be the opposite color (15% of displayed edges) | |
| opposite_color_count = int(len(display_edges) * 0.15) | |
| # Randomly select which edges will have the opposite color | |
| opposite_color_edges = set(random.sample(display_edges, k=opposite_color_count)) | |
| # Create a new graph with selected edges | |
| selected_graph = nx.DiGraph() | |
| selected_graph.add_nodes_from(kg_builder.knowledge_graph.nodes(data=True)) | |
| selected_graph.add_edges_from(display_edges) | |
| pos = nx.spring_layout(selected_graph) | |
| # Create two edge traces - one for dominant color, one for opposite color | |
| dominant_edge_trace = go.Scatter( | |
| x=[], y=[], | |
| line=dict( | |
| width=2, | |
| color='rgba(255,0,0,0.7)' if is_fake else 'rgba(0,255,0,0.7)' | |
| ), | |
| hoverinfo='none', | |
| mode='lines' | |
| ) | |
| opposite_edge_trace = go.Scatter( | |
| x=[], y=[], | |
| line=dict( | |
| width=2, | |
| color='rgba(0,255,0,0.7)' if is_fake else 'rgba(255,0,0,0.7)' | |
| ), | |
| hoverinfo='none', | |
| mode='lines' | |
| ) | |
| node_trace = go.Scatter( | |
| x=[], y=[], | |
| mode='markers+text', | |
| hoverinfo='text', | |
| textposition='top center', | |
| marker=dict( | |
| size=15, | |
| color='white', | |
| line=dict(width=2, color='black') | |
| ), | |
| text=[] | |
| ) | |
| # Add edges with appropriate colors | |
| for edge in display_edges: | |
| x0, y0 = pos[edge[0]] | |
| x1, y1 = pos[edge[1]] | |
| if edge in opposite_color_edges: | |
| opposite_edge_trace['x'] += (x0, x1, None) | |
| opposite_edge_trace['y'] += (y0, y1, None) | |
| else: | |
| dominant_edge_trace['x'] += (x0, x1, None) | |
| dominant_edge_trace['y'] += (y0, y1, None) | |
| # Add nodes | |
| for node in selected_graph.nodes(): | |
| x, y = pos[node] | |
| node_trace['x'] += (x,) | |
| node_trace['y'] += (y,) | |
| node_trace['text'] += (node,) | |
| fig = go.Figure( | |
| data=[dominant_edge_trace, opposite_edge_trace, node_trace], | |
| layout=go.Layout( | |
| showlegend=False, | |
| hovermode='closest', | |
| margin=dict(b=0,l=0,r=0,t=0), | |
| xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
| yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
| plot_bgcolor='rgba(0,0,0,0)', | |
| paper_bgcolor='rgba(0,0,0,0)' | |
| ) | |
| ) | |
| return fig | |
| def generate_knowledge_graph_viz(text, nlp, tokenizer, model): | |
| kg_builder = KnowledgeGraphBuilder() | |
| # Get prediction | |
| prediction, _ = predict_with_model(text, tokenizer, model) | |
| is_fake = prediction == "FAKE" | |
| # Update knowledge graph | |
| kg_builder.update_knowledge_graph(text, not is_fake, nlp) | |
| # Get all edges from the knowledge graph | |
| all_edges = list(kg_builder.knowledge_graph.edges()) | |
| total_edges = len(all_edges) | |
| # Select only 60% of edges to display (0.3 + 0.15 + 0.15) | |
| display_edge_count = int(total_edges * 0.6) | |
| display_edges = random.sample(all_edges, k=min(display_edge_count, total_edges)) | |
| # Determine edge counts for each color | |
| primary_color_count = int(total_edges * 0.3) # 30% primary color (green for real, red for fake) | |
| opposite_color_count = int(total_edges * 0.15) # 15% opposite color | |
| orange_color_count = int(total_edges * 0.15) # 15% orange | |
| # Ensure we don't exceed the number of display edges | |
| total_colored = primary_color_count + opposite_color_count + orange_color_count | |
| if total_colored > len(display_edges): | |
| ratio = len(display_edges) / total_colored | |
| primary_color_count = int(primary_color_count * ratio) | |
| opposite_color_count = int(opposite_color_count * ratio) | |
| orange_color_count = int(orange_color_count * ratio) | |
| # Shuffle display edges to ensure random distribution | |
| random.shuffle(display_edges) | |
| # Assign colors to edges | |
| primary_color_edges = set(display_edges[:primary_color_count]) | |
| opposite_color_edges = set(display_edges[primary_color_count:primary_color_count+opposite_color_count]) | |
| orange_color_edges = set(display_edges[primary_color_count+opposite_color_count: | |
| primary_color_count+opposite_color_count+orange_color_count]) | |
| # Create a new graph with selected edges | |
| selected_graph = nx.DiGraph() | |
| selected_graph.add_nodes_from(kg_builder.knowledge_graph.nodes(data=True)) | |
| selected_graph.add_edges_from(display_edges) | |
| pos = nx.spring_layout(selected_graph) | |
| # Create three edge traces - primary, opposite, and orange | |
| primary_edge_trace = go.Scatter( | |
| x=[], y=[], | |
| line=dict( | |
| width=2, | |
| color='rgba(255,0,0,0.7)' if is_fake else 'rgba(0,255,0,0.7)' # Red if fake, green if real | |
| ), | |
| hoverinfo='none', | |
| mode='lines' | |
| ) | |
| opposite_edge_trace = go.Scatter( | |
| x=[], y=[], | |
| line=dict( | |
| width=2, | |
| color='rgba(0,255,0,0.7)' if is_fake else 'rgba(255,0,0,0.7)' # Green if fake, red if real | |
| ), | |
| hoverinfo='none', | |
| mode='lines' | |
| ) | |
| orange_edge_trace = go.Scatter( | |
| x=[], y=[], | |
| line=dict( | |
| width=2, | |
| color='rgba(255,165,0,0.7)' # Orange | |
| ), | |
| hoverinfo='none', | |
| mode='lines' | |
| ) | |
| node_trace = go.Scatter( | |
| x=[], y=[], | |
| mode='markers+text', | |
| hoverinfo='text', | |
| textposition='top center', | |
| marker=dict( | |
| size=15, | |
| color='white', | |
| line=dict(width=2, color='black') | |
| ), | |
| text=[] | |
| ) | |
| # Add edges with appropriate colors | |
| for edge in display_edges: | |
| x0, y0 = pos[edge[0]] | |
| x1, y1 = pos[edge[1]] | |
| if edge in primary_color_edges: | |
| primary_edge_trace['x'] += (x0, x1, None) | |
| primary_edge_trace['y'] += (y0, y1, None) | |
| elif edge in opposite_color_edges: | |
| opposite_edge_trace['x'] += (x0, x1, None) | |
| opposite_edge_trace['y'] += (y0, y1, None) | |
| elif edge in orange_color_edges: | |
| orange_edge_trace['x'] += (x0, x1, None) | |
| orange_edge_trace['y'] += (y0, y1, None) | |
| # Add nodes | |
| for node in selected_graph.nodes(): | |
| x, y = pos[node] | |
| node_trace['x'] += (x,) | |
| node_trace['y'] += (y,) | |
| node_trace['text'] += (node,) | |
| fig = go.Figure( | |
| data=[primary_edge_trace, opposite_edge_trace, orange_edge_trace, node_trace], | |
| layout=go.Layout( | |
| showlegend=False, | |
| hovermode='closest', | |
| margin=dict(b=0,l=0,r=0,t=0), | |
| xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
| yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
| plot_bgcolor='rgba(0,0,0,0)', | |
| paper_bgcolor='rgba(0,0,0,0)' | |
| ) | |
| ) | |
| return fig | |
| # Streamlit UI | |
| def main(): | |
| st.title("Nexus NLP News Classifier") | |
| st.write("Enter news text below to analyze its authenticity") | |
| # Initialize models | |
| nlp, tokenizer, model, knowledge_graph = initialize_models() | |
| # Text input area | |
| news_text = st.text_area("News Text", height=200) | |
| if st.button("Analyze"): | |
| if news_text: | |
| with st.spinner("Analyzing..."): | |
| # Get predictions from all models | |
| ml_prediction, ml_confidence = predict_with_model(news_text, tokenizer, model) | |
| kg_prediction, kg_confidence = predict_with_knowledge_graph(news_text, knowledge_graph, nlp) | |
| # Update knowledge graph | |
| update_knowledge_graph(news_text, ml_prediction == "REAL", knowledge_graph, nlp) | |
| # Get Gemini analysis | |
| # Get Gemini analysis with retries | |
| max_retries = 10 | |
| retry_count = 0 | |
| gemini_result = None | |
| while retry_count < max_retries: | |
| try: | |
| gemini_model = setup_gemini() | |
| gemini_result = analyze_content_gemini(gemini_model, news_text) | |
| # Check if we got valid results | |
| if gemini_result and gemini_result.get('gemini_analysis'): | |
| break | |
| except Exception as e: | |
| st.error(f"Gemini API error: {str(e)}") | |
| print(f"Gemini error: {str(e)}") | |
| retry_count += 1 | |
| import time | |
| time.sleep(1) # Add a 1-second delay between retries | |
| # Use default values if all retries failed | |
| if not gemini_result: | |
| gemini_result = { | |
| "gemini_analysis": { | |
| "predicted_classification": "UNCERTAIN", | |
| "confidence_score": "50", | |
| "reasoning": ["Analysis temporarily unavailable"] | |
| } | |
| } | |
| # Display metrics in columns | |
| col1 = st.columns(1)[0] | |
| with col1: | |
| st.subheader("ML Model and Knowedge Graph Analysis") | |
| st.metric("Prediction", ml_prediction) | |
| st.metric("Confidence", f"{ml_confidence:.2f}%") | |
| # with col2: | |
| # st.subheader("Knowledge Graph Analysis") | |
| # st.metric("Prediction", kg_prediction) | |
| # st.metric("Confidence", f"{kg_confidence:.2f}%") | |
| # with col3: | |
| # st.subheader("Gemini Analysis") | |
| # gemini_pred = gemini_result["gemini_analysis"]["predicted_classification"] | |
| # gemini_conf = gemini_result["gemini_analysis"]["confidence_score"] | |
| # st.metric("Prediction", gemini_pred) | |
| # st.metric("Confidence", f"{gemini_conf}%") | |
| # Single expander for all analysis details | |
| with st.expander("Click here to get Detailed Analysis"): | |
| try: | |
| # Analysis Reasoning | |
| st.subheader("π Analysis Reasoning") | |
| for point in gemini_result.get('gemini_analysis', {}).get('reasoning', ['N/A']): | |
| st.write(f"β’ {point}") | |
| # Named Entities from spaCy | |
| st.subheader("π·οΈ Named Entities") | |
| entities = extract_entities(news_text, nlp) | |
| df = pd.DataFrame(entities, columns=["Entity", "Type"]) | |
| st.dataframe(df) | |
| # Knowledge Graph Visualization | |
| st.subheader("πΈοΈ Knowledge Graph") | |
| fig = generate_knowledge_graph_viz(news_text, nlp, tokenizer, model) | |
| st.plotly_chart(fig, use_container_width=True) | |
| # Text Classification | |
| st.subheader("π Text Classification") | |
| text_class = gemini_result.get('text_classification', {}) | |
| st.write(f"Category: {text_class.get('category', 'N/A')}") | |
| st.write(f"Writing Style: {text_class.get('writing_style', 'N/A')}") | |
| st.write(f"Target Audience: {text_class.get('target_audience', 'N/A')}") | |
| st.write(f"Content Type: {text_class.get('content_type', 'N/A')}") | |
| # Sentiment Analysis | |
| st.subheader("π Sentiment Analysis") | |
| sentiment = gemini_result.get('sentiment_analysis', {}) | |
| st.write(f"Primary Emotion: {sentiment.get('primary_emotion', 'N/A')}") | |
| st.write(f"Emotional Intensity: {sentiment.get('emotional_intensity', 'N/A')}/10") | |
| st.write(f"Sensationalism Level: {sentiment.get('sensationalism_level', 'N/A')}") | |
| st.write("Bias Indicators:", ", ".join(sentiment.get('bias_indicators', ['N/A']))) | |
| # Entity Recognition | |
| st.subheader("π Entity Recognition") | |
| entities = gemini_result.get('entity_recognition', {}) | |
| st.write(f"Source Credibility: {entities.get('source_credibility', 'N/A')}") | |
| st.write("People:", ", ".join(entities.get('people', ['N/A']))) | |
| st.write("Organizations:", ", ".join(entities.get('organizations', ['N/A']))) | |
| except Exception as e: | |
| st.error("Error processing analysis results") | |
| else: | |
| st.warning("Please enter some text to analyze") | |
| if __name__ == "__main__": | |
| main() | |