diff options
author | Asim Shankar <ashankar@google.com> | 2017-11-28 21:10:16 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-11-28 21:14:21 -0800 |
commit | 05f57851d4657ec6c09a454b157cf17d89d0cfe2 (patch) | |
tree | 105715c5c46481151153e37156b128bcfb5bf1e3 | |
parent | a55ee58c89d5bf6a8cd70b706dc3af90d7d6efc4 (diff) |
Bugfixes: Gather's gradient for 2+ dimensional indices with eager execution.
And shape inference function for the VariableShape operation.
PiperOrigin-RevId: 177262783
-rw-r--r-- | tensorflow/core/ops/resource_variable_ops.cc | 5 | ||||
-rw-r--r-- | tensorflow/python/ops/resource_variable_ops.py | 21 | ||||
-rw-r--r-- | tensorflow/python/training/momentum_test.py | 41 |
3 files changed, 39 insertions, 28 deletions
diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc index cdfbec85cf..bf9e673e8e 100644 --- a/tensorflow/core/ops/resource_variable_ops.cc +++ b/tensorflow/core/ops/resource_variable_ops.cc @@ -204,7 +204,10 @@ Status VariableShapeShapeFn(InferenceContext* c) { if (handle_data == nullptr || handle_data->empty()) { return errors::InvalidArgument("Handle doesn't have shape information."); } - c->set_output(0, (*handle_data)[0].shape); + ShapeHandle var_shape = (*handle_data)[0].shape; + int64 rank = c->RankKnown(var_shape) ? c->Rank(var_shape) + : InferenceContext::kUnknownDim; + c->set_output(0, c->Vector(rank)); return Status::OK(); } diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index 343e38f960..652bfa1ebc 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -887,26 +887,19 @@ def _ReadGrad(_, grad): def _GatherGrad(op, grad): """Gradient for gather op.""" # Build appropriately shaped IndexedSlices - # Walk graph back until the original handle is found. - # TODO(apassos): more robust way of getting the shape. - # TODO(apassos): implement this for EAGER mode. - if context.in_eager_mode(): - dense_shape = gen_resource_variable_ops.variable_shape(op.inputs[0]) - return (ops.IndexedSlices(grad, - op.inputs[1], - dense_shape=dense_shape), - None) handle = op.inputs[0] - while handle.op.type != "VarHandleOp": - handle = handle.op.inputs[0] - params_shape = ops.convert_to_tensor( - tensor_shape.TensorShape(handle.op.get_attr("shape"))) indices = op.inputs[1] + if context.in_graph_mode(): + # Walk graph back until the original handle is found. + # TODO(apassos): implement this for EAGER mode. + while handle.op.type != "VarHandleOp": + handle = handle.op.inputs[0] + params_shape = gen_resource_variable_ops.variable_shape(handle) size = array_ops.expand_dims(array_ops.size(indices), 0) values_shape = array_ops.concat([size, params_shape[1:]], 0) values = array_ops.reshape(grad, values_shape) indices = array_ops.reshape(indices, size) - return [ops.IndexedSlices(values, indices, params_shape), None] + return (ops.IndexedSlices(values, indices, params_shape), None) def _to_proto_fn(v, export_scope=None): diff --git a/tensorflow/python/training/momentum_test.py b/tensorflow/python/training/momentum_test.py index 7268b3abc9..6865513b0e 100644 --- a/tensorflow/python/training/momentum_test.py +++ b/tensorflow/python/training/momentum_test.py @@ -234,23 +234,38 @@ class MomentumOptimizerTest(test.TestCase): self.assertAllClose(var0_np, var0.eval()) self.assertAllClose(var1_np, var1.eval()) + @test_util.run_in_graph_and_eager_modes(reset_test=True) def testMinimizeSparseResourceVariable(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): - var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype) + var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype) + + # pylint: disable=cell-var-from-loop + def loss(): x = constant_op.constant([[4.0], [5.0]], dtype=dtype) pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x) - loss = pred * pred - sgd_op = momentum_lib.MomentumOptimizer( - learning_rate=1.0, momentum=0.0).minimize(loss) - variables.global_variables_initializer().run() - # Fetch params to validate initial values - self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval()) - # Run 1 step of sgd - sgd_op.run() - # Validate updated params - self.assertAllCloseAccordingToType( - [[-111, -138]], var0.eval()) + return pred * pred + # pylint: enable=cell-var-from-loop + + opt = momentum_lib.MomentumOptimizer(learning_rate=1.0, momentum=0.0) + sgd_op = opt.minimize(loss if context.in_eager_mode() else loss()) + self.evaluate(variables.global_variables_initializer()) + # Run 1 step of sgd + self.evaluate(sgd_op) + # Validate updated params + self.assertAllCloseAccordingToType([[-111, -138]], self.evaluate(var0)) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testMinimizeWith2DIndiciesForEmbeddingLookup(self): + var0 = resource_variable_ops.ResourceVariable(array_ops.ones([2, 2])) + + def loss(): + return math_ops.reduce_sum(embedding_ops.embedding_lookup(var0, [[1]])) + + opt = momentum_lib.MomentumOptimizer(learning_rate=1.0, momentum=0.0) + sgd_op = opt.minimize(loss if context.in_eager_mode() else loss()) + self.evaluate(variables.global_variables_initializer()) + self.evaluate(sgd_op) + self.assertAllCloseAccordingToType([[1, 1], [0, 0]], self.evaluate(var0)) def testTensorLearningRateAndMomentum(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: |