MapLocNet / app.py
wangerniu's picture
Update app.py
4291dca
raw
history blame
1.5 kB
import gradio as gr
import cv2
import gradio as gr
import torch
from torchvision import transforms
import requests
from PIL import Image
model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True).eval()
#标题
title = "抽取式问答"
#标题下的描述,支持md格式
description = "输入上下文与问题后,点击submit按钮,可从上下文中抽取出答案,赶快试试吧!"
# Download human-readable labels for ImageNet.
# response = requests.get("http://git.io/JJkYN")
# labels = response.text.split("\n")
# 打开文件
file = open('label.txt', 'r')
# 读取文件内容
labels = file.readlines()
def to_black(inp,long,lat,Area):
inp = Image.fromarray(inp.astype('uint8'), 'RGB')
inp = transforms.ToTensor()(inp).unsqueeze(0)
with torch.no_grad():
prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
return {labels[i]: float(prediction[i]) for i in range(1000)}
outputs = gr.outputs.Label(num_top_classes=3)
interface = gr.Interface(fn=to_black,
inputs=["image",
gr.Number(label="longitude"),
gr.Number(label="latitude"),
gr.Slider(256, 512,label='Area')],
outputs=outputs,
title=title,
description=description,
examples=[["gradio/未命名1688700109.png",70.1,40.0,256]])
interface.launch()