aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc16
1 files changed, 9 insertions, 7 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc
index 3af6bab409..f445157531 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc
@@ -19,13 +19,13 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-namespace vectorization_utils {
+namespace {
class CastVectorizer : public Vectorizer {
public:
Status Vectorize(const Node& node, Graph* outer_scope,
- std::vector<Port>* input_ports,
- std::vector<Port>* output_ports) override {
+ std::vector<WrappedTensor>&& inputs,
+ std::vector<WrappedTensor>* outputs) override {
Status s;
if (node.num_inputs() != 1) {
return errors::Internal("Cast op should only have one input.");
@@ -35,15 +35,17 @@ class CastVectorizer : public Vectorizer {
auto new_cast_node = outer_scope->AddNode(node.def(), &s);
TF_RETURN_IF_ERROR(s);
- // Add input and output mappings
- input_ports->push_back({new_cast_node, 0});
- output_ports->push_back({new_cast_node, 0});
+ outer_scope->AddEdge(inputs[0].node, inputs[0].output_index, new_cast_node,
+ 0);
+
+ // Add output mappings
+ outputs->push_back({new_cast_node, 0, true});
return Status::OK();
}
};
REGISTER_VECTORIZER("Cast", CastVectorizer);
-} // namespace vectorization_utils
+} // namespace
} // namespace grappler
} // namespace tensorflow