aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-27 14:28:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-27 14:31:07 -0700
commita4dbc33512adb3705345b093a0aafec151e7e32d (patch)
tree7922e1cf71084d592008299f99a0cf509a0714bc
parent6da711a50c3ef98aebacd6a909596a0f74b783e1 (diff)
If two identical functions are given different grad func,
they should be named differently. Otherwise, tf.gradients gets confused. PiperOrigin-RevId: 194593519
-rw-r--r--tensorflow/python/framework/function.py37
-rw-r--r--tensorflow/python/framework/function_test.py114
-rw-r--r--tensorflow/python/ops/gradients_impl.py32
3 files changed, 129 insertions, 54 deletions
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index 2432ab378c..e7f9e590af 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -353,8 +353,10 @@ class _DefinedFunction(object):
raise ValueError("Function can not return None.")
# Ensures each output is a Tensor in the function graph.
outputs = [ops.convert_to_tensor(t) for t in outputs]
- outputs = [temp_graph.capture(t) if t.graph is not temp_graph else t
- for t in outputs]
+ outputs = [
+ temp_graph.capture(t) if t.graph is not temp_graph else t
+ for t in outputs
+ ]
self._extra_inputs = temp_graph.extra_inputs
inputs.extend(temp_graph.extra_args)
# pylint: disable=protected-access
@@ -362,9 +364,13 @@ class _DefinedFunction(object):
# pylint: enable=protected-access
# Extra kwargs are treated as attrs on the function def.
- base_func_name = self._func_name or _get_func_name(self._func)
- kwargs_attr = _parse_kwargs_as_attrs(base_func_name,
- **self._extra_kwargs)
+ if self._func_name:
+ base_func_name = self._func_name
+ else:
+ base_func_name = _get_func_name(self._func)
+ if self._grad_func:
+ base_func_name += ("_%s" % self._grad_func.name)
+ kwargs_attr = _parse_kwargs_as_attrs(base_func_name, **self._extra_kwargs)
if not temp_graph._c_graph: # pylint: disable=protected-access
# Build the FunctionDef
@@ -503,6 +509,12 @@ class _DefinedFunction(object):
self.add_to_graph(ops.get_default_graph())
args = [ops.convert_to_tensor(_) for _ in args] + self._extra_inputs
ret, op = _call(self._signature, *args, **kwargs)
+
+ # Set a hidden attr in 'op' so that gradients_impl can refer back
+ # to this _DefinedFunction instance to access python_grad_func.
+ assert isinstance(op, ops.Operation)
+ setattr(op, "__defun", self)
+
if self._shape_func is not None:
shapes = self._shape_func(op)
if len(shapes) != len(op.outputs):
@@ -591,12 +603,11 @@ class _OverloadedFunction(object):
# _OverloadedFunction. We need to instantiate it with the
# right input types.
output_types = [
- dtypes.DType(_.type)
- for _ in defined._signature.output_arg # pylint: disable=protected-access
+ dtypes.DType(_.type) for _ in defined._signature.output_arg # pylint: disable=protected-access
]
# pylint: disable=protected-access
- defined._grad_func = self._grad_func.instantiate(
- input_types + output_types)
+ defined._grad_func = self._grad_func.instantiate(input_types +
+ output_types)
# pylint: enable=protected-access
self._overload[key] = defined
return defined
@@ -833,8 +844,8 @@ def _call(sig, *inputs, **kwargs):
ValueError: if the arguments are invalid.
"""
if len(inputs) != len(sig.input_arg):
- raise ValueError("Expected number of arguments: %d, received: %d" %
- (len(sig.input_arg), len(inputs)))
+ raise ValueError("Expected number of arguments: %d, received: %d" % (len(
+ sig.input_arg), len(inputs)))
name = kwargs.pop("name", None)
g = ops.get_default_graph()
func_name = sig.name
@@ -950,8 +961,8 @@ def _from_library(lib):
fdef for fdef in lib.function if func_to_grad[fdef.signature.name] is None
]
if not ready:
- raise ValueError("FunctionDefLibrary contains cyclic gradient functions!\n"
- + str(lib))
+ raise ValueError(
+ "FunctionDefLibrary contains cyclic gradient functions!\n" + str(lib))
# function name -> _DefinedFunction
initialized = {}
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index 594596ec1e..a5c19f189e 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -136,7 +136,8 @@ class FunctionTest(test.TestCase):
def testTooManyOutputNames(self):
@function.Defun(
- dtypes.float32, func_name="MyIdentity",
+ dtypes.float32,
+ func_name="MyIdentity",
out_names=["my_result1", "my_result2"])
def MyIdentityFunc(a):
return a
@@ -239,10 +240,11 @@ class FunctionTest(test.TestCase):
inp = np.array([-1, 1, 2, -2], dtype=np.float32)
feed = {x: inp}
- cfg = config_pb2.ConfigProto(graph_options=config_pb2.GraphOptions(
- optimizer_options=config_pb2.OptimizerOptions(
- opt_level=config_pb2.OptimizerOptions.L1,
- do_function_inlining=True)))
+ cfg = config_pb2.ConfigProto(
+ graph_options=config_pb2.GraphOptions(
+ optimizer_options=config_pb2.OptimizerOptions(
+ opt_level=config_pb2.OptimizerOptions.L1,
+ do_function_inlining=True)))
with session.Session(graph=g, config=cfg) as sess:
out, = sess.run(dx, feed)
self.assertAllClose(1 - np.square(np.tanh(inp)), out)
@@ -334,18 +336,20 @@ class FunctionTest(test.TestCase):
y = Foo(x)
dx, = gradients_impl.gradients(y, [x])
- cfg = config_pb2.ConfigProto(graph_options=config_pb2.GraphOptions(
- optimizer_options=config_pb2.OptimizerOptions(
- opt_level=config_pb2.OptimizerOptions.L0,
- do_common_subexpression_elimination=True,
- do_function_inlining=True,
- do_constant_folding=True)))
+ cfg = config_pb2.ConfigProto(
+ graph_options=config_pb2.GraphOptions(
+ optimizer_options=config_pb2.OptimizerOptions(
+ opt_level=config_pb2.OptimizerOptions.L0,
+ do_common_subexpression_elimination=True,
+ do_function_inlining=True,
+ do_constant_folding=True)))
with self.test_session(graph=g, config=cfg):
self.assertAllClose(y.eval(), 6.)
self.assertAllClose(dx.eval(), 2.)
def _testZNoDepOnY(self, use_const_grad_ys):
+
@function.Defun(dtypes.float32, dtypes.float32)
def Foo(x, y): # pylint: disable=unused-argument
return x * 2
@@ -775,9 +779,9 @@ class FunctionTest(test.TestCase):
@function.Defun()
def Foo():
- return control_flow_ops.while_loop(lambda i: i < 10,
- lambda i: i + x,
+ return control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + x,
[0])
+
y = Foo()
with self.test_session(graph=g) as sess:
@@ -790,9 +794,8 @@ class FunctionTest(test.TestCase):
@function.Defun(dtypes.bool)
def Foo(pred):
- return control_flow_ops.cond(pred,
- lambda: x,
- lambda: x + 1)
+ return control_flow_ops.cond(pred, lambda: x, lambda: x + 1)
+
y = Foo(True)
z = Foo(False)
@@ -945,6 +948,7 @@ class FunctionTest(test.TestCase):
self.assertEqual(len(f.signature.input_arg), 3)
def testGradientWithIntegerFunctionArgument(self):
+
@function.Defun(dtypes.int32, dtypes.float32)
def Foo(t, x):
return x[t]
@@ -959,8 +963,7 @@ class FunctionTest(test.TestCase):
x = np.zeros((2,)).astype(np.float32)
with session.Session(graph=g) as sess:
self.assertAllClose(
- np.array([1.0, 0.0]).astype(np.float32),
- sess.run(dinp, {inp: x}))
+ np.array([1.0, 0.0]).astype(np.float32), sess.run(dinp, {inp: x}))
def testFunctionMarkedStateful(self):
@@ -1073,6 +1076,60 @@ class FunctionTest(test.TestCase):
sess.run(var.initializer)
_ = sess.run(CapturesGuaranteedConst(), {also_not_const: 1.0})
+ def testSameFunctionDifferentGrads(self):
+
+ def PartOne(x):
+
+ # Default grad is dx = dy * 2
+ @function.Defun(dtypes.float32)
+ def Foo(x):
+ return x * 2
+
+ return Foo(x)
+
+ def PartTwo(x):
+
+ @function.Defun(dtypes.float32, dtypes.float32)
+ def Bar(x, dy):
+ return x + dy # crazy backprop
+
+ @function.Defun(dtypes.float32, grad_func=Bar)
+ def Foo(x):
+ return x * 2
+
+ return Foo(x)
+
+ def PartThree(x):
+
+ def Bar(op, dy):
+ return op.inputs[0] * dy / 2 # crazy backprop
+
+ @function.Defun(dtypes.float32, python_grad_func=Bar)
+ def Foo(x):
+ return x * 2
+
+ return Foo(x)
+
+ g = ops.Graph()
+ with g.as_default():
+ x = constant_op.constant(100.)
+ x0 = x
+ y0 = PartOne(x0)
+ dx0, = gradients_impl.gradients(ys=[y0], xs=[x0])
+ x1 = x
+ y1 = PartTwo(x1)
+ dx1, = gradients_impl.gradients(ys=[y1], xs=[x1])
+ x2 = x
+ y2 = PartThree(x2)
+ dx2, = gradients_impl.gradients(ys=[y2], xs=[x2])
+
+ with self.test_session(graph=g) as sess:
+ v0, v1, v2 = sess.run([dx0, dx1, dx2])
+
+ self.assertAllEqual(v0, 2.)
+ self.assertAllEqual(v1, 101.)
+ self.assertAllEqual(v2, 50.)
+
@test_util.with_c_shapes
class FunctionsFromProtos(test.TestCase):
@@ -1271,9 +1328,10 @@ class FunctionsFromProtos(test.TestCase):
@function.Defun(dtypes.int32, experimental_tag="tag_value")
def FunctionWithAttr(i):
return array_ops.identity(i)
+
self.assertTrue("experimental_tag" in FunctionWithAttr.definition.attr)
- self.assertEqual(
- FunctionWithAttr.definition.attr["experimental_tag"].s, b"tag_value")
+ self.assertEqual(FunctionWithAttr.definition.attr["experimental_tag"].s,
+ b"tag_value")
@test_util.with_c_shapes
@@ -1401,7 +1459,8 @@ class UnrollLSTMTest(test.TestCase):
return Loop(cell, weights, inp)
cell = function.Defun(dtypes.float32, dtypes.float32, dtypes.float32,
- dtypes.float32)(cell)
+ dtypes.float32)(
+ cell)
if mode == "cell":
# Just represent the LSTM as a function.
return Loop(cell, weights, inp)
@@ -1500,12 +1559,13 @@ class FunctionInlineControlTest(test.TestCase):
def testFoo(self):
dtype = dtypes.float32
- cfg = config_pb2.ConfigProto(graph_options=config_pb2.GraphOptions(
- optimizer_options=config_pb2.OptimizerOptions(
- opt_level=config_pb2.OptimizerOptions.L0,
- do_common_subexpression_elimination=True,
- do_function_inlining=True,
- do_constant_folding=True)))
+ cfg = config_pb2.ConfigProto(
+ graph_options=config_pb2.GraphOptions(
+ optimizer_options=config_pb2.OptimizerOptions(
+ opt_level=config_pb2.OptimizerOptions.L0,
+ do_common_subexpression_elimination=True,
+ do_function_inlining=True,
+ do_constant_folding=True)))
cell_func_call_pattern = re.compile(r"Cell[^/]*\(")
for noinline in [False, True]:
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py
index 581ba7de48..1448151fef 100644
--- a/tensorflow/python/ops/gradients_impl.py
+++ b/tensorflow/python/ops/gradients_impl.py
@@ -256,21 +256,21 @@ def _DefaultGradYs(grad_ys,
continue
if y.dtype.is_floating or y.dtype.is_integer:
if not grad_y.dtype.is_floating and not grad_y.dtype.is_integer:
- raise TypeError("Gradient type %s generated for real or "
- "integer-valued tensor %s with type %s must be "
- "real or integer" %
- (dtypes.as_dtype(grad_y.dtype).name, y,
- dtypes.as_dtype(y.dtype).name))
+ raise TypeError(
+ "Gradient type %s generated for real or "
+ "integer-valued tensor %s with type %s must be "
+ "real or integer" % (dtypes.as_dtype(grad_y.dtype).name, y,
+ dtypes.as_dtype(y.dtype).name))
elif y.dtype.is_complex:
if not grad_y.dtype.is_complex:
- raise TypeError("Gradient type %s generated for complex-valued "
- "tensor %s with type %s must be real" %
- (dtypes.as_dtype(grad_y.dtype).name, y,
- dtypes.as_dtype(y.dtype).name))
+ raise TypeError(
+ "Gradient type %s generated for complex-valued "
+ "tensor %s with type %s must be real" % (dtypes.as_dtype(
+ grad_y.dtype).name, y, dtypes.as_dtype(y.dtype).name))
else:
- raise TypeError("Tensor %s with type %s must be numeric "
- "to obtain a default gradient" %
- (y, dtypes.as_dtype(y.dtype).name))
+ raise TypeError(
+ "Tensor %s with type %s must be numeric "
+ "to obtain a default gradient" % (y, dtypes.as_dtype(y.dtype).name))
# Create a grad_y tensor in the name scope of the gradient.
# Required for TensorArrays to identify which gradient call a
# grad_y value is coming from.
@@ -605,15 +605,19 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,
loop_state.ExitGradWhileContext(op, before=True)
grad_fn = None
- # pylint: disable=protected-access
func_call = None
+ # pylint: disable=protected-access
is_func_call = ops.get_default_graph()._is_function(op.type)
+ # pylint: enable=protected-access
has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads)
if has_out_grads and (op._id not in stop_ops):
if is_func_call:
func_call = ops.get_default_graph()._get_function(op.type)
+ # Note that __defun is not set if the graph is
+ # imported. If it's set, we prefer to access the original
+ # defun.
+ func_call = getattr(op, "__defun", func_call)
grad_fn = func_call.python_grad_func
- # pylint: enable=protected-access
else:
# A grad_fn must be defined, either as a function or as None
# for ops that do not have gradients.