aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager
diff options
context:
space:
mode:
authorGravatar Cao Zongyan <zongyan.cao@alibaba-inc.com>2018-09-26 11:54:30 +0800
committerGravatar Cao Zongyan <zongyan.cao@alibaba-inc.com>2018-09-26 11:54:30 +0800
commit35174f46b973c66a2e6894a12b3018d60e8414ec (patch)
tree5bdae0172159bc02ec3a470722bf959b14dd47ba /tensorflow/python/eager
parentf0886f7269de900d226455d4831722f6fc94a71b (diff)
parent6666516f390f125ed70ddbd4e6f89b83d953c408 (diff)
Merge remote-tracking branch 'origin'
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r--tensorflow/python/eager/BUILD36
-rw-r--r--tensorflow/python/eager/backprop.py43
-rw-r--r--tensorflow/python/eager/backprop_test.py12
-rw-r--r--tensorflow/python/eager/def_function.py235
-rw-r--r--tensorflow/python/eager/def_function_test.py87
-rw-r--r--tensorflow/python/eager/function.py393
-rw-r--r--tensorflow/python/eager/function_test.py409
-rw-r--r--tensorflow/python/eager/imperative_grad.py5
-rw-r--r--tensorflow/python/eager/pywrap_tensor.cc41
-rw-r--r--tensorflow/python/eager/pywrap_tensor.h5
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc473
-rw-r--r--tensorflow/python/eager/pywrap_tfe_test.py25
12 files changed, 1404 insertions, 360 deletions
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index 85da1baaf0..d3d997e6df 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -17,7 +17,10 @@ cc_library(
"pywrap_tensor.h",
"pywrap_tfe.h",
],
- visibility = ["//tensorflow:internal"],
+ visibility = [
+ "//learning/deepmind/courier:__pkg__",
+ "//tensorflow:internal",
+ ],
deps = [
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_internal",
@@ -34,6 +37,7 @@ cc_library(
"//tensorflow/python:safe_ptr",
"//third_party/py/numpy:headers",
"//third_party/python_runtime:headers",
+ "@com_google_absl//absl/types:variant",
],
)
@@ -45,6 +49,7 @@ py_library(
":backprop",
":context",
":core",
+ ":def_function",
":execute",
":function",
":graph_only_ops",
@@ -146,6 +151,7 @@ cuda_py_test(
"//tensorflow/python:clip_ops",
"//tensorflow/python:init_ops",
"//tensorflow/python:layers",
+ "//tensorflow/python:list_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:resource_variable_ops",
],
@@ -345,6 +351,7 @@ py_test(
deps = [
":backprop",
":context",
+ ":core",
":test",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
@@ -377,3 +384,30 @@ cuda_py_test(
"optonly", # The test is too slow in non-opt mode
],
)
+
+py_library(
+ name = "def_function",
+ srcs = ["def_function.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":context",
+ ":function",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python/training/checkpointable:base",
+ ],
+)
+
+py_test(
+ name = "def_function_test",
+ srcs = ["def_function_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":def_function",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:framework_ops",
+ ],
+)
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index be392c7a0f..78f3198011 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -120,27 +120,6 @@ def _gradient_function(op_name, attr_tuple, num_inputs, inputs, outputs,
pywrap_tensorflow.TFE_Py_RegisterGradientFunction(_gradient_function)
-_tracing = False
-
-
-# TODO(agarwal): use an automatic mechanism for handling None arguments to
-# gradient functions.
-# Some gradient functions can accept None arguments for gradients. The following
-# maps the operation name to the indices at which the corresponding gradient
-# function can accept None values.
-# e.g. FusedBatchNorm outputs 5 values and hence receives 5 gradient values
-# during backprop. However the gradient function uses only the first of those
-# values and ignores the rest. The entry, "FusedBatchNorm": [1, 2, 3, 4],
-# indicates that only the gradient corresponding to index 0 is used, and the
-# gradient values at indices 1-4 are ignored (and hence can be None). The
-# backprop algorithm can then leverage this by not constructing zeros to
-# pass for those indices.
-_grad_fn_accepts_none_for_indices = {
- "SoftmaxCrossEntropyWithLogits": [1],
- "FusedBatchNorm": [1, 2, 3, 4]
-}
-
-
def _record_gradient(op_name, inputs, attrs, results, name):
return pywrap_tensorflow.TFE_Py_RecordGradient(op_name, inputs, attrs,
results, name)
@@ -585,7 +564,10 @@ def _aggregate_grads(gradients):
def _num_elements(grad):
"""The number of elements in the `grad` tensor."""
if isinstance(grad, ops.Tensor):
- return functools.reduce(operator.mul, grad._shape_tuple(), 1) # pylint: disable=protected-access
+ shape_tuple = grad._shape_tuple() # pylint: disable=protected-access
+ if shape_tuple is None or None in shape_tuple:
+ return 0
+ return functools.reduce(operator.mul, shape_tuple, 1)
if isinstance(grad, ops.IndexedSlices):
return functools.reduce(operator.mul, grad.values._shape_tuple(), 1) # pylint: disable=protected-access
raise ValueError("`grad` not a Tensor or IndexedSlices.")
@@ -629,8 +611,9 @@ def _ones(shape, dtype):
_default_vspace = imperative_grad.VSpace(
num_elements_fn=_num_elements,
aggregate_fn=_aggregate_grads,
- zeros=_zeros,
- ones=_ones)
+ zeros_fn=_zeros,
+ ones_fn=_ones,
+ graph_shape_fn=gen_array_ops.shape)
pywrap_tensorflow.TFE_Py_RegisterVSpace(_default_vspace)
@@ -648,8 +631,8 @@ class GradientTape(object):
Operations are recorded if they are executed within this context manager and
at least one of their inputs is being "watched".
- Trainable variables (created by `tf.Variable` or `tf.get_variable`,
- trainable=True is default in both cases) are automatically watched. Tensors
+ Trainable variables (created by `tf.Variable` or `tf.get_variable`, where
+ `trainable=True` is default in both cases) are automatically watched. Tensors
can be manually watched by invoking the `watch` method on this context
manager.
@@ -669,6 +652,7 @@ class GradientTape(object):
```python
x = tf.constant(3.0)
with tf.GradientTape() as g:
+ g.watch(x)
with tf.GradientTape() as gg:
gg.watch(x)
y = x * x
@@ -745,7 +729,9 @@ class GradientTape(object):
self._persistent = persistent
self._watch_accessed_variables = watch_accessed_variables
self._recording = False
- context.context().start_step()
+ self._created_eagerly = context.executing_eagerly()
+ if self._created_eagerly:
+ context.context().start_step()
def __enter__(self):
"""Enters a context inside which operations are recorded on this tape."""
@@ -775,7 +761,8 @@ class GradientTape(object):
self._recording = False
def __del__(self):
- context.context().end_step()
+ if self._created_eagerly:
+ context.context().end_step()
def watch(self, tensor):
"""Ensures that `tensor` is being traced by this tape.
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index f938ed5df8..32731747b7 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -1022,6 +1022,18 @@ class BackpropTest(test.TestCase):
resource_variable_ops.ResourceVariable(2.0))
self.assertAllEqual(gradients_constants, gradients_variables)
+ def testUnknownShapes(self):
+ with context.graph_mode():
+ with backprop.GradientTape() as tape:
+ a = array_ops.placeholder(dtype=dtypes.float32, shape=None)
+ tape.watch(a)
+ b = a**3
+
+ db_da = tape.gradient(b, a)
+
+ with self.cached_session() as sess:
+ self.assertEqual((8.0, 12.0), sess.run((b, db_da), feed_dict={a: 2.0}))
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py
new file mode 100644
index 0000000000..8dcacd5c99
--- /dev/null
+++ b/tensorflow/python/eager/def_function.py
@@ -0,0 +1,235 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# pylint: disable=unidiomatic-typecheck
+"""Prototype decorator for defining graph-mode functions with eager semantics."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.eager import context
+from tensorflow.python.eager import function
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.training.checkpointable import base as checkpointable
+
+
+class UnliftedInitializerVariable(resource_variable_ops.ResourceVariable):
+ """Variable which does not lift its initializer out of function context.
+
+ Instances of this variable, when created, build a graph which runs their
+ initializer inside a tf.cond(is_initialized) block.
+
+ This can only be created inside a defun called from (eventually) eager
+ mode. That is, non-function-building graphs are not supported.
+ """
+
+ def __init__(self, # pylint: disable=super-init-not-called
+ initial_value=None,
+ trainable=True,
+ caching_device=None,
+ name=None,
+ dtype=None,
+ constraint=None,
+ **unused_kwargs):
+ """Creates a variable.
+
+ Args:
+ initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
+ which is the initial value for the Variable. The initial value must have
+ a shape specified unless `validate_shape` is set to False. Can also be a
+ callable with no argument that returns the initial value when called.
+ (Note that initializer functions from init_ops.py must first be bound
+ to a shape before being used here.)
+ trainable: If `True`, GradientTapes automatically watch uses of this
+ Variable.
+ caching_device: Optional device string or function describing where the
+ Variable should be cached for reading. Defaults to the Variable's
+ device. If not `None`, caches on another device. Typical use is to
+ cache on the device where the Ops using the Variable reside, to
+ deduplicate copying through `Switch` and other conditional statements.
+ name: Optional name for the variable. Defaults to `'Variable'` and gets
+ uniquified automatically.
+ dtype: If set, initial_value will be converted to the given type.
+ If None, either the datatype will be kept (if initial_value is
+ a Tensor) or float32 will be used (if it is a Python object convertible
+ to a Tensor).
+ constraint: An optional projection function to be applied to the variable
+ after being updated by an `Optimizer` (e.g. used to implement norm
+ constraints or value constraints for layer weights). The function must
+ take as input the unprojected Tensor representing the value of the
+ variable and return the Tensor for the projected value
+ (which must have the same shape). Constraints are not safe to
+ use when doing asynchronous distributed training.
+
+ Raises:
+ ValueError: If the initial value is not specified, or does not have a
+ shape and `validate_shape` is `True`.
+ RuntimeError: If called outside of a function definition.
+ """
+ if context.executing_eagerly():
+ raise RuntimeError(
+ "UnliftedInitializerVariable should not be created "
+ "outside of functions.")
+ with ops.init_scope():
+ if not context.executing_eagerly():
+ raise RuntimeError(
+ "UnliftedInitializerVariable does not support legacy graph mode.")
+ self._in_graph_mode = False
+ if initial_value is None:
+ raise ValueError("initial_value must be specified.")
+ init_from_fn = callable(initial_value)
+
+ if constraint is not None and not callable(constraint):
+ raise ValueError("The `constraint` argument must be a callable.")
+
+ if isinstance(initial_value, checkpointable.CheckpointInitialValue):
+ self._maybe_initialize_checkpointable()
+ self._update_uid = initial_value.checkpoint_position.restore_uid
+ initial_value = initial_value.wrapped_value
+
+ self._trainable = trainable
+ self._save_slice_info = None
+ self._initial_value = None
+ self._initializer_op = None
+ self._is_initialized_op = None
+ self._graph_element = None
+ self._cached_value = None
+ # Store the graph key so optimizers know how to only retrieve variables from
+ # this graph. Guaranteed to be the same as the eager graph_key.
+ self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access
+ with ops.name_scope(name, "Variable", []
+ if init_from_fn else [initial_value]) as name:
+ # pylint: disable=protected-access
+ with ops.init_scope():
+ assert context.executing_eagerly()
+ shared_name = ops._name_from_scope_name(name)
+ shared_name = "%s_%d" % (shared_name, ops.uid())
+ # Use attr_scope and device(None) to simulate the behavior of
+ # colocate_with when the variable we want to colocate with doesn't
+ # yet exist.
+ with ops.name_scope("Initializer"), ops.device(None):
+ initial_value = ops.convert_to_tensor(
+ initial_value() if init_from_fn else initial_value,
+ name="initial_value", dtype=dtype)
+ with ops.init_scope():
+ self._handle = resource_variable_ops.eager_safe_variable_handle(
+ shape=initial_value.get_shape(),
+ dtype=initial_value.dtype.base_dtype,
+ shared_name=shared_name,
+ name=name,
+ graph_mode=False)
+ self._shape = initial_value.shape
+ self._unique_id = shared_name
+ self._handle_name = shared_name + ":0"
+ self._dtype = initial_value.dtype.base_dtype
+ self._constraint = constraint
+ assert initial_value is not None
+ def assign_fn():
+ with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
+ resource_variable_ops.assign_variable_op(
+ self._handle,
+ initial_value,
+ name=n)
+ # Returning values to keep tf.cond happy.
+ return ops.convert_to_tensor(1)
+ def not_assign_fn():
+ return ops.convert_to_tensor(0)
+ # Note: this cond is always guaranteed to run because we're inside a defun
+ # which will insert automatic control dependencies.
+ control_flow_ops.cond(
+ resource_variable_ops.var_is_initialized_op(self._handle),
+ not_assign_fn, assign_fn)
+
+ # After the handle has been created, set up a way to clean it up when
+ # executing eagerly. We'll hold the only reference to the deleter, so that
+ # when this object is garbage collected the deleter will be too. This
+ # means ResourceVariables can be part of reference cycles without those
+ # cycles being uncollectable.
+ self._handle_deleter = resource_variable_ops.EagerResourceDeleter(
+ handle=self._handle, handle_device=self._handle.device)
+ self._cached_shape_as_list = None
+
+
+def _defun_with_scope(scope, fn):
+
+ def wrapped_fn(*args, **kwds):
+ with variable_scope.variable_creator_scope(scope):
+ return fn(*args, **kwds)
+
+ return function.defun(wrapped_fn)
+
+
+def def_function(fn):
+ """Defines a function as per the "functions, not sessions" document."""
+
+ # Wrapping the values in lists to bypass python's lack of way to mutate
+ # symbols from an outer scope.
+ first_call = [True]
+ function_to_call = []
+
+ # TODO(apassos) represent this as an object and not as a closure.
+ def decorated_fn(*args, **kwds):
+ """Graph function for fn."""
+ if not first_call[0]:
+ return function_to_call[0](*args, **kwds)
+
+ first_call[0] = False
+ created_variables = []
+
+ def variable_creator_scope(unused_next_creator, **kwds):
+ """Creates UnliftedInitializerVariables and saves references to them."""
+ v = UnliftedInitializerVariable(**kwds)
+ created_variables.append(v)
+ return v
+
+ first_graph_function = _defun_with_scope(variable_creator_scope, fn)
+
+ # Force the definition of the function for these arguments
+ first_concrete = first_graph_function.get_concrete_function(*args, **kwds)
+
+ def invalid_creator_scope(*unused_args, **unused_kwds):
+ """Disables variable creation."""
+ raise ValueError(
+ "def_function-decorated function tried to create "
+ "variables on second call.")
+
+ second_graph_function = _defun_with_scope(invalid_creator_scope, fn)
+
+ function_to_call.append(second_graph_function)
+ if not created_variables:
+ # Note: this retracing might be unnecessary, but running the function
+ # forever in the scope which disallows variable creation is safer than not
+ # doing so.
+ return second_graph_function(*args, **kwds)
+
+ def fn_with_cond(*inner_args, **inner_kwds):
+ """Conditionally runs initialization if it's needed."""
+ condition = True
+ for variable in created_variables:
+ condition = condition and resource_variable_ops.var_is_initialized_op(
+ variable.handle)
+ # We want to call second_graph_function if possible because it avoids
+ # recomputing potentially expensive initializers.
+ return control_flow_ops.cond(
+ condition,
+ lambda: second_graph_function(*inner_args, **inner_kwds),
+ lambda: first_concrete(*inner_args, **inner_kwds))
+
+ return function.defun(fn_with_cond)(*args, **kwds)
+
+ return decorated_fn
diff --git a/tensorflow/python/eager/def_function_test.py b/tensorflow/python/eager/def_function_test.py
new file mode 100644
index 0000000000..804436c4bb
--- /dev/null
+++ b/tensorflow/python/eager/def_function_test.py
@@ -0,0 +1,87 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+from tensorflow.python.eager import def_function
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+class DefFunctionTest(test.TestCase):
+
+ def testNoVariables(self):
+
+ @def_function.def_function
+ def fn(x):
+ return 2 * x
+
+ self.assertAllEqual(fn(constant_op.constant(4.0)), 8.0)
+
+ def testFailIfVariablesAreCreatedMoreThanOnce(self):
+
+ @def_function.def_function
+ def fn(x):
+ return variables.Variable(1.0) + x
+
+ with self.assertRaises(ValueError):
+ fn(1.0)
+
+ def testFailIfVariablesAreCreatedMoreThanOnceNoWeakRef(self):
+ state = []
+
+ @def_function.def_function
+ def fn(x):
+ state.append(variables.Variable(1.0))
+ return state[-1] + x
+
+ with self.assertRaises(ValueError):
+ fn(1.0)
+
+ def testCorrectVariableCreation(self):
+
+ state = []
+
+ @def_function.def_function
+ def fn(x):
+ if not state:
+ state.append(variables.Variable(2.0))
+ return state[0] * x
+
+ self.assertAllEqual(fn(constant_op.constant(1.0)), 2.0)
+ self.assertAllEqual(fn(constant_op.constant(3.0)), 6.0)
+
+ def testVariableInitializerNotConstant(self):
+
+ state = []
+
+ @def_function.def_function
+ def fn(x):
+ if not state:
+ state.append(variables.Variable(2.0 * x))
+ return state[0] * x
+
+ self.assertAllEqual(fn(constant_op.constant(1.0)), 2.0)
+ self.assertAllEqual(fn(constant_op.constant(3.0)), 6.0)
+
+
+if __name__ == '__main__':
+ ops.enable_eager_execution()
+ test.main()
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 03f12139f6..b28befeb62 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -23,10 +23,12 @@ import collections
import functools
import sys
import threading
+import weakref
import numpy as np
import six
+from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import function_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
@@ -34,6 +36,7 @@ from tensorflow.python.eager import execute
from tensorflow.python.eager import tape
from tensorflow.python.eager.graph_only_ops import graph_placeholder
from tensorflow.python.framework import c_api_util
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes as dtypes_module
from tensorflow.python.framework import ops
@@ -59,23 +62,47 @@ cond_v2_impl._function = sys.modules[__name__] # pylint: disable=protected-acce
gradients_impl._function = sys.modules[__name__] # pylint: disable=protected-access
-def _create_substitute_placeholder(value, name, dtype=None):
+# TODO(scottzhu): Update this to allow arbitrary attribute names in future.
+WHITELIST_FUNCTION_ATTRIBUTE_PREFIX = "experimental_"
+
+
+def _create_substitute_placeholder(value, name=None, dtype=None):
"""Creates a placeholder for `value` and propagates shape info to it."""
# Note: setting ops.control_dependencies(None) ensures we always put
# capturing placeholders outside of any control flow context.
with ops.control_dependencies(None):
placeholder = graph_placeholder(
dtype=dtype or value.dtype, shape=value.shape, name=name)
- if placeholder.dtype == dtypes_module.resource:
- if isinstance(value, ops.EagerTensor):
- handle_data = value._handle_data # pylint: disable=protected-access
+ _copy_handle_data(value, placeholder)
+ return placeholder
+
+
+def _copy_handle_data(source_t, target_t):
+ """Copies HandleData for variant and resource type tensors if available.
+
+ The CppShapeInferenceResult::HandleData proto contains information about the
+ shapes and types of the element tensors of resource/variant type tensors.
+ We need to copy this across function boundaries, i.e., when capturing a
+ placeholder or when returning a function tensor as output. If we don't do this
+ the element tensors will have unknown shapes, e.g., if a TensorList variant
+ tensor is captured as a placeholder, elements popped from that list would have
+ unknown shape.
+
+ Args:
+ source_t: The tensor to copy HandleData from.
+ target_t: The tensor to copy HandleData to.
+ """
+ if (target_t.dtype == dtypes_module.resource or
+ target_t.dtype == dtypes_module.variant):
+ if isinstance(source_t, ops.EagerTensor):
+ handle_data = source_t._handle_data # pylint: disable=protected-access
else:
- handle_data = resource_variable_ops.get_resource_handle_data(value)
+ handle_data = resource_variable_ops.get_resource_handle_data(source_t)
if handle_data is not None and handle_data.is_set:
# pylint: disable=protected-access
- pywrap_tensorflow.SetResourceHandleShapeAndType(
- placeholder.graph._c_graph, placeholder._as_tf_output(),
- handle_data.SerializeToString())
+ pywrap_tensorflow.SetHandleShapeAndType(target_t.graph._c_graph,
+ target_t._as_tf_output(),
+ handle_data.SerializeToString())
# pylint: enable=protected-access
# Ensure that shapes and dtypes are propagated.
shapes, types = zip(*[(pair.shape, pair.dtype)
@@ -84,12 +111,10 @@ def _create_substitute_placeholder(value, name, dtype=None):
shapes = [[d.size for d in s.dim]
if not s.unknown_rank else None for s in shapes]
pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
- placeholder._op._graph._c_graph, # pylint: disable=protected-access
- placeholder._as_tf_output(), # pylint: disable=protected-access
+ target_t._op._graph._c_graph, # pylint: disable=protected-access
+ target_t._as_tf_output(), # pylint: disable=protected-access
shapes, ranks, types)
- return placeholder
-
def _get_device_functions(ctx, graph):
"""Returns a tuple of device functions representing the device stack."""
@@ -99,6 +124,44 @@ def _get_device_functions(ctx, graph):
return tuple(graph._device_functions_outer_to_inner) # pylint: disable=protected-access
+def _parse_func_attrs(attributes):
+ """Convert the keyword arguments into function_def attributes.
+
+ Currently only support primitive types: bool, int, float and string.
+
+ Args:
+ attributes: the dictionary of attributes.
+ Returns:
+ A dict of attributes where the key is the name of attribute and the value
+ is the AttrValue proto.
+ Raises:
+ ValueError: If the kwargs contains unwhitelisted name or unsupported value
+ types.
+ """
+ attrs = {}
+ for key, value in attributes.items():
+ if not key.startswith(WHITELIST_FUNCTION_ATTRIBUTE_PREFIX):
+ raise ValueError("Attribute name is not whitelisted. "
+ "Whitelisted: prefix %s, got: %s" %
+ (WHITELIST_FUNCTION_ATTRIBUTE_PREFIX, key))
+
+ if isinstance(value, attr_value_pb2.AttrValue):
+ attrs[key] = value
+ # bool type check has to happen before int since bool is a subclass of int.
+ elif isinstance(value, bool):
+ attrs[key] = attr_value_pb2.AttrValue(b=value)
+ elif isinstance(value, int):
+ attrs[key] = attr_value_pb2.AttrValue(i=value)
+ elif isinstance(value, float):
+ attrs[key] = attr_value_pb2.AttrValue(f=value)
+ elif isinstance(value, str):
+ attrs[key] = attr_value_pb2.AttrValue(s=compat.as_bytes(value))
+ else:
+ raise ValueError("Unsupported attribute type for %s with type %s" %
+ (key, type(value)))
+ return attrs
+
+
class FuncGraph(ops.Graph):
"""Graph representing a function body.
@@ -136,7 +199,7 @@ class FuncGraph(ops.Graph):
self.inputs = []
self.outputs = []
self.structured_outputs = None
- self.variables = []
+ self._weak_variables = []
self.outer_graph = ops.get_default_graph()
self.captures = collections.OrderedDict()
@@ -173,6 +236,31 @@ class FuncGraph(ops.Graph):
self._graph_key = graph._graph_key
# pylint: enable=protected-access
+ @property
+ def variables(self):
+ """A list of variables accessed by this FuncGraph.
+
+ Note that functions keep only weak references to variables. Calling the
+ function after a variable it accesses has been deleted is an error.
+
+ Yields:
+ Strong references to variables accessed by this FuncGraph.
+ """
+ for weak_v in self._weak_variables:
+ v = weak_v()
+ if v is None:
+ raise AssertionError(
+ "Called a function referencing variables which have been deleted. "
+ "This likely means that function-local variables were created and "
+ "not referenced elsewhere in the program. This is generally a "
+ "mistake; consider storing variables in an object attribute on "
+ "first call.")
+ yield v
+
+ @variables.setter
+ def variables(self, var_list):
+ self._weak_variables = [weakref.ref(v) for v in var_list]
+
def create_op(
self,
op_type,
@@ -365,6 +453,7 @@ class _EagerDefinedFunction(object):
self._num_outputs = len(self.signature.output_arg)
self._output_types = [o.type for o in self.signature.output_arg]
self._output_shapes = [o.shape for o in outputs]
+ self._func_graph_outputs = outputs
self.grad_func_name = None
self.python_grad_func = None
self._c_func = c_api_util.ScopedTFFunction(fn)
@@ -441,6 +530,8 @@ class _EagerDefinedFunction(object):
else:
for i, shape in enumerate(self._output_shapes):
outputs[i].set_shape(shape)
+ for i, func_graph_output in enumerate(self._func_graph_outputs):
+ _copy_handle_data(func_graph_output, outputs[i])
return outputs
@@ -485,7 +576,7 @@ class Function(object):
self._num_outputs = len(self._func_graph.outputs)
self._output_shapes = tuple(
output.shape for output in self._func_graph.outputs)
- self._attrs = attrs or {}
+ self._attrs = _parse_func_attrs(attrs or {})
self._device_functions = tuple(
self._func_graph._device_functions_outer_to_inner) # pylint: disable=protected-access
@@ -506,7 +597,19 @@ class Function(object):
self._distributed_variables[component_variable.handle] = variable
def __call__(self, *args):
- """Executes the wrapped function."""
+ """Executes the wrapped function.
+
+ Args:
+ *args: a list of Tensors or Variables.
+
+ Returns:
+ The result of applying the TF function to `args`.
+
+ Raises:
+ ValueError: If the current device stack does not match the device stack
+ under which the function was created, or if `args` contains anything
+ other than Tensors or Variables.
+ """
ctx = context.context()
device_functions = _get_device_functions(ctx, ops.get_default_graph())
if device_functions != self._device_functions:
@@ -522,7 +625,18 @@ class Function(object):
tape.variable_accessed(v)
captures = self._resolve_captured_inputs()
- tensor_inputs = [x for x in nest.flatten(args) if isinstance(x, ops.Tensor)]
+ tensor_inputs = []
+ for i, arg in enumerate(nest.flatten(args)):
+ if isinstance(arg, resource_variable_ops.ResourceVariable):
+ if arg.trainable:
+ tape.variable_accessed(arg)
+ tensor_inputs.append(arg.handle)
+ elif isinstance(arg, ops.Tensor):
+ tensor_inputs.append(arg)
+ else:
+ raise ValueError("All inputs to `Function`s must be Tensors; "
+ "on invocation of %s, the %d-th input (%s) was not a "
+ "Tensor." % (self._func_graph.name, i, str(arg)))
args = tensor_inputs + captures
if tape.should_record(tensor_inputs) or tape.should_record(captures):
@@ -537,11 +651,6 @@ class Function(object):
return self._func_graph
@property
- def variables(self):
- """Returns all variables touched by this function."""
- return self._func_graph.variables
-
- @property
def inputs(self):
"""Returns tensors in `self.graph` corresponding to arguments."""
return self._func_graph.inputs
@@ -738,7 +847,12 @@ def _get_defun_inputs_from_args(args):
return nest.pack_sequence_as(args, function_inputs)
-def func_graph_from_py_func(name, python_func, args, kwds, signature=None):
+def func_graph_from_py_func(name,
+ python_func,
+ args,
+ kwargs,
+ signature=None,
+ func_graph=None):
"""Returns a `FuncGraph` generated from `python_func`.
Args:
@@ -746,13 +860,15 @@ def func_graph_from_py_func(name, python_func, args, kwds, signature=None):
python_func: the Python function to trace.
args: the positional args with which the Python function should be called;
ignored if a signature is provided.
- kwds: the keyword args with which the Python function should be called;
+ kwargs: the keyword args with which the Python function should be called;
ignored if a signature is provided.
signature: a possibly nested sequence of `TensorSpecs` specifying the shapes
and dtypes of the arguments. When a signature is provided, `args` and
- `kwds` are ignored, and `python_func` is traced with Tensors conforming
+ `kwargs` are ignored, and `python_func` is traced with Tensors conforming
to `signature`. If `None`, the shapes and dtypes are inferred from the
inputs.
+ func_graph: Optional. An instance of FuncGraph. If provided, we will use
+ this graph else a new one is built and returned.
Returns:
A FuncGraph.
@@ -761,26 +877,25 @@ def func_graph_from_py_func(name, python_func, args, kwds, signature=None):
TypeError: If any of `python_func`'s return values is neither `None` nor a
`Tensor`.
"""
- func_graph = FuncGraph(name)
+ if func_graph is None:
+ func_graph = FuncGraph(name)
+ assert isinstance(func_graph, FuncGraph)
with func_graph.as_default(), AutomaticControlDependencies() as a:
variable_scope.get_variable_scope().set_use_resource(True)
if signature is None:
func_args = _get_defun_inputs_from_args(args)
- func_kwds = _get_defun_inputs_from_args(kwds)
+ func_kwargs = _get_defun_inputs_from_args(kwargs)
else:
func_args = _get_defun_inputs_from_signature(signature)
- func_kwds = {}
+ func_kwargs = {}
# Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`.
- func_graph.inputs.extend(
- x for x in nest.flatten(func_args) + nest.flatten(func_kwds)
- if isinstance(x, ops.Tensor))
-
# Variables to help check whether mutation happens in calling the function
# Copy the recursive list, tuple and map structure, but not base objects
func_args_before = nest.pack_sequence_as(func_args, nest.flatten(func_args))
- func_kwds_before = nest.pack_sequence_as(func_kwds, nest.flatten(func_kwds))
+ func_kwargs_before = nest.pack_sequence_as(
+ func_kwargs, nest.flatten(func_kwargs))
def convert(x):
"""Converts an argument to a Tensor."""
@@ -799,7 +914,7 @@ def func_graph_from_py_func(name, python_func, args, kwds, signature=None):
this_tape = tape.push_new_tape()
try:
- func_outputs = python_func(*func_args, **func_kwds)
+ func_outputs = python_func(*func_args, **func_kwargs)
# invariant: `func_outputs` contains only Tensors and `None`s.
func_outputs = nest.map_structure(convert, func_outputs)
@@ -819,10 +934,32 @@ def func_graph_from_py_func(name, python_func, args, kwds, signature=None):
raise ValueError(errmsg)
check_mutation(func_args_before, func_args)
- check_mutation(func_kwds_before, func_kwds)
+ check_mutation(func_kwargs_before, func_kwargs)
finally:
tape.pop_tape(this_tape)
+ # Variables in `func_args`, `func_kwargs` should be explicit inputs
+ # to the function, not captured inputs.
+ tape_variables = this_tape.watched_variables()
+ arg_variables = set()
+ inputs = []
+ for arg in nest.flatten(func_args) + nest.flatten(func_kwargs):
+ if isinstance(arg, resource_variable_ops.ResourceVariable):
+ try:
+ resource_placeholder = func_graph.captures.pop(arg.handle)
+ arg_variables.add(arg)
+ except KeyError:
+ # This case occurs if a Variable among the inputs is not actually
+ # used by the function; we still add an explicit input for it
+ # because the user should presumably pass the Variable as an input
+ # to the corresponding graph function.
+ resource_placeholder = _create_substitute_placeholder(arg.handle)
+ inputs.append(resource_placeholder)
+ elif isinstance(arg, ops.Tensor):
+ inputs.append(arg)
+ variables = [v for v in tape_variables if v not in arg_variables]
+ func_graph.inputs = inputs + list(func_graph.captures.values())
+
func_graph.structured_outputs = func_outputs
# Returning a closed-over tensor does not trigger convert_to_tensor.
func_graph.outputs.extend(
@@ -834,7 +971,6 @@ def func_graph_from_py_func(name, python_func, args, kwds, signature=None):
# Instead of storing non-distributed component variables, we
# store their distributed containers so we can retrieve the correct
# component variables at call-time.
- variables = list(this_tape.watched_variables())
strategy = distribution_strategy_context.get_distribution_strategy()
for i, variable in enumerate(variables):
# If variable is not distributed value_container returns itself.
@@ -879,9 +1015,6 @@ def _encode_arg(arg):
_TensorType(arg.values.dtype, arg.values._shape_tuple()),
_TensorType(arg.indices.dtype, arg.indices._shape_tuple()),
])
- elif isinstance(arg, np.ndarray):
- tensor = ops.convert_to_tensor(arg)
- return _TensorType(tensor.dtype, tensor._shape_tuple())
# pylint: enable=protected-access
elif isinstance(arg, (list, tuple)):
return tuple([_encode_arg(elem) for elem in arg])
@@ -889,7 +1022,16 @@ def _encode_arg(arg):
return tuple(
(_encode_arg(key), _encode_arg(arg[key])) for key in sorted(arg))
else:
- return arg
+ try:
+ # If possible, keep only a weak reference to Python objects. Weak
+ # references hash to the same value as the original object.
+ # TODO(allenl): Clean up dead functions and their cache keys if the cache
+ # gets large. Right now creating objects with a defunned method, calling
+ # the method, and losing a reference to the object in a loop will leak
+ # memory here.
+ return weakref.ref(arg)
+ except TypeError:
+ return arg
def _deterministic_dict_values(dictionary):
@@ -911,7 +1053,8 @@ class PolymorphicFunction(object):
def __init__(self,
python_function,
name,
- input_signature=None):
+ input_signature=None,
+ attributes=None):
"""Initializes a polymorphic function.
Args:
@@ -920,6 +1063,8 @@ class PolymorphicFunction(object):
input_signature: a possibly nested sequence of `TensorSpec` objects
specifying the input signature of this function. If `None`, a separate
function is instantiated for each inferred input signature.
+ attributes: dict, extra keyword arguments that will be added as attribute
+ of the function.
Raises:
ValueError: if `input_signature` is not None and the `python_function`'s
@@ -929,14 +1074,14 @@ class PolymorphicFunction(object):
if isinstance(python_function, functools.partial):
self._python_function = python_function.func
self._args_to_prepend = python_function.args or tuple()
- self._kwds_to_include = python_function.keywords or {}
+ self._kwargs_to_include = python_function.keywords or {}
else:
self._python_function = python_function
self._args_to_prepend = tuple()
- self._kwds_to_include = {}
+ self._kwargs_to_include = {}
self._name = name
self._function_cache = collections.OrderedDict()
- self._variables = []
+ self._function_attributes = attributes or {}
self._lock = threading.Lock()
@@ -971,9 +1116,9 @@ class PolymorphicFunction(object):
self._input_signature = tuple(input_signature)
self._flat_input_signature = tuple(nest.flatten(input_signature))
- def __call__(self, *args, **kwds):
+ def __call__(self, *args, **kwargs):
"""Calls a graph function specialized to the inputs."""
- graph_function, inputs = self._maybe_define_function(*args, **kwds)
+ graph_function, inputs = self._maybe_define_function(args, kwargs)
return graph_function(*inputs)
@property
@@ -981,12 +1126,6 @@ class PolymorphicFunction(object):
"""Returns the wrapped Python function."""
return self._python_function
- # TODO(akshayka): Remove this property.
- @property
- def variables(self):
- """Returns the union of all variables referenced by cached `Function`s`."""
- return self._variables
-
def get_concrete_function(self, *args, **kwargs):
"""Returns a `Function` object specialized to inputs and execution context.
@@ -997,7 +1136,7 @@ class PolymorphicFunction(object):
*args: inputs to specialize on.
**kwargs: inputs to specialize on.
"""
- graph_function, _ = self._maybe_define_function(*args, **kwargs)
+ graph_function, _ = self._maybe_define_function(args, kwargs)
return graph_function
def __get__(self, instance, owner):
@@ -1018,33 +1157,37 @@ class PolymorphicFunction(object):
# then `instance` will be `foo` (and `owner` will be `Foo`).
return functools.partial(self.__call__, instance)
- def _cache_key(self, args, kwds, ctx, graph):
+ def _cache_key(self, args, kwargs):
"""Computes the cache key given inputs and execution context."""
if self._input_signature is None:
- inputs = (args, kwds) if kwds else args
+ inputs = (args, kwargs) if kwargs else args
cache_key = tuple(_encode_arg(arg) for arg in inputs)
else:
- del args, kwds
+ del args, kwargs
cache_key = self._flat_input_signature
- # The graph, or whether we're executing eagerly, should be a part of the
- # cache key so we don't improperly capture tensors such as variables.
- executing_eagerly = ctx.executing_eagerly()
- execution_context = executing_eagerly or graph
+ with ops.init_scope():
+ init_graph = ops.get_default_graph()
+
+ # The graph, or whether we're executing eagerly, should be a part of the
+ # cache key so we don't improperly capture tensors such as variables.
+ executing_eagerly = context.executing_eagerly()
+ execution_context = executing_eagerly or init_graph
+ default_graph = ops.get_default_graph()
# Putting the device in the cache key ensures that call-site device
# annotations are respected.
- device_functions = _get_device_functions(ctx, graph)
+ device_functions = _get_device_functions(context.context(), default_graph)
# `ops.colocate_with` directives translate into `ops.device` directives when
# eager execution is enabled.
- colocation_stack = (None if executing_eagerly else
- tuple(graph._colocation_stack.peek_objs())) # pylint: disable=protected-access
+ colocation_stack = (() if executing_eagerly else
+ tuple(default_graph._colocation_stack.peek_objs())) # pylint: disable=protected-access
return cache_key + (execution_context, device_functions, colocation_stack)
- def _canonicalize_function_inputs(self, *args, **kwds):
- """Canonicalizes `args` and `kwds`.
+ def _canonicalize_function_inputs(self, *args, **kwargs):
+ """Canonicalizes `args` and `kwargs`.
Canonicalize the inputs to the Python function using its fullargspec. In
particular, we parse the varags and kwargs that this
@@ -1054,28 +1197,28 @@ class PolymorphicFunction(object):
Args:
*args: The varargs this object was called with.
- **kwds: The keyword args this function was called with.
+ **kwargs: The keyword args this function was called with.
Returns:
A canonicalized ordering of the inputs.
Raises:
- ValueError: If a keyword in `kwds` cannot be matched with a positional
+ ValueError: If a keyword in `kwargs` cannot be matched with a positional
argument when an input signature is specified, or when the inputs
do not conform to the input signature.
"""
args = self._args_to_prepend + args
- kwds = dict(kwds, **self._kwds_to_include)
+ kwargs = dict(kwargs, **self._kwargs_to_include)
# Maps from index of arg to its corresponding value, according to `args`
- # and `kwds`; seeded with the default values for the named args that aren't
- # in `args`.
+ # and `kwargs`; seeded with the default values for the named args that
+ # aren't in `args`.
arg_indices_to_values = {
index: default
for index, default in six.iteritems(self._arg_indices_to_default_values)
if index >= len(args)
}
consumed_args = []
- for arg, value in six.iteritems(kwds):
+ for arg, value in six.iteritems(kwargs):
index = self._args_to_indices.get(arg, None)
if index is not None:
arg_indices_to_values[index] = value
@@ -1085,20 +1228,30 @@ class PolymorphicFunction(object):
"function with keyword arguments when "
"input_signature is provided.")
for arg in consumed_args:
- # After this loop, `kwds` will only contain true keyword arguments, as
+ # After this loop, `kwargs` will only contain true keyword arguments, as
# opposed to named arguments called in a keyword-like fashion.
- kwds.pop(arg)
+ kwargs.pop(arg)
inputs = args + _deterministic_dict_values(arg_indices_to_values)
+ flat_inputs = nest.flatten(inputs)
+
+ # Check for NumPy arrays in arguments and convert them to Tensors.
+ need_packing = False
+ for index, value in enumerate(flat_inputs):
+ if isinstance(value, np.ndarray):
+ flat_inputs[index] = constant_op.constant(value)
+ need_packing = True
+ if need_packing:
+ inputs = nest.pack_sequence_as(structure=inputs,
+ flat_sequence=flat_inputs)
if self._input_signature is None:
- return inputs, kwds
+ return inputs, kwargs
else:
- assert not kwds
+ assert not kwargs
try:
nest.assert_same_structure(self._input_signature, inputs)
except (ValueError, TypeError):
raise ValueError("Structure of Python function inputs does not match "
"input_signature.")
- flat_inputs = nest.flatten(inputs)
if any(not isinstance(arg, ops.Tensor) for arg in flat_inputs):
raise ValueError("When input_signature is provided, all inputs to "
"the Python function must be Tensors.")
@@ -1112,25 +1265,27 @@ class PolymorphicFunction(object):
(str(inputs), str(self._input_signature)))
return inputs, {}
- def _maybe_define_function(self, *args, **kwds):
+ def _maybe_define_function(self, args, kwargs):
"""Gets a function for these inputs, defining it if necessary.
+ `args` and `kwargs` can be None if this `PolymorphicFunction` was created
+ with an `input_signature`.
+
Args:
- *args: args for the Python function.
- **kwds: keywords for the Python function.
+ args: The varargs for the Python function.
+ kwargs: The keyword args for the Python function.
Returns:
A graph function corresponding to the input signature implied by args and
- kwds, as well as the inputs that the object should be called with.
+ kwargs, as well as the inputs that the object should be called with.
Raises:
ValueError: If inputs are incompatible with the input signature.
TypeError: If the function inputs include non-hashable objects
"""
-
- args, kwds = self._canonicalize_function_inputs(*args, **kwds)
- cache_key = self._cache_key(args, kwds, context.context(),
- ops.get_default_graph())
+ if self._input_signature is None or args is not None or kwargs is not None:
+ args, kwargs = self._canonicalize_function_inputs(*args, **kwargs)
+ cache_key = self._cache_key(args, kwargs)
with self._lock:
try:
graph_function = self._function_cache.get(cache_key, None)
@@ -1141,11 +1296,41 @@ class PolymorphicFunction(object):
if graph_function is None:
graph_function = Function(
func_graph_from_py_func(self._name, self._python_function, args,
- kwds, self._input_signature))
- self._variables.extend(
- [v for v in graph_function.variables if v not in self._variables])
+ kwargs, self._input_signature),
+ self._function_attributes)
self._function_cache[cache_key] = graph_function
- return graph_function, (args, kwds)
+ return graph_function, [
+ t for t in nest.flatten((args, kwargs))
+ if isinstance(t, (ops.Tensor, resource_variable_ops.ResourceVariable))
+ ]
+
+
+def register(func, *args, **kwargs):
+ """Register the defun function into the graph.
+
+ This won't actually call the function with the inputs, and only put the
+ function definition into graph. Register function with different input param
+ will result into multiple version of functions registered in graph.
+
+ Args:
+ func: the PolymorphicFunction instance that generated by a @defun
+ *args: input arguments for the Python function.
+ **kwargs: input keyword arguments for the Python function.
+
+ Returns:
+ a `Function` object specialized to inputs and execution context.
+
+ Raises:
+ ValueError: When the input function is not a defun wrapped python function.
+ """
+ if not isinstance(func, PolymorphicFunction):
+ raise ValueError("Only defun function is allowed to be registered. "
+ "Got type: %s" % type(func))
+ concrete_func = func.get_concrete_function(*args, **kwargs)
+ graph = ops.get_default_graph()
+ concrete_func._inference_function.add_to_graph(graph) # pylint: disable=protected-access
+ # TODO(scottzhu): support concrete_func._backward_graph_function in future.
+ return concrete_func
def _validate_signature(signature):
@@ -1271,6 +1456,11 @@ def defun(func=None, input_signature=None):
tracing the execution of `f(*args, **kwargs)`; this graph is bound to an
input signature inferred from `(*args, **kwargs)` and cached for future reuse.
+ NumPy arrays passed as inputs to `F` are converted to `tf.Tensor` objects
+ before being passed to `f`, and are treated as Tensors for caching. This
+ allows a function to be called multiple times with NumPy arrays having
+ different values but the same shape and dtype without re-tracing each time.
+
`tf.contrib.eager.defun` caches graphs for your convenience, letting you
define TensorFlow functions without explicitly specifying their signatures.
However, this policy is conservative and potentially expensive; for example,
@@ -1470,7 +1660,29 @@ def defun(func=None, input_signature=None):
TypeError: If `input_signature` is neither `None` nor a sequence of
`tf.contrib.eager.TensorSpec` objects.
"""
+ return defun_with_attributes(func=func, input_signature=input_signature)
+
+def defun_with_attributes(func=None, input_signature=None, attributes=None):
+ """Compiles a Python function into a callable TensorFlow graph.
+
+ This function supports adding extra function attributes. See detailed
+ documentation in defun(). Currently this is not exposed in public API since we
+ don't expect user to directly use attributes, and attribute won't work by
+ itself. This assumption might change in future.
+
+ Args:
+ func: function to be compiled.
+ input_signature: same as defun()'s input_signature.
+ attributes: A dictionary of arguments which will be added to function def as
+ attributes. Currently only support primitive types as value, and only
+ whitelisted attribute name is allowed. Unwhitelisted attribute name or
+ unsupported value will result into ValueError.
+
+ Returns:
+ Same as the return value of defun, with attributes added to the function in
+ graph.
+ """
if input_signature is not None:
_validate_signature(input_signature)
@@ -1482,7 +1694,8 @@ def defun(func=None, input_signature=None):
name = "function"
return tf_decorator.make_decorator(
function,
- PolymorphicFunction(function, name, input_signature=input_signature))
+ PolymorphicFunction(function, name, input_signature=input_signature,
+ attributes=attributes))
# This code path is for the `foo = tfe.defun(foo, ...)` use case
if func is not None:
@@ -1727,9 +1940,9 @@ def automatic_control_dependencies(f):
The wrapped function.
"""
- def wrapper(*args, **kwds):
+ def wrapper(*args, **kwargs):
with AutomaticControlDependencies() as a:
- result = f(*args, **kwds)
+ result = f(*args, **kwargs)
result_flat = [a.mark_as_return(t) for t in nest.flatten(result)]
return nest.pack_sequence_as(result, result_flat)
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 92254a2c00..59faf967c5 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -21,8 +21,13 @@ import collections
import functools
from multiprocessing.pool import ThreadPool
import sys
+import weakref
+
+import numpy
from tensorflow.core.protobuf import config_pb2
+from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python import keras
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
@@ -36,12 +41,14 @@ from tensorflow.python.framework import random_seed
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
+from tensorflow.python.keras.engine import training as keras_training
from tensorflow.python.layers import convolutional
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import list_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
@@ -55,6 +62,28 @@ from tensorflow.python.util import compat
from tensorflow.python.util import nest
+class MiniModel(keras_training.Model):
+ """Minimal model for mnist.
+
+ Useful for testing and debugging on slow TPU simulators.
+ """
+
+ def __init__(self):
+ super(MiniModel, self).__init__(name='')
+ self.fc = keras.layers.Dense(1, name='fc', kernel_initializer='ones',
+ bias_initializer='ones')
+
+ def call(self, inputs, training=True):
+ return self.fc(inputs)
+
+
+class DefunnedMiniModel(MiniModel):
+
+ @function.defun
+ def call(self, inputs, training=True):
+ return super(DefunnedMiniModel, self).call(inputs, training=training)
+
+
@test_util.with_c_shapes
class FunctionTest(test.TestCase):
@@ -121,8 +150,8 @@ class FunctionTest(test.TestCase):
@function.defun
def f():
- v = resource_variable_ops.ResourceVariable(1.0)
- return v.read_value()
+ self.v = resource_variable_ops.ResourceVariable(1.0)
+ return self.v.read_value()
self.assertAllEqual(f(), 1.0)
@@ -314,6 +343,7 @@ class FunctionTest(test.TestCase):
def testDefunNumpyArraysConvertedToTensors(self):
def f(x):
+ self.assertIsInstance(x, ops.Tensor)
return x
x = random_ops.random_uniform([2, 2]).numpy()
@@ -327,6 +357,12 @@ class FunctionTest(test.TestCase):
# shouldn't trigger another function definition.
self.assertEqual(len(defined._function_cache), 1)
+ # Test that the numpy array is properly an argument to the graph function.
+ self.assertEqual(1., defined(numpy.ones([])).numpy())
+ self.assertEqual(0., defined(numpy.zeros([])).numpy())
+ self.assertEqual(1., defined(array_ops.ones([])).numpy())
+ self.assertEqual(0., defined(array_ops.zeros([])).numpy())
+
def testDefunCapturedInt32(self):
x = constant_op.constant(1, dtype=dtypes.int32)
@@ -373,9 +409,9 @@ class FunctionTest(test.TestCase):
@function.defun
def tensor_init():
- v = resource_variable_ops.ResourceVariable(
+ self.v = resource_variable_ops.ResourceVariable(
lambda: constant_op.constant(2.0))
- return v.read_value()
+ return self.v.read_value()
value = tensor_init()
if not context.executing_eagerly():
@@ -389,8 +425,8 @@ class FunctionTest(test.TestCase):
def tensor_init():
with ops.init_scope():
const = constant_op.constant(2.0)
- v = resource_variable_ops.ResourceVariable(const)
- return v.read_value()
+ self.v = resource_variable_ops.ResourceVariable(const)
+ return self.v.read_value()
value = tensor_init()
if not context.executing_eagerly():
@@ -403,10 +439,17 @@ class FunctionTest(test.TestCase):
def f():
x = constant_op.constant([[1, 2], [3, 4]])
out = math_ops.matmul(v, x)
- self.assertEqual(out.get_shape(), tensor_shape.TensorShape([2, 2]))
+ self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2]))
+ # We do not return v directly since the tensor conversion function of
+ # ResourceVariable returns the read value and not the resource itself.
+ return v._handle
compiled = function.defun(f)
- compiled()
+ var_handle = compiled()
+ self.assertEqual(var_handle.dtype, dtypes.resource)
+ self.assertEqual(var_handle.shape, tensor_shape.scalar())
+ var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype)
+ self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2]))
def testVariableInLoopInFunction(self):
@@ -430,10 +473,17 @@ class FunctionTest(test.TestCase):
def f():
x = constant_op.constant([[1, 2], [3, 4]])
out = math_ops.matmul(v, x)
- self.assertEqual(out.get_shape(), tensor_shape.TensorShape([2, 2]))
+ self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2]))
+ # We do not return v directly since the tensor conversion function of
+ # ResourceVariable returns the read value and not the resource itself.
+ return v._handle
compiled = function.defun(f)
- compiled()
+ var_handle = compiled()
+ self.assertEqual(var_handle.dtype, dtypes.resource)
+ self.assertEqual(var_handle.shape, tensor_shape.scalar())
+ var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype)
+ self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2]))
def testDefunShapeInferenceWithCapturedVariableInGraphMode(self):
with context.graph_mode():
@@ -442,23 +492,46 @@ class FunctionTest(test.TestCase):
def f():
x = constant_op.constant([[1, 2], [3, 4]])
out = math_ops.matmul(v, x)
- self.assertEqual(out.get_shape(), tensor_shape.TensorShape([2, 2]))
+ self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2]))
# Check that shape inference works while creating the defun
compiled = function.defun(f)
compiled()
+ def testDefunShapeInferenceWithCapturedTensorListInGraphMode(self):
+ with context.graph_mode():
+ tensor_list = list_ops.empty_tensor_list(
+ element_dtype=dtypes.float32,
+ element_shape=ops.convert_to_tensor([], dtype=dtypes.int32))
+ tensor_list = list_ops.tensor_list_push_back(tensor_list,
+ constant_op.constant(1.0))
+ tensor_list = list_ops.tensor_list_push_back(tensor_list,
+ constant_op.constant(2.0))
+
+ def f():
+ tl, value = list_ops.tensor_list_pop_back(
+ tensor_list, element_dtype=dtypes.float32)
+ self.assertEqual(value.shape, tensor_shape.scalar())
+ return tl
+
+ compiled = function.defun(f)
+ output_tensor_list = compiled()
+ _, value = list_ops.tensor_list_pop_back(
+ output_tensor_list, element_dtype=dtypes.float32)
+ self.assertEqual(value.shape, tensor_shape.scalar())
+
@test_util.run_in_graph_and_eager_modes
def testDefunForcesResourceVariables(self):
def variable_creator():
- return variables.Variable(0.0).read_value()
+ self.v = variables.Variable(0.0)
+ return self.v.read_value()
+ self.v = None
defined = function.defun(variable_creator)
defined() # Create the variable.
- self.assertEqual(len(defined.variables), 1)
self.assertIsInstance(
- defined.variables[0], resource_variable_ops.ResourceVariable)
+ self.v, resource_variable_ops.ResourceVariable)
def testDefunDifferentiable(self):
v = resource_variable_ops.ResourceVariable(1.0)
@@ -996,6 +1069,7 @@ class FunctionTest(test.TestCase):
with ops.get_default_graph().as_default():
create_variable()
+ @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
def testLayerInDefun(self):
conv = convolutional.Conv2D(
filters=1,
@@ -1009,7 +1083,34 @@ class FunctionTest(test.TestCase):
x = array_ops.ones([1, 2, 2, 1])
y = model(x)
- self.assertAllEqual([[[[4.0]]]], y.numpy())
+
+ if not context.executing_eagerly():
+ self.evaluate(variables.global_variables_initializer())
+
+ self.assertAllEqual([[[[4.0]]]], self.evaluate(y))
+
+ # Remove reference cycles in model
+ test_util.dismantle_polymorphic_function(model)
+
+ @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
+ def testDefunKerasModelCall(self):
+ model = MiniModel()
+ model.call = function.defun(model.call)
+
+ x = array_ops.ones([1, 2])
+ y = model(x)
+
+ if not context.executing_eagerly():
+ self.evaluate(variables.global_variables_initializer())
+
+ self.assertAllEqual([[3.0]], self.evaluate(y))
+
+ # Remove reference cycles in defun.
+ test_util.dismantle_polymorphic_function(model.call)
+ # Break the reference cycle between the MiniModel and the defun:
+ # MiniModel --(through its `call` method)--> PolymorphicFunction
+ # PolymorphicFunction --(instancemethod on MiniModel)--> MiniModel
+ del model.call
# Note: The ConfigProto below unfortunately only configures graph
# construction. Eager's configuration is controlled in `__main__`.
@@ -1130,13 +1231,11 @@ class FunctionTest(test.TestCase):
defined = function.defun(foo)
x = constant_op.constant([1.0])
- self.assertAllEqual(defined.variables, [])
- _ = defined(x)
- self.assertAllEqual(defined.variables, [v])
+ self.assertEqual(1., self.evaluate(defined(x)))
+ v.assign(2.)
x = constant_op.constant([1.0, 2.0])
- _ = defined(x) # ensure the variables list remains the same
- self.assertAllEqual(defined.variables, [v])
+ self.assertAllEqual([2., 4.], self.evaluate(defined(x)))
def testPythonFunctionWithDefaultArgs(self):
@@ -1492,6 +1591,257 @@ class FunctionTest(test.TestCase):
side_effecting_function.python_function()
self.assertAllEqual(state, [0, 0])
+ def testFunctionWithExtraAttributes(self):
+ @function.defun_with_attributes(attributes={'experimental_1': 'value1',
+ 'experimental_2': 2})
+ def matmul(x, y):
+ return math_ops.matmul(x, y)
+
+ def add(x, y):
+ return math_ops.add(x, y)
+ defun_add = function.defun_with_attributes(
+ add, attributes={'experimental_3': True, 'experimental_4': 1.0})
+
+ with context.graph_mode(), self.cached_session():
+ with ops.get_default_graph().as_default():
+ t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+ sq = matmul(t, t)
+ double = defun_add(t, t)
+ self.assertAllEqual(sq.eval().reshape(-1), [7, 10, 15, 22])
+ self.assertAllEqual(double.eval().reshape(-1), [2, 4, 6, 8])
+
+ graph = ops.get_default_graph()
+ # pylint: disable=protected-access
+ self.assertEqual(len(graph._functions), 2)
+ functions = list(graph._functions.values())
+ self.assertRegexpMatches(
+ functions[0].definition.signature.name, '.*matmul.*')
+ attrs = functions[0].definition.attr
+ self.assertEqual(len(attrs), 2)
+ self.assertEqual(attrs['experimental_1'].s, b'value1')
+ self.assertEqual(attrs['experimental_2'].i, 2)
+
+ self.assertRegexpMatches(
+ functions[1].definition.signature.name, '.*add.*')
+ attrs = functions[1].definition.attr
+ self.assertEqual(len(attrs), 2)
+ self.assertEqual(attrs['experimental_3'].b, True)
+ self.assertEqual(attrs['experimental_4'].f, 1.0)
+ # pylint: enable=protected-access
+
+ def testFunctionWithInvalidAttribute(self):
+ @function.defun_with_attributes(attributes={'attr1': 'value1'})
+ def matmul(x, y):
+ return math_ops.matmul(x, y)
+
+ with self.assertRaisesRegexp(ValueError,
+ '.*Attribute name is not whitelisted.*'):
+ with context.graph_mode(), self.cached_session():
+ with ops.get_default_graph().as_default():
+ t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+ matmul(t, t)
+
+ @function.defun_with_attributes(attributes={'experimental_1': ['value1']})
+ def add(x, y):
+ return math_ops.add(x, y)
+
+ with self.assertRaisesRegexp(ValueError,
+ '.*Unsupported attribute type.*'):
+ with context.graph_mode(), self.cached_session():
+ with ops.get_default_graph().as_default():
+ t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+ add(t, t)
+
+ def testRegisterFunction(self):
+ @function.defun
+ def add(x, y):
+ return math_ops.add(x, y)
+
+ def matmul(x, y):
+ return math_ops.matmul(x, y)
+ defun_matmul = function.defun(matmul)
+
+ with context.graph_mode(), self.cached_session():
+ with ops.get_default_graph().as_default():
+ t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+ function.register(defun_matmul, t, t)
+ function.register(add, t, t)
+
+ graph = ops.get_default_graph()
+ # pylint: disable=protected-access
+ self.assertEqual(len(graph._functions), 2)
+ functions = list(graph._functions.values())
+ pre_register_matmul_func_name = functions[0].definition.signature.name
+ self.assertRegexpMatches(pre_register_matmul_func_name, '.*matmul.*')
+ pre_register_add_func_name = functions[1].definition.signature.name
+ self.assertRegexpMatches(pre_register_add_func_name, '.*add.*')
+
+ sq = defun_matmul(t, t)
+ double = add(t, t)
+ self.assertAllEqual(sq.eval().reshape(-1), [7, 10, 15, 22])
+ self.assertAllEqual(double.eval().reshape(-1), [2, 4, 6, 8])
+ # Make sure the pre registered function is used, and no other function
+ # is added.
+ self.assertEqual(len(graph._functions), 2)
+ functions = list(graph._functions.values())
+ called_func_name = functions[0].definition.signature.name
+ self.assertEqual(pre_register_matmul_func_name, called_func_name)
+ called_func_name = functions[1].definition.signature.name
+ self.assertEqual(pre_register_add_func_name, called_func_name)
+
+ def testRegisterFunctionWithInputSignature(self):
+ def matmul(x, y):
+ return math_ops.matmul(x, y)
+ defun_matmul = function.defun(
+ matmul,
+ input_signature=[
+ tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32),
+ tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32)
+ ])
+ with context.graph_mode(), self.cached_session():
+ with ops.get_default_graph().as_default():
+ t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+ function.register(defun_matmul, t, t)
+
+ graph = ops.get_default_graph()
+ # pylint: disable=protected-access
+ self.assertEqual(len(graph._functions), 1)
+
+ # Test input param shape mismatch
+ t2 = constant_op.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
+ with self.assertRaisesRegexp(
+ ValueError, 'Python inputs incompatible with input_signature'):
+ function.register(defun_matmul, t2, t2)
+
+ def testRegisterFunctionWithCache(self):
+ def matmul(x, y):
+ return math_ops.matmul(x, y)
+ defun_matmul = function.defun(matmul)
+
+ with context.graph_mode(), self.cached_session():
+ with ops.get_default_graph().as_default():
+ t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+ t2 = constant_op.constant([[2.0, 3.0], [4.0, 5.0]])
+ function.register(defun_matmul, t, t)
+ function.register(defun_matmul, t2, t2)
+
+ graph = ops.get_default_graph()
+ # Only one function is registered since the input param are in same type
+ # pylint: disable=protected-access
+ self.assertEqual(len(graph._functions), 1)
+
+ def testCallingFunctionWithDifferentVariables(self):
+
+ @function.defun
+ def foo(v):
+ v.assign_add(1.0)
+ return v.read_value()
+
+ v = resource_variable_ops.ResourceVariable(0.0)
+ graph_function = foo.get_concrete_function(v)
+ self.assertEqual(len(graph_function.inputs), 1)
+ self.assertEqual(len(graph_function.captured_inputs), 0)
+
+ self.assertEqual(float(graph_function(v)), 1.0)
+ self.assertEqual(float(graph_function(v)), 2.0)
+
+ w = resource_variable_ops.ResourceVariable(0.0)
+
+ @function.defun
+ def bar(v):
+ del v
+ return constant_op.constant(1.0)
+
+ graph_function = bar.get_concrete_function(v)
+ self.assertEqual(float(graph_function(v)), 1.0)
+ self.assertEqual(float(graph_function(w)), 1.0)
+
+ def testCallingFunctionWithNonTensorsFails(self):
+
+ @function.defun
+ def foo(x):
+ return x
+
+ graph_function = foo.get_concrete_function(constant_op.constant(1.0))
+ with self.assertRaisesRegexp(ValueError, 'All inputs to `Function`s must '
+ 'be Tensors;.*'):
+ graph_function('Not a Tensor.')
+
+ def testSwapImplementationWithGrapplerPlugin(self):
+ rewrites = rewriter_config_pb2.RewriterConfig()
+ # function_optimizer has to be turn off, otherwise it will delete the
+ # registered function if it does not get called.
+ # TODO(scottzhu): Move the ExperimentalImplementationSelector to be called
+ # before function_optimizer in future.
+ rewrites.function_optimization = rewriter_config_pb2.RewriterConfig.OFF
+ customer_optimizer = rewrites.custom_optimizers.add()
+ customer_optimizer.name = 'ExperimentalImplementationSelector'
+ rewrites.min_graph_nodes = -1
+ graph_options = config_pb2.GraphOptions(
+ rewrite_options=rewrites, build_cost_model=1)
+ config = config_pb2.ConfigProto(graph_options=graph_options)
+
+ with context.graph_mode(), self.cached_session(
+ config=config, graph=ops.Graph(), use_gpu=True) as sess:
+
+ @function.defun_with_attributes(
+ attributes={
+ 'experimental_api_implements': 'random_boost',
+ 'experimental_api_preferred_device': 'CPU'
+ })
+ def cpu_boost(x):
+ return math_ops.add(x, 2.0)
+
+ @function.defun_with_attributes(
+ attributes={
+ 'experimental_api_implements': 'random_boost',
+ 'experimental_api_preferred_device': 'GPU'
+ })
+ def gpu_boost(x):
+ return math_ops.add(x, 4.0)
+
+ x = constant_op.constant(1.0)
+
+ function.register(cpu_boost, x)
+ y = gpu_boost(x)
+ y_value = sess.run(y)
+
+ if test.is_gpu_available():
+ self.assertEquals(y_value, 5.0)
+ else:
+ # Grappler fallback to use the CPU impl even called with GPU function.
+ self.assertEquals(y_value, 3.0)
+
+ def testDefunFunctionSeparateGraphs(self):
+ with context.graph_mode():
+
+ @function.defun
+ def add(x):
+ return x + 5
+
+ @function.defun
+ def maybe_add(x, should_add):
+ if should_add:
+ return add(x)
+ else:
+ return x
+
+ with ops.Graph().as_default():
+ x = constant_op.constant(11)
+ maybe_add(x, True)
+ self.assertEqual(len(maybe_add._function_cache), 1)
+ self.assertEqual(len(add._function_cache), 1)
+
+ maybe_add(x, False)
+ self.assertEqual(len(maybe_add._function_cache), 2)
+ self.assertEqual(len(add._function_cache), 1)
+
+ with ops.Graph().as_default():
+ x = constant_op.constant(11)
+ maybe_add(x, True)
+ self.assertEqual(len(maybe_add._function_cache), 3)
+ self.assertEqual(len(add._function_cache), 2)
+
@test_util.with_c_shapes
class AutomaticControlDependenciesTest(test.TestCase):
@@ -1683,10 +2033,10 @@ class AutomaticControlDependenciesTest(test.TestCase):
@function.defun
def train():
- v = resource_variable_ops.ResourceVariable(1.0)
- grad = backprop.implicit_grad(loss)(v)
+ self.v = resource_variable_ops.ResourceVariable(1.0)
+ grad = backprop.implicit_grad(loss)(self.v)
optimizer.apply_gradients(grad)
- return v.read_value()
+ return self.v.read_value()
value = train()
self.assertEqual(value.numpy(), -1.0)
@@ -1713,10 +2063,10 @@ class AutomaticControlDependenciesTest(test.TestCase):
@function.defun
def train():
- v = resource_variable_ops.ResourceVariable(1.0)
- grad = backprop.implicit_grad(loss)(v)
+ self.v = resource_variable_ops.ResourceVariable(1.0)
+ grad = backprop.implicit_grad(loss)(self.v)
optimizer.apply_gradients(grad)
- return v.read_value()
+ return self.v.read_value()
train()
@@ -1903,6 +2253,13 @@ class AutomaticControlDependenciesTest(test.TestCase):
modify_same_flat(nested_input)
+ def testDecoratedMethodVariableCleanup(self):
+ m = DefunnedMiniModel()
+ m(array_ops.ones([1, 2]))
+ weak_variables = weakref.WeakSet(m.variables)
+ self.assertEqual(2, len(weak_variables))
+ del m
+ self.assertEqual([], list(weak_variables))
if __name__ == '__main__':
ops.enable_eager_execution(
diff --git a/tensorflow/python/eager/imperative_grad.py b/tensorflow/python/eager/imperative_grad.py
index 5f027d107c..5f5af4ab6c 100644
--- a/tensorflow/python/eager/imperative_grad.py
+++ b/tensorflow/python/eager/imperative_grad.py
@@ -23,8 +23,9 @@ import collections
from tensorflow.python import pywrap_tensorflow
-VSpace = collections.namedtuple(
- "VSpace", ["aggregate_fn", "num_elements_fn", "zeros", "ones"])
+VSpace = collections.namedtuple("VSpace", [
+ "aggregate_fn", "num_elements_fn", "zeros_fn", "ones_fn", "graph_shape_fn"
+])
def imperative_grad(
diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc
index f34ce6af79..5f44bd4fec 100644
--- a/tensorflow/python/eager/pywrap_tensor.cc
+++ b/tensorflow/python/eager/pywrap_tensor.cc
@@ -516,25 +516,13 @@ static PyObject* EagerTensor_rank(EagerTensor* self) {
// Getter for `_num_elements`.
static PyObject* EagerTensor_num_elements(EagerTensor* self) {
auto handle = self->handle;
- int n = TFE_TensorHandleNumDims(handle, self->status);
+ int n = TFE_TensorHandleNumElements(handle, self->status);
if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
// Cleanup self->status before returning.
TF_SetStatus(self->status, TF_OK, "");
return nullptr;
}
- tensorflow::int64 value = 1;
- if (PyErr_Occurred()) return nullptr;
- for (int i = 0; i < n; ++i) {
- int64_t dim = TFE_TensorHandleDim(handle, i, self->status);
- if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
- // Cleanup self->status before returning.
- TF_SetStatus(self->status, TF_OK, "");
- PyErr_SetString(PyExc_RuntimeError, "Error while iterating dimensions");
- return nullptr;
- }
- value *= dim;
- }
- return PyLong_FromLongLong(value);
+ return PyLong_FromLongLong(n);
}
static PyObject* EagerTensor_tensor_handle(EagerTensor* self, void* unused) {
@@ -777,17 +765,34 @@ PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle) {
return reinterpret_cast<PyObject*>(t);
}
-tensorflow::int64 EagerTensor_id(const PyObject* tensor) {
- CHECK(EagerTensor_CheckExact(tensor));
+tensorflow::int64 PyEagerTensor_ID(const PyObject* tensor) {
+ DCHECK(EagerTensor_CheckExact(tensor));
return reinterpret_cast<const EagerTensor*>(tensor)->id;
}
-tensorflow::DataType EagerTensor_dtype(const PyObject* tensor) {
- CHECK(EagerTensor_CheckExact(tensor));
+tensorflow::DataType PyEagerTensor_Dtype(const PyObject* tensor) {
+ DCHECK(EagerTensor_CheckExact(tensor));
return static_cast<tensorflow::DataType>(TFE_TensorHandleDataType(
reinterpret_cast<const EagerTensor*>(tensor)->handle));
}
+tensorflow::int64 PyEagerTensor_NumElements(const PyObject* tensor) {
+ DCHECK(EagerTensor_CheckExact(tensor));
+ const EagerTensor* as_c_eager_tensor =
+ reinterpret_cast<const EagerTensor*>(tensor);
+ tensorflow::int64 result = TFE_TensorHandleNumElements(
+ as_c_eager_tensor->handle, as_c_eager_tensor->status);
+
+ if (MaybeRaiseExceptionFromTFStatus(as_c_eager_tensor->status,
+ PyExc_ValueError)) {
+ // Cleanup status before returning.
+ TF_SetStatus(as_c_eager_tensor->status, TF_OK, "");
+ return -1;
+ }
+
+ return result;
+}
+
PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
if (!PyType_Check(base_class)) {
PyErr_SetString(
diff --git a/tensorflow/python/eager/pywrap_tensor.h b/tensorflow/python/eager/pywrap_tensor.h
index bc042eb19e..4eaa1ba536 100644
--- a/tensorflow/python/eager/pywrap_tensor.h
+++ b/tensorflow/python/eager/pywrap_tensor.h
@@ -21,8 +21,9 @@ limitations under the License.
#include "tensorflow/python/lib/core/numpy.h"
bool EagerTensor_CheckExact(const PyObject* o);
-tensorflow::int64 EagerTensor_id(const PyObject* tensor);
-tensorflow::DataType EagerTensor_dtype(const PyObject* tensor);
+tensorflow::int64 PyEagerTensor_ID(const PyObject* tensor);
+tensorflow::DataType PyEagerTensor_Dtype(const PyObject* tensor);
+tensorflow::int64 PyEagerTensor_NumElements(const PyObject* tensor);
namespace tensorflow {
TFE_TensorHandle* ConvertToEagerTensor(PyObject* value, PyObject* dtype);
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 46dcf7c8a8..159b1c1218 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/python/eager/pywrap_tfe.h"
+#include "absl/types/variant.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api_internal.h"
@@ -860,7 +861,7 @@ static tensorflow::int64 MakeInt(PyObject* integer) {
static tensorflow::int64 FastTensorId(PyObject* tensor) {
if (EagerTensor_CheckExact(tensor)) {
- return EagerTensor_id(tensor);
+ return PyEagerTensor_ID(tensor);
}
PyObject* id_field = PyObject_GetAttrString(tensor, "_id");
if (id_field == nullptr) {
@@ -873,7 +874,7 @@ static tensorflow::int64 FastTensorId(PyObject* tensor) {
static tensorflow::DataType FastTensorDtype(PyObject* tensor) {
if (EagerTensor_CheckExact(tensor)) {
- return EagerTensor_dtype(tensor);
+ return PyEagerTensor_Dtype(tensor);
}
PyObject* dtype_field = PyObject_GetAttrString(tensor, "dtype");
if (dtype_field == nullptr) {
@@ -889,12 +890,239 @@ static tensorflow::DataType FastTensorDtype(PyObject* tensor) {
return static_cast<tensorflow::DataType>(id);
}
+class PyTapeTensor {
+ public:
+ PyTapeTensor(tensorflow::int64 id, tensorflow::DataType dtype,
+ const tensorflow::TensorShape& shape)
+ : id_(id), dtype_(dtype), shape_(shape) {}
+ PyTapeTensor(tensorflow::int64 id, tensorflow::DataType dtype,
+ PyObject* shape)
+ : id_(id), dtype_(dtype), shape_(shape) {
+ Py_INCREF(absl::get<1>(shape_));
+ }
+ PyTapeTensor(const PyTapeTensor& other) {
+ id_ = other.id_;
+ dtype_ = other.dtype_;
+ shape_ = other.shape_;
+ if (shape_.index() == 1) {
+ Py_INCREF(absl::get<1>(shape_));
+ }
+ }
+
+ ~PyTapeTensor() {
+ if (shape_.index() == 1) {
+ Py_DECREF(absl::get<1>(shape_));
+ }
+ }
+ PyObject* GetShape() const;
+ PyObject* GetDType() const { return PyLong_FromLong(dtype_); }
+ tensorflow::int64 GetID() const { return id_; }
+
+ private:
+ tensorflow::int64 id_;
+ tensorflow::DataType dtype_;
+ absl::variant<tensorflow::TensorShape, PyObject*> shape_;
+};
+
+class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction,
+ PyTapeTensor> {
+ public:
+ explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {
+ Py_INCREF(py_vspace_);
+ }
+
+ tensorflow::Status Initialize() {
+ num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn");
+ if (num_elements_ == nullptr) {
+ return tensorflow::errors::InvalidArgument("invalid vspace");
+ }
+ aggregate_fn_ = PyObject_GetAttrString(py_vspace_, "aggregate_fn");
+ if (aggregate_fn_ == nullptr) {
+ return tensorflow::errors::InvalidArgument("invalid vspace");
+ }
+ zeros_fn_ = PyObject_GetAttrString(py_vspace_, "zeros_fn");
+ if (zeros_fn_ == nullptr) {
+ return tensorflow::errors::InvalidArgument("invalid vspace");
+ }
+ ones_fn_ = PyObject_GetAttrString(py_vspace_, "ones_fn");
+ if (ones_fn_ == nullptr) {
+ return tensorflow::errors::InvalidArgument("invalid vspace");
+ }
+ graph_shape_fn_ = PyObject_GetAttrString(py_vspace_, "graph_shape_fn");
+ if (graph_shape_fn_ == nullptr) {
+ return tensorflow::errors::InvalidArgument("invalid vspace");
+ }
+ return tensorflow::Status::OK();
+ }
+
+ ~PyVSpace() override {
+ Py_XDECREF(num_elements_);
+ Py_XDECREF(aggregate_fn_);
+ Py_XDECREF(zeros_fn_);
+ Py_XDECREF(ones_fn_);
+ Py_XDECREF(graph_shape_fn_);
+
+ Py_DECREF(py_vspace_);
+ }
+
+ tensorflow::int64 NumElements(PyObject* tensor) const final {
+ if (EagerTensor_CheckExact(tensor)) {
+ return PyEagerTensor_NumElements(tensor);
+ }
+ PyObject* arglist =
+ Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor));
+ PyObject* result = PyEval_CallObject(num_elements_, arglist);
+ Py_DECREF(arglist);
+ if (result == nullptr) {
+ // The caller detects whether a python exception has been raised.
+ return -1;
+ }
+ tensorflow::int64 r = MakeInt(result);
+ Py_DECREF(result);
+ return r;
+ }
+
+ PyObject* AggregateGradients(
+ tensorflow::gtl::ArraySlice<PyObject*> gradient_tensors) const final {
+ PyObject* list = PyList_New(gradient_tensors.size());
+ for (int i = 0; i < gradient_tensors.size(); ++i) {
+ // Note: stealing a reference to the gradient tensors.
+ CHECK(gradient_tensors[i] != nullptr);
+ CHECK(gradient_tensors[i] != Py_None);
+ PyList_SET_ITEM(list, i,
+ reinterpret_cast<PyObject*>(gradient_tensors[i]));
+ }
+ PyObject* arglist = Py_BuildValue("(O)", list);
+ CHECK(arglist != nullptr);
+ PyObject* result = PyEval_CallObject(aggregate_fn_, arglist);
+ Py_DECREF(arglist);
+ Py_DECREF(list);
+ return result;
+ }
+
+ void MarkAsResult(PyObject* gradient) const final { Py_INCREF(gradient); }
+
+ PyObject* Zeros(const PyTapeTensor& tensor) const final {
+ PyObject* py_shape = tensor.GetShape();
+ PyObject* py_dtype = tensor.GetDType();
+ PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype);
+ PyObject* result = PyEval_CallObject(zeros_fn_, arg_list);
+ Py_DECREF(arg_list);
+ Py_DECREF(py_dtype);
+ Py_DECREF(py_shape);
+ return reinterpret_cast<PyObject*>(result);
+ }
+
+ PyObject* Ones(const PyTapeTensor& tensor) const final {
+ PyObject* py_shape = tensor.GetShape();
+ PyObject* py_dtype = tensor.GetDType();
+ PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype);
+ PyObject* result = PyEval_CallObject(ones_fn_, arg_list);
+ Py_DECREF(arg_list);
+ Py_DECREF(py_dtype);
+ Py_DECREF(py_shape);
+ return result;
+ }
+
+ PyObject* GraphShape(PyObject* tensor) const {
+ PyObject* arg_list = Py_BuildValue("(O)", tensor);
+ PyObject* result = PyEval_CallObject(graph_shape_fn_, arg_list);
+ Py_DECREF(arg_list);
+ return result;
+ }
+
+ tensorflow::Status CallBackwardFunction(
+ PyBackwardFunction* backward_function,
+ tensorflow::gtl::ArraySlice<PyObject*> output_gradients,
+ std::vector<PyObject*>* result) const final {
+ PyObject* grads = PyTuple_New(output_gradients.size());
+ for (int i = 0; i < output_gradients.size(); ++i) {
+ if (output_gradients[i] == nullptr) {
+ Py_INCREF(Py_None);
+ PyTuple_SET_ITEM(grads, i, Py_None);
+ } else {
+ PyTuple_SET_ITEM(grads, i,
+ reinterpret_cast<PyObject*>(output_gradients[i]));
+ }
+ }
+ PyObject* py_result = (*backward_function)(grads);
+ Py_DECREF(grads);
+ if (py_result == nullptr) {
+ return tensorflow::errors::Internal("gradient function threw exceptions");
+ }
+ result->clear();
+ PyObject* seq =
+ PySequence_Fast(py_result, "expected a sequence of gradients");
+ if (seq == nullptr) {
+ return tensorflow::errors::InvalidArgument(
+ "gradient function did not return a list");
+ }
+ int len = PySequence_Fast_GET_SIZE(seq);
+ VLOG(1) << "Gradient length is " << len;
+ result->reserve(len);
+ for (int i = 0; i < len; ++i) {
+ PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
+ if (item == Py_None) {
+ result->push_back(nullptr);
+ } else {
+ Py_INCREF(item);
+ result->push_back(item);
+ }
+ }
+ Py_DECREF(seq);
+ Py_DECREF(py_result);
+ return tensorflow::Status::OK();
+ }
+
+ void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor); }
+
+ private:
+ PyObject* py_vspace_;
+
+ PyObject* num_elements_;
+ PyObject* aggregate_fn_;
+ PyObject* zeros_fn_;
+ PyObject* ones_fn_;
+ PyObject* graph_shape_fn_;
+};
+PyVSpace* py_vspace = nullptr;
+
+PyObject* TFE_Py_RegisterVSpace(PyObject* e) {
+ if (py_vspace != nullptr) {
+ delete py_vspace;
+ }
+
+ py_vspace = new PyVSpace(e);
+ auto status = py_vspace->Initialize();
+ if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
+ delete py_vspace;
+ return nullptr;
+ }
+
+ Py_RETURN_NONE;
+}
+
+PyObject* PyTapeTensor::GetShape() const {
+ if (shape_.index() == 0) {
+ auto& shape = absl::get<0>(shape_);
+ PyObject* py_shape = PyTuple_New(shape.dims());
+ for (int i = 0; i < shape.dims(); ++i) {
+ PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i)));
+ }
+
+ return py_shape;
+ }
+
+ return py_vspace->GraphShape(absl::get<1>(shape_));
+}
+
class GradientTape
- : public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction> {
+ : public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
+ PyTapeTensor> {
public:
explicit GradientTape(bool persistent, bool watch_accessed_variables)
- : tensorflow::eager::GradientTape<PyObject, PyBackwardFunction>(
- persistent),
+ : tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
+ PyTapeTensor>(persistent),
watch_accessed_variables_(watch_accessed_variables) {}
virtual ~GradientTape() {
@@ -1175,24 +1403,41 @@ void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor) {
reinterpret_cast<TFE_Py_Tape*>(tape)->tape->Watch(tensor_id);
}
-static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) {
+bool ListContainsNone(PyObject* list) {
+ if (list == Py_None) return true;
+ tensorflow::Safe_PyObjectPtr seq(
+ PySequence_Fast(list, "expected a sequence"));
+ if (seq == nullptr) {
+ return false;
+ }
+
+ int len = PySequence_Size(list);
+ for (int i = 0; i < len; ++i) {
+ PyObject* item = PySequence_Fast_GET_ITEM(seq.get(), i);
+ if (item == Py_None) return true;
+ }
+
+ return false;
+}
+
+static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) {
if (EagerTensor_CheckExact(tensor)) {
TFE_TensorHandle* t = EagerTensor_Handle(tensor);
- tensorflow::int64 id = EagerTensor_id(tensor);
+ tensorflow::int64 id = PyEagerTensor_ID(tensor);
tensorflow::TensorShape tensor_shape;
const tensorflow::Status status = t->handle->Shape(&tensor_shape);
if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
- return tensorflow::eager::TapeTensor{id, t->handle->dtype,
- tensorflow::TensorShape({})};
+ return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
+ tensorflow::TensorShape({}));
} else {
- return tensorflow::eager::TapeTensor{id, t->handle->dtype, tensor_shape};
+ return PyTapeTensor(id, t->handle->dtype, tensor_shape);
}
}
tensorflow::int64 id = FastTensorId(tensor);
if (PyErr_Occurred()) {
- return tensorflow::eager::TapeTensor{
- id, static_cast<tensorflow::DataType>(0), tensorflow::TensorShape({})};
+ return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
+ tensorflow::TensorShape({}));
}
PyObject* dtype_object = PyObject_GetAttrString(tensor, "dtype");
PyObject* dtype_enum = PyObject_GetAttrString(dtype_object, "_type_enum");
@@ -1200,16 +1445,21 @@ static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) {
tensorflow::DataType dtype =
static_cast<tensorflow::DataType>(MakeInt(dtype_enum));
Py_DECREF(dtype_enum);
- if (PyErr_Occurred() != nullptr) {
- return tensorflow::eager::TapeTensor{id, dtype,
- tensorflow::TensorShape({})};
+ if (PyErr_Occurred()) {
+ return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
+ tensorflow::TensorShape({}));
}
static char _shape_tuple[] = "_shape_tuple";
PyObject* shape_tuple = PyObject_CallMethod(tensor, _shape_tuple, nullptr);
- if (PyErr_Occurred() != nullptr) {
- return tensorflow::eager::TapeTensor{id, dtype,
- tensorflow::TensorShape({})};
+ if (PyErr_Occurred()) {
+ return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
+ tensorflow::TensorShape({}));
}
+
+ if (ListContainsNone(shape_tuple)) {
+ return PyTapeTensor(id, dtype, tensor);
+ }
+
auto l = MakeIntList(shape_tuple);
Py_DECREF(shape_tuple);
// Replace -1, which represents accidental Nones which can occur in graph mode
@@ -1220,7 +1470,7 @@ static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) {
}
}
tensorflow::TensorShape shape(l);
- return tensorflow::eager::TapeTensor{id, dtype, shape};
+ return PyTapeTensor(id, dtype, shape);
}
std::vector<tensorflow::int64> MakeTensorIDList(PyObject* tensors) {
@@ -1286,7 +1536,7 @@ void TapeSetRecordOperation(
const std::vector<tensorflow::DataType>& input_dtypes,
const std::function<PyBackwardFunction*()>& backward_function_getter,
const std::function<void(PyBackwardFunction*)>& backward_function_killer) {
- std::vector<tensorflow::eager::TapeTensor> output_info;
+ std::vector<PyTapeTensor> output_info;
PyObject* seq = PySequence_Fast(output_tensors,
"expected a sequence of integer tensor ids");
int len = PySequence_Size(output_tensors);
@@ -1362,173 +1612,6 @@ void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) {
}
}
-class PyVSpace
- : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction> {
- public:
- explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {
- Py_INCREF(py_vspace_);
- }
-
- tensorflow::Status Initialize() {
- num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn");
- if (num_elements_ == nullptr) {
- return tensorflow::errors::InvalidArgument("invalid vspace");
- }
- aggregate_fn_ = PyObject_GetAttrString(py_vspace_, "aggregate_fn");
- if (aggregate_fn_ == nullptr) {
- return tensorflow::errors::InvalidArgument("invalid vspace");
- }
- zeros_ = PyObject_GetAttrString(py_vspace_, "zeros");
- if (zeros_ == nullptr) {
- return tensorflow::errors::InvalidArgument("invalid vspace");
- }
- ones_ =
- PyObject_GetAttrString(reinterpret_cast<PyObject*>(py_vspace_), "ones");
- if (ones_ == nullptr) {
- return tensorflow::errors::InvalidArgument("invalid vspace");
- }
- return tensorflow::Status::OK();
- }
-
- ~PyVSpace() override {
- Py_XDECREF(num_elements_);
- Py_XDECREF(aggregate_fn_);
- Py_XDECREF(zeros_);
- Py_XDECREF(ones_);
-
- Py_DECREF(py_vspace_);
- }
-
- tensorflow::int64 NumElements(PyObject* tensor) const final {
- PyObject* arglist =
- Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor));
- PyObject* result = PyEval_CallObject(num_elements_, arglist);
- tensorflow::int64 r = MakeInt(result);
- Py_DECREF(result);
- Py_DECREF(arglist);
- return r;
- }
-
- PyObject* AggregateGradients(
- tensorflow::gtl::ArraySlice<PyObject*> gradient_tensors) const final {
- PyObject* list = PyList_New(gradient_tensors.size());
- for (int i = 0; i < gradient_tensors.size(); ++i) {
- // Note: stealing a reference to the gradient tensors.
- CHECK(gradient_tensors[i] != nullptr);
- CHECK(gradient_tensors[i] != Py_None);
- PyList_SET_ITEM(list, i,
- reinterpret_cast<PyObject*>(gradient_tensors[i]));
- }
- PyObject* arglist = Py_BuildValue("(O)", list);
- CHECK(arglist != nullptr);
- PyObject* result = PyEval_CallObject(aggregate_fn_, arglist);
- Py_DECREF(arglist);
- Py_DECREF(list);
- return result;
- }
-
- void MarkAsResult(PyObject* gradient) const final { Py_INCREF(gradient); }
-
- PyObject* Zeros(tensorflow::TensorShape shape,
- tensorflow::DataType dtype) const final {
- PyObject* py_shape = PyTuple_New(shape.dims());
- for (int i = 0; i < shape.dims(); ++i) {
- PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i)));
- }
- PyObject* py_dtype = PyLong_FromLong(static_cast<int>(dtype));
- PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype);
- PyObject* result = PyEval_CallObject(zeros_, arg_list);
- Py_DECREF(arg_list);
- Py_DECREF(py_dtype);
- Py_DECREF(py_shape);
- return reinterpret_cast<PyObject*>(result);
- }
-
- PyObject* Ones(tensorflow::TensorShape shape,
- tensorflow::DataType dtype) const final {
- PyObject* py_shape = PyTuple_New(shape.dims());
- for (int i = 0; i < shape.dims(); ++i) {
- PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i)));
- }
- PyObject* py_dtype = PyLong_FromLong(static_cast<int>(dtype));
- PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype);
- PyObject* result = PyEval_CallObject(ones_, arg_list);
- Py_DECREF(arg_list);
- Py_DECREF(py_dtype);
- Py_DECREF(py_shape);
- return result;
- }
-
- tensorflow::Status CallBackwardFunction(
- PyBackwardFunction* backward_function,
- tensorflow::gtl::ArraySlice<PyObject*> output_gradients,
- std::vector<PyObject*>* result) const final {
- PyObject* grads = PyTuple_New(output_gradients.size());
- for (int i = 0; i < output_gradients.size(); ++i) {
- if (output_gradients[i] == nullptr) {
- Py_INCREF(Py_None);
- PyTuple_SET_ITEM(grads, i, Py_None);
- } else {
- PyTuple_SET_ITEM(grads, i,
- reinterpret_cast<PyObject*>(output_gradients[i]));
- }
- }
- PyObject* py_result = (*backward_function)(grads);
- Py_DECREF(grads);
- if (py_result == nullptr) {
- return tensorflow::errors::Internal("gradient function threw exceptions");
- }
- result->clear();
- PyObject* seq =
- PySequence_Fast(py_result, "expected a sequence of gradients");
- if (seq == nullptr) {
- return tensorflow::errors::InvalidArgument(
- "gradient function did not return a list");
- }
- int len = PySequence_Fast_GET_SIZE(seq);
- VLOG(1) << "Gradient length is " << len;
- result->reserve(len);
- for (int i = 0; i < len; ++i) {
- PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
- if (item == Py_None) {
- result->push_back(nullptr);
- } else {
- Py_INCREF(item);
- result->push_back(item);
- }
- }
- Py_DECREF(seq);
- Py_DECREF(py_result);
- return tensorflow::Status::OK();
- }
-
- void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor); }
-
- private:
- PyObject* py_vspace_;
-
- PyObject* num_elements_;
- PyObject* aggregate_fn_;
- PyObject* zeros_;
- PyObject* ones_;
-};
-PyVSpace* py_vspace = nullptr;
-
-PyObject* TFE_Py_RegisterVSpace(PyObject* e) {
- if (py_vspace != nullptr) {
- delete py_vspace;
- }
-
- py_vspace = new PyVSpace(e);
- auto status = py_vspace->Initialize();
- if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
- delete py_vspace;
- return nullptr;
- }
-
- Py_RETURN_NONE;
-}
-
std::vector<PyObject*> MakeTensorList(PyObject* tensors) {
PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
if (seq == nullptr) {
@@ -1740,6 +1823,9 @@ PyObject* MaybeGetDTypeForAttr(const string& attr,
Py_RETURN_NONE;
}
+// TODO(agarwal): use an automatic mechanism for handling None arguments to
+// gradient functions.
+
// Returns a pair where the first value of the pair indicates whether or not all
// outputs are unused. If the first value is false, the second value is a
// set that identifies which of the output indices are unused.
@@ -2565,13 +2651,18 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) {
int num_retvals = 0;
for (int i = 0; i < op_def->output_arg_size(); i++) {
const auto& output_arg = op_def->output_arg(i);
+ int delta = 1;
if (!output_arg.number_attr().empty()) {
- num_retvals += attr_list_sizes[output_arg.number_attr()];
+ delta = attr_list_sizes[output_arg.number_attr()];
} else if (!output_arg.type_list_attr().empty()) {
- num_retvals += attr_list_sizes[output_arg.type_list_attr()];
- } else {
- num_retvals++;
+ delta = attr_list_sizes[output_arg.type_list_attr()];
+ }
+ if (delta < 0) {
+ RaiseFallbackException(
+ "Attributes suggest that the size of an output list is less than 0");
+ return nullptr;
}
+ num_retvals += delta;
}
tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals(num_retvals);
diff --git a/tensorflow/python/eager/pywrap_tfe_test.py b/tensorflow/python/eager/pywrap_tfe_test.py
index fd8ab695b8..669fa08488 100644
--- a/tensorflow/python/eager/pywrap_tfe_test.py
+++ b/tensorflow/python/eager/pywrap_tfe_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
+from tensorflow.python.eager import core
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -123,8 +124,8 @@ class Tests(test.TestCase):
def testFastpathExecute_MixedPrecisionVariableTapeWrite(self):
ctx = context.context()
with backprop.GradientTape(persistent=True) as tape:
- a_2_by_2 = constant_op.constant(
- [[1.0, 2.0], [3.0, 4.0]], dtype=dtypes.float32)
+ a_2_by_2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]],
+ dtype=dtypes.float32)
a_2_by_2_fp16 = math_ops.cast(a_2_by_2, dtype=dtypes.float16)
m1 = resource_variable_ops.ResourceVariable(a_2_by_2)
m2 = resource_variable_ops._MixedPrecisionVariable(
@@ -233,6 +234,26 @@ class Tests(test.TestCase):
pywrap_tensorflow.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name,
ctx_handle, None, [], a_2_by_2)
+ @test_util.assert_no_new_tensors
+ @test_util.assert_no_garbage_created
+ def testFastPathExecute_InvalidAttributes(self):
+ split_dim = constant_op.constant(0, dtype=dtypes.int32)
+ value = constant_op.constant([0, 1, 2, 3], dtype=dtypes.float32)
+ ctx = context.context()
+ ctx_handle = ctx._handle
+ with self.assertRaises(core._FallbackException):
+ pywrap_tensorflow.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name,
+ "Split", None, None, split_dim,
+ value, "num_split", -1)
+
+ @test_util.assert_no_new_tensors
+ @test_util.assert_no_garbage_created
+ def testInvalidNumOutputs(self):
+ with self.assertRaisesRegexp(
+ Exception,
+ "Value for attr 'num_split' of -1 must be at least minimum 1"):
+ array_ops.split(value=[1, 2, 3], num_or_size_splits=-1)
+
if __name__ == "__main__":
test.main()