aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/mkl_layout_pass.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-09-06 17:57:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-06 18:01:27 -0700
commite722358e7e96dd2aa20d7e2c56336e76845daa6a (patch)
treea74960670ce4bacad0909fc913097bcc3e27ed18 /tensorflow/core/graph/mkl_layout_pass.cc
parentf8a43f9d63ce90f10852d69e40fbb9fe849fc190 (diff)
Merge changes from github.
END_PUBLIC --- Commit 607816029 authored by Eugene Brevdo<ebrevdo@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Extended ScratchSpace to expose its underlying scratch tensor object. PiperOrigin-RevId: 167649551 --- Commit db43fe68e authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Add fast math attributes to all generated methods when fast math enabled. RELNOTES: n/a PiperOrigin-RevId: 167646637 --- Commit aebe8cc6f authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Call HloComputation.Accept instead of HloInstruction.Accept to get all instructions profiled. RELNOTES: n/a PiperOrigin-RevId: 167640259 --- Commit 0ab137cd8 authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: BEGIN_PUBLIC Automated g4 rollback of changelist 167604306 PiperOrigin-RevId: 167800256
Diffstat (limited to 'tensorflow/core/graph/mkl_layout_pass.cc')
-rw-r--r--tensorflow/core/graph/mkl_layout_pass.cc60
1 files changed, 56 insertions, 4 deletions
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index 2f9ceaa3bd..cf5d6e8baa 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -1099,6 +1099,44 @@ int MklLayoutRewritePass::SetUpContiguousInputs(
CHECK_NOTNULL(workspace_tensors);
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
+ // TODO(nhasabni): Temporary solution to connect filter input of
+ // BackpropInput with the converted filter from Conv2D.
+ bool do_connect_conv2d_backprop_input_filter = false;
+ Node* conv2d_node = nullptr;
+ // Filter node is 2nd input (slot index 1) of Conv2D.
+ int kConv2DFilterInputSlotIdx = 1;
+ int kConv2DBackpropInputFilterInputSlotIdx = 1;
+ int kConv2DFilterOutputSlotIdx = 1;
+ if (old_node->type_string() == csinfo_.conv2d_grad_input) {
+ // We need to find Conv2D node from Conv2DBackpropInput.
+ // For that let's first find filter node that is 2nd input (slot 1)
+ // of BackpropInput.
+ Node* filter_node = nullptr;
+ old_node->input_node(kConv2DBackpropInputFilterInputSlotIdx, &filter_node);
+ CHECK_NOTNULL(filter_node);
+
+ // Now check which nodes receive from filter_node. Filter feeds as
+ // 2nd input (slot 1) of _MklConv2D and _MklConv2DWithBias.
+ for (const Edge* e : filter_node->out_edges()) {
+ if (e->dst()->type_string() == csinfo_.mkl_conv2d &&
+ e->dst_input() == kConv2DFilterInputSlotIdx
+ /* filter is 2nd input of Conv2D and _MklConv2D. */) {
+ if (conv2d_node != nullptr) {
+ VLOG(1) << "MklLayoutRewritePass: unusual case of same filter"
+ << " feeding multiple Conv2D nodes: "
+ << filter_node->DebugString();
+ // We will not connect filter input of Conv2DBackpropInput
+ // to be safe here.
+ do_connect_conv2d_backprop_input_filter = false;
+ break;
+ } else {
+ conv2d_node = e->dst();
+ do_connect_conv2d_backprop_input_filter = true;
+ }
+ }
+ }
+ }
+
// Number of input slots to original op
// Input slots are represented by .Input() calls in REGISTER_OP.
int old_node_input_slots = old_node->op_def().input_arg_size();
@@ -1122,7 +1160,13 @@ int MklLayoutRewritePass::SetUpContiguousInputs(
nb->Input(new_node_inputs);
nn_slot_idx++;
} else {
- nb->Input(old_node_inputs[iidx].first, old_node_inputs[iidx].second);
+ // Special case for connecting filter input of Conv2DBackpropInput
+ if (do_connect_conv2d_backprop_input_filter &&
+ iidx == kConv2DBackpropInputFilterInputSlotIdx) {
+ nb->Input(conv2d_node, kConv2DFilterOutputSlotIdx);
+ } else {
+ nb->Input(old_node_inputs[iidx].first, old_node_inputs[iidx].second);
+ }
iidx++;
nn_slot_idx++;
}
@@ -1157,9 +1201,17 @@ int MklLayoutRewritePass::SetUpContiguousInputs(
} else {
Node* mkl_node = nullptr;
int mkl_node_output_slot = 0;
- GetNodeProducingMklTensor(g, old_node, old_node_inputs[iidx].first,
- old_node_inputs[iidx].second,
- &mkl_node, &mkl_node_output_slot);
+ // Special case for connecting filter input of Conv2DBackpropInput
+ if (do_connect_conv2d_backprop_input_filter &&
+ iidx == kConv2DBackpropInputFilterInputSlotIdx) {
+ GetNodeProducingMklTensor(g, old_node, conv2d_node,
+ kConv2DFilterOutputSlotIdx, &mkl_node,
+ &mkl_node_output_slot);
+ } else {
+ GetNodeProducingMklTensor(g, old_node, old_node_inputs[iidx].first,
+ old_node_inputs[iidx].second, &mkl_node,
+ &mkl_node_output_slot);
+ }
nb->Input(mkl_node, mkl_node_output_slot);
iidx++;
nn_slot_idx++;