aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/memory_optimizer.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/optimizers/memory_optimizer.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/memory_optimizer.cc26
1 files changed, 6 insertions, 20 deletions
diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc
index c775a26914..73f0977242 100644
--- a/tensorflow/core/grappler/optimizers/memory_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
#include "tensorflow/core/grappler/costs/graph_memory.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
+#include "tensorflow/core/grappler/costs/utils.h"
#include "tensorflow/core/grappler/graph_view.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/op_types.h"
@@ -43,6 +44,8 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
+namespace {
+
// Prefix added to nodes which are recomputed.
const char* kRecomputedNodePrefix = "Recomputed";
const char* kRecomputeTriggerNodePrefix = "RecomputeTrigger";
@@ -744,25 +747,6 @@ Status BuildSwapPair(NodeDef* node, int input_to_swap,
return Status::OK();
}
-static int64 EstimateSize(const OpInfo::TensorProperties& t) {
- DataType dtype = t.dtype();
- int64 size = DataTypeSize(dtype);
- TensorShapeProto shape = t.shape();
- if (shape.unknown_rank()) {
- // Can't infer the size if the rank is unknown. It has to be at least a
- // scalar though.
- return size;
- }
- // If one of the dimensions is unknown statically, assume it's at least one.
- for (int i = 0; i < shape.dim_size(); ++i) {
- if (shape.dim(i).size() < 0) {
- shape.mutable_dim(i)->set_size(1);
- }
- }
- int64 num_elems = TensorShape(shape).num_elements();
- return num_elems * size;
-}
-
struct SwapInfo {
std::vector<int> inputs_to_swap;
Costs::NanoSeconds time_to_swap = 0;
@@ -1149,7 +1133,7 @@ bool SwappingPass(RewriterConfig::MemOptType optimization_level,
int64 bytes_to_swap = 0;
for (int64 input_id : swap_info.inputs_to_swap) {
const OpInfo::TensorProperties& t = props[input_id];
- bytes_to_swap += EstimateSize(t);
+ bytes_to_swap += CalculateTensorSize(t);
}
// Let's assume we're going to swap over PCIe running at 16 GBps.
swap_info.time_to_swap = bytes_to_swap / 16;
@@ -1299,6 +1283,8 @@ Status RelaxAllocatorConstraints(GraphDef* optimized_graph) {
return Status::OK();
}
+} // namespace
+
Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) {
*optimized_graph = item.graph;