diff options
author | Cao Zongyan <zongyan.cao@alibaba-inc.com> | 2018-09-26 11:54:30 +0800 |
---|---|---|
committer | Cao Zongyan <zongyan.cao@alibaba-inc.com> | 2018-09-26 11:54:30 +0800 |
commit | 35174f46b973c66a2e6894a12b3018d60e8414ec (patch) | |
tree | 5bdae0172159bc02ec3a470722bf959b14dd47ba /tensorflow/python/eager | |
parent | f0886f7269de900d226455d4831722f6fc94a71b (diff) | |
parent | 6666516f390f125ed70ddbd4e6f89b83d953c408 (diff) |
Merge remote-tracking branch 'origin'
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r-- | tensorflow/python/eager/BUILD | 36 | ||||
-rw-r--r-- | tensorflow/python/eager/backprop.py | 43 | ||||
-rw-r--r-- | tensorflow/python/eager/backprop_test.py | 12 | ||||
-rw-r--r-- | tensorflow/python/eager/def_function.py | 235 | ||||
-rw-r--r-- | tensorflow/python/eager/def_function_test.py | 87 | ||||
-rw-r--r-- | tensorflow/python/eager/function.py | 393 | ||||
-rw-r--r-- | tensorflow/python/eager/function_test.py | 409 | ||||
-rw-r--r-- | tensorflow/python/eager/imperative_grad.py | 5 | ||||
-rw-r--r-- | tensorflow/python/eager/pywrap_tensor.cc | 41 | ||||
-rw-r--r-- | tensorflow/python/eager/pywrap_tensor.h | 5 | ||||
-rw-r--r-- | tensorflow/python/eager/pywrap_tfe_src.cc | 473 | ||||
-rw-r--r-- | tensorflow/python/eager/pywrap_tfe_test.py | 25 |
12 files changed, 1404 insertions, 360 deletions
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 85da1baaf0..d3d997e6df 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -17,7 +17,10 @@ cc_library( "pywrap_tensor.h", "pywrap_tfe.h", ], - visibility = ["//tensorflow:internal"], + visibility = [ + "//learning/deepmind/courier:__pkg__", + "//tensorflow:internal", + ], deps = [ "//tensorflow/c:c_api", "//tensorflow/c:c_api_internal", @@ -34,6 +37,7 @@ cc_library( "//tensorflow/python:safe_ptr", "//third_party/py/numpy:headers", "//third_party/python_runtime:headers", + "@com_google_absl//absl/types:variant", ], ) @@ -45,6 +49,7 @@ py_library( ":backprop", ":context", ":core", + ":def_function", ":execute", ":function", ":graph_only_ops", @@ -146,6 +151,7 @@ cuda_py_test( "//tensorflow/python:clip_ops", "//tensorflow/python:init_ops", "//tensorflow/python:layers", + "//tensorflow/python:list_ops", "//tensorflow/python:math_ops", "//tensorflow/python:resource_variable_ops", ], @@ -345,6 +351,7 @@ py_test( deps = [ ":backprop", ":context", + ":core", ":test", "//tensorflow/python:framework_test_lib", "//tensorflow/python:math_ops", @@ -377,3 +384,30 @@ cuda_py_test( "optonly", # The test is too slow in non-opt mode ], ) + +py_library( + name = "def_function", + srcs = ["def_function.py"], + srcs_version = "PY2AND3", + deps = [ + ":context", + ":function", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:variable_scope", + "//tensorflow/python/training/checkpointable:base", + ], +) + +py_test( + name = "def_function_test", + srcs = ["def_function_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":def_function", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:framework_ops", + ], +) diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index be392c7a0f..78f3198011 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -120,27 +120,6 @@ def _gradient_function(op_name, attr_tuple, num_inputs, inputs, outputs, pywrap_tensorflow.TFE_Py_RegisterGradientFunction(_gradient_function) -_tracing = False - - -# TODO(agarwal): use an automatic mechanism for handling None arguments to -# gradient functions. -# Some gradient functions can accept None arguments for gradients. The following -# maps the operation name to the indices at which the corresponding gradient -# function can accept None values. -# e.g. FusedBatchNorm outputs 5 values and hence receives 5 gradient values -# during backprop. However the gradient function uses only the first of those -# values and ignores the rest. The entry, "FusedBatchNorm": [1, 2, 3, 4], -# indicates that only the gradient corresponding to index 0 is used, and the -# gradient values at indices 1-4 are ignored (and hence can be None). The -# backprop algorithm can then leverage this by not constructing zeros to -# pass for those indices. -_grad_fn_accepts_none_for_indices = { - "SoftmaxCrossEntropyWithLogits": [1], - "FusedBatchNorm": [1, 2, 3, 4] -} - - def _record_gradient(op_name, inputs, attrs, results, name): return pywrap_tensorflow.TFE_Py_RecordGradient(op_name, inputs, attrs, results, name) @@ -585,7 +564,10 @@ def _aggregate_grads(gradients): def _num_elements(grad): """The number of elements in the `grad` tensor.""" if isinstance(grad, ops.Tensor): - return functools.reduce(operator.mul, grad._shape_tuple(), 1) # pylint: disable=protected-access + shape_tuple = grad._shape_tuple() # pylint: disable=protected-access + if shape_tuple is None or None in shape_tuple: + return 0 + return functools.reduce(operator.mul, shape_tuple, 1) if isinstance(grad, ops.IndexedSlices): return functools.reduce(operator.mul, grad.values._shape_tuple(), 1) # pylint: disable=protected-access raise ValueError("`grad` not a Tensor or IndexedSlices.") @@ -629,8 +611,9 @@ def _ones(shape, dtype): _default_vspace = imperative_grad.VSpace( num_elements_fn=_num_elements, aggregate_fn=_aggregate_grads, - zeros=_zeros, - ones=_ones) + zeros_fn=_zeros, + ones_fn=_ones, + graph_shape_fn=gen_array_ops.shape) pywrap_tensorflow.TFE_Py_RegisterVSpace(_default_vspace) @@ -648,8 +631,8 @@ class GradientTape(object): Operations are recorded if they are executed within this context manager and at least one of their inputs is being "watched". - Trainable variables (created by `tf.Variable` or `tf.get_variable`, - trainable=True is default in both cases) are automatically watched. Tensors + Trainable variables (created by `tf.Variable` or `tf.get_variable`, where + `trainable=True` is default in both cases) are automatically watched. Tensors can be manually watched by invoking the `watch` method on this context manager. @@ -669,6 +652,7 @@ class GradientTape(object): ```python x = tf.constant(3.0) with tf.GradientTape() as g: + g.watch(x) with tf.GradientTape() as gg: gg.watch(x) y = x * x @@ -745,7 +729,9 @@ class GradientTape(object): self._persistent = persistent self._watch_accessed_variables = watch_accessed_variables self._recording = False - context.context().start_step() + self._created_eagerly = context.executing_eagerly() + if self._created_eagerly: + context.context().start_step() def __enter__(self): """Enters a context inside which operations are recorded on this tape.""" @@ -775,7 +761,8 @@ class GradientTape(object): self._recording = False def __del__(self): - context.context().end_step() + if self._created_eagerly: + context.context().end_step() def watch(self, tensor): """Ensures that `tensor` is being traced by this tape. diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index f938ed5df8..32731747b7 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -1022,6 +1022,18 @@ class BackpropTest(test.TestCase): resource_variable_ops.ResourceVariable(2.0)) self.assertAllEqual(gradients_constants, gradients_variables) + def testUnknownShapes(self): + with context.graph_mode(): + with backprop.GradientTape() as tape: + a = array_ops.placeholder(dtype=dtypes.float32, shape=None) + tape.watch(a) + b = a**3 + + db_da = tape.gradient(b, a) + + with self.cached_session() as sess: + self.assertEqual((8.0, 12.0), sess.run((b, db_da), feed_dict={a: 2.0})) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py new file mode 100644 index 0000000000..8dcacd5c99 --- /dev/null +++ b/tensorflow/python/eager/def_function.py @@ -0,0 +1,235 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=unidiomatic-typecheck +"""Prototype decorator for defining graph-mode functions with eager semantics.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.eager import context +from tensorflow.python.eager import function +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.training.checkpointable import base as checkpointable + + +class UnliftedInitializerVariable(resource_variable_ops.ResourceVariable): + """Variable which does not lift its initializer out of function context. + + Instances of this variable, when created, build a graph which runs their + initializer inside a tf.cond(is_initialized) block. + + This can only be created inside a defun called from (eventually) eager + mode. That is, non-function-building graphs are not supported. + """ + + def __init__(self, # pylint: disable=super-init-not-called + initial_value=None, + trainable=True, + caching_device=None, + name=None, + dtype=None, + constraint=None, + **unused_kwargs): + """Creates a variable. + + Args: + initial_value: A `Tensor`, or Python object convertible to a `Tensor`, + which is the initial value for the Variable. The initial value must have + a shape specified unless `validate_shape` is set to False. Can also be a + callable with no argument that returns the initial value when called. + (Note that initializer functions from init_ops.py must first be bound + to a shape before being used here.) + trainable: If `True`, GradientTapes automatically watch uses of this + Variable. + caching_device: Optional device string or function describing where the + Variable should be cached for reading. Defaults to the Variable's + device. If not `None`, caches on another device. Typical use is to + cache on the device where the Ops using the Variable reside, to + deduplicate copying through `Switch` and other conditional statements. + name: Optional name for the variable. Defaults to `'Variable'` and gets + uniquified automatically. + dtype: If set, initial_value will be converted to the given type. + If None, either the datatype will be kept (if initial_value is + a Tensor) or float32 will be used (if it is a Python object convertible + to a Tensor). + constraint: An optional projection function to be applied to the variable + after being updated by an `Optimizer` (e.g. used to implement norm + constraints or value constraints for layer weights). The function must + take as input the unprojected Tensor representing the value of the + variable and return the Tensor for the projected value + (which must have the same shape). Constraints are not safe to + use when doing asynchronous distributed training. + + Raises: + ValueError: If the initial value is not specified, or does not have a + shape and `validate_shape` is `True`. + RuntimeError: If called outside of a function definition. + """ + if context.executing_eagerly(): + raise RuntimeError( + "UnliftedInitializerVariable should not be created " + "outside of functions.") + with ops.init_scope(): + if not context.executing_eagerly(): + raise RuntimeError( + "UnliftedInitializerVariable does not support legacy graph mode.") + self._in_graph_mode = False + if initial_value is None: + raise ValueError("initial_value must be specified.") + init_from_fn = callable(initial_value) + + if constraint is not None and not callable(constraint): + raise ValueError("The `constraint` argument must be a callable.") + + if isinstance(initial_value, checkpointable.CheckpointInitialValue): + self._maybe_initialize_checkpointable() + self._update_uid = initial_value.checkpoint_position.restore_uid + initial_value = initial_value.wrapped_value + + self._trainable = trainable + self._save_slice_info = None + self._initial_value = None + self._initializer_op = None + self._is_initialized_op = None + self._graph_element = None + self._cached_value = None + # Store the graph key so optimizers know how to only retrieve variables from + # this graph. Guaranteed to be the same as the eager graph_key. + self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access + with ops.name_scope(name, "Variable", [] + if init_from_fn else [initial_value]) as name: + # pylint: disable=protected-access + with ops.init_scope(): + assert context.executing_eagerly() + shared_name = ops._name_from_scope_name(name) + shared_name = "%s_%d" % (shared_name, ops.uid()) + # Use attr_scope and device(None) to simulate the behavior of + # colocate_with when the variable we want to colocate with doesn't + # yet exist. + with ops.name_scope("Initializer"), ops.device(None): + initial_value = ops.convert_to_tensor( + initial_value() if init_from_fn else initial_value, + name="initial_value", dtype=dtype) + with ops.init_scope(): + self._handle = resource_variable_ops.eager_safe_variable_handle( + shape=initial_value.get_shape(), + dtype=initial_value.dtype.base_dtype, + shared_name=shared_name, + name=name, + graph_mode=False) + self._shape = initial_value.shape + self._unique_id = shared_name + self._handle_name = shared_name + ":0" + self._dtype = initial_value.dtype.base_dtype + self._constraint = constraint + assert initial_value is not None + def assign_fn(): + with ops.name_scope("Assign") as n, ops.colocate_with(self._handle): + resource_variable_ops.assign_variable_op( + self._handle, + initial_value, + name=n) + # Returning values to keep tf.cond happy. + return ops.convert_to_tensor(1) + def not_assign_fn(): + return ops.convert_to_tensor(0) + # Note: this cond is always guaranteed to run because we're inside a defun + # which will insert automatic control dependencies. + control_flow_ops.cond( + resource_variable_ops.var_is_initialized_op(self._handle), + not_assign_fn, assign_fn) + + # After the handle has been created, set up a way to clean it up when + # executing eagerly. We'll hold the only reference to the deleter, so that + # when this object is garbage collected the deleter will be too. This + # means ResourceVariables can be part of reference cycles without those + # cycles being uncollectable. + self._handle_deleter = resource_variable_ops.EagerResourceDeleter( + handle=self._handle, handle_device=self._handle.device) + self._cached_shape_as_list = None + + +def _defun_with_scope(scope, fn): + + def wrapped_fn(*args, **kwds): + with variable_scope.variable_creator_scope(scope): + return fn(*args, **kwds) + + return function.defun(wrapped_fn) + + +def def_function(fn): + """Defines a function as per the "functions, not sessions" document.""" + + # Wrapping the values in lists to bypass python's lack of way to mutate + # symbols from an outer scope. + first_call = [True] + function_to_call = [] + + # TODO(apassos) represent this as an object and not as a closure. + def decorated_fn(*args, **kwds): + """Graph function for fn.""" + if not first_call[0]: + return function_to_call[0](*args, **kwds) + + first_call[0] = False + created_variables = [] + + def variable_creator_scope(unused_next_creator, **kwds): + """Creates UnliftedInitializerVariables and saves references to them.""" + v = UnliftedInitializerVariable(**kwds) + created_variables.append(v) + return v + + first_graph_function = _defun_with_scope(variable_creator_scope, fn) + + # Force the definition of the function for these arguments + first_concrete = first_graph_function.get_concrete_function(*args, **kwds) + + def invalid_creator_scope(*unused_args, **unused_kwds): + """Disables variable creation.""" + raise ValueError( + "def_function-decorated function tried to create " + "variables on second call.") + + second_graph_function = _defun_with_scope(invalid_creator_scope, fn) + + function_to_call.append(second_graph_function) + if not created_variables: + # Note: this retracing might be unnecessary, but running the function + # forever in the scope which disallows variable creation is safer than not + # doing so. + return second_graph_function(*args, **kwds) + + def fn_with_cond(*inner_args, **inner_kwds): + """Conditionally runs initialization if it's needed.""" + condition = True + for variable in created_variables: + condition = condition and resource_variable_ops.var_is_initialized_op( + variable.handle) + # We want to call second_graph_function if possible because it avoids + # recomputing potentially expensive initializers. + return control_flow_ops.cond( + condition, + lambda: second_graph_function(*inner_args, **inner_kwds), + lambda: first_concrete(*inner_args, **inner_kwds)) + + return function.defun(fn_with_cond)(*args, **kwds) + + return decorated_fn diff --git a/tensorflow/python/eager/def_function_test.py b/tensorflow/python/eager/def_function_test.py new file mode 100644 index 0000000000..804436c4bb --- /dev/null +++ b/tensorflow/python/eager/def_function_test.py @@ -0,0 +1,87 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +from tensorflow.python.eager import def_function +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +class DefFunctionTest(test.TestCase): + + def testNoVariables(self): + + @def_function.def_function + def fn(x): + return 2 * x + + self.assertAllEqual(fn(constant_op.constant(4.0)), 8.0) + + def testFailIfVariablesAreCreatedMoreThanOnce(self): + + @def_function.def_function + def fn(x): + return variables.Variable(1.0) + x + + with self.assertRaises(ValueError): + fn(1.0) + + def testFailIfVariablesAreCreatedMoreThanOnceNoWeakRef(self): + state = [] + + @def_function.def_function + def fn(x): + state.append(variables.Variable(1.0)) + return state[-1] + x + + with self.assertRaises(ValueError): + fn(1.0) + + def testCorrectVariableCreation(self): + + state = [] + + @def_function.def_function + def fn(x): + if not state: + state.append(variables.Variable(2.0)) + return state[0] * x + + self.assertAllEqual(fn(constant_op.constant(1.0)), 2.0) + self.assertAllEqual(fn(constant_op.constant(3.0)), 6.0) + + def testVariableInitializerNotConstant(self): + + state = [] + + @def_function.def_function + def fn(x): + if not state: + state.append(variables.Variable(2.0 * x)) + return state[0] * x + + self.assertAllEqual(fn(constant_op.constant(1.0)), 2.0) + self.assertAllEqual(fn(constant_op.constant(3.0)), 6.0) + + +if __name__ == '__main__': + ops.enable_eager_execution() + test.main() diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 03f12139f6..b28befeb62 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -23,10 +23,12 @@ import collections import functools import sys import threading +import weakref import numpy as np import six +from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import function_pb2 from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import context @@ -34,6 +36,7 @@ from tensorflow.python.eager import execute from tensorflow.python.eager import tape from tensorflow.python.eager.graph_only_ops import graph_placeholder from tensorflow.python.framework import c_api_util +from tensorflow.python.framework import constant_op from tensorflow.python.framework import device as pydev from tensorflow.python.framework import dtypes as dtypes_module from tensorflow.python.framework import ops @@ -59,23 +62,47 @@ cond_v2_impl._function = sys.modules[__name__] # pylint: disable=protected-acce gradients_impl._function = sys.modules[__name__] # pylint: disable=protected-access -def _create_substitute_placeholder(value, name, dtype=None): +# TODO(scottzhu): Update this to allow arbitrary attribute names in future. +WHITELIST_FUNCTION_ATTRIBUTE_PREFIX = "experimental_" + + +def _create_substitute_placeholder(value, name=None, dtype=None): """Creates a placeholder for `value` and propagates shape info to it.""" # Note: setting ops.control_dependencies(None) ensures we always put # capturing placeholders outside of any control flow context. with ops.control_dependencies(None): placeholder = graph_placeholder( dtype=dtype or value.dtype, shape=value.shape, name=name) - if placeholder.dtype == dtypes_module.resource: - if isinstance(value, ops.EagerTensor): - handle_data = value._handle_data # pylint: disable=protected-access + _copy_handle_data(value, placeholder) + return placeholder + + +def _copy_handle_data(source_t, target_t): + """Copies HandleData for variant and resource type tensors if available. + + The CppShapeInferenceResult::HandleData proto contains information about the + shapes and types of the element tensors of resource/variant type tensors. + We need to copy this across function boundaries, i.e., when capturing a + placeholder or when returning a function tensor as output. If we don't do this + the element tensors will have unknown shapes, e.g., if a TensorList variant + tensor is captured as a placeholder, elements popped from that list would have + unknown shape. + + Args: + source_t: The tensor to copy HandleData from. + target_t: The tensor to copy HandleData to. + """ + if (target_t.dtype == dtypes_module.resource or + target_t.dtype == dtypes_module.variant): + if isinstance(source_t, ops.EagerTensor): + handle_data = source_t._handle_data # pylint: disable=protected-access else: - handle_data = resource_variable_ops.get_resource_handle_data(value) + handle_data = resource_variable_ops.get_resource_handle_data(source_t) if handle_data is not None and handle_data.is_set: # pylint: disable=protected-access - pywrap_tensorflow.SetResourceHandleShapeAndType( - placeholder.graph._c_graph, placeholder._as_tf_output(), - handle_data.SerializeToString()) + pywrap_tensorflow.SetHandleShapeAndType(target_t.graph._c_graph, + target_t._as_tf_output(), + handle_data.SerializeToString()) # pylint: enable=protected-access # Ensure that shapes and dtypes are propagated. shapes, types = zip(*[(pair.shape, pair.dtype) @@ -84,12 +111,10 @@ def _create_substitute_placeholder(value, name, dtype=None): shapes = [[d.size for d in s.dim] if not s.unknown_rank else None for s in shapes] pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper( - placeholder._op._graph._c_graph, # pylint: disable=protected-access - placeholder._as_tf_output(), # pylint: disable=protected-access + target_t._op._graph._c_graph, # pylint: disable=protected-access + target_t._as_tf_output(), # pylint: disable=protected-access shapes, ranks, types) - return placeholder - def _get_device_functions(ctx, graph): """Returns a tuple of device functions representing the device stack.""" @@ -99,6 +124,44 @@ def _get_device_functions(ctx, graph): return tuple(graph._device_functions_outer_to_inner) # pylint: disable=protected-access +def _parse_func_attrs(attributes): + """Convert the keyword arguments into function_def attributes. + + Currently only support primitive types: bool, int, float and string. + + Args: + attributes: the dictionary of attributes. + Returns: + A dict of attributes where the key is the name of attribute and the value + is the AttrValue proto. + Raises: + ValueError: If the kwargs contains unwhitelisted name or unsupported value + types. + """ + attrs = {} + for key, value in attributes.items(): + if not key.startswith(WHITELIST_FUNCTION_ATTRIBUTE_PREFIX): + raise ValueError("Attribute name is not whitelisted. " + "Whitelisted: prefix %s, got: %s" % + (WHITELIST_FUNCTION_ATTRIBUTE_PREFIX, key)) + + if isinstance(value, attr_value_pb2.AttrValue): + attrs[key] = value + # bool type check has to happen before int since bool is a subclass of int. + elif isinstance(value, bool): + attrs[key] = attr_value_pb2.AttrValue(b=value) + elif isinstance(value, int): + attrs[key] = attr_value_pb2.AttrValue(i=value) + elif isinstance(value, float): + attrs[key] = attr_value_pb2.AttrValue(f=value) + elif isinstance(value, str): + attrs[key] = attr_value_pb2.AttrValue(s=compat.as_bytes(value)) + else: + raise ValueError("Unsupported attribute type for %s with type %s" % + (key, type(value))) + return attrs + + class FuncGraph(ops.Graph): """Graph representing a function body. @@ -136,7 +199,7 @@ class FuncGraph(ops.Graph): self.inputs = [] self.outputs = [] self.structured_outputs = None - self.variables = [] + self._weak_variables = [] self.outer_graph = ops.get_default_graph() self.captures = collections.OrderedDict() @@ -173,6 +236,31 @@ class FuncGraph(ops.Graph): self._graph_key = graph._graph_key # pylint: enable=protected-access + @property + def variables(self): + """A list of variables accessed by this FuncGraph. + + Note that functions keep only weak references to variables. Calling the + function after a variable it accesses has been deleted is an error. + + Yields: + Strong references to variables accessed by this FuncGraph. + """ + for weak_v in self._weak_variables: + v = weak_v() + if v is None: + raise AssertionError( + "Called a function referencing variables which have been deleted. " + "This likely means that function-local variables were created and " + "not referenced elsewhere in the program. This is generally a " + "mistake; consider storing variables in an object attribute on " + "first call.") + yield v + + @variables.setter + def variables(self, var_list): + self._weak_variables = [weakref.ref(v) for v in var_list] + def create_op( self, op_type, @@ -365,6 +453,7 @@ class _EagerDefinedFunction(object): self._num_outputs = len(self.signature.output_arg) self._output_types = [o.type for o in self.signature.output_arg] self._output_shapes = [o.shape for o in outputs] + self._func_graph_outputs = outputs self.grad_func_name = None self.python_grad_func = None self._c_func = c_api_util.ScopedTFFunction(fn) @@ -441,6 +530,8 @@ class _EagerDefinedFunction(object): else: for i, shape in enumerate(self._output_shapes): outputs[i].set_shape(shape) + for i, func_graph_output in enumerate(self._func_graph_outputs): + _copy_handle_data(func_graph_output, outputs[i]) return outputs @@ -485,7 +576,7 @@ class Function(object): self._num_outputs = len(self._func_graph.outputs) self._output_shapes = tuple( output.shape for output in self._func_graph.outputs) - self._attrs = attrs or {} + self._attrs = _parse_func_attrs(attrs or {}) self._device_functions = tuple( self._func_graph._device_functions_outer_to_inner) # pylint: disable=protected-access @@ -506,7 +597,19 @@ class Function(object): self._distributed_variables[component_variable.handle] = variable def __call__(self, *args): - """Executes the wrapped function.""" + """Executes the wrapped function. + + Args: + *args: a list of Tensors or Variables. + + Returns: + The result of applying the TF function to `args`. + + Raises: + ValueError: If the current device stack does not match the device stack + under which the function was created, or if `args` contains anything + other than Tensors or Variables. + """ ctx = context.context() device_functions = _get_device_functions(ctx, ops.get_default_graph()) if device_functions != self._device_functions: @@ -522,7 +625,18 @@ class Function(object): tape.variable_accessed(v) captures = self._resolve_captured_inputs() - tensor_inputs = [x for x in nest.flatten(args) if isinstance(x, ops.Tensor)] + tensor_inputs = [] + for i, arg in enumerate(nest.flatten(args)): + if isinstance(arg, resource_variable_ops.ResourceVariable): + if arg.trainable: + tape.variable_accessed(arg) + tensor_inputs.append(arg.handle) + elif isinstance(arg, ops.Tensor): + tensor_inputs.append(arg) + else: + raise ValueError("All inputs to `Function`s must be Tensors; " + "on invocation of %s, the %d-th input (%s) was not a " + "Tensor." % (self._func_graph.name, i, str(arg))) args = tensor_inputs + captures if tape.should_record(tensor_inputs) or tape.should_record(captures): @@ -537,11 +651,6 @@ class Function(object): return self._func_graph @property - def variables(self): - """Returns all variables touched by this function.""" - return self._func_graph.variables - - @property def inputs(self): """Returns tensors in `self.graph` corresponding to arguments.""" return self._func_graph.inputs @@ -738,7 +847,12 @@ def _get_defun_inputs_from_args(args): return nest.pack_sequence_as(args, function_inputs) -def func_graph_from_py_func(name, python_func, args, kwds, signature=None): +def func_graph_from_py_func(name, + python_func, + args, + kwargs, + signature=None, + func_graph=None): """Returns a `FuncGraph` generated from `python_func`. Args: @@ -746,13 +860,15 @@ def func_graph_from_py_func(name, python_func, args, kwds, signature=None): python_func: the Python function to trace. args: the positional args with which the Python function should be called; ignored if a signature is provided. - kwds: the keyword args with which the Python function should be called; + kwargs: the keyword args with which the Python function should be called; ignored if a signature is provided. signature: a possibly nested sequence of `TensorSpecs` specifying the shapes and dtypes of the arguments. When a signature is provided, `args` and - `kwds` are ignored, and `python_func` is traced with Tensors conforming + `kwargs` are ignored, and `python_func` is traced with Tensors conforming to `signature`. If `None`, the shapes and dtypes are inferred from the inputs. + func_graph: Optional. An instance of FuncGraph. If provided, we will use + this graph else a new one is built and returned. Returns: A FuncGraph. @@ -761,26 +877,25 @@ def func_graph_from_py_func(name, python_func, args, kwds, signature=None): TypeError: If any of `python_func`'s return values is neither `None` nor a `Tensor`. """ - func_graph = FuncGraph(name) + if func_graph is None: + func_graph = FuncGraph(name) + assert isinstance(func_graph, FuncGraph) with func_graph.as_default(), AutomaticControlDependencies() as a: variable_scope.get_variable_scope().set_use_resource(True) if signature is None: func_args = _get_defun_inputs_from_args(args) - func_kwds = _get_defun_inputs_from_args(kwds) + func_kwargs = _get_defun_inputs_from_args(kwargs) else: func_args = _get_defun_inputs_from_signature(signature) - func_kwds = {} + func_kwargs = {} # Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`. - func_graph.inputs.extend( - x for x in nest.flatten(func_args) + nest.flatten(func_kwds) - if isinstance(x, ops.Tensor)) - # Variables to help check whether mutation happens in calling the function # Copy the recursive list, tuple and map structure, but not base objects func_args_before = nest.pack_sequence_as(func_args, nest.flatten(func_args)) - func_kwds_before = nest.pack_sequence_as(func_kwds, nest.flatten(func_kwds)) + func_kwargs_before = nest.pack_sequence_as( + func_kwargs, nest.flatten(func_kwargs)) def convert(x): """Converts an argument to a Tensor.""" @@ -799,7 +914,7 @@ def func_graph_from_py_func(name, python_func, args, kwds, signature=None): this_tape = tape.push_new_tape() try: - func_outputs = python_func(*func_args, **func_kwds) + func_outputs = python_func(*func_args, **func_kwargs) # invariant: `func_outputs` contains only Tensors and `None`s. func_outputs = nest.map_structure(convert, func_outputs) @@ -819,10 +934,32 @@ def func_graph_from_py_func(name, python_func, args, kwds, signature=None): raise ValueError(errmsg) check_mutation(func_args_before, func_args) - check_mutation(func_kwds_before, func_kwds) + check_mutation(func_kwargs_before, func_kwargs) finally: tape.pop_tape(this_tape) + # Variables in `func_args`, `func_kwargs` should be explicit inputs + # to the function, not captured inputs. + tape_variables = this_tape.watched_variables() + arg_variables = set() + inputs = [] + for arg in nest.flatten(func_args) + nest.flatten(func_kwargs): + if isinstance(arg, resource_variable_ops.ResourceVariable): + try: + resource_placeholder = func_graph.captures.pop(arg.handle) + arg_variables.add(arg) + except KeyError: + # This case occurs if a Variable among the inputs is not actually + # used by the function; we still add an explicit input for it + # because the user should presumably pass the Variable as an input + # to the corresponding graph function. + resource_placeholder = _create_substitute_placeholder(arg.handle) + inputs.append(resource_placeholder) + elif isinstance(arg, ops.Tensor): + inputs.append(arg) + variables = [v for v in tape_variables if v not in arg_variables] + func_graph.inputs = inputs + list(func_graph.captures.values()) + func_graph.structured_outputs = func_outputs # Returning a closed-over tensor does not trigger convert_to_tensor. func_graph.outputs.extend( @@ -834,7 +971,6 @@ def func_graph_from_py_func(name, python_func, args, kwds, signature=None): # Instead of storing non-distributed component variables, we # store their distributed containers so we can retrieve the correct # component variables at call-time. - variables = list(this_tape.watched_variables()) strategy = distribution_strategy_context.get_distribution_strategy() for i, variable in enumerate(variables): # If variable is not distributed value_container returns itself. @@ -879,9 +1015,6 @@ def _encode_arg(arg): _TensorType(arg.values.dtype, arg.values._shape_tuple()), _TensorType(arg.indices.dtype, arg.indices._shape_tuple()), ]) - elif isinstance(arg, np.ndarray): - tensor = ops.convert_to_tensor(arg) - return _TensorType(tensor.dtype, tensor._shape_tuple()) # pylint: enable=protected-access elif isinstance(arg, (list, tuple)): return tuple([_encode_arg(elem) for elem in arg]) @@ -889,7 +1022,16 @@ def _encode_arg(arg): return tuple( (_encode_arg(key), _encode_arg(arg[key])) for key in sorted(arg)) else: - return arg + try: + # If possible, keep only a weak reference to Python objects. Weak + # references hash to the same value as the original object. + # TODO(allenl): Clean up dead functions and their cache keys if the cache + # gets large. Right now creating objects with a defunned method, calling + # the method, and losing a reference to the object in a loop will leak + # memory here. + return weakref.ref(arg) + except TypeError: + return arg def _deterministic_dict_values(dictionary): @@ -911,7 +1053,8 @@ class PolymorphicFunction(object): def __init__(self, python_function, name, - input_signature=None): + input_signature=None, + attributes=None): """Initializes a polymorphic function. Args: @@ -920,6 +1063,8 @@ class PolymorphicFunction(object): input_signature: a possibly nested sequence of `TensorSpec` objects specifying the input signature of this function. If `None`, a separate function is instantiated for each inferred input signature. + attributes: dict, extra keyword arguments that will be added as attribute + of the function. Raises: ValueError: if `input_signature` is not None and the `python_function`'s @@ -929,14 +1074,14 @@ class PolymorphicFunction(object): if isinstance(python_function, functools.partial): self._python_function = python_function.func self._args_to_prepend = python_function.args or tuple() - self._kwds_to_include = python_function.keywords or {} + self._kwargs_to_include = python_function.keywords or {} else: self._python_function = python_function self._args_to_prepend = tuple() - self._kwds_to_include = {} + self._kwargs_to_include = {} self._name = name self._function_cache = collections.OrderedDict() - self._variables = [] + self._function_attributes = attributes or {} self._lock = threading.Lock() @@ -971,9 +1116,9 @@ class PolymorphicFunction(object): self._input_signature = tuple(input_signature) self._flat_input_signature = tuple(nest.flatten(input_signature)) - def __call__(self, *args, **kwds): + def __call__(self, *args, **kwargs): """Calls a graph function specialized to the inputs.""" - graph_function, inputs = self._maybe_define_function(*args, **kwds) + graph_function, inputs = self._maybe_define_function(args, kwargs) return graph_function(*inputs) @property @@ -981,12 +1126,6 @@ class PolymorphicFunction(object): """Returns the wrapped Python function.""" return self._python_function - # TODO(akshayka): Remove this property. - @property - def variables(self): - """Returns the union of all variables referenced by cached `Function`s`.""" - return self._variables - def get_concrete_function(self, *args, **kwargs): """Returns a `Function` object specialized to inputs and execution context. @@ -997,7 +1136,7 @@ class PolymorphicFunction(object): *args: inputs to specialize on. **kwargs: inputs to specialize on. """ - graph_function, _ = self._maybe_define_function(*args, **kwargs) + graph_function, _ = self._maybe_define_function(args, kwargs) return graph_function def __get__(self, instance, owner): @@ -1018,33 +1157,37 @@ class PolymorphicFunction(object): # then `instance` will be `foo` (and `owner` will be `Foo`). return functools.partial(self.__call__, instance) - def _cache_key(self, args, kwds, ctx, graph): + def _cache_key(self, args, kwargs): """Computes the cache key given inputs and execution context.""" if self._input_signature is None: - inputs = (args, kwds) if kwds else args + inputs = (args, kwargs) if kwargs else args cache_key = tuple(_encode_arg(arg) for arg in inputs) else: - del args, kwds + del args, kwargs cache_key = self._flat_input_signature - # The graph, or whether we're executing eagerly, should be a part of the - # cache key so we don't improperly capture tensors such as variables. - executing_eagerly = ctx.executing_eagerly() - execution_context = executing_eagerly or graph + with ops.init_scope(): + init_graph = ops.get_default_graph() + + # The graph, or whether we're executing eagerly, should be a part of the + # cache key so we don't improperly capture tensors such as variables. + executing_eagerly = context.executing_eagerly() + execution_context = executing_eagerly or init_graph + default_graph = ops.get_default_graph() # Putting the device in the cache key ensures that call-site device # annotations are respected. - device_functions = _get_device_functions(ctx, graph) + device_functions = _get_device_functions(context.context(), default_graph) # `ops.colocate_with` directives translate into `ops.device` directives when # eager execution is enabled. - colocation_stack = (None if executing_eagerly else - tuple(graph._colocation_stack.peek_objs())) # pylint: disable=protected-access + colocation_stack = (() if executing_eagerly else + tuple(default_graph._colocation_stack.peek_objs())) # pylint: disable=protected-access return cache_key + (execution_context, device_functions, colocation_stack) - def _canonicalize_function_inputs(self, *args, **kwds): - """Canonicalizes `args` and `kwds`. + def _canonicalize_function_inputs(self, *args, **kwargs): + """Canonicalizes `args` and `kwargs`. Canonicalize the inputs to the Python function using its fullargspec. In particular, we parse the varags and kwargs that this @@ -1054,28 +1197,28 @@ class PolymorphicFunction(object): Args: *args: The varargs this object was called with. - **kwds: The keyword args this function was called with. + **kwargs: The keyword args this function was called with. Returns: A canonicalized ordering of the inputs. Raises: - ValueError: If a keyword in `kwds` cannot be matched with a positional + ValueError: If a keyword in `kwargs` cannot be matched with a positional argument when an input signature is specified, or when the inputs do not conform to the input signature. """ args = self._args_to_prepend + args - kwds = dict(kwds, **self._kwds_to_include) + kwargs = dict(kwargs, **self._kwargs_to_include) # Maps from index of arg to its corresponding value, according to `args` - # and `kwds`; seeded with the default values for the named args that aren't - # in `args`. + # and `kwargs`; seeded with the default values for the named args that + # aren't in `args`. arg_indices_to_values = { index: default for index, default in six.iteritems(self._arg_indices_to_default_values) if index >= len(args) } consumed_args = [] - for arg, value in six.iteritems(kwds): + for arg, value in six.iteritems(kwargs): index = self._args_to_indices.get(arg, None) if index is not None: arg_indices_to_values[index] = value @@ -1085,20 +1228,30 @@ class PolymorphicFunction(object): "function with keyword arguments when " "input_signature is provided.") for arg in consumed_args: - # After this loop, `kwds` will only contain true keyword arguments, as + # After this loop, `kwargs` will only contain true keyword arguments, as # opposed to named arguments called in a keyword-like fashion. - kwds.pop(arg) + kwargs.pop(arg) inputs = args + _deterministic_dict_values(arg_indices_to_values) + flat_inputs = nest.flatten(inputs) + + # Check for NumPy arrays in arguments and convert them to Tensors. + need_packing = False + for index, value in enumerate(flat_inputs): + if isinstance(value, np.ndarray): + flat_inputs[index] = constant_op.constant(value) + need_packing = True + if need_packing: + inputs = nest.pack_sequence_as(structure=inputs, + flat_sequence=flat_inputs) if self._input_signature is None: - return inputs, kwds + return inputs, kwargs else: - assert not kwds + assert not kwargs try: nest.assert_same_structure(self._input_signature, inputs) except (ValueError, TypeError): raise ValueError("Structure of Python function inputs does not match " "input_signature.") - flat_inputs = nest.flatten(inputs) if any(not isinstance(arg, ops.Tensor) for arg in flat_inputs): raise ValueError("When input_signature is provided, all inputs to " "the Python function must be Tensors.") @@ -1112,25 +1265,27 @@ class PolymorphicFunction(object): (str(inputs), str(self._input_signature))) return inputs, {} - def _maybe_define_function(self, *args, **kwds): + def _maybe_define_function(self, args, kwargs): """Gets a function for these inputs, defining it if necessary. + `args` and `kwargs` can be None if this `PolymorphicFunction` was created + with an `input_signature`. + Args: - *args: args for the Python function. - **kwds: keywords for the Python function. + args: The varargs for the Python function. + kwargs: The keyword args for the Python function. Returns: A graph function corresponding to the input signature implied by args and - kwds, as well as the inputs that the object should be called with. + kwargs, as well as the inputs that the object should be called with. Raises: ValueError: If inputs are incompatible with the input signature. TypeError: If the function inputs include non-hashable objects """ - - args, kwds = self._canonicalize_function_inputs(*args, **kwds) - cache_key = self._cache_key(args, kwds, context.context(), - ops.get_default_graph()) + if self._input_signature is None or args is not None or kwargs is not None: + args, kwargs = self._canonicalize_function_inputs(*args, **kwargs) + cache_key = self._cache_key(args, kwargs) with self._lock: try: graph_function = self._function_cache.get(cache_key, None) @@ -1141,11 +1296,41 @@ class PolymorphicFunction(object): if graph_function is None: graph_function = Function( func_graph_from_py_func(self._name, self._python_function, args, - kwds, self._input_signature)) - self._variables.extend( - [v for v in graph_function.variables if v not in self._variables]) + kwargs, self._input_signature), + self._function_attributes) self._function_cache[cache_key] = graph_function - return graph_function, (args, kwds) + return graph_function, [ + t for t in nest.flatten((args, kwargs)) + if isinstance(t, (ops.Tensor, resource_variable_ops.ResourceVariable)) + ] + + +def register(func, *args, **kwargs): + """Register the defun function into the graph. + + This won't actually call the function with the inputs, and only put the + function definition into graph. Register function with different input param + will result into multiple version of functions registered in graph. + + Args: + func: the PolymorphicFunction instance that generated by a @defun + *args: input arguments for the Python function. + **kwargs: input keyword arguments for the Python function. + + Returns: + a `Function` object specialized to inputs and execution context. + + Raises: + ValueError: When the input function is not a defun wrapped python function. + """ + if not isinstance(func, PolymorphicFunction): + raise ValueError("Only defun function is allowed to be registered. " + "Got type: %s" % type(func)) + concrete_func = func.get_concrete_function(*args, **kwargs) + graph = ops.get_default_graph() + concrete_func._inference_function.add_to_graph(graph) # pylint: disable=protected-access + # TODO(scottzhu): support concrete_func._backward_graph_function in future. + return concrete_func def _validate_signature(signature): @@ -1271,6 +1456,11 @@ def defun(func=None, input_signature=None): tracing the execution of `f(*args, **kwargs)`; this graph is bound to an input signature inferred from `(*args, **kwargs)` and cached for future reuse. + NumPy arrays passed as inputs to `F` are converted to `tf.Tensor` objects + before being passed to `f`, and are treated as Tensors for caching. This + allows a function to be called multiple times with NumPy arrays having + different values but the same shape and dtype without re-tracing each time. + `tf.contrib.eager.defun` caches graphs for your convenience, letting you define TensorFlow functions without explicitly specifying their signatures. However, this policy is conservative and potentially expensive; for example, @@ -1470,7 +1660,29 @@ def defun(func=None, input_signature=None): TypeError: If `input_signature` is neither `None` nor a sequence of `tf.contrib.eager.TensorSpec` objects. """ + return defun_with_attributes(func=func, input_signature=input_signature) + +def defun_with_attributes(func=None, input_signature=None, attributes=None): + """Compiles a Python function into a callable TensorFlow graph. + + This function supports adding extra function attributes. See detailed + documentation in defun(). Currently this is not exposed in public API since we + don't expect user to directly use attributes, and attribute won't work by + itself. This assumption might change in future. + + Args: + func: function to be compiled. + input_signature: same as defun()'s input_signature. + attributes: A dictionary of arguments which will be added to function def as + attributes. Currently only support primitive types as value, and only + whitelisted attribute name is allowed. Unwhitelisted attribute name or + unsupported value will result into ValueError. + + Returns: + Same as the return value of defun, with attributes added to the function in + graph. + """ if input_signature is not None: _validate_signature(input_signature) @@ -1482,7 +1694,8 @@ def defun(func=None, input_signature=None): name = "function" return tf_decorator.make_decorator( function, - PolymorphicFunction(function, name, input_signature=input_signature)) + PolymorphicFunction(function, name, input_signature=input_signature, + attributes=attributes)) # This code path is for the `foo = tfe.defun(foo, ...)` use case if func is not None: @@ -1727,9 +1940,9 @@ def automatic_control_dependencies(f): The wrapped function. """ - def wrapper(*args, **kwds): + def wrapper(*args, **kwargs): with AutomaticControlDependencies() as a: - result = f(*args, **kwds) + result = f(*args, **kwargs) result_flat = [a.mark_as_return(t) for t in nest.flatten(result)] return nest.pack_sequence_as(result, result_flat) diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 92254a2c00..59faf967c5 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -21,8 +21,13 @@ import collections import functools from multiprocessing.pool import ThreadPool import sys +import weakref + +import numpy from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 +from tensorflow.python import keras from tensorflow.python.data.ops import iterator_ops from tensorflow.python.eager import backprop from tensorflow.python.eager import context @@ -36,12 +41,14 @@ from tensorflow.python.framework import random_seed from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util +from tensorflow.python.keras.engine import training as keras_training from tensorflow.python.layers import convolutional from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import init_ops +from tensorflow.python.ops import list_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import resource_variable_ops @@ -55,6 +62,28 @@ from tensorflow.python.util import compat from tensorflow.python.util import nest +class MiniModel(keras_training.Model): + """Minimal model for mnist. + + Useful for testing and debugging on slow TPU simulators. + """ + + def __init__(self): + super(MiniModel, self).__init__(name='') + self.fc = keras.layers.Dense(1, name='fc', kernel_initializer='ones', + bias_initializer='ones') + + def call(self, inputs, training=True): + return self.fc(inputs) + + +class DefunnedMiniModel(MiniModel): + + @function.defun + def call(self, inputs, training=True): + return super(DefunnedMiniModel, self).call(inputs, training=training) + + @test_util.with_c_shapes class FunctionTest(test.TestCase): @@ -121,8 +150,8 @@ class FunctionTest(test.TestCase): @function.defun def f(): - v = resource_variable_ops.ResourceVariable(1.0) - return v.read_value() + self.v = resource_variable_ops.ResourceVariable(1.0) + return self.v.read_value() self.assertAllEqual(f(), 1.0) @@ -314,6 +343,7 @@ class FunctionTest(test.TestCase): def testDefunNumpyArraysConvertedToTensors(self): def f(x): + self.assertIsInstance(x, ops.Tensor) return x x = random_ops.random_uniform([2, 2]).numpy() @@ -327,6 +357,12 @@ class FunctionTest(test.TestCase): # shouldn't trigger another function definition. self.assertEqual(len(defined._function_cache), 1) + # Test that the numpy array is properly an argument to the graph function. + self.assertEqual(1., defined(numpy.ones([])).numpy()) + self.assertEqual(0., defined(numpy.zeros([])).numpy()) + self.assertEqual(1., defined(array_ops.ones([])).numpy()) + self.assertEqual(0., defined(array_ops.zeros([])).numpy()) + def testDefunCapturedInt32(self): x = constant_op.constant(1, dtype=dtypes.int32) @@ -373,9 +409,9 @@ class FunctionTest(test.TestCase): @function.defun def tensor_init(): - v = resource_variable_ops.ResourceVariable( + self.v = resource_variable_ops.ResourceVariable( lambda: constant_op.constant(2.0)) - return v.read_value() + return self.v.read_value() value = tensor_init() if not context.executing_eagerly(): @@ -389,8 +425,8 @@ class FunctionTest(test.TestCase): def tensor_init(): with ops.init_scope(): const = constant_op.constant(2.0) - v = resource_variable_ops.ResourceVariable(const) - return v.read_value() + self.v = resource_variable_ops.ResourceVariable(const) + return self.v.read_value() value = tensor_init() if not context.executing_eagerly(): @@ -403,10 +439,17 @@ class FunctionTest(test.TestCase): def f(): x = constant_op.constant([[1, 2], [3, 4]]) out = math_ops.matmul(v, x) - self.assertEqual(out.get_shape(), tensor_shape.TensorShape([2, 2])) + self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2])) + # We do not return v directly since the tensor conversion function of + # ResourceVariable returns the read value and not the resource itself. + return v._handle compiled = function.defun(f) - compiled() + var_handle = compiled() + self.assertEqual(var_handle.dtype, dtypes.resource) + self.assertEqual(var_handle.shape, tensor_shape.scalar()) + var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype) + self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2])) def testVariableInLoopInFunction(self): @@ -430,10 +473,17 @@ class FunctionTest(test.TestCase): def f(): x = constant_op.constant([[1, 2], [3, 4]]) out = math_ops.matmul(v, x) - self.assertEqual(out.get_shape(), tensor_shape.TensorShape([2, 2])) + self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2])) + # We do not return v directly since the tensor conversion function of + # ResourceVariable returns the read value and not the resource itself. + return v._handle compiled = function.defun(f) - compiled() + var_handle = compiled() + self.assertEqual(var_handle.dtype, dtypes.resource) + self.assertEqual(var_handle.shape, tensor_shape.scalar()) + var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype) + self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2])) def testDefunShapeInferenceWithCapturedVariableInGraphMode(self): with context.graph_mode(): @@ -442,23 +492,46 @@ class FunctionTest(test.TestCase): def f(): x = constant_op.constant([[1, 2], [3, 4]]) out = math_ops.matmul(v, x) - self.assertEqual(out.get_shape(), tensor_shape.TensorShape([2, 2])) + self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2])) # Check that shape inference works while creating the defun compiled = function.defun(f) compiled() + def testDefunShapeInferenceWithCapturedTensorListInGraphMode(self): + with context.graph_mode(): + tensor_list = list_ops.empty_tensor_list( + element_dtype=dtypes.float32, + element_shape=ops.convert_to_tensor([], dtype=dtypes.int32)) + tensor_list = list_ops.tensor_list_push_back(tensor_list, + constant_op.constant(1.0)) + tensor_list = list_ops.tensor_list_push_back(tensor_list, + constant_op.constant(2.0)) + + def f(): + tl, value = list_ops.tensor_list_pop_back( + tensor_list, element_dtype=dtypes.float32) + self.assertEqual(value.shape, tensor_shape.scalar()) + return tl + + compiled = function.defun(f) + output_tensor_list = compiled() + _, value = list_ops.tensor_list_pop_back( + output_tensor_list, element_dtype=dtypes.float32) + self.assertEqual(value.shape, tensor_shape.scalar()) + @test_util.run_in_graph_and_eager_modes def testDefunForcesResourceVariables(self): def variable_creator(): - return variables.Variable(0.0).read_value() + self.v = variables.Variable(0.0) + return self.v.read_value() + self.v = None defined = function.defun(variable_creator) defined() # Create the variable. - self.assertEqual(len(defined.variables), 1) self.assertIsInstance( - defined.variables[0], resource_variable_ops.ResourceVariable) + self.v, resource_variable_ops.ResourceVariable) def testDefunDifferentiable(self): v = resource_variable_ops.ResourceVariable(1.0) @@ -996,6 +1069,7 @@ class FunctionTest(test.TestCase): with ops.get_default_graph().as_default(): create_variable() + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testLayerInDefun(self): conv = convolutional.Conv2D( filters=1, @@ -1009,7 +1083,34 @@ class FunctionTest(test.TestCase): x = array_ops.ones([1, 2, 2, 1]) y = model(x) - self.assertAllEqual([[[[4.0]]]], y.numpy()) + + if not context.executing_eagerly(): + self.evaluate(variables.global_variables_initializer()) + + self.assertAllEqual([[[[4.0]]]], self.evaluate(y)) + + # Remove reference cycles in model + test_util.dismantle_polymorphic_function(model) + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testDefunKerasModelCall(self): + model = MiniModel() + model.call = function.defun(model.call) + + x = array_ops.ones([1, 2]) + y = model(x) + + if not context.executing_eagerly(): + self.evaluate(variables.global_variables_initializer()) + + self.assertAllEqual([[3.0]], self.evaluate(y)) + + # Remove reference cycles in defun. + test_util.dismantle_polymorphic_function(model.call) + # Break the reference cycle between the MiniModel and the defun: + # MiniModel --(through its `call` method)--> PolymorphicFunction + # PolymorphicFunction --(instancemethod on MiniModel)--> MiniModel + del model.call # Note: The ConfigProto below unfortunately only configures graph # construction. Eager's configuration is controlled in `__main__`. @@ -1130,13 +1231,11 @@ class FunctionTest(test.TestCase): defined = function.defun(foo) x = constant_op.constant([1.0]) - self.assertAllEqual(defined.variables, []) - _ = defined(x) - self.assertAllEqual(defined.variables, [v]) + self.assertEqual(1., self.evaluate(defined(x))) + v.assign(2.) x = constant_op.constant([1.0, 2.0]) - _ = defined(x) # ensure the variables list remains the same - self.assertAllEqual(defined.variables, [v]) + self.assertAllEqual([2., 4.], self.evaluate(defined(x))) def testPythonFunctionWithDefaultArgs(self): @@ -1492,6 +1591,257 @@ class FunctionTest(test.TestCase): side_effecting_function.python_function() self.assertAllEqual(state, [0, 0]) + def testFunctionWithExtraAttributes(self): + @function.defun_with_attributes(attributes={'experimental_1': 'value1', + 'experimental_2': 2}) + def matmul(x, y): + return math_ops.matmul(x, y) + + def add(x, y): + return math_ops.add(x, y) + defun_add = function.defun_with_attributes( + add, attributes={'experimental_3': True, 'experimental_4': 1.0}) + + with context.graph_mode(), self.cached_session(): + with ops.get_default_graph().as_default(): + t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + sq = matmul(t, t) + double = defun_add(t, t) + self.assertAllEqual(sq.eval().reshape(-1), [7, 10, 15, 22]) + self.assertAllEqual(double.eval().reshape(-1), [2, 4, 6, 8]) + + graph = ops.get_default_graph() + # pylint: disable=protected-access + self.assertEqual(len(graph._functions), 2) + functions = list(graph._functions.values()) + self.assertRegexpMatches( + functions[0].definition.signature.name, '.*matmul.*') + attrs = functions[0].definition.attr + self.assertEqual(len(attrs), 2) + self.assertEqual(attrs['experimental_1'].s, b'value1') + self.assertEqual(attrs['experimental_2'].i, 2) + + self.assertRegexpMatches( + functions[1].definition.signature.name, '.*add.*') + attrs = functions[1].definition.attr + self.assertEqual(len(attrs), 2) + self.assertEqual(attrs['experimental_3'].b, True) + self.assertEqual(attrs['experimental_4'].f, 1.0) + # pylint: enable=protected-access + + def testFunctionWithInvalidAttribute(self): + @function.defun_with_attributes(attributes={'attr1': 'value1'}) + def matmul(x, y): + return math_ops.matmul(x, y) + + with self.assertRaisesRegexp(ValueError, + '.*Attribute name is not whitelisted.*'): + with context.graph_mode(), self.cached_session(): + with ops.get_default_graph().as_default(): + t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + matmul(t, t) + + @function.defun_with_attributes(attributes={'experimental_1': ['value1']}) + def add(x, y): + return math_ops.add(x, y) + + with self.assertRaisesRegexp(ValueError, + '.*Unsupported attribute type.*'): + with context.graph_mode(), self.cached_session(): + with ops.get_default_graph().as_default(): + t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + add(t, t) + + def testRegisterFunction(self): + @function.defun + def add(x, y): + return math_ops.add(x, y) + + def matmul(x, y): + return math_ops.matmul(x, y) + defun_matmul = function.defun(matmul) + + with context.graph_mode(), self.cached_session(): + with ops.get_default_graph().as_default(): + t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + function.register(defun_matmul, t, t) + function.register(add, t, t) + + graph = ops.get_default_graph() + # pylint: disable=protected-access + self.assertEqual(len(graph._functions), 2) + functions = list(graph._functions.values()) + pre_register_matmul_func_name = functions[0].definition.signature.name + self.assertRegexpMatches(pre_register_matmul_func_name, '.*matmul.*') + pre_register_add_func_name = functions[1].definition.signature.name + self.assertRegexpMatches(pre_register_add_func_name, '.*add.*') + + sq = defun_matmul(t, t) + double = add(t, t) + self.assertAllEqual(sq.eval().reshape(-1), [7, 10, 15, 22]) + self.assertAllEqual(double.eval().reshape(-1), [2, 4, 6, 8]) + # Make sure the pre registered function is used, and no other function + # is added. + self.assertEqual(len(graph._functions), 2) + functions = list(graph._functions.values()) + called_func_name = functions[0].definition.signature.name + self.assertEqual(pre_register_matmul_func_name, called_func_name) + called_func_name = functions[1].definition.signature.name + self.assertEqual(pre_register_add_func_name, called_func_name) + + def testRegisterFunctionWithInputSignature(self): + def matmul(x, y): + return math_ops.matmul(x, y) + defun_matmul = function.defun( + matmul, + input_signature=[ + tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32), + tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32) + ]) + with context.graph_mode(), self.cached_session(): + with ops.get_default_graph().as_default(): + t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + function.register(defun_matmul, t, t) + + graph = ops.get_default_graph() + # pylint: disable=protected-access + self.assertEqual(len(graph._functions), 1) + + # Test input param shape mismatch + t2 = constant_op.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + with self.assertRaisesRegexp( + ValueError, 'Python inputs incompatible with input_signature'): + function.register(defun_matmul, t2, t2) + + def testRegisterFunctionWithCache(self): + def matmul(x, y): + return math_ops.matmul(x, y) + defun_matmul = function.defun(matmul) + + with context.graph_mode(), self.cached_session(): + with ops.get_default_graph().as_default(): + t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + t2 = constant_op.constant([[2.0, 3.0], [4.0, 5.0]]) + function.register(defun_matmul, t, t) + function.register(defun_matmul, t2, t2) + + graph = ops.get_default_graph() + # Only one function is registered since the input param are in same type + # pylint: disable=protected-access + self.assertEqual(len(graph._functions), 1) + + def testCallingFunctionWithDifferentVariables(self): + + @function.defun + def foo(v): + v.assign_add(1.0) + return v.read_value() + + v = resource_variable_ops.ResourceVariable(0.0) + graph_function = foo.get_concrete_function(v) + self.assertEqual(len(graph_function.inputs), 1) + self.assertEqual(len(graph_function.captured_inputs), 0) + + self.assertEqual(float(graph_function(v)), 1.0) + self.assertEqual(float(graph_function(v)), 2.0) + + w = resource_variable_ops.ResourceVariable(0.0) + + @function.defun + def bar(v): + del v + return constant_op.constant(1.0) + + graph_function = bar.get_concrete_function(v) + self.assertEqual(float(graph_function(v)), 1.0) + self.assertEqual(float(graph_function(w)), 1.0) + + def testCallingFunctionWithNonTensorsFails(self): + + @function.defun + def foo(x): + return x + + graph_function = foo.get_concrete_function(constant_op.constant(1.0)) + with self.assertRaisesRegexp(ValueError, 'All inputs to `Function`s must ' + 'be Tensors;.*'): + graph_function('Not a Tensor.') + + def testSwapImplementationWithGrapplerPlugin(self): + rewrites = rewriter_config_pb2.RewriterConfig() + # function_optimizer has to be turn off, otherwise it will delete the + # registered function if it does not get called. + # TODO(scottzhu): Move the ExperimentalImplementationSelector to be called + # before function_optimizer in future. + rewrites.function_optimization = rewriter_config_pb2.RewriterConfig.OFF + customer_optimizer = rewrites.custom_optimizers.add() + customer_optimizer.name = 'ExperimentalImplementationSelector' + rewrites.min_graph_nodes = -1 + graph_options = config_pb2.GraphOptions( + rewrite_options=rewrites, build_cost_model=1) + config = config_pb2.ConfigProto(graph_options=graph_options) + + with context.graph_mode(), self.cached_session( + config=config, graph=ops.Graph(), use_gpu=True) as sess: + + @function.defun_with_attributes( + attributes={ + 'experimental_api_implements': 'random_boost', + 'experimental_api_preferred_device': 'CPU' + }) + def cpu_boost(x): + return math_ops.add(x, 2.0) + + @function.defun_with_attributes( + attributes={ + 'experimental_api_implements': 'random_boost', + 'experimental_api_preferred_device': 'GPU' + }) + def gpu_boost(x): + return math_ops.add(x, 4.0) + + x = constant_op.constant(1.0) + + function.register(cpu_boost, x) + y = gpu_boost(x) + y_value = sess.run(y) + + if test.is_gpu_available(): + self.assertEquals(y_value, 5.0) + else: + # Grappler fallback to use the CPU impl even called with GPU function. + self.assertEquals(y_value, 3.0) + + def testDefunFunctionSeparateGraphs(self): + with context.graph_mode(): + + @function.defun + def add(x): + return x + 5 + + @function.defun + def maybe_add(x, should_add): + if should_add: + return add(x) + else: + return x + + with ops.Graph().as_default(): + x = constant_op.constant(11) + maybe_add(x, True) + self.assertEqual(len(maybe_add._function_cache), 1) + self.assertEqual(len(add._function_cache), 1) + + maybe_add(x, False) + self.assertEqual(len(maybe_add._function_cache), 2) + self.assertEqual(len(add._function_cache), 1) + + with ops.Graph().as_default(): + x = constant_op.constant(11) + maybe_add(x, True) + self.assertEqual(len(maybe_add._function_cache), 3) + self.assertEqual(len(add._function_cache), 2) + @test_util.with_c_shapes class AutomaticControlDependenciesTest(test.TestCase): @@ -1683,10 +2033,10 @@ class AutomaticControlDependenciesTest(test.TestCase): @function.defun def train(): - v = resource_variable_ops.ResourceVariable(1.0) - grad = backprop.implicit_grad(loss)(v) + self.v = resource_variable_ops.ResourceVariable(1.0) + grad = backprop.implicit_grad(loss)(self.v) optimizer.apply_gradients(grad) - return v.read_value() + return self.v.read_value() value = train() self.assertEqual(value.numpy(), -1.0) @@ -1713,10 +2063,10 @@ class AutomaticControlDependenciesTest(test.TestCase): @function.defun def train(): - v = resource_variable_ops.ResourceVariable(1.0) - grad = backprop.implicit_grad(loss)(v) + self.v = resource_variable_ops.ResourceVariable(1.0) + grad = backprop.implicit_grad(loss)(self.v) optimizer.apply_gradients(grad) - return v.read_value() + return self.v.read_value() train() @@ -1903,6 +2253,13 @@ class AutomaticControlDependenciesTest(test.TestCase): modify_same_flat(nested_input) + def testDecoratedMethodVariableCleanup(self): + m = DefunnedMiniModel() + m(array_ops.ones([1, 2])) + weak_variables = weakref.WeakSet(m.variables) + self.assertEqual(2, len(weak_variables)) + del m + self.assertEqual([], list(weak_variables)) if __name__ == '__main__': ops.enable_eager_execution( diff --git a/tensorflow/python/eager/imperative_grad.py b/tensorflow/python/eager/imperative_grad.py index 5f027d107c..5f5af4ab6c 100644 --- a/tensorflow/python/eager/imperative_grad.py +++ b/tensorflow/python/eager/imperative_grad.py @@ -23,8 +23,9 @@ import collections from tensorflow.python import pywrap_tensorflow -VSpace = collections.namedtuple( - "VSpace", ["aggregate_fn", "num_elements_fn", "zeros", "ones"]) +VSpace = collections.namedtuple("VSpace", [ + "aggregate_fn", "num_elements_fn", "zeros_fn", "ones_fn", "graph_shape_fn" +]) def imperative_grad( diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc index f34ce6af79..5f44bd4fec 100644 --- a/tensorflow/python/eager/pywrap_tensor.cc +++ b/tensorflow/python/eager/pywrap_tensor.cc @@ -516,25 +516,13 @@ static PyObject* EagerTensor_rank(EagerTensor* self) { // Getter for `_num_elements`. static PyObject* EagerTensor_num_elements(EagerTensor* self) { auto handle = self->handle; - int n = TFE_TensorHandleNumDims(handle, self->status); + int n = TFE_TensorHandleNumElements(handle, self->status); if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) { // Cleanup self->status before returning. TF_SetStatus(self->status, TF_OK, ""); return nullptr; } - tensorflow::int64 value = 1; - if (PyErr_Occurred()) return nullptr; - for (int i = 0; i < n; ++i) { - int64_t dim = TFE_TensorHandleDim(handle, i, self->status); - if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) { - // Cleanup self->status before returning. - TF_SetStatus(self->status, TF_OK, ""); - PyErr_SetString(PyExc_RuntimeError, "Error while iterating dimensions"); - return nullptr; - } - value *= dim; - } - return PyLong_FromLongLong(value); + return PyLong_FromLongLong(n); } static PyObject* EagerTensor_tensor_handle(EagerTensor* self, void* unused) { @@ -777,17 +765,34 @@ PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle) { return reinterpret_cast<PyObject*>(t); } -tensorflow::int64 EagerTensor_id(const PyObject* tensor) { - CHECK(EagerTensor_CheckExact(tensor)); +tensorflow::int64 PyEagerTensor_ID(const PyObject* tensor) { + DCHECK(EagerTensor_CheckExact(tensor)); return reinterpret_cast<const EagerTensor*>(tensor)->id; } -tensorflow::DataType EagerTensor_dtype(const PyObject* tensor) { - CHECK(EagerTensor_CheckExact(tensor)); +tensorflow::DataType PyEagerTensor_Dtype(const PyObject* tensor) { + DCHECK(EagerTensor_CheckExact(tensor)); return static_cast<tensorflow::DataType>(TFE_TensorHandleDataType( reinterpret_cast<const EagerTensor*>(tensor)->handle)); } +tensorflow::int64 PyEagerTensor_NumElements(const PyObject* tensor) { + DCHECK(EagerTensor_CheckExact(tensor)); + const EagerTensor* as_c_eager_tensor = + reinterpret_cast<const EagerTensor*>(tensor); + tensorflow::int64 result = TFE_TensorHandleNumElements( + as_c_eager_tensor->handle, as_c_eager_tensor->status); + + if (MaybeRaiseExceptionFromTFStatus(as_c_eager_tensor->status, + PyExc_ValueError)) { + // Cleanup status before returning. + TF_SetStatus(as_c_eager_tensor->status, TF_OK, ""); + return -1; + } + + return result; +} + PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) { if (!PyType_Check(base_class)) { PyErr_SetString( diff --git a/tensorflow/python/eager/pywrap_tensor.h b/tensorflow/python/eager/pywrap_tensor.h index bc042eb19e..4eaa1ba536 100644 --- a/tensorflow/python/eager/pywrap_tensor.h +++ b/tensorflow/python/eager/pywrap_tensor.h @@ -21,8 +21,9 @@ limitations under the License. #include "tensorflow/python/lib/core/numpy.h" bool EagerTensor_CheckExact(const PyObject* o); -tensorflow::int64 EagerTensor_id(const PyObject* tensor); -tensorflow::DataType EagerTensor_dtype(const PyObject* tensor); +tensorflow::int64 PyEagerTensor_ID(const PyObject* tensor); +tensorflow::DataType PyEagerTensor_Dtype(const PyObject* tensor); +tensorflow::int64 PyEagerTensor_NumElements(const PyObject* tensor); namespace tensorflow { TFE_TensorHandle* ConvertToEagerTensor(PyObject* value, PyObject* dtype); diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 46dcf7c8a8..159b1c1218 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/python/eager/pywrap_tfe.h" +#include "absl/types/variant.h" #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/eager/c_api_internal.h" @@ -860,7 +861,7 @@ static tensorflow::int64 MakeInt(PyObject* integer) { static tensorflow::int64 FastTensorId(PyObject* tensor) { if (EagerTensor_CheckExact(tensor)) { - return EagerTensor_id(tensor); + return PyEagerTensor_ID(tensor); } PyObject* id_field = PyObject_GetAttrString(tensor, "_id"); if (id_field == nullptr) { @@ -873,7 +874,7 @@ static tensorflow::int64 FastTensorId(PyObject* tensor) { static tensorflow::DataType FastTensorDtype(PyObject* tensor) { if (EagerTensor_CheckExact(tensor)) { - return EagerTensor_dtype(tensor); + return PyEagerTensor_Dtype(tensor); } PyObject* dtype_field = PyObject_GetAttrString(tensor, "dtype"); if (dtype_field == nullptr) { @@ -889,12 +890,239 @@ static tensorflow::DataType FastTensorDtype(PyObject* tensor) { return static_cast<tensorflow::DataType>(id); } +class PyTapeTensor { + public: + PyTapeTensor(tensorflow::int64 id, tensorflow::DataType dtype, + const tensorflow::TensorShape& shape) + : id_(id), dtype_(dtype), shape_(shape) {} + PyTapeTensor(tensorflow::int64 id, tensorflow::DataType dtype, + PyObject* shape) + : id_(id), dtype_(dtype), shape_(shape) { + Py_INCREF(absl::get<1>(shape_)); + } + PyTapeTensor(const PyTapeTensor& other) { + id_ = other.id_; + dtype_ = other.dtype_; + shape_ = other.shape_; + if (shape_.index() == 1) { + Py_INCREF(absl::get<1>(shape_)); + } + } + + ~PyTapeTensor() { + if (shape_.index() == 1) { + Py_DECREF(absl::get<1>(shape_)); + } + } + PyObject* GetShape() const; + PyObject* GetDType() const { return PyLong_FromLong(dtype_); } + tensorflow::int64 GetID() const { return id_; } + + private: + tensorflow::int64 id_; + tensorflow::DataType dtype_; + absl::variant<tensorflow::TensorShape, PyObject*> shape_; +}; + +class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction, + PyTapeTensor> { + public: + explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) { + Py_INCREF(py_vspace_); + } + + tensorflow::Status Initialize() { + num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn"); + if (num_elements_ == nullptr) { + return tensorflow::errors::InvalidArgument("invalid vspace"); + } + aggregate_fn_ = PyObject_GetAttrString(py_vspace_, "aggregate_fn"); + if (aggregate_fn_ == nullptr) { + return tensorflow::errors::InvalidArgument("invalid vspace"); + } + zeros_fn_ = PyObject_GetAttrString(py_vspace_, "zeros_fn"); + if (zeros_fn_ == nullptr) { + return tensorflow::errors::InvalidArgument("invalid vspace"); + } + ones_fn_ = PyObject_GetAttrString(py_vspace_, "ones_fn"); + if (ones_fn_ == nullptr) { + return tensorflow::errors::InvalidArgument("invalid vspace"); + } + graph_shape_fn_ = PyObject_GetAttrString(py_vspace_, "graph_shape_fn"); + if (graph_shape_fn_ == nullptr) { + return tensorflow::errors::InvalidArgument("invalid vspace"); + } + return tensorflow::Status::OK(); + } + + ~PyVSpace() override { + Py_XDECREF(num_elements_); + Py_XDECREF(aggregate_fn_); + Py_XDECREF(zeros_fn_); + Py_XDECREF(ones_fn_); + Py_XDECREF(graph_shape_fn_); + + Py_DECREF(py_vspace_); + } + + tensorflow::int64 NumElements(PyObject* tensor) const final { + if (EagerTensor_CheckExact(tensor)) { + return PyEagerTensor_NumElements(tensor); + } + PyObject* arglist = + Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor)); + PyObject* result = PyEval_CallObject(num_elements_, arglist); + Py_DECREF(arglist); + if (result == nullptr) { + // The caller detects whether a python exception has been raised. + return -1; + } + tensorflow::int64 r = MakeInt(result); + Py_DECREF(result); + return r; + } + + PyObject* AggregateGradients( + tensorflow::gtl::ArraySlice<PyObject*> gradient_tensors) const final { + PyObject* list = PyList_New(gradient_tensors.size()); + for (int i = 0; i < gradient_tensors.size(); ++i) { + // Note: stealing a reference to the gradient tensors. + CHECK(gradient_tensors[i] != nullptr); + CHECK(gradient_tensors[i] != Py_None); + PyList_SET_ITEM(list, i, + reinterpret_cast<PyObject*>(gradient_tensors[i])); + } + PyObject* arglist = Py_BuildValue("(O)", list); + CHECK(arglist != nullptr); + PyObject* result = PyEval_CallObject(aggregate_fn_, arglist); + Py_DECREF(arglist); + Py_DECREF(list); + return result; + } + + void MarkAsResult(PyObject* gradient) const final { Py_INCREF(gradient); } + + PyObject* Zeros(const PyTapeTensor& tensor) const final { + PyObject* py_shape = tensor.GetShape(); + PyObject* py_dtype = tensor.GetDType(); + PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype); + PyObject* result = PyEval_CallObject(zeros_fn_, arg_list); + Py_DECREF(arg_list); + Py_DECREF(py_dtype); + Py_DECREF(py_shape); + return reinterpret_cast<PyObject*>(result); + } + + PyObject* Ones(const PyTapeTensor& tensor) const final { + PyObject* py_shape = tensor.GetShape(); + PyObject* py_dtype = tensor.GetDType(); + PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype); + PyObject* result = PyEval_CallObject(ones_fn_, arg_list); + Py_DECREF(arg_list); + Py_DECREF(py_dtype); + Py_DECREF(py_shape); + return result; + } + + PyObject* GraphShape(PyObject* tensor) const { + PyObject* arg_list = Py_BuildValue("(O)", tensor); + PyObject* result = PyEval_CallObject(graph_shape_fn_, arg_list); + Py_DECREF(arg_list); + return result; + } + + tensorflow::Status CallBackwardFunction( + PyBackwardFunction* backward_function, + tensorflow::gtl::ArraySlice<PyObject*> output_gradients, + std::vector<PyObject*>* result) const final { + PyObject* grads = PyTuple_New(output_gradients.size()); + for (int i = 0; i < output_gradients.size(); ++i) { + if (output_gradients[i] == nullptr) { + Py_INCREF(Py_None); + PyTuple_SET_ITEM(grads, i, Py_None); + } else { + PyTuple_SET_ITEM(grads, i, + reinterpret_cast<PyObject*>(output_gradients[i])); + } + } + PyObject* py_result = (*backward_function)(grads); + Py_DECREF(grads); + if (py_result == nullptr) { + return tensorflow::errors::Internal("gradient function threw exceptions"); + } + result->clear(); + PyObject* seq = + PySequence_Fast(py_result, "expected a sequence of gradients"); + if (seq == nullptr) { + return tensorflow::errors::InvalidArgument( + "gradient function did not return a list"); + } + int len = PySequence_Fast_GET_SIZE(seq); + VLOG(1) << "Gradient length is " << len; + result->reserve(len); + for (int i = 0; i < len; ++i) { + PyObject* item = PySequence_Fast_GET_ITEM(seq, i); + if (item == Py_None) { + result->push_back(nullptr); + } else { + Py_INCREF(item); + result->push_back(item); + } + } + Py_DECREF(seq); + Py_DECREF(py_result); + return tensorflow::Status::OK(); + } + + void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor); } + + private: + PyObject* py_vspace_; + + PyObject* num_elements_; + PyObject* aggregate_fn_; + PyObject* zeros_fn_; + PyObject* ones_fn_; + PyObject* graph_shape_fn_; +}; +PyVSpace* py_vspace = nullptr; + +PyObject* TFE_Py_RegisterVSpace(PyObject* e) { + if (py_vspace != nullptr) { + delete py_vspace; + } + + py_vspace = new PyVSpace(e); + auto status = py_vspace->Initialize(); + if (MaybeRaiseExceptionFromStatus(status, nullptr)) { + delete py_vspace; + return nullptr; + } + + Py_RETURN_NONE; +} + +PyObject* PyTapeTensor::GetShape() const { + if (shape_.index() == 0) { + auto& shape = absl::get<0>(shape_); + PyObject* py_shape = PyTuple_New(shape.dims()); + for (int i = 0; i < shape.dims(); ++i) { + PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i))); + } + + return py_shape; + } + + return py_vspace->GraphShape(absl::get<1>(shape_)); +} + class GradientTape - : public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction> { + : public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction, + PyTapeTensor> { public: explicit GradientTape(bool persistent, bool watch_accessed_variables) - : tensorflow::eager::GradientTape<PyObject, PyBackwardFunction>( - persistent), + : tensorflow::eager::GradientTape<PyObject, PyBackwardFunction, + PyTapeTensor>(persistent), watch_accessed_variables_(watch_accessed_variables) {} virtual ~GradientTape() { @@ -1175,24 +1403,41 @@ void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor) { reinterpret_cast<TFE_Py_Tape*>(tape)->tape->Watch(tensor_id); } -static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) { +bool ListContainsNone(PyObject* list) { + if (list == Py_None) return true; + tensorflow::Safe_PyObjectPtr seq( + PySequence_Fast(list, "expected a sequence")); + if (seq == nullptr) { + return false; + } + + int len = PySequence_Size(list); + for (int i = 0; i < len; ++i) { + PyObject* item = PySequence_Fast_GET_ITEM(seq.get(), i); + if (item == Py_None) return true; + } + + return false; +} + +static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) { if (EagerTensor_CheckExact(tensor)) { TFE_TensorHandle* t = EagerTensor_Handle(tensor); - tensorflow::int64 id = EagerTensor_id(tensor); + tensorflow::int64 id = PyEagerTensor_ID(tensor); tensorflow::TensorShape tensor_shape; const tensorflow::Status status = t->handle->Shape(&tensor_shape); if (MaybeRaiseExceptionFromStatus(status, nullptr)) { - return tensorflow::eager::TapeTensor{id, t->handle->dtype, - tensorflow::TensorShape({})}; + return PyTapeTensor(id, static_cast<tensorflow::DataType>(0), + tensorflow::TensorShape({})); } else { - return tensorflow::eager::TapeTensor{id, t->handle->dtype, tensor_shape}; + return PyTapeTensor(id, t->handle->dtype, tensor_shape); } } tensorflow::int64 id = FastTensorId(tensor); if (PyErr_Occurred()) { - return tensorflow::eager::TapeTensor{ - id, static_cast<tensorflow::DataType>(0), tensorflow::TensorShape({})}; + return PyTapeTensor(id, static_cast<tensorflow::DataType>(0), + tensorflow::TensorShape({})); } PyObject* dtype_object = PyObject_GetAttrString(tensor, "dtype"); PyObject* dtype_enum = PyObject_GetAttrString(dtype_object, "_type_enum"); @@ -1200,16 +1445,21 @@ static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) { tensorflow::DataType dtype = static_cast<tensorflow::DataType>(MakeInt(dtype_enum)); Py_DECREF(dtype_enum); - if (PyErr_Occurred() != nullptr) { - return tensorflow::eager::TapeTensor{id, dtype, - tensorflow::TensorShape({})}; + if (PyErr_Occurred()) { + return PyTapeTensor(id, static_cast<tensorflow::DataType>(0), + tensorflow::TensorShape({})); } static char _shape_tuple[] = "_shape_tuple"; PyObject* shape_tuple = PyObject_CallMethod(tensor, _shape_tuple, nullptr); - if (PyErr_Occurred() != nullptr) { - return tensorflow::eager::TapeTensor{id, dtype, - tensorflow::TensorShape({})}; + if (PyErr_Occurred()) { + return PyTapeTensor(id, static_cast<tensorflow::DataType>(0), + tensorflow::TensorShape({})); } + + if (ListContainsNone(shape_tuple)) { + return PyTapeTensor(id, dtype, tensor); + } + auto l = MakeIntList(shape_tuple); Py_DECREF(shape_tuple); // Replace -1, which represents accidental Nones which can occur in graph mode @@ -1220,7 +1470,7 @@ static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) { } } tensorflow::TensorShape shape(l); - return tensorflow::eager::TapeTensor{id, dtype, shape}; + return PyTapeTensor(id, dtype, shape); } std::vector<tensorflow::int64> MakeTensorIDList(PyObject* tensors) { @@ -1286,7 +1536,7 @@ void TapeSetRecordOperation( const std::vector<tensorflow::DataType>& input_dtypes, const std::function<PyBackwardFunction*()>& backward_function_getter, const std::function<void(PyBackwardFunction*)>& backward_function_killer) { - std::vector<tensorflow::eager::TapeTensor> output_info; + std::vector<PyTapeTensor> output_info; PyObject* seq = PySequence_Fast(output_tensors, "expected a sequence of integer tensor ids"); int len = PySequence_Size(output_tensors); @@ -1362,173 +1612,6 @@ void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) { } } -class PyVSpace - : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction> { - public: - explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) { - Py_INCREF(py_vspace_); - } - - tensorflow::Status Initialize() { - num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn"); - if (num_elements_ == nullptr) { - return tensorflow::errors::InvalidArgument("invalid vspace"); - } - aggregate_fn_ = PyObject_GetAttrString(py_vspace_, "aggregate_fn"); - if (aggregate_fn_ == nullptr) { - return tensorflow::errors::InvalidArgument("invalid vspace"); - } - zeros_ = PyObject_GetAttrString(py_vspace_, "zeros"); - if (zeros_ == nullptr) { - return tensorflow::errors::InvalidArgument("invalid vspace"); - } - ones_ = - PyObject_GetAttrString(reinterpret_cast<PyObject*>(py_vspace_), "ones"); - if (ones_ == nullptr) { - return tensorflow::errors::InvalidArgument("invalid vspace"); - } - return tensorflow::Status::OK(); - } - - ~PyVSpace() override { - Py_XDECREF(num_elements_); - Py_XDECREF(aggregate_fn_); - Py_XDECREF(zeros_); - Py_XDECREF(ones_); - - Py_DECREF(py_vspace_); - } - - tensorflow::int64 NumElements(PyObject* tensor) const final { - PyObject* arglist = - Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor)); - PyObject* result = PyEval_CallObject(num_elements_, arglist); - tensorflow::int64 r = MakeInt(result); - Py_DECREF(result); - Py_DECREF(arglist); - return r; - } - - PyObject* AggregateGradients( - tensorflow::gtl::ArraySlice<PyObject*> gradient_tensors) const final { - PyObject* list = PyList_New(gradient_tensors.size()); - for (int i = 0; i < gradient_tensors.size(); ++i) { - // Note: stealing a reference to the gradient tensors. - CHECK(gradient_tensors[i] != nullptr); - CHECK(gradient_tensors[i] != Py_None); - PyList_SET_ITEM(list, i, - reinterpret_cast<PyObject*>(gradient_tensors[i])); - } - PyObject* arglist = Py_BuildValue("(O)", list); - CHECK(arglist != nullptr); - PyObject* result = PyEval_CallObject(aggregate_fn_, arglist); - Py_DECREF(arglist); - Py_DECREF(list); - return result; - } - - void MarkAsResult(PyObject* gradient) const final { Py_INCREF(gradient); } - - PyObject* Zeros(tensorflow::TensorShape shape, - tensorflow::DataType dtype) const final { - PyObject* py_shape = PyTuple_New(shape.dims()); - for (int i = 0; i < shape.dims(); ++i) { - PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i))); - } - PyObject* py_dtype = PyLong_FromLong(static_cast<int>(dtype)); - PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype); - PyObject* result = PyEval_CallObject(zeros_, arg_list); - Py_DECREF(arg_list); - Py_DECREF(py_dtype); - Py_DECREF(py_shape); - return reinterpret_cast<PyObject*>(result); - } - - PyObject* Ones(tensorflow::TensorShape shape, - tensorflow::DataType dtype) const final { - PyObject* py_shape = PyTuple_New(shape.dims()); - for (int i = 0; i < shape.dims(); ++i) { - PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i))); - } - PyObject* py_dtype = PyLong_FromLong(static_cast<int>(dtype)); - PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype); - PyObject* result = PyEval_CallObject(ones_, arg_list); - Py_DECREF(arg_list); - Py_DECREF(py_dtype); - Py_DECREF(py_shape); - return result; - } - - tensorflow::Status CallBackwardFunction( - PyBackwardFunction* backward_function, - tensorflow::gtl::ArraySlice<PyObject*> output_gradients, - std::vector<PyObject*>* result) const final { - PyObject* grads = PyTuple_New(output_gradients.size()); - for (int i = 0; i < output_gradients.size(); ++i) { - if (output_gradients[i] == nullptr) { - Py_INCREF(Py_None); - PyTuple_SET_ITEM(grads, i, Py_None); - } else { - PyTuple_SET_ITEM(grads, i, - reinterpret_cast<PyObject*>(output_gradients[i])); - } - } - PyObject* py_result = (*backward_function)(grads); - Py_DECREF(grads); - if (py_result == nullptr) { - return tensorflow::errors::Internal("gradient function threw exceptions"); - } - result->clear(); - PyObject* seq = - PySequence_Fast(py_result, "expected a sequence of gradients"); - if (seq == nullptr) { - return tensorflow::errors::InvalidArgument( - "gradient function did not return a list"); - } - int len = PySequence_Fast_GET_SIZE(seq); - VLOG(1) << "Gradient length is " << len; - result->reserve(len); - for (int i = 0; i < len; ++i) { - PyObject* item = PySequence_Fast_GET_ITEM(seq, i); - if (item == Py_None) { - result->push_back(nullptr); - } else { - Py_INCREF(item); - result->push_back(item); - } - } - Py_DECREF(seq); - Py_DECREF(py_result); - return tensorflow::Status::OK(); - } - - void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor); } - - private: - PyObject* py_vspace_; - - PyObject* num_elements_; - PyObject* aggregate_fn_; - PyObject* zeros_; - PyObject* ones_; -}; -PyVSpace* py_vspace = nullptr; - -PyObject* TFE_Py_RegisterVSpace(PyObject* e) { - if (py_vspace != nullptr) { - delete py_vspace; - } - - py_vspace = new PyVSpace(e); - auto status = py_vspace->Initialize(); - if (MaybeRaiseExceptionFromStatus(status, nullptr)) { - delete py_vspace; - return nullptr; - } - - Py_RETURN_NONE; -} - std::vector<PyObject*> MakeTensorList(PyObject* tensors) { PyObject* seq = PySequence_Fast(tensors, "expected a sequence"); if (seq == nullptr) { @@ -1740,6 +1823,9 @@ PyObject* MaybeGetDTypeForAttr(const string& attr, Py_RETURN_NONE; } +// TODO(agarwal): use an automatic mechanism for handling None arguments to +// gradient functions. + // Returns a pair where the first value of the pair indicates whether or not all // outputs are unused. If the first value is false, the second value is a // set that identifies which of the output indices are unused. @@ -2565,13 +2651,18 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) { int num_retvals = 0; for (int i = 0; i < op_def->output_arg_size(); i++) { const auto& output_arg = op_def->output_arg(i); + int delta = 1; if (!output_arg.number_attr().empty()) { - num_retvals += attr_list_sizes[output_arg.number_attr()]; + delta = attr_list_sizes[output_arg.number_attr()]; } else if (!output_arg.type_list_attr().empty()) { - num_retvals += attr_list_sizes[output_arg.type_list_attr()]; - } else { - num_retvals++; + delta = attr_list_sizes[output_arg.type_list_attr()]; + } + if (delta < 0) { + RaiseFallbackException( + "Attributes suggest that the size of an output list is less than 0"); + return nullptr; } + num_retvals += delta; } tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals(num_retvals); diff --git a/tensorflow/python/eager/pywrap_tfe_test.py b/tensorflow/python/eager/pywrap_tfe_test.py index fd8ab695b8..669fa08488 100644 --- a/tensorflow/python/eager/pywrap_tfe_test.py +++ b/tensorflow/python/eager/pywrap_tfe_test.py @@ -21,6 +21,7 @@ from __future__ import print_function from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import backprop from tensorflow.python.eager import context +from tensorflow.python.eager import core from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -123,8 +124,8 @@ class Tests(test.TestCase): def testFastpathExecute_MixedPrecisionVariableTapeWrite(self): ctx = context.context() with backprop.GradientTape(persistent=True) as tape: - a_2_by_2 = constant_op.constant( - [[1.0, 2.0], [3.0, 4.0]], dtype=dtypes.float32) + a_2_by_2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]], + dtype=dtypes.float32) a_2_by_2_fp16 = math_ops.cast(a_2_by_2, dtype=dtypes.float16) m1 = resource_variable_ops.ResourceVariable(a_2_by_2) m2 = resource_variable_ops._MixedPrecisionVariable( @@ -233,6 +234,26 @@ class Tests(test.TestCase): pywrap_tensorflow.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name, ctx_handle, None, [], a_2_by_2) + @test_util.assert_no_new_tensors + @test_util.assert_no_garbage_created + def testFastPathExecute_InvalidAttributes(self): + split_dim = constant_op.constant(0, dtype=dtypes.int32) + value = constant_op.constant([0, 1, 2, 3], dtype=dtypes.float32) + ctx = context.context() + ctx_handle = ctx._handle + with self.assertRaises(core._FallbackException): + pywrap_tensorflow.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name, + "Split", None, None, split_dim, + value, "num_split", -1) + + @test_util.assert_no_new_tensors + @test_util.assert_no_garbage_created + def testInvalidNumOutputs(self): + with self.assertRaisesRegexp( + Exception, + "Value for attr 'num_split' of -1 must be at least minimum 1"): + array_ops.split(value=[1, 2, 3], num_or_size_splits=-1) + if __name__ == "__main__": test.main() |