aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/math_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/math_ops.py')
-rw-r--r--tensorflow/python/ops/math_ops.py9
1 files changed, 8 insertions, 1 deletions
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 9ebf251574..f0f438a33d 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -1087,7 +1087,14 @@ def matmul(a, b,
with ops.op_scope([a, b], name, "MatMul") as name:
a = ops.convert_to_tensor(a, name="a")
b = ops.convert_to_tensor(b, name="b")
- if a.dtype == dtypes.float32 and (a_is_sparse or b_is_sparse):
+ sparse_matmul_types = [dtypes.bfloat16, dtypes.float32]
+ use_sparse_matmul = (a.dtype in sparse_matmul_types and
+ b.dtype in sparse_matmul_types and
+ (a_is_sparse or b_is_sparse))
+ if dtypes.bfloat16 in (a.dtype, b.dtype):
+ # matmul currently doesn't handle bfloat16 inputs.
+ use_sparse_matmul = True
+ if use_sparse_matmul:
return sparse_matmul(a, b,
transpose_a=transpose_a,
transpose_b=transpose_b,