|
|
|
|
|
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
import torch |
|
import pandas as pd |
|
from datetime import datetime |
|
|
|
class SimplifiedMedicalChatbot: |
|
def __init__(self, data_path="/content/sample_data/sick.xlsx"): |
|
self.load_database(data_path) |
|
self.disclaimer = """ |
|
์ฃผ์: ์ด ์ฑ๋ด์ ์ฐธ๊ณ ์ฉ์ผ๋ก๋ง ์ฌ์ฉํ์๊ธฐ ๋ฐ๋๋๋ค. |
|
์ ํํ ์ง๋จ์ ์ํด์๋ ๋ฐ๋์ ์๋ฃ ์ ๋ฌธ๊ฐ์ ์๋ดํ์ธ์. |
|
์ฆ์์ด ์ฌ๊ฐํ๋ค๊ณ ํ๋จ๋๋ฉด ์ฆ์ ๋ณ์์ ๋ฐฉ๋ฌธํ์๊ธฐ ๋ฐ๋๋๋ค. |
|
""" |
|
|
|
def load_database(self, data_path): |
|
try: |
|
self.df = pd.read_excel(data_path) |
|
self.df = self.df.fillna('') |
|
print(f"๋ฐ์ดํฐ๋ฒ ์ด์ค ๋ก๋ ์๋ฃ: {len(self.df)} ๊ฐ์ ๋ ์ฝ๋") |
|
except Exception as e: |
|
print(f"๋ฐ์ดํฐ๋ฒ ์ด์ค ๋ก๋ฉ ์คํจ: {str(e)}") |
|
self.df = pd.DataFrame() |
|
|
|
def calculate_age(self, born_year): |
|
"""์ถ์์ฐ๋๋ก ๋์ด ๊ณ์ฐ""" |
|
current_year = datetime.now().year |
|
return current_year - int(born_year) |
|
|
|
def get_age_group(self, age): |
|
if age < 12: |
|
return "์ด๋ฆฐ์ด" |
|
elif age >= 65: |
|
return "๋
ธ์ธ" |
|
else: |
|
return "์ฑ์ธ" |
|
|
|
def simple_symptom_matching(self, pain, description): |
|
"""์ฆ์ ๋งค์นญ ๋ก์ง - pain๊ณผ description ํ์ฉ, ์ ๋ขฐ๋ ์ ์ ๊ณ์ฐ""" |
|
if '์ฆ์' not in self.df.columns: |
|
print("์ฆ์ ์ปฌ๋ผ์ด ๋ฐ์ดํฐ๋ฒ ์ด์ค์ ์์ต๋๋ค.") |
|
return [] |
|
|
|
matched_diseases = [] |
|
symptoms = f"{pain} {description}".strip().split() |
|
|
|
for _, row in self.df.iterrows(): |
|
db_symptoms = row['์ฆ์'].split() |
|
|
|
|
|
score = 0 |
|
|
|
|
|
if pain.strip() and pain in row['์ฆ์']: |
|
score += 60 |
|
|
|
|
|
desc_symptoms = description.strip().split() if description else [] |
|
for symptom in desc_symptoms: |
|
if symptom in row['์ฆ์']: |
|
score += 20 |
|
|
|
|
|
if score > 0 or pain in row['์ฆ์']: |
|
matched_diseases.append((row['์งํ๋ช
'], score)) |
|
|
|
|
|
sorted_diseases = sorted(matched_diseases, key=lambda x: x[1], reverse=True) |
|
|
|
if len(sorted_diseases) < 3: |
|
|
|
for _, row in self.df.iterrows(): |
|
disease_name = row['์งํ๋ช
'] |
|
if disease_name not in [d[0] for d in sorted_diseases]: |
|
if pain in row['์ฆ์']: |
|
sorted_diseases.append((disease_name, 10)) |
|
if len(sorted_diseases) >= 3: |
|
break |
|
|
|
|
|
return sorted_diseases[:3] |
|
|
|
def get_medicine_info(self, disease, age_group, medi_tf=False): |
|
"""์ง๋ณ๊ณผ ๋์ด๋์ ๋ฐ๋ฅธ ์ฝํ ์ ๋ณด ๋ฐํ""" |
|
try: |
|
disease_info = self.df[self.df['์งํ๋ช
'] == disease] |
|
|
|
if disease_info.empty: |
|
return { |
|
"ageWarning": "", |
|
"medicine": "", |
|
"ingredients": "", |
|
"medicalAttention": [], |
|
"error": "ํด๋น ์ง๋ณ์ ๋ํ ์ ๋ณด๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค." |
|
} |
|
|
|
result = { |
|
"ageWarning": "", |
|
"medicine": "", |
|
"ingredients": "", |
|
"medicalAttention": [] |
|
} |
|
|
|
|
|
age_specific_info = disease_info['๋์ด๋ '].iloc[0] |
|
if age_specific_info: |
|
if age_group == "์ด๋ฆฐ์ด": |
|
result["ageWarning"] = f"์ด๋ฆฐ์ด(12์ธ ๋ฏธ๋ง) ์ฃผ์์ฌํญ: {age_specific_info}" |
|
elif age_group == "๋
ธ์ธ": |
|
result["ageWarning"] = f"๋
ธ์ธ(65์ธ ์ด์) ์ฃผ์์ฌํญ: {age_specific_info}" |
|
|
|
|
|
if medi_tf: |
|
if disease_info['์ถ์ฒ ์์ฝํ'].iloc[0]: |
|
result["medicine"] = disease_info['์ถ์ฒ ์์ฝํ'].iloc[0] |
|
if disease_info['์ฑ๋ถ'].iloc[0]: |
|
result["ingredients"] = disease_info['์ฑ๋ถ'].iloc[0] |
|
|
|
|
|
medical_attention = disease_info['์์ฌ์ ์ง๋ฃ๊ฐ ํ์ํ ๊ฒฝ์ฐ'].iloc[0] |
|
if medical_attention: |
|
attention_list = [item.strip() for item in medical_attention.split(' |
|
')] |
|
result["medicalAttention"] = attention_list |
|
|
|
return result |
|
|
|
except Exception as e: |
|
print(f"์์ฝํ ์ ๋ณด ์กฐํ ์ค ์ค๋ฅ: {str(e)}") |
|
return { |
|
"ageWarning": "", |
|
"medicine": "", |
|
"ingredients": "", |
|
"medicalAttention": [], |
|
"error": f"์ ๋ณด ์กฐํ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}" |
|
} |
|
|
|
def format_chat_response(self, response_data): |
|
"""JSON ์๋ต์ ์ฌ์ฉ์ ์นํ์ ์ธ ํ
์คํธ๋ก ๋ณํ""" |
|
chat_response = [] |
|
|
|
|
|
chat_response.append(f"\n[{response_data['ageGroup']}({response_data['age']}์ธ) / {response_data['sex']} ์ฌ์ฉ์ ๋ถ์ ๊ฒฐ๊ณผ]") |
|
|
|
|
|
for analysis in response_data['analysis']: |
|
if analysis['disease']: |
|
chat_response.append(f"\n## ์ถ์ ์งํ: {analysis['disease']}") |
|
if analysis['ageWarning']: |
|
chat_response.append(analysis['ageWarning']) |
|
if analysis['medicine']: |
|
chat_response.append(f"์ถ์ฒ ์์ฝํ: {analysis['medicine']}") |
|
if analysis['ingredients']: |
|
chat_response.append(f"์ฃผ์ ์ฑ๋ถ: {analysis['ingredients']}") |
|
if analysis['medicalAttention']: |
|
chat_response.append("์์ฌ์ ์ง๋ฃ๊ฐ ํ์ํ ๊ฒฝ์ฐ:") |
|
chat_response.extend(analysis['medicalAttention']) |
|
else: |
|
chat_response.append("\n์ผ์นํ๋ ์ง๋ณ์ ์ฐพ์ ์ ์์ต๋๋ค.") |
|
chat_response.append("-" * 40) |
|
|
|
|
|
chat_response.append(f"\n{response_data['disclaimer']}") |
|
|
|
return "\n".join(chat_response) |
|
|
|
def process_request(self, request_data): |
|
"""๋ฐฑ์๋ ์์ฒญ ์ฒ๋ฆฌ""" |
|
try: |
|
|
|
user_id = request_data.get("userId", "") |
|
born_year = request_data.get("bornYear", "") |
|
sex = request_data.get("sex", "") |
|
|
|
|
|
pain = request_data.get("pain", "") |
|
pain_description = request_data.get("description", "") |
|
medi_tf = request_data.get("mediTF", False) |
|
|
|
|
|
age = self.calculate_age(born_year) |
|
age_group = self.get_age_group(age) |
|
|
|
|
|
matched_diseases = self.simple_symptom_matching(pain, pain_description) |
|
|
|
|
|
response = { |
|
"userId": user_id, |
|
"age": age, |
|
"ageGroup": age_group, |
|
"sex": sex, |
|
"analysis": [] |
|
} |
|
|
|
if matched_diseases: |
|
response["analysis"] = [] |
|
for disease, _ in matched_diseases: |
|
medical_info = self.get_medicine_info(disease, age_group, medi_tf) |
|
disease_info = { |
|
"disease": disease, |
|
"ageWarning": medical_info["ageWarning"], |
|
"medicine": medical_info["medicine"], |
|
"ingredients": medical_info["ingredients"], |
|
"medicalAttention": medical_info["medicalAttention"] |
|
} |
|
response["analysis"].append(disease_info) |
|
else: |
|
response["analysis"].append({ |
|
"disease": None, |
|
"confidence": 0.0, |
|
"ageWarning": "", |
|
"medicine": "", |
|
"ingredients": "", |
|
"medicalAttention": [], |
|
"error": "์ผ์นํ๋ ์ง๋ณ์ ์ฐพ์ ์ ์์ต๋๋ค." |
|
}) |
|
|
|
response["disclaimer"] = self.disclaimer |
|
|
|
return { |
|
"json_response": response, |
|
"chat_message": self.format_chat_response(response) |
|
} |
|
|
|
except Exception as e: |
|
error_response = { |
|
"error": f"์์ฒญ ์ฒ๋ฆฌ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}", |
|
"disclaimer": self.disclaimer |
|
} |
|
return { |
|
"json_response": error_response, |
|
"chat_message": f"์ฃ์กํฉ๋๋ค. ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค: {str(e)}\n\n{self.disclaimer}" |
|
} |
|
|
|
def create_sample_data(): |
|
"""ํ
์คํธ์ฉ ์ํ ๋ฐ์ดํฐ ์์ฑ""" |
|
data = { |
|
'์งํ๋ช
': ['ํธ๋ํต', '๊ตฐ๋ฐ์ฑ ๋ํต', '๊ธด์ฅ์ฑ ๋ํต', '๋์๋ง์ผ', '๊ฐ๊ธฐ', '์์ผ'], |
|
'์ฆ์': [ |
|
'๋ํต ํ์ชฝ ๋จธ๋ฆฌ ํต์ฆ ๊ตฌํ ์ด์ง๋ฌ์ ๋ฉ์ค๊บผ์', |
|
'ํ์ชฝ ๋จธ๋ฆฌ ํต์ฆ ๋ํต์ฆ ์ฝง๋ฌผ ์ด์ง๋ฌ์', |
|
'๋ํต ๋ชฉํต์ฆ ์ด๊นจํต์ฆ ๋ฉ์ค๊บผ์ ์คํธ๋ ์ค', |
|
'๋ํต ๊ตฌํ ๋ฐ์ด ๋ชฉ์ด ๋ปฃ๋ปฃํจ ๋ฉ์ค๊บผ์', |
|
'๋ํต ๋ฐ์ด ๊ธฐ์นจ ์ฝง๋ฌผ', |
|
'๋ณตํต ๋ฉ์ค๊บผ์ ๊ตฌํ ์ํ๋ถ๋' |
|
], |
|
'์ถ์ฒ ์์ฝํ': [ |
|
'๊ฒ๋ณด๋ฆฐ', |
|
'ํ์ด๋ ๋', |
|
'์ด์ง์6', |
|
'ํญ์์ ์ฒ๋ฐฉ ํ์', |
|
'ํ์ด๋ ๋', |
|
'๊ฐ๋น์ค์ฝ' |
|
], |
|
'์ฑ๋ถ': [ |
|
'์ด๋ถํ๋กํ', |
|
'์์ธํธ์๋ฏธ๋
ธํ', |
|
'๋ํ๋ก์ผ', |
|
'์ฒ๋ฐฉ์ฝ๋ง ๊ฐ๋ฅ', |
|
'์์ธํธ์๋ฏธ๋
ธํ', |
|
'์๊ธด์ฐ๋ํธ๋ฅจ' |
|
], |
|
'์์ฌ์ ์ง๋ฃ๊ฐ ํ์ํ ๊ฒฝ์ฐ': [ |
|
'๋ํต์ด 24์๊ฐ ์ด์ ์ง์๋ ๋\n์์ผ๊ฐ ํ๋ ค์ง ๋', |
|
'ํต์ฆ์ด ๋งค์ฐ ์ฌํ ๋\nํ๋ฃจ ์ฌ๋ฌ๋ฒ ๋ฐ์ํ ๋', |
|
'๋ํต์ด ๋ง์ฑ์ ์ผ ๋\n์ผ์์ํ์ด ์ด๋ ค์ธ ๋', |
|
'์ฆ์ ๋ณ์ ๋ฐฉ๋ฌธ ํ์\n์๊ธ์ํฉ์ผ ์ ์์', |
|
'๊ณ ์ด์ด 3์ผ ์ด์ ์ง์๋ ๋\nํธํก๊ณค๋์ด ์์ ๋', |
|
'๋ณตํต์ด ์ฌํ๊ณ ์ง์๋ ๋\n์์ฅ ์ถํ ์ฆ์์ด ์์ ๋' |
|
], |
|
'๋์ด๋ ': [ |
|
'์งํต์ ๋ณต์ฉ๋ ์กฐ์ ํ์', |
|
'๋
ธ์ธ ํฌ์ฝ ์ฃผ์', |
|
'์ฒญ์๋
๋ณต์ฉ๋ ์กฐ์ ', |
|
'์ฐ๋ น๋ณ ํญ์์ ์ฒ๋ฐฉ ํ์', |
|
'ํด์ด์ ๋ณต์ฉ ์ ์ฃผ์', |
|
'์์ฐ ๋ถ๋น ์กฐ์ ์ ์ฃผ์' |
|
] |
|
} |
|
return pd.DataFrame(data) |
|
|
|
def run_test(): |
|
"""ํ
์คํธ ์คํ""" |
|
print("=== ์๋ฃ ์๋ด ์ฑ๋ด ํ
์คํธ ์์ ===") |
|
|
|
|
|
sample_df = create_sample_data() |
|
temp_file = "/content/sample_data/sick.xlsx" |
|
sample_df.to_excel(temp_file, index=False) |
|
print("์ํ ๋ฐ์ดํฐ ์์ฑ ์๋ฃ") |
|
|
|
|
|
chatbot = SimplifiedMedicalChatbot(temp_file) |
|
|
|
|
|
test_cases = [ |
|
{ |
|
"userId": "user1", |
|
"bornYear": "1990", |
|
"sex": "male", |
|
"pain": "๋ํต", |
|
"description": "ํ์ชฝ ๋จธ๋ฆฌ๊ฐ ์ํ๊ณ ๋ฉ์ค๊บผ์์", |
|
"mediTF": True |
|
}, |
|
{ |
|
"userId": "user2", |
|
"bornYear": "2015", |
|
"sex": "female", |
|
"pain": "๋ณตํต", |
|
"description": "๋ฐฐ๊ฐ ์ํ๊ณ ๊ตฌํ ๋ฅผ ํด์", |
|
"mediTF": True |
|
} |
|
] |
|
|
|
|
|
for i, test_case in enumerate(test_cases, 1): |
|
print(f"\n=== ํ
์คํธ ์ผ์ด์ค {i} ===") |
|
print("์
๋ ฅ:", test_case) |
|
|
|
|
|
result = chatbot.process_request(test_case) |
|
|
|
|
|
print("\n[์ฑ๋ด ์๋ต]") |
|
print(result["chat_message"]) |
|
|
|
print("\n[JSON ์๋ต]") |
|
print(result["json_response"]) |
|
|
|
print("\n" + "="*50) |
|
|
|
if __name__ == "__main__": |
|
run_test() |