diff options
Diffstat (limited to 'tensorflow/core/graph')
-rw-r--r-- | tensorflow/core/graph/control_flow.cc | 37 | ||||
-rw-r--r-- | tensorflow/core/graph/control_flow_test.cc | 17 | ||||
-rw-r--r-- | tensorflow/core/graph/graph_partition.cc | 8 | ||||
-rw-r--r-- | tensorflow/core/graph/mkl_graph_util.h | 1 | ||||
-rw-r--r-- | tensorflow/core/graph/mkl_layout_pass.cc | 13 | ||||
-rw-r--r-- | tensorflow/core/graph/mkl_layout_pass_test.cc | 1 | ||||
-rw-r--r-- | tensorflow/core/graph/mkl_tfconversion_pass.cc | 1 | ||||
-rw-r--r-- | tensorflow/core/graph/mkl_tfconversion_pass_test.cc | 1 |
8 files changed, 52 insertions, 27 deletions
diff --git a/tensorflow/core/graph/control_flow.cc b/tensorflow/core/graph/control_flow.cc index 1778e48ef6..8e1e56d29b 100644 --- a/tensorflow/core/graph/control_flow.cc +++ b/tensorflow/core/graph/control_flow.cc @@ -18,6 +18,7 @@ limitations under the License. #include <deque> #include <vector> +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/core/errors.h" @@ -54,10 +55,11 @@ Status ValidateControlFlowInfo(const Graph* graph, frame.parent = parent; frame.name = cf.frame_name; } else if (frame.parent != parent) { - return errors::InvalidArgument( + return errors::Internal( "Invalid loop structure: Mismatched parent frames for \"", cf.frame_name, "\": \"", parent->name, "\" vs \"", frame.parent->name, - "\". This is an internal bug, please file a bug report with " + "\". The node giving this error: ", FormatNodeForError(*node), + "This is an internal bug, please file a bug report with " "instructions on how to reproduce the error."); } if (IsLoopCond(node)) { @@ -69,9 +71,9 @@ Status ValidateControlFlowInfo(const Graph* graph, !str_util::StrContains(node->name(), "LoopCounter")) { return errors::InvalidArgument( "Invalid loop structure: Loop \"", cf.frame_name, - "\" has more than one LoopCond node: \"", node->name(), "\" and \"", - frame.loop_cond->name(), - "\". This is an internal bug, please file a bug report with " + "\" has more than one LoopCond node: ", FormatNodeForError(*node), + " and ", FormatNodeForError(*frame.loop_cond), + ". This is an internal bug, please file a bug report with " "instructions on how to reproduce the error."); } frame.loop_cond = node; @@ -135,12 +137,11 @@ Status BuildControlFlowInfo(const Graph* g, std::vector<ControlFlowInfo>* info, const string& parent_frame = (*info)[out_parent->id()].frame_name; if (parent_frame != frame_name) { return errors::InvalidArgument( - "The node '", out->name(), - "' has inputs from different " - "frames. The input '", - curr_node->name(), "' is in frame '", frame_name, - "'. The input '", parent_nodes[out->id()]->name(), - "' is in frame '", parent_frame, "'."); + FormatNodeForError(*out), + " has inputs from different frames. The input ", + FormatNodeForError(*curr_node), " is in frame '", frame_name, + "'. The input ", FormatNodeForError(*parent_nodes[out->id()]), + " is in frame '", parent_frame, "'."); } } else { out_info->frame = out; @@ -148,7 +149,8 @@ Status BuildControlFlowInfo(const Graph* g, std::vector<ControlFlowInfo>* info, TF_RETURN_IF_ERROR( GetNodeAttr(out->attrs(), "frame_name", &out_info->frame_name)); if (out_info->frame_name.empty()) { - return errors::InvalidArgument("The Enter node ", out->name(), + return errors::InvalidArgument("The Enter ", + FormatNodeForError(*out), " must have a frame name."); } } @@ -156,12 +158,11 @@ Status BuildControlFlowInfo(const Graph* g, std::vector<ControlFlowInfo>* info, if (is_visited) { if (out_info->frame_name != frame_name) { return errors::InvalidArgument( - "The node '", out->name(), - "' has inputs from different " - "frames. The input '", - curr_node->name(), "' is in frame '", frame_name, - "'. The input '", parent_nodes[out->id()]->name(), - "' is in frame '", out_info->frame_name, "'."); + FormatNodeForError(*out), + " has inputs from different frames. The input ", + FormatNodeForError(*curr_node), " is in frame '", frame_name, + "'. The input ", FormatNodeForError(*parent_nodes[out->id()]), + " is in frame '", out_info->frame_name, "'."); } } else { out_info->frame = frame; diff --git a/tensorflow/core/graph/control_flow_test.cc b/tensorflow/core/graph/control_flow_test.cc index eb7937400f..803c757c3f 100644 --- a/tensorflow/core/graph/control_flow_test.cc +++ b/tensorflow/core/graph/control_flow_test.cc @@ -63,6 +63,15 @@ TEST(ValidateControlFlowTest, InputsFromDifferentFrames) { EXPECT_TRUE(str_util::StrContains(status.error_message(), "has inputs from different frames")) << status.error_message(); + EXPECT_TRUE(str_util::StrContains(status.error_message(), + "{{node outer/body/inner/Merge}}")) + << status.error_message(); + EXPECT_TRUE(str_util::StrContains(status.error_message(), + "{{node outer/body/inner/Enter}}")) + << status.error_message(); + EXPECT_TRUE( + str_util::StrContains(status.error_message(), "{{node outer/Switch}}")) + << status.error_message(); } TEST(ValidateControlFlowTest, MismatchedParentFrames) { @@ -102,6 +111,8 @@ TEST(ValidateControlFlowTest, MismatchedParentFrames) { EXPECT_TRUE( str_util::StrContains(status.error_message(), "Mismatched parent frames")) << status.error_message(); + EXPECT_TRUE(str_util::StrContains(status.error_message(), "{{node Enter2}}")) + << status.error_message(); } TEST(ValidateControlFlowTest, TwoLoopCond) { @@ -125,6 +136,12 @@ TEST(ValidateControlFlowTest, TwoLoopCond) { EXPECT_TRUE(str_util::StrContains(status.error_message(), "more than one LoopCond node")) << status.error_message(); + EXPECT_TRUE( + str_util::StrContains(status.error_message(), "{{node sub/LoopCond}}")) + << status.error_message(); + EXPECT_TRUE( + str_util::StrContains(status.error_message(), "{{node LoopCond}}")) + << status.error_message(); } } // namespace diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc index 1b1941f9c1..ea0a814ab8 100644 --- a/tensorflow/core/graph/graph_partition.cc +++ b/tensorflow/core/graph/graph_partition.cc @@ -214,6 +214,14 @@ NodeDef* AddSend(const PartitionOptions& opts, const GraphInfo& g_info, cast_builder.Attr("_start_time", start_time); } cast_builder.Attr("DstT", cast_dtype); + + if (cast_dtype == DT_BFLOAT16) { + // the below attribute specifies that the cast to bfloat16 should use + // truncation. This is needed to retain legacy behavior when we change + // the default bfloat16 casts to use rounding instead of truncation + cast_builder.Attr("Truncate", true); + } + NodeDef* cast = gdef->add_node(); *status = cast_builder.Finalize(cast); if (!status->ok()) return nullptr; diff --git a/tensorflow/core/graph/mkl_graph_util.h b/tensorflow/core/graph/mkl_graph_util.h index 5f51d6083b..333bf761b0 100644 --- a/tensorflow/core/graph/mkl_graph_util.h +++ b/tensorflow/core/graph/mkl_graph_util.h @@ -17,7 +17,6 @@ limitations under the License. #define TENSORFLOW_CORE_GRAPH_MKL_GRAPH_UTIL_H_ #ifdef INTEL_MKL -#include <string> #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc index 836ccae9b7..c22e0a3872 100644 --- a/tensorflow/core/graph/mkl_layout_pass.cc +++ b/tensorflow/core/graph/mkl_layout_pass.cc @@ -22,7 +22,6 @@ limitations under the License. #include <memory> #include <queue> #include <set> -#include <string> #include <unordered_set> #include <utility> #include <vector> @@ -2894,7 +2893,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass { for (const Edge* e : n->in_edges()) { // Rewrite only if there is corresponding LRN, i.e workspace is available if (e->dst()->type_string() == csinfo_.lrn_grad && e->dst_input() == 2 && - e->src()->type_string() == mkl_op_registry::GetMklOpName(csinfo_.lrn) && + e->src()->type_string() == + mkl_op_registry::GetMklOpName(csinfo_.lrn) && e->src_output() == 0) { do_rewrite = true; break; @@ -2907,9 +2907,12 @@ class MklLayoutRewritePass : public GraphOptimizationPass { CHECK_NOTNULL(n); bool do_rewrite = false; for (const Edge* e : n->in_edges()) { - // Rewrite only if there is corresponding Maxpool, i.e workspace is available - if (e->dst()->type_string() == csinfo_.max_pool_grad && e->dst_input() == 1 && - e->src()->type_string() == mkl_op_registry::GetMklOpName(csinfo_.max_pool) && + // Rewrite only if there is corresponding Maxpool, i.e workspace is + // available + if (e->dst()->type_string() == csinfo_.max_pool_grad && + e->dst_input() == 1 && + e->src()->type_string() == + mkl_op_registry::GetMklOpName(csinfo_.max_pool) && e->src_output() == 0) { do_rewrite = true; break; diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc index b8f5492f7c..a41f5861af 100644 --- a/tensorflow/core/graph/mkl_layout_pass_test.cc +++ b/tensorflow/core/graph/mkl_layout_pass_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/core/graph/mkl_graph_util.h" #include <algorithm> -#include <string> #include <vector> #include "tensorflow/core/framework/op.h" diff --git a/tensorflow/core/graph/mkl_tfconversion_pass.cc b/tensorflow/core/graph/mkl_tfconversion_pass.cc index e9ced4d2b6..aa39af637f 100644 --- a/tensorflow/core/graph/mkl_tfconversion_pass.cc +++ b/tensorflow/core/graph/mkl_tfconversion_pass.cc @@ -18,7 +18,6 @@ limitations under the License. #include <memory> #include <queue> #include <set> -#include <string> #include <utility> #include <vector> diff --git a/tensorflow/core/graph/mkl_tfconversion_pass_test.cc b/tensorflow/core/graph/mkl_tfconversion_pass_test.cc index bbdbe78bbd..ebcb6de551 100644 --- a/tensorflow/core/graph/mkl_tfconversion_pass_test.cc +++ b/tensorflow/core/graph/mkl_tfconversion_pass_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/core/graph/mkl_graph_util.h" #include <algorithm> -#include <string> #include <vector> #include "tensorflow/core/framework/op.h" |