diff options
Diffstat (limited to 'tensorflow/core/grappler/optimizers/memory_optimizer.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/memory_optimizer.cc | 26 |
1 files changed, 6 insertions, 20 deletions
diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc index c775a26914..73f0977242 100644 --- a/tensorflow/core/grappler/optimizers/memory_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/grappler/clusters/virtual_cluster.h" #include "tensorflow/core/grappler/costs/graph_memory.h" #include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorflow/core/grappler/costs/utils.h" #include "tensorflow/core/grappler/graph_view.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/op_types.h" @@ -43,6 +44,8 @@ limitations under the License. namespace tensorflow { namespace grappler { +namespace { + // Prefix added to nodes which are recomputed. const char* kRecomputedNodePrefix = "Recomputed"; const char* kRecomputeTriggerNodePrefix = "RecomputeTrigger"; @@ -744,25 +747,6 @@ Status BuildSwapPair(NodeDef* node, int input_to_swap, return Status::OK(); } -static int64 EstimateSize(const OpInfo::TensorProperties& t) { - DataType dtype = t.dtype(); - int64 size = DataTypeSize(dtype); - TensorShapeProto shape = t.shape(); - if (shape.unknown_rank()) { - // Can't infer the size if the rank is unknown. It has to be at least a - // scalar though. - return size; - } - // If one of the dimensions is unknown statically, assume it's at least one. - for (int i = 0; i < shape.dim_size(); ++i) { - if (shape.dim(i).size() < 0) { - shape.mutable_dim(i)->set_size(1); - } - } - int64 num_elems = TensorShape(shape).num_elements(); - return num_elems * size; -} - struct SwapInfo { std::vector<int> inputs_to_swap; Costs::NanoSeconds time_to_swap = 0; @@ -1149,7 +1133,7 @@ bool SwappingPass(RewriterConfig::MemOptType optimization_level, int64 bytes_to_swap = 0; for (int64 input_id : swap_info.inputs_to_swap) { const OpInfo::TensorProperties& t = props[input_id]; - bytes_to_swap += EstimateSize(t); + bytes_to_swap += CalculateTensorSize(t); } // Let's assume we're going to swap over PCIe running at 16 GBps. swap_info.time_to_swap = bytes_to_swap / 16; @@ -1299,6 +1283,8 @@ Status RelaxAllocatorConstraints(GraphDef* optimized_graph) { return Status::OK(); } +} // namespace + Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* optimized_graph) { *optimized_graph = item.graph; |