aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/conv_ops.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/conv_ops.cc12
1 files changed, 4 insertions, 8 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
index 5d41fc708a..48ac4867ed 100644
--- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/numeric.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/numeric_op.h"
@@ -96,14 +97,9 @@ xla::XlaOp CreateExpandedFilterMask(const TensorShape& filter_shape,
// Create a M sized linspace and an M*N sized linspace that will be
// broadcasted into perpendicular dimensions and compared.
- xla::XlaOp input_feature_iota;
- // DT_INT32 Iota will always return status::OK().
- TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, input_feature,
- &input_feature_iota));
- xla::XlaOp expanded_feature_iota;
- TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32,
- input_feature * depthwise_multiplier,
- &expanded_feature_iota));
+ xla::XlaOp input_feature_iota = xla::Iota(builder, xla::S32, input_feature);
+ xla::XlaOp expanded_feature_iota =
+ xla::Iota(builder, xla::S32, input_feature * depthwise_multiplier);
// Divide the M*N sized linspace by the depthwise_multiplier to create
// [0 0 1 1 2 2] in the example in the function comment.