aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/call_test.cc
blob: 8b31e53707eee456e09adfe9fb76f03a8855056d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <memory>
#include <utility>

#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/test.h"

namespace xla {
namespace {

class CallOpTest : public ClientLibraryTestBase {
 protected:
  XlaComputation CreateR0F32IdentityComputation() {
    XlaBuilder builder("Identity");
    Parameter(&builder, 0, r0f32_, "x");
    auto build_status = builder.Build();
    EXPECT_IS_OK(build_status.status());
    return build_status.ConsumeValueOrDie();
  }

  XlaComputation CreateR1S0F32AdditionComputation() {
    XlaBuilder builder("Addition");
    auto x = Parameter(&builder, 0, r1s0f32_, "x");
    auto y = Parameter(&builder, 1, r1s0f32_, "y");
    Add(x, y);
    auto build_status = builder.Build();
    EXPECT_IS_OK(build_status.status());
    return build_status.ConsumeValueOrDie();
  }

  XlaComputation CreateR1S2F32AdditionComputation() {
    XlaBuilder builder("Addition");
    auto x = Parameter(&builder, 0, r1s2f32_, "x");
    auto y = Parameter(&builder, 1, r1s2f32_, "y");
    Add(x, y);
    auto build_status = builder.Build();
    EXPECT_IS_OK(build_status.status());
    return build_status.ConsumeValueOrDie();
  }

  XlaComputation CreateR0F32TupleComputation() {
    XlaBuilder builder("Tuple");
    Tuple(&builder, {Parameter(&builder, 0, r0f32_, "x")});
    auto build_status = builder.Build();
    EXPECT_IS_OK(build_status.status());
    return build_status.ConsumeValueOrDie();
  }

  Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
  Shape r1s0f32_ = ShapeUtil::MakeShape(F32, {0});
  Shape r1s2f32_ = ShapeUtil::MakeShape(F32, {2});
};

XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) {
  XlaBuilder builder(TestName());
  XlaComputation callee = CreateR0F32IdentityComputation();
  auto constant = ConstantLiteral(&builder, LiteralUtil::CreateR0<float>(42.0));
  Call(&builder, callee, {constant});

  ComputeAndCompareR0<float>(&builder, 42.0, {}, ErrorSpec(0.01f));
}

XLA_TEST_F(CallOpTest, CallR1S0F32AddArray) {
  XlaBuilder builder(TestName());
  XlaComputation callee = CreateR1S0F32AdditionComputation();
  auto x = ConstantLiteral(&builder, LiteralUtil::CreateR1<float>({}));
  auto y = ConstantLiteral(&builder, LiteralUtil::CreateR1<float>({}));
  Call(&builder, callee, {x, y});

  ComputeAndCompareR1<float>(&builder, {}, {}, ErrorSpec(0.01f));
}

XLA_TEST_F(CallOpTest, CallR1S2F32AddArray) {
  XlaBuilder builder(TestName());
  XlaComputation callee = CreateR1S2F32AdditionComputation();
  auto x =
      ConstantLiteral(&builder, LiteralUtil::CreateR1<float>({1.0f, 2.0f}));
  auto y =
      ConstantLiteral(&builder, LiteralUtil::CreateR1<float>({2.0f, 3.0f}));
  Call(&builder, callee, {x, y});

  ComputeAndCompareR1<float>(&builder, {3.0f, 5.0f}, {}, ErrorSpec(0.01f));
}

XLA_TEST_F(CallOpTest, CallTreeTwoDeepBranchFactorThree) {
  XlaBuilder builder("inner");
  {
    auto x = Parameter(&builder, 0, r0f32_, "x");
    Add(x, ConstantR0<float>(&builder, 1.0));
  }
  TF_ASSERT_OK_AND_ASSIGN(XlaComputation inner, builder.Build());

  XlaBuilder builder2("outer");
  {
    auto x = Parameter(&builder2, 0, r0f32_, "x");
    x = Call(&builder2, inner, {x});
    x = Call(&builder2, inner, {x});
    x = Call(&builder2, inner, {x});
  }
  TF_ASSERT_OK_AND_ASSIGN(XlaComputation outer, builder2.Build());

  XlaBuilder builder3("outermost");
  {
    auto x = Parameter(&builder3, 0, r0f32_, "x");
    x = Call(&builder3, outer, {x});
    x = Call(&builder3, outer, {x});
    x = Call(&builder3, outer, {x});
  }

  TF_ASSERT_OK_AND_ASSIGN(
      std::unique_ptr<GlobalData> start,
      client_->TransferToServer(LiteralUtil::CreateR0<float>(1.0f)));
  ComputeAndCompareR0<float>(&builder3, 10.0f, {start.get()}, ErrorSpec(0.0f));
}

XLA_TEST_F(CallOpTest, CallR0F32Tuple) {
  XlaBuilder builder(TestName());
  XlaComputation callee = CreateR0F32TupleComputation();
  auto elem = LiteralUtil::CreateR0<float>(42.0);
  auto tuple = LiteralUtil::MakeTuple({&elem});
  Call(&builder, callee, {ConstantLiteral(&builder, elem)});

  ComputeAndCompareTuple(&builder, tuple, {}, ErrorSpec(0.01f));
}

}  // namespace
}  // namespace xla