diff options
Diffstat (limited to 'tensorflow/core/graph')
-rw-r--r-- | tensorflow/core/graph/graph_constructor.cc | 8 | ||||
-rw-r--r-- | tensorflow/core/graph/graph_constructor_test.cc | 9 | ||||
-rw-r--r-- | tensorflow/core/graph/mkl_layout_pass.cc | 5 | ||||
-rw-r--r-- | tensorflow/core/graph/mkl_layout_pass_test.cc | 4 | ||||
-rw-r--r-- | tensorflow/core/graph/mkl_tfconversion_pass.cc | 2 | ||||
-rw-r--r-- | tensorflow/core/graph/mkl_tfconversion_pass_test.cc | 4 | ||||
-rw-r--r-- | tensorflow/core/graph/testlib.cc | 27 | ||||
-rw-r--r-- | tensorflow/core/graph/testlib.h | 11 |
8 files changed, 61 insertions, 9 deletions
diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index ee10194142..eeb5c14eaa 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -1042,12 +1042,12 @@ Status GraphConstructor::Convert() { } if (processed < node_defs_.size()) { - LOG(WARNING) << "IN " << __func__ << (node_defs_.size() - processed) + LOG(WARNING) << "IN " << __func__ << " " << (node_defs_.size() - processed) << " NODES IN A CYCLE"; for (int64 i = 0; i < node_defs_.size(); i++) { if (pending_count_[i] != 0) { LOG(WARNING) << "PENDING: " << SummarizeNodeDef(*node_defs_[i]) - << "WITH PENDING COUNT = " << pending_count_[i]; + << " WITH PENDING COUNT = " << pending_count_[i]; } } return errors::InvalidArgument(node_defs_.size() - processed, @@ -1162,7 +1162,9 @@ Status GraphConstructor::PopulateMissingUnusedInputMapKeys() { const NodeDef* node_def = node_defs_[pair->second.gdef_index]; const OpDef* op_def; TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def->op(), &op_def)); - if (key.second >= op_def->output_arg_size()) { + int num_outputs; + TF_RETURN_IF_ERROR(NumOutputsForNode(*node_def, *op_def, &num_outputs)); + if (key.second >= num_outputs) { // key's index out of bounds missing_unused_input_map_keys_->push_back(key); } diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc index 73142ebde7..3eef6bd2bd 100644 --- a/tensorflow/core/graph/graph_constructor_test.cc +++ b/tensorflow/core/graph/graph_constructor_test.cc @@ -199,6 +199,10 @@ REGISTER_OP("TestOneInputOneOutput") .Output("y: T") .Attr("T: {float, int64}") .SetShapeFn(shape_inference::UnchangedShape); +REGISTER_OP("TestVariadicOutput") + .Output("outputs: N * int32") + .Attr("N: int >= 0") + .SetShapeFn(shape_inference::UnknownShape); REGISTER_OP("TestDefaultAttr") .Attr("default_int: int=31415") .SetShapeFn(shape_inference::NoOutputs); @@ -1463,12 +1467,15 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapMissingUnusedKeys) { opts.input_map[TensorId("DNE", 0)] = TensorId("input", 0); // Unused but not missing opts.input_map[TensorId("t1", 0)] = TensorId("W1", 0); + // Unused but not missing + opts.input_map[TensorId("variadic", 4)] = TensorId("input", 0); ExpectOK( R"EOF( node { name: 'W2' op: 'TestParams' } node { name: 'new_input' op: 'TestInput' input: [ '^W2' ] } node { name: 't1' op: 'TestMul' input: [ 'new_input:0', 'new_input:1' ] } - node { name: 't2' op: 'TestMul' input: [ 't1:0', 't1:0' ] } + node { name: 'variadic' op: 'TestVariadicOutput' + attr { key: "N" value { i: 5 } } } )EOF", opts, &refiner, &results); diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc index 50fd6bae12..06d3fefef1 100644 --- a/tensorflow/core/graph/mkl_layout_pass.cc +++ b/tensorflow/core/graph/mkl_layout_pass.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/node_builder.h" @@ -976,7 +977,9 @@ std::vector<MklLayoutRewritePass::ContextInfo*> MklLayoutRewritePass::cinfo_; // nodes. Do not change the ordering of the Mkl passes. const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup = OptimizationPassRegistry::POST_PARTITIONING; +#ifdef ENABLE_MKL REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass); +#endif // ENABLE_MKL ////////////////////////////////////////////////////////////////////////// // Helper functions for creating new node @@ -3155,7 +3158,9 @@ MklLayoutRewritePass::ConstStringsInfo MklLayoutRewritePass::csinfo_; // nodes. Do not change the ordering of the Mkl passes. const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup = OptimizationPassRegistry::POST_PARTITIONING; +#ifdef ENABLE_MKL REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass); +#endif // ENABLE_MKL ////////////////////////////////////////////////////////////////////////// // Helper functions for creating new node diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc index 7f96a18023..77640e287c 100644 --- a/tensorflow/core/graph/mkl_layout_pass_test.cc +++ b/tensorflow/core/graph/mkl_layout_pass_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifdef INTEL_MKL +#if defined(INTEL_MKL) && defined(ENABLE_MKL) #include "tensorflow/core/graph/mkl_layout_pass.h" #include "tensorflow/core/graph/mkl_graph_util.h" @@ -3606,4 +3606,4 @@ BENCHMARK(BM_MklLayoutRewritePass)->Arg(1000)->Arg(10000); } // namespace tensorflow -#endif /* INTEL_MKL */ +#endif // INTEL_MKL && ENABLE_MKL diff --git a/tensorflow/core/graph/mkl_tfconversion_pass.cc b/tensorflow/core/graph/mkl_tfconversion_pass.cc index b67a321fc1..8c5ffd71a3 100644 --- a/tensorflow/core/graph/mkl_tfconversion_pass.cc +++ b/tensorflow/core/graph/mkl_tfconversion_pass.cc @@ -133,7 +133,9 @@ class MklToTfConversionPass : public GraphOptimizationPass { // complete picture of inputs and outputs of the nodes in the graphs. const OptimizationPassRegistry::Grouping kMklTfConvPassGroup = OptimizationPassRegistry::POST_PARTITIONING; +#ifdef ENABLE_MKL REGISTER_OPTIMIZATION(kMklTfConvPassGroup, 2, MklToTfConversionPass); +#endif // ENABLE_MKL Status MklToTfConversionPass::InsertConversionNodeOnEdge( std::unique_ptr<Graph>* g, Edge* e) { diff --git a/tensorflow/core/graph/mkl_tfconversion_pass_test.cc b/tensorflow/core/graph/mkl_tfconversion_pass_test.cc index ebcb6de551..319437a801 100644 --- a/tensorflow/core/graph/mkl_tfconversion_pass_test.cc +++ b/tensorflow/core/graph/mkl_tfconversion_pass_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifdef INTEL_MKL +#if defined(INTEL_MKL) && defined(ENABLE_MKL) #include "tensorflow/core/graph/mkl_tfconversion_pass.h" #include "tensorflow/core/graph/mkl_graph_util.h" @@ -304,4 +304,4 @@ BENCHMARK(BM_RunMklToTfConversionPass)->Arg(1000)->Arg(10000); } // namespace } // namespace tensorflow -#endif /* INTEL_MKL */ +#endif // INTEL_MKL && ENABLE_MKL diff --git a/tensorflow/core/graph/testlib.cc b/tensorflow/core/graph/testlib.cc index ea7788f654..0a38aa1c91 100644 --- a/tensorflow/core/graph/testlib.cc +++ b/tensorflow/core/graph/testlib.cc @@ -485,6 +485,33 @@ Node* DiagPart(Graph* g, Node* in, DataType type) { return ret; } +Node* CheckNumerics(Graph* g, Node* in, const string& message) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "CheckNumerics") + .Input(in) + .Attr("message", message) + .Finalize(g, &ret)); + return ret; +} + +Node* Arg(Graph* g, int64 index, DataType type) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Arg") + .Attr("T", type) + .Attr("index", index) + .Finalize(g, &ret)); + return ret; +} + +Node* Retval(Graph* g, int64 index, Node* in) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Retval") + .Input(in) + .Attr("index", index) + .Finalize(g, &ret)); + return ret; +} + void ToGraphDef(Graph* g, GraphDef* gdef) { g->ToGraphDef(gdef); } } // end namespace graph diff --git a/tensorflow/core/graph/testlib.h b/tensorflow/core/graph/testlib.h index 8585b35a19..b00196f587 100644 --- a/tensorflow/core/graph/testlib.h +++ b/tensorflow/core/graph/testlib.h @@ -32,7 +32,7 @@ namespace test { namespace graph { // Converts "g" into its corresponding GraphDef "def". -// DEPRECATED: call g->ToGraphDef(def) instead. +ABSL_DEPRECATED("Call g->ToGraphDef(def) instead.") void ToGraphDef(Graph* g, GraphDef* def); // A few helpers to construct a graph. @@ -209,6 +209,15 @@ Node* Diag(Graph* g, Node* in, DataType type); // Add a DiagPart node in "g". Node* DiagPart(Graph* g, Node* in, DataType type); +// Add a CheckNumerics node in "g". +Node* CheckNumerics(Graph* g, Node* in, const string& message); + +// Add an _Arg node in "g". +Node* Arg(Graph* g, int64 index, DataType type); + +// Add a _Retval node in "g". +Node* Retval(Graph* g, int64 index, Node* in); + } // end namespace graph } // end namespace test } // end namespace tensorflow |