aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/math_grad.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/math_grad.py')
-rw-r--r--tensorflow/python/ops/math_grad.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index 55dd0c0e0d..e2ee9e4fe4 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -52,14 +52,14 @@ def _SumGrad(op, grad):
if axes is not None:
rank = len(input_0_shape)
if np.array_equal(axes, np.arange(rank)): # Reduce all dims.
- if context.in_graph_mode():
- new_shape = [1] * rank
- else:
+ if context.executing_eagerly():
ctx = context.context()
new_shape = ctx.ones_rank_cache().get(rank)
if new_shape is None:
new_shape = constant_op.constant([1] * rank, dtype=dtypes.int32)
ctx.ones_rank_cache().put(rank, new_shape)
+ else:
+ new_shape = [1] * rank
grad = array_ops.reshape(grad, new_shape)
# If shape is not fully defined (but rank is), we use Shape.
if None not in input_0_shape:
@@ -997,7 +997,7 @@ def _SparseMatMulGrad(op, grad):
op.inputs[0]: op.get_attr("a_is_sparse"),
op.inputs[1]: op.get_attr("b_is_sparse"),
# Use heuristic to figure out if grad might be sparse
- grad: context.in_graph_mode() and (grad.op.type == "ReluGrad")
+ grad: not context.executing_eagerly() and (grad.op.type == "ReluGrad")
}
def _SparseMatMul(t1, t2, out_dtype, transpose_a=False, transpose_b=False):