本ページでは以下のページを要約するとともに、個人的な解説も記載しています。独特な解釈をしている部分があるので、誤りなどの指摘はtwitterまでお願いします。
参考ページ
はじめに — scikit-learn 1.1.2 ドキュメント
https://scikit-learn.org/stable/tutorial/basic/tutorial.html
閲覧日:2022年10月8日
https://www.yosoaidol.com/p/scikit-learnpython.html
よりも少しだけ高度な内容になっています。
学習と予測
「手書き文字」のデータセットの場合のタスクは、与えられた画像(実際は8×8のそれぞれの画素の白黒の情報量を1次元配列で表している)から、それを表す数字を予測することです。
0から9のそれぞれのサンプルが学習データとして与えられ、AIモデルを作成し未知のデータを予測できるようにします。
scikit-learn ではfit(X, y)
で学習し、predict(T)
で予測を行います。
今回はサポートベクターマシンというアルゴリズムsklearn.svm.SVC
を使用し、分類を行います。
分類用のサポートベクターマシンは以下のコードで呼び出します。
from sklearn import svm
clf = svm.SVC(gamma=0.001, C=100.)
この例ではgammma
とC
を手動で設定します。細かい説明は省きますが、gamma
は決定境界の複雑さに影響します。大きいほど複雑な境界になります。C
は誤分類を許容するパラメータです。
今回は前回ロードしたdigits.data
を利用してAIモデルを作成してみましょう
再掲
from sklearn import datasets
digits.data
'''
array([[ 0., 0., 5., ..., 0., 0., 0.],
[ 0., 0., 0., ..., 10., 0., 0.],
[ 0., 0., 0., ..., 16., 9., 0.],
...
[ 0., 0., 1., ..., 6., 0., 0.],
[ 0., 0., 2., ..., 12., 0., 0.],
[ 0., 0., 10., ..., 12., 1., 0.]])
'''
fit
メソッドを使用することで学習をすることができます。今回は、最後のデータ以外を使って学習し、最後のデータで予測を行います。
※[:-1]
で配列の最初から一番最後の一つ手前までのデータを取得しています。
clf.fit(digits.data[:-1], digits.target[:-1])
#SVC(C=100.0, gamma=0.001)
以下のコードで最後のデータが何の数値を表すのかを予測した結果を出力します。
clf.predict(digits.data[-1:])
#array([8])
「8」と予測されました。
実際の画像データは以下になります。
参考
以下のコードで画像に戻すことができます。
import numpy as np
import matplotlib.pyplot as plt
img = np.reshape(digits.data[-1:], (8,8))
plt.imshow(img, cmap=plt.cm.gray_r, interpolation='nearest')
plt.axis('off')
plt.show()
あなたはこの予測結果が正しいと判断しますか?結局元の画像の解像度が低いため、うまく予測できているかどうかはわかりません。
本データセットのもう少し詳しい説明はこちら。Recognizing hand-written digits