Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import pipeline | |
import torch | |
# Load the instruct version of the model | |
physician = pipeline("text-generation", model="HuggingFaceTB/SmolLM2-1.7B-Instruct") | |
patient = pipeline("text-generation", model="HuggingFaceTB/SmolLM2-1.7B-Instruct") | |
# System prompts to define roles (not included in the input prompt) | |
patient_system_prompt = "You are a patient describing your symptoms to a physician." | |
physician_system_prompt = "You are a physician responding to a patient's symptoms." | |
def generate_conversation(topic, turns): | |
conversation = [] | |
total_tokens = 0 | |
physician_tokens = 0 | |
patient_tokens = 0 | |
# Initial prompt for the patient | |
patient_input = f"Patient: I'm here to talk about {topic}." | |
print(f"Patient Initial Input: {patient_input}") # Debugging | |
patient_response = patient( | |
patient_input, | |
max_new_tokens=50, # Allow the model to generate up to 50 new tokens | |
num_return_sequences=1, | |
truncation=True, # Explicitly enable truncation | |
do_sample=True, # Enable sampling | |
temperature=0.7 # Control randomness | |
)[0]['generated_text'] | |
print(f"Patient Response: {patient_response}") # Debugging | |
patient_tokens += len(patient_response.split()) | |
conversation.append({"role": "patient", "message": patient_response, "tokens": len(patient_response.split())}) | |
for turn in range(turns): | |
# Physician's turn | |
print(f"Physician Turn {turn} Prompt: {patient_response}") # Debugging | |
physician_response = physician( | |
f"Physician: {patient_response}", | |
max_new_tokens=50, # Allow the model to generate up to 50 new tokens | |
num_return_sequences=1, | |
truncation=True, # Explicitly enable truncation | |
do_sample=True, # Enable sampling | |
temperature=0.7 # Control randomness | |
)[0]['generated_text'] | |
print(f"Physician Response: {physician_response}") # Debugging | |
physician_tokens += len(physician_response.split()) | |
conversation.append({"role": "physician", "message": physician_response, "tokens": len(physician_response.split())}) | |
# Patient's turn | |
print(f"Patient Turn {turn} Prompt: {physician_response}") # Debugging | |
patient_response = patient( | |
f"Patient: {physician_response}", | |
max_new_tokens=50, # Allow the model to generate up to 50 new tokens | |
num_return_sequences=1, | |
truncation=True, # Explicitly enable truncation | |
do_sample=True, # Enable sampling | |
temperature=0.7 # Control randomness | |
)[0]['generated_text'] | |
print(f"Patient Response: {patient_response}") # Debugging | |
patient_tokens += len(patient_response.split()) | |
conversation.append({"role": "patient", "message": patient_response, "tokens": len(patient_response.split())}) | |
# Summarize the conversation | |
summary = { | |
"total_tokens": physician_tokens + patient_tokens, | |
"physician_tokens": physician_tokens, | |
"patient_tokens": patient_tokens | |
} | |
return conversation, summary | |
def app_interface(topic, turns): | |
conversation, summary = generate_conversation(topic, turns) | |
output = { | |
"input": {"topic": topic, "turns": turns}, | |
"conversation": conversation, | |
"summary": summary | |
} | |
return output | |
# Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("## π¨ββοΈ Synthetic Data Generation: Physician-Patient Role-Play π€") | |
with gr.Row(): | |
topic_input = gr.Textbox(label="Enter Disease/Topic", placeholder="e.g., chest pain") | |
turns_input = gr.Number(label="Number of Turns", value=1) # Default to 1 turn for debugging | |
submit_button = gr.Button("π Start Interaction") | |
output_json = gr.JSON(label="Generated Conversation") | |
# Download button for the conversation | |
download_button = gr.Button("π₯ Download Conversation") | |
download_button.click( | |
fn=lambda data: gr.File.download(data), | |
inputs=output_json, | |
outputs=gr.File() | |
) | |
submit_button.click( | |
fn=app_interface, | |
inputs=[topic_input, turns_input], | |
outputs=output_json | |
) | |
demo.launch() | |