aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/test_utils_test.cc
blob: 64d9e2031eb42878c3533205f95243645d289f74 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/compiler/xla/tests/test_utils.h"

#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/local_client_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/core/lib/core/status_test_util.h"

namespace xla {
namespace {

// A test fixture is used because we need a client for our computation builder.
class TestUtilsTest : public LocalClientTestBase {};

XLA_TEST_F(TestUtilsTest, UnusedParam) {
  XlaBuilder builder(TestName());
  // Make the reduction lambda.
  Shape single_float = ShapeUtil::MakeShape(F32, {});
  Parameter(&builder, 0, single_float, "unused");
  Parameter(&builder, 1, single_float, "used");
  auto computation_status = builder.Build();
  TF_ASSERT_OK(computation_status.status());

  // Make the reduction.
  Shape pair_float = ShapeUtil::MakeShape(F32, {2});
  Reduce(Parameter(&builder, 0, pair_float, "operand"),
         Parameter(&builder, 1, single_float, "init"),
         computation_status.ValueOrDie(), {0});
  computation_status = builder.Build();
  TF_ASSERT_OK(computation_status.status());

  auto executable_status = local_client_->Compile(
      computation_status.ValueOrDie(), {&pair_float, &single_float},
      ExecutableBuildOptions());
  TF_ASSERT_OK(executable_status.status());
  HloModule& module = const_cast<HloModule&>(
      executable_status.ValueOrDie()->executable()->module());
  TF_ASSERT_OK(MakeFakeArguments(&module).status());
}

XLA_TEST_F(TestUtilsTest, Token) {
  auto module = ParseHloString(
                    R"(HloModule outfeed_module

    ENTRY InfeedToOutfeed {
      token = token[] parameter(0)
      infeed = ((u32[3]{0}, pred[]), token[]) infeed(token)
      infeed.data = (u32[3]{0}, pred[]) get-tuple-element(infeed), index=0
      outfeed = token[] outfeed(infeed.data, token)
      ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token)
      infeed.1.data = (u32[3]{0}, pred[]) get-tuple-element(infeed.1), index=0
      infeed.1.token = token[] get-tuple-element(infeed.1), index=1
      outfeed.1 = token[] outfeed(infeed.1.data, infeed.1.token)
    })")
                    .ValueOrDie();
  TF_ASSERT_OK(MakeFakeArguments(module.get()).status());
}

XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicSlices) {
  auto module = ParseHloString(
                    R"(HloModule index_space_module

    ENTRY IndexSpace {
      index_param = s32[3]{0} parameter(0)
      array_param.1 = f32[123,4,789]{0,1,2} parameter(1)
      array_param.2 = f32[3,3000,5]{0,1,2} parameter(2)
      dynamic-slice.1 = f32[1,2,3] dynamic-slice(array_param.1, index_param), dynamic_slice_sizes={1,2,3}
      ROOT dynamic-slice.2 = f32[3,2,2] dynamic-slice(array_param.2, index_param), dynamic_slice_sizes={3,2,2}
    })")
                    .ValueOrDie();
  TF_ASSERT_OK_AND_ASSIGN(std::vector<std::unique_ptr<Literal>> args,
                          MakeFakeArguments(module.get()));
  ASSERT_EQ(args.size(), 3);
  const Literal& index_arg = *args[0];

  EXPECT_EQ(index_arg.Get<int32>({0}), 0);

  EXPECT_GE(index_arg.Get<int32>({1}), 0);
  EXPECT_LE(index_arg.Get<int32>({1}), 2);

  EXPECT_GE(index_arg.Get<int32>({2}), 0);
  EXPECT_LE(index_arg.Get<int32>({2}), 3);
}

XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) {
  auto module = ParseHloString(
                    R"(HloModule index_space_module

    ENTRY IndexSpace {
      index_param = s32[3]{0} parameter(0)
      array_param.1 = f32[123,4,789]{0,1,2} parameter(1)
      array_param.2 = f32[3,3000,5]{0,1,2} parameter(2)
      update_param.1 = f32[1,2,3]{0,1,2} parameter(3)
      update_param.2 = f32[3,2,2]{0,1,2} parameter(4)

      dynamic-update-slice.1 = f32[123,4,789] dynamic-update-slice(array_param.1, update_param.1, index_param)
      ROOT dynamic-update-slice.2 = f32[3,3000,5] dynamic-update-slice(array_param.2, update_param.2, index_param)
    })")
                    .ValueOrDie();
  TF_ASSERT_OK_AND_ASSIGN(std::vector<std::unique_ptr<Literal>> args,
                          MakeFakeArguments(module.get()));
  ASSERT_EQ(args.size(), 5);
  const Literal& index_arg = *args[0];

  EXPECT_EQ(index_arg.Get<int32>({0}), 0);

  EXPECT_GE(index_arg.Get<int32>({1}), 0);
  EXPECT_LE(index_arg.Get<int32>({1}), 2);

  EXPECT_GE(index_arg.Get<int32>({2}), 0);
  EXPECT_LE(index_arg.Get<int32>({2}), 3);
}

}  // namespace
}  // namespace xla