from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM import torch import streamlit as st from PIL import Image import pytesseract import pandas as pd import plotly.express as px # ✅ Step 1: Emoji 翻译模型(你自己训练的模型) emoji_model_id = "JenniferHJF/qwen1.5-emoji-finetuned" emoji_tokenizer = AutoTokenizer.from_pretrained(emoji_model_id, trust_remote_code=True) emoji_model = AutoModelForCausalLM.from_pretrained( emoji_model_id, trust_remote_code=True, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 ).to("cuda" if torch.cuda.is_available() else "cpu") emoji_model.eval() # ✅ Step 2: 可选择的冒犯性文本识别模型 model_options = { "Toxic-BERT": "unitary/toxic-bert", "Roberta Offensive": "cardiffnlp/twitter-roberta-base-offensive", "BERT Emotion": "bhadresh-savani/bert-base-go-emotion" } # ✅ 页面配置 st.set_page_config(page_title="Emoji Offensive Text Detector", page_icon="🚨", layout="wide") # ✅ 页面布局 with st.sidebar: st.header("🧠 Navigation") section = st.radio("Select Mode:", ["📍 Text Moderation", "📊 Text Analysis"]) if section == "📍 Text Moderation": selected_model = st.selectbox("Choose classification model", list(model_options.keys())) selected_model_id = model_options[selected_model] classifier = pipeline("text-classification", model=selected_model_id, device=0 if torch.cuda.is_available() else -1) elif section == "📊 Text Analysis": st.markdown("You can view the violation distribution chart and editing suggestions.") if "history" not in st.session_state: st.session_state.history = [] def classify_emoji_text(text: str): prompt = f"输入:{text}\n输出:" input_ids = emoji_tokenizer(prompt, return_tensors="pt").to(emoji_model.device) with torch.no_grad(): output_ids = emoji_model.generate(**input_ids, max_new_tokens=64, do_sample=False) decoded = emoji_tokenizer.decode(output_ids[0], skip_special_tokens=True) translated_text = decoded.split("输出:")[-1].strip() if "输出:" in decoded else decoded.strip() result = classifier(translated_text)[0] label = result["label"] score = result["score"] reasoning = f"The sentence was flagged as '{label}' due to potentially offensive phrases. Consider replacing emotionally charged, ambiguous, or abusive terms." st.session_state.history.append({"text": text, "translated": translated_text, "label": label, "score": score, "reason": reasoning}) return translated_text, label, score, reasoning # ✅ Section logic if section == "📍 Text Moderation": st.title("📍 Offensive Text Classification") st.markdown("### ✍️ Input your sentence:") default_text = "你是🐷" text = st.text_area("Enter sentence with emojis:", value=default_text, height=150) if st.button("🚦 Analyze"): with st.spinner("🔍 Processing..."): try: translated, label, score, reason = classify_emoji_text(text) st.markdown("### 🔄 Translated sentence:") st.code(translated, language="text") st.markdown(f"### 🎯 Prediction: {label}") st.markdown(f"### 📊 Confidence Score: {score:.2%}") st.markdown(f"### 🧠 Model Explanation:") st.info(reason) except Exception as e: st.error(f"❌ An error occurred during processing:\n\n{e}") st.markdown("---") st.markdown("### 🖼️ Or upload a screenshot of bullet comments:") uploaded_file = st.file_uploader("Upload an image (JPG/PNG)", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: image = Image.open(uploaded_file) st.image(image, caption="Uploaded Screenshot", use_column_width=True) with st.spinner("🧠 Extracting text via OCR..."): ocr_text = pytesseract.image_to_string(image, lang="chi_sim+eng") st.markdown("#### 📋 Extracted Text:") st.code(ocr_text.strip()) translated, label, score, reason = classify_emoji_text(ocr_text.strip()) st.markdown("### 🔄 Translated sentence:") st.code(translated, language="text") st.markdown(f"### 🎯 Prediction: {label}") st.markdown(f"### 📊 Confidence Score: {score:.2%}") st.markdown("### 🧠 Model Explanation:") st.info(reason) elif section == "📊 Text Analysis": st.title("📊 Violation Analysis Dashboard") if st.session_state.history: df = pd.DataFrame(st.session_state.history) # 已移除 Offensive Category Distribution 饼图 st.markdown("### 🧾 Offensive Terms & Suggestions") for item in st.session_state.history: st.markdown(f"- 🔹 **Input:** {item['text']}") st.markdown(f" - ✨ **Translated:** {item['translated']}") st.markdown(f" - ❗ **Label:** {item['label']} with **{item['score']:.2%}** confidence") st.markdown(f" - 🔧 **Suggestion:** {item['reason']}") radar_df = pd.DataFrame({ "Category": ["Insult", "Abuse", "Discrimination", "Hate Speech", "Vulgarity"], "Score": [0.7, 0.4, 0.3, 0.5, 0.6] }) radar_fig = px.line_polar(radar_df, r='Score', theta='Category', line_close=True, title="⚠️ Risk Radar by Category") radar_fig.update_traces(line_color='black') # 将雷达图线条改为黑色 st.plotly_chart(radar_fig) else: st.info("⚠️ No classification data available yet.")