diff options
author | Tong Shen <endlessroad@google.com> | 2018-09-26 22:43:25 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-26 22:46:28 -0700 |
commit | 5df53ab7eb81c67459e2a95e8fbcb71999c703ad (patch) | |
tree | dc3707652d47bcf4c424189c0bd426b9e6472b2c /tensorflow/core/common_runtime | |
parent | a40cfd42e20d7e4520c1306666c9dfee97eb0a2e (diff) |
Enable constant folding for device memory tensors.
PiperOrigin-RevId: 214723970
Diffstat (limited to 'tensorflow/core/common_runtime')
-rw-r--r-- | tensorflow/core/common_runtime/constant_folding.cc | 34 |
1 files changed, 25 insertions, 9 deletions
diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc index 99cb9ac6a0..419867ff58 100644 --- a/tensorflow/core/common_runtime/constant_folding.cc +++ b/tensorflow/core/common_runtime/constant_folding.cc @@ -470,19 +470,19 @@ bool ReplaceTensorWithConstant( const ConstantFoldNameGenerator& generate_new_name) { // Be conservative when replacing a tensor with a constant, when not // running on CPU. - // 1) If the destination tensor is not an int32 tensor, and has HOST_MEMORY + // 1) Do not replace another constant. + // 2) If the destination tensor is not an int32 tensor, and has HOST_MEMORY // constraint, do not replace it. - // 2) If the destination tensor is an int32 tensor, but has DEVICE_MEMORY - // constraint, do not replace it. - // 3) If the constant op created does not have a kernel implementation - // for the device, do not use it. - // 4) If the size of the constant in bytes is too large (> + // 3) If the size of the constant in bytes is too large (> // max_constant_in_bytes), do not replace it. This prevents the size of the // Graph from growing too large. + // 4) If the constant op created does not have a kernel implementation + // for the device, do not use it. // TODO(keveman): Consider adding a new constant op that has a kernel // implementation for all types, but with HostMemory constraint on it's // output. - // 5) Do not replace another constant. + // 5) If the constant op for the device has different output memory type + // from the original op output memory type, do not replace it. if (tensor.first->IsConstant()) { return false; } @@ -497,8 +497,7 @@ bool ReplaceTensorWithConstant( return false; } bool is_int32 = tensor.first->output_type(tensor.second) == DT_INT32; - if ((memory_type == HOST_MEMORY && !is_int32) || - (memory_type == DEVICE_MEMORY && is_int32)) { + if (memory_type == HOST_MEMORY && !is_int32) { return false; } } @@ -536,6 +535,23 @@ bool ReplaceTensorWithConstant( if (!NodeBuilder(builder).Finalize(graph, &constant_node).ok()) { return false; } + if (partition_device && device_type != DEVICE_CPU) { + MemoryType original_output_memory_type; + if (!MemoryTypeForOutput(device_type, graph, tensor.first, tensor.second, + &original_output_memory_type) + .ok()) { + return false; + } + MemoryType const_output_memory_type; + if (!MemoryTypeForOutput(device_type, graph, constant_node, 0, + &const_output_memory_type) + .ok()) { + return false; + } + if (original_output_memory_type != const_output_memory_type) { + return false; + } + } for (auto edge : edges_to_remove) { graph->AddEdge(constant_node, 0, edge->dst(), edge->dst_input()); graph->RemoveEdge(edge); |