aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2017-10-20 14:49:56 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-20 14:56:07 -0700
commit37fd951790d7ad27c679c925c28b01ca73875738 (patch)
treeed585b7348da216e11af4b174d687701b147c42d
parentd1e7382af7b99dad2455d9b7eaf34989a75f26d2 (diff)
Simplifies capturing code in graph_callable to use recent function improvements.
PiperOrigin-RevId: 172937003
-rw-r--r--tensorflow/python/eager/graph_callable.py57
-rw-r--r--tensorflow/python/ops/resource_variable_ops.py15
2 files changed, 14 insertions, 58 deletions
diff --git a/tensorflow/python/eager/graph_callable.py b/tensorflow/python/eager/graph_callable.py
index 0ec83636a0..7f7a8c4a88 100644
--- a/tensorflow/python/eager/graph_callable.py
+++ b/tensorflow/python/eager/graph_callable.py
@@ -45,28 +45,6 @@ def _default_initializer(name, shape, dtype):
return initializer[0]
-class _VariableFromResource(resource_variable_ops.ResourceVariable):
- """Variable object from a preexisting resource.
-
- Required because the ResourceVariable constructor creates the resource handle,
- and here we want to use a preexisting one.
- """
-
- def __init__(self, resource, dtype, name, shape):
- self._handle = resource
- self._graph_shape = tensor_shape.as_shape(shape)
- self._handle_device = resource.device
- self._handle_name = name
- self._cached_value = None
- self._initializer_op = None
- self._caching_device = None
- self._dtype = dtype
- self._constraint = None
- self._in_graph_mode = context.in_graph_mode()
- if self._in_graph_mode:
- self._graph_element = self.read_value()
-
-
class _CapturedVariable(object):
"""Variable captured by graph_callable.
@@ -137,17 +115,11 @@ class _VariableCapturingScope(object):
trainable=True, collections=None, caching_device=None, # pylint: disable=redefined-outer-name
partitioner=None, validate_shape=True,
use_resource=None):
- del getter, regularizer, partitioner, validate_shape, use_resource
- del collections, initializer, trainable, reuse, caching_device
+ del getter, regularizer, partitioner, validate_shape, use_resource, dtype
+ del collections, initializer, trainable, reuse, caching_device, shape,
assert name in self.variables
v = self.variables[name]
- v.placeholder = array_ops.placeholder(dtype=dtypes.resource, shape=shape)
- # TODO(apassos) remove the need for this by correctly dealing with shape
- # inference.
- v.placeholder._handle_data = v.variable.handle._handle_data # pylint: disable=protected-access
- return _VariableFromResource(
- v.placeholder, dtype=dtypes.as_dtype(dtype), name=name,
- shape=v.shape)
+ return v.variable
scope = variable_scope.get_variable_scope()
with variable_scope.variable_scope(scope, custom_getter=_custom_getter):
@@ -181,14 +153,12 @@ class _VariableCapturingScope(object):
v = _CapturedVariable(name, initializer, shape, dtype, trainable)
self.variables[name] = v
- graph_mode_resource = resource_variable_ops.var_handle_op(
- shared_name=name, shape=shape, dtype=dtype)
+ graph_mode_resource = v.variable.handle
if initializer is None:
initializer = _default_initializer(name, shape, dtype)
resource_variable_ops.assign_variable_op(
graph_mode_resource, initializer(shape, dtype))
- return _VariableFromResource(
- graph_mode_resource, dtype, name, shape=v.shape)
+ return v.variable
scope = variable_scope.get_variable_scope()
with variable_scope.variable_scope(scope, custom_getter=_custom_getter):
@@ -220,13 +190,6 @@ class _FunctionObject(function._GraphModeFunction): # pylint: disable=protected
def variables(self):
return [x.variable for x in self._variables]
- def __call__(self, *args, **kwds):
- kwds.pop("want_gradients", False)
- if kwds:
- raise ValueError("graph_callable functions do not take keyword args")
- values = [x.variable.handle for x in self._variables]
- return super(_FunctionObject, self).__call__(*(values + list(args)))
-
class _InitializingFunctionObject(object):
"""Responsible for deciding which version of func-to-object to call.
@@ -318,7 +281,8 @@ def _graph_callable_internal(func, shape_and_dtypes):
# This graph will store both the initialization and the call version of the
# wrapped function. It will later be used by the backprop code to build the
# backprop graph, if necessary.
- tmp_graph = tf_ops.Graph()
+ captures = {}
+ tmp_graph = function.CapturingGraph(captures)
# Inherit the container from the original graph to create resources at user
# expected containers. Also inherits the container prefix, since this is
# used for error checking when isolating Eager execution (the container
@@ -342,7 +306,6 @@ def _graph_callable_internal(func, shape_and_dtypes):
# variables. As a side-effect this will populate the variable capturing
# scope's view of which variables exist.
variable_captures = _VariableCapturingScope()
- captures = {}
with variable_captures.initializing_scope(), function.capture_tensors(
captures):
func_outputs = func(*func_inputs)
@@ -366,7 +329,6 @@ def _graph_callable_internal(func, shape_and_dtypes):
sorted_variables = sorted(variable_captures.variables.values(),
key=lambda x: x.name)
- variable_placeholders = [x.placeholder for x in sorted_variables]
ids = list(sorted(captures.keys()))
if ids:
extra_inputs, extra_placeholders = zip(*[captures[x] for x in ids])
@@ -377,7 +339,6 @@ def _graph_callable_internal(func, shape_and_dtypes):
flat_inputs = [x for x in nest.flatten(func_inputs)
if isinstance(x, tf_ops.Tensor)]
placeholder_inputs = flat_inputs+ list(extra_placeholders)
- all_inputs = variable_placeholders + placeholder_inputs
func_def_outputs = [x for x in outputs_list if isinstance(x, tf_ops.Tensor)]
initializer_function_def = function.make_function_def(
@@ -407,13 +368,13 @@ def _graph_callable_internal(func, shape_and_dtypes):
captured_function_def = function.make_function_def(
tmp_graph,
capturing_operations,
- all_inputs,
+ placeholder_inputs,
capture_func_def_outputs)
function._register_with_name(function._inference_name(func.__name__), # pylint: disable=protected-access
captured_function_def)
captured_function = _FunctionObject(
sorted_variables,
- all_inputs,
+ placeholder_inputs,
extra_inputs,
captured_function_def,
tmp_graph,
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index c94ddb0627..71e1fb0297 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -26,7 +26,6 @@ from tensorflow.python.eager import tape
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
-from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_resource_variable_ops
@@ -315,7 +314,7 @@ class ResourceVariable(variables.Variable):
self._handle_device = (
self._handle.device if self._in_graph_mode else
context.get_default_context().device_name)
- self._graph_shape = initial_value.get_shape()
+ self._shape = initial_value.get_shape()
else:
initial_value = initial_value()
with ops.name_scope("Initializer"):
@@ -330,7 +329,7 @@ class ResourceVariable(variables.Variable):
self._handle_device = (
self._handle.device if self._in_graph_mode else
context.get_default_context().device_name)
- self._graph_shape = initial_value.get_shape()
+ self._shape = initial_value.get_shape()
# pylint: enable=protected-access
# Or get the initial value from a Tensor or Python object.
@@ -355,7 +354,7 @@ class ResourceVariable(variables.Variable):
graph_mode=self._in_graph_mode)
self._handle_device = (self._handle.device if self._in_graph_mode else
context.get_default_context().device_name)
- self._graph_shape = initial_value.get_shape()
+ self._shape = initial_value.get_shape()
self._initial_value = initial_value if self._in_graph_mode else None
self._handle_name = handle_name + ":0"
@@ -422,7 +421,7 @@ class ResourceVariable(variables.Variable):
self._handle = g.as_graph_element(
ops.prepend_name_scope(
variable_def.variable_name, import_scope=import_scope))
- self._graph_shape = tensor_shape.TensorShape(
+ self._shape = tensor_shape.TensorShape(
self._handle.op.get_attr("shape"))
self._handle_device = self._handle.device
self._handle_name = self._handle.name
@@ -502,11 +501,7 @@ class ResourceVariable(variables.Variable):
@property
def shape(self):
"""The shape of this variable."""
- if self._in_graph_mode:
- return self._graph_shape
- return tensor_shape.TensorShape(
- tensor_util.constant_value(
- gen_resource_variable_ops.variable_shape(self._handle)))
+ return self._shape
@property
def create(self):