aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime
diff options
context:
space:
mode:
authorGravatar Tong Shen <endlessroad@google.com>2018-09-26 22:43:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-26 22:46:28 -0700
commit5df53ab7eb81c67459e2a95e8fbcb71999c703ad (patch)
treedc3707652d47bcf4c424189c0bd426b9e6472b2c /tensorflow/core/common_runtime
parenta40cfd42e20d7e4520c1306666c9dfee97eb0a2e (diff)
Enable constant folding for device memory tensors.
PiperOrigin-RevId: 214723970
Diffstat (limited to 'tensorflow/core/common_runtime')
-rw-r--r--tensorflow/core/common_runtime/constant_folding.cc34
1 files changed, 25 insertions, 9 deletions
diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc
index 99cb9ac6a0..419867ff58 100644
--- a/tensorflow/core/common_runtime/constant_folding.cc
+++ b/tensorflow/core/common_runtime/constant_folding.cc
@@ -470,19 +470,19 @@ bool ReplaceTensorWithConstant(
const ConstantFoldNameGenerator& generate_new_name) {
// Be conservative when replacing a tensor with a constant, when not
// running on CPU.
- // 1) If the destination tensor is not an int32 tensor, and has HOST_MEMORY
+ // 1) Do not replace another constant.
+ // 2) If the destination tensor is not an int32 tensor, and has HOST_MEMORY
// constraint, do not replace it.
- // 2) If the destination tensor is an int32 tensor, but has DEVICE_MEMORY
- // constraint, do not replace it.
- // 3) If the constant op created does not have a kernel implementation
- // for the device, do not use it.
- // 4) If the size of the constant in bytes is too large (>
+ // 3) If the size of the constant in bytes is too large (>
// max_constant_in_bytes), do not replace it. This prevents the size of the
// Graph from growing too large.
+ // 4) If the constant op created does not have a kernel implementation
+ // for the device, do not use it.
// TODO(keveman): Consider adding a new constant op that has a kernel
// implementation for all types, but with HostMemory constraint on it's
// output.
- // 5) Do not replace another constant.
+ // 5) If the constant op for the device has different output memory type
+ // from the original op output memory type, do not replace it.
if (tensor.first->IsConstant()) {
return false;
}
@@ -497,8 +497,7 @@ bool ReplaceTensorWithConstant(
return false;
}
bool is_int32 = tensor.first->output_type(tensor.second) == DT_INT32;
- if ((memory_type == HOST_MEMORY && !is_int32) ||
- (memory_type == DEVICE_MEMORY && is_int32)) {
+ if (memory_type == HOST_MEMORY && !is_int32) {
return false;
}
}
@@ -536,6 +535,23 @@ bool ReplaceTensorWithConstant(
if (!NodeBuilder(builder).Finalize(graph, &constant_node).ok()) {
return false;
}
+ if (partition_device && device_type != DEVICE_CPU) {
+ MemoryType original_output_memory_type;
+ if (!MemoryTypeForOutput(device_type, graph, tensor.first, tensor.second,
+ &original_output_memory_type)
+ .ok()) {
+ return false;
+ }
+ MemoryType const_output_memory_type;
+ if (!MemoryTypeForOutput(device_type, graph, constant_node, 0,
+ &const_output_memory_type)
+ .ok()) {
+ return false;
+ }
+ if (original_output_memory_type != const_output_memory_type) {
+ return false;
+ }
+ }
for (auto edge : edges_to_remove) {
graph->AddEdge(constant_node, 0, edge->dst(), edge->dst_input());
graph->RemoveEdge(edge);