import requests
from gtts import gTTS
import base64
import os

API_Key = os.environ['API_Key']
API_URL = os.environ['API_URL']

headers = {"Authorization": f"Bearer {API_Key}"}

basic_prompt = '''
You are a Virtual Assistant designed for assisting Alzheimer's Patients. Your name is Mysteria. You are currently assigned to a patient named Loki. The Guardian assigned to this patient is Sylvie. The Doctor assigned to the patient is Kang.

Details of Patient:
DOB: 14/04/1965
Last name: Odinson

Details of Guardian:
Name: Sylvie
Relation: Wife

Details of Doctor:
Name: Kang
Field: Psychology
Experience: 5 years
Office: End of Time
Next Appointment: 12/01/2024, 6:30 pm

You should respond only when "Mysteria" is announced.
When you are asked to shut up, You should stop responding, until you are awakened.
When you are awakened, Try to maintain a Conversation.

Maintain a conversation with the user, and, Answer the Questions properly.

'''
tags = {'user':'[Q]', 'assistant':"[A]", 'stop_query':''}

def build_prompt(query, conversation):
	prompt=basic_prompt+tags['stop_query']
	for msg in conversation:
		prompt+='\n'
		prompt+=tags[msg['role']]
		prompt+=msg['content']
		# prompt+=tags['stop_query']
	prompt+='\n'+tags['user']+query  # +tags['stop_query']
	prompt+='\n'+tags['assistant']
	return prompt, len(prompt)

def query(payload):
	response = requests.post(API_URL, headers=headers, json=payload)
	response = response.json()
	return response

def generate_response(inputs, conversation):
	prompt, next_index = build_prompt(inputs, conversation)
	payload = { 'inputs': prompt ,
			'parameters':{'max_new_tokens':50}}
	model_response = query(payload)
	model_response = model_response[0]['generated_text']
	response = model_response[next_index:]
	try:
		ind = response.index('[')
	except:
		ind = len(response)
	return response[:ind]
	
def audio_response(response):
	audio_stream="response_audio.mp3"
	tts = gTTS(response)
	tts.save(audio_stream)
	with open(audio_stream, 'rb') as file:
		audio_data = file.read()
	audio_base64 = base64.b64encode(audio_data).decode('utf-8')
	audio_tag = f'<audio autoplay="true" src="data:audio/mp3;base64,{audio_base64}">'
	return audio_tag