aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/shape_refiner.h
blob: 2d04ea15055015e0c4e59260ec0f0ef22551dd4b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_
#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_

#include <vector>

#include "tensorflow/core/common_runtime/graph_runner.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/macros.h"

namespace tensorflow {

// ShapeRefiner performs shape inference for TensorFlow Graphs.  It is
// responsible for instantiating InferenceContext objects for each
// Node in the Graph, and providing/storing the 'input_tensor' Tensors
// used by Shape Inference functions, when available at graph
// construction time.
class ShapeRefiner {
 public:
  ShapeRefiner(int graph_def_version, const OpRegistryInterface* ops);
  ~ShapeRefiner();

  // Performs validation of 'node' and runs 'node's shape function,
  // storing its shape outputs.
  //
  // All inputs of 'node' must be added to ShapeRefiner prior to
  // adding 'node'.
  //
  // Returns an error if:
  //  - the shape function for 'node' was not registered.
  //  - 'node' was added before its inputs.
  //  - The shape inference function returns an error.
  Status AddNode(const Node* node);

  // Sets 'node's 'output_port' output to have shape 'shape'.
  //
  // Returns an error if 'node' was not previously added to this
  // object, if 'output_port' is invalid, or if 'shape' is
  // not compatible with the existing shape of the output.
  Status SetShape(const Node* node, int output_port,
                  shape_inference::ShapeHandle shape);

  // Returns the InferenceContext for 'node', if present.
  shape_inference::InferenceContext* GetContext(const Node* node) const {
    auto it = node_to_context_.find(node);
    if (it == node_to_context_.end()) {
      return nullptr;
    }
    return it->second.get();
  }

  // Getters and setters for graph_def_version_.
  int32 graph_def_version() { return graph_def_version_; }
  void set_graph_def_version(int32 version) { graph_def_version_ = version; }

  void set_require_shape_inference_fns(bool require_shape_inference_fns) {
    require_shape_inference_fns_ = require_shape_inference_fns;
  }

 private:
  // Extracts the subgraph ending at 'node' that is statically
  // computable and inserts into 'out_graph'. If statically computable,
  // 'is_constant_graph' will be true.
  Status ExtractConstantSubgraph(
      Node* node, Graph* out_graph, bool* is_constant_graph,
      std::vector<std::pair<string, Tensor>>* const_inputs) TF_MUST_USE_RESULT;

  Status EvaluateConstantTensorForEdge(const Node* node, int dst_idx,
                                       bool* evaluated, Tensor* result);

  // This function tries to materialize as much information about the 'node''s
  // dst_idx input as a statically computable shape, and the result may be
  // partially known, depending on what is statically inferable.
  //
  // This is called when node.input[dst_idx] is a tensor that is used to define
  // the shape of some other tensor (e.g., the second argument to Reshape is a
  // <shape> tensor, where each element of the shape tensor is a dimension of
  // the target tensor).  It returns in <result> a shape for that input.
  //
  // Unlike simply resolving node.input[dst_idx] to a constant and then
  // converting that to a shape, this function can return a partial shape. This
  // is useful for cases where the shape tensor is only partially defined, such
  // as with calls for: reshape(x, shape(y)) where shape(y) is partially
  // defined.
  //
  // The implementation has op implementations for ops commonly called on shape
  // tensors, and the implementations are specialized to shape tensors (namely,
  // the output is a vector).
  //
  // <target_context> is used when creating new DimensionHandle and ShapeHandle
  // objects.
  Status ConstantPartialShape(shape_inference::InferenceContext* target_context,
                              const Node* node, int dst_idx,
                              shape_inference::ShapeHandle* result);

  int32 graph_def_version_;
  const OpRegistryInterface* const ops_registry_;

  // The lifetime of the tensors are bound to the runner, so it should be the
  // deleted after the tensors.
  GraphRunner graph_runner_;

  // Stores a map from a node to its InferenceContext.
  //
  // Owns values.
  std::unordered_map<const Node*,
                     std::unique_ptr<shape_inference::InferenceContext>>
      node_to_context_;

  // Holds a cache from 'tensor name' to the tensor that is
  // evaluatable as a constant expression.  This reduces repeated
  // execution of the entire constant subgraph as a graph is being
  // built up.  This could be changed to some kind of size-based LRU
  // cache to avoid consuming too much memory, if that eventually
  // becomes a concern.
  //
  // Only tensors less than 1KiB are currently stored in the cache.
  static constexpr int64 kMaxTensorSize = 1024;
  std::unordered_map<string, Tensor> const_tensor_map_;

  bool require_shape_inference_fns_ = true;

  TF_DISALLOW_COPY_AND_ASSIGN(ShapeRefiner);
};

}  // namespace tensorflow

#endif  // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_