import requests
import json
import base64

class BhashiniClient:
    """
    A client for interacting with Bhashini's ASR, NMT, and TTS services.

    Methods:
        list_available_languages(task_type): Lists available languages for a given task.
        get_supported_voices(source_language): Gets supported genders for TTS in a language.
        asr(audio_content, source_language, audio_format='wav', sampling_rate=16000): Performs ASR.
        translate(text, source_language, target_language): Translates text from source to target language.
        tts(text, source_language, gender='female', sampling_rate=8000): Performs TTS.
    """

    PIPELINE_CONFIG_ENDPOINT = "https://meity-auth.ulcacontrib.org/ulca/apis/v0/model/getModelsPipeline"
    INFERENCE_ENDPOINT = "https://dhruva-api.bhashini.gov.in/services/inference/pipeline"
    PIPELINE_ID = "64392f96daac500b55c543cd"

    def __init__(self, user_id, api_key, pipeline_id = PIPELINE_ID):
        """
        Initializes the BhashiniClient with user credentials and pipeline ID.

        Args:
            user_id (str): Your user ID.
            api_key (str): Your ULCA API key.
            pipeline_id (str): The pipeline ID.

        Raises:
            Exception: If the pipeline configuration retrieval fails.
        """
        self.user_id = user_id
        self.api_key = api_key
        self.pipeline_id = pipeline_id
        self.headers = {
            "Content-Type": "application/json",
            "userID": self.user_id,
            "ulcaApiKey": self.api_key
        }
        self.config = self._get_pipeline_config()
        self.pipeline_data = self._parse_pipeline_config(self.config)
        self.inference_api_key = self.pipeline_data['inferenceApiKey']

    def _get_pipeline_config(self):
        """
        Retrieves the pipeline configuration.

        Returns:
            dict: The pipeline configuration.

        Raises:
            Exception: If the request fails.
        """
        payload = {
            "pipelineTasks": [
                {"taskType": "asr"},
                {"taskType": "translation"},
                {"taskType": "tts"}
            ],
            "pipelineRequestConfig": {
                "pipelineId": self.pipeline_id
            }
        }
        response = requests.post(
            self.PIPELINE_CONFIG_ENDPOINT,
            headers=self.headers,
            data=json.dumps(payload)
        )
        response.raise_for_status()
        return response.json()

    def _parse_pipeline_config(self, config):
        """
        Parses the pipeline configuration and extracts necessary information.

        Args:
            config (dict): The pipeline configuration.

        Returns:
            dict: Parsed pipeline data.
        """
        inference_api_key = config['pipelineInferenceAPIEndPoint']['inferenceApiKey']['value']
        callback_url = config['pipelineInferenceAPIEndPoint']['callbackUrl']
        pipeline_data = {
            'asr': {},
            'tts': {},
            'translation': {},
            'inferenceApiKey': inference_api_key,
            'callbackUrl': callback_url
        }

        for pipeline in config['pipelineResponseConfig']:
            task_type = pipeline['taskType']
            if task_type in ['asr', 'translation', 'tts']:
                for language_config in pipeline['config']:
                    source_language = language_config['language']['sourceLanguage']

                    if task_type != 'translation':
                        if source_language not in pipeline_data[task_type]:
                            pipeline_data[task_type][source_language] = []

                        language_info = {
                            'serviceId': language_config['serviceId'],
                            'sourceScriptCode': language_config['language'].get('sourceScriptCode')
                        }

                        if task_type == 'tts':
                            language_info['supportedVoices'] = language_config.get('supportedVoices', [])

                        pipeline_data[task_type][source_language].append(language_info)
                    else:
                        target_language = language_config['language']['targetLanguage']
                        if source_language not in pipeline_data[task_type]:
                            pipeline_data[task_type][source_language] = {}

                        if target_language not in pipeline_data[task_type][source_language]:
                            pipeline_data[task_type][source_language][target_language] = []

                        language_info = {
                            'serviceId': language_config['serviceId'],
                            'sourceScriptCode': language_config['language'].get('sourceScriptCode'),
                            'targetScriptCode': language_config['language'].get('targetScriptCode')
                        }

                        pipeline_data[task_type][source_language][target_language].append(language_info)

        return pipeline_data

    def list_available_languages(self, task_type):
        """
        Lists the available languages for the specified task.

        Args:
            task_type (str): The task type ('asr', 'translation', or 'tts').

        Returns:
            list or dict: A list of available languages, or a dictionary for translation.

        Raises:
            ValueError: If an invalid task type is provided.

        Usage Example:
            client = BhashiniClient(user_id, api_key, pipeline_id)
            asr_languages = client.list_available_languages('asr')
            print("Available ASR Languages:", asr_languages)

            translation_languages = client.list_available_languages('translation')
            print("Available Translation Languages:", translation_languages)
        """
        if task_type not in ['asr', 'translation', 'tts']:
            raise ValueError("Invalid task type. Choose from 'asr', 'translation', or 'tts'.")

        if task_type == 'translation':
            languages = {}
            for src_lang in self.pipeline_data['translation']:
                languages[src_lang] = list(self.pipeline_data['translation'][src_lang].keys())
            return languages
        else:
            return list(self.pipeline_data[task_type].keys())

    def get_supported_voices(self, source_language):
        """
        Returns the supported genders for TTS in the specified language.

        Args:
            source_language (str): The language code (e.g., 'hi' for Hindi).

        Returns:
            list: A list of supported genders (e.g., ['male', 'female']).

        Raises:
            ValueError: If TTS is not supported for the language.

        Usage Example:
            client = BhashiniClient(user_id, api_key, pipeline_id)
            voices = client.get_supported_voices('hi')
            print("Supported voices for Hindi TTS:", voices)
        """
        if source_language not in self.pipeline_data['tts']:
            available_languages = ', '.join(self.list_available_languages('tts'))
            raise ValueError(
                f"TTS not supported for language '{source_language}'. "
                f"Available languages: {available_languages}"
            )

        service_info = self.pipeline_data['tts'][source_language][0]
        supported_voices = service_info.get('supportedVoices', [])
        return supported_voices


    def asr(self, audio_content, source_language, audio_format='wav', sampling_rate=16000):
        """
        Performs Automatic Speech Recognition on the provided audio content.

        Args:
            audio_content (bytes): The audio content in bytes.
            source_language (str): The language code of the audio (e.g., 'hi' for Hindi).
            audio_format (str): supported formats of audio content: ('wav', 'mp3', 'flac', 'ogg'.)
            sampling_rate (int): The sampling rate of the audio in Hz.

        Returns:
            dict: The ASR response from the API.

        Raises:
            ValueError: If the language is not supported.
            Exception: If the API request fails.

        Usage Example:
            client = BhashiniClient(user_id, api_key, pipeline_id)
            with open('audio.wav', 'rb') as f:
                audio_content = f.read()
            asr_result = client.asr(audio_content, source_language='hi', audio_format='wav')
            print("ASR Result:", asr_result)
        """
        if source_language not in self.pipeline_data['asr']:
            available_languages = ', '.join(self.list_available_languages('asr'))
            raise ValueError(
                f"ASR not supported for language '{source_language}'. "
                f"Available languages: {available_languages}"
            )

        service_info = self.pipeline_data['asr'][source_language][0]
        service_id = service_info['serviceId']

        payload = {
            "pipelineTasks": [
                {
                    "taskType": "asr",
                    "config": {
                        "language": {
                            "sourceLanguage": source_language
                        },
                        "serviceId": service_id,
                        "audioFormat": audio_format,
                        "samplingRate": sampling_rate
                    }
                }
            ],
            "inputData": {
                "audio": [
                    {
                        "audioContent": base64.b64encode(audio_content).decode('utf-8')
                    }
                ]
            }
        }

        headers = {
            'Accept': '*/*',
            'Authorization': self.inference_api_key,
            'Content-Type': 'application/json'
        }

        response = requests.post(
            self.INFERENCE_ENDPOINT,
            headers=headers,
            data=json.dumps(payload)
        )

        self._handle_response_errors(response)
        return response.json()

    def translate(self, text, source_language, target_language):
        """
        Translates the provided text from the source language to the target language.

        Args:
            text (str): The text to translate.
            source_language (str): The source language code.
            target_language (str): The target language code.

        Returns:
            dict: The translation response from the API.

        Raises:
            ValueError: If the language pair is not supported.
            Exception: If the API request fails.

        Usage Example:
            client = BhashiniClient(user_id, api_key, pipeline_id)
            translation_result = client.translate(
                'मेरा नाम विहिर है।',
                source_language='hi',
                target_language='gu'
            )
            print("Translation Result:", translation_result)
        """
        if source_language not in self.pipeline_data['translation']:
            available_languages = ', '.join(self.list_available_languages('translation').keys())
            raise ValueError(
                f"Translation not supported from language '{source_language}'. "
                f"Available source languages: {available_languages}"
            )

        if target_language not in self.pipeline_data['translation'][source_language]:
            available_targets = ', '.join(self.pipeline_data['translation'][source_language].keys())
            raise ValueError(
                f"Translation from '{source_language}' to '{target_language}' not supported. "
                f"Available target languages for '{source_language}': {available_targets}"
            )

        service_info = self.pipeline_data['translation'][source_language][target_language][0]
        service_id = service_info['serviceId']

        payload = {
            "pipelineTasks": [
                {
                    "taskType": "translation",
                    "config": {
                        "language": {
                            "sourceLanguage": source_language,
                            "targetLanguage": target_language
                        },
                        "serviceId": service_id
                    }
                }
            ],
            "inputData": {
                "input": [
                    {
                        "source": text
                    }
                ]
            }
        }

        headers = {
            'Accept': '*/*',
            'Authorization': self.inference_api_key,
            'Content-Type': 'application/json'
        }

        response = requests.post(
            self.INFERENCE_ENDPOINT,
            headers=headers,
            data=json.dumps(payload)
        )

        self._handle_response_errors(response)
        return response.json()

    def tts(self, text, source_language, gender='female', sampling_rate=8000):
        """
        Converts the provided text to speech in the specified language.

        Args:
            text (str): The text to convert to speech.
            source_language (str): The language code of the text.
            gender (str): The desired voice gender ('male' or 'female').
            sampling_rate (int): The sampling rate in Hz.

        Returns:
            dict: The TTS response from the API.

        Raises:
            ValueError: If the language or gender is not supported.
            Exception: If the API request fails.

        Usage Example:
            client = BhashiniClient(user_id, api_key, pipeline_id)
            tts_result = client.tts(
                'હેલો વર્લ્ડ',
                source_language='gu',
                gender='female'
            )
            # Save the audio output
            audio_base64 = tts_result['pipelineResponse'][0]['audio'][0]['audioContent']
            audio_data = base64.b64decode(audio_base64)
            with open('output_audio.wav', 'wb') as f:
                f.write(audio_data)
        """
        if source_language not in self.pipeline_data['tts']:
            available_languages = ', '.join(self.list_available_languages('tts'))
            raise ValueError(
                f"TTS not supported for language '{source_language}'. "
                f"Available languages: {available_languages}"
            )

        service_info = self.pipeline_data['tts'][source_language][0]
        service_id = service_info['serviceId']
        supported_voices = service_info.get('supportedVoices', [])

        if gender not in ['male', 'female']:
            raise ValueError("Gender must be 'male' or 'female'.")

        if supported_voices and gender not in supported_voices:
            available_genders = ', '.join(supported_voices)
            raise ValueError(
                f"Gender '{gender}' not supported for language '{source_language}'. "
                f"Available genders: {available_genders}"
            )

        payload = {
            "pipelineTasks": [
                {
                    "taskType": "tts",
                    "config": {
                        "language": {
                            "sourceLanguage": source_language
                        },
                        "serviceId": service_id,
                        "gender": gender,
                        "samplingRate": sampling_rate
                    }
                }
            ],
            "inputData": {
                "input": [
                    {
                        "source": text
                    }
                ]
            }
        }

        headers = {
            'Accept': '*/*',
            'Authorization': self.inference_api_key,
            'Content-Type': 'application/json'
        }

        response = requests.post(
            self.INFERENCE_ENDPOINT,
            headers=headers,
            data=json.dumps(payload)
        )

        self._handle_response_errors(response)
        return response.json()

    def _handle_response_errors(self, response):
        """
        Handles errors in the response.

        Args:
            response (requests.Response): The response object.

        Raises:
            Exception: If an HTTP error occurs.
        """
        try:
            response.raise_for_status()
        except requests.HTTPError as http_err:
            try:
                error_info = response.json()
                error_message = error_info.get('message', 'An error occurred.')
            except json.JSONDecodeError:
                error_message = response.text
            raise Exception(f"HTTP error occurred: {error_message}") from http_err