File size: 8,596 Bytes
9866f52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import collections
import lm_eval.tasks
import random
import time
from datetime import datetime as dt
import bittensor as bt
from tqdm import tqdm
import json

import http.client
import os
from argparse import ArgumentParser

parser = ArgumentParser()

parser.add_argument("--validator", required=True, type=str, help="validator name", choices=["opentensor_foundation", "taostats"], default="float16")
args = parser.parse_args()

default_prompt = '''
You are Chattensor.
Chattensor is a research project by Opentensor Cortex.
Chattensor is designed to be able to assist with a wide range of tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. As a language model, Chattensor is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand.
'''

if args.validator == "taostats":
    print("TAOOOOSATATS")
    try:
        bitapai_key = os.environ["BITAPAI_KEY"]
        conn = http.client.HTTPSConnection("dashboard.bitapai.io")
        headers = {
        'Content-Type': 'application/json',
        'X-API-KEY': bitapai_key
        }
    except KeyError:
        raise RuntimeError(f"BITAPAI_KEY does not exist and chosen validator is taostats. Please set your bitapai key using export BITAPAI_KEY=x.")
        

def get_response(prompt):
    if args.validator == "taostats":
        payload = json.dumps({
        "system": default_prompt,
        "user": prompt
        })
        conn.request("POST", "/api/v1/prompt", payload, headers)
        res = conn.getresponse()
        data = res.read()
        # print('test')
        print(data)
        time.sleep(1)
        return data.decode("utf-8")
    else:
        return bt.prompt(prompt)

# Load all the LMEH tasks
tasks = ["hellaswag", "arc_challenge", "truthfulqa_mc", "hendrycksTest-abstract_algebra", "hendrycksTest-anatomy", "hendrycksTest-astronomy", "hendrycksTest-business_ethics", "hendrycksTest-clinical_knowledge", "hendrycksTest-college_biology", "hendrycksTest-college_chemistry", "hendrycksTest-college_computer_science", "hendrycksTest-college_mathematics", "hendrycksTest-college_medicine", "hendrycksTest-college_physics", "hendrycksTest-computer_security", "hendrycksTest-conceptual_physics", "hendrycksTest-econometrics", "hendrycksTest-electrical_engineering", "hendrycksTest-elementary_mathematics", "hendrycksTest-formal_logic", "hendrycksTest-global_facts", "hendrycksTest-high_school_biology", "hendrycksTest-high_school_chemistry", "hendrycksTest-high_school_computer_science", "hendrycksTest-high_school_european_history", "hendrycksTest-high_school_geography", "hendrycksTest-high_school_government_and_politics", "hendrycksTest-high_school_macroeconomics", "hendrycksTest-high_school_mathematics", "hendrycksTest-high_school_microeconomics", "hendrycksTest-high_school_physics", "hendrycksTest-high_school_psychology", "hendrycksTest-high_school_statistics", "hendrycksTest-high_school_us_history", "hendrycksTest-high_school_world_history", "hendrycksTest-human_aging", "hendrycksTest-human_sexuality", "hendrycksTest-international_law", "hendrycksTest-jurisprudence", "hendrycksTest-logical_fallacies", "hendrycksTest-machine_learning", "hendrycksTest-management", "hendrycksTest-marketing", "hendrycksTest-medical_genetics", "hendrycksTest-miscellaneous", "hendrycksTest-moral_disputes", "hendrycksTest-moral_scenarios", "hendrycksTest-nutrition", "hendrycksTest-philosophy", "hendrycksTest-prehistory", "hendrycksTest-professional_accounting", "hendrycksTest-professional_law", "hendrycksTest-professional_medicine", "hendrycksTest-professional_psychology", "hendrycksTest-public_relations", "hendrycksTest-security_studies", "hendrycksTest-sociology", "hendrycksTest-us_foreign_policy", "hendrycksTest-virology", "hendrycksTest-world_religions"]
task_dict = lm_eval.tasks.get_task_dict(tasks)
task_dict_items = [
    (name, task)
    for name, task in task_dict.items()
    if (task.has_validation_docs() or task.has_test_docs())
]
versions = collections.defaultdict(dict)

# get lists of each type of request
for task_name, task in task_dict_items:
    versions[task_name] = task.VERSION
    # default to test doc, fall back to val doc if validation unavailable
    # TODO: the test-fallback-to-val system isn't final, we should revisit it at some point
    if task.has_test_docs():
        task_doc_func = task.test_docs
        task_set = "test"  # Required for caching in the decontamination
    elif task.has_validation_docs():
        task_set = "val"  # Required for caching in the decontamination
        task_doc_func = task.validation_docs
    else:
        raise RuntimeError("Task has neither test_docs nor validation_docs")
        # deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
    task_docs = list(task_doc_func())
    rnd = random.Random()
    rnd.seed(42)
    rnd.shuffle(task_docs)

    i=0
    for task_doc in tqdm(task_docs):
        print(task_name)
        print(task_doc)
        if ("result" in task_doc) and ("inference_time" in task_doc) and ("prompt" in task_doc) and ("result" in task_doc) and (task_doc['result'] != ""):
            continue

        query = task_doc["query"] if "query" in task_doc else ""
        choices_list = "\n".join([str(number+1) + ". " + choice for number, choice in enumerate(task_doc["choices"])]) if "choices" in task_doc else ""
        number_list = ",".join([str(number) for number in range(1,len(task_doc["choices"])+1)]) if "choices" in task_doc else ""

        if (task_name == "hellaswag") :
            prompt = ""
            prompt_list = list(task.training_docs())[:10]
            for prompt_item in prompt_list:
                prompt_item_query = prompt_item["query"]
                prompt_item_choices_list = "\n".join([str(number+1) + ". " + choice for number, choice in enumerate(prompt_item["choices"])])
                prompt_item_number_list = ",".join([str(number) for number in range(1,len(prompt_item["choices"])+1)])
                prompt_item_gold = prompt_item["gold"]+1

                prompt += f"""{prompt_item_query}...\n{prompt_item_choices_list}\nRespond with just one number only: {prompt_item_number_list}.\n{prompt_item_gold}\n\n"""

            prompt += f"""{query}...\n{choices_list}\nRespond with just one number only: {number_list}. """
        
        elif (task_name == "arc_challenge"):
            prompt = ""
            prompt_list = list(task.training_docs())[:25]
            for prompt_item in prompt_list:
                prompt_item_query = prompt_item["query"]
                prompt_item_choices_list = "\n".join([str(number+1) + ". " + choice for number, choice in enumerate(prompt_item["choices"])])
                prompt_item_number_list = ",".join([str(number) for number in range(1,len(prompt_item["choices"])+1)])
                prompt_item_gold = prompt_item["gold"]+1

                prompt += f"""{prompt_item_query}...\n{prompt_item_choices_list}\nRespond with just one number only: {prompt_item_number_list}.\n{prompt_item_gold}\n\n"""

            prompt += f"""{query}...\n{choices_list}\nRespond with just one number only: {number_list}. """


        elif (task_name == "truthfulqa_mc"):
            continue
            prompt = ""
            
        elif ("hendrycksTest" in task_name):
            prompt = ""
            prompt_list = list(task.test_docs())[:5]
            for prompt_item in prompt_list:
                prompt_item_query = prompt_item["query"]

                prompt += f"""{prompt_item_query.replace("Answer:", "Respond with just one letter only: A, B, C, D:")}\n{["A", "B", "C", "D"][prompt_item["gold"]]}\n\n"""

            prompt += query.replace("Answer:", "Respond with just one letter only: A, B, C, D:")

        # print(prompt)

        start = time.time()
        task_doc["result"] = get_response(prompt)
        end = time.time()
        task_doc["inference_time"] = end - start
        task_doc["prompt"] = prompt
        task_doc["datetime"] = dt.now().strftime(format = "%Y-%m-%d %H:%M:%S")
        print(task_doc["result"])

        i = i + 1
        if ((i % 100) / 1000 == 0):
            with open(f"""_results/few-shot/{args.validator}/{task_name}_results.json""", "w") as final:
                json.dump(task_docs, final)

    with open(f"""_results/few-shot/{args.validator}/{task_name}_results.json""", "w") as final:
        json.dump(task_docs, final)