diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-09-06 17:57:04 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-09-06 18:01:27 -0700 |
commit | e722358e7e96dd2aa20d7e2c56336e76845daa6a (patch) | |
tree | a74960670ce4bacad0909fc913097bcc3e27ed18 /tensorflow/core/graph/mkl_layout_pass.cc | |
parent | f8a43f9d63ce90f10852d69e40fbb9fe849fc190 (diff) |
Merge changes from github.
END_PUBLIC
---
Commit 607816029 authored by Eugene Brevdo<ebrevdo@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Extended ScratchSpace to expose its underlying scratch tensor object.
PiperOrigin-RevId: 167649551
---
Commit db43fe68e authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Add fast math attributes to all generated methods when fast math enabled.
RELNOTES: n/a
PiperOrigin-RevId: 167646637
---
Commit aebe8cc6f authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Call HloComputation.Accept instead of HloInstruction.Accept to get all instructions profiled.
RELNOTES: n/a
PiperOrigin-RevId: 167640259
---
Commit 0ab137cd8 authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
BEGIN_PUBLIC
Automated g4 rollback of changelist 167604306
PiperOrigin-RevId: 167800256
Diffstat (limited to 'tensorflow/core/graph/mkl_layout_pass.cc')
-rw-r--r-- | tensorflow/core/graph/mkl_layout_pass.cc | 60 |
1 files changed, 56 insertions, 4 deletions
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc index 2f9ceaa3bd..cf5d6e8baa 100644 --- a/tensorflow/core/graph/mkl_layout_pass.cc +++ b/tensorflow/core/graph/mkl_layout_pass.cc @@ -1099,6 +1099,44 @@ int MklLayoutRewritePass::SetUpContiguousInputs( CHECK_NOTNULL(workspace_tensors); CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); + // TODO(nhasabni): Temporary solution to connect filter input of + // BackpropInput with the converted filter from Conv2D. + bool do_connect_conv2d_backprop_input_filter = false; + Node* conv2d_node = nullptr; + // Filter node is 2nd input (slot index 1) of Conv2D. + int kConv2DFilterInputSlotIdx = 1; + int kConv2DBackpropInputFilterInputSlotIdx = 1; + int kConv2DFilterOutputSlotIdx = 1; + if (old_node->type_string() == csinfo_.conv2d_grad_input) { + // We need to find Conv2D node from Conv2DBackpropInput. + // For that let's first find filter node that is 2nd input (slot 1) + // of BackpropInput. + Node* filter_node = nullptr; + old_node->input_node(kConv2DBackpropInputFilterInputSlotIdx, &filter_node); + CHECK_NOTNULL(filter_node); + + // Now check which nodes receive from filter_node. Filter feeds as + // 2nd input (slot 1) of _MklConv2D and _MklConv2DWithBias. + for (const Edge* e : filter_node->out_edges()) { + if (e->dst()->type_string() == csinfo_.mkl_conv2d && + e->dst_input() == kConv2DFilterInputSlotIdx + /* filter is 2nd input of Conv2D and _MklConv2D. */) { + if (conv2d_node != nullptr) { + VLOG(1) << "MklLayoutRewritePass: unusual case of same filter" + << " feeding multiple Conv2D nodes: " + << filter_node->DebugString(); + // We will not connect filter input of Conv2DBackpropInput + // to be safe here. + do_connect_conv2d_backprop_input_filter = false; + break; + } else { + conv2d_node = e->dst(); + do_connect_conv2d_backprop_input_filter = true; + } + } + } + } + // Number of input slots to original op // Input slots are represented by .Input() calls in REGISTER_OP. int old_node_input_slots = old_node->op_def().input_arg_size(); @@ -1122,7 +1160,13 @@ int MklLayoutRewritePass::SetUpContiguousInputs( nb->Input(new_node_inputs); nn_slot_idx++; } else { - nb->Input(old_node_inputs[iidx].first, old_node_inputs[iidx].second); + // Special case for connecting filter input of Conv2DBackpropInput + if (do_connect_conv2d_backprop_input_filter && + iidx == kConv2DBackpropInputFilterInputSlotIdx) { + nb->Input(conv2d_node, kConv2DFilterOutputSlotIdx); + } else { + nb->Input(old_node_inputs[iidx].first, old_node_inputs[iidx].second); + } iidx++; nn_slot_idx++; } @@ -1157,9 +1201,17 @@ int MklLayoutRewritePass::SetUpContiguousInputs( } else { Node* mkl_node = nullptr; int mkl_node_output_slot = 0; - GetNodeProducingMklTensor(g, old_node, old_node_inputs[iidx].first, - old_node_inputs[iidx].second, - &mkl_node, &mkl_node_output_slot); + // Special case for connecting filter input of Conv2DBackpropInput + if (do_connect_conv2d_backprop_input_filter && + iidx == kConv2DBackpropInputFilterInputSlotIdx) { + GetNodeProducingMklTensor(g, old_node, conv2d_node, + kConv2DFilterOutputSlotIdx, &mkl_node, + &mkl_node_output_slot); + } else { + GetNodeProducingMklTensor(g, old_node, old_node_inputs[iidx].first, + old_node_inputs[iidx].second, &mkl_node, + &mkl_node_output_slot); + } nb->Input(mkl_node, mkl_node_output_slot); iidx++; nn_slot_idx++; |