scikit-learn 提供的繪圖工具

2023-02-20 14:33 更新

Scikit-learn定義了一個簡單的API,創(chuàng)建用于機器學習的可視化對象。該API的特點是無需重新計算即可進行快速繪圖和視覺調整。在以下示例中,我們繪制了利用支持向量機算法產(chǎn)生的ROC曲線:

from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.metrics import plot_roc_curve
from sklearn.datasets import load_wine

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
svc = SVC(random_state=42)
svc.fit(X_train, y_train)

svc_disp = plot_roc_curve(svc, X_test, y_test)


返回的svc_disp對象使我們可以在以后的圖中繼續(xù)使用已經(jīng)計算出的SVC的ROC曲線。在本例中,svc_disp是一個 RocCurveDisplay,它將計算得到的值儲存到稱作roc_auc,fpr,和tpr的屬性中。接下來,我們訓練一個隨機森林分類器,并使用Display對象的plot 方法再次繪制先前計算的roc曲線。

import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestClassifier

rfc = RandomForestClassifier(random_state=42)
rfc.fit(X_train, y_train)

ax = plt.gca()
rfc_disp = plot_roc_curve(rfc, X_test, y_test, ax=ax, alpha=0.8)
svc_disp.plot(ax=ax, alpha=0.8)


請注意,我們傳遞alpha=0.8給繪圖函數(shù)來調整曲線的透明度。

例子:
帶有可視化API的ROC曲線
局部依賴的高級繪圖
顯示對象的可視化

5.1.1 函數(shù)

inspection.plot_partial_dependence(…[, …]) 部分依賴圖。
metrics.plot_confusion_matrix(estimator, X, …) 繪制混淆矩陣。
metrics.plot_precision_recall_curve(…[, …]) 繪制二元分類器的精確度、召回率曲線。
metrics.plot_roc_curve(estimator, X, y, *) 繪制受試者工作特性(ROC)曲線。

5.1.2 可視化對象

inspection.PartialDependenceDisplay(…) 部分依賴圖(PDP)可視化。
metrics.ConfusionMatrixDisplay(…[, …]) 混淆矩陣可視化。
metrics.PrecisionRecallDisplay(precision, …) 精確度、召回率可視化。
metrics.RocCurveDisplay(*, fpr, tpr[, …]) ROC曲線可視化。


以上內容是否對您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號
微信公眾號

編程獅公眾號