niol08 commited on
Commit
a8e4c2f
·
verified ·
1 Parent(s): 486ca74

Upload 8 files

Browse files
Files changed (8) hide show
  1. src/chatbot.py +30 -0
  2. src/config.py +49 -0
  3. src/download_models.py +57 -0
  4. src/gemini.py +32 -0
  5. src/graph.py +111 -0
  6. src/model_loader.py +45 -0
  7. src/util.py +213 -0
  8. src/vag_util.py +29 -0
src/chatbot.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ from util import load_uploaded_file, segment_signal
4
+ from gemini import query_gemini_rest
5
+
6
+ CLASSES = ["N", "V", "/", "A", "F", "~"]
7
+ LABEL_MAP = {
8
+ "N": "Normal sinus beat",
9
+ "V": "Premature Ventricular Contraction (PVC)",
10
+ "/": "Paced beat (pacemaker)",
11
+ "A": "Atrial premature beat",
12
+ "F": "Fusion of ventricular & normal beat",
13
+ "~": "Unclassifiable / noise"
14
+ }
15
+
16
+ def analyze_signal(file, model, gemini_key="", signal_type="ECG"):
17
+ signal = load_uploaded_file(file, signal_type)
18
+ segments = segment_signal(signal)
19
+
20
+ preds = model.predict(segments, verbose=0)[0]
21
+ idx = int(np.argmax(preds))
22
+ conf = float(preds[idx])
23
+ label = CLASSES[idx]
24
+ human = LABEL_MAP[label]
25
+
26
+ gemini_txt = None
27
+ if gemini_key:
28
+ gemini_txt = query_gemini_rest(signal_type, human, conf, gemini_key)
29
+
30
+ return label, human, conf, gemini_txt
src/config.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import argparse
3
+
4
+ parser = argparse.ArgumentParser()
5
+
6
+ def add_argument_group(name):
7
+ arg = parser.add_argument_group(name)
8
+ return arg
9
+
10
+
11
+ misc_arg = add_argument_group('misc')
12
+ misc_arg.add_argument('--split', type=bool, default = True)
13
+ misc_arg.add_argument('--input_size', type=int, default = 256,
14
+ help='multiplies of 256 by the structure of the model')
15
+ misc_arg.add_argument('--use_network', type=bool, default = False)
16
+
17
+ data_arg = add_argument_group('data')
18
+ data_arg.add_argument('--downloading', type=bool, default = False)
19
+
20
+ graph_arg = add_argument_group('graph')
21
+ graph_arg.add_argument('--filter_length', type=int, default = 32)
22
+ graph_arg.add_argument('--kernel_size', type=int, default = 16)
23
+ graph_arg.add_argument('--drop_rate', type=float, default = 0.2)
24
+
25
+ train_arg = add_argument_group('train')
26
+ train_arg.add_argument('--feature', type=str, default = "MLII",
27
+ help='one of MLII, V1, V2, V4, V5. Favorably MLII or V1')
28
+ train_arg.add_argument('--epochs', type=int, default = 80)
29
+ train_arg.add_argument('--batch', type=int, default = 256)
30
+ train_arg.add_argument('--patience', type=int, default = 10)
31
+ train_arg.add_argument('--min_lr', type=float, default = 0.00005)
32
+ train_arg.add_argument('--checkpoint_path', type=str, default = None)
33
+ train_arg.add_argument('--resume_epoch', type=int)
34
+ train_arg.add_argument('--ensemble', type=bool, default = False)
35
+ train_arg.add_argument('--trained_model', type=str, default = None,
36
+ help='dir and filename of the trained model for usage.')
37
+
38
+ predict_arg = add_argument_group('predict')
39
+ predict_arg.add_argument('--num', type=int, default = None)
40
+ predict_arg.add_argument('--upload', type=bool, default = False)
41
+ predict_arg.add_argument('--sample_rate', type=int, default = None)
42
+ predict_arg.add_argument('--cinc_download', type=bool, default = False)
43
+
44
+
45
+
46
+ def get_config():
47
+ config, unparsed = parser.parse_known_args()
48
+
49
+ return config
src/download_models.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gdown
3
+ from dotenv import load_dotenv
4
+
5
+
6
+ load_dotenv()
7
+
8
+ def extract_file_id_from_url(url):
9
+ """Extract file ID from Google Drive URL"""
10
+ if "drive.google.com" in url:
11
+ if "/file/d/" in url:
12
+ return url.split("/file/d/")[1].split("/")[0]
13
+ elif "id=" in url:
14
+ return url.split("id=")[1].split("&")[0]
15
+ return url
16
+
17
+ def get_model_urls():
18
+ """Get model URLs from environment variables"""
19
+ return {
20
+ "../models/MLII-latest.keras": os.getenv("ECG_MODEL_URL", ""),
21
+ "../models/pcg_model.h5": os.getenv("PCG_MODEL_URL", ""),
22
+ "../models/emg_model.h5": os.getenv("EMG_MODEL_URL", ""),
23
+ "../models/vag_feature_classifier.pkl": os.getenv("VAG_MODEL_URL", "")
24
+ }
25
+
26
+ def download_from_gdrive(url, output_path):
27
+ """Download file from Google Drive using gdown"""
28
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
29
+
30
+
31
+ file_id = extract_file_id_from_url(url)
32
+
33
+
34
+ download_url = f"https://drive.google.com/uc?export=download&id={file_id}"
35
+ gdown.download(download_url, output_path, quiet=False)
36
+
37
+ def ensure_models_downloaded():
38
+ """Download models if they don't exist locally"""
39
+ model_urls = get_model_urls()
40
+
41
+ for local_path, url in model_urls.items():
42
+ if not url:
43
+ print(f"⚠️ No URL found for {local_path}")
44
+ continue
45
+
46
+ if not os.path.exists(local_path):
47
+ print(f"Downloading {local_path}...")
48
+ try:
49
+ download_from_gdrive(url, local_path)
50
+ print(f"✅ Downloaded {local_path}")
51
+ except Exception as e:
52
+ print(f"❌ Failed to download {local_path}: {e}")
53
+ else:
54
+ print(f"✅ {local_path} already exists")
55
+
56
+ if __name__ == "__main__":
57
+ ensure_models_downloaded()
src/gemini.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+
3
+ GEMINI_ENDPOINT = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent"
4
+
5
+ def query_gemini_rest(signal_type, label, confidence, api_key):
6
+ headers = {
7
+ "Content-Type": "application/json",
8
+ "X-goog-api-key": api_key,
9
+ }
10
+
11
+ prompt = (
12
+ f"Explain the meaning of a {signal_type} signal classified as '{label}' "
13
+ f"with a confidence of {confidence:.1%} in a medical diagnostic context."
14
+ )
15
+
16
+ payload = {
17
+ "contents": [
18
+ {
19
+ "parts": [
20
+ {"text": prompt}
21
+ ]
22
+ }
23
+ ]
24
+ }
25
+
26
+ try:
27
+ response = requests.post(GEMINI_ENDPOINT, headers=headers, json=payload)
28
+ response.raise_for_status()
29
+ content = response.json()
30
+ return content["candidates"][0]["content"]["parts"][0]["text"]
31
+ except Exception as e:
32
+ return f"⚠️ Gemini API error: {str(e)}"
src/graph.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division, print_function
2
+ from keras.models import Model
3
+ from keras.layers import Input, Conv1D, Dense, add, Flatten, Dropout,MaxPooling1D, Activation, BatchNormalization, Lambda
4
+ from keras import backend as K
5
+ from keras.optimizers import Adam
6
+ from keras.saving import register_keras_serializable
7
+ import tensorflow as tf
8
+
9
+ @register_keras_serializable(package="custom")
10
+ def zeropad(x):
11
+ """
12
+ zeropad and zeropad_output_shapes are from
13
+ https://github.com/awni/ecg/blob/master/ecg/network.py
14
+ """
15
+ y = tf.zeros_like(x)
16
+ return tf.concat([x, y], axis=2)
17
+
18
+ @register_keras_serializable(package="custom")
19
+ def zeropad_output_shape(input_shape):
20
+ shape = list(input_shape)
21
+ assert len(shape) == 3
22
+ shape[2] *= 2
23
+ return tuple(shape)
24
+
25
+
26
+ def ECG_model(config):
27
+ """
28
+ implementation of the model in https://www.nature.com/articles/s41591-018-0268-3
29
+ also have reference to codes at
30
+ https://github.com/awni/ecg/blob/master/ecg/network.py
31
+ and
32
+ https://github.com/fernandoandreotti/cinc-challenge2017/blob/master/deeplearn-approach/train_model.py
33
+ """
34
+ def first_conv_block(inputs, config):
35
+ layer = Conv1D(filters=config.filter_length,
36
+ kernel_size=config.kernel_size,
37
+ padding='same',
38
+ strides=1,
39
+ kernel_initializer='he_normal')(inputs)
40
+ layer = BatchNormalization()(layer)
41
+ layer = Activation('relu')(layer)
42
+
43
+ shortcut = MaxPooling1D(pool_size=1,
44
+ strides=1)(layer)
45
+
46
+ layer = Conv1D(filters=config.filter_length,
47
+ kernel_size=config.kernel_size,
48
+ padding='same',
49
+ strides=1,
50
+ kernel_initializer='he_normal')(layer)
51
+ layer = BatchNormalization()(layer)
52
+ layer = Activation('relu')(layer)
53
+ layer = Dropout(config.drop_rate)(layer)
54
+ layer = Conv1D(filters=config.filter_length,
55
+ kernel_size=config.kernel_size,
56
+ padding='same',
57
+ strides=1,
58
+ kernel_initializer='he_normal')(layer)
59
+ return add([shortcut, layer])
60
+
61
+ def main_loop_blocks(layer, config):
62
+ filter_length = config.filter_length
63
+ n_blocks = 15
64
+ for block_index in range(n_blocks):
65
+
66
+ subsample_length = 2 if block_index % 2 == 0 else 1
67
+ shortcut = MaxPooling1D(pool_size=subsample_length)(layer)
68
+
69
+ if block_index % 4 == 0 and block_index > 0 :
70
+ shortcut = Lambda(zeropad, output_shape=zeropad_output_shape)(shortcut)
71
+ filter_length *= 2
72
+
73
+ layer = BatchNormalization()(layer)
74
+ layer = Activation('relu')(layer)
75
+ layer = Conv1D(filters= filter_length,
76
+ kernel_size=config.kernel_size,
77
+ padding='same',
78
+ strides=subsample_length,
79
+ kernel_initializer='he_normal')(layer)
80
+ layer = BatchNormalization()(layer)
81
+ layer = Activation('relu')(layer)
82
+ layer = Dropout(config.drop_rate)(layer)
83
+ layer = Conv1D(filters= filter_length,
84
+ kernel_size=config.kernel_size,
85
+ padding='same',
86
+ strides= 1,
87
+ kernel_initializer='he_normal')(layer)
88
+ layer = add([shortcut, layer])
89
+ return layer
90
+
91
+ def output_block(layer, config):
92
+ layer = BatchNormalization()(layer)
93
+ layer = Activation('relu')(layer)
94
+ layer = Flatten()(layer)
95
+ outputs = Dense(len_classes, activation='softmax')(layer)
96
+ model = Model(inputs=inputs, outputs=outputs)
97
+
98
+ adam = Adam(learning_rate=0.1, beta_1=0.9, beta_2=0.999, epsilon=1e-7, amsgrad=False)
99
+ model.compile(optimizer= adam,
100
+ loss='categorical_crossentropy',
101
+ metrics=['accuracy'])
102
+ model.summary()
103
+ return model
104
+
105
+ classes = ['N','V','/','A','F','~']
106
+ len_classes = len(classes)
107
+
108
+ inputs = Input(shape=(config.input_size, 1), name='input')
109
+ layer = first_conv_block(inputs, config)
110
+ layer = main_loop_blocks(layer, config)
111
+ return output_block(layer, config)
src/model_loader.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from keras.models import load_model
3
+ from graph import zeropad, zeropad_output_shape
4
+ from pathlib import Path
5
+ import joblib
6
+ from download_models import ensure_models_downloaded
7
+
8
+ def load_mitbih_model():
9
+ ensure_models_downloaded()
10
+ return load_model(
11
+ "models/MLII-latest.keras",
12
+ custom_objects={
13
+ "zeropad": zeropad,
14
+ "zeropad_output_shape": zeropad_output_shape
15
+ },
16
+ compile=False
17
+ )
18
+
19
+ def load_pcg_model():
20
+ ensure_models_downloaded()
21
+ model_path = Path("models/pcg_model.h5")
22
+ if not model_path.exists():
23
+ raise FileNotFoundError(f"PCG model not found at {model_path.resolve()}")
24
+
25
+ model = load_model(model_path, compile=False)
26
+ model.compile()
27
+ return model
28
+
29
+ def load_emg_model():
30
+ ensure_models_downloaded()
31
+ model_path = Path("models/emg_classifier_txt.h5")
32
+ if not model_path.exists():
33
+ raise FileNotFoundError(f"EMG model not found at {model_path.resolve()}")
34
+ model = load_model(model_path, compile=False)
35
+ model.compile()
36
+ return model
37
+
38
+
39
+ def load_vag_model():
40
+ ensure_models_downloaded()
41
+ p = Path("models/vag_feature_classifier.pkl")
42
+ if not p.exists():
43
+ raise FileNotFoundError(f"No VAG model at {p.resolve()}")
44
+ return joblib.load(p)
45
+
src/util.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ import pandas as pd
4
+ from scipy.signal import resample
5
+ from sklearn.preprocessing import scale
6
+ import soundfile as sf
7
+ from gemini import query_gemini_rest
8
+ import librosa
9
+ import tempfile
10
+
11
+
12
+ EXPECTED_LEN = 256
13
+ STEP = 128
14
+
15
+ PCG_LABELS = [
16
+ "Normal",
17
+ "Aortic Stenosis",
18
+ "Mitral Stenosis",
19
+ "Mitral Valve Prolapse",
20
+ "Pericardial Murmurs"
21
+ ]
22
+
23
+ LABELS_EMG = ["healthy", "myopathy", "neuropathy"]
24
+
25
+ def load_uploaded_file(file, signal_type="ECG") -> np.ndarray:
26
+ name = file.name.lower()
27
+
28
+
29
+ if signal_type in ("ECG", "EMG"):
30
+ text = file.read().decode("utf-8").strip()
31
+ if "," in text:
32
+ vals = [float(x) for x in text.split(",") if x.strip()]
33
+ else:
34
+ vals = [float(x) for x in text.splitlines() if x.strip()]
35
+ return np.array(vals, dtype=np.float32)
36
+
37
+
38
+ if signal_type == "VAG":
39
+ if name.endswith(".csv"):
40
+ df = pd.read_csv(file)
41
+ features = [
42
+ "rms_amplitude",
43
+ "peak_frequency",
44
+ "spectral_entropy",
45
+ "zero_crossing_rate",
46
+ "mean_frequency",
47
+ ]
48
+ return df[features].iloc[0].values.astype(np.float32)
49
+
50
+ elif name.endswith(".npy"):
51
+ return np.load(file)
52
+
53
+ elif name.endswith(".wav"):
54
+ data, _ = sf.read(file)
55
+ return data.astype(np.float32)
56
+
57
+ raise ValueError("Unsupported VAG file format.")
58
+
59
+
60
+ if signal_type == "PCG" and name.endswith((".wav", ".flac", ".mp3")):
61
+ data, _ = sf.read(file)
62
+ if data.ndim > 1:
63
+ data = data[:, 0]
64
+ return data.astype(np.float32)
65
+
66
+ raise ValueError("Unsupported file format.")
67
+
68
+
69
+ def preprocess_signal(x: np.ndarray) -> np.ndarray:
70
+ if x.size != EXPECTED_LEN:
71
+ x = resample(x, EXPECTED_LEN)
72
+ return scale(x).astype(np.float32)
73
+
74
+
75
+ def segment_signal(raw: np.ndarray) -> np.ndarray:
76
+ raw = preprocess_signal(raw)
77
+ seg = raw.reshape(EXPECTED_LEN, 1)
78
+ return seg[np.newaxis, ...]
79
+
80
+
81
+
82
+ PCG_INPUT_LEN = 995
83
+
84
+ def preprocess_pcg_waveform(wave: np.ndarray) -> np.ndarray:
85
+
86
+ if wave.ndim > 1:
87
+ wave = wave.mean(axis=1)
88
+
89
+
90
+ if len(wave) < PCG_INPUT_LEN:
91
+ wave = np.pad(wave, (0, PCG_INPUT_LEN - len(wave)))
92
+ else:
93
+ wave = wave[:PCG_INPUT_LEN]
94
+
95
+
96
+ wave = (wave - np.mean(wave)) / (np.std(wave) + 1e-8)
97
+ return wave.astype(np.float32)
98
+
99
+ def analyze_pcg_signal(file, model, gemini_key=None):
100
+
101
+ signal, _ = sf.read(file)
102
+ signal = preprocess_pcg_waveform(signal)
103
+
104
+ input_data = signal.reshape(1, PCG_INPUT_LEN, 1)
105
+ preds = model.predict(input_data, verbose=0)[0]
106
+
107
+ labels = [
108
+ "Normal",
109
+ "Aortic Stenosis",
110
+ "Mitral Stenosis",
111
+ "Mitral Valve Prolapse",
112
+ "Pericardial Murmurs",
113
+ ]
114
+ idx = int(np.argmax(preds))
115
+ confidence = float(preds[idx])
116
+ label = labels[idx]
117
+
118
+ gem_txt = None
119
+ if gemini_key:
120
+ gem_txt = query_gemini_rest("PCG", label, confidence, gemini_key)
121
+
122
+ return label, label, confidence, gem_txt
123
+
124
+
125
+
126
+
127
+ def pcg_to_features(file_obj, target_sr=16000, n_mels=128, n_frames=112):
128
+
129
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
130
+ tmp.write(file_obj.read())
131
+ tmp_path = tmp.name
132
+
133
+
134
+ y, sr = librosa.load(tmp_path, sr=target_sr, mono=True)
135
+
136
+
137
+ mel = librosa.feature.melspectrogram(y=y, sr=sr, n_fft=512, hop_length=256, n_mels=n_mels)
138
+ logmel = librosa.power_to_db(mel, ref=np.max)
139
+ if logmel.shape[1] < n_frames:
140
+
141
+ pad_width = n_frames - logmel.shape[1]
142
+ pad = np.zeros((n_mels, pad_width))
143
+ logmel = np.hstack((logmel, pad))
144
+ else:
145
+ logmel = logmel[:, :n_frames]
146
+
147
+
148
+ feat = logmel.flatten().astype(np.float32)
149
+
150
+ return feat[np.newaxis, ...]
151
+
152
+
153
+ def analyze_emg_signal(file, model, gemini_key=""):
154
+ raw = load_uploaded_file(file, signal_type="EMG")
155
+
156
+ WINDOW = 1000
157
+
158
+ wins = []
159
+ if len(raw) < WINDOW:
160
+ pad = np.pad(raw, (0, WINDOW - len(raw)))
161
+ wins.append(((pad - pad.mean()) / (pad.std()+1e-6)).reshape(WINDOW, 1))
162
+ else:
163
+ for i in range(0, len(raw) - WINDOW + 1, WINDOW):
164
+ win = raw[i:i+WINDOW]
165
+ win = (win - win.mean()) / (win.std() + 1e-6)
166
+ wins.append(win.reshape(WINDOW, 1))
167
+ X = np.array(wins, dtype=np.float32)
168
+
169
+ preds = model.predict(X, verbose=0)
170
+ classes = np.argmax(preds, axis=1)
171
+ final = int(np.bincount(classes).argmax())
172
+ conf = float(preds[:, final].mean())
173
+ human = LABELS_EMG[final]
174
+
175
+ gemini_txt = None
176
+ if gemini_key:
177
+ gemini_txt = query_gemini_rest("EMG", human, conf, gemini_key)
178
+
179
+ return human, conf, gemini_txt
180
+
181
+
182
+
183
+ FEATURE_COLS = [
184
+ "rms_amplitude",
185
+ "peak_frequency",
186
+ "spectral_entropy",
187
+ "zero_crossing_rate",
188
+ "mean_frequency",
189
+ ]
190
+
191
+ def vag_to_features(file_obj) -> np.ndarray:
192
+ df = pd.read_csv(file_obj)
193
+ x = df[FEATURE_COLS].iloc[0].values.astype(np.float32)
194
+ return x.reshape(1, -1)
195
+
196
+
197
+ def predict_vag_from_features(file_obj, model_bundle, gemini_key=""):
198
+ model = model_bundle["model"]
199
+ scaler = model_bundle["scaler"]
200
+ encoder = model_bundle["encoder"]
201
+
202
+ x = vag_to_features(file_obj)
203
+ x_s = scaler.transform(x)
204
+ prob = model.predict_proba(x_s)[0]
205
+ idx = int(np.argmax(prob))
206
+ conf = float(prob[idx])
207
+ label = encoder.inverse_transform([idx])[0].title()
208
+
209
+ gem_note = (
210
+ query_gemini_rest("VAG", label, conf, gemini_key)
211
+ if gemini_key else None
212
+ )
213
+ return label, label, conf, gem_note
src/vag_util.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import pandas as pd
4
+ import numpy as np
5
+
6
+ def predict_vag_from_features(file, model, gemini_key=None):
7
+ df = pd.read_csv(file)
8
+ required_features = [
9
+ "rms_amplitude",
10
+ "peak_frequency",
11
+ "spectral_entropy",
12
+ "zero_crossing_rate",
13
+ "mean_frequency"
14
+ ]
15
+
16
+ x = df[required_features].values.astype(np.float32)
17
+ preds = model.predict_proba(x)[0]
18
+ idx = int(np.argmax(preds))
19
+ confidence = float(preds[idx])
20
+
21
+ labels = ["normal", "osteoarthritis", "ligament_injury"]
22
+ label = labels[idx]
23
+
24
+ gem_txt = None
25
+ if gemini_key:
26
+ from gemini import query_gemini_rest
27
+ gem_txt = query_gemini_rest("VAG", label, confidence, gemini_key)
28
+
29
+ return label, label, confidence, gem_txt