|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
package main |
|
|
|
import ( |
|
"bytes" |
|
"encoding/json" |
|
"flag" |
|
"fmt" |
|
"io" |
|
"net/http" |
|
"os" |
|
"strings" |
|
"time" |
|
|
|
"github.com/gin-contrib/cors" |
|
"github.com/gin-gonic/gin" |
|
) |
|
|
|
func parseAuthorizationHeader(c *gin.Context) (string, error) { |
|
|
|
apiKey := os.Getenv("KEY") |
|
if apiKey != "" { |
|
return apiKey, nil |
|
} |
|
|
|
authorizationHeader := c.GetHeader("Authorization") |
|
if !strings.HasPrefix(authorizationHeader, "Bearer ") { |
|
return "", fmt.Errorf("invalid Authorization header format") |
|
} |
|
return strings.TrimPrefix(authorizationHeader, "Bearer "), nil |
|
} |
|
|
|
func cohereRequest(c *gin.Context, openAIReq OpenAIRequest) { |
|
apiKey, err := parseAuthorizationHeader(c) |
|
if err != nil { |
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) |
|
return |
|
} |
|
|
|
cohereReq := CohereRequest{ |
|
Model: openAIReq.Model, |
|
ChatHistory: []ChatMessage{}, |
|
Message: "", |
|
Stream: openAIReq.Stream, |
|
MaxTokens: openAIReq.MaxTokens, |
|
} |
|
|
|
for _, msg := range openAIReq.Messages { |
|
if msg.Role == "user" { |
|
cohereReq.Message = msg.Content |
|
} else { |
|
var role string |
|
if msg.Role == "assistant" { |
|
role = "CHATBOT" |
|
} else if msg.Role == "system" { |
|
role = "SYSTEM" |
|
} else { |
|
role = "USER" |
|
} |
|
cohereReq.ChatHistory = append(cohereReq.ChatHistory, ChatMessage{ |
|
Role: role, |
|
Message: msg.Content, |
|
}) |
|
} |
|
} |
|
|
|
reqBody, _ := json.Marshal(cohereReq) |
|
req, err := http.NewRequest("POST", "https://api.cohere.ai/v1/chat", bytes.NewBuffer(reqBody)) |
|
if err != nil { |
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) |
|
return |
|
} |
|
|
|
req.Header.Set("Accept", "application/json") |
|
req.Header.Set("Content-Type", "application/json") |
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey)) |
|
|
|
client := &http.Client{} |
|
resp, err := client.Do(req) |
|
if err != nil { |
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) |
|
return |
|
} |
|
defer resp.Body.Close() |
|
|
|
c.Header("Content-Type", "text/event-stream") |
|
c.Header("Cache-Control", "no-cache") |
|
c.Header("Connection", "keep-alive") |
|
|
|
reader := resp.Body |
|
buffer := make([]byte, 1048576) |
|
|
|
isFirstChunk := true |
|
|
|
for { |
|
n, err := reader.Read(buffer) |
|
if err != nil { |
|
if err == io.EOF { |
|
break |
|
} |
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) |
|
return |
|
} |
|
|
|
var cohereResp CohereResponse |
|
decoder := json.NewDecoder(bytes.NewReader(buffer[:n])) |
|
decoder.UseNumber() |
|
err = decoder.Decode(&cohereResp) |
|
if err != nil { |
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) |
|
return |
|
} |
|
if cohereResp.IsFinished { |
|
var resp OpenAIResponse |
|
resp.ID = "chatcmpl-123" |
|
resp.Object = "chat.completion.chunk" |
|
resp.Created = time.Now().Unix() |
|
resp.Model = openAIReq.Model |
|
resp.Choices = []OpenAIChoice{ |
|
{ |
|
Index: 0, |
|
Delta: OpenAIDelta{}, |
|
Logprobs: nil, |
|
FinishReason: stringPtr("stop"), |
|
}, |
|
} |
|
|
|
respBytes, _ := json.Marshal(resp) |
|
c.Data(http.StatusOK, "application/json", []byte("data: ")) |
|
c.Data(http.StatusOK, "application/json", respBytes) |
|
c.Data(http.StatusOK, "application/json", []byte("\n\n")) |
|
|
|
c.Data(http.StatusOK, "application/json", []byte("data: [DONE]\n\n")) |
|
break |
|
} else { |
|
var resp OpenAIResponse |
|
resp.ID = "chatcmpl-123" |
|
resp.Object = "chat.completion.chunk" |
|
resp.Created = time.Now().Unix() |
|
resp.Model = openAIReq.Model |
|
|
|
if !isFirstChunk { |
|
resp.Choices = []OpenAIChoice{ |
|
{ |
|
Index: 0, |
|
Delta: OpenAIDelta{Content: cohereResp.Text}, |
|
Logprobs: nil, |
|
FinishReason: nil, |
|
}, |
|
} |
|
} else { |
|
resp.Choices = []OpenAIChoice{ |
|
{ |
|
Index: 0, |
|
Delta: OpenAIDelta{}, |
|
Logprobs: nil, |
|
FinishReason: nil, |
|
}, |
|
} |
|
isFirstChunk = false |
|
} |
|
|
|
respBytes, _ := json.Marshal(resp) |
|
c.Data(http.StatusOK, "application/json", []byte("data: ")) |
|
c.Data(http.StatusOK, "application/json", respBytes) |
|
c.Data(http.StatusOK, "application/json", []byte("\n\n")) |
|
} |
|
} |
|
} |
|
|
|
func cohereNonStreamRequest(c *gin.Context, openAIReq OpenAIRequest) { |
|
apiKey, err := parseAuthorizationHeader(c) |
|
if err != nil { |
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) |
|
return |
|
} |
|
|
|
cohereReq := CohereRequest{ |
|
Model: openAIReq.Model, |
|
ChatHistory: []ChatMessage{}, |
|
Message: "", |
|
Stream: openAIReq.Stream, |
|
MaxTokens: openAIReq.MaxTokens, |
|
} |
|
|
|
for _, msg := range openAIReq.Messages { |
|
if msg.Role == "user" { |
|
cohereReq.Message = msg.Content |
|
} else { |
|
var role string |
|
if msg.Role == "assistant" { |
|
role = "CHATBOT" |
|
} else if msg.Role == "system" { |
|
role = "SYSTEM" |
|
} else { |
|
role = "USER" |
|
} |
|
cohereReq.ChatHistory = append(cohereReq.ChatHistory, ChatMessage{ |
|
Role: role, |
|
Message: msg.Content, |
|
}) |
|
} |
|
} |
|
|
|
reqBody, _ := json.Marshal(cohereReq) |
|
req, err := http.NewRequest("POST", "https://api.cohere.ai/v1/chat", bytes.NewBuffer(reqBody)) |
|
if err != nil { |
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) |
|
return |
|
} |
|
|
|
req.Header.Set("Content-Type", "application/json") |
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey)) |
|
|
|
client := &http.Client{} |
|
resp, err := client.Do(req) |
|
if err != nil { |
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) |
|
return |
|
} |
|
defer resp.Body.Close() |
|
|
|
c.Header("Content-Type", "application/json") |
|
c.Header("Cache-Control", "no-cache") |
|
c.Header("Connection", "keep-alive") |
|
reader := resp.Body |
|
buffer := make([]byte, 1048576) |
|
n, _ := reader.Read(buffer) |
|
var cohereResp CohereResponse |
|
err = json.Unmarshal(buffer[:n], &cohereResp) |
|
if err != nil { |
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) |
|
return |
|
} |
|
|
|
var aiResp OpenAINonStreamResponse |
|
aiResp.ID = "chatcmpl-123" |
|
aiResp.Object = "chat.completion" |
|
aiResp.Created = time.Now().Unix() |
|
aiResp.Model = openAIReq.Model |
|
aiResp.Choices = []OpenAINonStreamChoice{ |
|
{ |
|
Index: 0, |
|
Message: OpenAIDelta{Content: cohereResp.Text, Role: "assistant"}, |
|
FinishReason: stringPtr("stop"), |
|
}, |
|
} |
|
|
|
c.JSON(http.StatusOK, aiResp) |
|
} |
|
|
|
func handler(c *gin.Context) { |
|
var openAIReq OpenAIRequest |
|
|
|
if err := c.BindJSON(&openAIReq); err != nil { |
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) |
|
return |
|
} |
|
|
|
allowModels := []string{"command-r-plus", "command-r", "command", "command-light", "command-light-nightly", "command-nightly"} |
|
|
|
if !isInSlice(openAIReq.Model, allowModels) { |
|
openAIReq.Model = "command-r-plus" |
|
} |
|
|
|
|
|
switch openAIReq.Model { |
|
case "command-light": |
|
openAIReq.MaxTokens = 4000 |
|
case "command": |
|
openAIReq.MaxTokens = 4000 |
|
case "command-light-nightly": |
|
openAIReq.MaxTokens = 4000 |
|
case "command-nightly": |
|
openAIReq.MaxTokens = 4000 |
|
case "command-r": |
|
openAIReq.MaxTokens = 4000 |
|
case "command-r-plus": |
|
openAIReq.MaxTokens = 4000 |
|
default: |
|
openAIReq.MaxTokens = 4096 |
|
} |
|
|
|
if openAIReq.Stream { |
|
cohereRequest(c, openAIReq) |
|
} else { |
|
cohereNonStreamRequest(c, openAIReq) |
|
} |
|
} |
|
|
|
func main() { |
|
|
|
var port, key string |
|
|
|
flag.StringVar(&port, "p", "", "Port to run the server on") |
|
flag.StringVar(&key, "k", "", "API key for Cohere") |
|
flag.Parse() |
|
|
|
if key != "" { |
|
os.Setenv("KEY", key) |
|
} |
|
|
|
if port == "" { |
|
port = os.Getenv("PORT") |
|
} |
|
|
|
if port == "" { |
|
port = "6600" |
|
} |
|
|
|
fmt.Println("Running on port " + port + "\nHave fun with Cohere2OpenAI!") |
|
|
|
gin.SetMode(gin.ReleaseMode) |
|
r := gin.Default() |
|
r.Use(cors.Default()) |
|
r.GET("/", func(c *gin.Context) { |
|
c.JSON(http.StatusOK, gin.H{ |
|
"message": "Thankyou", |
|
}) |
|
}) |
|
r.POST("/v1/chat/completions", handler) |
|
r.GET("/v1/models", func(c *gin.Context) { |
|
c.JSON(http.StatusOK, gin.H{ |
|
"object": "list", |
|
"data": []gin.H{ |
|
{ |
|
"id": "command-r", |
|
"object": "model", |
|
"created": 1692901427, |
|
"owned_by": "system", |
|
}, |
|
{ |
|
"id": "command-r-plus", |
|
"object": "model", |
|
"created": 1692901427, |
|
"owned_by": "system", |
|
}, |
|
{ |
|
"id": "command-light", |
|
"object": "model", |
|
"created": 1692901427, |
|
"owned_by": "system", |
|
}, |
|
{ |
|
"id": "command-light-nightly", |
|
"object": "model", |
|
"created": 1692901427, |
|
"owned_by": "system", |
|
}, |
|
{ |
|
"id": "command", |
|
"object": "model", |
|
"created": 1692901427, |
|
"owned_by": "system", |
|
}, |
|
{ |
|
"id": "command-nightly", |
|
"object": "model", |
|
"created": 1692901427, |
|
"owned_by": "system", |
|
}, |
|
}, |
|
}) |
|
}) |
|
|
|
r.NoRoute(func(c *gin.Context) { |
|
c.JSON(http.StatusNotFound, gin.H{ |
|
"code": http.StatusNotFound, |
|
"message": "Path not found", |
|
}) |
|
}) |
|
|
|
r.Run(":" + port) |
|
} |
|
|