- Python计算机视觉与深度学习实战
- 郭卡 戴亮编著
- 1492字
- 2025-02-24 11:09:48
1.2 数据集
人工智能的核心在于数据支持,近几年人工智能技术的快速发展与大数据技术的发展密切相关,大数据技术可以通过数据采集、分析及挖掘等方式,从海量复杂数据中快速提取出有价值的信息,为机器学习算法提供牢固的基础。
在机器学习任务中,数据集有三大功能:训练、验证和测试。
- 训练最好理解,是拟合模型的过程,模型会通过分析数据、调节内部参数从而得到最优的模型效果。
- 验证即验证模型效果,效果可以指导我们调整模型中的超参数(在开始训练之前设置参数,而不是通过训练得到参数),通常会使用少量未参与训练的数据对模型进行验证,在训练的间隙中进行。
- 测试的作用是检查模型是否具有泛化能力(泛化能力是指模型对训练集之外的数据集是否也有很好的拟合能力)。通常会在模型训练完毕之后,选用较多训练集以外的数据进行测试。
所以在机器学习(尤其是深度学习)任务开始前,需要收集大量高质量的数据,对于个人开发者来说,数据只能来源于开源的数据集和自己编写爬虫程序采集到的数据集,收集数据是一个费时费力的过程。
为了方便初学者学习以及进行小规模的算法测试,sklearn提供了不少小型的标准数据集和一些规模略大的真实数据集。除这些数据集之外,sklearn还能够按照一定规则自己生成数据集。3种类型的数据集分别通过load***
、fetch***
和make***
这3种函数形式获取,下面将对这几个接口做简单介绍。
1.2.1 自带的小型数据集
sklearn中最常用的数据集有3个:load_iris
、load_boston
和load_digits
。
直接从sklearn.datasets
中导入load_iris
,得到的数据是字典形式,可以通过字典中的键值选择数据的各项属性。
load_iris
是加载鸢尾花数据集的函数,该数据集包含了150条鸢尾花数据,其中包含的鸢尾花数据(在机器学习中,这种可以直接用于建模的数据叫作特征)有4种:
- 鸢尾花的花瓣长度(cm);
- 鸢尾花的花瓣宽度(cm);
- 鸢尾花的花萼长度(cm);
- 鸢尾花的花萼宽度(cm)。
标签是鸢尾花的种类,3个种类分别用0
、1
和2
表示。下面是load_iris
的使用方法:
>>> d = load_iris()
>>> d.keys()
dict_keys(['data', 'target', 'target_names', 'DESCR', 'feature_names', 'filename'])
>>> # 鸢尾花的类别名
>>> d['target_names']
array(['setosa', 'versicolor', 'virginica'], dtype='<U10')
>>> # 特征名称
>>> d['feature_names']
['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
>>> d['data'].shape
(150, 4)
>>> set(list(d['target']))
{0, 1, 2}
在上述代码中,通过load_iris
函数取出了鸢尾花数据并将其赋值给d
,通过keys
方法查看数据集中各个项目的名称,如鸢尾花的类别名(target_names
)、特征名(feature_names
)、数据(data
)与标签(target
)等。
load_boston
是关于波士顿房屋特征与房价之间关系的数据集,包含13个房屋特征,是一个进行入门回归训练的好例子。下面是load_boston
的使用方法:
>>> data = load_boston()
>>> # 房屋特征名称
>>> data['feature_names']
array(['CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX', 'RM', 'AGE', 'DIS', 'RAD',
'TAX', 'PTRATIO', 'B', 'LSTAT'], dtype='<U7')
>>> data['data'].shape
(506, 13)
从上述代码中可以看到,load_boston
中共有506个样本,每条数据中包含了房屋和房屋周边的13个重要信息,如城市犯罪率、环保指标、周边老房子的比例、是否临河等。
load_digits
是一个比MNIST更小的手写数字图片数据集,里面的图片尺寸是8像素×8像素(后面将省略单位),通过如下代码可以查看手写数字图片:
>>> g = sklearn.datasets.load_digits()
>>> plt.imshow(g['data'][0].reshape(8,8),cmap='gray')
<matplotlib.image.AxesImage object at 0x7f07e42ddeb8>
>>> plt.show()
输出图片如图1-5所示,因为是8×8的图片,所以看起来不是很清晰。

图 1-5 手写数字
1.2.2 在线下载的数据集
Fetch
系列函数用于获取较大规模的数据集,这些数据集会自动从网上下载,得到的数据格式与load***
一样,是字典形式。我们可以自定义下载目录,同时可以选择单独下载训练集或者测试集,常用的数据集如下。
- 人脸数据集:
fetch_olivetti_faces
和fetch_lfw_people
。 - 文本分类数据集:
fetch_20newsgroups
。 - 房价回归数据集:
fetch_california_housing
。
1.2.3 计算机生成的数据集
用sklearn生成的数据集可以用来测试一些基础的模型功能,比如多分类数据集、聚类数据集以及高斯分布数据集等。还有一些特殊形状的数据集,比如make_circles
和make_moons
等,示例如下:
>>> circle = make_circles()[0]
>>> # 创建子图
>>> plt.subplot(121)
<matplotlib.axes._subplots.AxesSubplot object at 0x000000001719BE80>
>>> # 绘制散点图
>>> plt.scatter(circle[:,0],circle[:,1])
<matplotlib.collections.PathCollection object at 0x000000002081D828>
>>> moon = make_moons()[0]
>>> plt.subplot(122)
<matplotlib.axes._subplots.AxesSubplot object at 0x000000002081D048>
>>> plt.scatter(moon[:,0],moon[:,1])
<matplotlib.collections.PathCollection object at 0x0000000017171D30>
>>> plt.show()
上述代码的作用是通过make_circles
和make_moons
函数生成两组坐标点数据,并使用plt.scatter
函数将生成的坐标点绘制成散点图。生成的散点图如图1-6所示,其他数据集详情请参考sklearn官网。

图 1-6 生成的散点图