File size: 3,016 Bytes
56a8be9
 
3ad9db3
56a8be9
 
11038f2
7f9e70c
56a8be9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import gradio as gr
from transformers import pipeline
import torch

# Load the physician and patient models via Hugging Face Model Hub
physician = pipeline("text-generation", model="HuggingFaceTB/SmolLM2-1.7B")  # Replace with actual medical model
patient = pipeline("text-generation", model="HuggingFaceTB/SmolLM2-1.7B")  # General conversational model

def generate_conversation(topic, turns):
    conversation = []
    total_tokens = 0
    physician_tokens = 0
    patient_tokens = 0

    # Initial prompt for the patient
    patient_prompt = f"I'm here to talk about {topic}."
    patient_response = patient(patient_prompt, max_length=50, num_return_sequences=1)[0]['generated_text']
    patient_tokens += len(patient_response.split())
    conversation.append({"role": "patient", "message": patient_response, "tokens": len(patient_response.split())})

    for turn in range(turns):
        # Physician's turn
        physician_prompt = f"As a physician, how would you respond to: {patient_response}"
        physician_response = physician(physician_prompt, max_length=50, num_return_sequences=1)[0]['generated_text']
        physician_tokens += len(physician_response.split())
        conversation.append({"role": "physician", "message": physician_response, "tokens": len(physician_response.split())})

        # Patient's turn
        patient_prompt = f"As a patient, how would you respond to: {physician_response}"
        patient_response = patient(patient_prompt, max_length=50, num_return_sequences=1)[0]['generated_text']
        patient_tokens += len(patient_response.split())
        conversation.append({"role": "patient", "message": patient_response, "tokens": len(patient_response.split())})

    # Summarize the conversation
    summary = {
        "total_tokens": physician_tokens + patient_tokens,
        "physician_tokens": physician_tokens,
        "patient_tokens": patient_tokens
    }

    return conversation, summary

def app_interface(topic, turns):
    conversation, summary = generate_conversation(topic, turns)
    output = {
        "input": {"topic": topic, "turns": turns},
        "conversation": conversation,
        "summary": summary
    }
    return output

# Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("## πŸ‘¨β€βš•οΈ Synthetic Data Generation: Physician-Patient Role-Play πŸ‘€")
    with gr.Row():
        topic_input = gr.Textbox(label="Enter Disease/Topic", placeholder="e.g., chest pain")
        turns_input = gr.Number(label="Number of Turns", value=5)
    submit_button = gr.Button("πŸš€ Start Interaction")
    output_json = gr.JSON(label="Generated Conversation")

    # Download button for the conversation
    download_button = gr.Button("πŸ“₯ Download Conversation")
    download_button.click(
        fn=lambda data: gr.File.download(data),
        inputs=output_json,
        outputs=gr.File()
    )

    submit_button.click(
        fn=app_interface,
        inputs=[topic_input, turns_input],
        outputs=output_json
    )

demo.launch()