diff options
Diffstat (limited to 'tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc index 74ce520ce1..f1ba741821 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc @@ -19,15 +19,15 @@ limitations under the License. namespace tensorflow { namespace grappler { -namespace vectorization_utils { +namespace { class UnpackVectorizer : public Vectorizer { public: Status Vectorize(const Node& node, Graph* outer_scope, - std::vector<Port>* input_ports, - std::vector<Port>* output_ports) override { + std::vector<WrappedTensor>&& inputs, + std::vector<WrappedTensor>* outputs) override { Status s; - if (node.num_inputs() != 1) { + if (node.num_inputs() != 1 || inputs.size() != 1) { return errors::Internal("Unpack op should only have one input."); } @@ -39,13 +39,13 @@ class UnpackVectorizer : public Vectorizer { int new_axis = node.def().attr().at("axis").i() + 1; new_unpack_node->AddAttr("axis", new_axis); - // Add the input mappings - input_ports->push_back({new_unpack_node, 0}); + outer_scope->AddEdge(inputs[0].node, inputs[0].output_index, + new_unpack_node, 0); // Add the output mappings int num = node.def().attr().at("num").i(); for (int i = 0; i < num; ++i) { - output_ports->push_back({new_unpack_node, i}); + outputs->push_back({new_unpack_node, i, true}); } return Status::OK(); @@ -54,6 +54,6 @@ class UnpackVectorizer : public Vectorizer { REGISTER_VECTORIZER("Unpack", UnpackVectorizer); -} // namespace vectorization_utils +} // namespace } // namespace grappler } // namespace tensorflow |