Spaces:
Build error
Build error
import gradio as gr | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
import torch | |
import datetime | |
import re | |
import os | |
import pytz | |
import dateutil.parser | |
# Load the DistilBERT model and tokenizer | |
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased") | |
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") | |
# Initialize an empty list to store events | |
events = [] | |
# Load events from file if it exists | |
if os.path.isfile("events.txt"): | |
with open("events.txt", "r") as f: | |
for line in f: | |
event_data = line.strip().split("|") | |
if len(event_data) == 4: | |
name, start_str, end_str, recurring = event_data | |
start = dateutil.parser.parse(start_str) | |
end = dateutil.parser.parse(end_str) | |
is_recurring = (recurring.lower() == "true") | |
events.append({"name": name, "start": start, "end": end, "recurring": is_recurring}) | |
print(f"Loaded event: {name} ({start} - {end})") | |
def generate_response(prompt): | |
""" | |
Generate a response using the DistilBERT model. | |
""" | |
inputs = tokenizer(prompt, return_tensors="pt") | |
output = model(**inputs)[0] # get the logits | |
return tokenizer.decode(torch.argmax(output, dim=-1)[0], skip_special_tokens=True) | |
def list_events(start, end): | |
""" | |
List events for the day between start and end times. | |
""" | |
event_summaries = [] | |
for event in events: | |
event_start = event["start"] | |
event_end = event["end"] | |
if event_start.tzinfo is None: | |
event_start = pytz.utc.localize(event_start) | |
if event_end.tzinfo is None: | |
event_end = pytz.utc.localize(event_end) | |
if start <= event_start < end: | |
event_summaries.append(f"{event['name']} ({event_start.strftime('%I:%M %p')} - {event_end.strftime('%I:%M %p')})") | |
if not event_summaries: | |
return "There are no events presently." | |
return ", ".join(event_summaries) | |
def create_event(summary, start, end, recurring=False): | |
""" | |
Create a new event. | |
""" | |
event = {"name": summary, "start": start, "end": end, "recurring": recurring} | |
events.append(event) | |
save_events() | |
return f"Event '{summary}' has been scheduled from {start.strftime('%I:%M %p')} to {end.strftime('%I:%M %p')}." | |
def save_events(): | |
""" | |
Save events to a text file. | |
""" | |
with open("events.txt", "w") as f: | |
for event in events: | |
start_str = event["start"].strftime("%Y-%m-%d %H:%M:%S") | |
end_str = event["end"].strftime("%Y-%m-%d %H:%M:%S") | |
recurring_str = "True" if event["recurring"] else "False" | |
f.write(f"{event['name']}|{start_str}|{end_str}|{recurring_str}\n") | |
def process_input(user_input): | |
""" | |
Process the user input and perform the corresponding action. | |
""" | |
if any(keyword in user_input.lower() for keyword in ["schedule", "create"]): | |
summary, start, end, recurring = extract_event_details(user_input) | |
if summary and start and end: | |
response = create_event(summary, start, end, recurring) | |
return response | |
else: | |
return "I'm sorry, I couldn't understand the event details. Please try again." | |
elif any(keyword in user_input.lower() for keyword in ["list", "show"]): | |
start = datetime.datetime.now(pytz.utc).replace(hour=0, minute=0, second=0, microsecond=0, tzinfo=pytz.utc) | |
end = start + datetime.timedelta(days=1, seconds=-1, microseconds=-1) | |
existing_events = list_events(start, end) | |
return existing_events | |
else: | |
return "I'm sorry, I didn't understand your request. Please try again." | |
def extract_event_details(user_input): | |
""" | |
Extract the event summary, start time, end time, and recurrence from the user input. | |
""" | |
patterns = [ | |
r"(schedule|create)(.*?)from\s*(\d+:\d+\s*[aApP][mM])\s*to\s*(\d+:\d+\s*[aApP][mM])\s*tomorrow", | |
r"(schedule|create)(.*?)from\s*(\d+:\d+\s*[aApP][mM])\s*to\s*(\d+:\d+\s*[aApP][mM])\s*on\s*(\w+)", | |
r"(schedule|create)(.*?)from\s*(\d+:\d+\s*[aApP][mM])\s*to\s*(\d+:\d+\s*[aApP][mM])\s*every\s*(\w+)", | |
r"(schedule|create)(.*?)from\s*(\d+:\d+\s*[aApP][mM])\s*to\s*(\d+:\d+\s*[aApP][mM])\s*on\s*the\s*(\w+)\s*of\s*every\s*(\w+)", | |
r"(schedule|create)(.*?)from\s*(\d+:\d+\s*[aApP][mM])\s*to\s*(\d+:\d+\s*[aApP][mM])\s*on\s*the\s*(last|first|second|third|fourth)\s*(\w+)\s*of\s*every\s*(\w+)", | |
] | |
for pattern in patterns: | |
match = re.search(pattern, user_input, re.IGNORECASE) | |
if match: | |
summary = match.group(2).strip() | |
start_str = match.group(3).strip() | |
end_str = match.group(4).strip() | |
if match.group(5) is None: | |
tomorrow = datetime.date.today() + datetime.timedelta(days=1) | |
start = datetime.datetime.combine(tomorrow, datetime.datetime.strptime(start_str, "%I:%M %p").time()) | |
end = datetime.datetime.combine(tomorrow, datetime.datetime.strptime(end_str, "%I:%M %p").time()) | |
recurring = False | |
elif match.group(6): | |
day_of_week = match.group(6).lower() | |
start = datetime.datetime.combine(datetime.date.today(), datetime.datetime.strptime(start_str, "%I:%M %p").time()) | |
while start.strftime("%A").lower() != day_of_week: | |
start += datetime.timedelta(days=1) | |
end = start + datetime.timedelta(hours=int(end_str.split(":")[0]) - int(start_str.split(":")[0]), minutes=int(end_str.split(":")[1]) - int(start_str.split(":")[1])) | |
recurring = (match.group(7) == "every") | |
elif match.group(8): | |
ordinal = match.group(8).lower() | |
weekday = match.group(9).lower() | |
month = match.group(10).lower() | |
start = datetime.datetime.combine(datetime.date.today(), datetime.datetime.strptime(start_str, "%I:%M %p").time()) | |
next_month = start.replace(day=1) + datetime.timedelta(days=32) | |
while start.strftime("%B").lower() != month: | |
start = next_month | |
next_month = start.replace(day=1) + datetime.timedelta(days=32) | |
while start.strftime("%A").lower() != weekday: | |
start += datetime.timedelta(days=1) | |
if ordinal == "last": | |
while start.replace(day=1) + datetime.timedelta(days=32) > start.replace(month=start.month + 1, day=1): | |
start -= datetime.timedelta(days=7) | |
else: | |
count = 1 | |
while count < int(ordinal): | |
start += datetime.timedelta(days=7) | |
if start.strftime("%B").lower() != month: | |
break | |
count += 1 | |
end = start + datetime.timedelta(hours=int(end_str.split(":")[0]) - int(start_str.split(":")[0]), minutes=int(end_str.split(":")[1]) - int(start_str.split(":")[1])) | |
recurring = (match.group(11) == "every") | |
start = pytz.utc.localize(start) | |
end = pytz.utc.localize(end) | |
return summary, start, end, recurring | |
# If the input doesn't match any pattern, try to parse it using dateutil | |
try: | |
date_strings = dateutil.parser.parse(user_input, fuzzy=True) | |
if isinstance(date_strings, list): | |
start, end = date_strings | |
else: | |
start = end = date_strings | |
summary = "Event" | |
start = pytz.utc.localize(start) | |
end = pytz.utc.localize(end) | |
return summary, start, end, False | |
except (ValueError, OverflowError): | |
pass | |
return None, None, None, False | |
# Gradio interface | |
def chat(user_input): | |
response = process_input(user_input) | |
return response | |
iface = gr.Interface(chat, inputs="text", outputs="text", title="AI Scheduling Assistant") | |
iface.launch() |