"""Call API providers."""

import json
import os
import random
import re
from typing import Optional
import time

import requests

from src.utils import build_logger


logger = build_logger("gradio_web_server", "gradio_web_server.log")


def get_api_provider_stream_iter(
    conv,
    model_name,
    model_api_dict,
    temperature,
    top_p,
    max_new_tokens,
    state,
):
    if model_api_dict["api_type"] == "openai":
        if model_api_dict["vision-arena"]:
            prompt = conv.to_openai_vision_api_messages()
        else:
            prompt = conv.to_openai_api_messages()
        stream_iter = openai_api_stream_iter(
            model_api_dict["model_name"],
            prompt,
            temperature,
            top_p,
            max_new_tokens,
            api_base=model_api_dict["api_base"],
            api_key=model_api_dict["api_key"],
        )
    elif model_api_dict["api_type"] == "openai_assistant":
        last_prompt = conv.messages[-2][1]
        stream_iter = openai_assistant_api_stream_iter(
            state,
            last_prompt,
            assistant_id=model_api_dict["assistant_id"],
            api_key=model_api_dict["api_key"],
        )
    elif model_api_dict["api_type"] == "anthropic":
        if model_api_dict["vision-arena"]:
            prompt = conv.to_anthropic_vision_api_messages()
        else:
            prompt = conv.to_openai_api_messages()
        stream_iter = anthropic_api_stream_iter(
            model_name, prompt, temperature, top_p, max_new_tokens
        )
    elif model_api_dict["api_type"] == "anthropic_message":
        if model_api_dict["vision-arena"]:
            prompt = conv.to_anthropic_vision_api_messages()
        else:
            prompt = conv.to_openai_api_messages()
        stream_iter = anthropic_message_api_stream_iter(
            model_name, prompt, temperature, top_p, max_new_tokens
        )
    elif model_api_dict["api_type"] == "anthropic_message_vertex":
        if model_api_dict["vision-arena"]:
            prompt = conv.to_anthropic_vision_api_messages()
        else:
            prompt = conv.to_openai_api_messages()
        stream_iter = anthropic_message_api_stream_iter(
            model_api_dict["model_name"],
            prompt,
            temperature,
            top_p,
            max_new_tokens,
            vertex_ai=True,
        )
    elif model_api_dict["api_type"] == "gemini":
        prompt = conv.to_gemini_api_messages()
        stream_iter = gemini_api_stream_iter(
            model_api_dict["model_name"],
            prompt,
            temperature,
            top_p,
            max_new_tokens,
            api_key=model_api_dict["api_key"],
        )
    elif model_api_dict["api_type"] == "bard":
        prompt = conv.to_openai_api_messages()
        stream_iter = bard_api_stream_iter(
            model_api_dict["model_name"],
            prompt,
            temperature,
            top_p,
            api_key=model_api_dict["api_key"],
        )
    elif model_api_dict["api_type"] == "mistral":
        prompt = conv.to_openai_api_messages()
        stream_iter = mistral_api_stream_iter(
            model_name, prompt, temperature, top_p, max_new_tokens
        )
    elif model_api_dict["api_type"] == "nvidia":
        prompt = conv.to_openai_api_messages()
        stream_iter = nvidia_api_stream_iter(
            model_name,
            prompt,
            temperature,
            top_p,
            max_new_tokens,
            model_api_dict["api_base"],
        )
    elif model_api_dict["api_type"] == "ai2":
        prompt = conv.to_openai_api_messages()
        stream_iter = ai2_api_stream_iter(
            model_name,
            model_api_dict["model_name"],
            prompt,
            temperature,
            top_p,
            max_new_tokens,
            api_base=model_api_dict["api_base"],
            api_key=model_api_dict["api_key"],
        )
    elif model_api_dict["api_type"] == "vertex":
        prompt = conv.to_vertex_api_messages()
        stream_iter = vertex_api_stream_iter(
            model_name, prompt, temperature, top_p, max_new_tokens
        )
    elif model_api_dict["api_type"] == "yandexgpt":
        # note: top_p parameter is unused by yandexgpt

        messages = []
        if conv.system_message:
            messages.append({"role": "system", "text": conv.system_message})
        messages += [
            {"role": role, "text": text}
            for role, text in conv.messages
            if text is not None
        ]

        fixed_temperature = model_api_dict.get("fixed_temperature")
        if fixed_temperature is not None:
            temperature = fixed_temperature

        stream_iter = yandexgpt_api_stream_iter(
            model_name=model_api_dict["model_name"],
            messages=messages,
            temperature=temperature,
            max_tokens=max_new_tokens,
            api_base=model_api_dict["api_base"],
            api_key=model_api_dict.get("api_key"),
            folder_id=model_api_dict.get("folder_id"),
        )
    elif model_api_dict["api_type"] == "cohere":
        messages = conv.to_openai_api_messages()
        stream_iter = cohere_api_stream_iter(
            client_name=model_api_dict.get("client_name", "FastChat"),
            model_id=model_api_dict["model_name"],
            messages=messages,
            temperature=temperature,
            top_p=top_p,
            max_new_tokens=max_new_tokens,
            api_base=model_api_dict["api_base"],
            api_key=model_api_dict["api_key"],
        )
    elif model_api_dict["api_type"] == "reka":
        messages = conv.to_reka_api_messages()
        stream_iter = reka_api_stream_iter(
            model_name=model_api_dict["model_name"],
            messages=messages,
            temperature=temperature,
            top_p=top_p,
            max_new_tokens=max_new_tokens,
            api_base=model_api_dict["api_base"],
            api_key=model_api_dict["api_key"],
        )
    else:
        raise NotImplementedError()

    return stream_iter


def openai_api_stream_iter(
    model_name,
    messages,
    temperature,
    top_p,
    max_new_tokens,
    api_base=None,
    api_key=None,
):
    import openai

    api_key = api_key or os.environ["OPENAI_API_KEY"]

    if "azure" in model_name:
        client = openai.AzureOpenAI(
            api_version="2023-07-01-preview",
            azure_endpoint=api_base or "https://api.openai.com/v1",
            api_key=api_key,
        )
    else:
        client = openai.OpenAI(
            base_url=api_base or "https://api.openai.com/v1",
            api_key=api_key,
            timeout=180,
        )

    # Make requests for logging
    text_messages = []
    for message in messages:
        if type(message["content"]) == str:  # text-only model
            text_messages.append(message)
        else:  # vision model
            filtered_content_list = [
                content for content in message["content"] if content["type"] == "text"
            ]
            text_messages.append(
                {"role": message["role"], "content": filtered_content_list}
            )

    gen_params = {
        "model": model_name,
        "prompt": text_messages,
        "temperature": temperature,
        "top_p": top_p,
        "max_new_tokens": max_new_tokens,
    }
    logger.info(f"==== request ====\n{gen_params}")

    res = client.chat.completions.create(
        model=model_name,
        messages=messages,
        temperature=temperature,
        max_tokens=max_new_tokens,
        stream=True,
    )
    text = ""
    for chunk in res:
        if len(chunk.choices) > 0:
            text += chunk.choices[0].delta.content or ""
            data = {
                "text": text,
                "error_code": 0,
            }
            yield data


def upload_openai_file_to_gcs(file_id):
    import openai
    from google.cloud import storage

    storage_client = storage.Client()

    file = openai.files.content(file_id)
    # upload file to GCS
    bucket = storage_client.get_bucket("arena_user_content")
    blob = bucket.blob(f"{file_id}")
    blob.upload_from_string(file.read())
    blob.make_public()
    return blob.public_url


def openai_assistant_api_stream_iter(
    state,
    prompt,
    assistant_id,
    api_key=None,
):
    import openai
    import base64

    api_key = api_key or os.environ["OPENAI_API_KEY"]
    client = openai.OpenAI(base_url="https://api.openai.com/v1", api_key=api_key)

    if state.oai_thread_id is None:
        logger.info("==== create thread ====")
        thread = client.beta.threads.create()
        state.oai_thread_id = thread.id
    logger.info(f"==== thread_id ====\n{state.oai_thread_id}")
    thread_message = client.beta.threads.messages.with_raw_response.create(
        state.oai_thread_id,
        role="user",
        content=prompt,
        timeout=3,
    )
    # logger.info(f"header {thread_message.headers}")
    thread_message = thread_message.parse()
    # Make requests
    gen_params = {
        "assistant_id": assistant_id,
        "thread_id": state.oai_thread_id,
        "message": prompt,
    }
    logger.info(f"==== request ====\n{gen_params}")

    res = requests.post(
        f"https://api.openai.com/v1/threads/{state.oai_thread_id}/runs",
        headers={
            "Authorization": f"Bearer {api_key}",
            "Content-Type": "application/json",
            "OpenAI-Beta": "assistants=v1",
        },
        json={"assistant_id": assistant_id, "stream": True},
        timeout=30,
        stream=True,
    )

    list_of_text = []
    list_of_raw_text = []
    offset_idx = 0
    full_ret_text = ""
    idx_mapping = {}
    for line in res.iter_lines():
        if not line:
            continue
        data = line.decode("utf-8")
        # logger.info("data:", data)
        if data.endswith("[DONE]"):
            break
        if data.startswith("event"):
            event = data.split(":")[1].strip()
            if event == "thread.message.completed":
                offset_idx += len(list_of_text)
            continue
        data = json.loads(data[6:])

        if data.get("status") == "failed":
            yield {
                "text": f"**API REQUEST ERROR** Reason: {data['last_error']['message']}",
                "error_code": 1,
            }
            return

        if data.get("status") == "completed":
            logger.info(f"[debug]: {data}")

        if data["object"] != "thread.message.delta":
            continue

        for delta in data["delta"]["content"]:
            text_index = delta["index"] + offset_idx
            if len(list_of_text) <= text_index:
                list_of_text.append("")
                list_of_raw_text.append("")

            text = list_of_text[text_index]
            raw_text = list_of_raw_text[text_index]

            if delta["type"] == "text":
                # text, url_citation or file_path
                content = delta["text"]
                if "annotations" in content and len(content["annotations"]) > 0:
                    annotations = content["annotations"]

                    cur_offset = 0
                    raw_text_copy = raw_text
                    for anno in annotations:
                        if anno["type"] == "url_citation":
                            anno_text = anno["text"]
                            if anno_text not in idx_mapping:
                                continue
                            citation_number = idx_mapping[anno_text]

                            start_idx = anno["start_index"] + cur_offset
                            end_idx = anno["end_index"] + cur_offset
                            url = anno["url_citation"]["url"]

                            citation = f" [[{citation_number}]]({url})"
                            raw_text_copy = (
                                raw_text_copy[:start_idx]
                                + citation
                                + raw_text_copy[end_idx:]
                            )
                            cur_offset += len(citation) - (end_idx - start_idx)
                        elif anno["type"] == "file_path":
                            file_public_url = upload_openai_file_to_gcs(
                                anno["file_path"]["file_id"]
                            )
                            raw_text_copy = raw_text_copy.replace(
                                anno["text"], f"{file_public_url}"
                            )
                    text = raw_text_copy
                else:
                    text_content = content["value"]
                    raw_text += text_content

                    # re-index citation number
                    pattern = r"【\d+】"
                    matches = re.findall(pattern, content["value"])
                    if len(matches) > 0:
                        for match in matches:
                            if match not in idx_mapping:
                                idx_mapping[match] = len(idx_mapping) + 1
                            citation_number = idx_mapping[match]
                            text_content = text_content.replace(
                                match, f" [{citation_number}]"
                            )
                    text += text_content
                    # yield {"text": text, "error_code": 0}
            elif delta["type"] == "image_file":
                image_public_url = upload_openai_file_to_gcs(
                    delta["image_file"]["file_id"]
                )
                # raw_text += f"![image]({image_public_url})"
                text += f"![image]({image_public_url})"

            list_of_text[text_index] = text
            list_of_raw_text[text_index] = raw_text

            full_ret_text = "\n".join(list_of_text)
            yield {"text": full_ret_text, "error_code": 0}


def anthropic_api_stream_iter(model_name, prompt, temperature, top_p, max_new_tokens):
    import anthropic

    c = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"])

    # Make requests
    gen_params = {
        "model": model_name,
        "prompt": prompt,
        "temperature": temperature,
        "top_p": top_p,
        "max_new_tokens": max_new_tokens,
    }
    logger.info(f"==== request ====\n{gen_params}")

    res = c.completions.create(
        prompt=prompt,
        stop_sequences=[anthropic.HUMAN_PROMPT],
        max_tokens_to_sample=max_new_tokens,
        temperature=temperature,
        top_p=top_p,
        model=model_name,
        stream=True,
    )
    text = ""
    for chunk in res:
        text += chunk.completion
        data = {
            "text": text,
            "error_code": 0,
        }
        yield data


def anthropic_message_api_stream_iter(
    model_name,
    messages,
    temperature,
    top_p,
    max_new_tokens,
    vertex_ai=False,
):
    import anthropic

    if vertex_ai:
        client = anthropic.AnthropicVertex(
            region=os.environ["GCP_LOCATION"],
            project_id=os.environ["GCP_PROJECT_ID"],
            max_retries=5,
        )
    else:
        client = anthropic.Anthropic(
            api_key=os.environ["ANTHROPIC_API_KEY"],
            max_retries=5,
        )

    text_messages = []
    for message in messages:
        if type(message["content"]) == str:  # text-only model
            text_messages.append(message)
        else:  # vision model
            filtered_content_list = [
                content for content in message["content"] if content["type"] == "text"
            ]
            text_messages.append(
                {"role": message["role"], "content": filtered_content_list}
            )

    # Make requests for logging
    gen_params = {
        "model": model_name,
        "prompt": text_messages,
        "temperature": temperature,
        "top_p": top_p,
        "max_new_tokens": max_new_tokens,
    }
    logger.info(f"==== request ====\n{gen_params}")

    system_prompt = ""
    if messages[0]["role"] == "system":
        if type(messages[0]["content"]) == dict:
            system_prompt = messages[0]["content"]["text"]
        elif type(messages[0]["content"]) == str:
            system_prompt = messages[0]["content"]
        # remove system prompt
        messages = messages[1:]

    text = ""
    with client.messages.stream(
        temperature=temperature,
        top_p=top_p,
        max_tokens=max_new_tokens,
        messages=messages,
        model=model_name,
        system=system_prompt,
    ) as stream:
        for chunk in stream.text_stream:
            text += chunk
            data = {
                "text": text,
                "error_code": 0,
            }
            yield data


def gemini_api_stream_iter(
    model_name, messages, temperature, top_p, max_new_tokens, api_key=None
):
    import google.generativeai as genai  # pip install google-generativeai

    if api_key is None:
        api_key = os.environ["GEMINI_API_KEY"]
    genai.configure(api_key=api_key)

    generation_config = {
        "temperature": temperature,
        "max_output_tokens": max_new_tokens,
        "top_p": top_p,
    }
    params = {
        "model": model_name,
        "prompt": messages,
    }
    params.update(generation_config)
    logger.info(f"==== request ====\n{params}")

    safety_settings = [
        {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
        {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
        {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
        {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
    ]

    history = []
    system_prompt = None
    for message in messages[:-1]:
        if message["role"] == "system":
            system_prompt = message["content"]
            continue
        history.append({"role": message["role"], "parts": message["content"]})

    model = genai.GenerativeModel(
        model_name=model_name,
        system_instruction=system_prompt,
        generation_config=generation_config,
        safety_settings=safety_settings,
    )
    convo = model.start_chat(history=history)
    response = convo.send_message(messages[-1]["content"], stream=True)

    try:
        text = ""
        for chunk in response:
            text += chunk.candidates[0].content.parts[0].text
            data = {
                "text": text,
                "error_code": 0,
            }
            yield data
    except Exception as e:
        logger.error(f"==== error ====\n{e}")
        reason = chunk.candidates
        yield {
            "text": f"**API REQUEST ERROR** Reason: {reason}.",
            "error_code": 1,
        }


def bard_api_stream_iter(model_name, conv, temperature, top_p, api_key=None):
    del top_p  # not supported
    del temperature  # not supported

    if api_key is None:
        api_key = os.environ["BARD_API_KEY"]

    # convert conv to conv_bard
    conv_bard = []
    for turn in conv:
        if turn["role"] == "user":
            conv_bard.append({"author": "0", "content": turn["content"]})
        elif turn["role"] == "assistant":
            conv_bard.append({"author": "1", "content": turn["content"]})
        else:
            raise ValueError(f"Unsupported role: {turn['role']}")

    params = {
        "model": model_name,
        "prompt": conv_bard,
    }
    logger.info(f"==== request ====\n{params}")

    try:
        res = requests.post(
            f"https://generativelanguage.googleapis.com/v1beta2/models/{model_name}:generateMessage?key={api_key}",
            json={
                "prompt": {
                    "messages": conv_bard,
                },
            },
            timeout=30,
        )
    except Exception as e:
        logger.error(f"==== error ====\n{e}")
        yield {
            "text": f"**API REQUEST ERROR** Reason: {e}.",
            "error_code": 1,
        }

    if res.status_code != 200:
        logger.error(f"==== error ==== ({res.status_code}): {res.text}")
        yield {
            "text": f"**API REQUEST ERROR** Reason: status code {res.status_code}.",
            "error_code": 1,
        }

    response_json = res.json()
    if "candidates" not in response_json:
        logger.error(f"==== error ==== response blocked: {response_json}")
        reason = response_json["filters"][0]["reason"]
        yield {
            "text": f"**API REQUEST ERROR** Reason: {reason}.",
            "error_code": 1,
        }

    response = response_json["candidates"][0]["content"]
    pos = 0
    while pos < len(response):
        # simulate token streaming
        pos += random.randint(3, 6)
        time.sleep(0.002)
        data = {
            "text": response[:pos],
            "error_code": 0,
        }
        yield data


def ai2_api_stream_iter(
    model_name,
    model_id,
    messages,
    temperature,
    top_p,
    max_new_tokens,
    api_key=None,
    api_base=None,
):
    # get keys and needed values
    ai2_key = api_key or os.environ.get("AI2_API_KEY")
    api_base = api_base or "https://inferd.allen.ai/api/v1/infer"

    # Make requests
    gen_params = {
        "model": model_name,
        "prompt": messages,
        "temperature": temperature,
        "top_p": top_p,
        "max_new_tokens": max_new_tokens,
    }
    logger.info(f"==== request ====\n{gen_params}")

    # AI2 uses vLLM, which requires that `top_p` be 1.0 for greedy sampling:
    # https://github.com/vllm-project/vllm/blob/v0.1.7/vllm/sampling_params.py#L156-L157
    if temperature == 0.0 and top_p < 1.0:
        raise ValueError("top_p must be 1 when temperature is 0.0")

    res = requests.post(
        api_base,
        stream=True,
        headers={"Authorization": f"Bearer {ai2_key}"},
        json={
            "model_id": model_id,
            # This input format is specific to the Tulu2 model. Other models
            # may require different input formats. See the model's schema
            # documentation on InferD for more information.
            "input": {
                "messages": messages,
                "opts": {
                    "max_tokens": max_new_tokens,
                    "temperature": temperature,
                    "top_p": top_p,
                    "logprobs": 1,  # increase for more choices
                },
            },
        },
        timeout=5,
    )

    if res.status_code != 200:
        logger.error(f"unexpected response ({res.status_code}): {res.text}")
        raise ValueError("unexpected response from InferD", res)

    text = ""
    for line in res.iter_lines():
        if line:
            part = json.loads(line)
            if "result" in part and "output" in part["result"]:
                for t in part["result"]["output"]["text"]:
                    text += t
            else:
                logger.error(f"unexpected part: {part}")
                raise ValueError("empty result in InferD response")

            data = {
                "text": text,
                "error_code": 0,
            }
            yield data


def mistral_api_stream_iter(model_name, messages, temperature, top_p, max_new_tokens):
    from mistralai.client import MistralClient
    from mistralai.models.chat_completion import ChatMessage

    api_key = os.environ["MISTRAL_API_KEY"]

    client = MistralClient(api_key=api_key, timeout=5)

    # Make requests
    gen_params = {
        "model": model_name,
        "prompt": messages,
        "temperature": temperature,
        "top_p": top_p,
        "max_new_tokens": max_new_tokens,
    }
    logger.info(f"==== request ====\n{gen_params}")

    new_messages = [
        ChatMessage(role=message["role"], content=message["content"])
        for message in messages
    ]

    res = client.chat_stream(
        model=model_name,
        temperature=temperature,
        messages=new_messages,
        max_tokens=max_new_tokens,
        top_p=top_p,
    )

    text = ""
    for chunk in res:
        if chunk.choices[0].delta.content is not None:
            text += chunk.choices[0].delta.content
            data = {
                "text": text,
                "error_code": 0,
            }
            yield data


def nvidia_api_stream_iter(model_name, messages, temp, top_p, max_tokens, api_base):
    api_key = os.environ["NVIDIA_API_KEY"]
    headers = {
        "Authorization": f"Bearer {api_key}",
        "accept": "text/event-stream",
        "content-type": "application/json",
    }
    # nvidia api does not accept 0 temperature
    if temp == 0.0:
        temp = 0.000001

    payload = {
        "messages": messages,
        "temperature": temp,
        "top_p": top_p,
        "max_tokens": max_tokens,
        "seed": 42,
        "stream": True,
    }
    logger.info(f"==== request ====\n{payload}")

    response = requests.post(
        api_base, headers=headers, json=payload, stream=True, timeout=1
    )
    text = ""
    for line in response.iter_lines():
        if line:
            data = line.decode("utf-8")
            if data.endswith("[DONE]"):
                break
            data = json.loads(data[6:])["choices"][0]["delta"]["content"]
            text += data
            yield {"text": text, "error_code": 0}


def yandexgpt_api_stream_iter(
    model_name, messages, temperature, max_tokens, api_base, api_key, folder_id
):
    api_key = api_key or os.environ["YANDEXGPT_API_KEY"]
    headers = {
        "Authorization": f"Api-Key {api_key}",
        "content-type": "application/json",
    }

    payload = {
        "modelUri": f"gpt://{folder_id}/{model_name}",
        "completionOptions": {
            "temperature": temperature,
            "max_tokens": max_tokens,
            "stream": True,
        },
        "messages": messages,
    }
    logger.info(f"==== request ====\n{payload}")

    # https://llm.api.cloud.yandex.net/foundationModels/v1/completion
    response = requests.post(
        api_base, headers=headers, json=payload, stream=True, timeout=60
    )
    text = ""
    for line in response.iter_lines():
        if line:
            data = json.loads(line.decode("utf-8"))
            data = data["result"]
            top_alternative = data["alternatives"][0]
            text = top_alternative["message"]["text"]
            yield {"text": text, "error_code": 0}

            status = top_alternative["status"]
            if status in (
                "ALTERNATIVE_STATUS_FINAL",
                "ALTERNATIVE_STATUS_TRUNCATED_FINAL",
            ):
                break


def cohere_api_stream_iter(
    client_name: str,
    model_id: str,
    messages: list,
    temperature: Optional[
        float
    ] = None,  # The SDK or API handles None for all parameters following
    top_p: Optional[float] = None,
    max_new_tokens: Optional[int] = None,
    api_key: Optional[str] = None,  # default is env var CO_API_KEY
    api_base: Optional[str] = None,
):
    import cohere

    OPENAI_TO_COHERE_ROLE_MAP = {
        "user": "User",
        "assistant": "Chatbot",
        "system": "System",
    }

    client = cohere.Client(
        api_key=api_key,
        base_url=api_base,
        client_name=client_name,
    )

    # prepare and log requests
    chat_history = [
        dict(
            role=OPENAI_TO_COHERE_ROLE_MAP[message["role"]], message=message["content"]
        )
        for message in messages[:-1]
    ]
    actual_prompt = messages[-1]["content"]

    gen_params = {
        "model": model_id,
        "messages": messages,
        "chat_history": chat_history,
        "prompt": actual_prompt,
        "temperature": temperature,
        "top_p": top_p,
        "max_new_tokens": max_new_tokens,
    }
    logger.info(f"==== request ====\n{gen_params}")

    # make request and stream response
    res = client.chat_stream(
        message=actual_prompt,
        chat_history=chat_history,
        model=model_id,
        temperature=temperature,
        max_tokens=max_new_tokens,
        p=top_p,
    )
    try:
        text = ""
        for streaming_item in res:
            if streaming_item.event_type == "text-generation":
                text += streaming_item.text
                yield {"text": text, "error_code": 0}
    except cohere.core.ApiError as e:
        logger.error(f"==== error from cohere api: {e} ====")
        yield {
            "text": f"**API REQUEST ERROR** Reason: {e}",
            "error_code": 1,
        }


def vertex_api_stream_iter(model_name, messages, temperature, top_p, max_new_tokens):
    import vertexai
    from vertexai import generative_models
    from vertexai.generative_models import (
        GenerationConfig,
        GenerativeModel,
        Image,
    )

    project_id = os.environ.get("GCP_PROJECT_ID", None)
    location = os.environ.get("GCP_LOCATION", None)
    vertexai.init(project=project_id, location=location)

    text_messages = []
    for message in messages:
        if type(message) == str:
            text_messages.append(message)

    gen_params = {
        "model": model_name,
        "prompt": text_messages,
        "temperature": temperature,
        "top_p": top_p,
        "max_new_tokens": max_new_tokens,
    }
    logger.info(f"==== request ====\n{gen_params}")

    safety_settings = [
        generative_models.SafetySetting(
            category=generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT,
            threshold=generative_models.HarmBlockThreshold.BLOCK_NONE,
        ),
        generative_models.SafetySetting(
            category=generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
            threshold=generative_models.HarmBlockThreshold.BLOCK_NONE,
        ),
        generative_models.SafetySetting(
            category=generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
            threshold=generative_models.HarmBlockThreshold.BLOCK_NONE,
        ),
        generative_models.SafetySetting(
            category=generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
            threshold=generative_models.HarmBlockThreshold.BLOCK_NONE,
        ),
    ]
    generator = GenerativeModel(model_name).generate_content(
        messages,
        stream=True,
        generation_config=GenerationConfig(
            top_p=top_p, max_output_tokens=max_new_tokens, temperature=temperature
        ),
        safety_settings=safety_settings,
    )

    ret = ""
    for chunk in generator:
        # NOTE(chris): This may be a vertex api error, below is HOTFIX: https://github.com/googleapis/python-aiplatform/issues/3129
        ret += chunk.candidates[0].content.parts[0]._raw_part.text
        # ret += chunk.text
        data = {
            "text": ret,
            "error_code": 0,
        }
        yield data


def reka_api_stream_iter(
    model_name: str,
    messages: list,
    temperature: Optional[
        float
    ] = None,  # The SDK or API handles None for all parameters following
    top_p: Optional[float] = None,
    max_new_tokens: Optional[int] = None,
    api_key: Optional[str] = None,  # default is env var CO_API_KEY
    api_base: Optional[str] = None,
):
    api_key = api_key or os.environ["REKA_API_KEY"]

    use_search_engine = False
    if "-online" in model_name:
        model_name = model_name.replace("-online", "")
        use_search_engine = True
    request = {
        "model_name": model_name,
        "conversation_history": messages,
        "temperature": temperature,
        "request_output_len": max_new_tokens,
        "runtime_top_p": top_p,
        "stream": True,
        "use_search_engine": use_search_engine,
    }

    # Make requests for logging
    text_messages = []
    for message in messages:
        text_messages.append({"type": message["type"], "text": message["text"]})
    logged_request = dict(request)
    logged_request["conversation_history"] = text_messages

    logger.info(f"==== request ====\n{logged_request}")

    response = requests.post(
        api_base,
        stream=True,
        json=request,
        headers={
            "X-Api-Key": api_key,
        },
    )

    if response.status_code != 200:
        error_message = response.text
        logger.error(f"==== error from reka api: {error_message} ====")
        yield {
            "text": f"**API REQUEST ERROR** Reason: {error_message}",
            "error_code": 1,
        }
        return

    for line in response.iter_lines():
        line = line.decode("utf8")
        if not line.startswith("data: "):
            continue
        gen = json.loads(line[6:])
        yield {"text": gen["text"], "error_code": 0}