diff options
Diffstat (limited to 'tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc | 139 |
1 files changed, 139 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc b/tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc new file mode 100644 index 0000000000..2368e577c2 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc @@ -0,0 +1,139 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/grappler/optimizers/experimental_implementation_selector.h" + +#include <algorithm> +#include <memory> +#include <string> +#include <vector> + +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" +#include "tensorflow/core/grappler/utils/grappler_test.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { + +constexpr char CpuDevice[] = "/device:CPU:0"; +constexpr char GpuDevice[] = "/device:GPU:0"; + +class ExperimentalImplementationSelectorTest : public GrapplerTest {}; + +TEST_F(ExperimentalImplementationSelectorTest, NoUpdate) { + TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {CpuDevice}); + GrapplerItem item; + CHECK(fake_input.NextItem(&item)); + + std::unique_ptr<CustomGraphOptimizer> optimizer = + CustomGraphOptimizerRegistry::CreateByNameOrNull( + "ExperimentalImplementationSelector"); + ASSERT_NE(nullptr, optimizer); + TF_ASSERT_OK(optimizer->Init()); + + GraphDef output; + const Status status = optimizer->Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + // This is a trivial graph so there is nothing to update. + EXPECT_EQ(item.graph.node_size(), output.node_size()); +} + +TEST_F(ExperimentalImplementationSelectorTest, SwapImplementation) { + using test::function::NDef; + auto cpu_def = test::function::XTimesTwo(); + auto* func_attr = cpu_def.mutable_attr(); + (*func_attr)["experimental_api_implements"].set_s("times_two"); + (*func_attr)["experimental_api_preferred_device"].set_s("CPU"); + + auto gpu_def = test::function::XAddX(); + auto* func2_attr = gpu_def.mutable_attr(); + (*func2_attr)["experimental_api_implements"].set_s("times_two"); + (*func2_attr)["experimental_api_preferred_device"].set_s("GPU"); + + ExperimentalImplementationSelector optimizer; + GraphDef output; + GrapplerItem item; + item.graph = test::function::GDef( + {NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, GpuDevice), + NDef("y1", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, GpuDevice), + NDef("z1", "Identity", {"y1"}, {{"T", DT_FLOAT}}, GpuDevice), + NDef("y2", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, CpuDevice), + NDef("z2", "Identity", {"y2"}, {{"T", DT_FLOAT}}, CpuDevice)}, + // FunctionLib + {cpu_def, gpu_def}); + + TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output)); + + EXPECT_EQ(output.node_size(), 5); + for (const NodeDef& node : output.node()) { + if (node.name() == "y1") { + // Make sure the implementation has been swapped to use the GPU version. + EXPECT_EQ("XAddX", node.op()); + } else if (node.name() == "y2") { + // Make sure the implementation is not changed. + EXPECT_EQ("XTimesTwo", node.op()); + } + } +} + +TEST_F(ExperimentalImplementationSelectorTest, SwapImplementationEval) { + using test::function::NDef; + auto cpu_def = test::function::XTimesTwo(); + auto* func_attr = cpu_def.mutable_attr(); + (*func_attr)["experimental_api_implements"].set_s("random_boost"); + (*func_attr)["experimental_api_preferred_device"].set_s("CPU"); + + auto gpu_def = test::function::XTimesFour(); + auto* func2_attr = gpu_def.mutable_attr(); + (*func2_attr)["experimental_api_implements"].set_s("random_boost"); + (*func2_attr)["experimental_api_preferred_device"].set_s("GPU"); + + ExperimentalImplementationSelector optimizer; + GraphDef output; + GrapplerItem item; + item.graph = test::function::GDef( + {NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, CpuDevice), + NDef("y", "XTimesFour", {"x"}, {{"T", DT_FLOAT}}, CpuDevice), + NDef("z", "Identity", {"y"}, {{"T", DT_FLOAT}}, CpuDevice)}, + // FunctionLib + {cpu_def, gpu_def}); + + const Tensor input = test::AsScalar<float>(1.0f); + item.fetch = {"z"}; + item.feed.emplace_back("x", input); + + const auto four_times_boosted_tensor = EvaluateFetchNodes(item); + test::ExpectTensorEqual<float>(four_times_boosted_tensor[0], + test::AsScalar<float>(4.0f)); + + TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output)); + GrapplerItem optimized(item, std::move(output)); + const auto twice_boosted_tensor = EvaluateFetchNodes(optimized); + test::ExpectTensorEqual<float>(twice_boosted_tensor[0], + test::AsScalar<float>(2.0f)); +} + +} // namespace +} // namespace grappler +} // namespace tensorflow |