diff options
author | David Majnemer <majnemer@google.com> | 2018-10-06 10:04:16 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-06 10:09:36 -0700 |
commit | 5c0a6bdfeb1848b0146a36706d921dde06ba160a (patch) | |
tree | e549be74d1f90165865102536d45cc1b4a2a75a0 /tensorflow/compiler/tf2xla | |
parent | 262f22f9eeee1ee00a9a92318d9a567a25c76696 (diff) |
[XLA] Add base and window dilation support to ReduceWindow
PiperOrigin-RevId: 216041507
Diffstat (limited to 'tensorflow/compiler/tf2xla')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc | 21 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/scan_ops.cc | 3 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/ops/xla_ops.cc | 2 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/python/xla.py | 6 |
4 files changed, 30 insertions, 2 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc index 8102faad28..8eee5b1299 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc @@ -40,10 +40,16 @@ class ReduceWindowOp : public XlaOpKernel { std::vector<int64> window_dimensions; std::vector<int64> window_strides; + std::vector<int64> base_dilations; + std::vector<int64> window_dilations; OP_REQUIRES_OK(context, context->ConstantInputAsIntVector( "window_dimensions", &window_dimensions)); OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("window_strides", &window_strides)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("base_dilations", + &base_dilations)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector( + "window_dilations", &window_dilations)); const int rank = input_shape.dims(); OP_REQUIRES(context, rank == window_dimensions.size(), @@ -56,6 +62,16 @@ class ReduceWindowOp : public XlaOpKernel { "The size of window_strides must be equal to the input " "rank (", window_strides.size(), " vs. ", rank, ")")); + OP_REQUIRES(context, rank == base_dilations.size(), + errors::InvalidArgument( + "The size of base_dilations must be equal to the input " + "rank (", + base_dilations.size(), " vs. ", rank, ")")); + OP_REQUIRES(context, rank == window_dilations.size(), + errors::InvalidArgument( + "The size of window_dilations must be equal to the input " + "rank (", + window_dilations.size(), " vs. ", rank, ")")); // Build the reducer function. XlaCompiler::Argument reducer_arg; @@ -102,7 +118,8 @@ class ReduceWindowOp : public XlaOpKernel { xla::XlaOp output = xla::ReduceWindowWithGeneralPadding( context->Input(0), context->Input(1), *reducer.computation, - window_dimensions, window_strides, padding); + window_dimensions, window_strides, base_dilations, window_dilations, + padding); context->SetOutput(0, output); } @@ -115,6 +132,8 @@ class ReduceWindowOp : public XlaOpKernel { REGISTER_XLA_OP(Name("XlaReduceWindow") .CompileTimeConstInput("window_dimensions") .CompileTimeConstInput("window_strides") + .CompileTimeConstInput("base_dilations") + .CompileTimeConstInput("window_dilations") .CompileTimeConstInput("padding"), ReduceWindowOp); diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc index ab094d7dd1..57afd608de 100644 --- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc @@ -104,7 +104,8 @@ class ScanOp : public XlaOpKernel { } auto output = xla::ReduceWindowWithGeneralPadding( XlaHelpers::ConvertElementType(builder, ctx->Input(0), dtype), init, - *reducer, window_dims, window_strides, padding); + *reducer, window_dims, window_strides, + /*base_dilations=*/{}, /*window_dilations=*/{}, padding); output = XlaHelpers::ConvertElementType(builder, output, ctx->input_type(0)); diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index 557911553d..bd2c0a5ee8 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -283,6 +283,8 @@ REGISTER_OP("XlaReduceWindow") .Input("init_value: T") .Input("window_dimensions: Tindices") .Input("window_strides: Tindices") + .Input("base_dilations: Tindices") + .Input("window_dilations: Tindices") .Input("padding: Tindices") .Attr("T: numbertype") .Attr("Tindices: {int32, int64}") diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index bc7924c371..5e86b5d8ec 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -320,6 +320,8 @@ def reduce_window(operand, reducer, window_dimensions, window_strides=None, + base_dilations=None, + window_dilations=None, padding=None, name=None): """Wraps the XLA ReduceWindow operator. @@ -343,12 +345,16 @@ def reduce_window(operand, A tensor that represents the output of the reduce_window operator. """ window_strides = window_strides or [1] * len(window_dimensions) + base_dilations = base_dilations or [1] * len(window_dimensions) + window_dilations = window_dilations or [1] * len(window_dimensions) padding = padding or [(0, 0)] * len(window_dimensions) return gen_xla_ops.xla_reduce_window( input=operand, init_value=init, window_dimensions=window_dimensions, window_strides=window_strides, + base_dilations=base_dilations, + window_dilations=window_dilations, padding=padding, computation=reducer, name=name) |