aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla
diff options
context:
space:
mode:
authorGravatar David Majnemer <majnemer@google.com>2018-10-06 10:04:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-06 10:09:36 -0700
commit5c0a6bdfeb1848b0146a36706d921dde06ba160a (patch)
treee549be74d1f90165865102536d45cc1b4a2a75a0 /tensorflow/compiler/tf2xla
parent262f22f9eeee1ee00a9a92318d9a567a25c76696 (diff)
[XLA] Add base and window dilation support to ReduceWindow
PiperOrigin-RevId: 216041507
Diffstat (limited to 'tensorflow/compiler/tf2xla')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc21
-rw-r--r--tensorflow/compiler/tf2xla/kernels/scan_ops.cc3
-rw-r--r--tensorflow/compiler/tf2xla/ops/xla_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/python/xla.py6
4 files changed, 30 insertions, 2 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc
index 8102faad28..8eee5b1299 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc
@@ -40,10 +40,16 @@ class ReduceWindowOp : public XlaOpKernel {
std::vector<int64> window_dimensions;
std::vector<int64> window_strides;
+ std::vector<int64> base_dilations;
+ std::vector<int64> window_dilations;
OP_REQUIRES_OK(context, context->ConstantInputAsIntVector(
"window_dimensions", &window_dimensions));
OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("window_strides",
&window_strides));
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("base_dilations",
+ &base_dilations));
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector(
+ "window_dilations", &window_dilations));
const int rank = input_shape.dims();
OP_REQUIRES(context, rank == window_dimensions.size(),
@@ -56,6 +62,16 @@ class ReduceWindowOp : public XlaOpKernel {
"The size of window_strides must be equal to the input "
"rank (",
window_strides.size(), " vs. ", rank, ")"));
+ OP_REQUIRES(context, rank == base_dilations.size(),
+ errors::InvalidArgument(
+ "The size of base_dilations must be equal to the input "
+ "rank (",
+ base_dilations.size(), " vs. ", rank, ")"));
+ OP_REQUIRES(context, rank == window_dilations.size(),
+ errors::InvalidArgument(
+ "The size of window_dilations must be equal to the input "
+ "rank (",
+ window_dilations.size(), " vs. ", rank, ")"));
// Build the reducer function.
XlaCompiler::Argument reducer_arg;
@@ -102,7 +118,8 @@ class ReduceWindowOp : public XlaOpKernel {
xla::XlaOp output = xla::ReduceWindowWithGeneralPadding(
context->Input(0), context->Input(1), *reducer.computation,
- window_dimensions, window_strides, padding);
+ window_dimensions, window_strides, base_dilations, window_dilations,
+ padding);
context->SetOutput(0, output);
}
@@ -115,6 +132,8 @@ class ReduceWindowOp : public XlaOpKernel {
REGISTER_XLA_OP(Name("XlaReduceWindow")
.CompileTimeConstInput("window_dimensions")
.CompileTimeConstInput("window_strides")
+ .CompileTimeConstInput("base_dilations")
+ .CompileTimeConstInput("window_dilations")
.CompileTimeConstInput("padding"),
ReduceWindowOp);
diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc
index ab094d7dd1..57afd608de 100644
--- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc
@@ -104,7 +104,8 @@ class ScanOp : public XlaOpKernel {
}
auto output = xla::ReduceWindowWithGeneralPadding(
XlaHelpers::ConvertElementType(builder, ctx->Input(0), dtype), init,
- *reducer, window_dims, window_strides, padding);
+ *reducer, window_dims, window_strides,
+ /*base_dilations=*/{}, /*window_dilations=*/{}, padding);
output =
XlaHelpers::ConvertElementType(builder, output, ctx->input_type(0));
diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
index 557911553d..bd2c0a5ee8 100644
--- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc
+++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
@@ -283,6 +283,8 @@ REGISTER_OP("XlaReduceWindow")
.Input("init_value: T")
.Input("window_dimensions: Tindices")
.Input("window_strides: Tindices")
+ .Input("base_dilations: Tindices")
+ .Input("window_dilations: Tindices")
.Input("padding: Tindices")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py
index bc7924c371..5e86b5d8ec 100644
--- a/tensorflow/compiler/tf2xla/python/xla.py
+++ b/tensorflow/compiler/tf2xla/python/xla.py
@@ -320,6 +320,8 @@ def reduce_window(operand,
reducer,
window_dimensions,
window_strides=None,
+ base_dilations=None,
+ window_dilations=None,
padding=None,
name=None):
"""Wraps the XLA ReduceWindow operator.
@@ -343,12 +345,16 @@ def reduce_window(operand,
A tensor that represents the output of the reduce_window operator.
"""
window_strides = window_strides or [1] * len(window_dimensions)
+ base_dilations = base_dilations or [1] * len(window_dimensions)
+ window_dilations = window_dilations or [1] * len(window_dimensions)
padding = padding or [(0, 0)] * len(window_dimensions)
return gen_xla_ops.xla_reduce_window(
input=operand,
init_value=init,
window_dimensions=window_dimensions,
window_strides=window_strides,
+ base_dilations=base_dilations,
+ window_dilations=window_dilations,
padding=padding,
computation=reducer,
name=name)