# from transformers import pipeline
# import gradio as gr
# import spaces
# Initialize the text generation pipeline with optimizations
# pipe = pipeline("text-generation", model="SakanaAI/EvoLLM-JP-v1-7B")


# Define a function to generate text based on user input
# @spaces.GPU
# def generate_text(prompt):
#    result = pipe(prompt, max_length=50, num_return_sequences=1)
#     return result[0]['generated_text']

# Create a Gradio interface with batching enabled
# iface = gr.Interface(
#    fn=generate_text, 
#    inputs=gr.Textbox(lines=2, placeholder="Enter your prompt here..."), 
#    outputs=gr.Textbox(label="生成されたテキスト"),
#    title="Text Generation with SakanaAI/EvoLLM-JP-v1-7B",
#    description="Enter a prompt and the model will generate a continuation of the text.",
#    batch=True,
#    max_batch_size=4
# )

# Launch the interface
# if __name__ == "__main__":
#    iface.launch()


import os
from threading import Thread
from typing import Iterator

import gradio as gr
import spaces
import torch
from transformers import (
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    GemmaTokenizerFast,
    TextIteratorStreamer,
    pipeline, AutoTokenizer
)


# 日本語モデルを指定
model_name = "SakanaAI/EvoLLM-JP-v1-7B"  # This line is now correctly indented
    # Add more code here, all indented at the same level
# model_name = "SakanaAI/EvoLLM-JP-v1-7B"
from spaces import GPU

@GPU
def generate():
    # Your code here
# トークナイザーとパイプラインの設定
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    generator = pipeline('text-generation', model=model_name, tokenizer=tokenizer, device=-1)  # device=0はGPUを使用する設定

def generate_text(prompt, max_length):
    result = generator(prompt, max_length=max_length, num_return_sequences=1)
    return result[0]['generated_text']

iface = gr.Interface(
    fn=generate_text,
    inputs=[
        gr.Textbox(label="プロンプト", placeholder="ここに日本語のプロンプトを入力してください"),
        gr.Slider(minimum=10, maximum=200, value=50, step=1, label="最大長")
    ],
    outputs=gr.Textbox(label="生成されたテキスト")
)

iface.launch()