Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import pipeline | |
import torch | |
# Load the physician and patient models via Hugging Face Model Hub | |
physician = pipeline("text-generation", model="HuggingFaceTB/SmolLM2-1.7B") # Replace with actual medical model | |
patient = pipeline("text-generation", model="HuggingFaceTB/SmolLM2-1.7B") # General conversational model | |
def generate_conversation(topic, turns): | |
conversation = [] | |
total_tokens = 0 | |
physician_tokens = 0 | |
patient_tokens = 0 | |
# Initial prompt for the patient | |
patient_prompt = f"I'm here to talk about {topic}." | |
patient_response = patient(patient_prompt, max_length=50, num_return_sequences=1)[0]['generated_text'] | |
patient_tokens += len(patient_response.split()) | |
conversation.append({"role": "patient", "message": patient_response, "tokens": len(patient_response.split())}) | |
for turn in range(turns): | |
# Physician's turn | |
physician_prompt = f"As a physician, how would you respond to: {patient_response}" | |
physician_response = physician(physician_prompt, max_length=50, num_return_sequences=1)[0]['generated_text'] | |
physician_tokens += len(physician_response.split()) | |
conversation.append({"role": "physician", "message": physician_response, "tokens": len(physician_response.split())}) | |
# Patient's turn | |
patient_prompt = f"As a patient, how would you respond to: {physician_response}" | |
patient_response = patient(patient_prompt, max_length=50, num_return_sequences=1)[0]['generated_text'] | |
patient_tokens += len(patient_response.split()) | |
conversation.append({"role": "patient", "message": patient_response, "tokens": len(patient_response.split())}) | |
# Summarize the conversation | |
summary = { | |
"total_tokens": physician_tokens + patient_tokens, | |
"physician_tokens": physician_tokens, | |
"patient_tokens": patient_tokens | |
} | |
return conversation, summary | |
def app_interface(topic, turns): | |
conversation, summary = generate_conversation(topic, turns) | |
output = { | |
"input": {"topic": topic, "turns": turns}, | |
"conversation": conversation, | |
"summary": summary | |
} | |
return output | |
# Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("## π¨ββοΈ Synthetic Data Generation: Physician-Patient Role-Play π€") | |
with gr.Row(): | |
topic_input = gr.Textbox(label="Enter Disease/Topic", placeholder="e.g., chest pain") | |
turns_input = gr.Number(label="Number of Turns", value=5) | |
submit_button = gr.Button("π Start Interaction") | |
output_json = gr.JSON(label="Generated Conversation") | |
# Download button for the conversation | |
download_button = gr.Button("π₯ Download Conversation") | |
download_button.click( | |
fn=lambda data: gr.File.download(data), | |
inputs=output_json, | |
outputs=gr.File() | |
) | |
submit_button.click( | |
fn=app_interface, | |
inputs=[topic_input, turns_input], | |
outputs=output_json | |
) | |
demo.launch() | |