diff options
Diffstat (limited to 'tensorflow/core/graph')
-rw-r--r-- | tensorflow/core/graph/graph_partition.cc | 8 |
1 files changed, 8 insertions, 0 deletions
diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc index 1b1941f9c1..ea0a814ab8 100644 --- a/tensorflow/core/graph/graph_partition.cc +++ b/tensorflow/core/graph/graph_partition.cc @@ -214,6 +214,14 @@ NodeDef* AddSend(const PartitionOptions& opts, const GraphInfo& g_info, cast_builder.Attr("_start_time", start_time); } cast_builder.Attr("DstT", cast_dtype); + + if (cast_dtype == DT_BFLOAT16) { + // the below attribute specifies that the cast to bfloat16 should use + // truncation. This is needed to retain legacy behavior when we change + // the default bfloat16 casts to use rounding instead of truncation + cast_builder.Attr("Truncate", true); + } + NodeDef* cast = gdef->add_node(); *status = cast_builder.Finalize(cast); if (!status->ok()) return nullptr; |