简介
Term-1第二节课是进行交通标志分类,数据集主要来自于German Traffic Sign,包含了42种交通标志,通过深度学习网络进行分类。
环境准备
- python 2.7
- numpy
- scikit-learn
- tensorflow
- keras
处理流程
处理流程如下图所示
数据读取
我们拿到的数据集是一系列交通标志图像,每个类别的交通标志放在了同一个文件夹下,并且有一个csv文件用于描述每个交通标志图片的ROI区域和该标志所属类别。下载图像的时候网站提供了一份用于数据处理的python程序(Python code for GTSRB文件夹下),在这里可以用到。
这里按csv描述的图像信息进行ROI部分的提取,并使用pickle保存为.p文件方便后续模型训练使用。代码示例如下:
1 | import numpy as np |
放几张效果图
数据处理
目前已经得到了训练集和测试集,需要对数据进行shuffle以及对label进行one-hot编码。
1 | def shuffle_data(X_train,Y_train,X_test,Y_test): |
模型构建
下面使用keras构建卷积神经网络,keras的使用方法具体参见Keras中文文档。
模型构建部分代码示例如下:
1 | def build_model(): |
模型训练
模型定义完成后开始进行模型训练
1 | //指定训练集、测试机,迭代次数等 |
效果评估
经过40轮迭代,模型准确率基本能达到99%左右,为了验证模型效果,特意挑选了10张图像进行比对。
最终结果如下图所示
其中第6张图类别应该是“Priority road”,其他分类全部正确。
总结
这节课主要是使用深度学习模型进行图像分类,在没有进行深入调优的情况下,在测试集上获得了90%的正确率,如果想进一步提高分类准确率可以采用效果更好的模型进行轮数更多的训练。这部分的代码详见Term-1-p2-traffic-sign-classifier。