File size: 2,239 Bytes
8f243be
 
 
 
d304119
eba494a
8f243be
 
 
43de206
8f243be
43de206
8f243be
 
0b10362
8f243be
 
 
 
 
 
d304119
eba494a
9a0738a
eba494a
d304119
 
 
 
eba494a
9a0738a
eba494a
d304119
 
 
 
eba494a
9a0738a
eba494a
d304119
 
 
 
 
 
 
8f243be
 
d304119
 
8f243be
d304119
8f243be
d304119
8f243be
 
4404cfc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from inference import Inference
import argparse
import gradio as gr
import glob
from huggingface_hub import hf_hub_download
import os

def parse_option():
    parser = argparse.ArgumentParser('MetaFG Inference script', add_help=False)
    parser.add_argument('--cfg', type=str, metavar="FILE", help='path to config file')
    # easy config modification
    parser.add_argument('--model-path', type=str, help="path to model data")
    parser.add_argument('--img-size', type=int, default=384, help='path to image')
    parser.add_argument('--meta-path', default="meta.txt", type=str, help='path to meta data')
    parser.add_argument('--names-path', type=str, help='path to meta data')
    args = parser.parse_args()
    return args

if __name__ == '__main__':
    args = parse_option()

    if not args.model_path:
        model_path = hf_hub_download(repo_id="joshvm/inaturalist_sgd_4k",
                                     filename="inat_sgd_6k.pth",
                                     token=os.environ["HUGGINGFACE_TOKEN"])
    else:
        model_path = args.model_path

    if not args.cfg:
        model_config = hf_hub_download(repo_id="joshvm/inaturalist_sgd_4k",
                                       filename="MetaFG_2_384_inat.yaml",
                                       token=os.environ["HUGGINGFACE_TOKEN"])
    else:
        model_config = args.cfg

    if not args.names_path:
        names_path = hf_hub_download(repo_id="joshvm/inaturalist_sgd_4k",
                                     filename="inat_sgd_names.txt",
                                     token=os.environ["HUGGINGFACE_TOKEN"])
    else:
        names_path = args.names_path


    model = Inference(config_path=model_config,
                       model_path=model_path,
                       names_path=names_path)
    
    def classify(image):
        preds = model.infer(img_path=image, meta_data_path="meta.txt")
        #confidences = {c: float(preds[i]) for i,c in enumerate(model.classes)}

        return preds
    
    gr.Interface(fn=classify, 
            inputs=gr.Image(shape=(args.img_size, args.img_size), type="pil"),
            outputs=gr.Label(num_top_classes=10),
            examples=glob.glob("./example_images/*.jpg")).launch()