File size: 3,245 Bytes
0fa3012
 
 
 
 
 
 
 
 
a253b5f
0fa3012
 
 
 
a253b5f
 
0fa3012
 
 
 
 
 
 
 
 
 
 
 
 
a253b5f
0fa3012
a253b5f
 
 
 
 
0fa3012
 
a253b5f
0fa3012
a253b5f
 
0fa3012
 
 
 
 
 
 
 
 
9a2368c
0fa3012
 
 
a253b5f
0fa3012
a253b5f
0fa3012
 
 
 
 
 
a253b5f
0fa3012
 
 
9a2368c
0fa3012
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import os
import requests
import gradio as gr
from huggingface_hub import InferenceClient
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

# Configuration - Changed to SDXL for image generation
HF_TOKEN = os.getenv("HF_TOKEN")
REDDIT_API = "https://www.reddit.com/r/trending/top.json?limit=5"
DEFAULT_TRENDS = ["Celebrity gossip", "Tech news", "Movie drama", "Gaming leaks"]
STYLES = ["Realistic", "Cartoon", "Anime", "3D Render"]
IMAGE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"  # Verified working model
TEXT_MODEL = "deepseek-ai/Janus-Pro-7B"  # For text processing

# Initialize clients
client = InferenceClient(token=HF_TOKEN)

def get_trends():
    """Fetch trending topics with Reddit API fallback"""
    try:
        res = requests.get(REDDIT_API, headers={"User-Agent": "MemeBot/1.0"}, timeout=3)
        return [post["data"]["title"] for post in res.json()["data"]["children"]]
    except:
        return DEFAULT_TRENDS

def generate_meme(prompt: str, style: str):
    """Generate meme using SDXL for images and Janus-Pro for text refinement"""
    try:
        # First enhance text with Janus-Pro
        enhanced_prompt = client.text_generation(
            model=TEXT_MODEL,
            prompt=f"Improve this meme caption about {prompt} in {style} style:",
            max_new_tokens=100
        )
        
        # Generate image with SDXL
        image = client.text_to_image(
            model=IMAGE_MODEL,
            prompt=f"Trending meme template ({style} style): {enhanced_prompt}",
            negative_prompt="low quality, text, watermark",
            height=512,
            width=512
        )
        return image
    except Exception as e:
        print(f"Generation error: {e}")
        return None

# Build UI with fixed parameters
with gr.Blocks(title="πŸš€ Viral Meme Generator", css="static/style.css") as demo:
    gr.Markdown("# <center>πŸ”₯ Create Viral Memes in Seconds</center>")
    
    with gr.Row():
        with gr.Column(scale=3):
            gr.HTML('<img src="file/static/assets/logo.png" style="height: 200px">')
            trend_select = gr.Dropdown(get_trends(), label="Trending Topics")
            style_select = gr.Dropdown(STYLES, label="Visual Style", value="Realistic")
            text_input = gr.Textbox(label="Your Message", placeholder="Add funny text...")
            generate_btn = gr.Button("Generate Now", variant="primary")
            
        with gr.Column(scale=2):
            output_img = gr.Image(label="Your Meme", height=512, width=512)
            gr.HTML("""
            <div class="monetization">
                <script type='text/javascript' src='https://storage.ko-fi.com/cdn/widget/Widget_2.js'></script>
                <script type='text/javascript'>kofiwidget2.init('Support Us', '#FF5F5F', 'K3K8L7L7S');kofiwidget2.draw();</script>
            </div>
            """)
            download_btn = gr.Button("πŸ”’ Unlock HD Download ($0.99)")

    # Event handling
    generate_btn.click(fn=generate_meme, inputs=[text_input, style_select], outputs=output_img)
    trend_select.change(fn=lambda x: x, inputs=trend_select, outputs=text_input)

if __name__ == "__main__":
    demo.launch(server_port=int(os.getenv("PORT", 7860)))