numpy的广播机制

注意

首先一定要注意,这个机制不适用于矩阵相乘,矩阵相乘一定按其法则进行。

broadcast的简单例子

举一个简单的例子,实现对一个1-d array的每一个元素乘以2。

1
2
3
4
>>> a = np.array([1., 2., 3.])
>>> b = np.array([2., 2., 2.])
>>> a*b
array([2., 4., 6.])

上面的这种是通用做法,而broadcast则是:

1
2
3
4
a = np.array([1, 2, 3])
b = 2
a * b
array([2, 4, 6])

广播过程

为了定义两个形状是否是可兼容的,Numpy从最后开始往前逐个比较它们的维度(dimensions)大小。比较过程中,如果两者的对应维度相同,或者其中之一(或者全是)等于1,比较继续进行直到最前面的维度。否则,你将看到ValueError错误出现(如,”operands could not be broadcast together with shapes …”)。

当其中之一的形状的维度超出范围(例如,a1 的dim=(2,3,4)而a2的dim=(3,4),当a1=2时a2超出范围),此时Numpy将会使用1进行比较直到另一个也超出dim范围

一旦Numpy确定两者的形状是可兼容的,最终结果的形状就成了每个维度上取两者之间最大的形状尺寸。

所以,针对这种设定有:

1
2
3
(2,3,4 )* (2, 3, 1)=(2, 3, 4)
(2,3,4 )* (2, 1, 4)=(2, 3, 4)
(2,3,4 )* (3, 4)=(2, 3, 4)

测试:

1
2
3
a = np.random.rand(1, 1, 4)
b = np.random.rand(2, 3, 4)
np.shape(a * b)

结果是:(2, 3, 4)