指标
此笔记本展示了 scikit-network
中的各种分类指标。
[1]:
from IPython.display import SVG
[2]:
import numpy as np
[3]:
from sknetwork.data import art_philo_science
from sknetwork.classification import get_accuracy_score, get_f1_scores, get_average_f1_score, get_confusion_matrix
from sknetwork.classification import DiffusionClassifier
from sknetwork.visualization import svg_graph
节点分类
[4]:
graph = art_philo_science(metadata=True)
adjacency = graph.adjacency
position = graph.position
names = graph.names
labels = graph.labels
[5]:
# Visualization
image = svg_graph(adjacency, position=position, names=names, labels=labels)
SVG(image)
[5]:
[6]:
# Number of labels
n_labels = len(set(labels))
[7]:
# Classification by diffusion
diffusion = DiffusionClassifier()
[8]:
# Prediction
labels_pred = diffusion.fit_predict(adjacency, labels={i: labels[i] for i in [0, 10, 15, 20]})
准确率
[9]:
accuracy = get_accuracy_score(labels, labels_pred)
[10]:
np.round(accuracy, 2)
[10]:
0.83
F1 分数
[11]:
# all f1 scores
f1_scores = get_f1_scores(labels, labels_pred)
[12]:
np.round(f1_scores, 2)
[12]:
array([0.71, 1. , 0.78])
[13]:
# with precisions and recalls
f1_scores, precisions, recalls = get_f1_scores(labels, labels_pred, return_precision_recall=True)
[14]:
np.round(precisions, 2)
[14]:
array([0.86, 1. , 0.69])
[15]:
np.round(recalls, 2)
[15]:
array([0.6, 1. , 0.9])
[16]:
# average
f1_score = get_average_f1_score(labels, labels_pred)
[17]:
np.round(f1_score, 2)
[17]:
0.83
混淆矩阵
[18]:
confusion = get_confusion_matrix(labels, labels_pred)
[19]:
# True labels on rows
confusion.toarray()
[19]:
array([[ 6, 0, 4],
[ 0, 10, 0],
[ 1, 0, 9]])