Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

writing a custom cost function in tensorflow

I'm trying to write my own cost function in tensor flow, however apparently I cannot 'slice' the tensor object?

import tensorflow as tf
import numpy as np

# Establish variables
x = tf.placeholder("float", [None, 3])
W = tf.Variable(tf.zeros([3,6]))
b = tf.Variable(tf.zeros([6]))

# Establish model
y = tf.nn.softmax(tf.matmul(x,W) + b)

# Truth
y_ = tf.placeholder("float", [None,6])

def angle(v1, v2):
  return np.arccos(np.sum(v1*v2,axis=1))

def normVec(y):
  return np.cross(y[:,[0,2,4]],y[:,[1,3,5]])

angle_distance = -tf.reduce_sum(angle(normVec(y_),normVec(y)))
# This is the example code they give for cross entropy
cross_entropy = -tf.reduce_sum(y_*tf.log(y))

I get the following error: TypeError: Bad slice index [0, 2, 4] of type <type 'list'>

like image 680
kmace Avatar asked Nov 13 '15 01:11

kmace


1 Answers

At present, tensorflow can't gather on axes other than the first - it's requested.

But for what you want to do in this specific situation, you can transpose, then gather 0,2,4, and then transpose back. It won't be crazy fast, but it works:

tf.transpose(tf.gather(tf.transpose(y), [0,2,4]))

This is a useful workaround for some of the limitations in the current implementation of gather.

(But it is also correct that you can't use a numpy slice on a tensorflow node - you can run it and slice the output, and also that you need to initialize those variables before you run. :). You're mixing tf and np in a way that doesn't work.

x = tf.Something(...)

is a tensorflow graph object. Numpy has no idea how to cope with such objects.

foo = tf.run(x)

is back to an object python can handle.

You typically want to keep your loss calculation in pure tensorflow, so do the cross and other functions in tf. You'll probably have to do the arccos the long way, as tf doesn't have a function for it.

like image 63
dga Avatar answered Oct 07 '22 17:10

dga