ae / app.py
sureshnam9's picture
Update app.py
dc255e5 verified
raw
history blame
1.53 kB
import gradio as gr
import os
import argparse
import concurrent.futures
import json
import logging
import math
import time
from itertools import cycle
from pathlib import Path
import requests
import json
import torch
import gradio as gr
#from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
#url = os.environ["TGI_GAUDI_ENDPOINT_URL"]
#myport = os.environ["myport"]
URL = "198.175.88.52"
#URL = "100.81.119.213"
myport = "8080"
gaudi_device_url = f"http://{URL}:{myport}/generate"
# This assumes that TGI is running on Gaudi so we don't need to define the pipeline here. It's like we're sending a curl command
def text_gen(url, prompt):
resp = requests.post(url, prompt=json.dumps(prompt))
return resp
def text_gen_cpu(prompt):
pipe = pipeline(task="text-generation", model="gpt2", tokenizer="gpt2", device="cpu", torch_dtype=torch.bfloat16)
result = pipe(prompt, max_length=100, num_return_sequences=1)
return result
demo = gr.Interface(
fn=text_gen,
inputs=[gaudi_device_url, "text"],
outputs=["text"],
)
demo.launch()
#url = gr.Textbox(label='url', value=URL, visible=False)
# This is some demo code for using the
#llm = HuggingFaceEndpoint(
# endpoint_url=url,
# max_new_tokens=1024,
# top_k=10,
# top_p=0.95,
# typical_p=0.95,
# temperature=0.01,
# repetition_penalty=1.03,
# streaming=True,
# )
#result = llm.invoke("Why is the sky blue?")
#print(result)