aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2018-03-19 17:34:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-19 17:39:00 -0700
commit2714c07c93c2fd84480f816e0da44030a0a2bd45 (patch)
treea37ff0e0e805c37f6c77e92d44c72cc590bbb9de
parentb6b4ec642a632af9abaf3ca7a2b1348ab2e94bef (diff)
Make _USE_C_API = True and_USE_C_SHAPES = False work with import_graph_def.
Without this change, shapes wouldn't be correctly computed for operations created via import_graph_def. PiperOrigin-RevId: 189670312
-rw-r--r--tensorflow/python/client/session_test.py3
-rw-r--r--tensorflow/python/framework/importer.py34
-rw-r--r--tensorflow/python/framework/importer_test.py4
-rw-r--r--tensorflow/python/framework/meta_graph_test.py3
-rw-r--r--tensorflow/python/framework/ops.py29
-rw-r--r--tensorflow/python/training/saver_test.py3
6 files changed, 43 insertions, 33 deletions
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index 44ff440cc5..6e2640efd1 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -62,8 +62,7 @@ from tensorflow.python.util import compat
ops.RegisterShape('ConstructionFails')(common_shapes.unknown_shape)
-# TODO(skyewm): reenable when this works with _USE_C_SHAPES=False
-# @test_util.with_c_api
+@test_util.with_c_api
class SessionTest(test_util.TensorFlowTestCase):
def setUp(self):
diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py
index 783e9259ad..a9e399f59b 100644
--- a/tensorflow/python/framework/importer.py
+++ b/tensorflow/python/framework/importer.py
@@ -489,23 +489,25 @@ def import_graph_def(graph_def,
# Convert to ValueError for backwards compatibility.
raise ValueError(str(e))
- _ProcessNewOps(graph)
+ # Create _DefinedFunctions for any imported functions.
+ #
+ # We do this by creating _DefinedFunctions directly from `graph_def`, and
+ # adding them to `graph`. Adding an existing function to a TF_Graph is a
+ # no-op, so this only has the effect of updating the Python state (usually
+ # _DefinedFunction.add_to_graph also adds the function to the TF_Graph).
+ #
+ # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph
+ # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph
+ # TODO(b/74620627): move this after _ProcessNewOps outside the lock once
+ # _USE_C_SHAPES is removed.
+ if graph_def.library and graph_def.library.function:
+ # pylint: disable=protected-access
+ functions = function._from_library(graph_def.library)
+ for f in functions:
+ f.add_to_graph(graph)
+ # pylint: enable=protected-access
- # Create _DefinedFunctions for any imported functions.
- #
- # We do this by creating _DefinedFunctions directly from `graph_def`, and
- # adding them to `graph`. Adding an existing function to a TF_Graph is a
- # no-op, so this only has the effect of updating the Python state (usually
- # _DefinedFunction.add_to_graph also adds the function to the TF_Graph).
- #
- # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph
- # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph
- if graph_def.library and graph_def.library.function:
- # pylint: disable=protected-access
- functions = function._from_library(graph_def.library)
- for f in functions:
- f.add_to_graph(graph)
- # pylint: enable=protected-access
+ _ProcessNewOps(graph)
# Treat input mappings that don't appear in the graph as an error, because
# they are likely to be due to a typo.
diff --git a/tensorflow/python/framework/importer_test.py b/tensorflow/python/framework/importer_test.py
index c39191e6d9..bf5d9fe093 100644
--- a/tensorflow/python/framework/importer_test.py
+++ b/tensorflow/python/framework/importer_test.py
@@ -31,6 +31,7 @@ from tensorflow.python.framework import function
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_ops # pylint: disable=unused-import
+from tensorflow.python.framework import test_util
from tensorflow.python.framework import versions
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@@ -43,8 +44,7 @@ import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
from tensorflow.python.platform import test
-# TODO(skyewm): reenable when this works with _USE_C_SHAPES=False
-# @test_util.with_c_api
+@test_util.with_c_api
class ImportGraphDefTest(test.TestCase):
def _MakeGraphDef(self,
diff --git a/tensorflow/python/framework/meta_graph_test.py b/tensorflow/python/framework/meta_graph_test.py
index 06cec504e4..21963d0bee 100644
--- a/tensorflow/python/framework/meta_graph_test.py
+++ b/tensorflow/python/framework/meta_graph_test.py
@@ -285,8 +285,7 @@ class SimpleMetaGraphTest(test.TestCase):
self.assertIs(global_vars[0], trainable_vars[0])
-# TODO(skyewm): reenable when this works with _USE_C_SHAPES=False
-# @test_util.with_c_api
+@test_util.with_c_api
class ScopedMetaGraphTest(test.TestCase):
def _testScopedExport(self, test_dir, exported_filenames):
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index f1cd341d66..4be2e2c15d 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -3303,6 +3303,20 @@ class Graph(object):
input_types=input_types,
original_op=self._default_original_op,
op_def=op_def)
+
+ # TODO(vrv): Instead of eagerly filling in shape property for every op,
+ # only populate the shape when requested.
+ #
+ # TODO(skyewm): unlike in the original Python implementation, the C API
+ # always computes shape information (even for function calls, which the
+ # original Python shape inference code doesn't handle). Deprecate the
+ # compute_shapes argument.
+ #
+ # TODO(b/74620627): move this back to _create_op_helper once _USE_C_SHAPES
+ # is removed
+ if (ret._c_op and _USE_C_SHAPES) or compute_shapes: # pylint: disable=protected-access
+ set_shapes_for_outputs(ret)
+
self._create_op_helper(ret, compute_shapes=compute_shapes,
compute_device=compute_device)
return ret
@@ -3336,15 +3350,6 @@ class Graph(object):
def _create_op_helper(self, op, compute_shapes=True, compute_device=True):
"""Common logic for creating an op in this graph."""
- # TODO(vrv): Instead of eagerly filling in shape property for every op, only
- # populate the shape when requested.
- #
- # TODO(skyewm): unlike in the original Python implementation, the C API
- # always computes shape information (even for function calls, which the
- # original Python shape inference code doesn't handle). Deprecate the
- # compute_shapes argument.
- if (op._c_op and _USE_C_SHAPES) or compute_shapes: # pylint: disable=protected-access
- set_shapes_for_outputs(op)
# TODO(b/XXXX): move to Operation.__init__ once _USE_C_API flag is removed.
self._add_op(op)
@@ -3449,6 +3454,12 @@ class Graph(object):
]
for op in new_ops:
+ # The Python shape inference code does not support imported functions. It
+ # also needs access to op.inputs, which is why we call it here.
+ # TODO(b/74620627): move this back to _create_op_helper once _USE_C_SHAPES
+ # is removed.
+ if not self._is_function(op.type) or _USE_C_SHAPES:
+ set_shapes_for_outputs(op)
new_control_inputs = self._control_dependencies_for_inputs(op.inputs)
# pylint: disable=protected-access
op._add_control_inputs(new_control_inputs)
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index 787582ae70..7de778f298 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -1739,8 +1739,7 @@ class CheckpointStateTest(test.TestCase):
os.path.join(save_dir, "./model.ckpt-687529"))
-# TODO(skyewm): reenable when this works with _USE_C_SHAPES=False
-# @test_util.with_c_api
+@test_util.with_c_api
class MetaGraphTest(test.TestCase):
def _get_test_dir(self, dirname):