From 00d0886815fee047e89d3328df4f76033c085fea Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Mon, 6 Nov 2017 10:17:58 -0800 Subject: Basic plumbing for calling C API from import_graph_def() PiperOrigin-RevId: 174724070 --- tensorflow/python/framework/c_api_util.py | 31 +++++++++++++++++++++++--- tensorflow/python/framework/importer.py | 33 ++++++++++++++++++++-------- tensorflow/python/framework/importer_test.py | 23 +++++++++++++++++++ 3 files changed, 75 insertions(+), 12 deletions(-) diff --git a/tensorflow/python/framework/c_api_util.py b/tensorflow/python/framework/c_api_util.py index ddababd5b8..1d0dd88dc5 100644 --- a/tensorflow/python/framework/c_api_util.py +++ b/tensorflow/python/framework/c_api_util.py @@ -20,6 +20,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python import pywrap_tensorflow as c_api +from tensorflow.python.util import compat from tensorflow.python.util import tf_contextlib @@ -49,22 +50,46 @@ class ScopedTFGraph(object): c_api.TF_DeleteGraph(self.graph) +class ScopedTFImportGraphDefOptions(object): + """Wrapper around TF_ImportGraphDefOptions that handles deletion.""" + + def __init__(self): + self.options = c_api.TF_NewImportGraphDefOptions() + + def __del__(self): + # Note: when we're destructing the global context (i.e when the process is + # terminating) we can have already deleted other modules. + if c_api.TF_DeleteImportGraphDefOptions is not None: + c_api.TF_DeleteImportGraphDefOptions(self.options) + + @tf_contextlib.contextmanager -def tf_buffer(): +def tf_buffer(data=None): """Context manager that creates and deletes TF_Buffer. Example usage: - wtih tf_buffer() as buf: + with tf_buffer() as buf: # get serialized graph def into buf ... proto_data = c_api.TF_GetBuffer(buf) graph_def.ParseFromString(compat.as_bytes(proto_data)) # buf has been deleted + with tf_buffer(some_string) as buf: + c_api.TF_SomeFunction(buf) + # buf has been deleted + + Args: + data: An optional `bytes`, `str`, or `unicode` object. If not None, the + yielded buffer will contain this data. + Yields: Created TF_Buffer """ - buf = c_api.TF_NewBuffer() + if data: + buf = c_api.TF_NewBufferFromString(compat.as_bytes(data)) + else: + buf = c_api.TF_NewBuffer() try: yield buf finally: diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py index c6b335e661..e4b94e1a34 100644 --- a/tensorflow/python/framework/importer.py +++ b/tensorflow/python/framework/importer.py @@ -25,8 +25,11 @@ import copy from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import types_pb2 +from tensorflow.python import pywrap_tensorflow as c_api +from tensorflow.python.framework import c_api_util from tensorflow.python.framework import device as pydev from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import function from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import ops @@ -242,12 +245,6 @@ def import_graph_def(graph_def, input_map=None, return_elements=None, input_map = _ProcessInputMapParam(input_map) return_elements = _ProcessReturnElementsParam(return_elements) - # Use a canonical representation for all tensor names. - input_map = {_CanonicalInputName(k): v for k, v in input_map.items()} - used_input_keys = set() - - name_to_op = {} - op_dict = op_def_registry.get_registered_ops() if producer_op_list is None: @@ -255,10 +252,28 @@ def import_graph_def(graph_def, input_map=None, return_elements=None, else: producer_op_dict = {op.name: op for op in producer_op_list.op} - g = ops.get_default_graph() - if g._c_graph: # pylint: disable=protected-access - assert 'import_graph_def not yet implemented with C API' + graph = ops.get_default_graph() + + if graph._c_graph: # pylint: disable=protected-access + scoped_options = c_api_util.ScopedTFImportGraphDefOptions() + + with errors.raise_exception_on_not_ok_status() as status: + with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized: + c_api.TF_GraphImportGraphDefWithResults( + graph._c_graph, serialized, scoped_options.options, status) # pylint: disable=protected-access + + if return_elements is not None: + raise ValueError('return_elements not yet implemented with C API') + return None + else: + g = graph + + # Use a canonical representation for all tensor names. + input_map = {_CanonicalInputName(k): v for k, v in input_map.items()} + used_input_keys = set() + name_to_op = {} + # Add any functions defined in `graph_def` to `g` if graph_def.library and graph_def.library.function: # Copy op_dict so we don't clobber the original diff --git a/tensorflow/python/framework/importer_test.py b/tensorflow/python/framework/importer_test.py index 24934f264d..d27ec1e30c 100644 --- a/tensorflow/python/framework/importer_test.py +++ b/tensorflow/python/framework/importer_test.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import function from tensorflow.python.framework import importer from tensorflow.python.framework import ops from tensorflow.python.framework import test_ops # pylint: disable=unused-import +from tensorflow.python.framework import test_util from tensorflow.python.framework import versions from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients_impl @@ -55,6 +56,28 @@ class ImportGraphDefTest(test.TestCase): text_format.Merge(text, ret) return ret + # The C API doesn't currently support return elements (or anything else beyond + # the most basic import). This test only checks that the import can run + # without error, and will be removed once more functionality is implemented + # and we can get coverage from the other tests. + @test_util.enable_c_api + def testCApi(self): + importer.import_graph_def( + self._MakeGraphDef(""" + node { name: 'A' op: 'IntOutputFloatOutput' } + node { name: 'B' op: 'ListOutput' + attr { key: 'T' + value { list { type: DT_INT32 type: DT_FLOAT } } } } + node { name: 'C' op: 'ListInput' + attr { key: 'N' value { i: 2 } } + attr { key: 'T' value { type: DT_INT32 } } + input: 'A:0' input: 'B:0' } + node { name: 'D' op: 'ListInput' + attr { key: 'N' value { i: 2 } } + attr { key: 'T' value { type: DT_FLOAT } } + input: 'A:1' input: 'B:1' } + """)) + def testBasic(self): with ops.Graph().as_default(): a, b, c, d = importer.import_graph_def( -- cgit v1.2.3