aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
blob: d977ff319865a5d371ad127f726880958ae652e2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h"
#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"

#include "absl/strings/str_join.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/functions.h"
#include "tensorflow/core/lib/gtl/map_util.h"

namespace tensorflow {
namespace grappler {
namespace vectorization_utils {

namespace {

// Describes a tensor with its operation Node and output position
typedef std::pair<Node*, int> TensorDesc;

const char* const kRetValOp = "_Retval";

void ReplaceEdgeSources(const TensorDesc& old_src, const TensorDesc& new_src,
                        Graph* graph) {
  // NOTE: We need two for loops here because we can't mutate the set of output
  // edges as we iterate over them.
  std::vector<const Edge*> edges_to_replace;
  for (auto edge : old_src.first->out_edges()) {
    if (edge->src_output() == old_src.second) {
      edges_to_replace.push_back(edge);
    }
  }
  for (auto edge : edges_to_replace) {
    graph->AddEdge(new_src.first, new_src.second, edge->dst(),
                   edge->dst_input());
    graph->RemoveEdge(edge);
  }
}

Status AddMapDefunOutput(FunctionBody* map_defun_fn, Node* map_defun_node,
                         const TensorDesc& output) {
  // Note that we don't update MapDefun attrs as we go, only when we are done
  DataType type = output.first->output_type(output.second);
  int index = map_defun_fn->ret_nodes.size();

  NodeDef ret_node_def;
  ret_node_def.set_name("map_out");
  ret_node_def.set_op(kRetValOp);
  AddNodeAttr("T", type, &ret_node_def);
  AddNodeAttr("index", index, &ret_node_def);

  Status s;
  Node* ret_node = map_defun_fn->graph->AddNode(ret_node_def, &s);
  TF_RETURN_IF_ERROR(s);

  map_defun_fn->graph->AddEdge(output.first, output.second, ret_node, 0);
  map_defun_fn->ret_nodes.push_back(ret_node);
  map_defun_fn->ret_types.push_back(type);

  return s;
}

void RemoveMapDefunOutput(int output_position, Graph* outer_scope,
                          FunctionBody* map_defun_fn, Node* map_defun_node) {
  // Note that we don't update MapDefun attrs as we go, only when we are done
  DCHECK_LT(output_position, map_defun_fn->ret_nodes.size())
      << "Trying to remove output that doesn't exist. Output number: "
      << output_position;

  int num_later_outputs = map_defun_fn->ret_nodes.size() - output_position - 1;

  // Modify map_defun_fn's signature and remove the output node from its graph
  map_defun_fn->graph->RemoveNode(map_defun_fn->ret_nodes[output_position]);
  map_defun_fn->ret_nodes.erase(map_defun_fn->ret_nodes.begin() +
                                output_position);
  map_defun_fn->ret_types.erase(map_defun_fn->ret_types.begin() +
                                output_position);

  // Renumber the nodes and edges that come after
  for (int i = 0; i < num_later_outputs; ++i) {
    ReplaceEdgeSources({map_defun_node, output_position + i + 1},
                       {map_defun_node, output_position + i}, outer_scope);
    // Each ret node has an "index" attr that has to be updated
    map_defun_fn->ret_nodes[output_position + i]->AddAttr("index",
                                                          output_position + i);
  }
}

// Helper class that vectorizes the body of a MapDefun node, adding new
// operations to the graph that collectively compute the same value as what
// running the MapDefun function on slices of the input would produce.
// This class transforms the input FunctionDefs into their corresponding
// Graph objects and works on the graphs directly, then converts them back
// to FunctionDefs when GetResult is called.
class Vectorization {
 public:
  explicit Vectorization(FunctionDefLibrary* lib)
      : lib_(lib), lib_def_(OpRegistry::Global(), *lib) {}

  // Adds the vectorized function and new map_defun_fn to lib, and points
  // vectorized_function to the former. Returns an error status if
  // the conversion between FunctionDef -> Graph -> FunctionDef failed anywhere
  // along the way.
  Status Vectorize(const FunctionDef& outer_scope,
                   const NodeDef& map_defun_node, FunctionDef** result);

 private:
  // Converts FunctionDefs to Graphs and adds mappings from
  // arg nodes and unstacked nodes to the corresponding nodes in outer_scope_.
  Status Initialize(const FunctionDef& outer_scope,
                    const NodeDef& map_defun_node);

  // Converts Graphs back to FunctionDefs and adds them to `lib_`.
  Status GetResult(FunctionDef** vectorized_function);

  // Repeatedly tries to convert outputs of `map_defun_fn_` into new nodes in
  // `outer_scope_`, until there are no convertible outputs remaining.
  void VectorizeHelper();

  // Vectorizes map_defun_fn's output at output_position.
  Status ConvertOutput(int output_position);

  // Adds mappings from node's outputs tensors to converted output tensors,
  // creating the necessary new node(s). Generally, the steps to convert an op
  // are:
  // 1) Create new node(s) in `outer_scope_` that act on batched input tensors.
  //    These operations collectively compute the same value as what running
  //    the original operation on slices of the input tensors would produce.
  //    For example, a Cast op in MapDefun translates to a Cast op in
  //    `outer_scope_`, since the vectorized version of Cast is itself.
  // 2) Promote the inputs of the op inputs to outputs of the
  //    `map_defun_node_` and `map_defun_fn_`.
  // 3) Add edges between the promoted inputs (that are now outputs of
  //    `map_defun_node`) and the inputs ports of the new node(s).
  // 4) For each output of the old node, add the mapping of output tensors to
  //    the conversion map.
  Status AddConversionMapping(Node* op_node);

  // Given a tensor t in `unstacked`, stacks it by doing the equivalent of
  // tf.tile(tf.expand_dims(t, 0), [n, 1, 1, ...]) where n is dimension 0 of
  // inputs to `map_defun_node_`. This stacked tensor will be compatible with
  // the expected output shape of `map_defun_node_`.
  // This is equivalent to the _stack function in python Pfor.
  Status StackTensor(WrappedTensor* unstacked, TensorDesc* result);

  // Recursively looks for unstacked nodes in the `map_defun_fn_` graph by
  // doing a depth-first search from the ret nodes. Lifts nodes that are
  // unstacked (i.e. don't derive from arg nodes) into `outer_scope_` directly
  // and add mappings to `conversion_map_`.
  Status AddUnstackedNodeMappings();

  // Recursive helper for `AddUnstackedNodeMappings`, returns true if tensor
  // is unstacked.
  bool AddUnstackedNodeMappingsHelper(TensorDesc&& tensor, Status* status);

  // Add mappings from `map_defun_fn_` arg nodes to `map_defun_node_` input
  // nodes to `conversion_map_`.
  Status AddArgNodeMappings();

  // Maps a tensor to the corresponding WrappedTensor. For example,
  // {"Cast" Node*, 0} -> WrappedTensor({"Vectorize/Cast" Node*, 0}, true)
  std::map<TensorDesc, WrappedTensor> conversion_map_;

  // Unconvertible ret nodes
  std::set<Node*> unconvertible_;

  FunctionDefLibrary* lib_;  // Not owned
  FunctionLibraryDefinition lib_def_;
  // Note that FunctionBody has a pointer to a Graph object that corresponds
  // to the function's subgraph, with additional kArgOp and kRetValOp nodes
  // that denote that function arguments and return values. These nodes have the
  // attrs "T" for the type, and "index" for the argument / retval index
  // respectively. FunctionBody also keeps track of arg/ret_nodes and
  // arg/ret_types, that should be ordered according to argument/output indices.
  std::unique_ptr<Graph> outer_scope_;
  std::unique_ptr<FunctionBody> map_defun_fn_;
  Node* map_defun_node_ = nullptr;  // Owned by `outer_scope`

  // Caches the loop_len_node_ needed for tiling unstacked output. This
  // corresponds to a vector with one element.
  Node* loop_len_node_ = nullptr;  // Owned by `outer_scope`
  Status status_;
};

Status Vectorization::AddConversionMapping(Node* op_node) {
  for (auto edge : op_node->in_edges()) {
    if (edge->IsControlEdge()) {
      return errors::InvalidArgument(
          "Vectorizing outputs with control inputs is currently not "
          "supported.");
    }
  }

  auto vectorizer = VectorizerRegistry::Global()->Get(op_node->type_string());
  if (vectorizer == nullptr) {
    return errors::Unimplemented("No vectorizer registered for op: ",
                                 op_node->type_string());
  }
  std::vector<WrappedTensor> inputs, outputs;
  inputs.reserve(op_node->num_inputs());
  outputs.reserve(op_node->num_outputs());

  std::vector<const Edge*> input_edges;
  TF_RETURN_IF_ERROR(op_node->input_edges(&input_edges));

  // The inputs for the node to be converted may already have been converted
  // themselves. For those that are not, we promote them to MapDefun outputs.
  for (size_t i = 0; i < op_node->num_inputs(); ++i) {
    auto edge = input_edges[i];
    if (auto found = gtl::FindOrNull(conversion_map_,
                                     {edge->src(), edge->src_output()})) {
      inputs.push_back(*found);
    } else {
      // TODO(rachelim): Handle the case where unconverted inputs are unstacked.
      // We assume that all unconverted inputs will be stacked, since we
      // converted all unstacked nodes in `Initialize`. However, it's actually
      // possible that yet-unconverted nodes may produce unstacked outputs after
      // they are vectorized. (For example, see the "Shape" converter in
      // tensorflow/python/ops/parallel_for/pfor.py). If a vectorizer expects
      // an unstacked input but receives a stacked one, vectorizer->Vectorize
      // will return an error.
      TF_RETURN_IF_ERROR(AddMapDefunOutput(map_defun_fn_.get(), map_defun_node_,
                                           {edge->src(), edge->src_output()}));
      int output_index = map_defun_fn_->ret_nodes.size() - 1;
      inputs.push_back({map_defun_node_, output_index, true});
    }
  }

  Status s = vectorizer->Vectorize(*op_node, outer_scope_.get(),
                                   std::move(inputs), &outputs);
  if (!s.ok()) {
    VLOG(2) << "Vectorizer for op \"" << op_node->type_string()
            << "\" failed with error: " << s;
    return s;
  }

  if (op_node->num_outputs() != outputs.size()) {
    return errors::Internal(
        "Number of vectorizer outputs does not match. Expected: ",
        op_node->num_outputs(), " Actual: ", outputs.size());
  }

  // Add output mappings.
  for (size_t i = 0; i < op_node->num_outputs(); ++i) {
    conversion_map_.insert({{op_node, i}, outputs[i]});
  }

  return Status::OK();
}

Status Vectorization::ConvertOutput(int output_position) {
  // ret_edge->src() is the actual op that generated the retval, and
  // ret_edge->dst() is the retval node whose op is "_Retval"
  const Edge* ret_edge;
  TF_RETURN_IF_ERROR(
      map_defun_fn_->ret_nodes[output_position]->input_edge(0, &ret_edge));

  TensorDesc output({ret_edge->src(), ret_edge->src_output()});
  TensorDesc converted_output;

  // It's possible the output already has a mapping, if it comes from a node
  // that has already been converted.
  auto found = gtl::FindOrNull(conversion_map_, output);
  if (!found) {
    TF_RETURN_IF_ERROR(AddConversionMapping(output.first));
    found = &conversion_map_.at(output);
  }

  if (found->stacked) {
    converted_output = {found->node, found->output_index};
  } else {
    // Some outputs may be unstacked if they don't derive from arg nodes
    // (for example, if a function returns a constant). For these, we
    // have to add extra nodes to tile it in the 0th dimension.
    TF_RETURN_IF_ERROR(StackTensor(found, &converted_output));
  }

  ReplaceEdgeSources({map_defun_node_, output_position}, converted_output,
                     outer_scope_.get());
  RemoveMapDefunOutput(output_position, outer_scope_.get(), map_defun_fn_.get(),
                       map_defun_node_);

  return Status::OK();
}

Status Vectorization::Vectorize(const FunctionDef& outer_scope,
                                const NodeDef& map_defun_node,
                                FunctionDef** result) {
  TF_RETURN_IF_ERROR(Initialize(outer_scope, map_defun_node));
  VectorizeHelper();
  return GetResult(result);
}

void Vectorization::VectorizeHelper() {
  while (true) {
    int output_position = graph_utils::GetFirstElementIndexWithPredicate(
        [this](Node* n) {
          return this->unconvertible_.find(n) == this->unconvertible_.end();
        },
        map_defun_fn_->ret_nodes);

    // No outputs left to convert
    if (output_position == -1) break;

    Status s = ConvertOutput(output_position);
    if (!s.ok()) {
      Node* output_node = map_defun_fn_->ret_nodes.at(output_position);
      VLOG(2) << "Could not convert the output at node: "
              << output_node->DebugString() << "\nError: " << s;
      unconvertible_.insert(output_node);
    }
  }

  // If we've converted all the outputs of the MapDefun function, we no longer
  // need the MapDefun node and can delete it.
  if (map_defun_fn_->ret_nodes.empty()) {
    outer_scope_->RemoveNode(map_defun_node_);
  } else {
    // Update MapDefun node attrs accordingly
    DCHECK_EQ(map_defun_fn_->ret_types.size(), map_defun_fn_->ret_nodes.size());
    map_defun_node_->AddAttr(
        "output_shapes",
        std::vector<PartialTensorShape>(map_defun_fn_->ret_types.size()));
    map_defun_node_->AddAttr("output_types", map_defun_fn_->ret_types);
  }
}

Status Vectorization::Initialize(const FunctionDef& outer_scope,
                                 const NodeDef& map_defun_node) {
  // Convert outer_scope and map_defun_fn to FunctionBodys so we can
  // work on Graphs directly.
  const FunctionDef* map_defun_fn =
      lib_def_.Find(map_defun_node.attr().at("f").func().name());

  if (map_defun_fn == nullptr) {
    return errors::NotFound("Could not find function with name ",
                            map_defun_node.attr().at("f").func().name(),
                            " in function library.");
  }

  auto get_func_sig = [this](const string& op, const OpDef** sig) {
    return this->lib_def_.LookUpOpDef(op, sig);
  };

  FunctionBody* outer_fn;
  TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(outer_scope, {}, &lib_def_,
                                             get_func_sig, &outer_fn));
  // We don't need outer_fn, just the graph
  outer_scope_.reset(outer_fn->graph);
  outer_fn->graph = nullptr;
  delete outer_fn;

  FunctionBody* tmp;
  TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*map_defun_fn, {}, &lib_def_,
                                             get_func_sig, &tmp));
  map_defun_fn_.reset(tmp);

  // Find the MapDefun node in outer_scope_
  int node_id = graph_utils::GetFirstElementIndexWithPredicate(
      [&map_defun_node](Node* n) { return n->name() == map_defun_node.name(); },
      outer_scope_->nodes());
  if (node_id == -1) {
    return errors::NotFound("Could not find node with name ",
                            map_defun_node.name(), " in outer_scope.");
  }
  map_defun_node_ = outer_scope_->FindNodeId(node_id);

  TF_RETURN_IF_ERROR(AddArgNodeMappings());

  TF_RETURN_IF_ERROR(AddUnstackedNodeMappings());
  loop_len_node_ = nullptr;

  return Status::OK();
}

// TODO(rachelim): It might be profitable to use the C++ API for this instead of
// NodeBuilder
Status Vectorization::StackTensor(WrappedTensor* unstacked,
                                  TensorDesc* result) {
  // Note that all these nodes are necessary as the size of the batch may not be
  // constant.
  if (unstacked->stacked) {
    return errors::Internal("Can only stack unstacked tensor.");
  }

  Graph* g = outer_scope_.get();
  auto node_builder = [](StringPiece op) {
    return NodeBuilder(strings::StrCat("vectorized/stack/", op), op);
  };

  auto make_const = [&node_builder](const Input::Initializer& val, Graph* graph,
                                    Node** result) {
    TF_RETURN_IF_ERROR(val.status);
    return node_builder("Const")
        .Attr("value", val.tensor)
        .Attr("dtype", val.tensor.dtype())
        .Finalize(graph, result);
  };

  // If loop_len_node_ hasn't been created yet, add the node and cache it.
  if (loop_len_node_ == nullptr) {
    Node* input_node;
    TF_RETURN_IF_ERROR(map_defun_node_->input_node(0, &input_node));

    Node* shape_node;
    TF_RETURN_IF_ERROR(
        node_builder("Shape").Input(input_node).Finalize(g, &shape_node));

    Node* const_vec_0;
    TF_RETURN_IF_ERROR(make_const({0}, g, &const_vec_0));
    Node* const_vec_1;
    TF_RETURN_IF_ERROR(make_const({1}, g, &const_vec_1));

    Node* strided_slice_node;
    TF_RETURN_IF_ERROR(node_builder("StridedSlice")
                           .Input(shape_node)   // input
                           .Input(const_vec_0)  // begin
                           .Input(const_vec_1)  // end
                           .Input(const_vec_1)  // strides
                           .Finalize(g, &strided_slice_node));

    // Produces a vector of length 1
    TF_RETURN_IF_ERROR(node_builder("Reshape")
                           .Input(strided_slice_node)  // tensor
                           .Input(const_vec_1)         // shape
                           .Finalize(g, &loop_len_node_));
  }

  Node* ones_shape;
  TF_RETURN_IF_ERROR(node_builder("Shape")
                         .Input(unstacked->node)  // input
                         .Finalize(g, &ones_shape));

  Node* ones;
  TF_RETURN_IF_ERROR(
      node_builder("OnesLike").Input(ones_shape).Finalize(g, &ones));

  Node* const_0;
  TF_RETURN_IF_ERROR(make_const(0, g, &const_0));

  Node* multiples;
  TF_RETURN_IF_ERROR(node_builder("Concat")
                         .Input(const_0)                           // concat_dim
                         .Input({{loop_len_node_, 0}, {ones, 0}})  // values
                         .Finalize(g, &multiples));

  Node* expand_dims;
  TF_RETURN_IF_ERROR(node_builder("ExpandDims")
                         .Input(unstacked->node)  // input
                         .Input(const_0)          // dim
                         .Finalize(g, &expand_dims));

  TF_RETURN_IF_ERROR(node_builder("Tile")
                         .Input(expand_dims)  // input
                         .Input(multiples)    // multiples
                         .Finalize(g, &result->first));
  result->second = 0;
  return Status::OK();
}

Status Vectorization::AddArgNodeMappings() {
  // Note that inputs to map_defun_fn_ are either regular arguments (for which
  // the operations are mapped across their 0th dimension) or captured inputs
  // (for which the operations apply to the argument wholesale).
  int num_args =
      map_defun_node_->attrs().Find("Targuments")->list().type_size();

  auto add_conversion = [this](Node* arg_node, bool stacked) {
    Node* input_node;
    TF_RETURN_IF_ERROR(map_defun_node_->input_node(
        arg_node->attrs().Find("index")->i(), &input_node));

    conversion_map_.insert({{arg_node, 0}, {input_node, 0, stacked}});

    // Control inputs
    conversion_map_.insert({{arg_node, Graph::kControlSlot},
                            {input_node, Graph::kControlSlot, stacked}});

    return Status::OK();
  };

  // Regular arguments
  for (int i = 0; i < num_args; ++i) {
    TF_RETURN_IF_ERROR(add_conversion(map_defun_fn_->arg_nodes[i], true));
  }

  // Captured inputs. These are applied (without slicing) to every iteration of
  // the map function, hence are mapped to unstacked nodes.
  for (int i = num_args; i < map_defun_fn_->arg_nodes.size(); ++i) {
    TF_RETURN_IF_ERROR(add_conversion(map_defun_fn_->arg_nodes[i], false));
  }

  return Status::OK();
}

bool Vectorization::AddUnstackedNodeMappingsHelper(TensorDesc&& tensor,
                                                   Status* status) {
  if (auto found = gtl::FindOrNull(conversion_map_, tensor)) {
    return !found->stacked;
  }

  if (tensor.first->op_def().is_stateful()) {
    // We don't lift stateful nodes directly out of the MapDefun, since they may
    // have to be executed N times.
    return false;
  }

  bool is_unstacked = true;
  for (auto edge : tensor.first->in_edges()) {
    // Ignore Source nodes. Note that these are also ignored in the
    // GraphToFunctionDef conversion.
    if (edge->src()->IsSource()) continue;

    // A node is unstacked if all of its inputs are unstacked
    is_unstacked &= AddUnstackedNodeMappingsHelper(
        {edge->src(), edge->src_output()}, status);
  }

  if (!is_unstacked) {
    return false;
  }

  // If the node is unstacked, we copy it into outer_scope_ and
  // add it to the map. Note that we don't clean up the nodes that are copied
  // in map_defun_fn_, and rely on them being pruned out later.
  Node* node = outer_scope_->AddNode(tensor.first->def(), status);
  if (!status->ok()) return true;

  // Add input edges to nodes that should already have been lifted.
  for (auto edge : tensor.first->in_edges()) {
    // Ignore Source nodes. Note that these are also ignored in the
    // GraphToFunctionDef conversion.
    if (edge->src()->IsSource()) continue;

    if (auto found = gtl::FindOrNull(conversion_map_,
                                     {edge->src(), edge->src_output()})) {
      outer_scope_->AddEdge(found->node, found->output_index, node,
                            edge->dst_input());
    } else {
      status->Update(errors::Internal(
          "Could not find input conversion even though we did depth first "
          "conversion."));
    }
  }

  // Add output mappings
  for (int i = 0; i < tensor.first->num_outputs(); ++i) {
    conversion_map_.insert({{tensor.first, i}, WrappedTensor(node, i, false)});
  }
  conversion_map_.insert({{tensor.first, Graph::kControlSlot},
                          WrappedTensor(node, Graph::kControlSlot, false)});

  return true;
}

Status Vectorization::AddUnstackedNodeMappings() {
  SetVector<Node*> unstacked_nodes;
  Status s;
  for (const auto& ret_node : map_defun_fn_->ret_nodes) {
    const Edge* in_edge = nullptr;
    TF_RETURN_IF_ERROR(ret_node->input_edge(0, &in_edge));
    AddUnstackedNodeMappingsHelper({in_edge->src(), in_edge->src_output()}, &s);
    TF_RETURN_IF_ERROR(s);
  }
  return Status::OK();
}

Status Vectorization::GetResult(FunctionDef** vectorized_function) {
  TF_RETURN_IF_ERROR(status_);
  TF_RETURN_IF_ERROR(graph_utils::EnsureNodeNamesUnique(outer_scope_.get()));
  TF_RETURN_IF_ERROR(graph_utils::EnsureNodeNamesUnique(map_defun_fn_->graph));

  if (!map_defun_fn_->ret_nodes.empty()) {
    FunctionDef* map_defun_fn = lib_->add_function();
    graph_utils::SetUniqueGraphFunctionName("map_defun_fn", lib_, map_defun_fn);
    TF_RETURN_IF_ERROR(GraphToFunctionDef(
        *map_defun_fn_->graph, map_defun_fn->signature().name(), map_defun_fn));

    AttrValue func_attr;
    func_attr.mutable_func()->set_name(map_defun_fn->signature().name());
    map_defun_node_->AddAttr("f", func_attr);
  }

  *vectorized_function = lib_->add_function();
  graph_utils::SetUniqueGraphFunctionName("vectorized_fn", lib_,
                                          *vectorized_function);
  TF_RETURN_IF_ERROR(GraphToFunctionDef(
      *outer_scope_, (*vectorized_function)->signature().name(),
      *vectorized_function));
  return Status::OK();
}

}  // namespace

Status VectorizeMapDefun(const FunctionDef& outer_scope,
                         const NodeDef& map_defun_node, FunctionDefLibrary* lib,
                         FunctionDef** result) {
  *result = nullptr;
  return Vectorization(lib).Vectorize(outer_scope, map_defun_node, result);
}

}  // end namespace vectorization_utils
}  // end namespace grappler
}  // end namespace tensorflow