我最近转向Python 3.5并注意到新的矩阵乘法运算符(@)有时与numpy点运算符的行为不同.例如,对于3d数组:
import numpy as np a = np.random.rand(8,13,13) b = np.random.rand(8,13,13) c = a @ b # Python 3.5+ d = np.dot(a, b)
的@
运算符返回形状的阵列:
c.shape (8, 13, 13)
而np.dot()
函数返回:
d.shape (8, 13, 8, 13)
如何用numpy dot重现相同的结果?还有其他重大差异吗?
该@
运营商称阵列的__matmul__
方法,而不是dot
.该方法也作为函数存在于API中np.matmul
.
>>> a = np.random.rand(8,13,13) >>> b = np.random.rand(8,13,13) >>> np.matmul(a, b).shape (8, 13, 13)
从文档:
matmul
与dot
两个重要方面不同.
不允许使用标量进行乘法运算.
矩阵堆栈一起广播,就好像矩阵是元素一样.
最后一点清楚地表明dot
,matmul
当传递3D(或更高维)数组时,方法的行为会有所不同.从文档中引用更多:
用于matmul
:
如果任一参数是ND,N> 2,则将其视为驻留在最后两个索引中的矩阵堆栈并相应地进行广播.
用于np.dot
:
对于2-D阵列,它相当于矩阵乘法,对于1-D阵列相当于矢量的内积(没有复共轭).对于N维,它是a的最后一个轴和b的倒数第二个轴的和积
@ajcr的答案解释了dot
和matmul
(由@
符号调用)的不同之处.通过一个简单的例子,可以清楚地看到两者在"堆叠的基质"或张量上进行操作时的行为方式.
为了澄清差异,需要使用4x4阵列,dot
并matmul
使用2x4x3'堆叠的matricies'或张量返回产品和产品.
import numpy as np fourbyfour = np.array([ [1,2,3,4], [3,2,1,4], [5,4,6,7], [11,12,13,14] ]) twobyfourbythree = np.array([ [[2,3],[11,9],[32,21],[28,17]], [[2,3],[1,9],[3,21],[28,7]], [[2,3],[1,9],[3,21],[28,7]], ]) print('4x4*4x2x3 dot:\n {}\n'.format(np.dot(fourbyfour,twobyfourbythree))) print('4x4*4x2x3 matmul:\n {}\n'.format(np.matmul(fourbyfour,twobyfourbythree)))
每个操作的产品如下所示.注意点积是怎样的,
... a的最后一个轴和b的倒数第二个的和积
以及如何通过将矩阵一起广播来形成矩阵产品.
4x4*4x2x3 dot: [[[232 152] [125 112] [125 112]] [[172 116] [123 76] [123 76]] [[442 296] [228 226] [228 226]] [[962 652] [465 512] [465 512]]] 4x4*4x2x3 matmul: [[[232 152] [172 116] [442 296] [962 652]] [[125 112] [123 76] [228 226] [465 512]] [[125 112] [123 76] [228 226] [465 512]]]