aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/costs/virtual_scheduler_test.cc')
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler_test.cc48
1 files changed, 3 insertions, 45 deletions
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
index 80889afc86..99272dd7e9 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
@@ -19,12 +19,14 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_description.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
+#include "tensorflow/core/grappler/costs/utils.h"
#include "tensorflow/core/grappler/costs/virtual_placer.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace grappler {
+
// Class for testing virtual scheduler.
class TestVirtualScheduler : public VirtualScheduler {
public:
@@ -33,7 +35,6 @@ class TestVirtualScheduler : public VirtualScheduler {
: VirtualScheduler(grappler_item, use_static_shapes, cluster,
&ready_node_manager_) {}
- FRIEND_TEST(VirtualSchedulerTest, CalculateOutputSize);
FRIEND_TEST(VirtualSchedulerTest, MemoryUsage);
FRIEND_TEST(VirtualSchedulerTest, ControlDependency);
FRIEND_TEST(VirtualSchedulerTest, ComplexDependency);
@@ -1034,17 +1035,6 @@ versions {
}
}
- // Helper method for converting shape vector to TensorProperty.
- OpInfo::TensorProperties ShapeToTensorProperty(
- const std::vector<int> shape, const DataType& data_type) const {
- OpInfo::TensorProperties tensor_property;
- tensor_property.set_dtype(data_type);
- for (const auto& x : shape) {
- tensor_property.mutable_shape()->add_dim()->set_size(x);
- }
- return tensor_property;
- }
-
// SetUp() inits cluster_ and placer_.
std::unique_ptr<VirtualCluster> cluster_;
std::unique_ptr<VirtualPlacer> placer_;
@@ -1729,38 +1719,6 @@ TEST_F(VirtualSchedulerTest, InitAndBasicScheduling) {
EXPECT_EQ(2, ops_executed["c1"].op_info.inputs_size());
}
-TEST_F(VirtualSchedulerTest, CalculateOutputSize) {
- // Init.
- CreateGrapplerItemWithAddN();
- InitScheduler();
-
- // Create a set of tensor properties.
- std::vector<OpInfo::TensorProperties> output;
- output.push_back(ShapeToTensorProperty({4, 4}, DT_FLOAT)); // 0
- output.push_back(ShapeToTensorProperty({1}, DT_FLOAT)); // 1
- output.push_back(ShapeToTensorProperty({10, 10, 10}, DT_HALF)); // 2
- output.push_back(ShapeToTensorProperty({100, 7, 8, 99}, DT_FLOAT)); // 3
- output.push_back(ShapeToTensorProperty({-1, 7, 8, 99}, DT_FLOAT)); // 4
- output.push_back(ShapeToTensorProperty({-1, 7, -1, 99}, DT_FLOAT)); // 4
-
- // port_num -1 is for control dependency: hard coded 4B.
- EXPECT_EQ(4, scheduler_->CalculateOutputSize(output, -1));
-
- // Test valid outputs.
- EXPECT_EQ(4 * 4 * 4, scheduler_->CalculateOutputSize(output, 0));
- EXPECT_EQ(4 * 1, scheduler_->CalculateOutputSize(output, 1));
- EXPECT_EQ(2 * 10 * 10 * 10, scheduler_->CalculateOutputSize(output, 2));
- EXPECT_EQ(4 * 100 * 7 * 8 * 99, scheduler_->CalculateOutputSize(output, 3));
-
- // Any unknown shape (-1) shall yield zero output size.
- EXPECT_EQ(0, scheduler_->CalculateOutputSize(output, 4));
- EXPECT_EQ(0, scheduler_->CalculateOutputSize(output, 5));
-
- // Invalid port_num (though it may be an error) shall yield zero
- // output size.
- EXPECT_EQ(0, scheduler_->CalculateOutputSize(output, 6));
-}
-
TEST_F(VirtualSchedulerTest, MemoryUsage) {
// Init.
CreateGrapplerItemWithAddN();
@@ -2041,7 +1999,7 @@ TEST_F(VirtualSchedulerTest, InterDeviceTransfer) {
for (const auto& output_property : output_properties_) {
output_properties.push_back(output_property);
}
- return scheduler_->CalculateOutputSize(output_properties, 0);
+ return CalculateOutputSize(output_properties, 0);
};
// Validate transfer size.