demo / examples /code_editor_scripts.py
huangzhii
update
32486dc
raw
history blame
5.77 kB
import streamlit as st
from streamlit_elements import elements, mui, editor, dashboard
from stqdm import stqdm
import textgrad as tg
import os
class CodeEditor:
def __init__(self, data) -> None:
self.data = data
self.llm_engine = tg.get_engine("gpt-4o")
print("="*50, "init", "="*50)
self.loss_value = ""
self.code_gradients = ""
if 'iteration' not in st.session_state:
st.session_state.iteration = 0
if 'results' not in st.session_state:
st.session_state.results = []
tg.set_backward_engine(self.llm_engine, override=True)
def load_layout(self):
col1, col2 = st.columns([1, 1])
with col1:
self.problem = st.text_area("Problem description:", self.data["default_problem_description"], height=300)
with col2:
self.loss_system_prompt = st.text_area("Loss system prompt:", self.data["default_loss_system_prompt"], height=150)
self.instruction = st.text_area("Instruction for formatted LLM call:", self.data["instruction"], height=100)
if "code_content" not in st.session_state:
st.session_state.code_content = self.data["default_initial_solution"]
def update_code_content(value):
st.session_state.code_content = value
col1, col2 = st.columns(2)
with col1:
with elements("monaco_editors_1"):
mui.Typography("Initial Solution:", sx={"fontSize": "20px", "fontWeight": "bold"})
editor.Monaco(
height=300,
defaultLanguage="python",
defaultValue=st.session_state.code_content,
onChange=update_code_content
)
with col2:
with elements("monaco_editors_2"):
mui.Typography("Current Solution:", sx={"fontSize": "20px", "fontWeight": "bold"})
editor.Monaco(
height=300,
defaultLanguage="python",
value=st.session_state.code_content,
options={"readOnly": True} # Make the editor read-only
)
# format_string = f"{instruction}\nProblem: {problem}\nCurrent Code: {st.session_state.code_content}"
# mui.Typography(format_string)
# mui.Typography("Final Snippet vs. Current Solution:", sx={"fontSize": "20px", "fontWeight": "bold"})
# editor.MonacoDiff(
# original=self.data["default_target_solution"],
# modified=st.session_state.code_content,
# height=300,
# )
def _run(self):
# Code is the variable of interest we want to optimize -- so requires_grad=True
solution = st.session_state.code_content
code = tg.Variable(value=solution,
requires_grad=True,
role_description="code instance to optimize")
# We are not interested in optimizing the problem -- so requires_grad=False
problem = tg.Variable(self.problem,
requires_grad=False,
role_description="the coding problem")
# Let TGD know to update code!
optimizer = tg.TGD(parameters=[code])
instruction = self.instruction
llm_engine = self.llm_engine
loss_system_prompt = self.loss_system_prompt
loss_system_prompt = tg.Variable(loss_system_prompt, requires_grad=False, role_description="system prompt to the loss function")
format_string = "{instruction}\nProblem: {{problem}}\nCurrent Code: {{code}}"
format_string = format_string.format(instruction=self.instruction)
fields = {"problem": None, "code": None}
formatted_llm_call = tg.autograd.FormattedLLMCall(engine=self.llm_engine,
format_string=format_string,
fields=fields,
system_prompt=loss_system_prompt)
# Finally, the loss function
def loss_fn(problem: tg.Variable, code: tg.Variable) -> tg.Variable:
inputs = {"problem": problem, "code": code}
return formatted_llm_call(inputs=inputs,
response_role_description=f"evaluation of the {code.get_role_description()}")
loss = loss_fn(problem, code)
self.loss_value = loss.value
self.graph = loss.generate_graph()
loss.backward()
self.gradients = code.gradients
optimizer.step() # Let's update the code
st.session_state.code_content = code.value
def show_results(self):
self._run()
st.session_state.iteration += 1
st.session_state.results.append({
'iteration': st.session_state.iteration,
'loss_value': self.loss_value,
'gradients': self.gradients
})
tabs = st.tabs([f"Iteration {i+1}" for i in range(st.session_state.iteration)])
for i, tab in enumerate(tabs):
with tab:
result = st.session_state.results[i]
st.markdown(f"Current iteration: **{result['iteration']}**")
col1, col2 = st.columns([1, 1])
with col1:
st.markdown("## Loss value")
st.markdown(result['loss_value'])
with col2:
st.markdown("## Code gradients")
for j, g in enumerate(result['gradients']):
st.markdown(f"### Gradient {j}")
st.markdown(g.value)