Spaces:
Running
Running
File size: 3,596 Bytes
c5e57d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import random, os
from PIL import Image
import copy
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import warnings
from transformers import AutoProcessor, AutoModelForCausalLM, AutoTokenizer
from vouchervision.utils_LLM import SystemLoadMonitor
warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated")
class FlorenceOCR:
def __init__(self, logger, model_id='microsoft/Florence-2-large'):
self.MAX_TOKENS = 1024
self.logger = logger
self.model_id = model_id
self.monitor = SystemLoadMonitor(logger)
self.model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).eval().cuda()
self.processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
# self.model_id_clean = "mistralai/Mistral-7B-v0.3"
self.model_id_clean = "unsloth/mistral-7b-instruct-v0.3-bnb-4bit"
self.tokenizer_clean = AutoTokenizer.from_pretrained(self.model_id_clean)
self.model_clean = AutoModelForCausalLM.from_pretrained(self.model_id_clean)
def ocr_florence(self, image, task_prompt='<OCR>', text_input=None):
self.monitor.start_monitoring_usage()
# Open image if a path is provided
if isinstance(image, str):
image = Image.open(image)
if text_input is None:
prompt = task_prompt
else:
prompt = task_prompt + text_input
inputs = self.processor(text=prompt, images=image, return_tensors="pt")
# Move input_ids and pixel_values to the same device as the model
inputs = {key: value.to(self.model.device) for key, value in inputs.items()}
generated_ids = self.model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=self.MAX_TOKENS,
early_stopping=False,
do_sample=False,
num_beams=3,
)
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer_dirty = self.processor.post_process_generation(
generated_text,
task=task_prompt,
image_size=(image.width, image.height)
)
inputs = self.tokenizer_clean(f"Insert spaces into this text to make all the words valid. This text contains scientific names of plants, locations, habitat, coordinate words: {parsed_answer_dirty[task_prompt]}", return_tensors="pt")
inputs = {key: value.to(self.model_clean.device) for key, value in inputs.items()}
outputs = self.model_clean.generate(**inputs, max_new_tokens=self.MAX_TOKENS)
parsed_answer = self.tokenizer_clean.decode(outputs[0], skip_special_tokens=True)
print(parsed_answer_dirty)
print(parsed_answer)
self.monitor.stop_inference_timer() # Starts tool timer too
usage_report = self.monitor.stop_monitoring_report_usage()
return parsed_answer, parsed_answer_dirty[task_prompt], parsed_answer_dirty, usage_report
def main():
img_path = '/home/brlab/Downloads/gem_2024_06_26__02-26-02/Cropped_Images/By_Class/label/1.jpg'
# img = 'D:/D_Desktop/BR_1839468565_Ochnaceae_Campylospermum_reticulatum_label.jpg'
image = Image.open(img_path)
ocr = FlorenceOCR(logger = None)
results_text, results, usage_report = ocr.ocr_florence(image, task_prompt='<OCR>', text_input=None)
print(results_text)
if __name__ == '__main__':
main()
|