File size: 8,809 Bytes
f0250b1
c9e00de
 
15bbe10
fb39607
f0250b1
2cb6075
 
 
1fd9ae1
15bbe10
 
a4e6a71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15bbe10
a4e6a71
 
 
 
 
 
 
f5f5cd4
a4e6a71
f5f5cd4
 
 
 
 
 
a4e6a71
f5f5cd4
a4e6a71
 
 
f5f5cd4
a4e6a71
 
 
f5f5cd4
 
 
 
 
 
 
 
 
 
 
 
 
15bbe10
a4e6a71
f5f5cd4
a4e6a71
eebf495
 
 
f5f5cd4
 
 
 
 
 
 
 
 
 
 
 
eebf495
 
f5f5cd4
 
eebf495
a4e6a71
 
 
 
fe9a872
187c8cf
 
 
 
 
 
 
fe9a872
a4e6a71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9e00de
 
 
9d92eeb
 
 
 
 
 
fe9a872
 
a4e6a71
 
 
 
 
 
c9e00de
 
76ed6d2
c9e00de
a4e6a71
76ed6d2
 
187c8cf
76ed6d2
187c8cf
a4e6a71
76ed6d2
a4e6a71
 
 
 
 
 
 
 
 
fb39607
f5f5cd4
a4e6a71
 
 
 
 
fb39607
 
 
 
 
 
 
 
 
76ed6d2
 
 
 
a4e6a71
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import streamlit as st
from main import benchmark_model_multithreaded, benchmark_model_sequential
from prompts import questions as predefined_questions
import requests
import pandas as pd

# Set the title in the browser tab
st.set_page_config(page_title="Aidan Bench - Generator")

st.title("Aidan Bench - Generator")

# API Key Inputs with Security and User Experience Enhancements
st.warning("Please keep your API keys secure and confidential. This app does not store or log your API keys.")

if "open_router_key" not in st.session_state:
    st.session_state.open_router_key = ""
if "openai_api_key" not in st.session_state:
    st.session_state.openai_api_key = ""

open_router_key = st.text_input("Enter your Open Router API Key:", type="password", value=st.session_state.open_router_key)
openai_api_key = st.text_input("Enter your OpenAI API Key:", type="password", value=st.session_state.openai_api_key)

if st.button("Confirm API Keys"):
    if open_router_key and openai_api_key:
        st.session_state.open_router_key = open_router_key
        st.session_state.openai_api_key = openai_api_key
        st.success("API keys confirmed!")
    else:
        st.warning("Please enter both API keys.")

# Access API keys from session state
if st.session_state.open_router_key and st.session_state.openai_api_key:
    # Fetch models from OpenRouter API
    try:
        response = requests.get("https://openrouter.ai/api/v1/models")
        response.raise_for_status()  # Raise an exception for bad status codes
        all_models = response.json()["data"]
        # Sort models alphabetically by their ID
        all_models.sort(key=lambda model: model["id"])

        # --- Create dictionaries for easy model lookup ---
        models_by_id = {model["id"]: model for model in all_models}
        judge_models = [model["id"] for model in all_models if "gpt" in model["id"]]
        judge_models.sort()

        model_names = list(models_by_id.keys())
    except requests.exceptions.RequestException as e:
        st.error(f"Error fetching models from OpenRouter API: {e}")
        model_names = []  # Provide an empty list if API call fails
        judge_models = []

    # Model Selection
    if model_names:
        model_name = st.selectbox("Select a Contestant Model", model_names)
        # --- Display pricing for the selected model ---
        selected_model = models_by_id.get(model_name)
        if selected_model:
            pricing_info = selected_model.get('pricing', {})
            prompt_price = float(pricing_info.get("prompt", 0)) * 1000000
            completion_price = float(pricing_info.get("completion", 0)) * 1000000

            # Display pricing information with increased precision
            st.write(f"**Prompt Pricing:** ${prompt_price:.2f}/Million tokens (if applicable)")
            st.write(f"**Completion Pricing:** ${completion_price:.2f}/Million tokens")
        else:
            st.write("**Pricing:** N/A")
    else:
        st.error("No models available. Please check your API connection.")
        st.stop()

    # Judge Model Selection
    if judge_models:
        judge_model_name = st.selectbox("Select a Judge Model", judge_models)
        # --- Display pricing for the selected judge model ---
        selected_judge_model = models_by_id.get(judge_model_name)
        if selected_judge_model:
            pricing_info = selected_judge_model.get('pricing', {})
            prompt_price = float(pricing_info.get("prompt", 0)) * 1000000
            completion_price = float(pricing_info.get("completion", 0)) * 1000000

            # Display pricing information with increased precision
            st.write(f"**Prompt Pricing:** ${prompt_price:.2f}/Million tokens (if applicable)")
            st.write(f"**Completion Pricing:** ${completion_price:.2f}/Million tokens")
        else:
            st.write("**Pricing:** N/A")
    else:
        st.error("No judge models available. Please check your API connection.")
        st.stop()


    # Initialize session state for user_questions and predefined_questions
    if "user_questions" not in st.session_state:
        st.session_state.user_questions = []

    # Threshold Sliders
    st.sidebar.subheader("Threshold Sliders")
    coherence_threshold = st.sidebar.slider("Coherence Threshold (0-5):", 0, 5, 3)
    novelty_threshold = st.sidebar.slider("Novelty Threshold (0-1):", 0.0, 1.0, 0.1)

    st.sidebar.subheader("Temp Sliders")
    temp_threshold = st.sidebar.slider("Temperature (0-2):", 0.0, 2.0, 1.0)
    top_p = st.sidebar.slider("Top P (0-1):", 0.0, 1.0, 1.0)

    # Workflow Selection
    workflow = st.radio("Select Workflow:", ["Use Predefined Questions", "Use User-Defined Questions"])

    # Handle Predefined Questions
    if workflow == "Use Predefined Questions":
        st.header("Question Selection")
        # Multiselect for predefined questions
        selected_questions = st.multiselect(
            "Select questions to benchmark:",
            predefined_questions,
            predefined_questions  # Select all by default
        )

    # Handle User-Defined Questions
    elif workflow == "Use User-Defined Questions":
        st.header("Question Input")

        # Input for adding a new question
        new_question = st.text_input("Enter a new question:")
        if st.button("Add Question") and new_question:
            new_question = new_question.strip()  # Remove leading/trailing whitespace
            if new_question and new_question not in st.session_state.user_questions:
                st.session_state.user_questions.append(new_question)  # Append to session state
                st.success(f"Question '{new_question}' added successfully.")
            else:
                st.warning("Question already exists or is empty!")

        # Display multiselect with updated user questions
        selected_questions = st.multiselect(
            "Select your custom questions:",
            options=st.session_state.user_questions,
            default=st.session_state.user_questions
        )

    # Display selected questions
    st.write("Selected Questions:", selected_questions)

    # Choose execution mode
    execution_mode = st.radio("Execution Mode:", ["Sequential", "Multithreaded"])

    # If multithreaded, allow user to configure thread pool size
    if execution_mode == "Multithreaded":
        max_threads = st.slider("Maximum Number of Threads:", 1, 10, 4)  # Default to 4 threads
    else:
        max_threads = None  # For sequential mode



    # Benchmark Execution
    if st.button("Start Benchmark"):
        if not selected_questions:
            st.warning("Please select at least one question.")
        else:
            num_questions = len(selected_questions)
            results = []

            # Stop button (not implemented yet)
            stop_button = st.button("Stop Benchmark")

            # Benchmarking logic using the chosen execution mode
            if execution_mode == "Sequential":
                question_results = benchmark_model_sequential(model_name, selected_questions, st.session_state.open_router_key, st.session_state.openai_api_key,judge_model_name,coherence_threshold,novelty_threshold,temp_threshold,top_p)
            else:  # Multithreaded
                question_results = benchmark_model_multithreaded(model_name, selected_questions, st.session_state.open_router_key, st.session_state.openai_api_key, max_threads, judge_model_name, coherence_threshold,novelty_threshold,temp_threshold,top_p)

            results.extend(question_results)

            # Display results in a table
            st.write("Results:")
            results_table = []
            for result in results:
                for answer in result["answers"]:
                    results_table.append({
                        "Question": result["question"],
                        "Answer": answer,
                        "Contestant Model": model_name,
                        "Judge Model": judge_model_name,
                        "Coherence Score": result["coherence_score"],
                        "Novelty Score": result["novelty_score"]
                    })
            st.table(results_table)

            df = pd.DataFrame(results_table)  # Create a Pandas DataFrame from the results
            csv = df.to_csv(index=False).encode('utf-8')  # Convert DataFrame to CSV
            st.download_button(
                label="Export Results as CSV",
                data=csv,
                file_name="benchmark_results.csv",
                mime='text/csv'
                )

            if stop_button:
                st.warning("Partial results displayed due to interruption.")
            else:
                st.success("Benchmark completed!")

else:
    st.warning("Please confirm your API keys first.")