File size: 2,216 Bytes
650ec6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3368fa1
650ec6e
 
 
 
3368fa1
650ec6e
 
 
 
 
 
 
 
 
 
 
 
 
4556b47
650ec6e
edfe9df
4556b47
 
650ec6e
 
 
 
 
 
 
 
3368fa1
 
650ec6e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import random
import torch
import gradio as gr
from gradio.mix import Series
from rudalle.pipelines import generate_images
from rudalle import get_rudalle_model, get_tokenizer, get_vae

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dalle = get_rudalle_model("Malevich", pretrained=True, fp16=True, device=device)
tokenizer = get_tokenizer()
vae = get_vae().to(device)

def dalle_wrapper(prompt: str):
    top_k, top_p = random.choice([
        (1024, 0.98),
        (512, 0.97),
        (384, 0.96),
    ])
    
    images , _ = generate_images(
        prompt, 
        tokenizer, 
        dalle, 
        vae, 
        top_k=top_k, 
        images_num=1, 
        top_p=top_p
    )
    title = f"<b>{prompt}</b>"
    return title, images[0]


translator = gr.Interface.load("huggingface/facebook/wmt19-en-ru", 
                               inputs=[gr.inputs.Textbox(label="What would you like to see?")])
outputs = [
    gr.outputs.HTML(label=""),   
    gr.outputs.Image(label=""),
]
generator = gr.Interface(fn=dalle_wrapper, inputs="text", outputs=outputs)


description = (
    "ruDALL-E is a 1.3B params text-to-image model by SberAI (links at the bottom). "
    "This demo uses an English-Russian translation model to adapt the prompts. "
    "Try pressing [Submit] multiple times to generate new images!"
)
article = (
    "<p style='text-align: center'>"
    "<a href='https://github.com/sberbank-ai/ru-dalle'>GitHub</a> | "
    "<a href='https://habr.com/ru/company/sberbank/blog/586926/'>Article (in Russian)</a>"
    "</p>"
)
examples = [["A still life of grapes and a bottle of wine"], 
            ["Город в стиле киберпанк"], 
            ["A colorful photo of a coral reef"], 
            ["A white cat sitting in a cardboard box"]]
            
series = Series(translator, generator, 
                title='Kinda-English ruDALL-E',
                description=description,
                article=article,
                layout='horizontal',
                theme='huggingface',
                examples=examples,
                allow_flagging=False,
                live=False, 
                enable_queue=True,
               )
series.launch()