diff options
author | Skye Wanderman-Milne <skyewm@google.com> | 2017-08-15 16:46:16 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-08-15 16:50:32 -0700 |
commit | 477d49c9eaafc5e1e1667d454ce5883956180713 (patch) | |
tree | 02ef20e777b6d686a6fa97327f761b40ff91c738 | |
parent | 9ba0abc2f06fe2d09c96487d5170b2faec79d2a4 (diff) |
C++ API: run shape inference as nodes are constructed
Here's an example of the new generated code:
AddN::AddN(const ::tensorflow::Scope& scope, ::tensorflow::InputList inputs) {
if (!scope.ok()) return;
auto _inputs = ::tensorflow::ops::AsNodeOutList(scope, inputs);
if (!scope.ok()) return;
::tensorflow::Node* ret;
const auto unique_name = scope.GetUniqueNameForOp("AddN");
auto builder = ::tensorflow::NodeBuilder(unique_name, "AddN")
.Input(_inputs)
;
scope.UpdateBuilder(&builder);
scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
if (!scope.ok()) return;
scope.UpdateStatus(scope.DoShapeInference(ret));
this->sum = Output(ret, 0);
}
Enabling shape inference unfortunately broke many tests. I fixed some of them, but for others I introduced a Scope::DisabledShapeInferenceScope() static method that returns a scope that doesn't perform shape inference. Eventually we should fix the tests that use this and remove it.
PiperOrigin-RevId: 165378429
26 files changed, 143 insertions, 50 deletions
diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index 80dd272f6f..38a17598b8 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -812,12 +812,8 @@ string OpInfo::GetConstructorBody() const { strings::StrAppend(&body, " ", scope_str, ".UpdateStatus(builder.Finalize(", scope_str, ".graph(), &ret));\n"); strings::StrAppend(&body, " ", return_on_error, "\n"); - - // TODO(b/28152992): Enable this code-path once we have converted - // all python shape functions to call their C++ versions. - - // strings::StrAppend(&body, " ", scope_str, ".UpdateStatus(", scope_str, - // ".refiner()->AddNode(ret));\n"); + strings::StrAppend(&body, " ", scope_str, ".UpdateStatus(", scope_str, + ".DoShapeInference(ret));\n"); GetOutput(&body); return body; diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc index 4705b6b7e8..7164249262 100644 --- a/tensorflow/cc/framework/scope.cc +++ b/tensorflow/cc/framework/scope.cc @@ -37,13 +37,14 @@ Scope& Scope::operator=(const Scope& other) { } Scope::Impl::Impl(Graph* graph, Status* status, NameMap* name_map, - ShapeRefiner* refiner) + ShapeRefiner* refiner, bool disable_shape_inference) : graph_(graph), status_(status), name_map_(name_map), refiner_(refiner), scope_used_(nullptr), - colocation_constraints_() {} + colocation_constraints_(), + disable_shape_inference_(disable_shape_inference) {} Scope::Impl::Impl(const std::shared_ptr<Graph>& graph, const std::shared_ptr<Status>& status, @@ -54,13 +55,23 @@ Scope::Impl::Impl(const std::shared_ptr<Graph>& graph, name_map_(name_map), refiner_(refiner), scope_used_(nullptr), - colocation_constraints_() {} + colocation_constraints_(), + disable_shape_inference_(false) {} Scope Scope::NewRootScope() { Graph* graph = new Graph(OpRegistry::Global()); ShapeRefiner* refiner = new ShapeRefiner(graph->versions(), graph->op_registry()); - return Scope(new Impl(graph, new Status, new Impl::NameMap, refiner)); + return Scope(new Impl(graph, new Status, new Impl::NameMap, refiner, + /* disable_shape_inference */ false)); +} + +Scope Scope::DisabledShapeInferenceScope() { + Graph* graph = new Graph(OpRegistry::Global()); + ShapeRefiner* refiner = + new ShapeRefiner(graph->versions(), graph->op_registry()); + return Scope(new Impl(graph, new Status, new Impl::NameMap, refiner, + /* disable_shape_inference */ true)); } Scope::Impl::Impl(const Scope& other, Tags::ScopeName, const string& name, @@ -77,7 +88,8 @@ Scope::Impl::Impl(const Scope& other, Tags::ScopeName, const string& name, exit_on_error_(other.impl()->exit_on_error_), kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), - colocation_constraints_(other.impl()->colocation_constraints_) {} + colocation_constraints_(other.impl()->colocation_constraints_), + disable_shape_inference_(other.impl()->disable_shape_inference_) {} Scope::Impl::Impl(const Scope& other, Tags::OpName, const string& name, const string& op_name) @@ -92,7 +104,8 @@ Scope::Impl::Impl(const Scope& other, Tags::OpName, const string& name, exit_on_error_(other.impl()->exit_on_error_), kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), - colocation_constraints_(other.impl()->colocation_constraints_) {} + colocation_constraints_(other.impl()->colocation_constraints_), + disable_shape_inference_(other.impl()->disable_shape_inference_) {} Scope::Impl::Impl(const Scope& other, Tags::ControlDeps, std::vector<Operation> control_deps, bool clear_control_deps) @@ -113,7 +126,8 @@ Scope::Impl::Impl(const Scope& other, Tags::ControlDeps, exit_on_error_(other.impl()->exit_on_error_), kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), - colocation_constraints_(other.impl()->colocation_constraints_) {} + colocation_constraints_(other.impl()->colocation_constraints_), + disable_shape_inference_(other.impl()->disable_shape_inference_) {} Scope::Impl::Impl(const Scope& other, Tags::Device, const string& device) : graph_(other.impl()->graph_), @@ -127,7 +141,8 @@ Scope::Impl::Impl(const Scope& other, Tags::Device, const string& device) exit_on_error_(other.impl()->exit_on_error_), kernel_label_(other.impl()->kernel_label_), device_(device), - colocation_constraints_(other.impl()->colocation_constraints_) {} + colocation_constraints_(other.impl()->colocation_constraints_), + disable_shape_inference_(other.impl()->disable_shape_inference_) {} Scope::Impl::Impl(const Scope& other, Tags::SingleUseScope, const string& op_name) @@ -142,7 +157,8 @@ Scope::Impl::Impl(const Scope& other, Tags::SingleUseScope, exit_on_error_(other.impl()->exit_on_error_), kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), - colocation_constraints_(other.impl()->colocation_constraints_) {} + colocation_constraints_(other.impl()->colocation_constraints_), + disable_shape_inference_(other.impl()->disable_shape_inference_) {} Scope::Impl::Impl(const Scope& other, Tags::ExitOnError) : graph_(other.impl()->graph_), @@ -156,7 +172,8 @@ Scope::Impl::Impl(const Scope& other, Tags::ExitOnError) exit_on_error_(true), kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), - colocation_constraints_(other.impl()->colocation_constraints_) {} + colocation_constraints_(other.impl()->colocation_constraints_), + disable_shape_inference_(other.impl()->disable_shape_inference_) {} Scope::Impl::Impl(const Scope& other, Tags::KernelLabel, const string& kernel_label) @@ -171,7 +188,8 @@ Scope::Impl::Impl(const Scope& other, Tags::KernelLabel, exit_on_error_(other.impl()->exit_on_error_), kernel_label_(kernel_label), device_(other.impl()->device_), - colocation_constraints_(other.impl()->colocation_constraints_) {} + colocation_constraints_(other.impl()->colocation_constraints_), + disable_shape_inference_(other.impl()->disable_shape_inference_) {} Scope::Impl::Impl(const Scope& other, Tags::Colocate, const Operation& colocate_with_op, bool clear_colocations) @@ -189,7 +207,8 @@ Scope::Impl::Impl(const Scope& other, Tags::Colocate, colocation_constraints_( clear_colocations ? std::unordered_set<string>() - : other.impl()->GetColocationConstraints(colocate_with_op)) {} + : other.impl()->GetColocationConstraints(colocate_with_op)), + disable_shape_inference_(other.impl()->disable_shape_inference_) {} std::unordered_set<string> Scope::Impl::GetColocationConstraints( const Operation& colocate_with_op) const { @@ -404,6 +423,11 @@ CompositeOpScopes Scope::GetCompositeOpScopes( } } +Status Scope::DoShapeInference(Node* node) const { + if (impl_->disable_shape_inference_) return Status::OK(); + return impl_->refiner_->AddNode(node); +} + class InternalScope { public: // NewScope doesn't take ownership of the inputs. diff --git a/tensorflow/cc/framework/scope.h b/tensorflow/cc/framework/scope.h index ec3543772d..5cae5c64ad 100644 --- a/tensorflow/cc/framework/scope.h +++ b/tensorflow/cc/framework/scope.h @@ -199,6 +199,18 @@ class Scope { // edges from the source and to the sink node, resolves back edges // by name), and makes sure the resulting graph is valid. Status ToGraph(Graph* g) const; + + // Calls AddNode() using this scope's ShapeRefiner. This exists in the public + // API to prevent custom op wrappers from needing access to shape_refiner.h or + // scope_internal.h. + // TODO(skyewm): remove this from public API + Status DoShapeInference(Node* node) const; + + // Creates a new root scope that causes all DoShapeInference() calls to return + // Status::OK() (on the returned scope and any subscopes). Used for testing. + // TODO(skyewm): fix tests that still require this and eventually remove, or + // at least remove from public API + static Scope DisabledShapeInferenceScope(); // END_SKIP_DOXYGEN const std::vector<Operation>& control_deps() const; diff --git a/tensorflow/cc/framework/scope_internal.h b/tensorflow/cc/framework/scope_internal.h index 3656c0aecf..e2cc22af5d 100644 --- a/tensorflow/cc/framework/scope_internal.h +++ b/tensorflow/cc/framework/scope_internal.h @@ -58,7 +58,8 @@ class Scope::Impl { enum class Colocate; }; - Impl(Graph* graph, Status* status, NameMap* name_map, ShapeRefiner* refiner); + Impl(Graph* graph, Status* status, NameMap* name_map, ShapeRefiner* refiner, + bool disable_shape_inference); Impl(const Scope& other, Tags::ScopeName, const string& name, bool copy_names); Impl(const Scope& other, Tags::OpName, const string& name, @@ -103,6 +104,10 @@ class Scope::Impl { const string kernel_label_ = ""; const string device_ = ""; const std::unordered_set<string> colocation_constraints_; + + // If true, Scope::DoShapeInference() always returns Status:OK(). + // TODO(skyewm): remove this when possible + const bool disable_shape_inference_; }; } // namespace tensorflow diff --git a/tensorflow/cc/framework/test_op.cc b/tensorflow/cc/framework/test_op.cc index fe0d907df0..b76842a9a0 100644 --- a/tensorflow/cc/framework/test_op.cc +++ b/tensorflow/cc/framework/test_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" namespace tensorflow { @@ -24,6 +25,7 @@ REGISTER_OP("ThrowAway1") .Attr("scope: int") .Attr("builder: int = 1") .Attr("while: int") + .SetShapeFn(shape_inference::UnknownShape) .Doc(R"doc( Op to test keywords and reserved words in input and attr names. @@ -36,12 +38,20 @@ REGISTER_OP("ThrowAway2") .Attr("scope: int = 2") .Attr("throw_away2: int = 2") .Attr("attrs: int = 4") - .Attr("node: int = 4"); + .Attr("node: int = 4") + .SetShapeFn(shape_inference::UnknownShape); -REGISTER_OP("ThrowAway3").Output("node: int32"); +REGISTER_OP("ThrowAway3") + .Output("node: int32") + .SetShapeFn(shape_inference::UnknownShape); -REGISTER_OP("ThrowAway4").Input("node: int32"); +REGISTER_OP("ThrowAway4") + .Input("node: int32") + .SetShapeFn(shape_inference::UnknownShape); -REGISTER_OP("ThrowAway5").Output("foo: int32").Attr("node: int = 4"); +REGISTER_OP("ThrowAway5") + .Output("foo: int32") + .Attr("node: int = 4") + .SetShapeFn(shape_inference::UnknownShape); } // namespace tensorflow diff --git a/tensorflow/cc/ops/const_op.cc b/tensorflow/cc/ops/const_op.cc index b37b8b67d7..0030c2b2a7 100644 --- a/tensorflow/cc/ops/const_op.cc +++ b/tensorflow/cc/ops/const_op.cc @@ -34,7 +34,9 @@ Output Const(const Scope& scope, const Input::Initializer& val) { .Attr("dtype", val.tensor.dtype()); scope.UpdateBuilder(&builder); scope.UpdateStatus(builder.Finalize(graph, &ret)); + if (!scope.ok()) return Output(); + scope.UpdateStatus(scope.DoShapeInference(ret)); if (!scope.ok()) return Output(); return Output(ret); diff --git a/tensorflow/cc/ops/const_op.h b/tensorflow/cc/ops/const_op.h index e8cb6cf1dd..516800920f 100644 --- a/tensorflow/cc/ops/const_op.h +++ b/tensorflow/cc/ops/const_op.h @@ -56,6 +56,8 @@ Output Const(const Scope& scope, const Input::Initializer& val) { scope.UpdateBuilder(&cast_builder); Node* ret; scope.UpdateStatus(cast_builder.Finalize(scope.graph(), &ret)); + if (!scope.ok()) return Output(); + scope.UpdateStatus(scope.DoShapeInference(ret)); return Output(ret, 0); } diff --git a/tensorflow/compiler/tf2xla/ops/functional_ops.cc b/tensorflow/compiler/tf2xla/ops/functional_ops.cc index 38bcaa3227..c1005405f9 100644 --- a/tensorflow/compiler/tf2xla/ops/functional_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/functional_ops.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" namespace tensorflow { @@ -26,6 +27,7 @@ REGISTER_OP("XlaWhile") .Attr("cond: func") .Attr("body: func") .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape) .Doc(R"doc( output = input; While (Cond(output)) { output = Body(output) } diff --git a/tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc b/tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc index d3a36868de..b6947bfe57 100644 --- a/tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" namespace tensorflow { @@ -22,6 +23,7 @@ REGISTER_OP("_XLASend") .Attr("T: type") .Attr("tensor_name: string") .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape) .Doc(R"doc( Sends the named tensor to another XLA computation. @@ -35,6 +37,7 @@ REGISTER_OP("_XLARecv") .Attr("tensor_name: string") .Attr("shape: shape") .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape) .Doc(R"doc( Receives the named tensor from another XLA computation. diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index b828c4b87b..a1e4dcb684 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/graph/graph.h" @@ -76,6 +77,8 @@ class DummyReadResourceCC { scope.UpdateBuilder(&builder); scope.UpdateStatus(builder.Finalize(scope.graph(), &ret)); if (!scope.ok()) return; + scope.UpdateStatus(scope.DoShapeInference(ret)); + if (!scope.ok()) return; this->output_ = Output(ret, 0); } Node* node() const { return output_.node(); } @@ -86,6 +89,7 @@ class DummyReadResourceCC { REGISTER_OP("DummyReadResource") .Input("input: int32") .Output("output: int32") + .SetShapeFn(shape_inference::UnknownShape) .Doc(R"doc( A dummy Op. diff --git a/tensorflow/contrib/tpu/ops/replication_ops.cc b/tensorflow/contrib/tpu/ops/replication_ops.cc index 282a00b52c..a40e2a7898 100644 --- a/tensorflow/contrib/tpu/ops/replication_ops.cc +++ b/tensorflow/contrib/tpu/ops/replication_ops.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" @@ -67,6 +68,7 @@ REGISTER_OP("TPUReplicate") .Input("broadcast_inputs: Tbroadcast_inputs") .Input("variables: NumVariables * resource") .Output("outputs: output_types") + .SetShapeFn(shape_inference::UnknownShape) .Doc(R"doc( Runs replicated computations on a distributed TPU system. diff --git a/tensorflow/contrib/tpu/ops/tpu_configuration_ops.cc b/tensorflow/contrib/tpu/ops/tpu_configuration_ops.cc index 5dc564ed27..8a87a91056 100644 --- a/tensorflow/contrib/tpu/ops/tpu_configuration_ops.cc +++ b/tensorflow/contrib/tpu/ops/tpu_configuration_ops.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/lib/core/status.h" @@ -154,7 +155,10 @@ global_tpu_array: A two-dimensional array. For each host (the outer dimension) the array lists the global ids of the TPUs on that host. )doc"); -REGISTER_OP("_ShutdownDistributedTPU").SetIsStateful().Doc(R"doc( +REGISTER_OP("_ShutdownDistributedTPU") + .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( An op that shuts down a running distributed TPU system. The Op returns an error if no system is running. This Op must be run on the same TPU_SYSTEM device as the corresponding _ConfigureDistributedTPU was run @@ -184,6 +188,7 @@ tpu_ids: A vector containing the global TPU id of each TPU on the host. REGISTER_OP("_DisconnectHostFromDistributedTPUSystem") .Output("number_of_tpu_chips: int32") .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape) .Doc(R"doc( An op that disconnects the TPUs on a host from a running distributed TPU system. @@ -196,6 +201,7 @@ REGISTER_OP("ConfigureDistributedTPU") .Output("global_tpu_array: int32") .Attr("embedding_config: string = ''") .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape) .Doc(R"doc( An op that sets up the centralized structures for a distributed TPU system. @@ -205,7 +211,10 @@ dimension) the array lists the global ids of the TPUs on that host. embedding_config: Internal use. )doc"); -REGISTER_OP("ShutdownDistributedTPU").SetIsStateful().Doc(R"doc( +REGISTER_OP("ShutdownDistributedTPU") + .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( An op that shuts down a running distributed TPU system. The Op returns an error if no system is running. )doc"); diff --git a/tensorflow/core/common_runtime/constant_folding_test.cc b/tensorflow/core/common_runtime/constant_folding_test.cc index c76ad647a0..380e5dfbd7 100644 --- a/tensorflow/core/common_runtime/constant_folding_test.cc +++ b/tensorflow/core/common_runtime/constant_folding_test.cc @@ -282,6 +282,7 @@ TEST_F(ConstantFoldingTest, TestNoReplaceFunctionCall) { Status status; Node* times_two = s.graph()->AddNode(def, &status); TF_ASSERT_OK(status); + TF_ASSERT_OK(s.DoShapeInference(times_two)); s.graph()->AddEdge(c.node(), 0, times_two, 0); auto times_two_send = @@ -297,7 +298,10 @@ TEST_F(ConstantFoldingTest, TestNoReplaceFunctionCall) { EXPECT_FALSE(was_mutated); } -REGISTER_OP("ConstantFoldingTestOp").Input("a: int64").Output("b: int64"); +REGISTER_OP("ConstantFoldingTestOp") + .Input("a: int64") + .Output("b: int64") + .SetShapeFn(shape_inference::UnknownShape); TEST_F(ConstantFoldingTest, TestNoReplaceNonCPUOp) { Graph g(OpRegistry::Global()); @@ -312,6 +316,7 @@ TEST_F(ConstantFoldingTest, TestNoReplaceNonCPUOp) { Status status; Node* non_cpu = s.graph()->AddNode(def, &status); TF_ASSERT_OK(status); + TF_ASSERT_OK(s.DoShapeInference(non_cpu)); auto non_cpu_send = ops::_Send(s.WithOpName("non_cpu_send"), Output(non_cpu), diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index b1d32a984a..3ca4457b00 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -284,6 +284,7 @@ Output Call(Scope* scope, const string& op_name, const string& fn_name, Status status; Node* n = scope->graph()->AddNode(def, &status); TF_CHECK_OK(status); + TF_CHECK_OK(scope->DoShapeInference(n)); for (int i = 0; i < inputs.size(); ++i) { scope->graph()->AddEdge(inputs[i].node(), inputs[i].index(), n, i); } @@ -989,7 +990,7 @@ TEST(OptimizationTest, RemoveDeadNodes) { GraphDef expected; { - Scope s = Scope::NewRootScope(); + Scope s = Scope::DisabledShapeInferenceScope(); auto x = ops::_Arg(s.WithOpName("x"), DT_INT32, 0); auto o = ops::Const(s.WithOpName("o"), 1); auto keep_me = ops::RandomUniform(s.WithOpName("keep_me"), {o}, DT_FLOAT); @@ -1070,7 +1071,7 @@ TEST(OptimizationTest, RemoveIdentityNodes) { {{"y"}, "Add", {"a", "o"}, {{"T", T}}}}); { - Scope s = Scope::NewRootScope(); + Scope s = Scope::DisabledShapeInferenceScope(); auto x = ops::_Arg(s.WithOpName("x"), DT_INT32, 0); auto o = ops::Const(s.WithOpName("o"), 1); auto a = ops::Square(s.WithOpName("a"), x); @@ -1087,7 +1088,7 @@ TEST(OptimizationTest, RemoveIdentityNodes) { } { - Scope s = Scope::NewRootScope(); + Scope s = Scope::DisabledShapeInferenceScope(); auto x = ops::_Arg(s.WithOpName("x"), DT_INT32, 0); auto o = ops::Const(s.WithOpName("o"), 1); auto a = ops::Square(s.WithOpName("a"), x); @@ -1137,7 +1138,7 @@ TEST(OptimizationTest, RemoveListArrayConverter) { {{"o", "o:sum"}}); { - Scope scope = Scope::NewRootScope(); + Scope scope = Scope::DisabledShapeInferenceScope(); auto i = ops::_Arg(scope.WithOpName("i"), DT_FLOAT, 0); auto zero = ops::Const(scope.WithOpName("zero"), 0); auto s = ops::Split(scope.WithOpName("s"), zero, i, 4); @@ -1222,7 +1223,7 @@ TEST(OptimizationTest, RemoveListArrayConverter_WithContolDeps) { {{"o", "o:sum"}}); { - Scope s = Scope::NewRootScope(); + Scope s = Scope::DisabledShapeInferenceScope(); auto i = ops::_Arg(s.WithOpName("i"), DT_FLOAT, 0); auto dummy = ops::Const(s.WithOpName("dummy"), 0); auto x = ops::_ListToArray(s.WithOpName("x").WithControlDependencies(dummy), diff --git a/tensorflow/core/common_runtime/gpu/gpu_stream_util_test.cc b/tensorflow/core/common_runtime/gpu/gpu_stream_util_test.cc index 003e416bbe..7763a4f2e6 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_stream_util_test.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_stream_util_test.cc @@ -64,7 +64,7 @@ TEST_F(GpuStreamUtilTest, EmptyGraph) { } TEST_F(GpuStreamUtilTest, SimpleGraphOneStream) { - auto root = Scope::NewRootScope().ExitOnError(); + auto root = Scope::DisabledShapeInferenceScope().ExitOnError(); ops::MatMul(root, {}, {}); Graph g(OpRegistry::Global()); TF_ASSERT_OK(root.ToGraph(&g)); @@ -83,7 +83,7 @@ TEST_F(GpuStreamUtilTest, SimpleGraphOneStream) { } TEST_F(GpuStreamUtilTest, SimpleGraphManyStreams) { - auto root = Scope::NewRootScope().ExitOnError(); + auto root = Scope::DisabledShapeInferenceScope().ExitOnError(); ops::MatMul(root, {}, {}); Graph g(OpRegistry::Global()); TF_ASSERT_OK(root.ToGraph(&g)); @@ -104,7 +104,7 @@ TEST_F(GpuStreamUtilTest, SimpleGraphManyStreams) { } TEST_F(GpuStreamUtilTest, StreamOverrides) { - auto root = Scope::NewRootScope().ExitOnError(); + auto root = Scope::DisabledShapeInferenceScope().ExitOnError(); ops::_Recv(root.WithOpName("input"), DT_FLOAT, "input", "/cpu:0", 0, "/device:GPU:0"); Output n = ops::MatMul(root, {}, {}); diff --git a/tensorflow/core/common_runtime/shape_refiner_test.cc b/tensorflow/core/common_runtime/shape_refiner_test.cc index 6d2a863106..4ef132486a 100644 --- a/tensorflow/core/common_runtime/shape_refiner_test.cc +++ b/tensorflow/core/common_runtime/shape_refiner_test.cc @@ -882,7 +882,7 @@ TEST_F(ShapeRefinerTest, ConstantValueAsShape_Shape) { } TEST_F(ShapeRefinerTest, ConstantValueAsShape_PackInt32) { - Scope root = Scope::NewRootScope(); + Scope root = Scope::DisabledShapeInferenceScope(); Node* scalar_non_const; TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt32") .Finalize(root.graph(), &scalar_non_const)); @@ -914,7 +914,7 @@ TEST_F(ShapeRefinerTest, ConstantValueAsShape_PackInt32) { } TEST_F(ShapeRefinerTest, ConstantValueAsShape_PackInt64) { - Scope root = Scope::NewRootScope(); + Scope root = Scope::DisabledShapeInferenceScope(); Node* scalar_non_const; TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt64") .Finalize(root.graph(), &scalar_non_const)); @@ -997,7 +997,7 @@ TEST_F(ShapeRefinerTest, ConstantValueAsShape_PackInvalidInput) { } TEST_F(ShapeRefinerTest, ConstantValueAsShape_Concat) { - Scope root = Scope::NewRootScope(); + Scope root = Scope::DisabledShapeInferenceScope(); Graph* g = root.graph(); Node* partial_1; Node* partial_2; @@ -1034,7 +1034,7 @@ TEST_F(ShapeRefinerTest, ConstantValueAsShape_Concat) { } TEST_F(ShapeRefinerTest, ConstantValueAsShape_ConcatWithUnknown) { - Scope root = Scope::NewRootScope(); + Scope root = Scope::DisabledShapeInferenceScope(); Graph* g = root.graph(); Node* scalar_non_const; TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt32") @@ -1077,7 +1077,7 @@ TEST_F(ShapeRefinerTest, ConstantValueAsShape_ConcatWithUnknown) { } TEST_F(ShapeRefinerTest, ConstantValueAsShape_ConcatInvalidDimValue) { - Scope root = Scope::NewRootScope(); + Scope root = Scope::DisabledShapeInferenceScope(); Graph* g = root.graph(); Node* scalar_non_const; TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt32") diff --git a/tensorflow/core/graph/graph_partition_test.cc b/tensorflow/core/graph/graph_partition_test.cc index d84c62d454..8dde7320ed 100644 --- a/tensorflow/core/graph/graph_partition_test.cc +++ b/tensorflow/core/graph/graph_partition_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/cc/ops/math_ops.h" #include "tensorflow/cc/ops/random_ops.h" #include "tensorflow/cc/ops/sendrecv_ops.h" +#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/versions.pb.h" @@ -141,9 +142,17 @@ void CheckLoopConstruction(const GraphDef& graph_def) { } } -REGISTER_OP("FloatInput").Output("o: float"); -REGISTER_OP("BoolInput").Output("o: bool"); -REGISTER_OP("Combine").Input("a: float").Input("b: float").Output("o: float"); +REGISTER_OP("FloatInput") + .Output("o: float") + .SetShapeFn(shape_inference::UnknownShape); +REGISTER_OP("BoolInput") + .Output("o: bool") + .SetShapeFn(shape_inference::UnknownShape); +REGISTER_OP("Combine") + .Input("a: float") + .Input("b: float") + .Output("o: float") + .SetShapeFn(shape_inference::UnknownShape); Output ConstructOp(const Scope& scope, const string& op_type, const gtl::ArraySlice<Input>& inputs) { @@ -158,6 +167,8 @@ Output ConstructOp(const Scope& scope, const string& op_type, Node* ret; scope.UpdateStatus(builder.Finalize(scope.graph(), &ret)); if (!scope.ok()) return Output(); + scope.UpdateStatus(scope.DoShapeInference(ret)); + if (!scope.ok()) return Output(); return Output(ret); } diff --git a/tensorflow/core/grappler/optimizers/auto_parallel_test.cc b/tensorflow/core/grappler/optimizers/auto_parallel_test.cc index 9a41b5e0b5..1c3186f1ee 100644 --- a/tensorflow/core/grappler/optimizers/auto_parallel_test.cc +++ b/tensorflow/core/grappler/optimizers/auto_parallel_test.cc @@ -28,7 +28,7 @@ namespace { class AutoParallelTest : public ::testing::Test {}; TEST_F(AutoParallelTest, SimpleParallel) { - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + tensorflow::Scope s = tensorflow::Scope::DisabledShapeInferenceScope(); Output constant_a = ops::Const(s.WithOpName("constant_a"), 1.0f, {1}); Output constant_b = ops::Const(s.WithOpName("constant_b"), 1, {1}); Output var = ops::Variable(s.WithOpName("var"), {1}, DT_FLOAT); diff --git a/tensorflow/core/kernels/encode_wav_op_test.cc b/tensorflow/core/kernels/encode_wav_op_test.cc index 2f92c13268..34138ac9a0 100644 --- a/tensorflow/core/kernels/encode_wav_op_test.cc +++ b/tensorflow/core/kernels/encode_wav_op_test.cc @@ -35,7 +35,7 @@ namespace tensorflow { using namespace ops; // NOLINT(build/namespaces) TEST(EncodeWavOpTest, EncodeWavTest) { - Scope root = Scope::NewRootScope(); + Scope root = Scope::DisabledShapeInferenceScope(); Tensor audio_tensor(DT_FLOAT, {4, 2}); test::FillValues<float>( diff --git a/tensorflow/core/kernels/fuzzing/fuzz_session.h b/tensorflow/core/kernels/fuzzing/fuzz_session.h index fb518798b2..0c0e548a90 100644 --- a/tensorflow/core/kernels/fuzzing/fuzz_session.h +++ b/tensorflow/core/kernels/fuzzing/fuzz_session.h @@ -88,7 +88,7 @@ class FuzzSession { } initialized_ = true; - Scope root = Scope::NewRootScope().ExitOnError(); + Scope root = Scope::DisabledShapeInferenceScope().ExitOnError(); SessionOptions options; session_ = std::unique_ptr<Session>(NewSession(options)); diff --git a/tensorflow/core/kernels/immutable_constant_op_test.cc b/tensorflow/core/kernels/immutable_constant_op_test.cc index d822e316ea..b318c9c79a 100644 --- a/tensorflow/core/kernels/immutable_constant_op_test.cc +++ b/tensorflow/core/kernels/immutable_constant_op_test.cc @@ -121,7 +121,7 @@ TEST(ImmutableConstantOpTest, ExecutionError) { const TensorShape kBadTensorShape({40, 100}); const TensorShape kTestTensorShapeT({1, 4}); - auto root = Scope::NewRootScope().ExitOnError(); + auto root = Scope::DisabledShapeInferenceScope().ExitOnError(); auto node1 = ops::ImmutableConst(root, DT_FLOAT, kBadTensorShape, "test:///2"); auto node2 = diff --git a/tensorflow/core/kernels/mfcc_op_test.cc b/tensorflow/core/kernels/mfcc_op_test.cc index d16171d526..57391128f9 100644 --- a/tensorflow/core/kernels/mfcc_op_test.cc +++ b/tensorflow/core/kernels/mfcc_op_test.cc @@ -35,7 +35,7 @@ namespace tensorflow { using namespace ops; // NOLINT(build/namespaces) TEST(MfccOpTest, SimpleTest) { - Scope root = Scope::NewRootScope(); + Scope root = Scope::DisabledShapeInferenceScope(); Tensor spectrogram_tensor(DT_FLOAT, TensorShape({1, 1, 513})); test::FillIota<float>(&spectrogram_tensor, 1.0f); diff --git a/tensorflow/core/ops/sendrecv_ops.cc b/tensorflow/core/ops/sendrecv_ops.cc index 55f6585ade..7d0fda2f87 100644 --- a/tensorflow/core/ops/sendrecv_ops.cc +++ b/tensorflow/core/ops/sendrecv_ops.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" namespace tensorflow { @@ -26,6 +27,7 @@ REGISTER_OP("_Send") .Attr("recv_device: string") .Attr("client_terminated: bool = false") .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape) .Doc(R"doc( Sends the named tensor from send_device to recv_device. @@ -49,6 +51,7 @@ REGISTER_OP("_Recv") .Attr("recv_device: string") .Attr("client_terminated: bool = false") .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape) .Doc(R"doc( Receives the named tensor from send_device on recv_device. @@ -72,6 +75,7 @@ REGISTER_OP("_HostSend") .Attr("recv_device: string") .Attr("client_terminated: bool = false") .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape) .Doc(R"doc( Sends the named tensor from send_device to recv_device. @@ -98,6 +102,7 @@ REGISTER_OP("_HostRecv") .Attr("recv_device: string") .Attr("client_terminated: bool = false") .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape) .Doc(R"doc( Receives the named tensor from send_device on recv_device. diff --git a/tensorflow/tools/graph_transforms/fake_quantize_training_test.cc b/tensorflow/tools/graph_transforms/fake_quantize_training_test.cc index 3ea7f512c6..5e4ab209e9 100644 --- a/tensorflow/tools/graph_transforms/fake_quantize_training_test.cc +++ b/tensorflow/tools/graph_transforms/fake_quantize_training_test.cc @@ -36,7 +36,7 @@ class FakeQuantizeTrainingTest : public ::testing::Test {}; // TODO(suharshs): Once we implement the fake_quantize_training transform // using the GTT, write proper tests of the transform here. TEST_F(FakeQuantizeTrainingTest, TransformOccurred) { - auto root = tensorflow::Scope::NewRootScope(); + auto root = tensorflow::Scope::DisabledShapeInferenceScope(); using namespace ::tensorflow::ops; // NOLINT(build/namespaces) Tensor a_data(DT_FLOAT, TensorShape()); diff --git a/tensorflow/tools/graph_transforms/quantize_weights_test.cc b/tensorflow/tools/graph_transforms/quantize_weights_test.cc index e1828831db..a58ce73453 100644 --- a/tensorflow/tools/graph_transforms/quantize_weights_test.cc +++ b/tensorflow/tools/graph_transforms/quantize_weights_test.cc @@ -40,7 +40,7 @@ class QuantizeWeightsTest : public ::testing::Test { const TensorShape& weight_shape, std::initializer_list<float> weight_values, GraphDef* original_graph_def) { - auto root = tensorflow::Scope::NewRootScope(); + auto root = tensorflow::Scope::DisabledShapeInferenceScope(); Tensor input_data(DT_FLOAT, input_shape); test::FillValues<float>(&input_data, input_values); diff --git a/tensorflow/tools/graph_transforms/transform_utils_test.cc b/tensorflow/tools/graph_transforms/transform_utils_test.cc index b5bc2d75fd..abb25916c5 100644 --- a/tensorflow/tools/graph_transforms/transform_utils_test.cc +++ b/tensorflow/tools/graph_transforms/transform_utils_test.cc @@ -622,7 +622,7 @@ class TransformUtilsTest : public ::testing::Test { } void TestRenameNodeInputsWithWildcard() { - auto root = tensorflow::Scope::NewRootScope(); + auto root = tensorflow::Scope::DisabledShapeInferenceScope(); using namespace ::tensorflow::ops; // NOLINT(build/namespaces) const int width = 10; |