aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/cmake/CMakeLists.txt2
-rw-r--r--tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc8
-rw-r--r--tensorflow/core/common_runtime/mkl_cpu_allocator_test.cc4
-rw-r--r--tensorflow/core/common_runtime/threadpool_device.cc5
-rw-r--r--tensorflow/core/graph/mkl_layout_pass.cc4
-rw-r--r--tensorflow/core/graph/mkl_layout_pass_test.cc4
-rw-r--r--tensorflow/core/graph/mkl_tfconversion_pass.cc2
-rw-r--r--tensorflow/core/graph/mkl_tfconversion_pass_test.cc4
-rw-r--r--tensorflow/core/kernels/batch_matmul_op_complex.cc10
-rw-r--r--tensorflow/core/kernels/batch_matmul_op_real.cc9
-rw-r--r--tensorflow/core/kernels/cwise_ops_common.cc4
-rw-r--r--tensorflow/core/kernels/gather_nd_op_cpu_impl.h6
-rw-r--r--tensorflow/core/kernels/matmul_op.cc8
-rw-r--r--tensorflow/core/kernels/mkl_batch_matmul_op.cc2
-rw-r--r--tensorflow/core/kernels/mkl_matmul_op.cc6
-rw-r--r--tensorflow/core/kernels/slice_op.cc14
-rw-r--r--tensorflow/core/kernels/transpose_op.cc10
-rw-r--r--tensorflow/core/util/port.cc4
-rw-r--r--tensorflow/tensorflow.bzl3
-rw-r--r--third_party/mkl/BUILD23
-rw-r--r--third_party/mkl/build_defs.bzl41
-rw-r--r--third_party/mkl_dnn/BUILD6
-rw-r--r--third_party/mkl_dnn/build_defs.bzl2
-rw-r--r--tools/bazel.rc5
24 files changed, 117 insertions, 69 deletions
diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt
index ebcabb4223..c6d6f04168 100644
--- a/tensorflow/contrib/cmake/CMakeLists.txt
+++ b/tensorflow/contrib/cmake/CMakeLists.txt
@@ -353,7 +353,7 @@ endif()
# MKL Support
if (tensorflow_ENABLE_MKL_SUPPORT)
- add_definitions(-DINTEL_MKL -DEIGEN_USE_VML)
+ add_definitions(-DINTEL_MKL -DEIGEN_USE_VML -DENABLE_MKL)
include(mkl)
list(APPEND tensorflow_EXTERNAL_LIBRARIES ${mkl_STATIC_LIBRARIES})
list(APPEND tensorflow_EXTERNAL_DEPENDENCIES mkl_copy_shared_to_destination)
diff --git a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc
index 2ed4f69f90..efd6185f8b 100644
--- a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc
+++ b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc
@@ -108,7 +108,7 @@ TEST(DirectSessionWithTrackingAllocTest, CostModelTest) {
EXPECT_EQ(2, shape.dim(0).size());
EXPECT_EQ(1, shape.dim(1).size());
if (node->name() == y->name()) {
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
// if MKL is used, it goes through various additional
// graph rewrite pass. In TF, everytime a graph pass
// happens, "constant" nodes are allocated
@@ -120,13 +120,13 @@ TEST(DirectSessionWithTrackingAllocTest, CostModelTest) {
EXPECT_EQ(29, cm->AllocationId(node, 0));
#else
EXPECT_EQ(21, cm->AllocationId(node, 0));
-#endif
+#endif // INTEL_MKL && ENABLE_MKL
} else {
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
EXPECT_EQ(30, cm->AllocationId(node, 0));
#else
EXPECT_EQ(22, cm->AllocationId(node, 0));
-#endif
+#endif // INTEL_MKL && ENABLE_MKL
}
}
EXPECT_LE(0, cm->MaxExecutionTime(node));
diff --git a/tensorflow/core/common_runtime/mkl_cpu_allocator_test.cc b/tensorflow/core/common_runtime/mkl_cpu_allocator_test.cc
index a67411cd2e..e08ab57638 100644
--- a/tensorflow/core/common_runtime/mkl_cpu_allocator_test.cc
+++ b/tensorflow/core/common_runtime/mkl_cpu_allocator_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
#include "tensorflow/core/common_runtime/mkl_cpu_allocator.h"
@@ -50,4 +50,4 @@ TEST(MKLBFCAllocatorTest, TestMaxLimit) {
} // namespace tensorflow
-#endif // INTEL_MKL
+#endif // INTEL_MKL && ENABLE_MKL
diff --git a/tensorflow/core/common_runtime/threadpool_device.cc b/tensorflow/core/common_runtime/threadpool_device.cc
index 0fbc20b34b..8587d1783a 100644
--- a/tensorflow/core/common_runtime/threadpool_device.cc
+++ b/tensorflow/core/common_runtime/threadpool_device.cc
@@ -113,8 +113,11 @@ class MklCPUAllocatorFactory : public AllocatorFactory {
}
};
+#ifdef ENABLE_MKL
REGISTER_MEM_ALLOCATOR("MklCPUAllocator", 200, MklCPUAllocatorFactory);
+#endif // ENABLE_MKL
+
} // namespace
-#endif
+#endif // INTEL_MKL
} // namespace tensorflow
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index f5b0105862..37b88f1728 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -977,7 +977,9 @@ std::vector<MklLayoutRewritePass::ContextInfo*> MklLayoutRewritePass::cinfo_;
// nodes. Do not change the ordering of the Mkl passes.
const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup =
OptimizationPassRegistry::POST_PARTITIONING;
+#ifdef ENABLE_MKL
REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass);
+#endif // ENABLE_MKL
//////////////////////////////////////////////////////////////////////////
// Helper functions for creating new node
@@ -3150,7 +3152,9 @@ MklLayoutRewritePass::ConstStringsInfo MklLayoutRewritePass::csinfo_;
// nodes. Do not change the ordering of the Mkl passes.
const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup =
OptimizationPassRegistry::POST_PARTITIONING;
+#ifdef ENABLE_MKL
REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass);
+#endif // ENABLE_MKL
//////////////////////////////////////////////////////////////////////////
// Helper functions for creating new node
diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc
index e8bac847e5..f42a4ee98b 100644
--- a/tensorflow/core/graph/mkl_layout_pass_test.cc
+++ b/tensorflow/core/graph/mkl_layout_pass_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
#include "tensorflow/core/graph/mkl_layout_pass.h"
#include "tensorflow/core/graph/mkl_graph_util.h"
@@ -3586,4 +3586,4 @@ BENCHMARK(BM_MklLayoutRewritePass)->Arg(1000)->Arg(10000);
} // namespace tensorflow
-#endif /* INTEL_MKL */
+#endif // INTEL_MKL && ENABLE_MKL
diff --git a/tensorflow/core/graph/mkl_tfconversion_pass.cc b/tensorflow/core/graph/mkl_tfconversion_pass.cc
index b67a321fc1..8c5ffd71a3 100644
--- a/tensorflow/core/graph/mkl_tfconversion_pass.cc
+++ b/tensorflow/core/graph/mkl_tfconversion_pass.cc
@@ -133,7 +133,9 @@ class MklToTfConversionPass : public GraphOptimizationPass {
// complete picture of inputs and outputs of the nodes in the graphs.
const OptimizationPassRegistry::Grouping kMklTfConvPassGroup =
OptimizationPassRegistry::POST_PARTITIONING;
+#ifdef ENABLE_MKL
REGISTER_OPTIMIZATION(kMklTfConvPassGroup, 2, MklToTfConversionPass);
+#endif // ENABLE_MKL
Status MklToTfConversionPass::InsertConversionNodeOnEdge(
std::unique_ptr<Graph>* g, Edge* e) {
diff --git a/tensorflow/core/graph/mkl_tfconversion_pass_test.cc b/tensorflow/core/graph/mkl_tfconversion_pass_test.cc
index ebcb6de551..319437a801 100644
--- a/tensorflow/core/graph/mkl_tfconversion_pass_test.cc
+++ b/tensorflow/core/graph/mkl_tfconversion_pass_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
#include "tensorflow/core/graph/mkl_tfconversion_pass.h"
#include "tensorflow/core/graph/mkl_graph_util.h"
@@ -304,4 +304,4 @@ BENCHMARK(BM_RunMklToTfConversionPass)->Arg(1000)->Arg(10000);
} // namespace
} // namespace tensorflow
-#endif /* INTEL_MKL */
+#endif // INTEL_MKL && ENABLE_MKL
diff --git a/tensorflow/core/kernels/batch_matmul_op_complex.cc b/tensorflow/core/kernels/batch_matmul_op_complex.cc
index 54c45bfe63..f48bd0c318 100644
--- a/tensorflow/core/kernels/batch_matmul_op_complex.cc
+++ b/tensorflow/core/kernels/batch_matmul_op_complex.cc
@@ -17,14 +17,18 @@ limitations under the License.
namespace tensorflow {
-#if !defined(INTEL_MKL) || defined(INTEL_MKL_DNN_ONLY)
+// MKL_ML registers its own complex64/128 kernels in mkl_batch_matmul_op.cc
+// if defined(INTEL_MKL) && !defined(INTEL_MKL_DNN_ONLY) && defined(ENABLE_MKL).
+// Anything else (the complement) should register the TF ones.
+// (MKL-DNN doesn't implement these kernels either.)
+#if !defined(INTEL_MKL) || defined(INTEL_MKL_DNN_ONLY) || !defined(ENABLE_MKL)
TF_CALL_complex64(REGISTER_BATCH_MATMUL_CPU);
TF_CALL_complex128(REGISTER_BATCH_MATMUL_CPU);
-#endif
+#endif // !INTEL_MKL || INTEL_MKL_DNN_ONLY || !ENABLE_MKL
#if GOOGLE_CUDA
TF_CALL_complex64(REGISTER_BATCH_MATMUL_GPU);
TF_CALL_complex128(REGISTER_BATCH_MATMUL_GPU);
-#endif
+#endif // GOOGLE_CUDA
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/batch_matmul_op_real.cc b/tensorflow/core/kernels/batch_matmul_op_real.cc
index 584b507c70..25ae795d8e 100644
--- a/tensorflow/core/kernels/batch_matmul_op_real.cc
+++ b/tensorflow/core/kernels/batch_matmul_op_real.cc
@@ -21,10 +21,15 @@ limitations under the License.
namespace tensorflow {
-#if !defined(INTEL_MKL) || defined(INTEL_MKL_DNN_ONLY)
+// MKL_ML registers its own float and double kernels in mkl_batch_matmul_op.cc
+// if defined(INTEL_MKL) && !defined(INTEL_MKL_DNN_ONLY) && defined(ENABLE_MKL).
+// Anything else (the complement) should register the TF ones.
+// (MKL-DNN doesn't implement these kernels either.)
+#if !defined(INTEL_MKL) || defined(INTEL_MKL_DNN_ONLY) || !defined(ENABLE_MKL)
TF_CALL_float(REGISTER_BATCH_MATMUL_CPU);
TF_CALL_double(REGISTER_BATCH_MATMUL_CPU);
-#endif
+#endif // !INTEL_MKL || INTEL_MKL_DNN_ONLY || !ENABLE_MKL
+
TF_CALL_half(REGISTER_BATCH_MATMUL_CPU);
TF_CALL_int32(REGISTER_BATCH_MATMUL_CPU);
diff --git a/tensorflow/core/kernels/cwise_ops_common.cc b/tensorflow/core/kernels/cwise_ops_common.cc
index 980edffceb..8ad3b4d1fc 100644
--- a/tensorflow/core/kernels/cwise_ops_common.cc
+++ b/tensorflow/core/kernels/cwise_ops_common.cc
@@ -20,9 +20,9 @@ namespace tensorflow {
BinaryOpShared::BinaryOpShared(OpKernelConstruction* ctx, DataType out,
DataType in)
: OpKernel(ctx) {
-#ifndef INTEL_MKL
+#if !defined(INTEL_MKL) || !defined(ENABLE_MKL)
OP_REQUIRES_OK(ctx, ctx->MatchSignature({in, in}, {out}));
-#endif
+#endif // !INTEL_MKL || !ENABLE_MKL
}
void BinaryOpShared::SetUnimplementedError(OpKernelContext* ctx) {
diff --git a/tensorflow/core/kernels/gather_nd_op_cpu_impl.h b/tensorflow/core/kernels/gather_nd_op_cpu_impl.h
index 277ee2be02..1c78de253e 100644
--- a/tensorflow/core/kernels/gather_nd_op_cpu_impl.h
+++ b/tensorflow/core/kernels/gather_nd_op_cpu_impl.h
@@ -114,7 +114,7 @@ struct GatherNdSlice<CPUDevice, T, Index, IXDIM> {
generator::GatherNdSliceGenerator<T, Index, IXDIM> gather_nd_generator(
slice_size, Tindices, Tparams, Tout, &error_loc);
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
// Eigen implementation below is not highly performant. gather_nd_generator
// does not seem to be called in parallel, leading to very poor performance.
// Additionally, since it uses scalar (Tscratch) to invoke 'generate', it
@@ -126,12 +126,12 @@ struct GatherNdSlice<CPUDevice, T, Index, IXDIM> {
const Eigen::array<Eigen::DenseIndex, 1> loc{i};
gather_nd_generator(loc);
}
-#else // INTEL_MKL
+#else // INTEL_MKL && ENABLE_MKL
Tscratch.device(d) = Tscratch.reshape(reshape_dims)
.broadcast(broadcast_dims)
.generate(gather_nd_generator)
.sum();
-#endif
+#endif // INTEL_MKL && ENABLE_MKL
// error_loc() returns -1 if there's no out-of-bounds index,
// otherwise it returns the location of an OOB index in Tindices.
diff --git a/tensorflow/core/kernels/matmul_op.cc b/tensorflow/core/kernels/matmul_op.cc
index 79967aab38..4ad390a411 100644
--- a/tensorflow/core/kernels/matmul_op.cc
+++ b/tensorflow/core/kernels/matmul_op.cc
@@ -578,7 +578,7 @@ struct MatMulFunctor<SYCLDevice, T> {
.Label("cublas"), \
MatMulOp<GPUDevice, T, true /* cublas */>)
-#if defined(INTEL_MKL)
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
// MKL does not support half, bfloat16 and int32 types for
// matrix-multiplication, so register the kernel to use default Eigen based
@@ -606,9 +606,9 @@ TF_CALL_double(REGISTER_CPU);
TF_CALL_complex64(REGISTER_CPU_EIGEN);
TF_CALL_complex128(REGISTER_CPU_EIGEN);
TF_CALL_double(REGISTER_CPU_EIGEN);
-#endif
+#endif // INTEL_MKL_DNN_ONLY
-#else // INTEL MKL
+#else // INTEL_MKL && ENABLE_MKL
TF_CALL_float(REGISTER_CPU);
TF_CALL_double(REGISTER_CPU);
TF_CALL_half(REGISTER_CPU);
@@ -616,7 +616,7 @@ TF_CALL_bfloat16(REGISTER_CPU);
TF_CALL_int32(REGISTER_CPU);
TF_CALL_complex64(REGISTER_CPU);
TF_CALL_complex128(REGISTER_CPU);
-#endif
+#endif // INTEL_MKL && ENABLE_MKL
#if GOOGLE_CUDA
TF_CALL_float(REGISTER_GPU);
diff --git a/tensorflow/core/kernels/mkl_batch_matmul_op.cc b/tensorflow/core/kernels/mkl_batch_matmul_op.cc
index 0841395dc3..bc135de11e 100644
--- a/tensorflow/core/kernels/mkl_batch_matmul_op.cc
+++ b/tensorflow/core/kernels/mkl_batch_matmul_op.cc
@@ -223,10 +223,12 @@ class BatchMatMulMkl : public OpKernel {
Name("BatchMatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
BatchMatMulMkl<CPUDevice, TYPE>)
+#ifdef ENABLE_MKL
TF_CALL_float(REGISTER_BATCH_MATMUL_MKL);
TF_CALL_double(REGISTER_BATCH_MATMUL_MKL);
TF_CALL_complex64(REGISTER_BATCH_MATMUL_MKL);
TF_CALL_complex128(REGISTER_BATCH_MATMUL_MKL);
+#endif // ENABLE_MKL
} // end namespace tensorflow
#endif
diff --git a/tensorflow/core/kernels/mkl_matmul_op.cc b/tensorflow/core/kernels/mkl_matmul_op.cc
index 077d62ce32..f4788f4851 100644
--- a/tensorflow/core/kernels/mkl_matmul_op.cc
+++ b/tensorflow/core/kernels/mkl_matmul_op.cc
@@ -217,7 +217,7 @@ class MklMatMulOp : public OpKernel {
reinterpret_cast<const MKL_Complex16*>(b), ldb, &beta,
reinterpret_cast<MKL_Complex16*>(c), ldc);
}
-#endif
+#endif // !INTEL_MKL_DNN_ONLY
};
#define REGISTER_CPU(T) \
@@ -225,6 +225,7 @@ class MklMatMulOp : public OpKernel {
Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
MklMatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>);
+#ifdef ENABLE_MKL
// TODO(inteltf) Consider template specialization when adding/removing
// additional types
TF_CALL_float(REGISTER_CPU);
@@ -233,7 +234,8 @@ TF_CALL_float(REGISTER_CPU);
TF_CALL_double(REGISTER_CPU);
TF_CALL_complex64(REGISTER_CPU);
TF_CALL_complex128(REGISTER_CPU);
-#endif
+#endif // !INTEL_MKL_DNN_ONLY
+#endif // ENABLE_MKL
} // namespace tensorflow
#endif // INTEL_MKL
diff --git a/tensorflow/core/kernels/slice_op.cc b/tensorflow/core/kernels/slice_op.cc
index 77594479cb..97f77e45b6 100644
--- a/tensorflow/core/kernels/slice_op.cc
+++ b/tensorflow/core/kernels/slice_op.cc
@@ -411,7 +411,7 @@ class MklSliceOp : public OpKernel {
context->input(0).tensor<T, NDIM>(), indices, sizes);
}
};
-#endif
+#endif // INTEL_MKL
// Forward declarations of the functor specializations for declared in the
// sharded source files.
@@ -440,18 +440,14 @@ TF_CALL_ALL_TYPES(DECLARE_FOR_N);
#undef DECLARE_CPU_SPEC
} // namespace functor
-#ifndef INTEL_MKL
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
#define REGISTER_SLICE(type) \
REGISTER_KERNEL_BUILDER(Name("Slice") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.HostMemory("begin") \
.HostMemory("size"), \
- SliceOp<CPUDevice, type>)
-
-TF_CALL_POD_STRING_TYPES(REGISTER_SLICE);
-TF_CALL_QUANTIZED_TYPES(REGISTER_SLICE);
-#undef REGISTER_SLICE
+ MklSliceOp<CPUDevice, type>)
#else
#define REGISTER_SLICE(type) \
REGISTER_KERNEL_BUILDER(Name("Slice") \
@@ -459,12 +455,12 @@ TF_CALL_QUANTIZED_TYPES(REGISTER_SLICE);
.TypeConstraint<type>("T") \
.HostMemory("begin") \
.HostMemory("size"), \
- MklSliceOp<CPUDevice, type>)
+ SliceOp<CPUDevice, type>)
+#endif // INTEL_MKL && ENABLE_MKL
TF_CALL_POD_STRING_TYPES(REGISTER_SLICE);
TF_CALL_QUANTIZED_TYPES(REGISTER_SLICE);
#undef REGISTER_SLICE
-#endif // INTEL_MKL
#if GOOGLE_CUDA
// Forward declarations of the functor specializations for GPU.
diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc
index 0f0f65c5a3..48e392c070 100644
--- a/tensorflow/core/kernels/transpose_op.cc
+++ b/tensorflow/core/kernels/transpose_op.cc
@@ -218,7 +218,7 @@ Status ConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx,
perm, out);
}
-#if defined(INTEL_MKL)
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
#define REGISTER(T) \
REGISTER_KERNEL_BUILDER(Name("Transpose") \
.Device(DEVICE_CPU) \
@@ -230,11 +230,8 @@ Status ConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx,
.TypeConstraint<T>("T") \
.HostMemory("perm"), \
MklConjugateTransposeCpuOp);
-TF_CALL_ALL_TYPES(REGISTER);
-#undef REGISTER
-
-#else // INTEL_MKL
+#else // INTEL_MKL && ENABLE_MKL
#define REGISTER(T) \
REGISTER_KERNEL_BUILDER(Name("Transpose") \
.Device(DEVICE_CPU) \
@@ -246,9 +243,10 @@ TF_CALL_ALL_TYPES(REGISTER);
.TypeConstraint<T>("T") \
.HostMemory("perm"), \
ConjugateTransposeCpuOp);
+#endif // INTEL_MKL && ENABLE_MKL
+
TF_CALL_ALL_TYPES(REGISTER)
#undef REGISTER
-#endif // INTEL_MKL
#if GOOGLE_CUDA
Status TransposeGpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in,
diff --git a/tensorflow/core/util/port.cc b/tensorflow/core/util/port.cc
index c081ceae57..e01058dff6 100644
--- a/tensorflow/core/util/port.cc
+++ b/tensorflow/core/util/port.cc
@@ -38,10 +38,10 @@ bool CudaSupportsHalfMatMulAndConv() {
}
bool IsMklEnabled() {
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
return true;
#else
return false;
-#endif
+#endif // INTEL_MKL && ENABLE_MKL
}
} // end namespace tensorflow
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index 7ddaf7806e..d6c75d675c 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -22,6 +22,7 @@ load(
)
load(
"//third_party/mkl:build_defs.bzl",
+ "if_enable_mkl",
"if_mkl",
"if_mkl_lnx_x64",
"if_mkl_ml",
@@ -237,6 +238,7 @@ def tf_copts(android_optimization_level_override = "-O2", is_external = False):
if_tensorrt(["-DGOOGLE_TENSORRT=1"]) +
if_mkl(["-DINTEL_MKL=1", "-DEIGEN_USE_VML"]) +
if_mkl_open_source_only(["-DINTEL_MKL_DNN_ONLY"]) +
+ if_enable_mkl(["-DENABLE_MKL"]) +
if_ngraph(["-DINTEL_NGRAPH=1"]) +
if_mkl_lnx_x64(["-fopenmp"]) +
if_android_arm(["-mfpu=neon"]) +
@@ -1082,6 +1084,7 @@ def tf_cuda_library(deps = None, cuda_deps = None, copts = tf_copts(), **kwargs)
]),
copts = (copts + if_cuda(["-DGOOGLE_CUDA=1"]) + if_mkl(["-DINTEL_MKL=1"]) +
if_mkl_open_source_only(["-DINTEL_MKL_DNN_ONLY"]) +
+ if_enable_mkl(["-DENABLE_MKL"]) +
if_tensorrt(["-DGOOGLE_TENSORRT=1"])),
**kwargs
)
diff --git a/third_party/mkl/BUILD b/third_party/mkl/BUILD
index efff7fd51b..15a3e5cfa7 100644
--- a/third_party/mkl/BUILD
+++ b/third_party/mkl/BUILD
@@ -1,26 +1,26 @@
licenses(["notice"]) # 3-Clause BSD
config_setting(
- name = "using_mkl",
+ name = "build_with_mkl",
define_values = {
- "using_mkl": "true",
+ "build_with_mkl": "true",
},
visibility = ["//visibility:public"],
)
config_setting(
- name = "using_mkl_ml_only",
+ name = "build_with_mkl_ml_only",
define_values = {
- "using_mkl": "true",
- "using_mkl_ml_only": "true",
+ "build_with_mkl": "true",
+ "build_with_mkl_ml_only": "true",
},
visibility = ["//visibility:public"],
)
config_setting(
- name = "using_mkl_lnx_x64",
+ name = "build_with_mkl_lnx_x64",
define_values = {
- "using_mkl": "true",
+ "build_with_mkl": "true",
},
values = {
"cpu": "k8",
@@ -28,6 +28,15 @@ config_setting(
visibility = ["//visibility:public"],
)
+config_setting(
+ name = "enable_mkl",
+ define_values = {
+ "enable_mkl": "true",
+ "build_with_mkl": "true",
+ },
+ visibility = ["//visibility:public"],
+)
+
load(
"//third_party/mkl:build_defs.bzl",
"if_mkl",
diff --git a/third_party/mkl/build_defs.bzl b/third_party/mkl/build_defs.bzl
index b645c0fc5c..bb798e715a 100644
--- a/third_party/mkl/build_defs.bzl
+++ b/third_party/mkl/build_defs.bzl
@@ -1,9 +1,11 @@
# -*- Python -*-
"""Skylark macros for MKL.
-if_mkl is a conditional to check if MKL is enabled or not.
-if_mkl_ml is a conditional to check if MKL-ML is enabled.
+
+if_mkl is a conditional to check if we are building with MKL.
+if_mkl_ml is a conditional to check if we are building with MKL-ML.
if_mkl_ml_only is a conditional to check for MKL-ML-only (no MKL-DNN) mode.
if_mkl_lnx_x64 is a conditional to check for MKL
+if_enable_mkl is a conditional to check if building with MKL and MKL is enabled.
mkl_repository is a repository rule for creating MKL repository rule that can
be pointed to either a local folder, or download it from the internet.
@@ -24,7 +26,7 @@ def if_mkl(if_true, if_false = []):
a select evaluating to either if_true or if_false as appropriate.
"""
return select({
- str(Label("//third_party/mkl:using_mkl")): if_true,
+ str(Label("//third_party/mkl:build_with_mkl")): if_true,
"//conditions:default": if_false,
})
@@ -40,8 +42,8 @@ def if_mkl_ml(if_true, if_false = []):
a select evaluating to either if_true or if_false as appropriate.
"""
return select({
- str(Label("//third_party/mkl_dnn:using_mkl_dnn_only")): if_false,
- str(Label("//third_party/mkl:using_mkl")): if_true,
+ str(Label("//third_party/mkl_dnn:build_with_mkl_dnn_only")): if_false,
+ str(Label("//third_party/mkl:build_with_mkl")): if_true,
"//conditions:default": if_false,
})
@@ -56,12 +58,12 @@ def if_mkl_ml_only(if_true, if_false = []):
a select evaluating to either if_true or if_false as appropriate.
"""
return select({
- str(Label("//third_party/mkl:using_mkl_ml_only")): if_true,
+ str(Label("//third_party/mkl:build_with_mkl_ml_only")): if_true,
"//conditions:default": if_false,
})
def if_mkl_lnx_x64(if_true, if_false = []):
- """Shorthand to select() on if MKL is on and the target is Linux x86-64.
+ """Shorthand to select() if building with MKL and the target is Linux x86-64.
Args:
if_true: expression to evaluate if building with MKL is enabled and the
@@ -73,7 +75,24 @@ def if_mkl_lnx_x64(if_true, if_false = []):
a select evaluating to either if_true or if_false as appropriate.
"""
return select({
- str(Label("//third_party/mkl:using_mkl_lnx_x64")): if_true,
+ str(Label("//third_party/mkl:build_with_mkl_lnx_x64")): if_true,
+ "//conditions:default": if_false,
+ })
+
+def if_enable_mkl(if_true, if_false = []):
+ """Shorthand to select() if we are building with MKL and MKL is enabled.
+
+ This is only effective when built with MKL.
+
+ Args:
+ if_true: expression to evaluate if building with MKL and MKL is enabled
+ if_false: expression to evaluate if building without MKL or MKL is not enabled.
+
+ Returns:
+ A select evaluating to either if_true or if_false as appropriate.
+ """
+ return select({
+ "//third_party/mkl:enable_mkl": if_true,
"//conditions:default": if_false,
})
@@ -87,9 +106,9 @@ def mkl_deps():
inclusion in the deps attribute of rules.
"""
return select({
- str(Label("//third_party/mkl_dnn:using_mkl_dnn_only")): ["@mkl_dnn"],
- str(Label("//third_party/mkl:using_mkl_ml_only")): ["//third_party/mkl:intel_binary_blob"],
- str(Label("//third_party/mkl:using_mkl")): [
+ str(Label("//third_party/mkl_dnn:build_with_mkl_dnn_only")): ["@mkl_dnn"],
+ str(Label("//third_party/mkl:build_with_mkl_ml_only")): ["//third_party/mkl:intel_binary_blob"],
+ str(Label("//third_party/mkl:build_with_mkl")): [
"//third_party/mkl:intel_binary_blob",
"@mkl_dnn",
],
diff --git a/third_party/mkl_dnn/BUILD b/third_party/mkl_dnn/BUILD
index 3e567fa9fc..58ecda55e6 100644
--- a/third_party/mkl_dnn/BUILD
+++ b/third_party/mkl_dnn/BUILD
@@ -3,10 +3,10 @@ licenses(["notice"])
exports_files(["LICENSE"])
config_setting(
- name = "using_mkl_dnn_only",
+ name = "build_with_mkl_dnn_only",
define_values = {
- "using_mkl": "true",
- "using_mkl_dnn_only": "true",
+ "build_with_mkl": "true",
+ "build_with_mkl_dnn_only": "true",
},
visibility = ["//visibility:public"],
)
diff --git a/third_party/mkl_dnn/build_defs.bzl b/third_party/mkl_dnn/build_defs.bzl
index 7ce2a7d9b0..6388f31971 100644
--- a/third_party/mkl_dnn/build_defs.bzl
+++ b/third_party/mkl_dnn/build_defs.bzl
@@ -8,6 +8,6 @@ def if_mkl_open_source_only(if_true, if_false = []):
"""
return select({
- str(Label("//third_party/mkl_dnn:using_mkl_dnn_only")): if_true,
+ str(Label("//third_party/mkl_dnn:build_with_mkl_dnn_only")): if_true,
"//conditions:default": if_false,
})
diff --git a/tools/bazel.rc b/tools/bazel.rc
index ccf62629d1..6747c7e795 100644
--- a/tools/bazel.rc
+++ b/tools/bazel.rc
@@ -24,12 +24,13 @@ build --define framework_shared_object=true
# Please note that MKL on MacOS or windows is still not supported.
# If you would like to use a local MKL instead of downloading, please set the
# environment variable "TF_MKL_ROOT" every time before build.
-build:mkl --define=using_mkl=true
+build:mkl --define=build_with_mkl=true --define=enable_mkl=true
build:mkl -c opt
# This config option is used to enable MKL-DNN open source library only,
# without depending on MKL binary version.
-build:mkl_open_source_only --define=using_mkl_dnn_only=true
+build:mkl_open_source_only --define=build_with_mkl_dnn_only=true
+build:mkl_open_source_only --define=build_with_mkl=true --define=enable_mkl=true
build:download_clang --crosstool_top=@local_config_download_clang//:toolchain
build:download_clang --define=using_clang=true