WikiKG90M
2022-07-03 23:43:37 0 举报
竞赛代码流程图
作者其他创作
大纲/内容
multiprocessing.Process()
在trainer子进程中注册子进程使用Queue接收异步更新的梯度数据
...
loss.backward(),trace 中的参数生成梯度数据
share_memory()
这里直接采用了 HOGWILD 的方案,将节点和边的emb保存在共享内存中,然后多进程直接按照索引更新参数。实验显示,在数据比较大的时候,参数的更新往往是比较稀疏的,这样操作相对于其他方案既不会损失过多的精度,并且可以很好的提升性能。
multiprocessing.QueueQueue.get()
多进程 spawn
计算loss
根据 score function计算 pos_score 和neg_score
遍历trace并put结束,即可开始下一个step
共享内存
将trace中的数据idx,grad_data share_memory() 然后 put 到之前注册的队列Queue中,遍历完 trace即可进入下一个step,重新生成batch子图。训练完成,Queue.put(none)
初始化模型,未记录梯度
按照进程数量采样子图
正常情况下是计算loss,backward()计算梯度,更新参数,然后在进行下一step训练。这里的操作使得 trainer 子进程中的参数还没更新,就用和上一个step相同的参数计算了loss。这是为了加快训练速度的做法,会减缓模型的收敛。在此,为了多个进程模型的同步,即训练相同步数相同,每 1000 step 设置一个barrier进行阻塞。根据实验效果来看,这样做效果可以得到保证。一方面,可能是由于数据量大,参数的更新特别的稀疏,下一次的step的idx 和上一次没什么交集;另一方面,可能也是因为设置了queue 的大小为1,意味着这一次 step 的梯度需要等待上一次 put 到quque里面的梯度被 get 掉,才能put到队列 quque里,避免了参数过于不同步。
根据batch子图的node 和 relation 索引idx获取对应的数据发送到gpu,并添加到trace中记录,并且s.clone().detach().require_grad_()
遍历trace,put 到Queue中
子进程按照梯度和索引更新参数。还没put数据进来的时候,子进程循环等待。接收到 none 代表训练完成,退出循环
0 条评论
下一页