diff options
Diffstat (limited to 'tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h')
-rw-r--r-- | tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h | 19 |
1 files changed, 8 insertions, 11 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h index 56eb88c95e..8d4676aae0 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h +++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h @@ -18,15 +18,12 @@ limitations under the License. #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" namespace tensorflow { namespace grappler { -namespace vectorization_utils { - -// Describes a tensor with its operation Node and output position -typedef std::pair<Node*, int> Port; // Interface for vectorization of TensorFlow operations. See `CastVectorizer` // for an example. @@ -36,17 +33,17 @@ class Vectorizer { // Vectorizes an operation, `node`, by adding Node(s) to `outer_scope` // that produce the same vector output(s) as executing `node`'s op - // on elements of the vector inputs. The new Node(s) collectively have the + // on elements of `inputs`. The new Node(s) collectively have the // same number of input and output ports as the node being converted. - // Adds mappings for the new nodes' input and output ports to `inputs` and - // `outputs` respectively, where the i'th Port in inputs/outputs - // corresponds to the i'th input/output port of the node to be converted. + // Adds edges between the newly created nodes and nodes in `inputs`, and adds + // mappings to the new nodes' output ports to `outputs`, where the i'th + // value in `outputs` corresponds to the i'th output port of the node + // to be converted. virtual Status Vectorize(const Node& node, Graph* outer_scope, - std::vector<Port>* input_ports, - std::vector<Port>* output_ports) = 0; + std::vector<WrappedTensor>&& inputs, + std::vector<WrappedTensor>* outputs) = 0; }; -} // namespace vectorization_utils } // namespace grappler } // namespace tensorflow #endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_H_ |