diff options
Diffstat (limited to 'tensorflow/core/grappler/costs/virtual_scheduler_test.cc')
-rw-r--r-- | tensorflow/core/grappler/costs/virtual_scheduler_test.cc | 48 |
1 files changed, 3 insertions, 45 deletions
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc index 80889afc86..99272dd7e9 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc @@ -19,12 +19,14 @@ limitations under the License. #include "tensorflow/core/framework/tensor_description.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/grappler/clusters/virtual_cluster.h" +#include "tensorflow/core/grappler/costs/utils.h" #include "tensorflow/core/grappler/costs/virtual_placer.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { namespace grappler { + // Class for testing virtual scheduler. class TestVirtualScheduler : public VirtualScheduler { public: @@ -33,7 +35,6 @@ class TestVirtualScheduler : public VirtualScheduler { : VirtualScheduler(grappler_item, use_static_shapes, cluster, &ready_node_manager_) {} - FRIEND_TEST(VirtualSchedulerTest, CalculateOutputSize); FRIEND_TEST(VirtualSchedulerTest, MemoryUsage); FRIEND_TEST(VirtualSchedulerTest, ControlDependency); FRIEND_TEST(VirtualSchedulerTest, ComplexDependency); @@ -1034,17 +1035,6 @@ versions { } } - // Helper method for converting shape vector to TensorProperty. - OpInfo::TensorProperties ShapeToTensorProperty( - const std::vector<int> shape, const DataType& data_type) const { - OpInfo::TensorProperties tensor_property; - tensor_property.set_dtype(data_type); - for (const auto& x : shape) { - tensor_property.mutable_shape()->add_dim()->set_size(x); - } - return tensor_property; - } - // SetUp() inits cluster_ and placer_. std::unique_ptr<VirtualCluster> cluster_; std::unique_ptr<VirtualPlacer> placer_; @@ -1729,38 +1719,6 @@ TEST_F(VirtualSchedulerTest, InitAndBasicScheduling) { EXPECT_EQ(2, ops_executed["c1"].op_info.inputs_size()); } -TEST_F(VirtualSchedulerTest, CalculateOutputSize) { - // Init. - CreateGrapplerItemWithAddN(); - InitScheduler(); - - // Create a set of tensor properties. - std::vector<OpInfo::TensorProperties> output; - output.push_back(ShapeToTensorProperty({4, 4}, DT_FLOAT)); // 0 - output.push_back(ShapeToTensorProperty({1}, DT_FLOAT)); // 1 - output.push_back(ShapeToTensorProperty({10, 10, 10}, DT_HALF)); // 2 - output.push_back(ShapeToTensorProperty({100, 7, 8, 99}, DT_FLOAT)); // 3 - output.push_back(ShapeToTensorProperty({-1, 7, 8, 99}, DT_FLOAT)); // 4 - output.push_back(ShapeToTensorProperty({-1, 7, -1, 99}, DT_FLOAT)); // 4 - - // port_num -1 is for control dependency: hard coded 4B. - EXPECT_EQ(4, scheduler_->CalculateOutputSize(output, -1)); - - // Test valid outputs. - EXPECT_EQ(4 * 4 * 4, scheduler_->CalculateOutputSize(output, 0)); - EXPECT_EQ(4 * 1, scheduler_->CalculateOutputSize(output, 1)); - EXPECT_EQ(2 * 10 * 10 * 10, scheduler_->CalculateOutputSize(output, 2)); - EXPECT_EQ(4 * 100 * 7 * 8 * 99, scheduler_->CalculateOutputSize(output, 3)); - - // Any unknown shape (-1) shall yield zero output size. - EXPECT_EQ(0, scheduler_->CalculateOutputSize(output, 4)); - EXPECT_EQ(0, scheduler_->CalculateOutputSize(output, 5)); - - // Invalid port_num (though it may be an error) shall yield zero - // output size. - EXPECT_EQ(0, scheduler_->CalculateOutputSize(output, 6)); -} - TEST_F(VirtualSchedulerTest, MemoryUsage) { // Init. CreateGrapplerItemWithAddN(); @@ -2041,7 +1999,7 @@ TEST_F(VirtualSchedulerTest, InterDeviceTransfer) { for (const auto& output_property : output_properties_) { output_properties.push_back(output_property); } - return scheduler_->CalculateOutputSize(output_properties, 0); + return CalculateOutputSize(output_properties, 0); }; // Validate transfer size. |