aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
diff options
context:
space:
mode:
authorGravatar Rachel Lim <rachelim@google.com>2018-09-20 14:49:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-20 14:52:53 -0700
commit424f0556ad8acde8f912a67e46421957a71dcef2 (patch)
tree053a1ec1c0993d4d6778ad3dda79dfd9425a0d7f /tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
parent800cc654de0bb99c5753fc4ab26a9293547ee0b3 (diff)
[tf.data] Some vectorization cleanup
PiperOrigin-RevId: 213886813
Diffstat (limited to 'tensorflow/core/grappler/optimizers/data/vectorization_utils.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils.cc82
1 files changed, 41 insertions, 41 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
index 5dd9d00511..bfca63b820 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/grappler/mutable_graph_view.h"
@@ -89,20 +90,13 @@ void RemoveMapDefunOutput(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
->ExtractSubrange(output_position, 1, nullptr);
}
-Status ConvertCastOp(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
- NodeDef* map_defun_node, const NodeDef& cast_node,
- const FunctionDefTensorDesc& output_desc,
+Status ConvertCastOp(FunctionDef* outer_scope, gtl::ArraySlice<string> inputs,
+ const NodeDef& cast_node,
std::map<string, string>* conversion_map) {
- if (output_desc.node_output != "y" || output_desc.position != 0) {
- // We expect the Cast node to have only one output, with the name "y".
- return errors::Internal("Cannot convert Cast op output.");
+ if (inputs.size() != 1) {
+ return errors::Internal("Cast op should only have one input.");
}
- // Promote Cast inputs to outputs of MapDefun
- DCHECK_EQ(cast_node.input_size(), 1);
- AddMapDefunOutput(map_defun_fn, map_defun_node, cast_node.input(0),
- cast_node.attr().at("SrcT").type());
-
// Add new Cast node
NodeDef* new_cast_node = outer_scope->add_node_def();
*new_cast_node = cast_node;
@@ -110,29 +104,22 @@ Status ConvertCastOp(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
function_utils::SetUniqueFunctionNodeName(
strings::StrCat("vectorized/", cast_node.name()), outer_scope,
new_cast_node);
- new_cast_node->set_input(
- 0, strings::StrCat(map_defun_node->name(), ":output:",
- map_defun_fn->signature().output_arg_size() - 1));
+ new_cast_node->set_input(0, inputs[0]);
// Add the output mapping to conversion map
- (*conversion_map)[strings::StrCat(output_desc.node_name, ":y:0")] =
+ (*conversion_map)[strings::StrCat(cast_node.name(), ":y:0")] =
strings::StrCat(new_cast_node->name(), ":y:0");
return Status::OK();
}
-Status ConvertUnpackOp(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
- NodeDef* map_defun_node, const NodeDef& unpack_node,
- const FunctionDefTensorDesc& output_desc,
+Status ConvertUnpackOp(FunctionDef* outer_scope, gtl::ArraySlice<string> inputs,
+ const NodeDef& unpack_node,
std::map<string, string>* conversion_map) {
- if (output_desc.node_output != "output") {
- return errors::Internal("Cannot convert Unpack op output.");
+ if (inputs.size() != 1) {
+ return errors::Internal("Unpack op should only have one input.");
}
- // Promote Unpack inputs to outputs of MapDefun
- AddMapDefunOutput(map_defun_fn, map_defun_node, unpack_node.input(0),
- unpack_node.attr().at("T").type());
-
// Add new Unpack node
NodeDef* new_unpack_node = outer_scope->add_node_def();
*new_unpack_node = unpack_node;
@@ -144,14 +131,12 @@ Status ConvertUnpackOp(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
// Increment "axis" attr by 1:
(*new_unpack_node->mutable_attr())["axis"].set_i(
unpack_node.attr().at("axis").i() + 1);
- new_unpack_node->set_input(
- 0, strings::StrCat(map_defun_node->name(), ":output:",
- map_defun_fn->signature().output_arg_size() - 1));
+ new_unpack_node->set_input(0, inputs[0]);
// Add the output mappings to conversion map
int num = new_unpack_node->attr().at("num").i();
for (int i = 0; i < num; ++i) {
- (*conversion_map)[strings::StrCat(output_desc.node_name, ":output:", i)] =
+ (*conversion_map)[strings::StrCat(unpack_node.name(), ":output:", i)] =
strings::StrCat(new_unpack_node->name(), ":output:", i);
}
@@ -241,17 +226,37 @@ Status Vectorization::AddConversionMappingFromOp(
// TODO(rachelim): Have some mechanism for registering converters and some
// uniform, simpler way to represent them.
- // TODO(rachelim): Do step (1) outside of the individual op converters, when
- // we know how to find out the type of the input.
+ DataTypeVector types;
+ const OpDef* op_def = nullptr;
+ TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node.op(), &op_def));
+ TF_RETURN_IF_ERROR(InputTypesForNode(node, *op_def, &types));
+
+ std::vector<string> promoted_inputs;
+ promoted_inputs.reserve(node.input_size());
+ for (int i = 0; i < node.input_size(); ++i) {
+ promoted_inputs.push_back(strings::StrCat(
+ map_defun_node_->name(),
+ ":output:", map_defun_fn_->signature().output_arg_size() + i));
+ }
+
if (node.op() == "Cast") {
- return ConvertCastOp(outer_scope_, map_defun_fn_, map_defun_node_, node,
- output_desc, &conversion_map_);
+ TF_RETURN_IF_ERROR(
+ ConvertCastOp(outer_scope_, promoted_inputs, node, &conversion_map_));
} else if (node.op() == "Unpack") {
- return ConvertUnpackOp(outer_scope_, map_defun_fn_, map_defun_node_, node,
- output_desc, &conversion_map_);
+ TF_RETURN_IF_ERROR(
+ ConvertUnpackOp(outer_scope_, promoted_inputs, node, &conversion_map_));
+ } else {
+ return errors::Unimplemented("Op converter for \"", node.op(),
+ "\" not implemented yet");
}
- return errors::Unimplemented("Op converter for \"", node.op(),
- "\" not implemented yet");
+
+ // If we get here, the conversion was successful, so we promote the inputs
+ // of the ops to MapDefun outputs.
+ for (int i = 0; i < types.size(); ++i) {
+ AddMapDefunOutput(map_defun_fn_, map_defun_node_, node.input(i), types[i]);
+ }
+
+ return Status::OK();
}
Status Vectorization::AddConversionMappingFromInput(
@@ -333,11 +338,6 @@ void Vectorization::Vectorize() {
void VectorizeMapDefun(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
NodeDef* map_defun_node) {
- if (map_defun_node->attr().at("f").func().name() !=
- map_defun_fn->signature().name()) {
- LOG(ERROR) << "`map_defun_fn` and `map_defun_node` do not match";
- return;
- }
Vectorization(outer_scope, map_defun_fn, map_defun_node).Vectorize();
}