KeenWoo's picture
Update utils.py
22c89de verified
# utils.py
# Contains shared utility functions for text processing, audio transcription,
# date/time handling, and image analysis that can be used by any assessment module.
import os
import re
import time
from datetime import datetime
import gradio as gr
import cv2
import nltk
import numpy as np
import pytz
import whisper
from scipy.io.wavfile import write as write_wav
# from shapely.geometry import Polygon
# --- NLTK Setup ---
LOCAL_NLTK_DATA_PATH = os.path.join(os.path.dirname(__file__), 'nltk_data')
if LOCAL_NLTK_DATA_PATH not in nltk.data.path:
nltk.data.path.append(LOCAL_NLTK_DATA_PATH)
def download_nltk_data_if_needed(resource_name, download_name):
"""Checks if NLTK data exists and downloads it if necessary."""
try:
nltk.data.find(resource_name)
except LookupError:
print(f"Downloading NLTK resource '{download_name}'...")
if not os.path.exists(LOCAL_NLTK_DATA_PATH):
os.makedirs(LOCAL_NLTK_DATA_PATH)
nltk.download(download_name, download_dir=LOCAL_NLTK_DATA_PATH)
print("Download complete.")
# Download necessary NLTK packages
download_nltk_data_if_needed('tokenizers/punkt', 'punkt')
download_nltk_data_if_needed('taggers/averaged_perceptron_tagger', 'averaged_perceptron_tagger')
download_nltk_data_if_needed('tokenizers/punkt_tab', 'punkt_tab')
download_nltk_data_if_needed('taggers/averaged_perceptron_tagger_eng', 'averaged_perceptron_tagger_eng')
# --- Whisper Model Loading ---
print("Loading Whisper transcription model...")
model = whisper.load_model("large-v3")
print("Whisper model loaded.")
def transcribe(audio, lang_code="en"):
"""
Transcribes audio using Whisper. If the language is not English,
it performs both transcription and translation, returning a combined string.
"""
if audio is None:
return ""
sample_rate, y = audio
temp_wav_path = "/tmp/temp_audio.wav"
write_wav(temp_wav_path, sample_rate, y)
if lang_code == "en":
# Standard transcription for English
result = model.transcribe(temp_wav_path, language="en", task="transcribe")
return result["text"].strip()
else:
# For other languages, transcribe first, then translate
options_transcribe = {"language": lang_code, "task": "transcribe"}
transcribed_result = model.transcribe(temp_wav_path, **options_transcribe)
original_text = transcribed_result["text"].strip()
# Translate the same audio to English
options_translate = {"task": "translate"}
translated_result = model.transcribe(temp_wav_path, **options_translate)
translated_text = translated_result["text"].strip()
if original_text == translated_text:
return original_text # Avoids showing "(text)" if translation is same as transcription
else:
return f"{original_text} ({translated_text})"
# --- Date & Time Utilities ---
TARGET_TIMEZONE = pytz.timezone("America/New_York")
now_utc = datetime.now(pytz.utc)
now = now_utc.astimezone(TARGET_TIMEZONE)
def get_season(month):
"""Determines the season in the Northern Hemisphere based on the month."""
if 3 <= month <= 5: return "spring"
elif 6 <= month <= 8: return "summer"
elif 9 <= month <= 11: return "fall"
else: return "winter"
# --- Text Normalization and Cleaning Dictionaries & Functions ---
WORD_TO_DIGIT = {
'zero': '0', 'one': '1', 'two': '2', 'three': '3', 'four': '4', 'five': '5',
'six': '6', 'seven': '7', 'eight': '8', 'nine': '9', 'ten': '10',
'eleven': '11', 'twelve': '12', 'thirteen': '13', 'fourteen': '14',
'fifteen': '15', 'sixteen': '16', 'seventeen': '17', 'eighteen': '18',
'nineteen': '19', 'twenty': '20', 'thirty': '30', 'thirty one': '31',
# Add common phrases for the 'sevens' test for robustness
'ninety three': '93', 'eighty six': '86', 'seventy nine': '79',
'seventy two': '72', 'sixty five': '65'
}
ORDINAL_TO_DIGIT = {
# Single word ordinals
'first': '1', 'second': '2', 'third': '3', 'fourth': '4', 'fifth': '5',
'sixth': '6', 'seventh': '7', 'eighth': '8', 'ninth': '9', 'tenth': '10',
'eleventh': '11', 'twelfth': '12', 'thirteenth': '13', 'fourteenth': '14',
'fifteenth': '15', 'sixteenth': '16', 'seventeenth': '17', 'eighteenth': '18',
'nineteenth': '19', 'twentieth': '20', 'thirtieth': '30',
# Hyphenated compound ordinals
'twenty-first': '21',
'twenty-second': '22', 'twenty-third': '23', 'twenty-fourth': '24',
'twenty-fifth': '25', 'twenty-sixth': '26', 'twenty-seventh': '27',
'twenty-eighth': '28', 'twenty-ninth': '29', 'thirty-first': '31',
# Unhyphenated compound ordinals
'twenty first': '21', 'twenty second': '22', 'twenty third': '23',
'twenty fourth': '24', 'twenty fifth': '25', 'twenty sixth': '26',
'twenty seventh': '27', 'twenty eighth': '28', 'twenty ninth': '29',
'thirty first': '31',
# Suffix-based ordinals
'1st': '1', '2nd': '2', '3rd': '3', '4th': '4',
'5th': '5', '6th': '6', '7th': '7', '8th': '8', '9th': '9', '10th': '10',
'11th': '11', '12th': '12', '13th': '13', '14th': '14', '15th': '15',
'16th': '16', '17th': '17', '18th': '18', '19th': '19', '20th': '20',
'21st': '21', '22nd': '22', '23rd': '23', '24th': '24', '25th': '25',
'26th': '26', '27th': '27', '28th': '28', '29th': '29', '30th': '30', '31st': '31'
}
def clean_text_answer(text: str) -> str:
"""A robust function to clean all text inputs before scoring."""
if not text: return ""
text = text.lower()
text = re.sub(r'[^\w\s]', '', text)
text = " ".join(text.split())
return text
def normalize_date_answer(text: str) -> str:
"""Converts spoken ordinals and phrases into a clean numeric string for dates."""
if not text: return ""
clean_text = text.lower().strip()
if clean_text.startswith("the "):
clean_text = clean_text[4:]
for word, digit in ORDINAL_TO_DIGIT.items():
if word in clean_text:
clean_text = clean_text.replace(word, digit)
break
return re.sub(r'\D', '', clean_text)
def clean_numeric_answer(text: str) -> str:
"""Removes all non-digit characters from a string."""
return re.sub(r'\D', '', text or "")
def normalize_numeric_words(text: str) -> str:
"""Converts spoken number words in a string to digits."""
if not text: return ""
text = text.lower().strip()
for word, digit in WORD_TO_DIGIT.items():
text = re.sub(r'\b' + re.escape(word) + r'\b', digit, text)
return text
# --- Generic Scoring Utilities ---
def score_keyword_match(expected, user_input):
"""Checks if any expected keywords (separated by '|') are in the user's answer."""
if not expected or not user_input:
return 0
cleaned_user = clean_text_answer(user_input)
possible_answers = expected.split('|')
for ans in possible_answers:
cleaned_ans = clean_text_answer(ans)
if cleaned_ans in cleaned_user:
return 1
return 0
def score_sentence_structure(raw_user_input):
"""Checks for noun/verb in the original, un-cleaned text using NLTK."""
try:
text = nltk.word_tokenize(raw_user_input or "")
if len(text) < 2: return 0
pos_tags = nltk.pos_tag(text)
has_noun = any(tag.startswith('NN') for _, tag in pos_tags)
has_verb = any(tag.startswith('VB') for _, tag in pos_tags)
return 1 if has_noun and has_verb else 0
except Exception as e:
print(f"[NLTK ERROR] Failed to parse sentence: {e}")
return 0
def score_drawing(image_path, expected_sides):
"""Scores a drawing by finding the number of sides of the smallest significant polygon."""
if not image_path or not os.path.exists(image_path):
return 0, 0
try:
img = cv2.imread(image_path)
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
_, thresh = cv2.threshold(gray, 240, 255, cv2.THRESH_BINARY_INV)
contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
significant_contours = [c for c in contours if cv2.contourArea(c) > 500]
if len(significant_contours) < 3:
return 0, 0 # Not enough shapes to form a valid intersection
min_area = float('inf')
sides_of_smallest_shape = 0
for contour in significant_contours:
area = cv2.contourArea(contour)
if area < min_area:
min_area = area
epsilon = 0.04 * cv2.arcLength(contour, True)
approx = cv2.approxPolyDP(contour, epsilon, True)
sides_of_smallest_shape = len(approx)
score = 1 if sides_of_smallest_shape == expected_sides else 0
return score, sides_of_smallest_shape
except Exception as e:
print(f"[OpenCV ERROR] Failed to process image: {e}")
return 0, 0
# --- Gradio UI Helper Functions ---
def save_final_answer(current_index, current_answer, all_answers):
"""A dedicated function to save the last answer before submitting."""
all_answers[current_index] = current_answer
return all_answers
def update_view(new_index, all_answers, module):
"""Updates the UI elements when navigating to a new question."""
q_data = module.STRUCTURED_QUESTIONS[new_index]
progress = f"## {q_data['main_cat']} - Q{q_data['main_num']}{q_data['sub_letter']} ({new_index + 1} of {module.TOTAL_QUESTIONS})"
is_drawing_q = "draw a copy" in q_data["question"]
return (
f"Say 🔊 {q_data['question']}",
all_answers[new_index],
new_index,
progress,
q_data["instruction"],
module.QUESTION_CHOICES[new_index],
gr.update(visible=is_drawing_q),
None # Clear the audio_input component
)
def save_and_navigate(direction, current_index, current_answer, all_answers, module):
"""Saves the current answer and moves to the next/previous question."""
all_answers[current_index] = current_answer
if direction == "next":
new_index = min(current_index + 1, module.TOTAL_QUESTIONS - 1)
else: # prev
new_index = max(current_index - 1, 0)
return update_view(new_index, all_answers, module) + (all_answers,)
def jump_to_question(selected_choice, current_index, current_answer, all_answers, module):
"""Saves the current answer and jumps to the selected question."""
if not selected_choice:
return update_view(current_index, all_answers, module) + (all_answers,)
all_answers[current_index] = current_answer
new_index = module.QUESTION_CHOICES.index(selected_choice)
return update_view(new_index, all_answers, module) + (all_answers,)
def reset_app(module):
"""Resets the state of an assessment tab to its initial view."""
initial_q = module.STRUCTURED_QUESTIONS[0]
is_drawing_q = "draw a copy" in initial_q["question"]
return (
0, # question_index
[""] * module.TOTAL_QUESTIONS, # answers
"", # score_lines
"", # total
f"Say 🔊 {initial_q['question']}", # question_button
f"## {initial_q['main_cat']} - Q{initial_q['main_num']}{initial_q['sub_letter']} (1 of {module.TOTAL_QUESTIONS})",
initial_q["instruction"],
"", # answer_text
module.QUESTION_CHOICES[0], # jump_nav
None, # audio_input
None, # image_upload
gr.update(visible=False), # start_over_btn
gr.update(visible=True), # submit_btn
None, # tts_audio
"" # score_state
)