【scikit-learn】機械学習チュートリアル-3【Python】

本ページでは以下のページを要約するとともに、個人的な解説も記載しています。独特な解釈をしている部分があるので、誤りなどの指摘はtwitterまでお願いします。

参考ページ
はじめに — scikit-learn 1.1.2 ドキュメント
https://scikit-learn.org/stable/tutorial/basic/tutorial.html
閲覧日:2022年10月8日

https://www.yosoaidol.com/p/scikit-learnpython.html
よりも少しだけ高度な内容になっています。

「機械学習チュートリアル-2」はこちら

学習と予測

「手書き文字」のデータセットの場合のタスクは、与えられた画像(実際は8×8のそれぞれの画素の白黒の情報量を1次元配列で表している)から、それを表す数字を予測することです。
0から9のそれぞれのサンプルが学習データとして与えられ、AIモデルを作成し未知のデータを予測できるようにします。

scikit-learn ではfit(X, y)で学習し、predict(T)で予測を行います。

今回はサポートベクターマシンというアルゴリズムsklearn.svm.SVCを使用し、分類を行います。

分類用のサポートベクターマシンは以下のコードで呼び出します。

from  sklearn  import  svm
clf = svm.SVC(gamma=0.001, C=100.)

この例ではgammmaCを手動で設定します。細かい説明は省きますが、gammaは決定境界の複雑さに影響します。大きいほど複雑な境界になります。Cは誤分類を許容するパラメータです。

今回は前回ロードしたdigits.dataを利用してAIモデルを作成してみましょう
再掲

from  sklearn  import  datasets
digits.data
'''
array([[ 0., 0., 5., ..., 0., 0., 0.],
[ 0., 0., 0., ..., 10., 0., 0.],
[ 0., 0., 0., ..., 16., 9., 0.],
...
[ 0., 0., 1., ..., 6., 0., 0.],
[ 0., 0., 2., ..., 12., 0., 0.],
[ 0., 0., 10., ..., 12., 1., 0.]])
'''

fitメソッドを使用することで学習をすることができます。今回は、最後のデータ以外を使って学習し、最後のデータで予測を行います。
[:-1]で配列の最初から一番最後の一つ手前までのデータを取得しています。

clf.fit(digits.data[:-1], digits.target[:-1])
#SVC(C=100.0, gamma=0.001)

以下のコードで最後のデータが何の数値を表すのかを予測した結果を出力します。

clf.predict(digits.data[-1:])
#array([8])

「8」と予測されました。
実際の画像データは以下になります。

参考
以下のコードで画像に戻すことができます。

import  numpy  as  np
import  matplotlib.pyplot  as  plt
img = np.reshape(digits.data[-1:], (8,8))
plt.imshow(img, cmap=plt.cm.gray_r, interpolation='nearest')
plt.axis('off')
plt.show()

あなたはこの予測結果が正しいと判断しますか?結局元の画像の解像度が低いため、うまく予測できているかどうかはわかりません。

本データセットのもう少し詳しい説明はこちら。Recognizing hand-written digits