Spaces:
Sleeping
Sleeping
Upload 8 files
Browse files- src/chatbot.py +30 -0
- src/config.py +49 -0
- src/download_models.py +57 -0
- src/gemini.py +32 -0
- src/graph.py +111 -0
- src/model_loader.py +45 -0
- src/util.py +213 -0
- src/vag_util.py +29 -0
src/chatbot.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import numpy as np
|
3 |
+
from util import load_uploaded_file, segment_signal
|
4 |
+
from gemini import query_gemini_rest
|
5 |
+
|
6 |
+
CLASSES = ["N", "V", "/", "A", "F", "~"]
|
7 |
+
LABEL_MAP = {
|
8 |
+
"N": "Normal sinus beat",
|
9 |
+
"V": "Premature Ventricular Contraction (PVC)",
|
10 |
+
"/": "Paced beat (pacemaker)",
|
11 |
+
"A": "Atrial premature beat",
|
12 |
+
"F": "Fusion of ventricular & normal beat",
|
13 |
+
"~": "Unclassifiable / noise"
|
14 |
+
}
|
15 |
+
|
16 |
+
def analyze_signal(file, model, gemini_key="", signal_type="ECG"):
|
17 |
+
signal = load_uploaded_file(file, signal_type)
|
18 |
+
segments = segment_signal(signal)
|
19 |
+
|
20 |
+
preds = model.predict(segments, verbose=0)[0]
|
21 |
+
idx = int(np.argmax(preds))
|
22 |
+
conf = float(preds[idx])
|
23 |
+
label = CLASSES[idx]
|
24 |
+
human = LABEL_MAP[label]
|
25 |
+
|
26 |
+
gemini_txt = None
|
27 |
+
if gemini_key:
|
28 |
+
gemini_txt = query_gemini_rest(signal_type, human, conf, gemini_key)
|
29 |
+
|
30 |
+
return label, human, conf, gemini_txt
|
src/config.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import argparse
|
3 |
+
|
4 |
+
parser = argparse.ArgumentParser()
|
5 |
+
|
6 |
+
def add_argument_group(name):
|
7 |
+
arg = parser.add_argument_group(name)
|
8 |
+
return arg
|
9 |
+
|
10 |
+
|
11 |
+
misc_arg = add_argument_group('misc')
|
12 |
+
misc_arg.add_argument('--split', type=bool, default = True)
|
13 |
+
misc_arg.add_argument('--input_size', type=int, default = 256,
|
14 |
+
help='multiplies of 256 by the structure of the model')
|
15 |
+
misc_arg.add_argument('--use_network', type=bool, default = False)
|
16 |
+
|
17 |
+
data_arg = add_argument_group('data')
|
18 |
+
data_arg.add_argument('--downloading', type=bool, default = False)
|
19 |
+
|
20 |
+
graph_arg = add_argument_group('graph')
|
21 |
+
graph_arg.add_argument('--filter_length', type=int, default = 32)
|
22 |
+
graph_arg.add_argument('--kernel_size', type=int, default = 16)
|
23 |
+
graph_arg.add_argument('--drop_rate', type=float, default = 0.2)
|
24 |
+
|
25 |
+
train_arg = add_argument_group('train')
|
26 |
+
train_arg.add_argument('--feature', type=str, default = "MLII",
|
27 |
+
help='one of MLII, V1, V2, V4, V5. Favorably MLII or V1')
|
28 |
+
train_arg.add_argument('--epochs', type=int, default = 80)
|
29 |
+
train_arg.add_argument('--batch', type=int, default = 256)
|
30 |
+
train_arg.add_argument('--patience', type=int, default = 10)
|
31 |
+
train_arg.add_argument('--min_lr', type=float, default = 0.00005)
|
32 |
+
train_arg.add_argument('--checkpoint_path', type=str, default = None)
|
33 |
+
train_arg.add_argument('--resume_epoch', type=int)
|
34 |
+
train_arg.add_argument('--ensemble', type=bool, default = False)
|
35 |
+
train_arg.add_argument('--trained_model', type=str, default = None,
|
36 |
+
help='dir and filename of the trained model for usage.')
|
37 |
+
|
38 |
+
predict_arg = add_argument_group('predict')
|
39 |
+
predict_arg.add_argument('--num', type=int, default = None)
|
40 |
+
predict_arg.add_argument('--upload', type=bool, default = False)
|
41 |
+
predict_arg.add_argument('--sample_rate', type=int, default = None)
|
42 |
+
predict_arg.add_argument('--cinc_download', type=bool, default = False)
|
43 |
+
|
44 |
+
|
45 |
+
|
46 |
+
def get_config():
|
47 |
+
config, unparsed = parser.parse_known_args()
|
48 |
+
|
49 |
+
return config
|
src/download_models.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gdown
|
3 |
+
from dotenv import load_dotenv
|
4 |
+
|
5 |
+
|
6 |
+
load_dotenv()
|
7 |
+
|
8 |
+
def extract_file_id_from_url(url):
|
9 |
+
"""Extract file ID from Google Drive URL"""
|
10 |
+
if "drive.google.com" in url:
|
11 |
+
if "/file/d/" in url:
|
12 |
+
return url.split("/file/d/")[1].split("/")[0]
|
13 |
+
elif "id=" in url:
|
14 |
+
return url.split("id=")[1].split("&")[0]
|
15 |
+
return url
|
16 |
+
|
17 |
+
def get_model_urls():
|
18 |
+
"""Get model URLs from environment variables"""
|
19 |
+
return {
|
20 |
+
"../models/MLII-latest.keras": os.getenv("ECG_MODEL_URL", ""),
|
21 |
+
"../models/pcg_model.h5": os.getenv("PCG_MODEL_URL", ""),
|
22 |
+
"../models/emg_model.h5": os.getenv("EMG_MODEL_URL", ""),
|
23 |
+
"../models/vag_feature_classifier.pkl": os.getenv("VAG_MODEL_URL", "")
|
24 |
+
}
|
25 |
+
|
26 |
+
def download_from_gdrive(url, output_path):
|
27 |
+
"""Download file from Google Drive using gdown"""
|
28 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
29 |
+
|
30 |
+
|
31 |
+
file_id = extract_file_id_from_url(url)
|
32 |
+
|
33 |
+
|
34 |
+
download_url = f"https://drive.google.com/uc?export=download&id={file_id}"
|
35 |
+
gdown.download(download_url, output_path, quiet=False)
|
36 |
+
|
37 |
+
def ensure_models_downloaded():
|
38 |
+
"""Download models if they don't exist locally"""
|
39 |
+
model_urls = get_model_urls()
|
40 |
+
|
41 |
+
for local_path, url in model_urls.items():
|
42 |
+
if not url:
|
43 |
+
print(f"⚠️ No URL found for {local_path}")
|
44 |
+
continue
|
45 |
+
|
46 |
+
if not os.path.exists(local_path):
|
47 |
+
print(f"Downloading {local_path}...")
|
48 |
+
try:
|
49 |
+
download_from_gdrive(url, local_path)
|
50 |
+
print(f"✅ Downloaded {local_path}")
|
51 |
+
except Exception as e:
|
52 |
+
print(f"❌ Failed to download {local_path}: {e}")
|
53 |
+
else:
|
54 |
+
print(f"✅ {local_path} already exists")
|
55 |
+
|
56 |
+
if __name__ == "__main__":
|
57 |
+
ensure_models_downloaded()
|
src/gemini.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
|
3 |
+
GEMINI_ENDPOINT = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent"
|
4 |
+
|
5 |
+
def query_gemini_rest(signal_type, label, confidence, api_key):
|
6 |
+
headers = {
|
7 |
+
"Content-Type": "application/json",
|
8 |
+
"X-goog-api-key": api_key,
|
9 |
+
}
|
10 |
+
|
11 |
+
prompt = (
|
12 |
+
f"Explain the meaning of a {signal_type} signal classified as '{label}' "
|
13 |
+
f"with a confidence of {confidence:.1%} in a medical diagnostic context."
|
14 |
+
)
|
15 |
+
|
16 |
+
payload = {
|
17 |
+
"contents": [
|
18 |
+
{
|
19 |
+
"parts": [
|
20 |
+
{"text": prompt}
|
21 |
+
]
|
22 |
+
}
|
23 |
+
]
|
24 |
+
}
|
25 |
+
|
26 |
+
try:
|
27 |
+
response = requests.post(GEMINI_ENDPOINT, headers=headers, json=payload)
|
28 |
+
response.raise_for_status()
|
29 |
+
content = response.json()
|
30 |
+
return content["candidates"][0]["content"]["parts"][0]["text"]
|
31 |
+
except Exception as e:
|
32 |
+
return f"⚠️ Gemini API error: {str(e)}"
|
src/graph.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import division, print_function
|
2 |
+
from keras.models import Model
|
3 |
+
from keras.layers import Input, Conv1D, Dense, add, Flatten, Dropout,MaxPooling1D, Activation, BatchNormalization, Lambda
|
4 |
+
from keras import backend as K
|
5 |
+
from keras.optimizers import Adam
|
6 |
+
from keras.saving import register_keras_serializable
|
7 |
+
import tensorflow as tf
|
8 |
+
|
9 |
+
@register_keras_serializable(package="custom")
|
10 |
+
def zeropad(x):
|
11 |
+
"""
|
12 |
+
zeropad and zeropad_output_shapes are from
|
13 |
+
https://github.com/awni/ecg/blob/master/ecg/network.py
|
14 |
+
"""
|
15 |
+
y = tf.zeros_like(x)
|
16 |
+
return tf.concat([x, y], axis=2)
|
17 |
+
|
18 |
+
@register_keras_serializable(package="custom")
|
19 |
+
def zeropad_output_shape(input_shape):
|
20 |
+
shape = list(input_shape)
|
21 |
+
assert len(shape) == 3
|
22 |
+
shape[2] *= 2
|
23 |
+
return tuple(shape)
|
24 |
+
|
25 |
+
|
26 |
+
def ECG_model(config):
|
27 |
+
"""
|
28 |
+
implementation of the model in https://www.nature.com/articles/s41591-018-0268-3
|
29 |
+
also have reference to codes at
|
30 |
+
https://github.com/awni/ecg/blob/master/ecg/network.py
|
31 |
+
and
|
32 |
+
https://github.com/fernandoandreotti/cinc-challenge2017/blob/master/deeplearn-approach/train_model.py
|
33 |
+
"""
|
34 |
+
def first_conv_block(inputs, config):
|
35 |
+
layer = Conv1D(filters=config.filter_length,
|
36 |
+
kernel_size=config.kernel_size,
|
37 |
+
padding='same',
|
38 |
+
strides=1,
|
39 |
+
kernel_initializer='he_normal')(inputs)
|
40 |
+
layer = BatchNormalization()(layer)
|
41 |
+
layer = Activation('relu')(layer)
|
42 |
+
|
43 |
+
shortcut = MaxPooling1D(pool_size=1,
|
44 |
+
strides=1)(layer)
|
45 |
+
|
46 |
+
layer = Conv1D(filters=config.filter_length,
|
47 |
+
kernel_size=config.kernel_size,
|
48 |
+
padding='same',
|
49 |
+
strides=1,
|
50 |
+
kernel_initializer='he_normal')(layer)
|
51 |
+
layer = BatchNormalization()(layer)
|
52 |
+
layer = Activation('relu')(layer)
|
53 |
+
layer = Dropout(config.drop_rate)(layer)
|
54 |
+
layer = Conv1D(filters=config.filter_length,
|
55 |
+
kernel_size=config.kernel_size,
|
56 |
+
padding='same',
|
57 |
+
strides=1,
|
58 |
+
kernel_initializer='he_normal')(layer)
|
59 |
+
return add([shortcut, layer])
|
60 |
+
|
61 |
+
def main_loop_blocks(layer, config):
|
62 |
+
filter_length = config.filter_length
|
63 |
+
n_blocks = 15
|
64 |
+
for block_index in range(n_blocks):
|
65 |
+
|
66 |
+
subsample_length = 2 if block_index % 2 == 0 else 1
|
67 |
+
shortcut = MaxPooling1D(pool_size=subsample_length)(layer)
|
68 |
+
|
69 |
+
if block_index % 4 == 0 and block_index > 0 :
|
70 |
+
shortcut = Lambda(zeropad, output_shape=zeropad_output_shape)(shortcut)
|
71 |
+
filter_length *= 2
|
72 |
+
|
73 |
+
layer = BatchNormalization()(layer)
|
74 |
+
layer = Activation('relu')(layer)
|
75 |
+
layer = Conv1D(filters= filter_length,
|
76 |
+
kernel_size=config.kernel_size,
|
77 |
+
padding='same',
|
78 |
+
strides=subsample_length,
|
79 |
+
kernel_initializer='he_normal')(layer)
|
80 |
+
layer = BatchNormalization()(layer)
|
81 |
+
layer = Activation('relu')(layer)
|
82 |
+
layer = Dropout(config.drop_rate)(layer)
|
83 |
+
layer = Conv1D(filters= filter_length,
|
84 |
+
kernel_size=config.kernel_size,
|
85 |
+
padding='same',
|
86 |
+
strides= 1,
|
87 |
+
kernel_initializer='he_normal')(layer)
|
88 |
+
layer = add([shortcut, layer])
|
89 |
+
return layer
|
90 |
+
|
91 |
+
def output_block(layer, config):
|
92 |
+
layer = BatchNormalization()(layer)
|
93 |
+
layer = Activation('relu')(layer)
|
94 |
+
layer = Flatten()(layer)
|
95 |
+
outputs = Dense(len_classes, activation='softmax')(layer)
|
96 |
+
model = Model(inputs=inputs, outputs=outputs)
|
97 |
+
|
98 |
+
adam = Adam(learning_rate=0.1, beta_1=0.9, beta_2=0.999, epsilon=1e-7, amsgrad=False)
|
99 |
+
model.compile(optimizer= adam,
|
100 |
+
loss='categorical_crossentropy',
|
101 |
+
metrics=['accuracy'])
|
102 |
+
model.summary()
|
103 |
+
return model
|
104 |
+
|
105 |
+
classes = ['N','V','/','A','F','~']
|
106 |
+
len_classes = len(classes)
|
107 |
+
|
108 |
+
inputs = Input(shape=(config.input_size, 1), name='input')
|
109 |
+
layer = first_conv_block(inputs, config)
|
110 |
+
layer = main_loop_blocks(layer, config)
|
111 |
+
return output_block(layer, config)
|
src/model_loader.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from keras.models import load_model
|
3 |
+
from graph import zeropad, zeropad_output_shape
|
4 |
+
from pathlib import Path
|
5 |
+
import joblib
|
6 |
+
from download_models import ensure_models_downloaded
|
7 |
+
|
8 |
+
def load_mitbih_model():
|
9 |
+
ensure_models_downloaded()
|
10 |
+
return load_model(
|
11 |
+
"models/MLII-latest.keras",
|
12 |
+
custom_objects={
|
13 |
+
"zeropad": zeropad,
|
14 |
+
"zeropad_output_shape": zeropad_output_shape
|
15 |
+
},
|
16 |
+
compile=False
|
17 |
+
)
|
18 |
+
|
19 |
+
def load_pcg_model():
|
20 |
+
ensure_models_downloaded()
|
21 |
+
model_path = Path("models/pcg_model.h5")
|
22 |
+
if not model_path.exists():
|
23 |
+
raise FileNotFoundError(f"PCG model not found at {model_path.resolve()}")
|
24 |
+
|
25 |
+
model = load_model(model_path, compile=False)
|
26 |
+
model.compile()
|
27 |
+
return model
|
28 |
+
|
29 |
+
def load_emg_model():
|
30 |
+
ensure_models_downloaded()
|
31 |
+
model_path = Path("models/emg_classifier_txt.h5")
|
32 |
+
if not model_path.exists():
|
33 |
+
raise FileNotFoundError(f"EMG model not found at {model_path.resolve()}")
|
34 |
+
model = load_model(model_path, compile=False)
|
35 |
+
model.compile()
|
36 |
+
return model
|
37 |
+
|
38 |
+
|
39 |
+
def load_vag_model():
|
40 |
+
ensure_models_downloaded()
|
41 |
+
p = Path("models/vag_feature_classifier.pkl")
|
42 |
+
if not p.exists():
|
43 |
+
raise FileNotFoundError(f"No VAG model at {p.resolve()}")
|
44 |
+
return joblib.load(p)
|
45 |
+
|
src/util.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import numpy as np
|
3 |
+
import pandas as pd
|
4 |
+
from scipy.signal import resample
|
5 |
+
from sklearn.preprocessing import scale
|
6 |
+
import soundfile as sf
|
7 |
+
from gemini import query_gemini_rest
|
8 |
+
import librosa
|
9 |
+
import tempfile
|
10 |
+
|
11 |
+
|
12 |
+
EXPECTED_LEN = 256
|
13 |
+
STEP = 128
|
14 |
+
|
15 |
+
PCG_LABELS = [
|
16 |
+
"Normal",
|
17 |
+
"Aortic Stenosis",
|
18 |
+
"Mitral Stenosis",
|
19 |
+
"Mitral Valve Prolapse",
|
20 |
+
"Pericardial Murmurs"
|
21 |
+
]
|
22 |
+
|
23 |
+
LABELS_EMG = ["healthy", "myopathy", "neuropathy"]
|
24 |
+
|
25 |
+
def load_uploaded_file(file, signal_type="ECG") -> np.ndarray:
|
26 |
+
name = file.name.lower()
|
27 |
+
|
28 |
+
|
29 |
+
if signal_type in ("ECG", "EMG"):
|
30 |
+
text = file.read().decode("utf-8").strip()
|
31 |
+
if "," in text:
|
32 |
+
vals = [float(x) for x in text.split(",") if x.strip()]
|
33 |
+
else:
|
34 |
+
vals = [float(x) for x in text.splitlines() if x.strip()]
|
35 |
+
return np.array(vals, dtype=np.float32)
|
36 |
+
|
37 |
+
|
38 |
+
if signal_type == "VAG":
|
39 |
+
if name.endswith(".csv"):
|
40 |
+
df = pd.read_csv(file)
|
41 |
+
features = [
|
42 |
+
"rms_amplitude",
|
43 |
+
"peak_frequency",
|
44 |
+
"spectral_entropy",
|
45 |
+
"zero_crossing_rate",
|
46 |
+
"mean_frequency",
|
47 |
+
]
|
48 |
+
return df[features].iloc[0].values.astype(np.float32)
|
49 |
+
|
50 |
+
elif name.endswith(".npy"):
|
51 |
+
return np.load(file)
|
52 |
+
|
53 |
+
elif name.endswith(".wav"):
|
54 |
+
data, _ = sf.read(file)
|
55 |
+
return data.astype(np.float32)
|
56 |
+
|
57 |
+
raise ValueError("Unsupported VAG file format.")
|
58 |
+
|
59 |
+
|
60 |
+
if signal_type == "PCG" and name.endswith((".wav", ".flac", ".mp3")):
|
61 |
+
data, _ = sf.read(file)
|
62 |
+
if data.ndim > 1:
|
63 |
+
data = data[:, 0]
|
64 |
+
return data.astype(np.float32)
|
65 |
+
|
66 |
+
raise ValueError("Unsupported file format.")
|
67 |
+
|
68 |
+
|
69 |
+
def preprocess_signal(x: np.ndarray) -> np.ndarray:
|
70 |
+
if x.size != EXPECTED_LEN:
|
71 |
+
x = resample(x, EXPECTED_LEN)
|
72 |
+
return scale(x).astype(np.float32)
|
73 |
+
|
74 |
+
|
75 |
+
def segment_signal(raw: np.ndarray) -> np.ndarray:
|
76 |
+
raw = preprocess_signal(raw)
|
77 |
+
seg = raw.reshape(EXPECTED_LEN, 1)
|
78 |
+
return seg[np.newaxis, ...]
|
79 |
+
|
80 |
+
|
81 |
+
|
82 |
+
PCG_INPUT_LEN = 995
|
83 |
+
|
84 |
+
def preprocess_pcg_waveform(wave: np.ndarray) -> np.ndarray:
|
85 |
+
|
86 |
+
if wave.ndim > 1:
|
87 |
+
wave = wave.mean(axis=1)
|
88 |
+
|
89 |
+
|
90 |
+
if len(wave) < PCG_INPUT_LEN:
|
91 |
+
wave = np.pad(wave, (0, PCG_INPUT_LEN - len(wave)))
|
92 |
+
else:
|
93 |
+
wave = wave[:PCG_INPUT_LEN]
|
94 |
+
|
95 |
+
|
96 |
+
wave = (wave - np.mean(wave)) / (np.std(wave) + 1e-8)
|
97 |
+
return wave.astype(np.float32)
|
98 |
+
|
99 |
+
def analyze_pcg_signal(file, model, gemini_key=None):
|
100 |
+
|
101 |
+
signal, _ = sf.read(file)
|
102 |
+
signal = preprocess_pcg_waveform(signal)
|
103 |
+
|
104 |
+
input_data = signal.reshape(1, PCG_INPUT_LEN, 1)
|
105 |
+
preds = model.predict(input_data, verbose=0)[0]
|
106 |
+
|
107 |
+
labels = [
|
108 |
+
"Normal",
|
109 |
+
"Aortic Stenosis",
|
110 |
+
"Mitral Stenosis",
|
111 |
+
"Mitral Valve Prolapse",
|
112 |
+
"Pericardial Murmurs",
|
113 |
+
]
|
114 |
+
idx = int(np.argmax(preds))
|
115 |
+
confidence = float(preds[idx])
|
116 |
+
label = labels[idx]
|
117 |
+
|
118 |
+
gem_txt = None
|
119 |
+
if gemini_key:
|
120 |
+
gem_txt = query_gemini_rest("PCG", label, confidence, gemini_key)
|
121 |
+
|
122 |
+
return label, label, confidence, gem_txt
|
123 |
+
|
124 |
+
|
125 |
+
|
126 |
+
|
127 |
+
def pcg_to_features(file_obj, target_sr=16000, n_mels=128, n_frames=112):
|
128 |
+
|
129 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
|
130 |
+
tmp.write(file_obj.read())
|
131 |
+
tmp_path = tmp.name
|
132 |
+
|
133 |
+
|
134 |
+
y, sr = librosa.load(tmp_path, sr=target_sr, mono=True)
|
135 |
+
|
136 |
+
|
137 |
+
mel = librosa.feature.melspectrogram(y=y, sr=sr, n_fft=512, hop_length=256, n_mels=n_mels)
|
138 |
+
logmel = librosa.power_to_db(mel, ref=np.max)
|
139 |
+
if logmel.shape[1] < n_frames:
|
140 |
+
|
141 |
+
pad_width = n_frames - logmel.shape[1]
|
142 |
+
pad = np.zeros((n_mels, pad_width))
|
143 |
+
logmel = np.hstack((logmel, pad))
|
144 |
+
else:
|
145 |
+
logmel = logmel[:, :n_frames]
|
146 |
+
|
147 |
+
|
148 |
+
feat = logmel.flatten().astype(np.float32)
|
149 |
+
|
150 |
+
return feat[np.newaxis, ...]
|
151 |
+
|
152 |
+
|
153 |
+
def analyze_emg_signal(file, model, gemini_key=""):
|
154 |
+
raw = load_uploaded_file(file, signal_type="EMG")
|
155 |
+
|
156 |
+
WINDOW = 1000
|
157 |
+
|
158 |
+
wins = []
|
159 |
+
if len(raw) < WINDOW:
|
160 |
+
pad = np.pad(raw, (0, WINDOW - len(raw)))
|
161 |
+
wins.append(((pad - pad.mean()) / (pad.std()+1e-6)).reshape(WINDOW, 1))
|
162 |
+
else:
|
163 |
+
for i in range(0, len(raw) - WINDOW + 1, WINDOW):
|
164 |
+
win = raw[i:i+WINDOW]
|
165 |
+
win = (win - win.mean()) / (win.std() + 1e-6)
|
166 |
+
wins.append(win.reshape(WINDOW, 1))
|
167 |
+
X = np.array(wins, dtype=np.float32)
|
168 |
+
|
169 |
+
preds = model.predict(X, verbose=0)
|
170 |
+
classes = np.argmax(preds, axis=1)
|
171 |
+
final = int(np.bincount(classes).argmax())
|
172 |
+
conf = float(preds[:, final].mean())
|
173 |
+
human = LABELS_EMG[final]
|
174 |
+
|
175 |
+
gemini_txt = None
|
176 |
+
if gemini_key:
|
177 |
+
gemini_txt = query_gemini_rest("EMG", human, conf, gemini_key)
|
178 |
+
|
179 |
+
return human, conf, gemini_txt
|
180 |
+
|
181 |
+
|
182 |
+
|
183 |
+
FEATURE_COLS = [
|
184 |
+
"rms_amplitude",
|
185 |
+
"peak_frequency",
|
186 |
+
"spectral_entropy",
|
187 |
+
"zero_crossing_rate",
|
188 |
+
"mean_frequency",
|
189 |
+
]
|
190 |
+
|
191 |
+
def vag_to_features(file_obj) -> np.ndarray:
|
192 |
+
df = pd.read_csv(file_obj)
|
193 |
+
x = df[FEATURE_COLS].iloc[0].values.astype(np.float32)
|
194 |
+
return x.reshape(1, -1)
|
195 |
+
|
196 |
+
|
197 |
+
def predict_vag_from_features(file_obj, model_bundle, gemini_key=""):
|
198 |
+
model = model_bundle["model"]
|
199 |
+
scaler = model_bundle["scaler"]
|
200 |
+
encoder = model_bundle["encoder"]
|
201 |
+
|
202 |
+
x = vag_to_features(file_obj)
|
203 |
+
x_s = scaler.transform(x)
|
204 |
+
prob = model.predict_proba(x_s)[0]
|
205 |
+
idx = int(np.argmax(prob))
|
206 |
+
conf = float(prob[idx])
|
207 |
+
label = encoder.inverse_transform([idx])[0].title()
|
208 |
+
|
209 |
+
gem_note = (
|
210 |
+
query_gemini_rest("VAG", label, conf, gemini_key)
|
211 |
+
if gemini_key else None
|
212 |
+
)
|
213 |
+
return label, label, conf, gem_note
|
src/vag_util.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
def predict_vag_from_features(file, model, gemini_key=None):
|
7 |
+
df = pd.read_csv(file)
|
8 |
+
required_features = [
|
9 |
+
"rms_amplitude",
|
10 |
+
"peak_frequency",
|
11 |
+
"spectral_entropy",
|
12 |
+
"zero_crossing_rate",
|
13 |
+
"mean_frequency"
|
14 |
+
]
|
15 |
+
|
16 |
+
x = df[required_features].values.astype(np.float32)
|
17 |
+
preds = model.predict_proba(x)[0]
|
18 |
+
idx = int(np.argmax(preds))
|
19 |
+
confidence = float(preds[idx])
|
20 |
+
|
21 |
+
labels = ["normal", "osteoarthritis", "ligament_injury"]
|
22 |
+
label = labels[idx]
|
23 |
+
|
24 |
+
gem_txt = None
|
25 |
+
if gemini_key:
|
26 |
+
from gemini import query_gemini_rest
|
27 |
+
gem_txt = query_gemini_rest("VAG", label, confidence, gemini_key)
|
28 |
+
|
29 |
+
return label, label, confidence, gem_txt
|