diff options
author | Rachel Lim <rachelim@google.com> | 2018-09-20 14:49:14 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-20 14:52:53 -0700 |
commit | 424f0556ad8acde8f912a67e46421957a71dcef2 (patch) | |
tree | 053a1ec1c0993d4d6778ad3dda79dfd9425a0d7f /tensorflow | |
parent | 800cc654de0bb99c5753fc4ab26a9293547ee0b3 (diff) |
[tf.data] Some vectorization cleanup
PiperOrigin-RevId: 213886813
Diffstat (limited to 'tensorflow')
5 files changed, 98 insertions, 46 deletions
diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc index 42ec315a32..43ac1d0ada 100644 --- a/tensorflow/core/framework/node_def_util.cc +++ b/tensorflow/core/framework/node_def_util.cc @@ -372,6 +372,14 @@ Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def, node_def.name()); } +Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def, + DataTypeVector* inputs) { + for (const auto& arg : op_def.input_arg()) { + TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, inputs)); + } + return Status::OK(); +} + Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def, int output_port, DataType* output_type) { DataTypeVector output_types; @@ -397,9 +405,7 @@ Status OutputTypesForNode(const NodeDef& node_def, const OpDef& op_def, Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def, DataTypeVector* inputs, DataTypeVector* outputs) { - for (const auto& arg : op_def.input_arg()) { - TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, inputs)); - } + TF_RETURN_IF_ERROR(InputTypesForNode(node_def, op_def, inputs)); return OutputTypesForNode(node_def, op_def, outputs); } diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h index 7528d3d306..187bfa2c88 100644 --- a/tensorflow/core/framework/node_def_util.h +++ b/tensorflow/core/framework/node_def_util.h @@ -249,6 +249,10 @@ const string& GetNodeAttrString(const AttrSlice& attrs, StringPiece attr_name); // REQUIRES: ValidateOpDef(op_def).ok() Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def, int input_port, DataType* input_type); +// Computes the input types for a specific node. +// REQUIRES: ValidateOpDef(op_def).ok() +Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def, + DataTypeVector* inputs); // Computes the output type for a specific node output. // REQUIRES: ValidateOpDef(op_def).ok() Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def, diff --git a/tensorflow/core/framework/node_def_util_test.cc b/tensorflow/core/framework/node_def_util_test.cc index 74cc594863..d9d437024a 100644 --- a/tensorflow/core/framework/node_def_util_test.cc +++ b/tensorflow/core/framework/node_def_util_test.cc @@ -370,6 +370,48 @@ TEST(NodeDefUtilTest, ValidSyntax) { "Illegal op input name 'a:00"); } +TEST(InputTypesForNode, Simple) { + const OpDef op_def = ToOpDef(OpDefBuilder("Simple") + .Input("a: float") + .Input("b: int32") + .Output("c: string") + .Output("d: bool")); + const NodeDef node_def = ToNodeDef( + NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput())); + DataTypeVector types; + EXPECT_TRUE(InputTypesForNode(node_def, op_def, &types).ok()); + EXPECT_EQ(types[0], DT_FLOAT); + EXPECT_EQ(types[1], DT_INT32); + + DataType type; + EXPECT_TRUE(InputTypeForNode(node_def, op_def, 0, &type).ok()); + EXPECT_EQ(type, DT_FLOAT); + EXPECT_TRUE(InputTypeForNode(node_def, op_def, 1, &type).ok()); + EXPECT_EQ(type, DT_INT32); + EXPECT_FALSE(InputTypeForNode(node_def, op_def, 2, &type).ok()); +} + +TEST(OutputTypesForNode, Simple) { + const OpDef op_def = ToOpDef(OpDefBuilder("Simple") + .Input("a: float") + .Input("b: int32") + .Output("c: string") + .Output("d: bool")); + const NodeDef node_def = ToNodeDef( + NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput())); + DataTypeVector types; + EXPECT_TRUE(OutputTypesForNode(node_def, op_def, &types).ok()); + EXPECT_EQ(types[0], DT_STRING); + EXPECT_EQ(types[1], DT_BOOL); + + DataType type; + EXPECT_TRUE(OutputTypeForNode(node_def, op_def, 0, &type).ok()); + EXPECT_EQ(type, DT_STRING); + EXPECT_TRUE(OutputTypeForNode(node_def, op_def, 1, &type).ok()); + EXPECT_EQ(type, DT_BOOL); + EXPECT_FALSE(OutputTypeForNode(node_def, op_def, 2, &type).ok()); +} + TEST(NameRangesForNodeTest, Simple) { const OpDef op_def = ToOpDef(OpDefBuilder("Simple") .Input("a: float") diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc index ad6722a3ae..7a2f1910da 100644 --- a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc +++ b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc @@ -86,8 +86,8 @@ FunctionDef* CreateMapDefunWrapper(const NodeDef& map_node, FunctionDef* AddVectorizedFunction(const NodeDef& map_node, const FunctionDef& orig_func, FunctionDefLibrary* library) { - // Vectorizes orig_func naively by wrapping in a MapDefun op, then tries to - // do true vectorization with Vectorize. + // Vectorizes orig_func naively by wrapping in a MapDefun op, then performing + // efficient vectorization with VectorizeMapDefun. FunctionDef* vectorized_func = CreateMapDefunWrapper(map_node, orig_func, library); NodeDef* map_defun_node = vectorized_func->mutable_node_def()->Mutable(0); diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc index 5dd9d00511..bfca63b820 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/grappler/mutable_graph_view.h" @@ -89,20 +90,13 @@ void RemoveMapDefunOutput(FunctionDef* outer_scope, FunctionDef* map_defun_fn, ->ExtractSubrange(output_position, 1, nullptr); } -Status ConvertCastOp(FunctionDef* outer_scope, FunctionDef* map_defun_fn, - NodeDef* map_defun_node, const NodeDef& cast_node, - const FunctionDefTensorDesc& output_desc, +Status ConvertCastOp(FunctionDef* outer_scope, gtl::ArraySlice<string> inputs, + const NodeDef& cast_node, std::map<string, string>* conversion_map) { - if (output_desc.node_output != "y" || output_desc.position != 0) { - // We expect the Cast node to have only one output, with the name "y". - return errors::Internal("Cannot convert Cast op output."); + if (inputs.size() != 1) { + return errors::Internal("Cast op should only have one input."); } - // Promote Cast inputs to outputs of MapDefun - DCHECK_EQ(cast_node.input_size(), 1); - AddMapDefunOutput(map_defun_fn, map_defun_node, cast_node.input(0), - cast_node.attr().at("SrcT").type()); - // Add new Cast node NodeDef* new_cast_node = outer_scope->add_node_def(); *new_cast_node = cast_node; @@ -110,29 +104,22 @@ Status ConvertCastOp(FunctionDef* outer_scope, FunctionDef* map_defun_fn, function_utils::SetUniqueFunctionNodeName( strings::StrCat("vectorized/", cast_node.name()), outer_scope, new_cast_node); - new_cast_node->set_input( - 0, strings::StrCat(map_defun_node->name(), ":output:", - map_defun_fn->signature().output_arg_size() - 1)); + new_cast_node->set_input(0, inputs[0]); // Add the output mapping to conversion map - (*conversion_map)[strings::StrCat(output_desc.node_name, ":y:0")] = + (*conversion_map)[strings::StrCat(cast_node.name(), ":y:0")] = strings::StrCat(new_cast_node->name(), ":y:0"); return Status::OK(); } -Status ConvertUnpackOp(FunctionDef* outer_scope, FunctionDef* map_defun_fn, - NodeDef* map_defun_node, const NodeDef& unpack_node, - const FunctionDefTensorDesc& output_desc, +Status ConvertUnpackOp(FunctionDef* outer_scope, gtl::ArraySlice<string> inputs, + const NodeDef& unpack_node, std::map<string, string>* conversion_map) { - if (output_desc.node_output != "output") { - return errors::Internal("Cannot convert Unpack op output."); + if (inputs.size() != 1) { + return errors::Internal("Unpack op should only have one input."); } - // Promote Unpack inputs to outputs of MapDefun - AddMapDefunOutput(map_defun_fn, map_defun_node, unpack_node.input(0), - unpack_node.attr().at("T").type()); - // Add new Unpack node NodeDef* new_unpack_node = outer_scope->add_node_def(); *new_unpack_node = unpack_node; @@ -144,14 +131,12 @@ Status ConvertUnpackOp(FunctionDef* outer_scope, FunctionDef* map_defun_fn, // Increment "axis" attr by 1: (*new_unpack_node->mutable_attr())["axis"].set_i( unpack_node.attr().at("axis").i() + 1); - new_unpack_node->set_input( - 0, strings::StrCat(map_defun_node->name(), ":output:", - map_defun_fn->signature().output_arg_size() - 1)); + new_unpack_node->set_input(0, inputs[0]); // Add the output mappings to conversion map int num = new_unpack_node->attr().at("num").i(); for (int i = 0; i < num; ++i) { - (*conversion_map)[strings::StrCat(output_desc.node_name, ":output:", i)] = + (*conversion_map)[strings::StrCat(unpack_node.name(), ":output:", i)] = strings::StrCat(new_unpack_node->name(), ":output:", i); } @@ -241,17 +226,37 @@ Status Vectorization::AddConversionMappingFromOp( // TODO(rachelim): Have some mechanism for registering converters and some // uniform, simpler way to represent them. - // TODO(rachelim): Do step (1) outside of the individual op converters, when - // we know how to find out the type of the input. + DataTypeVector types; + const OpDef* op_def = nullptr; + TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node.op(), &op_def)); + TF_RETURN_IF_ERROR(InputTypesForNode(node, *op_def, &types)); + + std::vector<string> promoted_inputs; + promoted_inputs.reserve(node.input_size()); + for (int i = 0; i < node.input_size(); ++i) { + promoted_inputs.push_back(strings::StrCat( + map_defun_node_->name(), + ":output:", map_defun_fn_->signature().output_arg_size() + i)); + } + if (node.op() == "Cast") { - return ConvertCastOp(outer_scope_, map_defun_fn_, map_defun_node_, node, - output_desc, &conversion_map_); + TF_RETURN_IF_ERROR( + ConvertCastOp(outer_scope_, promoted_inputs, node, &conversion_map_)); } else if (node.op() == "Unpack") { - return ConvertUnpackOp(outer_scope_, map_defun_fn_, map_defun_node_, node, - output_desc, &conversion_map_); + TF_RETURN_IF_ERROR( + ConvertUnpackOp(outer_scope_, promoted_inputs, node, &conversion_map_)); + } else { + return errors::Unimplemented("Op converter for \"", node.op(), + "\" not implemented yet"); } - return errors::Unimplemented("Op converter for \"", node.op(), - "\" not implemented yet"); + + // If we get here, the conversion was successful, so we promote the inputs + // of the ops to MapDefun outputs. + for (int i = 0; i < types.size(); ++i) { + AddMapDefunOutput(map_defun_fn_, map_defun_node_, node.input(i), types[i]); + } + + return Status::OK(); } Status Vectorization::AddConversionMappingFromInput( @@ -333,11 +338,6 @@ void Vectorization::Vectorize() { void VectorizeMapDefun(FunctionDef* outer_scope, FunctionDef* map_defun_fn, NodeDef* map_defun_node) { - if (map_defun_node->attr().at("f").func().name() != - map_defun_fn->signature().name()) { - LOG(ERROR) << "`map_defun_fn` and `map_defun_node` do not match"; - return; - } Vectorization(outer_scope, map_defun_fn, map_defun_node).Vectorize(); } |