medical-chatbot / medical_chatbot.py
0208suin's picture
Upload medical_chatbot.py with huggingface_hub
bcdfa3b verified
raw
history blame
13 kB
# ํ•„์š”ํ•œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ์„ค์น˜๋ฅผ ์œ„ํ•œ ์…€
# !pip install transformers torch pandas openpyxl
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import pandas as pd
from datetime import datetime
class SimplifiedMedicalChatbot:
def __init__(self, data_path="/content/sample_data/sick.xlsx"):
self.load_database(data_path)
self.disclaimer = """
์ฃผ์˜: ์ด ์ฑ—๋ด‡์€ ์ฐธ๊ณ ์šฉ์œผ๋กœ๋งŒ ์‚ฌ์šฉํ•˜์‹œ๊ธฐ ๋ฐ”๋ž๋‹ˆ๋‹ค.
์ •ํ™•ํ•œ ์ง„๋‹จ์„ ์œ„ํ•ด์„œ๋Š” ๋ฐ˜๋“œ์‹œ ์˜๋ฃŒ ์ „๋ฌธ๊ฐ€์™€ ์ƒ๋‹ดํ•˜์„ธ์š”.
์ฆ์ƒ์ด ์‹ฌ๊ฐํ•˜๋‹ค๊ณ  ํŒ๋‹จ๋˜๋ฉด ์ฆ‰์‹œ ๋ณ‘์›์„ ๋ฐฉ๋ฌธํ•˜์‹œ๊ธฐ ๋ฐ”๋ž๋‹ˆ๋‹ค.
"""
def load_database(self, data_path):
try:
self.df = pd.read_excel(data_path)
self.df = self.df.fillna('')
print(f"๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค ๋กœ๋“œ ์™„๋ฃŒ: {len(self.df)} ๊ฐœ์˜ ๋ ˆ์ฝ”๋“œ")
except Exception as e:
print(f"๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค ๋กœ๋”ฉ ์‹คํŒจ: {str(e)}")
self.df = pd.DataFrame()
def calculate_age(self, born_year):
"""์ถœ์ƒ์—ฐ๋„๋กœ ๋‚˜์ด ๊ณ„์‚ฐ"""
current_year = datetime.now().year
return current_year - int(born_year)
def get_age_group(self, age):
if age < 12:
return "์–ด๋ฆฐ์ด"
elif age >= 65:
return "๋…ธ์ธ"
else:
return "์„ฑ์ธ"
def simple_symptom_matching(self, pain, description):
"""์ฆ์ƒ ๋งค์นญ ๋กœ์ง - pain๊ณผ description ํ™œ์šฉ, ์‹ ๋ขฐ๋„ ์ ์ˆ˜ ๊ณ„์‚ฐ"""
if '์ฆ์ƒ' not in self.df.columns:
print("์ฆ์ƒ ์ปฌ๋Ÿผ์ด ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค์— ์—†์Šต๋‹ˆ๋‹ค.")
return []
matched_diseases = []
symptoms = f"{pain} {description}".strip().split()
for _, row in self.df.iterrows():
db_symptoms = row['์ฆ์ƒ'].split()
# ๊ฐ ์ฆ์ƒ๋ณ„ ์ ์ˆ˜ ๊ณ„์‚ฐ
score = 0
# ์ฃผ์š” ์ฆ์ƒ(pain) ์ฒดํฌ
if pain.strip() and pain in row['์ฆ์ƒ']:
score += 60 # ์ฃผ์š” ์ฆ์ƒ ์ผ์น˜์‹œ ๊ธฐ๋ณธ 60์ 
# ์ƒ์„ธ ์ฆ์ƒ ์ฒดํฌ
desc_symptoms = description.strip().split() if description else []
for symptom in desc_symptoms:
if symptom in row['์ฆ์ƒ']:
score += 20 # ๊ฐ ๋ถ€๊ฐ€ ์ฆ์ƒ ์ผ์น˜์‹œ 20์  ์ถ”๊ฐ€
# ์ตœ์†Œํ•œ์˜ ๊ด€๋ จ์„ฑ์ด ์žˆ๋Š” ๊ฒฝ์šฐ ์ถ”๊ฐ€ (0์  ์ด์ƒ)
if score > 0 or pain in row['์ฆ์ƒ']:
matched_diseases.append((row['์งˆํ™˜๋ช…'], score))
# ๊ฒฐ๊ณผ๊ฐ€ 3๊ฐœ ๋ฏธ๋งŒ์ธ ๊ฒฝ์šฐ ๊ด€๋ จ ์งˆ๋ณ‘ ์ถ”๊ฐ€
sorted_diseases = sorted(matched_diseases, key=lambda x: x[1], reverse=True)
if len(sorted_diseases) < 3:
# ์ ์ˆ˜์™€ ๊ด€๊ณ„์—†์ด ์ฃผ์š” ์ฆ์ƒ์ด ํฌํ•จ๋œ ๋‹ค๋ฅธ ์งˆ๋ณ‘๋“ค๋„ ์ถ”๊ฐ€
for _, row in self.df.iterrows():
disease_name = row['์งˆํ™˜๋ช…']
if disease_name not in [d[0] for d in sorted_diseases]:
if pain in row['์ฆ์ƒ']:
sorted_diseases.append((disease_name, 10)) # ๋‚ฎ์€ ์ ์ˆ˜๋กœ ์ถ”๊ฐ€
if len(sorted_diseases) >= 3:
break
# ์ƒ์œ„ 3๊ฐœ ๋ฐ˜ํ™˜ (3๊ฐœ ๋ฏธ๋งŒ์ธ ๊ฒฝ์šฐ ์ „์ฒด ๋ฐ˜ํ™˜)
return sorted_diseases[:3]
def get_medicine_info(self, disease, age_group, medi_tf=False):
"""์งˆ๋ณ‘๊ณผ ๋‚˜์ด๋Œ€์— ๋”ฐ๋ฅธ ์•ฝํ’ˆ ์ •๋ณด ๋ฐ˜ํ™˜"""
try:
disease_info = self.df[self.df['์งˆํ™˜๋ช…'] == disease]
if disease_info.empty:
return {
"ageWarning": "",
"medicine": "",
"ingredients": "",
"medicalAttention": [],
"error": "ํ•ด๋‹น ์งˆ๋ณ‘์— ๋Œ€ํ•œ ์ •๋ณด๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
}
result = {
"ageWarning": "",
"medicine": "",
"ingredients": "",
"medicalAttention": []
}
# ๋‚˜์ด๋Œ€๋ณ„ ์ฃผ์˜์‚ฌํ•ญ
age_specific_info = disease_info['๋‚˜์ด๋Œ€ '].iloc[0]
if age_specific_info:
if age_group == "์–ด๋ฆฐ์ด":
result["ageWarning"] = f"์–ด๋ฆฐ์ด(12์„ธ ๋ฏธ๋งŒ) ์ฃผ์˜์‚ฌํ•ญ: {age_specific_info}"
elif age_group == "๋…ธ์ธ":
result["ageWarning"] = f"๋…ธ์ธ(65์„ธ ์ด์ƒ) ์ฃผ์˜์‚ฌํ•ญ: {age_specific_info}"
# mediTF๊ฐ€ true์ผ ๋•Œ๋งŒ ์•ฝํ’ˆ ์ •๋ณด ํฌํ•จ
if medi_tf:
if disease_info['์ถ”์ฒœ ์˜์•ฝํ’ˆ'].iloc[0]:
result["medicine"] = disease_info['์ถ”์ฒœ ์˜์•ฝํ’ˆ'].iloc[0]
if disease_info['์„ฑ๋ถ„'].iloc[0]:
result["ingredients"] = disease_info['์„ฑ๋ถ„'].iloc[0]
# ์˜์‚ฌ ์ง„๋ฃŒ ํ•„์š” ์‚ฌํ•ญ
medical_attention = disease_info['์˜์‚ฌ์˜ ์ง„๋ฃŒ๊ฐ€ ํ•„์š”ํ•œ ๊ฒฝ์šฐ'].iloc[0]
if medical_attention:
attention_list = [item.strip() for item in medical_attention.split('
')]
result["medicalAttention"] = attention_list
return result
except Exception as e:
print(f"์˜์•ฝํ’ˆ ์ •๋ณด ์กฐํšŒ ์ค‘ ์˜ค๋ฅ˜: {str(e)}")
return {
"ageWarning": "",
"medicine": "",
"ingredients": "",
"medicalAttention": [],
"error": f"์ •๋ณด ์กฐํšŒ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}"
}
def format_chat_response(self, response_data):
"""JSON ์‘๋‹ต์„ ์‚ฌ์šฉ์ž ์นœํ™”์ ์ธ ํ…์ŠคํŠธ๋กœ ๋ณ€ํ™˜"""
chat_response = []
# ์‚ฌ์šฉ์ž ์ •๋ณด ์š”์•ฝ
chat_response.append(f"\n[{response_data['ageGroup']}({response_data['age']}์„ธ) / {response_data['sex']} ์‚ฌ์šฉ์ž ๋ถ„์„ ๊ฒฐ๊ณผ]")
# ๋ถ„์„ ๊ฒฐ๊ณผ
for analysis in response_data['analysis']:
if analysis['disease']:
chat_response.append(f"\n## ์ถ”์ • ์งˆํ™˜: {analysis['disease']}")
if analysis['ageWarning']:
chat_response.append(analysis['ageWarning'])
if analysis['medicine']:
chat_response.append(f"์ถ”์ฒœ ์˜์•ฝํ’ˆ: {analysis['medicine']}")
if analysis['ingredients']:
chat_response.append(f"์ฃผ์š” ์„ฑ๋ถ„: {analysis['ingredients']}")
if analysis['medicalAttention']:
chat_response.append("์˜์‚ฌ์˜ ์ง„๋ฃŒ๊ฐ€ ํ•„์š”ํ•œ ๊ฒฝ์šฐ:")
chat_response.extend(analysis['medicalAttention'])
else:
chat_response.append("\n์ผ์น˜ํ•˜๋Š” ์งˆ๋ณ‘์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
chat_response.append("-" * 40)
# ๋ฉด์ฑ… ์กฐํ•ญ
chat_response.append(f"\n{response_data['disclaimer']}")
return "\n".join(chat_response)
def process_request(self, request_data):
"""๋ฐฑ์—”๋“œ ์š”์ฒญ ์ฒ˜๋ฆฌ"""
try:
# ์‚ฌ์šฉ์ž ์ •๋ณด ์ถ”์ถœ
user_id = request_data.get("userId", "")
born_year = request_data.get("bornYear", "")
sex = request_data.get("sex", "")
# ์ฆ์ƒ ์ •๋ณด ์ถ”์ถœ
pain = request_data.get("pain", "")
pain_description = request_data.get("description", "")
medi_tf = request_data.get("mediTF", False)
# ๋‚˜์ด ๊ณ„์‚ฐ
age = self.calculate_age(born_year)
age_group = self.get_age_group(age)
# ์ฆ์ƒ ๋งค์นญ
matched_diseases = self.simple_symptom_matching(pain, pain_description)
# ์‘๋‹ต ์ƒ์„ฑ
response = {
"userId": user_id,
"age": age,
"ageGroup": age_group,
"sex": sex,
"analysis": []
}
if matched_diseases:
response["analysis"] = []
for disease, _ in matched_diseases:
medical_info = self.get_medicine_info(disease, age_group, medi_tf)
disease_info = {
"disease": disease,
"ageWarning": medical_info["ageWarning"],
"medicine": medical_info["medicine"],
"ingredients": medical_info["ingredients"],
"medicalAttention": medical_info["medicalAttention"]
}
response["analysis"].append(disease_info)
else:
response["analysis"].append({
"disease": None,
"confidence": 0.0, # ๋งค์นญ ์‹คํŒจ์‹œ 0% ์‹ ๋ขฐ๋„
"ageWarning": "",
"medicine": "",
"ingredients": "",
"medicalAttention": [],
"error": "์ผ์น˜ํ•˜๋Š” ์งˆ๋ณ‘์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
})
response["disclaimer"] = self.disclaimer
return {
"json_response": response,
"chat_message": self.format_chat_response(response)
}
except Exception as e:
error_response = {
"error": f"์š”์ฒญ ์ฒ˜๋ฆฌ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}",
"disclaimer": self.disclaimer
}
return {
"json_response": error_response,
"chat_message": f"์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}\n\n{self.disclaimer}"
}
def create_sample_data():
"""ํ…Œ์ŠคํŠธ์šฉ ์ƒ˜ํ”Œ ๋ฐ์ดํ„ฐ ์ƒ์„ฑ"""
data = {
'์งˆํ™˜๋ช…': ['ํŽธ๋‘ํ†ต', '๊ตฐ๋ฐœ์„ฑ ๋‘ํ†ต', '๊ธด์žฅ์„ฑ ๋‘ํ†ต', '๋‡Œ์ˆ˜๋ง‰์—ผ', '๊ฐ๊ธฐ', '์œ„์—ผ'],
'์ฆ์ƒ': [
'๋‘ํ†ต ํ•œ์ชฝ ๋จธ๋ฆฌ ํ†ต์ฆ ๊ตฌํ†  ์–ด์ง€๋Ÿฌ์›€ ๋ฉ”์Šค๊บผ์›€',
'ํ•œ์ชฝ ๋จธ๋ฆฌ ํ†ต์ฆ ๋ˆˆํ†ต์ฆ ์ฝง๋ฌผ ์–ด์ง€๋Ÿฌ์›€',
'๋‘ํ†ต ๋ชฉํ†ต์ฆ ์–ด๊นจํ†ต์ฆ ๋ฉ”์Šค๊บผ์›€ ์ŠคํŠธ๋ ˆ์Šค',
'๋‘ํ†ต ๊ตฌํ†  ๋ฐœ์—ด ๋ชฉ์ด ๋ปฃ๋ปฃํ•จ ๋ฉ”์Šค๊บผ์›€',
'๋‘ํ†ต ๋ฐœ์—ด ๊ธฐ์นจ ์ฝง๋ฌผ',
'๋ณตํ†ต ๋ฉ”์Šค๊บผ์›€ ๊ตฌํ†  ์†Œํ™”๋ถˆ๋Ÿ‰'
],
'์ถ”์ฒœ ์˜์•ฝํ’ˆ': [
'๊ฒŒ๋ณด๋ฆฐ',
'ํƒ€์ด๋ ˆ๋†€',
'์ด์ง€์—”6',
'ํ•ญ์ƒ์ œ ์ฒ˜๋ฐฉ ํ•„์š”',
'ํƒ€์ด๋ ˆ๋†€',
'๊ฐœ๋น„์Šค์ฝ˜'
],
'์„ฑ๋ถ„': [
'์ด๋ถ€ํ”„๋กœํŽœ',
'์•„์„ธํŠธ์•„๋ฏธ๋…ธํŽœ',
'๋‚˜ํ”„๋ก์„ผ',
'์ฒ˜๋ฐฉ์•ฝ๋งŒ ๊ฐ€๋Šฅ',
'์•„์„ธํŠธ์•„๋ฏธ๋…ธํŽœ',
'์•Œ๊ธด์‚ฐ๋‚˜ํŠธ๋ฅจ'
],
'์˜์‚ฌ์˜ ์ง„๋ฃŒ๊ฐ€ ํ•„์š”ํ•œ ๊ฒฝ์šฐ': [
'๋‘ํ†ต์ด 24์‹œ๊ฐ„ ์ด์ƒ ์ง€์†๋  ๋•Œ\n์‹œ์•ผ๊ฐ€ ํ๋ ค์งˆ ๋•Œ',
'ํ†ต์ฆ์ด ๋งค์šฐ ์‹ฌํ•  ๋•Œ\nํ•˜๋ฃจ ์—ฌ๋Ÿฌ๋ฒˆ ๋ฐœ์ƒํ•  ๋•Œ',
'๋‘ํ†ต์ด ๋งŒ์„ฑ์ ์ผ ๋•Œ\n์ผ์ƒ์ƒํ™œ์ด ์–ด๋ ค์šธ ๋•Œ',
'์ฆ‰์‹œ ๋ณ‘์› ๋ฐฉ๋ฌธ ํ•„์š”\n์‘๊ธ‰์ƒํ™ฉ์ผ ์ˆ˜ ์žˆ์Œ',
'๊ณ ์—ด์ด 3์ผ ์ด์ƒ ์ง€์†๋  ๋•Œ\nํ˜ธํก๊ณค๋ž€์ด ์žˆ์„ ๋•Œ',
'๋ณตํ†ต์ด ์‹ฌํ•˜๊ณ  ์ง€์†๋  ๋•Œ\n์œ„์žฅ ์ถœํ˜ˆ ์ฆ์ƒ์ด ์žˆ์„ ๋•Œ'
],
'๋‚˜์ด๋Œ€ ': [
'์ง„ํ†ต์ œ ๋ณต์šฉ๋Ÿ‰ ์กฐ์ ˆ ํ•„์š”',
'๋…ธ์ธ ํˆฌ์•ฝ ์ฃผ์˜',
'์ฒญ์†Œ๋…„ ๋ณต์šฉ๋Ÿ‰ ์กฐ์ ˆ',
'์—ฐ๋ น๋ณ„ ํ•ญ์ƒ์ œ ์ฒ˜๋ฐฉ ํ•„์š”',
'ํ•ด์—ด์ œ ๋ณต์šฉ ์‹œ ์ฃผ์˜',
'์œ„์‚ฐ ๋ถ„๋น„ ์กฐ์ ˆ์ œ ์ฃผ์˜'
]
}
return pd.DataFrame(data)
def run_test():
"""ํ…Œ์ŠคํŠธ ์‹คํ–‰"""
print("=== ์˜๋ฃŒ ์ƒ๋‹ด ์ฑ—๋ด‡ ํ…Œ์ŠคํŠธ ์‹œ์ž‘ ===")
# ์ƒ˜ํ”Œ ๋ฐ์ดํ„ฐ ์ƒ์„ฑ ๋ฐ ์ž„์‹œ ํŒŒ์ผ ์ €์žฅ
sample_df = create_sample_data()
temp_file = "/content/sample_data/sick.xlsx"
sample_df.to_excel(temp_file, index=False)
print("์ƒ˜ํ”Œ ๋ฐ์ดํ„ฐ ์ƒ์„ฑ ์™„๋ฃŒ")
# ์ฑ—๋ด‡ ์ธ์Šคํ„ด์Šค ์ƒ์„ฑ
chatbot = SimplifiedMedicalChatbot(temp_file)
# ํ…Œ์ŠคํŠธ ์ผ€์ด์Šค๋“ค
test_cases = [
{
"userId": "user1",
"bornYear": "1990",
"sex": "male",
"pain": "๋‘ํ†ต",
"description": "ํ•œ์ชฝ ๋จธ๋ฆฌ๊ฐ€ ์•„ํ”„๊ณ  ๋ฉ”์Šค๊บผ์›Œ์š”",
"mediTF": True
},
{
"userId": "user2",
"bornYear": "2015",
"sex": "female",
"pain": "๋ณตํ†ต",
"description": "๋ฐฐ๊ฐ€ ์•„ํ”„๊ณ  ๊ตฌํ† ๋ฅผ ํ•ด์š”",
"mediTF": True
}
]
# ๊ฐ ํ…Œ์ŠคํŠธ ์ผ€์ด์Šค ์‹คํ–‰
for i, test_case in enumerate(test_cases, 1):
print(f"\n=== ํ…Œ์ŠคํŠธ ์ผ€์ด์Šค {i} ===")
print("์ž…๋ ฅ:", test_case)
# ์š”์ฒญ ์ฒ˜๋ฆฌ
result = chatbot.process_request(test_case)
# ๊ฒฐ๊ณผ ์ถœ๋ ฅ
print("\n[์ฑ—๋ด‡ ์‘๋‹ต]")
print(result["chat_message"])
print("\n[JSON ์‘๋‹ต]")
print(result["json_response"])
print("\n" + "="*50)
if __name__ == "__main__":
run_test()