trainNetwork训练神经网络进行深度学习
原地址 /help/deeplearning/ref/trainnetwork.html
几种调用方法
net = trainNetwork(imds,layers,options)
net = trainNetwork(ds,layers,options)
net = trainNetwork(X,Y,layers,options)
net = trainNetwork(sequences,Y,layers,options)
net = trainNetwork(tbl,layers,options)
net = trainNetwork(tbl,responseName,layers,options)
[net,info] = trainNetwork(___)
描述
使用trainNetwork
训练卷积神经网络(ConvNet,CNN),长短期记忆(LSTM)网络,或双向LSTM(BiLSTM)网络的深度学习分类和回归的问题。您可以在CPU或GPU上训练网络。对于图像分类和图像回归,您可以使用多个GPU或并行进行训练。使用GPU,多GPU和并行选项需要Parallel Computing Toolbox™。要使用深层学习GPU,你还必须有一个CUDA®启用NVIDIA®GPU计算能力3.0或更高版本。使用指定培训选项,包括用于执行环境的选项trainingOptions
。
为图像分类问题训练网络。图像数据存储区net
= trainNetwork(imds
,layers
,options
)imds
存储输入的图像数据,layers
定义网络体系结构,并options
定义训练选项。
使用数据存储训练网络net
= trainNetwork(ds
,layers
,options
)ds
。对于具有多个输入的网络,请将此语法与组合或转换后的数据存储区结合使用。
为图像分类和回归问题训练网络。数字数组net
= trainNetwork(X
,Y
,layers
,options
)X
包含预测变量,并Y
包含分类标签或数字响应。
训练网络以解决序列分类和回归问题(例如LSTM或BiLSTM网络),其中net
= trainNetwork(sequences
,Y
,layers
,options
)sequences
包含序列或时间序列预测变量并Y
包含响应。对于分类问题,Y
是分类向量或分类序列的单元格数组。对于回归问题,Y
是目标矩阵或数字序列的单元格数组。
为分类和回归问题训练网络。该表net
= trainNetwork(tbl
,layers
,options
)tbl
包含数字数据或数据的文件路径。预测变量必须位于的第一列中tbl
。有关目标或响应变量的信息,请参见tbl。
为分类和回归问题训练网络。预测变量必须位于的第一列中net
= trainNetwork(tbl
,responseName
,layers
,options
)tbl
。该responseName
参数指定在响应变量tbl
。[
还可以使用先前语法中的任何输入参数返回有关训练的信息。net
,info
] = trainNetwork(___)
例子
图像分类训练网络将数据作为ImageDatastore
对象加载。
digitDatasetPath = fullfile(matlabroot,'toolbox','nnet',... 'nndemos','nndatasets','DigitDataset');imds = imageDatastore(digitDatasetPath,... 'IncludeSubfolders',true,... 'LabelSource','foldernames');
数据存储区包含10,000个从0到9的数字合成图像。这些图像是通过对使用不同字体创建的数字图像应用随机转换而生成的。每个数字图像为28 x 28像素。数据存储区每个类别包含相等数量的图像。
显示数据存储中的某些图像。
figurenumImages = 10000;perm = randperm(numImages,20);for i = 1:20subplot(4,5,i);imshow(imds.Files{perm(i)});end
指定卷积神经网络架构。对于回归问题,请在网络末端包括一个回归层。
layers = [ ...imageInputLayer([28 28 1])convolution2dLayer(5,20)reluLayermaxPooling2dLayer(2,'Stride',2)fullyConnectedLayer(10)softmaxLayerclassificationLayer];
指定网络训练选项。将初始学习速率设置为0.001。
options = trainingOptions('sgdm',... 'InitialLearnRate',0.001,... 'Verbose',false,... 'Plots','training-progress');
训练网络。
net = trainNetwork(imdsTrain,layers,options);
通过评估测试数据的预测准确性来测试网络的性能。使用predict
预测验证图像的旋转角度。
[XTest,〜,YTest] = digitTest4DArrayData;YPred =predict(net,XTest);
通过计算预测旋转角和实际旋转角的均方根误差(RMSE)来评估模型的性能。
rmse = sqrt(mean((YTest-YPred)。^ 2))
rmse = single6.0655
序列分类训练网络
查看MATLAB命令
训练用于序列到标签分类的深度学习LSTM网络。
如[1]和[2]中所述加载日语元音数据集。XTrain
是包含270个长度可变且特征尺寸为12的序列的单元格数组。Y
是标签1,2,...,9的分类向量。中的条目XTrain
是具有12行(每个要素一行)和不同列数(每个时间步长一列)的矩阵。
[XTrain,YTrain] = japaneseVowelsTrainData;
可视化图中的第一个时间序列。每行对应一个特征。
数字情节(XTrain {1}')标题(“训练观察1”)numFeatures = size(XTrain {1},1);图例(“ Feature” + string(1:numFeatures),'Location','northeastoutside')
定义LSTM网络体系结构。将输入大小指定为12(输入数据的特征数)。指定一个LSTM层,使其具有100个隐藏单元并输出序列的最后一个元素。最后,通过包括大小为9的完全连接的层,其后是softmax层和分类层,来指定九个类。
inputSize = 12;numHiddenUnits = 100;numClasses = 9;层数= [ ...sequenceInputLayer(inputSize)lstmLayer(numHiddenUnits,'OutputMode','last')fullyConnectedLayer(numClasses)softmaxLayer分类图层]
层数= 具有层的5x1层阵列:1英寸序列输入序列输入具有12个尺寸2英寸LSTM LSTM具有100个隐藏单元3英寸全连接9个全连接层4英寸Softmax softmax5''分类输出交叉熵
指定训练选项。将求解器指定为'adam'
和'GradientThreshold'
。1.将小批量大小设置为27,并将最大纪元数设置为100。
由于小批量生产的序列短,因此CPU更适合训练。设置'ExecutionEnvironment'
到'cpu'
。要在GPU上进行训练(如果有),请设置'ExecutionEnvironment'
为'auto'
(默认值)。
maxEpochs = 100;miniBatchSize = 27;options = trainingOptions('adam',... 'ExecutionEnvironment','cpu',... 'MaxEpochs',maxEpochs,... 'MiniBatchSize',miniBatchSize,... 'GradientThreshold',1,... '详细'',false,... ``情节'',``培训进度'');
使用指定的培训选项来培训LSTM网络。
net = trainNetwork(XTrain,YTrain,图层,选项);
加载测试集并将序列分类为扬声器。
[XTest,YTest] = japaneseVowelsTestData;
分类测试数据。指定用于训练的相同的小批量大小。
YPred = classify(net,XTest,'MiniBatchSize',miniBatchSize);
计算预测的分类准确性。
acc = sum(YPred == YTest)./ numel(YTest)
acc = 0.9541
输入参数
全部收缩
imds
—图像数据存储ImageDatastore
对象
ImageDatastore
对象图像数据存储,指定为ImageDatastore
对象。
ImageDatastore
允许使用预取功能批量读取JPG或PNG图像文件。如果您使用自定义功能读取图像,则ImageDatastore
不会预取。
小费
使用augmentedImageDatastore
针对深度学习包括图像大小调整图像的高效预处理。
不要使用readFcn
选项,imageDatastore
因为此选项通常会明显变慢。
ds
—数据存储数据存储
数据存储,用于内存不足数据和预处理。
对于只有一个输入的网络,数据存储区返回的表或单元格数组有两列,分别指定了网络输入和期望的响应。
对于具有多个输入的网络,数据存储区必须是组合或转换后的数据存储区,该数据存储区将返回具有(numInputs
+1)列的单元格数组,其中包含预测变量和响应,其中numInputs
是网络输入numResponses
的数量,是响应的数量。对于i
小于或等于的值,单元阵列numInputs
的i
第th个元素对应于inputlayers.InputNames(i)
,其中layers
是定义网络体系结构的层图。单元格数组的最后一列对应于响应。
下表列出了直接与兼容的数据存储trainNetwork
。您可以使用transform
和combine
函数将其他内置数据存储区用于训练深度学习网络。这些函数可以将从数据存储中读取的数据转换为所需的表或单元格数组格式trainNetwork
。有关更多信息,请参阅用于深度学习的数据存储。
X
—图像数据数字数组
图像数据,指定为数字数组。数组的大小取决于图像输入的类型:
如果数组包含NaN
,则它们将通过网络传播。
sequences
—数字数组的序列或时间序列数据单元格数组|数值数组|数据存储
序列或时间序列数据,指定为N乘1的数字数组单元格数组,其中N是观察数,代表单个序列的数字数组或数据存储。
对于单元格数组或数字数组输入,包含序列的数字数组的维数取决于数据类型。
对于数据存储区输入,数据存储区必须以序列的单元格数组或第一列包含序列的表的形式返回数据。序列数据的尺寸必须与上表相对应。
Y
—响应标签的分类向量|数值数组|分类序列的单元格数组|数字序列的单元格数组
响应,指定为标签的分类向量,数字数组,分类序列的单元格数组或数字序列的单元格数组。的格式Y
取决于任务的类型。响应中不得包含NaN
。
分类
对于一个观察到的序列到序列分类问题,sequences
也可以是向量。在这种情况下,Y
必须是标签的分类序列。
回归
对于只有一个观察值的逐序列回归问题,sequences
可以将其作为矩阵。在这种情况下,Y
必须是响应矩阵。
标准化响应通常有助于稳定和加速训练神经网络以进行回归。有关更多信息,请参阅训练卷积神经网络进行回归。
tbl
—输入数据table
table
输入数据,指定为包含第一列中的预测变量和其余列中的响应的表。表格中的每一行都对应一个观察值。
表列中预测变量和响应的排列方式取决于问题的类型。
分类
对于分类问题,如果您未指定responseName
,则该函数默认使用的第二列中的响应tbl
。
回归
对于回归问题,如果不指定responseName
,则该函数默认使用的其余列tbl
。标准化响应通常有助于稳定和加速训练神经网络以进行回归。有关更多信息,请参阅训练卷积神经网络进行回归。
响应中不能包含NaN
。如果预测变量数据包含NaN
,则它们将通过训练传播。但是,在大多数情况下,培训无法收敛。
资料类型:table
responseName
—输入表字符向量中的响应变量的名称|向量的元胞数组|字符串数组
输入表中响应变量的名称,指定为字符向量,字符向量的单元格数组或字符串数组。对于一个响应的问题,responseName
是中相应的变量名称tbl
。对于具有多个响应变量的回归问题,responseName
是中对应变量名称的数组tbl
。
数据类型:char
|cell
|string
layers
—网络层Layer
阵列|LayerGraph
目的
Layer
阵列|LayerGraph
目的网络层,指定为Layer
数组或LayerGraph
对象。
要创建依次连接所有层的网络,可以使用Layer
数组作为输入参数。在这种情况下,返回的网络是一个SeriesNetwork
对象。
有向无环图(DAG)网络具有复杂的结构,其中各层可以具有多个输入和输出。要创建DAG网络,请将网络体系结构指定为LayerGraph
对象,然后将该层图用作的输入参数trainNetwork
。
有关内置层的列表,请参阅深度学习层列表。
options
—培训选项TrainingOptionsSGDM
|TrainingOptionsRMSProp
|TrainingOptionsADAM
TrainingOptionsSGDM
|TrainingOptionsRMSProp
|TrainingOptionsADAM
培训选项,指定为TrainingOptionsSGDM
,TrainingOptionsRMSProp
或者TrainingOptionsADAM
对象通过返回的trainingOptions
功能。要指定求解器和其他用于网络训练的选项,请使用trainingOptions
。
输出参数
全部收缩
net
—训练有素的网络SeriesNetwork
对象|DAGNetwork
目的
SeriesNetwork
对象|DAGNetwork
目的经过训练的网络,作为SeriesNetwork
对象或DAGNetwork
对象返回。
如果使用Layer
数组作为layers
输入参数来训练网络,则它net
是一个SeriesNetwork
对象。如果使用LayerGraph
对象作为输入参数来训练网络,则net
该DAGNetwork
对象为对象。
info
—培训信息结构
训练信息,以结构形式返回,其中每个字段是标量或数字向量,每个训练迭代具有一个元素。
对于分类问题,info
包含以下字段:
TrainingLoss
—损失函数值
TrainingAccuracy
-训练精度
ValidationLoss
—损失函数值
ValidationAccuracy
—验证准确性
BaseLearnRate
—学习率
FinalValidationLoss
—最终验证损失
FinalValidationAccuracy
—最终验证准确性
对于回归问题,info
包含以下字段:
TrainingLoss
—损失函数值
TrainingRMSE
—训练RMSE值
ValidationLoss
—损失函数值
ValidationRMSE
—验证RMSE值
BaseLearnRate
—学习率
FinalValidationLoss
—最终验证损失
FinalValidationRMSE
—最终验证RMSE
结构只包含的字段ValidationLoss
,ValidationAccuracy
,ValidationRMSE
,FinalValidationLoss
,FinalValidationAccuracy
和FinalValidationRMSE
在options
指定的验证数据。所述'ValidationFrequency'
的选择trainingOptions
确定哪些迭代软件将计算验证指标。对于软件未计算验证指标的迭代,结构中的对应值为NaN
。
如果您的网络包含批处理规范化层,则最终验证指标通常与培训期间评估的验证指标不同。这是因为最终网络中的批处理归一化层执行的操作与训练期间不同。
更多关于
全部收缩
保存检查点网络并继续培训
深度学习工具箱™使您可以在训练期间的每个时期之后将网络另存为.mat文件。当您拥有大型网络或大型数据集并且训练需要很长时间时,这种定期保存特别有用。如果培训由于某种原因而中断,则可以从上次保存的检查点网络恢复培训。如果要trainNetwork
保存检查点网络,则必须使用的'CheckpointPath'
名称/值对参数指定路径的名称trainingOptions
。如果指定的路径不存在,则trainingOptions
返回错误。
trainNetwork
自动为检查点网络文件分配唯一的名称。在示例名称中net_checkpoint__351___04_12__18_09_52.mat
,351是迭代编号,_04_12
日期和保存网络18_09_52
的时间trainNetwork
。您可以通过双击或在命令行中使用load命令来加载检查点网络文件。例如:
<span style="color:#404040"><span style="color:inherit">加载net_checkpoint__351___04_12__18_09_52.mat</span></span>
然后,您可以使用网络的各层作为的输入参数来恢复训练trainNetwork
。例如:
<span style="color:#404040"><span style="color:inherit">trainNetwork(XTrain,YTrain,net.Layers,options)</span></span>
您必须手动指定培训选项和输入数据,因为检查点网络不包含此信息。有关示例,请参阅从Checkpoint Network继续培训。
浮点运算
深度学习工具箱中用于深度学习训练,预测和验证的所有功能都使用单精度浮点算术执行计算。深学习功能包括trainNetwork
,predict
,classify
,和activations
。当您同时使用CPU和GPU训练网络时,该软件使用单精度算术。
参考资料
[1] Kudo,M.,J。Toyama和M.Shimbo。“使用通过区域的多维曲线分类”。模式识别字母。卷20,第11-13号,第1103-1111页。
[2] Kudo,M.,J。Toyama和M.Shimbo。日本元音数据集。https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels
扩展功能
自动并行支持通过使用Parallel Computing Toolbox™自动并行
运行计算来加速代码。
如果觉得《trainNetwork - Matlab官网介绍的中文版》对你有帮助,请点赞、收藏,并留下你的观点哦!