diff options
author | Alexandre Passos <apassos@google.com> | 2018-09-21 16:56:18 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-21 16:59:39 -0700 |
commit | 305a392904e6981e935c2a3514394379ba7083b1 (patch) | |
tree | d64fab8487bf1389b62323f5f23ea74f323f18c2 /tensorflow/python/eager | |
parent | f32c678543fcee2950e7ac6a84022e929df3acd7 (diff) |
Prototype for the functions-not-sessions implementation.
PiperOrigin-RevId: 214065999
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r-- | tensorflow/python/eager/BUILD | 28 | ||||
-rw-r--r-- | tensorflow/python/eager/def_function.py | 235 | ||||
-rw-r--r-- | tensorflow/python/eager/def_function_test.py | 87 |
3 files changed, 350 insertions, 0 deletions
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index a2686c68a9..f571da308e 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -46,6 +46,7 @@ py_library( ":backprop", ":context", ":core", + ":def_function", ":execute", ":function", ":graph_only_ops", @@ -380,3 +381,30 @@ cuda_py_test( "optonly", # The test is too slow in non-opt mode ], ) + +py_library( + name = "def_function", + srcs = ["def_function.py"], + srcs_version = "PY2AND3", + deps = [ + ":context", + ":function", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:variable_scope", + "//tensorflow/python/training/checkpointable:base", + ], +) + +py_test( + name = "def_function_test", + srcs = ["def_function_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":def_function", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:framework_ops", + ], +) diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py new file mode 100644 index 0000000000..8dcacd5c99 --- /dev/null +++ b/tensorflow/python/eager/def_function.py @@ -0,0 +1,235 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=unidiomatic-typecheck +"""Prototype decorator for defining graph-mode functions with eager semantics.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.eager import context +from tensorflow.python.eager import function +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.training.checkpointable import base as checkpointable + + +class UnliftedInitializerVariable(resource_variable_ops.ResourceVariable): + """Variable which does not lift its initializer out of function context. + + Instances of this variable, when created, build a graph which runs their + initializer inside a tf.cond(is_initialized) block. + + This can only be created inside a defun called from (eventually) eager + mode. That is, non-function-building graphs are not supported. + """ + + def __init__(self, # pylint: disable=super-init-not-called + initial_value=None, + trainable=True, + caching_device=None, + name=None, + dtype=None, + constraint=None, + **unused_kwargs): + """Creates a variable. + + Args: + initial_value: A `Tensor`, or Python object convertible to a `Tensor`, + which is the initial value for the Variable. The initial value must have + a shape specified unless `validate_shape` is set to False. Can also be a + callable with no argument that returns the initial value when called. + (Note that initializer functions from init_ops.py must first be bound + to a shape before being used here.) + trainable: If `True`, GradientTapes automatically watch uses of this + Variable. + caching_device: Optional device string or function describing where the + Variable should be cached for reading. Defaults to the Variable's + device. If not `None`, caches on another device. Typical use is to + cache on the device where the Ops using the Variable reside, to + deduplicate copying through `Switch` and other conditional statements. + name: Optional name for the variable. Defaults to `'Variable'` and gets + uniquified automatically. + dtype: If set, initial_value will be converted to the given type. + If None, either the datatype will be kept (if initial_value is + a Tensor) or float32 will be used (if it is a Python object convertible + to a Tensor). + constraint: An optional projection function to be applied to the variable + after being updated by an `Optimizer` (e.g. used to implement norm + constraints or value constraints for layer weights). The function must + take as input the unprojected Tensor representing the value of the + variable and return the Tensor for the projected value + (which must have the same shape). Constraints are not safe to + use when doing asynchronous distributed training. + + Raises: + ValueError: If the initial value is not specified, or does not have a + shape and `validate_shape` is `True`. + RuntimeError: If called outside of a function definition. + """ + if context.executing_eagerly(): + raise RuntimeError( + "UnliftedInitializerVariable should not be created " + "outside of functions.") + with ops.init_scope(): + if not context.executing_eagerly(): + raise RuntimeError( + "UnliftedInitializerVariable does not support legacy graph mode.") + self._in_graph_mode = False + if initial_value is None: + raise ValueError("initial_value must be specified.") + init_from_fn = callable(initial_value) + + if constraint is not None and not callable(constraint): + raise ValueError("The `constraint` argument must be a callable.") + + if isinstance(initial_value, checkpointable.CheckpointInitialValue): + self._maybe_initialize_checkpointable() + self._update_uid = initial_value.checkpoint_position.restore_uid + initial_value = initial_value.wrapped_value + + self._trainable = trainable + self._save_slice_info = None + self._initial_value = None + self._initializer_op = None + self._is_initialized_op = None + self._graph_element = None + self._cached_value = None + # Store the graph key so optimizers know how to only retrieve variables from + # this graph. Guaranteed to be the same as the eager graph_key. + self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access + with ops.name_scope(name, "Variable", [] + if init_from_fn else [initial_value]) as name: + # pylint: disable=protected-access + with ops.init_scope(): + assert context.executing_eagerly() + shared_name = ops._name_from_scope_name(name) + shared_name = "%s_%d" % (shared_name, ops.uid()) + # Use attr_scope and device(None) to simulate the behavior of + # colocate_with when the variable we want to colocate with doesn't + # yet exist. + with ops.name_scope("Initializer"), ops.device(None): + initial_value = ops.convert_to_tensor( + initial_value() if init_from_fn else initial_value, + name="initial_value", dtype=dtype) + with ops.init_scope(): + self._handle = resource_variable_ops.eager_safe_variable_handle( + shape=initial_value.get_shape(), + dtype=initial_value.dtype.base_dtype, + shared_name=shared_name, + name=name, + graph_mode=False) + self._shape = initial_value.shape + self._unique_id = shared_name + self._handle_name = shared_name + ":0" + self._dtype = initial_value.dtype.base_dtype + self._constraint = constraint + assert initial_value is not None + def assign_fn(): + with ops.name_scope("Assign") as n, ops.colocate_with(self._handle): + resource_variable_ops.assign_variable_op( + self._handle, + initial_value, + name=n) + # Returning values to keep tf.cond happy. + return ops.convert_to_tensor(1) + def not_assign_fn(): + return ops.convert_to_tensor(0) + # Note: this cond is always guaranteed to run because we're inside a defun + # which will insert automatic control dependencies. + control_flow_ops.cond( + resource_variable_ops.var_is_initialized_op(self._handle), + not_assign_fn, assign_fn) + + # After the handle has been created, set up a way to clean it up when + # executing eagerly. We'll hold the only reference to the deleter, so that + # when this object is garbage collected the deleter will be too. This + # means ResourceVariables can be part of reference cycles without those + # cycles being uncollectable. + self._handle_deleter = resource_variable_ops.EagerResourceDeleter( + handle=self._handle, handle_device=self._handle.device) + self._cached_shape_as_list = None + + +def _defun_with_scope(scope, fn): + + def wrapped_fn(*args, **kwds): + with variable_scope.variable_creator_scope(scope): + return fn(*args, **kwds) + + return function.defun(wrapped_fn) + + +def def_function(fn): + """Defines a function as per the "functions, not sessions" document.""" + + # Wrapping the values in lists to bypass python's lack of way to mutate + # symbols from an outer scope. + first_call = [True] + function_to_call = [] + + # TODO(apassos) represent this as an object and not as a closure. + def decorated_fn(*args, **kwds): + """Graph function for fn.""" + if not first_call[0]: + return function_to_call[0](*args, **kwds) + + first_call[0] = False + created_variables = [] + + def variable_creator_scope(unused_next_creator, **kwds): + """Creates UnliftedInitializerVariables and saves references to them.""" + v = UnliftedInitializerVariable(**kwds) + created_variables.append(v) + return v + + first_graph_function = _defun_with_scope(variable_creator_scope, fn) + + # Force the definition of the function for these arguments + first_concrete = first_graph_function.get_concrete_function(*args, **kwds) + + def invalid_creator_scope(*unused_args, **unused_kwds): + """Disables variable creation.""" + raise ValueError( + "def_function-decorated function tried to create " + "variables on second call.") + + second_graph_function = _defun_with_scope(invalid_creator_scope, fn) + + function_to_call.append(second_graph_function) + if not created_variables: + # Note: this retracing might be unnecessary, but running the function + # forever in the scope which disallows variable creation is safer than not + # doing so. + return second_graph_function(*args, **kwds) + + def fn_with_cond(*inner_args, **inner_kwds): + """Conditionally runs initialization if it's needed.""" + condition = True + for variable in created_variables: + condition = condition and resource_variable_ops.var_is_initialized_op( + variable.handle) + # We want to call second_graph_function if possible because it avoids + # recomputing potentially expensive initializers. + return control_flow_ops.cond( + condition, + lambda: second_graph_function(*inner_args, **inner_kwds), + lambda: first_concrete(*inner_args, **inner_kwds)) + + return function.defun(fn_with_cond)(*args, **kwds) + + return decorated_fn diff --git a/tensorflow/python/eager/def_function_test.py b/tensorflow/python/eager/def_function_test.py new file mode 100644 index 0000000000..804436c4bb --- /dev/null +++ b/tensorflow/python/eager/def_function_test.py @@ -0,0 +1,87 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +from tensorflow.python.eager import def_function +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +class DefFunctionTest(test.TestCase): + + def testNoVariables(self): + + @def_function.def_function + def fn(x): + return 2 * x + + self.assertAllEqual(fn(constant_op.constant(4.0)), 8.0) + + def testFailIfVariablesAreCreatedMoreThanOnce(self): + + @def_function.def_function + def fn(x): + return variables.Variable(1.0) + x + + with self.assertRaises(ValueError): + fn(1.0) + + def testFailIfVariablesAreCreatedMoreThanOnceNoWeakRef(self): + state = [] + + @def_function.def_function + def fn(x): + state.append(variables.Variable(1.0)) + return state[-1] + x + + with self.assertRaises(ValueError): + fn(1.0) + + def testCorrectVariableCreation(self): + + state = [] + + @def_function.def_function + def fn(x): + if not state: + state.append(variables.Variable(2.0)) + return state[0] * x + + self.assertAllEqual(fn(constant_op.constant(1.0)), 2.0) + self.assertAllEqual(fn(constant_op.constant(3.0)), 6.0) + + def testVariableInitializerNotConstant(self): + + state = [] + + @def_function.def_function + def fn(x): + if not state: + state.append(variables.Variable(2.0 * x)) + return state[0] * x + + self.assertAllEqual(fn(constant_op.constant(1.0)), 2.0) + self.assertAllEqual(fn(constant_op.constant(3.0)), 6.0) + + +if __name__ == '__main__': + ops.enable_eager_execution() + test.main() |