File size: 1,660 Bytes
897b6ce
 
 
 
 
 
eb2d39a
897b6ce
 
beb0ead
897b6ce
d86203b
 
 
e399623
 
897b6ce
 
722c0fb
 
 
beb0ead
722c0fb
 
 
897b6ce
 
e399623
 
460180a
e399623
460180a
e399623
460180a
d86203b
 
38a84a5
 
460180a
 
 
eb2d39a
 
 
 
 
 
 
38a84a5
 
 
 
460180a
722c0fb
460180a
722c0fb
a5a374c
 
 
 
 
722c0fb
 
460180a
 
38a84a5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import re
import os
import requests
import gradio as gr
from datasets import load_dataset
from PIL import Image
from io import BytesIO
import torch
from torch import autocast

from transformers import pipeline, set_seed
from diffusers import DiffusionPipeline, StableDiffusionPipeline


# Config
DEVICE = "cuda"

# GPT2
def get_gpt2_pipeline():
  generator = pipeline('text-generation', model='gpt2')
  set_seed(42)

  # generator("Hello world, I'm vizard,", max_length=50, num_return_sequences=3)
  
  return generator

# SD v1.4
def get_stable_diffusion_v14_pipeline():
  model_id = "CompVis/stable-diffusion-v1-4"
  pipe = StableDiffusionPipeline.from_pretrained(mode_id)
  # pipeline = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True, revision="fp16", torch_dtype=torch.float16)
  pipe = pipe.to(DEVICE)
  torch.backends.cudnn.benchmark = True
  return pipe

# SD v1.5
def get_stable_diffusion_v15_pipeline():
  model_id = "runwayml/stable-diffusion-v1-5"
  pipe = DiffusionPipeline.from_pretrained(mode_id)
  pipe = pipe.to(DEVICE)
  return pipe

def get_image(url):
  response = requests.get(url)
  image = Image.open(BytesIO(response.content)).convert("RGB")
  resized_image = image.resize((768, 512))
  return resized_image

# main
def main():
  prompt = "Hello world, I'm vizard,"
  
  pipe = get_gpt2_pipeline()
  def greet(prompt):
    return pipe(prompt, max_length=50, num_return_sequences=3)
  
  ui = gr.Interface(
    fn=greet,
    inputs=gr.Textbox(lines=2, placeholder="Enter some text here..."),
    outputs="text"
  )
  ui.launch()
  
  pipe2 = get_stable_diffusion_v15_pipeline()
  images = pipe2(prompt).images

main