File size: 9,302 Bytes
59d3355 6738563 59d3355 6738563 59d3355 6738563 59d3355 3f89a2a 59d3355 6738563 59d3355 6738563 59d3355 6569632 6738563 6569632 6738563 6569632 6738563 6569632 6738563 6569632 8357c17 6569632 d336c1b 6738563 08a61a8 6738563 d336c1b 08a61a8 f584587 6738563 08a61a8 6738563 08a61a8 6738563 08a61a8 6738563 08a61a8 6738563 08a61a8 d336c1b 6738563 d336c1b 6738563 08a61a8 6738563 d336c1b 6738563 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 |
import base64
import io
import json
import os
from typing import Any, Dict, List
import chromadb
import google.generativeai as palm
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import pandas as pd
import requests
import streamlit as st
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
from langchain.text_splitter import (
RecursiveCharacterTextSplitter,
SentenceTransformersTokenTextSplitter,
)
from PIL import Image, ImageDraw, ImageFont
from pypdf import PdfReader
# API Key (You should set this in your environment variables)
# api_key = st.secrets["PALM_API_KEY"]
api_key = os.environ["PALM_API_KEY"]
palm.configure(api_key=api_key)
# Function to convert the image to bytes for download
def convert_image_to_bytes(image):
buffered = io.BytesIO()
image.save(buffered, format="JPEG")
return buffered.getvalue()
# Function to resize the image
def resize_image(image):
return image.resize((512, int(image.height * 512 / image.width)))
# Function to convert the image to base64
def convert_image_to_base64(image):
buffered = io.BytesIO()
image.save(buffered, format="JPEG")
return base64.b64encode(buffered.getvalue()).decode()
# Function to make an API call to Palm
def call_palm(prompt: str) -> str:
completion = palm.generate_text(
model="models/text-bison-001",
prompt=prompt,
temperature=0,
max_output_tokens=800,
)
return completion.result
# Function to make an API call to Google's Gemini API
def call_gemini_api(image_base64, api_key=api_key, prompt="What is this picture?"):
headers = {
"Content-Type": "application/json",
}
data = {
"contents": [
{
"parts": [
{"text": prompt},
{"inline_data": {"mime_type": "image/jpeg", "data": image_base64}},
]
}
]
}
response = requests.post(
f"https://generativelanguage.googleapis.com/v1beta/models/gemini-pro-vision:generateContent?key={api_key}",
headers=headers,
json=data,
)
return response.json()
def safely_get_text(response):
try:
response
except Exception as e:
print(f"An error occurred: {e}")
# Return None or a default value if the path does not exist
return None
def post_request_and_parse_response(
url: str, payload: Dict[str, Any]
) -> Dict[str, Any]:
"""
Sends a POST request to the specified URL with the given payload,
then parses the byte response to a dictionary.
Args:
url (str): The URL to which the POST request is sent.
payload (Dict[str, Any]): The payload to send in the POST request.
Returns:
Dict[str, Any]: The parsed dictionary from the response.
"""
# Set headers for the POST request
headers = {"Content-Type": "application/json"}
# Send the POST request and get the response
response = requests.post(url, json=payload, headers=headers)
# Extract the byte data from the response
byte_data = response.content
# Decode the byte data to a string
decoded_string = byte_data.decode("utf-8")
# Convert the JSON string to a dictionary
dict_data = json.loads(decoded_string)
return dict_data
def extract_line_items(input_data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Extracts items with "BlockType": "LINE" from the provided JSON data.
Args:
input_data (Dict[str, Any]): The input JSON data as a dictionary.
Returns:
List[Dict[str, Any]]: A list of dictionaries with the extracted data.
"""
# Initialize an empty list to hold the extracted line items
line_items: List[Dict[str, Any]] = []
# Get the list of items from the 'body' key in the input data
body_items = json.loads(input_data.get("body", "[]"))
# Iterate through each item in the body
for item in body_items:
# Check if the BlockType of the item is 'LINE'
if item.get("BlockType") == "LINE":
# Add the item to the line_items list
line_items.append(item)
return line_items
def rag(query: str, retrieved_documents: list, api_key: str = api_key) -> str:
"""
Function to process a query and a list of retrieved documents using the Gemini API.
Args:
query (str): The user's query or question.
retrieved_documents (list): A list of documents retrieved as relevant information to the query.
api_key (str): API key for accessing the Gemini API. Default is a predefined 'api_key'.
Returns:
str: The cleaned output from the Gemini API response.
"""
# Combine the retrieved documents into a single string, separated by two newlines.
information = "\n\n".join(retrieved_documents)
# Format the query and combined information into a single message.
messages = f"Question: {query}. \n Information: {information}"
# Call the Gemini API with the formatted message and the API key.
gemini_output = call_palm(prompt=messages)
# Placeholder for processing the Gemini output. Currently, it simply assigns the raw output to 'cleaned_output'.
cleaned_output = gemini_output # ["candidates"][0]["content"]["parts"][0]["text"]
return cleaned_output
def displayPDF(file: str) -> None:
"""
Displays a PDF file in a Streamlit application.
Parameters:
- file (str): The path to the PDF file to be displayed.
"""
# Opening the PDF file in binary read mode
with open(file, "rb") as f:
# Encoding the PDF file content to base64
base64_pdf: str = base64.b64encode(f.read()).decode("utf-8")
# Creating an HTML embed string for displaying the PDF
pdf_display: str = f'<embed src="data:application/pdf;base64,{base64_pdf}" width="700" height="1000" type="application/pdf">'
# Using Streamlit to display the HTML embed string as unsafe HTML
st.markdown(pdf_display, unsafe_allow_html=True)
def draw_boxes(image: Any, predictions: List[Dict[str, Any]]) -> Any:
"""
Draws bounding boxes and labels onto an image based on provided predictions.
Parameters:
- image (Any): The image to annotate, which should support the PIL drawing interface.
- predictions (List[Dict[str, Any]]): A list of predictions where each prediction is a dictionary
containing 'label', 'score', and 'box' keys. The 'box' is another dictionary with 'xmin',
'ymin', 'xmax', and 'ymax' as keys representing coordinates for the bounding box.
Returns:
- Any: The annotated image with bounding boxes and labels drawn on it.
Note:
- This function assumes that the incoming image supports the PIL ImageDraw interface.
- The function directly modifies the input image and returns it.
"""
# Create a drawing context from the image
draw = ImageDraw.Draw(image)
# Load a default font for text drawing
font = ImageFont.load_default()
# Loop through all predictions and draw boxes with labels
for pred in predictions:
# Extracting label and score from the prediction
label = pred["label"]
score = pred["score"]
# Extracting the bounding box coordinates
box = pred["box"]
xmin, ymin, xmax, ymax = box.values()
# Draw a rectangle over the image using the box's coordinates
draw.rectangle([xmin, ymin, xmax, ymax], outline="green", width=1)
# Annotate the image with label and score at the top-left corner of the bounding box
draw.text((xmin, ymin), f"{label} ({score:.2f})", fill="red", font=font)
# Return the annotated image
return image
def draw_bounding_boxes_for_textract(
image: Image.Image, json_data: Dict[str, Any]
) -> Image.Image:
"""
Draws bounding boxes on an image based on the provided JSON data from Textract.
Args:
image_path: The path to the image on which to draw bounding boxes.
json_data: The JSON string containing the bounding box data from Textract.
Returns:
A PIL Image object with bounding boxes drawn.
"""
# Load the image from the provided path
draw = ImageDraw.Draw(image)
# Parse the JSON data
try:
data = json_data
blocks = json.loads(data["body"]) if "body" in data else None
except json.JSONDecodeError:
st.error("Invalid JSON data.")
return image
if blocks is None:
st.error("No bounding box data found.")
return image
# Iterate through the elements to find bounding boxes and draw them
for item in blocks:
if "BlockType" in item and item["BlockType"] in ["LINE", "WORD"]:
bbox = item["Geometry"]["BoundingBox"]
# Extract coordinates and dimensions
left, top, width, height = (
bbox["Left"],
bbox["Top"],
bbox["Width"],
bbox["Height"],
)
# Calculate bounding box coordinates in image space
left_top = (left * image.width, top * image.height)
right_bottom = ((left + width) * image.width, (top + height) * image.height)
# Draw rectangle
draw.rectangle([left_top, right_bottom], outline="red", width=2)
return image
|