#!/usr/bin/env python3 """ what_comes_next.py – Hugging Face Space implementation of **What Comes Next** A slow, contemplative global guessing game. 🔮 HOW IT WORKS 🔮 • A single Llama‑3.1‑8B‑Instruct model (FP32 on CPU) is generating one very long completion for a chosen mystical prompt. It runs continuously in the background for everyone. • Any visitor sees the same prompt and the Oracle’s current partial response. • Players may submit *one* of two kinds of guesses: 1. 🧠 **Exact Completion** – the full sentence/paragraph they think the Oracle will eventually write. 2. 💡 **General Idea** – a short summary of the direction or theme they expect. • Each guess is recorded immediately (with timestamp, Oracle progress, etc.) to `data.json` (JSON‑Lines). When the Oracle finally finishes, offline evaluation can score the guesses against the final text. The game then moves on to the next prompt and the cycle repeats. """ import os import json import time import random import threading import logging from datetime import datetime, timezone from pathlib import Path from typing import Dict, Any import torch from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer import gradio as gr ############################################################################### # Settings # ############################################################################### MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct" # FP32, CPU‑only PROMPTS_PATH = "oracle_prompts.json" # 100 unfinished lines STATE_PATH = "current_state.json" # persistent Oracle state DATA_PATH = "data.json" # JSONL of user guesses TOKENS_PER_PROMPT = 2048 # stop after N tokens SECS_BETWEEN_TOKENS = 15 # pacing (≈10h / prompt) TEMPERATURE = 0.8 TOP_P = 0.95 MAX_CONTEXT_TOKENS = 8192 ############################################################################### logging.basicConfig(format="[%(asctime)s] %(levelname)s: %(message)s", level=logging.INFO) log = logging.getLogger("what‑comes‑next") lock = threading.Lock() # global file/variable lock # --------------------------------------------------------------------------- # # Helper functions # # --------------------------------------------------------------------------- # def _read_json(path: str, default: Any): try: with open(path, "r", encoding="utf‑8") as f: return json.load(f) except FileNotFoundError: return default def _write_json(path: str, obj: Any): tmp = f"{path}.tmp" with open(tmp, "w", encoding="utf‑8") as f: json.dump(obj, f, ensure_ascii=False, indent=2) os.replace(tmp, path) def load_prompts() -> list[str]: if not os.path.exists(PROMPTS_PATH): raise FileNotFoundError(f"Missing {PROMPTS_PATH}. Please add 100 prompts.") with open(PROMPTS_PATH, "r", encoding="utf‑8") as f: return json.load(f) prompts = load_prompts() # --------------------------------------------------------------------------- # # Model loading (FP32 ‑ CPU) # # --------------------------------------------------------------------------- # log.info("Loading Llama‑3.1‑8B‑Instruct in FP32 on CPU (this is *slow*) …") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=torch.float32, device_map={"": "cpu"}, # force CPU ) model.eval() log.info("Model loaded.") # --------------------------------------------------------------------------- # # Oracle generation thread # # --------------------------------------------------------------------------- # def init_state() -> Dict[str, Any]: """Return existing state or create a new one.""" state = _read_json(STATE_PATH, {}) if state.get("finished", False): state = {} # finished, start new prompt if not state: prompt_idx = random.randrange(len(prompts)) prompt = prompts[prompt_idx] state = { "prompt_idx": prompt_idx, "prompt": prompt, "generated": "", # Oracle’s text so far (string) "start_time": time.time(), "finished": False, "tokens_done": 0 } _write_json(STATE_PATH, state) log.info(f"Starting new Oracle prompt #{prompt_idx}: {prompt[:60]}…") return state def oracle_loop(): """Continuously extend the Oracle’s text by one token every SECS_BETWEEN_TOKENS.""" while True: with lock: state = init_state() if state["finished"]: # Should not happen, but guard anyway time.sleep(SECS_BETWEEN_TOKENS) continue prompt_text = state["prompt"] generated_text = state["generated"] tokens_done = state["tokens_done"] # Build input_ids (prompt + generated so far) full_input = prompt_text + generated_text input_ids = tokenizer(full_input, return_tensors="pt", truncation=True, max_length=MAX_CONTEXT_TOKENS).input_ids # Generate ONE token with torch.no_grad(): outputs = model.generate( input_ids, max_new_tokens=1, do_sample=True, temperature=TEMPERATURE, top_p=TOP_P, ) next_token_id = outputs[0, -1].unsqueeze(0) next_token_text = tokenizer.decode(next_token_id, skip_special_tokens=True, clean_up_tokenization_spaces=False) with lock: # Update state state["generated"] += next_token_text state["tokens_done"] += 1 if state["tokens_done"] >= TOKENS_PER_PROMPT: state["finished"] = True log.info("Prompt complete. Oracle will pick a new one next cycle.") _write_json(STATE_PATH, state) time.sleep(SECS_BETWEEN_TOKENS) # pacing threading.Thread(target=oracle_loop, daemon=True).start() # --------------------------------------------------------------------------- # # Gradio Interface # # --------------------------------------------------------------------------- # def human_readable_elapsed(start: float) -> str: delta = int(time.time() - start) h, rem = divmod(delta, 3600) m, s = divmod(rem, 60) return f"{h}h {m}m {s}s" def get_current_state() -> Dict[str, Any]: with lock: state = _read_json(STATE_PATH, {}) if not state: return {"prompt": "…loading…", "generated": "", "elapsed": "0h 0m 0s"} return { "prompt": state["prompt"], "generated": state["generated"], "elapsed": human_readable_elapsed(state["start_time"]) } def record_guess(full_guess: str, idea_guess: str): state = get_current_state() guess_text = full_guess.strip() or idea_guess.strip() if not guess_text: return gr.update(value="⚠️ Please enter a guess in one of the boxes …"), gr.update() guess_type = "full" if full_guess.strip() else "idea" record = { "timestamp": datetime.now(timezone.utc).isoformat(), "prompt": state["prompt"], "point‑in‑time": state["elapsed"], "response‑point": state["generated"], "user‑guess": guess_text, "guess‑type": guess_type } # Append to JSONL (data.json) with lock: with open(DATA_PATH, "a", encoding="utf‑8") as f: f.write(json.dumps(record, ensure_ascii=False) + "\n") log.info(f"Recorded {guess_type} guess ({len(guess_text)} chars).") return gr.update(value="✅ Guess recorded – check back when the Oracle finishes!"), gr.update(value="") with gr.Blocks(title="What Comes Next", theme="gradio/soft") as demo: gr.Markdown("""# ✨ What Comes Next A global, slow‑burn guessing game. The Oracle is continuously writing its story. Read the prompt, see the Oracle’s progress, and predict **what comes next**! *(FP32 CPU inference – deliberately unhurried.)*""") ### Live Oracle view prompt_box = gr.Markdown(label="🔮 Current Oracle Prompt") oracle_box = gr.Textbox(label="📜 Oracle’s current text", lines=10, interactive=False) elapsed_box = gr.Textbox(label="⏱️ Elapsed", interactive=False) ### Guess inputs gr.Markdown("**Make your prediction:** Fill **either** the exact continuation *or* a general idea.") with gr.Row(): full_guess = gr.Textbox(label="🧠 Exact continuation (full)") idea_guess = gr.Textbox(label="💡 General idea") submit_btn = gr.Button("Submit Guess") status_msg = gr.Textbox(label="Status", interactive=False) ### Refresh button refresh_btn = gr.Button("🔄 Refresh Oracle progress") def refresh(): st = get_current_state() return st["prompt"], st["generated"], st["elapsed"] refresh_btn.click(refresh, outputs=[prompt_box, oracle_box, elapsed_box]) demo.load(refresh, outputs=[prompt_box, oracle_box, elapsed_box]) # auto‑load on launch submit_btn.click(record_guess, inputs=[full_guess, idea_guess], outputs=[status_msg, full_guess]) # clear full_guess box on success if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)