mkoot007 commited on
Commit
4ae8bae
·
1 Parent(s): 8cb006a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -13
app.py CHANGED
@@ -1,34 +1,53 @@
1
  import streamlit as st
2
- import io
3
  from PIL import Image
 
4
  import torch
5
- from transformers import AutoTokenizer, AutoModelForCausalLM
6
  from easyocr import Reader
7
 
8
-
9
  ocr_reader = Reader(['en'])
10
  text_generator = AutoModelForCausalLM.from_pretrained("gpt2")
11
  text_tokenizer = AutoTokenizer.from_pretrained("gpt2")
 
 
 
 
 
 
12
  def extract_text(image):
13
  return ocr_reader.readtext(image)
14
- def explain_text(text):
15
- input_ids = text_tokenizer.encode(text, return_tensors="pt")
16
- explanation_ids = text_generator.generate(input_ids, max_length=100, num_return_sequences=1)
17
- explanation = text_tokenizer.decode(explanation_ids[0], skip_special_tokens=True)
18
- return explanation
19
 
20
- st.title("Text Classification Model")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  uploaded_file = st.file_uploader("Upload an image:")
22
 
 
23
  if uploaded_file is not None:
24
  image = Image.open(uploaded_file)
25
  ocr_results = extract_text(image)
26
- extracted_text = " ".join([res[1] for res in ocr_results])
27
- explanation = explain_text(extracted_text)
28
  st.markdown("**Extracted text:**")
29
- st.markdown(extracted_text)
30
 
31
- st.markdown("**Explanation:**")
32
  st.markdown(explanation)
33
 
34
  else:
 
1
  import streamlit as st
 
2
  from PIL import Image
3
+ import io
4
  import torch
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoProcessor, AutoModelForSeq2SeqLM
6
  from easyocr import Reader
7
 
8
+ # Load the OCR model and text generation model
9
  ocr_reader = Reader(['en'])
10
  text_generator = AutoModelForCausalLM.from_pretrained("gpt2")
11
  text_tokenizer = AutoTokenizer.from_pretrained("gpt2")
12
+
13
+ # Load the image captioning model
14
+ processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
15
+ caption_model = AutoModelForSeq2SeqLM.from_pretrained("Salesforce/blip-image-captioning-large")
16
+
17
+ # Define a function to extract text from an image using OCR
18
  def extract_text(image):
19
  return ocr_reader.readtext(image)
 
 
 
 
 
20
 
21
+ # Define a function to explain the extracted text using text generation
22
+ def explain_text(text, caption_model, processor):
23
+ # Extracted text
24
+ extracted_text = " ".join([res[1] for res in text])
25
+
26
+ # Generate an image caption using the image captioning model
27
+ inputs = processor(extracted_text, return_tensors="pt", padding="max_length", max_length=100, truncation=True)
28
+ input_ids = inputs["input_ids"]
29
+ caption = caption_model.generate(input_ids, max_length=50, num_return_sequences=1, no_repeat_ngram_size=2)
30
+
31
+ # Decode and return the generated caption
32
+ generated_caption = processor.decode(caption[0], skip_special_tokens=True)
33
+ return generated_caption
34
+
35
+ # Create a Streamlit layout
36
+ st.title("Text Extraction and Explanation")
37
+
38
+ # Allow users to upload an image
39
  uploaded_file = st.file_uploader("Upload an image:")
40
 
41
+ # Extract text from the uploaded image and explain it
42
  if uploaded_file is not None:
43
  image = Image.open(uploaded_file)
44
  ocr_results = extract_text(image)
45
+ explanation = explain_text(ocr_results, caption_model, processor)
46
+
47
  st.markdown("**Extracted text:**")
48
+ st.markdown(" ".join([res[1] for res in ocr_results]))
49
 
50
+ st.markdown("**Explanation (Image Caption):**")
51
  st.markdown(explanation)
52
 
53
  else: