Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Apply 1 channel mask to 3 channel Tensor in tensorflow

I'm trying to apply a mask (binary, only one channel) to an RGB image (3 channels, normalized to [0, 1]). My current solution is, that I split the RGB image into it's channels, multiply it with the mask and concatenate these channels again:

with tf.variable_scope('apply_mask') as scope:
  # Output mask is in range [-1, 1], bring to range [0, 1] first
  zero_one_mask = (output_mask + 1) / 2
  # Apply mask to all channels.
  channels = tf.split(3, 3, output_img)
  channels = [tf.mul(c, zero_one_mask) for c in channels]
  output_img = tf.concat(3, channels)

However, this seems pretty inefficient, especially since, to my understanding, none of these computations are done in-place. Is there a more efficient way for doing this?

like image 852
panmari Avatar asked Mar 14 '23 22:03

panmari


1 Answers

The tf.mul() operator supports numpy-style broadcasting, which would allow you to simplify and optimize the code slightly.

Let's say that zero_one_mask is an m x n tensor, and output_img is a b x m x n x 3 (where b is the batch size - I'm inferring this from the fact that you split output_img on dimension 3)*. You can use tf.expand_dims() to make zero_one_mask broadcastable to channels, by reshaping it to be an m x n x 1 tensor:

with tf.variable_scope('apply_mask') as scope:
  # Output mask is in range [-1, 1], bring to range [0, 1] first
  # NOTE: Assumes `output_mask` is a 2-D `m x n` tensor.
  zero_one_mask = tf.expand_dims((output_mask + 1) / 2, 2)
  # Apply mask to all channels.
  # NOTE: Assumes `output_img` is a 4-D `b x m x n x c` tensor.
  output_img = tf.mul(output_img, zero_one_mask)

(* This would work equally if output_img were a 4-D b x m x n x c (for any number of channels c) or 3-D m x n x c tensor, due to the way broadcasting works.)

like image 81
mrry Avatar answered Mar 23 '23 16:03

mrry