aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Akshay Agrawal <akshayka@google.com>2017-12-15 12:10:18 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-15 12:14:32 -0800
commitf962b77042b6fb207d18d00fd9ef9aa838e14a3d (patch)
tree24585d41f0a41ee538ec408f968c1460a1f2b164
parentfbb5392a65ebaeca19f95cb13fca9166bb5ba3ce (diff)
Capture tensors that do not trigger convert_to_tensor in defun
Returning a closed-over Tensor does not trigger a call to convert_to_tensor, so we need to manually coerce such Tensors to graph tensors and capture them. PiperOrigin-RevId: 179224063
-rw-r--r--tensorflow/compiler/xla/BUILD6
-rw-r--r--tensorflow/compiler/xla/python/BUILD4
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.i2
-rw-r--r--tensorflow/compiler/xla/python/xla_client_test.py6
-rw-r--r--tensorflow/python/eager/function.py23
-rw-r--r--tensorflow/python/eager/function_test.py19
6 files changed, 44 insertions, 16 deletions
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index be0dd0bc82..cd69c69889 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -641,12 +641,6 @@ filegroup(
visibility = ["//tensorflow:__subpackages__"],
)
-py_proto_library(
- name = "xla_data_proto_py_pb2",
- api_version = 2,
- deps = [":xla_data_proto"],
-)
-
# This is a headers target that extra XLA devices can use to prevent circular dependencies. Devices that are compiled as separate shared objects can also use it to prevent linking of library code.
cc_header_only_library(
name = "xla_headers_lib",
diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD
index 7734e55967..a6b8158671 100644
--- a/tensorflow/compiler/xla/python/BUILD
+++ b/tensorflow/compiler/xla/python/BUILD
@@ -11,7 +11,7 @@ py_library(
visibility = ["//visibility:public"],
deps = [
":pywrap_xla",
- "//tensorflow/compiler/xla:xla_data_proto_py_pb2",
+ "//tensorflow/compiler/xla:xla_data_proto_py",
],
)
@@ -23,7 +23,6 @@ py_test(
deps = [
":xla_client",
"//tensorflow/python:platform_test",
- "//third_party/py/numpy",
],
)
@@ -52,7 +51,6 @@ cc_library(
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/service:cpu_plugin",
"//tensorflow/core:lib",
- "//tensorflow/stream_executor/host:host_platform",
],
)
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i
index 678de3e762..ac8f3e4277 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.i
+++ b/tensorflow/compiler/xla/python/local_computation_builder.i
@@ -106,7 +106,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/xla_data.proto.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/compiler/xla/python/numpy_bridge.h"
#include "tensorflow/compiler/xla/python/local_computation_builder.h"
diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py
index cf71212fdb..878cd83edc 100644
--- a/tensorflow/compiler/xla/python/xla_client_test.py
+++ b/tensorflow/compiler/xla/python/xla_client_test.py
@@ -23,10 +23,10 @@ import itertools
import numpy as np
from tensorflow.compiler.xla.python import xla_client
-from tensorflow.python.platform import googletest
+import unittest
-class LocalComputationTest(googletest.TestCase):
+class LocalComputationTest(unittest.TestCase):
"""Base class for running an XLA Computation through the local client."""
def _NewComputation(self, name=None):
@@ -895,4 +895,4 @@ class EmbeddedComputationsTest(LocalComputationTest):
if __name__ == "__main__":
- googletest.main()
+ unittest.main()
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 239216243a..b068d5e584 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -91,13 +91,24 @@ def _convert_to_graph_tensor(value, dtype=None, name=None, as_ref=False):
is not enabled. A placeholder which will have the value of the
tensor at runtime otherwise.
"""
+ del as_ref # Unused.
+
if context.in_eager_mode():
return value
- _ = as_ref
+
+ default_graph = ops.get_default_graph()
+ if not default_graph.building_function:
+ return value
+
tensor_map = _scoped_captures.tensors
if tensor_map is None:
# Capturing is not enabled.
return constant_op.constant(value.numpy())
+ if type(value) == ops.Tensor and value.graph is default_graph:
+ # The tensor has already been converted and captured. The type check
+ # is intentional: we are checking that value is a Tensor and not an
+ # EagerTensor.
+ return value
return capture_value(tensor_map, value, dtype, name)
@@ -499,20 +510,26 @@ def _defun_internal(name, func, args, kwds):
func_outputs = func(*func_inputs, **kwds)
finally:
variables = tape.pop_tape().watched_variables()
+
+ # Returning a closed-over tensor as an output does not trigger a
+ # call to convert_to_tensor, so we manually capture all such tensors.
+ outputs_list = nest.flatten(func_outputs)
+ func_def_outputs = [
+ _convert_to_graph_tensor(x) for x in outputs_list if x is not None
+ ]
+
ids = list(sorted(captures.keys()))
if ids:
extra_inputs, extra_placeholders = zip(* [captures[x] for x in ids])
else:
extra_inputs = []
extra_placeholders = []
- outputs_list = nest.flatten(func_outputs)
output_shapes = tuple(x.shape for x in outputs_list if x is not None)
flat_inputs = [x for x in nest.flatten(func_inputs)
if isinstance(x, ops.Tensor)]
all_inputs = flat_inputs + list(extra_placeholders)
all_ignored_ops = frozenset(x.op for x in all_inputs)
- func_def_outputs = [x for x in outputs_list if x is not None]
fname = _inference_name(name)
operations = tuple(x for x in tmp_graph.get_operations()
if x not in all_ignored_ops)
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index aee2a91a0e..7018027386 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -323,6 +323,25 @@ class FunctionTest(test.TestCase):
self.assertEqual(1, int(outer()))
+ def testReturnCapturedEagerTensor(self):
+ t = constant_op.constant(1)
+
+ @function.defun
+ def read():
+ return t
+
+ self.assertEqual(1, int(read()))
+
+ def testReturnCapturedGraphTensor(self):
+ with context.graph_mode(), self.test_session():
+ t = constant_op.constant(1)
+
+ @function.defun
+ def read():
+ return t
+
+ self.assertEqual(1, int(self.evaluate(read())))
+
def testSequenceInputs(self):
clip_by_global_norm = function.defun(clip_ops.clip_by_global_norm)
t_list = [constant_op.constant(1.0), constant_op.constant(2.0)]