From e27ee15fa45a5f4e43e10ed1fe0eb3a1feb4253a Mon Sep 17 00:00:00 2001 From: Peter Ma Date: Mon, 8 Oct 2018 23:12:08 -0700 Subject: Refactor CalculateOutputSize() from VirtualScheduler protected member function to utils; Refactor EstimateSize() from memory_optimizer.cc to utils; some small changes for readability improvement PiperOrigin-RevId: 216307257 --- tensorflow/core/grappler/costs/BUILD | 1 + tensorflow/core/grappler/costs/utils.cc | 40 +++++++- tensorflow/core/grappler/costs/utils.h | 11 ++ tensorflow/core/grappler/costs/utils_test.cc | 112 +++++++++++++++------ .../core/grappler/costs/virtual_scheduler.cc | 48 ++------- tensorflow/core/grappler/costs/virtual_scheduler.h | 22 ++-- .../core/grappler/costs/virtual_scheduler_test.cc | 48 +-------- tensorflow/core/grappler/optimizers/BUILD | 1 + .../core/grappler/optimizers/memory_optimizer.cc | 26 ++--- 9 files changed, 161 insertions(+), 148 deletions(-) diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD index f3dc2c2091..46eacd3a06 100644 --- a/tensorflow/core/grappler/costs/BUILD +++ b/tensorflow/core/grappler/costs/BUILD @@ -236,6 +236,7 @@ tf_cc_test( name = "virtual_scheduler_test", srcs = ["virtual_scheduler_test.cc"], deps = [ + ":utils", ":virtual_placer", ":virtual_scheduler", "//tensorflow/cc:cc_ops", diff --git a/tensorflow/core/grappler/costs/utils.cc b/tensorflow/core/grappler/costs/utils.cc index 5415324b48..2fcadf1de3 100644 --- a/tensorflow/core/grappler/costs/utils.cc +++ b/tensorflow/core/grappler/costs/utils.cc @@ -74,7 +74,8 @@ static std::vector ExtractTensors(const AttrValue& attr_value) { } break; } - default: {} + default: { + } } return tensors; } @@ -201,6 +202,43 @@ std::vector FindInputFeatures( return inputs; } +int64 CalculateTensorSize(const OpInfo::TensorProperties& prop) { + int64 size = DataTypeSize(BaseType(prop.dtype())); + TensorShapeProto shape = prop.shape(); + + // Can't infer the size if the rank is unknown. It has to be at least a + // scalar though. + if (shape.unknown_rank()) { + LOG(WARNING) << "CalculateTensorSize() -- unknown rank"; + return size; + } + + // If one of the dimensions is unknown statically, assume it's at least one. + for (int i = 0; i < shape.dim_size(); ++i) { + if (shape.dim(i).size() < 0) { + shape.mutable_dim(i)->set_size(1); + LOG(WARNING) << "CalculateTensorSize() -- unknown dim: " << i; + } + } + + int64 num_elems = TensorShape(shape).num_elements(); + return num_elems * size; +} + +int64 CalculateOutputSize( + const std::vector& output_properties, + const int port_num) { + if (port_num < 0) return 4; // 4B for control dependency. + + if (port_num >= output_properties.size()) { + LOG(ERROR) << "CalculateOutputSize() -- port_num: " << port_num + << " >= output_properties.size(): " << output_properties.size(); + return 0; + } + + return CalculateTensorSize(output_properties[port_num]); +} + DeviceProperties GetDeviceInfo(const string& device_str) { DeviceProperties unknown; unknown.set_type("UNKNOWN"); diff --git a/tensorflow/core/grappler/costs/utils.h b/tensorflow/core/grappler/costs/utils.h index 5fd6717712..ea64e5a41d 100644 --- a/tensorflow/core/grappler/costs/utils.h +++ b/tensorflow/core/grappler/costs/utils.h @@ -43,6 +43,17 @@ std::vector FindInputFeatures( const std::unordered_map& name_to_cost, const std::unordered_map& name_to_node); +// Returns the size of tensor (unit: bytes). For tensor shape with unknown rank, +// it assumes the tensor to be scalar. For any unknown dimension, it assumes +// size one. +int64 CalculateTensorSize(const OpInfo::TensorProperties& prop); + +// Returns the size of output at port_num (unit: bytes). A special case is +// port_num -1, which is for control dependency and assumed to be 4 bytes. +int64 CalculateOutputSize( + const std::vector& output_properties, + int port_num); + // Returns the DeviceProperties of the device on which 'node' runs. DeviceProperties GetDeviceInfo(const CostGraphDef::Node& node); DeviceProperties GetDeviceInfo(const string& device_str); diff --git a/tensorflow/core/grappler/costs/utils_test.cc b/tensorflow/core/grappler/costs/utils_test.cc index baa654f475..db5c11f0fe 100644 --- a/tensorflow/core/grappler/costs/utils_test.cc +++ b/tensorflow/core/grappler/costs/utils_test.cc @@ -26,36 +26,42 @@ limitations under the License. namespace tensorflow { namespace grappler { -class UtilsTest : public ::testing::Test { - public: - void CreateConstOp(const string& name, std::initializer_list dims, - NodeDef* node) { - Tensor tensor(DT_FLOAT, TensorShape(dims)); - for (int64 i = 0; i < tensor.NumElements(); ++i) { - tensor.flat()(i) = i / 10.0f; - } - TF_CHECK_OK(NodeDefBuilder(name, "Const") - .Attr("dtype", DT_FLOAT) - .Attr("value", tensor) - .Finalize(node)); - } +namespace { - void CreateConstSizesOp(const string& name, const std::vector& sizes, - NodeDef* node) { - TensorShape shape; - shape.AddDim(sizes.size()); - Tensor tensor(DT_INT32, shape); - for (int64 i = 0; i < tensor.NumElements(); ++i) { - tensor.flat()(i) = sizes[i]; - } - TF_CHECK_OK(NodeDefBuilder(name, "Const") - .Attr("dtype", DT_INT32) - .Attr("value", tensor) - .Finalize(node)); - } -}; +void CreateConstOp(const string& name, std::initializer_list dims, + NodeDef* node) { + Tensor tensor(DT_FLOAT, TensorShape(dims)); + for (int64 i = 0; i < tensor.NumElements(); ++i) + tensor.flat()(i) = i / 10.0f; + TF_CHECK_OK(NodeDefBuilder(name, "Const") + .Attr("dtype", DT_FLOAT) + .Attr("value", tensor) + .Finalize(node)); +} -TEST_F(UtilsTest, ConvOpInfo) { +void CreateConstSizesOp(const string& name, const std::vector& sizes, + NodeDef* node) { + TensorShape shape; + shape.AddDim(sizes.size()); + Tensor tensor(DT_INT32, shape); + for (int64 i = 0; i < tensor.NumElements(); ++i) + tensor.flat()(i) = sizes[i]; + TF_CHECK_OK(NodeDefBuilder(name, "Const") + .Attr("dtype", DT_INT32) + .Attr("value", tensor) + .Finalize(node)); +} + +// Helper method for converting shapes vector to TensorProperty. +OpInfo::TensorProperties ShapeToTensorProperty(const std::vector& shapes, + const DataType& data_type) { + OpInfo::TensorProperties prop; + prop.set_dtype(data_type); + for (int shape : shapes) prop.mutable_shape()->add_dim()->set_size(shape); + return prop; +} + +TEST(UtilsTest, ConvOpInfo) { int batch = 32; int rows = 7; int cols = 9; @@ -146,7 +152,7 @@ TEST_F(UtilsTest, ConvOpInfo) { } } -TEST_F(UtilsTest, TestSkipControlInput) { +TEST(UtilsTest, TestSkipControlInput) { GraphDef graph; TF_CHECK_OK(NodeDefBuilder("constant", "Const") .Attr("dtype", DT_INT32) @@ -172,6 +178,52 @@ TEST_F(UtilsTest, TestSkipControlInput) { EXPECT_TRUE(node_found); } +TEST(UtilsTest, CalculateTensorSize) { + // Test normal usage. + EXPECT_EQ(DataTypeSize(DT_FLOAT) * 1, + CalculateTensorSize(ShapeToTensorProperty({1}, DT_FLOAT))); + EXPECT_EQ(DataTypeSize(DT_FLOAT) * 4 * 4, + CalculateTensorSize(ShapeToTensorProperty({4, 4}, DT_FLOAT))); + EXPECT_EQ(DataTypeSize(DT_HALF) * 10 * 10 * 10, + CalculateTensorSize(ShapeToTensorProperty({10, 10, 10}, DT_HALF))); + EXPECT_EQ( + DataTypeSize(DT_FLOAT) * 100 * 7 * 8 * 99, + CalculateTensorSize(ShapeToTensorProperty({100, 7, 8, 99}, DT_FLOAT))); + + // Test unknown rank: assumes the tensor to be a scalar. + OpInfo::TensorProperties t = ShapeToTensorProperty({100, 7, 8, 99}, DT_FLOAT); + t.mutable_shape()->set_unknown_rank(true); + EXPECT_EQ(DataTypeSize(DT_FLOAT) * 1, CalculateTensorSize(t)); + + // Test unknown shape: assumes unknown shape (-1) to have size 1. + EXPECT_EQ( + DataTypeSize(DT_FLOAT) * 1 * 7 * 8 * 99, + CalculateTensorSize(ShapeToTensorProperty({-1, 7, 8, 99}, DT_FLOAT))); + EXPECT_EQ( + DataTypeSize(DT_FLOAT) * 1 * 7 * 1 * 99, + CalculateTensorSize(ShapeToTensorProperty({-1, 7, -1, 99}, DT_FLOAT))); +} + +TEST(UtilsTest, CalculateOutputSize) { + // Create a set of tensor properties. + std::vector output = { + ShapeToTensorProperty({4, 4}, DT_FLOAT), // 0 + ShapeToTensorProperty({-1, 7, -1, 99}, DT_FLOAT) // 1 + }; + + // Test valid outputs. + EXPECT_EQ(DataTypeSize(DT_FLOAT) * 4 * 4, CalculateOutputSize(output, 0)); + EXPECT_EQ(DataTypeSize(DT_FLOAT) * 1 * 7 * 1 * 99, + CalculateOutputSize(output, 1)); + + // port_num -1 is for control dependency: hard coded 4B. + EXPECT_EQ(4, CalculateOutputSize(output, -1)); + + // Invalid port_num (though it may be an error) shall yield zero + // output size. + EXPECT_EQ(0, CalculateOutputSize(output, 2)); +} + // Class for testing TensorSizeHistogram. class TestTensorSizeHistogram : public TensorSizeHistogram { public: @@ -285,5 +337,7 @@ TEST(DeviceClassTest, GetDeviceClassForNonChannelDevice) { EXPECT_EQ("//GPU", GetDeviceClassForNonChannelDevice("/device:GPU:7")); } +} // namespace + } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index 037a823096..5b93fb128f 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -473,6 +473,7 @@ Status VirtualScheduler::Init() { VLOG(1) << "Some feed nodes were not consumed by the fetch fanin: " << str_util::Join(feed_nodes, ","); } + initialized_ = true; return Status::OK(); } @@ -695,38 +696,6 @@ NodeState& VirtualScheduler::GetNodeStateOrCreateIt(const NodeDef* node) { return it->second; } -int64 VirtualScheduler::CalculateOutputSize( - const std::vector& output_properties, - const int port_num) const { - if (port_num < 0) { - return 4; // 4B for control dependency. - } - - if (port_num >= output_properties.size()) { - VLOG(3) << "VirtualScheduler::CalculateOutputSize() -- " - << "port_num: " << port_num - << " >= output_properties.size(): " << output_properties.size(); - return 0; - } - - const auto& output = output_properties[port_num]; - int64 output_size = DataTypeSize(BaseType(output.dtype())); - - for (const auto& dim : output.shape().dim()) { - auto dim_size = dim.size(); - if (dim_size < 0) { - // Zero output size if there's any unknown dim. - output_size = 0; - VLOG(3) << "VirtualScheduler::CalculateOutputSize() -- " - << "unknown dim: " << output_size; - break; - } - output_size *= dim_size; - } - - return output_size; -} - Costs& VirtualScheduler::FindOrCreateZero(const string& op_name, std::map* op_cost) { auto it = op_cost->find(op_name); @@ -744,7 +713,10 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) { const NodeDef* node = ready_nodes_->GetCurrNode(); const string& op_name = node->op(); - // Also keep track of op counts and times per op (with their shapes). + auto& op_cost = FindOrCreateZero(op_name, &op_to_cost_); + op_cost = CombineCosts(op_cost, node_costs); + + // Also keep track of op counts and costs per op (with their shapes). OpContext op_context = GetCurrNode(); string node_description = GetOpDescription(op_context.op_info); op_counts_[node_description] += 1; @@ -752,9 +724,6 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) { std::make_pair(node_costs.execution_time.asMicroSeconds().count(), !node_costs.inaccurate); - auto& op_cost = FindOrCreateZero(op_name, &op_to_cost_); - op_cost = CombineCosts(op_cost, node_costs); - // Update node and device states. auto& node_state = node_map_[node]; auto& device = device_[node_state.device_name]; @@ -795,7 +764,7 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) { << ", scheduled: " << node_state.time_scheduled.count() << ", finished: " << node_state.time_finished.count(); - // Increment num_inputs_ready of the output nodes + // Increment num_inputs_ready of the output nodes and maybe add to ready nodes for (const auto& port_num_output_pair : node_state.outputs) { for (auto* output_node : port_num_output_pair.second) { auto& output_state = node_map_[output_node]; @@ -812,7 +781,7 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) { } } - // Increment num_outputs_executed of the input nodes. + // Increment num_outputs_executed of the input nodes and maybe update memory. for (const auto& input_port : node_state.inputs) { auto* input = input_port.first; auto port = input_port.second; @@ -841,7 +810,6 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) { } } - // Remove the current node; assume FIFO. ready_nodes_->RemoveCurrNode(); return !ready_nodes_->Empty(); @@ -1007,7 +975,7 @@ Costs VirtualScheduler::Summary(RunMetadata* metadata) { return Summary(); } - // Fill RunMetadata. + // Fill RunMetadata's step_stats and partition_graphs fields. StepStats* stepstats = metadata->mutable_step_stats(); for (const auto& device : device_) { GraphDef* device_partition_graph = metadata->add_partition_graphs(); diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.h b/tensorflow/core/grappler/costs/virtual_scheduler.h index 0e66e8a463..bead84af29 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.h +++ b/tensorflow/core/grappler/costs/virtual_scheduler.h @@ -107,10 +107,10 @@ struct DeviceState { mem_usage_snapshot_at_peak; Costs device_costs; - std::map op_to_cost; // Per-op cost. - std::map op_to_memory; // Per-op memory usage at peak usage. - int64 memory_usage; - int64 max_memory_usage; + std::map op_to_cost; // Per-op cost. + + int64 memory_usage; // Current temporary memory usage + int64 max_memory_usage; // Max temporary memory usage DeviceState() { device_costs = Costs::ZeroCosts(); @@ -283,13 +283,6 @@ class VirtualScheduler { return &node_map_; } - protected: - // Returns the size of output at port_num (unit: bytes). A special case is - // port_num -1, which is for control dependency and assumed to be 4 bytes. - int64 CalculateOutputSize( - const std::vector& output_properties, - const int port_num) const; - private: // Constants. const string kAttrInputSrc = "input_source_"; @@ -321,8 +314,11 @@ class VirtualScheduler { std::vector> additional_nodes_; // Stats: - std::map op_counts_; // Op counts with key with input shape. - // Individual op costs (with input shapes). + // Op counts with key with input shape. + // Example key: "[Op=AssignSub, input_shapes=[[7,1,160,160][7,1,160,160]]" + std::map op_counts_; + // Individual op costs with key with input shape. + // Integer field for execution time in micro seconds. // Boolean field for whether the cost is accurate. std::map> op_costs_; diff --git a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc index 80889afc86..99272dd7e9 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc @@ -19,12 +19,14 @@ limitations under the License. #include "tensorflow/core/framework/tensor_description.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/grappler/clusters/virtual_cluster.h" +#include "tensorflow/core/grappler/costs/utils.h" #include "tensorflow/core/grappler/costs/virtual_placer.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { namespace grappler { + // Class for testing virtual scheduler. class TestVirtualScheduler : public VirtualScheduler { public: @@ -33,7 +35,6 @@ class TestVirtualScheduler : public VirtualScheduler { : VirtualScheduler(grappler_item, use_static_shapes, cluster, &ready_node_manager_) {} - FRIEND_TEST(VirtualSchedulerTest, CalculateOutputSize); FRIEND_TEST(VirtualSchedulerTest, MemoryUsage); FRIEND_TEST(VirtualSchedulerTest, ControlDependency); FRIEND_TEST(VirtualSchedulerTest, ComplexDependency); @@ -1034,17 +1035,6 @@ versions { } } - // Helper method for converting shape vector to TensorProperty. - OpInfo::TensorProperties ShapeToTensorProperty( - const std::vector shape, const DataType& data_type) const { - OpInfo::TensorProperties tensor_property; - tensor_property.set_dtype(data_type); - for (const auto& x : shape) { - tensor_property.mutable_shape()->add_dim()->set_size(x); - } - return tensor_property; - } - // SetUp() inits cluster_ and placer_. std::unique_ptr cluster_; std::unique_ptr placer_; @@ -1729,38 +1719,6 @@ TEST_F(VirtualSchedulerTest, InitAndBasicScheduling) { EXPECT_EQ(2, ops_executed["c1"].op_info.inputs_size()); } -TEST_F(VirtualSchedulerTest, CalculateOutputSize) { - // Init. - CreateGrapplerItemWithAddN(); - InitScheduler(); - - // Create a set of tensor properties. - std::vector output; - output.push_back(ShapeToTensorProperty({4, 4}, DT_FLOAT)); // 0 - output.push_back(ShapeToTensorProperty({1}, DT_FLOAT)); // 1 - output.push_back(ShapeToTensorProperty({10, 10, 10}, DT_HALF)); // 2 - output.push_back(ShapeToTensorProperty({100, 7, 8, 99}, DT_FLOAT)); // 3 - output.push_back(ShapeToTensorProperty({-1, 7, 8, 99}, DT_FLOAT)); // 4 - output.push_back(ShapeToTensorProperty({-1, 7, -1, 99}, DT_FLOAT)); // 4 - - // port_num -1 is for control dependency: hard coded 4B. - EXPECT_EQ(4, scheduler_->CalculateOutputSize(output, -1)); - - // Test valid outputs. - EXPECT_EQ(4 * 4 * 4, scheduler_->CalculateOutputSize(output, 0)); - EXPECT_EQ(4 * 1, scheduler_->CalculateOutputSize(output, 1)); - EXPECT_EQ(2 * 10 * 10 * 10, scheduler_->CalculateOutputSize(output, 2)); - EXPECT_EQ(4 * 100 * 7 * 8 * 99, scheduler_->CalculateOutputSize(output, 3)); - - // Any unknown shape (-1) shall yield zero output size. - EXPECT_EQ(0, scheduler_->CalculateOutputSize(output, 4)); - EXPECT_EQ(0, scheduler_->CalculateOutputSize(output, 5)); - - // Invalid port_num (though it may be an error) shall yield zero - // output size. - EXPECT_EQ(0, scheduler_->CalculateOutputSize(output, 6)); -} - TEST_F(VirtualSchedulerTest, MemoryUsage) { // Init. CreateGrapplerItemWithAddN(); @@ -2041,7 +1999,7 @@ TEST_F(VirtualSchedulerTest, InterDeviceTransfer) { for (const auto& output_property : output_properties_) { output_properties.push_back(output_property); } - return scheduler_->CalculateOutputSize(output_properties, 0); + return CalculateOutputSize(output_properties, 0); }; // Validate transfer size. diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index c708f84948..e898377ded 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -423,6 +423,7 @@ cc_library( "//tensorflow/core/grappler/clusters:virtual_cluster", "//tensorflow/core/grappler/costs:graph_memory", "//tensorflow/core/grappler/costs:graph_properties", + "//tensorflow/core/grappler/costs:utils", "//tensorflow/core/grappler/utils:topological_sort", "//tensorflow/core/grappler/utils:traversal", ], diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc index c775a26914..73f0977242 100644 --- a/tensorflow/core/grappler/optimizers/memory_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/grappler/clusters/virtual_cluster.h" #include "tensorflow/core/grappler/costs/graph_memory.h" #include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorflow/core/grappler/costs/utils.h" #include "tensorflow/core/grappler/graph_view.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/op_types.h" @@ -43,6 +44,8 @@ limitations under the License. namespace tensorflow { namespace grappler { +namespace { + // Prefix added to nodes which are recomputed. const char* kRecomputedNodePrefix = "Recomputed"; const char* kRecomputeTriggerNodePrefix = "RecomputeTrigger"; @@ -744,25 +747,6 @@ Status BuildSwapPair(NodeDef* node, int input_to_swap, return Status::OK(); } -static int64 EstimateSize(const OpInfo::TensorProperties& t) { - DataType dtype = t.dtype(); - int64 size = DataTypeSize(dtype); - TensorShapeProto shape = t.shape(); - if (shape.unknown_rank()) { - // Can't infer the size if the rank is unknown. It has to be at least a - // scalar though. - return size; - } - // If one of the dimensions is unknown statically, assume it's at least one. - for (int i = 0; i < shape.dim_size(); ++i) { - if (shape.dim(i).size() < 0) { - shape.mutable_dim(i)->set_size(1); - } - } - int64 num_elems = TensorShape(shape).num_elements(); - return num_elems * size; -} - struct SwapInfo { std::vector inputs_to_swap; Costs::NanoSeconds time_to_swap = 0; @@ -1149,7 +1133,7 @@ bool SwappingPass(RewriterConfig::MemOptType optimization_level, int64 bytes_to_swap = 0; for (int64 input_id : swap_info.inputs_to_swap) { const OpInfo::TensorProperties& t = props[input_id]; - bytes_to_swap += EstimateSize(t); + bytes_to_swap += CalculateTensorSize(t); } // Let's assume we're going to swap over PCIe running at 16 GBps. swap_info.time_to_swap = bytes_to_swap / 16; @@ -1299,6 +1283,8 @@ Status RelaxAllocatorConstraints(GraphDef* optimized_graph) { return Status::OK(); } +} // namespace + Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* optimized_graph) { *optimized_graph = item.graph; -- cgit v1.2.3