aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc
blob: 6795130cd10933d745171acc7c44fed90a6cb87d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <memory>
#include <vector>

#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/test.h"

namespace xla {
namespace {

using ::testing::ContainsRegex;
using ::testing::HasSubstr;

class DeconstructTupleTest : public ClientLibraryTestBase {
 protected:
  // Build and execute the given computation then verify the results can be
  // transferred from the device successfully.
  std::unique_ptr<GlobalData> ExecuteAndCheckTransfer(
      XlaBuilder* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
    XlaComputation computation = builder->Build().ConsumeValueOrDie();
    auto global_data =
        client_->Execute(computation, arguments, &execution_options_)
            .ConsumeValueOrDie();
    TF_CHECK_OK(client_->Transfer(*global_data).status());
    return global_data;
  }
};

TEST_F(DeconstructTupleTest, DeconstructTuple) {
  XlaBuilder builder(TestName());
  auto const1 = ConstantR1<float>(&builder, {1.0, 2.0, 3.0, 4.0});
  auto const2 = ConstantR1<float>(&builder, {2.0, 4.0, 6.0, 8.0});
  Tuple(&builder, {const1, const2});
  auto global_data = ExecuteAndCheckTransfer(&builder, {});

  auto result_status = client_->DeconstructTuple(*global_data);
  EXPECT_TRUE(result_status.ok());

  // Try copying the elements back and comparing it
  auto handles = result_status.ConsumeValueOrDie();
  std::unique_ptr<Literal> literal;
  TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0]));
  LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
  TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1]));
  LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
}

TEST_F(DeconstructTupleTest, DeconstructTupleTwice) {
  XlaBuilder builder(TestName());
  auto const1 = ConstantR1<float>(&builder, {1.0, 2.0, 3.0, 4.0});
  auto const2 = ConstantR1<float>(&builder, {2.0, 4.0, 6.0, 8.0});
  Tuple(&builder, {const1, const2});
  auto global_data = ExecuteAndCheckTransfer(&builder, {});

  auto result_status1 = client_->DeconstructTuple(*global_data);
  EXPECT_TRUE(result_status1.ok());
  auto result_status2 = client_->DeconstructTuple(*global_data);
  EXPECT_TRUE(result_status2.ok());

  auto handles1 = result_status1.ConsumeValueOrDie();
  auto handles2 = result_status2.ConsumeValueOrDie();

  std::unique_ptr<Literal> literal;
  TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles1[0]));
  LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
  TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles1[1]));
  LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);

  handles1[0].reset();
  handles1[1].reset();

  TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles2[0]));
  LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
  TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles2[1]));
  LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
}

XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) {
  XlaBuilder builder(TestName());
  auto const1 = ConstantR1<float>(&builder, {1.0, 2.0, 3.0, 4.0});
  auto const2 = ConstantR1<float>(&builder, {2.0, 4.0, 6.0, 8.0});
  Tuple(&builder, {const1, const2, const2, const1});
  auto global_data = ExecuteAndCheckTransfer(&builder, {});

  auto result_status = client_->DeconstructTuple(*global_data);
  EXPECT_TRUE(result_status.ok());

  // Verify the returned GlobalDataHandle arrays have repeated elements like the
  // tuple does. That is, in the returned vector of handles, handle[0] should be
  // the same as handle[3] and handle[1] should be the same as handle[2].
  auto handles = result_status.ConsumeValueOrDie();

  std::unique_ptr<Literal> literal;
  TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0]));
  LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
  TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1]));
  LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
  TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2]));
  LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
  TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[3]));
  LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
}

TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) {
  XlaBuilder builder(TestName());
  auto const1 = ConstantR1<float>(&builder, {1.0, 2.0, 3.0, 4.0});
  auto const2 = ConstantR1<float>(&builder, {2.0, 4.0, 6.0, 8.0});
  Tuple(&builder, {const1, const2, const1});
  auto global_data = ExecuteAndCheckTransfer(&builder, {});

  auto result_status = client_->DeconstructTuple(*global_data);
  EXPECT_TRUE(result_status.ok());
  auto handles = result_status.ConsumeValueOrDie();

  // Deallocate the tuple, then try copying the elements back. The elements
  // should not have been deallocated because of reference counting.
  global_data.reset();

  std::unique_ptr<Literal> literal;
  TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0]));
  LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
  TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1]));
  LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
  TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2]));
  LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);

  /// Try deallocating one of the repeated elements, then copy
  handles[0].reset();

  TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2]));
  LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
}

TEST_F(DeconstructTupleTest, DeconstructNonTuple) {
  XlaBuilder builder(TestName());
  ConstantR1<float>(&builder, {1.0, 2.0, 3.0, 4.0});
  auto global_data = ExecuteAndCheckTransfer(&builder, {});

  auto result_status = client_->DeconstructTuple(*global_data);
  EXPECT_FALSE(result_status.ok());
  EXPECT_THAT(result_status.status().error_message(),
              ContainsRegex("global data handle .* is not a tuple"));
}

XLA_TEST_F(DeconstructTupleTest, DeconstructTupleFromParam) {
  XlaBuilder builder(TestName());
  std::unique_ptr<Literal> param0_literal =
      LiteralUtil::CreateR1<float>({3.14f, -100.25f});
  std::unique_ptr<GlobalData> param0_data =
      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
  auto p = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "param0");
  Tuple(&builder, {p});
  auto global_data = ExecuteAndCheckTransfer(&builder, {param0_data.get()});

  auto result_status = client_->DeconstructTuple(*global_data);
  EXPECT_TRUE(result_status.ok());
  auto handles = result_status.ConsumeValueOrDie();
  EXPECT_NE(handles[0]->handle().handle(), param0_data->handle().handle());
}

XLA_TEST_F(DeconstructTupleTest, DeconstructNestedTuple) {
  XlaBuilder builder(TestName());
  auto const1 = ConstantR1<float>(&builder, {1.0, 2.0, 3.0, 4.0});
  auto const2 = ConstantR1<float>(&builder, {2.0, 4.0, 6.0, 8.0});
  Tuple(&builder, {Tuple(&builder, {const1, const2}), const1});
  auto global_data = ExecuteAndCheckTransfer(&builder, {});

  auto result_status = client_->DeconstructTuple(*global_data);
  EXPECT_FALSE(result_status.ok());
  EXPECT_THAT(result_status.status().error_message(),
              HasSubstr("Deconstructing nested tuples is not implemented"));
}

}  // namespace
}  // namespace xla