aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2018-09-28 12:49:38 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-28 12:58:17 -0700
commit6d02ee8e581bf5211f362b80175122e3782fb37a (patch)
tree5fa8251b451e23926ff8d3a74091468a1ad23f4a
parent5e66d25666aad9fa76ed8cc0d2b162db76ea0cc8 (diff)
Simplify batch_dot logic
Remove dead logical branch. PiperOrigin-RevId: 214980627
-rw-r--r--tensorflow/python/keras/backend.py8
1 files changed, 2 insertions, 6 deletions
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index 4589c821e5..584facc859 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -1511,12 +1511,8 @@ def batch_dot(x, y, axes=None):
out = math_ops.reduce_sum(
math_ops.multiply(array_ops.transpose(x, [1, 0]), y), axes[1])
else:
- if axes is not None:
- adj_x = None if axes[0] == ndim(x) - 1 else True
- adj_y = True if axes[1] == ndim(y) - 1 else None
- else:
- adj_x = None
- adj_y = None
+ adj_x = None if axes[0] == ndim(x) - 1 else True
+ adj_y = True if axes[1] == ndim(y) - 1 else None
out = math_ops.matmul(x, y, adjoint_a=adj_x, adjoint_b=adj_y)
if diff:
if x_ndim > y_ndim: