aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/graph')
-rw-r--r--tensorflow/core/graph/graph_constructor.cc8
-rw-r--r--tensorflow/core/graph/graph_constructor_test.cc9
-rw-r--r--tensorflow/core/graph/mkl_layout_pass.cc5
-rw-r--r--tensorflow/core/graph/mkl_layout_pass_test.cc4
-rw-r--r--tensorflow/core/graph/mkl_tfconversion_pass.cc2
-rw-r--r--tensorflow/core/graph/mkl_tfconversion_pass_test.cc4
-rw-r--r--tensorflow/core/graph/testlib.cc27
-rw-r--r--tensorflow/core/graph/testlib.h11
8 files changed, 61 insertions, 9 deletions
diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc
index ee10194142..eeb5c14eaa 100644
--- a/tensorflow/core/graph/graph_constructor.cc
+++ b/tensorflow/core/graph/graph_constructor.cc
@@ -1042,12 +1042,12 @@ Status GraphConstructor::Convert() {
}
if (processed < node_defs_.size()) {
- LOG(WARNING) << "IN " << __func__ << (node_defs_.size() - processed)
+ LOG(WARNING) << "IN " << __func__ << " " << (node_defs_.size() - processed)
<< " NODES IN A CYCLE";
for (int64 i = 0; i < node_defs_.size(); i++) {
if (pending_count_[i] != 0) {
LOG(WARNING) << "PENDING: " << SummarizeNodeDef(*node_defs_[i])
- << "WITH PENDING COUNT = " << pending_count_[i];
+ << " WITH PENDING COUNT = " << pending_count_[i];
}
}
return errors::InvalidArgument(node_defs_.size() - processed,
@@ -1162,7 +1162,9 @@ Status GraphConstructor::PopulateMissingUnusedInputMapKeys() {
const NodeDef* node_def = node_defs_[pair->second.gdef_index];
const OpDef* op_def;
TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def->op(), &op_def));
- if (key.second >= op_def->output_arg_size()) {
+ int num_outputs;
+ TF_RETURN_IF_ERROR(NumOutputsForNode(*node_def, *op_def, &num_outputs));
+ if (key.second >= num_outputs) {
// key's index out of bounds
missing_unused_input_map_keys_->push_back(key);
}
diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc
index 73142ebde7..3eef6bd2bd 100644
--- a/tensorflow/core/graph/graph_constructor_test.cc
+++ b/tensorflow/core/graph/graph_constructor_test.cc
@@ -199,6 +199,10 @@ REGISTER_OP("TestOneInputOneOutput")
.Output("y: T")
.Attr("T: {float, int64}")
.SetShapeFn(shape_inference::UnchangedShape);
+REGISTER_OP("TestVariadicOutput")
+ .Output("outputs: N * int32")
+ .Attr("N: int >= 0")
+ .SetShapeFn(shape_inference::UnknownShape);
REGISTER_OP("TestDefaultAttr")
.Attr("default_int: int=31415")
.SetShapeFn(shape_inference::NoOutputs);
@@ -1463,12 +1467,15 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapMissingUnusedKeys) {
opts.input_map[TensorId("DNE", 0)] = TensorId("input", 0);
// Unused but not missing
opts.input_map[TensorId("t1", 0)] = TensorId("W1", 0);
+ // Unused but not missing
+ opts.input_map[TensorId("variadic", 4)] = TensorId("input", 0);
ExpectOK(
R"EOF(
node { name: 'W2' op: 'TestParams' }
node { name: 'new_input' op: 'TestInput' input: [ '^W2' ] }
node { name: 't1' op: 'TestMul' input: [ 'new_input:0', 'new_input:1' ] }
- node { name: 't2' op: 'TestMul' input: [ 't1:0', 't1:0' ] }
+ node { name: 'variadic' op: 'TestVariadicOutput'
+ attr { key: "N" value { i: 5 } } }
)EOF",
opts, &refiner, &results);
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index 50fd6bae12..06d3fefef1 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
@@ -976,7 +977,9 @@ std::vector<MklLayoutRewritePass::ContextInfo*> MklLayoutRewritePass::cinfo_;
// nodes. Do not change the ordering of the Mkl passes.
const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup =
OptimizationPassRegistry::POST_PARTITIONING;
+#ifdef ENABLE_MKL
REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass);
+#endif // ENABLE_MKL
//////////////////////////////////////////////////////////////////////////
// Helper functions for creating new node
@@ -3155,7 +3158,9 @@ MklLayoutRewritePass::ConstStringsInfo MklLayoutRewritePass::csinfo_;
// nodes. Do not change the ordering of the Mkl passes.
const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup =
OptimizationPassRegistry::POST_PARTITIONING;
+#ifdef ENABLE_MKL
REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass);
+#endif // ENABLE_MKL
//////////////////////////////////////////////////////////////////////////
// Helper functions for creating new node
diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc
index 7f96a18023..77640e287c 100644
--- a/tensorflow/core/graph/mkl_layout_pass_test.cc
+++ b/tensorflow/core/graph/mkl_layout_pass_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
#include "tensorflow/core/graph/mkl_layout_pass.h"
#include "tensorflow/core/graph/mkl_graph_util.h"
@@ -3606,4 +3606,4 @@ BENCHMARK(BM_MklLayoutRewritePass)->Arg(1000)->Arg(10000);
} // namespace tensorflow
-#endif /* INTEL_MKL */
+#endif // INTEL_MKL && ENABLE_MKL
diff --git a/tensorflow/core/graph/mkl_tfconversion_pass.cc b/tensorflow/core/graph/mkl_tfconversion_pass.cc
index b67a321fc1..8c5ffd71a3 100644
--- a/tensorflow/core/graph/mkl_tfconversion_pass.cc
+++ b/tensorflow/core/graph/mkl_tfconversion_pass.cc
@@ -133,7 +133,9 @@ class MklToTfConversionPass : public GraphOptimizationPass {
// complete picture of inputs and outputs of the nodes in the graphs.
const OptimizationPassRegistry::Grouping kMklTfConvPassGroup =
OptimizationPassRegistry::POST_PARTITIONING;
+#ifdef ENABLE_MKL
REGISTER_OPTIMIZATION(kMklTfConvPassGroup, 2, MklToTfConversionPass);
+#endif // ENABLE_MKL
Status MklToTfConversionPass::InsertConversionNodeOnEdge(
std::unique_ptr<Graph>* g, Edge* e) {
diff --git a/tensorflow/core/graph/mkl_tfconversion_pass_test.cc b/tensorflow/core/graph/mkl_tfconversion_pass_test.cc
index ebcb6de551..319437a801 100644
--- a/tensorflow/core/graph/mkl_tfconversion_pass_test.cc
+++ b/tensorflow/core/graph/mkl_tfconversion_pass_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
#include "tensorflow/core/graph/mkl_tfconversion_pass.h"
#include "tensorflow/core/graph/mkl_graph_util.h"
@@ -304,4 +304,4 @@ BENCHMARK(BM_RunMklToTfConversionPass)->Arg(1000)->Arg(10000);
} // namespace
} // namespace tensorflow
-#endif /* INTEL_MKL */
+#endif // INTEL_MKL && ENABLE_MKL
diff --git a/tensorflow/core/graph/testlib.cc b/tensorflow/core/graph/testlib.cc
index ea7788f654..0a38aa1c91 100644
--- a/tensorflow/core/graph/testlib.cc
+++ b/tensorflow/core/graph/testlib.cc
@@ -485,6 +485,33 @@ Node* DiagPart(Graph* g, Node* in, DataType type) {
return ret;
}
+Node* CheckNumerics(Graph* g, Node* in, const string& message) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "CheckNumerics")
+ .Input(in)
+ .Attr("message", message)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Arg(Graph* g, int64 index, DataType type) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Arg")
+ .Attr("T", type)
+ .Attr("index", index)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Retval(Graph* g, int64 index, Node* in) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Retval")
+ .Input(in)
+ .Attr("index", index)
+ .Finalize(g, &ret));
+ return ret;
+}
+
void ToGraphDef(Graph* g, GraphDef* gdef) { g->ToGraphDef(gdef); }
} // end namespace graph
diff --git a/tensorflow/core/graph/testlib.h b/tensorflow/core/graph/testlib.h
index 8585b35a19..b00196f587 100644
--- a/tensorflow/core/graph/testlib.h
+++ b/tensorflow/core/graph/testlib.h
@@ -32,7 +32,7 @@ namespace test {
namespace graph {
// Converts "g" into its corresponding GraphDef "def".
-// DEPRECATED: call g->ToGraphDef(def) instead.
+ABSL_DEPRECATED("Call g->ToGraphDef(def) instead.")
void ToGraphDef(Graph* g, GraphDef* def);
// A few helpers to construct a graph.
@@ -209,6 +209,15 @@ Node* Diag(Graph* g, Node* in, DataType type);
// Add a DiagPart node in "g".
Node* DiagPart(Graph* g, Node* in, DataType type);
+// Add a CheckNumerics node in "g".
+Node* CheckNumerics(Graph* g, Node* in, const string& message);
+
+// Add an _Arg node in "g".
+Node* Arg(Graph* g, int64 index, DataType type);
+
+// Add a _Retval node in "g".
+Node* Retval(Graph* g, int64 index, Node* in);
+
} // end namespace graph
} // end namespace test
} // end namespace tensorflow