jacobmp commited on
Commit
4130caf
·
verified ·
1 Parent(s): c9bae28

Fix using only generated_text[0]

Browse files
Files changed (1) hide show
  1. app.py +6 -2
app.py CHANGED
@@ -7,7 +7,7 @@ from transformers import pipeline
7
  from ultralytics import YOLO
8
  from PIL import Image
9
 
10
- def process(path, progress = gr.Progress()):
11
  progress(0, desc="Starting")
12
  LINE_MODEL_PATH = "Kansallisarkisto/multicentury-textline-detection"
13
  OCR_MODEL_PATH = "microsoft/trocr-large-handwritten"
@@ -16,6 +16,7 @@ def process(path, progress = gr.Progress()):
16
  # Load the model and processor
17
  processor = TrOCRProcessor.from_pretrained(OCR_MODEL_PATH)
18
  model = VisionEncoderDecoderModel.from_pretrained(OCR_MODEL_PATH)
 
19
 
20
  # Open an image of handwritten text
21
  image = Image.open(path).convert("RGB")
@@ -44,9 +45,12 @@ def process(path, progress = gr.Progress()):
44
 
45
  #Predict and decode the entire batch
46
  progress(0, desc="Recognizing..")
 
 
47
  generated_ids = model.generate(torch.cat(batch))
48
  progress(0, desc="Decoding (token -> str)")
49
- generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
50
  print(generated_text)
51
  full_text = " ".join(generated_text)
52
  print(full_text)
 
7
  from ultralytics import YOLO
8
  from PIL import Image
9
 
10
+ def process(path, progress = gr.Progress(), device = 'cpu'):
11
  progress(0, desc="Starting")
12
  LINE_MODEL_PATH = "Kansallisarkisto/multicentury-textline-detection"
13
  OCR_MODEL_PATH = "microsoft/trocr-large-handwritten"
 
16
  # Load the model and processor
17
  processor = TrOCRProcessor.from_pretrained(OCR_MODEL_PATH)
18
  model = VisionEncoderDecoderModel.from_pretrained(OCR_MODEL_PATH)
19
+ model.to(device)
20
 
21
  # Open an image of handwritten text
22
  image = Image.open(path).convert("RGB")
 
45
 
46
  #Predict and decode the entire batch
47
  progress(0, desc="Recognizing..")
48
+ batch = torch.cat(batch).to(device)
49
+ print("batch.shape", batch.shape)
50
  generated_ids = model.generate(torch.cat(batch))
51
  progress(0, desc="Decoding (token -> str)")
52
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
53
+
54
  print(generated_text)
55
  full_text = " ".join(generated_text)
56
  print(full_text)