Spaces:
Runtime error
Runtime error
import { type ActionFunctionArgs } from '@remix-run/cloudflare'; | |
import { createDataStream } from 'ai'; | |
import { MAX_RESPONSE_SEGMENTS, MAX_TOKENS } from '~/lib/.server/llm/constants'; | |
import { CONTINUE_PROMPT } from '~/lib/common/prompts/prompts'; | |
import { streamText, type Messages, type StreamingOptions } from '~/lib/.server/llm/stream-text'; | |
import SwitchableStream from '~/lib/.server/llm/switchable-stream'; | |
import type { IProviderSetting } from '~/types/model'; | |
import { createScopedLogger } from '~/utils/logger'; | |
export async function action(args: ActionFunctionArgs) { | |
return chatAction(args); | |
} | |
const logger = createScopedLogger('api.chat'); | |
function parseCookies(cookieHeader: string): Record<string, string> { | |
const cookies: Record<string, string> = {}; | |
const items = cookieHeader.split(';').map((cookie) => cookie.trim()); | |
items.forEach((item) => { | |
const [name, ...rest] = item.split('='); | |
if (name && rest) { | |
const decodedName = decodeURIComponent(name.trim()); | |
const decodedValue = decodeURIComponent(rest.join('=').trim()); | |
cookies[decodedName] = decodedValue; | |
} | |
}); | |
return cookies; | |
} | |
async function chatAction({ context, request }: ActionFunctionArgs) { | |
const { messages, files, promptId, contextOptimization, isPromptCachingEnabled } = await request.json<{ | |
messages: Messages; | |
files: any; | |
promptId?: string; | |
contextOptimization: boolean; | |
isPromptCachingEnabled: boolean; | |
}>(); | |
const cookieHeader = request.headers.get('Cookie'); | |
const apiKeys = JSON.parse(parseCookies(cookieHeader || '').apiKeys || '{}'); | |
const providerSettings: Record<string, IProviderSetting> = JSON.parse( | |
parseCookies(cookieHeader || '').providers || '{}', | |
); | |
const stream = new SwitchableStream(); | |
const cumulativeUsage = { | |
completionTokens: 0, | |
promptTokens: 0, | |
totalTokens: 0, | |
}; | |
try { | |
const options: StreamingOptions = { | |
toolChoice: 'none', | |
// eslint-disable-next-line @typescript-eslint/naming-convention | |
onFinish: async ({ text: content, finishReason, usage, experimental_providerMetadata }) => { | |
logger.debug('usage', JSON.stringify(usage)); | |
const cacheUsage = experimental_providerMetadata?.anthropic; | |
console.debug({ cacheUsage }); | |
const isCacheHit = !!cacheUsage?.cacheReadInputTokens; | |
const isCacheMiss = !!cacheUsage?.cacheCreationInputTokens && !isCacheHit; | |
if (usage) { | |
cumulativeUsage.completionTokens += Math.round(usage.completionTokens || 0); | |
cumulativeUsage.promptTokens += Math.round( | |
(usage.promptTokens || 0) + | |
((cacheUsage?.cacheCreationInputTokens as number) || 0) * 1.25 + | |
((cacheUsage?.cacheReadInputTokens as number) || 0) * 0.1, | |
); | |
cumulativeUsage.totalTokens = cumulativeUsage.completionTokens + cumulativeUsage.promptTokens; | |
} | |
if (finishReason !== 'length') { | |
const encoder = new TextEncoder(); | |
const usageStream = createDataStream({ | |
async execute(dataStream) { | |
dataStream.writeMessageAnnotation({ | |
type: 'usage', | |
value: { | |
completionTokens: cumulativeUsage.completionTokens, | |
promptTokens: cumulativeUsage.promptTokens, | |
totalTokens: cumulativeUsage.totalTokens, | |
isCacheHit, | |
isCacheMiss, | |
}, | |
}); | |
}, | |
onError: (error: any) => `Custom error: ${error.message}`, | |
}).pipeThrough( | |
new TransformStream({ | |
transform: (chunk, controller) => { | |
// Convert the string stream to a byte stream | |
const str = typeof chunk === 'string' ? chunk : JSON.stringify(chunk); | |
controller.enqueue(encoder.encode(str)); | |
}, | |
}), | |
); | |
await stream.switchSource(usageStream); | |
await new Promise((resolve) => setTimeout(resolve, 0)); | |
stream.close(); | |
return; | |
} | |
if (stream.switches >= MAX_RESPONSE_SEGMENTS) { | |
throw Error('Cannot continue message: Maximum segments reached'); | |
} | |
const switchesLeft = MAX_RESPONSE_SEGMENTS - stream.switches; | |
logger.info(`Reached max token limit (${MAX_TOKENS}): Continuing message (${switchesLeft} switches left)`); | |
messages.push({ role: 'assistant', content }); | |
messages.push({ role: 'user', content: CONTINUE_PROMPT }); | |
const result = await streamText({ | |
messages, | |
env: context.cloudflare.env, | |
options, | |
apiKeys, | |
files, | |
providerSettings, | |
promptId, | |
contextOptimization, | |
isPromptCachingEnabled, | |
}); | |
stream.switchSource(result.toDataStream()); | |
return; | |
}, | |
}; | |
const totalMessageContent = messages.reduce((acc, message) => acc + message.content, ''); | |
logger.debug(`Total message length: ${totalMessageContent.split(' ').length}, words`); | |
const result = await streamText({ | |
messages, | |
env: context.cloudflare.env, | |
options, | |
apiKeys, | |
files, | |
providerSettings, | |
promptId, | |
contextOptimization, | |
isPromptCachingEnabled, | |
}); | |
(async () => { | |
for await (const part of result.fullStream) { | |
if (part.type === 'error') { | |
const error: any = part.error; | |
logger.error(`${error}`); | |
return; | |
} | |
} | |
})(); | |
stream.switchSource(result.toDataStream()); | |
// return createrespo | |
return new Response(stream.readable, { | |
status: 200, | |
headers: { | |
'Content-Type': 'text/event-stream; charset=utf-8', | |
Connection: 'keep-alive', | |
'Cache-Control': 'no-cache', | |
'Text-Encoding': 'chunked', | |
}, | |
}); | |
} catch (error: any) { | |
logger.error(error); | |
if (error.message?.includes('API key')) { | |
throw new Response('Invalid or missing API key', { | |
status: 401, | |
statusText: 'Unauthorized', | |
}); | |
} | |
throw new Response(null, { | |
status: 500, | |
statusText: 'Internal Server Error', | |
}); | |
} | |
} | |