糖尿病康复,内容丰富有趣,生活中的好帮手!
糖尿病康复 > Xception迁移学习:玉米叶片病害识别分类

Xception迁移学习:玉米叶片病害识别分类

时间:2019-10-02 04:43:40

相关推荐

Xception迁移学习:玉米叶片病害识别分类

Xception迁移学习:玉米叶片病害识别分类

数据集:来自网上公开的PlantVillage数据集中的玉米叶片部分。

运行环境:Tensorflow深度学习开源框架,选用Python 3.6.12作为编程语言。

本代码是自己查阅了很多博客代码最后根绝自己要用的数据集综合而成的,由于过于久远,不记得参考了哪些博客,这里就不放链接了。记录下来,便于自己以后查阅。也是刚入门的小白,欢迎大佬指教!

代码如下

1. 导入

import tensorflow as tfimport tensorflow.keras as kerasimport matplotlib.pyplot as pltimport tensorflow.keras.preprocessing.image as imageimport os as osfrom tensorflow.keras.applications import Xceptionfrom tensorflow.keras.layers import Dense,Flatten,GlobalAveragePooling2D,Dropoutfrom tensorflow.keras.models import Model,load_modelfrom tensorflow.keras.optimizers import SGD

2. 设置参数和路径

IMG_SIZE:输入图片的尺寸;

batch_size:每次读取图片的数量;

EPOCHS:训练轮次;

train_path:训练集路径;val_path:验证集路径。

IMG_SIZE = 150batch_size = 16EPOCHS=100IMG_SHAPE = (IMG_SIZE, IMG_SIZE, 3)train_path = 'D:/tmp/New Maize Data set/Train_maize'val_path='D:/tmp/New Maize Data set/Vali_maize'

3. 数据增强

由于电脑的配置低,带不动很多图片,所以只选取了每种病害图片的几百张作为训练集,故需要数据增强操作,提高分类准确率。

使用keras提供的图像生成器ImageDataGenerator类来实现数据增强。主要做法是每次取一个批次即batch_size大小的样本数据提供给模型,同时对每批样本进行归一化、随机旋转40°、随机水平和上下位置平移、随机错切变换角度、随机缩放比例、随机将一半图像水平翻转等操作。这样每一轮训练时输入的样本批次就不会完全相同,可以增强模型的泛化能力。

数据增强后的结果如图:

from tensorflow.keras.preprocessing.image import ImageDataGeneratortrain_gen = ImageDataGenerator(rescale=1 / 255,rotation_range=40, # 角度值,0-180.表示图像随机旋转的角度范围width_shift_range=0.2, # 平移比例,下同height_shift_range=0.2,shear_range=0.2, # 随机错切变换角度zoom_range=0.2, # 随即缩放比例horizontal_flip=True, # 随机将一半图像水平翻转validation_split=0.2,fill_mode='nearest' # 填充新创建像素的方法)train_generator = train_gen.flow_from_directory(directory=train_path,shuffle = True,batch_size = batch_size,class_mode = 'categorical',target_size = IMG_SHAPE[:-1],color_mode='rgb',#classes =classes,#subset='training')validation_generator = train_gen.flow_from_directory(directory=val_path,shuffle = True,batch_size = batch_size,class_mode = 'categorical',target_size =IMG_SHAPE[:-1],color_mode='rgb',#classes =classes,#subset='validation')

4. 构建模型

这里所用模型直接调用keras中的Xception模型

#构建模型model = tf.keras.Sequential([tf.keras.applications.Xception(input_shape=(150,150,3),weights='imagenet',include_top=False),tf.keras.layers.GlobalAveragePooling2D(),tf.keras.layers.Dense(4,activation='softmax')])

设置迁移学习冻结模型的层数:冻结部分网络层,即只训练其中的一部分网络层。

for i, layer in enumerate(model.layers[0].layers):if i > 85:layer.trainable = Trueelse:layer.trainable = False

5. 编译模型

#编译模型 pile(optimizer='adam',loss = 'categorical_crossentropy',metrics=['accuracy'])

6. 打印模型

model.summary()

模型打印结果可以看到可训练的参数数量

7. 训练模型

history=model.fit_generator(train_generator,steps_per_epoch=max(1, train_generator.n//batch_size),validation_data=validation_generator,validation_steps=max(1, validation_generator.n//batch_size),epochs =100,#initial_epoch=0,#callbacks=[checkpoint])

8. 保存模型

将模型保存为.h5文件

model.save('model/Xception_2_85_model.h5')

9. 绘制损失值曲线和准确率曲线

# 记录准确率和损失值history_dict = history.historytrain_loss = history_dict["loss"]train_accuracy = history_dict["acc"]val_loss = history_dict["val_loss"]val_accuracy = history_dict["val_acc"]# 绘制损失值曲线plt.figure()plt.title('InceptionV3-1')plt.plot(range(EPOCHS),train_loss,c='k' ,ls='--',label='train_loss')plt.plot(range(EPOCHS),val_loss,'k' ,label='val_loss' )plt.legend()plt.xlabel('epochs')plt.ylabel('loss')import matplotlib as mpl#中文字体设置mpl.rcParams["font.family"] = "SimHei"mpl.rcParams["axes.unicode_minus"] = Falsempl.rcParams["font.style"] = "normal"mpl.rcParams["font.size"] = 10# 绘制准确率曲线plt.figure()#plt.title('InceptionV3-1')plt.plot(range(EPOCHS), train_accuracy,ls='--', c="k",label="训练集准确率")plt.plot(range(EPOCHS), val_accuracy,c="k",label="验证集准确率")plt.ylim(0.5,1)plt.legend(loc='lower right')plt.xlabel("训练轮次")plt.ylabel("准确率")plt.show()

10. 测试模型

测试结果可以输出一个混淆矩阵,查看每种病害类别的准确率。

import numpy as npimport tensorflow as tfimport tensorflow.keras as kerasfrom tensorflow.keras.preprocessing.image import ImageDataGeneratorfrom tensorflow.keras.models import load_modelimport datetimefrom tensorflow.keras.callbacks import TensorBoard from keras.backend.tensorflow_backend import set_sessionimport matplotlib.pyplot as pltfrom sklearn.metrics import confusion_matriximport itertoolsconfig = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config)set_session(sess)keras.backend.clear_session() #清理session#test image directorydst_path = 'D:/tmp/New Maize Data set/Test_maize'#model pathmodel_file ='C:/Users/name/model/Xception_2_85_model.h5'batch_size = 8def plot_confusion_matrix(cm,target_names,title='Confusion Matrix',cmap=plt.cm.Greens, # 设置混淆矩阵的颜色主题normalize=True):accuracy = np.trace(cm) / float(np.sum(cm))misclass = 1 - accuracyif cmap is None:cmap = plt.get_cmap('Blues')plt.figure()plt.imshow(cm, interpolation='nearest', cmap=cmap)# plt.title(title)plt.title(title+'\naccuracy={:0.4f}; misclass={:0.4f}'.format(accuracy, misclass))plt.colorbar()if target_names is not None:tick_marks = np.arange(len(target_names))plt.xticks(tick_marks, target_names, rotation=45)plt.yticks(tick_marks, target_names)if normalize:cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]thresh = cm.max() / 1.5 if normalize else cm.max() / 2for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):if normalize:plt.text(j, i, "{:0.4f}".format(cm[i, j]),horizontalalignment="center",color="white" if cm[i, j] > thresh else "black")else:plt.text(j, i, "{:,}".format(cm[i, j]),horizontalalignment="center",color="white" if cm[i, j] > thresh else "black")plt.ylabel('True label')plt.xlabel('Predicted label')# load modelmodel = load_model(model_file)# generator imagetest_datagen = ImageDataGenerator(rescale=1. / 255)test_generator = test_datagen.flow_from_directory(dst_path,target_size=(150, 150),batch_size=batch_size,shuffle=False)labels = test_generator.class_indices #查看类别的label#labels = ['blight', 'cercos', 'healthy','rust']#然后直接用predice_geneorator 可以进行预测test_generator.reset()pred = model.predict_generator(test_generator, verbose=1)# 输出每个图像的预测类别predicted_class_indices = np.argmax(pred, axis=1)#测试集的真实类别true_label= test_generator.classes#简单画出混淆矩阵import pandas as pdtable=pd.crosstab(true_label,predicted_class_indices,colnames=['predict'],rownames=['label'])print(table)#图片化显示混淆矩阵conf_mat = confusion_matrix(y_true=true_label,y_pred=predicted_class_indices)plt.figure()plot_confusion_matrix(conf_mat, normalize=False, target_names=labels, title='Confusion Matrix')

测试结果如下:可以看出每种类别的识别率都很高

如果觉得《Xception迁移学习:玉米叶片病害识别分类》对你有帮助,请点赞、收藏,并留下你的观点哦!

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。