dvc890's picture
Update api/chatgpt/api.go
fb2cf4c
package chatgpt
import (
"bytes"
"encoding/json"
"fmt"
"io"
"strings"
"github.com/PuerkitoBio/goquery"
"github.com/gin-gonic/gin"
"github.com/linweiyuan/go-chatgpt-api/api"
"github.com/linweiyuan/go-chatgpt-api/util/logger"
http "github.com/bogdanfinn/fhttp"
)
//goland:noinspection GoUnhandledErrorResult
func GetConversations(c *gin.Context) {
offset, ok := c.GetQuery("offset")
if !ok {
offset = "0"
}
limit, ok := c.GetQuery("limit")
if !ok {
limit = "20"
}
handleGet(c, apiPrefix+"/conversations?offset="+offset+"&limit="+limit, getConversationsErrorMessage)
}
//goland:noinspection GoUnhandledErrorResult
func CreateConversation(c *gin.Context) {
var request CreateConversationRequest
if err := c.BindJSON(&request); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, api.ReturnMessage(parseJsonErrorMessage))
return
}
if request.ConversationID == nil || *request.ConversationID == "" {
request.ConversationID = nil
}
if request.Messages[0].Author.Role == "" {
request.Messages[0].Author.Role = defaultRole
}
if request.Model == gpt4Model {
formParams := fmt.Sprintf(
"public_key=%s",
gpt4PublicKey,
)
req, _ := http.NewRequest(http.MethodPost, gpt4TokenUrl, strings.NewReader(formParams))
req.Header.Set("Content-Type", api.ContentType)
resp, err := api.Client.Do(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, api.ReturnMessage(err.Error()))
return
}
responseMap := make(map[string]string)
json.NewDecoder(resp.Body).Decode(&responseMap)
request.ArkoseToken = responseMap["token"]
}
jsonBytes, _ := json.Marshal(request)
logger.Info(fmt.Sprintf("ConversationRequest: %s", jsonBytes))
req, _ := http.NewRequest(http.MethodPost, apiPrefix+"/conversation", bytes.NewBuffer(jsonBytes))
req.Header.Set("User-Agent", api.UserAgent)
req.Header.Set("Authorization", api.GetAccessToken(c.GetHeader(api.AuthorizationHeader)))
req.Header.Set("Accept", "text/event-stream")
resp, err := api.Client.Do(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, api.ReturnMessage(err.Error()))
return
}
if resp.StatusCode != http.StatusOK {
responseMap := make(map[string]interface{})
json.NewDecoder(resp.Body).Decode(&responseMap)
c.AbortWithStatusJSON(resp.StatusCode, responseMap)
resp.Body.Close()
return
}
c.Set("oldpart", "")
Status, ParentMessageID, part := api.HandleConversationResponse(c, resp)
if Status {
resp.Body.Close()
ContinueConversation(c, *request.ConversationID, ParentMessageID, request.Model, part)
} else {
resp.Body.Close()
}
}
func ContinueConversation(c *gin.Context, conversationID string, parentMessageID string, model string, oldpart string) {
var request ContinueConversationRequest
request.ConversationID = &conversationID
request.ParentMessageID = parentMessageID
request.Model = model
request.Action = "continue"
if request.Model == gpt4Model {
formParams := fmt.Sprintf(
"public_key=%s",
gpt4PublicKey,
)
req, _ := http.NewRequest(http.MethodPost, gpt4TokenUrl, strings.NewReader(formParams))
req.Header.Set("Content-Type", api.ContentType)
resp, err := api.Client.Do(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, api.ReturnMessage(err.Error()))
return
}
responseMap := make(map[string]string)
json.NewDecoder(resp.Body).Decode(&responseMap)
request.ArkoseToken = responseMap["token"]
}
jsonBytes, _ := json.Marshal(request)
logger.Info(fmt.Sprintf("ContinueConversationRequest: %s", jsonBytes))
req, _ := http.NewRequest(http.MethodPost, apiPrefix+"/conversation", bytes.NewBuffer(jsonBytes))
req.Header.Set("User-Agent", api.UserAgent)
req.Header.Set("Authorization", api.GetAccessToken(c.GetHeader(api.AuthorizationHeader)))
req.Header.Set("Accept", "text/event-stream")
resp, err := api.Client.Do(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, api.ReturnMessage(err.Error()))
return
}
if resp.StatusCode != http.StatusOK {
responseMap := make(map[string]interface{})
json.NewDecoder(resp.Body).Decode(&responseMap)
c.AbortWithStatusJSON(resp.StatusCode, responseMap)
resp.Body.Close()
return
}
c.Set("oldpart", oldpart)
Status, ParentMessageID, part := api.HandleConversationResponse(c, resp)
if Status {
resp.Body.Close()
ContinueConversation(c, *request.ConversationID, ParentMessageID, request.Model, part)
} else {
resp.Body.Close()
}
}
//goland:noinspection GoUnhandledErrorResult
func GenerateTitle(c *gin.Context) {
var request GenerateTitleRequest
if err := c.BindJSON(&request); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, api.ReturnMessage(parseJsonErrorMessage))
return
}
jsonBytes, _ := json.Marshal(request)
handlePost(c, apiPrefix+"/conversation/gen_title/"+c.Param("id"), string(jsonBytes), generateTitleErrorMessage)
}
//goland:noinspection GoUnhandledErrorResult
func GetConversation(c *gin.Context) {
handleGet(c, apiPrefix+"/conversation/"+c.Param("id"), getContentErrorMessage)
}
//goland:noinspection GoUnhandledErrorResult
func UpdateConversation(c *gin.Context) {
var request PatchConversationRequest
if err := c.BindJSON(&request); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, api.ReturnMessage(parseJsonErrorMessage))
return
}
// bool default to false, then will hide (delete) the conversation
if request.Title != nil {
request.IsVisible = true
}
jsonBytes, _ := json.Marshal(request)
handlePatch(c, apiPrefix+"/conversation/"+c.Param("id"), string(jsonBytes), updateConversationErrorMessage)
}
//goland:noinspection GoUnhandledErrorResult
func FeedbackMessage(c *gin.Context) {
var request FeedbackMessageRequest
if err := c.BindJSON(&request); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, api.ReturnMessage(parseJsonErrorMessage))
return
}
jsonBytes, _ := json.Marshal(request)
handlePost(c, apiPrefix+"/conversation/message_feedback", string(jsonBytes), feedbackMessageErrorMessage)
}
//goland:noinspection GoUnhandledErrorResult
func ClearConversations(c *gin.Context) {
jsonBytes, _ := json.Marshal(PatchConversationRequest{
IsVisible: false,
})
handlePatch(c, apiPrefix+"/conversations", string(jsonBytes), clearConversationsErrorMessage)
}
//goland:noinspection GoUnhandledErrorResult
func GetModels(c *gin.Context) {
handleGet(c, apiPrefix+"/models", getModelsErrorMessage)
}
func GetAccountCheck(c *gin.Context) {
handleGet(c, apiPrefix+"/accounts/check", getAccountCheckErrorMessage)
}
//goland:noinspection GoUnhandledErrorResult
func Login(c *gin.Context) {
var loginInfo api.LoginInfo
if err := c.ShouldBindJSON(&loginInfo); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, api.ReturnMessage(api.ParseUserInfoErrorMessage))
return
}
userLogin := UserLogin{
client: api.NewHttpClient(),
}
// get csrf token
req, _ := http.NewRequest(http.MethodGet, csrfUrl, nil)
req.Header.Set("User-Agent", api.UserAgent)
resp, err := userLogin.client.Do(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, api.ReturnMessage(err.Error()))
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if resp.StatusCode == http.StatusForbidden {
doc, _ := goquery.NewDocumentFromReader(resp.Body)
alert := doc.Find(".message").Text()
if alert != "" {
c.AbortWithStatusJSON(resp.StatusCode, api.ReturnMessage(strings.TrimSpace(alert)))
return
}
}
c.AbortWithStatusJSON(resp.StatusCode, api.ReturnMessage(getCsrfTokenErrorMessage))
return
}
// get authorized url
responseMap := make(map[string]string)
json.NewDecoder(resp.Body).Decode(&responseMap)
authorizedUrl, statusCode, err := userLogin.GetAuthorizedUrl(responseMap["csrfToken"])
if err != nil {
c.AbortWithStatusJSON(statusCode, api.ReturnMessage(err.Error()))
return
}
// get state
state, statusCode, err := userLogin.GetState(authorizedUrl)
if err != nil {
c.AbortWithStatusJSON(statusCode, api.ReturnMessage(err.Error()))
return
}
// check username
statusCode, err = userLogin.CheckUsername(state, loginInfo.Username)
if err != nil {
c.AbortWithStatusJSON(statusCode, api.ReturnMessage(err.Error()))
return
}
// check password
_, statusCode, err = userLogin.CheckPassword(state, loginInfo.Username, loginInfo.Password)
if err != nil {
c.AbortWithStatusJSON(statusCode, api.ReturnMessage(err.Error()))
return
}
// get access token
accessToken, statusCode, err := userLogin.GetAccessToken("")
if err != nil {
c.AbortWithStatusJSON(statusCode, api.ReturnMessage(err.Error()))
return
}
c.Writer.WriteString(accessToken)
}
func Fallback(c *gin.Context) {
method := c.Request.Method
url := apiPrefix + c.Request.URL.Path
queryParams := c.Request.URL.Query().Encode()
if queryParams != "" {
url += "?" + queryParams
}
var requestBody string
if c.Request.Method == http.MethodPost || c.Request.Method == http.MethodPatch {
body, _ := io.ReadAll(c.Request.Body)
requestBody = string(body)
}
c.Status(http.StatusOK)
switch method {
case http.MethodGet:
handleGet(c, url, fallbackErrorMessage)
case http.MethodPost:
handlePost(c, url, requestBody, fallbackErrorMessage)
case http.MethodPatch:
handlePatch(c, url, requestBody, fallbackErrorMessage)
default:
c.JSON(http.StatusMethodNotAllowed, gin.H{"message": fallbackMethodNotAllowedMessage})
}
}
//goland:noinspection GoUnhandledErrorResult
func handleGet(c *gin.Context, url string, errorMessage string) {
req, _ := http.NewRequest(http.MethodGet, url, nil)
req.Header.Set("User-Agent", api.UserAgent)
req.Header.Set("Authorization", api.GetAccessToken(c.GetHeader(api.AuthorizationHeader)))
resp, err := api.Client.Do(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, api.ReturnMessage(err.Error()))
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
c.AbortWithStatusJSON(resp.StatusCode, api.ReturnMessage(errorMessage))
return
}
io.Copy(c.Writer, resp.Body)
}
//goland:noinspection GoUnhandledErrorResult
func handlePost(c *gin.Context, url string, requestBody string, errorMessage string) {
req, _ := http.NewRequest(http.MethodPost, url, strings.NewReader(requestBody))
handlePostOrPatch(c, req, errorMessage)
}
//goland:noinspection GoUnhandledErrorResult
func handlePatch(c *gin.Context, url string, requestBody string, errorMessage string) {
req, _ := http.NewRequest(http.MethodPatch, url, strings.NewReader(requestBody))
handlePostOrPatch(c, req, errorMessage)
}
//goland:noinspection GoUnhandledErrorResult
func handlePostOrPatch(c *gin.Context, req *http.Request, errorMessage string) {
req.Header.Set("User-Agent", api.UserAgent)
req.Header.Set("Authorization", api.GetAccessToken(c.GetHeader(api.AuthorizationHeader)))
resp, err := api.Client.Do(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, api.ReturnMessage(err.Error()))
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
c.AbortWithStatusJSON(resp.StatusCode, api.ReturnMessage(errorMessage))
return
}
io.Copy(c.Writer, resp.Body)
}