Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Error when parsing graph_def from string

I am trying to run a very simple saving of a Tensorflow graph as .pb file, but I have this error when parsing it back:

Traceback (most recent call last):
  File "test_import_stripped_bm.py", line 28, in <module>
    graph_def.ParseFromString(fileContent)
  File "/usr/local/lib/python3.5/dist-packages/google/protobuf/message.py", line 185, in ParseFromString
    self.MergeFromString(serialized)
  File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/python_message.py", line 1069, in MergeFromString
    if self._InternalParse(serialized, 0, length) != length:
  File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/python_message.py", line 1105, in InternalParse
    pos = field_decoder(buffer, new_pos, end, self, field_dict)
  File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/decoder.py", line 633, in DecodeField
    if value._InternalParse(buffer, pos, new_pos) != new_pos:
  File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/python_message.py", line 1105, in InternalParse
    pos = field_decoder(buffer, new_pos, end, self, field_dict)
  File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/decoder.py", line 612, in DecodeRepeatedField
    if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:
  File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/python_message.py", line 1105, in InternalParse
    pos = field_decoder(buffer, new_pos, end, self, field_dict)
  File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/decoder.py", line 743, in DecodeMap
    if submsg._InternalParse(buffer, pos, new_pos) != new_pos:
  File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/python_message.py", line 1095, in InternalParse
    new_pos = local_SkipField(buffer, new_pos, end, tag_bytes)
  File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/decoder.py", line 850, in SkipField
    return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end)
  File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/decoder.py", line 799, in _SkipGroup
    new_pos = SkipField(buffer, pos, end, tag_bytes)
  File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/decoder.py", line 850, in SkipField
    return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end)
  File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/decoder.py", line 814, in _SkipFixed32
    raise _DecodeError('Truncated message.')
google.protobuf.message.DecodeError: Truncated message.

This is the code that I use to write it to .pb:

import tensorflow as tf

builder = tf.saved_model.builder.SavedModelBuilder('models/TEST-3')

w1 = tf.Variable(tf.random_normal((2,2)), name="w1")
w2 = tf.Variable(tf.random_normal((2,2)), name="w2")

sess = tf.Session()
sess.run(tf.global_variables_initializer())

builder.add_meta_graph_and_variables(sess, tags=[tf.saved_model.tag_constants.SERVING], clear_devices = True)

builder.save()
sess.close()

And this is the code to parse it:

import tensorflow as tf
import os

model_path = os.path.join('models/TEST-3', 'saved_model.pb')
with open(model_path, mode='rb') as f:
    fileContent = f.read()
graph_def = tf.GraphDef()
graph_def.ParseFromString(fileContent)

To see the exact error I had to do

export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python

before running it. Also I've tried this on python 2 and 3 with different tensorflow versions, I am running on Ubuntu 16.04. On python 2.7 with tensorflow 0.9.0rc0 I managed to get a slightly different error:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python2.7/dist-packages/google/protobuf/message.py", line 185, in ParseFromString
    self.MergeFromString(serialized)
  File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/python_message.py", line 1091, in MergeFromString
    if self._InternalParse(serialized, 0, length) != length:
  File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/python_message.py", line 1127, in InternalParse
    pos = field_decoder(buffer, new_pos, end, self, field_dict)
  File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/decoder.py", line 633, in DecodeField
    if value._InternalParse(buffer, pos, new_pos) != new_pos:
  File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/python_message.py", line 1127, in InternalParse
    pos = field_decoder(buffer, new_pos, end, self, field_dict)
  File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/decoder.py", line 612, in DecodeRepeatedField
    if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:
  File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/python_message.py", line 1127, in InternalParse
    pos = field_decoder(buffer, new_pos, end, self, field_dict)
  File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/decoder.py", line 612, in DecodeRepeatedField
    if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:
  File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/python_message.py", line 1127, in InternalParse
    pos = field_decoder(buffer, new_pos, end, self, field_dict)
  File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/decoder.py", line 489, in DecodeRepeatedField
    value.append(_ConvertToUnicode(buffer[pos:new_pos]))
  File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/decoder.py", line 469, in _ConvertToUnicode
    return local_unicode(byte_str, 'utf-8')
UnicodeDecodeError: 'utf8' codec can't decode byte 0x80 in position 18: 'utf8' codec can't decode byte 0x80 in position 18: invalid start byte in field: tensorflow.FunctionDef.Node.ret

I am able to parse other .pb graph with this code for example this one https://github.com/taey16/tf/blob/master/imagenet/classify_image_graph_def.pb

Thanks in advance.

like image 366
vntzy Avatar asked Oct 03 '17 14:10

vntzy


2 Answers

The problem here is that you are trying to parse a SavedModel protocol buffer as if it were a GraphDef. Although a SavedModel contains GraphDef, they have different binary formats. The following code, using tf.saved_model.loader.load() should work:

import tensorflow as tf

with tf.Session(graph=tf.Graph()) as sess:
    tf.saved_model.loader.load(
        sess, [tf.saved_model.tag_constants.SERVING], "models/TEST-3") 
like image 124
mrry Avatar answered Sep 27 '22 23:09

mrry


graph_def = tf.GraphDef()
graph_def.ParseFromString(fileContent)

Here fileContent should be a **"Frozen Graph". Tensorflow provides an api for the same as well, refer Tensorflow freeze_graph API

Another way to create frozen graph is:

with tf.Session(graph=tf.Graph()) as sess:
      saver = tf.train.import_meta_graph(<.meta file>)
      saver.restore(sess, <checkpoint>)
      output_graph_def = tf.graph_util.convert_variables_to_constants(
                    sess,
                    tf.get_default_graph().as_graph_def(),
                    [comma separated output nodes name]
                ) 
      # Saving "output_graph_def " in a file and generate frozen graph.
      with tf.gfile.GFile('frozen_graph.pb', "wb") as f:
      f.write(output_graph_def.SerializeToString())

Use frozen_graph.pb as

graph_def.ParseFromString("frozen_graph.pb")

So first using Saver object generates the .meta and other files. Once its done create frozen graph.

like image 30
vipin bansal Avatar answered Sep 27 '22 22:09

vipin bansal