aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/framework/function.h11
-rw-r--r--tensorflow/core/framework/op.cc5
-rw-r--r--tensorflow/core/framework/op.h82
-rw-r--r--tensorflow/core/framework/op_kernel.h14
-rw-r--r--tensorflow/core/framework/selective_registration.h55
5 files changed, 138 insertions, 29 deletions
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
index 9ce2d1fe8b..fcfcf162ad 100644
--- a/tensorflow/core/framework/function.h
+++ b/tensorflow/core/framework/function.h
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/selective_registration.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/protobuf.h"
@@ -366,16 +367,6 @@ class FunctionLibraryRuntime {
//
// TODO(zhifengc): Better documentation somewhere.
-#ifdef SELECTIVE_REGISTRATION
-// Experimental selective registration support to reduce binary size.
-// If kRequiresSymbolicGradients is false, then no gradient ops are registered
-// and their code will be stripped out during the link phase.
-#include "ops_to_register.h"
-#define SHOULD_REGISTER_OP_GRADIENT kRequiresSymbolicGradients
-#else
-#define SHOULD_REGISTER_OP_GRADIENT true
-#endif
-
// Macros to define a gradient function factory for a primitive
// operation.
#define REGISTER_OP_GRADIENT(name, fn) \
diff --git a/tensorflow/core/framework/op.cc b/tensorflow/core/framework/op.cc
index f90b0a555a..06d45a40d2 100644
--- a/tensorflow/core/framework/op.cc
+++ b/tensorflow/core/framework/op.cc
@@ -176,9 +176,10 @@ const OpDef* OpListOpRegistry::LookUp(const string& op_type_name,
// Other registration ---------------------------------------------------------
namespace register_op {
-OpDefBuilderReceiver::OpDefBuilderReceiver(const OpDefBuilder& builder) {
+OpDefBuilderReceiver::OpDefBuilderReceiver(
+ const OpDefBuilderWrapper<true>& wrapper) {
OpDef op_def;
- builder.Finalize(&op_def);
+ wrapper.builder().Finalize(&op_def);
OpRegistry::Global()->Register(op_def);
}
} // namespace register_op
diff --git a/tensorflow/core/framework/op.h b/tensorflow/core/framework/op.h
index 1bc822e1d8..bf887982fd 100644
--- a/tensorflow/core/framework/op.h
+++ b/tensorflow/core/framework/op.h
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/op_def_builder.h"
#include "tensorflow/core/framework/op_def_util.h"
+#include "tensorflow/core/framework/selective_registration.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
@@ -166,20 +167,91 @@ extern "C" void RegisterOps(void* registry_ptr);
// For details, see the OpDefBuilder class in op_def_builder.h.
namespace register_op {
+
+// OpDefBuilderWrapper is a templated class that is used in the REGISTER_OP
+// calls. This allows the result of REGISTER_OP to be used in chaining, as in
+// REGISTER_OP(a).Attr("...").Input("...");, while still allowing selective
+// registration to turn the entire call-chain into a no-op.
+template <bool should_register>
+class OpDefBuilderWrapper;
+
+// Template specialization that forwards all calls to the contained builder.
+template <>
+class OpDefBuilderWrapper<true> {
+ public:
+ typedef OpDefBuilderWrapper<true> WrapperType;
+ OpDefBuilderWrapper(const char name[]) : builder_(name) {}
+ OpDefBuilderWrapper<true>& Attr(StringPiece spec) {
+ builder_.Attr(spec);
+ return *this;
+ }
+ OpDefBuilderWrapper<true>& Input(StringPiece spec) {
+ builder_.Input(spec);
+ return *this;
+ }
+ OpDefBuilderWrapper<true>& Output(StringPiece spec) {
+ builder_.Output(spec);
+ return *this;
+ }
+ OpDefBuilderWrapper<true>& SetIsCommutative() {
+ builder_.SetIsCommutative();
+ return *this;
+ }
+ OpDefBuilderWrapper<true>& SetIsAggregate() {
+ builder_.SetIsAggregate();
+ return *this;
+ }
+ OpDefBuilderWrapper<true>& SetIsStateful() {
+ builder_.SetIsStateful();
+ return *this;
+ }
+ OpDefBuilderWrapper<true>& SetAllowsUninitializedInput() {
+ builder_.SetAllowsUninitializedInput();
+ return *this;
+ }
+ OpDefBuilderWrapper<true>& Doc(StringPiece text) {
+ builder_.Doc(text);
+ return *this;
+ }
+ const ::tensorflow::OpDefBuilder& builder() const { return builder_; }
+
+ private:
+ mutable ::tensorflow::OpDefBuilder builder_;
+};
+
+// Template specialization that turns all calls into no-ops.
+template <>
+class OpDefBuilderWrapper<false> {
+ public:
+ constexpr OpDefBuilderWrapper(const char name[]) {}
+ OpDefBuilderWrapper<false>& Attr(StringPiece spec) { return *this; }
+ OpDefBuilderWrapper<false>& Input(StringPiece spec) { return *this; }
+ OpDefBuilderWrapper<false>& Output(StringPiece spec) { return *this; }
+ OpDefBuilderWrapper<false>& SetIsCommutative() { return *this; }
+ OpDefBuilderWrapper<false>& SetIsAggregate() { return *this; }
+ OpDefBuilderWrapper<false>& SetIsStateful() { return *this; }
+ OpDefBuilderWrapper<false>& SetAllowsUninitializedInput() { return *this; }
+ OpDefBuilderWrapper<false>& Doc(StringPiece text) { return *this; }
+};
+
struct OpDefBuilderReceiver {
// To call OpRegistry::Global()->Register(...), used by the
// REGISTER_OP macro below.
- // Note: This is an implicitly converting constructor.
+ // Note: These are implicitly converting constructors.
OpDefBuilderReceiver(
- const OpDefBuilder& builder); // NOLINT(runtime/explicit)
+ const OpDefBuilderWrapper<true>& wrapper); // NOLINT(runtime/explicit)
+ constexpr OpDefBuilderReceiver(const OpDefBuilderWrapper<false>&) {
+ } // NOLINT(runtime/explicit)
};
} // namespace register_op
#define REGISTER_OP(name) REGISTER_OP_UNIQ_HELPER(__COUNTER__, name)
#define REGISTER_OP_UNIQ_HELPER(ctr, name) REGISTER_OP_UNIQ(ctr, name)
-#define REGISTER_OP_UNIQ(ctr, name) \
- static ::tensorflow::register_op::OpDefBuilderReceiver register_op##ctr \
- TF_ATTRIBUTE_UNUSED = ::tensorflow::OpDefBuilder(name)
+#define REGISTER_OP_UNIQ(ctr, name) \
+ static ::tensorflow::register_op::OpDefBuilderReceiver register_op##ctr \
+ TF_ATTRIBUTE_UNUSED = \
+ ::tensorflow::register_op::OpDefBuilderWrapper<SHOULD_REGISTER_OP( \
+ name)>(name)
} // namespace tensorflow
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index edf3b8a415..46bae0a0b6 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/rendezvous.h"
+#include "tensorflow/core/framework/selective_registration.h"
#include "tensorflow/core/framework/step_stats.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
@@ -1022,17 +1023,6 @@ namespace register_kernel {
typedef ::tensorflow::KernelDefBuilder Name;
} // namespace register_kernel
-#ifdef SELECTIVE_REGISTRATION
-// Experimental selective registration support to reduce binary size.
-// Files which are not included in the whitelist provided by this
-// graph-specific header file will not be allowed to register their
-// operators, thus resulting in them being stripped out during the link phase.
-#include "ops_to_register.h"
-#define SHOULD_REGISTER_OP(filename) \
- (strstr(kNecessaryOpFiles, filename) != nullptr)
-#else
-#define SHOULD_REGISTER_OP(filename) true
-#endif
#define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \
REGISTER_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, __VA_ARGS__)
@@ -1043,7 +1033,7 @@ typedef ::tensorflow::KernelDefBuilder Name;
#define REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...) \
static ::tensorflow::kernel_factory::OpKernelRegistrar \
registrar__body__##ctr##__object( \
- SHOULD_REGISTER_OP(__FILE__) \
+ SHOULD_REGISTER_OP_KERNEL(__FILE__) \
? ::tensorflow::register_kernel::kernel_builder.Build() \
: nullptr, \
[](::tensorflow::OpKernelConstruction* context) \
diff --git a/tensorflow/core/framework/selective_registration.h b/tensorflow/core/framework/selective_registration.h
new file mode 100644
index 0000000000..78fe5b7d9f
--- /dev/null
+++ b/tensorflow/core/framework/selective_registration.h
@@ -0,0 +1,55 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_FRAMEWORK_SELECTIVE_REGISTRATION_H_
+#define TENSORFLOW_FRAMEWORK_SELECTIVE_REGISTRATION_H_
+
+#include <string.h>
+
+#ifdef SELECTIVE_REGISTRATION
+
+// Experimental selective registration support to reduce binary size.
+//
+// To use selective registration, when building:
+// 1. define SELECTIVE_REGISTRATION, e.g. in gcc by passing
+// -DSELECTIVE_REGISTRATION to compilation.
+// 2. Provide ops_to_register.h. This file is not included in the repo and must
+// be placed by the user or a tool where the compiler can find it. It must
+// define the constants and functions used in the macros below. The
+// functions should be defined as valid constexpr functions, so that they are
+// evaluated at compile time: this is needed to make symbols referenced by
+// un-registered objects unused, and therefore allow the linker to strip them
+// out.
+#include "ops_to_register.h"
+
+// Files which are not included in the whitelist provided by this
+// graph-specific header file will not be allowed to register their
+// operator kernels.
+#define SHOULD_REGISTER_OP_KERNEL(filename) \
+ (strstr(kNecessaryOpFiles, filename) != nullptr)
+
+// Ops for which ShouldRegisterOp return false will no be registered.
+#define SHOULD_REGISTER_OP(op) ShouldRegisterOp(op)
+
+// If kRequiresSymbolicGradients is false, then no gradient ops are registered.
+#define SHOULD_REGISTER_OP_GRADIENT kRequiresSymbolicGradients
+
+#else
+#define SHOULD_REGISTER_OP_KERNEL(filename) true
+#define SHOULD_REGISTER_OP(op) true
+#define SHOULD_REGISTER_OP_GRADIENT true
+#endif
+
+#endif // TENSORFLOW_FRAMEWORK_SELECTIVE_REGISTRATION_H_