aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/graph')
-rw-r--r--tensorflow/core/graph/control_flow.cc37
-rw-r--r--tensorflow/core/graph/control_flow_test.cc17
-rw-r--r--tensorflow/core/graph/graph_partition.cc8
-rw-r--r--tensorflow/core/graph/mkl_graph_util.h1
-rw-r--r--tensorflow/core/graph/mkl_layout_pass.cc13
-rw-r--r--tensorflow/core/graph/mkl_layout_pass_test.cc1
-rw-r--r--tensorflow/core/graph/mkl_tfconversion_pass.cc1
-rw-r--r--tensorflow/core/graph/mkl_tfconversion_pass_test.cc1
8 files changed, 52 insertions, 27 deletions
diff --git a/tensorflow/core/graph/control_flow.cc b/tensorflow/core/graph/control_flow.cc
index 1778e48ef6..8e1e56d29b 100644
--- a/tensorflow/core/graph/control_flow.cc
+++ b/tensorflow/core/graph/control_flow.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <deque>
#include <vector>
+#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -54,10 +55,11 @@ Status ValidateControlFlowInfo(const Graph* graph,
frame.parent = parent;
frame.name = cf.frame_name;
} else if (frame.parent != parent) {
- return errors::InvalidArgument(
+ return errors::Internal(
"Invalid loop structure: Mismatched parent frames for \"",
cf.frame_name, "\": \"", parent->name, "\" vs \"", frame.parent->name,
- "\". This is an internal bug, please file a bug report with "
+ "\". The node giving this error: ", FormatNodeForError(*node),
+ "This is an internal bug, please file a bug report with "
"instructions on how to reproduce the error.");
}
if (IsLoopCond(node)) {
@@ -69,9 +71,9 @@ Status ValidateControlFlowInfo(const Graph* graph,
!str_util::StrContains(node->name(), "LoopCounter")) {
return errors::InvalidArgument(
"Invalid loop structure: Loop \"", cf.frame_name,
- "\" has more than one LoopCond node: \"", node->name(), "\" and \"",
- frame.loop_cond->name(),
- "\". This is an internal bug, please file a bug report with "
+ "\" has more than one LoopCond node: ", FormatNodeForError(*node),
+ " and ", FormatNodeForError(*frame.loop_cond),
+ ". This is an internal bug, please file a bug report with "
"instructions on how to reproduce the error.");
}
frame.loop_cond = node;
@@ -135,12 +137,11 @@ Status BuildControlFlowInfo(const Graph* g, std::vector<ControlFlowInfo>* info,
const string& parent_frame = (*info)[out_parent->id()].frame_name;
if (parent_frame != frame_name) {
return errors::InvalidArgument(
- "The node '", out->name(),
- "' has inputs from different "
- "frames. The input '",
- curr_node->name(), "' is in frame '", frame_name,
- "'. The input '", parent_nodes[out->id()]->name(),
- "' is in frame '", parent_frame, "'.");
+ FormatNodeForError(*out),
+ " has inputs from different frames. The input ",
+ FormatNodeForError(*curr_node), " is in frame '", frame_name,
+ "'. The input ", FormatNodeForError(*parent_nodes[out->id()]),
+ " is in frame '", parent_frame, "'.");
}
} else {
out_info->frame = out;
@@ -148,7 +149,8 @@ Status BuildControlFlowInfo(const Graph* g, std::vector<ControlFlowInfo>* info,
TF_RETURN_IF_ERROR(
GetNodeAttr(out->attrs(), "frame_name", &out_info->frame_name));
if (out_info->frame_name.empty()) {
- return errors::InvalidArgument("The Enter node ", out->name(),
+ return errors::InvalidArgument("The Enter ",
+ FormatNodeForError(*out),
" must have a frame name.");
}
}
@@ -156,12 +158,11 @@ Status BuildControlFlowInfo(const Graph* g, std::vector<ControlFlowInfo>* info,
if (is_visited) {
if (out_info->frame_name != frame_name) {
return errors::InvalidArgument(
- "The node '", out->name(),
- "' has inputs from different "
- "frames. The input '",
- curr_node->name(), "' is in frame '", frame_name,
- "'. The input '", parent_nodes[out->id()]->name(),
- "' is in frame '", out_info->frame_name, "'.");
+ FormatNodeForError(*out),
+ " has inputs from different frames. The input ",
+ FormatNodeForError(*curr_node), " is in frame '", frame_name,
+ "'. The input ", FormatNodeForError(*parent_nodes[out->id()]),
+ " is in frame '", out_info->frame_name, "'.");
}
} else {
out_info->frame = frame;
diff --git a/tensorflow/core/graph/control_flow_test.cc b/tensorflow/core/graph/control_flow_test.cc
index eb7937400f..803c757c3f 100644
--- a/tensorflow/core/graph/control_flow_test.cc
+++ b/tensorflow/core/graph/control_flow_test.cc
@@ -63,6 +63,15 @@ TEST(ValidateControlFlowTest, InputsFromDifferentFrames) {
EXPECT_TRUE(str_util::StrContains(status.error_message(),
"has inputs from different frames"))
<< status.error_message();
+ EXPECT_TRUE(str_util::StrContains(status.error_message(),
+ "{{node outer/body/inner/Merge}}"))
+ << status.error_message();
+ EXPECT_TRUE(str_util::StrContains(status.error_message(),
+ "{{node outer/body/inner/Enter}}"))
+ << status.error_message();
+ EXPECT_TRUE(
+ str_util::StrContains(status.error_message(), "{{node outer/Switch}}"))
+ << status.error_message();
}
TEST(ValidateControlFlowTest, MismatchedParentFrames) {
@@ -102,6 +111,8 @@ TEST(ValidateControlFlowTest, MismatchedParentFrames) {
EXPECT_TRUE(
str_util::StrContains(status.error_message(), "Mismatched parent frames"))
<< status.error_message();
+ EXPECT_TRUE(str_util::StrContains(status.error_message(), "{{node Enter2}}"))
+ << status.error_message();
}
TEST(ValidateControlFlowTest, TwoLoopCond) {
@@ -125,6 +136,12 @@ TEST(ValidateControlFlowTest, TwoLoopCond) {
EXPECT_TRUE(str_util::StrContains(status.error_message(),
"more than one LoopCond node"))
<< status.error_message();
+ EXPECT_TRUE(
+ str_util::StrContains(status.error_message(), "{{node sub/LoopCond}}"))
+ << status.error_message();
+ EXPECT_TRUE(
+ str_util::StrContains(status.error_message(), "{{node LoopCond}}"))
+ << status.error_message();
}
} // namespace
diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc
index 1b1941f9c1..ea0a814ab8 100644
--- a/tensorflow/core/graph/graph_partition.cc
+++ b/tensorflow/core/graph/graph_partition.cc
@@ -214,6 +214,14 @@ NodeDef* AddSend(const PartitionOptions& opts, const GraphInfo& g_info,
cast_builder.Attr("_start_time", start_time);
}
cast_builder.Attr("DstT", cast_dtype);
+
+ if (cast_dtype == DT_BFLOAT16) {
+ // the below attribute specifies that the cast to bfloat16 should use
+ // truncation. This is needed to retain legacy behavior when we change
+ // the default bfloat16 casts to use rounding instead of truncation
+ cast_builder.Attr("Truncate", true);
+ }
+
NodeDef* cast = gdef->add_node();
*status = cast_builder.Finalize(cast);
if (!status->ok()) return nullptr;
diff --git a/tensorflow/core/graph/mkl_graph_util.h b/tensorflow/core/graph/mkl_graph_util.h
index 5f51d6083b..333bf761b0 100644
--- a/tensorflow/core/graph/mkl_graph_util.h
+++ b/tensorflow/core/graph/mkl_graph_util.h
@@ -17,7 +17,6 @@ limitations under the License.
#define TENSORFLOW_CORE_GRAPH_MKL_GRAPH_UTIL_H_
#ifdef INTEL_MKL
-#include <string>
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index 836ccae9b7..c22e0a3872 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -22,7 +22,6 @@ limitations under the License.
#include <memory>
#include <queue>
#include <set>
-#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
@@ -2894,7 +2893,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
for (const Edge* e : n->in_edges()) {
// Rewrite only if there is corresponding LRN, i.e workspace is available
if (e->dst()->type_string() == csinfo_.lrn_grad && e->dst_input() == 2 &&
- e->src()->type_string() == mkl_op_registry::GetMklOpName(csinfo_.lrn) &&
+ e->src()->type_string() ==
+ mkl_op_registry::GetMklOpName(csinfo_.lrn) &&
e->src_output() == 0) {
do_rewrite = true;
break;
@@ -2907,9 +2907,12 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
CHECK_NOTNULL(n);
bool do_rewrite = false;
for (const Edge* e : n->in_edges()) {
- // Rewrite only if there is corresponding Maxpool, i.e workspace is available
- if (e->dst()->type_string() == csinfo_.max_pool_grad && e->dst_input() == 1 &&
- e->src()->type_string() == mkl_op_registry::GetMklOpName(csinfo_.max_pool) &&
+ // Rewrite only if there is corresponding Maxpool, i.e workspace is
+ // available
+ if (e->dst()->type_string() == csinfo_.max_pool_grad &&
+ e->dst_input() == 1 &&
+ e->src()->type_string() ==
+ mkl_op_registry::GetMklOpName(csinfo_.max_pool) &&
e->src_output() == 0) {
do_rewrite = true;
break;
diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc
index b8f5492f7c..a41f5861af 100644
--- a/tensorflow/core/graph/mkl_layout_pass_test.cc
+++ b/tensorflow/core/graph/mkl_layout_pass_test.cc
@@ -19,7 +19,6 @@ limitations under the License.
#include "tensorflow/core/graph/mkl_graph_util.h"
#include <algorithm>
-#include <string>
#include <vector>
#include "tensorflow/core/framework/op.h"
diff --git a/tensorflow/core/graph/mkl_tfconversion_pass.cc b/tensorflow/core/graph/mkl_tfconversion_pass.cc
index e9ced4d2b6..aa39af637f 100644
--- a/tensorflow/core/graph/mkl_tfconversion_pass.cc
+++ b/tensorflow/core/graph/mkl_tfconversion_pass.cc
@@ -18,7 +18,6 @@ limitations under the License.
#include <memory>
#include <queue>
#include <set>
-#include <string>
#include <utility>
#include <vector>
diff --git a/tensorflow/core/graph/mkl_tfconversion_pass_test.cc b/tensorflow/core/graph/mkl_tfconversion_pass_test.cc
index bbdbe78bbd..ebcb6de551 100644
--- a/tensorflow/core/graph/mkl_tfconversion_pass_test.cc
+++ b/tensorflow/core/graph/mkl_tfconversion_pass_test.cc
@@ -19,7 +19,6 @@ limitations under the License.
#include "tensorflow/core/graph/mkl_graph_util.h"
#include <algorithm>
-#include <string>
#include <vector>
#include "tensorflow/core/framework/op.h"