aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/graph_partition.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/graph/graph_partition.cc')
-rw-r--r--tensorflow/core/graph/graph_partition.cc8
1 files changed, 8 insertions, 0 deletions
diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc
index 1b1941f9c1..ea0a814ab8 100644
--- a/tensorflow/core/graph/graph_partition.cc
+++ b/tensorflow/core/graph/graph_partition.cc
@@ -214,6 +214,14 @@ NodeDef* AddSend(const PartitionOptions& opts, const GraphInfo& g_info,
cast_builder.Attr("_start_time", start_time);
}
cast_builder.Attr("DstT", cast_dtype);
+
+ if (cast_dtype == DT_BFLOAT16) {
+ // the below attribute specifies that the cast to bfloat16 should use
+ // truncation. This is needed to retain legacy behavior when we change
+ // the default bfloat16 casts to use rounding instead of truncation
+ cast_builder.Attr("Truncate", true);
+ }
+
NodeDef* cast = gdef->add_node();
*status = cast_builder.Finalize(cast);
if (!status->ok()) return nullptr;