File size: 3,368 Bytes
56a8be9
 
3ad9db3
56a8be9
a5f7bd5
 
 
56a8be9
 
 
 
 
 
 
 
 
a5f7bd5
56a8be9
a5f7bd5
56a8be9
 
 
 
 
 
a5f7bd5
56a8be9
a5f7bd5
56a8be9
 
 
 
 
a5f7bd5
56a8be9
a5f7bd5
56a8be9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a5f7bd5
56a8be9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a5f7bd5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import gradio as gr
from transformers import pipeline
import torch

# Load the smaller models
physician = pipeline("text-generation", model="HuggingFaceTB/SmolLM2-1.7B")
patient = pipeline("text-generation", model="HuggingFaceTB/SmolLM2-1.7B")

def generate_conversation(topic, turns):
    conversation = []
    total_tokens = 0
    physician_tokens = 0
    patient_tokens = 0

    # Initial prompt for the patient
    patient_prompt = f"I'm here to talk about {topic}."
    print(f"Patient Initial Prompt: {patient_prompt}")  # Debugging
    patient_response = patient(patient_prompt, max_length=50, num_return_sequences=1)[0]['generated_text']
    print(f"Patient Response: {patient_response}")  # Debugging
    patient_tokens += len(patient_response.split())
    conversation.append({"role": "patient", "message": patient_response, "tokens": len(patient_response.split())})

    for turn in range(turns):
        # Physician's turn
        physician_prompt = f"As a physician, how would you respond to: {patient_response}"
        print(f"Physician Turn {turn} Prompt: {physician_prompt}")  # Debugging
        physician_response = physician(physician_prompt, max_length=50, num_return_sequences=1)[0]['generated_text']
        print(f"Physician Response: {physician_response}")  # Debugging
        physician_tokens += len(physician_response.split())
        conversation.append({"role": "physician", "message": physician_response, "tokens": len(physician_response.split())})

        # Patient's turn
        patient_prompt = f"As a patient, how would you respond to: {physician_response}"
        print(f"Patient Turn {turn} Prompt: {patient_prompt}")  # Debugging
        patient_response = patient(patient_prompt, max_length=50, num_return_sequences=1)[0]['generated_text']
        print(f"Patient Response: {patient_response}")  # Debugging
        patient_tokens += len(patient_response.split())
        conversation.append({"role": "patient", "message": patient_response, "tokens": len(patient_response.split())})

    # Summarize the conversation
    summary = {
        "total_tokens": physician_tokens + patient_tokens,
        "physician_tokens": physician_tokens,
        "patient_tokens": patient_tokens
    }

    return conversation, summary

def app_interface(topic, turns):
    conversation, summary = generate_conversation(topic, turns)
    output = {
        "input": {"topic": topic, "turns": turns},
        "conversation": conversation,
        "summary": summary
    }
    return output

# Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("## πŸ‘¨β€βš•οΈ Synthetic Data Generation: Physician-Patient Role-Play πŸ‘€")
    with gr.Row():
        topic_input = gr.Textbox(label="Enter Disease/Topic", placeholder="e.g., chest pain")
        turns_input = gr.Number(label="Number of Turns", value=1)  # Default to 1 turn for debugging
    submit_button = gr.Button("πŸš€ Start Interaction")
    output_json = gr.JSON(label="Generated Conversation")

    # Download button for the conversation
    download_button = gr.Button("πŸ“₯ Download Conversation")
    download_button.click(
        fn=lambda data: gr.File.download(data),
        inputs=output_json,
        outputs=gr.File()
    )

    submit_button.click(
        fn=app_interface,
        inputs=[topic_input, turns_input],
        outputs=output_json
    )

demo.launch()