diff options
Diffstat (limited to 'tensorflow/core/graph/mkl_graph_util.h')
-rw-r--r-- | tensorflow/core/graph/mkl_graph_util.h | 179 |
1 files changed, 90 insertions, 89 deletions
diff --git a/tensorflow/core/graph/mkl_graph_util.h b/tensorflow/core/graph/mkl_graph_util.h index cb32d64334..880e4e712e 100644 --- a/tensorflow/core/graph/mkl_graph_util.h +++ b/tensorflow/core/graph/mkl_graph_util.h @@ -21,107 +21,108 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { -// Since our ops are going to produce and also consume N addition tensors -// (Mkl) for N Tensorflow tensors, we can have following different -// orderings among these 2N tensors. -// -// E.g., for Tensorflow tensors A, B, and C, our ops will produce and -// consume A_m, B_m, and C_m additionally. -// -// INTERLEAVED: in this case 2N tensors are interleaved. So for above -// example, the ordering looks like: A, A_m, B, B_m, C, C_m. -// -// CONTIGUOUS: in thi case N Tensorflow tensors are contiguous followed -// by N Mkl tensors. So for above example, the ordering looks -// like: A, B, C, A_m, B_m, C_m -// -// Following APIs map index of original Tensorflow tensors to their -// appropriate position based on selected ordering. For contiguous ordering, -// we need to know the total number of tensors (parameter total). -// -typedef enum { TENSORS_INTERLEAVED, TENSORS_CONTIGUOUS } MklTfTensorOrdering; -// NOTE: Currently, we use contiguous ordering. If you change this, then you -// would need to change Mkl op definitions in nn_ops.cc. -static MklTfTensorOrdering kTensorOrdering = TENSORS_CONTIGUOUS; + // Since our ops are going to produce and also consume N addition tensors + // (Mkl) for N Tensorflow tensors, we can have following different + // orderings among these 2N tensors. + // + // E.g., for Tensorflow tensors A, B, and C, our ops will produce and + // consume A_m, B_m, and C_m additionally. + // + // INTERLEAVED: in this case 2N tensors are interleaved. So for above + // example, the ordering looks like: A, A_m, B, B_m, C, C_m. + // + // CONTIGUOUS: in thi case N Tensorflow tensors are contiguous followed + // by N Mkl tensors. So for above example, the ordering looks + // like: A, B, C, A_m, B_m, C_m + // + // Following APIs map index of original Tensorflow tensors to their + // appropriate position based on selected ordering. For contiguous ordering, + // we need to know the total number of tensors (parameter total). + // + typedef enum { TENSORS_INTERLEAVED, TENSORS_CONTIGUOUS } MklTfTensorOrdering; + // NOTE: Currently, we use contiguous ordering. If you change this, then you + // would need to change Mkl op definitions in nn_ops.cc. + static MklTfTensorOrdering kTensorOrdering = TENSORS_CONTIGUOUS; -// Get index of MetaData tensor from index 'n' of Data tensor. -inline int DataIndexToMetaDataIndex(int n, int total_tensors) { - if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) { - // For interleaved ordering, Mkl tensor follows immediately after - // Tensorflow tensor. - return n + 1; - } else { - CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); - // For contiguous ordering, Mkl tensor is n+total_tensors / 2 away. - return n + total_tensors / 2; + // Get index of MetaData tensor from index 'n' of Data tensor. + inline int DataIndexToMetaDataIndex(int n, int total_tensors) { + if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) { + // For interleaved ordering, Mkl tensor follows immediately after + // Tensorflow tensor. + return n + 1; + } else { + CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); + // For contiguous ordering, Mkl tensor is n+total_tensors / 2 away. + return n + total_tensors / 2; + } } -} -int inline GetTensorDataIndex(int n, int total_tensors) { - if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) { - return 2 * n; // index corresponding to nth input/output tensor - } else { - CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); - return n; - } -} + int inline GetTensorDataIndex(int n, int total_tensors) { + if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) { + return 2 * n; // index corresponding to nth input/output tensor + } else { + CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); + return n; + } + } -int inline GetTensorMetaDataIndex(int n, int total_tensors) { - // Get index for TensorData first and then use mapping function - // to get TensorMetaData index from TensorData index. - int tidx = GetTensorDataIndex(n, total_tensors); - return DataIndexToMetaDataIndex(tidx, total_tensors); -} + int inline GetTensorMetaDataIndex(int n, int total_tensors) { + // Get index for TensorData first and then use mapping function + // to get TensorMetaData index from TensorData index. + int tidx = GetTensorDataIndex(n, total_tensors); + return DataIndexToMetaDataIndex(tidx, total_tensors); + } namespace mkl_op_registry { -static const char* kMklOpLabel = "MklOp"; -static const char* kMklOpLabelPattern = "label='MklOp'"; - -// Get the name of Mkl op from original TensorFlow op -// We prefix 'Mkl' to the original op to get Mkl op. -inline string GetMklOpName(const string& name) { - // Prefix that we add to Tensorflow op name to construct Mkl op name. - const char* const kMklOpPrefix = "_Mkl"; - return string(kMklOpPrefix) + name; -} + static const char* kMklOpLabel = "MklOp"; + static const char* kMklOpLabelPattern = "label='MklOp'"; -// Check whether opname with type T is registered as MKL-compliant. -// -// @input: name of the op -// @input: T datatype to be used for checking op -// @return: true if opname is registered as Mkl op; false otherwise -static inline bool IsMklOp(const std::string& op_name, DataType T) { - string kernel = KernelsRegisteredForOp(op_name); - bool result = - kernel.find(kMklOpLabelPattern) != string::npos && (T == DT_FLOAT); - if (result) { - VLOG(1) << "mkl_op_registry::" << op_name << " is " << kMklOpLabel; + // Get the name of Mkl op from original TensorFlow op + // We prefix 'Mkl' to the original op to get Mkl op. + inline string GetMklOpName(const string& name) { + // Prefix that we add to Tensorflow op name to construct Mkl op name. + const char* const kMklOpPrefix = "_Mkl"; + return string(kMklOpPrefix) + name; } - return result; -} -// Check whether opname with type T is registered as MKL-compliant and -// is element-wise. -// -// @input: name of the op -// @input: T datatype to be used for checking op -// @return: true if opname is registered as element-wise Mkl op; -// false otherwise -static inline bool IsMklElementWiseOp(const std::string& op_name, DataType T) { - if (!IsMklOp(op_name, T)) { - return false; + // Check whether opname with type T is registered as MKL-compliant. + // + // @input: name of the op + // @input: T datatype to be used for checking op + // @return: true if opname is registered as Mkl op; false otherwise + static inline bool IsMklOp(const std::string& op_name, DataType T) { + string kernel = KernelsRegisteredForOp(op_name); + bool result = + kernel.find(kMklOpLabelPattern) != string::npos && (T == DT_FLOAT); + if (result) { + VLOG(1) << "mkl_op_registry::" << op_name << " is " << kMklOpLabel; + } + return result; } - bool result = (0 == op_name.compare(GetMklOpName("Add")) || - 0 == op_name.compare(GetMklOpName("Sub")) || - 0 == op_name.compare(GetMklOpName("Mul")) || - 0 == op_name.compare(GetMklOpName("Maximum")) || - 0 == op_name.compare(GetMklOpName("SquaredDifference"))); + // Check whether opname with type T is registered as MKL-compliant and + // is element-wise. + // + // @input: name of the op + // @input: T datatype to be used for checking op + // @return: true if opname is registered as element-wise Mkl op; + // false otherwise + static inline bool IsMklElementWiseOp(const std::string& op_name, + DataType T) { + if (!IsMklOp(op_name, T)) { + return false; + } - VLOG(1) << "mkl_op_registry::" << op_name - << " is elementwise MKL op: " << result; - return result; -} + bool result = (0 == op_name.compare(GetMklOpName("Add")) || + 0 == op_name.compare(GetMklOpName("Sub")) || + 0 == op_name.compare(GetMklOpName("Mul")) || + 0 == op_name.compare(GetMklOpName("Maximum")) || + 0 == op_name.compare(GetMklOpName("SquaredDifference"))); + + VLOG(1) << "mkl_op_registry::" << op_name + << " is elementwise MKL op: " << result; + return result; + } } // namespace mkl_op_registry } // namespace tensorflow #endif // INTEL_MKL |