aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/hlo_metadata_test.cc
blob: 4d82442f7e3630c115eff1f17544e2b892c5e7eb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/service/local_service.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/local_client_test_base.h"

namespace xla {
namespace {

class HloMetadataTest : public LocalClientTestBase {
 protected:
  HloMetadataTest() {
    metadata_.set_op_type("add");
    metadata_.set_op_name("my_sum_op");
  }

  void BuildAddComputation(XlaBuilder* builder) {
    auto x = Parameter(builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
    auto y = Parameter(builder, 1, ShapeUtil::MakeShape(F32, {}), "y");
    Add(x, y);
  }

  OpMetadata metadata_;
};

TEST_F(HloMetadataTest, MetadataPropagation) {
  XlaBuilder builder("add");
  builder.SetOpMetadata(metadata_);
  BuildAddComputation(&builder);
  builder.ClearOpMetadata();

  Shape argument_layout = ShapeUtil::MakeShape(F32, {});
  TF_ASSERT_OK_AND_ASSIGN(
      std::unique_ptr<LocalExecutable> executable,
      local_client_->Compile(builder.Build().ValueOrDie(),
                             {&argument_layout, &argument_layout},
                             ExecutableBuildOptions()));

  auto instruction = executable->executable()
                         ->module()
                         .entry_computation()
                         ->root_instruction();
  EXPECT_EQ("add", instruction->metadata().op_type());
  EXPECT_EQ("my_sum_op", instruction->metadata().op_name());
}

TEST_F(HloMetadataTest, MetadataClearing) {
  XlaBuilder builder("add");
  builder.SetOpMetadata(metadata_);
  // Some other pretend computation here.
  builder.ClearOpMetadata();
  BuildAddComputation(&builder);

  Shape argument_layout = ShapeUtil::MakeShape(F32, {});
  auto executable_status = local_client_->Compile(
      builder.Build().ValueOrDie(), {&argument_layout, &argument_layout},
      ExecutableBuildOptions());
  ASSERT_IS_OK(executable_status);

  std::unique_ptr<LocalExecutable> executable =
      executable_status.ConsumeValueOrDie();

  auto instruction = executable->executable()
                         ->module()
                         .entry_computation()
                         ->root_instruction();
  // We expect these to be empty (no metadata set).
  EXPECT_EQ("", instruction->metadata().op_type());
  EXPECT_EQ("", instruction->metadata().op_name());
}

}  // namespace
}  // namespace xla