I have two tensors with the same size:
a = [1, 2, 3, 4, 5, 10, 11, 12, 13, 20, 21, 22, 23, 24, 25, 26, 27, 28]
b = [0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1]
Tensor a has three regions which are demarked by consecutive values: region 1 is [1,2,3,4,5]
, region 2 is [10,11,12,13]
and region 3 is [20, 21, 22, 23, 24, 25, 26, 27, 28]
.
For each of those regions, I want to apply the following logic: if one of the values of b is 1, then the following i values are set to 0. If they are already 0, they continue as 0. After i values are changed, nothing happens until another value of b is 1. In that case, the next i values are forced to 0...
Some examples:
# i = 1
a = [1, 2, 3, 4, 5, 10, 11, 12, 13, 20, 21, 22, 23, 24, 25, 26, 27, 28]
b_new = [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1]
# i = 2
a = [1, 2, 3, 4, 5, 10, 11, 12, 13, 20, 21, 22, 23, 24, 25, 26, 27, 28]
b_new = [0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1]
# i = 4
a = [1, 2, 3, 4, 5, 10, 11, 12, 13, 20, 21, 22, 23, 24, 25, 26, 27, 28]
b_new = [0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1]
Not sure if this would help, but I was able to separate the regions into segments by doing:
a_shifted = tf.roll(a - 1, shift=-1, axis=0)
a_shifted_segs = tf.math.cumsum(tf.cast(a_shifted != a, dtype=tf.int64), exclusive=True)
# a_shifted_segs =
= [0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2]
Do you know any way of doing this efficiently?
Here you have a tensorflow solution, based on tf.scan
. I know the conditionals are a bit complicated, if you have suggestions how to simplify, I'm open for suggestions. However, if you know how to read the conditionals, it should be quite clear what the code does.
Here, the variable i
tells us, for each position in the array, how many more b
values have to overwritten with 0
.
import tensorflow as tf
a = tf.constant([1, 2, 3, 4, 5, 10, 11, 12, 13, 20, 21, 22, 23, 24, 25, 26, 27, 28])
b = tf.constant([0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1])
# Extract switches inside a
switches = tf.scan(
lambda e, new_a: {'a': new_a, 'out': new_a != (e['a']+1)},
a,
initializer={'a': tf.reduce_min(a)-2, 'out': tf.constant(False)}
)['out']
# Define inputs for the scan iterations
initializer = {'b': tf.constant(False), 'i': tf.constant(0)}
elems = {'switches': switches, 'b': tf.cast(b, dtype=tf.bool)}
@tf.function
def step(last_out, new_in, max_i):
new_i = tf.cond(
last_out['i'] > 0, # If we are currently overwriting with 0
lambda: tf.cond(
new_in['switches'], # Is there a segment switch?
lambda: tf.cond( # if switches:
new_in['b'], # Check if b == 1
lambda: tf.constant(max_i), # if b == 1: i = max_i
lambda: tf.constant(0) # if b == 0: i = 0
),
lambda: tf.maximum(last_out['i']-1, 0) # If no switch, decrement i
),
lambda: tf.cond( # if we are currently not overwriting with 0
new_in['b'], # check if b == 1
lambda: tf.constant(max_i), # if b == 1: i = max_i
lambda: tf.constant(0) # if b == 0: i = 0
)
)
b = tf.cond(
tf.equal(new_i, max_i), # Have we just reset i ?
lambda: tf.constant(True), # If yes, we want to write b = 1
lambda: tf.constant(False) # Otherwise, we write b = 0
)
return {'b': b, 'i': new_i}
Examples:
outp_1 = tf.scan(lambda _last, _inp: step(_last, _inp, max_i=1), elems=elems, initializer=initializer)
print( tf.cast(outp_1['b'], tf.int32) )
# tf.Tensor([0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 0 0 1], shape=(18,), dtype=int32)
outp_2 = tf.scan(lambda _last, _inp: step(_last, _inp, max_i=2), elems=elems, initializer=initializer)
print( tf.cast(outp_2['b'], tf.int32) )
# tf.Tensor([0 1 0 0 1 1 0 0 0 1 0 0 1 0 0 0 0 1], shape=(18,), dtype=int32)
outp_4 = tf.scan(lambda _last, _inp: step(_last, _inp, max_i=4), elems=elems, initializer=initializer)
print( tf.cast(outp_4['b'], tf.int32) )
# tf.Tensor([0 1 0 0 0 1 0 0 0 1 0 0 0 0 0 0 0 1], shape=(18,), dtype=int32)
This answer is sponsored by lambda.
Here is a pure Tensorflow
approach, which will work in Eager Execution
and Graph
mode:
# copy, paste, acknowledge
import tensorflow as tf
def split_regions_and_modify(a, b, i):
indices = tf.squeeze(tf.where(a[:-1] != a[1:] - 1), axis=-1) + 1
row_splits = tf.cast(tf.cond(tf.not_equal(tf.shape(indices)[0], 0),
lambda: tf.concat([indices, [indices[-1] + (tf.cast(tf.shape(a), dtype=tf.int64)[0] - indices[-1])]], axis=0),
lambda: tf.shape(a)[0][None]), dtype=tf.int32)
def body(i, j, k, tensor, row_splits):
k = tf.cond(tf.equal(row_splits[k], j), lambda: tf.add(k, 1), lambda: k)
current_indices = tf.range(j + 1, tf.minimum(j + 1 + i, row_splits[k]), dtype=tf.int32)
tensor = tf.cond(tf.logical_and(tf.equal(tensor[j], 1), tf.not_equal(j, row_splits[k])), lambda:
tf.tensor_scatter_nd_update(tensor, current_indices[..., None], tf.zeros_like(current_indices)), lambda: tensor)
return i, tf.add(j, 1), k, tensor, row_splits
j0 = tf.constant(0)
k0 = tf.constant(0)
c = lambda i, j0, k0, b, row_splits: tf.logical_and(tf.less(j0, tf.shape(b)[0]), tf.less(k0, tf.shape(row_splits)[0]))
_, _, _, output, _ = tf.while_loop(c, body, loop_vars=[i, j0, k0, b, row_splits])
return output
Usage:
a = tf.constant([1, 2, 3, 4, 5, 10, 11, 12, 13, 20, 21, 22, 23, 24, 25, 26, 27, 28])
b = tf.constant([0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1])
split_regions_and_modify(a, b, 1)
# <tf.Tensor: shape=(18,), dtype=int32, numpy=array([0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1], dtype=int32)>
split_regions_and_modify(a, b, 2)
# <tf.Tensor: shape=(18,), dtype=int32, numpy=array([0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1], dtype=int32)>
split_regions_and_modify(a, b, 4)
# <tf.Tensor: shape=(18,), dtype=int32, numpy=array([0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1], dtype=int32)>
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With