-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplot_ATI_curve_classification.py
More file actions
31 lines (26 loc) · 1.4 KB
/
plot_ATI_curve_classification.py
File metadata and controls
31 lines (26 loc) · 1.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import numpy as np
from sklearn.metrics import mean_absolute_error,f1_score
import matplotlib.pyplot as plt
data_folder = 'G:/Downstream_task/AF/'
# x_test = np.load(data_folder + 'standardized_x_test.npy')
quality_index = np.load(data_folder + 'test_quality_index2.npy')
true_label = np.load(data_folder + 'stanford_test_label.npy')
predicted_label = np.load(data_folder + 'predicted_labels2.npy')
quality_bins = {}
for threshold in np.arange(0.1,1,0.1):
quality_bins[threshold] = [[],[]]
#
for index, quality in enumerate(quality_index):
percentage_of_artifacts = np.sum(quality<=0.5)/len(quality)
# plt.plot(x_test[index])
# plt.show()
# print('hello')
for threshold in np.arange(0.1,1,0.1):
if percentage_of_artifacts < threshold:
quality_bins[threshold][0].append(true_label[index])
quality_bins[threshold][1].append(predicted_label[index])
for threshold in np.arange(0.1,1,0.1):
quality_bins[threshold][0] = np.array(quality_bins[threshold][0])
quality_bins[threshold][1] = np.array(quality_bins[threshold][1])
print('threshold: %f sample number: %f F1 Score: %f'%(threshold,len(quality_bins[threshold][0]), f1_score(quality_bins[threshold][0], quality_bins[threshold][1], average='weighted')))
print('threshold: %f sample number: %f F1 Score: %f'%(1.0,len(true_label),f1_score(true_label, predicted_label, average='weighted')))