|
package openai |
|
|
|
import ( |
|
"fmt" |
|
"net/http" |
|
"sort" |
|
"strconv" |
|
"strings" |
|
"sync/atomic" |
|
"time" |
|
|
|
"github.com/gofiber/fiber/v2" |
|
"github.com/mudler/LocalAI/core/config" |
|
"github.com/mudler/LocalAI/core/schema" |
|
"github.com/mudler/LocalAI/core/services" |
|
model "github.com/mudler/LocalAI/pkg/model" |
|
"github.com/mudler/LocalAI/pkg/utils" |
|
"github.com/rs/zerolog/log" |
|
) |
|
|
|
|
|
type ToolType string |
|
|
|
const ( |
|
CodeInterpreter ToolType = "code_interpreter" |
|
Retrieval ToolType = "retrieval" |
|
Function ToolType = "function" |
|
|
|
MaxCharacterInstructions = 32768 |
|
MaxCharacterDescription = 512 |
|
MaxCharacterName = 256 |
|
MaxToolsSize = 128 |
|
MaxFileIdSize = 20 |
|
MaxCharacterMetadataKey = 64 |
|
MaxCharacterMetadataValue = 512 |
|
) |
|
|
|
type Tool struct { |
|
Type ToolType `json:"type"` |
|
} |
|
|
|
|
|
type Assistant struct { |
|
ID string `json:"id"` |
|
Object string `json:"object"` |
|
Created int64 `json:"created"` |
|
Model string `json:"model"` |
|
Name string `json:"name,omitempty"` |
|
Description string `json:"description,omitempty"` |
|
Instructions string `json:"instructions,omitempty"` |
|
Tools []Tool `json:"tools,omitempty"` |
|
FileIDs []string `json:"file_ids,omitempty"` |
|
Metadata map[string]string `json:"metadata,omitempty"` |
|
} |
|
|
|
var ( |
|
Assistants = []Assistant{} |
|
AssistantsConfigFile = "assistants.json" |
|
) |
|
|
|
type AssistantRequest struct { |
|
Model string `json:"model"` |
|
Name string `json:"name,omitempty"` |
|
Description string `json:"description,omitempty"` |
|
Instructions string `json:"instructions,omitempty"` |
|
Tools []Tool `json:"tools,omitempty"` |
|
FileIDs []string `json:"file_ids,omitempty"` |
|
Metadata map[string]string `json:"metadata,omitempty"` |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
func CreateAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { |
|
return func(c *fiber.Ctx) error { |
|
request := new(AssistantRequest) |
|
if err := c.BodyParser(request); err != nil { |
|
log.Warn().AnErr("Unable to parse AssistantRequest", err) |
|
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Cannot parse JSON"}) |
|
} |
|
|
|
if !modelExists(cl, ml, request.Model) { |
|
log.Warn().Msgf("Model: %s was not found in list of models.", request.Model) |
|
return c.Status(fiber.StatusBadRequest).SendString("Model " + request.Model + " not found") |
|
} |
|
|
|
if request.Tools == nil { |
|
request.Tools = []Tool{} |
|
} |
|
|
|
if request.FileIDs == nil { |
|
request.FileIDs = []string{} |
|
} |
|
|
|
if request.Metadata == nil { |
|
request.Metadata = make(map[string]string) |
|
} |
|
|
|
id := "asst_" + strconv.FormatInt(generateRandomID(), 10) |
|
|
|
assistant := Assistant{ |
|
ID: id, |
|
Object: "assistant", |
|
Created: time.Now().Unix(), |
|
Model: request.Model, |
|
Name: request.Name, |
|
Description: request.Description, |
|
Instructions: request.Instructions, |
|
Tools: request.Tools, |
|
FileIDs: request.FileIDs, |
|
Metadata: request.Metadata, |
|
} |
|
|
|
Assistants = append(Assistants, assistant) |
|
utils.SaveConfig(appConfig.ConfigsDir, AssistantsConfigFile, Assistants) |
|
return c.Status(fiber.StatusOK).JSON(assistant) |
|
} |
|
} |
|
|
|
var currentId int64 = 0 |
|
|
|
func generateRandomID() int64 { |
|
atomic.AddInt64(¤tId, 1) |
|
return currentId |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func ListAssistantsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { |
|
return func(c *fiber.Ctx) error { |
|
|
|
returnAssistants := Assistants |
|
|
|
limitQuery := c.Query("limit", "20") |
|
orderQuery := c.Query("order", "desc") |
|
afterQuery := c.Query("after") |
|
beforeQuery := c.Query("before") |
|
|
|
|
|
limit, err := strconv.Atoi(limitQuery) |
|
if err != nil { |
|
return c.Status(http.StatusBadRequest).SendString(fmt.Sprintf("Invalid limit query value: %s", limitQuery)) |
|
} |
|
|
|
|
|
sort.SliceStable(returnAssistants, func(i, j int) bool { |
|
if orderQuery == "asc" { |
|
return returnAssistants[i].Created < returnAssistants[j].Created |
|
} |
|
return returnAssistants[i].Created > returnAssistants[j].Created |
|
}) |
|
|
|
|
|
if afterQuery != "" { |
|
returnAssistants = filterAssistantsAfterID(returnAssistants, afterQuery) |
|
} |
|
if beforeQuery != "" { |
|
returnAssistants = filterAssistantsBeforeID(returnAssistants, beforeQuery) |
|
} |
|
|
|
|
|
if limit < len(returnAssistants) { |
|
returnAssistants = returnAssistants[:limit] |
|
} |
|
|
|
return c.JSON(returnAssistants) |
|
} |
|
} |
|
|
|
|
|
|
|
func filterAssistantsBeforeID(assistants []Assistant, id string) []Assistant { |
|
idInt, err := strconv.Atoi(id) |
|
if err != nil { |
|
return assistants |
|
} |
|
|
|
var filteredAssistants []Assistant |
|
|
|
for _, assistant := range assistants { |
|
aid, err := strconv.Atoi(strings.TrimPrefix(assistant.ID, "asst_")) |
|
if err != nil { |
|
continue |
|
} |
|
|
|
if aid < idInt { |
|
filteredAssistants = append(filteredAssistants, assistant) |
|
} |
|
} |
|
|
|
return filteredAssistants |
|
} |
|
|
|
|
|
|
|
func filterAssistantsAfterID(assistants []Assistant, id string) []Assistant { |
|
idInt, err := strconv.Atoi(id) |
|
if err != nil { |
|
return assistants |
|
} |
|
|
|
var filteredAssistants []Assistant |
|
|
|
for _, assistant := range assistants { |
|
aid, err := strconv.Atoi(strings.TrimPrefix(assistant.ID, "asst_")) |
|
if err != nil { |
|
continue |
|
} |
|
|
|
if aid > idInt { |
|
filteredAssistants = append(filteredAssistants, assistant) |
|
} |
|
} |
|
|
|
return filteredAssistants |
|
} |
|
|
|
func modelExists(cl *config.BackendConfigLoader, ml *model.ModelLoader, modelName string) (found bool) { |
|
found = false |
|
models, err := services.ListModels(cl, ml, "", true) |
|
if err != nil { |
|
return |
|
} |
|
|
|
for _, model := range models { |
|
if model == modelName { |
|
found = true |
|
return |
|
} |
|
} |
|
return |
|
} |
|
|
|
|
|
|
|
|
|
|
|
func DeleteAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { |
|
return func(c *fiber.Ctx) error { |
|
assistantID := c.Params("assistant_id") |
|
if assistantID == "" { |
|
return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id is required") |
|
} |
|
|
|
for i, assistant := range Assistants { |
|
if assistant.ID == assistantID { |
|
Assistants = append(Assistants[:i], Assistants[i+1:]...) |
|
utils.SaveConfig(appConfig.ConfigsDir, AssistantsConfigFile, Assistants) |
|
return c.Status(fiber.StatusOK).JSON(schema.DeleteAssistantResponse{ |
|
ID: assistantID, |
|
Object: "assistant.deleted", |
|
Deleted: true, |
|
}) |
|
} |
|
} |
|
|
|
log.Warn().Msgf("Unable to find assistant %s for deletion", assistantID) |
|
return c.Status(fiber.StatusNotFound).JSON(schema.DeleteAssistantResponse{ |
|
ID: assistantID, |
|
Object: "assistant.deleted", |
|
Deleted: false, |
|
}) |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
func GetAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { |
|
return func(c *fiber.Ctx) error { |
|
assistantID := c.Params("assistant_id") |
|
if assistantID == "" { |
|
return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id is required") |
|
} |
|
|
|
for _, assistant := range Assistants { |
|
if assistant.ID == assistantID { |
|
return c.Status(fiber.StatusOK).JSON(assistant) |
|
} |
|
} |
|
|
|
return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find assistant with id: %s", assistantID)) |
|
} |
|
} |
|
|
|
type AssistantFile struct { |
|
ID string `json:"id"` |
|
Object string `json:"object"` |
|
CreatedAt int64 `json:"created_at"` |
|
AssistantID string `json:"assistant_id"` |
|
} |
|
|
|
var ( |
|
AssistantFiles []AssistantFile |
|
AssistantsFileConfigFile = "assistantsFile.json" |
|
) |
|
|
|
func CreateAssistantFileEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { |
|
return func(c *fiber.Ctx) error { |
|
request := new(schema.AssistantFileRequest) |
|
if err := c.BodyParser(request); err != nil { |
|
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Cannot parse JSON"}) |
|
} |
|
|
|
assistantID := c.Params("assistant_id") |
|
if assistantID == "" { |
|
return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id is required") |
|
} |
|
|
|
for _, assistant := range Assistants { |
|
if assistant.ID == assistantID { |
|
if len(assistant.FileIDs) > MaxFileIdSize { |
|
return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("Max files %d for assistant %s reached.", MaxFileIdSize, assistant.Name)) |
|
} |
|
|
|
for _, file := range UploadedFiles { |
|
if file.ID == request.FileID { |
|
assistant.FileIDs = append(assistant.FileIDs, request.FileID) |
|
assistantFile := AssistantFile{ |
|
ID: file.ID, |
|
Object: "assistant.file", |
|
CreatedAt: time.Now().Unix(), |
|
AssistantID: assistant.ID, |
|
} |
|
AssistantFiles = append(AssistantFiles, assistantFile) |
|
utils.SaveConfig(appConfig.ConfigsDir, AssistantsFileConfigFile, AssistantFiles) |
|
return c.Status(fiber.StatusOK).JSON(assistantFile) |
|
} |
|
} |
|
|
|
return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find file_id: %s", request.FileID)) |
|
} |
|
} |
|
|
|
return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find %q", assistantID)) |
|
} |
|
} |
|
|
|
func ListAssistantFilesEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { |
|
type ListAssistantFiles struct { |
|
Data []schema.File |
|
Object string |
|
} |
|
|
|
return func(c *fiber.Ctx) error { |
|
assistantID := c.Params("assistant_id") |
|
if assistantID == "" { |
|
return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id is required") |
|
} |
|
|
|
limitQuery := c.Query("limit", "20") |
|
order := c.Query("order", "desc") |
|
limit, err := strconv.Atoi(limitQuery) |
|
if err != nil || limit < 1 || limit > 100 { |
|
limit = 20 |
|
} |
|
|
|
|
|
if order == "asc" { |
|
sort.Slice(AssistantFiles, func(i, j int) bool { |
|
return AssistantFiles[i].CreatedAt < AssistantFiles[j].CreatedAt |
|
}) |
|
} else { |
|
sort.Slice(AssistantFiles, func(i, j int) bool { |
|
return AssistantFiles[i].CreatedAt > AssistantFiles[j].CreatedAt |
|
}) |
|
} |
|
|
|
|
|
var limitedFiles []AssistantFile |
|
hasMore := false |
|
if len(AssistantFiles) > limit { |
|
hasMore = true |
|
limitedFiles = AssistantFiles[:limit] |
|
} else { |
|
limitedFiles = AssistantFiles |
|
} |
|
|
|
response := map[string]interface{}{ |
|
"object": "list", |
|
"data": limitedFiles, |
|
"first_id": func() string { |
|
if len(limitedFiles) > 0 { |
|
return limitedFiles[0].ID |
|
} |
|
return "" |
|
}(), |
|
"last_id": func() string { |
|
if len(limitedFiles) > 0 { |
|
return limitedFiles[len(limitedFiles)-1].ID |
|
} |
|
return "" |
|
}(), |
|
"has_more": hasMore, |
|
} |
|
|
|
return c.Status(fiber.StatusOK).JSON(response) |
|
} |
|
} |
|
|
|
func ModifyAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { |
|
return func(c *fiber.Ctx) error { |
|
request := new(AssistantRequest) |
|
if err := c.BodyParser(request); err != nil { |
|
log.Warn().AnErr("Unable to parse AssistantRequest", err) |
|
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Cannot parse JSON"}) |
|
} |
|
|
|
assistantID := c.Params("assistant_id") |
|
if assistantID == "" { |
|
return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id is required") |
|
} |
|
|
|
for i, assistant := range Assistants { |
|
if assistant.ID == assistantID { |
|
newAssistant := Assistant{ |
|
ID: assistantID, |
|
Object: assistant.Object, |
|
Created: assistant.Created, |
|
Model: request.Model, |
|
Name: request.Name, |
|
Description: request.Description, |
|
Instructions: request.Instructions, |
|
Tools: request.Tools, |
|
FileIDs: request.FileIDs, |
|
Metadata: request.Metadata, |
|
} |
|
|
|
|
|
Assistants = append(Assistants[:i], Assistants[i+1:]...) |
|
Assistants = append(Assistants, newAssistant) |
|
utils.SaveConfig(appConfig.ConfigsDir, AssistantsConfigFile, Assistants) |
|
return c.Status(fiber.StatusOK).JSON(newAssistant) |
|
} |
|
} |
|
return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find assistant with id: %s", assistantID)) |
|
} |
|
} |
|
|
|
func DeleteAssistantFileEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { |
|
return func(c *fiber.Ctx) error { |
|
assistantID := c.Params("assistant_id") |
|
fileId := c.Params("file_id") |
|
if assistantID == "" { |
|
return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id and file_id are required") |
|
} |
|
|
|
for i, assistant := range Assistants { |
|
if assistant.ID == assistantID { |
|
for j, fileId := range assistant.FileIDs { |
|
Assistants[i].FileIDs = append(Assistants[i].FileIDs[:j], Assistants[i].FileIDs[j+1:]...) |
|
|
|
|
|
for i, assistantFile := range AssistantFiles { |
|
if assistantFile.ID == fileId { |
|
|
|
AssistantFiles = append(AssistantFiles[:i], AssistantFiles[i+1:]...) |
|
utils.SaveConfig(appConfig.ConfigsDir, AssistantsFileConfigFile, AssistantFiles) |
|
return c.Status(fiber.StatusOK).JSON(schema.DeleteAssistantFileResponse{ |
|
ID: fileId, |
|
Object: "assistant.file.deleted", |
|
Deleted: true, |
|
}) |
|
} |
|
} |
|
} |
|
|
|
log.Warn().Msgf("Unable to locate file_id: %s in assistants: %s. Continuing to delete assistant file.", fileId, assistantID) |
|
for i, assistantFile := range AssistantFiles { |
|
if assistantFile.AssistantID == assistantID { |
|
|
|
AssistantFiles = append(AssistantFiles[:i], AssistantFiles[i+1:]...) |
|
utils.SaveConfig(appConfig.ConfigsDir, AssistantsFileConfigFile, AssistantFiles) |
|
|
|
return c.Status(fiber.StatusNotFound).JSON(schema.DeleteAssistantFileResponse{ |
|
ID: fileId, |
|
Object: "assistant.file.deleted", |
|
Deleted: true, |
|
}) |
|
} |
|
} |
|
} |
|
} |
|
log.Warn().Msgf("Unable to find assistant: %s", assistantID) |
|
|
|
return c.Status(fiber.StatusNotFound).JSON(schema.DeleteAssistantFileResponse{ |
|
ID: fileId, |
|
Object: "assistant.file.deleted", |
|
Deleted: false, |
|
}) |
|
} |
|
} |
|
|
|
func GetAssistantFileEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { |
|
return func(c *fiber.Ctx) error { |
|
assistantID := c.Params("assistant_id") |
|
fileId := c.Params("file_id") |
|
if assistantID == "" { |
|
return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id and file_id are required") |
|
} |
|
|
|
for _, assistantFile := range AssistantFiles { |
|
if assistantFile.AssistantID == assistantID { |
|
if assistantFile.ID == fileId { |
|
return c.Status(fiber.StatusOK).JSON(assistantFile) |
|
} |
|
return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find assistant file with file_id: %s", fileId)) |
|
} |
|
} |
|
return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find assistant file with assistant_id: %s", assistantID)) |
|
} |
|
} |
|
|