update submit
Browse files- src/submission/submit.py +178 -35
    	
        src/submission/submit.py
    CHANGED
    
    | @@ -2,8 +2,15 @@ import json | |
| 2 | 
             
            import os
         | 
| 3 | 
             
            from datetime import datetime, timezone
         | 
| 4 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 5 | 
             
            from src.display.formatting import styled_error, styled_message, styled_warning
         | 
| 6 | 
            -
            from src.envs import API, EVAL_REQUESTS_PATH, TOKEN, QUEUE_REPO
         | 
| 7 | 
             
            from src.submission.check_validity import (
         | 
| 8 | 
             
                already_submitted_models,
         | 
| 9 | 
             
                check_model_card,
         | 
| @@ -14,6 +21,130 @@ from src.submission.check_validity import ( | |
| 14 | 
             
            REQUESTED_MODELS = None
         | 
| 15 | 
             
            USERS_TO_SUBMISSION_DATES = None
         | 
| 16 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 17 | 
             
            def add_new_eval(
         | 
| 18 | 
             
                model: str,
         | 
| 19 | 
             
                base_model: str,
         | 
| @@ -21,6 +152,7 @@ def add_new_eval( | |
| 21 | 
             
                precision: str,
         | 
| 22 | 
             
                weight_type: str,
         | 
| 23 | 
             
                model_type: str,
         | 
|  | |
| 24 | 
             
            ):
         | 
| 25 | 
             
                global REQUESTED_MODELS
         | 
| 26 | 
             
                global USERS_TO_SUBMISSION_DATES
         | 
| @@ -72,48 +204,59 @@ def add_new_eval( | |
| 72 | 
             
                if not modelcard_OK:
         | 
| 73 | 
             
                    return styled_error(error_msg)
         | 
| 74 |  | 
| 75 | 
            -
                # Seems good, creating the eval
         | 
| 76 | 
            -
                print("Adding new eval")
         | 
| 77 | 
            -
             | 
| 78 | 
            -
                eval_entry = {
         | 
| 79 | 
            -
                    "model": model,
         | 
| 80 | 
            -
                    "base_model": base_model,
         | 
| 81 | 
            -
                    "revision": revision,
         | 
| 82 | 
            -
                    "precision": precision,
         | 
| 83 | 
            -
                    "weight_type": weight_type,
         | 
| 84 | 
            -
                    "status": "PENDING",
         | 
| 85 | 
            -
                    "submitted_time": current_time,
         | 
| 86 | 
            -
                    "model_type": model_type,
         | 
| 87 | 
            -
                    "likes": model_info.likes,
         | 
| 88 | 
            -
                    "params": model_size,
         | 
| 89 | 
            -
                    "license": license,
         | 
| 90 | 
            -
                    "private": False,
         | 
| 91 | 
            -
                }
         | 
| 92 | 
            -
             | 
| 93 | 
             
                # Check for duplicate submission
         | 
| 94 | 
             
                if f"{model}_{revision}_{precision}" in REQUESTED_MODELS:
         | 
| 95 | 
             
                    return styled_warning("This model has been already submitted.")
         | 
| 96 |  | 
| 97 | 
            -
                 | 
| 98 | 
            -
                 | 
| 99 | 
            -
             | 
| 100 | 
            -
             | 
|  | |
|  | |
|  | |
| 101 |  | 
| 102 | 
            -
                 | 
| 103 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 104 |  | 
| 105 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 106 | 
             
                API.upload_file(
         | 
| 107 | 
            -
                    path_or_fileobj= | 
| 108 | 
            -
                    path_in_repo= | 
| 109 | 
            -
                    repo_id= | 
| 110 | 
             
                    repo_type="dataset",
         | 
| 111 | 
            -
                    commit_message=f"Add {model} | 
| 112 | 
             
                )
         | 
| 113 |  | 
| 114 | 
            -
                # Remove the local file
         | 
| 115 | 
            -
                os.remove( | 
| 116 |  | 
| 117 | 
            -
                return styled_message(
         | 
| 118 | 
            -
                    "Your request has been submitted to the evaluation queue!\nPlease wait for up to an hour for the model to show in the PENDING list."
         | 
| 119 | 
            -
                )
         | 
|  | |
| 2 | 
             
            import os
         | 
| 3 | 
             
            from datetime import datetime, timezone
         | 
| 4 |  | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            import pandas as pd
         | 
| 7 | 
            +
            import numpy as np
         | 
| 8 | 
            +
            from datasets import load_dataset
         | 
| 9 | 
            +
            from transformers import AutoTokenizer, AutoModelForCausalLM
         | 
| 10 | 
            +
            from langchain.prompts import PromptTemplate
         | 
| 11 | 
            +
             | 
| 12 | 
             
            from src.display.formatting import styled_error, styled_message, styled_warning
         | 
| 13 | 
            +
            from src.envs import API, EVAL_REQUESTS_PATH, TOKEN, QUEUE_REPO, EVAL_RESULTS_PATH, RESULTS_REPO
         | 
| 14 | 
             
            from src.submission.check_validity import (
         | 
| 15 | 
             
                already_submitted_models,
         | 
| 16 | 
             
                check_model_card,
         | 
|  | |
| 21 | 
             
            REQUESTED_MODELS = None
         | 
| 22 | 
             
            USERS_TO_SUBMISSION_DATES = None
         | 
| 23 |  | 
| 24 | 
            +
            def get_top_prediction(text, tokenizer, model):
         | 
| 25 | 
            +
                inputs = tokenizer(text, return_tensors='pt')
         | 
| 26 | 
            +
                if torch.cuda.is_available():
         | 
| 27 | 
            +
                    model = model.cuda()
         | 
| 28 | 
            +
                    inputs = {k: v.cuda() for k, v in inputs.items()}
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                with torch.no_grad():
         | 
| 31 | 
            +
                    outputs = model(**inputs)
         | 
| 32 | 
            +
                    logits = outputs.logits[0, -1]
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                options = [' A', ' B', ' C', ' D']
         | 
| 35 | 
            +
                option_logits = []
         | 
| 36 | 
            +
                for option in options:
         | 
| 37 | 
            +
                    option_id = tokenizer(option).input_ids[-1]
         | 
| 38 | 
            +
                    option_logit = logits[option_id]
         | 
| 39 | 
            +
                    option_logits.append((option_logit.item(), option.strip()))
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                # Get the option with the highest logit
         | 
| 42 | 
            +
                top_option = max(option_logits, key=lambda x: x[0])[1]
         | 
| 43 | 
            +
                return top_option
         | 
| 44 | 
            +
             | 
| 45 | 
            +
            def evaluate_model_accuracy(model_name, num_examples):
         | 
| 46 | 
            +
                try:
         | 
| 47 | 
            +
                    # Load the model and tokenizer
         | 
| 48 | 
            +
                    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
         | 
| 49 | 
            +
                    tokenizer.pad_token = tokenizer.eos_token
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                    model = AutoModelForCausalLM.from_pretrained(
         | 
| 52 | 
            +
                        model_name,
         | 
| 53 | 
            +
                        trust_remote_code=True
         | 
| 54 | 
            +
                    )
         | 
| 55 | 
            +
                    if torch.cuda.is_available():
         | 
| 56 | 
            +
                        model = model.cuda()  # Move model to GPU if available
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    # Load your dataset
         | 
| 59 | 
            +
                    dataset = load_dataset("Omartificial-Intelligence-Space/Arabic_Openai_MMMLU")
         | 
| 60 | 
            +
                    dataset = dataset['test']
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                    # Convert the dataset to a pandas DataFrame for easier manipulation
         | 
| 63 | 
            +
                    df_dataset = dataset.to_pandas()
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    # Get list of unique subjects
         | 
| 66 | 
            +
                    subjects = df_dataset['Subject'].unique()
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    # Define prompt template
         | 
| 69 | 
            +
                    template = """Answer the following multiple choice question by giving the most appropriate response. Answer should be one among [A, B, C, D].
         | 
| 70 | 
            +
             | 
| 71 | 
            +
            Question: {Question}
         | 
| 72 | 
            +
            A) {A}
         | 
| 73 | 
            +
            B) {B}
         | 
| 74 | 
            +
            C) {C}
         | 
| 75 | 
            +
            D) {D}
         | 
| 76 | 
            +
             | 
| 77 | 
            +
            Answer:"""
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                    prompt_template = PromptTemplate(template=template, input_variables=['Question', 'A', 'B', 'C', 'D'])
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                    # Initialize counters and results
         | 
| 82 | 
            +
                    overall_correct_predictions = 0
         | 
| 83 | 
            +
                    overall_total_questions = 0
         | 
| 84 | 
            +
                    per_subject_results = []
         | 
| 85 | 
            +
                    detailed_results = []
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    for subject in subjects:
         | 
| 88 | 
            +
                        # Filter dataset for the current subject
         | 
| 89 | 
            +
                        subject_df = df_dataset[df_dataset['Subject'] == subject]
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                        # Select up to num_examples questions
         | 
| 92 | 
            +
                        subject_df = subject_df.sample(n=min(num_examples, len(subject_df)), random_state=42)
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                        # Initialize counters for this subject
         | 
| 95 | 
            +
                        correct_predictions = 0
         | 
| 96 | 
            +
                        total_questions = 0
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                        for idx, data in subject_df.iterrows():
         | 
| 99 | 
            +
                            # Prepare text input
         | 
| 100 | 
            +
                            text = prompt_template.format(
         | 
| 101 | 
            +
                                Question=data['Question'],
         | 
| 102 | 
            +
                                A=data['A'],
         | 
| 103 | 
            +
                                B=data['B'],
         | 
| 104 | 
            +
                                C=data['C'],
         | 
| 105 | 
            +
                                D=data['D']
         | 
| 106 | 
            +
                            )
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                            # Get the top prediction
         | 
| 109 | 
            +
                            top_prediction = get_top_prediction(text, tokenizer, model)
         | 
| 110 | 
            +
                            is_correct = (top_prediction == data['Answer'])
         | 
| 111 | 
            +
                            correct_predictions += int(is_correct)
         | 
| 112 | 
            +
                            total_questions += 1
         | 
| 113 | 
            +
                            overall_correct_predictions += int(is_correct)
         | 
| 114 | 
            +
                            overall_total_questions +=1
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                            detailed_results.append({
         | 
| 117 | 
            +
                                'Subject': subject,
         | 
| 118 | 
            +
                                'Question': data['Question'],
         | 
| 119 | 
            +
                                'Answer': data['Answer'],
         | 
| 120 | 
            +
                                'Prediction': top_prediction,
         | 
| 121 | 
            +
                                'Correct': is_correct
         | 
| 122 | 
            +
                            })
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                        # Compute accuracy for this subject
         | 
| 125 | 
            +
                        subject_accuracy = correct_predictions / total_questions if total_questions > 0 else 0
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                        per_subject_results.append({
         | 
| 128 | 
            +
                            'Subject': subject,
         | 
| 129 | 
            +
                            'Total Score': correct_predictions,
         | 
| 130 | 
            +
                            'Total Questions': total_questions,
         | 
| 131 | 
            +
                            'Accuracy (%)': subject_accuracy * 100
         | 
| 132 | 
            +
                        })
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                    # Compute overall accuracy
         | 
| 135 | 
            +
                    overall_accuracy = overall_correct_predictions / overall_total_questions if overall_total_questions > 0 else 0
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                    # Convert per_subject_results to DataFrame
         | 
| 138 | 
            +
                    df_per_subject = pd.DataFrame(per_subject_results)
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                    # Convert detailed_results to DataFrame
         | 
| 141 | 
            +
                    df_detailed_results = pd.DataFrame(detailed_results)
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                    return overall_accuracy, df_per_subject, df_detailed_results
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                except Exception as e:
         | 
| 146 | 
            +
                    return f"Error: {str(e)}", pd.DataFrame(), pd.DataFrame()
         | 
| 147 | 
            +
             | 
| 148 | 
             
            def add_new_eval(
         | 
| 149 | 
             
                model: str,
         | 
| 150 | 
             
                base_model: str,
         | 
|  | |
| 152 | 
             
                precision: str,
         | 
| 153 | 
             
                weight_type: str,
         | 
| 154 | 
             
                model_type: str,
         | 
| 155 | 
            +
                num_examples: int  # New parameter
         | 
| 156 | 
             
            ):
         | 
| 157 | 
             
                global REQUESTED_MODELS
         | 
| 158 | 
             
                global USERS_TO_SUBMISSION_DATES
         | 
|  | |
| 204 | 
             
                if not modelcard_OK:
         | 
| 205 | 
             
                    return styled_error(error_msg)
         | 
| 206 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 207 | 
             
                # Check for duplicate submission
         | 
| 208 | 
             
                if f"{model}_{revision}_{precision}" in REQUESTED_MODELS:
         | 
| 209 | 
             
                    return styled_warning("This model has been already submitted.")
         | 
| 210 |  | 
| 211 | 
            +
                # Now, perform the evaluation
         | 
| 212 | 
            +
                try:
         | 
| 213 | 
            +
                    overall_accuracy, df_per_subject, df_detailed_results = evaluate_model_accuracy(model, int(num_examples))
         | 
| 214 | 
            +
                    if isinstance(overall_accuracy, str) and overall_accuracy.startswith("Error"):
         | 
| 215 | 
            +
                        return styled_error(overall_accuracy)
         | 
| 216 | 
            +
                except Exception as e:
         | 
| 217 | 
            +
                    return styled_error(f"An error occurred during evaluation: {str(e)}")
         | 
| 218 |  | 
| 219 | 
            +
                # Prepare results for storage
         | 
| 220 | 
            +
                results_dict = {
         | 
| 221 | 
            +
                    "config": {
         | 
| 222 | 
            +
                        "model_name": model,
         | 
| 223 | 
            +
                        "model_sha": revision,
         | 
| 224 | 
            +
                        "model_dtype": precision,
         | 
| 225 | 
            +
                        "submitted_time": current_time,
         | 
| 226 | 
            +
                        "model_type": model_type,
         | 
| 227 | 
            +
                        "weight_type": weight_type,
         | 
| 228 | 
            +
                        "license": license,
         | 
| 229 | 
            +
                        "likes": model_info.likes,
         | 
| 230 | 
            +
                        "params": model_size,
         | 
| 231 | 
            +
                        "still_on_hub": True,
         | 
| 232 | 
            +
                        "precision": precision,
         | 
| 233 | 
            +
                    },
         | 
| 234 | 
            +
                    "results": {
         | 
| 235 | 
            +
                        "average": overall_accuracy * 100,
         | 
| 236 | 
            +
                    },
         | 
| 237 | 
            +
                }
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                # Include per-subject accuracies
         | 
| 240 | 
            +
                for idx, row in df_per_subject.iterrows():
         | 
| 241 | 
            +
                    subject_name = row['Subject']
         | 
| 242 | 
            +
                    accuracy = row['Accuracy (%)']
         | 
| 243 | 
            +
                    results_dict['results'][subject_name] = accuracy
         | 
| 244 |  | 
| 245 | 
            +
                # Save results to a JSON file
         | 
| 246 | 
            +
                results_file_path = f"{EVAL_RESULTS_PATH}/{model.replace('/', '_')}_results.json"
         | 
| 247 | 
            +
                with open(results_file_path, "w") as f:
         | 
| 248 | 
            +
                    json.dump(results_dict, f)
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                # Upload the results file
         | 
| 251 | 
             
                API.upload_file(
         | 
| 252 | 
            +
                    path_or_fileobj=results_file_path,
         | 
| 253 | 
            +
                    path_in_repo=results_file_path.split(f"{EVAL_RESULTS_PATH}/")[1],
         | 
| 254 | 
            +
                    repo_id=RESULTS_REPO,
         | 
| 255 | 
             
                    repo_type="dataset",
         | 
| 256 | 
            +
                    commit_message=f"Add results for {model}"
         | 
| 257 | 
             
                )
         | 
| 258 |  | 
| 259 | 
            +
                # Remove the local results file
         | 
| 260 | 
            +
                os.remove(results_file_path)
         | 
| 261 |  | 
| 262 | 
            +
                return styled_message("Your model has been evaluated and the results are now on the leaderboard!")
         | 
|  | |
|  | 
