我正在尝试将一个掩码(二进制,只有一个通道)应用于RGB图像(3个通道,标准化为[0,1]).我目前的解决方案是,将RGB图像分割成它的通道,将其与掩码相乘并再次连接这些通道:
with tf.variable_scope('apply_mask') as scope:
# Output mask is in range [-1, 1], bring to range [0, 1] first
zero_one_mask = (output_mask + 1) / 2
# Apply mask to all channels.
channels = tf.split(3, 3, output_img)
channels = [tf.mul(c, zero_one_mask) for c in channels]
output_img = tf.concat(3, channels)
然而,这似乎效率很低,特别是因为根据我的理解,这些计算都不是就地完成的.有没有更有效的方法来做到这一点?
该tf.mul()
运营商支持numpy的风格的广播,这样可以让你简化和轻微优化代码.
让我们说这zero_one_mask
是一个m x n
张量,并且output_img
是一个b x m x n x 3
(b
批量大小 - 我从你output_img
在维度3上拆分的事实推断这个)*.您可以通过将其重塑为张量tf.expand_dims()
来使zero_one_mask
广播channels
成为m x n x 1
:
with tf.variable_scope('apply_mask') as scope: # Output mask is in range [-1, 1], bring to range [0, 1] first # NOTE: Assumes `output_mask` is a 2-D `m x n` tensor. zero_one_mask = tf.expand_dims((output_mask + 1) / 2, 2) # Apply mask to all channels. # NOTE: Assumes `output_img` is a 4-D `b x m x n x c` tensor. output_img = tf.mul(output_img, zero_one_mask)
(*由于广播的工作方式,如果output_img
是4-D b x m x n x c
(任意数量的频道c
)或3-D m x n x c
张量,这将同样有效.)