aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager
diff options
context:
space:
mode:
authorGravatar Cao Zongyan <zongyan.cao@alibaba-inc.com>2018-09-11 19:59:11 +0800
committerGravatar Cao Zongyan <zongyan.cao@alibaba-inc.com>2018-09-11 19:59:11 +0800
commit9b3a93edf5a1f259bfe5230cc3b6c076573d4ec9 (patch)
treecbb0548282ba1584ed91a1be8f89b03ec882f287 /tensorflow/python/eager
parent90cf7fb7786c8a9c135ef73482856b082e80f61a (diff)
parente18f84a394bcbde62b344a3b32e8d8fd248fea58 (diff)
Merge remote-tracking branch 'origin'
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r--tensorflow/python/eager/BUILD39
-rw-r--r--tensorflow/python/eager/backprop.py157
-rw-r--r--tensorflow/python/eager/backprop_test.py101
-rw-r--r--tensorflow/python/eager/benchmarks_test.py55
-rw-r--r--tensorflow/python/eager/context.py41
-rw-r--r--tensorflow/python/eager/core_test.py24
-rw-r--r--tensorflow/python/eager/execution_callbacks.py8
-rw-r--r--tensorflow/python/eager/function.py875
-rw-r--r--tensorflow/python/eager/function_test.py290
-rw-r--r--tensorflow/python/eager/graph_callable.py435
-rw-r--r--tensorflow/python/eager/graph_callable_test.py249
-rw-r--r--tensorflow/python/eager/graph_only_ops_test.py4
-rw-r--r--tensorflow/python/eager/imperative_grad.py10
-rw-r--r--tensorflow/python/eager/pywrap_tensor.cc90
-rwxr-xr-x[-rw-r--r--]tensorflow/python/eager/pywrap_tfe.h29
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc375
-rw-r--r--tensorflow/python/eager/tape.py26
-rw-r--r--tensorflow/python/eager/tape_test.py4
-rw-r--r--tensorflow/python/eager/tensor_test.py15
19 files changed, 1298 insertions, 1529 deletions
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index de93b1e2e1..85da1baaf0 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -47,7 +47,6 @@ py_library(
":core",
":execute",
":function",
- ":graph_callable",
":graph_only_ops",
":tape",
":test",
@@ -238,10 +237,11 @@ py_library(
visibility = ["//tensorflow:internal"],
deps = [
":graph_only_ops",
+ "//tensorflow/python:cond_v2_impl",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
- "//tensorflow/python:gradients",
+ "//tensorflow/python:gradients_impl",
"//tensorflow/python:graph_to_function_def",
"//tensorflow/python:util",
"//tensorflow/python/eager:context",
@@ -254,41 +254,6 @@ py_library(
)
py_library(
- name = "graph_callable",
- srcs = ["graph_callable.py"],
- srcs_version = "PY2AND3",
- visibility = ["//tensorflow:internal"],
- deps = [
- "//tensorflow/python:array_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:resource_variable_ops",
- "//tensorflow/python:util",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python/eager:context",
- "//tensorflow/python/eager:function",
- "//tensorflow/python/eager:tape",
- ],
-)
-
-py_test(
- name = "graph_callable_test",
- srcs = ["graph_callable_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":backprop",
- ":graph_callable",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:function",
- "//tensorflow/python:init_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python/eager:test",
- ],
-)
-
-py_library(
name = "backprop",
srcs = ["backprop.py"],
srcs_version = "PY2AND3",
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index 553f761a14..be392c7a0f 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -34,6 +34,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import tf_logging as logging
@@ -180,10 +181,10 @@ def implicit_val_and_grad(f):
```
Args:
- f: function to be differentiated. If `f` returns a scalar, this scalar will
- be differentiated. If `f` returns a tensor or list of tensors, by default
- a scalar will be computed by adding all their values to produce a single
- scalar.
+ f: function to be differentiated. If `f` returns a scalar, this scalar will
+ be differentiated. If `f` returns a tensor or list of tensors, by default
+ a scalar will be computed by adding all their values to produce a single
+ scalar.
Returns:
A function which, when called, returns a tuple pair.
@@ -215,9 +216,7 @@ def implicit_val_and_grad(f):
"function was being computed.")
sources = [v.handle for v in variables]
- grad = imperative_grad.imperative_grad(_default_vspace,
- this_tape,
- nest.flatten(end_node),
+ grad = imperative_grad.imperative_grad(this_tape, nest.flatten(end_node),
sources)
return end_node, list(zip(grad, variables))
@@ -255,10 +254,10 @@ def implicit_grad(f):
```
Args:
- f: function to be differentiated. If `f` returns a scalar, this scalar will
- be differentiated. If `f` returns a tensor or list of tensors, by default
- a scalar will be computed by adding all their values to produce a single
- scalar.
+ f: function to be differentiated. If `f` returns a scalar, this scalar will
+ be differentiated. If `f` returns a tensor or list of tensors, by default
+ a scalar will be computed by adding all their values to produce a single
+ scalar.
Returns:
A function which, when called, returns a list of (gradient, variable) pairs.
@@ -343,24 +342,24 @@ def gradients_function(f, params=None):
Note that only tensors with real or complex dtypes are differentiable.
Args:
- f: function to be differentiated. If `f` returns a scalar, this scalar will
- be differentiated. If `f` returns a tensor or list of tensors, by default
- a scalar will be computed by adding all their values to produce a single
- scalar. If desired, the tensors can be elementwise multiplied by the
- tensors passed as the `dy` keyword argument to the returned gradient
- function.
- params: list of parameter names of f or list of integers indexing the
- parameters with respect to which we'll differentiate. Passing None
- differentiates with respect to all parameters.
+ f: function to be differentiated. If `f` returns a scalar, this scalar will
+ be differentiated. If `f` returns a tensor or list of tensors, by default
+ a scalar will be computed by adding all their values to produce a single
+ scalar. If desired, the tensors can be elementwise multiplied by the
+ tensors passed as the `dy` keyword argument to the returned gradient
+ function.
+ params: list of parameter names of f or list of integers indexing the
+ parameters with respect to which we'll differentiate. Passing None
+ differentiates with respect to all parameters.
Returns:
function which, when called, returns the value of f and the gradient
- of f with respect to all of `params`. The function takes an extra optional
- keyword argument "dy". Setting it allows computation of vector jacobian
+ of `f` with respect to all of `params`. The function takes an extra optional
+ keyword argument `dy`. Setting it allows computation of vector jacobian
products for vectors other than the vector of ones.
Raises:
- ValueError: if the params are not all strings or all integers.
+ ValueError: if the params are not all strings or all integers.
"""
def decorated(*args, **kwds):
@@ -440,23 +439,24 @@ def val_and_grad_function(f, params=None):
```
Args:
- f: function to be differentiated. If `f` returns a scalar, this scalar will
- be differentiated. If `f` returns a tensor or list of tensors, by default
- a scalar will be computed by adding all their values to produce a single
- scalar. If desired, the tensors can be elementwise multiplied by the
- tensors passed as the `dy` keyword argument to the returned gradient
- function.
- params: list of parameter names of f or list of integers indexing the
- parameters with respect to which we'll differentiate. Passing `None`
- differentiates with respect to all parameters.
-
- Returns: function which, when called, returns the value of f and the gradient
- of f with respect to all of `params`. The function takes an extra optional
- keyword argument "dy". Setting it allows computation of vector jacobian
- products for vectors other than the vector of ones.
+ f: function to be differentiated. If `f` returns a scalar, this scalar will
+ be differentiated. If `f` returns a tensor or list of tensors, by default
+ a scalar will be computed by adding all their values to produce a single
+ scalar. If desired, the tensors can be elementwise multiplied by the
+ tensors passed as the `dy` keyword argument to the returned gradient
+ function.
+ params: list of parameter names of f or list of integers indexing the
+ parameters with respect to which we'll differentiate. Passing `None`
+ differentiates with respect to all parameters.
+
+ Returns:
+ function which, when called, returns the value of f and the gradient
+ of f with respect to all of `params`. The function takes an extra optional
+ keyword argument "dy". Setting it allows computation of vector jacobian
+ products for vectors other than the vector of ones.
Raises:
- ValueError: if the params are not all strings or all integers.
+ ValueError: if the params are not all strings or all integers.
"""
def decorated(*args, **kwds):
@@ -520,7 +520,7 @@ def make_vjp(f, params=None, persistent=True):
args = _ensure_unique_tensor_objects(parameter_positions, args)
for i in parameter_positions:
sources.append(args[i])
- tape.watch(args[i])
+ tape.watch(this_tape, args[i])
result = f(*args)
if result is None:
raise ValueError("Cannot differentiate a function that returns None; "
@@ -535,8 +535,8 @@ def make_vjp(f, params=None, persistent=True):
if dy is not None:
dy = [ops.convert_to_tensor(x) for x in nest.flatten(dy)]
return imperative_grad.imperative_grad(
- _default_vspace, this_tape, nest.flatten(result), sources,
- output_gradients=dy)
+ this_tape, nest.flatten(result), sources, output_gradients=dy)
+
return result, vjp
return decorated
@@ -557,7 +557,7 @@ def _aggregate_grads(gradients):
if len(gradients) == 1:
return gradients[0]
if all([isinstance(g, ops.Tensor) for g in gradients]):
- return math_ops.add_n(gradients)
+ return gen_math_ops.add_n(gradients)
else:
assert all([isinstance(g, (ops.Tensor, ops.IndexedSlices))
for g in gradients])
@@ -592,7 +592,9 @@ def _num_elements(grad):
def _fast_fill(value, shape, dtype):
- return array_ops.fill(shape, constant_op.constant(value, dtype=dtype))
+ return array_ops.fill(
+ constant_op.constant(shape, dtype=dtypes.int32),
+ constant_op.constant(value, dtype=dtype))
def _zeros(shape, dtype):
@@ -627,9 +629,9 @@ def _ones(shape, dtype):
_default_vspace = imperative_grad.VSpace(
num_elements_fn=_num_elements,
aggregate_fn=_aggregate_grads,
- tensor_id=ops.tensor_id,
zeros=_zeros,
ones=_ones)
+pywrap_tensorflow.TFE_Py_RegisterVSpace(_default_vspace)
def _handle_or_self(x):
@@ -691,19 +693,57 @@ class GradientTape(object):
del g # Drop the reference to the tape
```
+ By default GradientTape will automatically watch any trainable variables that
+ are accessed inside the context. If you want fine grained control over which
+ variables are watched you can disable automatic tracking by passing
+ `watch_accessed_variables=False` to the tape constructor:
+
+ ```python
+ with tf.GradientTape(watch_accessed_variables=False) as tape:
+ tape.watch(variable_a)
+ y = variable_a ** 2 # Gradients will be available for `variable_a`.
+ z = variable_b ** 3 # No gradients will be avaialble since `variable_b` is
+ # not being watched.
+ ```
+
+ Note that when using models you should ensure that your variables exist when
+ using `watch_accessed_variables=False`. Otherwise it's quite easy to make your
+ first iteration not have any gradients:
+
+ ```python
+ a = tf.keras.layers.Dense(32)
+ b = tf.keras.layers.Dense(32)
+
+ with tf.GradientTape(watch_accessed_variables=False) as tape:
+ tape.watch(a.variables) # Since `a.build` has not been called at this point
+ # `a.variables` will return an empty list and the
+ # tape will not be watching anything.
+ result = b(a(inputs))
+ tape.gradient(result, a.variables) # The result of this computation will be
+ # a list of `None`s since a's variables
+ # are not being watched.
+ ```
+
Note that only tensors with real or complex dtypes are differentiable.
"""
- def __init__(self, persistent=False):
+ def __init__(self, persistent=False, watch_accessed_variables=True):
"""Creates a new GradientTape.
Args:
persistent: Boolean controlling whether a persistent gradient tape
is created. False by default, which means at most one call can
be made to the gradient() method on this object.
+ watch_accessed_variables: Boolean controlling whether the tape will
+ automatically `watch` any (trainable) variables accessed while the tape
+ is active. Defaults to True meaning gradients can be requested from any
+ result computed in the tape derived from reading a trainable `Variable`.
+ If False users must explicitly `watch` any `Variable`s they want to
+ request gradients from.
"""
self._tape = None
self._persistent = persistent
+ self._watch_accessed_variables = watch_accessed_variables
self._recording = False
context.context().start_step()
@@ -717,15 +757,15 @@ class GradientTape(object):
if self._recording:
self._pop_tape()
- def _push_tape(self, existing_tape=False):
+ def _push_tape(self):
if self._recording:
raise ValueError("Tape is already recording.")
- if existing_tape:
- if self._tape is None:
- raise ValueError("There is no existing tape.")
- tape.push_tape(self._tape)
+ if self._tape is None:
+ self._tape = tape.push_new_tape(
+ persistent=self._persistent,
+ watch_accessed_variables=self._watch_accessed_variables)
else:
- self._tape = tape.push_new_tape(persistent=self._persistent)
+ tape.push_tape(self._tape)
self._recording = True
def _pop_tape(self):
@@ -744,7 +784,13 @@ class GradientTape(object):
tensor: a Tensor or list of Tensors.
"""
for t in nest.flatten(tensor):
- tape.watch(_handle_or_self(t))
+ if hasattr(t, "handle"):
+ # There are many variable-like objects, all of them currently have
+ # `handle` attribute that points to a tensor. If this changes, internals
+ # of watch_variable need to change as well.
+ tape.watch_variable(self._tape, t)
+ else:
+ tape.watch(self._tape, t)
@tf_contextlib.contextmanager
def stop_recording(self):
@@ -776,7 +822,7 @@ class GradientTape(object):
try:
yield
finally:
- self._push_tape(existing_tape=True)
+ self._push_tape()
def reset(self):
"""Clears all information stored in this tape.
@@ -810,6 +856,7 @@ class GradientTape(object):
```
"""
self._pop_tape()
+ self._tape = None
self._push_tape()
def watched_variables(self):
@@ -861,7 +908,9 @@ class GradientTape(object):
for x in nest.flatten(output_gradients)]
flat_grad = imperative_grad.imperative_grad(
- _default_vspace, self._tape, nest.flatten(target), flat_sources,
+ self._tape,
+ nest.flatten(target),
+ flat_sources,
output_gradients=output_gradients)
if not self._persistent:
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index 3d3f54b9c4..f938ed5df8 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -23,7 +23,6 @@ import numpy as np
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
-from tensorflow.python.eager import tape
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -65,7 +64,7 @@ class BackpropTest(test.TestCase):
grad = backprop.gradients_function(fn, [0])(var)[0]
grad = self.evaluate(ops.convert_to_tensor(grad))
- with context.graph_mode(), self.test_session():
+ with context.graph_mode():
tf_var = array_ops.constant(var_np, dtypes.float32)
tf_ind1 = array_ops.constant([0, 1])
tf_ind2 = array_ops.constant([2, 3])
@@ -80,14 +79,13 @@ class BackpropTest(test.TestCase):
tf_dense_grad = math_ops.unsorted_segment_sum(
tf_grad.values, tf_grad.indices, tf_grad.dense_shape[0])
- self.assertAllClose(grad, tf_dense_grad.eval())
+ self.assertAllClose(grad, self.evaluate(tf_dense_grad))
def testImplicitGradWithResourceVariable(self):
x = resource_variable_ops.ResourceVariable(
initial_value=constant_op.constant(1.0), name='x')
def fn():
- tape.watch_variable(x)
b = constant_op.constant(2.0)
c = math_ops.add(x.value(), b)
return math_ops.add(c, constant_op.constant(3.0))
@@ -194,14 +192,13 @@ class BackpropTest(test.TestCase):
initial_value=random_init, dtype=dtypes.float32, name='embedding')
def f():
- tape.watch_variable(embedding)
embedded_x = embedding_ops.embedding_lookup(embedding, x)
return constant_op.constant(1.0, dtypes.float32) - embedded_x
grad = backprop.implicit_grad(f)()[0][0]
opt = training.GradientDescentOptimizer(lrn_rate)
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
tf_x = array_ops.ones((batch_size), dtypes.int64)
# TODO(ashankar,apassos): Change to ResourceVariable.
tf_embedding = variables.Variable(
@@ -316,6 +313,24 @@ class BackpropTest(test.TestCase):
grad = backprop.gradients_function(second, [0])(f)[0]
self.assertAllEqual([[0.0]], grad)
+ @test_util.run_in_graph_and_eager_modes
+ def testWatchingIsTapeLocal(self):
+ x1 = resource_variable_ops.ResourceVariable(2.0, trainable=False)
+ x2 = resource_variable_ops.ResourceVariable(2.0, trainable=False)
+
+ with backprop.GradientTape() as tape1:
+ with backprop.GradientTape() as tape2:
+ tape1.watch(x1)
+ tape2.watch([x1, x2])
+ y = x1 ** 3
+ z = x2 ** 2
+ dy, dz = tape2.gradient([y, z], [x1, x2])
+ d2y, d2z = tape1.gradient([dy, dz], [x1, x2])
+
+ self.evaluate([x1.initializer, x2.initializer])
+ self.assertEqual(self.evaluate(d2y), 12.0)
+ self.assertIsNone(d2z)
+
@test_util.assert_no_new_tensors
def testMakeVJP(self):
@@ -404,7 +419,6 @@ class BackpropTest(test.TestCase):
def f():
with context.device('gpu:0'):
- tape.watch_variable(v)
return v.read_value()
self.assertEqual(
@@ -460,6 +474,18 @@ class BackpropTest(test.TestCase):
self.assertEqual(backprop.implicit_grad(f)()[0][0], None)
@test_util.assert_no_new_tensors
+ def testGradientTapeReEnterContext(self):
+ g = backprop.GradientTape()
+ with g:
+ x = constant_op.constant(3.0)
+ g.watch(x)
+ y = 2*x
+ with g:
+ z = 2*y
+ grad = g.gradient(target=z, sources=[x])
+ self.assertEqual(self.evaluate(grad), [4.0])
+
+ @test_util.assert_no_new_tensors
@test_util.run_in_graph_and_eager_modes
def testGradientTapeRepeatedSource(self):
with backprop.GradientTape(persistent=False) as g:
@@ -784,7 +810,6 @@ class BackpropTest(test.TestCase):
initial_value=array_ops.constant([1.0]), name='x')
def fn():
- tape.watch_variable(x)
a = math_ops.add(x.value(), 1.0)
# Make sure convert_to_tensor works correctly with list of TensorNodes.
b = array_ops.stack([a, a], axis=0)
@@ -928,21 +953,75 @@ class BackpropTest(test.TestCase):
def testZerosCacheDoesntLeakAcrossGraphs(self):
with context.graph_mode():
def get_grad():
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
t = constant_op.constant(1, dtype=dtypes.float32, shape=(10, 4))
x = constant_op.constant(2, dtype=dtypes.float32, shape=(10, 4))
- with backprop.GradientTape() as gt:
+ with backprop.GradientTape() as tape:
tape.watch(x)
x1, _ = array_ops.split(x, num_or_size_splits=2, axis=1)
y1 = x1**2
y = array_ops.concat([y1, t], axis=1)
- return self.evaluate(gt.gradient(y, x))
+ return self.evaluate(tape.gradient(y, x))
grad1 = get_grad()
grad2 = get_grad()
self.assertAllEqual(grad1, grad2)
+ @test_util.run_in_graph_and_eager_modes
+ def testSelectivelyWatchVariables(self):
+ x1 = resource_variable_ops.ResourceVariable(1.0)
+ x2 = resource_variable_ops.ResourceVariable(1.0)
+ with backprop.GradientTape(watch_accessed_variables=False) as tape:
+ tape.watch(x2)
+ y = x1**2
+ z = x2**3
+ self.assertTupleEqual(tape.watched_variables(), (x2,))
+ dy, dz = tape.gradient([y, z], [x1, x2])
+ self.evaluate([x1.initializer, x2.initializer])
+ self.assertIsNone(dy)
+ self.assertEqual(self.evaluate(dz), 3.0)
+
+
+ @test_util.run_in_graph_and_eager_modes
+ def testDifferentiatingScalarCache(self):
+ # In the following test, if x2 = x1 (i.e the objects are the exact same),
+ # then y is essentially, 2*x1, and dy/dx1 = 2.
+ # When we had a pure scalar cache in eager, this would be the case. This
+ # test prevents us from going back to that case.
+ with backprop.GradientTape(persistent=False) as g:
+ x1 = constant_op.constant(3.0)
+ x2 = constant_op.constant(3.0)
+ g.watch(x1)
+ g.watch(x2)
+ y = x1 + x2
+ grad = g.gradient(target=y, sources=[x1])
+ self.assertEqual(self.evaluate(grad), [1.0])
+
+ def testVariablesAndConstantsProduceTheSameGradients(self):
+
+ # In the following test, differentiating [y, z] against [a, b] gives:
+ # (dy/da + dz/da, dy/db + dz/db).
+ # If a and b are the same constant, dz/da will not be 0 (which it should
+ # be).
+ # This is solved by using variable since doing a read_value on a tensor will
+ # produce a new tensor and corresponding TensorHandle, and not reuse the
+ # same tensor (which would happen if we are using a cache and reusing
+ # EagerTensor objects).
+ def get_grads(a, b):
+ with backprop.GradientTape() as tape:
+ tape.watch([a, b])
+ y = a**3
+ z = b**2
+ return tape.gradient([y, z], [a, b])
+
+ gradients_constants = get_grads(
+ constant_op.constant(2.0), constant_op.constant(2.0))
+ gradients_variables = get_grads(
+ resource_variable_ops.ResourceVariable(2.0),
+ resource_variable_ops.ResourceVariable(2.0))
+ self.assertAllEqual(gradients_constants, gradients_variables)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py
index e2b1890c2f..3fe79ef244 100644
--- a/tensorflow/python/eager/benchmarks_test.py
+++ b/tensorflow/python/eager/benchmarks_test.py
@@ -42,6 +42,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
+from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops
@@ -175,6 +176,11 @@ class MicroBenchmarks(test.Benchmark):
self._run(func, 30000)
+ def benchmark_create_constant(self):
+ func = lambda: constant_op.constant(3.0)
+
+ self._run(func, 30000)
+
def benchmark_create_float_tensor_from_list_CPU(self):
self._benchmark_create_tensor([[3.0]], dtypes.float32.as_datatype_enum, CPU)
@@ -350,6 +356,21 @@ class MicroBenchmarks(test.Benchmark):
func = lambda: f(m, m, transpose_b)
self._run(func, num_iters, execution_mode=execution_mode)
+ def _benchmark_defun_matmul_forward_backward(self,
+ m,
+ transpose_b,
+ num_iters,
+ execution_mode=None):
+ f = function.defun(math_ops.matmul)
+
+ def func():
+ with backprop.GradientTape() as gt:
+ gt.watch(m)
+ y = f(m, m, transpose_b)
+ _ = gt.gradient(y, m)
+
+ self._run(func, num_iters, execution_mode=execution_mode)
+
def _benchmark_read_variable(self, m, num_iters):
self._run(m.value, num_iters)
@@ -421,6 +442,21 @@ class MicroBenchmarks(test.Benchmark):
num_iters=self._num_iters_2_by_2,
execution_mode=context.ASYNC)
+ def benchmark_defun_matmul_forward_backward_2_by_2_CPU(self):
+ with context.device(CPU):
+ m = self._m_2_by_2.cpu()
+ self._benchmark_defun_matmul_forward_backward(
+ m, transpose_b=False, num_iters=self._num_iters_2_by_2)
+
+ def benchmark_defun_matmul_forward_backward_2_by_2_CPU_async(self):
+ with context.device(CPU):
+ m = self._m_2_by_2.cpu()
+ self._benchmark_defun_matmul_forward_backward(
+ m,
+ transpose_b=False,
+ num_iters=self._num_iters_2_by_2,
+ execution_mode=context.ASYNC)
+
def benchmark_tf_matmul_2_by_2_GPU(self):
if not context.num_gpus():
return
@@ -682,6 +718,25 @@ class MicroBenchmarks(test.Benchmark):
assert np.equal(func(), make_keras_model()(data)).all()
self._run(func, 30000)
+ def benchmarkScan(self):
+ elems = math_ops.range(1600)
+
+ def scan():
+ return functional_ops.scan(
+ lambda a, x: a + x, elems, parallel_iterations=1)
+
+ self._run(scan, 100)
+
+ def benchmarkScanDefun(self):
+ elems = math_ops.range(1600)
+
+ @function.defun
+ def scan():
+ return functional_ops.scan(
+ lambda a, x: a + x, elems, parallel_iterations=1)
+
+ self._run(scan, 100)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py
index 6a327bd010..778ff85342 100644
--- a/tensorflow/python/eager/context.py
+++ b/tensorflow/python/eager/context.py
@@ -37,7 +37,7 @@ GRAPH_MODE = 0
EAGER_MODE = 1
# Default execution mode.
-_default_mode = GRAPH_MODE
+default_execution_mode = GRAPH_MODE
# Cache from (old_device_name, partial_new_device_name) -> (new_device_name,
# new_device_spec).
@@ -56,14 +56,18 @@ SYNC = 0
ASYNC = 1
-class _TensorCache(object):
+class _EagerTensorCache(object):
"""Simple cache which evicts items based on length in a FIFO manner."""
- def __init__(self, max_items=256):
+ def __init__(self, max_items=256, max_tensor_size=10000):
self._data = collections.OrderedDict()
- self._max_items = max_items if max_items else 256
+ self._max_items = max_items
+ self._max_tensor_size = max_tensor_size
def put(self, key, value):
+ if value._num_elements() > self._max_tensor_size: # pylint: disable=protected-access
+ return
+
self._data[key] = value
if len(self._data) > self._max_items:
@@ -84,14 +88,14 @@ class _EagerContext(threading.local):
super(_EagerContext, self).__init__()
self.device_spec = pydev.DeviceSpec.from_string("")
self.device_name = self.device_spec.to_string()
- self.mode = _default_mode
- self.is_eager = _default_mode == EAGER_MODE
+ self.mode = default_execution_mode
+ self.is_eager = default_execution_mode == EAGER_MODE
self.scope_name = ""
self.recording_summaries = False
self.summary_writer_resource = None
self.scalar_cache = {}
- self.ones_rank_cache = _TensorCache()
- self.zeros_cache = _TensorCache()
+ self.ones_rank_cache = _EagerTensorCache()
+ self.zeros_cache = _EagerTensorCache()
self.execution_mode = None
@@ -111,8 +115,8 @@ class _ContextSwitchStack(threading.local):
# Initialize the stack with a pointer to enter the eager context; this
# ensures that the fact that eager execution was enabled is propagated
# across threads, since (1) `enable_eager_execution` modifies a
- # process-level flag (`_default_mode`) and (2) `__init__` is called each
- # time a threading.local object is used in a separate thread.
+ # process-level flag (`default_execution_mode`) and (2) `__init__` is
+ # called each time a threading.local object is used in a separate thread.
self.push(is_building_function=False, enter_context_fn=eager_mode)
def push(self, is_building_function, enter_context_fn):
@@ -504,9 +508,7 @@ class Context(object):
Args:
fn: A wrapped TF_Function (returned from TF_GraphToFunction_wrapper).
"""
- pywrap_tensorflow.TFE_ContextAddFunction(
- self._handle, # pylint: disable=protected-access
- fn)
+ pywrap_tensorflow.TFE_ContextAddFunction(self._handle, fn)
def add_function_def(self, fdef):
"""Add a function definition to the context.
@@ -519,9 +521,7 @@ class Context(object):
"""
fdef_string = fdef.SerializeToString()
pywrap_tensorflow.TFE_ContextAddFunctionDef(
- self._handle, # pylint: disable=protected-access
- fdef_string,
- len(fdef_string))
+ self._handle, fdef_string, len(fdef_string))
def add_post_execution_callback(self, callback):
"""Add a post-execution callback to the context.
@@ -633,14 +633,7 @@ def context():
def context_safe():
- return _context
-
-
-# TODO(agarwal): remove this.
-def get_default_context():
- """Same as context."""
- if _context is None:
- _initialize_context()
+ """Returns current context (or None if one hasn't been initialized)."""
return _context
diff --git a/tensorflow/python/eager/core_test.py b/tensorflow/python/eager/core_test.py
index cc765725a4..fb5442b646 100644
--- a/tensorflow/python/eager/core_test.py
+++ b/tensorflow/python/eager/core_test.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os
+import pickle
import threading
import numpy as np
@@ -185,6 +187,17 @@ class TFETest(test_util.TensorFlowTestCase):
device_count={'GPU': 0}))
self.assertEquals(0, ctx.num_gpus())
+ def testPickle(self):
+ tmp_dir = self.get_temp_dir()
+ fname = os.path.join(tmp_dir, 't.pickle')
+ with open(fname, 'wb') as f:
+ t = constant_op.constant(10.0)
+ pickle.dump(t, f)
+
+ with open(fname, 'rb') as f:
+ t = pickle.load(f)
+ self.assertAllEqual(t.numpy(), 10.0)
+
def testTensorPlacement(self):
if not context.context().num_gpus():
self.skipTest('No GPUs found')
@@ -676,5 +689,16 @@ class SendRecvTest(test_util.TensorFlowTestCase):
2.0)
+class EagerTensorCacheTest(test_util.TensorFlowTestCase):
+
+ def testCacheSkipsTensorsTooLarge(self):
+ cache = context._EagerTensorCache(max_items=100, max_tensor_size=3)
+ cache.put('1', array_ops.zeros((2, 2)))
+ self.assertEqual(cache.get('1'), None)
+
+ cache.put('2', array_ops.zeros((2)))
+ self.assertNotEqual(cache.get('2'), None)
+
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/eager/execution_callbacks.py b/tensorflow/python/eager/execution_callbacks.py
index 9a08259653..80ff4459d6 100644
--- a/tensorflow/python/eager/execution_callbacks.py
+++ b/tensorflow/python/eager/execution_callbacks.py
@@ -146,7 +146,7 @@ def inf_nan_callback(op_type,
"""
del attrs, inputs # Not used.
- ctx = context.get_default_context()
+ ctx = context.context()
for index, output in enumerate(outputs):
if not output.dtype.is_numpy_compatible:
@@ -263,12 +263,12 @@ def add_execution_callback(callback):
Return value(s) from the callback are ignored.
"""
execute.execute = execute.execute_with_callbacks
- context.get_default_context().add_post_execution_callback(callback)
+ context.context().add_post_execution_callback(callback)
def clear_execution_callbacks():
"""Clear all execution callbacks from the default eager context."""
- context.get_default_context().clear_post_execution_callbacks()
+ context.context().clear_post_execution_callbacks()
def seterr(inf_or_nan=None):
@@ -309,7 +309,7 @@ def seterr(inf_or_nan=None):
"Valid actions are %s." % (inf_or_nan, _VALID_CALLBACK_ACTIONS))
old_settings = {"inf_or_nan": "ignore"}
- default_context = context.get_default_context()
+ default_context = context.context()
carryover_callbacks = []
for callback in default_context.post_execution_callbacks:
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 5afba466bc..03f12139f6 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -21,12 +21,12 @@ from __future__ import print_function
import collections
import functools
+import sys
import threading
import numpy as np
import six
-from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import function_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
@@ -34,10 +34,12 @@ from tensorflow.python.eager import execute
from tensorflow.python.eager import tape
from tensorflow.python.eager.graph_only_ops import graph_placeholder
from tensorflow.python.framework import c_api_util
+from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes as dtypes_module
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import cond_v2_impl
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gradients_impl
@@ -49,8 +51,15 @@ from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
+# This is to avoid a circular dependency with cond_v2_impl
+# (function -> gradients_impl -> control_flow_ops -> cond_v2_impl).
+cond_v2_impl._function = sys.modules[__name__] # pylint: disable=protected-access
-def create_substitute_placeholder(value, name, dtype=None):
+# This is to avoid a circular dependency with gradients_impl
+gradients_impl._function = sys.modules[__name__] # pylint: disable=protected-access
+
+
+def _create_substitute_placeholder(value, name, dtype=None):
"""Creates a placeholder for `value` and propagates shape info to it."""
# Note: setting ops.control_dependencies(None) ensures we always put
# capturing placeholders outside of any control flow context.
@@ -82,82 +91,131 @@ def create_substitute_placeholder(value, name, dtype=None):
return placeholder
-def capture_value(tensor_map, value, dtype, name):
- """Capture a value from outside the function, to pass in as an extra arg."""
- captured_value = tensor_map.get(value, None)
- if captured_value is None:
- captured_value = create_substitute_placeholder(value, name=name,
- dtype=dtype)
- tensor_map[value] = captured_value
- tape.record_operation("captured_value", [captured_value], [value],
- lambda x: [x])
- return captured_value
+def _get_device_functions(ctx, graph):
+ """Returns a tuple of device functions representing the device stack."""
+ if ctx.executing_eagerly():
+ return (pydev.merge_device(ctx.device_name),)
+ else:
+ return tuple(graph._device_functions_outer_to_inner) # pylint: disable=protected-access
-class CapturingGraph(ops.Graph):
- """Graph that can capture tensors from other graphs.
+class FuncGraph(ops.Graph):
+ """Graph representing a function body.
Attributes:
- captures: Maps external tensor -> internal tensor (e.g. input placeholder).
+ name: The name of the function.
+ inputs: Placeholder tensors representing the inputs to this function. The
+ tensors are in this FuncGraph. This represents "regular" inputs as well as
+ captured inputs (i.e. the values of self.captures), with the regular
+ inputs coming first.
+ outputs: Tensors that will be returned by this function. The tensors are in
+ this FuncGraph.
+ structured_outputs: A possibly-nested python object which will be returned
+ by this function. The Tensors in this structure are the same as those of
+ self.outputs. Note that this structure might contain Python `None`s.
+ variables: Variables that should be watched during function execution.
+ outer_graph: The graph this function is defined in. May be another FuncGraph
+ or the global default Graph.
+ captures: Maps external tensor -> internal tensor (i.e. input placeholder).
The entries are in the order they were captured.
+ seed: The graph-level random seed.
"""
- def __init__(self):
- super(CapturingGraph, self).__init__()
+ def __init__(self, name):
+ """Construct a new FuncGraph.
+
+ The graph will inherit its graph key, collections, seed, device stack, and
+ distribution strategy stack from the current context or graph.
+ Args:
+ name: the name of the function.
+ """
+ super(FuncGraph, self).__init__()
+
+ self.name = name
+ self.inputs = []
+ self.outputs = []
+ self.structured_outputs = None
+ self.variables = []
+ self.outer_graph = ops.get_default_graph()
self.captures = collections.OrderedDict()
- self._building_function = True
+ self._building_function = True
# Map from resource tensor name to last op (in program order) which uses
# this tensor. Used to enforce that execution order matches program order
# for resource tensors.
self._last_op_using_resource_tensor = {}
- # TODO(apassos) remove once the C API is used by default.
- def _use_c_api_hack(self):
- return True
-
- def clear_resource_control_flow_state(self):
- self._last_op_using_resource_tensor = {}
+ graph = self.outer_graph
- # TODO(skyewm): get rid of name and use the name of `tensor`.
- def capture(self, tensor, name=None):
- """Capture `tensor` if it's external to this graph.
-
- If `tensor` is from a different graph, returns a placeholder for it.
- `tensor` and the placeholder will also appears in self.captures. Multiple
- calls to this method with the same `tensor` argument will return the same
- placeholder. If `tensor` is from this graph, returns `tensor`.
-
- Args:
- tensor: Tensor. May be from this FuncGraph or a different graph.
- name: Optional name if a placeholder is created.
-
- Returns:
- Tensor from this FuncGraph.
- """
- if isinstance(tensor, ops.EagerTensor):
- if name is None:
- name = str(ops.uid())
- return capture_value(self.captures, tensor, tensor.dtype, name)
- if tensor.graph is not self:
- if name is None:
- name = tensor.op.name
- return capture_value(self.captures, tensor, tensor.dtype, name)
- return tensor
+ if context.executing_eagerly():
+ self.seed = context.global_seed()
+ self._xla_compile = (context.context().device_spec.device_type == "TPU")
+ self._add_device_to_stack(context.context().device_name)
+ else:
+ self.seed = graph.seed
+ self._xla_compile = getattr(graph, "_xla_compile", False)
+ self._device_function_stack = graph._device_function_stack.copy() # pylint: disable=protected-access
+ self._colocation_stack = graph._colocation_stack.copy() # pylint: disable=protected-access
+
+ # TODO(b/112165328, b/112906995): summaries depend on inheriting collections
+ # from the default graph even in eager mode. It'd be nice to not have a
+ # default graph with eager execution, so hopefully this will go away when we
+ # remove collections.
+ # pylint: disable=protected-access
+ self._collections = graph._collections
+ # TODO(b/112906995): distribution strategy depends on inheriting this stack
+ # from the default graph even in eager mode. Maybe it should be part of the
+ # eager context?
+ self._distribution_strategy_stack = graph._distribution_strategy_stack
+ # Inherit the graph key, since this is used for matching variables in
+ # optimizers.
+ self._graph_key = graph._graph_key
+ # pylint: enable=protected-access
def create_op(
self,
op_type,
inputs,
- dtypes, # pylint: disable=redefined-outer-name
+ dtypes,
input_types=None,
name=None,
attrs=None,
op_def=None,
compute_shapes=True,
compute_device=True):
- """Captures an external inputs before calling Graph.capture_op."""
+ """Like Graph.create_op, except handles external input tensors.
+
+ This overload adds functionality to create_op to "capture" any external
+ input tensors, i.e. tensors from the eager context or outer function graphs
+ if this is a nested function. See `capture` for more information.
+
+ Args:
+ op_type: The `Operation` type to create. This corresponds to the
+ `OpDef.name` field for the proto that defines the operation.
+ inputs: A list of `Tensor` objects that will be inputs to the `Operation`.
+ dtypes: A list of `DType` objects that will be the types of the tensors
+ that the operation produces.
+ input_types: (Optional.) A list of `DType`s that will be the types of
+ the tensors that the operation consumes. By default, uses the base
+ `DType` of each input in `inputs`. Operations that expect
+ reference-typed inputs must specify `input_types` explicitly.
+ name: (Optional.) A string name for the operation. If not specified, a
+ name is generated based on `op_type`.
+ attrs: (Optional.) A dictionary where the key is the attribute name (a
+ string) and the value is the respective `attr` attribute of the
+ `NodeDef` proto that will represent the operation (an `AttrValue`
+ proto).
+ op_def: (Optional.) The `OpDef` proto that describes the `op_type` that
+ the operation will have.
+ compute_shapes: (Optional.) Deprecated. Has no effect (shapes are always
+ computed).
+ compute_device: (Optional.) If True, device functions will be executed
+ to compute the device property of the Operation.
+
+ Returns:
+ An `Operation` object.
+ """
# This capturing logic interacts poorly with control flow contexts which
# want to replace inputs of ops far too late in the process. This can lead
# the context to get confused and try to create an Enter for an Enter. We
@@ -171,80 +229,61 @@ class CapturingGraph(ops.Graph):
# to capture the inputs.
ctxt = ops.get_default_graph()._control_flow_context # pylint: disable=protected-access
for i, inp in enumerate(inputs):
+ # TPU Estimator defines a control flow context with no AddValue method.
if ctxt is not None and hasattr(ctxt, "AddValue"):
inp = ctxt.AddValue(inp)
inp = self.capture(inp)
inputs[i] = inp
- return super(CapturingGraph, self).create_op(
+ return super(FuncGraph, self).create_op(
op_type, inputs, dtypes, input_types, name, attrs, op_def,
compute_device=compute_device)
+ def capture(self, tensor, name=None):
+ """Captures `tensor` if it's external to this graph.
-class FuncGraph(CapturingGraph):
- """Graph representing a function body.
-
- Attributes:
- name: The name of the function.
-
- inputs: Placeholder tensors representing the inputs to this function. The
- tensors are in this FuncGraph. This represents "regular" inputs as well as
- captured inputs (i.e. the values of self.captures), with the regular
- inputs coming first.
- outputs: Tensors that will be returned by this function. The tensors are in
- this FuncGraph.
- structured_outputs: A possibly-nested python object which will be returned
- by this function. The Tensors in this structure are the same as those of
- self.outputs. Note that this structure might contain Python `None`s.
- variables: Variables that should be watched during function execution.
- seed: The graph-level random seed.
- """
-
- def __init__(self, name, graph=None):
- """Construct a new FuncGraph.
+ If `tensor` is from a different graph, returns a placeholder for it.
+ `tensor` and the placeholder will appear in self.captures, and the
+ placeholder will appear in self.inputs. Multiple calls to this method with
+ the same `tensor` argument will return the same placeholder. If `tensor` is
+ from this graph, returns `tensor`.
Args:
- name: the name of the function.
- graph: if specified, this FuncGraph will inherit its graph key,
- collections, and seed from `graph`.
- """
- super(FuncGraph, self).__init__()
-
- self.name = name
- self.inputs = []
- self.outputs = []
- self.structured_outputs = None
- self.variables = []
-
- if graph is not None:
- # Inherit the graph key, since this is used for matching variables in
- # optimizers.
- self._graph_key = graph._graph_key # pylint: disable=protected-access
-
- # Copy the graph collections to ensure summaries and other things work.
- # This lets the function access (but not mutate) collections of the
- # containing graph, such as the global step and the summary writer
- # collections.
- for collection in graph.collections:
- self.get_collection_ref(collection)[:] = graph.get_collection(
- collection)
-
- # Copy distribution strategy scope from the containing graph as well.
- self._distribution_strategy_stack = graph._distribution_strategy_stack # pylint: disable=protected-access
+ tensor: Tensor. May be from this FuncGraph or a different graph.
+ name: Optional name if a placeholder is created.
- if context.executing_eagerly():
- self.seed = context.global_seed()
- else:
- self.seed = graph.seed
+ Returns:
+ Tensor from this FuncGraph.
+ """
+ if isinstance(tensor, ops.EagerTensor):
+ if name is None:
+ name = str(ops.uid())
+ return self._capture_helper(tensor, name)
+ if tensor.graph is not self:
+ if name is None:
+ name = tensor.op.name
+ return self._capture_helper(tensor, name)
+ return tensor
- def capture(self, tensor, name=None):
- """Calls CapturingGraph.capture and updates self.inputs if necessary."""
- new_capture = tensor not in self.captures
- internal_tensor = super(FuncGraph, self).capture(tensor, name)
+ def _capture_helper(self, tensor, name):
+ captured_tensor = self.captures.get(tensor, None)
+ if captured_tensor is None:
+ captured_tensor = _create_substitute_placeholder(tensor, name=name,
+ dtype=tensor.dtype)
+ self.captures[tensor] = captured_tensor
+ self.inputs.append(captured_tensor)
+ tape.record_operation("captured_value", [captured_tensor], [tensor],
+ lambda x: [x])
+ return captured_tensor
- if new_capture and tensor is not internal_tensor:
- self.inputs.append(internal_tensor)
+ @property
+ def external_captures(self):
+ """External tensors captured by this function."""
+ return list(self.captures.keys())
- return internal_tensor
+ @property
+ def internal_captures(self):
+ """Placeholders in this function corresponding captured tensors."""
+ return list(self.captures.values())
def _forward_name(n):
@@ -267,9 +306,6 @@ def _register(fn):
context.context().add_function(fn)
-_xla_compile_attr = "_XlaCompile"
-
-
# TODO(apassos) get rid of this by splitting framework.function._DefinedFunction
# so it doesn't have the definition-generating logic and is just a container for
# an already-defined function.
@@ -282,18 +318,20 @@ class _EagerDefinedFunction(object):
class may be provided as the value of these `func` attributes.
"""
- def __init__(self, name, graph, operations, inputs, outputs, attrs):
+ def __init__(self, name, graph, inputs, outputs, attrs):
"""Initializes an eager defined function.
Args:
name: str, the name for the created function.
graph: Graph, the graph containing the operations in the function
- operations: list of Operation; the subset of operations in the graph
- which will be in the function
inputs: the tensors in the graph to be used as inputs to the function
outputs: the tensors in the graph which will be outputs to the function
attrs: dict mapping names of attributes to their AttrValue values
"""
+ operations = [
+ op for op in graph.get_operations()
+ if op not in set(arg.op for arg in inputs)
+ ]
fn = pywrap_tensorflow.TF_GraphToFunction_wrapper(
graph._c_graph, # pylint: disable=protected-access
compat.as_str(name),
@@ -311,7 +349,6 @@ class _EagerDefinedFunction(object):
# It might be worth creating a convenient way to re-use status.
pywrap_tensorflow.TF_FunctionSetAttrValueProto(
fn, compat.as_str(name), serialized)
- self._xla_compile = _xla_compile_attr in attrs
# TODO(apassos) avoid creating a FunctionDef (specially to grab the
# signature, but also in general it's nice not to depend on it.
@@ -327,6 +364,7 @@ class _EagerDefinedFunction(object):
self.signature = function_def.signature
self._num_outputs = len(self.signature.output_arg)
self._output_types = [o.type for o in self.signature.output_arg]
+ self._output_shapes = [o.shape for o in outputs]
self.grad_func_name = None
self.python_grad_func = None
self._c_func = c_api_util.ScopedTFFunction(fn)
@@ -347,7 +385,7 @@ class _EagerDefinedFunction(object):
def stateful_ops(self):
return self._stateful_ops
- def call(self, ctx, args, output_shapes):
+ def call(self, ctx, args):
"""Calls this function with `args` as inputs.
Function execution respects device annotations only if the function won't
@@ -356,8 +394,6 @@ class _EagerDefinedFunction(object):
Args:
ctx: a Context object
args: a list of arguments to supply this function with.
- output_shapes: shapes to which outputs should be set; ignored when
- executing eagerly.
Returns:
The outputs of the function call.
@@ -365,10 +401,7 @@ class _EagerDefinedFunction(object):
executing_eagerly = ctx.executing_eagerly()
- xla_compile = self._xla_compile or (executing_eagerly and
- ctx.device_spec.device_type == "TPU")
-
- if xla_compile:
+ if self._graph._xla_compile: # pylint: disable=protected-access
# XLA compilation relies upon a custom kernel creator to run functions.
signature = self.signature
if executing_eagerly:
@@ -406,7 +439,7 @@ class _EagerDefinedFunction(object):
if executing_eagerly:
return outputs
else:
- for i, shape in enumerate(output_shapes):
+ for i, shape in enumerate(self._output_shapes):
outputs[i].set_shape(shape)
return outputs
@@ -427,179 +460,117 @@ def _flatten(sequence):
return outputs
-# TODO(akshayka): Perhaps rename to something more appropriate.
-class GraphModeFunction(object):
+class Function(object):
"""Callable object encapsulating a function definition and its gradient.
- `GraphModeFunction` is a callable that encapsulates a function definition and
+ `Function` is a callable that encapsulates a function definition and
is differentiable under `tf.GradientTape` objects.
"""
- def __init__(self,
- name,
- input_placeholders,
- extra_inputs,
- graph,
- operations,
- outputs,
- python_func_outputs,
- output_shapes,
- variables=None,
- attrs=None):
- """Initialize a GraphModeFunction.
+ def __init__(self, func_graph, attrs=None):
+ """Initialize a Function.
Args:
- name: str the name of the created function
- input_placeholders: list of placeholder values (tensors) to feed when
- calling the wrapped function.
- extra_inputs: Tensor inputs this function definition closed over which
- are passed as arguments. Need to track so gradients are supported
- correctly.
- graph: the Graph from which the operations will be pulled. Used as
- a context when computing gradients.
- operations: the subset of Operations in the graph used in the function
- definition.
- outputs: a flat list of the Tensors in the graph used as outputs to the
- function
- python_func_outputs: a possibly nested python object which will be
- returned by this function. The Tensors in this structure will be
- replaced by their corresponding values in outputs. Note that this
- structure might contain Python `None`s.
- output_shapes: List of shapes of all tensors in outputs
- variables: (optional) List of variables to watch during function
- execution.
+ func_graph: An instance of FuncGraph: the function body to wrap.
attrs: (optional) dict mapping names of attributes to their AttrValue
values. Attributes in `attrs` will be included in this function's
definition.
+
+ Raises:
+ ValueError: If number of input_placeholders is not equal to the number
+ of function inputs.
"""
+ self._func_graph = func_graph
+ self._captured_inputs = list(self._func_graph.captures.keys())
+ self._num_outputs = len(self._func_graph.outputs)
+ self._output_shapes = tuple(
+ output.shape for output in self._func_graph.outputs)
self._attrs = attrs or {}
- defined_function = _EagerDefinedFunction(
- name, graph, operations, input_placeholders, outputs, self._attrs)
- if len(input_placeholders) != len(defined_function.signature.input_arg):
- raise ValueError("Internal error: invalid lengths. %s %s" % (
- len(input_placeholders), len(defined_function.signature.input_arg)))
- self._input_placeholders = input_placeholders
- self._extra_inputs = list(extra_inputs)
- self._graph = graph
- self._backward_function = None
- self._func_name = name
- self._function_def = defined_function
- self._num_outputs = len(defined_function.signature.output_arg)
- self._python_func_outputs = python_func_outputs
- self._python_returns = [python_func_outputs] if isinstance(
- python_func_outputs,
- (ops.Tensor, type(None))) else _flatten(python_func_outputs)
- self._output_shapes = output_shapes
- self._variables = variables if variables is not None else []
-
- # Find the variables that are components of something distributed and
- # put them into a {handle_tensor -> distributed variable object} map.
+ self._device_functions = tuple(
+ self._func_graph._device_functions_outer_to_inner) # pylint: disable=protected-access
+
+ self._inference_function = _EagerDefinedFunction(
+ _inference_name(self._func_graph.name), self._func_graph,
+ self._func_graph.inputs, self._func_graph.outputs, self._attrs)
+ self._backward_graph_function = None
+
+ # Map holding distributed variables, keyed by resource handle tensors.
self._distributed_variables = {}
strategy = distribution_strategy_context.get_distribution_strategy()
- for variable in self._variables:
+ for variable in self._func_graph.variables:
# If variable is not distributed, unwrap returns [variable].
component_variables = strategy.unwrap(variable)
- # Only add to the dictionary when the variable is actually distributed,
- # i.e. more than one component or the component is different from the
- # variable itself. component_variables cannot be empty.
+ # Only update the dictionary when the variable is actually distributed.
if (len(component_variables) > 1 or component_variables[0] != variable):
for component_variable in component_variables:
self._distributed_variables[component_variable.handle] = variable
- @property
- def variables(self):
- return self._variables
+ def __call__(self, *args):
+ """Executes the wrapped function."""
+ ctx = context.context()
+ device_functions = _get_device_functions(ctx, ops.get_default_graph())
+ if device_functions != self._device_functions:
+ raise ValueError(
+ "The current device stack does not match the device stack under "
+ "which the TensorFlow function '%s' was created.\n"
+ "Current device stack: %s\n%s device stack: %s" %
+ (self._inference_function.name, device_functions,
+ self._inference_function.name, self._device_functions))
+
+ for v in self._func_graph.variables:
+ if v.trainable:
+ tape.variable_accessed(v)
- def _construct_backprop_function(self):
- """Constructs the backprop function object for this function."""
- filtered_outputs = [x for x in self._python_returns if x is not None]
- # TODO(skyewm): use FuncGraph
- backwards_graph = CapturingGraph()
- backwards_graph._graph_key = self._graph._graph_key # pylint: disable=protected-access
- for collection in self._graph.collections:
- backwards_graph.get_collection_ref(
- collection)[:] = self._graph.get_collection(collection)
- backwards_graph.seed = self._graph.seed
- with backwards_graph.as_default():
- self._out_grad_placeholders = [
- graph_placeholder(x.dtype, x.shape) for x in filtered_outputs]
- in_gradients = gradients_impl._GradientsHelper( # pylint: disable=protected-access
- filtered_outputs,
- self._input_placeholders,
- grad_ys=self._out_grad_placeholders,
- src_graph=self._graph)
-
- backward_outputs = tuple(
- grad for grad in _flatten(in_gradients) if grad is not None)
- output_shapes = tuple(grad.shape for grad in backward_outputs)
-
- extra_inputs = backwards_graph.captures.keys()
- extra_placeholders = backwards_graph.captures.values()
-
- forward_name = _forward_name(self._func_name)
- # Note: we cannot have placeholder ops in the graph or the TPU compilation
- # pass fails.
- placeholder_ops = set([y.op for y in self._input_placeholders])
- function_ops = [x for x in self._graph.get_operations()
- if x not in placeholder_ops]
- self._forward_fdef = _EagerDefinedFunction(
- forward_name, self._graph, function_ops,
- self._input_placeholders, filtered_outputs + list(extra_inputs),
- self._attrs)
- all_inputs = self._out_grad_placeholders + list(extra_placeholders)
- # Excluding input ops from the body as we do not intend to execute these
- # operations when the function is executed.
- all_ignored_ops = frozenset(x.op for x in all_inputs)
- # Enforce a deterministic order of operations in the generated graph. This
- # means rerunning the function-defining code will always define the same
- # function, which is useful if we serialize this etc.
- function_def_ops = tuple(x
- for x in sorted(backwards_graph.get_operations(),
- key=lambda x: x.name)
- if x not in all_ignored_ops)
- bname = _backward_name(self._func_name)
- self._backward_function = GraphModeFunction(
- bname, all_inputs, [], backwards_graph, function_def_ops,
- backward_outputs, in_gradients, output_shapes, attrs=self._attrs)
+ captures = self._resolve_captured_inputs()
+ tensor_inputs = [x for x in nest.flatten(args) if isinstance(x, ops.Tensor)]
+ args = tensor_inputs + captures
- def _backprop_call(self, args):
- """Calls the wrapped function and records the result on a tape.
+ if tape.should_record(tensor_inputs) or tape.should_record(captures):
+ return self._backprop_call(args)
- (Only records results on a tape if the function has outputs)
+ outputs = self._inference_function.call(ctx, args)
+ return self._build_call_outputs(outputs)
- Args:
- args: All inputs to the function, including resolved extra inputs
- Returns:
- The call output.
- """
- ctx = context.context()
- outputs = self._forward_fdef.call(ctx, args, self._output_shapes)
- if isinstance(outputs, ops.Operation) or outputs is None:
- return outputs
+ @property
+ def graph(self):
+ """Returns the graph from which this function was constructed."""
+ return self._func_graph
- # `real_outputs` are the actual outputs of the inference graph function;
- # `side_outputs` are the intermediate Tensors that were added as outputs to
- # the forward graph function so that we can compute its gradient.
- real_outputs = outputs[:self._num_outputs]
- side_outputs = outputs[self._num_outputs:]
+ @property
+ def variables(self):
+ """Returns all variables touched by this function."""
+ return self._func_graph.variables
- def backward_function(*args):
- return self._backward_function(*(list(args) + side_outputs)) # pylint: disable=not-callable
+ @property
+ def inputs(self):
+ """Returns tensors in `self.graph` corresponding to arguments."""
+ return self._func_graph.inputs
- tape.record_operation(
- self._forward_fdef.signature.name,
- real_outputs,
- args,
- backward_function)
+ @property
+ def outputs(self):
+ """Returns tensors in `self.graph` corresponding to return values."""
+ return self._func_graph.outputs
- return self._build_call_outputs(real_outputs)
+ @property
+ def captured_inputs(self):
+ """Returns external Tensors captured by this function.
+
+ self.__call__(*args) passes `args + self.captured_inputs` to the function.
+ """
+ return self._captured_inputs
+
+ @property
+ def function_def(self):
+ """Returns a `FunctionDef` object representing this function."""
+ return self._inference_function.definition
@property
def output_shapes(self):
"""The function's output shapes."""
# TODO(ebrevdo): Should we only keep the output shapes associated
# with len(self._python_returns) outputs?
- outputs_list = nest.flatten(self._python_func_outputs)
+ # TODO(akshayka): Consider removing this.
+ outputs_list = nest.flatten(self._func_graph.structured_outputs)
j = 0
for i, o in enumerate(outputs_list):
if o is not None:
@@ -613,23 +584,80 @@ class GraphModeFunction(object):
else:
outputs_list[i] = self._output_shapes[j]
j += 1
- return nest.pack_sequence_as(self._python_func_outputs, outputs_list)
+ return nest.pack_sequence_as(self._func_graph.structured_outputs,
+ outputs_list)
@property
def output_dtypes(self):
- return nest.map_structure(
- lambda x: x.dtype if x is not None else None, self._python_func_outputs)
+ # TODO(akshayka): Consider removing this.
+ return nest.map_structure(lambda x: x.dtype if x is not None else None,
+ self._func_graph.structured_outputs)
- @property
- def captured_inputs(self):
- return self._extra_inputs
+ def _construct_backprop_function(self):
+ """Constructs the backprop function object for this function."""
+ backwards_graph = FuncGraph(_backward_name(self._func_graph.name))
+ with backwards_graph.as_default():
+ gradients_wrt_outputs = [
+ graph_placeholder(x.dtype, x.shape) for x in self._func_graph.outputs
+ ]
+ gradients_wrt_inputs = gradients_impl._GradientsHelper( # pylint: disable=protected-access
+ self._func_graph.outputs,
+ self._func_graph.inputs,
+ grad_ys=gradients_wrt_outputs,
+ src_graph=self._func_graph)
+
+ self._forward_function = _EagerDefinedFunction(
+ _forward_name(
+ self._func_graph.name), self._func_graph, self._func_graph.inputs,
+ self._func_graph.outputs + list(backwards_graph.captures.keys()),
+ self._attrs)
- @property
- def name(self):
- """Returns the name of the function in Eager-compatible format."""
- return self._function_def.name.encode("utf-8")
+ # The ordering of `backwards_graph.inputs` is important: inputs of
+ # `self._backward_graph_function` correspond to outputs of
+ # `self._forward_function`.
+ backwards_graph.inputs = gradients_wrt_outputs + list(
+ backwards_graph.captures.values())
+ # Clear captures, since we pass them in as inputs.
+ backwards_graph.captures = {}
+ backwards_graph.outputs.extend(
+ grad for grad in _flatten(gradients_wrt_inputs) if grad is not None)
+ backwards_graph.structured_outputs = gradients_wrt_inputs
+ self._backward_graph_function = Function(
+ backwards_graph, attrs=self._attrs)
+
+ def _backprop_call(self, args):
+ """Calls the forward function and records the result on a tape.
+
+ (Only records results on a tape if the function has outputs)
- def _resolve_extra_inputs(self):
+ Args:
+ args: All inputs to the function, including resolved captured inputs
+
+ Returns:
+ The call output.
+ """
+ if self._backward_graph_function is None:
+ self._construct_backprop_function()
+
+ ctx = context.context()
+ outputs = self._forward_function.call(ctx, args)
+ if isinstance(outputs, ops.Operation) or outputs is None:
+ return outputs
+
+ # `real_outputs` are the actual outputs of the inference graph function;
+ # `side_outputs` are the intermediate Tensors that were added as outputs to
+ # the forward graph function so that we can compute its gradient.
+ real_outputs = outputs[:self._num_outputs]
+ side_outputs = outputs[self._num_outputs:]
+
+ def backward_function(*args):
+ return self._backward_graph_function(*(list(args) + side_outputs)) # pylint: disable=not-callable
+
+ tape.record_operation(self._forward_function.signature.name, real_outputs,
+ args, backward_function)
+ return self._build_call_outputs(real_outputs)
+
+ def _resolve_captured_inputs(self):
"""Resolve captured distributed variables to their current values.
Some inputs can be distributed variables. Such variables yield a different
@@ -637,44 +665,23 @@ class GraphModeFunction(object):
execution.
Returns:
- a list of resolved extra input tensors.
+ a list of resolved captured input tensors.
"""
if self._distributed_variables:
- # Loop over each extra_inputs and check if it corresponds to something
+ # Loop over each captured input and check if it corresponds to something
# distributed. If so, get its _distributed_container and fetch the
# component appropriate for the current execution context.
- resolved_extra_inputs = self._extra_inputs[:]
- for i, extra_input in enumerate(self._extra_inputs):
- distributed_var = self._distributed_variables.get(extra_input, None)
+ resolved_captured_inputs = self._captured_inputs[:]
+ for i, captured_input in enumerate(self._captured_inputs):
+ distributed_var = self._distributed_variables.get(captured_input, None)
if distributed_var is not None:
# distributed variables override __getattr__ and substitute the
# right component variable. In here, `distributed_var.handle`
# actually does the equivalent of
# distributed_var.get_current_component_var().handle.
- resolved_extra_inputs[i] = distributed_var.handle
- return resolved_extra_inputs
-
- return self._extra_inputs
-
- def __call__(self, *args):
- """Executes the passed function in eager mode."""
- for v in self._variables:
- if v.trainable:
- tape.watch_variable(v)
-
- resolved_extra_inputs = self._resolve_extra_inputs()
-
- tensor_inputs = [x for x in nest.flatten(args) if isinstance(x, ops.Tensor)]
- args = tensor_inputs + resolved_extra_inputs
- if tape.should_record(tensor_inputs) or tape.should_record(
- resolved_extra_inputs):
- if self._backward_function is None:
- self._construct_backprop_function()
- return self._backprop_call(args)
-
- ctx = context.context()
- outputs = self._function_def.call(ctx, args, self._output_shapes)
- return self._build_call_outputs(outputs)
+ resolved_captured_inputs[i] = distributed_var.handle
+ return resolved_captured_inputs
+ return self._captured_inputs
def _build_call_outputs(self, result):
"""Maps the fdef output list to actual output structure.
@@ -684,12 +691,12 @@ class GraphModeFunction(object):
Returns:
The actual call output.
"""
- if self._python_func_outputs is None:
+ if self._func_graph.structured_outputs is None:
return result
# Use `nest.flatten` instead of `_flatten` in order to preserve any
- # IndexedSlices in `self._python_func_outputs`.
- outputs_list = nest.flatten(self._python_func_outputs)
+ # IndexedSlices in `self._func_graph.structured_outputs`.
+ outputs_list = nest.flatten(self._func_graph.structured_outputs)
j = 0
for i, o in enumerate(outputs_list):
if o is not None:
@@ -703,13 +710,13 @@ class GraphModeFunction(object):
j += 3
else:
outputs_list[i] = ops.IndexedSlices(
- values=result[j],
- indices=result[j + 1])
+ values=result[j], indices=result[j + 1])
j += 2
else:
outputs_list[i] = result[j]
j += 1
- ret = nest.pack_sequence_as(self._python_func_outputs, outputs_list)
+ ret = nest.pack_sequence_as(self._func_graph.structured_outputs,
+ outputs_list)
return ret
@@ -725,20 +732,18 @@ def _get_defun_inputs_from_signature(signature):
def _get_defun_inputs_from_args(args):
"""Maps python function args to graph-construction inputs."""
function_inputs = [
- graph_placeholder(arg.dtype, arg.shape) if isinstance(arg, ops.Tensor)
- else arg for arg in nest.flatten(args)
+ graph_placeholder(arg.dtype, arg.shape)
+ if isinstance(arg, ops.Tensor) else arg for arg in nest.flatten(args)
]
return nest.pack_sequence_as(args, function_inputs)
-def _trace_and_define_function(name, python_func, compiled, args, kwds,
- signature=None):
- """Defines and returns graph-mode version of `python_func`.
+def func_graph_from_py_func(name, python_func, args, kwds, signature=None):
+ """Returns a `FuncGraph` generated from `python_func`.
Args:
name: an identifier for the function.
python_func: the Python function to trace.
- compiled: whether the graph function should be compiled through XLA.
args: the positional args with which the Python function should be called;
ignored if a signature is provided.
kwds: the keyword args with which the Python function should be called;
@@ -750,14 +755,13 @@ def _trace_and_define_function(name, python_func, compiled, args, kwds,
inputs.
Returns:
- A GraphModeFunction.
+ A FuncGraph.
Raises:
TypeError: If any of `python_func`'s return values is neither `None` nor a
`Tensor`.
"""
- func_graph = FuncGraph(_inference_name(name), graph=ops.get_default_graph())
-
+ func_graph = FuncGraph(name)
with func_graph.as_default(), AutomaticControlDependencies() as a:
variable_scope.get_variable_scope().set_use_resource(True)
@@ -771,8 +775,7 @@ def _trace_and_define_function(name, python_func, compiled, args, kwds,
# Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`.
func_graph.inputs.extend(
x for x in nest.flatten(func_args) + nest.flatten(func_kwds)
- if isinstance(x, ops.Tensor)
- )
+ if isinstance(x, ops.Tensor))
# Variables to help check whether mutation happens in calling the function
# Copy the recursive list, tuple and map structure, but not base objects
@@ -797,6 +800,7 @@ def _trace_and_define_function(name, python_func, compiled, args, kwds,
this_tape = tape.push_new_tape()
try:
func_outputs = python_func(*func_args, **func_kwds)
+ # invariant: `func_outputs` contains only Tensors and `None`s.
func_outputs = nest.map_structure(convert, func_outputs)
def check_mutation(n1, n2):
@@ -816,53 +820,34 @@ def _trace_and_define_function(name, python_func, compiled, args, kwds,
check_mutation(func_args_before, func_args)
check_mutation(func_kwds_before, func_kwds)
-
finally:
tape.pop_tape(this_tape)
+
func_graph.structured_outputs = func_outputs
+ # Returning a closed-over tensor does not trigger convert_to_tensor.
+ func_graph.outputs.extend(
+ func_graph.capture(x)
+ for x in _flatten(func_graph.structured_outputs)
+ if x is not None)
+
+ # Some captured variables might be components of DistributedValues.
+ # Instead of storing non-distributed component variables, we
+ # store their distributed containers so we can retrieve the correct
+ # component variables at call-time.
variables = list(this_tape.watched_variables())
-
- # Some variables captured by the tape can come from a DistributedValue.
- # At call time, DistributedValue can return another variable (e.g. if
- # the function is run on a different device). Thus, instead of storing
- # the specific captured variable, we replace it with its distributed
- # container.
strategy = distribution_strategy_context.get_distribution_strategy()
for i, variable in enumerate(variables):
# If variable is not distributed value_container returns itself.
variables[i] = strategy.value_container(variable)
-
func_graph.variables = variables
- # Returning a closed-over tensor as an output does not trigger a
- # call to convert_to_tensor, so we manually capture all such tensors.
- func_graph.outputs.extend(
- func_graph.capture(x) for x in _flatten(func_graph.structured_outputs)
- if x is not None
- )
-
- output_shapes = tuple(
- x.shape if isinstance(x, ops.Tensor) else None
- for x in func_graph.outputs)
-
- all_ignored_ops = frozenset(x.op for x in func_graph.inputs)
- operations = tuple(x for x in func_graph.get_operations()
- if x not in all_ignored_ops)
- # Register any other functions defined in the graph
- # TODO(ashankar): Oh lord, forgive me for this lint travesty.
+ # Register any other functions defined in the graph.
if context.executing_eagerly():
for f in func_graph._functions.values(): # pylint: disable=protected-access
# TODO(ashankar): What about the gradient registry?
_register(f._c_func.func) # pylint: disable=protected-access
- attrs = {}
- if compiled:
- attrs[_xla_compile_attr] = attr_value_pb2.AttrValue(b=True)
-
- return GraphModeFunction(
- func_graph.name, func_graph.inputs, func_graph.captures.keys(),
- func_graph, operations, func_graph.outputs, func_graph.structured_outputs,
- output_shapes, func_graph.variables, attrs)
+ return func_graph
_TensorType = collections.namedtuple("_TensorType", ["dtype", "shape"])
@@ -911,13 +896,13 @@ def _deterministic_dict_values(dictionary):
return tuple(dictionary[key] for key in sorted(dictionary))
-class _PolymorphicFunction(object):
+class PolymorphicFunction(object):
"""Wrapper class for the graph functions defined for a Python function.
See the documentation for `defun` for more information on the semantics of
defined functions.
- _PolymorphicFunction class is thread-compatible meaning that minimal
+ PolymorphicFunction class is thread-compatible meaning that minimal
usage of defuns (defining and calling) is thread-safe, but if users call other
methods or invoke the base `python_function` themselves, external
synchronization is necessary.
@@ -926,8 +911,7 @@ class _PolymorphicFunction(object):
def __init__(self,
python_function,
name,
- input_signature=None,
- compiled=False):
+ input_signature=None):
"""Initializes a polymorphic function.
Args:
@@ -936,14 +920,10 @@ class _PolymorphicFunction(object):
input_signature: a possibly nested sequence of `TensorSpec` objects
specifying the input signature of this function. If `None`, a separate
function is instantiated for each inferred input signature.
- compiled: if True, the framework will attempt to compile func with XLA.
Raises:
ValueError: if `input_signature` is not None and the `python_function`'s
argspec has keyword arguments.
- TypeError: if `input_signature` contains anything other than
- `TensorSpec` objects, or (if not None) is anything other than a tuple or
- list.
"""
if isinstance(python_function, functools.partial):
@@ -955,8 +935,7 @@ class _PolymorphicFunction(object):
self._args_to_prepend = tuple()
self._kwds_to_include = {}
self._name = name
- self._compiled = compiled
- self._arguments_to_functions = {}
+ self._function_cache = collections.OrderedDict()
self._variables = []
self._lock = threading.Lock()
@@ -991,15 +970,40 @@ class _PolymorphicFunction(object):
self._input_signature = tuple(input_signature)
self._flat_input_signature = tuple(nest.flatten(input_signature))
- if any(not isinstance(arg, tensor_spec.TensorSpec)
- for arg in self._flat_input_signature):
- raise TypeError("Invalid input_signature %s; input_signature must be "
- "a possibly nested sequence of TensorSpec objects.")
+
+ def __call__(self, *args, **kwds):
+ """Calls a graph function specialized to the inputs."""
+ graph_function, inputs = self._maybe_define_function(*args, **kwds)
+ return graph_function(*inputs)
+
+ @property
+ def python_function(self):
+ """Returns the wrapped Python function."""
+ return self._python_function
+
+ # TODO(akshayka): Remove this property.
+ @property
+ def variables(self):
+ """Returns the union of all variables referenced by cached `Function`s`."""
+ return self._variables
+
+ def get_concrete_function(self, *args, **kwargs):
+ """Returns a `Function` object specialized to inputs and execution context.
+
+ `args` and `kwargs` are ignored if this `PolymorphicFunction` was created
+ with an `input_signature`.
+
+ Args:
+ *args: inputs to specialize on.
+ **kwargs: inputs to specialize on.
+ """
+ graph_function, _ = self._maybe_define_function(*args, **kwargs)
+ return graph_function
def __get__(self, instance, owner):
"""Makes it possible to defun instance methods."""
del owner
- # `instance` here is the instance that this `_PolymorphicFunction` was
+ # `instance` here is the instance that this `PolymorphicFunction` was
# accessed through; e.g., for
#
# class Foo(object):
@@ -1009,29 +1013,42 @@ class _PolymorphicFunction(object):
# ...
#
# foo = Foo()
- # foo.bar() # `foo.bar` is a `_PolymorphicFunction` instance
+ # foo.bar() # `foo.bar` is a `PolymorphicFunction` instance
#
# then `instance` will be `foo` (and `owner` will be `Foo`).
return functools.partial(self.__call__, instance)
- def _cache_key(self, args, kwds):
- """Computes the cache key given inputs."""
+ def _cache_key(self, args, kwds, ctx, graph):
+ """Computes the cache key given inputs and execution context."""
if self._input_signature is None:
inputs = (args, kwds) if kwds else args
cache_key = tuple(_encode_arg(arg) for arg in inputs)
else:
del args, kwds
cache_key = self._flat_input_signature
+
# The graph, or whether we're executing eagerly, should be a part of the
# cache key so we don't improperly capture tensors such as variables.
- return cache_key + (context.executing_eagerly() or ops.get_default_graph(),)
+ executing_eagerly = ctx.executing_eagerly()
+ execution_context = executing_eagerly or graph
+
+ # Putting the device in the cache key ensures that call-site device
+ # annotations are respected.
+ device_functions = _get_device_functions(ctx, graph)
+
+ # `ops.colocate_with` directives translate into `ops.device` directives when
+ # eager execution is enabled.
+ colocation_stack = (None if executing_eagerly else
+ tuple(graph._colocation_stack.peek_objs())) # pylint: disable=protected-access
+
+ return cache_key + (execution_context, device_functions, colocation_stack)
def _canonicalize_function_inputs(self, *args, **kwds):
"""Canonicalizes `args` and `kwds`.
Canonicalize the inputs to the Python function using its fullargspec. In
particular, we parse the varags and kwargs that this
- `_PolymorphicFunction` was called with into a tuple corresponding to the
+ `PolymorphicFunction` was called with into a tuple corresponding to the
Python function's positional (named) arguments and a dictionary
corresponding to its kwargs.
@@ -1085,8 +1102,9 @@ class _PolymorphicFunction(object):
if any(not isinstance(arg, ops.Tensor) for arg in flat_inputs):
raise ValueError("When input_signature is provided, all inputs to "
"the Python function must be Tensors.")
- tensor_specs = [tensor_spec.TensorSpec.from_tensor(tensor)
- for tensor in flat_inputs]
+ tensor_specs = [
+ tensor_spec.TensorSpec.from_tensor(tensor) for tensor in flat_inputs
+ ]
if any(not spec.is_compatible_with(other)
for spec, other in zip(self._flat_input_signature, tensor_specs)):
raise ValueError("Python inputs incompatible with input_signature: "
@@ -1111,42 +1129,33 @@ class _PolymorphicFunction(object):
"""
args, kwds = self._canonicalize_function_inputs(*args, **kwds)
- cache_key = self._cache_key(args, kwds)
+ cache_key = self._cache_key(args, kwds, context.context(),
+ ops.get_default_graph())
with self._lock:
try:
- graph_function = self._arguments_to_functions.get(cache_key, None)
+ graph_function = self._function_cache.get(cache_key, None)
except TypeError:
raise TypeError("Arguments supplied to `defun`-generated functions "
"must be hashable.")
if graph_function is None:
- graph_function = _trace_and_define_function(
- self._name, self._python_function, self._compiled, args, kwds,
- self._input_signature)
+ graph_function = Function(
+ func_graph_from_py_func(self._name, self._python_function, args,
+ kwds, self._input_signature))
self._variables.extend(
[v for v in graph_function.variables if v not in self._variables])
- self._arguments_to_functions[cache_key] = graph_function
+ self._function_cache[cache_key] = graph_function
return graph_function, (args, kwds)
- def __call__(self, *args, **kwds):
- """Calls a graph function specialized for this input signature."""
- graph_function, inputs = self._maybe_define_function(*args, **kwds)
- return graph_function(*inputs)
-
- def call_python_function(self, *args, **kwargs):
- """Directly calls the wrapped python function."""
- return self._python_function(*args, **kwargs)
- @property
- def variables(self):
- """Returns a list of variables used in any of the defined functions."""
- return self._variables
+def _validate_signature(signature):
+ if any(not isinstance(arg, tensor_spec.TensorSpec)
+ for arg in nest.flatten(signature)):
+ raise TypeError("Invalid input_signature %s; input_signature must be "
+ "a possibly nested sequence of TensorSpec objects.")
-# TODO(akshayka): Remove the `compiled` flag and create a separate
-# API for xla compilation (`defun` is already complicated enough
-# as it is, and the keyword argument makes 'compiled' an overloaded concept)
-def defun(func=None, input_signature=None, compiled=False):
+def defun(func=None, input_signature=None):
"""Compiles a Python function into a callable TensorFlow graph.
`defun` (short for "define function") trace-compiles a Python function
@@ -1221,6 +1230,7 @@ def defun(func=None, input_signature=None, compiled=False):
self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
self.keep_probability = keep_probability
+ @tf.contrib.eager.defun
def call(self, inputs, training=True):
x = self.dense2(self.dense1(inputs))
if training:
@@ -1229,7 +1239,6 @@ def defun(func=None, input_signature=None, compiled=False):
return x
model = MyModel()
- model.call = tf.contrib.eager.defun(model.call)
model(x, training=True) # executes a graph, with dropout
model(x, training=False) # executes a graph, without dropout
@@ -1437,9 +1446,10 @@ def defun(func=None, input_signature=None, compiled=False):
func: function to be compiled. If `func` is None, returns a
decorator that can be invoked with a single argument - `func`. The
end result is equivalent to providing all the arguments up front.
- In other words, defun(compiled=True)(func) is equivalent to
- defun(func, compiled=True). The former allows the following use case:
- @tf.contrib.eager.defun(compiled=True)
+ In other words, defun(input_signature=...)(func) is equivalent to
+ defun(func, input_signature=...). The former allows
+ the following use case:
+ @tf.contrib.eager.defun(input_signature=...)
def foo(...):
...
@@ -1450,17 +1460,20 @@ def defun(func=None, input_signature=None, compiled=False):
signature is specified, every input to `func` must be a `Tensor`, and
`func` cannot accept `**kwargs`.
- compiled: If True, an attempt to compile `func` with XLA will be made.
- If it fails, function will be run normally. Experimental. Currently
- supported only for execution on TPUs. For the vast majority of users,
- this argument should be False.
-
Returns:
If `func` is not None, returns a callable that will execute the compiled
function (and return zero or more `tf.Tensor` objects).
If `func` is None, returns a decorator that, when invoked with a single
`func` argument, returns a callable equivalent to the case above.
+
+ Raises:
+ TypeError: If `input_signature` is neither `None` nor a sequence of
+ `tf.contrib.eager.TensorSpec` objects.
"""
+
+ if input_signature is not None:
+ _validate_signature(input_signature)
+
# TODO(apassos): deal with captured global state. Deal with control flow.
def decorated(function):
try:
@@ -1469,8 +1482,7 @@ def defun(func=None, input_signature=None, compiled=False):
name = "function"
return tf_decorator.make_decorator(
function,
- _PolymorphicFunction(
- function, name, input_signature=input_signature, compiled=compiled))
+ PolymorphicFunction(function, name, input_signature=input_signature))
# This code path is for the `foo = tfe.defun(foo, ...)` use case
if func is not None:
@@ -1486,51 +1498,6 @@ def defun(func=None, input_signature=None, compiled=False):
return decorated
-def make_defun_op(func, *args, **kwds):
- """Compile func into graph_mode, assuming func arguments are *args, **kwargs.
-
- `make_defun_op` converts a function that constructs a TensorFlow graph into
- a function object and attaches it to the graph. The resulting function
- object can be queried for its properties, and called directly with different
- inputs to execute.
-
- More details on use cases and limitations are available in the
- documentation for `defun`.
-
- Example:
- ```python
- def f(x, y):
- return tf.reduce_mean(tf.multiply(x ** 2, 3) + y)
-
- def g(x, y):
- return tf.reduce_mean(tf.multiply(x ** 2, 3) + y)
-
- z = tf.constant([[0.0, 0.0]])
- g_op = make_defun_op(g, z, z)
-
- assert g_op.output_shapes == tf.TensorShape([])
- assert g_op.output_types == tf.float32
-
- x = tf.constant([[2.0, 3.0]])
- y = tf.constant([[3.0, -2.0]])
-
- # The plain function and defun-compiled function should return the same value.
- assert f(x, y).numpy() == g_op(x, y).numpy()
- ```
-
- Args:
- func: function to be compiled.
- *args: List arguments to pass to `func` when attaching to the graph.
- **kwds: Keyword arguments to pass to `func` when attaching to the graph.
-
- Returns:
- A wrapper object which can be queried for its output properties,
- and which can be called directly the way a `@defun` wrapped function
- can.
- """
- return _trace_and_define_function(func.__name__, func, False, args, kwds)
-
-
class AutomaticControlDependencies(object):
"""Context manager to automatically add control dependencies.
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 380bcf763f..92254a2c00 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -27,7 +27,6 @@ from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import function
-from tensorflow.python.eager import tape
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -105,7 +104,7 @@ class FunctionTest(test.TestCase):
self.assertAllEqual(step(), 2.0)
def testGraphGradientVariable(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
v = resource_variable_ops.ResourceVariable(1.0)
@function.defun
@@ -130,16 +129,16 @@ class FunctionTest(test.TestCase):
with ops.Graph().as_default():
self.assertEqual(f().shape, ())
- def testBasicDefunOpGraphMode(self):
+ def testBasicGraphFunction(self):
matmul = function.defun(math_ops.matmul)
+ @function.defun
def sq(a):
return matmul(a, a)
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
- sq_op = function.make_defun_op(sq, t)
-
+ sq_op = sq.get_concrete_function(t)
self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2]))
out = sq_op(t)
self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
@@ -211,33 +210,44 @@ class FunctionTest(test.TestCase):
random_seed.set_random_seed(1)
self.assertAllEqual(f(), x)
- def testNestedInputsDefunOpGraphMode(self):
+ def testSymGradGatherNd(self):
+ with ops.Graph().as_default(), self.cached_session() as sess:
+
+ @function.defun
+ def f(x):
+ return array_ops.gather_nd(x, [[0]])
+
+ c = constant_op.constant([[2.]])
+ f_c = f(c)
+ g, = gradients_impl.gradients(f_c, c)
+ self.assertAllEqual(sess.run(g), [[1.0]])
+
+ def testNestedInputsGraphFunction(self):
matmul = function.defun(math_ops.matmul)
pair = collections.namedtuple('pair', ['a', 'b'])
+ @function.defun
def a_times_b(inputs):
return matmul(inputs.a['a'], inputs.b['b'])
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
-
inputs = pair({'a': t}, {'b': t})
- sq_op = function.make_defun_op(a_times_b, inputs)
-
+ sq_op = a_times_b.get_concrete_function(inputs)
self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2]))
out = sq_op(inputs)
self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
- def testNestedOutputDefunOpGraphMode(self):
+ def testNestedOutputGraphFunction(self):
matmul = function.defun(math_ops.matmul)
+ @function.defun
def sq(a):
return (matmul(a, a), {'b': constant_op.constant(1.0)})
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
- sq_op = function.make_defun_op(sq, t)
-
+ sq_op = sq.get_concrete_function(t)
self.assertEqual(sq_op.output_shapes,
(tensor_shape.TensorShape([2, 2]),
{'b': tensor_shape.TensorShape([])}))
@@ -247,28 +257,28 @@ class FunctionTest(test.TestCase):
self.assertAllEqual(a, math_ops.matmul(t, t).numpy())
self.assertAllEqual(b['b'].numpy(), 1.0)
- def testDefunOpGraphModeWithGradients(self):
+ def testGraphFunctionWithGradients(self):
v = resource_variable_ops.ResourceVariable(1.0, name='v')
+ @function.defun
def step():
def inner():
return v * v
return backprop.implicit_grad(inner)()[0][0]
- step_op = function.make_defun_op(step)
-
+ step_op = step.get_concrete_function()
self.assertEqual(step_op.output_dtypes, dtypes.float32)
self.assertEqual(step_op.output_shapes, tensor_shape.TensorShape([]))
self.assertAllEqual(step_op(), 2.0)
- def testDefunOpGraphModeNoneOutput(self):
+ def testGraphFunctionNoneOutput(self):
+ @function.defun
def fn(unused_a, unused_b):
return None
x = constant_op.constant(1)
- fn_op = function.make_defun_op(fn, x, x)
-
+ fn_op = fn.get_concrete_function(x, x)
self.assertEqual(fn_op.output_dtypes, None)
self.assertEqual(fn_op.output_shapes, None)
self.assertAllEqual(fn_op(x, x), None)
@@ -309,13 +319,13 @@ class FunctionTest(test.TestCase):
x = random_ops.random_uniform([2, 2]).numpy()
defined = function.defun(f)
defined(x)
- self.assertEqual(len(defined._arguments_to_functions), 1)
+ self.assertEqual(len(defined._function_cache), 1)
x = random_ops.random_uniform([2, 2]).numpy()
defined(x)
# A NumPy array with different values but the same shape and dtype
# shouldn't trigger another function definition.
- self.assertEqual(len(defined._arguments_to_functions), 1)
+ self.assertEqual(len(defined._function_cache), 1)
def testDefunCapturedInt32(self):
x = constant_op.constant(1, dtype=dtypes.int32)
@@ -346,6 +356,47 @@ class FunctionTest(test.TestCase):
self.assertEqual(3.0, float(test_assign_add()))
+ @test_util.run_in_graph_and_eager_modes
+ def testTensorInitializationInFunctionRaisesError(self):
+ error_msg = ('Tensor-typed variable initializers must either be '
+ 'wrapped in an init_scope or callable.*')
+
+ @function.defun
+ def tensor_init():
+ with self.assertRaisesRegexp(ValueError, error_msg):
+ resource_variable_ops.ResourceVariable(constant_op.constant(2.0))
+
+ tensor_init()
+
+ @test_util.run_in_graph_and_eager_modes
+ def testCallableTensorInitializationInFunction(self):
+
+ @function.defun
+ def tensor_init():
+ v = resource_variable_ops.ResourceVariable(
+ lambda: constant_op.constant(2.0))
+ return v.read_value()
+
+ value = tensor_init()
+ if not context.executing_eagerly():
+ self.evaluate(variables.global_variables_initializer())
+ self.assertEqual(self.evaluate(value), 2.0)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testInitScopeTensorInitializationInFunction(self):
+
+ @function.defun
+ def tensor_init():
+ with ops.init_scope():
+ const = constant_op.constant(2.0)
+ v = resource_variable_ops.ResourceVariable(const)
+ return v.read_value()
+
+ value = tensor_init()
+ if not context.executing_eagerly():
+ self.evaluate(variables.global_variables_initializer())
+ self.assertEqual(self.evaluate(value), 2.0)
+
def testDefunShapeInferenceWithCapturedResourceVariable(self):
v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]])
@@ -430,7 +481,7 @@ class FunctionTest(test.TestCase):
self.assertAllEqual(backprop.implicit_grad(f)()[0][0], 2.0)
def testGraphModeCaptureVariable(self):
- with context.graph_mode(), self.test_session() as sess:
+ with context.graph_mode(), self.cached_session() as sess:
class HasAVar(object):
@@ -458,12 +509,12 @@ class FunctionTest(test.TestCase):
x = constant_op.constant(1.0)
l = f(x, v)
_, dv = gradients_impl.gradients(l, [x, v])
- with self.test_session():
+ with self.cached_session():
v.initializer.run()
self.assertAllEqual(dv.eval(), 0.0)
def testGraphModeManyFunctions(self):
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
@function.defun
def f(x):
@@ -564,7 +615,6 @@ class FunctionTest(test.TestCase):
@function.defun
def g(x):
- tape.watch_variable(x)
y = math_ops.add(x, three)
f(y)
@@ -578,7 +628,6 @@ class FunctionTest(test.TestCase):
return math_ops.add(x, three)
def g(x):
- tape.watch_variable(three)
return f(x)
g = backprop.implicit_grad(g)(constant_op.constant(1.0))[0][0]
@@ -633,17 +682,19 @@ class FunctionTest(test.TestCase):
def testReturningIndexedSlicesWithDefun(self):
def validate(indexed_slice):
+ @function.defun
def f():
return indexed_slice
- output = function.defun(f)()
+ output = f()
self.assertTrue(isinstance(output, ops.IndexedSlices))
self.assertAllEqual(indexed_slice.values, output.values)
self.assertAllEqual(indexed_slice.indices, output.indices)
self.assertAllEqual(indexed_slice.dense_shape, output.dense_shape)
self.assertEqual(
- function.make_defun_op(f).output_shapes, indexed_slice.values.shape)
+ f.get_concrete_function().output_shapes,
+ indexed_slice.values.shape)
arg = ops.IndexedSlices(
values=constant_op.constant([1, 2]),
@@ -883,7 +934,7 @@ class FunctionTest(test.TestCase):
self.assertEqual(1, int(read()))
def testReturnCapturedGraphTensor(self):
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
t = constant_op.constant(1)
@function.defun
@@ -966,39 +1017,109 @@ class FunctionTest(test.TestCase):
config=config_pb2.ConfigProto(device_count={'CPU': 4}))
def testDeviceAnnotationsRespected(self):
- @function.defun
def multi_device_fn():
with ops.device('/cpu:0'):
- s1 = iterator_ops.Iterator.from_structure(
+ s0 = iterator_ops.Iterator.from_structure(
(dtypes.float32,)).string_handle()
with ops.device('/cpu:1'):
- s2 = iterator_ops.Iterator.from_structure(
+ s1 = iterator_ops.Iterator.from_structure(
(dtypes.float32,)).string_handle()
with ops.device('/cpu:2'):
- s3 = iterator_ops.Iterator.from_structure(
- (dtypes.float32,)).string_handle()
- with ops.device(''):
- # TODO(akshayka): This is unfortunate and brittle. It prevents
- # `Iterator.from_structure` from assigning the iterator op to 'cpu:0'.
- # Remove this hack once we have a way of obtaining metadata about
- # function execution.
- s4 = iterator_ops.Iterator.from_structure(
+ s2 = iterator_ops.Iterator.from_structure(
(dtypes.float32,)).string_handle()
- return s1, s2, s3, s4
+ s3 = iterator_ops.Iterator.from_structure(
+ (dtypes.float32,)).string_handle()
+ return s0, s1, s2, s3
- with ops.device('/cpu:3'):
- outputs = self.evaluate(multi_device_fn())
+ defined = function.defun(multi_device_fn)
+ outputs = self.evaluate(defined())
+ self.assertEqual(len(defined._function_cache), 1)
self.assertIn(compat.as_bytes('CPU:0'), outputs[0])
self.assertIn(compat.as_bytes('CPU:1'), outputs[1])
self.assertIn(compat.as_bytes('CPU:2'), outputs[2])
- self.assertIn(compat.as_bytes('CPU:3'), outputs[3])
- with ops.device('/cpu:0'):
- outputs = self.evaluate(multi_device_fn())
+ with ops.device('/cpu:3'):
+ outputs = self.evaluate(defined())
+ self.assertEqual(len(defined._function_cache), 2)
self.assertIn(compat.as_bytes('CPU:0'), outputs[0])
self.assertIn(compat.as_bytes('CPU:1'), outputs[1])
self.assertIn(compat.as_bytes('CPU:2'), outputs[2])
- self.assertIn(compat.as_bytes('CPU:0'), outputs[3])
+ self.assertIn(compat.as_bytes('CPU:3'), outputs[3])
+
+ # This should retrieve the call-site-device agnostic function
+ defined()
+ self.assertEqual(len(defined._function_cache), 2)
+
+ # And this should retrieve the function created for '/cpu:3'
+ with ops.device('/cpu:3'):
+ defined()
+ self.assertEqual(len(defined._function_cache), 2)
+
+ @test_util.run_in_graph_and_eager_modes(
+ config=config_pb2.ConfigProto(device_count={'CPU': 2}))
+ def testCallingGraphFunctionOnIncompatibleDeviceRaisesError(self):
+
+ def func():
+ return constant_op.constant(0)
+
+ defined = function.defun(func)
+ with ops.device('cpu:0'):
+ cpu_graph_function = defined.get_concrete_function()
+
+ with ops.device('cpu:0'):
+ self.assertEqual(
+ self.evaluate(cpu_graph_function()), self.evaluate(func()))
+
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'The current device stack does not match the device stack under '
+ 'which the TensorFlow function \'.*func.*\' was created.\n'
+ 'Current device stack: .*\n.*func.* device stack.*'):
+ with ops.device('cpu:1'):
+ cpu_graph_function()
+
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'The current device stack does not match the device stack under '
+ 'which the TensorFlow function \'.*func.*\' was created.\n'
+ 'Current device stack: .*\n.*func.* device stack.*'):
+ with ops.device(None):
+ cpu_graph_function()
+
+ default_graph_function = defined.get_concrete_function()
+ self.assertEqual(
+ self.evaluate(default_graph_function()), self.evaluate(func()))
+
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'The current device stack does not match the device stack under '
+ 'which the TensorFlow function \'.*func.*\' was created.\n'
+ 'Current device stack: .*\n.*func.* device stack.*'):
+ with ops.device('cpu:1'):
+ default_graph_function()
+
+ @test_util.run_in_graph_and_eager_modes
+ def testColocateWithRespected(self):
+ # TODO(b/113291792): Use multiple CPUs instead of a GPU.
+ if not context.context().num_gpus():
+ self.skipTest('No GPUs found.')
+
+ with ops.device('cpu:0'):
+ x = constant_op.constant(1.0)
+
+ with ops.device('gpu:0'):
+ y = constant_op.constant(1.0)
+
+ @function.defun
+ def foo():
+ return iterator_ops.Iterator.from_structure(
+ (dtypes.float32,)).string_handle()
+
+ with ops.colocate_with(x):
+ self.assertIn(compat.as_bytes('CPU:0'), self.evaluate(foo()))
+
+ with ops.colocate_with(y):
+ self.assertIn(compat.as_bytes('GPU:0'), self.evaluate(foo()))
def testVariablesAreTracked(self):
v = resource_variable_ops.ResourceVariable(1.0)
@@ -1027,26 +1148,31 @@ class FunctionTest(test.TestCase):
defined = function.defun(func)
defined(0, baz=20)
+
+ def cache_keys():
+ """Sanitizes cache keys of non-input metadata."""
+ return tuple(key[:3] for key in defined._function_cache)
+
# `True` corresponds to the fact that we're executing eagerly
- self.assertIn((0, 1, 20, True), defined._arguments_to_functions)
+ self.assertIn((0, 1, 20), cache_keys())
defined(1) # bar=1, baz=2
- self.assertIn((1, 1, 2, True), defined._arguments_to_functions)
+ self.assertIn((1, 1, 2), cache_keys())
# This matches the previous call.
defined(foo=1)
- self.assertEqual(len(defined._arguments_to_functions), 2)
+ self.assertEqual(len(defined._function_cache), 2)
defined(1, 2, 3)
- self.assertIn((1, 2, 3, True), defined._arguments_to_functions)
+ self.assertIn((1, 2, 3), cache_keys())
# This matches the previous call.
defined(1, bar=2, baz=3)
- self.assertEqual(len(defined._arguments_to_functions), 3)
+ self.assertEqual(len(defined._function_cache), 3)
# This matches the previous call.
defined(1, baz=3, bar=2)
- self.assertEqual(len(defined._arguments_to_functions), 3)
+ self.assertEqual(len(defined._function_cache), 3)
def testFunctoolsPartialUnwrappedCorrectly(self):
@@ -1072,7 +1198,7 @@ class FunctionTest(test.TestCase):
defined = function.defun(foo, input_signature=signature)
a = array_ops.ones([2])
out = defined(a)
- self.assertEqual(len(defined._arguments_to_functions), 1)
+ self.assertEqual(len(defined._function_cache), 1)
self.assertAllEqual(out, a)
def bar(a):
@@ -1083,13 +1209,13 @@ class FunctionTest(test.TestCase):
defined = function.defun(bar, input_signature=signature)
a = array_ops.ones([2, 1])
out = defined(a)
- self.assertEqual(len(defined._arguments_to_functions), 1)
+ self.assertEqual(len(defined._function_cache), 1)
self.assertAllEqual(out, a)
# Changing the second dimension shouldn't create a new function.
b = array_ops.ones([2, 3])
out = defined(b)
- self.assertEqual(len(defined._arguments_to_functions), 1)
+ self.assertEqual(len(defined._function_cache), 1)
self.assertAllEqual(out, b)
def testNestedInputSignatures(self):
@@ -1106,7 +1232,7 @@ class FunctionTest(test.TestCase):
a = array_ops.ones([2, 1])
b = array_ops.ones([1])
out = defined([a, a], b)
- self.assertEqual(len(defined._arguments_to_functions), 1)
+ self.assertEqual(len(defined._function_cache), 1)
nest.assert_same_structure(out, [[a, a], b])
self.assertAllEqual(out[0][0], a)
self.assertAllEqual(out[0][1], a)
@@ -1117,7 +1243,7 @@ class FunctionTest(test.TestCase):
b = array_ops.ones([2, 5])
c = array_ops.ones([1])
out = defined([a, b], c)
- self.assertEqual(len(defined._arguments_to_functions), 1)
+ self.assertEqual(len(defined._function_cache), 1)
nest.assert_same_structure(out, [[a, b], c])
self.assertAllEqual(out[0][0], a)
self.assertAllEqual(out[0][1], b)
@@ -1153,13 +1279,13 @@ class FunctionTest(test.TestCase):
# Signatures must consist exclusively of `TensorSpec` objects.
signature = [(2, 3), tensor_spec.TensorSpec([2, 3], dtypes.float32)]
with self.assertRaisesRegexp(TypeError, 'Invalid input_signature.*'):
- function.defun(foo, input_signature=signature)(1, 2)
+ function.defun(foo, input_signature=signature)
# Signatures must be either lists or tuples on their outermost levels.
signature = {'t1': tensor_spec.TensorSpec([], dtypes.float32)}
with self.assertRaisesRegexp(TypeError, 'input_signature must be either a '
'tuple or a list.*'):
- function.defun(foo, input_signature=signature)(1, 2)
+ function.defun(foo, input_signature=signature)
def testInputsIncompatibleWithSignatureRaisesError(self):
@@ -1213,22 +1339,22 @@ class FunctionTest(test.TestCase):
integer = constant_op.constant(2, dtypes.int64)
out1, out2 = foo(flt, integer)
- self.assertEqual(len(foo._arguments_to_functions), 1)
+ self.assertEqual(len(foo._function_cache), 1)
self.assertEqual(out1.numpy(), 1.0)
self.assertEqual(out2.numpy(), 2)
out1, out2 = foo(flt=flt, integer=integer)
- self.assertEqual(len(foo._arguments_to_functions), 1)
+ self.assertEqual(len(foo._function_cache), 1)
self.assertEqual(out1.numpy(), 1.0)
self.assertEqual(out2.numpy(), 2)
out1, out2 = foo(integer=integer, flt=flt)
- self.assertEqual(len(foo._arguments_to_functions), 1)
+ self.assertEqual(len(foo._function_cache), 1)
self.assertEqual(out1.numpy(), 1.0)
self.assertEqual(out2.numpy(), 2)
out1, out2 = foo(flt, integer=integer)
- self.assertEqual(len(foo._arguments_to_functions), 1)
+ self.assertEqual(len(foo._function_cache), 1)
self.assertEqual(out1.numpy(), 1.0)
self.assertEqual(out2.numpy(), 2)
@@ -1258,27 +1384,27 @@ class FunctionTest(test.TestCase):
a = constant_op.constant(2.0)
b = constant_op.constant([1.0, 2.0])
one = defined(a, b)
- self.assertEqual(len(defined._arguments_to_functions), 1)
+ self.assertEqual(len(defined._function_cache), 1)
two = defined(a=a, b=b)
- self.assertEqual(len(defined._arguments_to_functions), 1)
+ self.assertEqual(len(defined._function_cache), 1)
three = defined(b=b, a=a)
- self.assertEqual(len(defined._arguments_to_functions), 1)
+ self.assertEqual(len(defined._function_cache), 1)
four = defined(a, b=b)
- self.assertEqual(len(defined._arguments_to_functions), 1)
+ self.assertEqual(len(defined._function_cache), 1)
# The next call corresponds to a new input signature, hence
# we expect another function to be defined.
five = defined(b, a)
- self.assertEqual(len(defined._arguments_to_functions), 2)
+ self.assertEqual(len(defined._function_cache), 2)
six = defined(a=b, b=a)
- self.assertEqual(len(defined._arguments_to_functions), 2)
+ self.assertEqual(len(defined._function_cache), 2)
seven = defined(b=a, a=b)
- self.assertEqual(len(defined._arguments_to_functions), 2)
+ self.assertEqual(len(defined._function_cache), 2)
self.assertAllEqual(one, [1.0, 2.0])
self.assertAllEqual(two, [1.0, 2.0])
@@ -1298,14 +1424,14 @@ class FunctionTest(test.TestCase):
grad_t, = backprop.gradients_function(sq, [0])(t)
self.assertAllEqual(grad_t, [[6, 6], [14, 14]])
- with backprop.GradientTape(persistent=True) as gtape:
- gtape.watch(t)
+ with backprop.GradientTape(persistent=True) as tape:
+ tape.watch(t)
one = matmul(t, b=t, transpose_a=True)
two = matmul(b=t, a=t, transpose_a=True)
three = matmul(a=t, b=t, transpose_a=True)
for output in [one, two, three]:
- self.assertAllEqual(gtape.gradient(output, t), [[6, 6], [14, 14]])
+ self.assertAllEqual(tape.gradient(output, t), [[6, 6], [14, 14]])
def testGradientInFunctionWithKeywordArguments(self):
@@ -1363,7 +1489,7 @@ class FunctionTest(test.TestCase):
self.assertAllEqual(state, [0])
# Whereas calling the python function directly should create a side-effect.
- side_effecting_function.call_python_function()
+ side_effecting_function.python_function()
self.assertAllEqual(state, [0, 0])
@@ -1371,7 +1497,7 @@ class FunctionTest(test.TestCase):
class AutomaticControlDependenciesTest(test.TestCase):
def testBasic(self):
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
v = resource_variable_ops.ResourceVariable(1.0)
variables.global_variables_initializer().run()
with function.AutomaticControlDependencies() as c:
@@ -1382,7 +1508,7 @@ class AutomaticControlDependenciesTest(test.TestCase):
self.assertAllEqual(val.eval(), 4.0)
def testCondMustRun(self):
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
v = resource_variable_ops.ResourceVariable(1.0)
variables.global_variables_initializer().run()
p = array_ops.placeholder(dtype=dtypes.bool)
@@ -1403,7 +1529,7 @@ class AutomaticControlDependenciesTest(test.TestCase):
self.assertAllEqual(val.eval(feed_dict={p: True}), 6.0)
def testCondMustRunSeparateRead(self):
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
v = resource_variable_ops.ResourceVariable(1.0)
variables.global_variables_initializer().run()
p = array_ops.placeholder(dtype=dtypes.bool)
@@ -1426,7 +1552,7 @@ class AutomaticControlDependenciesTest(test.TestCase):
self.assertAllEqual(v.read_value().eval(), 6.0)
def testCondNested(self):
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
v = resource_variable_ops.ResourceVariable(1.0)
variables.global_variables_initializer().run()
p = array_ops.placeholder(dtype=dtypes.bool)
@@ -1460,7 +1586,7 @@ class AutomaticControlDependenciesTest(test.TestCase):
self.assertAllEqual(val.eval(feed_dict={p: True, q: False}), 8.0)
def testCondOneBranch(self):
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
v = resource_variable_ops.ResourceVariable(1.0)
variables.global_variables_initializer().run()
p = array_ops.placeholder(dtype=dtypes.bool)
@@ -1480,7 +1606,7 @@ class AutomaticControlDependenciesTest(test.TestCase):
self.assertAllEqual(val.eval(feed_dict={p: True}), 5.0)
def testCondOneBranchUpdateBefore(self):
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
v = resource_variable_ops.ResourceVariable(1.0)
variables.global_variables_initializer().run()
p = array_ops.placeholder(dtype=dtypes.bool)
@@ -1501,7 +1627,7 @@ class AutomaticControlDependenciesTest(test.TestCase):
self.assertAllEqual(val.eval(feed_dict={p: True}), 12.0)
def testCondOneBranchUpdateAfter(self):
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
v = resource_variable_ops.ResourceVariable(1.0)
variables.global_variables_initializer().run()
p = array_ops.placeholder(dtype=dtypes.bool)
@@ -1537,7 +1663,7 @@ class AutomaticControlDependenciesTest(test.TestCase):
self.assertAllEqual(out, [3, 4, 5])
def testDecorator(self):
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
v = resource_variable_ops.ResourceVariable(1.0)
variables.global_variables_initializer().run()
diff --git a/tensorflow/python/eager/graph_callable.py b/tensorflow/python/eager/graph_callable.py
deleted file mode 100644
index 7105d2e399..0000000000
--- a/tensorflow/python/eager/graph_callable.py
+++ /dev/null
@@ -1,435 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Decorator that produces a callable object that executes a TensorFlow graph.
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import contextlib
-
-from tensorflow.python.eager import context
-from tensorflow.python.eager import function
-from tensorflow.python.eager import tape
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops as tf_ops
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import resource_variable_ops
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.util import nest
-from tensorflow.python.util import tf_decorator
-from tensorflow.python.util import tf_inspect
-
-
-def _default_initializer(name, shape, dtype):
- """The default initializer for variables."""
- # pylint: disable=protected-access
- store = variable_scope._get_default_variable_store()
- initializer = store._get_default_initializer(name, shape=shape, dtype=dtype)
- # pylint: enable=protected-access
- return initializer[0]
-
-
-class _CapturedVariable(object):
- """Variable captured by graph_callable.
-
- Internal to the implementation of graph_callable. Created only by
- _VariableCapturingScope and used only to read the variable values when calling
- the function after the variables are initialized.
- """
-
- def __init__(self, name, initializer, shape, dtype, trainable):
- self.name = name
- if initializer is None:
- initializer = _default_initializer(name, shape, dtype)
- initial_value = lambda: initializer(shape, dtype=dtype)
-
- with context.eager_mode():
- self.variable = resource_variable_ops.ResourceVariable(
- initial_value=initial_value, name=name, dtype=dtype,
- trainable=trainable)
- self.shape = shape
- self.dtype = dtype
- self.placeholder = None
- self.trainable = trainable
-
- def read(self, want_gradients=True):
- if want_gradients and self.trainable:
- v = tape.watch_variable(self.variable)
- else:
- v = self.variable
- return v.read_value()
-
-
-class _VariableCapturingScope(object):
- """Variable-scope-like object which captures tf.get_variable calls.
-
- This is responsible for the main difference between the initialization version
- of a function object and the calling version of a function object.
-
- capturing_scope replaces calls to tf.get_variable with placeholder tensors to
- be fed the variable's current value. TODO(apassos): these placeholders should
- instead be objects implementing a similar API to tf.Variable, for full
- compatibility.
-
- initializing_scope replaces calls to tf.get_variable with creation of
- variables and initialization of their values. This allows eventual support of
- initialized_value and friends.
-
- TODO(apassos): once the eager mode layers API is implemented support eager
- func-to-object as well.
- """
-
- def __init__(self):
- self.variables = {}
- self.tf_variables = {}
-
- @contextlib.contextmanager
- def capturing_scope(self):
- """Context manager to capture variable creations.
-
- Replaces variable accesses with placeholders.
-
- Yields:
- nothing
- """
- # TODO(apassos) ignoring the regularizer and partitioner here; figure out
- # how to deal with these.
- def _custom_getter( # pylint: disable=missing-docstring
- getter=None,
- name=None,
- shape=None,
- dtype=dtypes.float32,
- initializer=None,
- regularizer=None,
- reuse=None,
- trainable=None,
- collections=None,
- caching_device=None, # pylint: disable=redefined-outer-name
- partitioner=None,
- validate_shape=True,
- use_resource=None,
- aggregation=variable_scope.VariableAggregation.NONE,
- synchronization=variable_scope.VariableSynchronization.AUTO):
- del getter, regularizer, partitioner, validate_shape, use_resource, dtype
- del collections, initializer, trainable, reuse, caching_device, shape
- del aggregation, synchronization
- assert name in self.variables
- v = self.variables[name]
- return v.variable
-
- scope = variable_scope.get_variable_scope()
- with variable_scope.variable_scope(scope, custom_getter=_custom_getter):
- yield
-
- @contextlib.contextmanager
- def initializing_scope(self):
- """Context manager to capture variable creations.
-
- Forcibly initializes all created variables.
-
- Yields:
- nothing
- """
- # TODO(apassos) ignoring the regularizer and partitioner here; figure out
- # how to deal with these.
- def _custom_getter( # pylint: disable=missing-docstring
- getter=None,
- name=None,
- shape=None,
- dtype=dtypes.float32,
- initializer=None,
- regularizer=None,
- reuse=None,
- trainable=None,
- collections=None,
- caching_device=None, # pylint: disable=redefined-outer-name
- partitioner=None,
- validate_shape=True,
- use_resource=None,
- aggregation=variable_scope.VariableAggregation.NONE,
- synchronization=variable_scope.VariableSynchronization.AUTO):
- del getter, regularizer, collections, caching_device, partitioner
- del use_resource, validate_shape, aggregation, synchronization
- if name in self.tf_variables:
- if reuse:
- return self.tf_variables[name].initialized_value()
- else:
- raise ValueError("Specified reuse=%s but tried to reuse variables."
- % reuse)
- # TODO(apassos): ensure this is on the same device as above
- v = _CapturedVariable(name, initializer, shape, dtype, trainable)
- self.variables[name] = v
-
- graph_mode_resource = v.variable.handle
- if initializer is None:
- initializer = _default_initializer(name, shape, dtype)
- resource_variable_ops.shape_safe_assign_variable_handle(
- graph_mode_resource, v.variable.shape, initializer(shape, dtype))
- return v.variable
-
- scope = variable_scope.get_variable_scope()
- with variable_scope.variable_scope(scope, custom_getter=_custom_getter):
- yield
-
-
-class _InitializingFunctionObject(object):
- """Responsible for deciding which version of func-to-object to call.
-
- call_fn is the version which calls the function with the current values of the
- variables and init_fn is the version which calls the function to initialize
- all variables.
-
- TODO(apassos): figure out a way to support initializing only _some_
- variables. This requires a way to pull out a variable's initialization code
- from the graph, which might not be possible in general.
- """
-
- def __init__(self, call_fn, init_fn, shape_and_dtypes):
- self._init_fn = init_fn
- self._call_fn = call_fn
- self.shape_and_dtypes = shape_and_dtypes
- self.flattened_shapes = [tensor_shape.as_shape(sd.shape) for sd in
- nest.flatten(self.shape_and_dtypes)]
-
- @property
- def variables(self):
- return self._call_fn.variables
-
- def __call__(self, *args):
- nest.assert_same_structure(self.shape_and_dtypes, args, check_types=False)
- if not all([
- shape.is_compatible_with(arg.shape)
- for shape, arg in zip(self.flattened_shapes, nest.flatten(args))
- ]):
- raise ValueError(
- "Declared shapes do not match argument shapes: Expected %s, found %s."
- % (self.flattened_shapes, [arg.shape for arg in nest.flatten(args)]))
-
- initialized = [resource_variable_ops.var_is_initialized_op(
- v.handle).numpy() for v in self._call_fn.variables]
- if all(x for x in initialized):
- for v in self._call_fn.variables:
- if v.trainable:
- tape.watch_variable(v)
- return self._call_fn(*args)
- elif all(not x for x in initialized):
- return self._init_fn(*args)
- else:
- raise ValueError("Some, but not all, variables are initialized.")
-
-
-def _get_graph_callable_inputs(shape_and_dtypes):
- """Maps specified shape_and_dtypes to graph inputs."""
- ret = []
- for x in shape_and_dtypes:
- if isinstance(x, ShapeAndDtype):
- ret.append(array_ops.placeholder(x.dtype, x.shape))
- elif isinstance(x, (tuple, list)):
- ret.append(_get_graph_callable_inputs(x))
- else:
- raise errors.InvalidArgumentError(
- None, None, "Expected the argument to @graph_callable to be a "
- "(possibly nested) list or tuple of ShapeAndDtype objects, "
- "but got an object of type: %s" % type(x))
-
- return tuple(ret) if isinstance(shape_and_dtypes, tuple) else ret
-
-
-def _graph_callable_internal(func, shape_and_dtypes):
- """Defines and returns a template version of func.
-
- Under the hood we make two function objects, each wrapping a different version
- of the graph-mode code. One version immediately runs variable initialization
- before making the variable's Tensors available for use, while the other
- version replaces the Variables with placeholders which become function
- arguments and get the current variable's value.
-
- Limitations in (2) and (4) are because this does not implement a graph-mode
- Variable class which has a convert_to_tensor(as_ref=True) method and a
- initialized_value method. This is fixable.
-
- Args:
- func: The tfe Python function to compile.
- shape_and_dtypes: A possibly nested list or tuple of ShapeAndDtype objects.
-
- Raises:
- ValueError: If any one of func's outputs is not a Tensor.
-
- Returns:
- Callable graph object.
- """
- container = tf_ops.get_default_graph()._container # pylint: disable=protected-access
- graph_key = tf_ops.get_default_graph()._graph_key # pylint: disable=protected-access
- with context.graph_mode():
- # This graph will store both the initialization and the call version of the
- # wrapped function. It will later be used by the backprop code to build the
- # backprop graph, if necessary.
- tmp_graph = function.CapturingGraph()
- # Inherit the graph key from the original graph to ensure optimizers don't
- # misbehave.
- tmp_graph._container = container # pylint: disable=protected-access
- tmp_graph._graph_key = graph_key # pylint: disable=protected-access
- with tmp_graph.as_default():
- # Placeholders for the non-variable inputs.
- func_inputs = _get_graph_callable_inputs(shape_and_dtypes)
- func_num_args = len(tf_inspect.getfullargspec(func).args)
- if len(func_inputs) != func_num_args:
- raise TypeError("The number of arguments accepted by the decorated "
- "function `%s` (%d) must match the number of "
- "ShapeAndDtype objects passed to the graph_callable() "
- "decorator (%d)." %
- (func.__name__, func_num_args, len(func_inputs)))
-
- # First call the function to generate a graph which can initialize all
- # variables. As a side-effect this will populate the variable capturing
- # scope's view of which variables exist.
- variable_captures = _VariableCapturingScope()
- with variable_captures.initializing_scope(
- ), function.AutomaticControlDependencies() as a:
- func_outputs = func(*func_inputs)
- outputs_list = nest.flatten(func_outputs)
- for i, x in enumerate(outputs_list):
- if x is not None:
- outputs_list[i] = a.mark_as_return(x)
- if len(outputs_list) == 1 and outputs_list[0] is None:
- outputs_list = []
- output_shapes = [x.shape for x in outputs_list]
- if not all(isinstance(x, tf_ops.Tensor) for x in outputs_list):
- raise ValueError("Found non-tensor output in %s" % str(outputs_list))
- initializing_operations = tmp_graph.get_operations()
-
- # Call the function again, now replacing usages of variables with
- # placeholders. This assumes the variable capturing scope created above
- # knows about all variables.
- tmp_graph.clear_resource_control_flow_state()
- with variable_captures.capturing_scope(
- ), function.AutomaticControlDependencies() as a:
- captured_outputs = func(*func_inputs)
- captured_outlist = nest.flatten(captured_outputs)
- for i, x in enumerate(captured_outlist):
- if x is not None:
- captured_outlist[i] = a.mark_as_return(x)
- capturing_operations = tmp_graph.get_operations()[
- len(initializing_operations):]
-
- sorted_variables = sorted(variable_captures.variables.values(),
- key=lambda x: x.name)
-
- extra_inputs = tmp_graph.captures.keys()
- extra_placeholders = tmp_graph.captures.values()
-
- flat_inputs = [x for x in nest.flatten(func_inputs)
- if isinstance(x, tf_ops.Tensor)]
- placeholder_inputs = flat_inputs+ list(extra_placeholders)
-
- func_def_outputs = [x for x in outputs_list if isinstance(x, tf_ops.Tensor)]
- initialization_name = function._inference_name(func.__name__) # pylint: disable=protected-access
- # TODO(ashankar): Oh lord, forgive me for this lint travesty.
- # Also, what about the gradient registry of these functions? Those need to be
- # addressed as well.
- for f in tmp_graph._functions.values(): # pylint: disable=protected-access
- function._register(f._c_func.func) # pylint: disable=protected-access
- initializer_function = function.GraphModeFunction(
- initialization_name,
- placeholder_inputs,
- extra_inputs,
- tmp_graph,
- initializing_operations,
- func_def_outputs,
- func_outputs,
- output_shapes)
-
- capture_func_def_outputs = [
- x for x in captured_outlist if isinstance(x, tf_ops.Tensor)]
- captured_function_name = function._inference_name(func.__name__) # pylint: disable=protected-access
- captured_function = function.GraphModeFunction(
- captured_function_name,
- placeholder_inputs,
- extra_inputs,
- tmp_graph,
- capturing_operations,
- capture_func_def_outputs,
- captured_outputs,
- output_shapes,
- variables=[x.variable for x in sorted_variables])
-
- return _InitializingFunctionObject(captured_function, initializer_function,
- shape_and_dtypes)
-
-
-class ShapeAndDtype(object):
- """Data type that packages together shape and type information.
-
- Used for arguments to graph callables. See graph_callable() for an example.
- """
-
- def __init__(self, shape, dtype):
- self.shape = shape
- self.dtype = dtype
-
-
-def graph_callable(shape_and_dtypes):
- """Decorator that produces a callable that executes a TensorFlow graph.
-
- When applied on a function that constructs a TensorFlow graph, this decorator
- produces a callable object that:
-
- 1. Executes the graph when invoked. The first call will initialize any
- variables defined in the graph.
-
- 2. Provides a .variables() method to return the list of TensorFlow variables
- defined in the graph.
-
- Note that the wrapped function is not allowed to change the values of the
- variables, just use them.
-
- The return value of the wrapped function must be one of the following:
- (1) None, (2) a Tensor, or (3) a possibly nested sequence of Tensors.
-
- Example:
-
- ```python
- @tfe.graph_callable([tfe.ShapeAndDtype(shape(), dtype=dtypes.float32)])
- def foo(x):
- v = tf.get_variable('v', initializer=tf.ones_initializer(), shape=())
- return v + x
-
- ret = foo(tfe.Tensor(2.0)) # `ret` here is a Tensor with value 3.0.
-
- foo.variables[0].assign(7.0) # Modify the value of variable `v`.
- ret = foo(tfe.Tensor(2.0)) # `ret` here now is a Tensor with value 9.0.
- ```
- Args:
- shape_and_dtypes: A possibly nested list or tuple of ShapeAndDtype objects
- that specifies shape and type information for each of the callable's
- arguments. The length of this list must be equal to the number of
- arguments accepted by the wrapped function.
-
- Returns:
- A callable graph object.
- """
- # TODO(alive,apassos): support initialized_value and friends from tf.Variable.
- assert context.executing_eagerly(), (
- "graph_callable can only be used when Eager execution is enabled.")
- def decorator(func):
- return tf_decorator.make_decorator(func,
- _graph_callable_internal(
- func, shape_and_dtypes))
-
- return decorator
diff --git a/tensorflow/python/eager/graph_callable_test.py b/tensorflow/python/eager/graph_callable_test.py
deleted file mode 100644
index b9e6ca2a93..0000000000
--- a/tensorflow/python/eager/graph_callable_test.py
+++ /dev/null
@@ -1,249 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.python.eager import backprop
-from tensorflow.python.eager import graph_callable
-from tensorflow.python.eager import test
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import function
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import init_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import variable_scope
-
-
-class GraphCallableTest(test.TestCase):
-
- def testBasic(self):
-
- @graph_callable.graph_callable(
- [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
- def my_function(x):
- v = variable_scope.get_variable(
- "v", initializer=init_ops.zeros_initializer(), shape=())
- return v + x
-
- self.assertEqual(
- 2, my_function(constant_op.constant(2, dtype=dtypes.float32)).numpy())
-
- my_function.variables[0].assign(1.)
- self.assertEqual(
- 3, my_function(constant_op.constant(2, dtype=dtypes.float32)).numpy())
-
- def testFunctionWithoutReturnValue(self):
-
- @graph_callable.graph_callable(
- [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
- def my_function(x):
- v = variable_scope.get_variable(
- "v", initializer=init_ops.zeros_initializer(), shape=())
- v.assign(x)
-
- my_function(constant_op.constant(4, dtype=dtypes.float32))
- self.assertAllEqual(4, my_function.variables[0].read_value())
-
- def testFunctionWithoutReturnValueAndArgs(self):
-
- @graph_callable.graph_callable([])
- def my_function():
- v = variable_scope.get_variable(
- "v", initializer=init_ops.zeros_initializer(), shape=())
- v.assign(4)
-
- my_function()
- self.assertAllEqual(4, my_function.variables[0].read_value())
-
- def testVariableAPI(self):
-
- @graph_callable.graph_callable(
- [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
- def my_function(x):
- v = variable_scope.get_variable(
- "v", initializer=init_ops.zeros_initializer(), shape=())
- return v.read_value() + x
-
- self.assertEqual(
- 2, my_function(constant_op.constant(2, dtype=dtypes.float32)).numpy())
-
- my_function.variables[0].assign(1.)
- self.assertEqual(
- 3, my_function(constant_op.constant(2, dtype=dtypes.float32)).numpy())
-
- def testTensorShape(self):
-
- @graph_callable.graph_callable(
- [graph_callable.ShapeAndDtype(shape=(1), dtype=dtypes.float32)])
- def my_function(x):
- _ = x.get_shape()
- v = variable_scope.get_variable(
- "v", initializer=init_ops.zeros_initializer(), shape=[x.shape[0]])
- self.assertEqual(v.shape[0], x.shape[0])
- return v + x
-
- self.assertEqual([2.],
- my_function(
- constant_op.constant([2.],
- dtype=dtypes.float32)).numpy())
-
- def testUpdatesAreOrdered(self):
-
- @graph_callable.graph_callable(
- [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
- def my_function(x):
- v = variable_scope.get_variable(
- "v", initializer=init_ops.zeros_initializer(), shape=())
- v.assign(x + 1)
- v.assign(v * x)
- return v.read_value()
-
- self.assertAllEqual(my_function(constant_op.constant(2.0)), 6.0)
-
- def testEmptyInitializer(self):
-
- @graph_callable.graph_callable(
- [graph_callable.ShapeAndDtype(shape=(1), dtype=dtypes.float32)])
- def my_function(x):
- v = variable_scope.get_variable("v", shape=[1])
- return x + 0 * v
-
- self.assertEqual([2.],
- my_function(
- constant_op.constant([2.],
- dtype=dtypes.float32)).numpy())
-
- def testMismatchingNumArgs(self):
- # pylint: disable=anomalous-backslash-in-string
- with self.assertRaisesRegexp(TypeError,
- "The number of arguments accepted by the "
- "decorated function `my_function` \(2\) must "
- "match the number of ShapeAndDtype objects "
- "passed to the graph_callable\(\) decorator "
- "\(1\)."):
- @graph_callable.graph_callable([
- graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
- def my_function(x, y): # pylint: disable=unused-variable
- return x + y
- # pylint: enable=anomalous-backslash-in-string
-
- def testPureFunction(self):
-
- @graph_callable.graph_callable(
- [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.int32)])
- def f(x):
- return math_ops.add(x, constant_op.constant(3))
-
- self.assertAllEqual(5, f(constant_op.constant(2)))
-
- def testNestedFunction(self):
- # TensorFlow function (which is what would be used in TensorFlow graph
- # construction).
- @function.Defun(dtypes.int32, dtypes.int32)
- def add(a, b):
- return math_ops.add(a, b)
-
- # A graph_callable that will invoke the TensorFlow function.
- @graph_callable.graph_callable(
- [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.int32)])
- def add_one(x):
- return add(x, 1)
-
- self.assertAllEqual(3, add_one(constant_op.constant(2)))
-
- # TODO(ashankar): Make this work.
- # The problem is that the two graph_callables (for add_one and add_two)
- # are both trying to register the FunctionDef corresponding to "add".
- def DISABLED_testRepeatedUseOfSubFunction(self):
-
- @function.Defun(dtypes.int32, dtypes.int32)
- def add(a, b):
- return math_ops.add(a, b)
-
- @graph_callable.graph_callable(
- [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.int32)])
- def add_one(x):
- return add(x, 1)
-
- @graph_callable.graph_callable(
- [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.int32)])
- def add_two(x):
- return add(x, 2)
-
- two = constant_op.constant(2)
- self.assertAllEqual(3, add_one(two))
- self.assertAllEqual(4, add_two(two))
-
- def testNestedSequenceInputs(self):
- sd = graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)
- @graph_callable.graph_callable([[sd, tuple([sd, sd]), sd]])
- def my_op(inputs):
- a, b, c = inputs
- e, f = b
- v = variable_scope.get_variable(
- "my_v", initializer=init_ops.zeros_initializer(), shape=())
- return [a + a + v, tuple([e + e, f + f]), c + c], a + e + f + c + v
-
- inputs = [constant_op.constant(1.),
- [constant_op.constant(2.), constant_op.constant(3.)],
- constant_op.constant(4.)]
- ret = my_op(inputs)
- self.assertEqual(len(ret), 2.)
- self.assertAllEqual(ret[1], 10.)
-
- my_op.variables[0].assign(1.)
- ret = my_op(inputs)
- self.assertAllEqual(ret[1], 11.)
-
- def testVariableShapeIsTensorShape(self):
- @graph_callable.graph_callable([])
- def my_function():
- v = variable_scope.get_variable(
- "v", initializer=init_ops.zeros_initializer(), shape=())
- self.assertIsInstance(v.get_shape(), tensor_shape.TensorShape)
-
- my_function()
-
- def testIncorrectlyShapedInputs(self):
- @graph_callable.graph_callable(
- [graph_callable.ShapeAndDtype(shape=(3), dtype=dtypes.float32)])
- def my_function(x):
- v = variable_scope.get_variable(
- "v", initializer=init_ops.zeros_initializer(), shape=())
- return v + x
-
- with self.assertRaises(ValueError):
- my_function([1, 2])
-
- self.assertTrue(([1, 2, 3] == my_function(
- constant_op.constant([1, 2, 3], dtype=dtypes.float32)).numpy()).all())
-
- def testGradients(self):
- @graph_callable.graph_callable([])
- def my_function():
- v = variable_scope.get_variable(
- "v", initializer=init_ops.constant_initializer(3.), shape=())
- return v * v
-
- grad_fn = backprop.implicit_grad(my_function)
- grads_and_vars = list(zip(*grad_fn()))
- self.assertAllEqual(6., grads_and_vars[0][0])
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/python/eager/graph_only_ops_test.py b/tensorflow/python/eager/graph_only_ops_test.py
index d2a2b4e223..3cf3a61a62 100644
--- a/tensorflow/python/eager/graph_only_ops_test.py
+++ b/tensorflow/python/eager/graph_only_ops_test.py
@@ -32,13 +32,13 @@ class GraphOnlyOpsTest(test_util.TensorFlowTestCase):
def testGraphZerosLike(self):
x = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
z_tf = graph_only_ops.graph_zeros_like(x)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(np.zeros((2, 3)), z_tf.eval())
def testGraphPlaceholder(self):
x_tf = graph_only_ops.graph_placeholder(dtypes.int32, shape=(1,))
y_tf = math_ops.square(x_tf)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = np.array([42])
y = sess.run(y_tf, feed_dict={x_tf: np.array([42])})
self.assertAllClose(np.square(x), y)
diff --git a/tensorflow/python/eager/imperative_grad.py b/tensorflow/python/eager/imperative_grad.py
index 000152855d..5f027d107c 100644
--- a/tensorflow/python/eager/imperative_grad.py
+++ b/tensorflow/python/eager/imperative_grad.py
@@ -24,12 +24,10 @@ from tensorflow.python import pywrap_tensorflow
VSpace = collections.namedtuple(
- "VSpace",
- ["aggregate_fn", "num_elements_fn", "tensor_id", "zeros", "ones"])
+ "VSpace", ["aggregate_fn", "num_elements_fn", "zeros", "ones"])
def imperative_grad(
- vspace,
tape,
target,
sources,
@@ -41,7 +39,6 @@ def imperative_grad(
gradients for all sources.
Args:
- vspace: the vector space in which to differentiate.
tape: the gradient tape which stores the trace.
target: either a Tensor or list of Tensors to be differentiated.
sources: list of Tensors for which we want gradients
@@ -60,4 +57,7 @@ def imperative_grad(
computation of target.
"""
return pywrap_tensorflow.TFE_Py_TapeGradient(
- tape._tape, vspace, target, sources, output_gradients) # pylint: disable=protected-access
+ tape._tape, # pylint: disable=protected-access
+ target,
+ sources,
+ output_gradients)
diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc
index 15d2ccf9d2..f34ce6af79 100644
--- a/tensorflow/python/eager/pywrap_tensor.cc
+++ b/tensorflow/python/eager/pywrap_tensor.cc
@@ -27,6 +27,8 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/python/lib/core/ndarray_tensor.h"
+#include "structmember.h" // NOLINT // For PyMemberDef
+
// forward declare
struct EagerTensor;
@@ -263,6 +265,14 @@ typedef struct EagerTensor {
TF_Status* status;
PyObject* weakreflist; /* List of weak references */
+
+ // Per-instance attribute dictionary, to support monkey patching
+ // (e.g. EagerTensor.assign when slicing variables). This dictionary is
+ // created by CPython the first time an attribute is assigned, pointed to by
+ // tp_dictoffset. Note that garbage collection is not enabled for
+ // EagerTensors, so assigning objects to EagerTensor attributes which require
+ // garbage collection is likely to cause issues.
+ PyObject* dict;
} EagerTensor;
namespace {
@@ -311,17 +321,42 @@ int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) {
Py_INCREF(Py_None);
self->tensor_shape = Py_None;
self->status = TF_NewStatus();
+ self->dict = nullptr;
self->weakreflist = nullptr;
PyObject* value;
PyObject* context = nullptr;
PyObject* device = nullptr;
PyObject* dtype = Py_None;
- const char* kwlist[] = {"value", "context", "device", "dtype", nullptr};
- if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO|O",
+ PyObject* other_value = nullptr;
+ const char* kwlist[] = {"value", "context", "device",
+ "dtype", "other_value", nullptr};
+ if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO|OO",
const_cast<char**>(kwlist), &value, &context,
- &device, &dtype)) {
+ &device, &dtype, &other_value)) {
return -1;
}
+
+ if (other_value != nullptr) {
+ if (!EagerTensor_CheckExact(other_value)) {
+ PyErr_SetString(PyExc_TypeError,
+ tensorflow::strings::StrCat(
+ "Expecting an EagerTensor for other_value, got ",
+ Py_TYPE(other_value)->tp_name)
+ .c_str());
+
+ return -1;
+ }
+ EagerTensor* other = reinterpret_cast<EagerTensor*>(other_value);
+ self->handle =
+ TFE_TensorHandleCopySharingTensor(other->handle, self->status);
+
+ if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
+ return -1;
+ }
+
+ return 0;
+ }
+
// Extract dtype
int desired_dtype = -1;
if (dtype != Py_None) {
@@ -410,6 +445,10 @@ void EagerTensor_dealloc(EagerTensor* self) {
Py_DECREF(self->handle_data);
Py_DECREF(self->keras_mask);
Py_DECREF(self->tensor_shape);
+ // If an attribute dictionary has been created, release it. Note that this
+ // is only ever created by CPython's attribute setting methods; we don't
+ // create it ourselves.
+ Py_CLEAR(self->dict);
if (self->handle != nullptr) {
TFE_DeleteTensorHandle(self->handle);
self->handle = nullptr;
@@ -474,6 +513,30 @@ static PyObject* EagerTensor_rank(EagerTensor* self) {
#endif
}
+// Getter for `_num_elements`.
+static PyObject* EagerTensor_num_elements(EagerTensor* self) {
+ auto handle = self->handle;
+ int n = TFE_TensorHandleNumDims(handle, self->status);
+ if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
+ // Cleanup self->status before returning.
+ TF_SetStatus(self->status, TF_OK, "");
+ return nullptr;
+ }
+ tensorflow::int64 value = 1;
+ if (PyErr_Occurred()) return nullptr;
+ for (int i = 0; i < n; ++i) {
+ int64_t dim = TFE_TensorHandleDim(handle, i, self->status);
+ if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
+ // Cleanup self->status before returning.
+ TF_SetStatus(self->status, TF_OK, "");
+ PyErr_SetString(PyExc_RuntimeError, "Error while iterating dimensions");
+ return nullptr;
+ }
+ value *= dim;
+ }
+ return PyLong_FromLongLong(value);
+}
+
static PyObject* EagerTensor_tensor_handle(EagerTensor* self, void* unused) {
Py_INCREF(self->handle_data);
return self->handle_data;
@@ -582,6 +645,15 @@ static PyGetSetDef EagerTensor_getseters[] = {
{nullptr} /* Sentinel */
};
+#if PY_MAJOR_VERSION < 3
+// Only used for Python2 since Python3 seems to set the __dict__ correctly.
+static PyMemberDef EagerTensor_members[] = {
+ {const_cast<char*>("__dict__"), T_OBJECT, offsetof(EagerTensor, dict),
+ READONLY},
+ {nullptr},
+};
+#endif
+
static PyMethodDef EagerTensor_methods[] = {
{"_numpy", (PyCFunction)EagerTensor_numpy, METH_NOARGS,
PyDoc_STR("_numpy")},
@@ -592,6 +664,8 @@ static PyMethodDef EagerTensor_methods[] = {
{"_rank", (PyCFunction)EagerTensor_rank, METH_NOARGS, PyDoc_STR("_rank")},
{"_copy_to_device", (PyCFunction)EagerTensor_copy_to_device,
METH_VARARGS | METH_KEYWORDS, PyDoc_STR("_copy_to_device")},
+ {"_num_elements", (PyCFunction)EagerTensor_num_elements, METH_NOARGS,
+ PyDoc_STR("_num_elements")},
{nullptr, nullptr},
};
@@ -654,13 +728,13 @@ static PyTypeObject _EagerTensorType = {
nullptr, /* tp_iter */
nullptr, /* tp_iternext */
EagerTensor_methods, /* tp_methods */
- nullptr, /* tp_members */
+ EagerTensor_members, /* tp_members */
EagerTensor_getseters, /* tp_getset */
nullptr, /* tp_base */
nullptr, /* tp_dict */
nullptr, /* tp_descr_get */
nullptr, /* tp_descr_set */
- 0, /* tp_dictoffset */
+ offsetof(EagerTensor, dict), /* tp_dictoffset */
(initproc)EagerTensor_init, /* tp_init */
nullptr, /* tp_alloc */
nullptr, /* tp_new */
@@ -788,8 +862,9 @@ PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
PyErr_SetString(PyExc_RuntimeError, "Error while creating EagerTensorType");
return nullptr;
}
+ EagerTensorType->tp_dictoffset = offsetof(EagerTensor, dict);
#else
- _EagerTensorType.tp_base = reinterpret_cast<PyTypeObject*>(base_class);
+ _EagerTensorType.tp_base = base_class_type;
if (PyType_Ready(&_EagerTensorType) < 0) {
if (PyErr_Occurred()) return nullptr;
@@ -800,9 +875,6 @@ PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
EagerTensorType = &_EagerTensorType;
Py_INCREF(EagerTensorType);
#endif
- // We disable instance based attribute lookup. Its not clear if these
- // dictionaries are correctly initialized in the first place.
- EagerTensorType->tp_dictoffset = 0;
return reinterpret_cast<PyObject*>(EagerTensorType);
}
diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h
index a916a75f00..f1b4042ec9 100644..100755
--- a/tensorflow/python/eager/pywrap_tfe.h
+++ b/tensorflow/python/eager/pywrap_tfe.h
@@ -59,6 +59,10 @@ PyObject* TFE_Py_RegisterExceptionClass(PyObject* e);
// This function is not thread-safe.
PyObject* TFE_Py_RegisterResourceVariableType(PyObject* e);
+// Registers e as the VSpace to use.
+// `vspace` must be a imperative_grad.py:VSpace named tuple.
+PyObject* TFE_Py_RegisterVSpace(PyObject* e);
+
// Registers e as the Exception to be raised when the conditions of
// TFE_Py_FastPathExecute_C have not been met. When this exception is set, it
// is a signal to the calling code that it should fall back to the safer (and
@@ -89,7 +93,7 @@ int MaybeRaiseExceptionFromStatus(const tensorflow::Status& status,
PyObject* exception);
// Returns the string associated with the passed-in python object.
-char* TFE_GetPythonString(PyObject* o);
+const char* TFE_GetPythonString(PyObject* o);
// Returns a unique id on each call.
int64_t get_uid();
@@ -124,9 +128,10 @@ PyObject* TFE_Py_InitEagerTensor(PyObject* base_class);
// To unset the profiler, pass Py_None as the value of `profiler`.
PyObject* TFE_Py_SetEagerTensorProfiler(PyObject* profiler);
-// Creates a new tape and adds it to the active set. `persistent` must be a
-// PyBool_Type, i.e either Py_True or Py_False
-PyObject* TFE_Py_TapeSetNew(PyObject* persistent);
+// Creates a new tape and adds it to the active set. `persistent` and
+// `watch_accessed_variables` must be `PyBool_Type` (`Py_True` or `Py_False`).
+PyObject* TFE_Py_TapeSetNew(PyObject* persistent,
+ PyObject* watch_accessed_variables);
// Removes the passed tape from the set of active tapes.
void TFE_Py_TapeSetRemove(PyObject* tape);
@@ -138,7 +143,7 @@ void TFE_Py_TapeSetAdd(PyObject* tape);
PyObject* TFE_Py_TapeSetIsEmpty();
PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors);
-void TFE_Py_TapeSetWatch(PyObject* tensor);
+void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor);
void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id);
// Stops any gradient recording on the current thread.
@@ -158,18 +163,20 @@ void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
PyObject* input_tensor_ids,
PyObject* backward_function);
+// Notifies all tapes that a variable has been accessed.
+void TFE_Py_TapeVariableAccessed(PyObject* variable);
+
// Watches the given variable object on the given tape.
-void TFE_Py_TapeSetWatchVariable(PyObject* variable);
+void TFE_Py_TapeWatchVariable(PyObject* tape, PyObject* variable);
// Computes a gradient based on information recorded on the tape.`tape` must
-// have been produced by TFE_Py_NewTape. `vspace` must be a
-// imperative_grad.py:VSpace named tuple. `target` and `sources` must be python
+// have been produced by TFE_Py_NewTape. `target` and `sources` must be python
// lists of Tensor objects. `output_gradients` is either None or a python list
// of either Tensor or None, and if not None should have the same length as
// target.
-PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace,
- PyObject* target, PyObject* sources,
- PyObject* output_gradients, TF_Status* status);
+PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* target,
+ PyObject* sources, PyObject* output_gradients,
+ TF_Status* status);
// Execute a tensorflow operation assuming that all provided inputs are
// correctly formatted (i.e. EagerTensors). If it doesn't find EagerTensors,
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 18fafd0de1..46dcf7c8a8 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -216,7 +216,7 @@ bool ParseStringValue(const string& key, PyObject* py_value, TF_Status* status,
#if PY_MAJOR_VERSION >= 3
if (PyUnicode_Check(py_value)) {
Py_ssize_t size = 0;
- char* buf = PyUnicode_AsUTF8AndSize(py_value, &size);
+ const char* buf = PyUnicode_AsUTF8AndSize(py_value, &size);
if (buf == nullptr) return false;
*value = tensorflow::StringPiece(buf, size);
return true;
@@ -825,7 +825,7 @@ int MaybeRaiseExceptionFromStatus(const tensorflow::Status& status,
return -1;
}
-char* TFE_GetPythonString(PyObject* o) {
+const char* TFE_GetPythonString(PyObject* o) {
if (PyBytes_Check(o)) {
return PyBytes_AsString(o);
}
@@ -892,9 +892,10 @@ static tensorflow::DataType FastTensorDtype(PyObject* tensor) {
class GradientTape
: public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction> {
public:
- explicit GradientTape(bool persistent)
+ explicit GradientTape(bool persistent, bool watch_accessed_variables)
: tensorflow::eager::GradientTape<PyObject, PyBackwardFunction>(
- persistent) {}
+ persistent),
+ watch_accessed_variables_(watch_accessed_variables) {}
virtual ~GradientTape() {
for (const IdAndVariable& v : watched_variables_) {
@@ -902,6 +903,12 @@ class GradientTape
}
}
+ void VariableAccessed(PyObject* v) {
+ if (watch_accessed_variables_) {
+ WatchVariable(v);
+ }
+ }
+
void WatchVariable(PyObject* v) {
tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(v, "handle"));
if (handle == nullptr) {
@@ -951,6 +958,7 @@ class GradientTape
}
};
+ bool watch_accessed_variables_;
tensorflow::mutex watched_variables_mu_;
std::set<IdAndVariable, CompareById> watched_variables_
GUARDED_BY(watched_variables_mu_);
@@ -1056,11 +1064,13 @@ void TFE_Py_TapeSetStopOnThread() { *ThreadTapeIsStopped() = true; }
void TFE_Py_TapeSetRestartOnThread() { *ThreadTapeIsStopped() = false; }
-PyObject* TFE_Py_TapeSetNew(PyObject* persistent) {
+PyObject* TFE_Py_TapeSetNew(PyObject* persistent,
+ PyObject* watch_accessed_variables) {
TFE_Py_Tape_Type.tp_new = PyType_GenericNew;
if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return nullptr;
TFE_Py_Tape* tape = PyObject_NEW(TFE_Py_Tape, &TFE_Py_Tape_Type);
- tape->tape = new GradientTape(persistent == Py_True);
+ tape->tape = new GradientTape(persistent == Py_True,
+ watch_accessed_variables == Py_True);
Py_INCREF(tape);
GetTapeSet()->insert(reinterpret_cast<TFE_Py_Tape*>(tape));
return reinterpret_cast<PyObject*>(tape);
@@ -1154,7 +1164,7 @@ PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors) {
Py_RETURN_FALSE;
}
-void TFE_Py_TapeSetWatch(PyObject* tensor) {
+void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor) {
if (*ThreadTapeIsStopped()) {
return;
}
@@ -1162,9 +1172,7 @@ void TFE_Py_TapeSetWatch(PyObject* tensor) {
if (PyErr_Occurred()) {
return;
}
- for (TFE_Py_Tape* tape : *GetTapeSet()) {
- tape->tape->Watch(tensor_id);
- }
+ reinterpret_cast<TFE_Py_Tape*>(tape)->tape->Watch(tensor_id);
}
static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) {
@@ -1235,15 +1243,22 @@ std::vector<tensorflow::int64> MakeTensorIDList(PyObject* tensors) {
return list;
}
-void TFE_Py_TapeSetWatchVariable(PyObject* variable) {
+void TFE_Py_TapeVariableAccessed(PyObject* variable) {
if (*ThreadTapeIsStopped()) {
return;
}
for (TFE_Py_Tape* tape : SafeTapeSet()) {
- tape->tape->WatchVariable(variable);
+ tape->tape->VariableAccessed(variable);
}
}
+void TFE_Py_TapeWatchVariable(PyObject* tape, PyObject* variable) {
+ if (*ThreadTapeIsStopped()) {
+ return;
+ }
+ reinterpret_cast<TFE_Py_Tape*>(tape)->tape->WatchVariable(variable);
+}
+
PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) {
return reinterpret_cast<TFE_Py_Tape*>(tape)->tape->GetVariablesAsPyTuple();
}
@@ -1350,7 +1365,9 @@ void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) {
class PyVSpace
: public tensorflow::eager::VSpace<PyObject, PyBackwardFunction> {
public:
- explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {}
+ explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {
+ Py_INCREF(py_vspace_);
+ }
tensorflow::Status Initialize() {
num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn");
@@ -1378,6 +1395,8 @@ class PyVSpace
Py_XDECREF(aggregate_fn_);
Py_XDECREF(zeros_);
Py_XDECREF(ones_);
+
+ Py_DECREF(py_vspace_);
}
tensorflow::int64 NumElements(PyObject* tensor) const final {
@@ -1493,6 +1512,22 @@ class PyVSpace
PyObject* zeros_;
PyObject* ones_;
};
+PyVSpace* py_vspace = nullptr;
+
+PyObject* TFE_Py_RegisterVSpace(PyObject* e) {
+ if (py_vspace != nullptr) {
+ delete py_vspace;
+ }
+
+ py_vspace = new PyVSpace(e);
+ auto status = py_vspace->Initialize();
+ if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
+ delete py_vspace;
+ return nullptr;
+ }
+
+ Py_RETURN_NONE;
+}
std::vector<PyObject*> MakeTensorList(PyObject* tensors) {
PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
@@ -1509,9 +1544,9 @@ std::vector<PyObject*> MakeTensorList(PyObject* tensors) {
return list;
}
-PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace,
- PyObject* target, PyObject* sources,
- PyObject* output_gradients, TF_Status* status) {
+PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* target,
+ PyObject* sources, PyObject* output_gradients,
+ TF_Status* status) {
TFE_Py_Tape* tape_obj = reinterpret_cast<TFE_Py_Tape*>(tape);
if (!tape_obj->tape->IsPersistent()) {
auto* tape_set = GetTapeSet();
@@ -1526,10 +1561,6 @@ PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace,
return nullptr;
}
}
- PyVSpace c_vspace(vspace);
- if (!c_vspace.Initialize().ok()) {
- return nullptr;
- }
std::vector<tensorflow::int64> target_vec = MakeTensorIDList(target);
if (PyErr_Occurred()) {
@@ -1553,7 +1584,7 @@ PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace,
}
std::vector<PyObject*> result;
status->status = tape_obj->tape->ComputeGradient(
- c_vspace, target_vec, sources_vec, outgrad_vec, &result);
+ *py_vspace, target_vec, sources_vec, outgrad_vec, &result);
if (!status->status.ok()) {
if (PyErr_Occurred()) {
// Do not propagate the erroneous status as that would swallow the
@@ -1709,118 +1740,169 @@ PyObject* MaybeGetDTypeForAttr(const string& attr,
Py_RETURN_NONE;
}
-bool OpDoesntRequireOutput(const string& op_name) {
- static tensorflow::gtl::FlatSet<string>* ops_that_dont_require_outputs =
- new tensorflow::gtl::FlatSet<string>({
- "Identity",
- "MatMul",
- "Conv2DBackpropInput",
- "Conv2DBackpropFilter",
- "Conv3D",
- "Conv3DBackpropInputV2",
- "AvgPool3D",
- "AvgPool3DGrad",
- "MaxPool3D",
- "MaxPool3DGrad",
- "MaxPool3DGradGrad",
- "BiasAdd",
- "BiasAddV1",
- "BiasAddGrad",
- "Softplus",
- "SoftplusGrad",
- "Softsign",
- "ReluGrad",
- "LeakyRelu",
- "LeakyReluGrad",
- "Conv2D",
- "DepthwiseConv2dNative",
- "Dilation2D",
- "AvgPool",
- "AvgPoolGrad",
- "BatchNormWithGlobalNormalization",
- "L2Loss",
- "Sum",
- "Prod",
- "SegmentSum",
- "SegmentMean",
- "SparseSegmentSum",
- "SparseSegmentMean",
- "SparseSegmentSqrtN",
- "SegmentMin",
- "SegmentMax",
- "UnsortedSegmentSum",
- "UnsortedSegmentMax",
- "Abs",
- "Neg",
- "ReciprocalGrad",
- "Square",
- "Expm1",
- "Log",
- "Log1p",
- "TanhGrad",
- "SigmoidGrad",
- "Sign",
- "Sin",
- "Cos",
- "Tan",
- "Add",
- "Sub",
- "Mul",
- "Div",
- "RealDiv",
- "Maximum",
- "Minimum",
- "SquaredDifference",
- "Select",
- "SparseMatMul",
- "BatchMatMul",
- "Complex",
- "Real",
- "Imag",
- "Angle",
- "Conj",
- "Cast",
- "Cross",
- "Cumsum",
- "Cumprod",
- "ReadVariableOp",
- "VarHandleOp",
- "Shape",
+// Returns a pair where the first value of the pair indicates whether or not all
+// outputs are unused. If the first value is false, the second value is a
+// set that identifies which of the output indices are unused.
+bool OpGradientDoesntRequireOutputIndices(
+ const string& op_name,
+ std::pair<bool, tensorflow::gtl::FlatSet<int>>** output) {
+ static tensorflow::gtl::FlatMap<
+ string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>* m =
+ new tensorflow::gtl::FlatMap<
+ string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>({
+ // Ops that don't require any outputs.
+ {"Identity", {true, {}}},
+ {"MatMul", {true, {}}},
+ {"Conv2DBackpropInput", {true, {}}},
+ {"Conv2DBackpropFilter", {true, {}}},
+ {"Conv3D", {true, {}}},
+ {"Conv3DBackpropInputV2", {true, {}}},
+ {"AvgPool3D", {true, {}}},
+ {"AvgPool3DGrad", {true, {}}},
+ {"MaxPool3D", {true, {}}},
+ {"MaxPool3DGrad", {true, {}}},
+ {"MaxPool3DGradGrad", {true, {}}},
+ {"BiasAdd", {true, {}}},
+ {"BiasAddV1", {true, {}}},
+ {"BiasAddGrad", {true, {}}},
+ {"Softplus", {true, {}}},
+ {"SoftplusGrad", {true, {}}},
+ {"Softsign", {true, {}}},
+ {"ReluGrad", {true, {}}},
+ {"LeakyRelu", {true, {}}},
+ {"LeakyReluGrad", {true, {}}},
+ {"Conv2D", {true, {}}},
+ {"DepthwiseConv2dNative", {true, {}}},
+ {"Dilation2D", {true, {}}},
+ {"AvgPool", {true, {}}},
+ {"AvgPoolGrad", {true, {}}},
+ {"BatchNormWithGlobalNormalization", {true, {}}},
+ {"L2Loss", {true, {}}},
+ {"Sum", {true, {}}},
+ {"Prod", {true, {}}},
+ {"SegmentSum", {true, {}}},
+ {"SegmentMean", {true, {}}},
+ {"SparseSegmentSum", {true, {}}},
+ {"SparseSegmentMean", {true, {}}},
+ {"SparseSegmentSqrtN", {true, {}}},
+ {"SegmentMin", {true, {}}},
+ {"SegmentMax", {true, {}}},
+ {"UnsortedSegmentSum", {true, {}}},
+ {"UnsortedSegmentMax", {true, {}}},
+ {"Abs", {true, {}}},
+ {"Neg", {true, {}}},
+ {"ReciprocalGrad", {true, {}}},
+ {"Square", {true, {}}},
+ {"Expm1", {true, {}}},
+ {"Log", {true, {}}},
+ {"Log1p", {true, {}}},
+ {"TanhGrad", {true, {}}},
+ {"SigmoidGrad", {true, {}}},
+ {"Sign", {true, {}}},
+ {"Sin", {true, {}}},
+ {"Cos", {true, {}}},
+ {"Tan", {true, {}}},
+ {"Add", {true, {}}},
+ {"Sub", {true, {}}},
+ {"Mul", {true, {}}},
+ {"Div", {true, {}}},
+ {"RealDiv", {true, {}}},
+ {"Maximum", {true, {}}},
+ {"Minimum", {true, {}}},
+ {"SquaredDifference", {true, {}}},
+ {"Select", {true, {}}},
+ {"SparseMatMul", {true, {}}},
+ {"BatchMatMul", {true, {}}},
+ {"Complex", {true, {}}},
+ {"Real", {true, {}}},
+ {"Imag", {true, {}}},
+ {"Angle", {true, {}}},
+ {"Conj", {true, {}}},
+ {"Cast", {true, {}}},
+ {"Cross", {true, {}}},
+ {"Cumsum", {true, {}}},
+ {"Cumprod", {true, {}}},
+ {"ReadVariableOp", {true, {}}},
+ {"VarHandleOp", {true, {}}},
+ {"Shape", {true, {}}},
+ {"StridedSlice", {true, {}}},
+ {"Fill", {true, {}}},
+
+ // Ops that don't require a subset of outputs.
+ {"FusedBatchNorm", {false, {0, 1, 2}}},
});
- return ops_that_dont_require_outputs->find(op_name) !=
- ops_that_dont_require_outputs->end();
-}
-
-bool OpDoesntRequireInput(const string& op_name) {
- static tensorflow::gtl::FlatSet<string>* ops_that_dont_require_inputs =
- new tensorflow::gtl::FlatSet<string>({
- "Identity",
- "Softmax",
- "LogSoftmax",
- "BiasAdd",
- "Relu",
- "Relu6",
- "Elu",
- "Selu",
- "SparseSoftmaxCrossEntropyWithLogits",
- "Neg",
- "Inv",
- "Reciprocal",
- "Sqrt",
- "Exp",
- "Tanh",
- "Sigmoid",
- "Real",
- "Imag",
- "Conj",
- "ReadVariableOp",
- "VarHandleOp",
- "Shape",
+ auto it = m->find(op_name);
+
+ if (it == m->end()) return false;
+
+ *output = &it->second;
+ return true;
+}
+
+// Returns a pair where the first value of the pair indicates whether or not all
+// inputs are unused. If the first value is false, the second value is a
+// set that identifies which of the input indices are unused.
+bool OpGradientDoesntRequireInputIndices(
+ const string& op_name,
+ std::pair<bool, tensorflow::gtl::FlatSet<int>>** output) {
+ static tensorflow::gtl::FlatMap<
+ string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>* m =
+ new tensorflow::gtl::FlatMap<
+ string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>({
+ // Ops that don't require any inputs.
+ {"Identity", {true, {}}},
+ {"Softmax", {true, {}}},
+ {"LogSoftmax", {true, {}}},
+ {"BiasAdd", {true, {}}},
+ {"Relu", {true, {}}},
+ {"Relu6", {true, {}}},
+ {"Elu", {true, {}}},
+ {"Selu", {true, {}}},
+ {"SparseSoftmaxCrossEntropyWithLogits", {true, {}}},
+ {"Neg", {true, {}}},
+ {"Inv", {true, {}}},
+ {"Reciprocal", {true, {}}},
+ {"Sqrt", {true, {}}},
+ {"Exp", {true, {}}},
+ {"Tanh", {true, {}}},
+ {"Sigmoid", {true, {}}},
+ {"Real", {true, {}}},
+ {"Imag", {true, {}}},
+ {"Conj", {true, {}}},
+ {"ReadVariableOp", {true, {}}},
+ {"VarHandleOp", {true, {}}},
+ {"Shape", {true, {}}},
+ {"Fill", {true, {}}},
+
+ // Ops that don't require a subset of inputs.
+ {"FusedBatchNorm", {false, {2}}},
});
- return ops_that_dont_require_inputs->find(op_name) !=
- ops_that_dont_require_inputs->end();
+ auto it = m->find(op_name);
+
+ if (it == m->end()) return false;
+
+ *output = &it->second;
+ return true;
+}
+
+PyObject* CopySequenceSettingIndicesToNull(
+ PyObject* seq, const tensorflow::gtl::FlatSet<int>& indices) {
+ tensorflow::Safe_PyObjectPtr fast_seq(
+ PySequence_Fast(seq, "unable to allocate"));
+ PyObject* result = PyTuple_New(PySequence_Fast_GET_SIZE(fast_seq.get()));
+ for (int i = 0; i < PySequence_Fast_GET_SIZE(fast_seq.get()); i++) {
+ PyObject* item;
+ if (indices.find(i) != indices.end()) {
+ item = Py_None;
+ } else {
+ item = PySequence_Fast_GET_ITEM(fast_seq.get(), i);
+ }
+ Py_INCREF(item);
+ PyTuple_SET_ITEM(result, i, item);
+ }
+ return result;
}
PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
@@ -1840,16 +1922,35 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
if (!should_record) Py_RETURN_NONE;
string c_op_name = TFE_GetPythonString(op_name);
+
PyObject* op_outputs;
- if (OpDoesntRequireOutput(c_op_name)) {
- op_outputs = Py_None;
+ bool op_outputs_tuple_created = false;
+ std::pair<bool, tensorflow::gtl::FlatSet<int>>* outputs_not_required;
+
+ if (OpGradientDoesntRequireOutputIndices(c_op_name, &outputs_not_required)) {
+ if (outputs_not_required->first) {
+ op_outputs = Py_None;
+ } else {
+ op_outputs_tuple_created = true;
+ op_outputs = CopySequenceSettingIndicesToNull(
+ results, outputs_not_required->second);
+ }
} else {
op_outputs = results;
}
PyObject* op_inputs;
- if (OpDoesntRequireInput(c_op_name)) {
- op_inputs = Py_None;
+ bool op_inputs_tuple_created = false;
+ std::pair<bool, tensorflow::gtl::FlatSet<int>>* inputs_not_required;
+
+ if (OpGradientDoesntRequireInputIndices(c_op_name, &inputs_not_required)) {
+ if (inputs_not_required->first) {
+ op_inputs = Py_None;
+ } else {
+ op_inputs_tuple_created = true;
+ op_inputs =
+ CopySequenceSettingIndicesToNull(inputs, inputs_not_required->second);
+ }
} else {
op_inputs = inputs;
}
@@ -1892,18 +1993,20 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
});
Py_DECREF(num_inputs);
+ if (op_outputs_tuple_created) Py_DECREF(op_outputs);
+ if (op_inputs_tuple_created) Py_DECREF(op_inputs);
Py_RETURN_NONE;
}
-void MaybeWatchVariable(PyObject* input) {
+void MaybeNotifyVariableAccessed(PyObject* input) {
DCHECK(CheckResourceVariable(input));
DCHECK(PyObject_HasAttrString(input, "_trainable"));
tensorflow::Safe_PyObjectPtr trainable(
PyObject_GetAttrString(input, "_trainable"));
if (trainable.get() == Py_False) return;
- TFE_Py_TapeSetWatchVariable(input);
+ TFE_Py_TapeVariableAccessed(input);
}
bool CastTensor(const FastPathOpExecInfo& op_exec_info,
@@ -1934,7 +2037,7 @@ bool CastTensor(const FastPathOpExecInfo& op_exec_info,
bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info,
PyObject* input, tensorflow::Safe_PyObjectPtr* output,
TF_Status* status) {
- MaybeWatchVariable(input);
+ MaybeNotifyVariableAccessed(input);
TFE_Op* op = TFE_NewOp(parent_op_exec_info.ctx, "ReadVariableOp", status);
auto cleaner = tensorflow::gtl::MakeCleanup([op] { TFE_DeleteOp(op); });
diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py
index caa217b70c..399d90223c 100644
--- a/tensorflow/python/eager/tape.py
+++ b/tensorflow/python/eager/tape.py
@@ -33,9 +33,10 @@ class Tape(object):
return pywrap_tensorflow.TFE_Py_TapeWatchedVariables(self._tape)
-def push_new_tape(persistent=False):
+def push_new_tape(persistent=False, watch_accessed_variables=True):
"""Pushes a new tape onto the tape stack."""
- tape = pywrap_tensorflow.TFE_Py_TapeSetNew(persistent)
+ tape = pywrap_tensorflow.TFE_Py_TapeSetNew(persistent,
+ watch_accessed_variables)
return Tape(tape)
@@ -44,22 +45,19 @@ def push_tape(tape):
pywrap_tensorflow.TFE_Py_TapeSetAdd(tape._tape) # pylint: disable=protected-access
-def watch(tensor):
- """Marks this tensor to be watched by all tapes in the stack.
+def watch(tape, tensor):
+ """Marks this tensor to be watched by the given tape."""
+ pywrap_tensorflow.TFE_Py_TapeWatch(tape._tape, tensor) # pylint: disable=protected-access
- Args:
- tensor: tensor to be watched.
- """
- pywrap_tensorflow.TFE_Py_TapeSetWatch(tensor)
+def watch_variable(tape, variable):
+ """Marks this variable to be watched by the given tape."""
+ pywrap_tensorflow.TFE_Py_TapeWatchVariable(tape._tape, variable) # pylint: disable=protected-access
-def watch_variable(variable):
- """Marks this variable to be watched by all tapes in the stack.
- Args:
- variable: variable to be watched.
- """
- pywrap_tensorflow.TFE_Py_TapeSetWatchVariable(variable)
+def variable_accessed(variable):
+ """Notifies all tapes in the stack that a variable has been accessed."""
+ pywrap_tensorflow.TFE_Py_TapeVariableAccessed(variable)
def pop_tape(tape):
diff --git a/tensorflow/python/eager/tape_test.py b/tensorflow/python/eager/tape_test.py
index 4326d5efa3..acd0e569f1 100644
--- a/tensorflow/python/eager/tape_test.py
+++ b/tensorflow/python/eager/tape_test.py
@@ -72,7 +72,7 @@ class TapeTest(test.TestCase):
a = constant_op.constant([[1., 0.], [0., 1.]])
b = constant_op.constant([[1., 2.], [3., 4.]])
da, db = backprop.gradients_function(fn, [0, 1])(a, b)
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
tf_a = constant_op.constant([[1, 0], [0, 1]], dtype=dtypes.float32)
tf_b = constant_op.constant([[1, 2], [3, 4]], dtype=dtypes.float32)
tf_c = tf_a + tf_b
@@ -135,7 +135,7 @@ class TapeTest(test.TestCase):
a = constant_op.constant([[1., 0.], [0., 1.]])
b = constant_op.constant([[1., 2.], [3., 4.]])
da, db = backprop.gradients_function(fn, [0, 1])(a, b)
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
tf_a = constant_op.constant([[1, 0], [0, 1]], dtype=dtypes.float32)
tf_b = constant_op.constant([[1, 2], [3, 4]], dtype=dtypes.float32)
tf_mm = math_ops.matmul(tf_a, tf_b)
diff --git a/tensorflow/python/eager/tensor_test.py b/tensorflow/python/eager/tensor_test.py
index 871136e2c8..344a9b25bd 100644
--- a/tensorflow/python/eager/tensor_test.py
+++ b/tensorflow/python/eager/tensor_test.py
@@ -31,6 +31,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
def _create_tensor(value, device=None, dtype=None):
@@ -295,6 +296,7 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase):
def testFloatTensor(self):
self.assertEqual(dtypes.float64, _create_tensor(np.float64()).dtype)
self.assertEqual(dtypes.float32, _create_tensor(np.float32()).dtype)
+ self.assertEqual(dtypes.float16, _create_tensor(np.float16()).dtype)
self.assertEqual(dtypes.float32, _create_tensor(0.0).dtype)
def testSliceDimOutOfRange(self):
@@ -332,6 +334,19 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase):
"but tensor at index 2 has rank 0"):
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t2, t1, t3], 0)
+ @test_util.assert_no_new_pyobjects_executing_eagerly
+ def testTensorDir(self):
+ t = array_ops.zeros(1)
+ t.test_attr = "Test"
+
+ instance_dir = dir(t)
+ type_dir = dir(ops.EagerTensor)
+
+ # Monkey patched attributes should show up in dir(t)
+ self.assertIn("test_attr", instance_dir)
+ instance_dir.remove("test_attr")
+ self.assertEqual(instance_dir, type_dir)
+
if __name__ == "__main__":
test.main()