diff options
author | 2017-09-06 17:57:04 -0700 | |
---|---|---|
committer | 2017-09-06 18:01:27 -0700 | |
commit | e722358e7e96dd2aa20d7e2c56336e76845daa6a (patch) | |
tree | a74960670ce4bacad0909fc913097bcc3e27ed18 /tensorflow/core/kernels/mkl_reshape_op.cc | |
parent | f8a43f9d63ce90f10852d69e40fbb9fe849fc190 (diff) |
Merge changes from github.
END_PUBLIC
---
Commit 607816029 authored by Eugene Brevdo<ebrevdo@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Extended ScratchSpace to expose its underlying scratch tensor object.
PiperOrigin-RevId: 167649551
---
Commit db43fe68e authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Add fast math attributes to all generated methods when fast math enabled.
RELNOTES: n/a
PiperOrigin-RevId: 167646637
---
Commit aebe8cc6f authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Call HloComputation.Accept instead of HloInstruction.Accept to get all instructions profiled.
RELNOTES: n/a
PiperOrigin-RevId: 167640259
---
Commit 0ab137cd8 authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
BEGIN_PUBLIC
Automated g4 rollback of changelist 167604306
PiperOrigin-RevId: 167800256
Diffstat (limited to 'tensorflow/core/kernels/mkl_reshape_op.cc')
-rw-r--r-- | tensorflow/core/kernels/mkl_reshape_op.cc | 68 |
1 files changed, 50 insertions, 18 deletions
diff --git a/tensorflow/core/kernels/mkl_reshape_op.cc b/tensorflow/core/kernels/mkl_reshape_op.cc index b3763f17bc..5e98582475 100644 --- a/tensorflow/core/kernels/mkl_reshape_op.cc +++ b/tensorflow/core/kernels/mkl_reshape_op.cc @@ -43,30 +43,26 @@ class MklReshapeOp : public OpKernel { OP_REQUIRES(context, IsLegacyVector(sizes.shape()), errors::InvalidArgument("sizes input must be 1-D, not shape ", sizes.shape().DebugString())); - const int64 num_dims = sizes.NumElements(); // Compute the output shape. Determine product of specified // dimensions, and find the index of the unspecified one. TensorShape shape; int64 product = 1; int unknown_index = -1; - auto vec_size = sizes.flat<int32>(); - for (int d = 0; d < num_dims; ++d) { - const int32 size = vec_size(d); - if (size == -1) { - OP_REQUIRES( - context, unknown_index == -1, - errors::InvalidArgument("only one input size may be -1, not both ", - unknown_index, " and ", d)); - unknown_index = d; - shape.AddDim(1); - } else { - OP_REQUIRES(context, size >= 0, - errors::InvalidArgument( - "size ", d, " must be non-negative, not ", size)); - shape.AddDim(size); - product *= size; - } + switch (sizes.dtype()) { + case DT_INT32: + OP_REQUIRES_OK(context, ValidateSizes<int32>(sizes, &product, + &unknown_index, &shape)); + break; + case DT_INT64: + OP_REQUIRES_OK(context, ValidateSizes<int64>(sizes, &product, + &unknown_index, &shape)); + break; + default: + context->CtxFailure(errors::InvalidArgument( + "desired shape must be a DT_INT32 or DT_INT64 vector, not a ", + DataTypeString(sizes.dtype()))); + return; } if (unknown_index != -1) { OP_REQUIRES( @@ -132,6 +128,35 @@ class MklReshapeOp : public OpKernel { CopyTfTensorInToOutWithShape(context, 0, 0, shape); } } + + private: + template <typename Tshape> + Status ValidateSizes(const Tensor& sizes, int64* product, int* unknown_index, + TensorShape* shape) { + *product = 1; + *unknown_index = -1; + const int64 num_dims = sizes.NumElements(); + auto Svec = sizes.flat<Tshape>(); + for (int d = 0; d < num_dims; ++d) { + const Tshape size = Svec(d); + if (size == -1) { + if (*unknown_index != -1) { + return errors::InvalidArgument( + "Only one input size may be -1, not both ", *unknown_index, + " and ", d); + } + *unknown_index = d; + shape->AddDim(1); + } else if (size < 0) { + return errors::InvalidArgument("Size ", d, + " must be non-negative, not ", size); + } else { + shape->AddDim(size); + (*product) *= size; + } + } + return Status::OK(); + } }; #define REGISTER_MKL_CPU(T) \ @@ -141,6 +166,13 @@ class MklReshapeOp : public OpKernel { .TypeConstraint<T>("T") \ .TypeConstraint<int32>("Tshape") \ .Label(mkl_op_registry::kMklOpLabel), \ + MklReshapeOp<CPUDevice, T>); \ + REGISTER_KERNEL_BUILDER(Name("_MklReshape") \ + .Device(DEVICE_CPU) \ + .HostMemory("shape") \ + .TypeConstraint<T>("T") \ + .TypeConstraint<int64>("Tshape") \ + .Label(mkl_op_registry::kMklOpLabel), \ MklReshapeOp<CPUDevice, T>); TF_CALL_float(REGISTER_MKL_CPU); #undef REGISTER_MKL_CPU |