aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-09-26 12:16:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-26 12:20:19 -0700
commitb29b839215fa9bf5a00ca97e19673cfa5f780314 (patch)
tree26355f17565a9e49bf5e107493b1de51a32f25a6 /tensorflow/compiler/xla/service/shape_inference.h
parentf97fd78f7ef585215d13b39980319b8cad13ddd3 (diff)
[XLA] Map API change to enable mapping over an arbitrary set of dimensions.
PiperOrigin-RevId: 170090055
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.h')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.h3
1 files changed, 2 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h
index 379feef5e4..d5d497176d 100644
--- a/tensorflow/compiler/xla/service/shape_inference.h
+++ b/tensorflow/compiler/xla/service/shape_inference.h
@@ -78,7 +78,8 @@ class ShapeInference {
// to the given operand shapes.
static StatusOr<Shape> InferMapShape(
tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
- const ProgramShape& to_apply);
+ const ProgramShape& to_apply,
+ tensorflow::gtl::ArraySlice<int64> dimensions);
// Infers the shape produced by InferBatchNormTraining with the given
// operands.