from towhee import pipeline, FileManagerConfig, FileManager
import gradio
import numpy
from PIL import Image

title = 'Towhee AnimeGanV2 Pipeline'
description = 'An end to end pipeline for AnimeGanV2 using the Towhee framework. Take a look at `app.py` to see how little steps it takes and try it for yourself using `pip install towhee`.\nQuick Note: First run after reboot may be slow due to caching pipeline operators.'
article = '<a href="https://github.com/towhee-io/towhee"  style="text-align:center" target="_blank">Check out the Towhee Github</a>'


size = (512, 512)

# Configuring the caching location
fmc = FileManagerConfig()
fmc.update_default_cache('./')

# All pipelines loaded in at start. These pipelines all share operators for reduced memory overhead.
celeba  = pipeline('filip-halt/style-transfer-animegan', tag = 'celeba')
facepaintv1  = pipeline('filip-halt/style-transfer-animegan', tag = 'facepaintv1')
facepaintv2  = pipeline('filip-halt/style-transfer-animegan', tag = 'facepaintv2')
hayao  = pipeline('filip-halt/style-transfer-animegan', tag = 'hayao')
paprika  = pipeline('filip-halt/style-transfer-animegan', tag = 'paprika')
shinkai  = pipeline('filip-halt/style-transfer-animegan', tag = 'shinkai')

def operation(Input, Version):
	# Resizing the image while keeping aspect ratio.
	Input.thumbnail(size, Image.ANTIALIAS)
	# Saving image to file for input. Very low chance of concurrent file saves during the time
	# between saving and taking first step of pipeline, so avoiding locks for now. In addition,
	# current gradio is set to queue so there will never be parallel runs for this. 
	Input.save('./test.jpg')

	if Version == 'celeba':
		x = celeba('./test.jpg')
	elif Version == 'facepaintv1':
		x = facepaintv1('./test.jpg')
	elif Version == 'facepaintv2':
		x = facepaintv2('./test.jpg')
	elif Version == 'hayao':
		x = hayao('./test.jpg')
	elif Version == 'paprika':
		x = paprika('./test.jpg')
	elif Version == 'shinkai':
		x = shinkai('./test.jpg')

	# Converting from channel-first, [0,1] value RGB, numpy array to PIL image.
	x = numpy.transpose(x[0][0], (1,2,0))
	x = Image.fromarray((x * 255).astype(numpy.uint8))
	return x

gradio.Interface(operation, [gradio.inputs.Image(type="pil"), gradio.inputs.Radio(["celeba", "facepaintv1", "facepaintv2", "hayao", "paprika", 'shinkai'])], gradio.outputs.Image(type="pil"), allow_flagging=False,allow_screenshot=False, title=title, article=article, description=description).launch(enable_queue=True)