Spaces:
Running
Running
File size: 4,465 Bytes
83e5367 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
# coding:utf-8
import argparse
import requests
import json
import os
from cerebras import CerebrasUnofficial
from flask import Flask, request, Response, stream_with_context, jsonify
import sys
# -- Start of Config --
# Replace with your cerebras.ai session token found in the `authjs.session-token` cookie.
# Or you can `set AUTHJS_SESSION_TOKEN=authjs.session-token`
# This token is valid for one month.
authjs_session_token = '12345678-abcd-abcd-abcd-12345678abcd'
# Replace with any string you wish like `my-api-key`.
# Or you can `set SERVER_API_KEY=my-api-key`
# You should set it to update the session token in the future.
server_api_key = 'my-api-key'
# -- End of Config --
sys.tracebacklimit = 0
authjs_session_token = os.environ.get('AUTHJS_SESSION_TOKEN', authjs_session_token)
server_api_key = os.environ.get('SERVER_API_KEY', server_api_key)
print(f'Using the cookie: authjs.session-token={authjs_session_token}')
print(f'Your api key: {server_api_key}')
cerebras_ai = CerebrasUnofficial(authjs_session_token)
app = Flask(__name__)
app.json.sort_keys = False
parser = argparse.ArgumentParser(description='Cerebras.AI API')
parser.add_argument('--host', type=str, help='Set the ip address.(default: 0.0.0.0)', default='0.0.0.0')
parser.add_argument('--port', type=int, help='Set the port.(default: 7860)', default=7860)
args = parser.parse_args()
class Provider:
key = ''
max_tokens = None
api_url = ''
def __init__(self, request_key, model):
self.request_key = request_key
self.model = model
self.init_request_info()
def init_request_info(self):
if self.request_key == server_api_key:
self.api_url = cerebras_ai.api_url
self.key = cerebras_ai.get_api_key()
@app.route('/api', methods=['GET', 'POST'])
@app.route('/', methods=['GET', 'POST'])
def index():
return f'''
renew/change token by visiting:<br>
{request.host_url}renew?key={{your server api key}}&token={{your Cerebras authjs_session_token}}<br>
<br>
Your interface:<br>
{request.host_url}v1/chat/completions OR<br>
{request.host_url}api/v1/chat/completions<br>
<br>
For more infomation by visiting:<br>
https://github.com/tastypear/CerebrasUnofficial
'''
@app.route('/api/renew', methods=['GET', 'POST'])
@app.route('/renew', methods=['GET', 'POST'])
def renew_token():
if server_api_key == request.args.get('key', ''):
request_token = request.args.get('token', '')
global cerebras_ai
cerebras_ai = CerebrasUnofficial(request_token)
return f'new authjs.session_token: {request_token}'
else:
raise Exception('invalid api key')
@app.route('/api/v1/models', methods=['GET', 'POST'])
@app.route('/v1/models', methods=['GET', 'POST'])
def model_list():
model_list = {
'object': 'list',
'data': [{
'id': 'llama3.1-8b',
'object': 'model',
'created': 1721692800,
'owned_by': 'Meta'
}, {
'id': 'llama-3.3-70b',
'object': 'model',
'created': 1733443200,
'owned_by': 'Meta'
}, {
'id': 'deepseek-r1-distill-llama-70b',
'object': 'model',
'created': 1733443200,
'owned_by': 'deepseek'
}]
}
return jsonify(model_list)
@app.route('/api/v1/chat/completions', methods=['POST'])
@app.route('/v1/chat/completions', methods=['POST'])
def proxy():
request_key = request.headers['Authorization'].split(' ')[1]
if server_api_key != request_key:
raise Exception('invalid api key')
headers = dict(request.headers)
headers.pop('Host', None)
headers.pop('Content-Length', None)
headers['X-Use-Cache'] = 'false'
model = request.get_json()['model']
provider = Provider(request_key, model)
headers['Authorization'] = f'Bearer {provider.key}'
chat_api = f'{provider.api_url}/v1/chat/completions'
def generate():
with requests.post(chat_api, json=request.json, headers=headers, stream=True) as resp:
for chunk in resp.iter_content(chunk_size=1024):
if chunk:
chunk_str = chunk.decode('utf-8')
yield chunk_str
return Response(stream_with_context(generate()), content_type='text/event-stream')
if __name__ == '__main__':
app.run(host=args.host, port=args.port, debug=True) |