aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2017-11-28 21:10:16 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-28 21:14:21 -0800
commit05f57851d4657ec6c09a454b157cf17d89d0cfe2 (patch)
tree105715c5c46481151153e37156b128bcfb5bf1e3
parenta55ee58c89d5bf6a8cd70b706dc3af90d7d6efc4 (diff)
Bugfixes: Gather's gradient for 2+ dimensional indices with eager execution.
And shape inference function for the VariableShape operation. PiperOrigin-RevId: 177262783
-rw-r--r--tensorflow/core/ops/resource_variable_ops.cc5
-rw-r--r--tensorflow/python/ops/resource_variable_ops.py21
-rw-r--r--tensorflow/python/training/momentum_test.py41
3 files changed, 39 insertions, 28 deletions
diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc
index cdfbec85cf..bf9e673e8e 100644
--- a/tensorflow/core/ops/resource_variable_ops.cc
+++ b/tensorflow/core/ops/resource_variable_ops.cc
@@ -204,7 +204,10 @@ Status VariableShapeShapeFn(InferenceContext* c) {
if (handle_data == nullptr || handle_data->empty()) {
return errors::InvalidArgument("Handle doesn't have shape information.");
}
- c->set_output(0, (*handle_data)[0].shape);
+ ShapeHandle var_shape = (*handle_data)[0].shape;
+ int64 rank = c->RankKnown(var_shape) ? c->Rank(var_shape)
+ : InferenceContext::kUnknownDim;
+ c->set_output(0, c->Vector(rank));
return Status::OK();
}
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index 343e38f960..652bfa1ebc 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -887,26 +887,19 @@ def _ReadGrad(_, grad):
def _GatherGrad(op, grad):
"""Gradient for gather op."""
# Build appropriately shaped IndexedSlices
- # Walk graph back until the original handle is found.
- # TODO(apassos): more robust way of getting the shape.
- # TODO(apassos): implement this for EAGER mode.
- if context.in_eager_mode():
- dense_shape = gen_resource_variable_ops.variable_shape(op.inputs[0])
- return (ops.IndexedSlices(grad,
- op.inputs[1],
- dense_shape=dense_shape),
- None)
handle = op.inputs[0]
- while handle.op.type != "VarHandleOp":
- handle = handle.op.inputs[0]
- params_shape = ops.convert_to_tensor(
- tensor_shape.TensorShape(handle.op.get_attr("shape")))
indices = op.inputs[1]
+ if context.in_graph_mode():
+ # Walk graph back until the original handle is found.
+ # TODO(apassos): implement this for EAGER mode.
+ while handle.op.type != "VarHandleOp":
+ handle = handle.op.inputs[0]
+ params_shape = gen_resource_variable_ops.variable_shape(handle)
size = array_ops.expand_dims(array_ops.size(indices), 0)
values_shape = array_ops.concat([size, params_shape[1:]], 0)
values = array_ops.reshape(grad, values_shape)
indices = array_ops.reshape(indices, size)
- return [ops.IndexedSlices(values, indices, params_shape), None]
+ return (ops.IndexedSlices(values, indices, params_shape), None)
def _to_proto_fn(v, export_scope=None):
diff --git a/tensorflow/python/training/momentum_test.py b/tensorflow/python/training/momentum_test.py
index 7268b3abc9..6865513b0e 100644
--- a/tensorflow/python/training/momentum_test.py
+++ b/tensorflow/python/training/momentum_test.py
@@ -234,23 +234,38 @@ class MomentumOptimizerTest(test.TestCase):
self.assertAllClose(var0_np, var0.eval())
self.assertAllClose(var1_np, var1.eval())
+ @test_util.run_in_graph_and_eager_modes(reset_test=True)
def testMinimizeSparseResourceVariable(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
- var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
+ var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
+
+ # pylint: disable=cell-var-from-loop
+ def loss():
x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
- loss = pred * pred
- sgd_op = momentum_lib.MomentumOptimizer(
- learning_rate=1.0, momentum=0.0).minimize(loss)
- variables.global_variables_initializer().run()
- # Fetch params to validate initial values
- self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval())
- # Run 1 step of sgd
- sgd_op.run()
- # Validate updated params
- self.assertAllCloseAccordingToType(
- [[-111, -138]], var0.eval())
+ return pred * pred
+ # pylint: enable=cell-var-from-loop
+
+ opt = momentum_lib.MomentumOptimizer(learning_rate=1.0, momentum=0.0)
+ sgd_op = opt.minimize(loss if context.in_eager_mode() else loss())
+ self.evaluate(variables.global_variables_initializer())
+ # Run 1 step of sgd
+ self.evaluate(sgd_op)
+ # Validate updated params
+ self.assertAllCloseAccordingToType([[-111, -138]], self.evaluate(var0))
+
+ @test_util.run_in_graph_and_eager_modes(reset_test=True)
+ def testMinimizeWith2DIndiciesForEmbeddingLookup(self):
+ var0 = resource_variable_ops.ResourceVariable(array_ops.ones([2, 2]))
+
+ def loss():
+ return math_ops.reduce_sum(embedding_ops.embedding_lookup(var0, [[1]]))
+
+ opt = momentum_lib.MomentumOptimizer(learning_rate=1.0, momentum=0.0)
+ sgd_op = opt.minimize(loss if context.in_eager_mode() else loss())
+ self.evaluate(variables.global_variables_initializer())
+ self.evaluate(sgd_op)
+ self.assertAllCloseAccordingToType([[1, 1], [0, 0]], self.evaluate(var0))
def testTensorLearningRateAndMomentum(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: