diff options
author | Skye Wanderman-Milne <skyewm@google.com> | 2017-08-15 16:46:16 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-08-15 16:50:32 -0700 |
commit | 477d49c9eaafc5e1e1667d454ce5883956180713 (patch) | |
tree | 02ef20e777b6d686a6fa97327f761b40ff91c738 /tensorflow/cc/framework | |
parent | 9ba0abc2f06fe2d09c96487d5170b2faec79d2a4 (diff) |
C++ API: run shape inference as nodes are constructed
Here's an example of the new generated code:
AddN::AddN(const ::tensorflow::Scope& scope, ::tensorflow::InputList inputs) {
if (!scope.ok()) return;
auto _inputs = ::tensorflow::ops::AsNodeOutList(scope, inputs);
if (!scope.ok()) return;
::tensorflow::Node* ret;
const auto unique_name = scope.GetUniqueNameForOp("AddN");
auto builder = ::tensorflow::NodeBuilder(unique_name, "AddN")
.Input(_inputs)
;
scope.UpdateBuilder(&builder);
scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
if (!scope.ok()) return;
scope.UpdateStatus(scope.DoShapeInference(ret));
this->sum = Output(ret, 0);
}
Enabling shape inference unfortunately broke many tests. I fixed some of them, but for others I introduced a Scope::DisabledShapeInferenceScope() static method that returns a scope that doesn't perform shape inference. Eventually we should fix the tests that use this and remove it.
PiperOrigin-RevId: 165378429
Diffstat (limited to 'tensorflow/cc/framework')
-rw-r--r-- | tensorflow/cc/framework/cc_op_gen.cc | 8 | ||||
-rw-r--r-- | tensorflow/cc/framework/scope.cc | 48 | ||||
-rw-r--r-- | tensorflow/cc/framework/scope.h | 12 | ||||
-rw-r--r-- | tensorflow/cc/framework/scope_internal.h | 7 | ||||
-rw-r--r-- | tensorflow/cc/framework/test_op.cc | 18 |
5 files changed, 70 insertions, 23 deletions
diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index 80dd272f6f..38a17598b8 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -812,12 +812,8 @@ string OpInfo::GetConstructorBody() const { strings::StrAppend(&body, " ", scope_str, ".UpdateStatus(builder.Finalize(", scope_str, ".graph(), &ret));\n"); strings::StrAppend(&body, " ", return_on_error, "\n"); - - // TODO(b/28152992): Enable this code-path once we have converted - // all python shape functions to call their C++ versions. - - // strings::StrAppend(&body, " ", scope_str, ".UpdateStatus(", scope_str, - // ".refiner()->AddNode(ret));\n"); + strings::StrAppend(&body, " ", scope_str, ".UpdateStatus(", scope_str, + ".DoShapeInference(ret));\n"); GetOutput(&body); return body; diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc index 4705b6b7e8..7164249262 100644 --- a/tensorflow/cc/framework/scope.cc +++ b/tensorflow/cc/framework/scope.cc @@ -37,13 +37,14 @@ Scope& Scope::operator=(const Scope& other) { } Scope::Impl::Impl(Graph* graph, Status* status, NameMap* name_map, - ShapeRefiner* refiner) + ShapeRefiner* refiner, bool disable_shape_inference) : graph_(graph), status_(status), name_map_(name_map), refiner_(refiner), scope_used_(nullptr), - colocation_constraints_() {} + colocation_constraints_(), + disable_shape_inference_(disable_shape_inference) {} Scope::Impl::Impl(const std::shared_ptr<Graph>& graph, const std::shared_ptr<Status>& status, @@ -54,13 +55,23 @@ Scope::Impl::Impl(const std::shared_ptr<Graph>& graph, name_map_(name_map), refiner_(refiner), scope_used_(nullptr), - colocation_constraints_() {} + colocation_constraints_(), + disable_shape_inference_(false) {} Scope Scope::NewRootScope() { Graph* graph = new Graph(OpRegistry::Global()); ShapeRefiner* refiner = new ShapeRefiner(graph->versions(), graph->op_registry()); - return Scope(new Impl(graph, new Status, new Impl::NameMap, refiner)); + return Scope(new Impl(graph, new Status, new Impl::NameMap, refiner, + /* disable_shape_inference */ false)); +} + +Scope Scope::DisabledShapeInferenceScope() { + Graph* graph = new Graph(OpRegistry::Global()); + ShapeRefiner* refiner = + new ShapeRefiner(graph->versions(), graph->op_registry()); + return Scope(new Impl(graph, new Status, new Impl::NameMap, refiner, + /* disable_shape_inference */ true)); } Scope::Impl::Impl(const Scope& other, Tags::ScopeName, const string& name, @@ -77,7 +88,8 @@ Scope::Impl::Impl(const Scope& other, Tags::ScopeName, const string& name, exit_on_error_(other.impl()->exit_on_error_), kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), - colocation_constraints_(other.impl()->colocation_constraints_) {} + colocation_constraints_(other.impl()->colocation_constraints_), + disable_shape_inference_(other.impl()->disable_shape_inference_) {} Scope::Impl::Impl(const Scope& other, Tags::OpName, const string& name, const string& op_name) @@ -92,7 +104,8 @@ Scope::Impl::Impl(const Scope& other, Tags::OpName, const string& name, exit_on_error_(other.impl()->exit_on_error_), kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), - colocation_constraints_(other.impl()->colocation_constraints_) {} + colocation_constraints_(other.impl()->colocation_constraints_), + disable_shape_inference_(other.impl()->disable_shape_inference_) {} Scope::Impl::Impl(const Scope& other, Tags::ControlDeps, std::vector<Operation> control_deps, bool clear_control_deps) @@ -113,7 +126,8 @@ Scope::Impl::Impl(const Scope& other, Tags::ControlDeps, exit_on_error_(other.impl()->exit_on_error_), kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), - colocation_constraints_(other.impl()->colocation_constraints_) {} + colocation_constraints_(other.impl()->colocation_constraints_), + disable_shape_inference_(other.impl()->disable_shape_inference_) {} Scope::Impl::Impl(const Scope& other, Tags::Device, const string& device) : graph_(other.impl()->graph_), @@ -127,7 +141,8 @@ Scope::Impl::Impl(const Scope& other, Tags::Device, const string& device) exit_on_error_(other.impl()->exit_on_error_), kernel_label_(other.impl()->kernel_label_), device_(device), - colocation_constraints_(other.impl()->colocation_constraints_) {} + colocation_constraints_(other.impl()->colocation_constraints_), + disable_shape_inference_(other.impl()->disable_shape_inference_) {} Scope::Impl::Impl(const Scope& other, Tags::SingleUseScope, const string& op_name) @@ -142,7 +157,8 @@ Scope::Impl::Impl(const Scope& other, Tags::SingleUseScope, exit_on_error_(other.impl()->exit_on_error_), kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), - colocation_constraints_(other.impl()->colocation_constraints_) {} + colocation_constraints_(other.impl()->colocation_constraints_), + disable_shape_inference_(other.impl()->disable_shape_inference_) {} Scope::Impl::Impl(const Scope& other, Tags::ExitOnError) : graph_(other.impl()->graph_), @@ -156,7 +172,8 @@ Scope::Impl::Impl(const Scope& other, Tags::ExitOnError) exit_on_error_(true), kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), - colocation_constraints_(other.impl()->colocation_constraints_) {} + colocation_constraints_(other.impl()->colocation_constraints_), + disable_shape_inference_(other.impl()->disable_shape_inference_) {} Scope::Impl::Impl(const Scope& other, Tags::KernelLabel, const string& kernel_label) @@ -171,7 +188,8 @@ Scope::Impl::Impl(const Scope& other, Tags::KernelLabel, exit_on_error_(other.impl()->exit_on_error_), kernel_label_(kernel_label), device_(other.impl()->device_), - colocation_constraints_(other.impl()->colocation_constraints_) {} + colocation_constraints_(other.impl()->colocation_constraints_), + disable_shape_inference_(other.impl()->disable_shape_inference_) {} Scope::Impl::Impl(const Scope& other, Tags::Colocate, const Operation& colocate_with_op, bool clear_colocations) @@ -189,7 +207,8 @@ Scope::Impl::Impl(const Scope& other, Tags::Colocate, colocation_constraints_( clear_colocations ? std::unordered_set<string>() - : other.impl()->GetColocationConstraints(colocate_with_op)) {} + : other.impl()->GetColocationConstraints(colocate_with_op)), + disable_shape_inference_(other.impl()->disable_shape_inference_) {} std::unordered_set<string> Scope::Impl::GetColocationConstraints( const Operation& colocate_with_op) const { @@ -404,6 +423,11 @@ CompositeOpScopes Scope::GetCompositeOpScopes( } } +Status Scope::DoShapeInference(Node* node) const { + if (impl_->disable_shape_inference_) return Status::OK(); + return impl_->refiner_->AddNode(node); +} + class InternalScope { public: // NewScope doesn't take ownership of the inputs. diff --git a/tensorflow/cc/framework/scope.h b/tensorflow/cc/framework/scope.h index ec3543772d..5cae5c64ad 100644 --- a/tensorflow/cc/framework/scope.h +++ b/tensorflow/cc/framework/scope.h @@ -199,6 +199,18 @@ class Scope { // edges from the source and to the sink node, resolves back edges // by name), and makes sure the resulting graph is valid. Status ToGraph(Graph* g) const; + + // Calls AddNode() using this scope's ShapeRefiner. This exists in the public + // API to prevent custom op wrappers from needing access to shape_refiner.h or + // scope_internal.h. + // TODO(skyewm): remove this from public API + Status DoShapeInference(Node* node) const; + + // Creates a new root scope that causes all DoShapeInference() calls to return + // Status::OK() (on the returned scope and any subscopes). Used for testing. + // TODO(skyewm): fix tests that still require this and eventually remove, or + // at least remove from public API + static Scope DisabledShapeInferenceScope(); // END_SKIP_DOXYGEN const std::vector<Operation>& control_deps() const; diff --git a/tensorflow/cc/framework/scope_internal.h b/tensorflow/cc/framework/scope_internal.h index 3656c0aecf..e2cc22af5d 100644 --- a/tensorflow/cc/framework/scope_internal.h +++ b/tensorflow/cc/framework/scope_internal.h @@ -58,7 +58,8 @@ class Scope::Impl { enum class Colocate; }; - Impl(Graph* graph, Status* status, NameMap* name_map, ShapeRefiner* refiner); + Impl(Graph* graph, Status* status, NameMap* name_map, ShapeRefiner* refiner, + bool disable_shape_inference); Impl(const Scope& other, Tags::ScopeName, const string& name, bool copy_names); Impl(const Scope& other, Tags::OpName, const string& name, @@ -103,6 +104,10 @@ class Scope::Impl { const string kernel_label_ = ""; const string device_ = ""; const std::unordered_set<string> colocation_constraints_; + + // If true, Scope::DoShapeInference() always returns Status:OK(). + // TODO(skyewm): remove this when possible + const bool disable_shape_inference_; }; } // namespace tensorflow diff --git a/tensorflow/cc/framework/test_op.cc b/tensorflow/cc/framework/test_op.cc index fe0d907df0..b76842a9a0 100644 --- a/tensorflow/cc/framework/test_op.cc +++ b/tensorflow/cc/framework/test_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" namespace tensorflow { @@ -24,6 +25,7 @@ REGISTER_OP("ThrowAway1") .Attr("scope: int") .Attr("builder: int = 1") .Attr("while: int") + .SetShapeFn(shape_inference::UnknownShape) .Doc(R"doc( Op to test keywords and reserved words in input and attr names. @@ -36,12 +38,20 @@ REGISTER_OP("ThrowAway2") .Attr("scope: int = 2") .Attr("throw_away2: int = 2") .Attr("attrs: int = 4") - .Attr("node: int = 4"); + .Attr("node: int = 4") + .SetShapeFn(shape_inference::UnknownShape); -REGISTER_OP("ThrowAway3").Output("node: int32"); +REGISTER_OP("ThrowAway3") + .Output("node: int32") + .SetShapeFn(shape_inference::UnknownShape); -REGISTER_OP("ThrowAway4").Input("node: int32"); +REGISTER_OP("ThrowAway4") + .Input("node: int32") + .SetShapeFn(shape_inference::UnknownShape); -REGISTER_OP("ThrowAway5").Output("foo: int32").Attr("node: int = 4"); +REGISTER_OP("ThrowAway5") + .Output("foo: int32") + .Attr("node: int = 4") + .SetShapeFn(shape_inference::UnknownShape); } // namespace tensorflow |