aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc
blob: 9ad98e50386c5ffd6f85292c270892eb9c91f14d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"

namespace xla {
Status HloInputOutputAliasConfig::SetUpAlias(const ShapeIndex& output_index,
                                             int64 param_number,
                                             const ShapeIndex& param_index) {
  // Output can't be aliased with multiple parameters.
  TF_RET_CHECK(!alias_.element(output_index));
  (*alias_.mutable_element(output_index)) =
      std::make_pair(param_number, param_index);
  return Status::OK();
}

HloInputOutputAliasProto HloInputOutputAliasConfig::ToProto() const {
  HloInputOutputAliasProto result;
  alias_.ForEachElement(
      [&](const ShapeIndex& index,
          const absl::optional<std::pair<int64, ShapeIndex>>& data) {
        if (data) {
          HloInputOutputAliasProto::AliasEntryProto entry;
          for (int64 i : index) {
            entry.add_output_shape_index(i);
          }
          entry.set_parameter_number(data->first);
          for (int64 i : data->second) {
            entry.add_parameter_shape_index(i);
          }
          result.add_entries()->Swap(&entry);
        }
      });
  return result;
}

StatusOr<HloInputOutputAliasConfig> HloInputOutputAliasConfig::CreateFromProto(
    const HloModule* module, const HloInputOutputAliasProto& proto) {
  HloInputOutputAliasConfig result(
      module->entry_computation()->root_instruction()->shape());
  for (const HloInputOutputAliasProto::AliasEntryProto& entry :
       proto.entries()) {
    ShapeIndex output_index(entry.output_shape_index().begin(),
                            entry.output_shape_index().end());

    int64 param_number = entry.parameter_number();
    ShapeIndex param_index(entry.parameter_shape_index().begin(),
                           entry.parameter_shape_index().end());
    TF_RETURN_IF_ERROR(
        result.SetUpAlias(output_index, param_number, param_index));
  }

  return result;
}

string HloInputOutputAliasConfig::ToString() const {
  std::vector<string> pieces;
  pieces.push_back("HloInputOutputAliasConfig");

  ForEachAlias([&](const ShapeIndex& output_index, int64 param_number,
                   const ShapeIndex& param_index) {
    pieces.push_back(absl::StrFormat(
        "  OutputIndex %s is aliased with parameter %lld at %s:",
        output_index.ToString(), param_number, param_index.ToString()));
  });

  return absl::StrJoin(pieces, "\n");
}

bool HloInputOutputAliasConfig::ParameterHasAlias(int64 param_number) const {
  bool output = false;
  alias_.ForEachElement(
      [&](const xla::ShapeIndex&,
          absl::optional<std::pair<int64, ShapeIndex>> alias) {
        if (alias && alias->first == param_number) {
          output = true;
        }
      });
  return output;
}

absl::optional<ShapeIndex> HloInputOutputAliasConfig::GetAliasedOutput(
    int64 param_number, const ShapeIndex& param_index) const {
  absl::optional<ShapeIndex> output;
  alias_.ForEachElement(
      [&](const xla::ShapeIndex& output_index,
          absl::optional<std::pair<int64, ShapeIndex>> alias) {
        if (alias && alias->first == param_number &&
            alias->second == param_index) {
          output = output_index;
        }
      });
  return output;
}

absl::optional<std::pair<int64, ShapeIndex>>
HloInputOutputAliasConfig::GetAliasedParameter(
    const ShapeIndex& output_index) const {
  CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index));
  return alias_.element(output_index);
}

void HloInputOutputAliasConfig::ForEachAlias(AliasFn fn) const {
  alias_.ForEachElement(
      [&](const ShapeIndex& output_index,
          absl::optional<std::pair<int64, ShapeIndex>> aliased) {
        if (aliased) {
          fn(output_index, aliased->first, aliased->second);
        }
      });
}

Status HloInputOutputAliasConfig::ForEachAliasWithStatus(
    AliasFnWithStatus fn) const {
  return alias_.ForEachElementWithStatus(
      [&](const ShapeIndex& output_index,
          absl::optional<std::pair<int64, ShapeIndex>> aliased) {
        if (aliased) {
          TF_RETURN_IF_ERROR(fn(output_index, aliased->first, aliased->second));
        }
        return Status::OK();
      });
}

Status HloInputOutputAliasConfig::Verify(const HloModule& module) const {
  std::vector<ShapeTree<bool>> param_has_seen;
  const HloComputation* entry = module.entry_computation();
  for (int64 i = 0; i < entry->num_parameters(); ++i) {
    HloInstruction* param = entry->parameter_instruction(i);
    param_has_seen.emplace_back(param->shape());
  }
  return ForEachAliasWithStatus([&](const ShapeIndex& output_index,
                                    int64 param_number,
                                    const ShapeIndex& param_index) -> Status {
    const HloInstruction* root = entry->root_instruction();

    const Shape& param_shape =
        entry->parameter_instruction(param_number)->shape();
    const Shape& output_shape = root->shape();
    TF_RET_CHECK(entry->num_parameters() > param_number);
    TF_RET_CHECK(ShapeUtil::IndexIsValid(param_shape, param_index));
    TF_RET_CHECK(ShapeUtil::IndexIsValid(output_shape, output_index));

    // Check each param_number and param_index pair only show up once. No
    // input can be aliased with output buffers.
    TF_RET_CHECK(param_has_seen[param_number].element(param_index) == false);

    *(param_has_seen[param_number].mutable_element(param_index)) = true;

    return Status::OK();
  });
}

std::ostream& operator<<(std::ostream& out,
                         const HloInputOutputAliasConfig& config) {
  out << config.ToString();
  return out;
}
}  // namespace xla