Python绘制混淆矩阵热力图
用matplotlib绘制混淆矩阵,可以通过改变 imshow 函数中的 cmap 参数来修改颜色。cmap 参数接受一个 colormap 的名字,你可以选择许多不同的 colormap,例如 ‘viridis’, ‘plasma’, ‘inferno’, ‘magma’, ‘cividis’, ‘cool’, ‘hot’ 等等。具体的 colormap 可以参考 matplotlib 的文档。
# 案例1:import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap, Normalizeconf_arr = [[33, 2, 0, 0, 0, 0, 0, 0, 0, 1, 3],[3, 31, 0, 0, 0, 0, 0, 0, 0, 0, 0],[0, 4, 41, 0, 0, 0, 0, 0, 0, 0, 1],[0, 1, 0, 30, 0, 6, 0, 0, 0, 0, 1],[0, 0, 0, 0, 38, 10, 0, 0, 0, 0, 0],[0, 0, 0, 3, 1, 39, 0, 0, 0, 0, 4],[0, 2, 2, 0, 4, 1, 31, 0, 0, 0, 2],[0, 1, 0, 0, 0, 0, 0, 36, 0, 2, 0],[0, 0, 0, 0, 0, 0, 1, 5, 37, 5, 1],[3, 0, 0, 0, 0, 0, 0, 0, 0, 39, 0],[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 38]]# 自定义渐变颜色
colors = ['#ffffff', '#ffcccc', '#ff6666', '#ff0000', '#990000']
custom_cmap = LinearSegmentedColormap.from_list('custom_cmap', colors)# 归一化对象,定义最小值和最大值
norm = Normalize(vmin=0, vmax=np.max(conf_arr))norm_conf = []
for i in conf_arr:a = 0tmp_arr = []a = sum(i, 0)for j in i:tmp_arr.append(float(j)/float(a))norm_conf.append(tmp_arr)fig = plt.figure()
plt.clf()
ax = fig.add_subplot(111)
ax.set_aspect(1)
res = ax.imshow(np.array(conf_arr), cmap=custom_cmap, norm=norm, interpolation='nearest')# 获取矩阵的宽度和高度
width, height = np.array(conf_arr).shape# 使用 range 替换 xrange
for x in range(width):for y in range(height):ax.annotate(str(conf_arr[x][y]), xy=(y, x),horizontalalignment='center',verticalalignment='center',color='black')cb = fig.colorbar(res)# 调整 alphabet 使其不会超出索引范围
alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'[:width]
plt.xticks(range(width), alphabet)
plt.yticks(range(height), alphabet)plt.savefig('confusion_matrix.png', format='png')
plt.show()
效果图如下:
# 案例2:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns# 创建混淆矩阵数据
intra_inter = np.array([[84.9, 4.7, 7.3, 0.5, 0.5, 1.6,0.1,0.2],[4.0, 89.6, 0.6, 4.0, 1.7, 0.1,0.3,0.1],[6.7, 0.5, 85.7, 1.4, 4.3, 1.0,0.2,0.4],[2.3, 0.0, 5.7, 87.4, 4.6, 0.0,0.3,0.1],[1.0, 1.0, 4.9, 1.0, 92.2, 0.0,0.0,0.3],[0.4, 0.4, 2.2, 0.4, 0.9, 95.1,0.2,0.4],[0.6, 0.6, 0.6, 2.2, 5.0, 0.1,96.1,0.3],[2.4, 0.6, 0.6, 7.7, 3.6, 4.5,0.2,95.1]
])intra_inter_obj = np.array([[89.1, 1.6, 7.8, 0.5, 0.5, 1.0,0.1,0.2],[1.7, 93.1, 1.2, 3.5, 0.6, 0.0,0.1,0.2],[6.7, 1.0, 91.4, 0.5, 0.5, 1.0,0.1,0.2],[1.1, 1.1, 4.6, 94.3, 4.6, 0.0,0.1,0.2],[1.0, 1.0, 6.9, 1.0, 90.2, 0.0,0.1,0.2],[0.4, 0.4, 0.4, 0.4, 0.4, 95.1,0.1,0.2],[7.7, 3.6, 3.6, 7.7, 2.7, 0.1,93.1,0.1],[2.4, 0.6, 0.6, 7.7, 3.6, 4.5,0.2,95.1]
])labels = ["Right set", "Right spike", "Right pass", "Right winpoint", "Left winpoint", "Left pass", "Left spike", "Left set"]# 每个矩阵的维度
n_labels = len(labels)
cell_size = 1.5 # 每个小块的大小# 创建图形
fig, ax = plt.subplots(1, 2, figsize=(n_labels * cell_size * 2, n_labels * cell_size))# 绘制第一个混淆矩阵
sns.heatmap(intra_inter, annot=True, fmt=".1f", cmap="Blues", ax=ax[0], xticklabels=labels, yticklabels=labels, cbar=False)
ax[0].set_title("Intra+Inter")
ax[0].set_xlabel('')
ax[0].set_ylabel('')# 绘制第二个混淆矩阵
sns.heatmap(intra_inter_obj, annot=True, fmt=".1f", cmap="Blues", ax=ax[1], xticklabels=labels, yticklabels=labels, cbar=True)
ax[1].set_title("Intra+Inter+Object")
ax[1].set_xlabel('')
ax[1].set_ylabel('')# 调整底部标签的倾斜角度
plt.setp(ax[0].get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
plt.setp(ax[1].get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")# 获取颜色条对象
cbar = ax[1].collections[0].colorbar# 调整颜色条的位置和大小
cbar.ax.set_position([0.92, ax[1].get_position().y0, 0.02, ax[1].get_position().height])# 调整布局
plt.tight_layout(rect=[0, 0, 0.9, 1])
plt.show()
效果图如下: