aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/BUILD6
-rw-r--r--tensorflow/core/common_runtime/mkl_cpu_allocator.h4
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc5
-rw-r--r--tensorflow/core/distributed_runtime/worker_cache_partial.cc18
-rw-r--r--tensorflow/core/graph/mkl_layout_pass.cc167
-rw-r--r--tensorflow/core/graph/mkl_layout_pass_test.cc288
-rw-r--r--tensorflow/core/graph/mkl_tfconversion_pass.cc159
-rw-r--r--tensorflow/core/kernels/BUILD45
-rw-r--r--tensorflow/core/kernels/bias_op_gpu.cu.cc23
-rw-r--r--tensorflow/core/kernels/cuda_solvers.cc38
-rw-r--r--tensorflow/core/kernels/cuda_solvers.h16
-rw-r--r--tensorflow/core/kernels/cwise_ops.h4
-rw-r--r--tensorflow/core/kernels/cwise_ops_common.cc2
-rw-r--r--tensorflow/core/kernels/decode_raw_op.cc1
-rw-r--r--tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc30
-rw-r--r--tensorflow/core/kernels/fill_functor.cc2
-rw-r--r--tensorflow/core/kernels/mkl_aggregate_ops.cc273
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops.cc2
-rw-r--r--tensorflow/core/kernels/mkl_cwise_ops_common.cc88
-rw-r--r--tensorflow/core/kernels/mkl_identity_op.cc4
-rw-r--r--tensorflow/core/kernels/mkl_input_conversion_op.cc259
-rw-r--r--tensorflow/core/kernels/mkl_tfconv_op.h136
-rw-r--r--tensorflow/core/kernels/svd_op_gpu.cu.cc413
-rw-r--r--tensorflow/core/kernels/tensor_array_ops.cc2
-rw-r--r--tensorflow/core/ops/math_ops.cc103
-rw-r--r--tensorflow/core/ops/nn_ops.cc23
-rw-r--r--tensorflow/core/ops/ops.pbtxt19
-rw-r--r--tensorflow/core/ops/parsing_ops.cc2
-rw-r--r--tensorflow/core/ops/string_ops.cc2
-rw-r--r--tensorflow/core/platform/cuda_libdevice_path_test.cc2
-rw-r--r--tensorflow/core/public/version.h4
-rw-r--r--tensorflow/core/util/cuda_kernel_helper.h76
-rw-r--r--tensorflow/core/util/mkl_util.h131
33 files changed, 2090 insertions, 257 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 6c1896d7ab..188036b7aa 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -790,13 +790,16 @@ cc_library(
]) + if_mkl([
"//tensorflow/core/kernels:mkl_concat_op",
"//tensorflow/core/kernels:mkl_conv_op",
+ "//tensorflow/core/kernels:mkl_cwise_ops_common",
"//tensorflow/core/kernels:mkl_fused_batch_norm_op",
"//tensorflow/core/kernels:mkl_identity_op",
+ "//tensorflow/core/kernels:mkl_input_conversion_op",
"//tensorflow/core/kernels:mkl_lrn_op",
"//tensorflow/core/kernels:mkl_pooling_ops",
"//tensorflow/core/kernels:mkl_relu_op",
"//tensorflow/core/kernels:mkl_reshape_op",
"//tensorflow/core/kernels:mkl_tfconv_op",
+ "//tensorflow/core/kernels:mkl_aggregate_ops",
]),
)
@@ -2481,10 +2484,13 @@ tf_cc_test_mkl(
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:scope",
"//tensorflow/cc:sendrecv_ops",
+ "//tensorflow/core/kernels:mkl_aggregate_ops",
"//tensorflow/core/kernels:mkl_concat_op",
"//tensorflow/core/kernels:mkl_conv_op",
+ "//tensorflow/core/kernels:mkl_cwise_ops_common",
"//tensorflow/core/kernels:mkl_fused_batch_norm_op",
"//tensorflow/core/kernels:mkl_identity_op",
+ "//tensorflow/core/kernels:mkl_input_conversion_op",
"//tensorflow/core/kernels:mkl_lrn_op",
"//tensorflow/core/kernels:mkl_pooling_ops",
"//tensorflow/core/kernels:mkl_relu_op",
diff --git a/tensorflow/core/common_runtime/mkl_cpu_allocator.h b/tensorflow/core/common_runtime/mkl_cpu_allocator.h
index 005aabf9b8..f16da10d7a 100644
--- a/tensorflow/core/common_runtime/mkl_cpu_allocator.h
+++ b/tensorflow/core/common_runtime/mkl_cpu_allocator.h
@@ -75,12 +75,12 @@ class MklCPUAllocator : public Allocator {
// Hooks provided by this allocator for memory allocation routines from MKL
static inline void* MallocHook(size_t size) {
- VLOG(2) << "MklCPUAllocator: In MallocHook";
+ VLOG(3) << "MklCPUAllocator: In MallocHook";
return cpu_allocator()->AllocateRaw(kAlignment, size);
}
static inline void FreeHook(void* ptr) {
- VLOG(2) << "MklCPUAllocator: In FreeHook";
+ VLOG(3) << "MklCPUAllocator: In FreeHook";
cpu_allocator()->DeallocateRaw(ptr);
}
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc
index 29acad34e9..06695db779 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc
@@ -69,9 +69,8 @@ class GrpcWorkerCache : public WorkerCachePartial {
} else {
SharedGrpcChannelPtr channel = channel_cache_->FindWorkerChannel(target);
if (!channel) return nullptr;
- WorkerInterface* ret = NewGrpcRemoteWorker(&live_rpc_counter_, channel,
- &completion_queue_, &logger_);
- return ret;
+ return NewGrpcRemoteWorker(&live_rpc_counter_, channel,
+ &completion_queue_, &logger_);
}
}
diff --git a/tensorflow/core/distributed_runtime/worker_cache_partial.cc b/tensorflow/core/distributed_runtime/worker_cache_partial.cc
index 90d5e78884..61e5416234 100644
--- a/tensorflow/core/distributed_runtime/worker_cache_partial.cc
+++ b/tensorflow/core/distributed_runtime/worker_cache_partial.cc
@@ -29,7 +29,7 @@ namespace tensorflow {
bool WorkerCachePartial::GetDeviceLocalityNonBlocking(
const string& device_name, DeviceLocality* locality) {
mutex_lock lock(mu_); // could use reader lock
- const auto& iter = device_status_cache_.find(device_name);
+ auto iter = device_status_cache_.find(device_name);
if (iter != device_status_cache_.end()) {
*locality = iter->second.locality();
return true;
@@ -44,16 +44,8 @@ void WorkerCachePartial::GetDeviceLocalityAsync(const string& device_name,
// If cache entry was empty, make one try to fill it by RPC.
SchedClosure([this, &device_name, locality, done]() {
Status s = RefreshDeviceStatus(device_name);
- if (s.ok()) {
- if (!GetDeviceLocalityNonBlocking(device_name, locality)) {
- mutex_lock lock(mu_);
- const auto& iter = device_status_cache_.find(device_name);
- if (iter == device_status_cache_.end()) {
- s = errors::Unavailable("No known remote device: ", device_name);
- } else {
- s = errors::Internal("Failed to find locality for ", device_name);
- }
- }
+ if (s.ok() && !GetDeviceLocalityNonBlocking(device_name, locality)) {
+ s = errors::Unavailable("No known remote device: ", device_name);
}
done(s);
});
@@ -70,7 +62,9 @@ Status WorkerCachePartial::RefreshDeviceStatus(const string& device_name) {
s = errors::InvalidArgument("Bad device name to RefreshDeviceStatus: ",
device_name);
}
- auto deleter = [this, task](WorkerInterface* wi) { ReleaseWorker(task, wi); };
+ auto deleter = [this, &task](WorkerInterface* wi) {
+ ReleaseWorker(task, wi);
+ };
std::unique_ptr<WorkerInterface, decltype(deleter)> rwi(CreateWorker(task),
deleter);
if (s.ok() && !rwi.get()) {
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index cf5d6e8baa..90377e54c7 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -256,6 +256,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
public:
MklLayoutRewritePass() {
// NOTE: names are alphabetically sorted.
+ csinfo_.addn = "AddN";
csinfo_.avg_pool = "AvgPool";
csinfo_.avg_pool_grad = "AvgPoolGrad";
csinfo_.bias_add = "BiasAdd";
@@ -279,17 +280,31 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias";
csinfo_.mkl_conv2d_with_bias_backprop_bias =
"_MklConv2DWithBiasBackpropBias";
- csinfo_.relu = "Relu";
- csinfo_.relu_grad = "ReluGrad";
- csinfo_.reshape = "Reshape";
- csinfo_.split = "Split";
+ csinfo_.relu = "Relu";
+ csinfo_.relu_grad = "ReluGrad";
+ csinfo_.reshape = "Reshape";
+ csinfo_.split = "Split";
+ // Element-wise ops. Ensure you also add any new ops to IsOpElementWise
+ // in the MklUtil.h (IsMklElementWiseOp method) to ensure that the
+ // MklInputConversion op is added before it.
+ csinfo_.add = "Add";
+ csinfo_.maximum = "Maximum";
+ csinfo_.mul = "Mul";
+ csinfo_.squared_difference = "SquaredDifference";
+ csinfo_.sub = "Sub";
+ // End - element-wise ops. See note above.
// NOTE: names are alphabetically sorted.
+ rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn), CopyAttrsAddN,
+ AddNRewrite, nullptr});
+ rinfo_.push_back({csinfo_.add,
+ mkl_op_registry::GetMklOpName(csinfo_.add),
+ CopyAttrsDataType, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.avg_pool,
- GetMklOpName(csinfo_.avg_pool),
+ mkl_op_registry::GetMklOpName(csinfo_.avg_pool),
CopyAttrsPooling, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.avg_pool_grad,
- GetMklOpName(csinfo_.avg_pool_grad),
+ mkl_op_registry::GetMklOpName(csinfo_.avg_pool_grad),
CopyAttrsPooling, AlwaysRewrite, nullptr});
// BiasAddGrad gets written into Conv2DWithBiasBackpropBias depending
// on if context contains Conv2D.
@@ -303,50 +318,62 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
CopyAttrsBiasAddGrad, ContextMatchRewrite,
&biasaddgrad_matmul_context_});
rinfo_.push_back({csinfo_.concat,
- GetMklOpName(csinfo_.concat),
+ mkl_op_registry::GetMklOpName(csinfo_.concat),
CopyAttrsConcat, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.concatv2,
- GetMklOpName(csinfo_.concatv2),
+ mkl_op_registry::GetMklOpName(csinfo_.concatv2),
CopyAttrsConcatV2, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.conv2d,
- GetMklOpName(csinfo_.conv2d),
+ mkl_op_registry::GetMklOpName(csinfo_.conv2d),
CopyAttrsConv2D, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.conv2d_grad_filter,
- GetMklOpName(csinfo_.conv2d_grad_filter),
+ mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_filter),
CopyAttrsConv2D, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.conv2d_grad_input,
- GetMklOpName(csinfo_.conv2d_grad_input),
+ mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_input),
CopyAttrsConv2D, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.fused_batch_norm,
- GetMklOpName(csinfo_.fused_batch_norm),
+ mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm),
CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.fused_batch_norm_grad,
- GetMklOpName(csinfo_.fused_batch_norm_grad),
+ mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad),
CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.identity,
- GetMklOpName(csinfo_.identity),
+ mkl_op_registry::GetMklOpName(csinfo_.identity),
CopyAttrsIdentity, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.lrn,
- GetMklOpName(csinfo_.lrn),
+ mkl_op_registry::GetMklOpName(csinfo_.lrn),
CopyAttrsLRN, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.lrn_grad,
- GetMklOpName(csinfo_.lrn_grad),
+ mkl_op_registry::GetMklOpName(csinfo_.lrn_grad),
CopyAttrsLRN, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.max_pool,
- GetMklOpName(csinfo_.max_pool),
+ mkl_op_registry::GetMklOpName(csinfo_.max_pool),
CopyAttrsPooling, NonDepthBatchWisePoolRewrite, nullptr});
rinfo_.push_back({csinfo_.max_pool_grad,
- GetMklOpName(csinfo_.max_pool_grad),
+ mkl_op_registry::GetMklOpName(csinfo_.max_pool_grad),
CopyAttrsPooling, AlwaysRewrite, nullptr});
+ rinfo_.push_back({csinfo_.maximum,
+ mkl_op_registry::GetMklOpName(csinfo_.maximum),
+ CopyAttrsDataType, AlwaysRewrite, nullptr});
+ rinfo_.push_back({csinfo_.mul,
+ mkl_op_registry::GetMklOpName(csinfo_.mul),
+ CopyAttrsDataType, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.relu,
- GetMklOpName(csinfo_.relu),
- CopyAttrsRelu, AlwaysRewrite, nullptr});
+ mkl_op_registry::GetMklOpName(csinfo_.relu),
+ CopyAttrsDataType, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.relu_grad,
- GetMklOpName(csinfo_.relu_grad),
- CopyAttrsRelu, AlwaysRewrite, nullptr});
+ mkl_op_registry::GetMklOpName(csinfo_.relu_grad),
+ CopyAttrsDataType, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.reshape,
- GetMklOpName(csinfo_.reshape),
+ mkl_op_registry::GetMklOpName(csinfo_.reshape),
CopyAttrsReshape, AlwaysRewrite, nullptr});
+ rinfo_.push_back({csinfo_.squared_difference,
+ mkl_op_registry::GetMklOpName(csinfo_.squared_difference),
+ CopyAttrsDataType, AlwaysRewrite, nullptr});
+ rinfo_.push_back({csinfo_.sub,
+ mkl_op_registry::GetMklOpName(csinfo_.sub),
+ CopyAttrsDataType, AlwaysRewrite, nullptr});
// Add info about which ops to add workspace edge to and the slots.
wsinfo_.push_back({csinfo_.lrn, csinfo_.lrn_grad, 0, 2, 1, 3});
@@ -429,6 +456,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
/// Structure to store all constant strings
/// NOTE: names are alphabetically sorted.
typedef struct {
+ string addn;
+ string add;
string avg_pool;
string avg_pool_grad;
string bias_add;
@@ -446,15 +475,19 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
string matmul;
string max_pool;
string max_pool_grad;
+ string maximum;
string mkl_conv2d;
string mkl_conv2d_grad_input;
string mkl_conv2d_grad_filter;
string mkl_conv2d_with_bias;
string mkl_conv2d_with_bias_backprop_bias;
+ string mul;
string relu;
string relu_grad;
string reshape;
string split;
+ string squared_difference;
+ string sub;
} ConstStringsInfo;
private:
@@ -502,15 +535,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
return N;
}
- // Get the name of Mkl op from original TensorFlow op
- // We prefix 'Mkl' to the original op to get Mkl op.
- // TODO(nhasabni) We should move this to mkl_util.h.
- inline string GetMklOpName(const string& name) const {
- // Prefix that we add to Tensorflow op name to construct Mkl op name.
- const char* const kMklOpPrefix = "_Mkl";
- return string(kMklOpPrefix) + name;
- }
-
// Can op represented by node 'n' run on DEVICE_CPU?
// Op can run on CPU with MKL if the runtime assigned device or the
// user requested device contains device CPU, or both are empty.
@@ -604,6 +628,19 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
return false;
}
+ static bool AddNRewrite(const Node* n, const ContextInfo* c) {
+ CHECK_NOTNULL(n);
+
+ int num;
+ CHECK_EQ(GetNodeAttr(n->def(), "N", &num).ok(), true);
+
+ // Condition that specifies non-batch-wise and non-depth-wise pooling.
+ if (num == 2) {
+ return true;
+ }
+
+ return false;
+ }
// Is BiasAddGrad node in 'n' is associated with Conv2DWithBias node
// specified in contextinfo 'ci'. Function updates fwd_node to point
// to Conv2DWithBias node if 'n' is associated with Conv2DWithBias.
@@ -907,15 +944,16 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// We need operator-specific function to copy attributes because the framework
// does not provide any generic function for it.
// NOTE: names are alphabetically sorted.
+ static void CopyAttrsAddN(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsBiasAddGrad(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsConcat(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsConcatV2(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsConv2D(const Node* orig_node, NodeBuilder* nb);
+ static void CopyAttrsDataType(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsFusedBatchNorm(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsIdentity(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsPooling(const Node* orig_node, NodeBuilder* nb);
- static void CopyAttrsRelu(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsReshape(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsSplit(const Node* orig_node, NodeBuilder* nb);
@@ -1334,7 +1372,7 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
for (auto ws : wsinfo_) {
if (orig_node->type_string() == ws.fwd_op &&
- mkl_op_registry::IsMklOp(GetMklOpName(orig_node->type_string()), T)) {
+ mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(orig_node->type_string()), T)) {
// If this op is a fwd op, then we need to check if there is an
// edge from this node's fwd_slot to bwdop's bwd_slot. If there is
// an edge, then we just add an attribute on this node for setting
@@ -1360,7 +1398,7 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
nb->Attr("workspace_enabled", false);
}
} else if (orig_node->type_string() == ws.bwd_op &&
- mkl_op_registry::IsMklOp(GetMklOpName(orig_node->type_string()),
+ mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(orig_node->type_string()),
T)) {
// If this op is a bwd op, then we need to add workspace edge and
// it's Mkl tensor edge between its corresponding fwd op and this
@@ -1376,7 +1414,7 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
if (e->src_output() == ws.fwd_slot &&
// We would have rewritten the forward op, so we need to use
// GetMklOpName call to get its Mkl name.
- e->src()->type_string() == GetMklOpName(ws.fwd_op) &&
+ e->src()->type_string() == mkl_op_registry::GetMklOpName(ws.fwd_op) &&
e->dst_input() == ws.bwd_slot) {
nb->Attr("workspace_enabled", true);
CHECK_NOTNULL(ws_tensors);
@@ -1455,6 +1493,20 @@ void MklLayoutRewritePass::CopyAttrsConv2D(const Node* orig_node,
nb->Attr("use_cudnn_on_gpu", use_cudnn_on_gpu);
}
+void MklLayoutRewritePass::CopyAttrsAddN(const Node* orig_node,
+ NodeBuilder* nb) {
+ DataType T;
+ int N;
+
+ // Get all attributes from old node.
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N));
+
+ // Add attributes to new node.
+ nb->Attr("T", T);
+ nb->Attr("N", N);
+}
+
void MklLayoutRewritePass::CopyAttrsBiasAddGrad(const Node* orig_node,
NodeBuilder* nb) {
DataType T;
@@ -1527,8 +1579,8 @@ void MklLayoutRewritePass::CopyAttrsPooling(const Node* orig_node,
nb->Attr("data_format", data_format);
}
-void MklLayoutRewritePass::CopyAttrsRelu(const Node* orig_node,
- NodeBuilder* nb) {
+void MklLayoutRewritePass::CopyAttrsDataType(const Node* orig_node,
+ NodeBuilder* nb) {
DataType T;
// Get all attributes from old node.
@@ -1894,7 +1946,15 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g,
}
// Get all inputs.
- const int num_inputs = orig_node->in_edges().size();
+ int num_inputs = orig_node->in_edges().size();
+
+ // Drop count for control edges from inputs
+ for (const Edge* e : orig_node->in_edges()) {
+ if (e->IsControlEdge()) {
+ num_inputs--;
+ }
+ }
+
gtl::InlinedVector<Node*, 4> control_edges;
gtl::InlinedVector<std::pair<Node*, int>, 4> inputs(num_inputs);
FillInputs(orig_node, &control_edges, &inputs);
@@ -2008,7 +2068,34 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
// BiasAddGrad is not an Mkl layer, so we make an exception for it.
if (n->type_string() != csinfo_.bias_add_grad) {
- if (!mkl_op_registry::IsMklOp(GetMklOpName(n->type_string()), T)) {
+ if (!mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()), T)) {
+ return nullptr;
+ }
+ }
+
+ // For elementwise node, we reuse the Eigen implementation and pass the MKL
+ // metadata tensor through so we can avoid conversions. However, if all
+ // incoming edges are in TF format, we don't need all this overhead, so
+ // replace the elementwise node only if at least one of its parents is a MKL
+ // node.
+ //
+ // TODO(vrane): Add implementation for element-wise ops that doesn't reuse
+ // eigen code to reduce cross-library dependency.
+ if (mkl_op_registry::IsMklElementWiseOp(
+ mkl_op_registry::GetMklOpName(n->type_string()), T)) {
+ bool incoming_mkl_edge = false;
+ for (auto parent : n->in_edges()) {
+ if (mkl_op_registry::IsMklOp(
+ mkl_op_registry::GetMklOpName(parent->src()->type_string()), T)) {
+ incoming_mkl_edge = true;
+ break;
+ } else {
+ VLOG(1) << "Non-MKL parent is: " << parent->src()->type_string();
+ }
+ }
+ if (incoming_mkl_edge == false) {
+ VLOG(1) << "Skipping replacement of elementwise node which has no MKL "
+ "parents.";
return nullptr;
}
}
diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc
index bd1d74368e..6a41e3965a 100644
--- a/tensorflow/core/graph/mkl_layout_pass_test.cc
+++ b/tensorflow/core/graph/mkl_layout_pass_test.cc
@@ -133,19 +133,19 @@ TEST_F(MklLayoutPassTest, Basic) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
- "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }"
- "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(Mul);D(Mul)|"
+ "A(Input);B(Input);C(Zeta);D(Zeta)|"
"A->C;A->D;B->C:1;B->D:1");
}
// Test set 1: Conv2D + AddBias
-// C=_MklConv2D(A,M,B,N); E=BiasAdd(C,D); Z=Sub(E,Y) (for interleaved ordering)
-// C=_MklConv2D(A,B,M,N); E=BiasAdd(C,D); Z=Sub(E,Y) (for contiguous ordering)
+// C=_MklConv2D(A,M,B,N); E=BiasAdd(C,D); Z=Zeta(E,Y) (for interleaved ordering)
+// C=_MklConv2D(A,B,M,N); E=BiasAdd(C,D); Z=Zeta(E,Y) (for contiguous ordering)
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive) {
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
InitGraph(
@@ -166,18 +166,18 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive) {
" attr { key: 'data_format' value { s: 'NCHW' } }"
" input: ['C', 'D'] }"
"node { name: 'Y' op: 'Input'}"
- "node { name: 'Z' op: 'Sub'"
+ "node { name: 'Z' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['E', 'Y']}");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);D(Input);DMT/_0(Const);E(_MklConv2DWithBias);"
- "M(_MklInput);N(_MklInput);Y(Input);Z(Sub)|A->E;"
+ "M(_MklInput);N(_MklInput);Y(Input);Z(Zeta)|A->E;"
"A:control->DMT/_0:control;B->E:1;D->E:2;DMT/_0->E:5;E->Z;M->E:3;"
"N->E:4;Y->Z:1");
}
-// C=_MklConv2D(A,M:1,B,N:1); E=BiasAdd(C,D); Z=Sub(E,Y) (for interleaved)
-// C=_MklConv2D(A,B,M:1,N:1); E=BiasAdd(C,D); Z=Sub(E,Y) (for contiguous)
+// C=_MklConv2D(A,M:1,B,N:1); E=BiasAdd(C,D); Z=Zeta(E,Y) (for interleaved)
+// C=_MklConv2D(A,B,M:1,N:1); E=BiasAdd(C,D); Z=Zeta(E,Y) (for contiguous)
// Test for correct output slots selected
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive1) {
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
@@ -199,17 +199,17 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive1) {
" attr { key: 'data_format' value { s: 'NCHW' } }"
" input: ['C', 'D'] }"
"node { name: 'Y' op: 'Input'}"
- "node { name: 'Z' op: 'Sub'"
+ "node { name: 'Z' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['E', 'Y']}");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);D(Input);DMT/_0(Const);E(_MklConv2DWithBias);"
- "M(_MklInput2);N(_MklInput2);Y(Input);Z(Sub)|A->E;"
+ "M(_MklInput2);N(_MklInput2);Y(Input);Z(Zeta)|A->E;"
"A:control->DMT/_0:control;B->E:1;D->E:2;DMT/_0->E:5;E->Z;"
"M:1->E:3;N:1->E:4;Y->Z:1");
}
-// C=Conv2D(A,B); E=BiasAdd(C,D); Z=Sub(E,Y);
+// C=Conv2D(A,B); E=BiasAdd(C,D); Z=Zeta(E,Y);
// This is a case of node rewrite followed by node merge.
// We will first rewrite Conv2D to _MklConv2D, and then merge _MklConv2D
// with BiasAdd to produce _MklConv2DWithBias.
@@ -231,12 +231,12 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive2) {
" attr { key: 'data_format' value { s: 'NCHW' } }"
" input: ['C', 'D'] }"
"node { name: 'Y' op: 'Input'}"
- "node { name: 'Z' op: 'Sub'"
+ "node { name: 'Z' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['E', 'Y']}");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
- "DMT/_2(Const);E(_MklConv2DWithBias);Y(Input);Z(Sub)|"
+ "DMT/_2(Const);E(_MklConv2DWithBias);Y(Input);Z(Zeta)|"
"A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;B->E:1;D->E:2;DMT/_0->E:3;DMT/_1->E:4;"
"DMT/_2->E:5;E->Z;Y->Z:1");
@@ -286,7 +286,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_Dataflow1) {
"M(_MklInput);N(_MklInput)|A->C;B->C:1;D->F;E->F:1;M->C:2;N->C:3");
}
-// _MklConv2D has two outgoing edges: BiasAdd and some other dummy node (Add).
+// _MklConv2D has two outgoing edges: BiasAdd and some other dummy node (Zeta).
// Merge should not be done in such case.
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_Dataflow2) {
InitGraph(
@@ -308,12 +308,12 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_Dataflow2) {
" attr { key: 'data_format' value { s: 'NCHW' } }"
" input: ['D', 'E'] }" // Conv2D has two outputs.
// No merge should happen.
- "node { name: 'G' op: 'Add'"
+ "node { name: 'G' op: 'Zeta'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'E'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(_MklConv2D);D(Input);E(Input);F(BiasAdd);"
- "G(Add);M(_MklInput);N(_MklInput)|A->C;B->C:1;C->G;D->F;"
+ "G(Zeta);M(_MklInput);N(_MklInput)|A->C;B->C:1;C->G;D->F;"
"E->F:1;E->G:1;M->C:2;N->C:3");
}
@@ -362,7 +362,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Positive) {
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" input: ['A', 'B', 'C', 'M', 'N', 'O']}"
- "node { name: 'E' op: 'Sub'"
+ "node { name: 'E' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['D', 'A']}"
"node { name: 'F' op: 'Int32Input'}"
@@ -387,7 +387,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Positive) {
" input: ['E'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(_MklConv2DWithBias);DMT/_0(Const);"
- "E(Sub);F(Int32Input);G(_MklConv2DBackpropFilter);H(Int32Input);"
+ "E(Zeta);F(Int32Input);G(_MklConv2DBackpropFilter);H(Int32Input);"
"I(_MklConv2DBackpropInput);J(_MklConv2DWithBiasBackpropBias);"
"M(_MklInput);N(_MklInput);O(_MklInput)|A->D;A->E:1;A->G;B->D:1;"
"B->I:1;C->D:2;D->E;DMT/_0->J:1;E->G:2;E->I:2;E->J;"
@@ -413,7 +413,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Negative1) {
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" input: ['A', 'B', 'C', 'M', 'N', 'O']}"
- "node { name: 'E' op: 'Sub'"
+ "node { name: 'E' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['D', 'A']}"
"node { name: 'F' op: 'Int32Input'}"
@@ -438,7 +438,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Negative1) {
" input: ['E'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(_MklConv2DWithBias);"
- "E(Sub);F(Int32Input);G(_MklConv2DBackpropFilter);H(Int32Input);"
+ "E(Zeta);F(Int32Input);G(_MklConv2DBackpropFilter);H(Int32Input);"
"I(_MklConv2DBackpropInput);J(BiasAddGrad);"
"M(_MklInput);N(_MklInput);O(_MklInput)|A->D;A->E:1;A->G:2;B->D:1;"
"B->I:1;C->D:2;D->E;E->G;E->I:2;E->J;F->G:1;H->I;M->D:3;M->G:3;"
@@ -463,7 +463,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Negative2) {
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" input: ['B', 'A', 'C', 'M', 'N', 'O']}"
- "node { name: 'E' op: 'Sub'"
+ "node { name: 'E' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['D', 'A']}"
"node { name: 'F' op: 'Int32Input'}"
@@ -488,7 +488,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Negative2) {
" input: ['E'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(_MklConv2DWithBias);"
- "E(Sub);F(Int32Input);G(_MklConv2DBackpropFilter);H(Int32Input);"
+ "E(Zeta);F(Int32Input);G(_MklConv2DBackpropFilter);H(Int32Input);"
"I(_MklConv2DBackpropInput);J(BiasAddGrad);"
"M(_MklInput);N(_MklInput);O(_MklInput)|A->D:1;A->E:1;A->G;B->D;"
"B->I:1;C->D:2;D->E;E->G:2;E->I:2;E->J;F->G:1;H->I;M->D:3;M->G:3;"
@@ -512,7 +512,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_BpropFilter_Positive) {
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" input: ['A', 'B', 'C', 'M', 'N', 'O']}"
- "node { name: 'E' op: 'Sub'"
+ "node { name: 'E' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['D', 'A']}"
"node { name: 'F' op: 'Int32Input'}"
@@ -529,7 +529,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_BpropFilter_Positive) {
" input: ['E'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(_MklConv2DWithBias);DMT/_0(Const);"
- "E(Sub);F(Int32Input);G(_MklConv2DBackpropFilter);"
+ "E(Zeta);F(Int32Input);G(_MklConv2DBackpropFilter);"
"H(_MklConv2DWithBiasBackpropBias);M(_MklInput);N(_MklInput);"
"O(_MklInput)|A->D;A->E:1;A->G;B->D:1;C->D:2;D->E;DMT/_0->H:1;"
"E->G:2;E->H;E:control->DMT/_0:control;F->G:1;M->D:3;M->G:3;"
@@ -553,7 +553,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_BpropFilter_Negative1) {
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" input: ['A', 'B', 'C', 'M', 'N', 'O']}"
- "node { name: 'E' op: 'Sub'"
+ "node { name: 'E' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['D', 'A']}"
"node { name: 'F' op: 'Int32Input'}"
@@ -570,7 +570,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_BpropFilter_Negative1) {
" input: ['E'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(_MklConv2DWithBias);"
- "E(Sub);F(Int32Input);G(_MklConv2DBackpropFilter);H(BiasAddGrad);"
+ "E(Zeta);F(Int32Input);G(_MklConv2DBackpropFilter);H(BiasAddGrad);"
"M(_MklInput);N(_MklInput);O(_MklInput)|A->D;A->E:1;A->G:2;B->D:1;"
"C->D:2;D->E;E->G;E->H;F->G:1;M->D:3;M->G:3;N->D:4;N->G:4;O->D:5;"
"O->G:5");
@@ -593,7 +593,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_BpropFilter_Negative2) {
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" input: ['B', 'A', 'C', 'M', 'N', 'O']}"
- "node { name: 'E' op: 'Sub'"
+ "node { name: 'E' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['D', 'A']}"
"node { name: 'F' op: 'Int32Input'}"
@@ -610,7 +610,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_BpropFilter_Negative2) {
" input: ['E'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(_MklConv2DWithBias);"
- "E(Sub);F(Int32Input);G(_MklConv2DBackpropFilter);H(BiasAddGrad);"
+ "E(Zeta);F(Int32Input);G(_MklConv2DBackpropFilter);H(BiasAddGrad);"
"M(_MklInput);N(_MklInput);O(_MklInput)|A->D:1;A->E:1;A->G;B->D;"
"C->D:2;D->E;E->G:2;E->H;F->G:1;M->D:3;M->G:3;N->D:4;N->G:4;O->D:5;"
"O->G:5");
@@ -618,8 +618,8 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_BpropFilter_Negative2) {
// No _MklConv2DWithBias in context, but _MklConv2D in context.
// No rewrite for BiasAddGrad should happen.
-// C=_MklConv2D(A,M,B,N); D=Sub(C,A); E=BiasAddGrad(D) (for interleaved)
-// C=_MklConv2D(A,B,M,N); D=Sub(C,A); E=BiasAddGrad(D) (for contiguous)
+// C=_MklConv2D(A,M,B,N); D=Zeta(C,A); E=BiasAddGrad(D) (for interleaved)
+// C=_MklConv2D(A,B,M,N); D=Zeta(C,A); E=BiasAddGrad(D) (for contiguous)
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Neg_NoMklConv2DWithBias) {
InitGraph(
"node { name: 'A' op: 'Input'}"
@@ -633,7 +633,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Neg_NoMklConv2DWithBias) {
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" input: ['A', 'B', 'M', 'N']}"
- "node { name: 'D' op: 'Sub'"
+ "node { name: 'D' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'A']}"
"node { name: 'E' op: 'BiasAddGrad'"
@@ -641,21 +641,21 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Neg_NoMklConv2DWithBias) {
" attr { key: 'data_format' value { s: 'NCHW' } }"
" input: ['D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(_MklConv2D);D(Sub);E(BiasAddGrad);"
+ "A(Input);B(Input);C(_MklConv2D);D(Zeta);E(BiasAddGrad);"
"M(_MklInput);N(_MklInput)|A->C;A->D:1;B->C:1;C->D;D->E;"
"M->C:2;N->C:3");
}
// No Conv2D in the context for BiasAddGrad. No rewrite should happen.
-// C=Add(A,B); D=Sub(C,A); E=BiasAddGrad(D)
+// C=Polygamma(A,B); D=Zeta(C,A); E=BiasAddGrad(D)
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Negative_NoConv2D) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
- "node { name: 'C' op: 'Add'"
+ "node { name: 'C' op: 'Polygamma'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B']}"
- "node { name: 'D' op: 'Sub'"
+ "node { name: 'D' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'A']}"
"node { name: 'E' op: 'BiasAddGrad'"
@@ -663,13 +663,13 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Negative_NoConv2D) {
" attr { key: 'data_format' value { s: 'NCHW' } }"
" input: ['D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(Add);D(Sub);E(BiasAddGrad)|"
+ "A(Input);B(Input);C(Polygamma);D(Zeta);E(BiasAddGrad)|"
"A->C;A->D:1;B->C:1;C->D;D->E");
}
// No Conv2D in the context for BiasAddGrad, but MatMul in context.
// Rewrite should happen, but name of BiasAddGrad does not change.
-// C=MatMul(A,B); D=Sub(C,A); E=BiasAddGrad(D)
+// C=MatMul(A,B); D=Zeta(C,A); E=BiasAddGrad(D)
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Negative_NoConv2D_MatMul) {
InitGraph(
"node { name: 'A' op: 'Input'}"
@@ -679,7 +679,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Negative_NoConv2D_MatMul) {
" attr { key: 'transpose_a' value { b: false } }"
" attr { key: 'transpose_b' value { b: false } }"
" input: ['A', 'B']}"
- "node { name: 'D' op: 'Sub'"
+ "node { name: 'D' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'A']}"
"node { name: 'E' op: 'BiasAddGrad'"
@@ -687,12 +687,12 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Negative_NoConv2D_MatMul) {
" attr { key: 'data_format' value { s: 'NCHW' } }"
" input: ['D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(MatMul);D(Sub);E(BiasAddGrad)|"
+ "A(Input);B(Input);C(MatMul);D(Zeta);E(BiasAddGrad)|"
"A->C;A->D:1;B->C:1;C->D;D->E");
}
// Test set 3: MatMul..BiasAddGrad -> BiasAddGrad rewrite tests
-// C=MatMul(A,B); D=Sub(C,A); E=BiasAddGrad(D)
+// C=MatMul(A,B); D=Zeta(C,A); E=BiasAddGrad(D)
TEST_F(MklLayoutPassTest, NodeMerge_MatMulBiasAddGrad_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
@@ -702,7 +702,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_MatMulBiasAddGrad_Positive) {
" attr { key: 'transpose_a' value { b: false } }"
" attr { key: 'transpose_b' value { b: false } }"
" input: ['A', 'B']}"
- "node { name: 'D' op: 'Sub'"
+ "node { name: 'D' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'A']}"
"node { name: 'E' op: 'BiasAddGrad'"
@@ -710,20 +710,20 @@ TEST_F(MklLayoutPassTest, NodeMerge_MatMulBiasAddGrad_Positive) {
" attr { key: 'data_format' value { s: 'NCHW' } }"
" input: ['D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(MatMul);D(Sub);E(BiasAddGrad)|"
+ "A(Input);B(Input);C(MatMul);D(Zeta);E(BiasAddGrad)|"
"A->C;A->D:1;B->C:1;C->D;D->E");
}
// No MatMul in the context for BiasAddGrad. No rewrite should happen.
-// C=Add(A,B); D=Sub(C,A); E=BiasAddGrad(D)
+// C=Polygamma(A,B); D=Zeta(C,A); E=BiasAddGrad(D)
TEST_F(MklLayoutPassTest, NodeMerge_MatMulBiasAddGrad_Negative_NoMatMul) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
- "node { name: 'C' op: 'Add'"
+ "node { name: 'C' op: 'Polygamma'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B']}"
- "node { name: 'D' op: 'Sub'"
+ "node { name: 'D' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'A']}"
"node { name: 'E' op: 'BiasAddGrad'"
@@ -731,7 +731,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_MatMulBiasAddGrad_Negative_NoMatMul) {
" attr { key: 'data_format' value { s: 'NCHW' } }"
" input: ['D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(Add);D(Sub);E(BiasAddGrad)|"
+ "A(Input);B(Input);C(Polygamma);D(Zeta);E(BiasAddGrad)|"
"A->C;A->D:1;B->C:1;C->D;D->E");
}
@@ -752,10 +752,10 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Basic) {
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" input: ['A', 'B']}"
- "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['B', 'C'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(_MklConv2D);D(Mul);DMT/_0(Const);"
+ "A(Input);B(Input);C(_MklConv2D);D(Zeta);DMT/_0(Const);"
"DMT/_1(Const)|A->C;A:control->DMT/_0:control;"
"A:control->DMT/_1:control;B->C:1;B->D;C->D:1;DMT/_0->C:2;"
"DMT/_1->C:3");
@@ -781,11 +781,11 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Positive1) {
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" input: ['A', 'C']}"
- "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(_MklConv2D);D(_MklConv2D);DMT/_0(Const);"
- "DMT/_1(Const);DMT/_2(Const);E(Mul)|A->C;A->D;"
+ "DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->C;A->D;"
"A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;B->C:1;C->D:1;C->E;"
"C:2->D:3;D->E:1;DMT/_0->C:2;DMT/_1->C:3;DMT/_2->D:2");
@@ -803,10 +803,10 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Negative_UnsupportedType) {
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" input: ['A', 'B']}"
- "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_HALF } }"
+ "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_HALF } }"
" input: ['B', 'C'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(HalfInput);B(HalfInput);C(Conv2D);D(Mul)|"
+ "A(HalfInput);B(HalfInput);C(Conv2D);D(Zeta)|"
"A->C;B->C:1;B->D;C->D:1");
}
@@ -822,11 +822,11 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradFilter_Positive) {
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" input: ['A', 'B', 'C']}"
- "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Int32Input);C(Input);D(_MklConv2DBackpropFilter);"
- "DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Mul)|"
+ "DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Zeta)|"
"A->D;A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;"
"DMT/_1->D:4;DMT/_2->D:5");
@@ -844,11 +844,11 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradInput_Positive) {
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" input: ['B', 'A', 'C']}"
- "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Int32Input);C(Input);D(_MklConv2DBackpropInput);"
- "DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Mul)|"
+ "DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Zeta)|"
"A->D:1;A->E;B->D;B:control->DMT/_0:control;"
"B:control->DMT/_1:control;B:control->DMT/_2:control;C->D:2;"
"D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
@@ -869,11 +869,11 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Basic) {
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'N' value { i: 2 } }"
" input: ['A', 'B:0', 'B:1']}"
- "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);"
- "DMT/_1(Const);DMT/_2(Const);E(Mul)|A->D;A:control->DMT/_0:control;"
+ "DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;A:control->DMT/_0:control;"
"A:control->DMT/_1:control;A:control->DMT/_2:control;B->D:1;"
"B:1->D:2;C->E;D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
}
@@ -908,12 +908,12 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_Mkl) {
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'N' value { i: 2 } }"
" input: ['G', 'E', 'F']}"
- "node { name: 'I' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'H'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);"
- "F(_MklConv2D);G(Const);H(_MklConcat);I(Mul)|A->E;A->I;"
+ "F(_MklConv2D);G(Const);H(_MklConcat);I(Zeta)|A->E;A->I;"
"A:control->DMT/_2:control;A:control->DMT/_3:control;"
"B->E:1;C->F;C:control->DMT/_0:control;C:control->DMT/_1:control;"
"D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;"
@@ -935,7 +935,7 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_MixedMkl) {
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" input: ['A', 'B']}"
- "node { name: 'F' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ "node { name: 'F' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'D']}"
"node { name: 'G' op: 'Const' "
" attr { key: 'dtype' value { type: DT_INT32 } }"
@@ -946,12 +946,12 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_MixedMkl) {
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'N' value { i: 2 } }"
" input: ['G', 'E', 'F']}"
- "node { name: 'I' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'H'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
- "DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Mul);G(Const);"
- "H(_MklConcat);I(Mul)|A->E;A->I;A:control->DMT/_0:control;"
+ "DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Zeta);G(Const);"
+ "H(_MklConcat);I(Zeta)|A->E;A->I;A:control->DMT/_0:control;"
"A:control->DMT/_1:control;B->E:1;C->F;D->F:1;DMT/_0->E:2;"
"DMT/_1->E:3;DMT/_2->H:3;DMT/_3->H:5;E->H:1;E:2->H:4;F->H:2;"
"G->H;G:control->DMT/_2:control;G:control->DMT/_3:control;H->I:1");
@@ -973,11 +973,11 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Basic) {
" attr { key: 'Tidx' value { type: DT_INT32 } }"
" attr { key: 'N' value { i: 2 } }"
" input: ['B:0', 'B:1', 'A']}"
- "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Const);B(InputList);C(Input);D(_MklConcatV2);DMT/_0(Const);"
- "DMT/_1(Const);DMT/_2(Const);E(Mul)|A->D:2;B->D;B:1->D:1;"
+ "DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D:2;B->D;B:1->D:1;"
"B:control->DMT/_0:control;B:control->DMT/_1:control;"
"B:control->DMT/_2:control;C->E;D->E:1;DMT/_0->D:3;"
"DMT/_1->D:4;DMT/_2->D:5");
@@ -1014,12 +1014,12 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_Mkl) {
" attr { key: 'Tidx' value { type: DT_INT32 } }"
" attr { key: 'N' value { i: 2 } }"
" input: ['E', 'F', 'G']}"
- "node { name: 'I' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'H'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);"
- "F(_MklConv2D);G(Const);H(_MklConcatV2);I(Mul)|A->E;A->I;"
+ "F(_MklConv2D);G(Const);H(_MklConcatV2);I(Zeta)|A->E;A->I;"
"A:control->DMT/_2:control;A:control->DMT/_3:control;B->E:1;C->F;"
"C:control->DMT/_0:control;C:control->DMT/_1:control;"
"D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;"
@@ -1041,7 +1041,7 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_MixedMkl) {
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" input: ['A', 'B']}"
- "node { name: 'F' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ "node { name: 'F' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'D']}"
"node { name: 'G' op: 'Const' "
" attr { key: 'dtype' value { type: DT_INT32 } }"
@@ -1053,12 +1053,12 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_MixedMkl) {
" attr { key: 'Tidx' value { type: DT_INT32 } }"
" attr { key: 'N' value { i: 2 } }"
" input: ['E', 'F', 'G']}"
- "node { name: 'I' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'H'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
- "DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Mul);G(Const);"
- "H(_MklConcatV2);I(Mul)|A->E;A->I;A:control->DMT/_0:control;"
+ "DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Zeta);G(Const);"
+ "H(_MklConcatV2);I(Zeta)|A->E;A->I;A:control->DMT/_0:control;"
"A:control->DMT/_1:control;B->E:1;C->F;D->F:1;DMT/_0->E:2;"
"DMT/_1->E:3;DMT/_2->H:4;DMT/_3->H:5;E->H;E:2->H:3;"
"E:control->DMT/_2:control;E:control->DMT/_3:control;F->H:1;"
@@ -1071,10 +1071,10 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Relu_Positive) {
"node { name: 'B' op: 'Relu'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A'] }"
- "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(_MklRelu);C(Mul);DMT/_0(Const)|A->B;A->C;"
+ "A(Input);B(_MklRelu);C(Zeta);DMT/_0(Const)|A->B;A->C;"
"A:control->DMT/_0:control;B->C:1;DMT/_0->B:1");
}
@@ -1085,10 +1085,10 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ReluGrad_Positive) {
"node { name: 'C' op: 'ReluGrad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }"
- "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'C'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(_MklReluGrad);D(Mul);DMT/_0(Const);"
+ "A(Input);B(Input);C(_MklReluGrad);D(Zeta);DMT/_0(Const);"
"DMT/_1(Const)|A->C;A->D;A:control->DMT/_0:control;"
"A:control->DMT/_1:control;B->C:1;C->D:1;DMT/_0->C:2;DMT/_1->C:3");
}
@@ -1102,10 +1102,10 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ReluReluGrad_Positive) {
"node { name: 'C' op: 'ReluGrad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }"
- "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'C'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(_MklRelu);C(_MklReluGrad);D(Mul);DMT/_0(Const);"
+ "A(Input);B(_MklRelu);C(_MklReluGrad);D(Zeta);DMT/_0(Const);"
"DMT/_1(Const)|A->B;A->C;A->D;A:control->DMT/_0:control;"
"A:control->DMT/_1:control;B->C:1;B:1->C:3;C->D:1;DMT/_0->B:1;"
"DMT/_1->C:2");
@@ -1121,10 +1121,10 @@ TEST_F(MklLayoutPassTest, NodeRewrite_AvgPool_Positive) {
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }"
" input: ['A'] }"
- "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(_MklAvgPool);C(Mul);DMT/_0(Const)|A->B;A->C;"
+ "A(Input);B(_MklAvgPool);C(Zeta);DMT/_0(Const)|A->B;A->C;"
"A:control->DMT/_0:control;B->C:1;DMT/_0->B:1");
}
@@ -1139,10 +1139,10 @@ TEST_F(MklLayoutPassTest, NodeRewrite_AvgPoolGrad_Positive) {
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }"
" input: ['A', 'B'] }"
- "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['B', 'C'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Int32Input);B(Input);C(_MklAvgPoolGrad);D(Mul);DMT/_0(Const);"
+ "A(Int32Input);B(Input);C(_MklAvgPoolGrad);D(Zeta);DMT/_0(Const);"
"DMT/_1(Const)|A->C;A:control->DMT/_0:control;"
"A:control->DMT/_1:control;B->C:1;B->D;C->D:1;DMT/_0->C:2;"
"DMT/_1->C:3");
@@ -1166,10 +1166,10 @@ TEST_F(MklLayoutPassTest, NodeRewrite_AvgPoolAvgPoolGrad_Positive) {
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }"
" input: ['I', 'B'] }"
- "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'C'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(_MklAvgPool);C(_MklAvgPoolGrad);D(Mul);DMT/_0(Const);"
+ "A(Input);B(_MklAvgPool);C(_MklAvgPoolGrad);D(Zeta);DMT/_0(Const);"
"DMT/_1(Const);I(Int32Input)|A->B;A->D;A:control->DMT/_0:control;"
"B->C:1;B:1->C:3;C->D:1;DMT/_0->B:1;DMT/_1->C:2;I->C;"
"I:control->DMT/_1:control");
@@ -1188,12 +1188,12 @@ TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNormGrad_Positive) {
" attr { key: 'epsilon' value { f: 0.0001 } }"
" attr { key: 'is_training' value { b: true } }"
" input: ['A', 'B', 'C', 'D', 'E'] }"
- "node { name: 'G' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ "node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'F'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Input);"
- "F(_MklFusedBatchNormGrad);G(Mul)|A->F;A->G;"
+ "F(_MklFusedBatchNormGrad);G(Zeta)|A->F;A->G;"
"A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;A:control->DMT/_3:control;"
"A:control->DMT/_4:control;B->F:1;C->F:2;D->F:3;"
@@ -1214,12 +1214,12 @@ TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNorm_Positive) {
" attr { key: 'epsilon' value { f: 0.0001 } }"
" attr { key: 'is_training' value { b: true } }"
" input: ['A', 'B', 'C', 'D', 'E'] }"
- "node { name: 'G' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ "node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'F'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Input);"
- "F(_MklFusedBatchNorm);G(Mul)|A->F;A->G;"
+ "F(_MklFusedBatchNorm);G(Zeta)|A->F;A->G;"
"A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;A:control->DMT/_3:control;"
"A:control->DMT/_4:control;B->F:1;C->F:2;D->F:3;"
@@ -1268,12 +1268,12 @@ TEST_F(MklLayoutPassTest, MaxPoolLRN_Positive) {
" attr { key: 'depth_radius' value { i: 2 } }"
" input: ['E', 'F', 'B'] }"
"node { name: 'H' op: 'Input'}"
- "node { name: 'I' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['H', 'G'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(_MklLRN);C(_MklMaxPool);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);E(_MklMaxPoolGrad);F(Input);G(_MklLRNGrad);H(Input);"
- "I(Mul)|A->B;A:control->DMT/_0:control;B->C;B->E;B->G:2;B:1->G:3;"
+ "I(Zeta)|A->B;A:control->DMT/_0:control;B->C;B->E;B->G:2;B:1->G:3;"
"B:2->C:1;B:2->E:4;B:2->G:6;B:3->G:7;B:control->DMT/_1:control;C->E:1;"
"C:1->E:3;C:2->E:5;C:3->E:7;D->E:2;DMT/_0->B:1;DMT/_1->E:6;DMT/_2->G:5;"
"E->G;E:1->G:4;E:control->DMT/_2:control;F->G:1;G->I:1;H->I");
@@ -1301,11 +1301,11 @@ TEST_F(MklLayoutPassTest, LRN_Positive) {
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'depth_radius' value { i: 2 } }"
" input: ['C', 'D', 'B'] }"
- "node { name: 'F' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ "node { name: 'F' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'E'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(_MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
- "DMT/_2(Const);E(_MklLRNGrad);F(Mul)|"
+ "DMT/_2(Const);E(_MklLRNGrad);F(Zeta)|"
"A->B;A:control->DMT/_0:control;B->E:2;B:1->E:3;B:2->E:6;B:3->E:7;"
"C->E;C->F;C:control->DMT/_1:control;C:control->DMT/_2:control;"
"D->E:1;DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:5;E->F:1");
@@ -1323,10 +1323,10 @@ TEST_F(MklLayoutPassTest, LRN_Negative1) {
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'depth_radius' value { i: 2 } }"
" input: ['A'] }"
- "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(_MklLRN);C(Mul);DMT/_0(Const)|"
+ "A(Input);B(_MklLRN);C(Zeta);DMT/_0(Const)|"
"A->B;A->C;A:control->DMT/_0:control;B->C:1;DMT/_0->B:1");
}
@@ -1344,11 +1344,11 @@ TEST_F(MklLayoutPassTest, LRN_Negative2) {
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'depth_radius' value { i: 2 } }"
" input: ['A', 'B', 'C'] }"
- "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(_MklLRNGrad);DMT/_0(Const);"
- "DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Mul)|"
+ "DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Zeta)|"
"A->D;A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;A:control->DMT/_3:control;"
"A:control->DMT/_4:control;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;"
@@ -1386,12 +1386,12 @@ TEST_F(MklLayoutPassTest, LRN_Negative3) {
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'depth_radius' value { i: 2 } }"
" input: ['C', 'B', 'D'] }"
- "node { name: 'G' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ "node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['E', 'F'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(_MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);DMT/_5(Const);"
- "DMT/_6(Const);E(_MklLRNGrad);F(_MklLRNGrad);G(Mul)|A->B;"
+ "DMT/_6(Const);E(_MklLRNGrad);F(_MklLRNGrad);G(Zeta)|A->B;"
"A:control->DMT/_0:control;B->E:2;"
"B->F:1;B:1->E:3;B:2->E:6;B:2->F:5;B:3->E:7;C->E;C->F;"
"C:control->DMT/_1:control;C:control->DMT/_2:control;"
@@ -1421,11 +1421,11 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Positive) {
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }"
" input: ['C', 'B', 'D'] }"
- "node { name: 'F' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ "node { name: 'F' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'E'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(_MklMaxPool);C(Input);D(Input);DMT/_0(Const);"
- "DMT/_1(Const);DMT/_2(Const);E(_MklMaxPoolGrad);F(Mul)|"
+ "DMT/_1(Const);DMT/_2(Const);E(_MklMaxPoolGrad);F(Zeta)|"
"A->B;A:control->DMT/_0:control;B->E:1;B:1->E:3;B:2->E:5;B:3->E:7;"
"C->E;C->F;C:control->DMT/_1:control;C:control->DMT/_2:control;"
"D->E:2;DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:6;E->F:1");
@@ -1444,10 +1444,10 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative1) {
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }"
" input: ['A'] }"
- "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(_MklMaxPool);C(Mul);DMT/_0(Const)|"
+ "A(Input);B(_MklMaxPool);C(Zeta);DMT/_0(Const)|"
"A->B;A->C;A:control->DMT/_0:control;B->C:1;DMT/_0->B:1");
}
@@ -1466,11 +1466,11 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative2) {
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }"
" input: ['A', 'B', 'C'] }"
- "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(_MklMaxPoolGrad);DMT/_0(Const);"
- "DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Mul)|"
+ "DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Zeta)|"
"A->D;A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;A:control->DMT/_3:control;"
"A:control->DMT/_4:control;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;"
@@ -1489,10 +1489,10 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative3) {
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A'] }"
- "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1");
+ "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
}
// Test MaxPool handling for batch-wise pooling (NCHW)
@@ -1507,10 +1507,10 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative4) {
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 2, i:1, i:1, i:1} } }"
" input: ['A'] }"
- "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1");
+ "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
}
// Test MaxPool handling for depth-wise pooling (NHWC)
@@ -1525,10 +1525,10 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative5) {
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A'] }"
- "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1");
+ "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
}
// Test MaxPool handling for depth-wise pooling (NCHW)
@@ -1543,10 +1543,10 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative6) {
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 1, i:2, i:1, i:1} } }"
" input: ['A'] }"
- "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1");
+ "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
}
// Test MaxPool handling for batch-wise pooling (NHWC)
@@ -1561,10 +1561,10 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative7) {
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A'] }"
- "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1");
+ "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
}
// Test MaxPool handling for batch-wise pooling (NHWC)
@@ -1579,10 +1579,10 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative8) {
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 2, i:1, i:1, i:1} } }"
" input: ['A'] }"
- "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1");
+ "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
}
// Test MaxPool handling for depth-wise pooling (NHWC)
@@ -1597,10 +1597,10 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative9) {
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A'] }"
- "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1");
+ "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
}
// Test MaxPool handling for depth-wise pooling (NHWC)
@@ -1615,10 +1615,10 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative10) {
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:2} } }"
" input: ['A'] }"
- "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1");
+ "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
}
/////////////////////////////////////////////////////////////////////
@@ -1636,10 +1636,10 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_DeviceTest) {
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" input: ['A', 'B']}"
- "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['B', 'C'] }", kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(Conv2D);D(Mul)|A->C;B->C:1;B->D;C->D:1");
+ "A(Input);B(Input);C(Conv2D);D(Zeta)|A->C;B->C:1;B->D;C->D:1");
}
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_DeviceTest) {
@@ -1657,7 +1657,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_DeviceTest) {
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" input: ['A', 'B', 'C', 'M', 'N', 'O']}"
- "node { name: 'E' op: 'Sub'"
+ "node { name: 'E' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['D', 'A']}"
"node { name: 'F' op: 'BiasAddGrad'"
@@ -1666,7 +1666,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_DeviceTest) {
" input: ['E'] }", kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(_MklConv2DWithBias);"
- "E(Sub);F(BiasAddGrad);M(_MklInput);N(_MklInput);"
+ "E(Zeta);F(BiasAddGrad);M(_MklInput);N(_MklInput);"
"O(_MklInput)|A->D;A->E:1;B->D:1;C->D:2;D->E;E->F;"
"M->D:3;N->D:4;O->D:5");
}
@@ -1683,10 +1683,10 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradFilter_DeviceTest) {
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" input: ['A', 'B', 'C']}"
- "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'D'] }", kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Int32Input);C(Input);D(Conv2DBackpropFilter);E(Mul)|"
+ "A(Input);B(Int32Input);C(Input);D(Conv2DBackpropFilter);E(Zeta)|"
"A->D;A->E;B->D:1;C->D:2;D->E:1");
}
@@ -1696,10 +1696,10 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Relu_DeviceTest) {
"node { name: 'B' op: 'Relu'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A'] }"
- "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }", kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Relu);C(Mul)|A->B;A->C;B->C:1");
+ "A(Input);B(Relu);C(Zeta)|A->B;A->C;B->C:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_ReluGrad_DeviceTest) {
@@ -1709,10 +1709,10 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ReluGrad_DeviceTest) {
"node { name: 'C' op: 'ReluGrad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }"
- "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'C'] }", kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(ReluGrad);D(Mul)|A->C;A->D;B->C:1;C->D:1");
+ "A(Input);B(Input);C(ReluGrad);D(Zeta)|A->C;A->D;B->C:1;C->D:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_MaxPool_DeviceTest) {
@@ -1725,10 +1725,10 @@ TEST_F(MklLayoutPassTest, NodeRewrite_MaxPool_DeviceTest) {
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A'] }"
- "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }", kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1");
+ "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_AvgPool_DeviceTest) {
@@ -1741,10 +1741,10 @@ TEST_F(MklLayoutPassTest, NodeRewrite_AvgPool_DeviceTest) {
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A'] }"
- "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }", kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(AvgPool);C(Mul)|A->B;A->C;B->C:1");
+ "A(Input);B(AvgPool);C(Zeta)|A->B;A->C;B->C:1");
}
// Concat Op test: Concat with no Mkl layer feeding it
@@ -1762,10 +1762,10 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_DeviceTest) {
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'N' value { i: 2 } }"
" input: ['A', 'B:0', 'B:1']}"
- "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'D'] }", kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Const);B(InputList);C(Input);D(Concat);E(Mul)|A->D;"
+ "A(Const);B(InputList);C(Input);D(Concat);E(Zeta)|A->D;"
"B->D:1;B:1->D:2;C->E;D->E:1");
}
@@ -1784,10 +1784,10 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_DeviceTest) {
" attr { key: 'Tidx' value { type: DT_INT32 } }"
" attr { key: 'N' value { i: 2 } }"
" input: ['B:0', 'B:1', 'A']}"
- "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'D'] }", kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Const);B(InputList);C(Input);D(ConcatV2);E(Mul)|"
+ "A(Const);B(InputList);C(Input);D(ConcatV2);E(Zeta)|"
"A->D:2;B->D;B:1->D:1;C->E;D->E:1");
}
@@ -1804,11 +1804,11 @@ TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNorm_DeviceTest) {
" attr { key: 'epsilon' value { f: 0.0001 } }"
" attr { key: 'is_training' value { b: true } }"
" input: ['A', 'B', 'C', 'D', 'E'] }"
- "node { name: 'G' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ "node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'F'] }", kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(Input);E(Input);"
- "F(FusedBatchNorm);G(Mul)|A->F;A->G;B->F:1;C->F:2;D->F:3;"
+ "F(FusedBatchNorm);G(Zeta)|A->F;A->G;B->F:1;C->F:2;D->F:3;"
"E->F:4;F->G:1");
}
@@ -1832,12 +1832,12 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_DeviceTest) {
" attr { key: 'data_format' value { s: 'NCHW' } }"
" input: ['C', 'D'] }"
"node { name: 'Y' op: 'Input'}"
- "node { name: 'Z' op: 'Sub'"
+ "node { name: 'Z' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['E', 'Y']}", kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(_MklConv2D);D(Input);E(BiasAdd);"
- "M(_MklInput);N(_MklInput);Y(Input);Z(Sub)|A->C;"
+ "M(_MklInput);N(_MklInput);Y(Input);Z(Zeta)|A->C;"
"B->C:1;C->E;D->E:1;E->Z;M->C:2;N->C:3;Y->Z:1");
}
@@ -1853,7 +1853,7 @@ static void BM_MklLayoutRewritePass(int iters, int op_nodes) {
random::SimplePhilox rnd(&philox);
for (int op = 0; op < op_nodes; op++) {
s += strings::Printf(
- "node { name: 'op%04d' op: 'Mul' attr { key: 'T' value { "
+ "node { name: 'op%04d' op: 'Zeta' attr { key: 'T' value { "
"type: DT_FLOAT } } input: ['in%04d', 'in%04d' ] }",
op, rnd.Uniform(10), rnd.Uniform(10));
}
diff --git a/tensorflow/core/graph/mkl_tfconversion_pass.cc b/tensorflow/core/graph/mkl_tfconversion_pass.cc
index 590b3d030f..3f8b0e86d0 100644
--- a/tensorflow/core/graph/mkl_tfconversion_pass.cc
+++ b/tensorflow/core/graph/mkl_tfconversion_pass.cc
@@ -64,6 +64,15 @@ namespace tensorflow {
// in the Mkl format. Non-compliant ops accept inputs and outputs in the
// TensorFlow format.
//
+// ADDENDUM: For element-wise ops, we may or may not need a conversion to
+// take place before we hit the op. For this, we add a new op before each
+// element-wise MKL op to deal with the inputs, called _MklInputConversion.
+// This pass has been enhanced to add this capability.
+//
+// The _MklInputConversion op will check the inputs to the elementwise op and
+// make sure that either both are in MKL format or both are in TF format,
+// depending on their initial state and whether broadcast is needed or not.
+
class MklToTfConversionPass : public GraphOptimizationPass {
public:
MklToTfConversionPass() {}
@@ -87,6 +96,16 @@ class MklToTfConversionPass : public GraphOptimizationPass {
return mkl_op_registry::IsMklOp(op_name, T);
}
+ // Is the input Op supported by Mkl-specific layout AND
+ // is it element-wise?
+ //
+ // @input op_name string of the op
+ // @input T Datatype to use for checking input op
+ // @return true if op is Mkl supported; false, otherwise.
+ inline bool IsMklElementWiseOp(const string& op_name, DataType T) const {
+ return mkl_op_registry::IsMklElementWiseOp(op_name, T);
+ }
+
// Insert layout conversion node on the edge pointed by 'e' from graph 'g'.
//
// Edge will be deleted once a call to this function is successful.
@@ -96,6 +115,17 @@ class MklToTfConversionPass : public GraphOptimizationPass {
// @return Success:OK() if insertion is successful, otherwise returns
// appropriate error status code.
Status InsertConversionNodeOnEdge(std::unique_ptr<Graph>* g, Edge*);
+
+ // For element-wise ops, we need to sanitize the inputs. For this, we add a
+ // new node at the input of the replacement element-wise node that checks
+ // the inputs and converts one/both of them as required. See the op code
+ // comments for details.
+ //
+ // Insert input conversion node as parent of 'n' from graph 'g'.
+ //
+ // @return Success:OK() if insertion is successful, otherwise returns
+ // appropriate error status code.
+ Status InsertInputConversionNode(std::unique_ptr<Graph>* g, Node*);
};
// We register MklToTf insertion for phase 2 in post-partition grouping
@@ -171,6 +201,92 @@ Status MklToTfConversionPass::InsertConversionNodeOnEdge(
return Status::OK();
}
+Status MklToTfConversionPass::InsertInputConversionNode(
+ std::unique_ptr<Graph>* g, Node* n) {
+ CHECK_NOTNULL(n);
+
+ // Get the input nodes and edges
+ std::vector<const Edge*> edges;
+ TF_CHECK_OK(n->input_edges(&edges));
+ if (edges.size() != 4) {
+ return Status(error::Code::INVALID_ARGUMENT,
+ "MKL Binary Element-wise op should have exactly 2 data"
+ " inputs and 2 metadata inputs");
+ }
+
+ // Sanity check: ensure that both inputs are of the expected type, and the
+ // same type as input type
+ CHECK_EQ(BaseType(edges[0]->src()->output_type(edges[0]->src_output())),
+ BaseType(edges[1]->src()->output_type(edges[1]->src_output())));
+ CHECK_EQ(BaseType(edges[0]->src()->output_type(edges[0]->src_output())),
+ BaseType(n->input_type(0)));
+
+ // Check ordering of edges
+ for (uint i = 0; i < 4; i++) {
+ CHECK_EQ((edges[i]->dst_input() == i), true);
+ }
+
+ // Build the conversion node and specify src as input.
+ Node* conversion_node = nullptr;
+
+ TF_CHECK_OK(
+ NodeBuilder((*g)->NewName("MklInputConversion"), "_MklInputConversion")
+ .Input(edges[0]->src(), edges[0]->src_output())
+ .Input(edges[1]->src(), edges[1]->src_output())
+ .Input(edges[2]->src(), edges[2]->src_output())
+ .Input(edges[3]->src(), edges[3]->src_output())
+ .Device(n->def().device())
+ .Attr("T", n->input_type(0))
+ .Finalize(&**g, &conversion_node));
+
+ CHECK_NOTNULL(conversion_node);
+
+ // Change the destination of any control edges to the InputConversion node
+ if (edges.size() != n->in_edges().size()) {
+ std::vector<const Edge*> edges_to_remove;
+ for (const Edge* e : n->in_edges()) {
+ if (e->IsControlEdge()) {
+ CHECK_NOTNULL((*g)->AddControlEdge(e->src(), conversion_node));
+ edges_to_remove.push_back(e);
+ }
+ }
+ for (const Edge* e : edges_to_remove) {
+ (*g)->RemoveEdge(e);
+ }
+ }
+
+ string data_format;
+ if (GetNodeAttr(edges[0]->src()->def(), "data_format", &data_format) ==
+ Status::OK()) {
+ conversion_node->AddAttr("data_format", data_format);
+ }
+
+ // Get assigned device from destination node and apply it to conversion node.
+ // We want conversion node to be on the same device as the destination node.
+ conversion_node->set_assigned_device_name(n->assigned_device_name());
+
+ // Set the Mkl op label for this op.
+ conversion_node->AddAttr("_kernel", mkl_op_registry::kMklOpLabel);
+
+ // Now that we have added edges from src->conversion_node, let's add edge from
+ // output of conversion_node to the element-wise node.
+ CHECK_NOTNULL((*g)->AddEdge(conversion_node, 0, n, edges[0]->dst_input()));
+ CHECK_NOTNULL((*g)->AddEdge(conversion_node, 1, n, edges[1]->dst_input()));
+ CHECK_NOTNULL((*g)->AddEdge(conversion_node, 2, n, edges[2]->dst_input()));
+ CHECK_NOTNULL((*g)->AddEdge(conversion_node, 3, n, edges[3]->dst_input()));
+
+ VLOG(1) << "MklToTfConversionPass - InputConversion: Inserting input "
+ << "conversion node on: " << n->type_string() << " successful.";
+
+ // Remove src->dst edge now.
+ (*g)->RemoveEdge(edges[0]);
+ (*g)->RemoveEdge(edges[1]);
+ (*g)->RemoveEdge(edges[2]);
+ (*g)->RemoveEdge(edges[3]);
+
+ return Status::OK();
+}
+
bool MklToTfConversionPass::RunPass(std::unique_ptr<Graph>* g) {
bool result = false;
@@ -239,6 +355,49 @@ bool MklToTfConversionPass::RunPass(std::unique_ptr<Graph>* g) {
DumpGraph("After MklToTfConversionPass", &**g);
+ //---------------------------------------------------------------------------
+ // Check all nodes and add an input-conversion-node if the node is an mkl
+ // element-wise node.
+ VLOG(1) << "Before running MklToTfConversionPass - InputConversion";
+
+ std::vector<Node*> candidate_nodes;
+ std::vector<Node*> order;
+ GetReversePostOrder(**g, &order); // This will give us topological sort.
+
+ for (Node* n : order) {
+ // If node is not an op or it does not have a datatype, then skip.
+ DataType datatype;
+ if (!n->IsOp() || (GetNodeAttr(n->def(), "T", &datatype) != Status::OK())) {
+ continue;
+ }
+ if (IsMklElementWiseOp(n->type_string(), datatype)) {
+ // If the input node is an input-conversion op, skip
+ Node* input_node = nullptr;
+ TF_CHECK_OK(n->input_node(0, &input_node));
+ DataType input_datatype;
+ if ((GetNodeAttr(n->def(), "T", &input_datatype) == Status::OK()) &&
+ (input_node->type_string().compare("_MklInputConversion") == 0)) {
+ continue;
+ }
+
+ VLOG(1) << "MklToTfConversionPass: InputConversion: Scheduled node "
+ << n->name() << " for inserting input conversion node";
+ candidate_nodes.push_back(const_cast<Node*>(n));
+ }
+ }
+
+ // Process all candidate edges and insert conversion nodes on them.
+ for (Node* n : candidate_nodes) {
+ // Even if we insert conversion node on a single node, we
+ // need to return true.
+ if (InsertInputConversionNode(g, n) == Status::OK()) {
+ VLOG(1) << "MklToTfConversionPass: Inserted conversion "
+ << "on node " << n->name();
+ result = true;
+ }
+ }
+ DumpGraph("After MklToTfConversionPass - InputConversion", &**g);
+
// We need to return true even if we insert one conversion node
// anywhere in the graph.
return result;
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index b6d7e3b4b2..cff6e30c04 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -2340,7 +2340,10 @@ tf_kernel_library(
tf_kernel_library(
name = "svd_op",
prefix = "svd_op",
- deps = LINALG_DEPS,
+ deps = LINALG_DEPS + if_cuda([
+ ":cuda_solvers",
+ ":transpose_functor",
+ ]),
)
cc_library(
@@ -2938,7 +2941,7 @@ tf_kernel_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:nn_ops_op_lib",
- ],
+ ] + if_cuda(["@cub_archive//:cub"]),
)
tf_kernel_library(
@@ -5502,6 +5505,22 @@ tf_mkl_kernel_library(
)
tf_mkl_kernel_library(
+ name = "mkl_input_conversion_op",
+ hdrs = ["mkl_tfconv_op.h"],
+ prefix = "mkl_input_conversion",
+ deps = [
+ ":bounds_check",
+ ":ops_util",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:nn_ops_op_lib",
+ "//third_party/mkl:intel_binary_blob",
+ ],
+)
+
+tf_mkl_kernel_library(
name = "mkl_pooling_ops",
srcs = [
"mkl_avgpooling_op.cc",
@@ -5544,6 +5563,14 @@ tf_mkl_kernel_library(
)
tf_mkl_kernel_library(
+ name = "mkl_aggregate_ops",
+ prefix = "mkl_aggregate_ops",
+ deps = MATH_DEPS + [
+ "//third_party/mkl:intel_binary_blob",
+ ],
+)
+
+tf_mkl_kernel_library(
name = "mkl_concat_op",
prefix = "mkl_concat_op",
deps = ARRAY_DEPS + [
@@ -5575,6 +5602,20 @@ tf_mkl_kernel_library(
],
)
+tf_mkl_kernel_library(
+ name = "mkl_cwise_ops_common",
+ hdrs = [
+ "cwise_ops.h",
+ "cwise_ops_common.h",
+ "cwise_ops_gradients.h",
+ ],
+ prefix = "mkl_cwise_ops_common",
+ deps = NN_DEPS + [
+ "cwise_op",
+ "//third_party/mkl:intel_binary_blob",
+ ],
+)
+
cc_library(
name = "dataset",
srcs = ["dataset.cc"],
diff --git a/tensorflow/core/kernels/bias_op_gpu.cu.cc b/tensorflow/core/kernels/bias_op_gpu.cu.cc
index ddc2d457b0..42f3db1d79 100644
--- a/tensorflow/core/kernels/bias_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/bias_op_gpu.cu.cc
@@ -173,15 +173,20 @@ __global__ void BiasGradNCHW_SharedAtomics(const T* output_backprop,
// Accumulate the results in the shared memory into the first element.
// No syncthreads is needed since this is only in the same warp.
int32 thread_index = threadIdx.x;
- if (thread_index < 16) s_data[thread_index] += s_data[thread_index + 16];
- if (thread_index < 8) s_data[thread_index] += s_data[thread_index + 8];
- if (thread_index < 4) s_data[thread_index] += s_data[thread_index + 4];
- if (thread_index < 2) s_data[thread_index] += s_data[thread_index + 2];
- if (thread_index < 1) s_data[thread_index] += s_data[thread_index + 1];
-
- // The first thread writes out the accumulated result to the global location.
- if (thread_index == 0) {
- CudaAtomicAdd(bias_backprop + bias_index, T(s_data[0]));
+ if (thread_index < 16) {
+ s_data[thread_index] += s_data[thread_index + 16];
+ __syncwarp(0xFFFF);
+ if (thread_index < 8) s_data[thread_index] += s_data[thread_index + 8];
+ __syncwarp(0xFF);
+ if (thread_index < 4) s_data[thread_index] += s_data[thread_index + 4];
+ __syncwarp(0xF);
+ if (thread_index < 2) s_data[thread_index] += s_data[thread_index + 2];
+ __syncwarp(0x3);
+ if (thread_index == 0) {
+ T val = T(s_data[0] + s_data[1]);
+ // The first thread writes out the accumulated result to global location.
+ CudaAtomicAdd(bias_backprop + bias_index, val);
+ }
}
}
diff --git a/tensorflow/core/kernels/cuda_solvers.cc b/tensorflow/core/kernels/cuda_solvers.cc
index 5c6b5eec82..43197d8cf4 100644
--- a/tensorflow/core/kernels/cuda_solvers.cc
+++ b/tensorflow/core/kernels/cuda_solvers.cc
@@ -174,7 +174,7 @@ Status CudaSolver::CopyLapackInfoToHostAsync(
}
info_checker_callback(status, host_lapack_infos);
};
-
+
auto cb =
std::bind(wrapped_info_checker_callback, context_,
std::move(info_checker_callback), std::move(host_lapack_infos));
@@ -188,6 +188,7 @@ Status CudaSolver::CopyLapackInfoToHostAsync(
// numeric types.
#define TF_CALL_LAPACK_TYPES(m) \
m(float, S) m(double, D) m(std::complex<float>, C) m(std::complex<double>, Z)
+#define TF_CALL_LAPACK_TYPES_NO_COMPLEX(m) m(float, S) m(double, D)
// Macros to construct cusolverDn method names.
#define DN_SOLVER_FN(method, lapack_prefix) cusolverDn##lapack_prefix##method
@@ -327,6 +328,41 @@ static inline Status GetrsImpl(SolverFnT solver, OpKernelContext* context,
TF_CALL_LAPACK_TYPES(GETRS_INSTANCE);
+template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
+static inline Status GesvdImpl(BufSizeFnT bufsize, SolverFnT solver,
+ OpKernelContext* context,
+ cusolverDnHandle_t cusolver_dn_handle,
+ signed char jobu, signed char jobvt, int m,
+ int n, Scalar* A, int lda, Scalar* S, Scalar* U,
+ int ldu, Scalar* VT, int ldvt,
+ int* dev_lapack_info) {
+ /* Get amount of workspace memory required. */
+ int lwork;
+ TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, m, n, &lwork));
+ /* Allocate device memory for workspace. */
+ ScratchSpace<Scalar> dev_workspace(context, lwork, /* on_host */ false);
+ /* Launch the solver kernel. */
+ TF_RETURN_IF_CUSOLVER_ERROR(solver(
+ cusolver_dn_handle, jobu, jobvt, m, n, CUDAComplex(A), lda, S,
+ CUDAComplex(U), ldu, CUDAComplex(VT), ldvt,
+ CUDAComplex(dev_workspace.mutable_data()), lwork, NULL, dev_lapack_info));
+ return Status::OK();
+}
+
+#define GESVD_INSTANCE(Scalar, lapack_prefix) \
+ template <> \
+ Status CudaSolver::Gesvd<Scalar>( \
+ signed char jobu, signed char jobvt, int m, int n, Scalar* dev_A, \
+ int lda, Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_VT, \
+ int ldvt, int* dev_lapack_info) const { \
+ return GesvdImpl(DN_BUFSIZE_FN(gesvd, lapack_prefix), \
+ DN_SOLVER_FN(gesvd, lapack_prefix), context_, \
+ cusolver_dn_handle_, jobu, jobvt, m, n, dev_A, lda, \
+ dev_S, dev_U, ldu, dev_VT, ldvt, dev_lapack_info); \
+ }
+
+TF_CALL_LAPACK_TYPES_NO_COMPLEX(GESVD_INSTANCE);
+
//=============================================================================
// Wrappers of cuBlas computational methods begin here.
//
diff --git a/tensorflow/core/kernels/cuda_solvers.h b/tensorflow/core/kernels/cuda_solvers.h
index 0fd6450f98..7cbdc895dd 100644
--- a/tensorflow/core/kernels/cuda_solvers.h
+++ b/tensorflow/core/kernels/cuda_solvers.h
@@ -258,13 +258,23 @@ class CudaSolver {
Status Syevd(cusolverEigMode_t jobz, cublasFillMode_t uplo, int n, Scalar*
dev_A, int lda, Scalar* dev_W, int* dev_lapack_info) const;
+*/
// Singular value decomposition.
// See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-gesvd
template <typename Scalar>
Status Gesvd(signed char jobu, signed char jobvt, int m, int n, Scalar* dev_A,
- int lda, Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_VT,
- int ldvt, int* dev_lapack_info);
- */
+ int lda, Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_VT,
+ int ldvt, int* dev_lapack_info) const;
+ /*
+ // Batched linear solver using LU factorization from getrfBatched.
+ // See:
+ http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getrsbatched
+ template <typename Scalar>
+ Status GetrsBatched(cublasOperation_t trans, int n, int nrhs,
+ const Scalar* dev_Aarray[], int lda, const int* devIpiv,
+ Scalar* dev_Barray[], int ldb, int* info, int batch_size)
+ const;
+ */
private:
OpKernelContext* context_; // not owned.
diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h
index d935331904..ada39eae38 100644
--- a/tensorflow/core/kernels/cwise_ops.h
+++ b/tensorflow/core/kernels/cwise_ops.h
@@ -139,7 +139,7 @@ struct scalar_left : private Binary {
typedef Tout result_type;
const Tin* left;
- EIGEN_DEVICE_FUNC inline scalar_left(const scalar_left& other) = default;
+ inline scalar_left(const scalar_left& other) = default;
template <typename... Args>
EIGEN_DEVICE_FUNC inline explicit scalar_left(const Tin* c, Args... args)
@@ -169,7 +169,7 @@ struct scalar_right : private Binary {
typedef Tout result_type;
const Tin* right;
- EIGEN_DEVICE_FUNC inline scalar_right(const scalar_right& other) = default;
+ inline scalar_right(const scalar_right& other) = default;
template <typename... Args>
EIGEN_DEVICE_FUNC inline explicit scalar_right(const Tin* c, Args... args)
diff --git a/tensorflow/core/kernels/cwise_ops_common.cc b/tensorflow/core/kernels/cwise_ops_common.cc
index 192a4f732e..693c6467ac 100644
--- a/tensorflow/core/kernels/cwise_ops_common.cc
+++ b/tensorflow/core/kernels/cwise_ops_common.cc
@@ -20,7 +20,9 @@ namespace tensorflow {
BinaryOpShared::BinaryOpShared(OpKernelConstruction* ctx, DataType out,
DataType in)
: OpKernel(ctx) {
+#ifndef INTEL_MKL
OP_REQUIRES_OK(ctx, ctx->MatchSignature({in, in}, {out}));
+#endif
}
void BinaryOpShared::SetUnimplementedError(OpKernelContext* ctx) {
diff --git a/tensorflow/core/kernels/decode_raw_op.cc b/tensorflow/core/kernels/decode_raw_op.cc
index 9492a4e26d..1c0085cfea 100644
--- a/tensorflow/core/kernels/decode_raw_op.cc
+++ b/tensorflow/core/kernels/decode_raw_op.cc
@@ -105,6 +105,7 @@ REGISTER(Eigen::half);
REGISTER(float);
REGISTER(double);
REGISTER(int32);
+REGISTER(uint16);
REGISTER(uint8);
REGISTER(int16);
REGISTER(int8);
diff --git a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
index fcfcd188d2..ecfe51d599 100644
--- a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
#include "tensorflow/core/util/tensor_format.h"
+#include "external/cub_archive/cub/util_ptx.cuh"
#if !defined(_MSC_VER)
#define UNROLL _Pragma("unroll")
@@ -1015,6 +1016,21 @@ __global__ void __launch_bounds__(640, 2)
}
}
+// Device function to compute sub-warp sum reduction for a power-of-two group of
+// neighboring threads.
+template<int kWidth, typename T>
+__device__ __forceinline__ T WarpSumReduce(T val) {
+ // support only power-of-two widths.
+ assert(__popc(kWidth) == 1);
+ int sub_warp = cub::LaneId() / kWidth;
+ int zeros = sub_warp * kWidth;
+ unsigned mask = ((1UL << kWidth) - 1) << zeros;
+ for (int delta = kWidth / 2; delta > 0; delta /= 2) {
+ val += CudaShuffleXor(mask, val, delta);
+ }
+ return val;
+}
+
// CUDA kernel to compute the depthwise convolution backward w.r.t. filter in
// NHWC format, tailored for small images up to 32x32. Stride and depth
// multiplier must be 1. Padding must be 'SAME'. Only use this kernel if
@@ -1127,6 +1143,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
// Note: the condition to reach this is uniform across the entire block.
__syncthreads();
+ unsigned active_threads = CudaBallot(CUDA_WARP_ALL, depth_in_range);
if (depth_in_range) {
const T* const out_ptr = inout_offset + output;
@@ -1140,7 +1157,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
T val = out1 * tile_ptr[0] + out2 * tile_ptr[tile_offset];
// Warp-accumulate pixels of the same depth and write to accumulator.
for (int delta = 16; delta >= kBlockSlices; delta /= 2) {
- val += CudaShuffleDown(val, delta);
+ val += CudaShuffleDown(active_threads, val, delta);
}
if (!(thread_idx & 32 - kBlockSlices) /* lane_idx < kBlockSlices */) {
*accum_ptr = val;
@@ -1164,9 +1181,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
if (filter_depth < in_depth) {
T val = accum_data[i];
// Warp-accumulate the pixels of the same depth from the accumulator.
- for (int delta = kAccumPixels / 2; delta > 0; delta /= 2) {
- val += CudaShuffleDown(val, delta);
- }
+ val = WarpSumReduce<kAccumPixels>(val);
if (!(thread_idx & kAccumPixels - 1)) {
CudaAtomicAdd(filter_offset + filter, val);
}
@@ -1382,6 +1397,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
// Note: the condition to reach this is uniform across the entire block.
__syncthreads();
+ unsigned active_threads = CudaBallot(CUDA_WARP_ALL, slice_in_range);
if (slice_in_range) {
const T* const out_ptr = inout_offset + output;
@@ -1395,7 +1411,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
T val = out1 * tile_ptr[0] + out2 * tile_ptr[tile_offset];
// Warp-accumulate pixels of the same depth and write to accumulator.
for (int delta = 16 / kBlockSlices; delta > 0; delta /= 2) {
- val += CudaShuffleDown(val, delta);
+ val += CudaShuffleDown(active_threads, val, delta);
}
if (!(thread_idx & 32 / kBlockSlices - 1)) {
*accum_ptr = val;
@@ -1419,9 +1435,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
if (filter_depth < in_depth) {
T val = accum_data[i];
// Warp-accumulate pixels of the same depth from the accumulator.
- for (int delta = kAccumPixels / 2; delta > 0; delta /= 2) {
- val += CudaShuffleDown(val, delta);
- }
+ val = WarpSumReduce<kAccumPixels>(val);
if (!(thread_idx & kAccumPixels - 1)) {
CudaAtomicAdd(filter_offset + filter, val);
}
diff --git a/tensorflow/core/kernels/fill_functor.cc b/tensorflow/core/kernels/fill_functor.cc
index 8a0a558eef..ea0cc139f3 100644
--- a/tensorflow/core/kernels/fill_functor.cc
+++ b/tensorflow/core/kernels/fill_functor.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/variant_encode_decode.h"
namespace tensorflow {
namespace functor {
@@ -50,6 +51,7 @@ DEFINE_SETZERO_CPU(int32);
DEFINE_SETZERO_CPU(int64);
DEFINE_SETZERO_CPU(complex64);
DEFINE_SETZERO_CPU(complex128);
+DEFINE_SETZERO_CPU(Variant);
#undef DEFINE_SETZERO_CPU
#ifdef TENSORFLOW_USE_SYCL
diff --git a/tensorflow/core/kernels/mkl_aggregate_ops.cc b/tensorflow/core/kernels/mkl_aggregate_ops.cc
new file mode 100644
index 0000000000..51ba127def
--- /dev/null
+++ b/tensorflow/core/kernels/mkl_aggregate_ops.cc
@@ -0,0 +1,273 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// See docs in ../ops/math_ops.cc.
+
+#ifdef INTEL_MKL
+#define EIGEN_USE_THREADS
+
+#include <numeric>
+
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/platform/logging.h"
+
+#include "mkl_dnn.h"
+#include "mkl_dnn_types.h"
+#include "tensorflow/core/util/mkl_util.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+template <typename Device, typename T>
+class MklAddNOp : public OpKernel {
+ public:
+ explicit MklAddNOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const int num = ctx->num_inputs();
+ OP_REQUIRES(ctx, num / 2 == 2,
+ errors::InvalidArgument("Only additions of two arguments "
+ "supported by MKL. Num inputs: ",
+ num));
+
+ MklAddNOpContext mkl_context;
+ const Tensor& input0 = MklGetInput(ctx, 0);
+ GetMklShape(ctx, 0, &(mkl_context.input1_shape));
+ bool input1_in_mkl_format = mkl_context.input1_shape.IsMklTensor();
+
+ const Tensor& input1 = MklGetInput(ctx, 1);
+ GetMklShape(ctx, 1, &(mkl_context.input2_shape));
+ bool input2_in_mkl_format = mkl_context.input2_shape.IsMklTensor();
+
+ mkl_context.in_dims = input1_in_mkl_format
+ ? mkl_context.input1_shape.GetDimension()
+ : input0.dims();
+ mkl_context.in_dims = input2_in_mkl_format
+ ? mkl_context.input2_shape.GetDimension()
+ : input1.dims();
+ // Generate size, stride for input if input is in MKL format.
+ ExtractMklOpParams(&mkl_context.in1_sizes,
+ &mkl_context.in1_strides, input0, &mkl_context.input1_shape);
+ ExtractMklOpParams(&mkl_context.in2_sizes,
+ &mkl_context.in2_strides, input1, &mkl_context.input2_shape);
+
+ std::vector<float> coeff(2, 1.0);
+ mkl_context.MklCreateInputLayouts(ctx);
+ CHECK_EQ(dnnSumCreate_F32(&mkl_context.Eltwise, mkl_context.attributes, 2,
+ mkl_context.lt_input1, &coeff[0]),
+ E_SUCCESS);
+
+ Tensor mkl_tmp_input1_buf_tensor, mkl_tmp_input2_buf_tensor;
+ mkl_context.MklPrepareAddNInputs(ctx, &mkl_tmp_input1_buf_tensor,
+ &mkl_tmp_input2_buf_tensor);
+ Tensor* output = nullptr;
+ if (input1_in_mkl_format || input2_in_mkl_format) {
+ TensorShape tf_shape;
+ mkl_context.output_shape.SetMklTensor(true);
+ mkl_context.output_shape.SetMklLayout(mkl_context.Eltwise, dnnResourceDst);
+
+ mkl_context.output_shape.SetTfLayout(
+ mkl_context.in_dims, mkl_context.in1_sizes, mkl_context.in1_strides);
+ if (input1_in_mkl_format == true) {
+ mkl_context.output_shape.SetTfDimOrder(mkl_context.in_dims,
+ mkl_context.input1_shape.GetTfToMklDimMap());
+ } else {
+ mkl_context.output_shape.SetTfDimOrder(mkl_context.in_dims,
+ mkl_context.input2_shape.GetTfToMklDimMap());
+ }
+ tf_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
+ mkl_context.output_shape.GetMklLayout())) /
+ sizeof(T));
+
+ AllocateOutputSetMklShape(ctx, 0, &output, tf_shape,
+ mkl_context.output_shape);
+ } else {
+ const TensorShape& o_shape = input1.shape();
+ mkl_context.output_shape.SetMklTensor(false);
+ AllocateOutputSetMklShape(ctx, 0, &output, o_shape,
+ mkl_context.output_shape);
+ }
+
+ mkl_context.Eltwise_res[dnnResourceDst] =
+ static_cast<void*>(output->flat<T>().data());
+
+ // Execute convolution
+ CHECK_EQ(dnnExecute_F32(mkl_context.Eltwise, mkl_context.Eltwise_res),
+ E_SUCCESS);
+
+ mkl_context.MklCleanup();
+ }
+
+ void ExtractMklOpParams(size_t** out_sizes, size_t** out_strides,
+ const Tensor& input, const MklShape* input_shape) {
+ bool input_in_mkl_format = input_shape->IsMklTensor();
+ int in_dims = input_in_mkl_format
+ ? input_shape->GetDimension()
+ : input.dims();
+ size_t* in_sizes = new size_t[in_dims];
+ size_t* in_strides = new size_t[in_dims];
+
+ if (input_in_mkl_format) {
+ for (int i = 0; i < in_dims; i++) {
+ in_sizes[i] = input_shape->GetSizes()[i];
+ in_strides[i] = input_shape->GetStrides()[i];
+ }
+ } else {
+ for (int i = 0; i < in_dims; i++) {
+ in_sizes[i] =
+ input.dim_size((in_dims - 1) - i);
+ }
+ in_strides[0] = 1;
+ for (int i = 1; i < in_dims; i++) {
+ in_strides[i] =
+ in_strides[i - 1] * in_sizes[i - 1];
+ }
+ }
+ *out_sizes = in_sizes;
+ *out_strides = in_strides;
+ }
+
+
+ private:
+ typedef struct {
+ int in_dims;
+ size_t* in1_sizes;
+ size_t* in1_strides;
+
+ size_t* in2_sizes;
+ size_t* in2_strides;
+ dnnPrimitive_t Eltwise = nullptr;
+ dnnPrimitiveAttributes_t attributes = nullptr;
+ void* Eltwise_res[dnnResourceNumber];
+ dnnLayout_t lt_input1 = nullptr, lt_input2 = nullptr;
+ MklShape input1_shape, input2_shape, output_shape;
+
+ void MklCreateInputLayouts(OpKernelContext* context) {
+ bool input1_in_mkl_format = input1_shape.IsMklTensor();
+ if (!input1_in_mkl_format) {
+ CHECK_EQ(
+ dnnLayoutCreate_F32(&lt_input1, in_dims, in1_sizes, in1_strides),
+ E_SUCCESS);
+ } else {
+ lt_input1 = static_cast<dnnLayout_t>(input1_shape.GetCurLayout());
+ }
+
+ bool input2_in_mkl_format = input2_shape.IsMklTensor();
+ if (!input2_in_mkl_format) {
+ CHECK_EQ(
+ dnnLayoutCreate_F32(&lt_input2, in_dims, in2_sizes, in2_strides),
+ E_SUCCESS);
+ } else {
+ lt_input2 = static_cast<dnnLayout_t>(input2_shape.GetCurLayout());
+ }
+ }
+
+ void MklPrepareAddNInputs(OpKernelContext* context,
+ Tensor* mkl_tmp_input1_buf_tensor,
+ Tensor* mkl_tmp_input2_buf_tensor) {
+ bool mkl_convert_input1, mkl_convert_input2;
+ dnnPrimitive_t mkl_prim_convert_input1 = nullptr,
+ mkl_prim_convert_input2 = nullptr;
+ dnnLayout_t mkl_lt_internal_input1 = nullptr,
+ mkl_lt_internal_input2 = nullptr;
+ void *mkl_buf_convert_input1 = nullptr, *mkl_buf_convert_input2 = nullptr;
+ dnnResourceType_t dnnResourceMultipleSrc2 =
+ (dnnResourceType_t)(dnnResourceMultipleSrc + 1);
+ // Compare with internal layouts and convert if needed
+ const Tensor& input1 = MklGetInput(context, 0);
+
+ void* mkl_buf_input1 =
+ const_cast<void*>(static_cast<const void*>(input1.flat<T>().data()));
+
+ CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(
+ &mkl_lt_internal_input1, Eltwise, dnnResourceMultipleSrc),
+ E_SUCCESS);
+ mkl_convert_input1 =
+ !dnnLayoutCompare_F32(mkl_lt_internal_input1, lt_input1);
+ if (mkl_convert_input1) {
+ CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_input1, lt_input1,
+ mkl_lt_internal_input1),
+ E_SUCCESS);
+ AllocTmpBuffer(context, mkl_tmp_input1_buf_tensor,
+ mkl_lt_internal_input1, &mkl_buf_convert_input1);
+ CHECK_EQ(
+ dnnConversionExecute_F32(mkl_prim_convert_input1, mkl_buf_input1,
+ mkl_buf_convert_input1),
+ E_SUCCESS);
+ dnnDelete_F32(mkl_prim_convert_input1);
+ }
+ dnnLayoutDelete_F32(mkl_lt_internal_input1);
+
+ Eltwise_res[dnnResourceMultipleSrc] =
+ (mkl_convert_input1) ? mkl_buf_convert_input1 : mkl_buf_input1;
+
+ const Tensor& input2 = MklGetInput(context, 1);
+ void* mkl_buf_input2 =
+ const_cast<void*>(static_cast<const void*>(input2.flat<T>().data()));
+ CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(
+ &mkl_lt_internal_input2, Eltwise, dnnResourceMultipleSrc2),
+ E_SUCCESS);
+ mkl_convert_input2 =
+ !dnnLayoutCompare_F32(mkl_lt_internal_input2, lt_input2);
+ if (mkl_convert_input2) {
+ CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_input2, lt_input2,
+ mkl_lt_internal_input2),
+ E_SUCCESS);
+ AllocTmpBuffer(context, mkl_tmp_input2_buf_tensor,
+ mkl_lt_internal_input2, &mkl_buf_convert_input2);
+ CHECK_EQ(
+ dnnConversionExecute_F32(mkl_prim_convert_input2, mkl_buf_input2,
+ mkl_buf_convert_input2),
+ E_SUCCESS);
+ dnnDelete_F32(mkl_prim_convert_input2);
+ }
+ dnnLayoutDelete_F32(mkl_lt_internal_input2);
+
+ Eltwise_res[dnnResourceMultipleSrc2] =
+ (mkl_convert_input2) ? mkl_buf_convert_input2 : mkl_buf_input2;
+ }
+
+ void MklCleanup() {
+ bool input1_in_mkl_format = input1_shape.IsMklTensor();
+ bool input2_in_mkl_format = input2_shape.IsMklTensor();
+ dnnDelete_F32(Eltwise);
+ if (!input1_in_mkl_format) {
+ dnnLayoutDelete_F32(lt_input1);
+ delete [] in1_sizes;
+ delete [] in1_strides;
+ }
+ if (!input2_in_mkl_format) {
+ dnnLayoutDelete_F32(lt_input2);
+ delete [] in2_sizes;
+ delete [] in2_strides;
+ }
+ }
+ } MklAddNOpContext;
+};
+
+#define REGISTER_MKL_CPU(T) \
+ REGISTER_KERNEL_BUILDER(Name("_MklAddN") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklAddNOp<CPUDevice, T>);
+
+TF_CALL_float(REGISTER_MKL_CPU);
+#undef REGISTER_MKL_CPU
+} // namespace tensorflow
+#endif // INTEL_MKL
diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc
index 5dfce5d5c6..7f1555d325 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_ops.cc
@@ -406,8 +406,10 @@ class MklConv2DOp : public OpKernel {
CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_filter, lt_filter,
mkl_lt_internal_filter),
E_SUCCESS);
+
mkl_buf_convert_filter = const_cast<void*>(
static_cast<const void*>(output_filter->flat<T>().data()));
+
CHECK_EQ(
dnnConversionExecute_F32(mkl_prim_convert_filter, mkl_buf_filter,
mkl_buf_convert_filter),
diff --git a/tensorflow/core/kernels/mkl_cwise_ops_common.cc b/tensorflow/core/kernels/mkl_cwise_ops_common.cc
new file mode 100644
index 0000000000..7fc633c254
--- /dev/null
+++ b/tensorflow/core/kernels/mkl_cwise_ops_common.cc
@@ -0,0 +1,88 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0(the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef INTEL_MKL
+
+// See docs in ../ops/math_ops.cc.
+
+#define EIGEN_USE_THREADS
+#include <iostream>
+#include <vector>
+
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+#include "tensorflow/core/util/mkl_util.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+template <typename Device, typename Functor>
+class MklBinaryOp : public BinaryOp<Device, Functor> {
+ public:
+ explicit MklBinaryOp(OpKernelConstruction* context)
+ : BinaryOp<Device, Functor>(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ auto in0 = context->input(0);
+ auto in1 = context->input(1);
+ VLOG(1) << "Shapes (start mklbinaryop compute): "
+ << in0.shape().DebugString() << " _and_ "
+ << in1.shape().DebugString();
+
+ // Call the TensorFlow BinaryOp Compute method
+ BinaryOp<Device, Functor>::Compute(context);
+
+ auto out = context->mutable_output(0);
+ VLOG(1) << "Shapes (output): " << out->shape().DebugString();
+
+ // Pass input shape through to ouput shape
+ ForwardMklMetaDataInToOut(context, 0, 0);
+
+ out = context->mutable_output(0);
+ VLOG(1) << "Shapes (output): " << out->shape().DebugString();
+ }
+};
+
+//---------- Registration macros for various element-wise ops -----------
+// We will need to redefine "REGISTER" to include the mkl_op_registry flag
+#pragma push_macro("REGISTER")
+#undef REGISTER
+#define REGISTER(OP, D, N, F, T) \
+ REGISTER_KERNEL_BUILDER(Name(N) \
+ .Device(DEVICE_##D) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ OP<D##Device, F<T>>);
+
+REGISTER5(MklBinaryOp, CPU, "_MklAdd", functor::add, float, Eigen::half, double,
+ int32, int64);
+REGISTER7(MklBinaryOp, CPU, "_MklSub", functor::sub, float, Eigen::half, double,
+ int32, int64, complex64, complex128);
+REGISTER5(MklBinaryOp, CPU, "_MklMul", functor::mul, float, Eigen::half, double,
+ uint8, int32);
+REGISTER5(MklBinaryOp, CPU, "_MklMaximum", functor::maximum, float, Eigen::half,
+ double, int32, int64);
+REGISTER5(MklBinaryOp, CPU, "_MklSquaredDifference",
+ functor::squared_difference, float, Eigen::half, double, int32,
+ int64);
+
+#undef REGISTER
+#pragma pop_macro("REGISTER")
+//-----------------------------------------------------------------------
+
+} // end namespace tensorflow
+
+#endif // INTEL_MKL
diff --git a/tensorflow/core/kernels/mkl_identity_op.cc b/tensorflow/core/kernels/mkl_identity_op.cc
index ca20294a26..f31e7afd46 100644
--- a/tensorflow/core/kernels/mkl_identity_op.cc
+++ b/tensorflow/core/kernels/mkl_identity_op.cc
@@ -41,9 +41,9 @@ class MklIdentityOp : public OpKernel {
bool input_in_mkl_format = mkl_shape_input.IsMklTensor();
if (input_in_mkl_format) {
- ForwarMklTensorInToOut(context, 0, 0);
+ ForwardMklTensorInToOut(context, 0, 0);
} else {
- FowardTfTensorInToOut(context, 0, 0);
+ ForwardTfTensorInToOut(context, 0, 0);
}
}
diff --git a/tensorflow/core/kernels/mkl_input_conversion_op.cc b/tensorflow/core/kernels/mkl_input_conversion_op.cc
new file mode 100644
index 0000000000..b58e44e398
--- /dev/null
+++ b/tensorflow/core/kernels/mkl_input_conversion_op.cc
@@ -0,0 +1,259 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef INTEL_MKL
+
+#include <algorithm>
+#include <vector>
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/platform/cpu_info.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/util/tensor_format.h"
+
+#include "tensorflow/core/kernels/mkl_tfconv_op.h"
+#include "tensorflow/core/util/mkl_util.h"
+
+namespace tensorflow {
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+///////////////////////////////////////////////////////////
+// Op kernel
+// Checks and ensures that the 2 inputs are compatible for mkl binary ops.
+// Here's the basic logic:
+//
+// if both inputs are in TF format:
+// pass the inputs through to the output
+// else if both inputs are in mkl format:
+// if both have the same shape:
+// pass the inputs through to the output
+// else:
+// convert both to TF
+// else if one is TF and one is MKL:
+// if broadcast is needed:
+// convert the MKL format input to TF format
+// else:
+// convert the TF format input to MKL format
+///////////////////////////////////////////////////////////
+
+template <typename Device, typename T>
+class MklInputConversionOp : public OpKernel {
+ public:
+ explicit MklInputConversionOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str));
+ OP_REQUIRES_OK(context, context->GetAttr("T", &op_data_type));
+ has_avx512f_ = port::TestCPUFeature(port::CPUFeature::AVX512F);
+ }
+
+ private:
+ void Compute(OpKernelContext* context) override {
+ // Check if input tensors are in MKL format.
+ const Tensor& input_tensor_0 = MklGetInput(context, 0);
+ MklShape input_shape_0;
+ GetMklShape(context, 0, &input_shape_0);
+
+ const Tensor& input_tensor_1 = MklGetInput(context, 1);
+ MklShape input_shape_1;
+ GetMklShape(context, 1, &input_shape_1);
+
+ bool tf_shapes_are_same = MklCompareShapes(&context->input(0).shape(),
+ &context->input(1).shape());
+
+ VLOG(1) << "MklInputConversionOp: Input shapes are "
+ << (tf_shapes_are_same ? "*same*" : "*different*") << ": "
+ << context->input(0).shape().DebugString() << " and "
+ << context->input(1).shape().DebugString();
+
+ // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
+ // if both inputs are in TF format, just copy input tensors to output.
+ if (!input_shape_0.IsMklTensor() && !input_shape_1.IsMklTensor()) {
+ VLOG(1) << "MklInputConversionOp: No conversion needed, "
+ << "copying TF inputs to output";
+
+ ForwardTfTensorInToOut(context, 0, 0);
+ ForwardTfTensorInToOut(context, 1, 1);
+ return;
+ }
+
+ // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
+ // If both inputs are in MKL format
+ if (input_shape_0.IsMklTensor() && input_shape_1.IsMklTensor()) {
+ // If both have the same shape, pass them through
+ if (tf_shapes_are_same) {
+ VLOG(1) << "MklInputConversionOp: No conversion needed, "
+ << "copying MKL inputs with identical shapes to output";
+
+ ForwardMklTensorInToOut(context, 0, 0);
+ ForwardMklTensorInToOut(context, 1, 1);
+ return;
+ }
+
+ // Sanity check
+ bool mkl_shapes_are_same =
+ MklCompareShapes(&input_shape_0, &input_shape_1);
+ if (mkl_shapes_are_same) {
+ CHECK(false) << "MklInputConversionOp: Unexpected: TF shapes are "
+ "different but MKL shapes are same";
+ }
+
+ // Both have different shapes, so broadcast will be necessary.
+ // Convert to TF and pass both tensors through (we can't do broadcast
+ // with MKL tensors)
+ VLOG(1) << "MklInputConversionOp: Broadcast needed, "
+ << "converted MKL inputs to TF format";
+
+ MklToTfOp<Device, T>::ConvertMklToTf(this, context, data_format_str,
+ op_data_type, has_avx512f_, 0);
+ MklToTfOp<Device, T>::ConvertMklToTf(this, context, data_format_str,
+ op_data_type, has_avx512f_, 1);
+ SetDummyMklShapeOutput(context, 0);
+ SetDummyMklShapeOutput(context, 1);
+ return;
+ }
+
+ // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
+ // One input is MKL and one is TF. If no broadcast is needed, convert
+ // the TF tensor to MKL, otherwise convert the MKL tensor to TF format
+ VLOG(1) << "MklInputConversionOp: Inputs in different formats (MKL/TF)";
+
+ const Tensor* mkl_tensor;
+ const MklShape* mkl_shape;
+ const Tensor* tf_tensor;
+ MklShape* tf_mkl_shape;
+ uint mkl_tensor_index;
+ uint tf_tensor_index;
+ if (input_shape_0.IsMklTensor() && !input_shape_1.IsMklTensor()) {
+ mkl_tensor = &input_tensor_0;
+ mkl_shape = &input_shape_0;
+ mkl_tensor_index = 0;
+ tf_tensor = &input_tensor_1;
+ tf_mkl_shape = &input_shape_1;
+ tf_tensor_index = 1;
+ } else if (!input_shape_0.IsMklTensor() && input_shape_1.IsMklTensor()) {
+ mkl_tensor = &input_tensor_1;
+ mkl_shape = &input_shape_1;
+ mkl_tensor_index = 1;
+ tf_tensor = &input_tensor_0;
+ tf_mkl_shape = &input_shape_0;
+ tf_tensor_index = 0;
+ } else {
+ CHECK(false) << "MklInputConversionOp: Unexpected combination of input "
+ "shapes for MKL "
+ << "element-wise op";
+ }
+
+ // Broadcast is needed if the shapes are not the same
+ bool broadcast_needed;
+
+ size_t in0_size = 1;
+ for (size_t i = 0; i < mkl_shape->GetDimension(); ++i)
+ in0_size *= mkl_shape->tf_dim_size(i);
+
+ size_t in1_size = 1;
+ for (size_t i = 0; i < tf_tensor->shape().dims(); ++i)
+ in1_size *= tf_tensor->shape().dim_size(i);
+
+ broadcast_needed = (in0_size != in1_size);
+
+ if (!broadcast_needed) {
+ // Both shapes are same, convert the TF input to MKL
+ VLOG(1) << "MklInputConversionOp: No broadcast needed.";
+ VLOG(1) << "MklInputConversionOp: Converting input " << tf_tensor_index
+ << " to MKL format";
+
+ // Create MklShape
+ Tensor* tensor_out;
+ MklShape mkl_output_mkl_shape;
+ mkl_output_mkl_shape.SetMklTensor(true);
+ mkl_output_mkl_shape.SetTfLayout(mkl_shape->GetDimension(),
+ mkl_shape->GetSizes(),
+ mkl_shape->GetStrides());
+ mkl_output_mkl_shape.SetTfDimOrder(mkl_shape->GetDimension());
+
+ // ** Temporarily borrow the layout from the MKL input **
+ mkl_output_mkl_shape.SetMklLayout(mkl_shape->GetCurLayout());
+
+ // Create output tensor
+ AllocateOutputSetMklShape(context, tf_tensor_index, &tensor_out,
+ mkl_tensor->shape(), mkl_output_mkl_shape);
+
+ // Since the shapes are the same, use information from the other tensor
+ tf_mkl_shape->SetTfLayout(mkl_shape->GetDimension(),
+ mkl_shape->GetSizes(), mkl_shape->GetStrides());
+ // Convert the data format
+ tf_mkl_shape->GetConvertedFlatData(
+ mkl_shape->GetCurLayout(),
+ const_cast<T*>(tf_tensor->flat<T>().data()),
+ const_cast<T*>(tensor_out->flat<T>().data()));
+
+ // ** Release the borrowed layout to avoid double deletion
+ // in the destructor call **
+ mkl_output_mkl_shape.SetMklLayout(nullptr);
+
+ // -- The tensor in MKL format passes through --
+ ForwardMklTensorInToOut(context, mkl_tensor_index, mkl_tensor_index);
+ } else {
+ // Broadcast is needed, so convert the MKL input to TF
+ VLOG(1) << "MklInputConversionOp: Broadcast needed.";
+ VLOG(1) << "MklInputConversionOp: Converting input " << mkl_tensor_index
+ << " to TF format";
+ MklToTfOp<Device, T>::ConvertMklToTf(this, context, data_format_str,
+ op_data_type, has_avx512f_,
+ mkl_tensor_index);
+ SetDummyMklShapeOutput(context, mkl_tensor_index);
+
+ // The tensor in TF format passes through
+ ForwardTfTensorInToOut(context, tf_tensor_index, tf_tensor_index);
+ }
+
+ VLOG(1) << "MklInputConversionOp: Shapes (output): "
+ << context->mutable_output(0)->shape().DebugString() << " and "
+ << context->mutable_output(1)->shape().DebugString();
+
+ VLOG(1) << "MklInputConversion completed successfully.";
+ }
+
+ private:
+ /// Data format of the operation
+ string data_format_str;
+
+ /// Data type of the operation
+ DataType op_data_type;
+
+ /// CPUIDInfo
+ bool has_avx512f_ = false;
+};
+
+///////////////////////////////////////////////////////////
+// Register kernel
+///////////////////////////////////////////////////////////
+
+#define REGISTER_CPU(T) \
+ REGISTER_KERNEL_BUILDER(Name("_MklInputConversion") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklInputConversionOp<CPUDevice, T>);
+
+TF_CALL_NUMBER_TYPES(REGISTER_CPU);
+#undef REGISTER_CPU
+} // namespace tensorflow
+#endif // INTEL_MKL
diff --git a/tensorflow/core/kernels/mkl_tfconv_op.h b/tensorflow/core/kernels/mkl_tfconv_op.h
new file mode 100644
index 0000000000..a240ee44fb
--- /dev/null
+++ b/tensorflow/core/kernels/mkl_tfconv_op.h
@@ -0,0 +1,136 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef INTEL_MKL
+
+#ifndef TENSORFLOW_CORE_KERNELS_MKL_TFCONV_OP_H_
+#define TENSORFLOW_CORE_KERNELS_MKL_TFCONV_OP_H_
+
+#include <algorithm>
+#include <vector>
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/platform/cpu_info.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/util/tensor_format.h"
+
+#include "mkl_dnn.h"
+#include "mkl_dnn_types.h"
+#include "tensorflow/core/util/mkl_util.h"
+
+namespace tensorflow {
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+///////////////////////////////////////////////////////////
+// Op kernel
+///////////////////////////////////////////////////////////
+
+template <typename Device, typename T>
+class MklToTfOp : public OpKernel {
+ public:
+ explicit MklToTfOp(OpKernelConstruction* context) : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str));
+ OP_REQUIRES_OK(context, context->GetAttr("T", &op_data_type));
+ has_avx512f_ = port::TestCPUFeature(port::CPUFeature::AVX512F);
+ }
+
+ void Compute(OpKernelContext* context) override {
+ ConvertMklToTf(this, context, data_format_str, op_data_type, has_avx512f_,
+ 0);
+ VLOG(1) << "MKLToTFConversion complete successfully.";
+ }
+
+ static void ConvertMklToTf(OpKernel* op_kernel, OpKernelContext* context,
+ string data_format_str, DataType op_data_type,
+ bool has_avx512f, uint input_number) {
+ // Check that input tensor is in MKL format.
+ const Tensor& input_tensor = MklGetInput(context, input_number);
+ MklShape input_shape;
+ GetMklShape(context, input_number, &input_shape);
+
+ // if input is already in Tf format, then just copy input tensor to output.
+ if (!input_shape.IsMklTensor()) {
+ context->set_output(input_number, input_tensor);
+ VLOG(1) << "MKLToTFConversion: No conversion needed, "
+ << "copying input to output";
+ return;
+ }
+
+ // Check that input data type is same as operator data type and that it is
+ // same as output data type.
+ DataType input_data_type = op_kernel->input_type(input_number);
+ DataType output_data_type = op_kernel->output_type(input_number);
+ CHECK_EQ(op_data_type, input_data_type);
+ CHECK_EQ(op_data_type, output_data_type);
+
+ TensorShape output_shape;
+ size_t ndims = input_shape.GetDimension();
+ size_t* in_sizes = new size_t[ndims];
+ for (size_t i = 0; i < ndims; i++) {
+ // Outermost to innermost dimension
+ output_shape.AddDim(input_shape.GetSizes()[input_shape.tf_dim_idx(i)]);
+ in_sizes[i] = input_shape.GetSizes()[i];
+ }
+
+ // Allocate output tensor.
+ Tensor* output_tensor = NULL;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(input_number, output_shape, &output_tensor));
+
+ dnnLayout_t output_layout =
+ static_cast<dnnLayout_t>(input_shape.GetTfLayout());
+ // Execute DNNConversion.
+ void* input_buffer =
+ static_cast<void*>(const_cast<T*>(input_tensor.flat<T>().data()));
+ delete[] in_sizes;
+ void* output_buffer =
+ static_cast<void*>(const_cast<T*>(output_tensor->flat<T>().data()));
+ input_shape.GetConvertedFlatData(output_layout, input_buffer,
+ output_buffer);
+ VLOG(1) << "MKLToTFConversion complete successfully.";
+ }
+
+ private:
+ /// Data format of the operation
+ string data_format_str;
+
+ /// Data type of the operation
+ DataType op_data_type;
+
+ /// CPUIDInfo
+ bool has_avx512f_ = false;
+};
+
+///////////////////////////////////////////////////////////
+// Register kernel
+///////////////////////////////////////////////////////////
+
+#define REGISTER_CPU(T) \
+ REGISTER_KERNEL_BUILDER(Name("_MklToTf") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklToTfOp<CPUDevice, T>);
+
+TF_CALL_NUMBER_TYPES(REGISTER_CPU);
+#undef REGISTER_CPU
+} // namespace tensorflow
+#endif // TENSORFLOW_CORE_KERNELS_MKL_TFCONV_OP_H_
+#endif // INTEL_MKL
diff --git a/tensorflow/core/kernels/svd_op_gpu.cu.cc b/tensorflow/core/kernels/svd_op_gpu.cu.cc
new file mode 100644
index 0000000000..c8b307a2e4
--- /dev/null
+++ b/tensorflow/core/kernels/svd_op_gpu.cu.cc
@@ -0,0 +1,413 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// See docs in ../ops/linalg_ops.cc.
+// TODO(shamanDevel): Enable complex inputs. This will require a specialization
+// of Gesvd for complex inputs as well as a new kernel
+// definition to output the singular values as reals
+// instead of complex values. The current CPU implementation
+// outputs the singular values as complex values and then
+// casts them to reals in the python wrapper.
+#if GOOGLE_CUDA
+#define EIGEN_USE_GPU
+
+#include <algorithm>
+#include <vector>
+
+#include "tensorflow/core/framework/kernel_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/cuda_solvers.h"
+#include "tensorflow/core/kernels/linalg_ops_common.h"
+#include "tensorflow/core/kernels/transpose_functor.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/stream_executor.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/cuda_kernel_helper.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+static const char kErrMsg[] =
+ "Singular Value Decomposition was not successful. The input might not be "
+ "valid.";
+
+typedef Eigen::GpuDevice GPUDevice;
+
+namespace {
+// This kernel computes the reduction
+// V' = sum_i (M_i * U_i,1 * S_i).
+// The result is stored in V[batch] and has the same sign as the
+// real value of V (which should be computed)
+template <class Scalar>
+__global__ void ComputeValueOfVKernel(Cuda2DLaunchConfig config, int64 m,
+ int64 ldu, const Scalar* M,
+ const Scalar* U, const Scalar* S,
+ Scalar* V) {
+ CUDA_AXIS_KERNEL_LOOP(batch, config.virtual_thread_count, x) {
+ CUDA_AXIS_KERNEL_LOOP(i, config.virtual_thread_count, y) {
+ Scalar v = M[i + m * batch] * U[ldu * (i + m * batch)] * S[batch];
+ CudaAtomicAdd(V + batch, v);
+ }
+ }
+}
+
+// Extracts the sign of V
+// V[i] = V[i]>=0 ? 1 : 0
+template <class Scalar>
+__global__ void ExtractSignOfVKernel(CudaLaunchConfig config, Scalar* V) {
+ CUDA_1D_KERNEL_LOOP(i, config.virtual_thread_count) {
+ V[i] = V[i] >= 0 ? Scalar(1) : Scalar(-1);
+ }
+}
+}
+
+// Scalar: The input scalar type (can be complex)
+template <class Scalar>
+class SvdOpGpu : public AsyncOpKernel {
+ public:
+ using RealScalar = typename Eigen::NumTraits<Scalar>::Real;
+
+ explicit SvdOpGpu(OpKernelConstruction* context) : AsyncOpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("compute_uv", &compute_uv_));
+ OP_REQUIRES_OK(context, context->GetAttr("full_matrices", &full_matrices_));
+ }
+
+ void RunSVD(OpKernelContext* context, DoneCallback done, int64 m, int64 n,
+ int64 p, int64 batch_size, Scalar* input_ptr,
+ RealScalar* outputS_ptr, Scalar* outputU_ptr,
+ Scalar* outputVT_ptr, int* dev_info_ptr, CudaSolver& solver) {
+ // Save the input matrix
+ // Needed for the n=1 fix, see below, since SVD destroys the input
+ Tensor input_copy;
+ if (compute_uv_ && n == 1) {
+ OP_REQUIRES_OK_ASYNC(
+ context,
+ context->allocate_temp(DataTypeToEnum<Scalar>::v(),
+ TensorShape({batch_size, m}), &input_copy),
+ done);
+ const GPUDevice& d = context->eigen_device<GPUDevice>();
+ d.memcpy(input_copy.flat<Scalar>().data(), input_ptr,
+ batch_size * m * sizeof(Scalar));
+ }
+
+ for (int64 batch = 0; batch < batch_size; ++batch) {
+ Scalar* input = input_ptr + batch * m * n;
+ RealScalar* outputS = outputS_ptr + batch * p;
+ Scalar* outputU = NULL;
+ Scalar* outputVT = NULL;
+ char jobu = 'N';
+ char jobvt = 'N';
+
+ if (compute_uv_) {
+ if (full_matrices_) {
+ outputU = outputU_ptr + batch * m * m;
+ outputVT = outputVT_ptr + batch * n * n;
+ jobu = 'A';
+ jobvt = 'A';
+ } else {
+ outputU = outputU_ptr + batch * m * p;
+ outputVT = outputVT_ptr + batch * n * p;
+ jobu = 'S';
+ jobvt = 'S';
+ }
+ }
+
+ OP_REQUIRES_OK_ASYNC(
+ context, solver.Gesvd(jobu, jobvt, m, n, input, m, outputS, outputU,
+ m, outputVT, n, dev_info_ptr + batch),
+ done);
+ }
+
+ // This is a bug in cuSolver:
+ // If n is one, then outputVT only contains zeros instead of ones.
+ // Hence, I need to fill outputVT manually
+ // The question is: +1 or -1?
+ // -> Compute U*S and compare sign against M
+ // But because S is zero except for the first entry, the multiplication
+ // simplifies a lot.
+ // However, what happens if M contains zeros? At these indices, it is
+ // impossible to determine the value of V.
+ // -> Compute V for all rows in M to cope for zeros.
+ // 1. V' = sum_i (M_i * U_i,1 * S_i)
+ // 2. V = {1, V'>=0, -1, V'<0}
+ // TODO: what is with complex values?
+ if (compute_uv_ && n == 1) {
+ // 1. compute the (batched) sum
+ const GPUDevice& d = context->eigen_device<GPUDevice>();
+ d.memset(outputVT_ptr, 0, batch_size * sizeof(Scalar));
+ Cuda2DLaunchConfig cfg2D = GetCuda2DLaunchConfig(batch_size, m, d);
+ ComputeValueOfVKernel<<<cfg2D.block_count, cfg2D.thread_per_block, 0,
+ d.stream()>>>(
+ cfg2D, m, full_matrices_ ? m : p, input_copy.flat<Scalar>().data(),
+ outputU_ptr, outputS_ptr, outputVT_ptr);
+ // 2. clamp V to -1 or +1
+ CudaLaunchConfig cfg1D = GetCudaLaunchConfig(batch_size, d);
+ ExtractSignOfVKernel<<<cfg1D.block_count, cfg1D.thread_per_block, 0,
+ d.stream()>>>(cfg1D, outputVT_ptr);
+ }
+ }
+
+ void CheckResult(OpKernelContext* context, DoneCallback done,
+ const std::vector<DeviceLapackInfo>& dev_info,
+ CudaSolver& solver, Tensor& catch1, Tensor& catch2) {
+ auto info_checker = [context, dev_info, done, catch1, catch2](
+ const Status& status, const std::vector<HostLapackInfo>& /* unused */) {
+ Status full_status = status;
+ if (!full_status.ok()) {
+ full_status.Update(errors::InvalidArgument(kErrMsg));
+ }
+ OP_REQUIRES_OK_ASYNC(context, full_status, done);
+ done();
+ };
+
+ OP_REQUIRES_OK_ASYNC(context, solver.CopyLapackInfoToHostAsync(
+ dev_info, std::move(info_checker)),
+ done);
+ }
+
+ // The SVD if m >= n
+ // TODO: can the two cases (MgeqN and MlessN) be simplified,
+ // common boilerplate be reduced, or even combined in one method?
+ void PerformSVD_MgeqN(OpKernelContext* context, DoneCallback done, int64 m,
+ int64 n, int64 p, const gtl::ArraySlice<int32>& perm,
+ const Tensor& M, Tensor* S, Tensor* U, Tensor* V) {
+ TensorShape shapeRaw = M.shape();
+ shapeRaw.RemoveDim(shapeRaw.dims() - 1);
+ shapeRaw.RemoveDim(shapeRaw.dims() - 1);
+
+ // Transpose M, because cuSolver expects it to be column-major
+ TensorShape input_shape = shapeRaw;
+ input_shape.AddDim(n);
+ input_shape.AddDim(m);
+ Tensor input_copy;
+ OP_REQUIRES_OK_ASYNC(
+ context, context->allocate_temp(M.dtype(), input_shape, &input_copy),
+ done);
+ auto device = context->eigen_device<GPUDevice>();
+ OP_REQUIRES_OK_ASYNC(context, DoTranspose(device, M, perm, &input_copy),
+ done);
+
+ // I need to transpose U at the end
+ // Not V, because cuSolver work column-major
+ Tensor u_copy;
+ if (compute_uv_) {
+ TensorShape u_shape;
+ if (full_matrices_) {
+ u_shape = U->shape();
+ } else {
+ u_shape = shapeRaw;
+ u_shape.AddDim(p);
+ u_shape.AddDim(m);
+ }
+ OP_REQUIRES_OK_ASYNC(
+ context, context->allocate_temp(U->dtype(), u_shape, &u_copy), done);
+ }
+
+ // get the pointers to the data
+ Scalar* input_ptr;
+ RealScalar* outputS_ptr;
+ Scalar* outputU_ptr = NULL;
+ Scalar* outputV_ptr = NULL;
+ auto input_reshaped = input_copy.template flat_inner_dims<Scalar, 3>();
+ input_ptr = input_reshaped.data();
+ outputS_ptr = S->template flat_inner_dims<RealScalar, 2>().data();
+ if (compute_uv_) {
+ outputU_ptr = u_copy.template flat_inner_dims<Scalar, 3>().data();
+ outputV_ptr = V->template flat_inner_dims<Scalar, 3>().data();
+ }
+
+ // call the SVD
+ const int64 batch_size = input_reshaped.dimension(0);
+ std::vector<DeviceLapackInfo> dev_info;
+ dev_info.emplace_back(context, batch_size, "gesvd");
+ CudaSolver solver(context);
+ RunSVD(context, done, m, n, p, batch_size, input_ptr, outputS_ptr,
+ outputU_ptr, outputV_ptr, dev_info.back().mutable_data(), solver);
+
+ // Transpose U
+ if (compute_uv_) {
+ OP_REQUIRES_OK_ASYNC(context, DoTranspose(device, u_copy, perm, U), done);
+ }
+
+ // now check if the SVD operation succeeded or not
+ CheckResult(context, done, dev_info, solver, input_copy, u_copy);
+ }
+
+ // The SVD if m < n
+ void PerformSVD_MlessN(OpKernelContext* context, DoneCallback done, int64 m,
+ int64 n, int64 p, const gtl::ArraySlice<int32>& perm,
+ const Tensor& M, Tensor* S, Tensor* U, Tensor* V) {
+ // Perform the SVD on M'
+
+ // Reuse the input buffer or make a copy for the SVD depending on whether
+ // this op owns the
+ // input buffer exclusively. This is needed because the SVD modifies the
+ // input
+ Tensor input_copy;
+ OP_REQUIRES_OK_ASYNC(context, context->forward_input_or_allocate_temp(
+ {0}, DataTypeToEnum<Scalar>::value,
+ M.shape(), &input_copy),
+ done);
+
+ if (!M.SharesBufferWith(input_copy)) {
+ const GPUDevice& d = context->eigen_device<GPUDevice>();
+ d.memcpy(input_copy.flat<Scalar>().data(), M.flat<Scalar>().data(),
+ M.NumElements() * sizeof(Scalar));
+ }
+
+ // I need to transpose V at the end
+ Tensor v_copy;
+ if (compute_uv_) {
+ TensorShape v_shape;
+ if (full_matrices_) {
+ v_shape = V->shape();
+ } else {
+ TensorShape shapeRaw = M.shape();
+ shapeRaw.RemoveDim(shapeRaw.dims() - 1);
+ shapeRaw.RemoveDim(shapeRaw.dims() - 1);
+ v_shape = shapeRaw;
+ v_shape.AddDim(p);
+ v_shape.AddDim(n);
+ }
+ OP_REQUIRES_OK_ASYNC(
+ context, context->allocate_temp(V->dtype(), v_shape, &v_copy), done);
+ }
+
+ // get the pointers to the data
+ Scalar* input_ptr;
+ RealScalar* outputS_ptr;
+ Scalar* outputU_ptr = NULL;
+ Scalar* outputV_ptr = NULL;
+ auto input_reshaped = input_copy.template flat_inner_dims<Scalar, 3>();
+ input_ptr = input_reshaped.data();
+ outputS_ptr = S->template flat_inner_dims<RealScalar, 2>().data();
+ if (compute_uv_) {
+ // Note that U and V are flipped
+ outputU_ptr = v_copy.template flat_inner_dims<Scalar, 3>().data();
+ outputV_ptr = U->template flat_inner_dims<Scalar, 3>().data();
+ }
+
+ // call the SVD
+ const int64 batch_size = input_reshaped.dimension(0);
+ std::vector<DeviceLapackInfo> dev_info;
+ dev_info.emplace_back(context, batch_size, "gesvd");
+ CudaSolver solver(context);
+ // Note that m and n are flipped
+ RunSVD(context, done, n, m, p, batch_size, input_ptr, outputS_ptr,
+ outputU_ptr, outputV_ptr, dev_info.back().mutable_data(), solver);
+
+ // Transpose V
+ if (compute_uv_) {
+ auto device = context->eigen_device<GPUDevice>();
+ OP_REQUIRES_OK_ASYNC(context, DoTranspose(device, v_copy, perm, V), done);
+ }
+
+ // now check if the SVD operation succeeded or not
+ CheckResult(context, done, dev_info, solver, input_copy, v_copy);
+ }
+
+ void ComputeAsync(OpKernelContext* context, DoneCallback done) final {
+ const Tensor& input = context->input(0);
+ const int ndims = input.dims();
+ const int64 m = input.dim_size(ndims - 2);
+ const int64 n = input.dim_size(ndims - 1);
+ const int64 p = std::min(m, n);
+
+ // Validate inputs.
+ OP_REQUIRES_ASYNC(
+ context, ndims >= 2,
+ errors::InvalidArgument("Input must have rank >= 2, got ", ndims),
+ done);
+
+ // output tensors.
+ Tensor* outputU = NULL;
+ Tensor* outputS = NULL;
+ Tensor* outputV = NULL;
+
+ // compute shapes
+ TensorShape shapeRaw = input.shape();
+ shapeRaw.RemoveDim(shapeRaw.dims() - 1);
+ shapeRaw.RemoveDim(shapeRaw.dims() - 1);
+ TensorShape shapeS = shapeRaw;
+ TensorShape shapeU = shapeRaw;
+ TensorShape shapeV = shapeRaw;
+ shapeS.AddDim(p);
+ if (compute_uv_) {
+ if (full_matrices_) {
+ shapeU.AddDim(m);
+ shapeU.AddDim(m);
+ shapeV.AddDim(n);
+ shapeV.AddDim(n);
+ } else {
+ shapeU.AddDim(m);
+ shapeU.AddDim(p);
+ shapeV.AddDim(n);
+ shapeV.AddDim(p);
+ }
+ } else {
+ shapeU = TensorShape({0});
+ shapeV = TensorShape({0});
+ }
+
+ // allocate output
+ OP_REQUIRES_OK_ASYNC(context, context->allocate_output(0, shapeS, &outputS),
+ done);
+ OP_REQUIRES_OK_ASYNC(context, context->allocate_output(1, shapeU, &outputU),
+ done);
+ OP_REQUIRES_OK_ASYNC(context, context->allocate_output(2, shapeV, &outputV),
+ done);
+
+ if (n == 0 || m == 0) {
+ // If X is an empty matrix (0 rows, 0 col), X * X' == X.
+ // Therefore, we return X.
+ done();
+ return;
+ }
+
+ // Prepare permutation
+ std::vector<int32> perm;
+ for (size_t i = 0; i < ndims - 2; ++i) perm.push_back(i);
+ perm.push_back(ndims - 1); // transpose last two dimensions
+ perm.push_back(ndims - 2);
+ gtl::ArraySlice<int32> permAS(perm);
+
+ // call implementations
+ if (m >= n) {
+ PerformSVD_MgeqN(context, done, m, n, p, permAS, input, outputS, outputU,
+ outputV);
+ } else {
+ PerformSVD_MlessN(context, done, m, n, p, permAS, input, outputS, outputU,
+ outputV);
+ }
+ }
+
+ private:
+ bool compute_uv_;
+ bool full_matrices_;
+};
+
+// TODO: add support for complex types
+REGISTER_LINALG_OP_GPU("Svd", (SvdOpGpu<float>), float);
+REGISTER_LINALG_OP_GPU("Svd", (SvdOpGpu<double>), double);
+REGISTER_LINALG_OP_GPU("BatchSvd", (SvdOpGpu<float>), float);
+REGISTER_LINALG_OP_GPU("BatchSvd", (SvdOpGpu<double>), double);
+
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/tensor_array_ops.cc b/tensorflow/core/kernels/tensor_array_ops.cc
index 075bacb432..2191e4e8c5 100644
--- a/tensorflow/core/kernels/tensor_array_ops.cc
+++ b/tensorflow/core/kernels/tensor_array_ops.cc
@@ -1069,7 +1069,7 @@ class TensorArrayUnpackOrScatterOp : public OpKernel {
} else {
OP_REQUIRES(
ctx, max_index < array_size,
- errors::InvalidArgument("Max scatter index must be <= array size (",
+ errors::InvalidArgument("Max scatter index must be < array size (",
max_index, " vs. ", array_size, ")"));
}
element_shape.RemoveDim(0);
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 2a59282fa5..ef4737cafe 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -498,6 +498,24 @@ Returns x + y element-wise.
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
)doc");
+REGISTER_OP("_MklAdd")
+ .Input("x: T")
+ .Input("y: T")
+ .Input("mkl_x: uint8")
+ .Input("mkl_y: uint8")
+ .Output("z: T")
+ .Output("mkl_z: uint8")
+ .Attr(
+ "T: {half, float, double, uint8, int8, int16, int32, int64, complex64, "
+ "complex128, string}")
+ .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
+ .Doc(R"doc(
+Returns x + y element-wise.
+
+*NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting
+[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+)doc");
+
REGISTER_OP("Sub")
.BINARY_MORE()
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
@@ -508,6 +526,19 @@ Returns x - y element-wise.
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
)doc");
+REGISTER_OP("_MklSub")
+ .BINARY_FEWER()
+ .Input("mkl_x: uint8")
+ .Input("mkl_y: uint8")
+ .Output("mkl_z: uint8")
+ .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
+ .Doc(R"doc(
+Returns x - y element-wise.
+
+*NOTE*: `Sub` supports broadcasting. More about broadcasting
+[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+)doc");
+
REGISTER_OP("Mul")
.BINARY_MORE()
.SetIsCommutative()
@@ -519,6 +550,20 @@ Returns x * y element-wise.
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
)doc");
+REGISTER_OP("_MklMul")
+ .BINARY_MORE()
+ .Input("mkl_x: uint8")
+ .Input("mkl_y: uint8")
+ .Output("mkl_z: uint8")
+ .SetIsCommutative()
+ .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
+ .Doc(R"doc(
+Returns x * y element-wise.
+
+*NOTE*: `Mul` supports broadcasting. More about broadcasting
+[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+)doc");
+
REGISTER_OP("Div")
.BINARY_MORE()
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
@@ -577,6 +622,20 @@ Returns (x - y)(x - y) element-wise.
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
)doc");
+REGISTER_OP("_MklSquaredDifference")
+ .BINARY_FEWER()
+ .Input("mkl_x: uint8")
+ .Input("mkl_y: uint8")
+ .Output("mkl_z: uint8")
+ .SetIsCommutative()
+ .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
+ .Doc(R"doc(
+Returns (x - y)(x - y) element-wise.
+
+*NOTE*: `SquaredDifference` supports broadcasting. More about broadcasting
+[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+)doc");
+
#undef BINARY_FEWER
#undef BINARY_MORE
@@ -594,6 +653,23 @@ Returns the max of x and y (i.e. x > y ? x : y) element-wise.
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
)doc");
+REGISTER_OP("_MklMaximum")
+ .Input("x: T")
+ .Input("y: T")
+ .Input("mkl_x: uint8")
+ .Input("mkl_y: uint8")
+ .Output("z: T")
+ .Output("mkl_z: uint8")
+ .Attr("T: {half, float, double, int32, int64}")
+ .SetIsCommutative()
+ .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
+ .Doc(R"doc(
+Returns the max of x and y (i.e. x > y ? x : y) element-wise.
+
+*NOTE*: `Maximum` supports broadcasting. More about broadcasting
+[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+)doc");
+
REGISTER_OP("Minimum")
.Input("x: T")
.Input("y: T")
@@ -2604,4 +2680,31 @@ Equivalent to np.digitize.
@end_compatibility
)doc");
+#ifdef INTEL_MKL
+REGISTER_OP("_MklAddN")
+ .Input("inputs: N * T")
+ .Input("mkl_input: N * uint8")
+ .Output("sum: T")
+ .Output("mkl_sum: uint8")
+ .Attr("N: int >= 1")
+ .Attr("T: numbertype")
+ .SetIsCommutative()
+ .SetIsAggregate()
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle cur = c->input(c->num_inputs() - 1);
+ for (int i = c->num_inputs() - 2; i >= 0; --i) {
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur),
+ "From merging shape ", i,
+ " with other shapes.");
+ }
+ c->set_output(0, cur);
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Add two input tensors element wise using mkl kernel sum.
+inputs: Must all be the same size and shape.
+)doc");
+
+#endif // INTEL_MKL
+
} // namespace tensorflow
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index fd0b785b8f..22afa4db9a 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -3241,6 +3241,29 @@ MKL operator to convert a tensor from MKL layout to TensorFlow layout.
NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
expected to invoke these operators.
)doc");
+
+REGISTER_OP("_MklInputConversion")
+ .Input("input_0: T")
+ .Input("input_1: T")
+ .Input("mkl_input_0: uint8")
+ .Input("mkl_input_1: uint8")
+ .Output("output_0: T")
+ .Output("output_1: T")
+ .Output("mkl_output_0: uint8")
+ .Output("mkl_output_1: uint8")
+ // All datatypes supported by element-wise ops
+ .Attr(
+ "T: {half, float, double, uint8, int8, uint16, int16, int32, int64, "
+ "complex64, complex128}")
+ .Attr(GetConvnetDataFormatAttrString())
+ .Doc(R"doc(
+MKL operator to process the inputs to an elementwise MKL op. Both inputs
+need to be either in TF or in MKL format. This op is added before every
+element-wise MKL op.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
#endif // INTEL_MKL
} // namespace tensorflow
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index c4bc57fd77..1b07f4ecf8 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -15866,6 +15866,25 @@ op {
summary: "Transforms a serialized tensorflow.TensorProto proto into a Tensor."
}
op {
+ name: "SerializeTensor"
+ input_arg {
+ name: "tensor"
+ description: "A Tensor of type `T`."
+ type: "T"
+ }
+ output_arg {
+ name: "serialized"
+ description: "A serialized TensorProto proto of the input tensor."
+ type_attr: DT_STRING
+ }
+ attr {
+ name: "T"
+ type: "type"
+ description: "The type of the input tensor."
+ }
+ summary: "Transforms a Tensor into a serialized TensorProto proto."
+}
+op {
name: "Placeholder"
output_arg {
name: "output"
diff --git a/tensorflow/core/ops/parsing_ops.cc b/tensorflow/core/ops/parsing_ops.cc
index 1f7ebe91cf..f23ff083af 100644
--- a/tensorflow/core/ops/parsing_ops.cc
+++ b/tensorflow/core/ops/parsing_ops.cc
@@ -26,7 +26,7 @@ using shape_inference::ShapeHandle;
REGISTER_OP("DecodeRaw")
.Input("bytes: string")
.Output("output: out_type")
- .Attr("out_type: {half,float,double,int32,uint8,int16,int8,int64}")
+ .Attr("out_type: {half,float,double,int32,uint16,uint8,int16,int8,int64}")
.Attr("little_endian: bool = true")
.SetShapeFn([](InferenceContext* c) {
// Note: last dimension is data dependent.
diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc
index 5e99187d50..aebd14c7e5 100644
--- a/tensorflow/core/ops/string_ops.cc
+++ b/tensorflow/core/ops/string_ops.cc
@@ -381,7 +381,7 @@ input = b'thirteen'
position = [1, 5, 7]
length = [3, 2, 1]
-output = [b'hir', b'ee', b'n"]
+output = [b'hir', b'ee', b'n']
```
input: Tensor of strings
diff --git a/tensorflow/core/platform/cuda_libdevice_path_test.cc b/tensorflow/core/platform/cuda_libdevice_path_test.cc
index 86295592a8..639f6804ea 100644
--- a/tensorflow/core/platform/cuda_libdevice_path_test.cc
+++ b/tensorflow/core/platform/cuda_libdevice_path_test.cc
@@ -27,7 +27,7 @@ TEST(CudaLibdevicePathTest, LibdevicePath) {
VLOG(2) << "Libdevice root = " << LibdeviceRoot();
std::vector<string> libdevice_files;
TF_EXPECT_OK(Env::Default()->GetMatchingPaths(
- io::JoinPath(LibdeviceRoot(), "libdevice.compute_*.bc"),
+ io::JoinPath(LibdeviceRoot(), "libdevice.*.bc"),
&libdevice_files));
EXPECT_LT(0, libdevice_files.size());
}
diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h
index ccb861c93a..9ba3a509c3 100644
--- a/tensorflow/core/public/version.h
+++ b/tensorflow/core/public/version.h
@@ -19,12 +19,12 @@ limitations under the License.
// TensorFlow uses semantic versioning, see http://semver.org/.
#define TF_MAJOR_VERSION 1
-#define TF_MINOR_VERSION 3
+#define TF_MINOR_VERSION 4
#define TF_PATCH_VERSION 0
// TF_VERSION_SUFFIX is non-empty for pre-releases (e.g. "-alpha", "-alpha.1",
// "-beta", "-rc", "-rc.1")
-#define TF_VERSION_SUFFIX ""
+#define TF_VERSION_SUFFIX "-dev"
#define TF_STR_HELPER(x) #x
#define TF_STR(x) TF_STR_HELPER(x)
diff --git a/tensorflow/core/util/cuda_kernel_helper.h b/tensorflow/core/util/cuda_kernel_helper.h
index af727c3d2b..f8eddbb2a9 100644
--- a/tensorflow/core/util/cuda_kernel_helper.h
+++ b/tensorflow/core/util/cuda_kernel_helper.h
@@ -25,6 +25,29 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/platform/types.h"
+#include "cuda/include/cuda.h"
+
+// Mask for all 32 threads in a warp.
+#define CUDA_WARP_ALL 0xFFFFFFFF
+
+#if defined(CUDA_VERSION) && CUDA_VERSION < 9000
+// CUDA 9.0 introduces a new, light-weight barrier synchronization primitive
+// that operates at the warp-scope. This is required to ensure visibility of
+// reads/writes among threads that can make indepenent progress on Volta.
+// For previous CUDA versions these synchronizations not necessary, and we
+// define an empty function as a convenience for backward compatibility.
+__device__ inline void __syncwarp(unsigned mask=CUDA_WARP_ALL) {}
+
+// CUDA 9.0 deprecates the warp-intrinsic functions (shfl, ballot, etc.) in
+// favor of synchronizing versions. These ensure that all warp lanes specified
+// in mask execute the intrinsic in convergence. Here we provide legacy mappings
+// to the less-verbose routines provided in previous versions of CUDA.
+#define __ballot_sync(mask, predicate) __ballot(predicate)
+#define __shfl_sync(mask, val, srcLane, width) __shfl(val, srcLane, width)
+#define __shfl_down_sync(mask, val, delta, width) __shfl_down(val, delta, width)
+#define __shfl_up_sync(mask, val, delta, width) __shfl_up(val, delta, width)
+#define __shfl_xor_sync(mask, val, laneMask, width) __shfl_xor(val, laneMask, width)
+#endif
// Usage of GetCudaLaunchConfig, GetCuda2DLaunchConfig, and
// GetCuda3DLaunchConfig:
@@ -613,82 +636,95 @@ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T tf_max(const T& x, const T& y) {
return x < y ? y : x;
}
+__device__ EIGEN_ALWAYS_INLINE unsigned CudaBallot(unsigned mask,
+ int predicate) {
+ return __ballot_sync(mask, predicate);
+}
+
template <typename T>
-__device__ EIGEN_ALWAYS_INLINE T CudaShuffle(T value, int srcLane,
+__device__ EIGEN_ALWAYS_INLINE T CudaShuffle(unsigned mask, T value,
+ int srcLane,
int width = warpSize) {
- return __shfl(value, srcLane, width);
+ return __shfl_sync(mask, value, srcLane, width);
}
// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
// instead of float for lo and hi (which is incorrect with ftz, for example).
// A bug has been filed with NVIDIA and will be fixed in the next CUDA release.
// TODO(csigg): remove when the bug is fixed in the next CUDA release.
-__device__ EIGEN_ALWAYS_INLINE double CudaShuffle(double value, int srcLane,
+__device__ EIGEN_ALWAYS_INLINE double CudaShuffle(unsigned mask,
+ double value, int srcLane,
int width = warpSize) {
unsigned lo, hi;
asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
- hi = __shfl(hi, srcLane, width);
- lo = __shfl(lo, srcLane, width);
+ hi = __shfl_sync(mask, hi, srcLane, width);
+ lo = __shfl_sync(mask, lo, srcLane, width);
asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
return value;
}
template <typename T>
-__device__ EIGEN_ALWAYS_INLINE T CudaShuffleUp(T value, int delta,
+__device__ EIGEN_ALWAYS_INLINE T CudaShuffleUp(unsigned mask,
+ T value, int delta,
int width = warpSize) {
- return __shfl_up(value, delta, width);
+ return __shfl_up_sync(mask, value, delta, width);
}
// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
// instead of float for lo and hi (which is incorrect with ftz, for example).
// A bug has been filed with NVIDIA and will be fixed in the next CUDA release.
// TODO(csigg): remove when the bug is fixed in the next CUDA release.
-__device__ EIGEN_ALWAYS_INLINE double CudaShuffleUp(double value, int delta,
+__device__ EIGEN_ALWAYS_INLINE double CudaShuffleUp(unsigned mask,
+ double value, int delta,
int width = warpSize) {
unsigned lo, hi;
asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
- hi = __shfl_up(hi, delta, width);
- lo = __shfl_up(lo, delta, width);
+ hi = __shfl_up_sync(mask, hi, delta, width);
+ lo = __shfl_up_sync(mask, lo, delta, width);
asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
return value;
}
template <typename T>
-__device__ EIGEN_ALWAYS_INLINE T CudaShuffleDown(T value, int delta,
+__device__ EIGEN_ALWAYS_INLINE T CudaShuffleDown(unsigned mask,
+ T value, int delta,
int width = warpSize) {
- return __shfl_down(value, delta, width);
+ return __shfl_down_sync(mask, value, delta, width);
}
// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
// instead of float for lo and hi (which is incorrect with ftz, for example).
// A bug has been filed with NVIDIA and will be fixed in the next CUDA release.
// TODO(csigg): remove when the bug is fixed in the next CUDA release.
-__device__ EIGEN_ALWAYS_INLINE double CudaShuffleDown(double value, int delta,
+__device__ EIGEN_ALWAYS_INLINE double CudaShuffleDown(unsigned mask,
+ double value, int delta,
int width = warpSize) {
unsigned lo, hi;
asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
- hi = __shfl_down(hi, delta, width);
- lo = __shfl_down(lo, delta, width);
+ hi = __shfl_down_sync(mask, hi, delta, width);
+ lo = __shfl_down_sync(mask, lo, delta, width);
asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
return value;
}
template <typename T>
-__device__ EIGEN_ALWAYS_INLINE T CudaShuffleXor(T value, int laneMask,
+__device__ EIGEN_ALWAYS_INLINE T CudaShuffleXor(unsigned mask,
+ T value, int laneMask,
int width = warpSize) {
- return __shfl_xor(value, laneMask, width);
+ return __shfl_xor_sync(mask, value, laneMask, width);
}
// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
// instead of float for lo and hi (which is incorrect with ftz, for example).
// A bug has been filed with NVIDIA and will be fixed in the next CUDA release.
// TODO(csigg): remove when the bug is fixed in the next CUDA release.
-__device__ EIGEN_ALWAYS_INLINE double CudaShuffleXor(double value, int laneMask,
+__device__ EIGEN_ALWAYS_INLINE double CudaShuffleXor(unsigned mask,
+ double value, int laneMask,
int width = warpSize) {
unsigned lo, hi;
asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
- hi = __shfl_xor(hi, laneMask, width);
- lo = __shfl_xor(lo, laneMask, width);
+ hi = __shfl_xor_sync(mask, hi, laneMask, width);
+ lo = __shfl_xor_sync(mask, lo, laneMask, width);
asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
return value;
}
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h
index cb22a50e8f..f4bec9524a 100644
--- a/tensorflow/core/util/mkl_util.h
+++ b/tensorflow/core/util/mkl_util.h
@@ -65,6 +65,8 @@ class MklShape {
void SetDimensions(const size_t dimension) { dimension_ = dimension; }
+ void SetMklLayout(dnnLayout_t mklLayout) { mklLayout_ = mklLayout; }
+
void SetMklLayout(const void* primitive, size_t resourceType) {
CHECK_EQ(
dnnLayoutCreateFromPrimitive_F32(&mklLayout_, (dnnPrimitive_t)primitive,
@@ -135,6 +137,7 @@ class MklShape {
size_t GetDimension() const { return dimension_; }
const size_t* GetSizes() const { return sizes_; }
int64 dim_size(int index) const { return sizes_[index]; }
+ int64 tf_dim_size(int index) const { return sizes_[tf_to_mkl_dim_map_[index]]; }
const size_t* GetStrides() const { return strides_; }
const size_t* GetTfToMklDimMap() const { return tf_to_mkl_dim_map_; }
size_t tf_dim_idx(int index) const { return tf_to_mkl_dim_map_[index]; }
@@ -581,7 +584,7 @@ inline void CopyTfTensorInToOutWithShape(OpKernelContext* context,
context->set_output(idx_data_out, output);
}
-inline void FowardTfTensorInToOut(OpKernelContext* context,
+inline void ForwardTfTensorInToOut(OpKernelContext* context,
int idx_in, int idx_out) {
int num_inputs = context->num_inputs();
int num_outputs = context->num_outputs();
@@ -598,7 +601,7 @@ inline void FowardTfTensorInToOut(OpKernelContext* context,
}
}
-inline void ForwarMklTensorInToOut(OpKernelContext* context,
+inline void ForwardMklTensorInToOut(OpKernelContext* context,
int idx_in, int idx_out) {
int num_inputs = context->num_inputs();
int num_outputs = context->num_outputs();
@@ -616,6 +619,98 @@ inline void ForwarMklTensorInToOut(OpKernelContext* context,
}
}
+// Forward the MKL shape ONLY (used in elementwise and other ops where
+// we call the eigen implementation and MKL shape is not used)
+inline void ForwardMklMetaDataInToOut(OpKernelContext* context,
+ uint idx_data_in, uint idx_data_out) {
+ uint idx_meta_in = GetTensorMetaDataIndex(idx_data_in, context->num_inputs());
+ uint idx_meta_out =
+ GetTensorMetaDataIndex(idx_data_out, context->num_outputs());
+
+ if (IsRefType(context->input_dtype(idx_data_in))) {
+ context->forward_ref_input_to_ref_output(idx_meta_in, idx_meta_out);
+ } else {
+ context->set_output(idx_meta_out, context->input(idx_meta_in));
+ }
+}
+
+// Set a dummy MKL shape (called when the output is in TF format)
+inline void SetDummyMklShapeOutput(OpKernelContext* context,
+ uint idx_data_out) {
+ MklShape mkl_shape_output;
+ mkl_shape_output.SetMklTensor(false);
+ AllocateOutputSetMklShape(context, idx_data_out, mkl_shape_output);
+}
+
+// Checks if the TF shape for both MKL tensors is the same or not
+// Returns: true if both TF shapes are the same, false otherwise
+inline bool MklCompareShapes(const MklShape* input_shape_0,
+ const MklShape* input_shape_1) {
+ // Check for number of dimensions
+ if (input_shape_0->GetDimension() != input_shape_1->GetDimension()) {
+ return false;
+ }
+
+ // Check size of each dimension
+ size_t ndims = input_shape_0->GetDimension();
+ for (size_t i = 0; i < ndims; i++) {
+ if (input_shape_0->dim_size(i) != input_shape_1->dim_size(i)) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+// Checks if the TF shape for both tensors is the same or not
+// Returns: true if TF shapes for both are the same, false otherwise
+inline bool MklCompareShapes(const MklShape* input_shape_0,
+ const TensorShape* input_shape_1) {
+ // Check for number of dimensions
+ if (input_shape_0->GetDimension() != input_shape_1->dims()) {
+ return false;
+ }
+
+ // Check size of each dimension
+ size_t ndims = input_shape_0->GetDimension();
+ for (size_t i = 0; i < ndims; i++) {
+ if (input_shape_0->tf_dim_size(i) != input_shape_1->dim_size(i)) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+// Checks if the TF shape for both tensors is the same or not
+// Returns: true if TF shapes for both are the same, false otherwise
+inline bool MklCompareShapes(const TensorShape* input_shape_0,
+ const MklShape* input_shape_1) {
+ return MklCompareShapes(input_shape_1, input_shape_0);
+}
+
+// Checks if the TF shape for both tensors is the same or not
+// Returns: true if TF shapes for both are the same, false otherwise
+inline bool MklCompareShapes(const TensorShape* input_shape_0,
+ const TensorShape* input_shape_1) {
+ // Check for number of dimensions
+ if (input_shape_0->dims() != input_shape_1->dims()) {
+ return false;
+ }
+
+ // Check size of each dimension
+ size_t ndims = input_shape_0->dims();
+ for (size_t i = 0; i < ndims; i++) {
+ if (input_shape_0->dim_size(i) != input_shape_1->dim_size(i)) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+// TODO(intel_tf): Remove this routine when faster MKL layout conversion is
+// out.
inline void MklNHWCToNCHW(const Tensor& input, Tensor** output) {
const float* buf_in = input.flat<float>().data();
float* buf_out = (*output)->flat<float>().data();
@@ -652,11 +747,19 @@ namespace mkl_op_registry {
static const char* kMklOpLabel = "MklOp";
static const char* kMklOpLabelPattern = "label='MklOp'";
+// Get the name of Mkl op from original TensorFlow op
+// We prefix 'Mkl' to the original op to get Mkl op.
+inline string GetMklOpName(const string& name) {
+ // Prefix that we add to Tensorflow op name to construct Mkl op name.
+ const char* const kMklOpPrefix = "_Mkl";
+ return string(kMklOpPrefix) + name;
+}
+
// Check whether opname with type T is registered as MKL-compliant.
//
// @input: name of the op
// @input: T datatype to be used for checking op
-// @return: true if opname is registered as Mkl op
+// @return: true if opname is registered as Mkl op; false otherwise
static inline bool IsMklOp(const std::string& op_name, DataType T) {
string kernel = KernelsRegisteredForOp(op_name);
bool result =
@@ -667,6 +770,28 @@ static inline bool IsMklOp(const std::string& op_name, DataType T) {
return result;
}
+// Check whether opname with type T is registered as MKL-compliant and
+// is element-wise.
+//
+// @input: name of the op
+// @input: T datatype to be used for checking op
+// @return: true if opname is registered as element-wise Mkl op; false otherwise
+static inline bool IsMklElementWiseOp(const std::string& op_name, DataType T) {
+ if (!IsMklOp(op_name, T)) {
+ return false;
+ }
+
+ bool result = (0 == op_name.compare(GetMklOpName("Add")) ||
+ 0 == op_name.compare(GetMklOpName("Sub")) ||
+ 0 == op_name.compare(GetMklOpName("Mul")) ||
+ 0 == op_name.compare(GetMklOpName("Maximum")) ||
+ 0 == op_name.compare(GetMklOpName("SquaredDifference")));
+
+ VLOG(1) << "mkl_op_registry::" << op_name
+ << " is elementwise MKL op: " << result;
+ return result;
+}
+
} // namespace mkl_op_registry
} // namespace tensorflow