aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Rachel Lim <rachelim@google.com>2018-09-20 14:49:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-20 14:52:53 -0700
commit424f0556ad8acde8f912a67e46421957a71dcef2 (patch)
tree053a1ec1c0993d4d6778ad3dda79dfd9425a0d7f
parent800cc654de0bb99c5753fc4ab26a9293547ee0b3 (diff)
[tf.data] Some vectorization cleanup
PiperOrigin-RevId: 213886813
-rw-r--r--tensorflow/core/framework/node_def_util.cc12
-rw-r--r--tensorflow/core/framework/node_def_util.h4
-rw-r--r--tensorflow/core/framework/node_def_util_test.cc42
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_vectorization.cc4
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils.cc82
5 files changed, 98 insertions, 46 deletions
diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc
index 42ec315a32..43ac1d0ada 100644
--- a/tensorflow/core/framework/node_def_util.cc
+++ b/tensorflow/core/framework/node_def_util.cc
@@ -372,6 +372,14 @@ Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
node_def.name());
}
+Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
+ DataTypeVector* inputs) {
+ for (const auto& arg : op_def.input_arg()) {
+ TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, inputs));
+ }
+ return Status::OK();
+}
+
Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
int output_port, DataType* output_type) {
DataTypeVector output_types;
@@ -397,9 +405,7 @@ Status OutputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
DataTypeVector* inputs, DataTypeVector* outputs) {
- for (const auto& arg : op_def.input_arg()) {
- TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, inputs));
- }
+ TF_RETURN_IF_ERROR(InputTypesForNode(node_def, op_def, inputs));
return OutputTypesForNode(node_def, op_def, outputs);
}
diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h
index 7528d3d306..187bfa2c88 100644
--- a/tensorflow/core/framework/node_def_util.h
+++ b/tensorflow/core/framework/node_def_util.h
@@ -249,6 +249,10 @@ const string& GetNodeAttrString(const AttrSlice& attrs, StringPiece attr_name);
// REQUIRES: ValidateOpDef(op_def).ok()
Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
int input_port, DataType* input_type);
+// Computes the input types for a specific node.
+// REQUIRES: ValidateOpDef(op_def).ok()
+Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
+ DataTypeVector* inputs);
// Computes the output type for a specific node output.
// REQUIRES: ValidateOpDef(op_def).ok()
Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
diff --git a/tensorflow/core/framework/node_def_util_test.cc b/tensorflow/core/framework/node_def_util_test.cc
index 74cc594863..d9d437024a 100644
--- a/tensorflow/core/framework/node_def_util_test.cc
+++ b/tensorflow/core/framework/node_def_util_test.cc
@@ -370,6 +370,48 @@ TEST(NodeDefUtilTest, ValidSyntax) {
"Illegal op input name 'a:00");
}
+TEST(InputTypesForNode, Simple) {
+ const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
+ .Input("a: float")
+ .Input("b: int32")
+ .Output("c: string")
+ .Output("d: bool"));
+ const NodeDef node_def = ToNodeDef(
+ NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput()));
+ DataTypeVector types;
+ EXPECT_TRUE(InputTypesForNode(node_def, op_def, &types).ok());
+ EXPECT_EQ(types[0], DT_FLOAT);
+ EXPECT_EQ(types[1], DT_INT32);
+
+ DataType type;
+ EXPECT_TRUE(InputTypeForNode(node_def, op_def, 0, &type).ok());
+ EXPECT_EQ(type, DT_FLOAT);
+ EXPECT_TRUE(InputTypeForNode(node_def, op_def, 1, &type).ok());
+ EXPECT_EQ(type, DT_INT32);
+ EXPECT_FALSE(InputTypeForNode(node_def, op_def, 2, &type).ok());
+}
+
+TEST(OutputTypesForNode, Simple) {
+ const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
+ .Input("a: float")
+ .Input("b: int32")
+ .Output("c: string")
+ .Output("d: bool"));
+ const NodeDef node_def = ToNodeDef(
+ NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput()));
+ DataTypeVector types;
+ EXPECT_TRUE(OutputTypesForNode(node_def, op_def, &types).ok());
+ EXPECT_EQ(types[0], DT_STRING);
+ EXPECT_EQ(types[1], DT_BOOL);
+
+ DataType type;
+ EXPECT_TRUE(OutputTypeForNode(node_def, op_def, 0, &type).ok());
+ EXPECT_EQ(type, DT_STRING);
+ EXPECT_TRUE(OutputTypeForNode(node_def, op_def, 1, &type).ok());
+ EXPECT_EQ(type, DT_BOOL);
+ EXPECT_FALSE(OutputTypeForNode(node_def, op_def, 2, &type).ok());
+}
+
TEST(NameRangesForNodeTest, Simple) {
const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
.Input("a: float")
diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
index ad6722a3ae..7a2f1910da 100644
--- a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
@@ -86,8 +86,8 @@ FunctionDef* CreateMapDefunWrapper(const NodeDef& map_node,
FunctionDef* AddVectorizedFunction(const NodeDef& map_node,
const FunctionDef& orig_func,
FunctionDefLibrary* library) {
- // Vectorizes orig_func naively by wrapping in a MapDefun op, then tries to
- // do true vectorization with Vectorize.
+ // Vectorizes orig_func naively by wrapping in a MapDefun op, then performing
+ // efficient vectorization with VectorizeMapDefun.
FunctionDef* vectorized_func =
CreateMapDefunWrapper(map_node, orig_func, library);
NodeDef* map_defun_node = vectorized_func->mutable_node_def()->Mutable(0);
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
index 5dd9d00511..bfca63b820 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/grappler/mutable_graph_view.h"
@@ -89,20 +90,13 @@ void RemoveMapDefunOutput(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
->ExtractSubrange(output_position, 1, nullptr);
}
-Status ConvertCastOp(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
- NodeDef* map_defun_node, const NodeDef& cast_node,
- const FunctionDefTensorDesc& output_desc,
+Status ConvertCastOp(FunctionDef* outer_scope, gtl::ArraySlice<string> inputs,
+ const NodeDef& cast_node,
std::map<string, string>* conversion_map) {
- if (output_desc.node_output != "y" || output_desc.position != 0) {
- // We expect the Cast node to have only one output, with the name "y".
- return errors::Internal("Cannot convert Cast op output.");
+ if (inputs.size() != 1) {
+ return errors::Internal("Cast op should only have one input.");
}
- // Promote Cast inputs to outputs of MapDefun
- DCHECK_EQ(cast_node.input_size(), 1);
- AddMapDefunOutput(map_defun_fn, map_defun_node, cast_node.input(0),
- cast_node.attr().at("SrcT").type());
-
// Add new Cast node
NodeDef* new_cast_node = outer_scope->add_node_def();
*new_cast_node = cast_node;
@@ -110,29 +104,22 @@ Status ConvertCastOp(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
function_utils::SetUniqueFunctionNodeName(
strings::StrCat("vectorized/", cast_node.name()), outer_scope,
new_cast_node);
- new_cast_node->set_input(
- 0, strings::StrCat(map_defun_node->name(), ":output:",
- map_defun_fn->signature().output_arg_size() - 1));
+ new_cast_node->set_input(0, inputs[0]);
// Add the output mapping to conversion map
- (*conversion_map)[strings::StrCat(output_desc.node_name, ":y:0")] =
+ (*conversion_map)[strings::StrCat(cast_node.name(), ":y:0")] =
strings::StrCat(new_cast_node->name(), ":y:0");
return Status::OK();
}
-Status ConvertUnpackOp(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
- NodeDef* map_defun_node, const NodeDef& unpack_node,
- const FunctionDefTensorDesc& output_desc,
+Status ConvertUnpackOp(FunctionDef* outer_scope, gtl::ArraySlice<string> inputs,
+ const NodeDef& unpack_node,
std::map<string, string>* conversion_map) {
- if (output_desc.node_output != "output") {
- return errors::Internal("Cannot convert Unpack op output.");
+ if (inputs.size() != 1) {
+ return errors::Internal("Unpack op should only have one input.");
}
- // Promote Unpack inputs to outputs of MapDefun
- AddMapDefunOutput(map_defun_fn, map_defun_node, unpack_node.input(0),
- unpack_node.attr().at("T").type());
-
// Add new Unpack node
NodeDef* new_unpack_node = outer_scope->add_node_def();
*new_unpack_node = unpack_node;
@@ -144,14 +131,12 @@ Status ConvertUnpackOp(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
// Increment "axis" attr by 1:
(*new_unpack_node->mutable_attr())["axis"].set_i(
unpack_node.attr().at("axis").i() + 1);
- new_unpack_node->set_input(
- 0, strings::StrCat(map_defun_node->name(), ":output:",
- map_defun_fn->signature().output_arg_size() - 1));
+ new_unpack_node->set_input(0, inputs[0]);
// Add the output mappings to conversion map
int num = new_unpack_node->attr().at("num").i();
for (int i = 0; i < num; ++i) {
- (*conversion_map)[strings::StrCat(output_desc.node_name, ":output:", i)] =
+ (*conversion_map)[strings::StrCat(unpack_node.name(), ":output:", i)] =
strings::StrCat(new_unpack_node->name(), ":output:", i);
}
@@ -241,17 +226,37 @@ Status Vectorization::AddConversionMappingFromOp(
// TODO(rachelim): Have some mechanism for registering converters and some
// uniform, simpler way to represent them.
- // TODO(rachelim): Do step (1) outside of the individual op converters, when
- // we know how to find out the type of the input.
+ DataTypeVector types;
+ const OpDef* op_def = nullptr;
+ TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node.op(), &op_def));
+ TF_RETURN_IF_ERROR(InputTypesForNode(node, *op_def, &types));
+
+ std::vector<string> promoted_inputs;
+ promoted_inputs.reserve(node.input_size());
+ for (int i = 0; i < node.input_size(); ++i) {
+ promoted_inputs.push_back(strings::StrCat(
+ map_defun_node_->name(),
+ ":output:", map_defun_fn_->signature().output_arg_size() + i));
+ }
+
if (node.op() == "Cast") {
- return ConvertCastOp(outer_scope_, map_defun_fn_, map_defun_node_, node,
- output_desc, &conversion_map_);
+ TF_RETURN_IF_ERROR(
+ ConvertCastOp(outer_scope_, promoted_inputs, node, &conversion_map_));
} else if (node.op() == "Unpack") {
- return ConvertUnpackOp(outer_scope_, map_defun_fn_, map_defun_node_, node,
- output_desc, &conversion_map_);
+ TF_RETURN_IF_ERROR(
+ ConvertUnpackOp(outer_scope_, promoted_inputs, node, &conversion_map_));
+ } else {
+ return errors::Unimplemented("Op converter for \"", node.op(),
+ "\" not implemented yet");
}
- return errors::Unimplemented("Op converter for \"", node.op(),
- "\" not implemented yet");
+
+ // If we get here, the conversion was successful, so we promote the inputs
+ // of the ops to MapDefun outputs.
+ for (int i = 0; i < types.size(); ++i) {
+ AddMapDefunOutput(map_defun_fn_, map_defun_node_, node.input(i), types[i]);
+ }
+
+ return Status::OK();
}
Status Vectorization::AddConversionMappingFromInput(
@@ -333,11 +338,6 @@ void Vectorization::Vectorize() {
void VectorizeMapDefun(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
NodeDef* map_defun_node) {
- if (map_defun_node->attr().at("f").func().name() !=
- map_defun_fn->signature().name()) {
- LOG(ERROR) << "`map_defun_fn` and `map_defun_node` do not match";
- return;
- }
Vectorization(outer_scope, map_defun_fn, map_defun_node).Vectorize();
}