Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow Estimator API: How to pass parameter from input function

I'm trying to add class weights as a hyperparameter for my model, but to calculate weight I need to read input data, this happens inside input_fn which then passed to estimator.fit(). An output of input_fn are only features, labels which should have same shape num_examples * num_features. My questions - is there any way to propagate data from input_fn to model_fn's hyperparameter map? Or as alternative - maybe there is a wrapper for input_fn dataset which allows to oversample minority/undersample majority along with batching - in this case I would not need any parameter to propagate.

like image 600
Stanislav Levental Avatar asked Jan 21 '18 23:01

Stanislav Levental


1 Answers

Both features and labels can be dictionary of tensors (not just one tensor). The tensors can be any shape you want though it's common to be num_examples * ...

If you don't use any of the predefined estimators, the easiest way would be to add another feature with what you need to compute the weights, compute the weights in the model then use them (multiply the loss or pass it as a parameter).

You also have access to hyper parameters inside the input_fn so you can compute the weight there and add it as a separate column.

If you use a canned estimator check the documentation. I see most of them support a weight_column_name. In this case just give it the name you used in the features dictionary for the weight values.

Alternatively, if all else fails you can sample the data the way you want before you feed it to tensorflow.

like image 134
Sorin Avatar answered Sep 25 '22 01:09

Sorin