aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc139
1 files changed, 139 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc b/tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc
new file mode 100644
index 0000000000..2368e577c2
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc
@@ -0,0 +1,139 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/grappler/optimizers/experimental_implementation_selector.h"
+
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/utils/grappler_test.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+constexpr char CpuDevice[] = "/device:CPU:0";
+constexpr char GpuDevice[] = "/device:GPU:0";
+
+class ExperimentalImplementationSelectorTest : public GrapplerTest {};
+
+TEST_F(ExperimentalImplementationSelectorTest, NoUpdate) {
+ TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {CpuDevice});
+ GrapplerItem item;
+ CHECK(fake_input.NextItem(&item));
+
+ std::unique_ptr<CustomGraphOptimizer> optimizer =
+ CustomGraphOptimizerRegistry::CreateByNameOrNull(
+ "ExperimentalImplementationSelector");
+ ASSERT_NE(nullptr, optimizer);
+ TF_ASSERT_OK(optimizer->Init());
+
+ GraphDef output;
+ const Status status = optimizer->Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+
+ // This is a trivial graph so there is nothing to update.
+ EXPECT_EQ(item.graph.node_size(), output.node_size());
+}
+
+TEST_F(ExperimentalImplementationSelectorTest, SwapImplementation) {
+ using test::function::NDef;
+ auto cpu_def = test::function::XTimesTwo();
+ auto* func_attr = cpu_def.mutable_attr();
+ (*func_attr)["experimental_api_implements"].set_s("times_two");
+ (*func_attr)["experimental_api_preferred_device"].set_s("CPU");
+
+ auto gpu_def = test::function::XAddX();
+ auto* func2_attr = gpu_def.mutable_attr();
+ (*func2_attr)["experimental_api_implements"].set_s("times_two");
+ (*func2_attr)["experimental_api_preferred_device"].set_s("GPU");
+
+ ExperimentalImplementationSelector optimizer;
+ GraphDef output;
+ GrapplerItem item;
+ item.graph = test::function::GDef(
+ {NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, GpuDevice),
+ NDef("y1", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, GpuDevice),
+ NDef("z1", "Identity", {"y1"}, {{"T", DT_FLOAT}}, GpuDevice),
+ NDef("y2", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, CpuDevice),
+ NDef("z2", "Identity", {"y2"}, {{"T", DT_FLOAT}}, CpuDevice)},
+ // FunctionLib
+ {cpu_def, gpu_def});
+
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ EXPECT_EQ(output.node_size(), 5);
+ for (const NodeDef& node : output.node()) {
+ if (node.name() == "y1") {
+ // Make sure the implementation has been swapped to use the GPU version.
+ EXPECT_EQ("XAddX", node.op());
+ } else if (node.name() == "y2") {
+ // Make sure the implementation is not changed.
+ EXPECT_EQ("XTimesTwo", node.op());
+ }
+ }
+}
+
+TEST_F(ExperimentalImplementationSelectorTest, SwapImplementationEval) {
+ using test::function::NDef;
+ auto cpu_def = test::function::XTimesTwo();
+ auto* func_attr = cpu_def.mutable_attr();
+ (*func_attr)["experimental_api_implements"].set_s("random_boost");
+ (*func_attr)["experimental_api_preferred_device"].set_s("CPU");
+
+ auto gpu_def = test::function::XTimesFour();
+ auto* func2_attr = gpu_def.mutable_attr();
+ (*func2_attr)["experimental_api_implements"].set_s("random_boost");
+ (*func2_attr)["experimental_api_preferred_device"].set_s("GPU");
+
+ ExperimentalImplementationSelector optimizer;
+ GraphDef output;
+ GrapplerItem item;
+ item.graph = test::function::GDef(
+ {NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, CpuDevice),
+ NDef("y", "XTimesFour", {"x"}, {{"T", DT_FLOAT}}, CpuDevice),
+ NDef("z", "Identity", {"y"}, {{"T", DT_FLOAT}}, CpuDevice)},
+ // FunctionLib
+ {cpu_def, gpu_def});
+
+ const Tensor input = test::AsScalar<float>(1.0f);
+ item.fetch = {"z"};
+ item.feed.emplace_back("x", input);
+
+ const auto four_times_boosted_tensor = EvaluateFetchNodes(item);
+ test::ExpectTensorEqual<float>(four_times_boosted_tensor[0],
+ test::AsScalar<float>(4.0f));
+
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+ GrapplerItem optimized(item, std::move(output));
+ const auto twice_boosted_tensor = EvaluateFetchNodes(optimized);
+ test::ExpectTensorEqual<float>(twice_boosted_tensor[0],
+ test::AsScalar<float>(2.0f));
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow