【機器學習】入門第一步

機器學習前提介紹:

  • 使用python語言,最好使用python3
  • 使用Jupyter notebook
  • 熟練使用Numpy/SciPy/Pandas/matplotlib
  • 機器學習主要框架scikit-learn

另外,為了方便呈現數據,這裡使用了mglarn模塊。該模塊的使用不必費腦學習,只需要知道它可以幫助美化圖表、呈現數據即可。

# 在學習之前,先導入這些常用的模塊
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import mglearn

構建一個簡單的機器學習應用

假設這裡已經收集了關於鳶尾花的測量數據:花瓣的長度和寬度;花萼的長度和寬度。這些花共有三個品種:setosa/versicolor/virginica。並且事先已經將所有鳶尾花的數據與分類做了對應關係。

如果現在又有一批新的關於鳶尾花的數據,但沒有做出分類,是否可以根據其花瓣的長度和寬度、花萼的長度和寬度來預測出其類別呢?

以上問題,是一個分類問題,最終對於數據結果的輸出叫做類別

由於事先對已有的鳶尾花數據做了分類處理,再從這些數據的經驗中判斷新數據的分類,這種學習方式被叫做監督式學習,即從給定好的輸入與輸出的對應關係中,得出新的數據可能的結果。

第一步,獲得數據

鳶尾花數據集已經包含在 scikit-learn 的 datasets 模塊中,可以直接調用 load_iris 函數來載入數據:

# 導入load_iris
from sklearn.datasets import load_iris
# 調用數據函數
iris_dataset = load_iris()
# 展示結果
iris_dataset
{data: array([[5.1, 3.5, 1.4, 0.2],
[4.9, 3. , 1.4, 0.2],
[4.7, 3.2, 1.3, 0.2],
[4.6, 3.1, 1.5, 0.2],
[5. , 3.6, 1.4, 0.2],
[5.4, 3.9, 1.7, 0.4],
[4.6, 3.4, 1.4, 0.3],
[5. , 3.4, 1.5, 0.2],
[4.4, 2.9, 1.4, 0.2],
[4.9, 3.1, 1.5, 0.1],
[5.4, 3.7, 1.5, 0.2],
[4.8, 3.4, 1.6, 0.2],
[4.8, 3. , 1.4, 0.1],
[4.3, 3. , 1.1, 0.1],
[5.8, 4. , 1.2, 0.2],
[5.7, 4.4, 1.5, 0.4],
[5.4, 3.9, 1.3, 0.4],
[5.1, 3.5, 1.4, 0.3],
[5.7, 3.8, 1.7, 0.3],
[5.1, 3.8, 1.5, 0.3],
[5.4, 3.4, 1.7, 0.2],
[5.1, 3.7, 1.5, 0.4],
[4.6, 3.6, 1. , 0.2],
[5.1, 3.3, 1.7, 0.5],
[4.8, 3.4, 1.9, 0.2],
[5. , 3. , 1.6, 0.2],
[5. , 3.4, 1.6, 0.4],
[5.2, 3.5, 1.5, 0.2],
[5.2, 3.4, 1.4, 0.2],
[4.7, 3.2, 1.6, 0.2],
[4.8, 3.1, 1.6, 0.2],
[5.4, 3.4, 1.5, 0.4],
[5.2, 4.1, 1.5, 0.1],
[5.5, 4.2, 1.4, 0.2],
[4.9, 3.1, 1.5, 0.1],
[5. , 3.2, 1.2, 0.2],
[5.5, 3.5, 1.3, 0.2],
[4.9, 3.1, 1.5, 0.1],
[4.4, 3. , 1.3, 0.2],
[5.1, 3.4, 1.5, 0.2],
[5. , 3.5, 1.3, 0.3],
[4.5, 2.3, 1.3, 0.3],
[4.4, 3.2, 1.3, 0.2],
[5. , 3.5, 1.6, 0.6],
[5.1, 3.8, 1.9, 0.4],
[4.8, 3. , 1.4, 0.3],
[5.1, 3.8, 1.6, 0.2],
[4.6, 3.2, 1.4, 0.2],
[5.3, 3.7, 1.5, 0.2],
[5. , 3.3, 1.4, 0.2],
[7. , 3.2, 4.7, 1.4],
[6.4, 3.2, 4.5, 1.5],
[6.9, 3.1, 4.9, 1.5],
[5.5, 2.3, 4. , 1.3],
[6.5, 2.8, 4.6, 1.5],
[5.7, 2.8, 4.5, 1.3],
[6.3, 3.3, 4.7, 1.6],
[4.9, 2.4, 3.3, 1. ],
[6.6, 2.9, 4.6, 1.3],
[5.2, 2.7, 3.9, 1.4],
[5. , 2. , 3.5, 1. ],
[5.9, 3. , 4.2, 1.5],
[6. , 2.2, 4. , 1. ],
[6.1, 2.9, 4.7, 1.4],
[5.6, 2.9, 3.6, 1.3],
[6.7, 3.1, 4.4, 1.4],
[5.6, 3. , 4.5, 1.5],
[5.8, 2.7, 4.1, 1. ],
[6.2, 2.2, 4.5, 1.5],
[5.6, 2.5, 3.9, 1.1],
[5.9, 3.2, 4.8, 1.8],
[6.1, 2.8, 4. , 1.3],
[6.3, 2.5, 4.9, 1.5],
[6.1, 2.8, 4.7, 1.2],
[6.4, 2.9, 4.3, 1.3],
[6.6, 3. , 4.4, 1.4],
[6.8, 2.8, 4.8, 1.4],
[6.7, 3. , 5. , 1.7],
[6. , 2.9, 4.5, 1.5],
[5.7, 2.6, 3.5, 1. ],
[5.5, 2.4, 3.8, 1.1],
[5.5, 2.4, 3.7, 1. ],
[5.8, 2.7, 3.9, 1.2],
[6. , 2.7, 5.1, 1.6],
[5.4, 3. , 4.5, 1.5],
[6. , 3.4, 4.5, 1.6],
[6.7, 3.1, 4.7, 1.5],
[6.3, 2.3, 4.4, 1.3],
[5.6, 3. , 4.1, 1.3],
[5.5, 2.5, 4. , 1.3],
[5.5, 2.6, 4.4, 1.2],
[6.1, 3. , 4.6, 1.4],
[5.8, 2.6, 4. , 1.2],
[5. , 2.3, 3.3, 1. ],
[5.6, 2.7, 4.2, 1.3],
[5.7, 3. , 4.2, 1.2],
[5.7, 2.9, 4.2, 1.3],
[6.2, 2.9, 4.3, 1.3],
[5.1, 2.5, 3. , 1.1],
[5.7, 2.8, 4.1, 1.3],
[6.3, 3.3, 6. , 2.5],
[5.8, 2.7, 5.1, 1.9],
[7.1, 3. , 5.9, 2.1],
[6.3, 2.9, 5.6, 1.8],
[6.5, 3. , 5.8, 2.2],
[7.6, 3. , 6.6, 2.1],
[4.9, 2.5, 4.5, 1.7],
[7.3, 2.9, 6.3, 1.8],
[6.7, 2.5, 5.8, 1.8],
[7.2, 3.6, 6.1, 2.5],
[6.5, 3.2, 5.1, 2. ],
[6.4, 2.7, 5.3, 1.9],
[6.8, 3. , 5.5, 2.1],
[5.7, 2.5, 5. , 2. ],
[5.8, 2.8, 5.1, 2.4],
[6.4, 3.2, 5.3, 2.3],
[6.5, 3. , 5.5, 1.8],
[7.7, 3.8, 6.7, 2.2],
[7.7, 2.6, 6.9, 2.3],
[6. , 2.2, 5. , 1.5],
[6.9, 3.2, 5.7, 2.3],
[5.6, 2.8, 4.9, 2. ],
[7.7, 2.8, 6.7, 2. ],
[6.3, 2.7, 4.9, 1.8],
[6.7, 3.3, 5.7, 2.1],
[7.2, 3.2, 6. , 1.8],
[6.2, 2.8, 4.8, 1.8],
[6.1, 3. , 4.9, 1.8],
[6.4, 2.8, 5.6, 2.1],
[7.2, 3. , 5.8, 1.6],
[7.4, 2.8, 6.1, 1.9],
[7.9, 3.8, 6.4, 2. ],
[6.4, 2.8, 5.6, 2.2],
[6.3, 2.8, 5.1, 1.5],
[6.1, 2.6, 5.6, 1.4],
[7.7, 3. , 6.1, 2.3],
[6.3, 3.4, 5.6, 2.4],
[6.4, 3.1, 5.5, 1.8],
[6. , 3. , 4.8, 1.8],
[6.9, 3.1, 5.4, 2.1],
[6.7, 3.1, 5.6, 2.4],
[6.9, 3.1, 5.1, 2.3],
[5.8, 2.7, 5.1, 1.9],
[6.8, 3.2, 5.9, 2.3],
[6.7, 3.3, 5.7, 2.5],
[6.7, 3. , 5.2, 2.3],
[6.3, 2.5, 5. , 1.9],
[6.5, 3. , 5.2, 2. ],
[6.2, 3.4, 5.4, 2.3],
[5.9, 3. , 5.1, 1.8]]),
target: array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]),
target_names: array([setosa, versicolor, virginica], dtype=<U10),
DESCR: Iris Plants Database
====================

Notes
-----
Data Set Characteristics:
:Number of Instances: 150 (50 in each of three classes)
:Number of Attributes: 4 numeric, predictive attributes and the class
:Attribute Information:
- sepal length in cm
- sepal width in cm
- petal length in cm
- petal width in cm
- class:
- Iris-Setosa
- Iris-Versicolour
- Iris-Virginica
:Summary Statistics:

============== ==== ==== ======= ===== ====================
Min Max Mean SD Class Correlation
============== ==== ==== ======= ===== ====================
sepal length: 4.3 7.9 5.84 0.83 0.7826
sepal width: 2.0 4.4 3.05 0.43 -0.4194
petal length: 1.0 6.9 3.76 1.76 0.9490 (high!)
petal width: 0.1 2.5 1.20 0.76 0.9565 (high!)
============== ==== ==== ======= ===== ====================

:Missing Attribute Values: None
:Class Distribution: 33.3% for each of 3 classes.
:Creator: R.A. Fisher
:Donor: Michael Marshall (MARSHALL%[email protected])
:Date: July, 1988

This is a copy of UCI ML iris datasets.
http://archive.ics.uci.edu/ml/datasets/Iris

The famous Iris database, first used by Sir R.A Fisher

This is perhaps the best known database to be found in the
pattern recognition literature. Fishers paper is a classic in the field and
is referenced frequently to this day. (See Duda & Hart, for example.) The
data set contains 3 classes of 50 instances each, where each class refers to a
type of iris plant. One class is linearly separable from the other 2; the
latter are NOT linearly separable from each other.

References
----------
- Fisher,R.A. "The use of multiple measurements in taxonomic problems"
Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions to
Mathematical Statistics" (John Wiley, NY, 1950).
- Duda,R.O., & Hart,P.E. (1973) Pattern Classification and Scene Analysis.
(Q327.D83) John Wiley & Sons. ISBN 0-471-22361-1. See page 218.
- Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System
Structure and Classification Rule for Recognition in Partially Exposed
Environments". IEEE Transactions on Pattern Analysis and Machine
Intelligence, Vol. PAMI-2, No. 1, 67-71.
- Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule". IEEE Transactions
on Information Theory, May 1972, 431-433.
- See also: 1988 MLC Proceedings, 54-64. Cheeseman et al"s AUTOCLASS II
conceptual clustering system finds 3 classes in the data.
- Many, many more ...
,
feature_names: [sepal length (cm),
sepal width (cm),
petal length (cm),
petal width (cm)]}

laod_iris 返回的是一個 Buch 對象,與字典非常相似,裡面包含鍵和值:

# 查看鍵
iris_dataset.keys()
dict_keys([data, target, target_names, DESCR, feature_names])

  • data 對應是的鳶尾花測量的數據集
  • target 對應的是分類
  • target_names 對應是的類別的名稱
  • DESCR 對應的是數據集的說明
  • feature_names 對應的是數據特徵列表

# 查看數據集
iris_dataset.data
array([[5.1, 3.5, 1.4, 0.2],
[4.9, 3. , 1.4, 0.2],
[4.7, 3.2, 1.3, 0.2],
[4.6, 3.1, 1.5, 0.2],
[5. , 3.6, 1.4, 0.2],
[5.4, 3.9, 1.7, 0.4],
[4.6, 3.4, 1.4, 0.3],
[5. , 3.4, 1.5, 0.2],
[4.4, 2.9, 1.4, 0.2],
[4.9, 3.1, 1.5, 0.1],
[5.4, 3.7, 1.5, 0.2],
[4.8, 3.4, 1.6, 0.2],
[4.8, 3. , 1.4, 0.1],
[4.3, 3. , 1.1, 0.1],
[5.8, 4. , 1.2, 0.2],
[5.7, 4.4, 1.5, 0.4],
[5.4, 3.9, 1.3, 0.4],
[5.1, 3.5, 1.4, 0.3],
[5.7, 3.8, 1.7, 0.3],
[5.1, 3.8, 1.5, 0.3],
[5.4, 3.4, 1.7, 0.2],
[5.1, 3.7, 1.5, 0.4],
[4.6, 3.6, 1. , 0.2],
[5.1, 3.3, 1.7, 0.5],
[4.8, 3.4, 1.9, 0.2],
[5. , 3. , 1.6, 0.2],
[5. , 3.4, 1.6, 0.4],
[5.2, 3.5, 1.5, 0.2],
[5.2, 3.4, 1.4, 0.2],
[4.7, 3.2, 1.6, 0.2],
[4.8, 3.1, 1.6, 0.2],
[5.4, 3.4, 1.5, 0.4],
[5.2, 4.1, 1.5, 0.1],
[5.5, 4.2, 1.4, 0.2],
[4.9, 3.1, 1.5, 0.1],
[5. , 3.2, 1.2, 0.2],
[5.5, 3.5, 1.3, 0.2],
[4.9, 3.1, 1.5, 0.1],
[4.4, 3. , 1.3, 0.2],
[5.1, 3.4, 1.5, 0.2],
[5. , 3.5, 1.3, 0.3],
[4.5, 2.3, 1.3, 0.3],
[4.4, 3.2, 1.3, 0.2],
[5. , 3.5, 1.6, 0.6],
[5.1, 3.8, 1.9, 0.4],
[4.8, 3. , 1.4, 0.3],
[5.1, 3.8, 1.6, 0.2],
[4.6, 3.2, 1.4, 0.2],
[5.3, 3.7, 1.5, 0.2],
[5. , 3.3, 1.4, 0.2],
[7. , 3.2, 4.7, 1.4],
[6.4, 3.2, 4.5, 1.5],
[6.9, 3.1, 4.9, 1.5],
[5.5, 2.3, 4. , 1.3],
[6.5, 2.8, 4.6, 1.5],
[5.7, 2.8, 4.5, 1.3],
[6.3, 3.3, 4.7, 1.6],
[4.9, 2.4, 3.3, 1. ],
[6.6, 2.9, 4.6, 1.3],
[5.2, 2.7, 3.9, 1.4],
[5. , 2. , 3.5, 1. ],
[5.9, 3. , 4.2, 1.5],
[6. , 2.2, 4. , 1. ],
[6.1, 2.9, 4.7, 1.4],
[5.6, 2.9, 3.6, 1.3],
[6.7, 3.1, 4.4, 1.4],
[5.6, 3. , 4.5, 1.5],
[5.8, 2.7, 4.1, 1. ],
[6.2, 2.2, 4.5, 1.5],
[5.6, 2.5, 3.9, 1.1],
[5.9, 3.2, 4.8, 1.8],
[6.1, 2.8, 4. , 1.3],
[6.3, 2.5, 4.9, 1.5],
[6.1, 2.8, 4.7, 1.2],
[6.4, 2.9, 4.3, 1.3],
[6.6, 3. , 4.4, 1.4],
[6.8, 2.8, 4.8, 1.4],
[6.7, 3. , 5. , 1.7],
[6. , 2.9, 4.5, 1.5],
[5.7, 2.6, 3.5, 1. ],
[5.5, 2.4, 3.8, 1.1],
[5.5, 2.4, 3.7, 1. ],
[5.8, 2.7, 3.9, 1.2],
[6. , 2.7, 5.1, 1.6],
[5.4, 3. , 4.5, 1.5],
[6. , 3.4, 4.5, 1.6],
[6.7, 3.1, 4.7, 1.5],
[6.3, 2.3, 4.4, 1.3],
[5.6, 3. , 4.1, 1.3],
[5.5, 2.5, 4. , 1.3],
[5.5, 2.6, 4.4, 1.2],
[6.1, 3. , 4.6, 1.4],
[5.8, 2.6, 4. , 1.2],
[5. , 2.3, 3.3, 1. ],
[5.6, 2.7, 4.2, 1.3],
[5.7, 3. , 4.2, 1.2],
[5.7, 2.9, 4.2, 1.3],
[6.2, 2.9, 4.3, 1.3],
[5.1, 2.5, 3. , 1.1],
[5.7, 2.8, 4.1, 1.3],
[6.3, 3.3, 6. , 2.5],
[5.8, 2.7, 5.1, 1.9],
[7.1, 3. , 5.9, 2.1],
[6.3, 2.9, 5.6, 1.8],
[6.5, 3. , 5.8, 2.2],
[7.6, 3. , 6.6, 2.1],
[4.9, 2.5, 4.5, 1.7],
[7.3, 2.9, 6.3, 1.8],
[6.7, 2.5, 5.8, 1.8],
[7.2, 3.6, 6.1, 2.5],
[6.5, 3.2, 5.1, 2. ],
[6.4, 2.7, 5.3, 1.9],
[6.8, 3. , 5.5, 2.1],
[5.7, 2.5, 5. , 2. ],
[5.8, 2.8, 5.1, 2.4],
[6.4, 3.2, 5.3, 2.3],
[6.5, 3. , 5.5, 1.8],
[7.7, 3.8, 6.7, 2.2],
[7.7, 2.6, 6.9, 2.3],
[6. , 2.2, 5. , 1.5],
[6.9, 3.2, 5.7, 2.3],
[5.6, 2.8, 4.9, 2. ],
[7.7, 2.8, 6.7, 2. ],
[6.3, 2.7, 4.9, 1.8],
[6.7, 3.3, 5.7, 2.1],
[7.2, 3.2, 6. , 1.8],
[6.2, 2.8, 4.8, 1.8],
[6.1, 3. , 4.9, 1.8],
[6.4, 2.8, 5.6, 2.1],
[7.2, 3. , 5.8, 1.6],
[7.4, 2.8, 6.1, 1.9],
[7.9, 3.8, 6.4, 2. ],
[6.4, 2.8, 5.6, 2.2],
[6.3, 2.8, 5.1, 1.5],
[6.1, 2.6, 5.6, 1.4],
[7.7, 3. , 6.1, 2.3],
[6.3, 3.4, 5.6, 2.4],
[6.4, 3.1, 5.5, 1.8],
[6. , 3. , 4.8, 1.8],
[6.9, 3.1, 5.4, 2.1],
[6.7, 3.1, 5.6, 2.4],
[6.9, 3.1, 5.1, 2.3],
[5.8, 2.7, 5.1, 1.9],
[6.8, 3.2, 5.9, 2.3],
[6.7, 3.3, 5.7, 2.5],
[6.7, 3. , 5.2, 2.3],
[6.3, 2.5, 5. , 1.9],
[6.5, 3. , 5.2, 2. ],
[6.2, 3.4, 5.4, 2.3],
[5.9, 3. , 5.1, 1.8]])

datak中的數據有四列,每列表示花萼的長度、花萼的寬度、花瓣的長度、花瓣的寬度,格式為Numpy數組。

# 查看數據的數量
iris_dataset.data.shape
(150, 4)

可以看出數據一共有150行,4列。

# 查看數據對應的分類
iris_dataset.target
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

每一行代表一個花朵,也對應一個類別,這裡的0,1,2分別代表三個品種。

要點知道:

數據集中的個體叫做樣本,其屬性叫作特徵或標籤

# 查看數據的特徵
iris_dataset.feature_names
[sepal length (cm),
sepal width (cm),
petal length (cm),
petal width (cm)]

第二步:訓練數據

這裡不會將所有的數據都用來訓練,還要留出一部分用來預測。這裡的預測指的是,通過訓練後,形成一個模型,使用未經訓練的數據去測試模型是否能準確預測出其分類。

sklearn-learn 中的 train_split 函數可以打亂數據集並進行拆分,默認會將75%的數據用作訓練集,25%的數據集用作測試集。

在書寫上,數據通常用大寫的X表示,標籤則用小寫的y表示。一般大寫用來表示二維矩陣,小寫表示一維的向量。

# 導入 train_split函數
from sklearn.model_selection import train_test_split
# 拆分數據
X_train, X_test, y_train, y_test = train_test_split(iris_dataset[data], iris_dataset[target], random_state=0)

train_test_split 中需要三個參數,第一個是數據集,第二個是標籤集,第三個是隨機種子數。

由於iris_dataset是一個Buch對象,因為既可以使用屬性的方式也可以使用中括弧的方式獲得對應的值。

random_state是指利用偽隨機數生成器將數據集打亂。

# 查看訓練數據集
X_train
array([[5.9, 3. , 4.2, 1.5],
[5.8, 2.6, 4. , 1.2],
[6.8, 3. , 5.5, 2.1],
[4.7, 3.2, 1.3, 0.2],
[6.9, 3.1, 5.1, 2.3],
[5. , 3.5, 1.6, 0.6],
[5.4, 3.7, 1.5, 0.2],
[5. , 2. , 3.5, 1. ],
[6.5, 3. , 5.5, 1.8],
[6.7, 3.3, 5.7, 2.5],
[6. , 2.2, 5. , 1.5],
[6.7, 2.5, 5.8, 1.8],
[5.6, 2.5, 3.9, 1.1],
[7.7, 3. , 6.1, 2.3],
[6.3, 3.3, 4.7, 1.6],
[5.5, 2.4, 3.8, 1.1],
[6.3, 2.7, 4.9, 1.8],
[6.3, 2.8, 5.1, 1.5],
[4.9, 2.5, 4.5, 1.7],
[6.3, 2.5, 5. , 1.9],
[7. , 3.2, 4.7, 1.4],
[6.5, 3. , 5.2, 2. ],
[6. , 3.4, 4.5, 1.6],
[4.8, 3.1, 1.6, 0.2],
[5.8, 2.7, 5.1, 1.9],
[5.6, 2.7, 4.2, 1.3],
[5.6, 2.9, 3.6, 1.3],
[5.5, 2.5, 4. , 1.3],
[6.1, 3. , 4.6, 1.4],
[7.2, 3.2, 6. , 1.8],
[5.3, 3.7, 1.5, 0.2],
[4.3, 3. , 1.1, 0.1],
[6.4, 2.7, 5.3, 1.9],
[5.7, 3. , 4.2, 1.2],
[5.4, 3.4, 1.7, 0.2],
[5.7, 4.4, 1.5, 0.4],
[6.9, 3.1, 4.9, 1.5],
[4.6, 3.1, 1.5, 0.2],
[5.9, 3. , 5.1, 1.8],
[5.1, 2.5, 3. , 1.1],
[4.6, 3.4, 1.4, 0.3],
[6.2, 2.2, 4.5, 1.5],
[7.2, 3.6, 6.1, 2.5],
[5.7, 2.9, 4.2, 1.3],
[4.8, 3. , 1.4, 0.1],
[7.1, 3. , 5.9, 2.1],
[6.9, 3.2, 5.7, 2.3],
[6.5, 3. , 5.8, 2.2],
[6.4, 2.8, 5.6, 2.1],
[5.1, 3.8, 1.6, 0.2],
[4.8, 3.4, 1.6, 0.2],
[6.5, 3.2, 5.1, 2. ],
[6.7, 3.3, 5.7, 2.1],
[4.5, 2.3, 1.3, 0.3],
[6.2, 3.4, 5.4, 2.3],
[4.9, 3. , 1.4, 0.2],
[5.7, 2.5, 5. , 2. ],
[6.9, 3.1, 5.4, 2.1],
[4.4, 3.2, 1.3, 0.2],
[5. , 3.6, 1.4, 0.2],
[7.2, 3. , 5.8, 1.6],
[5.1, 3.5, 1.4, 0.3],
[4.4, 3. , 1.3, 0.2],
[5.4, 3.9, 1.7, 0.4],
[5.5, 2.3, 4. , 1.3],
[6.8, 3.2, 5.9, 2.3],
[7.6, 3. , 6.6, 2.1],
[5.1, 3.5, 1.4, 0.2],
[4.9, 3.1, 1.5, 0.1],
[5.2, 3.4, 1.4, 0.2],
[5.7, 2.8, 4.5, 1.3],
[6.6, 3. , 4.4, 1.4],
[5. , 3.2, 1.2, 0.2],
[5.1, 3.3, 1.7, 0.5],
[6.4, 2.9, 4.3, 1.3],
[5.4, 3.4, 1.5, 0.4],
[7.7, 2.6, 6.9, 2.3],
[4.9, 2.4, 3.3, 1. ],
[7.9, 3.8, 6.4, 2. ],
[6.7, 3.1, 4.4, 1.4],
[5.2, 4.1, 1.5, 0.1],
[6. , 3. , 4.8, 1.8],
[5.8, 4. , 1.2, 0.2],
[7.7, 2.8, 6.7, 2. ],
[5.1, 3.8, 1.5, 0.3],
[4.7, 3.2, 1.6, 0.2],
[7.4, 2.8, 6.1, 1.9],
[5. , 3.3, 1.4, 0.2],
[6.3, 3.4, 5.6, 2.4],
[5.7, 2.8, 4.1, 1.3],
[5.8, 2.7, 3.9, 1.2],
[5.7, 2.6, 3.5, 1. ],
[6.4, 3.2, 5.3, 2.3],
[6.7, 3. , 5.2, 2.3],
[6.3, 2.5, 4.9, 1.5],
[6.7, 3. , 5. , 1.7],
[5. , 3. , 1.6, 0.2],
[5.5, 2.4, 3.7, 1. ],
[6.7, 3.1, 5.6, 2.4],
[5.8, 2.7, 5.1, 1.9],
[5.1, 3.4, 1.5, 0.2],
[6.6, 2.9, 4.6, 1.3],
[5.6, 3. , 4.1, 1.3],
[5.9, 3.2, 4.8, 1.8],
[6.3, 2.3, 4.4, 1.3],
[5.5, 3.5, 1.3, 0.2],
[5.1, 3.7, 1.5, 0.4],
[4.9, 3.1, 1.5, 0.1],
[6.3, 2.9, 5.6, 1.8],
[5.8, 2.7, 4.1, 1. ],
[7.7, 3.8, 6.7, 2.2],
[4.6, 3.2, 1.4, 0.2]])
# 查看訓練標籤集
y_train
array([1, 1, 2, 0, 2, 0, 0, 1, 2, 2, 2, 2, 1, 2, 1, 1, 2, 2, 2, 2, 1, 2,
1, 0, 2, 1, 1, 1, 1, 2, 0, 0, 2, 1, 0, 0, 1, 0, 2, 1, 0, 1, 2, 1,
0, 2, 2, 2, 2, 0, 0, 2, 2, 0, 2, 0, 2, 2, 0, 0, 2, 0, 0, 0, 1, 2,
2, 0, 0, 0, 1, 1, 0, 0, 1, 0, 2, 1, 2, 1, 0, 2, 0, 2, 0, 0, 2, 0,
2, 1, 1, 1, 2, 2, 1, 1, 0, 1, 2, 2, 0, 1, 1, 1, 1, 0, 0, 0, 2, 1,
2, 0])
# 查看測試數據集
X_test
array([[5.8, 2.8, 5.1, 2.4],
[6. , 2.2, 4. , 1. ],
[5.5, 4.2, 1.4, 0.2],
[7.3, 2.9, 6.3, 1.8],
[5. , 3.4, 1.5, 0.2],
[6.3, 3.3, 6. , 2.5],
[5. , 3.5, 1.3, 0.3],
[6.7, 3.1, 4.7, 1.5],
[6.8, 2.8, 4.8, 1.4],
[6.1, 2.8, 4. , 1.3],
[6.1, 2.6, 5.6, 1.4],
[6.4, 3.2, 4.5, 1.5],
[6.1, 2.8, 4.7, 1.2],
[6.5, 2.8, 4.6, 1.5],
[6.1, 2.9, 4.7, 1.4],
[4.9, 3.1, 1.5, 0.1],
[6. , 2.9, 4.5, 1.5],
[5.5, 2.6, 4.4, 1.2],
[4.8, 3. , 1.4, 0.3],
[5.4, 3.9, 1.3, 0.4],
[5.6, 2.8, 4.9, 2. ],
[5.6, 3. , 4.5, 1.5],
[4.8, 3.4, 1.9, 0.2],
[4.4, 2.9, 1.4, 0.2],
[6.2, 2.8, 4.8, 1.8],
[4.6, 3.6, 1. , 0.2],
[5.1, 3.8, 1.9, 0.4],
[6.2, 2.9, 4.3, 1.3],
[5. , 2.3, 3.3, 1. ],
[5. , 3.4, 1.6, 0.4],
[6.4, 3.1, 5.5, 1.8],
[5.4, 3. , 4.5, 1.5],
[5.2, 3.5, 1.5, 0.2],
[6.1, 3. , 4.9, 1.8],
[6.4, 2.8, 5.6, 2.2],
[5.2, 2.7, 3.9, 1.4],
[5.7, 3.8, 1.7, 0.3],
[6. , 2.7, 5.1, 1.6]])
# 查看測試標籤集
y_test
array([2, 1, 0, 2, 0, 2, 0, 1, 1, 1, 2, 1, 1, 1, 1, 0, 1, 1, 0, 0, 2, 1,
0, 0, 2, 0, 0, 1, 1, 0, 2, 1, 0, 2, 2, 1, 0, 1])

觀察數據

下圖可以通過四個標籤值兩兩對應的關係,查看其表現。(這裡不做深究其原理)

# 將訓練數據轉換成DataFrame
iris_dataframe = pd.DataFrame(X_train, columns=iris_dataset.feature_names)
# 通過scatter_matrix繪製出矩陣圖
grr = pd.scatter_matrix(iris_dataframe, c=y_train, figsize=(15, 15), marker=o, hist_kwds={bins: 20}, s=60, alpha=.8, cmap=mglearn.cm3)
C:UsersAdministratorAnaconda3libsite-packagesipykernel_launcher.py:4: FutureWarning: pandas.scatter_matrix is deprecated, use pandas.plotting.scatter_matrix instead
after removing the cwd from sys.path.

構建第一個模型:k近鄰演算法

想要訓練數據,則需要一個演算法模型。這裡選擇使用k近鄰分類演算法。

k近鄰分類器中k的含義,新數據與訓練集中最近的任意k個鄰居,也就是說,新數據與k個某標籤離得最近,則歸類為該標籤

scikit_lean 中所有的機器學習模型都在各自的類中實現,k近鄰演算法實在 neighors 模塊的 KNeighborsClassifier 類中實現的,我們需要將這個列實例化為一個對象,然後才能使用這個模型

# 導入KNeighborsClassifier模塊
from sklearn.neighbors import KNeighborsClassifier
# 實例化對象
knn = KNeighborsClassifier(n_neighbors=1)

n_neighbors 參數表示k的個數,1一表示按與它相鄰最近的那1個進行分類。

想要基於訓練集來構建模型,需要調用knn對象的fit方法,輸入參數X_train和y_train。

# 訓練數據,並返回模型
knn.fit(X_train,y_train)
KNeighborsClassifier(algorithm=auto, leaf_size=30, metric=minkowski,
metric_params=None, n_jobs=1, n_neighbors=1, p=2,
weights=uniform)

fit方法返回的是knn對象,所以這裡得到了一個表示該對象的字元串

第三步:做出預測

# 假設這裡有一個新的花瓣數據
X_new = np.array([[5,2.9,1,0.2]])

需要注意的是,這裡的數據一定要是二維的數據才可以

調用 knn 的 predict 方法來進行預測

# 調用 predict 函數進行預測
prediction = knn.predict(X_new)
# 查看返回的類型
prediction
array([0])
iris_dataset[target_names][prediction]
array([setosa], dtype=<U10)

predict 方法會返回一個標籤值,通過標籤值,則可獲得其對應的品種名稱

第四步:評估模型

調用測試集,對測試數據中的每朵鳶尾花進行預測,並將預測結果與標籤(一直的品種)進行對比。我們可以通過計算精度來衡量模型的優劣,精度就是品種預測正確的花所佔的比例

y_pred = knn.predict(X_test)
y_pred
array([2, 1, 0, 2, 0, 2, 0, 1, 1, 1, 2, 1, 1, 1, 1, 0, 1, 1, 0, 0, 2, 1,
0, 0, 2, 0, 0, 1, 1, 0, 2, 1, 0, 2, 2, 1, 0, 2])

那麼,測試返回的分類集合,與原始的分類是否一致呢?這裡需要將 y_pred 與 y_test 進行對比

np.mean(y_pred==y_test)
0.9736842105263158

或者直接調用knn的score方法來計算精度

knn.score(X_test,y_test)
0.9736842105263158

可以看出,測試返回的結果中,與原始分類集合具有97%的相似度。

以上便是機器學習的基本流程。O(∩_∩)


推薦閱讀:

TAG:sklearn | 機器學習 | Python |