1. Python

optimizer.zero_grad() 踩坑记录

最近花了很多时间去尝试复现别人paper里的网络模型,一个很简单的forked cnn + mlp结构。

网络模型倒是很好搭建,一下子就写好了,但是因为模型要求输入的数据格式需要准备,还有数据标注的处理花了一天时间,这部分代码可以用多线程来提高效率。

然后就是编写训练部分的代码,起初参考d2l的ch6代码进行编写,训练发现train loss剧烈波动,test loss 小幅波动且小于train loss。将训练好的模型加载进来,给它喂不同数据,发现输出都一样。通过调查,有可能是输入数据未归一化,数据标注未归一化问题,经查不是。有可能是lr设置太大的原因,经查不是。最后查看pytorch文档,对比我的代码,发现优化器没有在一个batch后执行梯度归零操作,导致梯度一直在累积直到爆炸。