test / app.py
mskov's picture
Update app.py
091bc92
raw
history blame
2.01 kB
import transformers
from transformers import pipeline
import gradio as gr
import pandas
import matplotlib.pyplot as plt
import os
import sys
os.system('pip install -U scikit-learn scipy matplotlib')
#import scikit-learn
from sklearn import model_selection
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.naive_bayes import GaussianNB
from sklearn.svm import SVC
os.system("pip install git+https://github.com/openai/whisper.git")
import whisper
os.system("pip install numba==0.53")
whisper_esc50 = pipeline(model="mskov/whisper_esc50")
whisper_miso= pipeline(model="mskov/whisper_miso")
whisper_tiny = whisper.load_model("tiny")
whisper_base = whisper.load_model("base")
dataset = load_dataset("mskov/miso_test")
names = ['path', 'file_name', 'category']
dataframe = pandas.read_csv(url, names=names)
array = dataframe.values
X = array[:,0:2]
Y = array[:,2]
# prepare configuration for cross validation test harness
seed = 7
# prepare models
models = [whisper_esc50, whisper_miso, whisper_tiny, whisper_base]
models.append(('LR', LogisticRegression()))
models.append(('LDA', LinearDiscriminantAnalysis()))
models.append(('KNN', KNeighborsClassifier()))
models.append(('CART', DecisionTreeClassifier()))
models.append(('NB', GaussianNB()))
models.append(('SVM', SVC()))
# evaluate each model in turn
results = []
names = []
scoring = 'accuracy'
for name, model in models:
kfold = model_selection.KFold(n_splits=10, random_state=seed)
cv_results = model_selection.cross_val_score(model, X, Y, cv=kfold, scoring=scoring)
results.append(cv_results)
names.append(name)
msg = "%s: %f (%f)" % (name, cv_results.mean(), cv_results.std())
print(msg)
# boxplot algorithm comparison
fig = plt.figure()
fig.suptitle('Algorithm Comparison')
ax = fig.add_subplot(111)
plt.boxplot(results)
ax.set_xticklabels(names)
plt.show()