aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_reshape_op.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-09-06 17:57:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-06 18:01:27 -0700
commite722358e7e96dd2aa20d7e2c56336e76845daa6a (patch)
treea74960670ce4bacad0909fc913097bcc3e27ed18 /tensorflow/core/kernels/mkl_reshape_op.cc
parentf8a43f9d63ce90f10852d69e40fbb9fe849fc190 (diff)
Merge changes from github.
END_PUBLIC --- Commit 607816029 authored by Eugene Brevdo<ebrevdo@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Extended ScratchSpace to expose its underlying scratch tensor object. PiperOrigin-RevId: 167649551 --- Commit db43fe68e authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Add fast math attributes to all generated methods when fast math enabled. RELNOTES: n/a PiperOrigin-RevId: 167646637 --- Commit aebe8cc6f authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Call HloComputation.Accept instead of HloInstruction.Accept to get all instructions profiled. RELNOTES: n/a PiperOrigin-RevId: 167640259 --- Commit 0ab137cd8 authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: BEGIN_PUBLIC Automated g4 rollback of changelist 167604306 PiperOrigin-RevId: 167800256
Diffstat (limited to 'tensorflow/core/kernels/mkl_reshape_op.cc')
-rw-r--r--tensorflow/core/kernels/mkl_reshape_op.cc68
1 files changed, 50 insertions, 18 deletions
diff --git a/tensorflow/core/kernels/mkl_reshape_op.cc b/tensorflow/core/kernels/mkl_reshape_op.cc
index b3763f17bc..5e98582475 100644
--- a/tensorflow/core/kernels/mkl_reshape_op.cc
+++ b/tensorflow/core/kernels/mkl_reshape_op.cc
@@ -43,30 +43,26 @@ class MklReshapeOp : public OpKernel {
OP_REQUIRES(context, IsLegacyVector(sizes.shape()),
errors::InvalidArgument("sizes input must be 1-D, not shape ",
sizes.shape().DebugString()));
- const int64 num_dims = sizes.NumElements();
// Compute the output shape. Determine product of specified
// dimensions, and find the index of the unspecified one.
TensorShape shape;
int64 product = 1;
int unknown_index = -1;
- auto vec_size = sizes.flat<int32>();
- for (int d = 0; d < num_dims; ++d) {
- const int32 size = vec_size(d);
- if (size == -1) {
- OP_REQUIRES(
- context, unknown_index == -1,
- errors::InvalidArgument("only one input size may be -1, not both ",
- unknown_index, " and ", d));
- unknown_index = d;
- shape.AddDim(1);
- } else {
- OP_REQUIRES(context, size >= 0,
- errors::InvalidArgument(
- "size ", d, " must be non-negative, not ", size));
- shape.AddDim(size);
- product *= size;
- }
+ switch (sizes.dtype()) {
+ case DT_INT32:
+ OP_REQUIRES_OK(context, ValidateSizes<int32>(sizes, &product,
+ &unknown_index, &shape));
+ break;
+ case DT_INT64:
+ OP_REQUIRES_OK(context, ValidateSizes<int64>(sizes, &product,
+ &unknown_index, &shape));
+ break;
+ default:
+ context->CtxFailure(errors::InvalidArgument(
+ "desired shape must be a DT_INT32 or DT_INT64 vector, not a ",
+ DataTypeString(sizes.dtype())));
+ return;
}
if (unknown_index != -1) {
OP_REQUIRES(
@@ -132,6 +128,35 @@ class MklReshapeOp : public OpKernel {
CopyTfTensorInToOutWithShape(context, 0, 0, shape);
}
}
+
+ private:
+ template <typename Tshape>
+ Status ValidateSizes(const Tensor& sizes, int64* product, int* unknown_index,
+ TensorShape* shape) {
+ *product = 1;
+ *unknown_index = -1;
+ const int64 num_dims = sizes.NumElements();
+ auto Svec = sizes.flat<Tshape>();
+ for (int d = 0; d < num_dims; ++d) {
+ const Tshape size = Svec(d);
+ if (size == -1) {
+ if (*unknown_index != -1) {
+ return errors::InvalidArgument(
+ "Only one input size may be -1, not both ", *unknown_index,
+ " and ", d);
+ }
+ *unknown_index = d;
+ shape->AddDim(1);
+ } else if (size < 0) {
+ return errors::InvalidArgument("Size ", d,
+ " must be non-negative, not ", size);
+ } else {
+ shape->AddDim(size);
+ (*product) *= size;
+ }
+ }
+ return Status::OK();
+ }
};
#define REGISTER_MKL_CPU(T) \
@@ -141,6 +166,13 @@ class MklReshapeOp : public OpKernel {
.TypeConstraint<T>("T") \
.TypeConstraint<int32>("Tshape") \
.Label(mkl_op_registry::kMklOpLabel), \
+ MklReshapeOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("_MklReshape") \
+ .Device(DEVICE_CPU) \
+ .HostMemory("shape") \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<int64>("Tshape") \
+ .Label(mkl_op_registry::kMklOpLabel), \
MklReshapeOp<CPUDevice, T>);
TF_CALL_float(REGISTER_MKL_CPU);
#undef REGISTER_MKL_CPU