Add support for passing an API key or any other custom token in the authorization header (#579)
Browse files* Add support for passing an API key or any other custom token in the authorization header
* Make linter happy
* Fix README as per linter suggestions
* Refactor endpoints to actually parse zod config
* Remove top level env var and simplify header addition
* Skip section on API key or other, remove obsolete comment in endpointTgi.ts and remote CUSTOM_AUTHORIZATION_TOKEN from .env
---------
Co-authored-by: Nathan Sarrazin <[email protected]>
README.md
CHANGED
|
@@ -397,6 +397,8 @@ You can then add the generated information and the `authorization` parameter to
|
|
| 397 |
]
|
| 398 |
```
|
| 399 |
|
|
|
|
|
|
|
| 400 |
#### Models hosted on multiple custom endpoints
|
| 401 |
|
| 402 |
If the model being hosted will be available on multiple servers/instances add the `weight` parameter to your `.env.local`. The `weight` will be used to determine the probability of requesting a particular endpoint.
|
|
|
|
| 397 |
]
|
| 398 |
```
|
| 399 |
|
| 400 |
+
Please note that if `HF_ACCESS_TOKEN` is also set or not empty, it will take precedence.
|
| 401 |
+
|
| 402 |
#### Models hosted on multiple custom endpoints
|
| 403 |
|
| 404 |
If the model being hosted will be available on multiple servers/instances add the `weight` parameter to your `.env.local`. The `weight` will be used to determine the probability of requesting a particular endpoint.
|
src/lib/server/endpoints/aws/endpointAws.ts
CHANGED
|
@@ -15,15 +15,9 @@ export const endpointAwsParametersSchema = z.object({
|
|
| 15 |
region: z.string().optional(),
|
| 16 |
});
|
| 17 |
|
| 18 |
-
export async function endpointAws(
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
secretKey,
|
| 22 |
-
sessionToken,
|
| 23 |
-
model,
|
| 24 |
-
region,
|
| 25 |
-
service,
|
| 26 |
-
}: z.infer<typeof endpointAwsParametersSchema>): Promise<Endpoint> {
|
| 27 |
let AwsClient;
|
| 28 |
try {
|
| 29 |
AwsClient = (await import("aws4fetch")).AwsClient;
|
|
@@ -31,6 +25,9 @@ export async function endpointAws({
|
|
| 31 |
throw new Error("Failed to import aws4fetch");
|
| 32 |
}
|
| 33 |
|
|
|
|
|
|
|
|
|
|
| 34 |
const aws = new AwsClient({
|
| 35 |
accessKeyId: accessKey,
|
| 36 |
secretAccessKey: secretKey,
|
|
|
|
| 15 |
region: z.string().optional(),
|
| 16 |
});
|
| 17 |
|
| 18 |
+
export async function endpointAws(
|
| 19 |
+
input: z.input<typeof endpointAwsParametersSchema>
|
| 20 |
+
): Promise<Endpoint> {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
let AwsClient;
|
| 22 |
try {
|
| 23 |
AwsClient = (await import("aws4fetch")).AwsClient;
|
|
|
|
| 25 |
throw new Error("Failed to import aws4fetch");
|
| 26 |
}
|
| 27 |
|
| 28 |
+
const { url, accessKey, secretKey, sessionToken, model, region, service } =
|
| 29 |
+
endpointAwsParametersSchema.parse(input);
|
| 30 |
+
|
| 31 |
const aws = new AwsClient({
|
| 32 |
accessKeyId: accessKey,
|
| 33 |
secretAccessKey: secretKey,
|
src/lib/server/endpoints/llamacpp/endpointLlamacpp.ts
CHANGED
|
@@ -12,10 +12,10 @@ export const endpointLlamacppParametersSchema = z.object({
|
|
| 12 |
accessToken: z.string().min(1).default(HF_ACCESS_TOKEN),
|
| 13 |
});
|
| 14 |
|
| 15 |
-
export function endpointLlamacpp(
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
}
|
| 19 |
return async ({ conversation }) => {
|
| 20 |
const prompt = await buildPrompt({
|
| 21 |
messages: conversation.messages,
|
|
|
|
| 12 |
accessToken: z.string().min(1).default(HF_ACCESS_TOKEN),
|
| 13 |
});
|
| 14 |
|
| 15 |
+
export function endpointLlamacpp(
|
| 16 |
+
input: z.input<typeof endpointLlamacppParametersSchema>
|
| 17 |
+
): Endpoint {
|
| 18 |
+
const { url, model } = endpointLlamacppParametersSchema.parse(input);
|
| 19 |
return async ({ conversation }) => {
|
| 20 |
const prompt = await buildPrompt({
|
| 21 |
messages: conversation.messages,
|
src/lib/server/endpoints/ollama/endpointOllama.ts
CHANGED
|
@@ -11,11 +11,9 @@ export const endpointOllamaParametersSchema = z.object({
|
|
| 11 |
ollamaName: z.string().min(1).optional(),
|
| 12 |
});
|
| 13 |
|
| 14 |
-
export function endpointOllama({
|
| 15 |
-
url,
|
| 16 |
-
|
| 17 |
-
ollamaName,
|
| 18 |
-
}: z.infer<typeof endpointOllamaParametersSchema>): Endpoint {
|
| 19 |
return async ({ conversation }) => {
|
| 20 |
const prompt = await buildPrompt({
|
| 21 |
messages: conversation.messages,
|
|
|
|
| 11 |
ollamaName: z.string().min(1).optional(),
|
| 12 |
});
|
| 13 |
|
| 14 |
+
export function endpointOllama(input: z.input<typeof endpointOllamaParametersSchema>): Endpoint {
|
| 15 |
+
const { url, model, ollamaName } = endpointOllamaParametersSchema.parse(input);
|
| 16 |
+
|
|
|
|
|
|
|
| 17 |
return async ({ conversation }) => {
|
| 18 |
const prompt = await buildPrompt({
|
| 19 |
messages: conversation.messages,
|
src/lib/server/endpoints/openai/endpointOai.ts
CHANGED
|
@@ -16,12 +16,10 @@ export const endpointOAIParametersSchema = z.object({
|
|
| 16 |
.default("chat_completions"),
|
| 17 |
});
|
| 18 |
|
| 19 |
-
export async function endpointOai(
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
completion,
|
| 23 |
-
model,
|
| 24 |
-
}: z.infer<typeof endpointOAIParametersSchema>): Promise<Endpoint> {
|
| 25 |
let OpenAI;
|
| 26 |
try {
|
| 27 |
OpenAI = (await import("openai")).OpenAI;
|
|
|
|
| 16 |
.default("chat_completions"),
|
| 17 |
});
|
| 18 |
|
| 19 |
+
export async function endpointOai(
|
| 20 |
+
input: z.input<typeof endpointOAIParametersSchema>
|
| 21 |
+
): Promise<Endpoint> {
|
| 22 |
+
const { baseURL, apiKey, completion, model } = endpointOAIParametersSchema.parse(input);
|
|
|
|
|
|
|
| 23 |
let OpenAI;
|
| 24 |
try {
|
| 25 |
OpenAI = (await import("openai")).OpenAI;
|
src/lib/server/endpoints/tgi/endpointTgi.ts
CHANGED
|
@@ -10,13 +10,11 @@ export const endpointTgiParametersSchema = z.object({
|
|
| 10 |
type: z.literal("tgi"),
|
| 11 |
url: z.string().url(),
|
| 12 |
accessToken: z.string().default(HF_ACCESS_TOKEN),
|
|
|
|
| 13 |
});
|
| 14 |
|
| 15 |
-
export function endpointTgi({
|
| 16 |
-
url,
|
| 17 |
-
accessToken,
|
| 18 |
-
model,
|
| 19 |
-
}: z.infer<typeof endpointTgiParametersSchema>): Endpoint {
|
| 20 |
return async ({ conversation }) => {
|
| 21 |
const prompt = await buildPrompt({
|
| 22 |
messages: conversation.messages,
|
|
@@ -33,7 +31,19 @@ export function endpointTgi({
|
|
| 33 |
inputs: prompt,
|
| 34 |
accessToken,
|
| 35 |
},
|
| 36 |
-
{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
);
|
| 38 |
};
|
| 39 |
}
|
|
|
|
| 10 |
type: z.literal("tgi"),
|
| 11 |
url: z.string().url(),
|
| 12 |
accessToken: z.string().default(HF_ACCESS_TOKEN),
|
| 13 |
+
authorization: z.string().optional(),
|
| 14 |
});
|
| 15 |
|
| 16 |
+
export function endpointTgi(input: z.input<typeof endpointTgiParametersSchema>): Endpoint {
|
| 17 |
+
const { url, accessToken, model, authorization } = endpointTgiParametersSchema.parse(input);
|
|
|
|
|
|
|
|
|
|
| 18 |
return async ({ conversation }) => {
|
| 19 |
const prompt = await buildPrompt({
|
| 20 |
messages: conversation.messages,
|
|
|
|
| 31 |
inputs: prompt,
|
| 32 |
accessToken,
|
| 33 |
},
|
| 34 |
+
{
|
| 35 |
+
use_cache: false,
|
| 36 |
+
fetch: async (endpointUrl, info) => {
|
| 37 |
+
if (info && authorization && !accessToken) {
|
| 38 |
+
// Set authorization header if it is defined and HF_ACCESS_TOKEN is empty
|
| 39 |
+
info.headers = {
|
| 40 |
+
...info.headers,
|
| 41 |
+
Authorization: authorization,
|
| 42 |
+
};
|
| 43 |
+
}
|
| 44 |
+
return fetch(endpointUrl, info);
|
| 45 |
+
},
|
| 46 |
+
}
|
| 47 |
);
|
| 48 |
};
|
| 49 |
}
|