/** BUSINESS
 *
 * All utils that are bound to business logic
 * (and wouldn't be useful in another project)
 * should be here.
 *
 **/

import ctxLengthData from "$lib/data/context_length.json";
import { InferenceClient, snippets } from "@huggingface/inference";
import { ConversationClass, type ConversationEntityMembers } from "$lib/state/conversations.svelte";
import { token } from "$lib/state/token.svelte";
import {
	isCustomModel,
	isHFModel,
	Provider,
	type Conversation,
	type ConversationMessage,
	type CustomModel,
	type Model,
} from "$lib/types.js";
import { safeParse } from "$lib/utils/json.js";
import { omit, tryGet } from "$lib/utils/object.svelte.js";
import type { ChatCompletionInputMessage, InferenceSnippet } from "@huggingface/tasks";
import { type ChatCompletionOutputMessage } from "@huggingface/tasks";
import { AutoTokenizer, PreTrainedTokenizer } from "@huggingface/transformers";
import OpenAI from "openai";
import { images } from "$lib/state/images.svelte.js";
import { projects } from "$lib/state/projects.svelte.js";
import { structuredForbiddenProviders } from "$lib/state/models.svelte.js";
import { modifySnippet } from "$lib/utils/snippets.js";

type ChatCompletionInputMessageChunk =
	NonNullable<ChatCompletionInputMessage["content"]> extends string | (infer U)[] ? U : never;

async function parseMessage(message: ConversationMessage): Promise<ChatCompletionInputMessage> {
	if (!message.images) return message;

	const urls = await Promise.all(message.images?.map(k => images.get(k)) ?? []);

	return {
		...omit(message, "images"),
		content: [
			{
				type: "text",
				text: message.content ?? "",
			},
			...message.images.map((_imgKey, i) => {
				return {
					type: "image_url",
					image_url: { url: urls[i] as string },
				} satisfies ChatCompletionInputMessageChunk;
			}),
		],
	};
}

type HFCompletionMetadata = {
	type: "huggingface";
	client: InferenceClient;
	args: Parameters<InferenceClient["chatCompletion"]>[0];
};

type OpenAICompletionMetadata = {
	type: "openai";
	client: OpenAI;
	args: OpenAI.ChatCompletionCreateParams;
};

type CompletionMetadata = HFCompletionMetadata | OpenAICompletionMetadata;

export function maxAllowedTokens(conversation: ConversationClass) {
	const ctxLength = (() => {
		const model = conversation.model;
		const { provider } = conversation.data;

		if (!provider || !isHFModel(model)) return;

		const idOnProvider = model.inferenceProviderMapping.find(data => data.provider === provider)?.providerId;
		if (!idOnProvider) return;

		const models = tryGet(ctxLengthData, provider);
		if (!models) return;

		return tryGet(models, idOnProvider) as number | undefined;
	})();

	if (!ctxLength) return customMaxTokens[conversation.model.id] ?? 100000;
	return ctxLength;
}

function getResponseFormatObj(conversation: ConversationClass | Conversation) {
	const data = conversation instanceof ConversationClass ? conversation.data : conversation;
	const json = safeParse(data.structuredOutput?.schema ?? "");
	// eslint-disable-next-line @typescript-eslint/no-explicit-any
	if (json && data.structuredOutput?.enabled && !structuredForbiddenProviders.includes(data.provider as any)) {
		switch (data.provider) {
			case "cohere": {
				return {
					type: "json_object",
					...json,
				};
			}
			case Provider.Cerebras: {
				return {
					type: "json_schema",
					json_schema: { ...json, name: "schema" },
				};
			}
			default: {
				return {
					type: "json_schema",
					json_schema: json,
				};
			}
		}
	}
}

async function getCompletionMetadata(
	conversation: ConversationClass | Conversation,
	signal?: AbortSignal,
): Promise<CompletionMetadata> {
	const data = conversation instanceof ConversationClass ? conversation.data : conversation;
	const model = conversation.model;
	const systemMessage = projects.current?.systemMessage;

	const messages: ConversationMessage[] = [
		...(isSystemPromptSupported(model) && systemMessage?.length ? [{ role: "system", content: systemMessage }] : []),
		...(data.messages || []),
	];
	const parsed = await Promise.all(messages.map(parseMessage));

	const baseArgs = {
		...data.config,
		messages: parsed,
		model: model.id,
		response_format: getResponseFormatObj(conversation),
		// eslint-disable-next-line @typescript-eslint/no-explicit-any
	} as any;

	// Handle OpenAI-compatible models
	if (isCustomModel(model)) {
		const openai = new OpenAI({
			apiKey: model.accessToken,
			baseURL: model.endpointUrl,
			dangerouslyAllowBrowser: true,
			fetch: (...args: Parameters<typeof fetch>) => {
				return fetch(args[0], { ...args[1], signal });
			},
		});

		const args = {
			...baseArgs,
			// eslint-disable-next-line @typescript-eslint/no-explicit-any
		} as any;

		return {
			type: "openai",
			client: openai,
			args,
		};
	}
	const args = {
		...baseArgs,
		provider: data.provider,
		// max_tokens: maxAllowedTokens(conversation) - currTokens,
		// eslint-disable-next-line @typescript-eslint/no-explicit-any
	} as any;

	// Handle HuggingFace models
	return {
		type: "huggingface",
		client: new InferenceClient(token.value),
		args,
	};
}

export async function handleStreamingResponse(
	conversation: ConversationClass | Conversation,
	onChunk: (content: string) => void,
	abortController: AbortController,
): Promise<void> {
	const metadata = await getCompletionMetadata(conversation, abortController.signal);

	if (metadata.type === "openai") {
		const stream = await metadata.client.chat.completions.create({
			...metadata.args,
			stream: true,
		} as OpenAI.ChatCompletionCreateParamsStreaming);

		let out = "";
		for await (const chunk of stream) {
			if (chunk.choices[0]?.delta?.content) {
				out += chunk.choices[0].delta.content;
				onChunk(out);
			}
		}
		return;
	}

	// HuggingFace streaming
	let out = "";
	for await (const chunk of metadata.client.chatCompletionStream(metadata.args, { signal: abortController.signal })) {
		if (chunk.choices && chunk.choices.length > 0 && chunk.choices[0]?.delta?.content) {
			out += chunk.choices[0].delta.content;
			onChunk(out);
		}
	}
}

export async function handleNonStreamingResponse(
	conversation: ConversationClass | Conversation,
): Promise<{ message: ChatCompletionOutputMessage; completion_tokens: number }> {
	const metadata = await getCompletionMetadata(conversation);

	if (metadata.type === "openai") {
		const response = await metadata.client.chat.completions.create({
			...metadata.args,
			stream: false,
		} as OpenAI.ChatCompletionCreateParamsNonStreaming);

		if (response.choices && response.choices.length > 0 && response.choices[0]?.message) {
			return {
				message: {
					role: "assistant",
					content: response.choices[0].message.content || "",
				},
				completion_tokens: response.usage?.completion_tokens || 0,
			};
		}
		throw new Error("No response from the model");
	}

	// HuggingFace non-streaming
	const response = await metadata.client.chatCompletion(metadata.args);
	if (response.choices && response.choices.length > 0) {
		const { message } = response.choices[0]!;
		const { completion_tokens } = response.usage;
		return { message, completion_tokens };
	}
	throw new Error("No response from the model");
}

export function isSystemPromptSupported(model: Model | CustomModel) {
	if (isCustomModel(model)) return true; // OpenAI-compatible models support system messages
	const template = model?.config.tokenizer_config?.chat_template;
	if (typeof template !== "string") return false;
	return template.includes("system");
}

export const defaultSystemMessage: { [key: string]: string } = {
	"Qwen/QwQ-32B-Preview":
		"You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step.",
} as const;

export const customMaxTokens: { [key: string]: number } = {
	"01-ai/Yi-1.5-34B-Chat": 2048,
	"HuggingFaceM4/idefics-9b-instruct": 2048,
	"deepseek-ai/DeepSeek-Coder-V2-Instruct": 16384,
	"bigcode/starcoder": 8192,
	"bigcode/starcoderplus": 8192,
	"HuggingFaceH4/starcoderbase-finetuned-oasst1": 8192,
	"google/gemma-7b": 8192,
	"google/gemma-1.1-7b-it": 8192,
	"google/gemma-2b": 8192,
	"google/gemma-1.1-2b-it": 8192,
	"google/gemma-2-27b-it": 8192,
	"google/gemma-2-9b-it": 4096,
	"google/gemma-2-2b-it": 8192,
	"tiiuae/falcon-7b": 8192,
	"tiiuae/falcon-7b-instruct": 8192,
	"timdettmers/guanaco-33b-merged": 2048,
	"mistralai/Mixtral-8x7B-Instruct-v0.1": 32768,
	"Qwen/Qwen2.5-72B-Instruct": 32768,
	"Qwen/Qwen2.5-Coder-32B-Instruct": 32768,
	"meta-llama/Meta-Llama-3-70B-Instruct": 8192,
	"CohereForAI/c4ai-command-r-plus-08-2024": 32768,
	"NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO": 32768,
	"meta-llama/Llama-2-70b-chat-hf": 8192,
	"HuggingFaceH4/zephyr-7b-alpha": 17432,
	"HuggingFaceH4/zephyr-7b-beta": 32768,
	"mistralai/Mistral-7B-Instruct-v0.1": 32768,
	"mistralai/Mistral-7B-Instruct-v0.2": 32768,
	"mistralai/Mistral-7B-Instruct-v0.3": 32768,
	"mistralai/Mistral-Nemo-Instruct-2407": 32768,
	"meta-llama/Meta-Llama-3-8B-Instruct": 8192,
	"mistralai/Mistral-7B-v0.1": 32768,
	"bigcode/starcoder2-3b": 16384,
	"bigcode/starcoder2-15b": 16384,
	"HuggingFaceH4/starchat2-15b-v0.1": 16384,
	"codellama/CodeLlama-7b-hf": 8192,
	"codellama/CodeLlama-13b-hf": 8192,
	"codellama/CodeLlama-34b-Instruct-hf": 8192,
	"meta-llama/Llama-2-7b-chat-hf": 8192,
	"meta-llama/Llama-2-13b-chat-hf": 8192,
	"OpenAssistant/oasst-sft-6-llama-30b": 2048,
	"TheBloke/vicuna-7B-v1.5-GPTQ": 2048,
	"HuggingFaceH4/starchat-beta": 8192,
	"bigcode/octocoder": 8192,
	"vwxyzjn/starcoderbase-triviaqa": 8192,
	"lvwerra/starcoderbase-gsm8k": 8192,
	"NousResearch/Hermes-3-Llama-3.1-8B": 16384,
	"microsoft/Phi-3.5-mini-instruct": 32768,
	"meta-llama/Llama-3.1-70B-Instruct": 32768,
	"meta-llama/Llama-3.1-8B-Instruct": 8192,
} as const;

// Order of the elements in InferenceModal.svelte is determined by this const
export const inferenceSnippetLanguages = ["python", "js", "sh"] as const;

export type InferenceSnippetLanguage = (typeof inferenceSnippetLanguages)[number];

export type GetInferenceSnippetReturn = InferenceSnippet[];

export function getInferenceSnippet(
	conversation: ConversationClass,
	language: InferenceSnippetLanguage,
	opts?: {
		accessToken?: string;
		messages?: ConversationEntityMembers["messages"];
		streaming?: ConversationEntityMembers["streaming"];
		max_tokens?: ConversationEntityMembers["config"]["max_tokens"];
		temperature?: ConversationEntityMembers["config"]["temperature"];
		top_p?: ConversationEntityMembers["config"]["top_p"];
		structured_output?: ConversationEntityMembers["structuredOutput"];
	},
): GetInferenceSnippetReturn {
	const model = conversation.model;
	const data = conversation.data;
	const provider = (isCustomModel(model) ? "hf-inference" : data.provider) as Provider;

	// If it's a custom model, we don't generate inference snippets
	if (isCustomModel(model)) {
		return [];
	}

	const providerMapping = model.inferenceProviderMapping.find(p => p.provider === provider);
	if (!providerMapping && provider !== "auto") return [];
	const allSnippets = snippets.getInferenceSnippets(
		{ ...model, inference: "" },
		provider,
		// eslint-disable-next-line @typescript-eslint/no-explicit-any
		{ ...providerMapping, hfModelId: model.id } as any,
		{ ...opts, directRequest: false },
	);
	console.log(allSnippets);

	return allSnippets
		.filter(s => s.language === language)
		.map(s => {
			if (
				opts?.structured_output?.schema &&
				opts.structured_output.enabled &&
				!structuredForbiddenProviders.includes(provider as Provider)
			) {
				return {
					...s,
					content: modifySnippet(s.content, {
						response_format: getResponseFormatObj(conversation),
					}),
				};
			}
			return s;
		});
}

const tokenizers = new Map<string, PreTrainedTokenizer | null>();

export async function getTokenizer(model: Model) {
	if (tokenizers.has(model.id)) return tokenizers.get(model.id)!;
	try {
		const tokenizer = await AutoTokenizer.from_pretrained(model.id);
		tokenizers.set(model.id, tokenizer);
		return tokenizer;
	} catch {
		tokenizers.set(model.id, null);
		return null;
	}
}

// When you don't have access to a tokenizer, guesstimate
export function estimateTokens(conversation: Conversation) {
	const content = conversation.messages.reduce((acc, curr) => {
		return acc + (curr?.content ?? "");
	}, "");

	return content.length / 4; // 1 token ~ 4 characters
}

export async function getTokens(conversation: Conversation): Promise<number> {
	const model = conversation.model;
	if (isCustomModel(model)) return estimateTokens(conversation);
	const tokenizer = await getTokenizer(model);
	if (tokenizer === null) return estimateTokens(conversation);

	// This is a simplified version - you might need to adjust based on your exact needs
	let formattedText = "";

	conversation.messages.forEach((message, index) => {
		let content = `<|start_header_id|>${message.role}<|end_header_id|>\n\n${message.content?.trim()}<|eot_id|>`;

		// Add BOS token to the first message
		if (index === 0) {
			content = "<|begin_of_text|>" + content;
		}

		formattedText += content;
	});

	// Encode the text to get tokens
	const encodedInput = tokenizer.encode(formattedText);

	// Return the number of tokens
	return encodedInput.length;
}