aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/constant_folding.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/common_runtime/constant_folding.cc')
-rw-r--r--tensorflow/core/common_runtime/constant_folding.cc35
1 files changed, 8 insertions, 27 deletions
diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc
index db137f1a19..e81e61b633 100644
--- a/tensorflow/core/common_runtime/constant_folding.cc
+++ b/tensorflow/core/common_runtime/constant_folding.cc
@@ -466,23 +466,23 @@ Graph* GetConstantGraph(
bool ReplaceTensorWithConstant(
Graph* graph, Device* partition_device, NodeAndOutput tensor,
const Tensor& constant, const gtl::FlatSet<Node*>& control_deps,
- int64 max_constant_size_in_bytes, bool disable_memory_output_type_check,
+ int64 max_constant_size_in_bytes,
const ConstantFoldNameGenerator& generate_new_name) {
// Be conservative when replacing a tensor with a constant, when not
// running on CPU.
// 1) Do not replace another constant.
// 2) If the destination tensor is not an int32 tensor, and has HOST_MEMORY
// constraint, do not replace it.
- // 3) If the size of the constant in bytes is too large (>
+ // 3) If the destination tensor is an int32 tensor, and has DEVICE_MEMORY
+ // constraint, do not replace it.
+ // 4) If the size of the constant in bytes is too large (>
// max_constant_in_bytes), do not replace it. This prevents the size of the
// Graph from growing too large.
- // 4) If the constant op created does not have a kernel implementation
+ // 5) If the constant op created does not have a kernel implementation
// for the device, do not use it.
// TODO(keveman): Consider adding a new constant op that has a kernel
// implementation for all types, but with HostMemory constraint on it's
// output.
- // 5) If the constant op for the device has different output memory type
- // from the original op output memory type, do not replace it.
if (tensor.first->IsConstant()) {
return false;
}
@@ -497,7 +497,8 @@ bool ReplaceTensorWithConstant(
return false;
}
bool is_int32 = tensor.first->output_type(tensor.second) == DT_INT32;
- if (memory_type == HOST_MEMORY && !is_int32) {
+ if ((memory_type == HOST_MEMORY && !is_int32) ||
+ (memory_type == DEVICE_MEMORY && is_int32)) {
return false;
}
}
@@ -535,25 +536,6 @@ bool ReplaceTensorWithConstant(
if (!NodeBuilder(builder).Finalize(graph, &constant_node).ok()) {
return false;
}
- if (!disable_memory_output_type_check) {
- if (partition_device && device_type != DEVICE_CPU) {
- MemoryType original_output_memory_type;
- if (!MemoryTypeForOutput(device_type, graph, tensor.first, tensor.second,
- &original_output_memory_type)
- .ok()) {
- return false;
- }
- MemoryType const_output_memory_type;
- if (!MemoryTypeForOutput(device_type, graph, constant_node, 0,
- &const_output_memory_type)
- .ok()) {
- return false;
- }
- if (original_output_memory_type != const_output_memory_type) {
- return false;
- }
- }
- }
for (auto edge : edges_to_remove) {
graph->AddEdge(constant_node, 0, edge->dst(), edge->dst_input());
graph->RemoveEdge(edge);
@@ -660,8 +642,7 @@ Status ConstantFold(const ConstantFoldingOptions& opts,
constant_control_deps[tensors_to_replace[c].first];
if (ReplaceTensorWithConstant(
graph, partition_device, tensors_to_replace[c], outputs[c],
- control_deps, opts.max_constant_size_in_bytes,
- opts.disable_memory_output_type_check, generate_new_name)) {
+ control_deps, opts.max_constant_size_in_bytes, generate_new_name)) {
++num_nodes_replaced;
}
}