diff options
author | 2017-12-15 12:10:18 -0800 | |
---|---|---|
committer | 2017-12-15 12:14:32 -0800 | |
commit | f962b77042b6fb207d18d00fd9ef9aa838e14a3d (patch) | |
tree | 24585d41f0a41ee538ec408f968c1460a1f2b164 | |
parent | fbb5392a65ebaeca19f95cb13fca9166bb5ba3ce (diff) |
Capture tensors that do not trigger convert_to_tensor in defun
Returning a closed-over Tensor does not trigger a call to convert_to_tensor,
so we need to manually coerce such Tensors to graph tensors and capture them.
PiperOrigin-RevId: 179224063
-rw-r--r-- | tensorflow/compiler/xla/BUILD | 6 | ||||
-rw-r--r-- | tensorflow/compiler/xla/python/BUILD | 4 | ||||
-rw-r--r-- | tensorflow/compiler/xla/python/local_computation_builder.i | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/python/xla_client_test.py | 6 | ||||
-rw-r--r-- | tensorflow/python/eager/function.py | 23 | ||||
-rw-r--r-- | tensorflow/python/eager/function_test.py | 19 |
6 files changed, 44 insertions, 16 deletions
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index be0dd0bc82..cd69c69889 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -641,12 +641,6 @@ filegroup( visibility = ["//tensorflow:__subpackages__"], ) -py_proto_library( - name = "xla_data_proto_py_pb2", - api_version = 2, - deps = [":xla_data_proto"], -) - # This is a headers target that extra XLA devices can use to prevent circular dependencies. Devices that are compiled as separate shared objects can also use it to prevent linking of library code. cc_header_only_library( name = "xla_headers_lib", diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index 7734e55967..a6b8158671 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -11,7 +11,7 @@ py_library( visibility = ["//visibility:public"], deps = [ ":pywrap_xla", - "//tensorflow/compiler/xla:xla_data_proto_py_pb2", + "//tensorflow/compiler/xla:xla_data_proto_py", ], ) @@ -23,7 +23,6 @@ py_test( deps = [ ":xla_client", "//tensorflow/python:platform_test", - "//third_party/py/numpy", ], ) @@ -52,7 +51,6 @@ cc_library( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/core:lib", - "//tensorflow/stream_executor/host:host_platform", ], ) diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index 678de3e762..ac8f3e4277 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -106,7 +106,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/xla_data.proto.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/compiler/xla/python/numpy_bridge.h" #include "tensorflow/compiler/xla/python/local_computation_builder.h" diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index cf71212fdb..878cd83edc 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -23,10 +23,10 @@ import itertools import numpy as np from tensorflow.compiler.xla.python import xla_client -from tensorflow.python.platform import googletest +import unittest -class LocalComputationTest(googletest.TestCase): +class LocalComputationTest(unittest.TestCase): """Base class for running an XLA Computation through the local client.""" def _NewComputation(self, name=None): @@ -895,4 +895,4 @@ class EmbeddedComputationsTest(LocalComputationTest): if __name__ == "__main__": - googletest.main() + unittest.main() diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 239216243a..b068d5e584 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -91,13 +91,24 @@ def _convert_to_graph_tensor(value, dtype=None, name=None, as_ref=False): is not enabled. A placeholder which will have the value of the tensor at runtime otherwise. """ + del as_ref # Unused. + if context.in_eager_mode(): return value - _ = as_ref + + default_graph = ops.get_default_graph() + if not default_graph.building_function: + return value + tensor_map = _scoped_captures.tensors if tensor_map is None: # Capturing is not enabled. return constant_op.constant(value.numpy()) + if type(value) == ops.Tensor and value.graph is default_graph: + # The tensor has already been converted and captured. The type check + # is intentional: we are checking that value is a Tensor and not an + # EagerTensor. + return value return capture_value(tensor_map, value, dtype, name) @@ -499,20 +510,26 @@ def _defun_internal(name, func, args, kwds): func_outputs = func(*func_inputs, **kwds) finally: variables = tape.pop_tape().watched_variables() + + # Returning a closed-over tensor as an output does not trigger a + # call to convert_to_tensor, so we manually capture all such tensors. + outputs_list = nest.flatten(func_outputs) + func_def_outputs = [ + _convert_to_graph_tensor(x) for x in outputs_list if x is not None + ] + ids = list(sorted(captures.keys())) if ids: extra_inputs, extra_placeholders = zip(* [captures[x] for x in ids]) else: extra_inputs = [] extra_placeholders = [] - outputs_list = nest.flatten(func_outputs) output_shapes = tuple(x.shape for x in outputs_list if x is not None) flat_inputs = [x for x in nest.flatten(func_inputs) if isinstance(x, ops.Tensor)] all_inputs = flat_inputs + list(extra_placeholders) all_ignored_ops = frozenset(x.op for x in all_inputs) - func_def_outputs = [x for x in outputs_list if x is not None] fname = _inference_name(name) operations = tuple(x for x in tmp_graph.get_operations() if x not in all_ignored_ops) diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index aee2a91a0e..7018027386 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -323,6 +323,25 @@ class FunctionTest(test.TestCase): self.assertEqual(1, int(outer())) + def testReturnCapturedEagerTensor(self): + t = constant_op.constant(1) + + @function.defun + def read(): + return t + + self.assertEqual(1, int(read())) + + def testReturnCapturedGraphTensor(self): + with context.graph_mode(), self.test_session(): + t = constant_op.constant(1) + + @function.defun + def read(): + return t + + self.assertEqual(1, int(self.evaluate(read()))) + def testSequenceInputs(self): clip_by_global_norm = function.defun(clip_ops.clip_by_global_norm) t_list = [constant_op.constant(1.0), constant_op.constant(2.0)] |