import torch import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer import threading import time # Global variables for model and tokenizer model = None tokenizer = None model_loaded = False def load_model(): """Load the model and tokenizer""" global model, tokenizer, model_loaded try: print("Loading Prompt Generator model...") tokenizer = AutoTokenizer.from_pretrained("UnfilteredAI/Promt-generator") model = AutoModelForCausalLM.from_pretrained( "UnfilteredAI/Promt-generator", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 ).to("cuda" if torch.cuda.is_available() else "cpu") model_loaded = True print("Prompt Generator model loaded successfully!") except Exception as e: print(f"Error loading model: {e}") model_loaded = False def generate_prompt(input_text, max_length, temperature, top_p, num_return_sequences): """Generate enhanced prompts from input text""" global model, tokenizer, model_loaded if not model_loaded: return "模型尚未加载完成,请稍等..." if not input_text.strip(): return "请输入一些文本作为提示词的起始内容。" try: # Tokenize input inputs = tokenizer(input_text, return_tensors="pt") if torch.cuda.is_available(): inputs = inputs.to("cuda") # Generate with torch.no_grad(): outputs = model.generate( **inputs, max_length=max_length, temperature=temperature, top_p=top_p, do_sample=True, num_return_sequences=num_return_sequences, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id ) # Decode generated prompts generated_prompts = [] for output in outputs: generated_text = tokenizer.decode(output, skip_special_tokens=True) generated_prompts.append(generated_text) return "\n\n---\n\n".join(generated_prompts) except Exception as e: return f"生成提示词时出错: {str(e)}" def clear_output(): """Clear the output""" return "" # Load model in background loading_thread = threading.Thread(target=load_model) loading_thread.start() # Create Gradio interface with gr.Blocks(title="AI Prompt Generator") as demo: gr.Markdown("# 🎨 AI Prompt Generator") gr.Markdown("基于 UnfilteredAI/Promt-generator 模型的智能提示词生成器") with gr.Row(): with gr.Column(scale=2): input_text = gr.Textbox( label="输入起始文本", placeholder="例如: a red car, beautiful landscape, futuristic city...", lines=3 ) with gr.Row(): generate_btn = gr.Button("生成提示词", variant="primary", scale=2) clear_btn = gr.Button("清空", scale=1) output_text = gr.Textbox( label="生成的提示词", lines=10, max_lines=20, show_copy_button=True ) with gr.Column(scale=1): gr.Markdown("### 生成参数") max_length = gr.Slider( minimum=50, maximum=500, value=150, step=10, label="最大长度" ) temperature = gr.Slider( minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature (创造性)" ) top_p = gr.Slider( minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p (多样性)" ) num_return_sequences = gr.Slider( minimum=1, maximum=5, value=3, step=1, label="生成数量" ) gr.Markdown("### 使用说明") gr.Markdown( """- **输入起始文本**: 描述你想要的内容主题 - **Temperature**: 控制生成的随机性,越高越有创意 - **Top-p**: 控制词汇选择的多样性 - **生成数量**: 一次生成多个不同的提示词""" ) # Event handlers generate_btn.click( generate_prompt, inputs=[input_text, max_length, temperature, top_p, num_return_sequences], outputs=output_text ) input_text.submit( generate_prompt, inputs=[input_text, max_length, temperature, top_p, num_return_sequences], outputs=output_text ) clear_btn.click( clear_output, outputs=output_text ) if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7861, share=False, show_error=True )