浅谈Pytorch中的torch.gather函数的含义
Python  /  管理员 发布于 5年前   213
pytorch中的gather函数
pytorch比tensorflow更加编程友好,所以准备用pytorch试着做最近要做的一些实验。
立个flag开始学习pytorch,新开一个分类整理学习pytorch中的一些踩到的泥坑。
今天刚开始接触,读了一下documentation,写一个一开始每太搞懂的函数gather
b = torch.Tensor([[1,2,3],[4,5,6]])print bindex_1 = torch.LongTensor([[0,1],[2,0]])index_2 = torch.LongTensor([[0,1,1],[0,0,0]])print torch.gather(b, dim=1, index=index_1)print torch.gather(b, dim=0, index=index_2)
观察它的输出结果:
1 2 3 4 5 6[torch.FloatTensor of size 2x3] 1 2 6 4[torch.FloatTensor of size 2x2] 1 5 6 1 2 3[torch.FloatTensor of size 2x3]
这里是官方文档的解释
torch.gather(input, dim, index, out=None) → Tensor Gathers values along an axis specified by dim. For a 3-D tensor the output is specified by: out[i][j][k] = input[index[i][j][k]][j][k] # dim=0 out[i][j][k] = input[i][index[i][j][k]][k] # dim=1 out[i][j][k] = input[i][j][index[i][j][k]] # dim=2 Parameters: input (Tensor) C The source tensor dim (int) C The axis along which to index index (LongTensor) C The indices of elements to gather out (Tensor, optional) C Destination tensor Example: >>> t = torch.Tensor([[1,2],[3,4]]) >>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]])) 1 1 4 3 [torch.FloatTensor of size 2x2]
可以看出,gather的作用是这样的,index实际上是索引,具体是行还是列的索引要看前面dim 的指定,比如对于我们的栗子,【1,2,3;4,5,6,】,指定dim=1,也就是横向,那么索引就是列号。index的大小就是输出的大小,所以比如index是【1,0;0,0】,那么看index第一行,1列指的是2, 0列指的是1,同理,第二行为4,4 。这样就输入为【2,1;4,4】,参考这样的解释看上面的输出结果,即可理解gather的含义。
gather在one-hot为输出的多分类问题中,可以把最大值坐标作为index传进去,然后提取到每一行的正确预测结果,这也是gather可能的一个作用。
以上这篇浅谈Pytorch中的torch.gather函数的含义就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。
122 在
学历:一种延缓就业设计,生活需求下的权衡之选中评论 工作几年后,报名考研了,到现在还没认真学习备考,迷茫中。作为一名北漂互联网打工人..123 在
Clash for Windows作者删库跑路了,github已404中评论 按理说只要你在国内,所有的流量进出都在监控范围内,不管你怎么隐藏也没用,想搞你分..原梓番博客 在
在Laravel框架中使用模型Model分表最简单的方法中评论 好久好久都没看友情链接申请了,今天刚看,已经添加。..博主 在
佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 @1111老铁这个不行了,可以看看近期评论的其他文章..1111 在
佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 网站不能打开,博主百忙中能否发个APP下载链接,佛跳墙或极光..
Copyright·© 2019 侯体宗版权所有·
粤ICP备20027696号