import { LlmIcon } from '@/components/svg-icon'; import { LlmModelType } from '@/constants/knowledge'; import { ResponseGetType } from '@/interfaces/database/base'; import { IFactory, IMyLlmValue, IThirdOAIModelCollection as IThirdAiModelCollection, IThirdOAIModelCollection, } from '@/interfaces/database/llm'; import { IAddLlmRequestBody, IDeleteLlmRequestBody, } from '@/interfaces/request/llm'; import userService from '@/services/user-service'; import { sortLLmFactoryListBySpecifiedOrder } from '@/utils/common-util'; import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query'; import { Flex, message } from 'antd'; import { DefaultOptionType } from 'antd/es/select'; import { useMemo } from 'react'; import { useTranslation } from 'react-i18next'; export const useFetchLlmList = ( modelType?: LlmModelType, ): IThirdAiModelCollection => { const { data } = useQuery({ queryKey: ['llmList'], initialData: {}, queryFn: async () => { const { data } = await userService.llm_list({ model_type: modelType }); return data?.data ?? {}; }, }); return data; }; export const useSelectLlmOptions = () => { const llmInfo: IThirdOAIModelCollection = useFetchLlmList(); const embeddingModelOptions = useMemo(() => { return Object.entries(llmInfo).map(([key, value]) => { return { label: key, options: value.map((x) => ({ label: x.llm_name, value: `${x.llm_name}@${x.fid}`, disabled: !x.available, })), }; }); }, [llmInfo]); return embeddingModelOptions; }; const getLLMIconName = (fid: string, llm_name: string) => { if (fid === 'FastEmbed') { return llm_name.split('/').at(0) ?? ''; } return fid; }; export const useSelectLlmOptionsByModelType = () => { const llmInfo: IThirdOAIModelCollection = useFetchLlmList(); const groupOptionsByModelType = (modelType: LlmModelType) => { return Object.entries(llmInfo) .filter(([, value]) => modelType ? value.some((x) => x.model_type.includes(modelType)) : true, ) .map(([key, value]) => { return { label: key, options: value .filter( (x) => (modelType ? x.model_type.includes(modelType) : true) && x.available, ) .map((x) => ({ label: ( {x.llm_name} ), value: `${x.llm_name}@${x.fid}`, disabled: !x.available, })), }; }) .filter((x) => x.options.length > 0); }; return { [LlmModelType.Chat]: groupOptionsByModelType(LlmModelType.Chat), [LlmModelType.Embedding]: groupOptionsByModelType(LlmModelType.Embedding), [LlmModelType.Image2text]: groupOptionsByModelType(LlmModelType.Image2text), [LlmModelType.Speech2text]: groupOptionsByModelType( LlmModelType.Speech2text, ), [LlmModelType.Rerank]: groupOptionsByModelType(LlmModelType.Rerank), [LlmModelType.TTS]: groupOptionsByModelType(LlmModelType.TTS), }; }; export const useComposeLlmOptionsByModelTypes = ( modelTypes: LlmModelType[], ) => { const allOptions = useSelectLlmOptionsByModelType(); return modelTypes.reduce((pre, cur) => { const options = allOptions[cur]; options.forEach((x) => { const item = pre.find((y) => y.label === x.label); if (item) { item.options.push(...x.options); } else { pre.push(x); } }); return pre; }, []); }; export const useFetchLlmFactoryList = (): ResponseGetType => { const { data, isFetching: loading } = useQuery({ queryKey: ['factoryList'], initialData: [], gcTime: 0, queryFn: async () => { const { data } = await userService.factories_list(); return data?.data ?? []; }, }); return { data, loading }; }; export type LlmItem = { name: string; logo: string } & IMyLlmValue; export const useFetchMyLlmList = (): ResponseGetType< Record > => { const { data, isFetching: loading } = useQuery({ queryKey: ['myLlmList'], initialData: {}, gcTime: 0, queryFn: async () => { const { data } = await userService.my_llm(); return data?.data ?? {}; }, }); return { data, loading }; }; export const useSelectLlmList = () => { const { data: myLlmList, loading: myLlmListLoading } = useFetchMyLlmList(); const { data: factoryList, loading: factoryListLoading } = useFetchLlmFactoryList(); const nextMyLlmList: Array = useMemo(() => { return Object.entries(myLlmList).map(([key, value]) => ({ name: key, logo: factoryList.find((x) => x.name === key)?.logo ?? '', ...value, })); }, [myLlmList, factoryList]); const nextFactoryList = useMemo(() => { const currentList = factoryList.filter((x) => Object.keys(myLlmList).every((y) => y !== x.name), ); return sortLLmFactoryListBySpecifiedOrder(currentList); }, [factoryList, myLlmList]); return { myLlmList: nextMyLlmList, factoryList: nextFactoryList, loading: myLlmListLoading || factoryListLoading, }; }; export interface IApiKeySavingParams { llm_factory: string; api_key: string; llm_name?: string; model_type?: string; base_url?: string; } export const useSaveApiKey = () => { const queryClient = useQueryClient(); const { t } = useTranslation(); const { data, isPending: loading, mutateAsync, } = useMutation({ mutationKey: ['saveApiKey'], mutationFn: async (params: IApiKeySavingParams) => { const { data } = await userService.set_api_key(params); if (data.code === 0) { message.success(t('message.modified')); queryClient.invalidateQueries({ queryKey: ['myLlmList'] }); queryClient.invalidateQueries({ queryKey: ['factoryList'] }); } return data.code; }, }); return { data, loading, saveApiKey: mutateAsync }; }; export interface ISystemModelSettingSavingParams { tenant_id: string; name?: string; asr_id: string; embd_id: string; img2txt_id: string; llm_id: string; } export const useSaveTenantInfo = () => { const { t } = useTranslation(); const { data, isPending: loading, mutateAsync, } = useMutation({ mutationKey: ['saveTenantInfo'], mutationFn: async (params: ISystemModelSettingSavingParams) => { const { data } = await userService.set_tenant_info(params); if (data.code === 0) { message.success(t('message.modified')); } return data.code; }, }); return { data, loading, saveTenantInfo: mutateAsync }; }; export const useAddLlm = () => { const queryClient = useQueryClient(); const { t } = useTranslation(); const { data, isPending: loading, mutateAsync, } = useMutation({ mutationKey: ['addLlm'], mutationFn: async (params: IAddLlmRequestBody) => { const { data } = await userService.add_llm(params); if (data.code === 0) { queryClient.invalidateQueries({ queryKey: ['myLlmList'] }); queryClient.invalidateQueries({ queryKey: ['factoryList'] }); message.success(t('message.modified')); } return data.code; }, }); return { data, loading, addLlm: mutateAsync }; }; export const useDeleteLlm = () => { const queryClient = useQueryClient(); const { t } = useTranslation(); const { data, isPending: loading, mutateAsync, } = useMutation({ mutationKey: ['deleteLlm'], mutationFn: async (params: IDeleteLlmRequestBody) => { const { data } = await userService.delete_llm(params); if (data.code === 0) { queryClient.invalidateQueries({ queryKey: ['myLlmList'] }); queryClient.invalidateQueries({ queryKey: ['factoryList'] }); message.success(t('message.deleted')); } return data.code; }, }); return { data, loading, deleteLlm: mutateAsync }; }; export const useDeleteFactory = () => { const queryClient = useQueryClient(); const { t } = useTranslation(); const { data, isPending: loading, mutateAsync, } = useMutation({ mutationKey: ['deleteFactory'], mutationFn: async (params: IDeleteLlmRequestBody) => { const { data } = await userService.deleteFactory(params); if (data.code === 0) { queryClient.invalidateQueries({ queryKey: ['myLlmList'] }); queryClient.invalidateQueries({ queryKey: ['factoryList'] }); message.success(t('message.deleted')); } return data.code; }, }); return { data, loading, deleteFactory: mutateAsync }; };