diff options
-rw-r--r-- | tensorflow/core/framework/function.h | 11 | ||||
-rw-r--r-- | tensorflow/core/framework/op.cc | 5 | ||||
-rw-r--r-- | tensorflow/core/framework/op.h | 82 | ||||
-rw-r--r-- | tensorflow/core/framework/op_kernel.h | 14 | ||||
-rw-r--r-- | tensorflow/core/framework/selective_registration.h | 55 |
5 files changed, 138 insertions, 29 deletions
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index 9ce2d1fe8b..fcfcf162ad 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/selective_registration.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/protobuf.h" @@ -366,16 +367,6 @@ class FunctionLibraryRuntime { // // TODO(zhifengc): Better documentation somewhere. -#ifdef SELECTIVE_REGISTRATION -// Experimental selective registration support to reduce binary size. -// If kRequiresSymbolicGradients is false, then no gradient ops are registered -// and their code will be stripped out during the link phase. -#include "ops_to_register.h" -#define SHOULD_REGISTER_OP_GRADIENT kRequiresSymbolicGradients -#else -#define SHOULD_REGISTER_OP_GRADIENT true -#endif - // Macros to define a gradient function factory for a primitive // operation. #define REGISTER_OP_GRADIENT(name, fn) \ diff --git a/tensorflow/core/framework/op.cc b/tensorflow/core/framework/op.cc index f90b0a555a..06d45a40d2 100644 --- a/tensorflow/core/framework/op.cc +++ b/tensorflow/core/framework/op.cc @@ -176,9 +176,10 @@ const OpDef* OpListOpRegistry::LookUp(const string& op_type_name, // Other registration --------------------------------------------------------- namespace register_op { -OpDefBuilderReceiver::OpDefBuilderReceiver(const OpDefBuilder& builder) { +OpDefBuilderReceiver::OpDefBuilderReceiver( + const OpDefBuilderWrapper<true>& wrapper) { OpDef op_def; - builder.Finalize(&op_def); + wrapper.builder().Finalize(&op_def); OpRegistry::Global()->Register(op_def); } } // namespace register_op diff --git a/tensorflow/core/framework/op.h b/tensorflow/core/framework/op.h index 1bc822e1d8..bf887982fd 100644 --- a/tensorflow/core/framework/op.h +++ b/tensorflow/core/framework/op.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/op_def_builder.h" #include "tensorflow/core/framework/op_def_util.h" +#include "tensorflow/core/framework/selective_registration.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -166,20 +167,91 @@ extern "C" void RegisterOps(void* registry_ptr); // For details, see the OpDefBuilder class in op_def_builder.h. namespace register_op { + +// OpDefBuilderWrapper is a templated class that is used in the REGISTER_OP +// calls. This allows the result of REGISTER_OP to be used in chaining, as in +// REGISTER_OP(a).Attr("...").Input("...");, while still allowing selective +// registration to turn the entire call-chain into a no-op. +template <bool should_register> +class OpDefBuilderWrapper; + +// Template specialization that forwards all calls to the contained builder. +template <> +class OpDefBuilderWrapper<true> { + public: + typedef OpDefBuilderWrapper<true> WrapperType; + OpDefBuilderWrapper(const char name[]) : builder_(name) {} + OpDefBuilderWrapper<true>& Attr(StringPiece spec) { + builder_.Attr(spec); + return *this; + } + OpDefBuilderWrapper<true>& Input(StringPiece spec) { + builder_.Input(spec); + return *this; + } + OpDefBuilderWrapper<true>& Output(StringPiece spec) { + builder_.Output(spec); + return *this; + } + OpDefBuilderWrapper<true>& SetIsCommutative() { + builder_.SetIsCommutative(); + return *this; + } + OpDefBuilderWrapper<true>& SetIsAggregate() { + builder_.SetIsAggregate(); + return *this; + } + OpDefBuilderWrapper<true>& SetIsStateful() { + builder_.SetIsStateful(); + return *this; + } + OpDefBuilderWrapper<true>& SetAllowsUninitializedInput() { + builder_.SetAllowsUninitializedInput(); + return *this; + } + OpDefBuilderWrapper<true>& Doc(StringPiece text) { + builder_.Doc(text); + return *this; + } + const ::tensorflow::OpDefBuilder& builder() const { return builder_; } + + private: + mutable ::tensorflow::OpDefBuilder builder_; +}; + +// Template specialization that turns all calls into no-ops. +template <> +class OpDefBuilderWrapper<false> { + public: + constexpr OpDefBuilderWrapper(const char name[]) {} + OpDefBuilderWrapper<false>& Attr(StringPiece spec) { return *this; } + OpDefBuilderWrapper<false>& Input(StringPiece spec) { return *this; } + OpDefBuilderWrapper<false>& Output(StringPiece spec) { return *this; } + OpDefBuilderWrapper<false>& SetIsCommutative() { return *this; } + OpDefBuilderWrapper<false>& SetIsAggregate() { return *this; } + OpDefBuilderWrapper<false>& SetIsStateful() { return *this; } + OpDefBuilderWrapper<false>& SetAllowsUninitializedInput() { return *this; } + OpDefBuilderWrapper<false>& Doc(StringPiece text) { return *this; } +}; + struct OpDefBuilderReceiver { // To call OpRegistry::Global()->Register(...), used by the // REGISTER_OP macro below. - // Note: This is an implicitly converting constructor. + // Note: These are implicitly converting constructors. OpDefBuilderReceiver( - const OpDefBuilder& builder); // NOLINT(runtime/explicit) + const OpDefBuilderWrapper<true>& wrapper); // NOLINT(runtime/explicit) + constexpr OpDefBuilderReceiver(const OpDefBuilderWrapper<false>&) { + } // NOLINT(runtime/explicit) }; } // namespace register_op #define REGISTER_OP(name) REGISTER_OP_UNIQ_HELPER(__COUNTER__, name) #define REGISTER_OP_UNIQ_HELPER(ctr, name) REGISTER_OP_UNIQ(ctr, name) -#define REGISTER_OP_UNIQ(ctr, name) \ - static ::tensorflow::register_op::OpDefBuilderReceiver register_op##ctr \ - TF_ATTRIBUTE_UNUSED = ::tensorflow::OpDefBuilder(name) +#define REGISTER_OP_UNIQ(ctr, name) \ + static ::tensorflow::register_op::OpDefBuilderReceiver register_op##ctr \ + TF_ATTRIBUTE_UNUSED = \ + ::tensorflow::register_op::OpDefBuilderWrapper<SHOULD_REGISTER_OP( \ + name)>(name) } // namespace tensorflow diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index edf3b8a415..46bae0a0b6 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/rendezvous.h" +#include "tensorflow/core/framework/selective_registration.h" #include "tensorflow/core/framework/step_stats.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -1022,17 +1023,6 @@ namespace register_kernel { typedef ::tensorflow::KernelDefBuilder Name; } // namespace register_kernel -#ifdef SELECTIVE_REGISTRATION -// Experimental selective registration support to reduce binary size. -// Files which are not included in the whitelist provided by this -// graph-specific header file will not be allowed to register their -// operators, thus resulting in them being stripped out during the link phase. -#include "ops_to_register.h" -#define SHOULD_REGISTER_OP(filename) \ - (strstr(kNecessaryOpFiles, filename) != nullptr) -#else -#define SHOULD_REGISTER_OP(filename) true -#endif #define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \ REGISTER_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, __VA_ARGS__) @@ -1043,7 +1033,7 @@ typedef ::tensorflow::KernelDefBuilder Name; #define REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...) \ static ::tensorflow::kernel_factory::OpKernelRegistrar \ registrar__body__##ctr##__object( \ - SHOULD_REGISTER_OP(__FILE__) \ + SHOULD_REGISTER_OP_KERNEL(__FILE__) \ ? ::tensorflow::register_kernel::kernel_builder.Build() \ : nullptr, \ [](::tensorflow::OpKernelConstruction* context) \ diff --git a/tensorflow/core/framework/selective_registration.h b/tensorflow/core/framework/selective_registration.h new file mode 100644 index 0000000000..78fe5b7d9f --- /dev/null +++ b/tensorflow/core/framework/selective_registration.h @@ -0,0 +1,55 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_SELECTIVE_REGISTRATION_H_ +#define TENSORFLOW_FRAMEWORK_SELECTIVE_REGISTRATION_H_ + +#include <string.h> + +#ifdef SELECTIVE_REGISTRATION + +// Experimental selective registration support to reduce binary size. +// +// To use selective registration, when building: +// 1. define SELECTIVE_REGISTRATION, e.g. in gcc by passing +// -DSELECTIVE_REGISTRATION to compilation. +// 2. Provide ops_to_register.h. This file is not included in the repo and must +// be placed by the user or a tool where the compiler can find it. It must +// define the constants and functions used in the macros below. The +// functions should be defined as valid constexpr functions, so that they are +// evaluated at compile time: this is needed to make symbols referenced by +// un-registered objects unused, and therefore allow the linker to strip them +// out. +#include "ops_to_register.h" + +// Files which are not included in the whitelist provided by this +// graph-specific header file will not be allowed to register their +// operator kernels. +#define SHOULD_REGISTER_OP_KERNEL(filename) \ + (strstr(kNecessaryOpFiles, filename) != nullptr) + +// Ops for which ShouldRegisterOp return false will no be registered. +#define SHOULD_REGISTER_OP(op) ShouldRegisterOp(op) + +// If kRequiresSymbolicGradients is false, then no gradient ops are registered. +#define SHOULD_REGISTER_OP_GRADIENT kRequiresSymbolicGradients + +#else +#define SHOULD_REGISTER_OP_KERNEL(filename) true +#define SHOULD_REGISTER_OP(op) true +#define SHOULD_REGISTER_OP_GRADIENT true +#endif + +#endif // TENSORFLOW_FRAMEWORK_SELECTIVE_REGISTRATION_H_ |