当前位置:  开发笔记 > 编程语言 > 正文

numpy dot()和Python 3.5+矩阵乘法之间的区别@

如何解决《numpydot()和Python3.5+矩阵乘法之间的区别@》经验,为你挑选了2个好方法。

我最近转向Python 3.5并注意到新的矩阵乘法运算符(@)有时与numpy点运算符的行为不同.例如,对于3d数组:

import numpy as np

a = np.random.rand(8,13,13)
b = np.random.rand(8,13,13)
c = a @ b  # Python 3.5+
d = np.dot(a, b)

@运算符返回形状的阵列:

c.shape
(8, 13, 13)

np.dot()函数返回:

d.shape
(8, 13, 8, 13)

如何用numpy dot重现相同的结果?还有其他重大差异吗?



1> Alex Riley..:

@运营商称阵列的__matmul__方法,而不是dot.该方法也作为函数存在于API中np.matmul.

>>> a = np.random.rand(8,13,13)
>>> b = np.random.rand(8,13,13)
>>> np.matmul(a, b).shape
(8, 13, 13)

从文档:

matmuldot两个重要方面不同.

不允许使用标量进行乘法运算.

矩阵堆栈一起广播,就好像矩阵是元素一样.

最后一点清楚地表明dot,matmul当传递3D(或更高维)数组时,方法的行为会有所不同.从文档中引用更多:

用于matmul:

如果任一参数是ND,N> 2,则将其视为驻留在最后两个索引中的矩阵堆栈并相应地进行广播.

用于np.dot:

对于2-D阵列,它相当于矩阵乘法,对于1-D阵列相当于矢量的内积(没有复共轭).对于N维,它是a的最后一个轴和b的倒数第二个轴的和积


这里的混淆可能是因为发行说明直接将"@"符号等同于示例代码中numpy的dot()函数.

2> Nathan..:

@ajcr的答案解释了dotmatmul(由@符号调用)的不同之处.通过一个简单的例子,可以清楚地看到两者在"堆叠的基质"或张量上进行操作时的行为方式.

为了澄清差异,需要使用4x4阵列,dotmatmul使用2x4x3'堆叠的matricies'或张量返回产品和产品.

import numpy as np
fourbyfour = np.array([
                       [1,2,3,4],
                       [3,2,1,4],
                       [5,4,6,7],
                       [11,12,13,14]
                      ])


twobyfourbythree = np.array([
                             [[2,3],[11,9],[32,21],[28,17]],
                             [[2,3],[1,9],[3,21],[28,7]],
                             [[2,3],[1,9],[3,21],[28,7]],
                            ])

print('4x4*4x2x3 dot:\n {}\n'.format(np.dot(fourbyfour,twobyfourbythree)))
print('4x4*4x2x3 matmul:\n {}\n'.format(np.matmul(fourbyfour,twobyfourbythree)))

每个操作的产品如下所示.注意点积是怎样的,

... a的最后一个轴和b的倒数第二个的和积

以及如何通过将矩阵一起广播来形成矩阵产品.

4x4*4x2x3 dot:
 [[[232 152]
  [125 112]
  [125 112]]

 [[172 116]
  [123  76]
  [123  76]]

 [[442 296]
  [228 226]
  [228 226]]

 [[962 652]
  [465 512]
  [465 512]]]

4x4*4x2x3 matmul:
 [[[232 152]
  [172 116]
  [442 296]
  [962 652]]

 [[125 112]
  [123  76]
  [228 226]
  [465 512]]

 [[125 112]
  [123  76]
  [228 226]
  [465 512]]]


dot(a,b)[i,j,k,m] = sum(a [i,j,:]*b [k,:,m])-------类文档说:它是一个在a的最后一个轴和b的倒数第二个轴上求和乘积:
推荐阅读
LEEstarmmmmm
这个屌丝很懒,什么也没留下!
DevBox开发工具箱 | 专业的在线开发工具网站    京公网安备 11010802040832号  |  京ICP备19059560号-6
Copyright © 1998 - 2020 DevBox.CN. All Rights Reserved devBox.cn 开发工具箱 版权所有