diff options
Diffstat (limited to 'tensorflow/python/ops/math_ops.py')
-rw-r--r-- | tensorflow/python/ops/math_ops.py | 9 |
1 files changed, 8 insertions, 1 deletions
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 9ebf251574..f0f438a33d 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -1087,7 +1087,14 @@ def matmul(a, b, with ops.op_scope([a, b], name, "MatMul") as name: a = ops.convert_to_tensor(a, name="a") b = ops.convert_to_tensor(b, name="b") - if a.dtype == dtypes.float32 and (a_is_sparse or b_is_sparse): + sparse_matmul_types = [dtypes.bfloat16, dtypes.float32] + use_sparse_matmul = (a.dtype in sparse_matmul_types and + b.dtype in sparse_matmul_types and + (a_is_sparse or b_is_sparse)) + if dtypes.bfloat16 in (a.dtype, b.dtype): + # matmul currently doesn't handle bfloat16 inputs. + use_sparse_matmul = True + if use_sparse_matmul: return sparse_matmul(a, b, transpose_a=transpose_a, transpose_b=transpose_b, |