Spaces:
Running
Running
import gradio as gr | |
from transformers import AutoTokenizer, AutoProcessor, VisionEncoderDecoderModel, TrOCRProcessor | |
from vllm import LLM, SamplingParams | |
from PIL import Image | |
# Load the language model and tokenizer from Hugging Face | |
model_name = "facebook/opt-125m" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Initialize vLLM with CPU configuration | |
vllm_model = LLM(model=model_name, tensor_parallel_size=1, device="cpu") | |
# Load the OCR model and processor | |
ocr_model_name = "microsoft/trocr-small-handwritten" | |
ocr_model = VisionEncoderDecoderModel.from_pretrained(ocr_model_name) | |
ocr_processor = TrOCRProcessor.from_pretrained(ocr_model_name) | |
#ocr_processor = AutoProcessor.from_pretrained(ocr_model_name) | |
def generate_response(prompt, max_tokens, temperature, top_p): | |
# Define sampling parameters | |
sampling_params = SamplingParams( | |
max_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
) | |
# Generate text using vLLM (input is the raw string `prompt`) | |
output = vllm_model.generate(prompt, sampling_params) | |
# Extract and decode the generated tokens | |
generated_text = output[0].outputs[0].text | |
return generated_text | |
def ocr_image(image_path): | |
# Open the image from the file path | |
image = Image.open(image_path).convert("RGB") | |
# Preprocess the image for the OCR model | |
pixel_values = ocr_processor(images=image, return_tensors="pt").pixel_values | |
# Perform OCR inference | |
outputs = ocr_model.generate(pixel_values) | |
# Decode the generated tokens into text | |
text = ocr_processor.batch_decode(outputs, skip_special_tokens=True)[0] | |
return text | |
# Gradio UI | |
with gr.Blocks() as demo: | |
gr.Markdown("# π Hugging Face Integration with vLLM and OCR (CPU)") | |
gr.Markdown("Upload an image to extract text using OCR or generate text using the vLLM integration.") | |
with gr.Tab("Text Generation"): | |
with gr.Row(): | |
with gr.Column(): | |
prompt_input = gr.Textbox( | |
label="Prompt", | |
placeholder="Enter your prompt here...", | |
lines=3, | |
) | |
max_tokens = gr.Slider( | |
label="Max Tokens", | |
minimum=10, | |
maximum=500, | |
value=100, | |
step=10, | |
) | |
temperature = gr.Slider( | |
label="Temperature", | |
minimum=0.1, | |
maximum=1.0, | |
value=0.7, | |
step=0.1, | |
) | |
top_p = gr.Slider( | |
label="Top P", | |
minimum=0.1, | |
maximum=1.0, | |
value=0.9, | |
step=0.1, | |
) | |
submit_button = gr.Button("Generate") | |
with gr.Column(): | |
output_text = gr.Textbox( | |
label="Generated Text", | |
lines=10, | |
interactive=False, | |
) | |
submit_button.click( | |
generate_response, | |
inputs=[prompt_input, max_tokens, temperature, top_p], | |
outputs=output_text, | |
) | |
with gr.Tab("OCR"): | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image( | |
label="Upload Image", | |
type="filepath", # Corrected type | |
image_mode="RGB", | |
) | |
ocr_submit_button = gr.Button("Extract Text") | |
with gr.Column(): | |
ocr_output = gr.Textbox( | |
label="Extracted Text", | |
lines=10, | |
interactive=False, | |
) | |
ocr_submit_button.click( | |
ocr_image, | |
inputs=[image_input], | |
outputs=ocr_output, | |
) | |
# Launch the app | |
demo.launch() |