Spaces:
Build error
Build error
import os | |
import streamlit as st | |
from paddleocr import PaddleOCR | |
import cv2 | |
from langchain.chains import LLMChain | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
from dotenv import load_dotenv | |
from sqlalchemy import create_engine, Column, Integer, String, JSON | |
from sqlalchemy.ext.declarative import declarative_base | |
from sqlalchemy.orm import sessionmaker | |
import google.generativeai as genai | |
# Load environment variables | |
load_dotenv() | |
# Set up environment variables | |
api_key = os.getenv('API_KEY') | |
DATABASE_URL = "sqlite:///mydatabase.db" | |
# Setup database | |
Base = declarative_base() | |
class MyDataModel(Base): | |
__tablename__ = 'my_data_table' | |
id = Column(Integer, primary_key=True) | |
name = Column(String) | |
data = Column(JSON) | |
engine = create_engine(DATABASE_URL) | |
Session = sessionmaker(bind=engine) | |
session = Session() | |
Base.metadata.create_all(engine) | |
# Initialize Google Generative AI API | |
genai.configure(api_key=api_key) | |
# Define OCR function using PaddleOCR | |
def ocr_with_paddle(img_path): | |
finaltext = '' | |
ocr = PaddleOCR(lang='en', use_angle_cls=True) | |
img = cv2.imread(img_path) | |
result = ocr.ocr(img) | |
for line in result[0]: | |
for word_info in line: | |
if isinstance(word_info[1], list): | |
text = word_info[1][0] | |
text=str(text) | |
finaltext += text + ' ' | |
else: | |
finaltext += str(word_info)+' ' | |
return finaltext.strip() | |
# Define the prompt template for extracting invoice details | |
prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", "You are a helpful assistant that extracts invoice details such as invoice number, customer name, date, amount, and other relevant information from a provided invoice text."), | |
("human", "{input}"), | |
] | |
) | |
llm = ChatGoogleGenerativeAI( | |
model="gemini-1.5-pro", | |
temperature=0.5, | |
max_tokens=None, | |
timeout=None, | |
max_retries=2, | |
api_key=api_key | |
) | |
invoice_chain = LLMChain(prompt=prompt, llm=llm) | |
def extract_invoice_details(input_text): | |
response = invoice_chain({"input": input_text}) | |
extracted_details = response["text"].strip() | |
return extracted_details | |
# Streamlit UI | |
st.title("Invoice OCR and Details Extraction") | |
st.write( | |
"Upload an image file to extract the text and invoice details such as invoice number, customer name, date, and amount." | |
) | |
# Image Upload | |
uploaded_image = st.file_uploader("Choose an Image", type=["jpg", "jpeg", "png"]) | |
if uploaded_image is not None: | |
# Save uploaded image to a temporary file | |
img_path = "temp_image.png" | |
with open(img_path, "wb") as f: | |
f.write(uploaded_image.getbuffer()) | |
# Perform OCR on the uploaded image | |
text = ocr_with_paddle(img_path) | |
st.write("Extracted Text:") | |
st.text_area("OCR Output", text, height=300) | |
# Extract invoice details from the text | |
invoice_details = extract_invoice_details(text) | |
st.write("Extracted Invoice Details:") | |
st.text_area("Invoice Details", invoice_details, height=300) | |
# Save details to the database | |
new_entry = MyDataModel(name="invoice_details", data=invoice_details) | |
session.add(new_entry) | |
session.commit() | |
session.close() | |
st.success("Invoice details saved to the database!") | |