File size: 3,829 Bytes
56a8be9
1c1b8b9
56a8be9
1c1b8b9
 
 
 
 
56a8be9
1c1b8b9
 
 
 
2cf432b
1c1b8b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d69dd7
56a8be9
 
 
 
 
 
 
1c1b8b9
8b369df
1c1b8b9
a5f7bd5
56a8be9
 
 
 
 
2cf432b
1c1b8b9
a5f7bd5
56a8be9
 
 
 
2cf432b
1c1b8b9
a5f7bd5
56a8be9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a5f7bd5
56a8be9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c1b8b9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import gradio as gr
from smolagents import CodeAgent, HfApiModel

# Define system prompts for the agents
patient_system_prompt = """
You are a patient describing your symptoms to a physician. You are here to talk about a health issue.
Be concise and provide relevant information about your symptoms.
"""

physician_system_prompt = """
You are a physician responding to a patient's symptoms. 
Ask relevant questions to understand the patient's condition and provide appropriate advice.
"""

# Load the models for the agents
patient_model = HfApiModel(model_id="HuggingFaceTB/SmolLM2-1.7B-Instruct")
physician_model = HfApiModel(model_id="HuggingFaceTB/SmolLM2-1.7B-Instruct")

# Initialize the agents
patient_agent = CodeAgent(
    model=patient_model,
    system_prompt=patient_system_prompt,
    planning_interval=1  # Allow the agent to plan after each turn
)

physician_agent = CodeAgent(
    model=physician_model,
    system_prompt=physician_system_prompt,
    planning_interval=1  # Allow the agent to plan after each turn
)

def generate_conversation(topic, turns):
    conversation = []
    total_tokens = 0
    physician_tokens = 0
    patient_tokens = 0

    # Initial prompt for the patient
    patient_input = f"I'm here to talk about {topic}."
    print(f"Patient Initial Input: {patient_input}")  # Debugging
    patient_response = patient_agent.run(patient_input)
    print(f"Patient Response: {patient_response}")  # Debugging
    patient_tokens += len(patient_response.split())
    conversation.append({"role": "patient", "message": patient_response, "tokens": len(patient_response.split())})

    for turn in range(turns):
        # Physician's turn
        print(f"Physician Turn {turn} Prompt: {patient_response}")  # Debugging
        physician_response = physician_agent.run(patient_response)
        print(f"Physician Response: {physician_response}")  # Debugging
        physician_tokens += len(physician_response.split())
        conversation.append({"role": "physician", "message": physician_response, "tokens": len(physician_response.split())})

        # Patient's turn
        print(f"Patient Turn {turn} Prompt: {physician_response}")  # Debugging
        patient_response = patient_agent.run(physician_response)
        print(f"Patient Response: {patient_response}")  # Debugging
        patient_tokens += len(patient_response.split())
        conversation.append({"role": "patient", "message": patient_response, "tokens": len(patient_response.split())})

    # Summarize the conversation
    summary = {
        "total_tokens": physician_tokens + patient_tokens,
        "physician_tokens": physician_tokens,
        "patient_tokens": patient_tokens
    }

    return conversation, summary

def app_interface(topic, turns):
    conversation, summary = generate_conversation(topic, turns)
    output = {
        "input": {"topic": topic, "turns": turns},
        "conversation": conversation,
        "summary": summary
    }
    return output

# Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("## πŸ‘¨β€βš•οΈ Synthetic Data Generation: Physician-Patient Role-Play πŸ‘€")
    with gr.Row():
        topic_input = gr.Textbox(label="Enter Disease/Topic", placeholder="e.g., chest pain")
        turns_input = gr.Number(label="Number of Turns", value=1)  # Default to 1 turn for debugging
    submit_button = gr.Button("πŸš€ Start Interaction")
    output_json = gr.JSON(label="Generated Conversation")

    # Download button for the conversation
    download_button = gr.Button("πŸ“₯ Download Conversation")
    download_button.click(
        fn=lambda data: gr.File.download(data),
        inputs=output_json,
        outputs=gr.File()
    )

    submit_button.click(
        fn=app_interface,
        inputs=[topic_input, turns_input],
        outputs=output_json
    )

demo.launch()