aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/framework
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2017-08-15 16:46:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-15 16:50:32 -0700
commit477d49c9eaafc5e1e1667d454ce5883956180713 (patch)
tree02ef20e777b6d686a6fa97327f761b40ff91c738 /tensorflow/cc/framework
parent9ba0abc2f06fe2d09c96487d5170b2faec79d2a4 (diff)
C++ API: run shape inference as nodes are constructed
Here's an example of the new generated code: AddN::AddN(const ::tensorflow::Scope& scope, ::tensorflow::InputList inputs) { if (!scope.ok()) return; auto _inputs = ::tensorflow::ops::AsNodeOutList(scope, inputs); if (!scope.ok()) return; ::tensorflow::Node* ret; const auto unique_name = scope.GetUniqueNameForOp("AddN"); auto builder = ::tensorflow::NodeBuilder(unique_name, "AddN") .Input(_inputs) ; scope.UpdateBuilder(&builder); scope.UpdateStatus(builder.Finalize(scope.graph(), &ret)); if (!scope.ok()) return; scope.UpdateStatus(scope.DoShapeInference(ret)); this->sum = Output(ret, 0); } Enabling shape inference unfortunately broke many tests. I fixed some of them, but for others I introduced a Scope::DisabledShapeInferenceScope() static method that returns a scope that doesn't perform shape inference. Eventually we should fix the tests that use this and remove it. PiperOrigin-RevId: 165378429
Diffstat (limited to 'tensorflow/cc/framework')
-rw-r--r--tensorflow/cc/framework/cc_op_gen.cc8
-rw-r--r--tensorflow/cc/framework/scope.cc48
-rw-r--r--tensorflow/cc/framework/scope.h12
-rw-r--r--tensorflow/cc/framework/scope_internal.h7
-rw-r--r--tensorflow/cc/framework/test_op.cc18
5 files changed, 70 insertions, 23 deletions
diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc
index 80dd272f6f..38a17598b8 100644
--- a/tensorflow/cc/framework/cc_op_gen.cc
+++ b/tensorflow/cc/framework/cc_op_gen.cc
@@ -812,12 +812,8 @@ string OpInfo::GetConstructorBody() const {
strings::StrAppend(&body, " ", scope_str, ".UpdateStatus(builder.Finalize(",
scope_str, ".graph(), &ret));\n");
strings::StrAppend(&body, " ", return_on_error, "\n");
-
- // TODO(b/28152992): Enable this code-path once we have converted
- // all python shape functions to call their C++ versions.
-
- // strings::StrAppend(&body, " ", scope_str, ".UpdateStatus(", scope_str,
- // ".refiner()->AddNode(ret));\n");
+ strings::StrAppend(&body, " ", scope_str, ".UpdateStatus(", scope_str,
+ ".DoShapeInference(ret));\n");
GetOutput(&body);
return body;
diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc
index 4705b6b7e8..7164249262 100644
--- a/tensorflow/cc/framework/scope.cc
+++ b/tensorflow/cc/framework/scope.cc
@@ -37,13 +37,14 @@ Scope& Scope::operator=(const Scope& other) {
}
Scope::Impl::Impl(Graph* graph, Status* status, NameMap* name_map,
- ShapeRefiner* refiner)
+ ShapeRefiner* refiner, bool disable_shape_inference)
: graph_(graph),
status_(status),
name_map_(name_map),
refiner_(refiner),
scope_used_(nullptr),
- colocation_constraints_() {}
+ colocation_constraints_(),
+ disable_shape_inference_(disable_shape_inference) {}
Scope::Impl::Impl(const std::shared_ptr<Graph>& graph,
const std::shared_ptr<Status>& status,
@@ -54,13 +55,23 @@ Scope::Impl::Impl(const std::shared_ptr<Graph>& graph,
name_map_(name_map),
refiner_(refiner),
scope_used_(nullptr),
- colocation_constraints_() {}
+ colocation_constraints_(),
+ disable_shape_inference_(false) {}
Scope Scope::NewRootScope() {
Graph* graph = new Graph(OpRegistry::Global());
ShapeRefiner* refiner =
new ShapeRefiner(graph->versions(), graph->op_registry());
- return Scope(new Impl(graph, new Status, new Impl::NameMap, refiner));
+ return Scope(new Impl(graph, new Status, new Impl::NameMap, refiner,
+ /* disable_shape_inference */ false));
+}
+
+Scope Scope::DisabledShapeInferenceScope() {
+ Graph* graph = new Graph(OpRegistry::Global());
+ ShapeRefiner* refiner =
+ new ShapeRefiner(graph->versions(), graph->op_registry());
+ return Scope(new Impl(graph, new Status, new Impl::NameMap, refiner,
+ /* disable_shape_inference */ true));
}
Scope::Impl::Impl(const Scope& other, Tags::ScopeName, const string& name,
@@ -77,7 +88,8 @@ Scope::Impl::Impl(const Scope& other, Tags::ScopeName, const string& name,
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
- colocation_constraints_(other.impl()->colocation_constraints_) {}
+ colocation_constraints_(other.impl()->colocation_constraints_),
+ disable_shape_inference_(other.impl()->disable_shape_inference_) {}
Scope::Impl::Impl(const Scope& other, Tags::OpName, const string& name,
const string& op_name)
@@ -92,7 +104,8 @@ Scope::Impl::Impl(const Scope& other, Tags::OpName, const string& name,
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
- colocation_constraints_(other.impl()->colocation_constraints_) {}
+ colocation_constraints_(other.impl()->colocation_constraints_),
+ disable_shape_inference_(other.impl()->disable_shape_inference_) {}
Scope::Impl::Impl(const Scope& other, Tags::ControlDeps,
std::vector<Operation> control_deps, bool clear_control_deps)
@@ -113,7 +126,8 @@ Scope::Impl::Impl(const Scope& other, Tags::ControlDeps,
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
- colocation_constraints_(other.impl()->colocation_constraints_) {}
+ colocation_constraints_(other.impl()->colocation_constraints_),
+ disable_shape_inference_(other.impl()->disable_shape_inference_) {}
Scope::Impl::Impl(const Scope& other, Tags::Device, const string& device)
: graph_(other.impl()->graph_),
@@ -127,7 +141,8 @@ Scope::Impl::Impl(const Scope& other, Tags::Device, const string& device)
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(device),
- colocation_constraints_(other.impl()->colocation_constraints_) {}
+ colocation_constraints_(other.impl()->colocation_constraints_),
+ disable_shape_inference_(other.impl()->disable_shape_inference_) {}
Scope::Impl::Impl(const Scope& other, Tags::SingleUseScope,
const string& op_name)
@@ -142,7 +157,8 @@ Scope::Impl::Impl(const Scope& other, Tags::SingleUseScope,
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
- colocation_constraints_(other.impl()->colocation_constraints_) {}
+ colocation_constraints_(other.impl()->colocation_constraints_),
+ disable_shape_inference_(other.impl()->disable_shape_inference_) {}
Scope::Impl::Impl(const Scope& other, Tags::ExitOnError)
: graph_(other.impl()->graph_),
@@ -156,7 +172,8 @@ Scope::Impl::Impl(const Scope& other, Tags::ExitOnError)
exit_on_error_(true),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
- colocation_constraints_(other.impl()->colocation_constraints_) {}
+ colocation_constraints_(other.impl()->colocation_constraints_),
+ disable_shape_inference_(other.impl()->disable_shape_inference_) {}
Scope::Impl::Impl(const Scope& other, Tags::KernelLabel,
const string& kernel_label)
@@ -171,7 +188,8 @@ Scope::Impl::Impl(const Scope& other, Tags::KernelLabel,
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(kernel_label),
device_(other.impl()->device_),
- colocation_constraints_(other.impl()->colocation_constraints_) {}
+ colocation_constraints_(other.impl()->colocation_constraints_),
+ disable_shape_inference_(other.impl()->disable_shape_inference_) {}
Scope::Impl::Impl(const Scope& other, Tags::Colocate,
const Operation& colocate_with_op, bool clear_colocations)
@@ -189,7 +207,8 @@ Scope::Impl::Impl(const Scope& other, Tags::Colocate,
colocation_constraints_(
clear_colocations
? std::unordered_set<string>()
- : other.impl()->GetColocationConstraints(colocate_with_op)) {}
+ : other.impl()->GetColocationConstraints(colocate_with_op)),
+ disable_shape_inference_(other.impl()->disable_shape_inference_) {}
std::unordered_set<string> Scope::Impl::GetColocationConstraints(
const Operation& colocate_with_op) const {
@@ -404,6 +423,11 @@ CompositeOpScopes Scope::GetCompositeOpScopes(
}
}
+Status Scope::DoShapeInference(Node* node) const {
+ if (impl_->disable_shape_inference_) return Status::OK();
+ return impl_->refiner_->AddNode(node);
+}
+
class InternalScope {
public:
// NewScope doesn't take ownership of the inputs.
diff --git a/tensorflow/cc/framework/scope.h b/tensorflow/cc/framework/scope.h
index ec3543772d..5cae5c64ad 100644
--- a/tensorflow/cc/framework/scope.h
+++ b/tensorflow/cc/framework/scope.h
@@ -199,6 +199,18 @@ class Scope {
// edges from the source and to the sink node, resolves back edges
// by name), and makes sure the resulting graph is valid.
Status ToGraph(Graph* g) const;
+
+ // Calls AddNode() using this scope's ShapeRefiner. This exists in the public
+ // API to prevent custom op wrappers from needing access to shape_refiner.h or
+ // scope_internal.h.
+ // TODO(skyewm): remove this from public API
+ Status DoShapeInference(Node* node) const;
+
+ // Creates a new root scope that causes all DoShapeInference() calls to return
+ // Status::OK() (on the returned scope and any subscopes). Used for testing.
+ // TODO(skyewm): fix tests that still require this and eventually remove, or
+ // at least remove from public API
+ static Scope DisabledShapeInferenceScope();
// END_SKIP_DOXYGEN
const std::vector<Operation>& control_deps() const;
diff --git a/tensorflow/cc/framework/scope_internal.h b/tensorflow/cc/framework/scope_internal.h
index 3656c0aecf..e2cc22af5d 100644
--- a/tensorflow/cc/framework/scope_internal.h
+++ b/tensorflow/cc/framework/scope_internal.h
@@ -58,7 +58,8 @@ class Scope::Impl {
enum class Colocate;
};
- Impl(Graph* graph, Status* status, NameMap* name_map, ShapeRefiner* refiner);
+ Impl(Graph* graph, Status* status, NameMap* name_map, ShapeRefiner* refiner,
+ bool disable_shape_inference);
Impl(const Scope& other, Tags::ScopeName, const string& name,
bool copy_names);
Impl(const Scope& other, Tags::OpName, const string& name,
@@ -103,6 +104,10 @@ class Scope::Impl {
const string kernel_label_ = "";
const string device_ = "";
const std::unordered_set<string> colocation_constraints_;
+
+ // If true, Scope::DoShapeInference() always returns Status:OK().
+ // TODO(skyewm): remove this when possible
+ const bool disable_shape_inference_;
};
} // namespace tensorflow
diff --git a/tensorflow/cc/framework/test_op.cc b/tensorflow/cc/framework/test_op.cc
index fe0d907df0..b76842a9a0 100644
--- a/tensorflow/cc/framework/test_op.cc
+++ b/tensorflow/cc/framework/test_op.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
namespace tensorflow {
@@ -24,6 +25,7 @@ REGISTER_OP("ThrowAway1")
.Attr("scope: int")
.Attr("builder: int = 1")
.Attr("while: int")
+ .SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
Op to test keywords and reserved words in input and attr names.
@@ -36,12 +38,20 @@ REGISTER_OP("ThrowAway2")
.Attr("scope: int = 2")
.Attr("throw_away2: int = 2")
.Attr("attrs: int = 4")
- .Attr("node: int = 4");
+ .Attr("node: int = 4")
+ .SetShapeFn(shape_inference::UnknownShape);
-REGISTER_OP("ThrowAway3").Output("node: int32");
+REGISTER_OP("ThrowAway3")
+ .Output("node: int32")
+ .SetShapeFn(shape_inference::UnknownShape);
-REGISTER_OP("ThrowAway4").Input("node: int32");
+REGISTER_OP("ThrowAway4")
+ .Input("node: int32")
+ .SetShapeFn(shape_inference::UnknownShape);
-REGISTER_OP("ThrowAway5").Output("foo: int32").Attr("node: int = 4");
+REGISTER_OP("ThrowAway5")
+ .Output("foo: int32")
+ .Attr("node: int = 4")
+ .SetShapeFn(shape_inference::UnknownShape);
} // namespace tensorflow