aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2017-08-15 16:46:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-15 16:50:32 -0700
commit477d49c9eaafc5e1e1667d454ce5883956180713 (patch)
tree02ef20e777b6d686a6fa97327f761b40ff91c738
parent9ba0abc2f06fe2d09c96487d5170b2faec79d2a4 (diff)
C++ API: run shape inference as nodes are constructed
Here's an example of the new generated code: AddN::AddN(const ::tensorflow::Scope& scope, ::tensorflow::InputList inputs) { if (!scope.ok()) return; auto _inputs = ::tensorflow::ops::AsNodeOutList(scope, inputs); if (!scope.ok()) return; ::tensorflow::Node* ret; const auto unique_name = scope.GetUniqueNameForOp("AddN"); auto builder = ::tensorflow::NodeBuilder(unique_name, "AddN") .Input(_inputs) ; scope.UpdateBuilder(&builder); scope.UpdateStatus(builder.Finalize(scope.graph(), &ret)); if (!scope.ok()) return; scope.UpdateStatus(scope.DoShapeInference(ret)); this->sum = Output(ret, 0); } Enabling shape inference unfortunately broke many tests. I fixed some of them, but for others I introduced a Scope::DisabledShapeInferenceScope() static method that returns a scope that doesn't perform shape inference. Eventually we should fix the tests that use this and remove it. PiperOrigin-RevId: 165378429
-rw-r--r--tensorflow/cc/framework/cc_op_gen.cc8
-rw-r--r--tensorflow/cc/framework/scope.cc48
-rw-r--r--tensorflow/cc/framework/scope.h12
-rw-r--r--tensorflow/cc/framework/scope_internal.h7
-rw-r--r--tensorflow/cc/framework/test_op.cc18
-rw-r--r--tensorflow/cc/ops/const_op.cc2
-rw-r--r--tensorflow/cc/ops/const_op.h2
-rw-r--r--tensorflow/compiler/tf2xla/ops/functional_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc3
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler_test.cc4
-rw-r--r--tensorflow/contrib/tpu/ops/replication_ops.cc2
-rw-r--r--tensorflow/contrib/tpu/ops/tpu_configuration_ops.cc13
-rw-r--r--tensorflow/core/common_runtime/constant_folding_test.cc7
-rw-r--r--tensorflow/core/common_runtime/function_test.cc11
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_stream_util_test.cc6
-rw-r--r--tensorflow/core/common_runtime/shape_refiner_test.cc10
-rw-r--r--tensorflow/core/graph/graph_partition_test.cc17
-rw-r--r--tensorflow/core/grappler/optimizers/auto_parallel_test.cc2
-rw-r--r--tensorflow/core/kernels/encode_wav_op_test.cc2
-rw-r--r--tensorflow/core/kernels/fuzzing/fuzz_session.h2
-rw-r--r--tensorflow/core/kernels/immutable_constant_op_test.cc2
-rw-r--r--tensorflow/core/kernels/mfcc_op_test.cc2
-rw-r--r--tensorflow/core/ops/sendrecv_ops.cc5
-rw-r--r--tensorflow/tools/graph_transforms/fake_quantize_training_test.cc2
-rw-r--r--tensorflow/tools/graph_transforms/quantize_weights_test.cc2
-rw-r--r--tensorflow/tools/graph_transforms/transform_utils_test.cc2
26 files changed, 143 insertions, 50 deletions
diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc
index 80dd272f6f..38a17598b8 100644
--- a/tensorflow/cc/framework/cc_op_gen.cc
+++ b/tensorflow/cc/framework/cc_op_gen.cc
@@ -812,12 +812,8 @@ string OpInfo::GetConstructorBody() const {
strings::StrAppend(&body, " ", scope_str, ".UpdateStatus(builder.Finalize(",
scope_str, ".graph(), &ret));\n");
strings::StrAppend(&body, " ", return_on_error, "\n");
-
- // TODO(b/28152992): Enable this code-path once we have converted
- // all python shape functions to call their C++ versions.
-
- // strings::StrAppend(&body, " ", scope_str, ".UpdateStatus(", scope_str,
- // ".refiner()->AddNode(ret));\n");
+ strings::StrAppend(&body, " ", scope_str, ".UpdateStatus(", scope_str,
+ ".DoShapeInference(ret));\n");
GetOutput(&body);
return body;
diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc
index 4705b6b7e8..7164249262 100644
--- a/tensorflow/cc/framework/scope.cc
+++ b/tensorflow/cc/framework/scope.cc
@@ -37,13 +37,14 @@ Scope& Scope::operator=(const Scope& other) {
}
Scope::Impl::Impl(Graph* graph, Status* status, NameMap* name_map,
- ShapeRefiner* refiner)
+ ShapeRefiner* refiner, bool disable_shape_inference)
: graph_(graph),
status_(status),
name_map_(name_map),
refiner_(refiner),
scope_used_(nullptr),
- colocation_constraints_() {}
+ colocation_constraints_(),
+ disable_shape_inference_(disable_shape_inference) {}
Scope::Impl::Impl(const std::shared_ptr<Graph>& graph,
const std::shared_ptr<Status>& status,
@@ -54,13 +55,23 @@ Scope::Impl::Impl(const std::shared_ptr<Graph>& graph,
name_map_(name_map),
refiner_(refiner),
scope_used_(nullptr),
- colocation_constraints_() {}
+ colocation_constraints_(),
+ disable_shape_inference_(false) {}
Scope Scope::NewRootScope() {
Graph* graph = new Graph(OpRegistry::Global());
ShapeRefiner* refiner =
new ShapeRefiner(graph->versions(), graph->op_registry());
- return Scope(new Impl(graph, new Status, new Impl::NameMap, refiner));
+ return Scope(new Impl(graph, new Status, new Impl::NameMap, refiner,
+ /* disable_shape_inference */ false));
+}
+
+Scope Scope::DisabledShapeInferenceScope() {
+ Graph* graph = new Graph(OpRegistry::Global());
+ ShapeRefiner* refiner =
+ new ShapeRefiner(graph->versions(), graph->op_registry());
+ return Scope(new Impl(graph, new Status, new Impl::NameMap, refiner,
+ /* disable_shape_inference */ true));
}
Scope::Impl::Impl(const Scope& other, Tags::ScopeName, const string& name,
@@ -77,7 +88,8 @@ Scope::Impl::Impl(const Scope& other, Tags::ScopeName, const string& name,
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
- colocation_constraints_(other.impl()->colocation_constraints_) {}
+ colocation_constraints_(other.impl()->colocation_constraints_),
+ disable_shape_inference_(other.impl()->disable_shape_inference_) {}
Scope::Impl::Impl(const Scope& other, Tags::OpName, const string& name,
const string& op_name)
@@ -92,7 +104,8 @@ Scope::Impl::Impl(const Scope& other, Tags::OpName, const string& name,
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
- colocation_constraints_(other.impl()->colocation_constraints_) {}
+ colocation_constraints_(other.impl()->colocation_constraints_),
+ disable_shape_inference_(other.impl()->disable_shape_inference_) {}
Scope::Impl::Impl(const Scope& other, Tags::ControlDeps,
std::vector<Operation> control_deps, bool clear_control_deps)
@@ -113,7 +126,8 @@ Scope::Impl::Impl(const Scope& other, Tags::ControlDeps,
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
- colocation_constraints_(other.impl()->colocation_constraints_) {}
+ colocation_constraints_(other.impl()->colocation_constraints_),
+ disable_shape_inference_(other.impl()->disable_shape_inference_) {}
Scope::Impl::Impl(const Scope& other, Tags::Device, const string& device)
: graph_(other.impl()->graph_),
@@ -127,7 +141,8 @@ Scope::Impl::Impl(const Scope& other, Tags::Device, const string& device)
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(device),
- colocation_constraints_(other.impl()->colocation_constraints_) {}
+ colocation_constraints_(other.impl()->colocation_constraints_),
+ disable_shape_inference_(other.impl()->disable_shape_inference_) {}
Scope::Impl::Impl(const Scope& other, Tags::SingleUseScope,
const string& op_name)
@@ -142,7 +157,8 @@ Scope::Impl::Impl(const Scope& other, Tags::SingleUseScope,
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
- colocation_constraints_(other.impl()->colocation_constraints_) {}
+ colocation_constraints_(other.impl()->colocation_constraints_),
+ disable_shape_inference_(other.impl()->disable_shape_inference_) {}
Scope::Impl::Impl(const Scope& other, Tags::ExitOnError)
: graph_(other.impl()->graph_),
@@ -156,7 +172,8 @@ Scope::Impl::Impl(const Scope& other, Tags::ExitOnError)
exit_on_error_(true),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
- colocation_constraints_(other.impl()->colocation_constraints_) {}
+ colocation_constraints_(other.impl()->colocation_constraints_),
+ disable_shape_inference_(other.impl()->disable_shape_inference_) {}
Scope::Impl::Impl(const Scope& other, Tags::KernelLabel,
const string& kernel_label)
@@ -171,7 +188,8 @@ Scope::Impl::Impl(const Scope& other, Tags::KernelLabel,
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(kernel_label),
device_(other.impl()->device_),
- colocation_constraints_(other.impl()->colocation_constraints_) {}
+ colocation_constraints_(other.impl()->colocation_constraints_),
+ disable_shape_inference_(other.impl()->disable_shape_inference_) {}
Scope::Impl::Impl(const Scope& other, Tags::Colocate,
const Operation& colocate_with_op, bool clear_colocations)
@@ -189,7 +207,8 @@ Scope::Impl::Impl(const Scope& other, Tags::Colocate,
colocation_constraints_(
clear_colocations
? std::unordered_set<string>()
- : other.impl()->GetColocationConstraints(colocate_with_op)) {}
+ : other.impl()->GetColocationConstraints(colocate_with_op)),
+ disable_shape_inference_(other.impl()->disable_shape_inference_) {}
std::unordered_set<string> Scope::Impl::GetColocationConstraints(
const Operation& colocate_with_op) const {
@@ -404,6 +423,11 @@ CompositeOpScopes Scope::GetCompositeOpScopes(
}
}
+Status Scope::DoShapeInference(Node* node) const {
+ if (impl_->disable_shape_inference_) return Status::OK();
+ return impl_->refiner_->AddNode(node);
+}
+
class InternalScope {
public:
// NewScope doesn't take ownership of the inputs.
diff --git a/tensorflow/cc/framework/scope.h b/tensorflow/cc/framework/scope.h
index ec3543772d..5cae5c64ad 100644
--- a/tensorflow/cc/framework/scope.h
+++ b/tensorflow/cc/framework/scope.h
@@ -199,6 +199,18 @@ class Scope {
// edges from the source and to the sink node, resolves back edges
// by name), and makes sure the resulting graph is valid.
Status ToGraph(Graph* g) const;
+
+ // Calls AddNode() using this scope's ShapeRefiner. This exists in the public
+ // API to prevent custom op wrappers from needing access to shape_refiner.h or
+ // scope_internal.h.
+ // TODO(skyewm): remove this from public API
+ Status DoShapeInference(Node* node) const;
+
+ // Creates a new root scope that causes all DoShapeInference() calls to return
+ // Status::OK() (on the returned scope and any subscopes). Used for testing.
+ // TODO(skyewm): fix tests that still require this and eventually remove, or
+ // at least remove from public API
+ static Scope DisabledShapeInferenceScope();
// END_SKIP_DOXYGEN
const std::vector<Operation>& control_deps() const;
diff --git a/tensorflow/cc/framework/scope_internal.h b/tensorflow/cc/framework/scope_internal.h
index 3656c0aecf..e2cc22af5d 100644
--- a/tensorflow/cc/framework/scope_internal.h
+++ b/tensorflow/cc/framework/scope_internal.h
@@ -58,7 +58,8 @@ class Scope::Impl {
enum class Colocate;
};
- Impl(Graph* graph, Status* status, NameMap* name_map, ShapeRefiner* refiner);
+ Impl(Graph* graph, Status* status, NameMap* name_map, ShapeRefiner* refiner,
+ bool disable_shape_inference);
Impl(const Scope& other, Tags::ScopeName, const string& name,
bool copy_names);
Impl(const Scope& other, Tags::OpName, const string& name,
@@ -103,6 +104,10 @@ class Scope::Impl {
const string kernel_label_ = "";
const string device_ = "";
const std::unordered_set<string> colocation_constraints_;
+
+ // If true, Scope::DoShapeInference() always returns Status:OK().
+ // TODO(skyewm): remove this when possible
+ const bool disable_shape_inference_;
};
} // namespace tensorflow
diff --git a/tensorflow/cc/framework/test_op.cc b/tensorflow/cc/framework/test_op.cc
index fe0d907df0..b76842a9a0 100644
--- a/tensorflow/cc/framework/test_op.cc
+++ b/tensorflow/cc/framework/test_op.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
namespace tensorflow {
@@ -24,6 +25,7 @@ REGISTER_OP("ThrowAway1")
.Attr("scope: int")
.Attr("builder: int = 1")
.Attr("while: int")
+ .SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
Op to test keywords and reserved words in input and attr names.
@@ -36,12 +38,20 @@ REGISTER_OP("ThrowAway2")
.Attr("scope: int = 2")
.Attr("throw_away2: int = 2")
.Attr("attrs: int = 4")
- .Attr("node: int = 4");
+ .Attr("node: int = 4")
+ .SetShapeFn(shape_inference::UnknownShape);
-REGISTER_OP("ThrowAway3").Output("node: int32");
+REGISTER_OP("ThrowAway3")
+ .Output("node: int32")
+ .SetShapeFn(shape_inference::UnknownShape);
-REGISTER_OP("ThrowAway4").Input("node: int32");
+REGISTER_OP("ThrowAway4")
+ .Input("node: int32")
+ .SetShapeFn(shape_inference::UnknownShape);
-REGISTER_OP("ThrowAway5").Output("foo: int32").Attr("node: int = 4");
+REGISTER_OP("ThrowAway5")
+ .Output("foo: int32")
+ .Attr("node: int = 4")
+ .SetShapeFn(shape_inference::UnknownShape);
} // namespace tensorflow
diff --git a/tensorflow/cc/ops/const_op.cc b/tensorflow/cc/ops/const_op.cc
index b37b8b67d7..0030c2b2a7 100644
--- a/tensorflow/cc/ops/const_op.cc
+++ b/tensorflow/cc/ops/const_op.cc
@@ -34,7 +34,9 @@ Output Const(const Scope& scope, const Input::Initializer& val) {
.Attr("dtype", val.tensor.dtype());
scope.UpdateBuilder(&builder);
scope.UpdateStatus(builder.Finalize(graph, &ret));
+ if (!scope.ok()) return Output();
+ scope.UpdateStatus(scope.DoShapeInference(ret));
if (!scope.ok()) return Output();
return Output(ret);
diff --git a/tensorflow/cc/ops/const_op.h b/tensorflow/cc/ops/const_op.h
index e8cb6cf1dd..516800920f 100644
--- a/tensorflow/cc/ops/const_op.h
+++ b/tensorflow/cc/ops/const_op.h
@@ -56,6 +56,8 @@ Output Const(const Scope& scope, const Input::Initializer& val) {
scope.UpdateBuilder(&cast_builder);
Node* ret;
scope.UpdateStatus(cast_builder.Finalize(scope.graph(), &ret));
+ if (!scope.ok()) return Output();
+ scope.UpdateStatus(scope.DoShapeInference(ret));
return Output(ret, 0);
}
diff --git a/tensorflow/compiler/tf2xla/ops/functional_ops.cc b/tensorflow/compiler/tf2xla/ops/functional_ops.cc
index 38bcaa3227..c1005405f9 100644
--- a/tensorflow/compiler/tf2xla/ops/functional_ops.cc
+++ b/tensorflow/compiler/tf2xla/ops/functional_ops.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
namespace tensorflow {
@@ -26,6 +27,7 @@ REGISTER_OP("XlaWhile")
.Attr("cond: func")
.Attr("body: func")
.SetIsStateful()
+ .SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
output = input; While (Cond(output)) { output = Body(output) }
diff --git a/tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc b/tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc
index d3a36868de..b6947bfe57 100644
--- a/tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc
+++ b/tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
namespace tensorflow {
@@ -22,6 +23,7 @@ REGISTER_OP("_XLASend")
.Attr("T: type")
.Attr("tensor_name: string")
.SetIsStateful()
+ .SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
Sends the named tensor to another XLA computation.
@@ -35,6 +37,7 @@ REGISTER_OP("_XLARecv")
.Attr("tensor_name: string")
.Attr("shape: shape")
.SetIsStateful()
+ .SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
Receives the named tensor from another XLA computation.
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
index b828c4b87b..a1e4dcb684 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/graph/graph.h"
@@ -76,6 +77,8 @@ class DummyReadResourceCC {
scope.UpdateBuilder(&builder);
scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
if (!scope.ok()) return;
+ scope.UpdateStatus(scope.DoShapeInference(ret));
+ if (!scope.ok()) return;
this->output_ = Output(ret, 0);
}
Node* node() const { return output_.node(); }
@@ -86,6 +89,7 @@ class DummyReadResourceCC {
REGISTER_OP("DummyReadResource")
.Input("input: int32")
.Output("output: int32")
+ .SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
A dummy Op.
diff --git a/tensorflow/contrib/tpu/ops/replication_ops.cc b/tensorflow/contrib/tpu/ops/replication_ops.cc
index 282a00b52c..a40e2a7898 100644
--- a/tensorflow/contrib/tpu/ops/replication_ops.cc
+++ b/tensorflow/contrib/tpu/ops/replication_ops.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
@@ -67,6 +68,7 @@ REGISTER_OP("TPUReplicate")
.Input("broadcast_inputs: Tbroadcast_inputs")
.Input("variables: NumVariables * resource")
.Output("outputs: output_types")
+ .SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
Runs replicated computations on a distributed TPU system.
diff --git a/tensorflow/contrib/tpu/ops/tpu_configuration_ops.cc b/tensorflow/contrib/tpu/ops/tpu_configuration_ops.cc
index 5dc564ed27..8a87a91056 100644
--- a/tensorflow/contrib/tpu/ops/tpu_configuration_ops.cc
+++ b/tensorflow/contrib/tpu/ops/tpu_configuration_ops.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/lib/core/status.h"
@@ -154,7 +155,10 @@ global_tpu_array: A two-dimensional array. For each host (the outer
dimension) the array lists the global ids of the TPUs on that host.
)doc");
-REGISTER_OP("_ShutdownDistributedTPU").SetIsStateful().Doc(R"doc(
+REGISTER_OP("_ShutdownDistributedTPU")
+ .SetIsStateful()
+ .SetShapeFn(shape_inference::UnknownShape)
+ .Doc(R"doc(
An op that shuts down a running distributed TPU system. The Op returns
an error if no system is running. This Op must be run on the same
TPU_SYSTEM device as the corresponding _ConfigureDistributedTPU was run
@@ -184,6 +188,7 @@ tpu_ids: A vector containing the global TPU id of each TPU on the host.
REGISTER_OP("_DisconnectHostFromDistributedTPUSystem")
.Output("number_of_tpu_chips: int32")
.SetIsStateful()
+ .SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
An op that disconnects the TPUs on a host from a running distributed
TPU system.
@@ -196,6 +201,7 @@ REGISTER_OP("ConfigureDistributedTPU")
.Output("global_tpu_array: int32")
.Attr("embedding_config: string = ''")
.SetIsStateful()
+ .SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
An op that sets up the centralized structures for a distributed TPU
system.
@@ -205,7 +211,10 @@ dimension) the array lists the global ids of the TPUs on that host.
embedding_config: Internal use.
)doc");
-REGISTER_OP("ShutdownDistributedTPU").SetIsStateful().Doc(R"doc(
+REGISTER_OP("ShutdownDistributedTPU")
+ .SetIsStateful()
+ .SetShapeFn(shape_inference::UnknownShape)
+ .Doc(R"doc(
An op that shuts down a running distributed TPU system. The Op returns
an error if no system is running.
)doc");
diff --git a/tensorflow/core/common_runtime/constant_folding_test.cc b/tensorflow/core/common_runtime/constant_folding_test.cc
index c76ad647a0..380e5dfbd7 100644
--- a/tensorflow/core/common_runtime/constant_folding_test.cc
+++ b/tensorflow/core/common_runtime/constant_folding_test.cc
@@ -282,6 +282,7 @@ TEST_F(ConstantFoldingTest, TestNoReplaceFunctionCall) {
Status status;
Node* times_two = s.graph()->AddNode(def, &status);
TF_ASSERT_OK(status);
+ TF_ASSERT_OK(s.DoShapeInference(times_two));
s.graph()->AddEdge(c.node(), 0, times_two, 0);
auto times_two_send =
@@ -297,7 +298,10 @@ TEST_F(ConstantFoldingTest, TestNoReplaceFunctionCall) {
EXPECT_FALSE(was_mutated);
}
-REGISTER_OP("ConstantFoldingTestOp").Input("a: int64").Output("b: int64");
+REGISTER_OP("ConstantFoldingTestOp")
+ .Input("a: int64")
+ .Output("b: int64")
+ .SetShapeFn(shape_inference::UnknownShape);
TEST_F(ConstantFoldingTest, TestNoReplaceNonCPUOp) {
Graph g(OpRegistry::Global());
@@ -312,6 +316,7 @@ TEST_F(ConstantFoldingTest, TestNoReplaceNonCPUOp) {
Status status;
Node* non_cpu = s.graph()->AddNode(def, &status);
TF_ASSERT_OK(status);
+ TF_ASSERT_OK(s.DoShapeInference(non_cpu));
auto non_cpu_send =
ops::_Send(s.WithOpName("non_cpu_send"), Output(non_cpu),
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index b1d32a984a..3ca4457b00 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -284,6 +284,7 @@ Output Call(Scope* scope, const string& op_name, const string& fn_name,
Status status;
Node* n = scope->graph()->AddNode(def, &status);
TF_CHECK_OK(status);
+ TF_CHECK_OK(scope->DoShapeInference(n));
for (int i = 0; i < inputs.size(); ++i) {
scope->graph()->AddEdge(inputs[i].node(), inputs[i].index(), n, i);
}
@@ -989,7 +990,7 @@ TEST(OptimizationTest, RemoveDeadNodes) {
GraphDef expected;
{
- Scope s = Scope::NewRootScope();
+ Scope s = Scope::DisabledShapeInferenceScope();
auto x = ops::_Arg(s.WithOpName("x"), DT_INT32, 0);
auto o = ops::Const(s.WithOpName("o"), 1);
auto keep_me = ops::RandomUniform(s.WithOpName("keep_me"), {o}, DT_FLOAT);
@@ -1070,7 +1071,7 @@ TEST(OptimizationTest, RemoveIdentityNodes) {
{{"y"}, "Add", {"a", "o"}, {{"T", T}}}});
{
- Scope s = Scope::NewRootScope();
+ Scope s = Scope::DisabledShapeInferenceScope();
auto x = ops::_Arg(s.WithOpName("x"), DT_INT32, 0);
auto o = ops::Const(s.WithOpName("o"), 1);
auto a = ops::Square(s.WithOpName("a"), x);
@@ -1087,7 +1088,7 @@ TEST(OptimizationTest, RemoveIdentityNodes) {
}
{
- Scope s = Scope::NewRootScope();
+ Scope s = Scope::DisabledShapeInferenceScope();
auto x = ops::_Arg(s.WithOpName("x"), DT_INT32, 0);
auto o = ops::Const(s.WithOpName("o"), 1);
auto a = ops::Square(s.WithOpName("a"), x);
@@ -1137,7 +1138,7 @@ TEST(OptimizationTest, RemoveListArrayConverter) {
{{"o", "o:sum"}});
{
- Scope scope = Scope::NewRootScope();
+ Scope scope = Scope::DisabledShapeInferenceScope();
auto i = ops::_Arg(scope.WithOpName("i"), DT_FLOAT, 0);
auto zero = ops::Const(scope.WithOpName("zero"), 0);
auto s = ops::Split(scope.WithOpName("s"), zero, i, 4);
@@ -1222,7 +1223,7 @@ TEST(OptimizationTest, RemoveListArrayConverter_WithContolDeps) {
{{"o", "o:sum"}});
{
- Scope s = Scope::NewRootScope();
+ Scope s = Scope::DisabledShapeInferenceScope();
auto i = ops::_Arg(s.WithOpName("i"), DT_FLOAT, 0);
auto dummy = ops::Const(s.WithOpName("dummy"), 0);
auto x = ops::_ListToArray(s.WithOpName("x").WithControlDependencies(dummy),
diff --git a/tensorflow/core/common_runtime/gpu/gpu_stream_util_test.cc b/tensorflow/core/common_runtime/gpu/gpu_stream_util_test.cc
index 003e416bbe..7763a4f2e6 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_stream_util_test.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_stream_util_test.cc
@@ -64,7 +64,7 @@ TEST_F(GpuStreamUtilTest, EmptyGraph) {
}
TEST_F(GpuStreamUtilTest, SimpleGraphOneStream) {
- auto root = Scope::NewRootScope().ExitOnError();
+ auto root = Scope::DisabledShapeInferenceScope().ExitOnError();
ops::MatMul(root, {}, {});
Graph g(OpRegistry::Global());
TF_ASSERT_OK(root.ToGraph(&g));
@@ -83,7 +83,7 @@ TEST_F(GpuStreamUtilTest, SimpleGraphOneStream) {
}
TEST_F(GpuStreamUtilTest, SimpleGraphManyStreams) {
- auto root = Scope::NewRootScope().ExitOnError();
+ auto root = Scope::DisabledShapeInferenceScope().ExitOnError();
ops::MatMul(root, {}, {});
Graph g(OpRegistry::Global());
TF_ASSERT_OK(root.ToGraph(&g));
@@ -104,7 +104,7 @@ TEST_F(GpuStreamUtilTest, SimpleGraphManyStreams) {
}
TEST_F(GpuStreamUtilTest, StreamOverrides) {
- auto root = Scope::NewRootScope().ExitOnError();
+ auto root = Scope::DisabledShapeInferenceScope().ExitOnError();
ops::_Recv(root.WithOpName("input"), DT_FLOAT, "input", "/cpu:0", 0,
"/device:GPU:0");
Output n = ops::MatMul(root, {}, {});
diff --git a/tensorflow/core/common_runtime/shape_refiner_test.cc b/tensorflow/core/common_runtime/shape_refiner_test.cc
index 6d2a863106..4ef132486a 100644
--- a/tensorflow/core/common_runtime/shape_refiner_test.cc
+++ b/tensorflow/core/common_runtime/shape_refiner_test.cc
@@ -882,7 +882,7 @@ TEST_F(ShapeRefinerTest, ConstantValueAsShape_Shape) {
}
TEST_F(ShapeRefinerTest, ConstantValueAsShape_PackInt32) {
- Scope root = Scope::NewRootScope();
+ Scope root = Scope::DisabledShapeInferenceScope();
Node* scalar_non_const;
TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt32")
.Finalize(root.graph(), &scalar_non_const));
@@ -914,7 +914,7 @@ TEST_F(ShapeRefinerTest, ConstantValueAsShape_PackInt32) {
}
TEST_F(ShapeRefinerTest, ConstantValueAsShape_PackInt64) {
- Scope root = Scope::NewRootScope();
+ Scope root = Scope::DisabledShapeInferenceScope();
Node* scalar_non_const;
TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt64")
.Finalize(root.graph(), &scalar_non_const));
@@ -997,7 +997,7 @@ TEST_F(ShapeRefinerTest, ConstantValueAsShape_PackInvalidInput) {
}
TEST_F(ShapeRefinerTest, ConstantValueAsShape_Concat) {
- Scope root = Scope::NewRootScope();
+ Scope root = Scope::DisabledShapeInferenceScope();
Graph* g = root.graph();
Node* partial_1;
Node* partial_2;
@@ -1034,7 +1034,7 @@ TEST_F(ShapeRefinerTest, ConstantValueAsShape_Concat) {
}
TEST_F(ShapeRefinerTest, ConstantValueAsShape_ConcatWithUnknown) {
- Scope root = Scope::NewRootScope();
+ Scope root = Scope::DisabledShapeInferenceScope();
Graph* g = root.graph();
Node* scalar_non_const;
TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt32")
@@ -1077,7 +1077,7 @@ TEST_F(ShapeRefinerTest, ConstantValueAsShape_ConcatWithUnknown) {
}
TEST_F(ShapeRefinerTest, ConstantValueAsShape_ConcatInvalidDimValue) {
- Scope root = Scope::NewRootScope();
+ Scope root = Scope::DisabledShapeInferenceScope();
Graph* g = root.graph();
Node* scalar_non_const;
TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt32")
diff --git a/tensorflow/core/graph/graph_partition_test.cc b/tensorflow/core/graph/graph_partition_test.cc
index d84c62d454..8dde7320ed 100644
--- a/tensorflow/core/graph/graph_partition_test.cc
+++ b/tensorflow/core/graph/graph_partition_test.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/cc/ops/math_ops.h"
#include "tensorflow/cc/ops/random_ops.h"
#include "tensorflow/cc/ops/sendrecv_ops.h"
+#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/versions.pb.h"
@@ -141,9 +142,17 @@ void CheckLoopConstruction(const GraphDef& graph_def) {
}
}
-REGISTER_OP("FloatInput").Output("o: float");
-REGISTER_OP("BoolInput").Output("o: bool");
-REGISTER_OP("Combine").Input("a: float").Input("b: float").Output("o: float");
+REGISTER_OP("FloatInput")
+ .Output("o: float")
+ .SetShapeFn(shape_inference::UnknownShape);
+REGISTER_OP("BoolInput")
+ .Output("o: bool")
+ .SetShapeFn(shape_inference::UnknownShape);
+REGISTER_OP("Combine")
+ .Input("a: float")
+ .Input("b: float")
+ .Output("o: float")
+ .SetShapeFn(shape_inference::UnknownShape);
Output ConstructOp(const Scope& scope, const string& op_type,
const gtl::ArraySlice<Input>& inputs) {
@@ -158,6 +167,8 @@ Output ConstructOp(const Scope& scope, const string& op_type,
Node* ret;
scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
if (!scope.ok()) return Output();
+ scope.UpdateStatus(scope.DoShapeInference(ret));
+ if (!scope.ok()) return Output();
return Output(ret);
}
diff --git a/tensorflow/core/grappler/optimizers/auto_parallel_test.cc b/tensorflow/core/grappler/optimizers/auto_parallel_test.cc
index 9a41b5e0b5..1c3186f1ee 100644
--- a/tensorflow/core/grappler/optimizers/auto_parallel_test.cc
+++ b/tensorflow/core/grappler/optimizers/auto_parallel_test.cc
@@ -28,7 +28,7 @@ namespace {
class AutoParallelTest : public ::testing::Test {};
TEST_F(AutoParallelTest, SimpleParallel) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ tensorflow::Scope s = tensorflow::Scope::DisabledShapeInferenceScope();
Output constant_a = ops::Const(s.WithOpName("constant_a"), 1.0f, {1});
Output constant_b = ops::Const(s.WithOpName("constant_b"), 1, {1});
Output var = ops::Variable(s.WithOpName("var"), {1}, DT_FLOAT);
diff --git a/tensorflow/core/kernels/encode_wav_op_test.cc b/tensorflow/core/kernels/encode_wav_op_test.cc
index 2f92c13268..34138ac9a0 100644
--- a/tensorflow/core/kernels/encode_wav_op_test.cc
+++ b/tensorflow/core/kernels/encode_wav_op_test.cc
@@ -35,7 +35,7 @@ namespace tensorflow {
using namespace ops; // NOLINT(build/namespaces)
TEST(EncodeWavOpTest, EncodeWavTest) {
- Scope root = Scope::NewRootScope();
+ Scope root = Scope::DisabledShapeInferenceScope();
Tensor audio_tensor(DT_FLOAT, {4, 2});
test::FillValues<float>(
diff --git a/tensorflow/core/kernels/fuzzing/fuzz_session.h b/tensorflow/core/kernels/fuzzing/fuzz_session.h
index fb518798b2..0c0e548a90 100644
--- a/tensorflow/core/kernels/fuzzing/fuzz_session.h
+++ b/tensorflow/core/kernels/fuzzing/fuzz_session.h
@@ -88,7 +88,7 @@ class FuzzSession {
}
initialized_ = true;
- Scope root = Scope::NewRootScope().ExitOnError();
+ Scope root = Scope::DisabledShapeInferenceScope().ExitOnError();
SessionOptions options;
session_ = std::unique_ptr<Session>(NewSession(options));
diff --git a/tensorflow/core/kernels/immutable_constant_op_test.cc b/tensorflow/core/kernels/immutable_constant_op_test.cc
index d822e316ea..b318c9c79a 100644
--- a/tensorflow/core/kernels/immutable_constant_op_test.cc
+++ b/tensorflow/core/kernels/immutable_constant_op_test.cc
@@ -121,7 +121,7 @@ TEST(ImmutableConstantOpTest, ExecutionError) {
const TensorShape kBadTensorShape({40, 100});
const TensorShape kTestTensorShapeT({1, 4});
- auto root = Scope::NewRootScope().ExitOnError();
+ auto root = Scope::DisabledShapeInferenceScope().ExitOnError();
auto node1 =
ops::ImmutableConst(root, DT_FLOAT, kBadTensorShape, "test:///2");
auto node2 =
diff --git a/tensorflow/core/kernels/mfcc_op_test.cc b/tensorflow/core/kernels/mfcc_op_test.cc
index d16171d526..57391128f9 100644
--- a/tensorflow/core/kernels/mfcc_op_test.cc
+++ b/tensorflow/core/kernels/mfcc_op_test.cc
@@ -35,7 +35,7 @@ namespace tensorflow {
using namespace ops; // NOLINT(build/namespaces)
TEST(MfccOpTest, SimpleTest) {
- Scope root = Scope::NewRootScope();
+ Scope root = Scope::DisabledShapeInferenceScope();
Tensor spectrogram_tensor(DT_FLOAT, TensorShape({1, 1, 513}));
test::FillIota<float>(&spectrogram_tensor, 1.0f);
diff --git a/tensorflow/core/ops/sendrecv_ops.cc b/tensorflow/core/ops/sendrecv_ops.cc
index 55f6585ade..7d0fda2f87 100644
--- a/tensorflow/core/ops/sendrecv_ops.cc
+++ b/tensorflow/core/ops/sendrecv_ops.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
namespace tensorflow {
@@ -26,6 +27,7 @@ REGISTER_OP("_Send")
.Attr("recv_device: string")
.Attr("client_terminated: bool = false")
.SetIsStateful()
+ .SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
Sends the named tensor from send_device to recv_device.
@@ -49,6 +51,7 @@ REGISTER_OP("_Recv")
.Attr("recv_device: string")
.Attr("client_terminated: bool = false")
.SetIsStateful()
+ .SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
Receives the named tensor from send_device on recv_device.
@@ -72,6 +75,7 @@ REGISTER_OP("_HostSend")
.Attr("recv_device: string")
.Attr("client_terminated: bool = false")
.SetIsStateful()
+ .SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
Sends the named tensor from send_device to recv_device.
@@ -98,6 +102,7 @@ REGISTER_OP("_HostRecv")
.Attr("recv_device: string")
.Attr("client_terminated: bool = false")
.SetIsStateful()
+ .SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
Receives the named tensor from send_device on recv_device.
diff --git a/tensorflow/tools/graph_transforms/fake_quantize_training_test.cc b/tensorflow/tools/graph_transforms/fake_quantize_training_test.cc
index 3ea7f512c6..5e4ab209e9 100644
--- a/tensorflow/tools/graph_transforms/fake_quantize_training_test.cc
+++ b/tensorflow/tools/graph_transforms/fake_quantize_training_test.cc
@@ -36,7 +36,7 @@ class FakeQuantizeTrainingTest : public ::testing::Test {};
// TODO(suharshs): Once we implement the fake_quantize_training transform
// using the GTT, write proper tests of the transform here.
TEST_F(FakeQuantizeTrainingTest, TransformOccurred) {
- auto root = tensorflow::Scope::NewRootScope();
+ auto root = tensorflow::Scope::DisabledShapeInferenceScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
Tensor a_data(DT_FLOAT, TensorShape());
diff --git a/tensorflow/tools/graph_transforms/quantize_weights_test.cc b/tensorflow/tools/graph_transforms/quantize_weights_test.cc
index e1828831db..a58ce73453 100644
--- a/tensorflow/tools/graph_transforms/quantize_weights_test.cc
+++ b/tensorflow/tools/graph_transforms/quantize_weights_test.cc
@@ -40,7 +40,7 @@ class QuantizeWeightsTest : public ::testing::Test {
const TensorShape& weight_shape,
std::initializer_list<float> weight_values,
GraphDef* original_graph_def) {
- auto root = tensorflow::Scope::NewRootScope();
+ auto root = tensorflow::Scope::DisabledShapeInferenceScope();
Tensor input_data(DT_FLOAT, input_shape);
test::FillValues<float>(&input_data, input_values);
diff --git a/tensorflow/tools/graph_transforms/transform_utils_test.cc b/tensorflow/tools/graph_transforms/transform_utils_test.cc
index b5bc2d75fd..abb25916c5 100644
--- a/tensorflow/tools/graph_transforms/transform_utils_test.cc
+++ b/tensorflow/tools/graph_transforms/transform_utils_test.cc
@@ -622,7 +622,7 @@ class TransformUtilsTest : public ::testing::Test {
}
void TestRenameNodeInputsWithWildcard() {
- auto root = tensorflow::Scope::NewRootScope();
+ auto root = tensorflow::Scope::DisabledShapeInferenceScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
const int width = 10;