aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc
blob: 86a78ee429e8913edb4a948727fa692083c472f4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h"

#include <memory>
#include <vector>

#include "tensorflow/compiler/tf2xla/tf2xla.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_executable.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"

namespace tensorflow {

namespace {
// Returns the index of the result in the temp buffers.
xla::StatusOr<size_t> ComputeResultIndex(
    const xla::BufferAssignment& buffer_assignment) {
  TF_ASSIGN_OR_RETURN(const xla::BufferAllocation::Slice result_slice,
                      buffer_assignment.GetUniqueTopLevelOutputSlice());
  return result_slice.index();
}

// Collect names from `entries`, where T is one of tf2xla::{Feed,Fetch}. We hold
// the actual strings in nonempty_names, and hold arrays of pointers in
// name_ptrs, terminated by a nullptr entry.
template <typename T>
void CollectNames(const T& entries, std::vector<string>* nonempty_names,
                  std::vector<const char*>* name_ptrs) {
  // First collect `nonempty_names`, to ensure the underlying strings won't
  // change out from under us.
  for (const auto& entry : entries) {
    const string& name = entry.name();
    if (!name.empty()) {
      nonempty_names->push_back(name);
    }
  }
  // Now set `name_ptrs` pointing to the strings in `nonempty_names`.
  name_ptrs->reserve(entries.size() + 1);  // +1 for nullptr array terminator
  size_t nonempty_index = 0;
  for (const auto& entry : entries) {
    const string& name = entry.name();
    if (!name.empty()) {
      name_ptrs->push_back(nonempty_names->at(nonempty_index).c_str());
      ++nonempty_index;
    } else {
      name_ptrs->push_back("");
    }
  }
  name_ptrs->push_back(nullptr);  // array terminator
}

}  // namespace

/*static*/ xla::StatusOr<std::unique_ptr<XlaJitCompiledCpuFunction>>
XlaJitCompiledCpuFunction::Compile(
    const GraphDef& graph_def, const tf2xla::Config& config,
    const xla::ExecutableBuildOptions& build_options) {
  // Convert the graph_def into an xla::XlaComputation.
  TF_ASSIGN_OR_RETURN(xla::LocalClient * client,
                      xla::ClientLibrary::GetOrCreateLocalClient());
  xla::XlaComputation computation;
  TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToXla(graph_def, config, client,
                                                      &computation));

  // Get and verify the program shape.
  TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::ProgramShape> program_shape,
                      client->GetComputationShape(computation));
  if (program_shape->result().element_type() != xla::TUPLE) {
    // The XlaCompiler we use to build the xla computation always generates a
    // tuple result, and XlaCompiledCpuFunction relies on this for simpler
    // calling semantics.
    return errors::Internal(
        "XlaJitCompiledCpuFunction requires the XLA result to be a tuple");
  }
  // The parameter names are currently meaningless, and redundant with the rest
  // of our metadata, so clear them out to avoid confusion and save space.
  program_shape->clear_parameter_names();

  // Compute arg shapes, needed to compile the executable.
  std::vector<const xla::Shape*> arg_shapes;
  arg_shapes.reserve(program_shape->parameters_size());
  for (int i = 0; i < program_shape->parameters_size(); ++i) {
    arg_shapes.push_back(&program_shape->parameters(i));
  }

  // Compile the executable. The static_cast to the CpuExecutable subclass is
  // necessary since the raw function and buffer assignments are only available
  // there.
  TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::LocalExecutable> executable,
                      client->Compile(computation, arg_shapes, build_options));
  const xla::cpu::CpuExecutable* cpu_executable =
      static_cast<xla::cpu::CpuExecutable*>(executable->executable());
  XlaCompiledCpuFunction::RawFunction raw_function =
      cpu_executable->compute_function();
  const xla::BufferAssignment& buffer_assignment =
      cpu_executable->buffer_assignment();

  // Compute buffer infos and the result index, needed to run the raw function.
  std::vector<cpu_function_runtime::BufferInfo> buffer_infos =
      xla::cpu::CreateBufferInfosFromBufferAssignment(buffer_assignment);
  std::vector<int32> arg_index_table =
      xla::cpu::CreateArgIndexTableFromBufferInfos(buffer_infos);
  TF_ASSIGN_OR_RETURN(size_t result_index,
                      ComputeResultIndex(buffer_assignment));

  std::unique_ptr<XlaJitCompiledCpuFunction> jit_unique_ptr(
      new XlaJitCompiledCpuFunction);
  XlaJitCompiledCpuFunction* jit = jit_unique_ptr.get();
  jit->executable_ = std::move(executable);
  jit->buffer_infos_ = std::move(buffer_infos);
  jit->arg_index_table_ = std::move(arg_index_table);
  jit->program_shape_ = std::move(program_shape);
  jit->static_data_.set_raw_function(raw_function);
  jit->static_data_.set_buffer_infos(jit->buffer_infos_.data());
  jit->static_data_.set_num_buffers(jit->buffer_infos_.size());
  jit->static_data_.set_arg_index_table(jit->arg_index_table_.data());
  jit->static_data_.set_num_args(jit->arg_index_table_.size());
  jit->static_data_.set_result_index(result_index);
  // Optional metadata is collected and set below.
  CollectNames(config.feed(), &jit->nonempty_arg_names_, &jit->arg_names_);
  CollectNames(config.fetch(), &jit->nonempty_result_names_,
               &jit->result_names_);
  jit->static_data_.set_arg_names(jit->arg_names_.data());
  jit->static_data_.set_result_names(jit->result_names_.data());
  jit->static_data_.set_program_shape(jit->program_shape_.get());

  if (cpu_executable->hlo_profiling_enabled()) {
    jit->static_data_.set_hlo_profile_printer_data(
        &cpu_executable->hlo_profile_printer_data());
    jit->static_data_.set_profile_counters_size(
        cpu_executable->hlo_profile_printer_data().profile_counters_size());
  }

  return std::move(jit_unique_ptr);
}

}  // namespace tensorflow