diff options
Diffstat (limited to 'tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc | 16 |
1 files changed, 9 insertions, 7 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc index 3af6bab409..f445157531 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc @@ -19,13 +19,13 @@ limitations under the License. namespace tensorflow { namespace grappler { -namespace vectorization_utils { +namespace { class CastVectorizer : public Vectorizer { public: Status Vectorize(const Node& node, Graph* outer_scope, - std::vector<Port>* input_ports, - std::vector<Port>* output_ports) override { + std::vector<WrappedTensor>&& inputs, + std::vector<WrappedTensor>* outputs) override { Status s; if (node.num_inputs() != 1) { return errors::Internal("Cast op should only have one input."); @@ -35,15 +35,17 @@ class CastVectorizer : public Vectorizer { auto new_cast_node = outer_scope->AddNode(node.def(), &s); TF_RETURN_IF_ERROR(s); - // Add input and output mappings - input_ports->push_back({new_cast_node, 0}); - output_ports->push_back({new_cast_node, 0}); + outer_scope->AddEdge(inputs[0].node, inputs[0].output_index, new_cast_node, + 0); + + // Add output mappings + outputs->push_back({new_cast_node, 0, true}); return Status::OK(); } }; REGISTER_VECTORIZER("Cast", CastVectorizer); -} // namespace vectorization_utils +} // namespace } // namespace grappler } // namespace tensorflow |