diff options
Diffstat (limited to 'tensorflow/core/kernels/mkl_identity_op.cc')
-rw-r--r-- | tensorflow/core/kernels/mkl_identity_op.cc | 63 |
1 files changed, 63 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/mkl_identity_op.cc b/tensorflow/core/kernels/mkl_identity_op.cc new file mode 100644 index 0000000000..e138cc2e95 --- /dev/null +++ b/tensorflow/core/kernels/mkl_identity_op.cc @@ -0,0 +1,63 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/array_ops.cc. +#ifdef INTEL_MKL + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" + +#include "tensorflow/core/util/mkl_util.h" +#include "third_party/mkl/include/mkl_dnn.h" +#include "third_party/mkl/include/mkl_dnn_types.h" + +namespace tensorflow { +typedef Eigen::ThreadPoolDevice CPUDevice; +template <typename Device, typename T> +class MklIdentityOp : public OpKernel { + public: + explicit MklIdentityOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + MklShape mkl_shape_input; + GetMklShape(context, 0, &mkl_shape_input); + bool input_in_mkl_format = mkl_shape_input.IsMklTensor(); + + if (input_in_mkl_format) { + ForwarMklTensorInToOut(context, 0, 0); + } else { + FowardTfTensorInToOut(context, 0, 0); + } + } + + bool IsExpensive() override { return false; } +}; + +#define REGISTER_MKL_CPU(T) \ + REGISTER_KERNEL_BUILDER(Name("_MklIdentity") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T") \ + .Label(mkl_op_registry::kMklOpLabel), \ + MklIdentityOp<CPUDevice, T>); \ + +TF_CALL_float(REGISTER_MKL_CPU); +#undef REGISTER_MKL_CPU +} // namespace tensorflow +#endif // INTEL_MKL |