Spaces:
Sleeping
Sleeping
import os, re | |
import numpy as np | |
import pandas as pd | |
import gradio as gr | |
import faiss | |
import torch | |
from typing import List | |
from sentence_transformers import SentenceTransformer, CrossEncoder | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
# ---- Config ---- | |
FLAN_PRIMARY = os.getenv("FLAN_PRIMARY", "google/flan-t5-large") | |
EMBED_NAME = "sentence-transformers/all-mpnet-base-v2" | |
RERANK_NAME = "cross-encoder/stsb-roberta-base" | |
NUM_SLOGAN_SAMPLES = int(os.getenv("NUM_SLOGAN_SAMPLES", "16")) | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
ASSETS_DIR = "assets" | |
# ---- Lazy models ---- | |
_GEN_TOK = None | |
_GEN_MODEL = None | |
_EMBED_MODEL = None | |
_RERANKER = None | |
def _ensure_models(): | |
global _GEN_TOK, _GEN_MODEL, _EMBED_MODEL, _RERANKER | |
if _EMBED_MODEL is None: | |
_EMBED_MODEL = SentenceTransformer(EMBED_NAME) | |
if _RERANKER is None: | |
_RERANKER = CrossEncoder(RERANK_NAME) | |
if _GEN_MODEL is None: | |
tok = AutoTokenizer.from_pretrained(FLAN_PRIMARY) | |
mdl = AutoModelForSeq2SeqLM.from_pretrained(FLAN_PRIMARY) | |
_GEN_TOK, _GEN_MODEL = tok, mdl.to(DEVICE) | |
print(f"[INFO] Loaded generator: {FLAN_PRIMARY}") | |
# ---- Data & PRE-BUILT FAISS from assets folder ---- | |
_DATA_DF = None | |
_INDEX = None | |
_EMBEDDINGS = None | |
def _ensure_index(): | |
global _DATA_DF, _INDEX, _EMBEDDINGS | |
if _INDEX is not None: | |
return | |
# Load assets from the assets directory | |
try: | |
data_path = os.path.join(ASSETS_DIR, "data.parquet") | |
index_path = os.path.join(ASSETS_DIR, "faiss.index") | |
emb_path = os.path.join(ASSETS_DIR, "embeddings.npy") | |
_DATA_DF = pd.read_parquet(data_path) | |
_INDEX = faiss.read_index(index_path) | |
_EMBEDDINGS = np.load(emb_path) | |
print(f"[INFO] Loaded pre-built FAISS index. rows={len(_DATA_DF)}, dim={_INDEX.d}") | |
except FileNotFoundError: | |
print("[ERROR] Pre-built assets not found. The space may fail to run.") | |
print("[INFO] Falling back to building a tiny demo index.") | |
_DATA_DF = pd.DataFrame({ | |
"name": ["HowDidIDo", "Museotainment", "Movitr"], | |
"tagline": ["Online evaluation platform", "PacMan & Louvre meet", "Crowdsourced video translation"], | |
"description": [ | |
"Public speaking, Presentation skills and interview practice", | |
"Interactive AR museum tours", | |
"Video translation with voice and subtitles" | |
] | |
}) | |
_ensure_models() | |
vecs = _EMBED_MODEL.encode(_DATA_DF["description"].astype(str).tolist(), normalize_embeddings=True).astype(np.float32) | |
_INDEX = faiss.IndexFlatIP(vecs.shape[1]) | |
_INDEX.add(vecs) | |
def recommend(query_text: str, top_k: int = 3) -> pd.DataFrame: | |
_ensure_index() | |
_ensure_models() | |
q_vec = _EMBED_MODEL.encode([query_text], normalize_embeddings=True).astype("float32") | |
scores, idxs = _INDEX.search(q_vec, top_k) | |
out = _DATA_DF.iloc[idxs[0]].copy() | |
out["score"] = scores[0] | |
return out[["name","tagline","description","score"]] | |
# ---- Refined v2 slogan generator (unchanged logic) ---- | |
BLOCK_PATTERNS = [ | |
r"^[A-Z][a-z]+ [A-Z][a-z]+ (Platform|Solution|System|Application|Marketplace)$", | |
r"^[A-Z][a-z]+ [A-Z][a-z]+$", | |
r"^[A-Z][a-z]+$", | |
] | |
HARD_BLOCK_WORDS = { | |
"platform","solution","system","application","marketplace", | |
"ai-powered","ai powered","empower","empowering", | |
"artificial intelligence","machine learning","augmented reality","virtual reality", | |
} | |
GENERIC_WORDS = {"app","assistant","smart","ai","ml","ar","vr","decentralized","blockchain"} | |
MARKETING_VERBS = {"build","grow","simplify","discover","create","connect","transform","unlock","boost","learn","move","clarify"} | |
BENEFIT_WORDS = {"faster","smarter","easier","better","safer","clearer","stronger","together","confidently","simply","instantly"} | |
GOOD_SLOGANS_TO_AVOID_DUP = { | |
"smarter care, faster decisions", | |
"checkout built for small brands", | |
"less guessing. more healing.", | |
"built to grow with your cart.", | |
"stand tall. feel better.", | |
"train your brain to win.", | |
"your body. your algorithm.", | |
"play smarter. grow brighter.", | |
"style that thinks with you." | |
} | |
def _tokens(s: str) -> List[str]: | |
return re.findall(r"[a-z0-9]{3,}", s.lower()) | |
def _jaccard(a: List[str], b: List[str]) -> float: | |
A, B = set(a), set(b) | |
return 0.0 if not A or not B else len(A & B) / len(A | B) | |
def _titlecase_soft(s: str) -> str: | |
out = [] | |
for w in s.split(): | |
out.append(w if w.isupper() else w.capitalize()) | |
return " ".join(out) | |
def _is_blocked_slogan(s: str) -> bool: | |
if not s: return True | |
s_strip = s.strip() | |
for pat in BLOCK_PATTERNS: | |
if re.match(pat, s_strip): | |
return True | |
s_low = s_strip.lower() | |
for w in HARD_BLOCK_WORDS: | |
if w in s_low: | |
return True | |
if s_low in GOOD_SLOGANS_TO_AVOID_DUP: | |
return True | |
return False | |
def _generic_penalty(s: str) -> float: | |
hits = sum(1 for w in GENERIC_WORDS if w in s.lower()) | |
return min(1.0, 0.25 * hits) | |
def _for_penalty(s: str) -> float: | |
return 0.3 if re.search(r"\bfor\b", s.lower()) else 0.0 | |
def _neighbor_context(neighbors_df: pd.DataFrame) -> str: | |
if neighbors_df is None or neighbors_df.empty: | |
return "" | |
examples = [] | |
for _, row in neighbors_df.head(3).iterrows(): | |
tg = str(row.get("tagline", "")).strip() | |
if 5 <= len(tg) <= 70: | |
examples.append(f"- {tg}") | |
return "\n".join(examples) | |
def _copies_neighbor(s: str, neighbors_df: pd.DataFrame) -> bool: | |
if neighbors_df is None or neighbors_df.empty: | |
return False | |
s_low = s.lower() | |
s_toks = _tokens(s_low) | |
for _, row in neighbors_df.iterrows(): | |
t = str(row.get("tagline", "")).strip() | |
if not t: | |
continue | |
t_low = t.lower() | |
if s_low == t_low: | |
return True | |
if _jaccard(s_toks, _tokens(t_low)) >= 0.7: | |
return True | |
try: | |
_ensure_models() | |
s_vec = _EMBED_MODEL.encode([s])[0]; s_vec = s_vec / np.linalg.norm(s_vec) | |
for _, row in neighbors_df.head(3).iterrows(): | |
t = str(row.get("tagline", "")).strip() | |
if not t: continue | |
t_vec = _EMBED_MODEL.encode([t])[0]; t_vec = t_vec / np.linalg.norm(t_vec) | |
if float(np.dot(s_vec, t_vec)) >= 0.85: | |
return True | |
except Exception: | |
pass | |
return False | |
def _clean_slogan(text: str, max_words: int = 8) -> str: | |
text = text.strip().split("\n")[0] | |
text = re.sub(r"[\"“”‘’]", "", text) | |
text = re.sub(r"\s+", " ", text).strip() | |
text = re.sub(r"^\W+|\W+$", "", text) | |
words = text.split() | |
if len(words) > max_words: | |
text = " ".join(words[:max_words]) | |
return text | |
def _score_candidates(query: str, cands: List[str], neighbors_df: pd.DataFrame) -> List[tuple]: | |
if not cands: | |
return [] | |
_ensure_models() | |
ce_scores = np.asarray(_RERANKER.predict([(query, s) for s in cands]), dtype=np.float32) / 5.0 | |
q_toks = _tokens(query) | |
results = [] | |
neighbor_vecs = [] | |
if neighbors_df is not None and not neighbors_df.empty: | |
_ensure_models() | |
for _, row in neighbors_df.head(3).iterrows(): | |
t = str(row.get("tagline","")).strip() | |
if t: | |
v = _EMBED_MODEL.encode([t])[0] | |
neighbor_vecs.append(v / np.linalg.norm(v)) | |
for i, s in enumerate(cands): | |
words = s.split() | |
brevity = 1.0 - min(1.0, abs(len(words) - 5) / 5.0) | |
wl = set(w.lower() for w in words) | |
m_hits = len(wl & MARKETING_VERBS) | |
b_hits = len(wl & BENEFIT_WORDS) | |
marketing = min(1.0, 0.2*m_hits + 0.2*b_hits) | |
g_pen = _generic_penalty(s) | |
f_pen = _for_penalty(s) | |
n_pen = 0.0 | |
if neighbor_vecs: | |
try: | |
_ensure_models() | |
s_vec = _EMBED_MODEL.encode([s])[0]; s_vec = s_vec / np.linalg.norm(s_vec) | |
sim_max = max(float(np.dot(s_vec, nv)) for nv in neighbor_vecs) if neighbor_vecs else 0.0 | |
n_pen = sim_max | |
except Exception: | |
n_pen = 0.0 | |
overlap = _jaccard(q_toks, _tokens(s)) | |
anti_copy = 1.0 - overlap | |
score = ( | |
0.55*float(ce_scores[i]) + | |
0.20*brevity + | |
0.15*marketing + | |
0.03*anti_copy - | |
0.07*g_pen - | |
0.03*f_pen - | |
0.10*n_pen | |
) | |
results.append((s, float(score))) | |
return results | |
def generate_slogan(query_text: str, neighbors_df: pd.DataFrame = None, n_samples: int = NUM_SLOGAN_SAMPLES) -> str: | |
_ensure_models() | |
ctx = _neighbor_context(neighbors_df) | |
prompt = ( | |
"You are a creative brand copywriter. Write short, original, memorable startup slogans (max 8 words).\n" | |
"Forbidden words: app, assistant, platform, solution, system, marketplace, AI, machine learning, augmented reality, virtual reality, decentralized, empower.\n" | |
"Focus on clear benefits and vivid verbs. Do not copy the description. Return ONLY a list, one slogan per line.\n\n" | |
"Good Examples:\n" | |
"Description: AI assistant for doctors to prioritize patient cases\n" | |
"Slogan: Less Guessing. More Healing.\n\n" | |
"Description: Payments for small online stores\n" | |
"Slogan: Built to Grow with Your Cart.\n\n" | |
"Description: Neurotech headset to boost focus\n" | |
"Slogan: Train Your Brain to Win.\n\n" | |
"Description: Interior design suggestions with AI\n" | |
"Slogan: Style That Thinks With You.\n\n" | |
"Bad Examples (avoid these): Innovative AI Platform / Smart App for Everyone / Empowering Small Businesses\n\n" | |
) | |
if ctx: | |
prompt += f"Similar taglines (style only):\n{ctx}\n\n" | |
prompt += f"Description: {query_text}\nSlogans:" | |
input_ids = _GEN_TOK(prompt, return_tensors="pt").input_ids.to(DEVICE) | |
outputs = _GEN_MODEL.generate( | |
input_ids, | |
max_new_tokens=24, | |
do_sample=True, | |
top_k=60, | |
top_p=0.92, | |
temperature=1.2, | |
num_return_sequences=n_samples, | |
repetition_penalty=1.08 | |
) | |
raw_cands = [_GEN_TOK.decode(o, skip_special_tokens=True) for o in outputs] | |
cand_set = set() | |
for txt in raw_cands: | |
for line in txt.split("\n"): | |
s = _clean_slogan(line) | |
if not s: | |
continue | |
if len(s.split()) < 2 or len(s.split()) > 8: | |
continue | |
if _is_blocked_slogan(s): | |
continue | |
if _copies_neighbor(s, neighbors_df): | |
continue | |
cand_set.add(_titlecase_soft(s)) | |
if not cand_set: | |
return _clean_slogan(_GEN_TOK.decode(outputs[0], skip_special_tokens=True)) | |
scored = _score_candidates(query_text, sorted(cand_set), neighbors_df) | |
if not scored: | |
return _clean_slogan(_GEN_TOK.decode(outputs[0], skip_special_tokens=True)) | |
scored.sort(key=lambda x: x[1], reverse=True) | |
return scored[0][0] | |
# ---- Gradio UI ---- | |
EXAMPLES = [ | |
"AI coach for improving public speaking skills", | |
"Augmented reality app for interactive museum tours", | |
"Voice-controlled task manager for remote teams", | |
"Machine learning system for predicting crop yields", | |
"Platform for AI-assisted interior design suggestions", | |
] | |
def pipeline(user_input: str): | |
recs = recommend(user_input, top_k=3) | |
slogan = generate_slogan(user_input, neighbors_df=recs, n_samples=NUM_SLOGAN_SAMPLES) | |
recs = recs.reset_index(drop=True) | |
recs.loc[len(recs)] = {"name":"Synthetic Example","tagline":slogan,"description":user_input,"score":np.nan} | |
return recs[["name","tagline","description","score"]], slogan | |
with gr.Blocks(title="SloganAI — Recommendations + Slogan Generator") as demo: | |
gr.Markdown("## SloganAI — Top-3 Recommendations + A High-Quality Generated Slogan") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
inp = gr.Textbox(label="Enter a startup description", lines=3, placeholder="e.g., AI coach for improving public speaking skills") | |
gr.Examples(EXAMPLES, inputs=inp, label="One-click examples") | |
btn = gr.Button("Submit", variant="primary") | |
with gr.Column(scale=2): | |
out_df = gr.Dataframe(headers=["Name","Tagline","Description","Score"], label="Top 3 + Generated") | |
out_sg = gr.Textbox(label="Generated Slogan", interactive=False) | |
btn.click(fn=pipeline, inputs=inp, outputs=[out_df, out_sg]) | |
if __name__ == "__main__": | |
_ensure_models() | |
_ensure_index() | |
demo.queue().launch() | |