File size: 2,273 Bytes
63a5c24
9513cae
63a5c24
9513cae
63a5c24
5fb8127
63a5c24
 
 
9513cae
 
 
 
 
63a5c24
9513cae
 
 
 
 
 
 
 
5fb8127
 
 
 
 
63a5c24
9513cae
aca8fee
8b0e392
5fb8127
 
63a5c24
5fb8127
af66144
 
 
 
 
9513cae
 
 
8b0e392
 
 
 
 
 
 
5fb8127
 
9513cae
63a5c24
 
5fb8127
63a5c24
 
 
 
 
 
 
 
 
 
 
5fb8127
 
 
9513cae
aca8fee
8b0e392
5fb8127
 
181d76c
e00c743
de14443
5fb8127
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
import json
from typing import List, Tuple
from collections import OrderedDict

import gradio as gr
from openai import OpenAI


config = json.loads(os.environ['CONFIG'])


model_names = list(config.keys())
personas = list(OrderedDict.fromkeys(persona for name in config for persona in config[name]["personas"]))


clients = {}
for name in config:
    client = OpenAI(
        base_url=f"{os.environ[config[name]['api_url']]}/v1",
        api_key=os.environ[config[name]['api_key']],
    )
    clients[name] = client



def respond(
    message,
    history: List[Tuple[str, str]],
    persona,
    model,
    conversational,
    max_tokens,
):
    messages = []

    try:
        system_prompt = config[model]["personas"][persona]
    except KeyError:
        supported_personas = list(config[model]["personas"].keys())
        raise gr.Error(f"Model '{model}' does not support persona '{persona}', only {supported_personas}")
    if system_prompt is not None:
        messages.append({"role": "system", "content": system_prompt})

    if conversational:
        for val in history[-2:]:
            if val[0]:
                messages.append({"role": "user", "content": val[0]})
            if val[1]:
                messages.append({"role": "assistant", "content": val[1]})

    messages.append({"role": "user", "content": message})

    completion = clients[model].chat.completions.create(
        model="neongeckocom/NeonLLM",
        messages=messages,
        max_tokens=max_tokens,
        temperature=0,
        extra_body={
            "repetition_penalty": 1.05,
            "use_beam_search": True,
            "best_of": 5,
        },
    )
    response = completion.choices[0].message.content
    return response


demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Radio(choices=personas, value="default", label="persona"),
        gr.Radio(choices=model_names, value="stable", label="model"),
        gr.Checkbox(value=True, label="conversational"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
    ],
    additional_inputs_accordion=gr.Accordion(label="Config", open=True),
    title="NeonLLM (v2024-06-17)",
    concurrency_limit=5,
)


if __name__ == "__main__":
    demo.launch()