Spaces:
Running
Running
import transformers | |
import gradio as gr | |
from transformers import pipeline, DonutProcessor, VisionEncoderDecoderModel, AutoTokenizer, AutoModelForSequenceClassification, LayoutLMv3Processor, LayoutLMv3ForTokenClassification | |
from PIL import Image | |
import torch | |
import speech_recognition as sr | |
from pydub import AudioSegment | |
import os | |
import re | |
processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2",force_download=True) | |
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2") | |
task_prompt = "<s_cord-v2>" | |
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt")["input_ids"] | |
# Image Classification Model | |
image_classifier = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base") | |
# Sentiment Analysis Model | |
sentiment_pipeline = pipeline("sentiment-analysis", model="nlptown/bert-base-multilingual-uncased-sentiment") | |
# Text Categorization Model | |
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-mnli") | |
nli_model = AutoModelForSequenceClassification.from_pretrained("facebook/bart-large-mnli") | |
nli_pipeline = pipeline("zero-shot-classification", model=nli_model, tokenizer=tokenizer) | |
# Function for Image Recognition | |
def image_recognition(image): | |
try: | |
result = image_classifier(image) | |
output = "<h4>Image Details:</h4><ul>" | |
for item in result: | |
output += f"<li>Discription: {result[0]['generated_text']}</li>" | |
output += "</ul>" | |
return output | |
except Exception as e: | |
return f"<b>Error in Image Recognition:</b> {str(e)}" | |
# Function to extract text from an image using Donut | |
def extract_text_from_image(input_img): | |
try: | |
pixel_values = processor(input_img, return_tensors="pt").pixel_values | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model.to(device) | |
outputs = model.generate(pixel_values.to(device), | |
decoder_input_ids=decoder_input_ids.to(device), | |
max_length=model.decoder.config.max_position_embeddings, | |
early_stopping=True, | |
pad_token_id=processor.tokenizer.pad_token_id, | |
eos_token_id=processor.tokenizer.eos_token_id, | |
use_cache=True, | |
num_beams=1, | |
bad_words_ids=[[processor.tokenizer.unk_token_id]], | |
return_dict_in_generate=True, | |
output_scores=True,) | |
sequence = processor.batch_decode(outputs.sequences)[0] | |
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "") | |
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token | |
final=processor.token2json(sequence) | |
return final | |
except Exception as e: | |
return {"error": str(e)} | |
# Function for Sentiment Analysis | |
def analyze_sentiment(feedback_text): | |
try: | |
sentiment_result = sentiment_pipeline(feedback_text) | |
output = "<h4>Feedback Sentiment Analysis:</h4><ul>" | |
for item in sentiment_result: | |
output += f"<li>Label: {item['label']}, Score: {item['score']:.2f}</li>" | |
output += "</ul>" | |
return output | |
except Exception as e: | |
return f"<b>Error in Sentiment Analysis:</b> {str(e)}" | |
# Function for Text Categorization | |
def categorize_complaint(complaint_text): | |
try: | |
labels = ["coach cleanliness", "damage", "staff behavior", "safety", "delay", "other"] | |
result = nli_pipeline(complaint_text, candidate_labels=labels) | |
output = f"<h4>Complaint Categorization:</h4><p>Text: {result['sequence']}</p><ul>" | |
for label, score in zip(result['labels'], result['scores']): | |
output += f"<li>{label}: {score:.2f}</li>" | |
output += "</ul>" | |
return output | |
except Exception as e: | |
return f"<b>Error in Complaint Categorization:</b> {str(e)}" | |
# Function to Process Voice Input | |
def process_audio(audio): | |
recognizer = sr.Recognizer() | |
audio_file = audio # The file path from Gradio | |
# Convert audio to required format for processing | |
try: | |
sound = AudioSegment.from_file(audio_file) | |
sound.export("temp.wav", format="wav") | |
except Exception as e: | |
return f"<b>Audio processing error:</b> {e}" | |
with sr.AudioFile("temp.wav") as source: | |
audio_data = recognizer.record(source) | |
try: | |
text = recognizer.recognize_google(audio_data) | |
os.remove("temp.wav") # Clean up temporary file | |
return f"<h4>Transcribed Audio:</h4><p>{text}</p>" | |
except sr.UnknownValueError: | |
os.remove("temp.wav") # Clean up temporary file | |
return "<b>Could not understand the audio.</b>" | |
except sr.RequestError as e: | |
os.remove("temp.wav") # Clean up temporary file | |
return f"<b>Could not request results:</b> {e}" | |
# Gradio Interface Components | |
def main(image, complaint_text, feedback_text, audio): | |
# Process Image | |
image_results = image_recognition(image) if image else "<i>No image provided.</i>" | |
# Process OCR Text | |
ocr_text = extract_text_from_image(image) if image else "<i>No image provided.</i>" | |
# Process Complaint Categorization | |
categorized_complaint = categorize_complaint(complaint_text) if complaint_text else "<i>No complaint text provided.</i>" | |
# Process Sentiment Analysis | |
sentiment_result = analyze_sentiment(feedback_text) if feedback_text else "<i>No feedback text provided.</i>" | |
# Process Audio Input | |
audio_text = process_audio(audio) if audio else "<i>No audio provided.</i>" | |
return f"{image_results}<br>{ocr_text}<br>{categorized_complaint}<br>{sentiment_result}<br>{audio_text}" | |
# Build Gradio UI | |
iface = gr.Interface( | |
fn=main, | |
inputs=[ | |
gr.Image(type="pil", label="Upload Complaint Image"), | |
gr.Textbox(lines=5, placeholder="Enter Complaint Text", label="Complaint Text"), | |
gr.Textbox(lines=2, placeholder="Enter Feedback Text", label="Feedback Text"), | |
gr.Audio(type="filepath", label="Upload Audio Complaint") # Use 'filepath' for audio input | |
], | |
outputs=[ | |
gr.HTML(label="Results") # Changed to HTML for more customization | |
], | |
title="Rail Madad Complaint Resolution System", | |
description="AI-powered system for automated categorization, prioritization, and response to complaints on Rail Madad." | |
) | |
iface.launch() | |
img=Image.open("/content/Tech Mahindra hiring process.png") | |
image_classifier(img)[0]["generated_text"] |