aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc5
1 files changed, 4 insertions, 1 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
index 14506d65c4..bb8dd3ac90 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
@@ -20,7 +20,8 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
namespace tensorflow {
@@ -32,6 +33,8 @@ XlaReductionOp::XlaReductionOp(OpKernelConstruction* ctx,
OP_REQUIRES_OK(ctx, ctx->MatchSignature({dt, DT_INT32}, {dt}));
OP_REQUIRES_OK(ctx, ctx->GetAttr("keep_dims", &keep_dims_));
+ OP_REQUIRES_OK(
+ ctx, DataTypeToPrimitiveType(reduction_type_, &xla_reduction_type_));
}
// Unless BuildFinalizer is overridden the reduction has no