Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Can Numba be used with Tensorflow?

Can Numba be used to compile Python code which interfaces with Tensorflow? I.e. computations outside of the Tensorflow universe would run with Numba for speed. I have not found any resources on how to do this.

like image 554
Ron Cohen Avatar asked Dec 14 '16 00:12

Ron Cohen


1 Answers

You can use tf.numpy_function, or tf.py_func to wrap a python function and use it as a TensorFlow op. Here is an example which I used:

@jit
def dice_coeff_nb(y_true, y_pred):
    "Calculates dice coefficient"
    smooth = np.float32(1)
    y_true_f = np.reshape(y_true, [-1])
    y_pred_f = np.reshape(y_pred, [-1])
    intersection = np.sum(y_true_f * y_pred_f)
    score = (2. * intersection + smooth) / (np.sum(y_true_f) +
                                            np.sum(y_pred_f) + smooth)
    return score

@jit
def dice_loss_nb(y_true, y_pred):
    "Calculates dice loss"
    loss = 1 - dice_coeff_nb(y_true, y_pred)
    return loss

def bce_dice_loss_nb(y_true, y_pred):
    "Adds dice_loss to categorical_crossentropy"
    loss =  tf.numpy_function(dice_loss_nb, [y_true, y_pred], tf.float64) + \
            tf.keras.losses.categorical_crossentropy(y_true, y_pred)
    return loss

Then I used this loss function in training a tf.keras model:

...
model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
model.compile(optimizer='adam', loss=bce_dice_loss_nb)
like image 84
Kamin Avatar answered Sep 20 '22 19:09

Kamin