Embrw's picture
Update app.py
50baa1e verified
raw
history blame
3.7 kB
import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoModel, AutoTokenizer
import gradio as gr
# Dùng CPU thay vì GPU
device = torch.device("cpu")
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
def build_transform(input_size):
"""
Xây dựng chuỗi các phép biến đổi cho ảnh:
- Đảm bảo ảnh ở chế độ RGB.
- Resize ảnh về kích thước (input_size x input_size) với nội suy BICUBIC.
- Chuyển ảnh sang tensor và chuẩn hóa theo chuẩn ImageNet.
"""
return T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
])
def load_image(image, input_size=448):
"""
Xử lý đầu vào ảnh:
- Nếu image là None, báo lỗi.
- Nếu image không phải là đối tượng PIL, thử mở ảnh từ file path.
- Áp dụng các phép biến đổi và thêm batch dimension.
"""
if image is None:
raise ValueError("Vui lòng tải lên một hình ảnh hợp lệ.")
if not isinstance(image, Image.Image):
try:
image = Image.open(image)
except Exception as e:
raise ValueError(f"Lỗi khi mở ảnh từ file: {e}")
try:
transform = build_transform(input_size)
pixel_values = transform(image).unsqueeze(0) # Thêm chiều batch
except Exception as e:
raise ValueError(f"Lỗi khi biến đổi ảnh: {e}")
return pixel_values
# Tải mô hình trên CPU
model = AutoModel.from_pretrained(
"5CD-AI/Vintern-1B-v3_5",
torch_dtype=torch.float32, # Dùng float32 cho CPU
low_cpu_mem_usage=True,
trust_remote_code=True,
).eval().to(device)
tokenizer = AutoTokenizer.from_pretrained(
"5CD-AI/Vintern-1B-v3_5",
trust_remote_code=True,
use_fast=False
)
def process_image(image, user_request):
"""
1. Xử lý ảnh đầu vào và chuyển thành tensor.
2. Nếu không có yêu cầu từ người dùng, sử dụng mặc định.
3. Gọi mô hình với phương thức chat và trả về kết quả.
"""
try:
pixel_values = load_image(image).to(device)
except ValueError as e:
return str(e)
generation_config = {
"max_new_tokens": 256, # Giới hạn số token tạo mới
"do_sample": False,
"num_beams": 3,
"repetition_penalty": 2.0
}
if not user_request.strip():
user_request = "Trích xuất toàn bộ thông tin trong ảnh và trả về dạng Markdown."
question = f"<image>\n{user_request}"
with torch.inference_mode():
try:
response, _ = model.chat(tokenizer, pixel_values, question, generation_config, history=None, return_history=True)
except Exception as e:
return f"Lỗi khi xử lý ảnh: {e}"
return response
# Đặt tiêu đề mà không căn giữa
title = "Vietnamese Hand Writing OCR"
iface = gr.Interface(
fn=process_image,
inputs=[
gr.Image(type="pil", label="Chọn ảnh"),
gr.Textbox(lines=2, placeholder="Nhập yêu cầu của bạn, ví dụ: 'Nhận dạng chữ viết tay và trả về dạng văn bản'", label="Yêu cầu")
],
outputs=gr.Textbox(label="Kết quả"),
title=title
)
if __name__ == "__main__":
iface.launch("=True")