purchasing_api / app.py
jonathanjordan21's picture
Update app.py
1fd682a verified
from fastapi import FastAPI
import os
from typing import Union
from custom_llm import CustomLLM
from pydantic import BaseModel
from langchain.prompts import PromptTemplate
from langchain_huggingface import HuggingFacePipeline
from langchain_huggingface import HuggingFaceEndpoint
class ConversationPost(BaseModel):
tenant: Union[str, None] = None
module: Union[str, None] = None
question: str
class InferencePost(BaseModel):
question: str
with_template: Union[str, None] = None
class LLMPost(BaseModel):
model: str
question: str
API_TOKEN = os.environ['HF_API_KEY']
os.environ["HUGGINGFACEHUB_API_TOKEN"] = API_TOKEN
app = FastAPI()
prompt_qwen = PromptTemplate.from_template("""<|im_start|>system
Kamu adalah Asisten AI yang dikembangkan oleh Jonthan Jordan. Answer strictly in Bahasa Indonesia<|im_end|>
<|im_start|>user
{question}<|im_end|>
<|im_start|>assistant
""")
prompt_llama = PromptTemplate.from_template("""<|start_header_id|>system<|end_header_id|>
Kamu adalah Asisten AI yang dikembangkan oleh Jonthan Jordan. Answer strictly in Bahasa Indonesia<|eot_id|><|start_header_id|>user<|end_header_id|>
{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
""")
# llm = prompt | HuggingFacePipeline.from_model_id(
# model_id="Qwen/Qwen2-1.5B-Instruct",
# task="text-generation",
# pipeline_kwargs={
# "max_new_tokens": 150,
# "return_full_text":False
# },
# )
llama = HuggingFaceEndpoint(
repo_id="meta-llama/Meta-Llama-3-8B-Instruct",
task="text-generation",
max_new_tokens=4096,
do_sample=False,
)
qwen = HuggingFaceEndpoint(
repo_id="Qwen/Qwen1.5-4B-Chat",
task="text-generation",
max_new_tokens=4096,
do_sample=False,
)
qwen2 = HuggingFaceEndpoint(
repo_id="Qwen/Qwen2-1.5B-Instruct",
task="text-generation",
max_new_tokens=4096,
do_sample=False,
)
llm = prompt_qwen | qwen
llm2 = prompt_llama | llama
llm3 = prompt_qwen | qwen2
# llm = prompt | CustomLLM(repo_id="Qwen/Qwen-VL-Chat", model_type='text-generation', api_token=API_TOKEN, max_new_tokens=150).bind(stop=['<|im_end|>'])
@app.get("/")
def greet_json():
return {"Hello": "World!"}
@app.post("/chat")
async def chat(data: LLMPost):
if data.model == 'llama':
return {"data":llama.invoke(data.question)}
elif data.model == 'qwen':
return {"data":qwen.invoke(data.question)}
else:
return {"data":qwen2.invoke(data.question)}
@app.post("/conversation")
async def conversation(data : ConversationPost):
return {"output":llm.invoke({"question":data.question})}
@app.post("/conversation2")
async def conversation2(data : ConversationPost):
return {"output":llm2.invoke({"question":data.question})}
@app.post("/conversation3")
async def conversation3(data : ConversationPost):
return {"output":llm3.invoke({"question":data.question})}
@app.post("/inference")
async def inference(data : InferencePost):
if data.with_template == 'llama':
out = llm2.invoke(data.question)
elif data.with_template == 'qwen':
out = llm.invoke(data.question)
elif data.with_template == 'qwen2':
out = llm3.invoke(data.question)
else:
out = llama.invoke(data.question)
return {"output":out}