aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util/mkl_util.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/util/mkl_util.h')
-rw-r--r--tensorflow/core/util/mkl_util.h50
1 files changed, 42 insertions, 8 deletions
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h
index dffc965b14..90b6533690 100644
--- a/tensorflow/core/util/mkl_util.h
+++ b/tensorflow/core/util/mkl_util.h
@@ -42,6 +42,7 @@ limitations under the License.
#ifndef INTEL_MKL_ML
#include "mkldnn.hpp"
+#include "tensorflow/core/lib/core/stringpiece.h"
using mkldnn::engine;
using mkldnn::memory;
@@ -712,15 +713,48 @@ inline Tensor ConvertMklToTF(OpKernelContext* context, const Tensor& mkl_tensor,
return output_tensor;
}
#else
+using mkldnn::stream;
+template <typename T> class MklDnnData;
+
template <typename T>
inline Tensor ConvertMklToTF(OpKernelContext* context, const Tensor& mkl_tensor,
const MklDnnShape& mkl_shape) {
Tensor output_tensor;
- TensorShape output_shape;
-
- TF_CHECK_OK(
- Status(error::Code::UNIMPLEMENTED, "Unimplemented conversion function"));
-
+ try {
+ if (!mkl_shape.IsMklTensor())
+ return mkl_tensor; // return input since it is already TF tensor
+
+ TensorShape output_shape = mkl_shape.GetTfShape();;
+
+ // Allocate output tensor.
+ context->allocate_temp(DataTypeToEnum<T>::v(),
+ output_shape, &output_tensor);
+
+ auto cpu_engine = engine(engine::cpu, 0);
+ MklDnnData<T> input(&cpu_engine);
+
+ // Get Mkl layout of input tensor.
+ auto input_mkl_md = mkl_shape.GetMklLayout();
+ auto output_tf_md = mkl_shape.GetTfLayout();
+ auto output_tf_pd = memory::primitive_desc(output_tf_md, cpu_engine);
+ input.SetUsrMem(input_mkl_md, &mkl_tensor);
+
+ // reorder
+ if (input.IsReorderNeeded(output_tf_pd)) {
+ std::vector<primitive> net;
+ CHECK_EQ(input.CheckReorderToOpMem(output_tf_pd, &output_tensor, &net),
+ true);
+ stream(stream::kind::eager).submit(net).wait();
+ } else {
+ // If not, just forward input tensor to output tensor.
+ CHECK(output_tensor.CopyFrom(mkl_tensor, output_shape));
+ }
+ } catch (mkldnn::error& e) {
+ string error_msg = "Status: " + std::to_string(e.status) +
+ ", message: " + string(e.message) + ", in file " +
+ string(__FILE__) + ":" + std::to_string(__LINE__);
+ LOG(FATAL) << "Operation received an exception: " << error_msg;
+ }
return output_tensor;
}
#endif
@@ -1843,7 +1877,7 @@ class FactoryKeyCreator {
template <typename T>
void AddAsKey(const T data) {
auto buffer = reinterpret_cast<const char *>(&data);
- Append(absl::string_view(buffer, sizeof(T)));
+ Append(StringPiece(buffer, sizeof(T)));
}
std::string GetKey() {
@@ -1854,8 +1888,8 @@ class FactoryKeyCreator {
string key_;
const char delimiter = 'x';
const int kMaxKeyLength = 256;
- void Append(absl::string_view s) {
- key_.append(string(s));
+ void Append(StringPiece s) {
+ key_.append(s.ToString());
key_.append(1, delimiter);
}
};