File size: 1,828 Bytes
f6086aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import os
from PIL import Image
from io import BytesIO
import base64
import re
import argparse
import torch
from transformers import AutoTokenizer
from sat.model.mixins import CachedAutoregressiveMixin
from sat.quantization.kernels import quantize
import hashlib
from .visualglm import VisualGLMModel

def get_infer_setting(gpu_device=0, quant=None):
    os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_device)
    args = argparse.Namespace(
        fp16=True,
        skip_init=True,
        device='cuda' if quant is None else 'cpu',
    )
    model, args = VisualGLMModel.from_pretrained('visualglm-6b', args)
    model.add_mixin('auto-regressive', CachedAutoregressiveMixin())
    assert quant in [None, 4, 8]
    if quant is not None:
        quantize(model.transformer, quant)
    model.eval()
    model = model.cuda()
    tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
    return model, tokenizer

def is_chinese(text):
    zh_pattern = re.compile(u'[\u4e00-\u9fa5]+')
    return zh_pattern.search(text)

def generate_input(input_text, input_image_prompt, history=[], input_para=None, image_is_encoded=True):
    if not image_is_encoded:
        image = input_image_prompt
    else:
        decoded_image = base64.b64decode(input_image_prompt)
        image = Image.open(BytesIO(decoded_image))

    input_data = {'input_query': input_text, 'input_image': image, 'history': history, 'gen_kwargs': input_para}
    return input_data


def process_image(image_encoded):
    decoded_image = base64.b64decode(image_encoded)
    image = Image.open(BytesIO(decoded_image))
    image_hash = hashlib.sha256(image.tobytes()).hexdigest()
    image_path = f'./examples/{image_hash}.png'
    if not os.path.isfile(image_path):
        image.save(image_path)
    return os.path.abspath(image_path)