diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc | 5 |
1 files changed, 4 insertions, 1 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc index 14506d65c4..bb8dd3ac90 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc @@ -20,7 +20,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" namespace tensorflow { @@ -32,6 +33,8 @@ XlaReductionOp::XlaReductionOp(OpKernelConstruction* ctx, OP_REQUIRES_OK(ctx, ctx->MatchSignature({dt, DT_INT32}, {dt})); OP_REQUIRES_OK(ctx, ctx->GetAttr("keep_dims", &keep_dims_)); + OP_REQUIRES_OK( + ctx, DataTypeToPrimitiveType(reduction_type_, &xla_reduction_type_)); } // Unless BuildFinalizer is overridden the reduction has no |