|
package openai |
|
|
|
import ( |
|
"encoding/json" |
|
"fmt" |
|
"io" |
|
"net/http" |
|
"net/http/httptest" |
|
"os" |
|
"path/filepath" |
|
"strings" |
|
"testing" |
|
"time" |
|
|
|
"github.com/gofiber/fiber/v2" |
|
"github.com/mudler/LocalAI/core/config" |
|
"github.com/mudler/LocalAI/core/schema" |
|
"github.com/mudler/LocalAI/pkg/model" |
|
"github.com/stretchr/testify/assert" |
|
) |
|
|
|
var configsDir string = "/tmp/localai/configs" |
|
|
|
type MockLoader struct { |
|
models []string |
|
} |
|
|
|
func tearDown() func() { |
|
return func() { |
|
UploadedFiles = []schema.File{} |
|
Assistants = []Assistant{} |
|
AssistantFiles = []AssistantFile{} |
|
_ = os.Remove(filepath.Join(configsDir, AssistantsConfigFile)) |
|
_ = os.Remove(filepath.Join(configsDir, AssistantsFileConfigFile)) |
|
} |
|
} |
|
|
|
func TestAssistantEndpoints(t *testing.T) { |
|
|
|
cl := &config.BackendConfigLoader{} |
|
|
|
modelPath := "/tmp/localai/model" |
|
var ml = model.NewModelLoader(modelPath) |
|
|
|
appConfig := &config.ApplicationConfig{ |
|
ConfigsDir: configsDir, |
|
UploadLimitMB: 10, |
|
UploadDir: "test_dir", |
|
ModelPath: modelPath, |
|
} |
|
|
|
_ = os.RemoveAll(appConfig.ConfigsDir) |
|
_ = os.MkdirAll(appConfig.ConfigsDir, 0750) |
|
_ = os.MkdirAll(modelPath, 0750) |
|
os.Create(filepath.Join(modelPath, "ggml-gpt4all-j")) |
|
|
|
app := fiber.New(fiber.Config{ |
|
BodyLimit: 20 * 1024 * 1024, |
|
}) |
|
|
|
|
|
app.Get("/assistants", ListAssistantsEndpoint(cl, ml, appConfig)) |
|
app.Post("/assistants", CreateAssistantEndpoint(cl, ml, appConfig)) |
|
app.Delete("/assistants/:assistant_id", DeleteAssistantEndpoint(cl, ml, appConfig)) |
|
app.Get("/assistants/:assistant_id", GetAssistantEndpoint(cl, ml, appConfig)) |
|
app.Post("/assistants/:assistant_id", ModifyAssistantEndpoint(cl, ml, appConfig)) |
|
|
|
app.Post("/files", UploadFilesEndpoint(cl, appConfig)) |
|
app.Get("/assistants/:assistant_id/files", ListAssistantFilesEndpoint(cl, ml, appConfig)) |
|
app.Post("/assistants/:assistant_id/files", CreateAssistantFileEndpoint(cl, ml, appConfig)) |
|
app.Delete("/assistants/:assistant_id/files/:file_id", DeleteAssistantFileEndpoint(cl, ml, appConfig)) |
|
app.Get("/assistants/:assistant_id/files/:file_id", GetAssistantFileEndpoint(cl, ml, appConfig)) |
|
|
|
t.Run("CreateAssistantEndpoint", func(t *testing.T) { |
|
t.Cleanup(tearDown()) |
|
ar := &AssistantRequest{ |
|
Model: "ggml-gpt4all-j", |
|
Name: "3.5-turbo", |
|
Description: "Test Assistant", |
|
Instructions: "You are computer science teacher answering student questions", |
|
Tools: []Tool{{Type: Function}}, |
|
FileIDs: nil, |
|
Metadata: nil, |
|
} |
|
|
|
resultAssistant, resp, err := createAssistant(app, *ar) |
|
assert.NoError(t, err) |
|
assert.Equal(t, fiber.StatusOK, resp.StatusCode) |
|
|
|
assert.Equal(t, 1, len(Assistants)) |
|
|
|
|
|
assert.Equal(t, ar.Name, resultAssistant.Name) |
|
assert.Equal(t, ar.Model, resultAssistant.Model) |
|
assert.Equal(t, ar.Tools, resultAssistant.Tools) |
|
assert.Equal(t, ar.Description, resultAssistant.Description) |
|
assert.Equal(t, ar.Instructions, resultAssistant.Instructions) |
|
assert.Equal(t, ar.FileIDs, resultAssistant.FileIDs) |
|
assert.Equal(t, ar.Metadata, resultAssistant.Metadata) |
|
}) |
|
|
|
t.Run("ListAssistantsEndpoint", func(t *testing.T) { |
|
var ids []string |
|
var resultAssistant []Assistant |
|
for i := 0; i < 4; i++ { |
|
ar := &AssistantRequest{ |
|
Model: "ggml-gpt4all-j", |
|
Name: fmt.Sprintf("3.5-turbo-%d", i), |
|
Description: fmt.Sprintf("Test Assistant - %d", i), |
|
Instructions: fmt.Sprintf("You are computer science teacher answering student questions - %d", i), |
|
Tools: []Tool{{Type: Function}}, |
|
FileIDs: []string{"fid-1234"}, |
|
Metadata: map[string]string{"meta": "data"}, |
|
} |
|
|
|
|
|
ra, _, err := createAssistant(app, *ar) |
|
|
|
time.Sleep(time.Second) |
|
resultAssistant = append(resultAssistant, ra) |
|
assert.NoError(t, err) |
|
ids = append(ids, resultAssistant[i].ID) |
|
} |
|
|
|
t.Cleanup(cleanupAllAssistants(t, app, ids)) |
|
|
|
tests := []struct { |
|
name string |
|
reqURL string |
|
expectedStatus int |
|
expectedResult []Assistant |
|
expectedStringResult string |
|
}{ |
|
{ |
|
name: "Valid Usage - limit only", |
|
reqURL: "/assistants?limit=2", |
|
expectedStatus: http.StatusOK, |
|
expectedResult: Assistants[:2], |
|
}, |
|
{ |
|
name: "Valid Usage - order asc", |
|
reqURL: "/assistants?order=asc", |
|
expectedStatus: http.StatusOK, |
|
expectedResult: Assistants, |
|
}, |
|
{ |
|
name: "Valid Usage - order desc", |
|
reqURL: "/assistants?order=desc", |
|
expectedStatus: http.StatusOK, |
|
expectedResult: []Assistant{Assistants[3], Assistants[2], Assistants[1], Assistants[0]}, |
|
}, |
|
{ |
|
name: "Valid Usage - after specific ID", |
|
reqURL: "/assistants?after=2", |
|
expectedStatus: http.StatusOK, |
|
|
|
expectedResult: Assistants[:3], |
|
}, |
|
{ |
|
name: "Valid Usage - before specific ID", |
|
reqURL: "/assistants?before=4", |
|
expectedStatus: http.StatusOK, |
|
expectedResult: Assistants[2:], |
|
}, |
|
{ |
|
name: "Invalid Usage - non-integer limit", |
|
reqURL: "/assistants?limit=two", |
|
expectedStatus: http.StatusBadRequest, |
|
expectedStringResult: "Invalid limit query value: two", |
|
}, |
|
{ |
|
name: "Invalid Usage - non-existing id in after", |
|
reqURL: "/assistants?after=100", |
|
expectedStatus: http.StatusOK, |
|
expectedResult: []Assistant(nil), |
|
}, |
|
} |
|
|
|
for _, tt := range tests { |
|
t.Run(tt.name, func(t *testing.T) { |
|
request := httptest.NewRequest(http.MethodGet, tt.reqURL, nil) |
|
response, err := app.Test(request) |
|
assert.NoError(t, err) |
|
assert.Equal(t, tt.expectedStatus, response.StatusCode) |
|
if tt.expectedStatus != fiber.StatusOK { |
|
all, _ := io.ReadAll(response.Body) |
|
assert.Equal(t, tt.expectedStringResult, string(all)) |
|
} else { |
|
var result []Assistant |
|
err = json.NewDecoder(response.Body).Decode(&result) |
|
assert.NoError(t, err) |
|
|
|
assert.Equal(t, tt.expectedResult, result) |
|
} |
|
}) |
|
} |
|
}) |
|
|
|
t.Run("DeleteAssistantEndpoint", func(t *testing.T) { |
|
ar := &AssistantRequest{ |
|
Model: "ggml-gpt4all-j", |
|
Name: "3.5-turbo", |
|
Description: "Test Assistant", |
|
Instructions: "You are computer science teacher answering student questions", |
|
Tools: []Tool{{Type: Function}}, |
|
FileIDs: nil, |
|
Metadata: nil, |
|
} |
|
|
|
resultAssistant, _, err := createAssistant(app, *ar) |
|
assert.NoError(t, err) |
|
|
|
target := fmt.Sprintf("/assistants/%s", resultAssistant.ID) |
|
deleteReq := httptest.NewRequest(http.MethodDelete, target, nil) |
|
_, err = app.Test(deleteReq) |
|
assert.NoError(t, err) |
|
assert.Equal(t, 0, len(Assistants)) |
|
}) |
|
|
|
t.Run("GetAssistantEndpoint", func(t *testing.T) { |
|
ar := &AssistantRequest{ |
|
Model: "ggml-gpt4all-j", |
|
Name: "3.5-turbo", |
|
Description: "Test Assistant", |
|
Instructions: "You are computer science teacher answering student questions", |
|
Tools: []Tool{{Type: Function}}, |
|
FileIDs: nil, |
|
Metadata: nil, |
|
} |
|
|
|
resultAssistant, _, err := createAssistant(app, *ar) |
|
assert.NoError(t, err) |
|
t.Cleanup(cleanupAllAssistants(t, app, []string{resultAssistant.ID})) |
|
|
|
target := fmt.Sprintf("/assistants/%s", resultAssistant.ID) |
|
request := httptest.NewRequest(http.MethodGet, target, nil) |
|
response, err := app.Test(request) |
|
assert.NoError(t, err) |
|
|
|
var getAssistant Assistant |
|
err = json.NewDecoder(response.Body).Decode(&getAssistant) |
|
assert.NoError(t, err) |
|
|
|
assert.Equal(t, resultAssistant.ID, getAssistant.ID) |
|
}) |
|
|
|
t.Run("ModifyAssistantEndpoint", func(t *testing.T) { |
|
ar := &AssistantRequest{ |
|
Model: "ggml-gpt4all-j", |
|
Name: "3.5-turbo", |
|
Description: "Test Assistant", |
|
Instructions: "You are computer science teacher answering student questions", |
|
Tools: []Tool{{Type: Function}}, |
|
FileIDs: nil, |
|
Metadata: nil, |
|
} |
|
|
|
resultAssistant, _, err := createAssistant(app, *ar) |
|
assert.NoError(t, err) |
|
|
|
modifiedAr := &AssistantRequest{ |
|
Model: "ggml-gpt4all-j", |
|
Name: "4.0-turbo", |
|
Description: "Modified Test Assistant", |
|
Instructions: "You are math teacher answering student questions", |
|
Tools: []Tool{{Type: CodeInterpreter}}, |
|
FileIDs: nil, |
|
Metadata: nil, |
|
} |
|
|
|
modifiedArJson, err := json.Marshal(modifiedAr) |
|
assert.NoError(t, err) |
|
|
|
target := fmt.Sprintf("/assistants/%s", resultAssistant.ID) |
|
request := httptest.NewRequest(http.MethodPost, target, strings.NewReader(string(modifiedArJson))) |
|
request.Header.Set(fiber.HeaderContentType, "application/json") |
|
|
|
modifyResponse, err := app.Test(request) |
|
assert.NoError(t, err) |
|
var getAssistant Assistant |
|
err = json.NewDecoder(modifyResponse.Body).Decode(&getAssistant) |
|
assert.NoError(t, err) |
|
|
|
t.Cleanup(cleanupAllAssistants(t, app, []string{getAssistant.ID})) |
|
|
|
assert.Equal(t, resultAssistant.ID, getAssistant.ID) |
|
assert.Equal(t, modifiedAr.Tools, getAssistant.Tools) |
|
assert.Equal(t, modifiedAr.Name, getAssistant.Name) |
|
assert.Equal(t, modifiedAr.Instructions, getAssistant.Instructions) |
|
assert.Equal(t, modifiedAr.Description, getAssistant.Description) |
|
}) |
|
|
|
t.Run("CreateAssistantFileEndpoint", func(t *testing.T) { |
|
t.Cleanup(tearDown()) |
|
file, assistant, err := createFileAndAssistant(t, app, appConfig) |
|
assert.NoError(t, err) |
|
|
|
afr := schema.AssistantFileRequest{FileID: file.ID} |
|
af, _, err := createAssistantFile(app, afr, assistant.ID) |
|
|
|
assert.NoError(t, err) |
|
assert.Equal(t, assistant.ID, af.AssistantID) |
|
}) |
|
t.Run("ListAssistantFilesEndpoint", func(t *testing.T) { |
|
t.Cleanup(tearDown()) |
|
file, assistant, err := createFileAndAssistant(t, app, appConfig) |
|
assert.NoError(t, err) |
|
|
|
afr := schema.AssistantFileRequest{FileID: file.ID} |
|
af, _, err := createAssistantFile(app, afr, assistant.ID) |
|
assert.NoError(t, err) |
|
|
|
assert.Equal(t, assistant.ID, af.AssistantID) |
|
}) |
|
t.Run("GetAssistantFileEndpoint", func(t *testing.T) { |
|
t.Cleanup(tearDown()) |
|
file, assistant, err := createFileAndAssistant(t, app, appConfig) |
|
assert.NoError(t, err) |
|
|
|
afr := schema.AssistantFileRequest{FileID: file.ID} |
|
af, _, err := createAssistantFile(app, afr, assistant.ID) |
|
assert.NoError(t, err) |
|
t.Cleanup(cleanupAssistantFile(t, app, af.ID, af.AssistantID)) |
|
|
|
target := fmt.Sprintf("/assistants/%s/files/%s", assistant.ID, file.ID) |
|
request := httptest.NewRequest(http.MethodGet, target, nil) |
|
response, err := app.Test(request) |
|
assert.NoError(t, err) |
|
|
|
var assistantFile AssistantFile |
|
err = json.NewDecoder(response.Body).Decode(&assistantFile) |
|
assert.NoError(t, err) |
|
|
|
assert.Equal(t, af.ID, assistantFile.ID) |
|
assert.Equal(t, af.AssistantID, assistantFile.AssistantID) |
|
}) |
|
t.Run("DeleteAssistantFileEndpoint", func(t *testing.T) { |
|
t.Cleanup(tearDown()) |
|
file, assistant, err := createFileAndAssistant(t, app, appConfig) |
|
assert.NoError(t, err) |
|
|
|
afr := schema.AssistantFileRequest{FileID: file.ID} |
|
af, _, err := createAssistantFile(app, afr, assistant.ID) |
|
assert.NoError(t, err) |
|
|
|
cleanupAssistantFile(t, app, af.ID, af.AssistantID)() |
|
|
|
assert.Empty(t, AssistantFiles) |
|
}) |
|
|
|
} |
|
|
|
func createFileAndAssistant(t *testing.T, app *fiber.App, o *config.ApplicationConfig) (schema.File, Assistant, error) { |
|
ar := &AssistantRequest{ |
|
Model: "ggml-gpt4all-j", |
|
Name: "3.5-turbo", |
|
Description: "Test Assistant", |
|
Instructions: "You are computer science teacher answering student questions", |
|
Tools: []Tool{{Type: Function}}, |
|
FileIDs: nil, |
|
Metadata: nil, |
|
} |
|
|
|
assistant, _, err := createAssistant(app, *ar) |
|
if err != nil { |
|
return schema.File{}, Assistant{}, err |
|
} |
|
t.Cleanup(cleanupAllAssistants(t, app, []string{assistant.ID})) |
|
|
|
file := CallFilesUploadEndpointWithCleanup(t, app, "test.txt", "file", "fine-tune", 5, o) |
|
t.Cleanup(func() { |
|
_, err := CallFilesDeleteEndpoint(t, app, file.ID) |
|
assert.NoError(t, err) |
|
}) |
|
return file, assistant, nil |
|
} |
|
|
|
func createAssistantFile(app *fiber.App, afr schema.AssistantFileRequest, assistantId string) (AssistantFile, *http.Response, error) { |
|
afrJson, err := json.Marshal(afr) |
|
if err != nil { |
|
return AssistantFile{}, nil, err |
|
} |
|
|
|
target := fmt.Sprintf("/assistants/%s/files", assistantId) |
|
request := httptest.NewRequest(http.MethodPost, target, strings.NewReader(string(afrJson))) |
|
request.Header.Set(fiber.HeaderContentType, "application/json") |
|
request.Header.Set("OpenAi-Beta", "assistants=v1") |
|
|
|
resp, err := app.Test(request) |
|
if err != nil { |
|
return AssistantFile{}, resp, err |
|
} |
|
|
|
var assistantFile AssistantFile |
|
all, err := io.ReadAll(resp.Body) |
|
if err != nil { |
|
return AssistantFile{}, resp, err |
|
} |
|
err = json.NewDecoder(strings.NewReader(string(all))).Decode(&assistantFile) |
|
if err != nil { |
|
return AssistantFile{}, resp, err |
|
} |
|
|
|
return assistantFile, resp, nil |
|
} |
|
|
|
func createAssistant(app *fiber.App, ar AssistantRequest) (Assistant, *http.Response, error) { |
|
assistant, err := json.Marshal(ar) |
|
if err != nil { |
|
return Assistant{}, nil, err |
|
} |
|
|
|
request := httptest.NewRequest(http.MethodPost, "/assistants", strings.NewReader(string(assistant))) |
|
request.Header.Set(fiber.HeaderContentType, "application/json") |
|
request.Header.Set("OpenAi-Beta", "assistants=v1") |
|
|
|
resp, err := app.Test(request) |
|
if err != nil { |
|
return Assistant{}, resp, err |
|
} |
|
|
|
bodyString, err := io.ReadAll(resp.Body) |
|
if err != nil { |
|
return Assistant{}, resp, err |
|
} |
|
|
|
var resultAssistant Assistant |
|
err = json.NewDecoder(strings.NewReader(string(bodyString))).Decode(&resultAssistant) |
|
return resultAssistant, resp, err |
|
} |
|
|
|
func cleanupAllAssistants(t *testing.T, app *fiber.App, ids []string) func() { |
|
return func() { |
|
for _, assistant := range ids { |
|
target := fmt.Sprintf("/assistants/%s", assistant) |
|
deleteReq := httptest.NewRequest(http.MethodDelete, target, nil) |
|
_, err := app.Test(deleteReq) |
|
if err != nil { |
|
t.Fatalf("Failed to delete assistant %s: %v", assistant, err) |
|
} |
|
} |
|
} |
|
} |
|
|
|
func cleanupAssistantFile(t *testing.T, app *fiber.App, fileId, assistantId string) func() { |
|
return func() { |
|
target := fmt.Sprintf("/assistants/%s/files/%s", assistantId, fileId) |
|
request := httptest.NewRequest(http.MethodDelete, target, nil) |
|
request.Header.Set(fiber.HeaderContentType, "application/json") |
|
request.Header.Set("OpenAi-Beta", "assistants=v1") |
|
|
|
resp, err := app.Test(request) |
|
assert.NoError(t, err) |
|
|
|
var dafr schema.DeleteAssistantFileResponse |
|
err = json.NewDecoder(resp.Body).Decode(&dafr) |
|
assert.NoError(t, err) |
|
assert.True(t, dafr.Deleted) |
|
} |
|
} |
|
|