aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2017-11-06 10:17:58 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-06 10:22:00 -0800
commit00d0886815fee047e89d3328df4f76033c085fea (patch)
tree4a516b88d66ff5718a4aa7339334ea2ef6e68cbd
parent61644e5dd762cc56ba18e7e9d1e4ce53ff0b9008 (diff)
Basic plumbing for calling C API from import_graph_def()
PiperOrigin-RevId: 174724070
-rw-r--r--tensorflow/python/framework/c_api_util.py31
-rw-r--r--tensorflow/python/framework/importer.py33
-rw-r--r--tensorflow/python/framework/importer_test.py23
3 files changed, 75 insertions, 12 deletions
diff --git a/tensorflow/python/framework/c_api_util.py b/tensorflow/python/framework/c_api_util.py
index ddababd5b8..1d0dd88dc5 100644
--- a/tensorflow/python/framework/c_api_util.py
+++ b/tensorflow/python/framework/c_api_util.py
@@ -20,6 +20,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python import pywrap_tensorflow as c_api
+from tensorflow.python.util import compat
from tensorflow.python.util import tf_contextlib
@@ -49,22 +50,46 @@ class ScopedTFGraph(object):
c_api.TF_DeleteGraph(self.graph)
+class ScopedTFImportGraphDefOptions(object):
+ """Wrapper around TF_ImportGraphDefOptions that handles deletion."""
+
+ def __init__(self):
+ self.options = c_api.TF_NewImportGraphDefOptions()
+
+ def __del__(self):
+ # Note: when we're destructing the global context (i.e when the process is
+ # terminating) we can have already deleted other modules.
+ if c_api.TF_DeleteImportGraphDefOptions is not None:
+ c_api.TF_DeleteImportGraphDefOptions(self.options)
+
+
@tf_contextlib.contextmanager
-def tf_buffer():
+def tf_buffer(data=None):
"""Context manager that creates and deletes TF_Buffer.
Example usage:
- wtih tf_buffer() as buf:
+ with tf_buffer() as buf:
# get serialized graph def into buf
...
proto_data = c_api.TF_GetBuffer(buf)
graph_def.ParseFromString(compat.as_bytes(proto_data))
# buf has been deleted
+ with tf_buffer(some_string) as buf:
+ c_api.TF_SomeFunction(buf)
+ # buf has been deleted
+
+ Args:
+ data: An optional `bytes`, `str`, or `unicode` object. If not None, the
+ yielded buffer will contain this data.
+
Yields:
Created TF_Buffer
"""
- buf = c_api.TF_NewBuffer()
+ if data:
+ buf = c_api.TF_NewBufferFromString(compat.as_bytes(data))
+ else:
+ buf = c_api.TF_NewBuffer()
try:
yield buf
finally:
diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py
index c6b335e661..e4b94e1a34 100644
--- a/tensorflow/python/framework/importer.py
+++ b/tensorflow/python/framework/importer.py
@@ -25,8 +25,11 @@ import copy
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import types_pb2
+from tensorflow.python import pywrap_tensorflow as c_api
+from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
from tensorflow.python.framework import function
from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import ops
@@ -242,12 +245,6 @@ def import_graph_def(graph_def, input_map=None, return_elements=None,
input_map = _ProcessInputMapParam(input_map)
return_elements = _ProcessReturnElementsParam(return_elements)
- # Use a canonical representation for all tensor names.
- input_map = {_CanonicalInputName(k): v for k, v in input_map.items()}
- used_input_keys = set()
-
- name_to_op = {}
-
op_dict = op_def_registry.get_registered_ops()
if producer_op_list is None:
@@ -255,10 +252,28 @@ def import_graph_def(graph_def, input_map=None, return_elements=None,
else:
producer_op_dict = {op.name: op for op in producer_op_list.op}
- g = ops.get_default_graph()
- if g._c_graph: # pylint: disable=protected-access
- assert 'import_graph_def not yet implemented with C API'
+ graph = ops.get_default_graph()
+
+ if graph._c_graph: # pylint: disable=protected-access
+ scoped_options = c_api_util.ScopedTFImportGraphDefOptions()
+
+ with errors.raise_exception_on_not_ok_status() as status:
+ with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized:
+ c_api.TF_GraphImportGraphDefWithResults(
+ graph._c_graph, serialized, scoped_options.options, status) # pylint: disable=protected-access
+
+ if return_elements is not None:
+ raise ValueError('return_elements not yet implemented with C API')
+ return None
+
else:
+ g = graph
+
+ # Use a canonical representation for all tensor names.
+ input_map = {_CanonicalInputName(k): v for k, v in input_map.items()}
+ used_input_keys = set()
+ name_to_op = {}
+
# Add any functions defined in `graph_def` to `g`
if graph_def.library and graph_def.library.function:
# Copy op_dict so we don't clobber the original
diff --git a/tensorflow/python/framework/importer_test.py b/tensorflow/python/framework/importer_test.py
index 24934f264d..d27ec1e30c 100644
--- a/tensorflow/python/framework/importer_test.py
+++ b/tensorflow/python/framework/importer_test.py
@@ -31,6 +31,7 @@ from tensorflow.python.framework import function
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_ops # pylint: disable=unused-import
+from tensorflow.python.framework import test_util
from tensorflow.python.framework import versions
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_impl
@@ -55,6 +56,28 @@ class ImportGraphDefTest(test.TestCase):
text_format.Merge(text, ret)
return ret
+ # The C API doesn't currently support return elements (or anything else beyond
+ # the most basic import). This test only checks that the import can run
+ # without error, and will be removed once more functionality is implemented
+ # and we can get coverage from the other tests.
+ @test_util.enable_c_api
+ def testCApi(self):
+ importer.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'IntOutputFloatOutput' }
+ node { name: 'B' op: 'ListOutput'
+ attr { key: 'T'
+ value { list { type: DT_INT32 type: DT_FLOAT } } } }
+ node { name: 'C' op: 'ListInput'
+ attr { key: 'N' value { i: 2 } }
+ attr { key: 'T' value { type: DT_INT32 } }
+ input: 'A:0' input: 'B:0' }
+ node { name: 'D' op: 'ListInput'
+ attr { key: 'N' value { i: 2 } }
+ attr { key: 'T' value { type: DT_FLOAT } }
+ input: 'A:1' input: 'B:1' }
+ """))
+
def testBasic(self):
with ops.Graph().as_default():
a, b, c, d = importer.import_graph_def(