aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/mkl_graph_util.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/graph/mkl_graph_util.h')
-rw-r--r--tensorflow/core/graph/mkl_graph_util.h179
1 files changed, 90 insertions, 89 deletions
diff --git a/tensorflow/core/graph/mkl_graph_util.h b/tensorflow/core/graph/mkl_graph_util.h
index cb32d64334..880e4e712e 100644
--- a/tensorflow/core/graph/mkl_graph_util.h
+++ b/tensorflow/core/graph/mkl_graph_util.h
@@ -21,107 +21,108 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
-// Since our ops are going to produce and also consume N addition tensors
-// (Mkl) for N Tensorflow tensors, we can have following different
-// orderings among these 2N tensors.
-//
-// E.g., for Tensorflow tensors A, B, and C, our ops will produce and
-// consume A_m, B_m, and C_m additionally.
-//
-// INTERLEAVED: in this case 2N tensors are interleaved. So for above
-// example, the ordering looks like: A, A_m, B, B_m, C, C_m.
-//
-// CONTIGUOUS: in thi case N Tensorflow tensors are contiguous followed
-// by N Mkl tensors. So for above example, the ordering looks
-// like: A, B, C, A_m, B_m, C_m
-//
-// Following APIs map index of original Tensorflow tensors to their
-// appropriate position based on selected ordering. For contiguous ordering,
-// we need to know the total number of tensors (parameter total).
-//
-typedef enum { TENSORS_INTERLEAVED, TENSORS_CONTIGUOUS } MklTfTensorOrdering;
-// NOTE: Currently, we use contiguous ordering. If you change this, then you
-// would need to change Mkl op definitions in nn_ops.cc.
-static MklTfTensorOrdering kTensorOrdering = TENSORS_CONTIGUOUS;
+ // Since our ops are going to produce and also consume N addition tensors
+ // (Mkl) for N Tensorflow tensors, we can have following different
+ // orderings among these 2N tensors.
+ //
+ // E.g., for Tensorflow tensors A, B, and C, our ops will produce and
+ // consume A_m, B_m, and C_m additionally.
+ //
+ // INTERLEAVED: in this case 2N tensors are interleaved. So for above
+ // example, the ordering looks like: A, A_m, B, B_m, C, C_m.
+ //
+ // CONTIGUOUS: in thi case N Tensorflow tensors are contiguous followed
+ // by N Mkl tensors. So for above example, the ordering looks
+ // like: A, B, C, A_m, B_m, C_m
+ //
+ // Following APIs map index of original Tensorflow tensors to their
+ // appropriate position based on selected ordering. For contiguous ordering,
+ // we need to know the total number of tensors (parameter total).
+ //
+ typedef enum { TENSORS_INTERLEAVED, TENSORS_CONTIGUOUS } MklTfTensorOrdering;
+ // NOTE: Currently, we use contiguous ordering. If you change this, then you
+ // would need to change Mkl op definitions in nn_ops.cc.
+ static MklTfTensorOrdering kTensorOrdering = TENSORS_CONTIGUOUS;
-// Get index of MetaData tensor from index 'n' of Data tensor.
-inline int DataIndexToMetaDataIndex(int n, int total_tensors) {
- if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
- // For interleaved ordering, Mkl tensor follows immediately after
- // Tensorflow tensor.
- return n + 1;
- } else {
- CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
- // For contiguous ordering, Mkl tensor is n+total_tensors / 2 away.
- return n + total_tensors / 2;
+ // Get index of MetaData tensor from index 'n' of Data tensor.
+ inline int DataIndexToMetaDataIndex(int n, int total_tensors) {
+ if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
+ // For interleaved ordering, Mkl tensor follows immediately after
+ // Tensorflow tensor.
+ return n + 1;
+ } else {
+ CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
+ // For contiguous ordering, Mkl tensor is n+total_tensors / 2 away.
+ return n + total_tensors / 2;
+ }
}
-}
-int inline GetTensorDataIndex(int n, int total_tensors) {
- if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
- return 2 * n; // index corresponding to nth input/output tensor
- } else {
- CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
- return n;
- }
-}
+ int inline GetTensorDataIndex(int n, int total_tensors) {
+ if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
+ return 2 * n; // index corresponding to nth input/output tensor
+ } else {
+ CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
+ return n;
+ }
+ }
-int inline GetTensorMetaDataIndex(int n, int total_tensors) {
- // Get index for TensorData first and then use mapping function
- // to get TensorMetaData index from TensorData index.
- int tidx = GetTensorDataIndex(n, total_tensors);
- return DataIndexToMetaDataIndex(tidx, total_tensors);
-}
+ int inline GetTensorMetaDataIndex(int n, int total_tensors) {
+ // Get index for TensorData first and then use mapping function
+ // to get TensorMetaData index from TensorData index.
+ int tidx = GetTensorDataIndex(n, total_tensors);
+ return DataIndexToMetaDataIndex(tidx, total_tensors);
+ }
namespace mkl_op_registry {
-static const char* kMklOpLabel = "MklOp";
-static const char* kMklOpLabelPattern = "label='MklOp'";
-
-// Get the name of Mkl op from original TensorFlow op
-// We prefix 'Mkl' to the original op to get Mkl op.
-inline string GetMklOpName(const string& name) {
- // Prefix that we add to Tensorflow op name to construct Mkl op name.
- const char* const kMklOpPrefix = "_Mkl";
- return string(kMklOpPrefix) + name;
-}
+ static const char* kMklOpLabel = "MklOp";
+ static const char* kMklOpLabelPattern = "label='MklOp'";
-// Check whether opname with type T is registered as MKL-compliant.
-//
-// @input: name of the op
-// @input: T datatype to be used for checking op
-// @return: true if opname is registered as Mkl op; false otherwise
-static inline bool IsMklOp(const std::string& op_name, DataType T) {
- string kernel = KernelsRegisteredForOp(op_name);
- bool result =
- kernel.find(kMklOpLabelPattern) != string::npos && (T == DT_FLOAT);
- if (result) {
- VLOG(1) << "mkl_op_registry::" << op_name << " is " << kMklOpLabel;
+ // Get the name of Mkl op from original TensorFlow op
+ // We prefix 'Mkl' to the original op to get Mkl op.
+ inline string GetMklOpName(const string& name) {
+ // Prefix that we add to Tensorflow op name to construct Mkl op name.
+ const char* const kMklOpPrefix = "_Mkl";
+ return string(kMklOpPrefix) + name;
}
- return result;
-}
-// Check whether opname with type T is registered as MKL-compliant and
-// is element-wise.
-//
-// @input: name of the op
-// @input: T datatype to be used for checking op
-// @return: true if opname is registered as element-wise Mkl op;
-// false otherwise
-static inline bool IsMklElementWiseOp(const std::string& op_name, DataType T) {
- if (!IsMklOp(op_name, T)) {
- return false;
+ // Check whether opname with type T is registered as MKL-compliant.
+ //
+ // @input: name of the op
+ // @input: T datatype to be used for checking op
+ // @return: true if opname is registered as Mkl op; false otherwise
+ static inline bool IsMklOp(const std::string& op_name, DataType T) {
+ string kernel = KernelsRegisteredForOp(op_name);
+ bool result =
+ kernel.find(kMklOpLabelPattern) != string::npos && (T == DT_FLOAT);
+ if (result) {
+ VLOG(1) << "mkl_op_registry::" << op_name << " is " << kMklOpLabel;
+ }
+ return result;
}
- bool result = (0 == op_name.compare(GetMklOpName("Add")) ||
- 0 == op_name.compare(GetMklOpName("Sub")) ||
- 0 == op_name.compare(GetMklOpName("Mul")) ||
- 0 == op_name.compare(GetMklOpName("Maximum")) ||
- 0 == op_name.compare(GetMklOpName("SquaredDifference")));
+ // Check whether opname with type T is registered as MKL-compliant and
+ // is element-wise.
+ //
+ // @input: name of the op
+ // @input: T datatype to be used for checking op
+ // @return: true if opname is registered as element-wise Mkl op;
+ // false otherwise
+ static inline bool IsMklElementWiseOp(const std::string& op_name,
+ DataType T) {
+ if (!IsMklOp(op_name, T)) {
+ return false;
+ }
- VLOG(1) << "mkl_op_registry::" << op_name
- << " is elementwise MKL op: " << result;
- return result;
-}
+ bool result = (0 == op_name.compare(GetMklOpName("Add")) ||
+ 0 == op_name.compare(GetMklOpName("Sub")) ||
+ 0 == op_name.compare(GetMklOpName("Mul")) ||
+ 0 == op_name.compare(GetMklOpName("Maximum")) ||
+ 0 == op_name.compare(GetMklOpName("SquaredDifference")));
+
+ VLOG(1) << "mkl_op_registry::" << op_name
+ << " is elementwise MKL op: " << result;
+ return result;
+ }
} // namespace mkl_op_registry
} // namespace tensorflow
#endif // INTEL_MKL