from fastapi import FastAPI
import os
from typing import Union

from custom_llm import CustomLLM

from pydantic import BaseModel
from langchain.prompts import PromptTemplate
from langchain_huggingface import HuggingFacePipeline
from langchain_huggingface import HuggingFaceEndpoint


class ConversationPost(BaseModel):
    tenant: Union[str, None] = None
    module: Union[str, None] = None
    question: str

class InferencePost(BaseModel):
    question: str
    with_template: Union[str, None] = None

class LLMPost(BaseModel):
    model: str
    question: str

API_TOKEN = os.environ['HF_API_KEY']

os.environ["HUGGINGFACEHUB_API_TOKEN"] = API_TOKEN

app = FastAPI()
prompt_qwen = PromptTemplate.from_template("""<|im_start|>system
Kamu adalah Asisten AI yang dikembangkan oleh Jonthan Jordan. Answer strictly in Bahasa Indonesia<|im_end|>
<|im_start|>user
{question}<|im_end|>
<|im_start|>assistant
""")

prompt_llama = PromptTemplate.from_template("""<|start_header_id|>system<|end_header_id|>

Kamu adalah Asisten AI yang dikembangkan oleh Jonthan Jordan. Answer strictly in Bahasa Indonesia<|eot_id|><|start_header_id|>user<|end_header_id|>

{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
""")
# llm = prompt | HuggingFacePipeline.from_model_id(
#     model_id="Qwen/Qwen2-1.5B-Instruct",
#     task="text-generation",
#     pipeline_kwargs={
#         "max_new_tokens": 150,
#         "return_full_text":False
#     },
# )

llama = HuggingFaceEndpoint(
    repo_id="meta-llama/Meta-Llama-3-8B-Instruct",
    task="text-generation",
    max_new_tokens=4096,
    do_sample=False,
)

qwen = HuggingFaceEndpoint(
    repo_id="Qwen/Qwen1.5-4B-Chat",
    task="text-generation",
    max_new_tokens=4096,
    do_sample=False,
)

qwen2 = HuggingFaceEndpoint(
    repo_id="Qwen/Qwen2-1.5B-Instruct",
    task="text-generation",
    max_new_tokens=4096,
    do_sample=False,
)

llm = prompt_qwen | qwen

llm2 = prompt_llama | llama

llm3 = prompt_qwen | qwen2
# llm = prompt | CustomLLM(repo_id="Qwen/Qwen-VL-Chat", model_type='text-generation', api_token=API_TOKEN, max_new_tokens=150).bind(stop=['<|im_end|>'])


@app.get("/")
def greet_json():
    return {"Hello": "World!"}



@app.post("/chat")
async def chat(data: LLMPost):
    if data.model == 'llama':
        return {"data":llama.invoke(data.question)}
    elif data.model == 'qwen':
        return {"data":qwen.invoke(data.question)}
    else:
        return {"data":qwen2.invoke(data.question)}




@app.post("/conversation")
async def conversation(data : ConversationPost):
    return {"output":llm.invoke({"question":data.question})}


@app.post("/conversation2")
async def conversation2(data : ConversationPost):
    return {"output":llm2.invoke({"question":data.question})}

@app.post("/conversation3")
async def conversation3(data : ConversationPost):
    return {"output":llm3.invoke({"question":data.question})}


@app.post("/inference")
async def inference(data : InferencePost):
    if data.with_template == 'llama':
        out = llm2.invoke(data.question)
    elif data.with_template == 'qwen':
        out = llm.invoke(data.question)
    elif data.with_template == 'qwen2':
        out = llm3.invoke(data.question)
    else:
        out = llama.invoke(data.question)
        
    return {"output":out}