diff options
Diffstat (limited to 'tensorflow/core/util/mkl_util.h')
-rw-r--r-- | tensorflow/core/util/mkl_util.h | 50 |
1 files changed, 42 insertions, 8 deletions
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h index dffc965b14..90b6533690 100644 --- a/tensorflow/core/util/mkl_util.h +++ b/tensorflow/core/util/mkl_util.h @@ -42,6 +42,7 @@ limitations under the License. #ifndef INTEL_MKL_ML #include "mkldnn.hpp" +#include "tensorflow/core/lib/core/stringpiece.h" using mkldnn::engine; using mkldnn::memory; @@ -712,15 +713,48 @@ inline Tensor ConvertMklToTF(OpKernelContext* context, const Tensor& mkl_tensor, return output_tensor; } #else +using mkldnn::stream; +template <typename T> class MklDnnData; + template <typename T> inline Tensor ConvertMklToTF(OpKernelContext* context, const Tensor& mkl_tensor, const MklDnnShape& mkl_shape) { Tensor output_tensor; - TensorShape output_shape; - - TF_CHECK_OK( - Status(error::Code::UNIMPLEMENTED, "Unimplemented conversion function")); - + try { + if (!mkl_shape.IsMklTensor()) + return mkl_tensor; // return input since it is already TF tensor + + TensorShape output_shape = mkl_shape.GetTfShape();; + + // Allocate output tensor. + context->allocate_temp(DataTypeToEnum<T>::v(), + output_shape, &output_tensor); + + auto cpu_engine = engine(engine::cpu, 0); + MklDnnData<T> input(&cpu_engine); + + // Get Mkl layout of input tensor. + auto input_mkl_md = mkl_shape.GetMklLayout(); + auto output_tf_md = mkl_shape.GetTfLayout(); + auto output_tf_pd = memory::primitive_desc(output_tf_md, cpu_engine); + input.SetUsrMem(input_mkl_md, &mkl_tensor); + + // reorder + if (input.IsReorderNeeded(output_tf_pd)) { + std::vector<primitive> net; + CHECK_EQ(input.CheckReorderToOpMem(output_tf_pd, &output_tensor, &net), + true); + stream(stream::kind::eager).submit(net).wait(); + } else { + // If not, just forward input tensor to output tensor. + CHECK(output_tensor.CopyFrom(mkl_tensor, output_shape)); + } + } catch (mkldnn::error& e) { + string error_msg = "Status: " + std::to_string(e.status) + + ", message: " + string(e.message) + ", in file " + + string(__FILE__) + ":" + std::to_string(__LINE__); + LOG(FATAL) << "Operation received an exception: " << error_msg; + } return output_tensor; } #endif @@ -1843,7 +1877,7 @@ class FactoryKeyCreator { template <typename T> void AddAsKey(const T data) { auto buffer = reinterpret_cast<const char *>(&data); - Append(absl::string_view(buffer, sizeof(T))); + Append(StringPiece(buffer, sizeof(T))); } std::string GetKey() { @@ -1854,8 +1888,8 @@ class FactoryKeyCreator { string key_; const char delimiter = 'x'; const int kMaxKeyLength = 256; - void Append(absl::string_view s) { - key_.append(string(s)); + void Append(StringPiece s) { + key_.append(s.ToString()); key_.append(1, delimiter); } }; |