-
目的
:使用 MNIST 數(shù)據(jù)集,建立數(shù)字圖像識別模型,識別任意圖像中的數(shù)字;
1. 數(shù)據(jù)準備(MNIST)
-
MNIST
,一組由美國高中生和人口調(diào)查局員工手寫的 70000 個數(shù)字圖片;每張圖片都用其代表的數(shù)字標記;因廣泛被應(yīng)用于機器學(xué)習(xí)入門,被稱作機器學(xué)習(xí)領(lǐng)域的Hello World
;也可用于測試新分類算法的效果;
使用 Scikit-Learn
下載數(shù)據(jù)集的前置工作
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
Scikit-Learn
使用 Python 的 urllib
包通過 HTTPS
協(xié)議下載數(shù)據(jù)集,這里全局取消證書驗證(否則 Scikit-Learn
可能無法建立 ssl 連接);
使用 Scikit-Learn 下載 MNIST
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', version=1)
mnist.keys()
dict_keys(['data', 'target', 'frame', 'categories', 'feature_names', 'target_names', 'DESCR', 'details', 'url'])
### 查看數(shù)組
X, y = mnist["data"], mnist["target"]
X.shape
(70000, 784)
y.shape
(70000,)
共 70000 張圖片,每張圖片由 784 個特征(28 * 28 個像素,每個像素用 0(白色) 到 255(黑色) 表示);
Scikit-Learn 數(shù)據(jù)集通用字典結(jié)構(gòu)
-
DESCR
,描述數(shù)據(jù)集; -
data
,包含一個數(shù)組,每個實例為一行,每個特征為一列; -
target
,包含一個帶有標記的數(shù)組;
使用 Matplotlib 查看數(shù)字圖片
- 編寫繪圖函數(shù);
import matplotlib.pyplot as plt
import matplotlib as mpl
def plot_digit(data):
image = data.reshape(28, 28)
plt.imshow(image, cmap = mpl.cm.binary, interpolation="nearest")
plt.axis("off")
def plot_digits(instances, images_per_row=10, **options):
size = 28
images_per_row = min(len(instances), images_per_row)
# This is equivalent to n_rows = ceil(len(instances) / images_per_row):
n_rows = (len(instances) - 1) // images_per_row + 1
# Append empty images to fill the end of the grid, if needed:
n_empty = n_rows * images_per_row - len(instances)
padded_instances = np.concatenate([instances, np.zeros((n_empty, size * size))], axis=0)
# Reshape the array so it's organized as a grid containing 28×28 images:
image_grid = padded_instances.reshape((n_rows, images_per_row, size, size))
# Combine axes 0 and 2 (vertical image grid axis, and vertical image axis),
# and axes 1 and 3 (horizontal axes). We first need to move the axes that we
# want to combine next to each other, using transpose(), and only then we
# can reshape:
big_image = image_grid.transpose(0, 2, 1, 3).reshape(n_rows * size, images_per_row * size)
# Now that we have a big image, we just need to show it:
plt.imshow(big_image, cmap = mpl.cm.binary, **options)
plt.axis("off")
- MNIST 的第一個圖片展示;
some_digit = X[:1].to_numpy()
plot_digit(some_digit)
plt.show()
# 查看圖片對應(yīng)標簽,驗證是一個數(shù)字 '5'
y[0]
'5'
- MNIST 的多圖樣例展示;
plt.figure(figsize=(9,9))
example_images = X[:100]
plot_digits(example_images, images_per_row=10)
# save_fig("more_digits_plot")
plt.show()
將字符標簽轉(zhuǎn)換成整數(shù)
import numpy as np
y = y.astype(np.uint8)
創(chuàng)建測試集
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
MNIST 數(shù)據(jù)集已經(jīng)分成訓(xùn)練集(前 6 萬張圖片)和測試集(最后 1 萬張圖片);
可以對訓(xùn)練集進行混洗,保障在做交叉驗證時所有折疊的實例分布相當;有一些算法對訓(xùn)練實例的順序敏感,連續(xù)輸入相同的實例可能導(dǎo)致性能不佳;也有一些情況時間序列也是實例特征(如股市架構(gòu)或天氣狀態(tài)),則不可混洗數(shù)據(jù)集;
2. 二元分類器(SGD)
-
二元分類器
,在兩個類中區(qū)分;
簡化問題,圖片數(shù)字識別,先從識別圖片 是 5
和 非 5
開始;
轉(zhuǎn)換圖片的標簽
y_train_5 = (y_train == 5) # True for all 5s, False for all other digits
y_test_5 = (y_test == 5)
使用 Scikit-Learn 的 SGDClassifier 訓(xùn)練隨機梯度下降(SGD)分類器
-
SGD
,獨立處理訓(xùn)練實例,一次一個,非常適合處理大型的數(shù)據(jù)集,也適合在線學(xué)習(xí);
from sklearn.linear_model import SGDClassifier
sgd_clf = SGDClassifier(random_state=42)
sgd_clf.fit(X_train, y_train_5)
給 random_state
設(shè)置固定值,如 =42
可以讓 SGD 的隨機訓(xùn)練變得結(jié)果可復(fù)現(xiàn);
sgd_clf.predict(X[:1])
array([ True])
SGD 分類器預(yù)測這是一張 5
,結(jié)果正確;
3. 性能測試
-
準確率
,正確預(yù)測的比率;
1. 交叉驗證
自定義實現(xiàn)交叉驗證
from sklearn.model_selection import StratifiedKFold
from sklearn.base import clone
skfolds = StratifiedKFold(n_splits=3, random_state=42, shuffle=True)
for train_index, test_index in skfolds.split(X_train, y_train_5):
clone_clf = clone(sgd_clf)
X_train_folds = X_train.iloc[train_index]
y_train_folds = y_train_5.iloc[train_index]
X_test_fold = X_train.iloc[test_index]
y_test_fold = y_train_5.iloc[test_index]
clone_clf.fit(X_train_folds, y_train_folds)
y_pred = clone_clf.predict(X_test_fold)
n_correct = sum(y_pred == y_test_fold)
print(n_correct / len(y_pred))
0.9669
0.91625
0.96785
-
StratifiedKFold
,實現(xiàn)分層抽樣;讓每個折疊中各個類的比例與整體比例相當; -
clone
,為每個迭代創(chuàng)建一個分類器的副本,用于對訓(xùn)練集的訓(xùn)練和測試集的預(yù)測;
使用 Scikit-Learn 的 cross_val_score() 實現(xiàn) K-折交叉驗證
from sklearn.model_selection import cross_val_score
cross_val_score(sgd_clf, X_train, y_train_5, cv=3, scoring="accuracy")
array([0.95035, 0.96035, 0.9604 ])
-
K-折交叉驗證
,將訓(xùn)練集分解成 K 個折疊(這里是 3 折),每次留 1 個折疊用于測試集,剩余用于訓(xùn)練集;
所有折疊交叉驗證的準確率都超過了 91%,這看似很準確,實則準確率不足以衡量這個分類器的優(yōu)劣;
自定義 非 5
分類器
from sklearn.base import BaseEstimator
class Never5Classifier(BaseEstimator):
def fit(self, X, y=None):
return self
def predict(self, X):
return np.zeros((len(X), 1), dtype=bool)
never_5_clf = Never5Classifier()
cross_val_score(never_5_clf, X_train, y_train_5, cv=3, scoring="accuracy")
array([0.91125, 0.90855, 0.90915])
使用自定義 非 5
分類器進行交叉驗證,得到所有折疊的準確率也在 90% 以上;這是因為所有圖片中只有約 10% 是數(shù)字 5,90% 非 5 是正確的;這進一步說明準確率不足以評判分類器的性能(特別是處理有偏數(shù)據(jù)集時);
2. 混淆矩陣
-
混淆矩陣
,對多個二分類或多分類進行訓(xùn)練/測試,統(tǒng)計 A 類實例被分類為 B 類別的次數(shù);是評估分類器性能的常見方法; -
使用 cross_val_predict() 進行 K-折交叉預(yù)測
from sklearn.model_selection import cross_val_predict
y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3)
cross_val_predict 與 cross_val_score 類似,但返回的不是交叉驗證的評分,而是每個折疊的預(yù)測值;
- 使用 confusion_matrix() 獲取混淆矩陣
from sklearn.metrics import confusion_matrix
confusion_matrix(y_train_5, y_train_pred)
array([[53892, 687],
[ 1891, 3530]])
混淆矩陣的行表示實際類別(實際為 非 5
、5
),列表示預(yù)測類別(預(yù)測為 非 5
、5
);
-
負類(Negative)
:實際為非 5
-
真負類(TN)
:53892 個正確
分類為非 5
; -
假正類(FP)
:687 個錯誤
分類為5
;
-
-
正類(Positive)
:實際為5
-
假負類(FN)
:1891 個錯誤
分類為非 5
; -
真正類(TP)
:3530 個正確
分類為5
;
-
完美的分類器只存在真正類與真負類,混淆矩陣的對角線(左上和右下)有非零值;
y_train_perfect_predictions = y_train_5 # pretend we reached perfection
confusion_matrix(y_train_5, y_train_perfect_predictions)
array([[54579, 0],
[ 0, 5421]])
3. 查準率與查全率
-
查準率(precision)
,真正類占真正類和假正類之和的比例;將忽略這個正類實例之外的所有內(nèi)容;
p r e c i s i o n = T P T P + F P precision = \frac{TP}{TP + FP} precision=TP+FPTP?
-
查全率(recall)
:召回率
,靈敏度
或真正類率
,真正類占所有正類(真正類和假負類)之和的比例;正確檢測到的正類實例的比率;
r e c a l l = T P T P + F N recall = \frac{TP}{TP + FN} recall=TP+FNTP?
使用 Scikit-Learn 計算查準率和查全率
from sklearn.metrics import precision_score, recall_score
precision_score(y_train_5, y_train_pred) # == 3530 / (3530 + 687)
0.8370879772350012
recall_score(y_train_5, y_train_pred) # == 3530 / (3530 + 1891)
0.6511713705958311
這說明,當這個 5-檢測器
說一張圖片是 5 時,只有 83% 時準確的,且只有 65% 的 5 被檢測出來了;
- F 1 F_1 F1? 分數(shù),查準率與查全率的諧波平均值,會給予低值更高的權(quán)重;更適用于查準率和查全率相近的分類器;
F 1 = 2 1 p r e c i s i o n + 1 r e c a l l = 2 × p r e c i s i o n × r e c a l l p r e c i s i o n + r e c a l l = T P T P = F N + F P 2 F_1 = \frac{2}{\frac{1}{precision} + \frac{1}{recall}} = 2 \times \frac{precision \times recall}{precision + recall} = \frac{TP}{TP = \frac{FN + FP}{2}} F1?=precision1?+recall1?2?=2×precision+recallprecision×recall?=TP=2FN+FP?TP?
使用 f1_score() 計算 F 1 F_1 F1? 分數(shù)
from sklearn.metrics import f1_score
f1_score(y_train_5, y_train_pred)
0.7325171197343846
魚與熊掌不可得兼
,不能同時兼顧查準率和查全率;
-
對于
寧缺毋濫
類型的分類器,更在乎查準率(如給小孩子推薦視頻); -
對于
寧殺錯不放過
類型的分類器,更在乎查全率(如小區(qū)監(jiān)控抓小偷);
4. P-R 曲線
-
P-R 曲線
,將實例按預(yù)測為正類的概率高低排序,然后逐個把樣本作為正類進行預(yù)測評估,計算其查準率和查全率,以查全率為橫軸,查準率為縱軸繪制一個曲線圖;
SGDClassifier 的分類決策
基于決策函數(shù)計算處每個實例的分值;將每個實例按分數(shù)從低到高從左到右排列;取一個閾值,大于該閾值的實例為正類,否則為負類;(通常閾值越高,查全率越低,查準率越高);
- 若決策閾值在中間箭頭位置(兩個 5 之間),查準率為 80%(4/5),查全率為 67%(4/6);
- 若決策閾值在右邊箭頭位置(提升閾值),查準率為 100%(3/3),查全率為 50%(3/6);
- 若決策閾值在左邊箭頭位置(降低閾值),查準率為 75%(6/8),查全率為 100%(6/6);
使用 decision_function() 獲取每個實例的分數(shù)
y_scores = sgd_clf.decision_function(some_digit)
y_scores
array([2164.22030239])
- 通過閾值控制預(yù)測結(jié)果;
threshold = 0
y_some_digit_pred = (y_scores > threshold)
y_some_digit_pred
array([ True])
- 提升閾值控制預(yù)測結(jié)果;
threshold = 8000
y_some_digit_pred = (y_scores > threshold)
y_some_digit_pred
array([False])
提升閾值可以降低查全率(將本是 5 的圖片判定為了非 5
);
使用 cross_val_predict() 獲取訓(xùn)練集的實例分數(shù)
y_scores = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3, method="decision_function")
使用 precision_recall_curve() 計算所有閾值對應(yīng)的查準率和查全率
from sklearn.metrics import precision_recall_curve
precisions, recalls, thresholds = precision_recall_curve(y_train_5, y_scores)
繪制查準率和查全率與決策閾值的關(guān)系曲線
def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):
plt.plot(thresholds, precisions[:-1], "b--", label="Precision", linewidth=2)
plt.plot(thresholds, recalls[:-1], "g-", label="Recall", linewidth=2)
plt.legend(loc="center right", fontsize=16)
plt.xlabel("Threshold", fontsize=16)
plt.grid(True)
plt.axis([-50000, 50000, 0, 1])
recall_90_precision = recalls[np.argmax(precisions >= 0.90)]
threshold_90_precision = thresholds[np.argmax(precisions >= 0.90)]
plt.figure(figsize=(8, 4))
plot_precision_recall_vs_threshold(precisions, recalls, thresholds)
plt.plot([threshold_90_precision, threshold_90_precision], [0., 0.9], "r:")
plt.plot([-50000, threshold_90_precision], [0.9, 0.9], "r:")
plt.plot([-50000, threshold_90_precision], [recall_90_precision, recall_90_precision], "r:")
plt.plot([threshold_90_precision], [0.9], "ro")
plt.plot([threshold_90_precision], [recall_90_precision], "ro")
plt.show()
查準率比查全率曲線要崎嶇一些,因為隨著閾值的提升,查準率可能會下降,但查全率只會下降;
繪制 P/R 曲線
以查全率為橫軸,查準率為縱軸,將上文決策閾值關(guān)系圖
轉(zhuǎn)化成一張 P-R 曲線
;
def plot_precision_vs_recall(precisions, recalls):
plt.plot(recalls, precisions, "b-", linewidth=2)
plt.xlabel("Recall", fontsize=16)
plt.ylabel("Precision", fontsize=16)
plt.axis([0, 1, 0, 1])
plt.grid(True)
plt.figure(figsize=(8, 6))
plot_precision_vs_recall(precisions, recalls)
plt.plot([recall_90_precision, recall_90_precision], [0., 0.9], "r:")
plt.plot([0.0, recall_90_precision], [0.9, 0.9], "r:")
plt.plot([recall_90_precision], [0.9], "ro")
plt.show()
查全率在 80% 之后,查準率急劇下降,說明可能需要在此之前選擇一個權(quán)衡點
;
通常若學(xué)習(xí)器 A 的 P-R 曲線能完全包住學(xué)習(xí)器 B 的,則可斷言 A 優(yōu)于 B;若存在交叉,可采用面積比較法,或平衡點比較法;
查找指定查準率/查全率的最低/最高閾值
>>> threshold_90_precision = thresholds[np.argmax(precisions >= 0.90)]
3370.0194991439557 # 第一個 True 的最大索引
>>> threshold_90_recall = thresholds[np.argmin(recalls >= 0.90)]
-6861.032537940274 # 第一個 True 的最小索引
使用實例分數(shù)與閾值進行預(yù)測
>>> y_train_pred_90 = (y_scores >= threshold_90_precision)
array([False, False, False, ..., True, False, False])
- 查看預(yù)測的查準率與查全率;
>>> precision_score(y_train_5, y_train_pred_90)
0.9000345901072293
>>> recall_score(y_train_5, y_train_pred_90)
0.4799852425751706
查準率確實是指定的 90%;
5. ROC 曲線
-
ROC
(Receiver Operating Characteristic
,受試者工作特征
),以真正類率
為縱軸,以假正類率
為橫軸;描述的是查全率與(1 - 特異度)的關(guān)系;與 P-R 圖相似,若學(xué)習(xí)器 A 的曲線完全包住
學(xué)習(xí)器 B 的曲線,則可可斷言 A 優(yōu)于 B; -
真正類率
,查全率、靈敏度、召回率、True Positive Rate
,TPR
= T P T P + F N \frac{TP}{TP + FN} TP+FNTP?,所有正類中被測出來的正類的概率; -
假正類率
,False Positive Rate
,FPR
= F P T N + F P \frac{FP}{TN + FP} TN+FPFP?,所有負類中被錯認為正類的概率; -
真負類率
,TNR
,特異率
,正確被分類為負類的負類實例比率;
使用 roc_curve() 計算多種閾值的 TPR 和 FPR
from sklearn.metrics import roc_curve
fpr, tpr, thresholds = roc_curve(y_train_5, y_scores)
通過 Matplotlib 繪制 ROC 曲線
def plot_roc_curve(fpr, tpr, label=None):
plt.plot(fpr, tpr, linewidth=2, label=label)
plt.plot([0, 1], [0, 1], 'k--') # dashed diagonal
plt.axis([0, 1, 0, 1])
plt.xlabel('False Positive Rate (Fall-Out)', fontsize=16)
plt.ylabel('True Positive Rate (Recall)', fontsize=16)
plt.grid(True)
plt.figure(figsize=(8, 6))
plot_roc_curve(fpr, tpr)
fpr_90 = fpr[np.argmax(tpr >= recall_90_precision)]
plt.plot([fpr_90, fpr_90], [0., recall_90_precision], "r:")
plt.plot([0.0, fpr_90], [recall_90_precision, recall_90_precision], "r:")
plt.plot([fpr_90], [recall_90_precision], "ro")
plt.show()
召回率(TPR)越高,分類器的假正類(FPR)就越多(虛線表示純隨機分類器的 ROC 曲線,越高于虛線的 ROC 曲線,對應(yīng)的分類器越優(yōu));
使用 Scikit-Learn 計算 ROC 的 AUC
-
AUC
,Area Under ROC Curve
,ROC 曲線下的面積;當 ROC 曲線相交時,可通過 AUC 判定學(xué)習(xí)器的好壞;
from sklearn.metrics import roc_auc_score
>>> roc_auc_score(y_train_5, y_scores)
0.9604938554008616
這里 ROC AUC 分值看著很高,是因為正類(數(shù)字 5)比負類(非 5)的數(shù)量少很多;
P-R 曲線與 ROC 曲線的選擇
當正類非常少見或者更關(guān)注假正類而非假負類是,選擇 P-R 曲線;反之選擇 ROC 曲線;
6. RandomForestClassifier vs. SGDClassifier
RandomForestClassifier 沒有 decision_function(),代替的是 dict_proba();
-
dict_proba()
,返回一個數(shù)組,每行代表一個實例,每列表示一個類別,代表某個實例屬于某個給定類別的概率;
訓(xùn)練 RandomForestClassifier 分類器
from sklearn.ensemble import RandomForestClassifier
forest_clf = RandomForestClassifier(random_state=42)
y_probas_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3, method="predict_proba")
y_scores_forest = y_probas_forest[:, 1] # score = proba of positive class
fpr_forest, tpr_forest, thresholds_forest = roc_curve(y_train_5, y_scores_forest)
這里將正類率
作為分數(shù)
傳遞給 roc_curve();
繪制 RandomForestClassifier 分類器的 ROC 曲線
plt.plot(fpr, tpr, "b:", label="SGD")
plot_roc_curve(fpr_forest, tpr_forest, "Random Forest")
plt.legend(loc="lower right")
plt.show()
RandomForestClassifier 的 ROC 曲線比 SGDClassifier 好很多;
# ROC AUC 分數(shù)
>>> roc_auc_score(y_train_5, y_scores_forest)
0.9983436731328145
# 查準率
y_train_pred_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3)
>>> precision_score(y_train_5, y_train_pred_forest)
0.9905083315756169
# 查全率(召回率)
>>> recall_score(y_train_5, y_train_pred_forest)
0.8662608374838591
RandomForestClassifier 的效果確實好很多(查準率與查全率都比較高);
4. 多類分類器
-
多元分類器
,多項分類器,在兩個以上的類別中區(qū)分;
隨機森林、樸素貝葉斯等分類器可以直接處理多個類;支持向量機、線性分類器則是嚴格的二元分類器,但是可以通過一些策略讓二院分類器實現(xiàn)多分類的目的;
-
OvR
,一對剩余,一對多(one-versus-all),訓(xùn)練 10 個二元分類器(0-檢測器、1-檢測器、2-檢測器…),當需要檢測一張圖片時,先獲取每個分類器的決策分數(shù),哪個分類器的分值最高,圖片歸為哪一類; -
OvO
,一對一,訓(xùn)練 N × ( N ? 1 ) 2 \frac{N \times (N - 1)}{2} 2N×(N?1)? 個分類器,為每一對數(shù)字訓(xùn)練一個二元分類器(0-1 分類器、0-2 分類器、1-2 分類器…);優(yōu)點是,每個分類器只需要用到部分訓(xùn)練集對其必須區(qū)分的兩個類進行訓(xùn)練;
支持向量機在數(shù)據(jù)規(guī)模較大時表現(xiàn)較差,因此應(yīng)優(yōu)先選擇 OvO
策略,但對于大多數(shù)二分類器來書,OvR
是更好的選擇;
使用 Scikit-Learn 訓(xùn)練 SVM 分類器
>>> from sklearn.svm import SVC
>>> svm_clf = SVC()
>>> svm_clf.fit(X_train, y_train) # y_train, not y_train_5
>>> svm_clf.predict([some_digit])
array([5], dtype=uint8)
Scikit-Learn 檢測到嘗試使用二元分類算法進行多類分類任務(wù)時,會自動運行 OvR
或 OvO
;
這里 Scikit-Learn 實際訓(xùn)練了 45 個二元分類器,獲得它們對圖片的決策分數(shù),然后選擇了分數(shù)最高的類;
使用 decision_function() 查看 SVM 分類器的分數(shù)
>>> some_digit_scores = svm_clf.decision_function(some_digit)
>>> some_digit_scores
array([[ 1.72501977, 2.72809088, 7.2510018 , 8.3076379 , -0.31087254,
9.3132482 , 1.70975103, 2.76765202, 6.23049537, 4.84771048]])
查看分數(shù)最高的分類
>>> np.argmax(some_digit_scores)
5
>>> svm_clf.classes_
array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=uint8)
>>> svm_clf.classes_[5]
5
-
classes_
,存儲目標類的列表,按值的大小排序(索引與類值不一定相同);
強制使用 OneVsRestClassifier 策略訓(xùn)練 SVC 多類分類器
>>> from sklearn.multiclass import OneVsRestClassifier
>>> ovr_clf = OneVsRestClassifier(SVC())
>>> ovr_clf.fit(X_train, y_train)
>>> ovr_clf.predict(some_digit)
array([5], dtype=uint8)
>>> len(ovr_clf.estimators_)
10
-
OneVsRestClassifier
,OvR 策略實現(xiàn)類; -
OneVsOneClassifier
,OvO 策略實現(xiàn)類;
訓(xùn)練 SGDClassifier 的多類分類器
>>> sgd_clf.fit(X_train, y_train)
>>> sgd_clf.predict([some_digit])
array([3], dtype=uint8)
SGC 分類器可以直接將實例分為多個類,不必運行 OvR
或 OvO
;
使用 decision_function() 計算每個實例分類為每個類的概率
>>> sgd_clf.decision_function(some_digit)
array([[-31893.03095419, -34419.69069632, -9530.63950739,
1823.73154031, -22320.14822878, -1385.80478895,
-26188.91070951, -16147.51323997, -4604.35491274,
-12050.767298 ]])
第 3 類得分 1823,其他都是負分值(預(yù)測錯誤,實際是 5);
使用 scross_val_score() 評估 SGDClassifier 的準確性
>>> cross_val_score(sgd_clf, X_train, y_train, cv=3, scoring="accuracy")
array([0.87365, 0.85835, 0.8689 ])
每個折疊的準確率在 85% 以上(隨機分類器準確率約為 10%);
通過縮放對 SGD 分離進行優(yōu)化
>>> from sklearn.preprocessing import StandardScaler
>>> scaler = StandardScaler()
>>> X_train_scaled = scaler.fit_transform(X_train.astype(np.float64))
>>> cross_val_score(sgd_clf, X_train_scaled, y_train, cv=3, scoring="accuracy")
array([0.8983, 0.891 , 0.9018])
簡單縮放訓(xùn)練集數(shù)據(jù)后,準確率提升到 89%;
5. 誤差分析
使用 cross_val_predict() 進行預(yù)測并計算混淆矩陣
>>> y_train_pred = cross_val_predict(sgd_clf, X_train_scaled, y_train, cv=3)
>>> conf_mx = confusion_matrix(y_train, y_train_pred)
>>> conf_mx
array([[5577, 0, 22, 5, 8, 43, 36, 6, 225, 1],
[ 0, 6400, 37, 24, 4, 44, 4, 7, 212, 10],
[ 27, 27, 5220, 92, 73, 27, 67, 36, 378, 11],
[ 22, 17, 117, 5227, 2, 203, 27, 40, 403, 73],
[ 12, 14, 41, 9, 5182, 12, 34, 27, 347, 164],
[ 27, 15, 30, 168, 53, 4444, 75, 14, 535, 60],
[ 30, 15, 42, 3, 44, 97, 5552, 3, 131, 1],
[ 21, 10, 51, 30, 49, 12, 3, 5684, 195, 210],
[ 17, 63, 48, 86, 3, 126, 25, 10, 5429, 44],
[ 25, 18, 30, 64, 118, 36, 1, 179, 371, 5107]])
使用 Matplotlib 的 matshow() 查看混淆矩陣
plt.matshow(conf_mx, cmap=plt.cm.gray)
plt.show()
大多數(shù)圖片被分到對角線上,說明它們被正確分類了;數(shù)字 5 略暗,說明可能數(shù)字 5 較少,也可能數(shù)字 5 的分類效果不如其他數(shù)字;
將混淆矩陣中的每個值除以相應(yīng)類中的圖片數(shù)量,這樣比較的就是錯誤率(而非錯誤的絕對值)
row_sums = conf_mx.sum(axis=1, keepdims=True)
norm_conf_mx = conf_mx / row_sums
重新繪制混淆矩陣效果圖
用 0 填充對角線,只看錯誤部分;
np.fill_diagonal(norm_conf_mx, 0)
plt.matshow(norm_conf_mx, cmap=plt.cm.gray)
plt.show()
每行代表實際類、每列代表預(yù)測類;
- 第 8 列比較亮,說明許多圖片被錯誤的分類為了 8;
- 改進數(shù)字 8 的分類錯誤,可以試著收集更多像數(shù)字 8 的訓(xùn)練數(shù)據(jù),以便分類器學(xué)會將它們與真實的數(shù)字 8 區(qū)分開;也可以開發(fā)一些新特征用來改進分類器(計算閉環(huán)的數(shù)量,如 8 有兩個、6 有一個、5 沒有);還可以對圖片進行預(yù)處理(Scikit-Image、Pillow、OpenCV 等),讓某些模式更為突出,如閉環(huán)等;
- 數(shù)字 3 和數(shù)字 5 經(jīng)常被混淆,兩個方向的交叉處較亮;
- 可以分析單個錯誤示例在做什么,為何失??;
查看數(shù)字 3 和數(shù)字 5
cl_a, cl_b = 3, 5
X_aa = X_train[(y_train == cl_a) & (y_train_pred == cl_a)]
X_ab = X_train[(y_train == cl_a) & (y_train_pred == cl_b)]
X_ba = X_train[(y_train == cl_b) & (y_train_pred == cl_a)]
X_bb = X_train[(y_train == cl_b) & (y_train_pred == cl_b)]
plt.figure(figsize=(8,8))
plt.subplot(221); plot_digits(X_aa[:25], images_per_row=5)
plt.subplot(222); plot_digits(X_ab[:25], images_per_row=5)
plt.subplot(223); plot_digits(X_ba[:25], images_per_row=5)
plt.subplot(224); plot_digits(X_bb[:25], images_per_row=5)
plt.show()
左側(cè)兩個 5 × 5 5 \times 5 5×5 矩陣顯示了唄分類為數(shù)字 3 的圖片,右側(cè)兩個 5 × 5 5 \times 5 5×5 矩陣顯示了被分類為數(shù)字 5 的圖片(左下和右上為分類錯誤示例);
SGD 是一個簡單的線性模型,它為每一個像素分配一個各個類別的權(quán)重,當它看到新圖片時,將加權(quán)后的 像素強度匯總,從而得到一個分數(shù)進行分類;而 3 和 5 的像素位大多重疊,因此容易混淆;
減少 3 和 5 之間混淆的方式可以是對圖片進行預(yù)處理,如確保他們在中心位置且沒有選擇;
6. 多標簽分類
-
多標簽分類
,分類器為每個實例輸出多個類(如一張圖片識別出多個人);
使用 KNeighborsClassifier 創(chuàng)建多標簽分類
-
KNeighborsClassifier
,支持多標簽分類,不是所有分類器都支持;
>>> from sklearn.neighbors import KNeighborsClassifier
>>> y_train_large = (y_train >= 7) # 大數(shù)標簽
>>> y_train_odd = (y_train % 2 == 1) # 奇數(shù)標簽
>>> y_multilabel = np.c_[y_train_large, y_train_odd] # 多標簽數(shù)組
>>> knn_clf = KNeighborsClassifier()
>>> knn_clf.fit(X_train, y_multilabel)
>>> knn_clf.predict(some_digit)
array([[False, True]])
分類正確:數(shù)字 5 不是大數(shù),是奇數(shù);
多標簽分類器的性能評估
>>> y_train_knn_pred = cross_val_predict(knn_clf, X_train, y_multilabel, cv=3)
>>> f1_score(y_multilabel, y_train_knn_pred, average="macro")
0.976410265560605
假設(shè)所有標簽都同等重要,可以通過測量每個標簽的 F 1 F_1 F1? 分數(shù)(或其他任何二元分類器指標),并計算它們的平均分數(shù);
但實際往往并發(fā)如此,比如識別圖片中的多個人,其中有的人可能拍了很多照片,那這個人的權(quán)重就要高很多;這時需要給每個標簽設(shè)置一個相當?shù)臋?quán)重(可以是具有該目標標簽的實例的數(shù)量);
7. 多輸出分類
-
多輸出分類
,或稱多輸出多分類,是多標簽分類的泛化,其標簽也可以是多類的;
1. 消除圖片中的噪聲
目標:構(gòu)建一個系統(tǒng),輸入一張有噪聲的圖片,系統(tǒng)輸出一張干凈的數(shù)字圖片;
分類和回歸之間有時是模糊的,這個示例即可一說是多輸出分類任務(wù)
,也可以說是像素強度的回歸任務(wù)
;
使用 NumPy 的 randint() 為 MNIST 圖片添加噪聲
noise = np.random.randint(0, 100, (len(X_train), 784))
X_train_mod = X_train + noise
noise = np.random.randint(0, 100, (len(X_test), 784))
X_test_mod = X_test + noise
y_train_mod = X_train
y_test_mod = X_test
查看圖片樣例
plt.subplot(121)
plot_digit(X_test_mod[:1].to_numpy())
plt.subplot(122)
plot_digit(y_test_mod[:1].to_numpy())
plt.show()
通過訓(xùn)練分類器,清洗噪聲圖片
knn_clf.fit(X_train_mod, y_train_mod)
clean_digit = knn_clf.predict(X_test_mod[:1].to_numpy())
plot_digit(clean_digit)
清洗后的效果與原圖相近了!
- 上一篇:「ML 實踐篇」回歸系統(tǒng):房價中位數(shù)預(yù)測
- 下一篇:「ML 實踐篇」模型訓(xùn)練
PS:歡迎各路道友閱讀
與評論
,感謝道友點贊
、關(guān)注
、收藏
!文章來源:http://www.zghlxwxcb.cn/news/detail-401164.html
參考資料:文章來源地址http://www.zghlxwxcb.cn/news/detail-401164.html
- [1]《機器學(xué)習(xí)》
- [2]《機器學(xué)習(xí)實戰(zhàn)》
到了這里,關(guān)于「ML 實踐篇」分類系統(tǒng):圖片數(shù)字識別的文章就介紹完了。如果您還想了解更多內(nèi)容,請在右上角搜索TOY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!