diff options
author | 2017-09-26 12:16:20 -0700 | |
---|---|---|
committer | 2017-09-26 12:20:19 -0700 | |
commit | b29b839215fa9bf5a00ca97e19673cfa5f780314 (patch) | |
tree | 26355f17565a9e49bf5e107493b1de51a32f25a6 /tensorflow/compiler/xla/service/shape_inference.h | |
parent | f97fd78f7ef585215d13b39980319b8cad13ddd3 (diff) |
[XLA] Map API change to enable mapping over an arbitrary set of dimensions.
PiperOrigin-RevId: 170090055
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.h')
-rw-r--r-- | tensorflow/compiler/xla/service/shape_inference.h | 3 |
1 files changed, 2 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 379feef5e4..d5d497176d 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -78,7 +78,8 @@ class ShapeInference { // to the given operand shapes. static StatusOr<Shape> InferMapShape( tensorflow::gtl::ArraySlice<const Shape*> arg_shapes, - const ProgramShape& to_apply); + const ProgramShape& to_apply, + tensorflow::gtl::ArraySlice<int64> dimensions); // Infers the shape produced by InferBatchNormTraining with the given // operands. |