aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/compilation_cache_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/compilation_cache_test.cc')
-rw-r--r--tensorflow/compiler/xla/tests/compilation_cache_test.cc218
1 files changed, 218 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/tests/compilation_cache_test.cc b/tensorflow/compiler/xla/tests/compilation_cache_test.cc
new file mode 100644
index 0000000000..38ce007cb0
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/compilation_cache_test.cc
@@ -0,0 +1,218 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <initializer_list>
+#include <memory>
+#include <string>
+
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
+#include "tensorflow/compiler/xla/xla.pb.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+class CompilationCacheTest : public ClientLibraryTestBase {
+ public:
+ void ExecuteComputationR0F32(
+ const Computation& computation,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments, float expected_result,
+ bool expect_cache_hit) {
+ ExecutionProfile execution_profile;
+ std::unique_ptr<Literal> result =
+ client_
+ ->ExecuteAndTransfer(computation, arguments,
+ /*output_layout=*/nullptr, &execution_profile)
+ .ConsumeValueOrDie();
+ LiteralTestUtil::ExpectNear(*LiteralUtil::CreateR0<float>(expected_result),
+ *result, error_spec_);
+ EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit());
+ }
+
+ void ExecuteComputationR2F32(
+ const Computation& computation,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ std::initializer_list<std::initializer_list<float>> expected_result,
+ bool expect_cache_hit) {
+ ExecutionProfile execution_profile;
+ auto data_handle =
+ client_
+ ->Execute(computation, arguments, /*output_layout=*/nullptr,
+ &execution_profile)
+ .ConsumeValueOrDie();
+ std::unique_ptr<Literal> result =
+ client_->Transfer(*data_handle).ConsumeValueOrDie();
+ LiteralTestUtil::ExpectNear(*LiteralUtil::CreateR2<float>(expected_result),
+ *result, error_spec_);
+ EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit());
+ }
+
+ ErrorSpec error_spec_{0.0001};
+};
+
+XLA_TEST_F(CompilationCacheTest, ComputationCalledMultipleTimes) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Neg(builder.ConstantR0<float>(42.0));
+ Computation computation = builder.Build().ConsumeValueOrDie();
+
+ ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/false);
+ ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/true);
+ ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/true);
+}
+
+XLA_TEST_F(CompilationCacheTest, ComputationCalledWithDifferentParameters) {
+ std::unique_ptr<GlobalData> data_42 =
+ client_->TransferToServer(*LiteralUtil::CreateR0<float>(42.0f))
+ .ConsumeValueOrDie();
+ std::unique_ptr<GlobalData> data_123 =
+ client_->TransferToServer(*LiteralUtil::CreateR0<float>(123.0f))
+ .ConsumeValueOrDie();
+ std::unique_ptr<GlobalData> data_456 =
+ client_->TransferToServer(*LiteralUtil::CreateR0<float>(456.0f))
+ .ConsumeValueOrDie();
+
+ ComputationBuilder builder(client_, TestName());
+ builder.Neg(builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param"));
+ Computation computation = builder.Build().ConsumeValueOrDie();
+
+ ExecuteComputationR0F32(computation, {data_42.get()}, -42.0,
+ /*expect_cache_hit=*/false);
+ ExecuteComputationR0F32(computation, {data_123.get()}, -123.0,
+ /*expect_cache_hit=*/true);
+ ExecuteComputationR0F32(computation, {data_456.get()}, -456.0,
+ /*expect_cache_hit=*/true);
+ ExecuteComputationR0F32(computation, {data_42.get()}, -42.0,
+ /*expect_cache_hit=*/true);
+}
+
+XLA_TEST_F(CompilationCacheTest, MultipleComputations) {
+ ComputationBuilder builder_neg(client_, TestName() + "_neg");
+ builder_neg.Neg(builder_neg.ConstantR0<float>(42.0));
+ Computation computation_neg = builder_neg.Build().ConsumeValueOrDie();
+
+ ComputationBuilder builder_exp(client_, TestName() + "_exp");
+ builder_exp.Exp(builder_exp.ConstantR0<float>(1.0));
+ Computation computation_exp = builder_exp.Build().ConsumeValueOrDie();
+
+ ComputationBuilder builder_add(client_, TestName() + "_add");
+ builder_add.Add(builder_add.ConstantR0<float>(2.0),
+ builder_add.ConstantR0<float>(3.0));
+ Computation computation_add = builder_add.Build().ConsumeValueOrDie();
+
+ ExecuteComputationR0F32(computation_neg, {}, -42.0,
+ /*expect_cache_hit=*/false);
+ ExecuteComputationR0F32(computation_exp, {}, 2.7182817,
+ /*expect_cache_hit=*/false);
+ ExecuteComputationR0F32(computation_add, {}, 5.0,
+ /*expect_cache_hit=*/false);
+ ExecuteComputationR0F32(computation_neg, {}, -42.0,
+ /*expect_cache_hit=*/true);
+}
+
+XLA_TEST_F(CompilationCacheTest, DifferentParameterLayouts) {
+ // Create two GlobalData arrays with the same shape but different
+ // layouts. Use these arrays as parameters to a simple computation. If the
+ // layout of the array changes then computation should be recompiled (cache
+ // miss).
+ auto rowmaj_array = test_utils::CreateR2LiteralWithLayout(
+ {{1.0f, 2.0f}, {3.0f, 4.0f}}, /*minor_to_major=*/{1, 0});
+ auto rowmaj_handle =
+ client_->TransferToServer(*rowmaj_array).ConsumeValueOrDie();
+
+ auto colmaj_array = test_utils::CreateR2LiteralWithLayout(
+ {{1.0f, 2.0f}, {3.0f, 4.0f}}, /*minor_to_major=*/{0, 1});
+ auto colmaj_handle =
+ client_->TransferToServer(*colmaj_array).ConsumeValueOrDie();
+
+ ComputationBuilder builder(client_, TestName());
+ builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "param0");
+ Computation computation = builder.Build().ConsumeValueOrDie();
+
+ ExecuteComputationR2F32(computation, {colmaj_handle.get()},
+ {{1.0f, 2.0f}, {3.0f, 4.0f}},
+ /*expect_cache_hit=*/false);
+ ExecuteComputationR2F32(computation, {colmaj_handle.get()},
+ {{1.0f, 2.0f}, {3.0f, 4.0f}},
+ /*expect_cache_hit=*/true);
+ ExecuteComputationR2F32(computation, {rowmaj_handle.get()},
+ {{1.0f, 2.0f}, {3.0f, 4.0f}},
+ /*expect_cache_hit=*/false);
+ ExecuteComputationR2F32(computation, {rowmaj_handle.get()},
+ {{1.0f, 2.0f}, {3.0f, 4.0f}},
+ /*expect_cache_hit=*/true);
+ ExecuteComputationR2F32(computation, {colmaj_handle.get()},
+ {{1.0f, 2.0f}, {3.0f, 4.0f}},
+ /*expect_cache_hit=*/true);
+}
+
+XLA_TEST_F(CompilationCacheTest, MutatedComputation) {
+ // Build a computation, execute it, then mutate it. The mutated computation
+ // should not be in the cache until it is run once. This must be done through
+ // the stub interface because Computations built from ComputationBuilder are
+ // immutable.
+ ComputationBuilder builder(client_, TestName());
+ auto neg = builder.Neg(builder.ConstantR0<float>(42.0));
+ Computation computation = builder.Build().ConsumeValueOrDie();
+
+ ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/false);
+ ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/true);
+
+ BinaryOpRequest request;
+ request.set_binop(BINOP_ADD);
+ *request.mutable_lhs() = neg;
+ *request.mutable_rhs() = neg;
+ OpRequest op_request;
+ *op_request.mutable_computation() = computation.handle();
+ *op_request.mutable_binary_op_request() = request;
+ OpResponse response;
+ tensorflow::Status s = client_->stub()->Op(&op_request, &response);
+ ASSERT_TRUE(s.ok());
+
+ ExecuteComputationR0F32(computation, {}, -84.0, /*expect_cache_hit=*/false);
+ ExecuteComputationR0F32(computation, {}, -84.0, /*expect_cache_hit=*/true);
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}