糖尿病康复,内容丰富有趣,生活中的好帮手!
糖尿病康复 > 对抗训练:FGM FGSM PGD

对抗训练:FGM FGSM PGD

时间:2019-12-20 14:32:51

相关推荐

对抗训练:FGM FGSM PGD

当前,在各大NLP竞赛中,对抗训练已然成为上分神器,尤其是fgm和pgd使用较多,下面来说说吧。对抗训练是一种引入噪声的训练方式,可以对参数进行正则化,提升模型鲁棒性和泛化能力。

一、什么是对抗训练?

对抗样本:对输入增加微小扰动得到的样本。旨在增加模型损失

对抗训练:训练模型去区分样例是真实样例还是对抗样本的过程。对抗训练不仅可以提升模型对对抗样本的防御能力,还能提升对原始样本的泛化能力

1、FGM——Fast Gradient Method

FSGM是每个方向上都走相同的一步,Goodfellow后续提出的FGM则是根据具体的梯度进行scale,得到更好的对抗样本:

对于每个x:1.计算x的前向loss、反向传播得到梯度2.根据embedding矩阵的梯度计算出r,并加到当前embedding上,相当于x+r3.计算x+r的前向loss,反向传播得到对抗的梯度,累加到(1)的梯度上4.将embedding恢复为(1)时的值5.根据(3)的梯度对参数进行更新

Pytorch实现class FGM():""" 快速梯度对抗训练"""def __init__(self, model):self.model = modelself.backup = {}def attack(self, epsilon=1., emb_name='word_embeddings'):# emb_name这个参数要换成你模型中embedding的参数名for name, param in self.model.named_parameters():if param.requires_grad and emb_name in name:self.backup[name] = param.data.clone() # 保存原始参数,用于后续恢复norm = torch.norm(param.grad)if norm != 0 and not torch.isnan(norm):r_at = epsilon * param.grad / normparam.data.add_(r_at)def restore(self, emb_name='word_embeddings'):# emb_name这个参数要换成你模型中embedding的参数名for name, param in self.model.named_parameters():if param.requires_grad and emb_name in name:assert name in self.backupparam.data = self.backup[name]self.backup = {}

一、使用步骤

for step, batch in enumerate(train_dataloader):# 遍历批数据# Add batch to GPUbatch = tuple(t.to(device) for t in batch)# Unpack the inputs from our dataloader# 每一批数据展开# train_inputs.extend(one_freq_input_ids)# train_labels.extend(one_freq_labels)# train_masks.extend(one_freq_attention_masks)# train_token_types.extend(one_freq_token_types)# 接收batch的输入b_input_ids, b_input_mask, b_labels, b_token_types = batchoutputs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask)logits = outputs[0]loss_func = BCEWithLogitsLoss() # 计算损失loss = loss_func(logits.view(-1,num_labels),b_labels.type_as(logits).view(-1,num_labels)) train_loss_set.append(loss.item())# 记录loss # Backward passloss.backward(retain_graph=True) # loss反向求导#对抗训练fgm.attack()loss_adv = loss_func(logits.view(-1,num_labels),b_labels.type_as(logits).view(-1,num_labels))loss_adv.backward(retain_graph=True)fgm.restore()#梯度更新optimizer.step()model.zero_grad()

总结

对抗训练中关键的是需要找到对抗样本(尽量让模型预测出错的样本),通常是对原始的输入添加一定的扰动来构造,然后用来给模型训练.

FGM对抗训练_Mr.奇的博客-CSDN博客

对抗训练fgm、fgsm和pgd原理和源码分析_谈笑风生...的博客-CSDN博客_pgd对抗训练

如果觉得《对抗训练:FGM FGSM PGD》对你有帮助,请点赞、收藏,并留下你的观点哦!

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。