aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc
blob: 0a0426adcbc1b5b89be0841fa2c4204e2b65abf4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/compiler/xla/tests/local_client_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/core/lib/core/status_test_util.h"

namespace xla {
namespace {

// Tests that ensure outfeed instructions that are contained in nested
// computations in non-root positions are executed.

class OutfeedInNestedComputationTest : public LocalClientTestBase {};

XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInWhile) {
  XlaBuilder b(TestName());

  Shape state_tuple_array_shape = ShapeUtil::MakeShape(xla::S32, {10, 5});
  Shape int_shape = ShapeUtil::MakeShape(xla::S32, {});
  Shape state_tuple_shape =
      ShapeUtil::MakeTupleShape({int_shape, state_tuple_array_shape});
  Shape xfeed_shape = ShapeUtil::MakeShape(xla::S32, {2});

  XlaOp some_buffer = Broadcast(ConstantR0<int32_t>(&b, 0), {10, 5});
  XlaOp num_iter = Infeed(&b, int_shape);
  XlaOp init_tuple = Tuple(&b, {num_iter, some_buffer});

  TF_ASSERT_OK_AND_ASSIGN(XlaComputation loop_cond, [&] {
    // Condition: iteration variable > 0
    XlaBuilder cond_builder("loop_condition");
    XlaOp state_tuple = Parameter(&cond_builder, 0, state_tuple_shape, "state");
    XlaOp loop_counter = GetTupleElement(state_tuple, 0);
    Outfeed(loop_counter, int_shape, "");
    Gt(loop_counter, ConstantR0<int32_t>(&cond_builder, 0));
    return cond_builder.Build();
  }());

  TF_ASSERT_OK_AND_ASSIGN(XlaComputation loop_body, [&] {
    XlaBuilder body_builder("loop_body");
    XlaOp state_tuple = Parameter(&body_builder, 0, state_tuple_shape, "state");
    XlaOp loop_counter = GetTupleElement(state_tuple, 0);
    XlaOp buffer_inside = GetTupleElement(state_tuple, 1);

    // Read some stuff from Infeed.
    XlaOp some_input = Infeed(&body_builder, xfeed_shape);
    XlaOp sum = Add(some_input, Broadcast(loop_counter, {2}));
    Outfeed(sum, xfeed_shape, "");

    XlaOp iter_left = Sub(loop_counter, ConstantR0<int32_t>(&body_builder, 1));

    Tuple(&body_builder, {iter_left, buffer_inside});
    return body_builder.Build();
  }());

  // Build loop.
  XlaOp result_tuple = While(loop_cond, loop_body, init_tuple);
  GetTupleElement(result_tuple, 0);
  TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, b.Build());

  std::unique_ptr<xla::Literal> comp_result;
  std::unique_ptr<tensorflow::Thread> thread(
      tensorflow::Env::Default()->StartThread(
          tensorflow::ThreadOptions(), "execute_thread", [&] {
            comp_result = local_client_->ExecuteAndTransfer(computation, {})
                              .ConsumeValueOrDie();
          }));

  VLOG(1) << "Transferring trip count to computation";
  // Transfer number of iterations to Infeed.
  TF_ASSERT_OK(
      local_client_->TransferToInfeed(*LiteralUtil::CreateR0<int32_t>(1)));

  // Pick up value from outfeed
  {
    VLOG(1) << "Reading from condition outfeed";
    TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> r,
                            local_client_->TransferFromOutfeed(&int_shape));
    EXPECT_EQ(r->Get<int32>({}), 1);
  }

  VLOG(1) << "Writing data to infeed";
  // Transfer some stuff to Infeed for use inside of loop.
  TF_ASSERT_OK(local_client_->TransferToInfeed(
      *LiteralUtil::CreateR1<int32_t>({10, 20})));

  // Pick up value from outfeed
  {
    VLOG(1) << "Reading from body outfeed";
    TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> r,
                            local_client_->TransferFromOutfeed(&xfeed_shape));
    EXPECT_EQ(r->Get<int32>({0}), 11);
    EXPECT_EQ(r->Get<int32>({1}), 21);
  }

  {
    VLOG(1) << "Reading from condition outfeed";
    TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> r,
                            local_client_->TransferFromOutfeed(&int_shape));
    EXPECT_EQ(r->Get<int32>({}), 0);
  }

  // Joins the thread
  thread.reset();

  EXPECT_EQ(comp_result->Get<int32>({}), 0);
}

XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInConditional) {
  XlaBuilder b(TestName());

  Shape condition_shape = ShapeUtil::MakeShape(xla::PRED, {});
  Shape result_shape = ShapeUtil::MakeShape(xla::PRED, {});

  TF_ASSERT_OK_AND_ASSIGN(XlaComputation true_computation, [&] {
    XlaBuilder inner_builder("true_computation");
    XlaOp param = Parameter(&inner_builder, 0, result_shape, "param");
    Outfeed(param, result_shape, "");
    Or(param, param);
    return inner_builder.Build();
  }());

  TF_ASSERT_OK_AND_ASSIGN(XlaComputation false_computation, [&] {
    XlaBuilder inner_builder("false_computation");
    Parameter(&inner_builder, 0, result_shape, "param");
    return inner_builder.Build();
  }());

  XlaOp pred = Infeed(&b, condition_shape);
  Conditional(/*predicate=*/pred, /*true_operand=*/pred,
              /*true_computation=*/true_computation, /*false_operand=*/pred,
              /*false_computation=*/false_computation);

  TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, b.Build());

  std::unique_ptr<xla::Literal> comp_result;
  std::unique_ptr<tensorflow::Thread> thread(
      tensorflow::Env::Default()->StartThread(
          tensorflow::ThreadOptions(), "execute_thread", [&] {
            comp_result = local_client_->ExecuteAndTransfer(computation, {})
                              .ConsumeValueOrDie();
          }));

  TF_ASSERT_OK(
      local_client_->TransferToInfeed(*LiteralUtil::CreateR0<bool>(true)));

  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> r,
                          local_client_->TransferFromOutfeed(&result_shape));

  EXPECT_EQ(r->Get<bool>({}), true);

  // Join the thread
  thread.reset();
}

}  // namespace
}  // namespace xla