Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from dataclasses import dataclass | |
| from concurrent.futures import ThreadPoolExecutor, TimeoutError | |
| import os | |
| import re | |
| import subprocess | |
| import tempfile | |
| import json | |
| import datasets | |
| import random | |
| from typing import Tuple, Dict, Any, List | |
| from sympy import N, simplify | |
| from sympy.parsing.latex import parse_latex | |
| from openai import OpenAI | |
| import base64 | |
| client = OpenAI( | |
| base_url=os.environ.get("SERVER_URL"), | |
| api_key=os.environ.get("HF_TOKEN"), | |
| ) | |
| class Config: | |
| model_id: str # SELECT MODEL | |
| revision: str # SELECT REVISION | |
| # Append an optional system prompt to each problem | |
| system_prompt: str | |
| # Number of samples to generate per problem | |
| num_samples: int | |
| num_generations: int | |
| # Generation parameters | |
| do_sample: bool | |
| temperature: float | |
| top_p: float | |
| top_k: int | |
| max_new_tokens: int | |
| restart_on_fail: bool | |
| # Enable 4-bit quantization | |
| is_quantized: bool | |
| # Run on train or test data? | |
| is_submission: bool = True if os.getenv("KAGGLE_IS_COMPETITION_RERUN") else False | |
| validation_set: str = "kaggle-validation-set-medium" | |
| notebook_time_limit: int = 9 * 60 * 60 - 15 * 60 # 9 hours - 15 minute buffer | |
| # Debug by solving only the first problem | |
| debug: bool = False | |
| # Push solutions to the Hub | |
| push_to_hub: bool = False | |
| class PythonREPL: | |
| def __init__(self, timeout=5): | |
| self.timeout = timeout | |
| def execute(self, query: str) -> Tuple[bool, str]: | |
| query = "import math\nimport numpy as np\nimport sympy as sp\n" + query | |
| query = query.strip().split("\n") | |
| if "print(" not in query[-1]: | |
| if "#" in query[-1]: | |
| query[-1] = query[-1].split("#")[0] | |
| query[-1] = "print(" + query[-1] + ")" | |
| query = "\n".join(query) | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| temp_file_path = os.path.join(temp_dir, "tmp.py") | |
| with open(temp_file_path, "w") as f: | |
| f.write(query) | |
| result = subprocess.run( | |
| ["python3", temp_file_path], | |
| capture_output=True, | |
| check=False, | |
| text=True, | |
| timeout=self.timeout, | |
| ) | |
| if result.returncode == 0: | |
| output = result.stdout | |
| return True, output.strip() | |
| else: | |
| error_msg = result.stderr.strip() | |
| msgs = error_msg.split("\n") | |
| new_msgs = [] | |
| want_next = False | |
| for m in msgs: | |
| if "Traceback" in m: | |
| new_msgs.append(m) | |
| elif m == msgs[-1]: | |
| new_msgs.append(m) | |
| elif temp_file_path in m: | |
| st = m.index('"/') + 1 if '"/' in m else 0 | |
| ed = m.index(temp_file_path) + 1 if temp_file_path in m else None | |
| clr = m[st:ed] if not ed else m[st:] | |
| m = m.replace(clr, "") | |
| new_msgs.append(m) | |
| want_next = True | |
| elif want_next: | |
| new_msgs.append(m) | |
| want_next = False | |
| error_msg = "\n".join(new_msgs) | |
| return False, error_msg.strip() | |
| def __call__(self, query: str) -> Tuple[bool, str]: | |
| with ThreadPoolExecutor() as executor: | |
| future = executor.submit(self.execute, query) | |
| try: | |
| return future.result(timeout=self.timeout) | |
| except TimeoutError: | |
| return False, f"Timed out after {self.timeout} seconds." | |
| def execute_completion( | |
| executor: PythonREPL, | |
| completion: str, | |
| return_status: bool = False, | |
| last_code_block: bool = False, | |
| ) -> str | Tuple[str, bool]: | |
| # executions = ["!" + code for code in re.findall(r"```bash(.*?)```", completion, re.DOTALL) if "!" not in code] | |
| executions = re.findall(r"```python(.*?)```", completion, re.DOTALL) | |
| if len(executions) == 0: # directly return cot result | |
| return completion, False if return_status else completion | |
| else: | |
| if last_code_block: | |
| executions = [executions[-1]] | |
| # Python | |
| execution_outputs = [] | |
| successes = [] | |
| for code in executions: | |
| success = False | |
| if "subprocess" in code: | |
| output = "subprocess is not allowed" | |
| execution_outputs.append(output) | |
| successes.append(success) | |
| continue | |
| if "venv" in code: | |
| output = "venv is not allowed" | |
| execution_outputs.append(output) | |
| successes.append(success) | |
| continue | |
| try: | |
| success, output = executor(code) | |
| except TimeoutError as e: | |
| print("time out") | |
| output = e | |
| if not success and not return_status: | |
| output = "" | |
| execution_outputs.append(output) | |
| successes.append(success) | |
| output = str(execution_outputs[-1]).strip() | |
| success = successes[-1] | |
| if return_status: | |
| return output, success | |
| else: | |
| return output | |
| def postprocess_completion( | |
| text: str, return_status: bool = False, last_code_block=False, timeout=5 | |
| ) -> str | Tuple[str, bool]: | |
| executor = PythonREPL(timeout=timeout) | |
| result = execute_completion(executor, text, return_status=return_status, last_code_block=last_code_block) | |
| del executor | |
| return result | |
| def apply_template(example: Dict[str, Any], prompt: str) -> Dict[str, Any]: | |
| return prompt.format(example["prompt"], "{}") | |
| def last_boxed_only_string(string): | |
| """ | |
| Extracts the last LaTeX boxed or framed expression from a string. | |
| Args: | |
| string (str): The input string containing LaTeX expressions. | |
| Returns: | |
| str or None: The last boxed or framed expression, if found; | |
| otherwise, None. | |
| """ | |
| idx = string.rfind("\\boxed") | |
| if idx < 0: | |
| idx = string.rfind("\\fbox") | |
| if idx < 0: | |
| return None | |
| i = idx | |
| right_brace_idx = None | |
| num_left_braces_open = 0 | |
| while i < len(string): | |
| if string[i] == "{": | |
| num_left_braces_open += 1 | |
| if string[i] == "}": | |
| num_left_braces_open -= 1 | |
| if num_left_braces_open == 0: | |
| right_brace_idx = i | |
| break | |
| i += 1 | |
| if right_brace_idx is None: | |
| retval = None | |
| else: | |
| retval = string[idx : right_brace_idx + 1] | |
| return retval | |
| def remove_boxed(s): | |
| """ | |
| Removes the LaTeX boxed command, returning the content inside the braces. | |
| Args: | |
| s (str): The string containing a LaTeX boxed expression. | |
| Returns: | |
| str or None: The content inside the boxed command, if valid; | |
| otherwise, None. | |
| """ | |
| left = "\\boxed{" | |
| try: | |
| assert s[: len(left)] == left | |
| assert s[-1] == "}" | |
| length = len(left) | |
| return s[length:-1] | |
| except Exception: | |
| return None | |
| def extract_boxed_answer(pred_str, strip_double_curly_brace=False): | |
| """ | |
| Extracts the answer from a LaTeX boxed expression within | |
| a prediction string. | |
| Args: | |
| pred_str (str): The string containing one or more LaTeX | |
| boxed expressions. | |
| strip_double_curly_brace (bool): If True, removes an additional | |
| layer of braces. | |
| Returns: | |
| str or None: The extracted answer, if any; otherwise, None. | |
| """ | |
| boxed_str = last_boxed_only_string(pred_str) | |
| if boxed_str is None: | |
| return None | |
| answer = remove_boxed(boxed_str) | |
| if answer is None: | |
| return None | |
| if strip_double_curly_brace: | |
| match = re.match("^\{(.*)\}$", answer) # noqa: W605 | |
| if match: | |
| answer = match.group(1) | |
| return answer | |
| def normalize_final_answer(final_answer: str) -> str: | |
| """ | |
| Normalizes a final answer string by removing or replacing various LaTeX | |
| and text elements. | |
| Args: | |
| final_answer (str): The answer string to normalize. | |
| Returns: | |
| str: The normalized answer string. | |
| """ | |
| match = re.search(r"(.*?)Problem:", final_answer, flags=re.S) | |
| if match: | |
| final_answer = match.group(1) # 返回匹配的第一部分,即"Problem"之前的所有文本 | |
| """Normalize a final answer to a quantitative reasoning question.""" | |
| # final_answer = final_answer.split('=')[-1] | |
| SUBSTITUTIONS = [ | |
| ("an ", ""), | |
| ("a ", ""), | |
| (".$", "$"), | |
| ("\\$", ""), | |
| (r"\ ", ""), | |
| (" ", ""), | |
| ("mbox", "text"), | |
| (",\\text{and}", ","), | |
| ("\\text{and}", ","), | |
| ("\\text{m}", "\\text{}"), | |
| ("\\le", "<"), | |
| ] | |
| REMOVED_EXPRESSIONS = [ | |
| "square", | |
| "ways", | |
| "integers", | |
| "dollars", | |
| "mph", | |
| "inches", | |
| "ft", | |
| "hours", | |
| "km", | |
| "units", | |
| "\\ldots", | |
| "sue", | |
| "points", | |
| "feet", | |
| "minutes", | |
| "digits", | |
| "cents", | |
| "degrees", | |
| "cm", | |
| "gm", | |
| "pounds", | |
| "meters", | |
| "meals", | |
| "edges", | |
| "students", | |
| "childrentickets", | |
| "multiples", | |
| "\\text{s}", | |
| "\\text{.}", | |
| "\\text{\ns}", | |
| "\\text{}^2", | |
| "\\text{}^3", | |
| "\\text{\n}", | |
| "\\text{}", | |
| r"\mathrm{th}", | |
| r"^\circ", | |
| r"^{\circ}", | |
| r"\;", | |
| r",\!", | |
| "{,}", | |
| '"', | |
| "\\dots", | |
| "\n", | |
| "\r", | |
| "\f", | |
| "\%", | |
| ] | |
| for before, after in SUBSTITUTIONS: | |
| final_answer = final_answer.replace(before, after) | |
| for expr in REMOVED_EXPRESSIONS: | |
| final_answer = final_answer.replace(expr, "") | |
| # Extract answer that is in LaTeX math, is bold, | |
| # is surrounded by a box, etc. | |
| final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer) | |
| final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer) | |
| final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer) | |
| final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer) | |
| assert "\n" not in final_answer | |
| assert "\r" not in final_answer | |
| assert "\f" not in final_answer | |
| if len(re.findall(r"finalansweris(.*)", final_answer)) > 0: | |
| final_answer = re.findall(r"finalansweris(.*)", final_answer)[-1] | |
| if len(re.findall(r"answer?is:?(.*)", final_answer)) > 0: | |
| final_answer = re.findall(r"answer?is:?(.*)", final_answer)[-1] | |
| if len(re.findall(r"oxed\{(.*?)\}", final_answer)) > 0: | |
| final_answer = re.findall(r"oxed\{(.*?)\}", final_answer)[-1] | |
| if len(re.findall(r"\$(.*?)\$", final_answer)) > 0: | |
| final_answer = re.findall(r"\$(.*?)\$", final_answer)[-1] | |
| final_answer = final_answer.strip() | |
| if "rac" in final_answer and "\\frac" not in final_answer: | |
| final_answer = final_answer.replace("rac", "\\frac") | |
| final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer) | |
| final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer) | |
| final_answer = final_answer.replace("$", "") | |
| if final_answer.replace(",", "").isdigit(): | |
| final_answer = final_answer.replace(",", "") | |
| return final_answer | |
| def naive_parse(answer: str) -> str: | |
| """ | |
| Extracts and returns the numeric digits from the input string, processing them in reverse order | |
| until a non-numeric character is encountered after encountering the first numeric character. | |
| Args: | |
| answer (str): The input string to parse. | |
| Returns: | |
| str: A string consisting of the numeric digits extracted from the input, in their original order. | |
| Example: | |
| >>> naive_parse("abc123def") | |
| '123' | |
| >>> naive_parse("def456ghi") | |
| '456' | |
| >>> naive_parse("no numbers here") | |
| '' | |
| """ | |
| out = [] | |
| start = False | |
| end = False | |
| for l in reversed(list(answer)): | |
| if l in "0123456789" and not end: | |
| start = True | |
| out.append(l) | |
| else: | |
| if start: | |
| end = True | |
| out = reversed(out) | |
| return "".join(out) | |
| def validate_answer_is_numeric(x: str | int | float) -> int: | |
| FLOAT_TOLERANCE = 0.2 | |
| try: | |
| x = round(float(x)) | |
| f = float(x) | |
| if abs(x - f) > FLOAT_TOLERANCE: | |
| x = -1 | |
| except Exception: | |
| x = -1 | |
| return x | |
| def filter_answers(answers: List[str]) -> List[int]: | |
| formatted_answers = [validate_answer_is_numeric(a) for a in answers] | |
| # Filter for non-negative answers | |
| formatted_answers = [a for a in formatted_answers if a >= 0] | |
| # Compute modulo | |
| formatted_answers = [a % 1_000 for a in formatted_answers] | |
| # less than 2.1 billion or cannot convert to C int (32-bit) | |
| formatted_answers = [a for a in formatted_answers if a <= 999] | |
| return formatted_answers | |
| def check_sympy_equivalence(ref_answer: str, model_answer: str) -> bool: | |
| def do_answers_match(ref_answer: str, model_answer: str) -> bool: | |
| ref_sympy = parse_latex(ref_answer) | |
| model_sympy = parse_latex(model_answer) | |
| diff = simplify(ref_sympy - model_sympy) | |
| return True if -1e-12 < N(diff) < 1e-12 or diff.is_zero else False | |
| try: | |
| result = do_answers_match(ref_answer, model_answer) | |
| return result | |
| except Exception as e: | |
| print(e) | |
| return False | |
| def check_string_match(ref_answer: str, model_answer: str) -> bool: | |
| try: | |
| return ref_answer == model_answer | |
| except Exception as e: | |
| print(e) | |
| return False | |
| def check_answer(ref_answer: str, model_answer: str) -> bool: | |
| # check if strings are the same | |
| correct = check_string_match(ref_answer, model_answer) | |
| if correct: | |
| return True | |
| # use the sympy library to check if the expressions are the same | |
| correct = check_sympy_equivalence(ref_answer, model_answer) | |
| if correct: | |
| return True | |
| return False | |
| debug = False | |
| model_id = "Numina-Math-7B" | |
| revision = "main" | |
| system_prompt = "{}" | |
| validation_set = "kaggle-validation-set-medium" | |
| is_submission = True | |
| num_samples = 4 | |
| num_generations = 4 | |
| temperature = 0.8 | |
| is_quantized = False | |
| restart_on_fail = False | |
| top_p = 1.0 | |
| top_k = 0 | |
| max_new_tokens = 2048 | |
| # Papermill related variables | |
| push_to_hub = False | |
| notebook_name = "" | |
| config = Config( | |
| debug=debug, | |
| push_to_hub=push_to_hub, | |
| model_id=model_id, | |
| revision=revision, | |
| system_prompt=system_prompt, | |
| validation_set=validation_set, | |
| is_quantized=is_quantized, | |
| restart_on_fail=restart_on_fail, | |
| is_submission=is_submission, | |
| num_samples=num_samples, | |
| num_generations=num_generations, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| max_new_tokens=max_new_tokens, | |
| ) | |
| print(f"=== Running submission with config ===\n\n{config}") | |
| def generate(message, temperature): | |
| """ | |
| Generates a chat completion response by streaming data from the client chat model. | |
| This function streams the response from the client chat model and yields the content | |
| of the response chunk by chunk. If an error occurs, it yields the error message. | |
| Parameters: | |
| message (str): The input message to be sent to the chat model. | |
| temperature (float): The sampling temperature to use. Higher values mean the model will take more risks. | |
| Yields: | |
| tuple: A tuple containing the content of the response and a boolean flag indicating if an error occurred. | |
| If no error occurred, the boolean flag will be False and the content will be the response text. | |
| If an error occurred, the boolean flag will be True and the content will be the error message. | |
| """ | |
| stream = client.chat.completions.create( | |
| model="tgi", | |
| messages=message, | |
| stream=True, | |
| max_tokens=1024, | |
| stop=["```output\n"], | |
| temperature=temperature, | |
| timeout=30, | |
| ) | |
| response = stream.response | |
| # The reason why the library method is not used here is that if an error occurs, | |
| # the returned data will not be a stream, and using the official library will result in an error. | |
| for chunk in response.iter_bytes(): | |
| chunk = chunk.decode("utf-8") | |
| chune_json = json.loads(chunk.replace("data:", "")) | |
| try: | |
| if "error" in chune_json and chune_json["error"]: | |
| yield chune_json["error"], True | |
| break | |
| content = chune_json["choices"][0]["delta"]["content"] | |
| if content is not None: | |
| yield content, False | |
| except Exception as e: | |
| print(f"func: generate error occurred\njson:{chune_json}\nerror:{e}") | |
| yield "", True | |
| def get_majority_text(data): | |
| from collections import Counter | |
| # Count the frequency of each answer in model_answers | |
| answer_counts = Counter(data["model_answers"]) | |
| # Find the majority response | |
| majority_response = answer_counts.most_common(1)[0][0] | |
| # Find the index of the first occurrence of the majority response | |
| majority_index = data["model_answers"].index(majority_response) | |
| # Return the corresponding text in gen_texts | |
| return data["gen_texts"][majority_index] | |
| def extract_solution(text): | |
| # Split the text at "### Solution:" | |
| parts = text.split("### Solution:", 1) | |
| if len(parts) > 1: | |
| # Return everything after "### Solution:" | |
| return parts[1].strip() | |
| else: | |
| # Return an empty string if "### Solution:" is not found | |
| return "" | |
| def process_code( | |
| example: Dict[str, Any], | |
| config: Config, | |
| restart_on_fail: bool = False, | |
| last_step: bool = False, | |
| ) -> Dict[str, Any]: | |
| gen_text = example["gen_texts"] | |
| num_python_blocks = len(re.findall(r"```python(.*?)```", gen_text, re.DOTALL)) | |
| if num_python_blocks == 0: | |
| if restart_on_fail: | |
| print("no code has ever been generated, RESTARTING") | |
| # reset the text to the original | |
| example["gen_texts"] = example["text"] | |
| else: | |
| print("no code has ever been generated, STOP") | |
| example["should_prune"] = True | |
| example["has_code"] = False | |
| return example | |
| if gen_text[-10:] != "```output\n" and ("answer is" in gen_text[-100:] or "\\boxed" in gen_text[-100:]): | |
| num_output_blocks = len(re.findall(r"```output(.*?)```", gen_text, re.DOTALL)) | |
| if num_output_blocks == 0: | |
| print("the model hallucinated the code answer") | |
| example["should_prune"] = True | |
| return example | |
| if "boxed" in gen_text[-100:]: | |
| try: | |
| answer = normalize_final_answer(extract_boxed_answer(gen_text[-100:])) | |
| except Exception: | |
| answer = "-1" | |
| else: | |
| answer = normalize_final_answer(gen_text[-100:]) | |
| example["model_answers"] = answer | |
| if not config.is_submission: | |
| example["corrects"] = check_answer(example["ground_truth"], answer) | |
| example["should_prune"] = True | |
| print("Answer is: ", answer, example["ground_truth"], example["corrects"]) | |
| return example | |
| if last_step: | |
| # no point in continuing if we are at the last step | |
| return example | |
| if gen_text[-10:] != "```output\n": | |
| # something else has gone wrong with the generation | |
| print("warning: output block not found: ", gen_text[-40:]) | |
| if restart_on_fail: | |
| example["gen_texts"] = example["text"] | |
| else: | |
| example["should_prune"] = True | |
| return example | |
| code_result, status = postprocess_completion(gen_text, return_status=True, last_code_block=True) | |
| # add the code result for the next round of generation | |
| TRUNCATION_LIMIT = 200 | |
| if len(code_result) > TRUNCATION_LIMIT: | |
| code_result = code_result[:TRUNCATION_LIMIT] + " ... (output truncated)" | |
| example["gen_texts"] = gen_text + f"{code_result}\n```" | |
| return example | |
| def solve_problem(problem, temperature, progress=gr.Progress()): | |
| problem = apply_template({"prompt": problem}, prompt=config.system_prompt) | |
| print(f"Problem: {problem}") | |
| sample = { | |
| "problem": problem, # not used for the submission TODO Remove | |
| "ground_truth": "unknown", # not used for the submission TODO Remove | |
| "text": "## Solution:\n", | |
| "gen_texts": "## Solution:\n", # used to store all the generated text | |
| "should_prune": False, | |
| "problem_index": -1, # not used for the submission TODO Remove | |
| "model_answers": "-1", | |
| "has_code": True, | |
| "corrects": False, # not used for the submission TODO Remove | |
| } | |
| for step in progress.tqdm( | |
| range(config.num_generations), desc="Generating candidates" | |
| ): # Depth of the tree (e.g. 6 steps = 5 code blocks) | |
| step_reponse = sample["gen_texts"] | |
| messages = [ | |
| {"role": "user", "content": sample["problem"]}, | |
| {"role": "assistant", "content": sample["gen_texts"]}, | |
| ] | |
| for reponse_message, error in generate(messages, temperature): | |
| if reponse_message is not None: | |
| step_reponse += reponse_message | |
| yield step_reponse | |
| if error: | |
| return | |
| sample["gen_texts"] = step_reponse | |
| # TODO: Maybe it should just return the result of running the code | |
| sample = process_code( | |
| sample, | |
| config=config, | |
| restart_on_fail=config.restart_on_fail, | |
| last_step=(step == (config.num_generations - 1)), | |
| ) | |
| sample["gen_texts"] = sample["gen_texts"] + "\n" | |
| run_code_reponse = sample["gen_texts"].replace(step_reponse, "") | |
| for output_mseeage in run_code_reponse: | |
| if output_mseeage is not None: | |
| step_reponse += output_mseeage | |
| yield step_reponse | |
| if sample["should_prune"]: | |
| break | |
| yield sample["gen_texts"] | |
| example_data = datasets.load_dataset( | |
| "AI-MO/kaggle-validation-set-medium-extended", | |
| split="train", | |
| use_auth_token=os.environ.get("HF_DATASET_TOKEN", None), | |
| ) | |
| with open("app.css", "r") as f: | |
| css = f.read() | |
| latex_delimiters = [ | |
| {"left": "[", "right": "]", "display": True}, | |
| ] | |
| def get_random_problem(): | |
| example = random.choice(list(example_data)) | |
| problem = example["problem"] | |
| return problem | |
| def update_example_problem(): | |
| problem_example_text = get_random_problem() | |
| problem_example_text_display = ( | |
| problem_example_text[:100] + "..." if len(problem_example_text) > 100 else problem_example_text | |
| ) | |
| return problem_example_text_display, problem_example_text | |
| def clear(): | |
| problem_example_text, problem_example_text_display = update_example_problem() | |
| return "", 0.1, "", problem_example_text_display, problem_example_text | |
| with gr.Blocks(css=css, title="Math Olympiad Solver") as demo: | |
| problem_example_text, problem_example_text_display = update_example_problem() | |
| with gr.Row(elem_classes="title"): | |
| gr.HTML("Math Olympiad Solver", elem_classes="title-content") | |
| with gr.Row(elem_classes="sub-title"): | |
| gr.HTML("Here may need to add some description string", elem_classes="sub-title-content") | |
| with gr.Row(elem_classes="main-area"): | |
| with gr.Column(scale=1, elem_classes="left"): | |
| with gr.Row(elem_classes="probelm-example-container"): | |
| with gr.Blocks(elem_classes="probelm-example-title"): | |
| gr.HTML("Problem example", elem_classes="probelm-example-title-content") | |
| with gr.Blocks(elem_classes="action-container"): | |
| gr.Button( | |
| "", | |
| elem_classes="probelm-example-another", | |
| icon="./static/images/reset.png", | |
| ) | |
| copy_btn = gr.Button("Copy", elem_classes="probelm-example-copy") | |
| with gr.Row(elem_classes="copy-icon-container"): | |
| problem_example = gr.Markdown( | |
| problem_example_text_display, | |
| latex_delimiters=latex_delimiters, | |
| elem_classes="probelm-example-content", | |
| ) | |
| inp = gr.Textbox(placeholder="Problem", label="Problem input", lines=5) | |
| with gr.Accordion("Advanced Options", open=False): | |
| temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.1, step=0.1, label="Temperature") | |
| with gr.Row(): | |
| btn_run = gr.Button("Run") | |
| btn_clear = gr.Button("Clear") | |
| with gr.Column(scale=1, elem_classes="right"): | |
| gr.HTML("Solution", elem_classes="solution-title-content") | |
| out = gr.Markdown(elem_classes="solution-content", latex_delimiters=latex_delimiters) | |
| problem_example_text_hidden = gr.Markdown(value=problem_example_text, visible=False) | |
| btn_run.click(fn=solve_problem, inputs=[inp, temperature], outputs=out) | |
| copy_btn.click(fn=lambda example: example, inputs=[problem_example_text_hidden], outputs=[inp]) | |
| btn_clear.click( | |
| fn=clear, | |
| inputs=[], | |
| outputs=[ | |
| inp, | |
| temperature, | |
| out, | |
| problem_example, | |
| problem_example_text_hidden, | |
| ], | |
| ) | |
| demo.load( | |
| update_example_problem, | |
| inputs=None, | |
| outputs=[ | |
| problem_example, | |
| problem_example_text_hidden, | |
| ], | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(default_concurrency_limit=5).launch() | |