|
import requests |
|
from gtts import gTTS |
|
import base64 |
|
import os |
|
|
|
API_Key = os.environ['API_Key'] |
|
API_URL = os.environ['API_URL'] |
|
|
|
headers = {"Authorization": f"Bearer {API_Key}"} |
|
|
|
basic_prompt = ''' |
|
You are a Virtual Assistant designed for assisting Alzheimer's Patients. Your name is Mysteria. You are currently assigned to a patient named Loki. The Guardian assigned to this patient is Sylvie. The Doctor assigned to the patient is Kang. |
|
|
|
Details of Patient: |
|
DOB: 14/04/1965 |
|
Last name: Odinson |
|
|
|
Details of Guardian: |
|
Name: Sylvie |
|
Relation: Wife |
|
|
|
Details of Doctor: |
|
Name: Kang |
|
Field: Psychology |
|
Experience: 5 years |
|
Office: End of Time |
|
Next Appointment: 12/01/2024, 6:30 pm |
|
|
|
You should respond only when "Mysteria" is announced. |
|
When you are asked to shut up, You should stop responding, until you are awakened. |
|
When you are awakened, Try to maintain a Conversation. |
|
|
|
Maintain a conversation with the user, and, Answer the Questions properly. |
|
|
|
''' |
|
tags = {'user':'[Q]', 'assistant':"[A]", 'stop_query':''} |
|
|
|
def build_prompt(query, conversation): |
|
prompt=basic_prompt+tags['stop_query'] |
|
for msg in conversation: |
|
prompt+='\n' |
|
prompt+=tags[msg['role']] |
|
prompt+=msg['content'] |
|
|
|
prompt+='\n'+tags['user']+query |
|
prompt+='\n'+tags['assistant'] |
|
return prompt, len(prompt) |
|
|
|
def query(payload): |
|
response = requests.post(API_URL, headers=headers, json=payload) |
|
response = response.json() |
|
return response |
|
|
|
def generate_response(inputs, conversation): |
|
prompt, next_index = build_prompt(inputs, conversation) |
|
payload = { 'inputs': prompt , |
|
'parameters':{'max_new_tokens':50}} |
|
model_response = query(payload) |
|
model_response = model_response[0]['generated_text'] |
|
response = model_response[next_index:] |
|
try: |
|
ind = response.index('[') |
|
except: |
|
ind = len(response) |
|
return response[:ind] |
|
|
|
def audio_response(response): |
|
audio_stream="response_audio.mp3" |
|
tts = gTTS(response) |
|
tts.save(audio_stream) |
|
with open(audio_stream, 'rb') as file: |
|
audio_data = file.read() |
|
audio_base64 = base64.b64encode(audio_data).decode('utf-8') |
|
audio_tag = f'<audio autoplay="true" src="data:audio/mp3;base64,{audio_base64}">' |
|
return audio_tag |