aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Peter Ma <pcma@google.com>2018-10-08 23:12:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-08 23:16:17 -0700
commite27ee15fa45a5f4e43e10ed1fe0eb3a1feb4253a (patch)
tree2588e0531141c95d8c443fa4923d2df20b4970fc
parentd1f0494b89a31298df7743018c0a3fa388ac16a2 (diff)
Refactor CalculateOutputSize() from VirtualScheduler protected member function to utils; Refactor EstimateSize() from memory_optimizer.cc to utils; some small changes for readability improvement
PiperOrigin-RevId: 216307257
-rw-r--r--tensorflow/core/grappler/costs/BUILD1
-rw-r--r--tensorflow/core/grappler/costs/utils.cc40
-rw-r--r--tensorflow/core/grappler/costs/utils.h11
-rw-r--r--tensorflow/core/grappler/costs/utils_test.cc112
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler.cc48
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler.h22
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler_test.cc48
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD1
-rw-r--r--tensorflow/core/grappler/optimizers/memory_optimizer.cc26
9 files changed, 161 insertions, 148 deletions
diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD
index f3dc2c2091..46eacd3a06 100644
--- a/tensorflow/core/grappler/costs/BUILD
+++ b/tensorflow/core/grappler/costs/BUILD
@@ -236,6 +236,7 @@ tf_cc_test(
name = "virtual_scheduler_test",
srcs = ["virtual_scheduler_test.cc"],
deps = [
+ ":utils",
":virtual_placer",
":virtual_scheduler",
"//tensorflow/cc:cc_ops",
diff --git a/tensorflow/core/grappler/costs/utils.cc b/tensorflow/core/grappler/costs/utils.cc
index 5415324b48..2fcadf1de3 100644
--- a/tensorflow/core/grappler/costs/utils.cc
+++ b/tensorflow/core/grappler/costs/utils.cc
@@ -74,7 +74,8 @@ static std::vector<TensorProto> ExtractTensors(const AttrValue& attr_value) {
}
break;
}
- default: {}
+ default: {
+ }
}
return tensors;
}
@@ -201,6 +202,43 @@ std::vector<OpInfo::TensorProperties> FindInputFeatures(
return inputs;
}
+int64 CalculateTensorSize(const OpInfo::TensorProperties& prop) {
+ int64 size = DataTypeSize(BaseType(prop.dtype()));
+ TensorShapeProto shape = prop.shape();
+
+ // Can't infer the size if the rank is unknown. It has to be at least a
+ // scalar though.
+ if (shape.unknown_rank()) {
+ LOG(WARNING) << "CalculateTensorSize() -- unknown rank";
+ return size;
+ }
+
+ // If one of the dimensions is unknown statically, assume it's at least one.
+ for (int i = 0; i < shape.dim_size(); ++i) {
+ if (shape.dim(i).size() < 0) {
+ shape.mutable_dim(i)->set_size(1);
+ LOG(WARNING) << "CalculateTensorSize() -- unknown dim: " << i;
+ }
+ }
+
+ int64 num_elems = TensorShape(shape).num_elements();
+ return num_elems * size;
+}
+
+int64 CalculateOutputSize(
+ const std::vector<OpInfo::TensorProperties>& output_properties,
+ const int port_num) {
+ if (port_num < 0) return 4; // 4B for control dependency.
+
+ if (port_num >= output_properties.size()) {
+ LOG(ERROR) << "CalculateOutputSize() -- port_num: " << port_num
+ << " >= output_properties.size(): " << output_properties.size();
+ return 0;
+ }
+
+ return CalculateTensorSize(output_properties[port_num]);
+}
+
DeviceProperties GetDeviceInfo(const string& device_str) {
DeviceProperties unknown;
unknown.set_type("UNKNOWN");
diff --git a/tensorflow/core/grappler/costs/utils.h b/tensorflow/core/grappler/costs/utils.h
index 5fd6717712..ea64e5a41d 100644
--- a/tensorflow/core/grappler/costs/utils.h
+++ b/tensorflow/core/grappler/costs/utils.h
@@ -43,6 +43,17 @@ std::vector<OpInfo::TensorProperties> FindInputFeatures(
const std::unordered_map<string, const CostGraphDef::Node*>& name_to_cost,
const std::unordered_map<string, const NodeDef*>& name_to_node);
+// Returns the size of tensor (unit: bytes). For tensor shape with unknown rank,
+// it assumes the tensor to be scalar. For any unknown dimension, it assumes
+// size one.
+int64 CalculateTensorSize(const OpInfo::TensorProperties& prop);
+
+// Returns the size of output at port_num (unit: bytes). A special case is
+// port_num -1, which is for control dependency and assumed to be 4 bytes.
+int64 CalculateOutputSize(
+ const std::vector<OpInfo::TensorProperties>& output_properties,
+ int port_num);
+
// Returns the DeviceProperties of the device on which 'node' runs.
DeviceProperties GetDeviceInfo(const CostGraphDef::Node& node);
DeviceProperties GetDeviceInfo(const string& device_str);
diff --git a/tensorflow/core/grappler/costs/utils_test.cc b/tensorflow/core/grappler/costs/utils_test.cc
index baa654f475..db5c11f0fe 100644
--- a/tensorflow/core/grappler/costs/utils_test.cc
+++ b/tensorflow/core/grappler/costs/utils_test.cc
@@ -26,36 +26,42 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-class UtilsTest : public ::testing::Test {
- public:
- void CreateConstOp(const string& name, std::initializer_list<int64> dims,
- NodeDef* node) {
- Tensor tensor(DT_FLOAT, TensorShape(dims));
- for (int64 i = 0; i < tensor.NumElements(); ++i) {
- tensor.flat<float>()(i) = i / 10.0f;
- }
- TF_CHECK_OK(NodeDefBuilder(name, "Const")
- .Attr("dtype", DT_FLOAT)
- .Attr("value", tensor)
- .Finalize(node));
- }
+namespace {
- void CreateConstSizesOp(const string& name, const std::vector<int32>& sizes,
- NodeDef* node) {
- TensorShape shape;
- shape.AddDim(sizes.size());
- Tensor tensor(DT_INT32, shape);
- for (int64 i = 0; i < tensor.NumElements(); ++i) {
- tensor.flat<int32>()(i) = sizes[i];
- }
- TF_CHECK_OK(NodeDefBuilder(name, "Const")
- .Attr("dtype", DT_INT32)
- .Attr("value", tensor)
- .Finalize(node));
- }
-};
+void CreateConstOp(const string& name, std::initializer_list<int64> dims,
+ NodeDef* node) {
+ Tensor tensor(DT_FLOAT, TensorShape(dims));
+ for (int64 i = 0; i < tensor.NumElements(); ++i)
+ tensor.flat<float>()(i) = i / 10.0f;
+ TF_CHECK_OK(NodeDefBuilder(name, "Const")
+ .Attr("dtype", DT_FLOAT)
+ .Attr("value", tensor)
+ .Finalize(node));
+}
-TEST_F(UtilsTest, ConvOpInfo) {
+void CreateConstSizesOp(const string& name, const std::vector<int32>& sizes,
+ NodeDef* node) {
+ TensorShape shape;
+ shape.AddDim(sizes.size());
+ Tensor tensor(DT_INT32, shape);
+ for (int64 i = 0; i < tensor.NumElements(); ++i)
+ tensor.flat<int32>()(i) = sizes[i];
+ TF_CHECK_OK(NodeDefBuilder(name, "Const")
+ .Attr("dtype", DT_INT32)
+ .Attr("value", tensor)
+ .Finalize(node));
+}
+
+// Helper method for converting shapes vector to TensorProperty.
+OpInfo::TensorProperties ShapeToTensorProperty(const std::vector<int>& shapes,
+ const DataType& data_type) {
+ OpInfo::TensorProperties prop;
+ prop.set_dtype(data_type);
+ for (int shape : shapes) prop.mutable_shape()->add_dim()->set_size(shape);
+ return prop;
+}
+
+TEST(UtilsTest, ConvOpInfo) {
int batch = 32;
int rows = 7;
int cols = 9;
@@ -146,7 +152,7 @@ TEST_F(UtilsTest, ConvOpInfo) {
}
}
-TEST_F(UtilsTest, TestSkipControlInput) {
+TEST(UtilsTest, TestSkipControlInput) {
GraphDef graph;
TF_CHECK_OK(NodeDefBuilder("constant", "Const")
.Attr("dtype", DT_INT32)
@@ -172,6 +178,52 @@ TEST_F(UtilsTest, TestSkipControlInput) {
EXPECT_TRUE(node_found);
}
+TEST(UtilsTest, CalculateTensorSize) {
+ // Test normal usage.
+ EXPECT_EQ(DataTypeSize(DT_FLOAT) * 1,
+ CalculateTensorSize(ShapeToTensorProperty({1}, DT_FLOAT)));
+ EXPECT_EQ(DataTypeSize(DT_FLOAT) * 4 * 4,
+ CalculateTensorSize(ShapeToTensorProperty({4, 4}, DT_FLOAT)));
+ EXPECT_EQ(DataTypeSize(DT_HALF) * 10 * 10 * 10,
+ CalculateTensorSize(ShapeToTensorProperty({10, 10, 10}, DT_HALF)));
+ EXPECT_EQ(
+ DataTypeSize(DT_FLOAT) * 100 * 7 * 8 * 99,
+ CalculateTensorSize(ShapeToTensorProperty({100, 7, 8, 99}, DT_FLOAT)));
+
+ // Test unknown rank: assumes the tensor to be a scalar.
+ OpInfo::TensorProperties t = ShapeToTensorProperty({100, 7, 8, 99}, DT_FLOAT);
+ t.mutable_shape()->set_unknown_rank(true);
+ EXPECT_EQ(DataTypeSize(DT_FLOAT) * 1, CalculateTensorSize(t));
+
+ // Test unknown shape: assumes unknown shape (-1) to have size 1.
+ EXPECT_EQ(
+ DataTypeSize(DT_FLOAT) * 1 * 7 * 8 * 99,
+ CalculateTensorSize(ShapeToTensorProperty({-1, 7, 8, 99}, DT_FLOAT)));
+ EXPECT_EQ(
+ DataTypeSize(DT_FLOAT) * 1 * 7 * 1 * 99,
+ CalculateTensorSize(ShapeToTensorProperty({-1, 7, -1, 99}, DT_FLOAT)));
+}
+
+TEST(UtilsTest, CalculateOutputSize) {
+ // Create a set of tensor properties.
+ std::vector<OpInfo::TensorProperties> output = {
+ ShapeToTensorProperty({4, 4}, DT_FLOAT), // 0
+ ShapeToTensorProperty({-1, 7, -1, 99}, DT_FLOAT) // 1
+ };
+
+ // Test valid outputs.
+ EXPECT_EQ(DataTypeSize(DT_FLOAT) * 4 * 4, CalculateOutputSize(output, 0));
+ EXPECT_EQ(DataTypeSize(DT_FLOAT) * 1 * 7 * 1 * 99,
+ CalculateOutputSize(output, 1));
+
+ // port_num -1 is for control dependency: hard coded 4B.
+ EXPECT_EQ(4, CalculateOutputSize(output, -1));
+
+ // Invalid port_num (though it may be an error) shall yield zero
+ // output size.
+ EXPECT_EQ(0, CalculateOutputSize(output, 2));
+}
+
// Class for testing TensorSizeHistogram.
class TestTensorSizeHistogram : public TensorSizeHistogram {
public:
@@ -285,5 +337,7 @@ TEST(DeviceClassTest, GetDeviceClassForNonChannelDevice) {
EXPECT_EQ("//GPU", GetDeviceClassForNonChannelDevice("/device:GPU:7"));
}
+} // namespace
+
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc
index 037a823096..5b93fb128f 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc
@@ -473,6 +473,7 @@ Status VirtualScheduler::Init() {
VLOG(1) << "Some feed nodes were not consumed by the fetch fanin: "
<< str_util::Join(feed_nodes, ",");
}
+
initialized_ = true;
return Status::OK();
}
@@ -695,38 +696,6 @@ NodeState& VirtualScheduler::GetNodeStateOrCreateIt(const NodeDef* node) {
return it->second;
}
-int64 VirtualScheduler::CalculateOutputSize(
- const std::vector<OpInfo::TensorProperties>& output_properties,
- const int port_num) const {
- if (port_num < 0) {
- return 4; // 4B for control dependency.
- }
-
- if (port_num >= output_properties.size()) {
- VLOG(3) << "VirtualScheduler::CalculateOutputSize() -- "
- << "port_num: " << port_num
- << " >= output_properties.size(): " << output_properties.size();
- return 0;
- }
-
- const auto& output = output_properties[port_num];
- int64 output_size = DataTypeSize(BaseType(output.dtype()));
-
- for (const auto& dim : output.shape().dim()) {
- auto dim_size = dim.size();
- if (dim_size < 0) {
- // Zero output size if there's any unknown dim.
- output_size = 0;
- VLOG(3) << "VirtualScheduler::CalculateOutputSize() -- "
- << "unknown dim: " << output_size;
- break;
- }
- output_size *= dim_size;
- }
-
- return output_size;
-}
-
Costs& VirtualScheduler::FindOrCreateZero(const string& op_name,
std::map<string, Costs>* op_cost) {
auto it = op_cost->find(op_name);
@@ -744,7 +713,10 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
const NodeDef* node = ready_nodes_->GetCurrNode();
const string& op_name = node->op();
- // Also keep track of op counts and times per op (with their shapes).
+ auto& op_cost = FindOrCreateZero(op_name, &op_to_cost_);
+ op_cost = CombineCosts(op_cost, node_costs);
+
+ // Also keep track of op counts and costs per op (with their shapes).
OpContext op_context = GetCurrNode();
string node_description = GetOpDescription(op_context.op_info);
op_counts_[node_description] += 1;
@@ -752,9 +724,6 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
std::make_pair(node_costs.execution_time.asMicroSeconds().count(),
!node_costs.inaccurate);
- auto& op_cost = FindOrCreateZero(op_name, &op_to_cost_);
- op_cost = CombineCosts(op_cost, node_costs);
-
// Update node and device states.
auto& node_state = node_map_[node];
auto& device = device_[node_state.device_name];
@@ -795,7 +764,7 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
<< ", scheduled: " << node_state.time_scheduled.count()
<< ", finished: " << node_state.time_finished.count();
- // Increment num_inputs_ready of the output nodes
+ // Increment num_inputs_ready of the output nodes and maybe add to ready nodes
for (const auto& port_num_output_pair : node_state.outputs) {
for (auto* output_node : port_num_output_pair.second) {
auto& output_state = node_map_[output_node];
@@ -812,7 +781,7 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
}
}
- // Increment num_outputs_executed of the input nodes.
+ // Increment num_outputs_executed of the input nodes and maybe update memory.
for (const auto& input_port : node_state.inputs) {
auto* input = input_port.first;
auto port = input_port.second;
@@ -841,7 +810,6 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
}
}
- // Remove the current node; assume FIFO.
ready_nodes_->RemoveCurrNode();
return !ready_nodes_->Empty();
@@ -1007,7 +975,7 @@ Costs VirtualScheduler::Summary(RunMetadata* metadata) {
return Summary();
}
- // Fill RunMetadata.
+ // Fill RunMetadata's step_stats and partition_graphs fields.
StepStats* stepstats = metadata->mutable_step_stats();
for (const auto& device : device_) {
GraphDef* device_partition_graph = metadata->add_partition_graphs();
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.h b/tensorflow/core/grappler/costs/virtual_scheduler.h
index 0e66e8a463..bead84af29 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.h
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.h
@@ -107,10 +107,10 @@ struct DeviceState {
mem_usage_snapshot_at_peak;
Costs device_costs;
- std::map<string, Costs> op_to_cost; // Per-op cost.
- std::map<string, int64> op_to_memory; // Per-op memory usage at peak usage.
- int64 memory_usage;
- int64 max_memory_usage;
+ std::map<string, Costs> op_to_cost; // Per-op cost.
+
+ int64 memory_usage; // Current temporary memory usage
+ int64 max_memory_usage; // Max temporary memory usage
DeviceState() {
device_costs = Costs::ZeroCosts();
@@ -283,13 +283,6 @@ class VirtualScheduler {
return &node_map_;
}
- protected:
- // Returns the size of output at port_num (unit: bytes). A special case is
- // port_num -1, which is for control dependency and assumed to be 4 bytes.
- int64 CalculateOutputSize(
- const std::vector<OpInfo::TensorProperties>& output_properties,
- const int port_num) const;
-
private:
// Constants.
const string kAttrInputSrc = "input_source_";
@@ -321,8 +314,11 @@ class VirtualScheduler {
std::vector<std::unique_ptr<NodeDef>> additional_nodes_;
// Stats:
- std::map<string, int> op_counts_; // Op counts with key with input shape.
- // Individual op costs (with input shapes).
+ // Op counts with key with input shape.
+ // Example key: "[Op=AssignSub, input_shapes=[[7,1,160,160][7,1,160,160]]"
+ std::map<string, int> op_counts_;
+ // Individual op costs with key with input shape.
+ // Integer field for execution time in micro seconds.
// Boolean field for whether the cost is accurate.
std::map<string, std::pair<int, bool>> op_costs_;
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
index 80889afc86..99272dd7e9 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
@@ -19,12 +19,14 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_description.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
+#include "tensorflow/core/grappler/costs/utils.h"
#include "tensorflow/core/grappler/costs/virtual_placer.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace grappler {
+
// Class for testing virtual scheduler.
class TestVirtualScheduler : public VirtualScheduler {
public:
@@ -33,7 +35,6 @@ class TestVirtualScheduler : public VirtualScheduler {
: VirtualScheduler(grappler_item, use_static_shapes, cluster,
&ready_node_manager_) {}
- FRIEND_TEST(VirtualSchedulerTest, CalculateOutputSize);
FRIEND_TEST(VirtualSchedulerTest, MemoryUsage);
FRIEND_TEST(VirtualSchedulerTest, ControlDependency);
FRIEND_TEST(VirtualSchedulerTest, ComplexDependency);
@@ -1034,17 +1035,6 @@ versions {
}
}
- // Helper method for converting shape vector to TensorProperty.
- OpInfo::TensorProperties ShapeToTensorProperty(
- const std::vector<int> shape, const DataType& data_type) const {
- OpInfo::TensorProperties tensor_property;
- tensor_property.set_dtype(data_type);
- for (const auto& x : shape) {
- tensor_property.mutable_shape()->add_dim()->set_size(x);
- }
- return tensor_property;
- }
-
// SetUp() inits cluster_ and placer_.
std::unique_ptr<VirtualCluster> cluster_;
std::unique_ptr<VirtualPlacer> placer_;
@@ -1729,38 +1719,6 @@ TEST_F(VirtualSchedulerTest, InitAndBasicScheduling) {
EXPECT_EQ(2, ops_executed["c1"].op_info.inputs_size());
}
-TEST_F(VirtualSchedulerTest, CalculateOutputSize) {
- // Init.
- CreateGrapplerItemWithAddN();
- InitScheduler();
-
- // Create a set of tensor properties.
- std::vector<OpInfo::TensorProperties> output;
- output.push_back(ShapeToTensorProperty({4, 4}, DT_FLOAT)); // 0
- output.push_back(ShapeToTensorProperty({1}, DT_FLOAT)); // 1
- output.push_back(ShapeToTensorProperty({10, 10, 10}, DT_HALF)); // 2
- output.push_back(ShapeToTensorProperty({100, 7, 8, 99}, DT_FLOAT)); // 3
- output.push_back(ShapeToTensorProperty({-1, 7, 8, 99}, DT_FLOAT)); // 4
- output.push_back(ShapeToTensorProperty({-1, 7, -1, 99}, DT_FLOAT)); // 4
-
- // port_num -1 is for control dependency: hard coded 4B.
- EXPECT_EQ(4, scheduler_->CalculateOutputSize(output, -1));
-
- // Test valid outputs.
- EXPECT_EQ(4 * 4 * 4, scheduler_->CalculateOutputSize(output, 0));
- EXPECT_EQ(4 * 1, scheduler_->CalculateOutputSize(output, 1));
- EXPECT_EQ(2 * 10 * 10 * 10, scheduler_->CalculateOutputSize(output, 2));
- EXPECT_EQ(4 * 100 * 7 * 8 * 99, scheduler_->CalculateOutputSize(output, 3));
-
- // Any unknown shape (-1) shall yield zero output size.
- EXPECT_EQ(0, scheduler_->CalculateOutputSize(output, 4));
- EXPECT_EQ(0, scheduler_->CalculateOutputSize(output, 5));
-
- // Invalid port_num (though it may be an error) shall yield zero
- // output size.
- EXPECT_EQ(0, scheduler_->CalculateOutputSize(output, 6));
-}
-
TEST_F(VirtualSchedulerTest, MemoryUsage) {
// Init.
CreateGrapplerItemWithAddN();
@@ -2041,7 +1999,7 @@ TEST_F(VirtualSchedulerTest, InterDeviceTransfer) {
for (const auto& output_property : output_properties_) {
output_properties.push_back(output_property);
}
- return scheduler_->CalculateOutputSize(output_properties, 0);
+ return CalculateOutputSize(output_properties, 0);
};
// Validate transfer size.
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index c708f84948..e898377ded 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -423,6 +423,7 @@ cc_library(
"//tensorflow/core/grappler/clusters:virtual_cluster",
"//tensorflow/core/grappler/costs:graph_memory",
"//tensorflow/core/grappler/costs:graph_properties",
+ "//tensorflow/core/grappler/costs:utils",
"//tensorflow/core/grappler/utils:topological_sort",
"//tensorflow/core/grappler/utils:traversal",
],
diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc
index c775a26914..73f0977242 100644
--- a/tensorflow/core/grappler/optimizers/memory_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
#include "tensorflow/core/grappler/costs/graph_memory.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
+#include "tensorflow/core/grappler/costs/utils.h"
#include "tensorflow/core/grappler/graph_view.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/op_types.h"
@@ -43,6 +44,8 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
+namespace {
+
// Prefix added to nodes which are recomputed.
const char* kRecomputedNodePrefix = "Recomputed";
const char* kRecomputeTriggerNodePrefix = "RecomputeTrigger";
@@ -744,25 +747,6 @@ Status BuildSwapPair(NodeDef* node, int input_to_swap,
return Status::OK();
}
-static int64 EstimateSize(const OpInfo::TensorProperties& t) {
- DataType dtype = t.dtype();
- int64 size = DataTypeSize(dtype);
- TensorShapeProto shape = t.shape();
- if (shape.unknown_rank()) {
- // Can't infer the size if the rank is unknown. It has to be at least a
- // scalar though.
- return size;
- }
- // If one of the dimensions is unknown statically, assume it's at least one.
- for (int i = 0; i < shape.dim_size(); ++i) {
- if (shape.dim(i).size() < 0) {
- shape.mutable_dim(i)->set_size(1);
- }
- }
- int64 num_elems = TensorShape(shape).num_elements();
- return num_elems * size;
-}
-
struct SwapInfo {
std::vector<int> inputs_to_swap;
Costs::NanoSeconds time_to_swap = 0;
@@ -1149,7 +1133,7 @@ bool SwappingPass(RewriterConfig::MemOptType optimization_level,
int64 bytes_to_swap = 0;
for (int64 input_id : swap_info.inputs_to_swap) {
const OpInfo::TensorProperties& t = props[input_id];
- bytes_to_swap += EstimateSize(t);
+ bytes_to_swap += CalculateTensorSize(t);
}
// Let's assume we're going to swap over PCIe running at 16 GBps.
swap_info.time_to_swap = bytes_to_swap / 16;
@@ -1299,6 +1283,8 @@ Status RelaxAllocatorConstraints(GraphDef* optimized_graph) {
return Status::OK();
}
+} // namespace
+
Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) {
*optimized_graph = item.graph;