import gradio as gr from diffusers import StableDiffusionPipeline import torch # Load models once at the start of the app for efficiency. # This prevents reloading the models for every new request, which # would be very slow. device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 if torch.cuda.is_available() and device == "cuda" else torch.float32 # Stage 1: Text-to-Sketch model # We use a base Stable Diffusion pipeline with a special prompt # to generate a line drawing effect. try: sketch_pipeline = StableDiffusionPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", torch_dtype=dtype ) sketch_pipeline.to(device) except Exception as e: print(f"Error loading sketch pipeline: {e}") sketch_pipeline = None # Stage 2: Sketch-to-Image model # This pipeline is loaded with the Stable Diffusion base and then # a LoRA model is attached to handle the sketch-to-image conversion. try: image_pipeline = StableDiffusionPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", torch_dtype=dtype ) image_pipeline.load_lora("gokaygokay/Sketch-to-Image-Kontext-Dev-LoRA", lora_weights_name="model.safetensors") image_pipeline.to(device) except Exception as e: print(f"Error loading image pipeline or LoRA: {e}") image_pipeline = None # The main function that connects the two stages def generate_full_image(text_prompt): if not sketch_pipeline or not image_pipeline: return None, None # Step 1: Generate the sketch from the text prompt # The "line drawing" prompt helps steer the model's output sketch_prompt = f"line drawing of a {text_prompt}" sketch = sketch_pipeline(sketch_prompt).images[0] # Step 2: Generate the final image from the sketch # The 'image' input to the pipeline uses the generated sketch final_image = image_pipeline(image=sketch, prompt="a realistic human portrait").images[0] return sketch, final_image # Define the Gradio UI using Blocks for a custom layout with gr.Blocks(title="Sketch-to-Image Pipeline") as demo: gr.Markdown("# Text-to-Sketch-to-Portrait") gr.Markdown("Enter a description to generate a sketch, which is then converted into a realistic human portrait.") with gr.Row(): text_input = gr.Textbox( label="Person Description", placeholder="e.g., A middle-aged man with a scar on his right cheek and shaggy hair" ) generate_button = gr.Button("Generate Portrait") with gr.Row(): sketch_output = gr.Image(label="Generated Sketch", type="pil") final_image_output = gr.Image(label="Generated Portrait", type="pil") # Connect the UI components to the Python function generate_button.click( fn=generate_full_image, inputs=text_input, outputs=[sketch_output, final_image_output] ) # Launch the app demo.launch()