File size: 4,465 Bytes
83e5367
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# coding:utf-8
import argparse
import requests
import json
import os
from cerebras import CerebrasUnofficial
from flask import Flask, request, Response, stream_with_context, jsonify
import sys

# -- Start of Config --

# Replace with your cerebras.ai session token found in the `authjs.session-token` cookie.
# Or you can `set AUTHJS_SESSION_TOKEN=authjs.session-token`
# This token is valid for one month.
authjs_session_token = '12345678-abcd-abcd-abcd-12345678abcd'

# Replace with any string you wish like `my-api-key`.
# Or you can `set SERVER_API_KEY=my-api-key`
# You should set it to update the session token in the future.
server_api_key = 'my-api-key'

# -- End of Config --

sys.tracebacklimit = 0

authjs_session_token = os.environ.get('AUTHJS_SESSION_TOKEN', authjs_session_token)
server_api_key = os.environ.get('SERVER_API_KEY', server_api_key)
print(f'Using the cookie: authjs.session-token={authjs_session_token}')
print(f'Your api key: {server_api_key}')

cerebras_ai = CerebrasUnofficial(authjs_session_token)

app = Flask(__name__)
app.json.sort_keys = False
parser = argparse.ArgumentParser(description='Cerebras.AI API')
parser.add_argument('--host', type=str, help='Set the ip address.(default: 0.0.0.0)', default='0.0.0.0')
parser.add_argument('--port', type=int, help='Set the port.(default: 7860)', default=7860)
args = parser.parse_args()

class Provider:
    key = ''
    max_tokens = None
    api_url = ''

    def __init__(self, request_key, model):
        self.request_key = request_key
        self.model = model
        self.init_request_info()

    def init_request_info(self):
        if self.request_key == server_api_key:
            self.api_url = cerebras_ai.api_url
            self.key = cerebras_ai.get_api_key()

@app.route('/api', methods=['GET', 'POST'])
@app.route('/', methods=['GET', 'POST'])
def index():
    return f'''
        renew/change token by visiting:<br>
        {request.host_url}renew?key={{your server api key}}&token={{your Cerebras authjs_session_token}}<br>
        <br>
        Your interface:<br>
        {request.host_url}v1/chat/completions OR<br>
        {request.host_url}api/v1/chat/completions<br>
        <br>
        For more infomation by visiting:<br>
        https://github.com/tastypear/CerebrasUnofficial
    '''

@app.route('/api/renew', methods=['GET', 'POST'])
@app.route('/renew', methods=['GET', 'POST'])
def renew_token():
    if server_api_key == request.args.get('key', ''):
        request_token = request.args.get('token', '')
        global cerebras_ai
        cerebras_ai = CerebrasUnofficial(request_token)
        return f'new authjs.session_token: {request_token}'
    else:
        raise Exception('invalid api key')

@app.route('/api/v1/models', methods=['GET', 'POST'])
@app.route('/v1/models', methods=['GET', 'POST'])
def model_list():
    model_list = {
        'object': 'list',
        'data': [{
            'id': 'llama3.1-8b',
            'object': 'model',
            'created': 1721692800,
            'owned_by': 'Meta'
        }, {
            'id': 'llama-3.3-70b',
            'object': 'model',
            'created': 1733443200,
            'owned_by': 'Meta'
        }, {
            'id': 'deepseek-r1-distill-llama-70b',
            'object': 'model',
            'created': 1733443200,
            'owned_by': 'deepseek'
        }]
    }
    return jsonify(model_list)


@app.route('/api/v1/chat/completions', methods=['POST'])
@app.route('/v1/chat/completions', methods=['POST'])
def proxy():
    request_key = request.headers['Authorization'].split(' ')[1]
    if server_api_key != request_key:
        raise Exception('invalid api key')

    headers = dict(request.headers)
    headers.pop('Host', None)
    headers.pop('Content-Length', None)

    headers['X-Use-Cache'] = 'false'
    model = request.get_json()['model']
    provider = Provider(request_key, model)
    headers['Authorization'] = f'Bearer {provider.key}'
    chat_api = f'{provider.api_url}/v1/chat/completions'

    def generate():
        with requests.post(chat_api, json=request.json, headers=headers, stream=True) as resp:
            for chunk in resp.iter_content(chunk_size=1024):
                if chunk:
                    chunk_str = chunk.decode('utf-8')
                    yield chunk_str
    return Response(stream_with_context(generate()), content_type='text/event-stream')

if __name__ == '__main__':
    app.run(host=args.host, port=args.port, debug=True)