diff options
author | avijit-nervana <avijit.chakraborty@intel.com> | 2018-09-07 18:54:26 -0700 |
---|---|---|
committer | avijit-nervana <avijit.chakraborty@intel.com> | 2018-09-07 18:54:26 -0700 |
commit | d9a738d5fff96ecb6db62d67e049ab12202dcb42 (patch) | |
tree | 292539c9ca4036ea55ae4763d3029f32829c9722 | |
parent | 18b80bbd4b8db8bd35afad7264258c1c5c269226 (diff) | |
parent | 3e1b06ee93d7a638db1fdd5f733d66064c1acf59 (diff) |
Merge branch 'master' into avijit/add-cpu-backend
474 files changed, 10937 insertions, 4827 deletions
diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 661cba5ff0..386e0096ff 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -12,6 +12,7 @@ exports_files([ # The leakr files are used by //third_party/cloud_tpu. "leakr_badwords.dic", "leakr_badfiles.dic", + "leakr_file_type_recipe.ftrcp", ]) load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object") @@ -23,6 +24,11 @@ load( "//tensorflow/python/tools/api/generator:api_gen.bzl", "gen_api_init_files", # @unused ) +load("//tensorflow/python/tools/api/generator:api_gen.bzl", "get_compat_files") +load( + "//tensorflow/python/tools/api/generator:api_init_files.bzl", + "TENSORFLOW_API_INIT_FILES", # @unused +) load( "//tensorflow/python/tools/api/generator:api_init_files_v1.bzl", "TENSORFLOW_API_INIT_FILES_V1", # @unused @@ -32,6 +38,11 @@ load( "if_ngraph", ) +# @unused +TENSORFLOW_API_INIT_FILES_V2 = ( + TENSORFLOW_API_INIT_FILES + get_compat_files(TENSORFLOW_API_INIT_FILES_V1, 1) +) + # Config setting used when building for products # which requires restricted licenses to be avoided. config_setting( @@ -427,6 +438,13 @@ config_setting( visibility = ["//visibility:public"], ) +# This flag specifies whether TensorFlow 2.0 API should be built instead +# of 1.* API. Note that TensorFlow 2.0 API is currently under development. +config_setting( + name = "api_version_2", + define_values = {"tf_api_version": "2"}, +) + package_group( name = "internal", packages = [ @@ -591,13 +609,39 @@ exports_files( ) gen_api_init_files( - name = "tensorflow_python_api_gen", + name = "tf_python_api_gen_v1", srcs = ["api_template.__init__.py"], api_version = 1, + output_dir = "_api/v1/", output_files = TENSORFLOW_API_INIT_FILES_V1, + output_package = "tensorflow._api.v1", + root_init_template = "api_template.__init__.py", +) + +gen_api_init_files( + name = "tf_python_api_gen_v2", + srcs = ["api_template.__init__.py"], + api_version = 2, + compat_api_versions = [1], + output_dir = "_api/v2/", + output_files = TENSORFLOW_API_INIT_FILES_V2, + output_package = "tensorflow._api.v2", root_init_template = "api_template.__init__.py", ) +genrule( + name = "root_init_gen", + srcs = select({ + "api_version_2": [":tf_python_api_gen_v2"], + "//conditions:default": [":tf_python_api_gen_v1"], + }), + outs = ["__init__.py"], + cmd = select({ + "api_version_2": "cp $(@D)/_api/v2/__init__.py $(OUTS)", + "//conditions:default": "cp $(@D)/_api/v1/__init__.py $(OUTS)", + }), +) + py_library( name = "tensorflow_py", srcs = ["//tensorflow/python/estimator/api:estimator_python_api_gen"], @@ -612,7 +656,10 @@ py_library( py_library( name = "tensorflow_py_no_contrib", - srcs = [":tensorflow_python_api_gen"], + srcs = select({ + "api_version_2": [":tf_python_api_gen_v2"], + "//conditions:default": [":tf_python_api_gen_v1"], + }) + [":root_init_gen"], srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = ["//tensorflow/python:no_contrib"], diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index 779f65d5b1..53a72b8443 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -18,11 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os as _os + # pylint: disable=g-bad-import-order from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import try: - import os # pylint: disable=g-import-not-at-top # Add `estimator` attribute to allow access to estimator APIs via # "tf.estimator..." from tensorflow.python.estimator.api import estimator # pylint: disable=g-import-not-at-top @@ -30,9 +31,8 @@ try: # Add `estimator` to the __path__ to allow "from tensorflow.estimator..." # style imports. from tensorflow.python.estimator import api as estimator_api # pylint: disable=g-import-not-at-top - __path__ += [os.path.dirname(estimator_api.__file__)] + __path__ += [_os.path.dirname(estimator_api.__file__)] del estimator_api - del os except (ImportError, AttributeError): print('tf.estimator package not installed.') @@ -45,6 +45,12 @@ del LazyLoader from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top app.flags = flags # pylint: disable=undefined-variable +# Make sure directory containing top level submodules is in +# the __path__ so that "from tensorflow.foo import bar" works. +_tf_api_dir = _os.path.dirname(_os.path.dirname(app.__file__)) # pylint: disable=undefined-variable +if _tf_api_dir not in __path__: + __path__.append(_tf_api_dir) + del absolute_import del division del print_function @@ -54,6 +60,12 @@ del print_function # must come from this module. So python adds these symbols for the # resolution to succeed. # pylint: disable=undefined-variable -del python -del core +try: + del python + del core +except NameError: + # Don't fail if these modules are not available. + # For e.g. we are using this file for compat.v1 module as well and + # 'python', 'core' directories are not under compat/v1. + pass # pylint: enable=undefined-variable diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 109b3b37aa..43c279bd80 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -204,6 +204,7 @@ tf_cuda_cc_test( "//tensorflow:darwin": ["-headerpad_max_install_names"], "//conditions:default": [], }), + tags = ["noasan"], # We must ensure that the dependencies can be dynamically linked since # the shared library must be able to use core:framework. # linkstatic = tf_kernel_tests_linkstatic(), diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index 69b3ffe2a1..c046bd66cd 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -79,6 +79,18 @@ TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation, auto* gpu_options = config.mutable_gpu_options(); gpu_options->set_allow_growth(gpu_memory_allow_growth); + // TODO(b/113217601): This is needed for EagerContext::runner_ to use a + // threadpool, so that we avoid the possibility of running the runner_ in the + // threadpool of GPU event mgr, as that can trigger more callbacks to be + // scheduled on that same threadpool, causing a deadlock in cases where the + // caller of event_mgr->ThenExecute() blocks on the completion of the callback + // (as in the case of ConstOp kernel creation on GPU, which involves copying a + // CPU tensor to GPU). + // Setting a larger thread pool does not help with the Swift caller, as we use + // a different TFE context for each thread of execution (for running graph + // functions, and their send/recvs corountines). + config.set_inter_op_parallelism_threads(1); + TF_Buffer* ret = TF_NewBuffer(); TF_CHECK_OK(MessageToBuffer(config, ret)); return ret; @@ -8494,3 +8506,201 @@ void TF_EnqueueNamedTensor(TF_Session* session, int tensor_id, /*run_metadata*/ nullptr, status); VLOG(1) << "Enqueuing is done."; } + +TFE_Context* TFE_CreateContextFromSession(TF_Session* session, + TF_Status* status) { + auto* opts = TFE_NewContextOptions(); + + // Reduce GPU memory allocation, and set appropriate config options for TFE + // context. + auto* config = + TF_CreateConfig(/*xla*/ false, /* gpu_memory_allow_growth */ true); + TFE_ContextOptionsSetConfig(opts, config->data, config->length, status); + if (!status->status.ok()) { + CHECK(!config); + TFE_DeleteContextOptions(opts); + return nullptr; + } + + auto* ctx = TFE_NewContextFromSession(opts, session, status); + TF_DeleteBuffer(config); + TFE_DeleteContextOptions(opts); + return ctx; +} + +// TODO: retrieve the device string via TFE_ContextListDevices() +static const char DEFAULT_CPU_DEVICE[] = + "/job:localhost/replica:0/task:0/device:CPU:0"; + +static TFE_TensorHandle* createTFEQueue(TFE_Context* ctx, TF_DataType inputType, + int tensor_id, TF_Status* status) { + std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> queueOp( + TFE_NewOp(ctx, "FIFOQueueV2", status), TFE_DeleteOp); + TFE_OpSetDevice(queueOp.get(), DEFAULT_CPU_DEVICE, status); + if (!status->status.ok()) return nullptr; + // TODO: use NAMED_TENSOR_QUEUE_CAPACITY in S4TF compiler. + TFE_OpSetAttrInt(queueOp.get(), "capacity", 1); + TFE_OpSetAttrTypeList(queueOp.get(), "component_types", &inputType, 1); + auto shared_name = tensorflow::strings::StrCat("fifo_queue_", tensor_id); + TFE_OpSetAttrString(queueOp.get(), "shared_name", shared_name.data(), + shared_name.size()); + TFE_OpSetAttrString(queueOp.get(), "container", "", 0); + + // TODO: consider making this an unknown shape. + const int64_t* dims_ptr = nullptr; + int num_dims = 0; + TFE_OpSetAttrShapeList(queueOp.get(), "shapes", &dims_ptr, &num_dims, + /*num_values*/ 0, status); + if (!status->status.ok()) return nullptr; + + int num_retvals = 1; + TFE_TensorHandle* queue = nullptr; + TFE_Execute(queueOp.get(), &queue, &num_retvals, status); + if (!status->status.ok()) return nullptr; + CHECK_EQ(num_retvals, 1); + + return queue; +} + +static void createTFEEnqueue(TFE_Context* ctx, TF_DataType inputType, + TFE_TensorHandle* queue, TFE_TensorHandle* tensor, + TF_Status* status) { + TFE_Op* op = TFE_NewOp(ctx, "QueueEnqueueV2", status); + if (!status->status.ok()) return; + std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op_deleter(op, TFE_DeleteOp); + TFE_OpSetDevice(op, DEFAULT_CPU_DEVICE, status); + if (!status->status.ok()) return; + TFE_OpAddInput(op, queue, status); + if (!status->status.ok()) return; + TFE_OpAddInput(op, tensor, status); + if (!status->status.ok()) return; + TFE_OpSetAttrTypeList(op, "Tcomponents", &inputType, 1); + TFE_OpSetAttrInt(op, "timeout_ms", -1); + + int num_retvals = 0; + TFE_Execute(op, nullptr /*retvals*/, &num_retvals, status); + if (!status->status.ok()) return; + CHECK_EQ(num_retvals, 0); +} + +static TFE_TensorHandle* createTFEDequeue(TFE_Context* ctx, + TF_DataType inputType, + TFE_TensorHandle* queue, + TF_Status* status) { + TFE_Op* op = TFE_NewOp(ctx, "QueueDequeueV2", status); + if (!status->status.ok()) return nullptr; + std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op_deleter(op, TFE_DeleteOp); + TFE_OpSetDevice(op, DEFAULT_CPU_DEVICE, status); + if (!status->status.ok()) return nullptr; + + TFE_OpAddInput(op, queue, status); + if (!status->status.ok()) return nullptr; + TFE_OpSetAttrTypeList(op, "component_types", &inputType, 1); + TFE_OpSetAttrInt(op, "timeout_ms", -1); + TFE_TensorHandle* ret; + int num_retvals = 1; + TFE_Execute(op, &ret, &num_retvals, status); + if (!status->status.ok()) return nullptr; + CHECK_EQ(num_retvals, 1); + return ret; +} + +TFE_TensorHandle* TFE_DequeueNamedTensor(TF_Session* session, int tensor_id, + TF_DataType inputType, + TF_Status* status) { + assert(session); + VLOG(1) << "Dequeuing data tensor with id " << tensor_id; + + auto ctx = TFE_CreateContextFromSession(session, status); + if (!status->status.ok()) return nullptr; + std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter( + ctx, TFE_DeleteContext); + + TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status); + if (!status->status.ok()) return nullptr; + std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> + queue_deleter(queue, TFE_DeleteTensorHandle); + + auto* ret = createTFEDequeue(ctx, inputType, queue, status); + return ret; +} + +TFE_TensorHandle* TFE_DequeueNamedTensorFromCtx(TFE_Context* ctx, int tensor_id, + TF_DataType inputType, + TF_Status* status) { + TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status); + if (!status->status.ok()) return nullptr; + std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> + queue_deleter(queue, TFE_DeleteTensorHandle); + + auto* ret = createTFEDequeue(ctx, inputType, queue, status); + + return ret; +} + +void TFE_EnqueueNamedTensor(TF_Session* session, int tensor_id, + TFE_TensorHandle* tensor, TF_Status* status) { + assert(session); + VLOG(1) << "Enqueuing data tensor with id " << tensor_id; + + auto ctx = TFE_CreateContextFromSession(session, status); + if (!status->status.ok()) return; + std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter( + ctx, TFE_DeleteContext); + + TF_DataType inputType = TFE_TensorHandleDataType(tensor); + TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status); + if (!status->status.ok()) return; + std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> + queue_deleter(queue, TFE_DeleteTensorHandle); + + createTFEEnqueue(ctx, inputType, queue, tensor, status); +} + +void TFE_EnqueueNamedTensorFromCtx(TFE_Context* ctx, int tensor_id, + TFE_TensorHandle* tensor, + TF_Status* status) { + VLOG(1) << "Enqueuing data tensor with id " << tensor_id; + + TF_DataType inputType = TFE_TensorHandleDataType(tensor); + TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status); + if (!status->status.ok()) return; + std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> + queue_deleter(queue, TFE_DeleteTensorHandle); + + createTFEEnqueue(ctx, inputType, queue, tensor, status); +} + +void TFE_EnqueueVariantTensor(TF_Session* session, int tensor_id, + TFE_TensorHandle* tensor, TF_Status* status) { + VLOG(1) << "Enqueuing variant tensor with id " << tensor_id; + + auto ctx = TFE_CreateContextFromSession(session, status); + if (!status->status.ok()) return; + std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter( + ctx, TFE_DeleteContext); + + TFE_TensorHandle* queue = createTFEQueue(ctx, TF_VARIANT, tensor_id, status); + if (!status->status.ok()) return; + std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> + queue_deleter(queue, TFE_DeleteTensorHandle); + + createTFEEnqueue(ctx, TF_VARIANT, queue, tensor, status); +} + +TFE_TensorHandle* TFE_DequeueVariantTensor(TF_Session* session, int tensor_id, + TF_Status* status) { + VLOG(1) << "Dequeuing variant tensor with id " << tensor_id; + + auto ctx = TFE_CreateContextFromSession(session, status); + if (!status->status.ok()) return nullptr; + std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter( + ctx, TFE_DeleteContext); + + TFE_TensorHandle* queue = createTFEQueue(ctx, TF_VARIANT, tensor_id, status); + if (!status->status.ok()) return nullptr; + std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> + queue_deleter(queue, TFE_DeleteTensorHandle); + + return createTFEDequeue(ctx, TF_VARIANT, queue, status); +} diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index 09d482d6df..522c91f67e 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -132,9 +132,48 @@ TF_CAPI_EXPORT extern void TF_EnqueueNamedTensor(TF_Session* session, TF_Tensor* tensor, TF_Status* status); +// TODO: remove this API in favor of the next one. TF_CAPI_EXPORT extern TFE_Context* TFE_NewContextFromSession( const TFE_ContextOptions* opts, TF_Session* sess, TF_Status* status); +// Creates from `session` a new eager context to run a graph function or +// sends/recvs, so that these concurrent TFE executions can share (via +// `session` and its associated device mgr) the same set of fifo queue resource +// ops, used for host<->TF tensor transfers. This way the sends/recvs calls and +// graph function execution can access the same fifo queue resource handles +// (associated with devices managed by the device manager, which can be obtained +// from `session`). +// +// TODO: Remove this function once we migrate away from using session. +TF_CAPI_EXPORT extern TFE_Context* TFE_CreateContextFromSession( + TF_Session* session, TF_Status* status); + +// TODO: Retire this API in favor of the next one. +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueNamedTensor( + TF_Session* session, int tensor_id, TF_DataType inputType, + TF_Status* status); + +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueNamedTensorFromCtx( + TFE_Context* ctx, int tensor_id, TF_DataType inputType, TF_Status* status); + +TF_CAPI_EXPORT extern void TFE_EnqueueNamedTensor(TF_Session* session, + int tensor_id, + TFE_TensorHandle* tensor, + TF_Status* status); + +TF_CAPI_EXPORT extern void TFE_EnqueueNamedTensorFromCtx( + TFE_Context* ctx, int tensor_id, TFE_TensorHandle* tensor, + TF_Status* status); + +// TODO: consider folding the 2 APIs below into the ones above. +TF_CAPI_EXPORT extern void TFE_EnqueueVariantTensor(TF_Session* session, + int tensor_id, + TFE_TensorHandle* tensor, + TF_Status* status); + +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueVariantTensor( + TF_Session* session, int tensor_id, TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 77e3878a94..349d9bcd7c 100755 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -399,6 +399,19 @@ const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) { : d->name().c_str(); } +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor( + TFE_TensorHandle* h, TF_Status* status) { + if (h == nullptr || h->handle == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "The passed in handle is a nullptr"); + return nullptr; + } + + h->handle->Ref(); + + return new TFE_TensorHandle(h->handle); +} + TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { if (h == nullptr || h->handle == nullptr) { status->status = tensorflow::errors::InvalidArgument( diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index eec2750d6e..337447eec9 100755 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -171,6 +171,12 @@ TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName( TFE_TensorHandle* h, TF_Status* status); +// Return a pointer to a new TFE_TensorHandle that shares the underlying tensor +// with `h`. On success, `status` is set to OK. On failure, `status` reflects +// the error and a nullptr is returned. +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor( + TFE_TensorHandle* h, TF_Status* status); + // This function will block till the operation that produces `h` has // completed. The memory returned might alias the internal memory used by // TensorFlow. Hence, callers should not mutate this memory (for example by diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 7126227cf5..55331022b9 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -1528,4 +1528,29 @@ TEST(CAPI, StringAttributes) { TFE_DeleteContext(ctx); TF_DeleteStatus(status); } + +TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) { + TFE_TensorHandle* h = TestMatrixTensorHandle(); + EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h)); + + std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( + TF_NewStatus(), TF_DeleteStatus); + + TFE_TensorHandle* h_shares_tensor = + TFE_TensorHandleCopySharingTensor(h, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + TF_Tensor* t = TFE_TensorHandleResolve(h_shares_tensor, status.get()); + ASSERT_EQ(16, TF_TensorByteSize(t)); + float data[4] = {0}; + memcpy(&data[0], TF_TensorData(t), TF_TensorByteSize(t)); + EXPECT_EQ(1.0, data[0]); + EXPECT_EQ(2.0, data[1]); + EXPECT_EQ(3.0, data[2]); + EXPECT_EQ(4.0, data[3]); + TF_DeleteTensor(t); + + TFE_DeleteTensorHandle(h); + TFE_DeleteTensorHandle(h_shares_tensor); +} } // namespace diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.h b/tensorflow/compiler/aot/embedded_protocol_buffers.h index cf5c04ac4b..bd270045e3 100644 --- a/tensorflow/compiler/aot/embedded_protocol_buffers.h +++ b/tensorflow/compiler/aot/embedded_protocol_buffers.h @@ -20,6 +20,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_AOT_EMBEDDED_PROTOCOL_BUFFERS_H_ #define TENSORFLOW_COMPILER_AOT_EMBEDDED_PROTOCOL_BUFFERS_H_ +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/platform/protobuf.h" diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index b95b063348..1c9d30d7b0 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" @@ -92,9 +93,8 @@ Status Main(const MainFlags& flags) { // Write output files. Env* env = Env::Default(); const std::vector<char>& obj = compile_result.aot->object_file_data(); - TF_RETURN_IF_ERROR( - WriteStringToFile(env, flags.out_function_object, - absl::string_view(obj.data(), obj.size()))); + TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_function_object, + StringPiece(obj.data(), obj.size()))); CodegenOpts codegen_opts; codegen_opts.gen_name_to_index = flags.gen_name_to_index; codegen_opts.gen_program_shape = flags.gen_program_shape; diff --git a/tensorflow/compiler/jit/legacy_flags/BUILD b/tensorflow/compiler/jit/legacy_flags/BUILD index 5b6692f523..07c5b23188 100644 --- a/tensorflow/compiler/jit/legacy_flags/BUILD +++ b/tensorflow/compiler/jit/legacy_flags/BUILD @@ -29,18 +29,6 @@ cc_library( ) cc_library( - name = "parallel_check_op_flags", - srcs = ["parallel_check_op_flags.cc"], - hdrs = ["parallel_check_op_flags.h"], - deps = - [ - "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - -cc_library( name = "xla_device_flags", srcs = ["xla_device_flags.cc"], hdrs = ["xla_device_flags.h"], diff --git a/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.cc b/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.cc deleted file mode 100644 index a61694b494..0000000000 --- a/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.cc +++ /dev/null @@ -1,68 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for the XLA bridge's parallel_check_op module. - -#include <mutex> -#include <vector> - -#include "tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static ParallelCheckOpFlags* flags; -static std::vector<Flag>* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new ParallelCheckOpFlags; - flags->parallel_check_failfast = true; - flags->parallel_check_atol = "1e-5"; - flags->parallel_check_rtol = "1e-5"; - flag_list = new std::vector<Flag>({ - Flag("parallel_check_failfast", &flags->parallel_check_failfast, - "Fail immediately on first parallel-check comparison error."), - Flag("parallel_check_atol", &flags->parallel_check_atol, - "Absolute error tolerance for parallel-check comparison."), - Flag("parallel_check_rtol", &flags->parallel_check_rtol, - "Relative error tolerance for parallel-check comparison."), - }); - xla::legacy_flags::ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with the XLA bridge's -// parallel_check_op module. -void AppendParallelCheckOpFlags(std::vector<Flag>* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the ParallelCheckOpFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -ParallelCheckOpFlags* GetParallelCheckOpFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h b/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h deleted file mode 100644 index 156a2a2a71..0000000000 --- a/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h +++ /dev/null @@ -1,52 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_PARALLEL_CHECK_OP_FLAGS_H_ -#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_PARALLEL_CHECK_OP_FLAGS_H_ - -// Legacy flags for the XLA bridge's parallel_check_op module. - -#include <vector> - -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with the XLA bridge's -// parallel_check_op module. -void AppendParallelCheckOpFlags(std::vector<tensorflow::Flag>* flag_list); - -// The values of flags associated with the XLA bridge's -// parallel_check_op module. -typedef struct { - bool parallel_check_failfast; // Fail immediately on first parallel-check - // comparison error. - string parallel_check_atol; // Absolute error tolerance for parallel-check - // comparison. - string parallel_check_rtol; // Relative error tolerance for parallel-check - // comparison. -} ParallelCheckOpFlags; - -// Return a pointer to the ParallelCheckOpFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -ParallelCheckOpFlags* GetParallelCheckOpFlags(); - -} // namespace legacy_flags -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_PARALLEL_CHECK_OP_FLAGS_H_ diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 9473ac0a4c..807ab51fd3 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -633,7 +633,7 @@ TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) { std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); Scope root = Scope::NewRootScope().ExitOnError(); { - auto BuildNoopNode = [](absl::string_view name, Graph* graph) { + auto BuildNoopNode = [](StringPiece name, Graph* graph) { NodeDefBuilder builder(name, "NoOp"); NodeDef def; TF_CHECK_OK(builder.Finalize(&def)); diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h index 17ae510a0e..debd9038c7 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.h +++ b/tensorflow/compiler/jit/xla_cluster_util.h @@ -18,6 +18,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ #define TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ +#include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/core/graph/algorithm.h" diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index af83c792e5..6d4160a968 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -339,11 +339,11 @@ void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, } void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, - absl::string_view tensor_name, + StringPiece tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done) { - manager_.CopyDeviceTensorToCPU(device_tensor, tensor_name, device, cpu_tensor, - done); + manager_.CopyDeviceTensorToCPU(device_tensor, absl::string_view(tensor_name), + device, cpu_tensor, done); } void XlaDeviceContext::CopyDeviceTensorToDevice(const Tensor& src_tensor, diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index df82421294..1effd6628f 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" namespace tensorflow { @@ -110,9 +111,12 @@ class XlaDeviceContext : public DeviceContext { void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, StatusCallback done) const override; + // TODO(rlahaye): Replace StringPiece with absl::string_view when the + // StringPiece->absl::string_view change is rolled forward. void CopyDeviceTensorToCPU(const Tensor* device_tensor, - absl::string_view tensor_name, Device* device, - Tensor* cpu_tensor, StatusCallback done) override; + StringPiece tensor_name, // non-ABSL OK + Device* device, Tensor* cpu_tensor, + StatusCallback done) override; void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor, const StatusCallback& done); diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 22be7f048f..3821dced63 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -191,6 +191,7 @@ cc_library( ":functionalize_control_flow", ":host_compute_metadata_proto", ":sharding_util", + ":side_effect_util", ":tf2xla_util", "//tensorflow/compiler/tf2xla/lib:util", "//tensorflow/compiler/xla:literal", @@ -214,6 +215,7 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], alwayslink = 1, @@ -359,6 +361,7 @@ tf_cc_test( name = "xla_compiler_test", srcs = ["xla_compiler_test.cc"], deps = [ + ":side_effect_util", ":xla_compiler", "//tensorflow/cc:cc_ops", "//tensorflow/cc:function_ops", @@ -370,6 +373,7 @@ tf_cc_test( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:core_cpu_internal", @@ -631,3 +635,12 @@ tf_cc_test( "@com_google_absl//absl/strings", ], ) + +cc_library( + name = "side_effect_util", + srcs = ["side_effect_util.cc"], + hdrs = ["side_effect_util.h"], + deps = [ + "//tensorflow/core:core_cpu", + ], +) diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 4c776fb178..46794f7b50 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -115,9 +115,6 @@ tf_kernel_library( deps = [ ":if_op", ":while_op", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/lib:batch_dot", @@ -168,14 +165,11 @@ tf_kernel_library( "//tensorflow/core/kernels:sparse_to_dense_op", "//tensorflow/core/kernels:stack_ops", "//tensorflow/core/kernels:training_ops", - ] + if_mkl( - [ - "//tensorflow/core/kernels:mkl_transpose_op", - ], - [ - "//tensorflow/core/kernels:transpose_op", - ], - ), + "//tensorflow/core/kernels:transpose_op", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], ) tf_kernel_library( @@ -184,6 +178,7 @@ tf_kernel_library( hdrs = ["while_op.h"], deps = [ "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:side_effect_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:literal", @@ -201,6 +196,7 @@ tf_kernel_library( hdrs = ["if_op.h"], deps = [ "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:side_effect_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:literal", diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc index 6e1dbf5472..56da50f140 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/if_op.h" #include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -33,6 +34,11 @@ XlaIfOp::XlaIfOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("Tcond", &cond_type_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("Tin", &input_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &output_types_)); + if (!ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_).ok()) { + has_token_input_output_ = false; + } else { + has_token_input_output_ = !token_input_nodes_.empty(); + } } // TODO(b/35949885): There is duplication here with the handling of the @@ -90,6 +96,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { options.resolve_compile_time_constants = false; options.return_updated_values_for_all_resources = true; options.is_entry_computation = false; + options.add_token_input_output = has_token_input_output_; XlaCompiler* compiler = ctx->compiler(); XlaCompiler::CompilationResult then_result; @@ -191,7 +198,16 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { std::vector<xla::XlaOp> inputs(num_inputs); for (int i = 0; i < num_inputs; ++i) { int input_num = then_result.input_mapping[i] + 1; - if (ctx->input_type(input_num) == DT_RESOURCE) { + if (has_token_input_output_ && i == num_inputs - 1) { + // Set token input for this "if" op. + std::vector<xla::XlaOp> token_inputs; + for (const string& node_name : token_input_nodes_) { + auto token_or = compiler->GetNodeToken(node_name); + OP_REQUIRES_OK(ctx, token_or.status()); + token_inputs.push_back(token_or.ValueOrDie()); + } + inputs[i] = xla::AfterAll(b, token_inputs); + } else if (ctx->input_type(input_num) == DT_RESOURCE) { XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource)); OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b)); @@ -219,6 +235,18 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { } ctx->SetOutput(i, output_handle); } + if (has_token_input_output_) { + // Set token output for this "if" op. + xla::XlaOp token_output = + xla::GetTupleElement(outputs, output_types_.size()); + auto shape_or = b->GetShape(token_output); + OP_REQUIRES_OK(ctx, shape_or.status()); + OP_REQUIRES(ctx, xla::ShapeUtil::IsToken(shape_or.ValueOrDie()), + errors::FailedPrecondition( + "Token output is not token type: ", + xla::ShapeUtil::HumanString(shape_or.ValueOrDie()))); + OP_REQUIRES_OK(ctx, compiler->SetNodeToken(name(), token_output)); + } // Updates the values of any resource variables modified by the conditional // bodies. diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.h b/tensorflow/compiler/tf2xla/kernels/if_op.h index f9bc98a198..7783e13a8a 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.h +++ b/tensorflow/compiler/tf2xla/kernels/if_op.h @@ -52,6 +52,8 @@ class XlaIfOp : public XlaOpKernel { DataType cond_type_; DataTypeVector input_types_; DataTypeVector output_types_; + bool has_token_input_output_; + std::vector<string> token_input_nodes_; }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index 296518229e..559414eeaa 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/while_op.h" #include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" @@ -90,6 +91,11 @@ XlaWhileOp::XlaWhileOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { cond_name_attr_ = *name_attr; OP_REQUIRES_OK(ctx, ctx->GetAttr("body", &name_attr)); body_name_attr_ = *name_attr; + if (!ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_).ok()) { + has_token_input_output_ = false; + } else { + has_token_input_output_ = !token_input_nodes_.empty(); + } } void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { @@ -120,6 +126,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { body_options.return_updated_values_for_all_resources = true; body_options.resolve_compile_time_constants = false; body_options.is_entry_computation = false; + body_options.add_token_input_output = has_token_input_output_; XlaCompiler::CompilationResult body; OP_REQUIRES_OK(ctx, compiler->CompileFunction(body_options, body_name_attr_, arguments, &body)); @@ -192,6 +199,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { cond_options.use_tuple_arg = true; cond_options.resolve_compile_time_constants = false; cond_options.is_entry_computation = false; + cond_options.add_token_input_output = has_token_input_output_; XlaCompiler::CompilationResult cond; OP_REQUIRES_OK(ctx, compiler->CompileFunction(cond_options, cond_name_attr_, arguments, &cond)); @@ -238,7 +246,16 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { std::vector<xla::XlaOp> inputs(num_inputs); for (int i = 0; i < num_inputs; ++i) { int input_num = body.input_mapping[i]; - if (ctx->input_type(input_num) == DT_RESOURCE) { + if (has_token_input_output_ && i == num_inputs - 1) { + // Set token input for this "while" op. + std::vector<xla::XlaOp> token_inputs; + for (const string& node_name : token_input_nodes_) { + auto token_or = compiler->GetNodeToken(node_name); + OP_REQUIRES_OK(ctx, token_or.status()); + token_inputs.push_back(token_or.ValueOrDie()); + } + inputs[i] = xla::AfterAll(builder, token_inputs); + } else if (ctx->input_type(input_num) == DT_RESOURCE) { XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource)); OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], builder)); @@ -273,6 +290,18 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { xla::GetTupleElement(while_result, i)); } } + if (has_token_input_output_) { + // Set token output for this "while" op. + xla::XlaOp token_output = + xla::GetTupleElement(while_result, ctx->num_outputs()); + auto shape_or = builder->GetShape(token_output); + OP_REQUIRES_OK(ctx, shape_or.status()); + OP_REQUIRES(ctx, xla::ShapeUtil::IsToken(shape_or.ValueOrDie()), + errors::FailedPrecondition( + "Token output is not token type: ", + xla::ShapeUtil::HumanString(shape_or.ValueOrDie()))); + OP_REQUIRES_OK(ctx, compiler->SetNodeToken(name(), token_output)); + } // Updates the values of any resource variables modified by the loop. for (int i = 0; i < body.resource_updates.size(); ++i) { diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.h b/tensorflow/compiler/tf2xla/kernels/while_op.h index 67edebabf9..aeeff40e68 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.h +++ b/tensorflow/compiler/tf2xla/kernels/while_op.h @@ -56,6 +56,8 @@ class XlaWhileOp : public XlaOpKernel { private: NameAttrList cond_name_attr_; NameAttrList body_name_attr_; + bool has_token_input_output_; + std::vector<string> token_input_nodes_; TF_DISALLOW_COPY_AND_ASSIGN(XlaWhileOp); }; diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.cc b/tensorflow/compiler/tf2xla/resource_operation_table.cc index 20f2ce2919..92577b5bc8 100644 --- a/tensorflow/compiler/tf2xla/resource_operation_table.cc +++ b/tensorflow/compiler/tf2xla/resource_operation_table.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "absl/algorithm/container.h" +#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/flatmap.h" namespace tensorflow { @@ -30,11 +31,10 @@ namespace tensorflow { } } -static gtl::FlatMap<absl::string_view, XlaResourceOpInfo>* -CreateResourceOpInfoMap() { - auto* result = new gtl::FlatMap<absl::string_view, XlaResourceOpInfo>; +static gtl::FlatMap<StringPiece, XlaResourceOpInfo>* CreateResourceOpInfoMap() { + auto* result = new gtl::FlatMap<StringPiece, XlaResourceOpInfo>; - auto add = [&](absl::string_view op, XlaResourceOpKind op_kind, + auto add = [&](StringPiece op, XlaResourceOpKind op_kind, XlaResourceKind resource_kind) { auto insert_result = result->insert({op, XlaResourceOpInfo(op_kind, resource_kind)}); @@ -103,17 +103,17 @@ CreateResourceOpInfoMap() { return result; } -static const gtl::FlatMap<absl::string_view, XlaResourceOpInfo>& +static const gtl::FlatMap<StringPiece, XlaResourceOpInfo>& GetStaticResourceOpInfoMap() { - static gtl::FlatMap<absl::string_view, XlaResourceOpInfo>* op_info_map = + static gtl::FlatMap<StringPiece, XlaResourceOpInfo>* op_info_map = CreateResourceOpInfoMap(); return *op_info_map; } const XlaResourceOpInfo* GetResourceOpInfoForOp(absl::string_view op) { - const gtl::FlatMap<absl::string_view, XlaResourceOpInfo>& op_infos = + const gtl::FlatMap<StringPiece, XlaResourceOpInfo>& op_infos = GetStaticResourceOpInfoMap(); - auto it = op_infos.find(op); + auto it = op_infos.find(StringPiece(op.data(), op.length())); return it == op_infos.end() ? nullptr : &it->second; } @@ -121,7 +121,7 @@ namespace resource_op_table_internal { std::vector<absl::string_view> GetKnownResourceOps() { std::vector<absl::string_view> result; for (const auto& p : GetStaticResourceOpInfoMap()) { - result.push_back(p.first); + result.push_back(absl::string_view(p.first)); } absl::c_sort(result); return result; diff --git a/tensorflow/compiler/tf2xla/side_effect_util.cc b/tensorflow/compiler/tf2xla/side_effect_util.cc new file mode 100644 index 0000000000..6cd7b24592 --- /dev/null +++ b/tensorflow/compiler/tf2xla/side_effect_util.cc @@ -0,0 +1,67 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/side_effect_util.h" + +#include "tensorflow/core/graph/algorithm.h" + +namespace tensorflow { + +const char kXlaTokenInputNodesAttrName[] = "_xla_token_input_nodes"; + +const char kXlaTokenArgNodeName[] = "_xla_token_arg_node"; + +std::set<std::string> CalculateTokenInputsForOutputToken(const Graph& g) { + std::set<std::string> results; + Node* first_side_effecting_node_on_path = nullptr; + ReverseDFS(g, + [&](Node* n) { + std::vector<string> token_input_nodes; + if (!GetNodeAttr(n->attrs(), kXlaTokenInputNodesAttrName, + &token_input_nodes) + .ok() || + token_input_nodes.empty()) { + return; + } + + if (first_side_effecting_node_on_path != nullptr) { + return; + } + + first_side_effecting_node_on_path = n; + results.insert(n->name()); + }, + [&](Node* n) { + if (first_side_effecting_node_on_path == n) { + first_side_effecting_node_on_path = nullptr; + } + }, + NodeComparatorName()); + return results; +} + +bool HasSideEffectingNodes(const Graph& g) { + for (Node* n : g.nodes()) { + std::vector<string> token_input_nodes; + if (GetNodeAttr(n->attrs(), kXlaTokenInputNodesAttrName, &token_input_nodes) + .ok() && + !token_input_nodes.empty()) { + return true; + } + } + return false; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/side_effect_util.h b/tensorflow/compiler/tf2xla/side_effect_util.h new file mode 100644 index 0000000000..ad07624729 --- /dev/null +++ b/tensorflow/compiler/tf2xla/side_effect_util.h @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_ + +#include <vector> + +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// Side-effecting nodes will have this attribute set. Its value is the list of +// node names which this node has side-effect dependencies on. +// +// Nodes like HostCompute, SendToHost, RecvFromHost always have this attribute, +// because they always have side-effect. +// If and While nodes may or may not have this attribute, depending on whether +// their bodies have side-effecting nodes. +extern const char kXlaTokenInputNodesAttrName[]; + +// This node name is used in kXlaTokenInputNodesAttrName attr to signal that a +// node has side-effect dependency on current graph's token input. +extern const char kXlaTokenArgNodeName[]; + +// Calculates side-effect dependencies for the graph's token output. +// Returns a set of node names representing these dependencies. +std::set<std::string> CalculateTokenInputsForOutputToken(const Graph& g); + +// Returns whether a graph contains side-effecting nodes. +bool HasSideEffectingNodes(const Graph& g); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.h b/tensorflow/compiler/tf2xla/tf2xla_util.h index a29e764466..dcddef8418 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.h +++ b/tensorflow/compiler/tf2xla/tf2xla_util.h @@ -18,6 +18,7 @@ limitations under the License. #include <unordered_map> +#include "absl/strings/string_view.h" #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/kernel_def.pb.h" diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 41d305d461..dcb455779d 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/graph_compiler.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/sharding_util.h" +#include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" @@ -291,6 +292,10 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, "Invalid resource type in XLAShapeForArgument()"); } } + case XlaCompiler::Argument::kToken: { + *xla_shape = xla::ShapeUtil::MakeTokenShape(); + return Status::OK(); + } case XlaCompiler::Argument::kInvalid: return errors::Internal("Invalid argument type in XLAShapeForArgument()"); } @@ -489,7 +494,8 @@ Status XlaCompiler::BuildArguments( } break; - case XlaCompiler::Argument::kParameter: { + case XlaCompiler::Argument::kParameter: + case XlaCompiler::Argument::kToken: { input_mapping->push_back(i); break; } @@ -616,6 +622,10 @@ Status XlaCompiler::BuildArguments( arg_expression.set_handle(arg_handles[i]); } break; + case XlaCompiler::Argument::kToken: { + arg_expression.set_handle(arg_handles[i]); + break; + } case XlaCompiler::Argument::kConstant: case XlaCompiler::Argument::kInvalid: return errors::Internal( @@ -757,23 +767,71 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, &options_.shape_representation_fn); core::ScopedUnref context_unref(context); + std::vector<XlaCompiler::Argument> real_args(args); + int token_input_index = -1; + if (options.add_token_input_output) { + // Add extra token input. + token_input_index = real_args.size(); + + XlaCompiler::Argument token_arg; + token_arg.kind = XlaCompiler::Argument::kToken; + real_args.push_back(token_arg); + } + std::vector<XlaExpression> arg_expressions; std::vector<int> arg_cores; - TF_RETURN_IF_ERROR( - BuildArguments(*graph, args, options.use_tuple_arg, &builder, context, - &arg_cores, &arg_expressions, &result->input_mapping, - &result->xla_input_shapes, options.is_entry_computation)); + TF_RETURN_IF_ERROR(BuildArguments( + *graph, real_args, options.use_tuple_arg, &builder, context, &arg_cores, + &arg_expressions, &result->input_mapping, &result->xla_input_shapes, + options.is_entry_computation)); context->set_args(std::move(arg_expressions)); + PushNodeTokenMapping(); + // Use std::set instead of std::unordered_set to ensure determinism. + std::set<std::string> output_node_token_inputs; + if (token_input_index != -1) { + // Original token comes from input. + auto arg_expression = context->args()[token_input_index]; + TF_RETURN_IF_ERROR( + SetNodeToken(kXlaTokenArgNodeName, arg_expression.handle())); + + // Calculate token inputs for output token. + output_node_token_inputs = CalculateTokenInputsForOutputToken(*graph); + + // If there's no side-effecting op in the graph, use token input as token + // output. + if (output_node_token_inputs.empty()) { + output_node_token_inputs.insert(kXlaTokenArgNodeName); + } + } else if (options.is_entry_computation) { + // Original token is manually created. + if (HasSideEffectingNodes(*graph)) { + TF_RETURN_IF_ERROR( + SetNodeToken(kXlaTokenArgNodeName, xla::CreateToken(&builder))); + } + } + TF_RETURN_IF_ERROR(ExecuteGraph(context, std::move(graph), device_, flib_runtime_, NextStepId())); + if (token_input_index != -1) { + // Add extra token output. + std::vector<xla::XlaOp> token_inputs; + for (const auto& node_name : output_node_token_inputs) { + auto token_or = GetNodeToken(node_name); + TF_RETURN_IF_ERROR(token_or.status()); + token_inputs.push_back(token_or.ValueOrDie()); + } + TF_RETURN_IF_ERROR( + context->AppendTokenRetval(xla::AfterAll(&builder, token_inputs))); + } + TF_RETURN_IF_ERROR(PopNodeTokenMapping()); int num_nonconst_outputs; int num_computation_outputs; result->computation = std::make_shared<xla::XlaComputation>(); result->outputs.resize(context->retvals().size()); TF_RETURN_IF_ERROR(BuildComputation( - args, arg_cores, context->retvals(), context->resources(), + real_args, arg_cores, context->retvals(), context->resources(), options.return_updated_values_for_all_resources, options.always_return_tuple, &builder, result->computation.get(), &num_computation_outputs, &num_nonconst_outputs, &result->outputs, @@ -912,4 +970,47 @@ Status XlaCompiler::SetHostComputeControlDependency( return Status::OK(); } +void XlaCompiler::PushNodeTokenMapping() { + node_token_mapping_stack_.emplace(std::map<string, xla::XlaOp>{}); +} + +Status XlaCompiler::PopNodeTokenMapping() { + if (node_token_mapping_stack_.empty()) { + return errors::FailedPrecondition( + "Calling PopNodeTokenMapping() when node_token_mapping_stack_ is " + "empty."); + } + node_token_mapping_stack_.pop(); + return Status::OK(); +} + +Status XlaCompiler::SetNodeToken(const string& node_name, + const xla::XlaOp& op) { + if (node_token_mapping_stack_.empty()) { + return errors::FailedPrecondition( + "Calling SetNodeToken() when node_token_mapping_stack_ is " + "empty."); + } + auto insert_result = node_token_mapping_stack_.top().insert({node_name, op}); + if (!insert_result.second) { + return errors::FailedPrecondition("Token mapping already exists for node ", + node_name); + } + return Status::OK(); +} + +xla::StatusOr<xla::XlaOp> XlaCompiler::GetNodeToken(const string& node_name) { + if (node_token_mapping_stack_.empty()) { + return errors::FailedPrecondition( + "Calling GetNodeToken() when node_token_mapping_stack_ is " + "empty."); + } + auto iter = node_token_mapping_stack_.top().find(node_name); + if (iter == node_token_mapping_stack_.top().end()) { + return errors::FailedPrecondition("Cannot find token mapping for node ", + node_name); + } + return iter->second; +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 8f4a9858ed..2cc603a580 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ #define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ +#include <stack> + #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -26,6 +28,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/function.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/notification.h" @@ -106,6 +109,9 @@ class XlaCompiler { // Argument is a run-time parameter. kParameter, + + // Argument is an XLA token. + kToken, }; Kind kind = kInvalid; @@ -179,6 +185,9 @@ class XlaCompiler { // True when compiling the entry computation, false for subcomputations // (while, call, etc.) bool is_entry_computation = true; + + // True when we should add XLA input & output to the graph/function. + bool add_token_input_output = false; }; struct OutputDescription { @@ -384,6 +393,11 @@ class XlaCompiler { xla::Client* client() const { return options_.client; } FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; } + void PushNodeTokenMapping(); + Status PopNodeTokenMapping(); + Status SetNodeToken(const string& node_name, const xla::XlaOp& op); + xla::StatusOr<xla::XlaOp> GetNodeToken(const string& node_name); + private: // Sets the function body `fbody` to the one registered as `function`. Status FindFunctionBody(const NameAttrList& function, @@ -448,6 +462,15 @@ class XlaCompiler { std::unordered_map<string, xla::XlaOp> host_compute_control_output_; + // This is used to store <node name, token output> mapping. Side-effecting + // ops call SetNodeToken() to record its token output, so later side-effecting + // ops can use GetNodeToken() to get it and use it as token input. + // + // It's a stack because we need a mapping like this for each level of nested + // CompileGraph() call. In CompileGraph(), we will push a new mapping to the + // stack, and pop the mapping before returning. + std::stack<std::map<string, xla::XlaOp>> node_token_mapping_stack_; + TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler); }; diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index be3c93ae47..40ce9fb41c 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -20,10 +20,12 @@ limitations under the License. #include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -32,6 +34,7 @@ limitations under the License. #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_testutil.h" @@ -1274,5 +1277,70 @@ TEST_F(XlaCompilerTest, SingleOpWithoutInputs) { } } +class DummySideEffectingOp : public XlaOpKernel { + public: + explicit DummySideEffectingOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + void Compile(XlaOpKernelContext* ctx) override { + OP_REQUIRES_OK(ctx, ctx->compiler()->SetNodeToken( + name(), xla::CreateToken(ctx->builder()))); + } +}; + +REGISTER_OP("DummySideEffectingOp"); + +REGISTER_XLA_OP(Name("DummySideEffectingOp"), DummySideEffectingOp); + +TEST_F(XlaCompilerTest, TokenInputAndOutput) { + std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); + NodeDef side_effecting_op; + side_effecting_op.set_name("DummySideEffectingOp"); + side_effecting_op.set_op("DummySideEffectingOp"); + AddNodeAttr(kXlaTokenInputNodesAttrName, + std::vector<string>{kXlaTokenArgNodeName}, &side_effecting_op); + Status status; + graph->AddNode(side_effecting_op, &status); + TF_ASSERT_OK(status); + EXPECT_TRUE(FixupSourceAndSinkEdges(graph.get())); + + const std::vector<XlaCompiler::Argument> empty_args; + { + // The case for entry computation: we don't add token input/output. Instead, + // we use CreateToken HLO to create the entry token. + XlaCompiler::CompileOptions options; + options.is_entry_computation = true; + options.add_token_input_output = false; + XlaCompiler compiler(DefaultOptions()); + + std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global())); + CopyGraph(*graph, graph_copy.get()); + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy), + empty_args, &result)); + EXPECT_EQ(result.xla_input_shapes.size(), 0); + EXPECT_TRUE(xla::ShapeUtil::IsTuple(result.xla_output_shape)); + EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 0); + } + { + // The case for non-entry computation (e.g. while loop body). We add token + // input/output. + XlaCompiler::CompileOptions options; + options.is_entry_computation = false; + options.add_token_input_output = true; + XlaCompiler compiler(DefaultOptions()); + + std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global())); + CopyGraph(*graph, graph_copy.get()); + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy), + empty_args, &result)); + EXPECT_EQ(result.xla_input_shapes.size(), 1); + EXPECT_TRUE(xla::ShapeUtil::IsToken(result.xla_input_shapes[0])); + EXPECT_TRUE(xla::ShapeUtil::IsTuple(result.xla_output_shape)); + EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 1); + EXPECT_TRUE(xla::ShapeUtil::IsToken( + xla::ShapeUtil::GetTupleElementShape(result.xla_output_shape, 0))); + } +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index e8b4b0eb36..f247570d72 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -119,6 +119,17 @@ Status XlaContext::AddResourceRetval(int retval_index, XlaResource* resource) { return Status::OK(); } +Status XlaContext::AppendTokenRetval(const xla::XlaOp& token) { + VLOG(1) << "Adding retval index " << retvals_.size() + << " with token to XLA computation"; + XlaExpression e; + e.set_handle(token); + // We use DT_INVALID because there is no TF DataType which corresponds to XLA + // token. XlaCompiler handles this case separately, so putting it here is OK. + retvals_.push_back(Retval{DT_INVALID, TensorShape(), e}); + return Status::OK(); +} + xla::XlaBuilder* XlaContext::builder() { return builder_; } Status XlaContext::CreateResource( diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index 4da891634e..d7dbdc957f 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -89,6 +89,9 @@ class XlaContext : public ResourceBase { // As for Retval, but for return values that are resource handles. Status AddResourceRetval(int retval_index, XlaResource* resource); + // As for Retval, but for return values that are XLA tokens. + Status AppendTokenRetval(const xla::XlaOp& token); + // Creates a resource with resource `kind` and initial value `handle`. `name` // is a descriptive name for use in error messages. See the `XlaResource` // constructor for a description of the remaining arguments. diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index d67e50375b..636cb71e21 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -102,7 +102,8 @@ Status XlaOpKernelContext::ConstantInput(int index, static xla::StatusOr<int> InputIndex(XlaOpKernelContext* context, absl::string_view name) { int start, stop; - TF_RETURN_IF_ERROR(context->op_kernel().InputRange(name, &start, &stop)); + TF_RETURN_IF_ERROR(context->op_kernel().InputRange( + StringPiece(name.data(), name.length()), &start, &stop)); if (stop != start + 1) { return errors::InvalidArgument("OpKernel used list-valued input name '", name, @@ -365,7 +366,8 @@ Status XlaOpKernelContext::InputList(absl::string_view name, std::vector<xla::XlaOp>* handles, std::vector<TensorShape>* shapes) { OpInputList inputs; - TF_RETURN_IF_ERROR(context_->input_list(name, &inputs)); + TF_RETURN_IF_ERROR( + context_->input_list(StringPiece(name.data(), name.size()), &inputs)); handles->clear(); shapes->clear(); for (const Tensor& input : inputs) { @@ -378,7 +380,8 @@ Status XlaOpKernelContext::InputList(absl::string_view name, Status XlaOpKernelContext::ConstantInputList( absl::string_view name, std::vector<xla::Literal>* outputs) { int start, stop; - TF_RETURN_IF_ERROR(op_kernel().InputRange(name, &start, &stop)); + TF_RETURN_IF_ERROR(op_kernel().InputRange( + StringPiece(name.data(), name.size()), &start, &stop)); outputs->resize(stop - start); for (int i = start; i < stop; ++i) { TF_RETURN_IF_ERROR(ConstantInput(i, &(*outputs)[i])); @@ -612,7 +615,7 @@ const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMul( const Tensor& XlaOpKernelContext::GetInputTensorByName(absl::string_view name) { const Tensor* tensor; - CHECK(context_->input(name, &tensor).ok()); + CHECK(context_->input(StringPiece(name.data(), name.length()), &tensor).ok()); return *tensor; } diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index 74a4885f1f..5d53169f68 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -22,6 +22,7 @@ limitations under the License. #include <unordered_map> #include <vector> +#include "absl/strings/string_view.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/framework/device_base.h" diff --git a/tensorflow/compiler/xla/packed_literal_reader.cc b/tensorflow/compiler/xla/packed_literal_reader.cc index f9473d372b..bddb664149 100644 --- a/tensorflow/compiler/xla/packed_literal_reader.cc +++ b/tensorflow/compiler/xla/packed_literal_reader.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" @@ -64,7 +65,7 @@ StatusOr<std::unique_ptr<Literal>> PackedLiteralReader::Read( absl::Span<const float> field = result->data<float>(); char* data = absl::bit_cast<char*>(field.data()); uint64 bytes = elements * sizeof(float); - absl::string_view sp; + tensorflow::StringPiece sp; auto s = file_->Read(offset_, bytes, &sp, data); offset_ += sp.size(); if (!s.ok()) { @@ -85,7 +86,7 @@ bool PackedLiteralReader::IsExhausted() const { // Try to read a single byte from offset_. If we can't, we've // exhausted the data. char single_byte[1]; - absl::string_view sp; + tensorflow::StringPiece sp; auto s = file_->Read(offset_, sizeof(single_byte), &sp, single_byte); return !s.ok(); } diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index 76c09512d8..450d3fe5af 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -109,12 +109,12 @@ limitations under the License. // Must be included first #include "tensorflow/python/lib/core/numpy.h" -#include "third_party/absl/strings/str_cat.h" -#include "third_party/absl/strings/str_format.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "third_party/absl/types/span.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/python/numpy_bridge.h" #include "tensorflow/compiler/xla/python/local_computation_builder.h" diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index ab86dce510..e784663ff6 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -159,6 +159,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep ], @@ -291,6 +292,7 @@ cc_library( "hlo_instructions.cc", "hlo_module.cc", "hlo_opcode.cc", + "hlo_schedule.cc", "hlo_sharding.cc", ], hdrs = [ @@ -303,6 +305,7 @@ cc_library( "hlo_instructions.h", "hlo_module.h", "hlo_opcode.h", + "hlo_schedule.h", "hlo_sharding.h", ], deps = [ @@ -331,6 +334,8 @@ cc_library( "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) @@ -1037,7 +1042,6 @@ tf_cc_test( ":flatten_call_graph", ":hlo", ":hlo_ordering", - ":hlo_schedule", ":hlo_scheduling", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -1065,7 +1069,6 @@ cc_library( ":hlo", ":hlo_dataflow_analysis", ":hlo_proto", - ":hlo_schedule", ":hlo_value", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -1086,7 +1089,6 @@ tf_cc_test( ":hlo", ":hlo_dataflow_analysis", ":hlo_ordering", - ":hlo_schedule", ":hlo_scheduling", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", @@ -1108,7 +1110,6 @@ cc_library( ":hlo", ":hlo_ordering", ":hlo_proto", - ":hlo_schedule", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", @@ -1177,22 +1178,6 @@ cc_library( ], ) -cc_library( - name = "hlo_schedule", - srcs = ["hlo_schedule.cc"], - hdrs = ["hlo_schedule.h"], - deps = [ - ":hlo", - "//tensorflow/compiler/xla:status", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:util", - "//tensorflow/core:lib_internal", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:span", - ], -) - tf_cc_test( name = "hlo_schedule_test", srcs = ["hlo_schedule_test.cc"], @@ -1202,7 +1187,6 @@ tf_cc_test( ":hlo_dce", ":hlo_ordering", ":hlo_parser", - ":hlo_schedule", ":hlo_scheduling", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", @@ -1222,7 +1206,6 @@ cc_library( ":heap_simulator", ":hlo", ":hlo_ordering", - ":hlo_schedule", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -1969,6 +1952,8 @@ tf_cc_test( srcs = ["hlo_module_test.cc"], deps = [ ":hlo", + ":hlo_matchers", + ":hlo_parser", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -1977,6 +1962,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "//tensorflow/core:test", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", ], @@ -2413,7 +2399,6 @@ cc_library( ":hlo", ":hlo_dce", ":hlo_ordering", - ":hlo_schedule", ":hlo_scheduling", ":logical_buffer", ":tuple_points_to_analysis", @@ -2587,6 +2572,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index aa40fba9bb..a0db4563fb 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -2369,20 +2369,20 @@ TEST_P(ConvFilterPaddingTest, DoIt) { rhs_pad->shape().dimensions(3), testcase.orig_conv_window)) .ValueOrDie(); - auto* orig_conv = builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeInference::InferConvolveShape(input->shape(), rhs_pad->shape(), - /*feature_group_count=*/1, window, - dnums) - .ValueOrDie(), - input, rhs_pad, /*feature_group_count=*/1, window, dnums, - DefaultPrecisionConfig(2))); // Add a PrecisionConfig and check that AlgebraicSimplifier keeps it in place // after the transformation. PrecisionConfig precision_config; precision_config.add_operand_precision(PrecisionConfig::HIGH); precision_config.add_operand_precision(PrecisionConfig::HIGHEST); - orig_conv->set_precision_config(precision_config); + + builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeInference::InferConvolveShape(input->shape(), rhs_pad->shape(), + /*feature_group_count=*/1, window, + dnums) + .ValueOrDie(), + input, rhs_pad, /*feature_group_count=*/1, window, dnums, + precision_config)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); @@ -2401,7 +2401,9 @@ TEST_P(ConvFilterPaddingTest, DoIt) { conv->operand(1)->shape().dimensions(2), conv->operand(1)->shape().dimensions(3), testcase.expected_conv_window)); - EXPECT_THAT(conv->precision_config().operand_precision(), + EXPECT_THAT(Cast<HloConvolutionInstruction>(conv) + ->precision_config() + .operand_precision(), ElementsAre(PrecisionConfig::HIGH, PrecisionConfig::HIGHEST)); } } diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc index 69b654d30e..388fd5df99 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -55,8 +55,12 @@ class TestBFloat16Support : public BFloat16Support { } }; -class BFloat16PropagationTest : public HloTestBase { +class BFloat16PropagationTest : public HloVerifiedTestBase { protected: + BFloat16PropagationTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/true) {} + // Runs the propagation pass on the given module, and returns whether the // module is changed after this pass. bool PropagatePrecision(HloModule* module) { @@ -77,6 +81,16 @@ class BFloat16PropagationTest : public HloTestBase { inst->users()[0]->opcode() == HloOpcode::kConvert && inst->users()[0]->shape().element_type() == BF16; } + + std::unique_ptr<HloInstruction> CreateDot(const Shape& shape, + HloInstruction* lhs, + HloInstruction* rhs) { + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums, + DefaultPrecisionConfig(2)); + } }; // Tests that BF16 can propagate through select over non-tuple buffers, but not @@ -95,22 +109,22 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSelectButNotAdd) { HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b)); HloInstruction* add1 = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add0, b)); - HloInstruction* pred = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kEq, a, b)); + HloInstruction* pred = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {2, 4}), HloOpcode::kEq, a, b)); HloInstruction* sel = builder.AddInstruction( HloInstruction::CreateTernary(shape, HloOpcode::kSelect, pred, c, add1)); HloInstruction* xpose = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {4, 2}), sel, {1, 0})); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, xpose, a)); - HloInstruction* root = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot)); + HloInstruction* dot = builder.AddInstruction( + CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), xpose, a)); + HloInstruction* root = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kAdd, dot, dot)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), root); EXPECT_TRUE(OutputsBF16(xpose)); @@ -136,13 +150,12 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) { HloInstruction::CreateConstant(LiteralUtil::CreateFromArray(array_a))); HloInstruction* b = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateFromArray(array_b))); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kDot, a, b)); + HloInstruction* dot = builder.AddInstruction(CreateDot(shape, a, b)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE(OutputsBF16(dot->operand(0))); @@ -189,8 +202,8 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTuples) { builder.AddInstruction(HloInstruction::CreateGetTupleElement( tuple0->shape(), tuple1, 0)), 0)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, lhs, rhs)); + HloInstruction* dot = builder.AddInstruction( + CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), lhs, rhs)); HloInstruction* output_tuple = builder.AddInstruction(HloInstruction::CreateTuple({dot, add2})); @@ -198,7 +211,7 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTuples) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), output_tuple); EXPECT_TRUE(OutputsBF16(xpose)); @@ -231,13 +244,13 @@ TEST_F(BFloat16PropagationTest, SameValueReferencedTwice) { HloInstruction::CreateGetTupleElement(add1->shape(), tuple, 1)); // lhs is the transpose of add1, and rhs is a get-tuple-element aliasing add1. - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, lhs, rhs)); + HloInstruction* dot = builder.AddInstruction( + CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), lhs, rhs)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE(OutputsBF16(add1)); @@ -249,7 +262,7 @@ TEST_F(BFloat16PropagationTest, SameValueReferencedTwice) { // Tests that a non-fusion computation's root should not be changed. TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) { auto builder = HloComputation::Builder(TestName()); - Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); HloInstruction* a = builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); @@ -258,8 +271,7 @@ TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) { HloInstruction* add = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, add, add)); + HloInstruction* dot = builder.AddInstruction(CreateDot(shape, add, add)); HloInstruction* tuple = builder.AddInstruction(HloInstruction::CreateTuple({add, dot})); @@ -267,7 +279,7 @@ TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(PropagatePrecision(module.get())); + EXPECT_FALSE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), tuple); EXPECT_FALSE(OutputsBF16(add)); @@ -277,7 +289,7 @@ TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) { TEST_F(BFloat16PropagationTest, PropagateThroughFusion) { auto module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); - Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); HloInstruction* param = builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "param")); @@ -303,15 +315,14 @@ TEST_F(BFloat16PropagationTest, PropagateThroughFusion) { HloInstruction::CreateGetTupleElement(shape, p_f1, 0)); HloInstruction* b_f1 = builder_f1.AddInstruction( HloInstruction::CreateGetTupleElement(shape, p_f1, 1)); - HloInstruction* dot = builder_f1.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, a_f1, b_f1)); + HloInstruction* dot = builder_f1.AddInstruction(CreateDot(shape, a_f1, b_f1)); auto comp_f1 = module->AddEmbeddedComputation(builder_f1.Build()); auto fusion1 = builder.AddInstruction(HloInstruction::CreateFusion( dot->shape(), HloInstruction::FusionKind::kCustom, {fusion0}, comp_f1)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), fusion1); EXPECT_TRUE(OutputsBF16(add)); @@ -326,7 +337,7 @@ TEST_F(BFloat16PropagationTest, PropagateThroughFusion) { TEST_F(BFloat16PropagationTest, DiscardFusionInternalBF16Changes) { auto module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); - Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); HloInstruction* param = builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "param")); @@ -340,15 +351,15 @@ TEST_F(BFloat16PropagationTest, DiscardFusionInternalBF16Changes) { builder_f.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); HloInstruction* add_f = builder_f.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a_f, b_f)); - HloInstruction* dot_f = builder_f.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, add_f, add_f)); + HloInstruction* dot_f = + builder_f.AddInstruction(CreateDot(shape, add_f, add_f)); auto comp_f = module->AddEmbeddedComputation(builder_f.Build()); auto fusion = builder.AddInstruction(HloInstruction::CreateFusion( dot_f->shape(), HloInstruction::FusionKind::kCustom, {add, add}, comp_f)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(PropagatePrecision(module.get())); + EXPECT_FALSE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), fusion); } @@ -390,12 +401,11 @@ TEST_F(BFloat16PropagationTest, ConvertTupleFusionElementIfUsedByAdd) { HloInstruction::CreateGetTupleElement(shape, fusion, 0)); HloInstruction* gte1 = builder.AddInstruction( HloInstruction::CreateGetTupleElement(shape, fusion, 1)); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kDot, gte0, gte1)); + HloInstruction* dot = builder.AddInstruction(CreateDot(shape, gte0, gte1)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE(OutputsBF16(gte0)); @@ -440,12 +450,12 @@ TEST_F(BFloat16PropagationTest, SelectOverTuples) { HloInstruction* xpose = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {4, 2}), gte0, {1, 0})); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, xpose, gte1)); + HloInstruction* dot = builder.AddInstruction( + CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), xpose, gte1)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_FALSE(OutputsBF16(add0)); @@ -472,31 +482,36 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSimpleWhile) { auto builder_cond = HloComputation::Builder("cond"); auto cond_param = builder_cond.AddInstruction( HloInstruction::CreateParameter(0, shape, "cond_param")); - auto cond_dot = builder_cond.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, cond_param, cond_param)); + auto cond_dot = + builder_cond.AddInstruction(CreateDot(shape, cond_param, cond_param)); auto cond_root = builder_cond.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, - builder_cond.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond_dot, {0, 0}, {1, 1}, {1, 1})), - builder_cond.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond_dot, {1, 1}, {2, 2}, {1, 1})))); + builder_cond.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond.AddInstruction( + HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}), + cond_dot, {0, 0}, {1, 1}, {1, 1})))), + builder_cond.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {1, 1}), cond_dot, {1, 1}, {2, 2}, + {1, 1})))))); auto cond = module->AddEmbeddedComputation(builder_cond.Build()); auto builder_body = HloComputation::Builder("body"); auto body_param = builder_body.AddInstruction( HloInstruction::CreateParameter(0, shape, "body_param")); - auto body_dot = builder_body.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, body_param, body_param)); + auto body_dot = + builder_body.AddInstruction(CreateDot(shape, body_param, body_param)); auto body = module->AddEmbeddedComputation(builder_body.Build()); auto while_hlo = builder.AddInstruction( HloInstruction::CreateWhile(shape, cond, body, add)); - auto dot = builder.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, while_hlo, while_hlo)); + auto dot = builder.AddInstruction(CreateDot(shape, while_hlo, while_hlo)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE( @@ -528,10 +543,16 @@ TEST_F(BFloat16PropagationTest, HloInstruction::CreateParameter(0, shape, "cond_param")); builder_cond.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, - builder_cond.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond_param, {0, 0}, {1, 1}, {1, 1})), - builder_cond.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond_param, {1, 1}, {2, 2}, {1, 1})))); + builder_cond.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {1, 1}), cond_param, {0, 0}, {1, 1}, + {1, 1})))), + builder_cond.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {1, 1}), cond_param, {1, 1}, {2, 2}, + {1, 1})))))); auto cond = module->AddEmbeddedComputation(builder_cond.Build()); auto builder_body = HloComputation::Builder("body"); @@ -552,11 +573,10 @@ TEST_F(BFloat16PropagationTest, auto while_hlo = builder.AddInstruction( HloInstruction::CreateWhile(shape, cond, body, add)); - auto dot = builder.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, while_hlo, while_hlo)); + auto dot = builder.AddInstruction(CreateDot(shape, while_hlo, while_hlo)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(PropagatePrecision(module.get())); + EXPECT_FALSE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_FALSE(OutputsBF16(add)); EXPECT_FALSE(OutputsBF16(body_fusion)); @@ -593,14 +613,20 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { // This add should prevent RHS from using BF16 auto cond_add_rhs = builder_cond.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, cond_rhs, cond_rhs)); - auto cond_dot = builder_cond.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, cond_lhs, cond_add_rhs)); + auto cond_dot = + builder_cond.AddInstruction(CreateDot(shape, cond_lhs, cond_add_rhs)); builder_cond.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, - builder_cond.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond_dot, {0, 0}, {1, 1}, {1, 1})), - builder_cond.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond_dot, {1, 1}, {2, 2}, {1, 1})))); + builder_cond.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond.AddInstruction( + HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}), + cond_dot, {0, 0}, {1, 1}, {1, 1})))), + builder_cond.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {1, 1}), cond_dot, {1, 1}, {2, 2}, + {1, 1})))))); auto cond = module->AddEmbeddedComputation(builder_cond.Build()); auto builder_body = HloComputation::Builder("body"); @@ -610,10 +636,10 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { HloInstruction::CreateGetTupleElement(shape, body_param, 0)); auto body_rhs = builder_body.AddInstruction( HloInstruction::CreateGetTupleElement(shape, body_param, 1)); - auto body_dot1 = builder_body.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kDot, body_lhs, body_rhs)); - auto body_dot2 = builder_body.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kDot, body_rhs, body_lhs)); + auto body_dot1 = + builder_body.AddInstruction(CreateDot(shape, body_lhs, body_rhs)); + auto body_dot2 = + builder_body.AddInstruction(CreateDot(shape, body_rhs, body_lhs)); auto body_transpose = builder_body.AddInstruction( HloInstruction::CreateTranspose(shape, body_dot2, {0, 1})); builder_body.AddInstruction( @@ -627,11 +653,10 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { HloInstruction::CreateGetTupleElement(shape, while_hlo, 0)); auto rhs = builder.AddInstruction( HloInstruction::CreateGetTupleElement(shape, while_hlo, 1)); - auto dot = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kDot, lhs, rhs)); + auto dot = builder.AddInstruction(CreateDot(shape, lhs, rhs)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE(OutputsBF16(lhs)); @@ -683,14 +708,20 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) { auto cond0_add_rhs = builder_cond0.AddInstruction(HloInstruction::CreateBinary( shape, HloOpcode::kAdd, cond0_rhs, cond0_rhs)); - auto cond0_dot = builder_cond0.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, cond0_lhs, cond0_add_rhs)); + auto cond0_dot = + builder_cond0.AddInstruction(CreateDot(shape, cond0_lhs, cond0_add_rhs)); builder_cond0.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, - builder_cond0.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond0_dot, {0, 0}, {1, 1}, {1, 1})), - builder_cond0.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond0_dot, {1, 1}, {2, 2}, {1, 1})))); + builder_cond0.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond0.AddInstruction( + HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}), + cond0_dot, {0, 0}, {1, 1}, {1, 1})))), + builder_cond0.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond0.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {1, 1}), cond0_dot, {1, 1}, {2, 2}, + {1, 1})))))); auto cond0 = module->AddEmbeddedComputation(builder_cond0.Build()); // Condition computation for the second while. @@ -705,14 +736,20 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) { auto cond1_add_lhs = builder_cond1.AddInstruction(HloInstruction::CreateBinary( shape, HloOpcode::kAdd, cond1_lhs, cond1_lhs)); - auto cond1_dot = builder_cond1.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, cond1_add_lhs, cond1_rhs)); + auto cond1_dot = + builder_cond1.AddInstruction(CreateDot(shape, cond1_add_lhs, cond1_rhs)); builder_cond1.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, - builder_cond1.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond1_dot, {0, 0}, {1, 1}, {1, 1})), - builder_cond1.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond1_dot, {1, 1}, {2, 2}, {1, 1})))); + builder_cond1.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond1.AddInstruction( + HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}), + cond1_dot, {0, 0}, {1, 1}, {1, 1})))), + builder_cond1.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond1.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {1, 1}), cond1_dot, {1, 1}, {2, 2}, + {1, 1})))))); auto cond1 = module->AddEmbeddedComputation(builder_cond1.Build()); // Body computation shared by both whiles. @@ -723,8 +760,8 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) { HloInstruction::CreateGetTupleElement(shape, body_param, 0)); auto body_rhs = builder_body.AddInstruction( HloInstruction::CreateGetTupleElement(shape, body_param, 1)); - auto body_dot = builder_body.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kDot, body_lhs, body_rhs)); + auto body_dot = + builder_body.AddInstruction(CreateDot(shape, body_lhs, body_rhs)); builder_body.AddInstruction( HloInstruction::CreateTuple({body_dot, body_rhs})); auto body = module->AddEmbeddedComputation(builder_body.Build()); @@ -734,23 +771,22 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) { auto while1 = builder.AddInstruction( HloInstruction::CreateWhile(tuple1->shape(), cond1, body, tuple1)); - auto lhs = builder.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, - builder.AddInstruction( - HloInstruction::CreateGetTupleElement(shape, while0, 0)), - builder.AddInstruction( - HloInstruction::CreateGetTupleElement(shape, while0, 1)))); - auto rhs = builder.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, - builder.AddInstruction( - HloInstruction::CreateGetTupleElement(shape, while1, 0)), - builder.AddInstruction( - HloInstruction::CreateGetTupleElement(shape, while1, 1)))); - auto dot = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kDot, lhs, rhs)); + auto lhs = builder.AddInstruction( + CreateDot(shape, + builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, while0, 0)), + builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, while0, 1)))); + auto rhs = builder.AddInstruction( + CreateDot(shape, + builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, while1, 0)), + builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, while1, 1)))); + auto dot = builder.AddInstruction(CreateDot(shape, lhs, rhs)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_FALSE(OutputsBF16(body_dot)); EXPECT_FALSE(OutputsBF16(body_rhs)); EXPECT_FALSE(OutputsBF16(body_lhs)); @@ -792,7 +828,7 @@ TEST_F(BFloat16PropagationTest, NoopConversionRemoved) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), add2); EXPECT_EQ(add2->operand(0), add0); @@ -821,15 +857,14 @@ TEST_F(BFloat16PropagationTest, TupleDomain) { HloInstruction::CreateGetTupleElement(shape, domain, 0)); HloInstruction* b_gte = builder.AddInstruction( HloInstruction::CreateGetTupleElement(shape, domain, 1)); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kDot, a_gte, b_gte)); + HloInstruction* dot = builder.AddInstruction(CreateDot(shape, a_gte, b_gte)); HloInstruction* root = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), root); // test BF16 propagated through domain @@ -867,15 +902,15 @@ TEST_F(BFloat16PropagationTest, TupleDomainNoPropagation) { HloInstruction::CreateTranspose(shape, a_gte, {0, 1})); HloInstruction* b_trans = builder.AddInstruction( HloInstruction::CreateTranspose(shape, b_gte, {0, 1})); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kDot, a_trans, b_trans)); + HloInstruction* dot = + builder.AddInstruction(CreateDot(shape, a_trans, b_trans)); HloInstruction* root = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), root); EXPECT_TRUE(OutputsBF16(a_trans)); diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index d412578619..2368ac8c6a 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -670,6 +670,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "@com_google_absl//absl/strings", diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index 0fea462c85..7d99b914d4 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" namespace op = xla::testing::opcode_matchers; @@ -696,8 +697,8 @@ void CreateComputationForDotAddOutputFusionTest(const string& test_name, auto* addend = builder.AddInstruction( HloInstruction::CreateParameter(2, dot_shape, "param2")); - auto* dot = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs)); + auto* dot = + builder.AddInstruction(CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs)); builder.AddInstruction( HloInstruction::CreateBinary(dot_shape, HloOpcode::kAdd, dot, addend)); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc index 9363af3b89..4668f3872d 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc @@ -70,7 +70,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensor) { auto dot_rhs = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape))); auto result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); + CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); @@ -107,9 +107,9 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor0) { auto dot_rhs = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape))); auto dot_a_result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(result_shape, dot_a_lhs, dot_rhs)); + CreateCanonicalDot(result_shape, dot_a_lhs, dot_rhs)); auto dot_b_result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(result_shape, dot_b_lhs, dot_rhs)); + CreateCanonicalDot(result_shape, dot_b_lhs, dot_rhs)); builder.AddInstruction(HloInstruction::CreateBinary( result_shape, HloOpcode::kAdd, dot_a_result, dot_b_result)); @@ -151,9 +151,9 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor1) { auto dot_rhs = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape))); auto dot_a_result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(result_a_shape, dot_a_lhs, dot_rhs)); + CreateCanonicalDot(result_a_shape, dot_a_lhs, dot_rhs)); auto dot_b_result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(result_b_shape, dot_b_lhs, dot_rhs)); + CreateCanonicalDot(result_b_shape, dot_b_lhs, dot_rhs)); auto tuple_result = builder.AddInstruction( HloInstruction::CreateTuple({dot_a_result, dot_b_result})); @@ -189,7 +189,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantLhsTensor) { auto dot_rhs = builder.AddInstruction( HloInstruction::CreateParameter(0, rhs_shape, "param0")); auto dot_result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); + CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); @@ -229,7 +229,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensorThroughGTE) { auto dot_rhs = builder.AddInstruction( HloInstruction::CreateGetTupleElement(rhs_shape, constant, 1)); auto dot_result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); + CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); @@ -276,8 +276,8 @@ static StatusOr<DotOutputFusionLayoutAssignmentResult> RunDotOutputFusion( HloInstruction::CreateParameter(1, dot_shape, "param1")); HloInstruction* dot_rhs = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateFromShape(dot_rhs_shape))); - HloInstruction* dot_result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs)); + HloInstruction* dot_result = + builder.AddInstruction(CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs)); HloInstruction* add_result; if (dot_operand_idx_in_add == 0) { add_result = builder.AddInstruction(HloInstruction::CreateBinary( diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc index a84ee78b19..fad76338a5 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc @@ -35,9 +35,7 @@ class ParallelTaskAssignmentTest : public HloVerifiedTestBase { cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features_; ParallelTaskAssignmentTest() - : HloVerifiedTestBase(/*layout_sensitive=*/false, - /*allow_mixed_precision=*/false), - target_machine_features_([](int64 shape_size) { + : HloVerifiedTestBase(), target_machine_features_([](int64 shape_size) { return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; }) {} diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD index 2384166fd2..f11aff0573 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD @@ -121,6 +121,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", + "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc index fcd87b36b3..18ee25ba91 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" #include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -69,8 +70,7 @@ TEST_P(CpuEigenDotOperationTest, SimpleDotOp) { HloInstruction* rhs = builder.AddInstruction( HloInstruction::CreateParameter(1, param_shape, "input")); - builder.AddInstruction( - HloInstruction::CreateCanonicalDot(param_shape, lhs, rhs)); + builder.AddInstruction(CreateCanonicalDot(param_shape, lhs, rhs)); CompileAndCheck(builder.Build(), spec.filecheck_lines); } @@ -87,8 +87,7 @@ TEST_P(CpuEigenDotOperationTest, DotTransposeOp) { HloInstruction* lhs_transposed = builder.AddInstruction( HloInstruction::CreateTranspose(param_shape, lhs, {1, 0})); - builder.AddInstruction( - HloInstruction::CreateCanonicalDot(param_shape, lhs_transposed, rhs)); + builder.AddInstruction(CreateCanonicalDot(param_shape, lhs_transposed, rhs)); CompileAndCheck(builder.Build(), spec.filecheck_lines); } diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 13ccff35f8..6791e15ee0 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -108,6 +108,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "@com_google_absl//absl/memory", @@ -480,6 +481,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -813,7 +815,6 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_ordering", "//tensorflow/compiler/xla/service:hlo_reachability", - "//tensorflow/compiler/xla/service:hlo_schedule", "//tensorflow/compiler/xla/service:hlo_scheduling", "@com_google_absl//absl/memory", ], @@ -831,6 +832,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings:str_format", diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc index 0922e44a12..59ade96f7d 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" namespace xla { @@ -73,10 +74,10 @@ TEST_F(GpuHloScheduleTest, SequentialMatMul) { /*parameter_number=*/1, f32_2x2_, /*name=*/"y")); HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/2, f32_2x2_, /*name=*/"z")); - HloInstruction* dot1 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, x, y)); - HloInstruction* dot2 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, z)); + HloInstruction* dot1 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, x, y)); + HloInstruction* dot2 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, z)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(dot2)); @@ -201,12 +202,12 @@ TEST_F(GpuHloScheduleTest, ConcurrentMatMul) { /*parameter_number=*/0, f32_2x2_, /*name=*/"x")); HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/1, f32_2x2_, /*name=*/"y")); - HloInstruction* dot1 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, x, y)); - HloInstruction* dot2 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, y, x)); - HloInstruction* add = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, dot2)); + HloInstruction* dot1 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, x, y)); + HloInstruction* dot2 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, y, x)); + HloInstruction* add = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, dot2)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(add)); @@ -269,23 +270,23 @@ TEST_F(GpuHloScheduleTest, LatticeMatMul) { i, f32_2x2_, /*name=*/absl::StrFormat("param%d", i)))); } HloInstruction* d00 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, params[2], params[3])); - HloInstruction* d10 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, params[1], d00)); - HloInstruction* d11 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d00, params[4])); - HloInstruction* d20 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, params[0], d10)); - HloInstruction* d21 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d10, d11)); - HloInstruction* d22 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d11, params[5])); - HloInstruction* d30 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d20, d21)); - HloInstruction* d31 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d21, d22)); - HloInstruction* d40 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d30, d31)); + CreateCanonicalDot(f32_2x2_, params[2], params[3])); + HloInstruction* d10 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, params[1], d00)); + HloInstruction* d11 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d00, params[4])); + HloInstruction* d20 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, params[0], d10)); + HloInstruction* d21 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d10, d11)); + HloInstruction* d22 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d11, params[5])); + HloInstruction* d30 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d20, d21)); + HloInstruction* d31 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d21, d22)); + HloInstruction* d40 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d30, d31)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(d40)); diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index bca775c475..96bfe0c12e 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/util.h" namespace op = xla::testing::opcode_matchers; @@ -111,8 +112,8 @@ TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfDotUnfused) { HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(S32, {1, 1}), "0")); - auto dot1 = builder.AddInstruction(HloInstruction::CreateCanonicalDot( - ShapeUtil::MakeShape(S32, {1, 1}), param0, param0)); + auto dot1 = builder.AddInstruction( + CreateCanonicalDot(ShapeUtil::MakeShape(S32, {1, 1}), param0, param0)); auto reshape2 = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(S32, {1, 1, 1}), dot1)); @@ -128,8 +129,8 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) { HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(S32, {1, 1}), "0")); - auto dot1 = builder.AddInstruction(HloInstruction::CreateCanonicalDot( - ShapeUtil::MakeShape(S32, {1, 1}), param0, param0)); + auto dot1 = builder.AddInstruction( + CreateCanonicalDot(ShapeUtil::MakeShape(S32, {1, 1}), param0, param0)); auto transpose2 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {1, 1}), dot1, {0, 1})); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index ffca5d6549..b7c37bcf3c 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -764,5 +764,20 @@ StatusOr<llvm::Value*> IrEmitter::ComputeNestedElement( return Load(return_buffer); } +std::vector<llvm_ir::IrArray> IrEmitter::ConstructIrArrayForOutputs( + const HloInstruction& hlo) { + std::vector<llvm_ir::IrArray> output_arrays; + if (ShapeUtil::IsTuple(hlo.shape())) { + int64 num_outputs = ShapeUtil::TupleElementCount(hlo.shape()); + output_arrays.reserve(num_outputs); + for (int64 i = 0; i < num_outputs; ++i) { + output_arrays.push_back(GetIrArray(hlo, hlo, {i})); + } + } else { + output_arrays.push_back(GetIrArray(hlo, hlo)); + } + return output_arrays; +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index 579268f071..8805201480 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -124,6 +124,12 @@ class IrEmitter : public DfsHloVisitorWithDefault, llvm::Value* GetBasePointer(const HloInstruction& inst) const { return bindings_.GetBasePointer(inst); } + + // Generates the IrArray for each output of an hlo instruction and returns + // a vector containing such IrArrays. + std::vector<llvm_ir::IrArray> ConstructIrArrayForOutputs( + const HloInstruction& hlo); + // A convenient helper for calling BufferAssignment::GetUniqueSlice. BufferAllocation::Slice GetAllocationSlice( const HloInstruction& hlo, const ShapeIndex& index = {}) const { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc index 5c827e5f9c..66c65f6975 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc @@ -119,21 +119,11 @@ Status IrEmitterNested::EmitTargetElementLoop( // For MOF we give the loop emitter an array for every output it should // generate. if (hlo.IsMultiOutputFusion()) { - const int64 num_elems = ShapeUtil::TupleElementCount(hlo.shape()); - std::vector<llvm_ir::IrArray> target_arrays; - target_arrays.reserve(num_elems); - for (int64 i = 0; i != num_elems; ++i) { - target_arrays.push_back(GetIrArray(hlo, hlo, {i})); - } + std::vector<llvm_ir::IrArray> target_arrays = + ConstructIrArrayForOutputs(hlo); TF_RETURN_IF_ERROR( llvm_ir::LoopEmitter(element_generator, target_arrays, &b_).EmitLoop()); - - std::vector<llvm::Value*> tuple_operand_ptrs; - tuple_operand_ptrs.reserve(num_elems); - for (const llvm_ir::IrArray& array : target_arrays) { - tuple_operand_ptrs.push_back(array.GetBasePointer()); - } - llvm_ir::EmitTuple(GetIrArray(hlo, hlo), tuple_operand_ptrs, &b_, module_); + llvm_ir::EmitTuple(GetIrArray(hlo, hlo), target_arrays, &b_, module_); return Status::OK(); } return llvm_ir::LoopEmitter(element_generator, GetIrArray(hlo, hlo), &b_) diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 389a98facb..f91cc00d71 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -2521,15 +2521,15 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildFftThunk( } StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk( - const HloInstruction* hlo, const ShapeIndex& index) { + HloInstruction* hlo, const ShapeIndex& index) { bool fused = HloOpcode::kFusion == hlo->opcode(); - const HloInstruction* inst = fused ? hlo->fused_expression_root() : hlo; - const HloInstruction* init_value_operand = [&] { + HloInstruction* inst = fused ? hlo->fused_expression_root() : hlo; + HloInstruction* init_value_operand = [&] { switch (inst->opcode()) { case HloOpcode::kSelectAndScatter: - return inst->operand(2); + return inst->mutable_operand(2); case HloOpcode::kReduce: - return inst->operand(1); + return inst->mutable_operand(1); case HloOpcode::kTuple: CHECK(hlo->IsMultiOutputFusion()) << ": " << hlo->ToString() << " is not a multi-output fusion."; @@ -2537,7 +2537,7 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk( << ": Found '" << inst->operand(index.back())->opcode() << "' in " << inst->ToString() << " but expected 'reduce'."; // For multi-output fusion look through the tuple. - return inst->operand(index.back())->operand(1); + return inst->mutable_operand(index.back())->mutable_operand(1); default: LOG(FATAL) << "Opcode " << inst->opcode() << " should not need an initializer."; @@ -2609,28 +2609,35 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk( ir_emitter_context_->device_description()); UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(), ir_emitter_context_->llvm_module()); - // If the init_value was fused into this reduce we have to generate it first. - if (fused && init_value_operand->opcode() != HloOpcode::kParameter) { - CHECK_EQ(HloOpcode::kConstant, init_value_operand->opcode()); - const Literal& literal = init_value_operand->literal(); - llvm::Constant* initializer = - llvm_ir::ConvertLiteralToIrConstant(literal, module_); + if (fused) { + // If init_value was fused into this reduce we have to generate it first. + std::vector<IrArray> parameter_arrays; + for (HloInstruction* operand : hlo->operands()) { + parameter_arrays.push_back(GetIrArray(*operand, *hlo)); + } + GpuElementalIrEmitter elemental_emitter(hlo_module_config_, + ir_emitter_context_->llvm_module(), + &b_, GetNestedComputer()); - llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable( - *module_, initializer->getType(), - /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, initializer, - /*Name=*/""); - global_for_const->setAlignment(kConstantBufferAlignBytes); - bindings_.BindHloToIrValue(*init_value_operand, global_for_const); + FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter); + TF_RETURN_IF_ERROR(init_value_operand->Accept(&fused_emitter)); + TF_RETURN_IF_ERROR( + ParallelLoopEmitter(fused_emitter.GetGenerator(init_value_operand), + GetIrArray(*hlo, *hlo, index), launch_dimensions, + &b_) + .EmitLoop(IrName(hlo))); + } else { + // In the unfused case the element is already there, just read from it. + TF_RETURN_IF_ERROR(ParallelLoopEmitter( + [=](const IrArray::Index& index) { + return GetIrArray(*init_value, *hlo) + .EmitReadArrayElement(index, &b_); + }, + GetIrArray(*hlo, *hlo, index), launch_dimensions, + &b_) + .EmitLoop(IrName(hlo))); } - TF_RETURN_IF_ERROR(ParallelLoopEmitter( - [=](const IrArray::Index& index) { - return GetIrArray(*init_value, *hlo) - .EmitReadArrayElement(index, &b_); - }, - GetIrArray(*hlo, *hlo, index), launch_dimensions, &b_) - .EmitLoop(IrName(hlo))); // Clean up state left behind by emitting the loop above. (This is normally // done in IrEmitterUnnested::Postprocess().) @@ -2819,10 +2826,7 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( } // For multioutput fusion, we need to emit each operand and the root. - std::vector<IrArray> output_arrays; - for (int64 i = 0; i < ShapeUtil::TupleElementCount(hlo.shape()); ++i) { - output_arrays.push_back(GetIrArray(hlo, hlo, {i})); - } + std::vector<IrArray> output_arrays = ConstructIrArrayForOutputs(hlo); TF_RETURN_IF_ERROR( ParallelLoopEmitter(element_generator, output_arrays, launch_dimensions, &b_, unroll_factor) @@ -2830,12 +2834,9 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( GetIndexTypeForKernel( &hlo, launch_dimensions.launch_bound(), &b_))); - std::vector<llvm::Value*> tuple_operand_ptrs; - for (int64 i = 0; i < output_arrays.size(); ++i) { - tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer()); - } b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator()); - llvm_ir::EmitTuple(GetIrArray(hlo, hlo), tuple_operand_ptrs, &b_, module_); + llvm_ir::EmitTuple(GetIrArray(hlo, hlo), output_arrays, &b_, module_); + return Status::OK(); } @@ -2847,29 +2848,14 @@ Status IrEmitterUnnested::EmitTargetElementLoop( static_cast<KernelThunk*>(LastThunk())); } -int IrEmitterUnnested::ConstructIrArrayForOutputs( - const HloInstruction& hlo, std::vector<IrArray>* output_arrays) { - int64 num_outputs = 1; - if (hlo.IsMultiOutputFusion()) { - num_outputs = ShapeUtil::TupleElementCount(hlo.shape()); - output_arrays->reserve(num_outputs); - for (int64 i = 0; i < num_outputs; ++i) { - output_arrays->push_back(GetIrArray(hlo, hlo, {i})); - } - } else { - output_arrays->push_back(GetIrArray(hlo, hlo)); - } - return num_outputs; -} - -int IrEmitterUnnested::ConstructIrArrayForInputs( - const HloInstruction& hlo, std::vector<IrArray>* param_arrays) { - int64 num_params = hlo.operands().size(); - param_arrays->reserve(num_params); +std::vector<IrArray> IrEmitterUnnested::ConstructIrArrayForInputs( + const HloInstruction& hlo) { + std::vector<IrArray> param_arrays; + param_arrays.reserve(hlo.operands().size()); for (const HloInstruction* param : hlo.operands()) { - param_arrays->push_back(GetIrArray(*param, hlo)); + param_arrays.push_back(GetIrArray(*param, hlo)); } - return num_params; + return param_arrays; } int IrEmitterUnnested::ConstructOutputReducedShapeAndCastOutputIrArrayToShape( @@ -3050,10 +3036,10 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( constexpr int64 kThreadsPerTile = kTileSize * kNumRows; // Construct IrArrays for the inputs and outputs. - std::vector<IrArray> output_arrays; - int64 num_outputs = ConstructIrArrayForOutputs(*hlo, &output_arrays); - std::vector<IrArray> param_arrays; - int64 num_params = ConstructIrArrayForInputs(*hlo, ¶m_arrays); + std::vector<IrArray> output_arrays = ConstructIrArrayForOutputs(*hlo); + int64 num_outputs = output_arrays.size(); + std::vector<IrArray> param_arrays = ConstructIrArrayForInputs(*hlo); + int64 num_params = param_arrays.size(); // Allocate shared memory buffers to store the tiled inputs. std::vector<llvm::Value*> param_shmem_buffers(num_params, nullptr); @@ -3251,12 +3237,7 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( // For multioutput fusion, emit a tuple with all the individual outputs. if (hlo->IsMultiOutputFusion()) { - std::vector<llvm::Value*> tuple_operand_ptrs; - for (int64 i = 0; i < output_arrays.size(); ++i) { - tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer()); - } - llvm_ir::EmitTuple(GetIrArray(*hlo, *hlo), tuple_operand_ptrs, &b_, - module_); + llvm_ir::EmitTuple(GetIrArray(*hlo, *hlo), output_arrays, &b_, module_); } return launch_dimensions; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 084462330e..bd5db72051 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -193,14 +193,12 @@ class IrEmitterUnnested : public IrEmitter { LaunchDimensions EmitHlo021Tile(HloInstruction* hlo, absl::Span<const int64> reduced_output_dims, absl::Span<const int64> tiled_param_ids); - // Generates the IrArray for each output of hlo and returns the number of - // outputs. - int ConstructIrArrayForOutputs(const HloInstruction& hlo, - std::vector<llvm_ir::IrArray>* output_arrays); - // Generates the IrArray for each input of hlo and returns the number of - // inputs. - int ConstructIrArrayForInputs(const HloInstruction& hlo, - std::vector<llvm_ir::IrArray>* param_arrays); + + // Generates the IrArray for each input of an hlo and returns a vector that + // constains such IrArrays. + std::vector<llvm_ir::IrArray> ConstructIrArrayForInputs( + const HloInstruction& hlo); + // For each output of the `hlo` instruction, constructs the reduced shape for // the output with the given `reduced_output_dims` and cast the original // output IrArray element in `output_arrays` to the reduced shape. Returns @@ -244,7 +242,7 @@ class IrEmitterUnnested : public IrEmitter { // Returns a thunk that, given a reduce or select-and-scatter op, initializes // its memory to the appropriate initial value. StatusOr<std::unique_ptr<Thunk>> BuildInitializerThunk( - const HloInstruction* hlo, const ShapeIndex& index = {}); + HloInstruction* hlo, const ShapeIndex& index = {}); // Returns a thunk that calls host-to-device cuMemcpy to implement `inst`. std::unique_ptr<Thunk> BuildHostToDeviceCopyThunk(const HloInstruction* inst); diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc index 091aca23e5..8f0dedfa40 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" namespace xla { @@ -49,10 +50,10 @@ TEST_F(StreamAssignmentTest, SequentialMatMul) { /*parameter_number=*/1, f32_2x2_, /*name=*/"y")); HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/2, f32_2x2_, /*name=*/"z")); - HloInstruction* dot1 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, x, y)); - HloInstruction* dot2 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, z)); + HloInstruction* dot1 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, x, y)); + HloInstruction* dot2 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, z)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(dot2)); @@ -68,10 +69,10 @@ TEST_F(StreamAssignmentTest, ConcurrentMatMul) { /*parameter_number=*/0, f32_2x2_, /*name=*/"x")); HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/1, f32_2x2_, /*name=*/"y")); - HloInstruction* dot1 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, x, y)); - HloInstruction* dot2 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, y, x)); + HloInstruction* dot1 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, x, y)); + HloInstruction* dot2 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, y, x)); HloInstruction* add = builder.AddInstruction( HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, dot1, dot2)); @@ -101,23 +102,23 @@ TEST_F(StreamAssignmentTest, LatticeMatMul) { i, f32_2x2_, /*name=*/absl::StrFormat("param%d", i)))); } HloInstruction* d00 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, params[2], params[3])); - HloInstruction* d10 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, params[1], d00)); - HloInstruction* d11 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d00, params[4])); - HloInstruction* d20 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, params[0], d10)); - HloInstruction* d21 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d10, d11)); - HloInstruction* d22 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d11, params[5])); - HloInstruction* d30 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d20, d21)); - HloInstruction* d31 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d21, d22)); - HloInstruction* d40 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d30, d31)); + CreateCanonicalDot(f32_2x2_, params[2], params[3])); + HloInstruction* d10 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, params[1], d00)); + HloInstruction* d11 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d00, params[4])); + HloInstruction* d20 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, params[0], d10)); + HloInstruction* d21 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d10, d11)); + HloInstruction* d22 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d11, params[5])); + HloInstruction* d30 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d20, d21)); + HloInstruction* d31 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d21, d22)); + HloInstruction* d40 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d30, d31)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(d40)); diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 99d0cf50ca..93ec2c9438 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -199,6 +199,17 @@ message HloComputationProto { int64 root_id = 6; } +// Serialization of an HLO schedule. An HLO schedule contains a total order of +// instructions for each non-fusion computation in the module. +message HloScheduleProto { + message InstructionSequence { + repeated int64 instruction_ids = 1; + } + + // Map from computation id to sequence. + map<int64, InstructionSequence> sequences = 1; +} + // Serialization of HloModule. message HloModuleProto { string name = 1; @@ -214,16 +225,9 @@ message HloModuleProto { // The id of this module. int64 id = 5; -} -// Serialization of HloOrdering. -message HloOrderingProto { - // NOTE: currently only sequential orderings are serialized. - message SequentialComputation { - string computation_name = 1; - repeated string instruction_names = 2; - } - repeated SequentialComputation sequential_computations = 1; + // The schedule for this module. + HloScheduleProto schedule = 7; } // Serialization of LogicalBuffer. @@ -322,8 +326,10 @@ message BufferAssignmentProto { // Grouping message that contains all of the information above. message HloProto { + reserved 2; + reserved "hlo_ordering"; + HloModuleProto hlo_module = 1; - HloOrderingProto hlo_ordering = 2; BufferAssignmentProto buffer_assignment = 3; } diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index fe7f2be888..233d2199d1 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -464,6 +464,14 @@ std::vector<HloComputation*> HloComputation::MakeEmbeddedComputationsList() } string HloComputation::ToString(const HloPrintOptions& options) const { + return ToString(options, MakeInstructionPostOrder()); +} + +string HloComputation::ToString( + const HloPrintOptions& options, + absl::Span<const HloInstruction* const> instruction_order) const { + CHECK_EQ(instruction_order.size(), instruction_count()); + std::ostringstream s; for (int i = 0; i < options.indent_amount(); i++) { s << " "; @@ -486,7 +494,9 @@ string HloComputation::ToString(const HloPrintOptions& options) const { new_options.set_indent_amount(options.indent_amount() + 1) .set_is_in_nested_computation(true); CanonicalNameMap name_map; - for (const HloInstruction* instruction : MakeInstructionPostOrder()) { + for (const HloInstruction* instruction : instruction_order) { + CHECK_EQ(this, instruction->parent()); + for (int i = 0; i < new_options.indent_amount(); i++) { s << " "; } diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index fe2d3bbbe5..91c5234a6f 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -170,6 +170,11 @@ class HloComputation { string ToString() const { return ToString(HloPrintOptions()); } string ToString(const HloPrintOptions& options) const; + // Overload which accepts an order to emit the instructions in. + string ToString( + const HloPrintOptions& options, + absl::Span<const HloInstruction* const> instruction_order) const; + // Returns a serialized representation of this computation. HloComputationProto ToProto() const; diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 939b5114c3..a502fff9a0 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -227,6 +227,14 @@ Status HloCostAnalysis::HandleCopy(const HloInstruction*) { return Status::OK(); } +Status HloCostAnalysis::HandleDomain(const HloInstruction* domain) { + // Domain does not have any computation or data transfer. + current_should_compute_bottleneck_time_ = false; + current_properties_[kBytesAccessedKey] = 0; + current_properties_[kOptimalSecondsKey] = 0; + return Status::OK(); +} + Status HloCostAnalysis::HandleDot(const HloInstruction* dot) { const Shape& lhs_shape = dot->operand(0)->shape(); const Shape& rhs_shape = dot->operand(1)->shape(); @@ -507,8 +515,9 @@ Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) { valid_position_counts.push_back(valid_position_count); } - const int64 fma_count = - input_feature * output_feature * batch * Product(valid_position_counts); + const int64 fma_count = (input_feature / convolution->feature_group_count()) * + output_feature * batch * + Product(valid_position_counts); current_properties_[kFlopsKey] = fma_count * kFmaFlops; return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index 9bb3f12ee2..46b4bbeef2 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -67,6 +67,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleRecvDone(const HloInstruction* recv_done) override; Status HandleConvert(const HloInstruction* convert) override; Status HandleCopy(const HloInstruction* copy) override; + Status HandleDomain(const HloInstruction* domain) override; Status HandleDot(const HloInstruction* dot) override; Status HandleConvolution(const HloInstruction* convolution) override; Status HandleFft(const HloInstruction* fft) override; diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index 2c854eea18..d76ce9ecbc 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -203,6 +203,35 @@ TEST_F(HloCostAnalysisTest, Convolution) { sizeof(float) * (10 * 20 + 3 * 3 + 8 * 18)); } +TEST_F(HloCostAnalysisTest, ConvolutionWithFeatureGroup) { + XlaBuilder builder("convolution"); + auto input = Parameter( + &builder, 0, + ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/120, /*y_dim=*/10, + /*x_dim=*/20}), + "input"); + auto kernel = Parameter( + &builder, 1, + ShapeUtil::MakeShape(F32, {/*p_dim=*/120, /*z_dim=*/1, /*y_dim=*/3, + /*x_dim=*/3}), + "kernel"); + Conv(input, kernel, {1, 1}, Padding::kValid, /*feature_group_count=*/120); + + // Run HLO cost analysis. + auto hlo_module = BuildHloGraph(&builder); + HloCostAnalysis analysis(ShapeSize); + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + // Output shape is [1x120x8x18] and each output element requires (3x3) + // FMAs and one FMA is 2 flops. + EXPECT_EQ(analysis.flop_count(), 120 * 8 * 18 * 2 * 3 * 3); + + // Bytes accessed is sum of inputs and output. + EXPECT_EQ(analysis.bytes_accessed(), + sizeof(float) * (120 * 10 * 20 + 120 * 3 * 3 + 120 * 8 * 18)); +} + TEST_F(HloCostAnalysisTest, Reduce) { XlaBuilder builder("reduce"); auto input = @@ -415,7 +444,7 @@ TEST_F(FusionCostAnalysis, NoLayout) { TEST_F(HloCostAnalysisTest, TupleCost) { HloCostAnalysis analysis(ShapeSize); { - XlaBuilder builder("matmul"); + XlaBuilder builder("tuple"); auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {123}), "x"); auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {42}), "y"); Tuple(&builder, {x, y}); @@ -430,6 +459,30 @@ TEST_F(HloCostAnalysisTest, TupleCost) { EXPECT_EQ(analysis.bytes_accessed(), kPointerSize * 2); } +using DomainCostAnalysis = HloTestBase; +TEST_F(DomainCostAnalysis, DomainCost) { + HloCostAnalysis analysis(ShapeSize); + + HloComputation::Builder builder("domain"); + auto x = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {123}), "x")); + auto y = builder.AddInstruction( + HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {42}), "y")); + auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({x, y})); + auto domain = builder.AddInstruction( + HloInstruction::CreateDomain(tuple->shape(), tuple, nullptr, nullptr)); + + auto hlo_module = CreateNewModule(); + hlo_module->AddEntryComputation(builder.Build()); + + EXPECT_EQ(hlo_module->entry_computation()->root_instruction(), domain); + ASSERT_IS_OK(domain->Accept(&analysis)); + + EXPECT_EQ(analysis.flop_count(*domain), 0); + EXPECT_EQ(analysis.transcendental_count(*domain), 0); + EXPECT_EQ(analysis.bytes_accessed(*domain), 0); +} + TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) { XlaBuilder builder("BaseDilatedConvolution"); auto input = Parameter( diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc index 406d712ec6..e09d5868f2 100644 --- a/tensorflow/compiler/xla/service/hlo_cse_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -29,7 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/util.h" @@ -44,7 +44,7 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -class HloCseTest : public HloTestBase { +class HloCseTest : public HloVerifiedTestBase { protected: HloCseTest() {} }; @@ -65,13 +65,13 @@ TEST_F(HloCseTest, CombineTwoConstants) { EXPECT_EQ(3, computation->instruction_count()); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(cse.Run(module).ValueOrDie()); EXPECT_EQ(2, computation->instruction_count()); HloInstruction* constant = *computation->instructions().begin(); EXPECT_EQ(42.0f, constant->literal().Get<float>({})); - auto result = ExecuteAndTransfer(std::move(module), {}); + auto result = ExecuteAndTransfer(module->Clone(), {}); auto expected = LiteralUtil::CreateR0<float>(84.0); EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); } @@ -96,14 +96,14 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { EXPECT_THAT(add, op::Add(constant1, constant2)); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(cse.Run(module).ValueOrDie()); EXPECT_EQ(2, computation->instruction_count()); auto first_operand = add->operand(0); EXPECT_THAT(first_operand, ::testing::AnyOf(constant1, constant2)); EXPECT_THAT(add, op::Add(first_operand, first_operand)); - auto result = ExecuteAndTransfer(std::move(module), {}); + auto result = ExecuteAndTransfer(module->Clone(), {}); auto expected = LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}}); EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); } @@ -128,12 +128,12 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { EXPECT_THAT(add, op::Add(constant1, constant2)); HloCSE cse(/*is_layout_sensitive=*/true); - EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(cse.Run(module).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); EXPECT_THAT(add, op::Add(constant1, constant2)); - auto result = ExecuteAndTransfer(std::move(module), {}); + auto result = ExecuteAndTransfer(module->Clone(), {}); auto expected = LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}}); EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); } @@ -177,7 +177,7 @@ TEST_F(HloCseTest, ConstantsSameValueDifferentType) { EXPECT_EQ(20, computation->instruction_count()); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(cse.Run(module).ValueOrDie()); // CSE will remove both the second float(42.0f) and the corresponding // convert/cast. @@ -209,7 +209,7 @@ TEST_F(HloCseTest, NonscalarConstants) { op::Tuple(common_constant1, common_constant2, uncommon_constant)); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(cse.Run(module).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); auto first_operand = tuple->operand(0); @@ -240,7 +240,7 @@ TEST_F(HloCseTest, IdenticalInstructions) { EXPECT_THAT(tuple, op::Tuple(exp1, exp2, exp3)); HloCSE cse(/*is_layout_sensitive=*/true); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(cse.Run(module).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); auto first_operand = tuple->operand(0); @@ -250,7 +250,7 @@ TEST_F(HloCseTest, IdenticalInstructions) { // Test two identical while loops with same inputs TEST_F(HloCseTest, WhileLoopsIdenticalConditionsAndBodiesSameInput) { - auto module = ParseHloString(R"( + ParseAndVerifyModule(R"( HloModule WhileLoopsIdenticalConditionsAndBodiesSameInput %body (param: (f32[], f32[])) -> (f32[], f32[]) { @@ -278,21 +278,20 @@ f32[]) while((f32[], f32[]) %tuple.1), condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.1), condition=%condition.1, body=%body } - )") - .ValueOrDie(); + )"); - auto computation = module->entry_computation(); + auto computation = module().entry_computation(); EXPECT_EQ(5, computation->instruction_count()); HloCSE cse(true); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(cse.Run(&module()).ValueOrDie()); EXPECT_EQ(4, computation->instruction_count()); } // Test two while loops with same conditions, same inputs, but different // bodies TEST_F(HloCseTest, WhileLoopsIdenticalConditionsSameInputAndDifferentBodies) { - auto module = ParseHloString(R"( + ParseAndVerifyModule(R"( HloModule WhileLoopsIdenticalConditionsSameInputAndDifferentBodies %body (param: (f32[], f32[])) -> (f32[], f32[]) { @@ -329,20 +328,19 @@ index=1 %sub = f32[] subtract(f32[] %get-tuple-element.2, f32[] condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.1), condition=%condition.1, body=%body2 } - )") - .ValueOrDie(); + )"); - auto computation = module->entry_computation(); + auto computation = module().entry_computation(); EXPECT_EQ(5, computation->instruction_count()); HloCSE cse(true); - EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(cse.Run(&module()).ValueOrDie()); EXPECT_EQ(5, computation->instruction_count()); } // Test two identical while loops with different inputs TEST_F(HloCseTest, WhileLoopsIdenticalConditionsAndBodiesDifferentInput) { - auto module = ParseHloString(R"( + ParseAndVerifyModule(R"( HloModule WhileLoopsIdenticalConditionsAndBodiesDifferentInput %body (param: (f32[], f32[])) -> (f32[], f32[]) { @@ -373,21 +371,20 @@ f32[] constant(2) %tuple.2 = (f32[], f32[]) tuple(f32[] %constant.4, f32[] condition=%condition.1, body=%body } - )") - .ValueOrDie(); + )"); - auto computation = module->entry_computation(); + auto computation = module().entry_computation(); EXPECT_EQ(8, computation->instruction_count()); HloCSE cse(true); - EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(cse.Run(&module()).ValueOrDie()); EXPECT_EQ(8, computation->instruction_count()); } // Test two while loops with identical bodies and same inputs, but different // conditions TEST_F(HloCseTest, WhileLoopsIdenticalBodiesAndInputDifferntConditions) { - auto module = ParseHloString(R"( + ParseAndVerifyModule(R"( HloModule WhileLoopsIdenticalBodiesAndInputDifferntConditions %body (param: (f32[], f32[])) -> (f32[], f32[]) { @@ -414,14 +411,13 @@ f32[]) { %constant.2 = f32[] constant(1) %constant.3 = f32[] constant(2) %while = (f32[], f32[]) while((f32[], f32[]) %tuple.1), condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.1), condition=%condition.1, body=%body - })") - .ValueOrDie(); + })"); - auto computation = module->entry_computation(); + auto computation = module().entry_computation(); EXPECT_EQ(5, computation->instruction_count()); HloCSE cse(true); - EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(cse.Run(&module()).ValueOrDie()); EXPECT_EQ(5, computation->instruction_count()); } @@ -450,7 +446,7 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsSensitive) { EXPECT_THAT(tuple, op::Tuple(exp1, exp2)); HloCSE cse(/*is_layout_sensitive=*/true); - EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(cse.Run(module).ValueOrDie()); EXPECT_EQ(4, computation->instruction_count()); EXPECT_THAT(tuple, op::Tuple(exp1, exp2)); @@ -481,7 +477,7 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsInsensitive) { EXPECT_THAT(tuple, op::Tuple(exp1, exp2)); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(cse.Run(module).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); auto first_operand = tuple->operand(0); @@ -516,7 +512,7 @@ TEST_F(HloCseTest, FusionInternalCSE) { EXPECT_EQ(5, fused_computation->instruction_count()); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(cse.Run(module).ValueOrDie()); EXPECT_EQ(4, fused_computation->instruction_count()); auto root = fused_computation->root_instruction(); @@ -565,7 +561,7 @@ TEST_F(HloCseTest, IdenticalExpressions) { EXPECT_THAT(tuple, op::Tuple(op::Add(negate1, exp1), op::Add(negate2, exp2))); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(cse.Run(module).ValueOrDie()); EXPECT_EQ(5, computation->instruction_count()); auto operand = tuple->operand(0); @@ -599,7 +595,7 @@ TEST_F(HloCseTest, DoNotCombineRng) { uint32 count_before = computation->instruction_count(); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(cse.Run(module).ValueOrDie()); uint32 count_after = computation->instruction_count(); EXPECT_EQ(count_before, count_after); @@ -653,7 +649,7 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) { VLOG(3) << "before: " << module->ToString(); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(cse.Run(module).ValueOrDie()); VLOG(3) << "after: " << module->ToString(); @@ -663,7 +659,7 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) { } TEST_F(HloCseTest, CompareComputations) { - auto module = ParseHloString(R"( + ParseAndVerifyModule(R"( HloModule m add_computation { @@ -684,12 +680,11 @@ TEST_F(HloCseTest, CompareComputations) { r1 = f32[] reduce(p, c), dimensions={0}, to_apply=add_computation r2 = f32[] reduce(p, c), dimensions={0}, to_apply=add_computation2 ROOT f2 = (f32[],f32[]) tuple(r1, r2) - })") - .ValueOrDie(); + })"); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); - HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_TRUE(cse.Run(&module()).ValueOrDie()); + HloInstruction* root = module().entry_computation()->root_instruction(); EXPECT_EQ(root->operand(0), root->operand(1)); } @@ -708,13 +703,13 @@ TEST_F(HloCseTest, ConstantsSameValueInDifferentDomains) { EXPECT_EQ(2, computation->instruction_count()); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(cse.Run(module).ValueOrDie()); EXPECT_EQ(2, computation->instruction_count()); } TEST_F(HloCseTest, Domain) { - auto module = ParseHloString(R"( + ParseAndVerifyModule(R"( HloModule module ENTRY %entry { %param = f32[] parameter(0), sharding={maximal device=0} @@ -735,13 +730,11 @@ ENTRY %entry { domain={kind="sharding", entry={maximal device=2}, exit={maximal device=0}} %add = f32[] add(%domain.3, %domain.4) ROOT %sub = f32[] subtract(%add, %domain.5) -})") - .ValueOrDie(); +})"); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); - LOG(INFO) << "AAAAA " << module->ToString(); - const HloInstruction* sub = module->entry_computation()->root_instruction(); + EXPECT_TRUE(cse.Run(&module()).ValueOrDie()); + const HloInstruction* sub = module().entry_computation()->root_instruction(); const HloInstruction* add = sub->operand(0); EXPECT_EQ(add->operand(0), add->operand(1)); EXPECT_NE(add->operand(0), sub->operand(1)); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index abd4bb1f73..102ebb24ab 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -52,10 +52,7 @@ static std::array<bool, 2> use_bf16_params{true, false}; class HloEvaluatorTest : public ::testing::WithParamInterface<bool>, public HloVerifiedTestBase { protected: - HloEvaluatorTest() - : HloVerifiedTestBase(/*layout_sensitive=*/false, - /*allow_mixed_precision=*/false), - use_bfloat16_(GetParam()) { + HloEvaluatorTest() : HloVerifiedTestBase(), use_bfloat16_(GetParam()) { evaluator_ = absl::make_unique<HloEvaluator>(); } diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 6a09bb08f4..63303aef1e 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -1052,7 +1052,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { auto func = [&window_shape, &dnums, &lhs_shape, &rhs_shape, &window, &lhs_dim_multipliers, &rhs_dim_multipliers, lhs_literal_data, rhs_literal_data, - feature_group_count](absl::Span<const int64> out_index) { + feature_group_count](const absl::Span<const int64> out_index) { // Dimension number applicable for input (lhs). const int64 input_batch_dim = dnums.input_batch_dimension(); const int64 input_z_dim = dnums.input_feature_dimension(); @@ -1063,9 +1063,22 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const int64 output_batch_dim = dnums.output_batch_dimension(); const int64 output_z_dim = dnums.output_feature_dimension(); - const int64 z_size = ShapeUtil::GetDimension(lhs_shape, input_z_dim); + const int64 input_z_size = + ShapeUtil::GetDimension(lhs_shape, input_z_dim); + // The size of an input feature group. + const int64 input_feature_group_size = input_z_size / feature_group_count; + const int64 output_z_size = ShapeUtil::GetDimension(rhs_shape, kernel_output_z_dim); + // The output feature dimension is a concatenation of convolution results + // from the different groups. + const int64 output_feature_group_size = + output_z_size / feature_group_count; + + // Calculate the group index to which the current output index + // belongs. + const int64 feature_group_index = + out_index[output_z_dim] / output_feature_group_size; ElementwiseT result_val = static_cast<ElementwiseT>(0); DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size(), @@ -1073,33 +1086,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Convolve input feature with kernel. do { - for (int64 iz = 0; iz < z_size; ++iz) { - int64 rhs_iz = iz; - // Handle grouped convolutions. - if (feature_group_count > 1) { - // The size of a feature group. - int64 feature_group_size = z_size / feature_group_count; - rhs_iz = iz % feature_group_size; - - // The output feature dimension is a concatenation of convolution - // results from the different groups. - int64 output_feature_group_size = - output_z_size / feature_group_count; - - // Calculate the group index to which the current input feature - // index belongs. - int64 input_group_index = iz / feature_group_size; - - // Calculate the group index to which the current output index - // belongs. - int64 output_group_index = - out_index[output_z_dim] / output_feature_group_size; - if (input_group_index != output_group_index) { - // If the current output index does not belong to the current - // feature group, skip it. - continue; - } - } + for (int64 rhs_iz = 0; rhs_iz < input_feature_group_size; ++rhs_iz) { + const int64 iz = + feature_group_index * input_feature_group_size + rhs_iz; int64 lhs_linear_index = 0; lhs_linear_index += out_index[output_batch_dim] * diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 471a12d6aa..25ae344ea5 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -451,6 +451,28 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( << proto.dimensions_size(); instruction = CreateIota(proto.shape(), proto.dimensions(0)); break; + case HloOpcode::kDot: { + TF_RET_CHECK(proto.has_dot_dimension_numbers()) + << "Dot instruction should have dot_dimension_numbers."; + TF_RET_CHECK(proto.operand_ids_size() == 2) + << "Dot instruction should have 2 operands but sees " + << proto.operand_ids_size(); + PrecisionConfig precision_config = proto.precision_config(); + precision_config.mutable_operand_precision()->Resize( + proto.operand_ids_size(), PrecisionConfig::DEFAULT); + instruction = absl::make_unique<HloDotInstruction>( + proto.shape(), operands(0), operands(1), + proto.dot_dimension_numbers(), precision_config); + break; + } + case HloOpcode::kDomain: + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Domain instruction should have 1 operands but sees " + << proto.operand_ids_size(); + instruction = absl::make_unique<HloDomainInstruction>( + proto.shape(), operands(0), /*operand_side_metadata=*/nullptr, + /*user_side_metadata=*/nullptr); + break; default: { instruction = absl::WrapUnique(new HloInstruction(opcode, proto.shape())); for (const int64 operand_id : proto.operand_ids()) { @@ -472,20 +494,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( computation_map.at(computation_id)); } } - if (instruction->opcode() == HloOpcode::kDot) { - instruction->precision_config_ = proto.precision_config(); - instruction->precision_config_.mutable_operand_precision()->Resize( - instruction->operand_count(), PrecisionConfig::DEFAULT); - TF_RET_CHECK(proto.has_dot_dimension_numbers()); - instruction->dot_dimension_numbers_ = - absl::make_unique<DotDimensionNumbers>( - proto.dot_dimension_numbers()); - } else { - TF_RET_CHECK(!proto.has_precision_config()) - << instruction->opcode() << proto.DebugString(); - TF_RET_CHECK(!proto.has_dot_dimension_numbers()) - << instruction->opcode(); - } + TF_RET_CHECK(!proto.has_precision_config()) + << instruction->opcode() << proto.DebugString(); + TF_RET_CHECK(!proto.has_dot_dimension_numbers()) << instruction->opcode(); break; } } @@ -564,7 +575,6 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kClz: - case HloOpcode::kDomain: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFloor: @@ -596,7 +606,6 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case HloOpcode::kAtan2: case HloOpcode::kDivide: case HloOpcode::kComplex: - case HloOpcode::kDot: case HloOpcode::kEq: case HloOpcode::kGe: case HloOpcode::kGt: @@ -674,30 +683,8 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, const DotDimensionNumbers& dimension_numbers, const PrecisionConfig& precision_config) { - auto instruction = - absl::WrapUnique(new HloInstruction(HloOpcode::kDot, shape)); - instruction->AppendOperand(lhs); - instruction->AppendOperand(rhs); - instruction->dot_dimension_numbers_ = - absl::make_unique<DotDimensionNumbers>(dimension_numbers); - instruction->set_precision_config(precision_config); - return instruction; -} - -/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCanonicalDot( - const Shape& shape, HloInstruction* lhs, HloInstruction* rhs) { - CHECK_EQ(ShapeUtil::Rank(lhs->shape()), 2); - CHECK_EQ(ShapeUtil::Rank(rhs->shape()), 2); - - auto instruction = - absl::WrapUnique(new HloInstruction(HloOpcode::kDot, shape)); - instruction->AppendOperand(lhs); - instruction->AppendOperand(rhs); - instruction->dot_dimension_numbers_ = - absl::make_unique<DotDimensionNumbers>(); - instruction->dot_dimension_numbers_->add_lhs_contracting_dimensions(1); - instruction->dot_dimension_numbers_->add_rhs_contracting_dimensions(0); - return instruction; + return absl::make_unique<HloDotInstruction>( + shape, lhs, rhs, dimension_numbers, precision_config); } /* static */ std::unique_ptr<HloInstruction> @@ -1157,12 +1144,9 @@ bool HloInstruction::HasSideEffect() const { const Shape& shape, HloInstruction* operand, std::unique_ptr<DomainMetadata> operand_side_metadata, std::unique_ptr<DomainMetadata> user_side_metadata) { - auto instruction = - absl::WrapUnique(new HloInstruction(HloOpcode::kDomain, shape)); - instruction->operand_side_metadata_ = std::move(operand_side_metadata); - instruction->user_side_metadata_ = std::move(user_side_metadata); - instruction->AppendOperand(operand); - return instruction; + return absl::make_unique<HloDomainInstruction>( + shape, operand, std::move(operand_side_metadata), + std::move(user_side_metadata)); } std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( @@ -1218,6 +1202,8 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( case HloOpcode::kGather: case HloOpcode::kScatter: case HloOpcode::kIota: + case HloOpcode::kDot: + case HloOpcode::kDomain: clone = CloneWithNewOperandsImpl(shape, new_operands, context); break; // Unary ops. @@ -1290,11 +1276,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( CHECK_EQ(new_operands.size(), 1); clone = CreateBitcastConvert(shape, new_operands[0]); break; - case HloOpcode::kDot: - CHECK_EQ(new_operands.size(), 2); - clone = CreateDot(shape, new_operands[0], new_operands[1], - *dot_dimension_numbers_, precision_config()); - break; case HloOpcode::kReshape: CHECK_EQ(new_operands.size(), 1); clone = CreateReshape(shape, new_operands[0]); @@ -1319,12 +1300,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( true_computation(), new_operands[2], false_computation()); break; - case HloOpcode::kDomain: - CHECK_EQ(new_operands.size(), 1); - clone = - CreateDomain(shape, new_operands[0], operand_side_metadata_->Clone(), - user_side_metadata_->Clone()); - break; case HloOpcode::kAfterAll: if (new_operands.empty()) { clone = CreateToken(); @@ -1620,11 +1595,6 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kAfterAll: return false; - // Check dot dimension numbers. - case HloOpcode::kDot: - return protobuf_util::ProtobufEquals(dot_dimension_numbers(), - other.dot_dimension_numbers()); - // Remaining instructions with special values. case HloOpcode::kCall: return eq_computations(to_apply(), other.to_apply()); @@ -1640,10 +1610,6 @@ bool HloInstruction::IdenticalSlowPath( return false; } - case HloOpcode::kDomain: - return operand_side_metadata().Matches(other.operand_side_metadata()) && - user_side_metadata().Matches(other.user_side_metadata()); - // Ops migrated to subclasses should never come to this line. // TODO(b/80131774): Remove this switch when migration is complete. case HloOpcode::kBatchNormTraining: @@ -1683,6 +1649,8 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kDynamicSlice: case HloOpcode::kGather: case HloOpcode::kScatter: + case HloOpcode::kDot: + case HloOpcode::kDomain: LOG(FATAL) << "Base class impl called for opcode with subclass: " << opcode(); } @@ -2052,15 +2020,6 @@ std::vector<string> HloInstruction::ExtraAttributesToString( const HloPrintOptions& options) const { std::vector<string> extra = ExtraAttributesToStringImpl(options); - if (dot_dimension_numbers_ != nullptr) { - extra.push_back(DotDimensionNumbersToString()); - } - - string precision_config_string = PrecisionConfigToString(); - if (!precision_config_string.empty()) { - extra.push_back(precision_config_string); - } - if (options.print_subcomputation_mode() == HloPrintOptions::PrintSubcomputationMode::kNameOnly) { if (opcode() == HloOpcode::kWhile) { @@ -2146,11 +2105,6 @@ std::vector<string> HloInstruction::ExtraAttributesToString( }), "}")); } - if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) { - extra.push_back(StrCat("domain={kind=\"", operand_side_metadata_->Kind(), - "\", entry=", user_side_metadata_->ToString(), - ", exit=", operand_side_metadata_->ToString(), "}")); - } return extra; } @@ -2182,19 +2136,12 @@ HloInstructionProto HloInstruction::ToProto() const { *proto.mutable_metadata() = metadata_; proto.set_backend_config(backend_config_); - if (opcode() == HloOpcode::kConvolution || opcode() == HloOpcode::kDot) { - *proto.mutable_precision_config() = precision_config_; - } if (opcode() != HloOpcode::kFusion) { for (const HloComputation* computation : called_computations_) { proto.add_called_computation_ids(computation->unique_id()); } } - if (dot_dimension_numbers_ != nullptr) { - *proto.mutable_dot_dimension_numbers() = *dot_dimension_numbers_; - } - if (has_sharding()) { *proto.mutable_sharding() = sharding().ToProto(); } @@ -2921,31 +2868,6 @@ string ConvolutionDimensionNumbersToString( StrJoin(output_dims, "")); } -string HloInstruction::DotDimensionNumbersToString() const { - std::vector<string> result; - if (dot_dimension_numbers_ == nullptr) { - return ""; - } - const DotDimensionNumbers& dnums = *dot_dimension_numbers_; - if (!dnums.lhs_batch_dimensions().empty()) { - result.push_back(StrCat("lhs_batch_dims={", - StrJoin(dnums.lhs_batch_dimensions(), ","), "}")); - } - result.push_back(StrCat("lhs_contracting_dims={", - StrJoin(dnums.lhs_contracting_dimensions(), ","), - "}")); - - if (!dnums.rhs_batch_dimensions().empty()) { - result.push_back(StrCat("rhs_batch_dims={", - StrJoin(dnums.rhs_batch_dimensions(), ","), "}")); - } - result.push_back(StrCat("rhs_contracting_dims={", - StrJoin(dnums.rhs_contracting_dimensions(), ","), - "}")); - - return StrJoin(result, ", "); -} - StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) { static std::unordered_map<string, RandomDistribution>* map = [] { static auto* map = new std::unordered_map<string, RandomDistribution>; @@ -2964,27 +2886,6 @@ StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) { return found->second; } -string HloInstruction::PrecisionConfigToString() const { - if (absl::c_all_of( - precision_config_.operand_precision(), [](int32 precision) { - return static_cast<PrecisionConfig::Precision>(precision) == - PrecisionConfig::DEFAULT; - })) { - return ""; - } - return StrCat( - "operand_precision={", - StrJoin( - precision_config_.operand_precision(), ",", - [](string* out, int32 precision) { - CHECK(PrecisionConfig::Precision_IsValid(precision)) << precision; - StrAppend(out, - PrecisionToString( - static_cast<PrecisionConfig::Precision>(precision))); - }), - "}"); -} - StatusOr<PrecisionConfig::Precision> StringToPrecision(const string& name) { static std::unordered_map<string, PrecisionConfig::Precision>* map = [] { static auto* map = @@ -3044,6 +2945,16 @@ Status HloInstruction::set_backend_config( return ret; } +const PrecisionConfig& HloInstruction::precision_config() const { + if (auto* convolution = DynCast<HloConvolutionInstruction>(this)) { + return convolution->precision_config(); + } + if (auto* dot = DynCast<HloDotInstruction>(this)) { + return dot->precision_config(); + } + LOG(FATAL) << "Unimplemented method."; +} + HloModule* HloInstruction::GetModule() const { if (parent_) { return parent_->parent(); @@ -3348,4 +3259,15 @@ const ScatterDimensionNumbers& HloInstruction::scatter_dimension_numbers() return Cast<HloScatterInstruction>(this)->scatter_dimension_numbers(); } +const DotDimensionNumbers& HloInstruction::dot_dimension_numbers() const { + return Cast<HloDotInstruction>(this)->dot_dimension_numbers(); +} + +const DomainMetadata& HloInstruction::operand_side_metadata() const { + return Cast<HloDomainInstruction>(this)->operand_side_metadata(); +} + +const DomainMetadata& HloInstruction::user_side_metadata() const { + return Cast<HloDomainInstruction>(this)->user_side_metadata(); +} } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 691f8155f9..5581c17c2d 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -421,12 +421,6 @@ class HloInstruction { const DotDimensionNumbers& dimension_numbers, const PrecisionConfig& precision_config); - // Creates a dot op with operands 'lhs' and 'rhs' that contracts dimension 1 - // of the LHS with dimension 0 of the RHS with no batch dimensions. Both LHS - // and the RHS must be of rank 2. - static std::unique_ptr<HloInstruction> CreateCanonicalDot( - const Shape& shape, HloInstruction* lhs, HloInstruction* rhs); - // Creates a reduce-precision op, where operand is the data to reduce in // precision, and exponent_bits and mantissa_bits describe the precision to // reduce it to. @@ -866,11 +860,6 @@ class HloInstruction { return false; } - if (!absl::c_equal(precision_config_.operand_precision(), - other.precision_config_.operand_precision())) { - return false; - } - return IdenticalSlowPath(other, eq_computations); } @@ -1085,15 +1074,6 @@ class HloInstruction { return other->has_sharding() ? sharding() == other->sharding() : false; } - // Retrieves the operand side metadata of a kDomain instruction. - const DomainMetadata& operand_side_metadata() const { - return *operand_side_metadata_; - } - // Retrieves the user side metadata of a kDomain instruction. - const DomainMetadata& user_side_metadata() const { - return *user_side_metadata_; - } - // When creating a new instruction which either replaces, or shifts up (kCopy // insertion case), another instruction, we need to make sure the certain // properties of the new instruction are copied into the derived one. As of @@ -1101,18 +1081,6 @@ class HloInstruction { // instruction. void SetupDerivedInstruction(HloInstruction* derived_instruction) const; - // Returns data on the dimension numbers used for a dot operation. - const DotDimensionNumbers& dot_dimension_numbers() const { - CHECK(dot_dimension_numbers_ != nullptr); - return *dot_dimension_numbers_; - } - - // Returns the dump string of the dot dimension numbers. - string DotDimensionNumbersToString() const; - - // Returns the dump string of the precision configuration. - string PrecisionConfigToString() const; - // Clones the HLO instruction. The clone will have the same opcode, shape, and // operands. After creation the clone has no uses. "this" (the instruction // cloned from) is not changed. Suffix is the string to append to the name of @@ -1262,10 +1230,8 @@ class HloInstruction { // information. Transformations to other HLOs will not preserve this // information but it is presumed that the alternate lowering is strictly // superior. - const PrecisionConfig& precision_config() const { return precision_config_; } - void set_precision_config(const PrecisionConfig& precision_config) { - precision_config_ = precision_config; - } + // Precondition: opcode must be kConvolution or kDot. + const PrecisionConfig& precision_config() const; // Sets the debug metadata for this instruction. void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; } @@ -1508,6 +1474,15 @@ class HloInstruction { // Delegates to HloScatterInstruction::scatter_dimension_numbers(). const ScatterDimensionNumbers& scatter_dimension_numbers() const; + // Delegates to HloDotInstruction::dot_dimension_numbers(). + const DotDimensionNumbers& dot_dimension_numbers() const; + + // Delegates to HloDomainInstruction::operand_side_metadata(). + const DomainMetadata& operand_side_metadata() const; + + // Delegates to HloDomainInstruction::user_side_metadata(). + const DomainMetadata& user_side_metadata() const; + // Old methods kept for smooth subclassing transition END. protected: @@ -1647,22 +1622,12 @@ class HloInstruction { // Result shape of this instruction. Shape shape_; - // Describes the dimension numbers used for a dot. - std::unique_ptr<DotDimensionNumbers> dot_dimension_numbers_; - - // Used to tag kCopy instructions that are eligible for copy elision. - bool copy_elision_allowed_ = true; - // The sharding, if one exists. // Uses std::shared_ptr to allow reuse of the same sharding object between // HloInstructions and other components as HloSharding can be very large for // many element tuples. std::shared_ptr<const HloSharding> sharding_; - // Fields used by the kDomain instruction. - std::unique_ptr<DomainMetadata> operand_side_metadata_; - std::unique_ptr<DomainMetadata> user_side_metadata_; - // Computations called by this instruction. std::vector<HloComputation*> called_computations_; @@ -1676,10 +1641,6 @@ class HloInstruction { // HLO. See the documentation on backend_config(). string backend_config_; - // Information used to communicate to the implementation about the algorithm - // used to produce results. See the documentation on precision_config(). - PrecisionConfig precision_config_; - // String identifier for instruction. string name_; diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index ad87aa1123..fb7345a2ad 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -47,6 +47,27 @@ bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction, return instruction->IsElementwiseOnOperand(operand_index); }); } + +string PrecisionConfigToString(const PrecisionConfig& precision_config) { + if (absl::c_all_of(precision_config.operand_precision(), [](int32 precision) { + return static_cast<PrecisionConfig::Precision>(precision) == + PrecisionConfig::DEFAULT; + })) { + return ""; + } + + return StrCat( + "operand_precision={", + StrJoin( + precision_config.operand_precision(), ",", + [](string* out, int32 precision) { + CHECK(PrecisionConfig::Precision_IsValid(precision)) << precision; + StrAppend(out, + PrecisionToString( + static_cast<PrecisionConfig::Precision>(precision))); + }), + "}"); +} } // namespace HloBatchNormInstruction::HloBatchNormInstruction( @@ -1634,7 +1655,8 @@ HloConvolutionInstruction::HloConvolutionInstruction( : HloInstruction(HloOpcode::kConvolution, shape), feature_group_count_(feature_group_count), window_(window), - convolution_dimension_numbers_(dimension_numbers) { + convolution_dimension_numbers_(dimension_numbers), + precision_config_(precision_config) { if (window_util::HasBaseDilation(window)) { SetAndSanitizeName(StrCat(name(), "-base-dilated")); } @@ -1643,7 +1665,6 @@ HloConvolutionInstruction::HloConvolutionInstruction( } AppendOperand(lhs); AppendOperand(rhs); - set_precision_config(precision_config); } string HloConvolutionInstruction::ToCategory() const { @@ -1663,6 +1684,7 @@ HloInstructionProto HloConvolutionInstruction::ToProto() const { *proto.mutable_convolution_dimension_numbers() = convolution_dimension_numbers_; proto.set_feature_group_count(feature_group_count_); + *proto.mutable_precision_config() = precision_config_; return proto; } @@ -1677,6 +1699,12 @@ std::vector<string> HloConvolutionInstruction::ExtraAttributesToStringImpl( if (feature_group_count_ != 1) { extra.push_back(StrCat("feature_group_count=", feature_group_count_)); } + + string precision_config_string = PrecisionConfigToString(precision_config_); + if (!precision_config_string.empty()) { + extra.push_back(precision_config_string); + } + return extra; } @@ -1692,7 +1720,9 @@ bool HloConvolutionInstruction::IdenticalSlowPath( return protobuf_util::ProtobufEquals(window(), casted_other.window()) && protobuf_util::ProtobufEquals( convolution_dimension_numbers(), - casted_other.convolution_dimension_numbers()); + casted_other.convolution_dimension_numbers()) && + protobuf_util::ProtobufEquals(precision_config(), + casted_other.precision_config()); } std::unique_ptr<HloInstruction> @@ -1702,7 +1732,7 @@ HloConvolutionInstruction::CloneWithNewOperandsImpl( CHECK_EQ(new_operands.size(), 2); return absl::make_unique<HloConvolutionInstruction>( shape, new_operands[0], new_operands[1], feature_group_count_, window(), - convolution_dimension_numbers_, precision_config()); + convolution_dimension_numbers_, precision_config_); } HloReduceWindowInstruction::HloReduceWindowInstruction( @@ -2161,4 +2191,113 @@ std::unique_ptr<HloInstruction> HloIotaInstruction::CloneWithNewOperandsImpl( return absl::make_unique<HloIotaInstruction>(shape, iota_dimension()); } +HloDotInstruction::HloDotInstruction( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig& precision_config) + : HloInstruction(HloOpcode::kDot, shape), + dot_dimension_numbers_(dimension_numbers), + precision_config_(precision_config) { + AppendOperand(lhs); + AppendOperand(rhs); +} + +HloInstructionProto HloDotInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_dot_dimension_numbers() = dot_dimension_numbers_; + *proto.mutable_precision_config() = precision_config_; + return proto; +} + +std::vector<string> HloDotInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + std::vector<string> extra = {DotDimensionNumbersToString()}; + + string precision_config_string = PrecisionConfigToString(precision_config_); + if (!precision_config_string.empty()) { + extra.push_back(precision_config_string); + } + return extra; +} + +bool HloDotInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function<bool(const HloComputation*, const HloComputation*)>& + eq_computations) const { + const auto& casted_other = static_cast<const HloDotInstruction&>(other); + return protobuf_util::ProtobufEquals(dot_dimension_numbers(), + casted_other.dot_dimension_numbers()) && + protobuf_util::ProtobufEquals(precision_config(), + casted_other.precision_config()); +} + +std::unique_ptr<HloInstruction> HloDotInstruction::CloneWithNewOperandsImpl( + const Shape& shape, absl::Span<HloInstruction* const> new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 2); + return absl::make_unique<HloDotInstruction>( + shape, new_operands[0], new_operands[1], dot_dimension_numbers_, + precision_config_); +} + +string HloDotInstruction::DotDimensionNumbersToString() const { + std::vector<string> result; + const DotDimensionNumbers& dnums = dot_dimension_numbers_; + if (!dnums.lhs_batch_dimensions().empty()) { + result.push_back(StrCat("lhs_batch_dims={", + StrJoin(dnums.lhs_batch_dimensions(), ","), "}")); + } + result.push_back(StrCat("lhs_contracting_dims={", + StrJoin(dnums.lhs_contracting_dimensions(), ","), + "}")); + + if (!dnums.rhs_batch_dimensions().empty()) { + result.push_back(StrCat("rhs_batch_dims={", + StrJoin(dnums.rhs_batch_dimensions(), ","), "}")); + } + result.push_back(StrCat("rhs_contracting_dims={", + StrJoin(dnums.rhs_contracting_dimensions(), ","), + "}")); + + return StrJoin(result, ", "); +} + +HloDomainInstruction::HloDomainInstruction( + const Shape& shape, HloInstruction* operand, + std::unique_ptr<DomainMetadata> operand_side_metadata, + std::unique_ptr<DomainMetadata> user_side_metadata) + : HloInstruction(HloOpcode::kDomain, shape), + operand_side_metadata_(std::move(operand_side_metadata)), + user_side_metadata_(std::move(user_side_metadata)) { + AppendOperand(operand); +} + +std::vector<string> HloDomainInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) { + return {StrCat("domain={kind=\"", operand_side_metadata_->Kind(), + "\", entry=", user_side_metadata_->ToString(), + ", exit=", operand_side_metadata_->ToString(), "}")}; + } + return {}; +} + +bool HloDomainInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function<bool(const HloComputation*, const HloComputation*)>& + eq_computations) const { + const auto& casted_other = static_cast<const HloDomainInstruction&>(other); + return operand_side_metadata().Matches( + casted_other.operand_side_metadata()) && + user_side_metadata().Matches(casted_other.user_side_metadata()); +} + +std::unique_ptr<HloInstruction> HloDomainInstruction::CloneWithNewOperandsImpl( + const Shape& shape, absl::Span<HloInstruction* const> new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return absl::make_unique<HloDomainInstruction>( + shape, new_operands[0], operand_side_metadata_->Clone(), + user_side_metadata_->Clone()); +} } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index e1215a7566..c3a7801164 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -957,6 +957,16 @@ class HloConvolutionInstruction : public HloInstruction { // The number of feature groups. Must be a divisor of the input feature // dimension and output feature dimension. int64 feature_group_count() const { return feature_group_count_; } + + // Returns the information used to tell the implementation information about + // what sort of precision is requested. The meaning of the field is backend + // specific. At the moment, it is only supported for kConvolution and kDot. + // Transformations on one kDot or kConvolution to another will preserve this + // information. Transformations to other HLOs will not preserve this + // information but it is presumed that the alternate lowering is strictly + // superior. + const PrecisionConfig& precision_config() const { return precision_config_; } + string ToCategory() const override; // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -979,6 +989,9 @@ class HloConvolutionInstruction : public HloInstruction { Window window_; // Describes the dimension numbers used for a convolution. ConvolutionDimensionNumbers convolution_dimension_numbers_; + // Information used to communicate to the implementation about the algorithm + // used to produce results. See the documentation on precision_config(). + PrecisionConfig precision_config_; }; class HloReduceWindowInstruction : public HloInstruction { @@ -1271,6 +1284,85 @@ class HloIotaInstruction : public HloInstruction { const int64 iota_dimension_; }; +class HloDotInstruction : public HloInstruction { + public: + // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch + // dimensions specified in 'dimension_numbers'. + explicit HloDotInstruction(const Shape& shape, HloInstruction* lhs, + HloInstruction* rhs, + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig& precision_config); + + // Returns data on the dimension numbers used for a dot operation. + const DotDimensionNumbers& dot_dimension_numbers() const { + return dot_dimension_numbers_; + } + + // Returns the information used to tell the implementation information about + // what sort of precision is requested. The meaning of the field is backend + // specific. At the moment, it is only supported for kConvolution and kDot. + // Transformations on one kDot or kConvolution to another will preserve this + // information. Transformations to other HLOs will not preserve this + // information but it is presumed that the alternate lowering is strictly + // superior. + const PrecisionConfig& precision_config() const { return precision_config_; } + + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector<string> ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function<bool(const HloComputation*, const HloComputation*)>& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( + const Shape& shape, absl::Span<HloInstruction* const> new_operands, + HloCloneContext* context) const override; + // Returns the dump string of the dot dimension numbers. + string DotDimensionNumbersToString() const; + + // Describes the dimension numbers used for a dot. + DotDimensionNumbers dot_dimension_numbers_; + + // Information used to communicate to the implementation about the algorithm + // used to produce results. See the documentation on precision_config(). + PrecisionConfig precision_config_; +}; + +class HloDomainInstruction : public HloInstruction { + public: + explicit HloDomainInstruction( + const Shape& shape, HloInstruction* operand, + std::unique_ptr<DomainMetadata> operand_side_metadata, + std::unique_ptr<DomainMetadata> user_side_metadata); + + // Retrieves the operand side metadata of a kDomain instruction. + const DomainMetadata& operand_side_metadata() const { + return *operand_side_metadata_; + } + // Retrieves the user side metadata of a kDomain instruction. + const DomainMetadata& user_side_metadata() const { + return *user_side_metadata_; + } + + private: + std::vector<string> ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function<bool(const HloComputation*, const HloComputation*)>& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( + const Shape& shape, absl::Span<HloInstruction* const> new_operands, + HloCloneContext* context) const override; + + std::unique_ptr<DomainMetadata> operand_side_metadata_; + std::unique_ptr<DomainMetadata> user_side_metadata_; +}; } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 3a1bc4e328..cfe906d9c5 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -50,6 +51,13 @@ StatusOr<HloInstruction*> HloModule::LaunderConstInstructionFromModule( return const_cast<HloInstruction*>(hlo); } +Status HloModule::set_schedule(HloSchedule schedule) { + TF_RET_CHECK(schedule.module() == this); + TF_RETURN_IF_ERROR(schedule.Verify()); + schedule_ = std::move(schedule); + return Status::OK(); +} + HloComputation* HloModule::AddComputationInternal( std::unique_ptr<HloComputation> computation, bool is_entry, bool uniquify_names) { @@ -198,12 +206,23 @@ void HloModule::ReplaceComputations( string HloModule::ToString(const HloPrintOptions& options) const { std::ostringstream s; - s << "HloModule " << name() << "\n\n"; + s << "HloModule " << name(); + if (has_schedule()) { + TF_CHECK_OK(schedule().Verify()); + s << ", is_scheduled=true"; + } + s << "\n\n"; for (const HloComputation* computation : MakeComputationPostOrder()) { if (computation == entry_computation()) { s << "ENTRY "; } - s << computation->ToString(options) << "\n\n"; + if (has_schedule() && schedule().is_computation_scheduled(computation)) { + s << computation->ToString( + options, schedule().sequence(computation).instructions()) + << "\n\n"; + } else { + s << computation->ToString(options) << "\n\n"; + } } return s.str(); } @@ -221,6 +240,9 @@ HloModuleProto HloModule::ToProto() const { } proto.add_computations()->Swap(&computation_proto); } + if (has_schedule()) { + *proto.mutable_schedule() = schedule().ToProto().ValueOrDie(); + } return proto; } @@ -309,6 +331,13 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto( } } + if (proto.has_schedule()) { + TF_ASSIGN_OR_RETURN( + HloSchedule schedule, + HloSchedule::CreateFromProto(module.get(), proto.schedule())); + TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule))); + } + return std::move(module); } diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 3c3371426b..26fd1b2438 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -25,6 +25,7 @@ limitations under the License. #include <vector> #include "absl/strings/string_view.h" +#include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/iterator_util.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" @@ -32,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/iterator_range.h" @@ -235,6 +237,19 @@ class HloModule { StatusOr<HloInstruction*> LaunderConstInstructionFromModule( const HloInstruction* hlo); + // Sets the schedule of the module to the given schedule. + Status set_schedule(HloSchedule schedule); + + // Clears the schedule of the module. + void clear_schedule() { schedule_.reset(); } + + // Returns true if the module has a schedule set. + bool has_schedule() const { return schedule_.has_value(); } + + // Returns the schedue of the module. CHECK fails if no schedule is set. + const HloSchedule& schedule() const { return *schedule_; } + HloSchedule& schedule() { return *schedule_; } + private: HloComputation* AddComputationInternal( std::unique_ptr<HloComputation> computation, bool is_entry, @@ -262,6 +277,11 @@ class HloModule { static std::atomic<int> next_unique_module_id_; // A unique id to label modules with. int unique_id_; + + // The HloSchedule of the module. The schedule if it exists contains a + // sequential order of instructions for each non-fusion computation in the + // module. + absl::optional<HloSchedule> schedule_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h index 3f1e1cc73e..68c18836eb 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.h +++ b/tensorflow/compiler/xla/service/hlo_module_config.h @@ -106,9 +106,6 @@ class HloModuleConfig { absl::optional<ComputationLayout> entry_computation_layout_; - // Whether this is a 'host module'. - bool is_host_module_ = false; - // Module/graph-level seed handle. uint64 seed_ = 0; diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc index 4bc1bacd7d..400bd4d947 100644 --- a/tensorflow/compiler/xla/service/hlo_module_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_test.cc @@ -19,9 +19,12 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/test.h" @@ -30,6 +33,8 @@ namespace xla { namespace { +namespace op = ::xla::testing::opcode_matchers; + class HloModuleTest : public HloTestBase { protected: HloModuleTest() {} @@ -194,6 +199,60 @@ TEST_F(HloModuleTest, UniqueModuleId) { EXPECT_NE(module_a->unique_id(), module_b->unique_id()); } +TEST_F(HloModuleTest, ProtoSerializationWithoutSchedule) { + const string text = R"( +HloModule axpy_module + +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %x = f32[2,4]{1,0} parameter(1) + %y = f32[2,4]{1,0} parameter(2) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseHloString(text)); + ASSERT_FALSE(module->has_schedule()); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<HloModule> module_copy, + HloModule::CreateFromProto(module->ToProto(), module->config())); + ASSERT_FALSE(module_copy->has_schedule()); +} + +TEST_F(HloModuleTest, ProtoSerializationWithSchedule) { + const string text = R"( +HloModule axpy_module, is_scheduled=true + +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %x = f32[2,4]{1,0} parameter(1) + %y = f32[2,4]{1,0} parameter(2) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseHloString(text)); + ASSERT_TRUE(module->has_schedule()); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<HloModule> module_copy, + HloModule::CreateFromProto(module->ToProto(), module->config())); + ASSERT_TRUE(module_copy->has_schedule()); + TF_ASSERT_OK(module_copy->schedule().Verify()); + EXPECT_EQ(module_copy->schedule().sequences().size(), 1); + ASSERT_TRUE(module_copy->schedule().is_computation_scheduled( + module_copy->entry_computation())); + EXPECT_THAT( + module_copy->schedule() + .sequence(module_copy->entry_computation()) + .instructions(), + ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Parameter(), + op::Broadcast(), op::Multiply(), op::Add())); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index 2105f7a349..f1dc08bafa 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -293,23 +293,6 @@ bool HloOrdering::MayInterfere(const HloValue& a, const HloValue& b, !LiveRangeStrictlyBefore(b, a, dataflow); } -HloOrderingProto HloOrdering::ToProto() const { - HloOrderingProto proto; - for (const auto& computation : module_->computations()) { - const std::vector<const HloInstruction*>* sequence = - SequentialOrder(*computation); - if (sequence != nullptr) { - HloOrderingProto::SequentialComputation* proto_computation = - proto.add_sequential_computations(); - proto_computation->set_computation_name(computation->name()); - for (const HloInstruction* instruction : *sequence) { - *proto_computation->add_instruction_names() = instruction->name(); - } - } - } - return proto; -} - PredecessorHloOrdering::PredecessorHloOrdering(const HloModule* module) : HloOrdering(module) {} diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h index b21071c4b2..b0361c3f02 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.h +++ b/tensorflow/compiler/xla/service/hlo_ordering.h @@ -72,10 +72,6 @@ class HloOrdering { virtual string ToString() const = 0; - // Returns the serialized representation of this ordering. - // Only sequential computation orders are represented. - HloOrderingProto ToProto() const; - protected: // Returns true if instruction 'a' executes before instruction 'b'. // Precondition: 'a' and 'b' are in the same computation. diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 0f26ed4235..c54360b063 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" @@ -44,6 +45,20 @@ using absl::StrJoin; const double kF16max = 65504; +// Creates and returns a schedule created using the order of the instructions in +// the HloComputation::instructions() vectors in the module. +HloSchedule ScheduleFromInstructionOrder(const HloModule* module) { + HloSchedule schedule(module); + for (const HloComputation* computation : module->computations()) { + if (!computation->IsFusionComputation()) { + for (const HloInstruction* instruction : computation->instructions()) { + schedule.GetOrCreateSequence(computation).push_back(instruction); + } + } + } + return schedule; +} + // Parser for the HloModule::ToString() format text. class HloParser { public: @@ -366,9 +381,25 @@ bool HloParser::ParseHloModule() { return false; } + absl::optional<bool> is_scheduled; + std::unordered_map<string, AttrConfig> attrs; + attrs["is_scheduled"] = {/*required=*/false, AttrTy::kBool, &is_scheduled}; + if (!ParseAttributes(attrs)) { + return false; + } + module_ = absl::make_unique<HloModule>(name, config_); - return ParseComputations(); + if (!ParseComputations()) { + return false; + } + + if (is_scheduled.has_value() && *is_scheduled) { + TF_CHECK_OK( + module_->set_schedule(ScheduleFromInstructionOrder(module_.get()))); + } + + return true; } // computations ::= (computation)+ @@ -1248,11 +1279,14 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, optional<string> custom_call_target; optional<Window> window; optional<ConvolutionDimensionNumbers> dnums; + optional<int64> feature_group_count; attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString, &custom_call_target}; attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window}; attrs["dim_labels"] = {/*required=*/false, AttrTy::kConvolutionDimensionNumbers, &dnums}; + attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64, + &feature_group_count}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } @@ -1264,6 +1298,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (dnums.has_value()) { instruction->set_convolution_dimension_numbers(*dnums); } + if (feature_group_count.has_value()) { + instruction->set_feature_group_count(*feature_group_count); + } break; } case HloOpcode::kDot: { diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 0dfc0a4d1c..cca50fab54 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -1123,18 +1123,31 @@ ENTRY Iota { )" }, -// custom-call with window and dim_labels +// custom-call with window, dim_labels and feature_group_count { -"CustomCallWithWindowAndDimLabels", -R"(HloModule CustomCallWithWindowAndDimLabels +"CustomCallWithWindowAndDimLabelsAndFeatureGroupCount", +R"(HloModule CustomCallWithWindowAndDimLabelsAndFeatureGroupCount ENTRY Computation { - ROOT r = f32[100]{0} custom-call(), window={size=2x2}, dim_labels=b01f_01io->b01f, custom_call_target="target" + ROOT r = f32[100]{0} custom-call(), window={size=2x2}, dim_labels=b01f_01io->b01f, feature_group_count=2, custom_call_target="target" } )" + }, +// is_scheduled=true attribute +{ +"ScheduledModule", +R"(HloModule scheduled_module, is_scheduled=true + +ENTRY Sort { + keys = f32[1024]{0} parameter(0) + values = s32[1024]{0} parameter(1) + ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values), dimensions={0} } - }); + +)" +} +}); // clang-format on } @@ -1790,5 +1803,94 @@ TEST(HloParserSingleOpTest, ConvolutionTrivialFeatureGroupCount) { EXPECT_EQ(convolution->feature_group_count(), 1); } +TEST_F(HloParserTest, IsScheduledIsFalse) { + const string text = R"( +HloModule axpy_module, is_scheduled=false + +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %x = f32[2,4]{1,0} parameter(1) + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + %y = f32[2,4]{1,0} parameter(2) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseHloString(text)); + ASSERT_FALSE(module->has_schedule()); +} + +TEST_F(HloParserTest, IsScheduledNotPresent) { + const string text = R"( +HloModule axpy_module + +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %x = f32[2,4]{1,0} parameter(1) + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + %y = f32[2,4]{1,0} parameter(2) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseHloString(text)); + ASSERT_FALSE(module->has_schedule()); +} + +TEST_F(HloParserTest, IsScheduledIsTrue) { + const string text = R"( +HloModule axpy_module, is_scheduled=true + +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %x = f32[2,4]{1,0} parameter(1) + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + %y = f32[2,4]{1,0} parameter(2) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseHloString(text)); + ASSERT_TRUE(module->has_schedule()); + TF_ASSERT_OK(module->schedule().Verify()); + EXPECT_EQ(module->schedule().sequences().size(), 1); + ASSERT_TRUE( + module->schedule().is_computation_scheduled(module->entry_computation())); + EXPECT_THAT( + module->schedule().sequence(module->entry_computation()).instructions(), + ::testing::ElementsAre(op::Parameter(), op::Broadcast(), op::Parameter(), + op::Multiply(), op::Parameter(), op::Add())); +} + +TEST_F(HloParserTest, IsScheduledIsTrueDifferentOrder) { + // As above but in with a different schedule order. + const string text = R"( +HloModule axpy_module, is_scheduled=true + +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %x = f32[2,4]{1,0} parameter(1) + %y = f32[2,4]{1,0} parameter(2) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseHloString(text)); + ASSERT_TRUE(module->has_schedule()); + TF_ASSERT_OK(module->schedule().Verify()); + EXPECT_EQ(module->schedule().sequences().size(), 1); + ASSERT_TRUE( + module->schedule().is_computation_scheduled(module->entry_computation())); + EXPECT_THAT( + module->schedule().sequence(module->entry_computation()).instructions(), + ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Parameter(), + op::Broadcast(), op::Multiply(), op::Add())); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.cc b/tensorflow/compiler/xla/service/hlo_proto_util.cc index 3460679558..b9c0b0c4ee 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util.cc +++ b/tensorflow/compiler/xla/service/hlo_proto_util.cc @@ -23,11 +23,8 @@ namespace xla { HloProto MakeHloProto(const HloModule& module, const BufferAssignment& assignment) { - HloOrderingProto proto_ordering = - assignment.liveness().hlo_ordering().ToProto(); BufferAssignmentProto proto_assignment = assignment.ToProto(); HloProto proto = MakeHloProto(module); - proto.mutable_hlo_ordering()->Swap(&proto_ordering); proto.mutable_buffer_assignment()->Swap(&proto_assignment); return proto; } diff --git a/tensorflow/compiler/xla/service/hlo_schedule.cc b/tensorflow/compiler/xla/service/hlo_schedule.cc index a65b33bf40..3fc5dbeb02 100644 --- a/tensorflow/compiler/xla/service/hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/hlo_schedule.cc @@ -21,12 +21,64 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/map_util.h" namespace xla { +/* static */ StatusOr<HloSchedule> HloSchedule::CreateFromProto( + const HloModule* module, const HloScheduleProto& proto) { + tensorflow::gtl::FlatMap<int64, const HloComputation*> id_to_computation; + for (const HloComputation* computation : module->computations()) { + id_to_computation[computation->unique_id()] = computation; + } + + HloSchedule schedule(module); + for (const auto& id_sequence : proto.sequences()) { + int64 computation_id = id_sequence.first; + + auto comp_it = id_to_computation.find(computation_id); + TF_RET_CHECK(comp_it != id_to_computation.end()) + << "No computation exists in HLO module with id " << computation_id; + const HloComputation* computation = comp_it->second; + + tensorflow::gtl::FlatMap<int64, const HloInstruction*> id_to_instruction; + for (const HloInstruction* instruction : computation->instructions()) { + id_to_instruction[instruction->unique_id()] = instruction; + } + + HloInstructionSequence& sequence = + schedule.GetOrCreateSequence(computation); + for (const int64 instruction_id : id_sequence.second.instruction_ids()) { + auto instr_it = id_to_instruction.find(instruction_id); + TF_RET_CHECK(instr_it != id_to_instruction.end()) + << "No instruction exists in HLO computation " << computation->name() + << " with id " << instruction_id; + sequence.push_back(instr_it->second); + } + } + TF_RETURN_IF_ERROR(schedule.Verify()); + return std::move(schedule); +} + +StatusOr<HloScheduleProto> HloSchedule::ToProto() const { + TF_RETURN_IF_ERROR(Verify()); + HloScheduleProto proto; + for (const auto& id_sequence : sequences_) { + int64 computation_id = id_sequence.first; + const HloInstructionSequence& sequence = id_sequence.second; + HloScheduleProto::InstructionSequence& proto_sequence = + (*proto.mutable_sequences())[computation_id]; + proto_sequence.mutable_instruction_ids()->Reserve(sequence.size()); + for (const int64 id : sequence.ids()) { + proto_sequence.add_instruction_ids(id); + } + } + return std::move(proto); +} + void HloSchedule::set_sequence( const HloComputation* computation, absl::Span<const HloInstruction* const> sequence) { diff --git a/tensorflow/compiler/xla/service/hlo_schedule.h b/tensorflow/compiler/xla/service/hlo_schedule.h index 21c6988638..270fe6039f 100644 --- a/tensorflow/compiler/xla/service/hlo_schedule.h +++ b/tensorflow/compiler/xla/service/hlo_schedule.h @@ -21,18 +21,20 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/status.h" namespace xla { +class HloModule; + // Class representing a sequence of HLO instructions such as the sequential // execution order of an HLO computation. class HloInstructionSequence { public: HloInstructionSequence() = default; - HloInstructionSequence(absl::Span<const HloInstruction* const> instructions) { + explicit HloInstructionSequence( + absl::Span<const HloInstruction* const> instructions) { for (const HloInstruction* instruction : instructions) { push_back(instruction); } @@ -77,7 +79,12 @@ class HloInstructionSequence { // non-fusion computation in the HLO module. class HloSchedule { public: - HloSchedule(const HloModule* module) : module_(module) {} + explicit HloSchedule(const HloModule* module) : module_(module) {} + + // (De)Serialize an HloSchedule to/from a HloScheduleProto. + static StatusOr<HloSchedule> CreateFromProto(const HloModule* module, + const HloScheduleProto& proto); + StatusOr<HloScheduleProto> ToProto() const; // Returns a reference to the sequence for the given computation. const HloInstructionSequence& sequence( diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 6e17711f57..082bf8bffe 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -855,8 +855,7 @@ void LayoutAssignment::SetupCopiedInstruction(const HloInstruction& instruction, ? instruction.sharding().GetSubSharding(instruction.shape(), index) : instruction.sharding(); // We propagate the sharding to the copied instruction only if it is a - // special sharding, like tiled ones, or special devices like the - // HostCompute module. + // special sharding, like tiled ones. // Otherwise it is preferable to leave the new instruction without device, // and let the automatic device placer to choose the best location. auto device = sharding.UniqueDevice(); diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc index 7d49b8d6c2..a60643bc75 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc @@ -75,6 +75,16 @@ void EmitTuple(const IrArray& tuple, absl::Span<llvm::Value* const> operands, } } +void EmitTuple(const IrArray& tuple, absl::Span<const IrArray> buffers, + llvm::IRBuilder<>* b, llvm::Module* module) { + std::vector<llvm::Value*> buffer_ptrs; + buffer_ptrs.reserve(buffers.size()); + absl::c_transform( + buffers, std::back_inserter(buffer_ptrs), + [](const llvm_ir::IrArray& buffer) { return buffer.GetBasePointer(); }); + llvm_ir::EmitTuple(tuple, buffer_ptrs, b, module); +} + llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index, int alignment, llvm::Value* operand, llvm::IRBuilder<>* b, llvm::Module* module) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h index 887fb61371..94340b91d8 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h +++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h @@ -68,6 +68,11 @@ void EmitTupleSelect(const IrArray& select, const IrArray& pred, void EmitTuple(const IrArray& tuple, absl::Span<llvm::Value* const> operands, llvm::IRBuilder<>* b, llvm::Module* module); +// Similar to EmitTuple above, except that the output buffers are provided in +// the form of IrArray. +void EmitTuple(const IrArray& tuple, absl::Span<const IrArray> buffers, + llvm::IRBuilder<>* b, llvm::Module* module); + // A tuple is an array of pointers, one for each operand. Each pointer points to // the output buffer of its corresponding operand. A GetTupleElement instruction // forwards the pointer to underlying tuple element buffer at the given index. diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 36b8fb2644..d0bda45cf8 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -75,7 +75,6 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_headers_lib", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", ], diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index 8c62adea23..57f7fed61f 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -866,10 +866,7 @@ INSTANTIATE_TEST_CASE_P( BoundsLayout{{2, 300, 784}, {2, 1, 0}, {1}}, BoundsLayout{{2, 300, 784}, {2, 1, 0}, {0}})); -// TODO(b/64093391) Disabled on GPU due to an assertion failure when running -// IrEmitterUnnested::EmitInitializer() for the Reduce operator. Failed on -// 2017-07-26. -XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(OperationOnConstantAsInitValue)) { +XLA_TEST_F(ReduceTest, OperationOnConstantAsInitValue) { XlaBuilder builder(TestName()); XlaComputation max_f32 = CreateScalarMaxComputation(F32, &builder); diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index c20a7c8fe4..3ae31191a0 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -417,4 +417,18 @@ Status VerifyHloModule(HloModule* const module, bool layout_sensitive, .status(); } +std::unique_ptr<HloDotInstruction> CreateCanonicalDot(const Shape& shape, + HloInstruction* lhs, + HloInstruction* rhs) { + CHECK_EQ(ShapeUtil::Rank(lhs->shape()), 2); + CHECK_EQ(ShapeUtil::Rank(rhs->shape()), 2); + PrecisionConfig precision_config; + precision_config.mutable_operand_precision()->Resize( + 2, PrecisionConfig::DEFAULT); + DotDimensionNumbers dot_dimension_numbers; + dot_dimension_numbers.add_lhs_contracting_dimensions(1); + dot_dimension_numbers.add_rhs_contracting_dimensions(0); + return absl::make_unique<HloDotInstruction>( + shape, lhs, rhs, dot_dimension_numbers, precision_config); +} } // namespace xla diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h index 7790737c09..a260271b1b 100644 --- a/tensorflow/compiler/xla/tests/test_utils.h +++ b/tensorflow/compiler/xla/tests/test_utils.h @@ -24,10 +24,10 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/stream_executor/platform.h" namespace xla { @@ -98,6 +98,12 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments( Status VerifyHloModule(HloModule* const module, bool layout_sensitive, bool allow_mixed_precision); +// Creates a dot op with operands 'lhs' and 'rhs' that contracts dimension 1 of +// the LHS with dimension 0 of the RHS with no batch dimensions. +// Both LHS and the RHS must be of rank 2. +std::unique_ptr<HloDotInstruction> CreateCanonicalDot(const Shape& shape, + HloInstruction* lhs, + HloInstruction* rhs); } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_UTILS_H_ diff --git a/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc b/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc index 23ce1d235b..0c3ec5934e 100644 --- a/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc +++ b/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc @@ -67,8 +67,8 @@ int main(int argc, char** argv) { floats.push_back(value); } - absl::string_view content(absl::bit_cast<const char*>(floats.data()), - floats.size() * sizeof(float)); + tensorflow::StringPiece content(absl::bit_cast<const char*>(floats.data()), + floats.size() * sizeof(float)); TF_CHECK_OK(tensorflow::WriteStringToFile(tensorflow::Env::Default(), output_file, content)); return 0; diff --git a/tensorflow/contrib/autograph/converters/logical_expressions.py b/tensorflow/contrib/autograph/converters/logical_expressions.py index 16eb1f0e3f..41c3424fa3 100644 --- a/tensorflow/contrib/autograph/converters/logical_expressions.py +++ b/tensorflow/contrib/autograph/converters/logical_expressions.py @@ -57,8 +57,8 @@ class LogicalExpressionTransformer(converter.Base): gast.NotEq: 'tf.not_equal', gast.Or: 'tf.logical_or', gast.USub: 'tf.negative', - gast.Is: 'autograph_utils.dynamic_is', - gast.IsNot: 'autograph_utils.dynamic_is_not' + gast.Is: 'ag__.utils.dynamic_is', + gast.IsNot: 'ag__.utils.dynamic_is_not' } def _expect_simple_symbol(self, operand): diff --git a/tensorflow/contrib/autograph/converters/logical_expressions_test.py b/tensorflow/contrib/autograph/converters/logical_expressions_test.py index 8f9eee7081..409a73afba 100644 --- a/tensorflow/contrib/autograph/converters/logical_expressions_test.py +++ b/tensorflow/contrib/autograph/converters/logical_expressions_test.py @@ -47,6 +47,15 @@ class GradientsFunctionTest(converter_testing.TestCase): with self.cached_session() as sess: self.assertTrue(sess.run(result.test_fn(True, False, True))) + def test_ag_utils_lookup(self): + def test_fn(a, b): + return a is b or a is not b + + with self.converted(test_fn, logical_expressions, {}, math_ops.logical_or + ) as result: + with self.cached_session() as sess: + self.assertTrue(sess.run(result.test_fn(True, False))) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/autograph/impl/api_test.py b/tensorflow/contrib/autograph/impl/api_test.py index 803fde9089..a4c6fed265 100644 --- a/tensorflow/contrib/autograph/impl/api_test.py +++ b/tensorflow/contrib/autograph/impl/api_test.py @@ -38,9 +38,6 @@ class ApiTest(test.TestCase): def setUp(self): config.COMPILED_IMPORT_STATEMENTS = ( 'from __future__ import print_function', - 'from tensorflow.contrib.autograph import utils' - ' as autograph_utils', - 'tf = autograph_utils.fake_tf()', ) def test_decorator_recurses(self): diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/anf.py b/tensorflow/contrib/autograph/pyct/common_transformers/anf.py index e42f679cfe..d77c15915b 100644 --- a/tensorflow/contrib/autograph/pyct/common_transformers/anf.py +++ b/tensorflow/contrib/autograph/pyct/common_transformers/anf.py @@ -394,10 +394,16 @@ class AnfTransformer(transformer.Base): # just recur. def visit_List(self, node): - return self._visit_strict_expression(node) + node = self.generic_visit(node) + if not isinstance(node.ctx, gast.Store): + self._ensure_fields_trivial(node) + return node def visit_Tuple(self, node): - return self._visit_strict_expression(node) + node = self.generic_visit(node) + if not isinstance(node.ctx, gast.Store): + self._ensure_fields_trivial(node) + return node def transform(node, entity_info, gensym_source=None): diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py b/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py index 951974820c..1ffd4bbe55 100644 --- a/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py +++ b/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py @@ -165,6 +165,46 @@ class AnfTransformerTest(test.TestCase): self.assert_body_anfs_as_expected(expected_result, test_function) + def test_nested_multi_value_assign(self): + + def test_function(a, b, c): + x, y = a, a + b + (z, y), x = (c, y + b), x + a + return z, (y, x) + + def expected_result(a, b, c): + tmp_1001 = a + b + x, y = a, tmp_1001 + tmp_1002 = y + b + tmp_1003 = (c, tmp_1002) + tmp_1004 = x + a + (z, y), x = tmp_1003, tmp_1004 + tmp_1005 = y, x + tmp_1006 = z, tmp_1005 + return tmp_1006 + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_deeply_nested_multi_value_assign(self): + + def test_function(a): + [([(b, c), [d, e]], (f, g)), [(h, i, j), k]] = a + return [([(b, c), [d, e]], (f, g)), [(h, i, j), k]] + + def expected_result(a): + [([(b, c), [d, e]], (f, g)), [(h, i, j), k]] = a + tmp_1001 = b, c + tmp_1002 = [d, e] + tmp_1003 = [tmp_1001, tmp_1002] + tmp_1004 = f, g + tmp_1005 = h, i, j + tmp_1006 = tmp_1003, tmp_1004 + tmp_1007 = [tmp_1005, k] + tmp_1008 = [tmp_1006, tmp_1007] + return tmp_1008 + + self.assert_body_anfs_as_expected(expected_result, test_function) + def test_local_definition_and_binary_compare(self): def test_function(): diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py b/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py index 2d8f922a45..e7baa244b2 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py @@ -29,6 +29,11 @@ from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno +# TODO(aqj): Do we need this? Do other builtins fail in similar ways +# See b/114389775 for a related bug in pyct +# These symbols are legal in Python, but don't appear in the namespace. +_special_symbols = {'range': range} + class LiveValueResolver(transformer.Base): """Annotates nodes with live values.""" @@ -66,6 +71,8 @@ class LiveValueResolver(transformer.Base): # If the symbol value is for example a primitive, then it will not # have a name. pass + elif node.id in _special_symbols: + anno.setanno(node, 'live_val', _special_symbols[node.id]) else: pass # TODO(mdan): Should we raise an error here? diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py index 870ce2442b..4c7a538b38 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py @@ -52,7 +52,8 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): center_bias=True, use_core_libs=False, output_leaf_index=False, - override_global_step_value=None): + override_global_step_value=None, + num_quantiles=100): """Initializes a GradientBoostedDecisionTreeClassifier estimator instance. Args: @@ -94,6 +95,7 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): trees were trained), this parameter can be used to set the global step to a large value, making it look like that number of training steps ran. If None, no override of global step will happen. + num_quantiles: Number of quantiles to build for numeric feature values. Raises: ValueError: If learner_config is not valid. @@ -134,7 +136,8 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): 'logits_modifier_function': logits_modifier_function, 'use_core_libs': use_core_libs, 'output_leaf_index': output_leaf_index, - 'override_global_step_value': override_global_step_value + 'override_global_step_value': override_global_step_value, + 'num_quantiles': num_quantiles, }, model_dir=model_dir, config=config, @@ -159,7 +162,8 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator): center_bias=True, use_core_libs=False, output_leaf_index=False, - override_global_step_value=None): + override_global_step_value=None, + num_quantiles=100): """Initializes a GradientBoostedDecisionTreeRegressor estimator instance. Args: @@ -201,6 +205,7 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator): trees were trained), this parameter can be used to set the global step to a large value, making it look like that number of training steps ran. If None, no override of global step will happen. + num_quantiles: Number of quantiles to build for numeric feature values. """ head = head_lib.regression_head( label_name=label_name, @@ -224,7 +229,8 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator): 'center_bias': center_bias, 'use_core_libs': use_core_libs, 'output_leaf_index': False, - 'override_global_step_value': override_global_step_value + 'override_global_step_value': override_global_step_value, + 'num_quantiles': num_quantiles, }, model_dir=model_dir, config=config, @@ -251,7 +257,8 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator): center_bias=True, use_core_libs=False, output_leaf_index=False, - override_global_step_value=None): + override_global_step_value=None, + num_quantiles=100): """Initializes a GradientBoostedDecisionTreeEstimator estimator instance. Args: @@ -289,6 +296,7 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator): trees were trained), this parameter can be used to set the global step to a large value, making it look like that number of training steps ran. If None, no override of global step will happen. + num_quantiles: Number of quantiles to build for numeric feature values. """ super(GradientBoostedDecisionTreeEstimator, self).__init__( model_fn=model.model_builder, @@ -303,7 +311,8 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator): 'center_bias': center_bias, 'use_core_libs': use_core_libs, 'output_leaf_index': False, - 'override_global_step_value': override_global_step_value + 'override_global_step_value': override_global_step_value, + 'num_quantiles': num_quantiles, }, model_dir=model_dir, config=config, @@ -329,7 +338,8 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator): center_bias=False, use_core_libs=False, output_leaf_index=False, - override_global_step_value=None): + override_global_step_value=None, + num_quantiles=100): """Initializes a GradientBoostedDecisionTreeRanker instance. This is an estimator that can be trained off the pairwise data and can be @@ -377,6 +387,8 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator): trees were trained), this parameter can be used to set the global step to a large value, making it look like that number of training steps ran. If None, no override of global step will happen. + num_quantiles: Number of quantiles to build for numeric feature values. + Raises: ValueError: If learner_config is not valid. """ @@ -395,7 +407,8 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator): 'use_core_libs': use_core_libs, 'output_leaf_index': output_leaf_index, 'ranking_model_pair_keys': ranking_model_pair_keys, - 'override_global_step_value': override_global_step_value + 'override_global_step_value': override_global_step_value, + 'num_quantiles': num_quantiles, }, model_dir=model_dir, config=config, @@ -444,7 +457,8 @@ class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator): feature_engineering_fn=None, logits_modifier_function=None, center_bias=True, - output_leaf_index=False): + output_leaf_index=False, + num_quantiles=100): """Initializes a core version of GradientBoostedDecisionTreeEstimator. Args: @@ -474,6 +488,7 @@ class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator): for example_prediction_result in result_dict: # access leaf index list by example_prediction_result["leaf_index"] # which contains one leaf index per tree + num_quantiles: Number of quantiles to build for numeric feature values. """ def _model_fn(features, labels, mode, config): @@ -493,7 +508,8 @@ class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator): 'logits_modifier_function': logits_modifier_function, 'use_core_libs': True, 'output_leaf_index': output_leaf_index, - 'override_global_step_value': None + 'override_global_step_value': None, + 'num_quantiles': num_quantiles, }, output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC) @@ -517,7 +533,8 @@ class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator): label_keys=None, logits_modifier_function=None, center_bias=False, - output_leaf_index=False): + output_leaf_index=False, + num_quantiles=100): """Initializes a GradientBoostedDecisionTreeRanker instance. This is an estimator that can be trained off the pairwise data and can be @@ -552,6 +569,7 @@ class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator): for result_dict in result_iter: # access leaf index list by result_dict["leaf_index"] # which contains one leaf index per tree + num_quantiles: Number of quantiles to build for numeric feature values. Raises: ValueError: If learner_config is not valid. @@ -576,7 +594,8 @@ class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator): 'use_core_libs': True, 'output_leaf_index': output_leaf_index, 'ranking_model_pair_keys': ranking_model_pair_keys, - 'override_global_step_value': None + 'override_global_step_value': None, + 'num_quantiles': num_quantiles, }, output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/model.py b/tensorflow/contrib/boosted_trees/estimator_batch/model.py index 04b46c3483..a6e422847d 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/model.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/model.py @@ -81,6 +81,7 @@ def model_builder(features, logits_modifier_function = params["logits_modifier_function"] output_leaf_index = params["output_leaf_index"] override_global_step_value = params.get("override_global_step_value", None) + num_quantiles = params["num_quantiles"] if features is None: raise ValueError("At least one feature must be specified.") @@ -116,7 +117,8 @@ def model_builder(features, logits_dimension=head.logits_dimension, features=training_features, use_core_columns=use_core_libs, - output_leaf_index=output_leaf_index) + output_leaf_index=output_leaf_index, + num_quantiles=num_quantiles) with ops.name_scope("gbdt", "gbdt_optimizer"): predictions_dict = gbdt_model.predict(mode) logits = predictions_dict["predictions"] @@ -237,6 +239,7 @@ def ranking_model_builder(features, output_leaf_index = params["output_leaf_index"] ranking_model_pair_keys = params["ranking_model_pair_keys"] override_global_step_value = params.get("override_global_step_value", None) + num_quantiles = params["num_quantiles"] if features is None: raise ValueError("At least one feature must be specified.") @@ -299,7 +302,8 @@ def ranking_model_builder(features, logits_dimension=head.logits_dimension, features=main_features, use_core_columns=use_core_libs, - output_leaf_index=output_leaf_index) + output_leaf_index=output_leaf_index, + num_quantiles=num_quantiles) with ops.name_scope("gbdt", "gbdt_optimizer"): # Logits for inference. diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index b008c6e534..c7eb2493a8 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -304,7 +304,8 @@ class GradientBoostedDecisionTreeModel(object): feature_columns=None, use_core_columns=False, output_leaf_index=False, - output_leaf_index_modes=None): + output_leaf_index_modes=None, + num_quantiles=100): """Construct a new GradientBoostedDecisionTreeModel function. Args: @@ -327,6 +328,7 @@ class GradientBoostedDecisionTreeModel(object): output_leaf_index_modes: A list of modes from (TRAIN, EVAL, INFER) which dictates when leaf indices will be outputted. By default, leaf indices are only outputted in INFER mode. + num_quantiles: Number of quantiles to build for numeric feature values. Raises: ValueError: if inputs are not valid. @@ -399,6 +401,7 @@ class GradientBoostedDecisionTreeModel(object): self._learner_config = learner_config self._feature_columns = feature_columns self._learner_config_serialized = learner_config.SerializeToString() + self._num_quantiles = num_quantiles self._max_tree_depth = variables.Variable( initial_value=self._learner_config.constraints.max_tree_depth) self._attempted_trees = variables.Variable( @@ -689,8 +692,8 @@ class GradientBoostedDecisionTreeModel(object): loss_uses_sum_reduction = constant_op.constant(loss_uses_sum_reduction) weak_learner_type = constant_op.constant( self._learner_config.weak_learner_type) - epsilon = 0.01 - num_quantiles = 100 + num_quantiles = self._num_quantiles + epsilon = 1.0 / num_quantiles strategy_tensor = constant_op.constant(strategy) with ops.device(self._get_replica_device_setter(worker_device)): # Create handlers for dense float columns diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index 1ab150d74a..1056894f18 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -229,6 +229,10 @@ class TPUClusterResolver(ClusterResolver): def get_master(self): return self.master() + def get_job_name(self): + if self._shouldResolve(): + return self._job_name + def cluster_spec(self): """Returns a ClusterSpec object based on the latest TPU information. diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 34f594f741..b9320e5fef 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -279,7 +279,9 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:function", + "//tensorflow/python:functional_ops", "//tensorflow/python:math_ops", + "//tensorflow/python:session", ], ) diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index 9d8e955245..67242fecfe 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -428,10 +428,10 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): self.assertEqual([None, 30], dataset.output_shapes[1][1].as_list()) @parameterized.named_parameters( - ("default", None, None), - ("sequential_calls", 1, None), - ("parallel_calls", 2, None), - ("parallel_batches", None, 10), + ("Default", None, None), + ("SequentialCalls", 1, None), + ("ParallelCalls", 2, None), + ("ParallelBatches", None, 10), ) def testMapAndBatch(self, num_parallel_calls, num_parallel_batches): """Test a dataset that maps a TF function across its input elements.""" @@ -505,8 +505,8 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): sess.run(init_op, feed_dict={count: 14, batch_size: 0}) @parameterized.named_parameters( - ("even", False), - ("uneven", True), + ("Even", False), + ("Uneven", True), ) def testMapAndBatchPartialBatch(self, drop_remainder): iterator = ( @@ -663,7 +663,14 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): for _ in range(3): sess.run(get_next) - @parameterized.parameters(0, 5, 10, 90, 95, 99) + @parameterized.named_parameters( + ("1", 0), + ("2", 5), + ("3", 10), + ("4", 90), + ("5", 95), + ("6", 99), + ) def testMapAndBatchOutOfRangeError(self, threshold): def raising_py_fn(i): @@ -689,18 +696,18 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - @parameterized.parameters( - (False, dtypes.bool), - (-42, dtypes.int8), - (-42, dtypes.int16), - (-42, dtypes.int32), - (-42, dtypes.int64), - (42, dtypes.uint8), - (42, dtypes.uint16), - (42.0, dtypes.float16), - (42.0, dtypes.float32), - (42.0, dtypes.float64), - (b"hello", dtypes.string), + @parameterized.named_parameters( + ("1", False, dtypes.bool), + ("2", -42, dtypes.int8), + ("3", -42, dtypes.int16), + ("4", -42, dtypes.int32), + ("5", -42, dtypes.int64), + ("6", 42, dtypes.uint8), + ("7", 42, dtypes.uint16), + ("8", 42.0, dtypes.float16), + ("9", 42.0, dtypes.float32), + ("10", 42.0, dtypes.float64), + ("11", b"hello", dtypes.string), ) def testMapAndBatchTypes(self, element, dtype): def gen(): diff --git a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py index 091eb5ce37..61567bc8d7 100644 --- a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py @@ -17,7 +17,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import time + from tensorflow.contrib.data.python.ops import map_defun +from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -25,10 +28,10 @@ from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops +from tensorflow.python.ops import functional_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test - class MapDefunTest(test.TestCase): def testMapDefunSimple(self): @@ -146,6 +149,105 @@ class MapDefunTest(test.TestCase): r"indices = 10 is not in \[0, 5\)"): self.evaluate(map_defun_op) + def testMapDefunWithUnspecifiedOutputShape(self): + + @function.Defun(dtypes.int32) + def simple_fn(x): + res = x * 2 + 3 + return (res, res + 1, res + 2) + + nums = [[1, 2], [3, 4], [5, 6]] + elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") + r = map_defun.map_defun(simple_fn, [elems], + [dtypes.int32, dtypes.int32, dtypes.int32], + [None, (None,), (2,)]) + expected = elems * 2 + 3 + self.assertAllEqual(self.evaluate(r[0]), self.evaluate(expected)) + self.assertAllEqual(self.evaluate(r[1]), self.evaluate(expected + 1)) + self.assertAllEqual(self.evaluate(r[2]), self.evaluate(expected + 2)) + + def testMapDefunWithDifferentOutputShapeEachRun(self): + + @function.Defun(dtypes.int32) + def simple_fn(x): + return x * 2 + 3 + + elems = array_ops.placeholder(dtypes.int32, name="data") + r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [None])[0] + with session.Session() as sess: + self.assertAllEqual(sess.run(r, feed_dict={elems: [0]}), [3]) + self.assertAllEqual( + sess.run(r, feed_dict={elems: [[0], [1]]}), [[3], [5]]) + + def testMapDefunWithWrongOutputShape(self): + + @function.Defun(dtypes.int32) + def simple_fn(x): + return x * 2 + 3 + + nums = [[1, 2], [3, 4], [5, 6]] + elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") + r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(1,)])[0] + with self.assertRaises(errors.InvalidArgumentError): + self.evaluate(r) + + def testMapDefunWithInvalidInput(self): + + @function.Defun(dtypes.int32) + def simple_fn(x): + return x * 2 + + c = constant_op.constant(2) + with self.assertRaises(ValueError): + # Fails at graph construction time for inputs with known shapes. + r = map_defun.map_defun(simple_fn, [c], [dtypes.int32], [None])[0] + p = array_ops.placeholder(dtypes.int32) + r = map_defun.map_defun(simple_fn, [p], [dtypes.int32], [None])[0] + with session.Session() as sess: + with self.assertRaises(errors.InvalidArgumentError): + sess.run(r, feed_dict={p: 0}) + + +class MapDefunBenchmark(test.Benchmark): + + def _run(self, op, name=None, num_iters=3000): + with session.Session() as sess: + # Warm up the session + for _ in range(5): + sess.run(op) + start = time.time() + for _ in range(num_iters): + sess.run(op) + end = time.time() + mean_us = (end - start) * 1e6 / num_iters + self.report_benchmark( + name=name, + iters=num_iters, + wall_time=mean_us, + extras={"examples_per_sec": num_iters / (end - start)}) + + def benchmarkDefunVsMapFn(self): + """Benchmarks to compare the performance of MapDefun vs tf.map_fn.""" + + @function.Defun(dtypes.int32) + def defun(x): + return array_ops.identity(x) + + def map_fn(x): + return array_ops.identity(x) + + base = math_ops.range(100) + for input_size in [10, 100, 1000, 10000]: + num_iters = 100000 // input_size + map_defun_op = map_defun.map_defun(defun, [base], [dtypes.int32], [()]) + map_fn_op = functional_ops.map_fn(map_fn, base) + + self._run( + map_defun_op, + "benchmarkMapDefun_size_%d" % input_size, + num_iters=num_iters) + self._run( + map_fn_op, "benchmarkMapFn_size_%d" % input_size, num_iters=num_iters) if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py index 586b4bee5f..6a7ef877f9 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py @@ -44,22 +44,22 @@ class MapAndFilterFusionTest(test.TestCase, parameterized.TestCase): for i, fun1 in enumerate(functions): for j, fun2 in enumerate(functions): tests.append(( - "test_{}_{}".format(i, j), + "Test{}{}".format(i, j), [fun1, fun2], )) for k, fun3 in enumerate(functions): tests.append(( - "test_{}_{}_{}".format(i, j, k), + "Test{}{}{}".format(i, j, k), [fun1, fun2, fun3], )) swap = lambda x, n: (n, x) tests.append(( - "swap1", + "Swap1", [lambda x: (x, 42), swap], )) tests.append(( - "swap2", + "Swap2", [lambda x: (x, 42), swap, swap], )) return tuple(tests) @@ -109,13 +109,13 @@ class MapAndFilterFusionTest(test.TestCase, parameterized.TestCase): for x, fun in enumerate(functions): for y, predicate in enumerate(filters): - tests.append(("mixed_{}_{}".format(x, y), fun, predicate)) + tests.append(("Mixed{}{}".format(x, y), fun, predicate)) # Multi output - tests.append(("multiOne", lambda x: (x, x), + tests.append(("Multi1", lambda x: (x, x), lambda x, y: constant_op.constant(True))) tests.append( - ("multiTwo", lambda x: (x, 2), + ("Multi2", lambda x: (x, 2), lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0))) return tuple(tests) @@ -172,17 +172,17 @@ class MapAndFilterFusionTest(test.TestCase, parameterized.TestCase): identity = lambda x: x for x, predicate_1 in enumerate(filters): for y, predicate_2 in enumerate(filters): - tests.append(("mixed_{}_{}".format(x, y), identity, + tests.append(("Mixed{}{}".format(x, y), identity, [predicate_1, predicate_2])) for z, predicate_3 in enumerate(filters): - tests.append(("mixed_{}_{}_{}".format(x, y, z), identity, + tests.append(("Mixed{}{}{}".format(x, y, z), identity, [predicate_1, predicate_2, predicate_3])) take_all_multiple = lambda x, y: constant_op.constant(True) # Multi output - tests.append(("multiOne", lambda x: (x, x), + tests.append(("Multi1", lambda x: (x, x), [take_all_multiple, take_all_multiple])) - tests.append(("multiTwo", lambda x: (x, 2), [ + tests.append(("Multi2", lambda x: (x, 2), [ take_all_multiple, lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0) ])) diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD index 4881f63ab9..aa89674c6e 100644 --- a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD @@ -210,6 +210,7 @@ py_test( "//tensorflow/python:sparse_tensor", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py index ac3892fe81..243f6405a1 100644 --- a/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base @@ -27,42 +28,38 @@ from tensorflow.python.platform import test class InterleaveDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): + dataset_serialization_test_base.DatasetSerializationTestBase, + parameterized.TestCase): - def _build_iterator_graph(self, input_values, cycle_length, block_length): + def _build_iterator_graph(self, input_values, cycle_length, block_length, + num_parallel_calls): repeat_count = 2 return dataset_ops.Dataset.from_tensor_slices(input_values).repeat( repeat_count).interleave( lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x), - cycle_length, block_length) + cycle_length, block_length, num_parallel_calls) - def testSerializationCore(self): + @parameterized.named_parameters( + ("1", 2, 3, None), + ("2", 2, 3, 1), + ("3", 2, 3, 2), + ("4", 1, 3, None), + ("5", 1, 3, 1), + ("6", 2, 1, None), + ("7", 2, 1, 1), + ("8", 2, 1, 2), + ) + def testSerializationCore(self, cycle_length, block_length, + num_parallel_calls): input_values = np.array([4, 5, 6], dtype=np.int64) num_outputs = np.sum(input_values) * 2 - # cycle_length > 1, block_length > 1 - cycle_length = 2 - block_length = 3 # pylint: disable=g-long-lambda self.run_core_tests( lambda: self._build_iterator_graph( - input_values, cycle_length, block_length), + input_values, cycle_length, block_length, num_parallel_calls), lambda: self._build_iterator_graph( - input_values, cycle_length * 2, block_length * 1), + input_values, cycle_length * 2, block_length, num_parallel_calls), num_outputs) - # cycle_length = 1 - cycle_length = 1 - block_length = 3 - self.run_core_tests( - lambda: self._build_iterator_graph( - input_values, cycle_length, block_length), - None, num_outputs) - # block_length = 1 - cycle_length = 2 - block_length = 1 - self.run_core_tests( - lambda: self._build_iterator_graph( - input_values, cycle_length, block_length), - None, num_outputs) # pylint: enable=g-long-lambda def testSparseCore(self): @@ -82,5 +79,5 @@ class InterleaveDatasetSerializationTest( self.run_core_tests(_build_dataset, None, 20) -if __name__ == '__main__': +if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py index 8b2f846494..6b3e8e9f6e 100644 --- a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py @@ -32,18 +32,18 @@ from tensorflow.python.platform import test class SlideDatasetTest(test.TestCase, parameterized.TestCase): - @parameterized.parameters( - (20, 14, 7, 1), - (20, 17, 9, 1), - (20, 14, 14, 1), - (20, 10, 14, 1), - (20, 14, 19, 1), - (20, 4, 1, 2), - (20, 2, 1, 6), - (20, 4, 7, 2), - (20, 2, 7, 6), - (1, 10, 4, 1), - (0, 10, 4, 1), + @parameterized.named_parameters( + ("1", 20, 14, 7, 1), + ("2", 20, 17, 9, 1), + ("3", 20, 14, 14, 1), + ("4", 20, 10, 14, 1), + ("5", 20, 14, 19, 1), + ("6", 20, 4, 1, 2), + ("7", 20, 2, 1, 6), + ("8", 20, 4, 7, 2), + ("9", 20, 2, 7, 6), + ("10", 1, 10, 4, 1), + ("11", 0, 10, 4, 1), ) def testSlideDataset(self, count, window_size, window_shift, window_stride): """Tests a dataset that slides a window its input elements.""" @@ -96,18 +96,18 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - @parameterized.parameters( - (20, 14, 7, 1), - (20, 17, 9, 1), - (20, 14, 14, 1), - (20, 10, 14, 1), - (20, 14, 19, 1), - (20, 4, 1, 2), - (20, 2, 1, 6), - (20, 4, 7, 2), - (20, 2, 7, 6), - (1, 10, 4, 1), - (0, 10, 4, 1), + @parameterized.named_parameters( + ("1", 20, 14, 7, 1), + ("2", 20, 17, 9, 1), + ("3", 20, 14, 14, 1), + ("4", 20, 10, 14, 1), + ("5", 20, 14, 19, 1), + ("6", 20, 4, 1, 2), + ("7", 20, 2, 1, 6), + ("8", 20, 4, 7, 2), + ("9", 20, 2, 7, 6), + ("10", 1, 10, 4, 1), + ("11", 0, 10, 4, 1), ) def testSlideDatasetDeprecated(self, count, window_size, stride, window_stride): @@ -160,10 +160,10 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - @parameterized.parameters( - (14, 0, 3, 1), - (14, 3, 0, 1), - (14, 3, 3, 0), + @parameterized.named_parameters( + ("1", 14, 0, 3, 1), + ("2", 14, 3, 0, 1), + ("3", 14, 3, 3, 0), ) def testSlideDatasetInvalid(self, count, window_size, window_shift, window_stride): diff --git a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py index 0486e2bce2..4b08ec759d 100644 --- a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py @@ -33,8 +33,17 @@ from tensorflow.python.platform import test class OverrideThreadpoolDatasetTest(test.TestCase, parameterized.TestCase): - @parameterized.parameters((1, None), (2, None), (4, None), (8, None), - (16, None), (4, -1), (4, 0), (4, 1), (4, 4)) + @parameterized.named_parameters( + ("1", 1, None), + ("2", 2, None), + ("3", 4, None), + ("4", 8, None), + ("5", 16, None), + ("6", 4, -1), + ("7", 4, 0), + ("8", 4, 1), + ("9", 4, 4), + ) def testNumThreads(self, num_threads, max_intra_op_parallelism): def get_thread_id(_): diff --git a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py index 33d95d6754..ff4d9b3260 100644 --- a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py @@ -64,15 +64,15 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): else: self.assertEqual(xs, ys) - @parameterized.parameters( - (None, np.int32([]), dtypes.bool), - (None, np.int32([]), dtypes.int32), - (None, np.int32([]), dtypes.float32), - (None, np.int32([]), dtypes.string), - (None, np.int32([2]), dtypes.int32), - (None, np.int32([2, 2]), dtypes.int32), - ((None, None, None), np.int32([]), dtypes.int32), - ((None, (None, None)), np.int32([]), dtypes.int32), + @parameterized.named_parameters( + ("1", None, np.int32([]), dtypes.bool), + ("2", None, np.int32([]), dtypes.int32), + ("3", None, np.int32([]), dtypes.float32), + ("4", None, np.int32([]), dtypes.string), + ("5", None, np.int32([2]), dtypes.int32), + ("6", None, np.int32([2, 2]), dtypes.int32), + ("7", (None, None, None), np.int32([]), dtypes.int32), + ("8", (None, (None, None)), np.int32([]), dtypes.int32), ) def testWindowDatasetFlatMap(self, structure, shape, dtype): """Tests windowing by chaining it with flat map. @@ -97,15 +97,15 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): actual = sess.run(get_next) self._assertEqual(expected, actual) - @parameterized.parameters( - (None, np.int32([]), dtypes.bool), - (None, np.int32([]), dtypes.int32), - (None, np.int32([]), dtypes.float32), - (None, np.int32([]), dtypes.string), - (None, np.int32([2]), dtypes.int32), - (None, np.int32([2, 2]), dtypes.int32), - ((None, None, None), np.int32([]), dtypes.int32), - ((None, (None, None)), np.int32([]), dtypes.int32), + @parameterized.named_parameters( + ("1", None, np.int32([]), dtypes.bool), + ("2", None, np.int32([]), dtypes.int32), + ("3", None, np.int32([]), dtypes.float32), + ("4", None, np.int32([]), dtypes.string), + ("5", None, np.int32([2]), dtypes.int32), + ("6", None, np.int32([2, 2]), dtypes.int32), + ("7", (None, None, None), np.int32([]), dtypes.int32), + ("8", (None, (None, None)), np.int32([]), dtypes.int32), ) def testWindowDatasetBatchDense(self, structure, shape, dtype): """Tests batching of dense tensor windows. @@ -135,10 +135,10 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): actual = sess.run(get_next) self._assertEqual(expected, actual) - @parameterized.parameters( - (np.int32([]),), - (np.int32([1]),), - (np.int32([1, 2, 3]),), + @parameterized.named_parameters( + ("1", np.int32([])), + ("2", np.int32([1])), + ("3", np.int32([1, 2, 3])), ) def testWindowDatasetBatchDenseDynamicShape(self, shape): """Tests batching of dynamically shaped dense tensor windows. @@ -203,15 +203,15 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): for substructure in structure ]) - @parameterized.parameters( - (None, np.int32([]), dtypes.bool), - (None, np.int32([]), dtypes.int32), - (None, np.int32([]), dtypes.float32), - (None, np.int32([]), dtypes.string), - (None, np.int32([2]), dtypes.int32), - (None, np.int32([2, 2]), dtypes.int32), - ((None, None, None), np.int32([]), dtypes.int32), - ((None, (None, None)), np.int32([]), dtypes.int32), + @parameterized.named_parameters( + ("1", None, np.int32([]), dtypes.bool), + ("2", None, np.int32([]), dtypes.int32), + ("3", None, np.int32([]), dtypes.float32), + ("4", None, np.int32([]), dtypes.string), + ("5", None, np.int32([2]), dtypes.int32), + ("6", None, np.int32([2, 2]), dtypes.int32), + ("7", (None, None, None), np.int32([]), dtypes.int32), + ("8", (None, (None, None)), np.int32([]), dtypes.int32), ) def testWindowDatasetBatchSparse(self, structure, shape, dtype): """Tests batching of sparse tensor windows. @@ -243,10 +243,10 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): actual = sess.run(get_next) self._assertEqual(expected, actual) - @parameterized.parameters( - (np.int32([]),), - (np.int32([1]),), - (np.int32([1, 2, 3]),), + @parameterized.named_parameters( + ("1", np.int32([])), + ("2", np.int32([1])), + ("3", np.int32([1, 2, 3])), ) def testWindowDatasetBatchSparseDynamicShape(self, shape): """Tests batching of dynamically shaped sparse tensor windows. @@ -284,17 +284,18 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): for substructure in structure ])) - @parameterized.parameters( - (None, np.int32([[1], [2], [3]]), dtypes.bool, [-1]), - (None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]), - (None, np.int32([[1], [2], [3]]), dtypes.float32, [-1]), - (None, np.int32([[1], [2], [3]]), dtypes.string, [-1]), - (None, np.int32([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]), - (None, np.int32([[3, 1, 3], [1, 3, 1]]), dtypes.int32, [-1, -1, -1]), - ((None, None, None), np.int32([[1], [2], [3]]), dtypes.int32, [-1]), - ((None, (None, None)), np.int32([[1], [2], [3]]), dtypes.int32, [-1]), - (None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]), - (None, np.int32([[1], [2], [3]]), dtypes.int32, np.int32([10])), + @parameterized.named_parameters( + ("1", None, np.int32([[1], [2], [3]]), dtypes.bool, [-1]), + ("2", None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]), + ("3", None, np.int32([[1], [2], [3]]), dtypes.float32, [-1]), + ("4", None, np.int32([[1], [2], [3]]), dtypes.string, [-1]), + ("5", None, np.int32([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]), + ("6", None, np.int32([[3, 1, 3], [1, 3, 1]]), dtypes.int32, [-1, -1, -1]), + ("7", (None, None, None), np.int32([[1], [2], [3]]), dtypes.int32, [-1]), + ("8", (None, + (None, None)), np.int32([[1], [2], [3]]), dtypes.int32, [-1]), + ("9", None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]), + ("10", None, np.int32([[1], [2], [3]]), dtypes.int32, np.int32([10])), ) def testWindowDatasetPaddedBatchDense(self, structure, shapes, dtype, padded_shape): @@ -329,10 +330,10 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): actual = sess.run(get_next) self._assertEqual(expected, actual) - @parameterized.parameters( - (np.int32([[1], [2], [3]]), [-1]), - (np.int32([[1, 3], [2, 2], [3, 1]]), [-1, -1]), - (np.int32([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]), + @parameterized.named_parameters( + ("1", np.int32([[1], [2], [3]]), [-1]), + ("2", np.int32([[1, 3], [2, 2], [3, 1]]), [-1, -1]), + ("3", np.int32([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]), ) def testWindowDatasetPaddedBatchDenseDynamicShape(self, shapes, padded_shape): """Tests padded batching of dynamically shaped dense tensor windows. @@ -361,9 +362,9 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): actual = sess.run(get_next) self._assertEqual(expected, actual) - @parameterized.parameters( - (np.int32([[1]]), np.int32([0])), - (np.int32([[10], [20]]), np.int32([15])), + @parameterized.named_parameters( + ("1", np.int32([[1]]), np.int32([0])), + ("2", np.int32([[10], [20]]), np.int32([15])), ) def testWindowDatasetPaddedBatchDenseInvalid(self, shapes, padded_shape): """Tests invalid padded batching of dense tensor windows. @@ -420,17 +421,18 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): for substructure in structure ]) - @parameterized.parameters( - (None, np.int64([[1], [2], [3]]), dtypes.bool, [-1]), - (None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]), - (None, np.int64([[1], [2], [3]]), dtypes.float32, [-1]), - (None, np.int64([[1], [2], [3]]), dtypes.string, [-1]), - (None, np.int64([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]), - (None, np.int64([[1, 3, 1], [3, 1, 3]]), dtypes.int32, [-1, -1, -1]), - ((None, None, None), np.int64([[1], [2], [3]]), dtypes.int32, [-1]), - ((None, (None, None)), np.int64([[1], [2], [3]]), dtypes.int32, [-1]), - (None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]), - (None, np.int64([[1], [2], [3]]), dtypes.int32, np.int64([10])), + @parameterized.named_parameters( + ("1", None, np.int64([[1], [2], [3]]), dtypes.bool, [-1]), + ("2", None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]), + ("3", None, np.int64([[1], [2], [3]]), dtypes.float32, [-1]), + ("4", None, np.int64([[1], [2], [3]]), dtypes.string, [-1]), + ("5", None, np.int64([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]), + ("6", None, np.int64([[1, 3, 1], [3, 1, 3]]), dtypes.int32, [-1, -1, -1]), + ("7", (None, None, None), np.int64([[1], [2], [3]]), dtypes.int32, [-1]), + ("8", (None, + (None, None)), np.int64([[1], [2], [3]]), dtypes.int32, [-1]), + ("9", None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]), + ("10", None, np.int64([[1], [2], [3]]), dtypes.int32, np.int64([10])), ) def testWindowDatasetPaddedBatchSparse(self, structure, shapes, dtype, padded_shape): @@ -463,10 +465,10 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): actual = sess.run(get_next) self._assertEqual(expected, actual) - @parameterized.parameters( - (np.int64([[1], [2], [3]]), [-1]), - (np.int64([[1, 3], [2, 2], [3, 1]]), [-1, -1]), - (np.int64([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]), + @parameterized.named_parameters( + ("1", np.int64([[1], [2], [3]]), [-1]), + ("2", np.int64([[1, 3], [2, 2], [3, 1]]), [-1, -1]), + ("3", np.int64([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]), ) def testWindowDatasetPaddedBatchSparseDynamicShape(self, shapes, padded_shape): @@ -495,9 +497,9 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): actual = sess.run(get_next) self._assertEqual(expected, actual) - @parameterized.parameters( - (np.int64([[1]]), [0]), - (np.int64([[10], [20]]), [15]), + @parameterized.named_parameters( + ("1", np.int64([[1]]), [0]), + ("2", np.int64([[10], [20]]), [15]), ) def testWindowDatasetPaddedBatchSparseInvalid(self, shapes, padded_shape): """Tests invalid padded batching of sparse tensor windows. diff --git a/tensorflow/contrib/data/python/ops/map_defun.py b/tensorflow/contrib/data/python/ops/map_defun.py index 54d5cd6da0..3d0d0993c9 100644 --- a/tensorflow/contrib/data/python/ops/map_defun.py +++ b/tensorflow/contrib/data/python/ops/map_defun.py @@ -53,6 +53,4 @@ def map_defun(fn, elems, output_dtypes, output_shapes): elems = [ops.convert_to_tensor(e) for e in elems] output_shapes = [tensor_shape.TensorShape(s) for s in output_shapes] - if not all(s.is_fully_defined() for s in output_shapes): - raise ValueError("All fn output shapes must be fully defined.") return gen_dataset_ops.map_defun(elems, output_dtypes, output_shapes, fn) diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py index d39fd57294..3cee3e37a7 100644 --- a/tensorflow/contrib/distribute/python/keras_test.py +++ b/tensorflow/contrib/distribute/python/keras_test.py @@ -446,8 +446,7 @@ class TestWithDistributionStrategy(test.TestCase): dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) dataset = dataset.repeat(100) - with self.assertRaisesRegexp(ValueError, - 'expected input to have 2 dimensions'): + with self.assertRaisesRegexp(ValueError, 'expected input to have shape'): model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) # Wrong input shape diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index 4fb70ec685..6ba83976fc 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -310,7 +310,8 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): def get_host_cpu_device(self, host_id): if self._tpu_cluster_resolver.get_master() in ('', 'local'): return '/replica:0/task:0/device:CPU:0' - return '/job:tpu_worker/task:%d/device:CPU:0' % (host_id,) + job_name = self._tpu_cluster_resolver.get_job_name() or 'tpu_worker' + return '/job:%s/task:%d/device:CPU:0' % (job_name, host_id) def configure(self, session_config=None, diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index 77f62df99d..437b3d965d 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -446,6 +446,7 @@ py_library( "//tensorflow/python/estimator", "//tensorflow/python/estimator:head", "//tensorflow/python/estimator:optimizers", + "//tensorflow/python/ops/losses", "@six_archive//:six", ], ) diff --git a/tensorflow/contrib/estimator/python/estimator/rnn.py b/tensorflow/contrib/estimator/python/estimator/rnn.py index 7c49cd00d1..98660bb731 100644 --- a/tensorflow/contrib/estimator/python/estimator/rnn.py +++ b/tensorflow/contrib/estimator/python/estimator/rnn.py @@ -37,6 +37,7 @@ from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import rnn from tensorflow.python.ops import rnn_cell from tensorflow.python.ops import variable_scope +from tensorflow.python.ops.losses import losses from tensorflow.python.summary import summary from tensorflow.python.training import optimizer as optimizer_lib from tensorflow.python.training import training_util @@ -405,6 +406,7 @@ class RNNClassifier(estimator.Estimator): weight_column=None, label_vocabulary=None, optimizer='Adagrad', + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE, input_layer_partitioner=None, config=None): """Initializes a `RNNClassifier` instance. @@ -454,6 +456,8 @@ class RNNClassifier(estimator.Estimator): string. optimizer: An instance of `tf.Optimizer` or string specifying optimizer type. Defaults to Adagrad optimizer. + loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how + to reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`. input_layer_partitioner: Optional. Partitioner for input layer. Defaults to `min_max_variable_partitioner` with `min_slice_size` 64 << 20. config: `RunConfig` object to configure the runtime settings. @@ -467,11 +471,15 @@ class RNNClassifier(estimator.Estimator): if n_classes == 2: head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( # pylint: disable=protected-access weight_column=weight_column, - label_vocabulary=label_vocabulary) + label_vocabulary=label_vocabulary, + loss_reduction=loss_reduction) else: head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint: disable=protected-access - n_classes, weight_column=weight_column, - label_vocabulary=label_vocabulary) + n_classes, + weight_column=weight_column, + label_vocabulary=label_vocabulary, + loss_reduction=loss_reduction) + def _model_fn(features, labels, mode, config): return _rnn_model_fn( features=features, diff --git a/tensorflow/contrib/estimator/python/estimator/rnn_test.py b/tensorflow/contrib/estimator/python/estimator/rnn_test.py index 959b40371a..1aebed348d 100644 --- a/tensorflow/contrib/estimator/python/estimator/rnn_test.py +++ b/tensorflow/contrib/estimator/python/estimator/rnn_test.py @@ -713,7 +713,7 @@ class RNNClassifierTrainingTest(test.TestCase): # Uses same checkpoint and examples as testBinaryClassEvaluationMetrics. # See that test for loss calculation. - mock_optimizer = self._mock_optimizer(expected_loss=1.119661) + mock_optimizer = self._mock_optimizer(expected_loss=0.559831) sequence_feature_columns = [ seq_fc.sequence_numeric_column('price', shape=(1,))] @@ -748,7 +748,7 @@ class RNNClassifierTrainingTest(test.TestCase): # Uses same checkpoint and examples as testMultiClassEvaluationMetrics. # See that test for loss calculation. - mock_optimizer = self._mock_optimizer(expected_loss=2.662932) + mock_optimizer = self._mock_optimizer(expected_loss=1.331465) sequence_feature_columns = [ seq_fc.sequence_numeric_column('price', shape=(1,))] @@ -812,20 +812,32 @@ class RNNClassifierEvaluationTest(test.TestCase): # probability = exp(logits) / (1 + exp(logits)) = [[0.353593], [0.504930]] # loss = -label * ln(p) - (1 - label) * ln(1 - p) # = [[0.436326], [0.683335]] + # sum_over_batch_size = (0.436326 + 0.683335)/2 expected_metrics = { - ops.GraphKeys.GLOBAL_STEP: global_step, - metric_keys.MetricKeys.LOSS: 1.119661, - metric_keys.MetricKeys.LOSS_MEAN: 0.559831, - metric_keys.MetricKeys.ACCURACY: 1.0, - metric_keys.MetricKeys.PREDICTION_MEAN: 0.429262, - metric_keys.MetricKeys.LABEL_MEAN: 0.5, - metric_keys.MetricKeys.ACCURACY_BASELINE: 0.5, + ops.GraphKeys.GLOBAL_STEP: + global_step, + metric_keys.MetricKeys.LOSS: + 0.559831, + metric_keys.MetricKeys.LOSS_MEAN: + 0.559831, + metric_keys.MetricKeys.ACCURACY: + 1.0, + metric_keys.MetricKeys.PREDICTION_MEAN: + 0.429262, + metric_keys.MetricKeys.LABEL_MEAN: + 0.5, + metric_keys.MetricKeys.ACCURACY_BASELINE: + 0.5, # With default threshold of 0.5, the model is a perfect classifier. - metric_keys.MetricKeys.RECALL: 1.0, - metric_keys.MetricKeys.PRECISION: 1.0, + metric_keys.MetricKeys.RECALL: + 1.0, + metric_keys.MetricKeys.PRECISION: + 1.0, # Positive example is scored above negative, so AUC = 1.0. - metric_keys.MetricKeys.AUC: 1.0, - metric_keys.MetricKeys.AUC_PR: 1.0, + metric_keys.MetricKeys.AUC: + 1.0, + metric_keys.MetricKeys.AUC_PR: + 1.0, } self.assertAllClose( sorted_key_dict(expected_metrics), sorted_key_dict(eval_metrics)) @@ -871,9 +883,10 @@ class RNNClassifierEvaluationTest(test.TestCase): # [0.059494, 0.572639, 0.367866]] # loss = -1. * log(softmax[label]) # = [[2.105432], [0.557500]] + # sum_over_batch_size = (2.105432 + 0.557500)/2 expected_metrics = { ops.GraphKeys.GLOBAL_STEP: global_step, - metric_keys.MetricKeys.LOSS: 2.662932, + metric_keys.MetricKeys.LOSS: 1.331465, metric_keys.MetricKeys.LOSS_MEAN: 1.331466, metric_keys.MetricKeys.ACCURACY: 0.5, } diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc index 0ccb4583ab..716bb87e38 100644 --- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc +++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc @@ -174,7 +174,7 @@ class FusedConv2DBiasActivationOp : public OpKernel { // Input bias is a 1-D tensor, with size matching output depth. const Tensor& bias = context->input(kBias); - OP_REQUIRES_OK(context, CheckShape(bias, "conv_input")); + OP_REQUIRES_OK(context, CheckShape(bias, "bias")); const Tensor& conv_input_scale_tensor = context->input(kConvInputScale); const Tensor& side_input_scale_tensor = context->input(kSideInputScale); diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index 418b0cf392..61185f65a9 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -403,6 +403,7 @@ py_test( srcs = ["python/learn/estimators/dnn_test.py"], shard_count = 4, srcs_version = "PY2AND3", + tags = ["notap"], deps = [ ":learn", "//tensorflow/contrib/layers:layers_py", diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD index 0091587bf7..f320b53d94 100644 --- a/tensorflow/contrib/lite/BUILD +++ b/tensorflow/contrib/lite/BUILD @@ -36,10 +36,10 @@ cc_library( srcs = ["arena_planner.cc"], hdrs = ["arena_planner.h"], deps = [ - ":context", ":graph_info", ":memory_planner", ":simple_memory_arena", + "//tensorflow/contrib/lite/c:c_api_internal", ], ) @@ -54,6 +54,7 @@ cc_test( deps = [ ":arena_planner", "//tensorflow/contrib/lite/testing:util", + "//tensorflow/core:framework", "//tensorflow/core:lib", "@com_google_googletest//:gtest", ], @@ -63,27 +64,27 @@ cc_test( # TODO(aselle): Resolve problems preventing C99 usage. cc_library( name = "context", - srcs = ["context.c"], hdrs = ["context.h"], + deps = ["//tensorflow/contrib/lite/c:c_api_internal"], ) cc_library( name = "graph_info", hdrs = ["graph_info.h"], - deps = [":context"], + deps = ["//tensorflow/contrib/lite/c:c_api_internal"], ) cc_library( name = "memory_planner", hdrs = ["memory_planner.h"], - deps = [":context"], + deps = ["//tensorflow/contrib/lite/c:c_api_internal"], ) cc_library( name = "simple_memory_arena", srcs = ["simple_memory_arena.cc"], hdrs = ["simple_memory_arena.h"], - deps = [":context"], + deps = ["//tensorflow/contrib/lite/c:c_api_internal"], ) cc_library( @@ -91,7 +92,7 @@ cc_library( hdrs = [ "builtin_op_data.h", ], - deps = [":context"], + deps = ["//tensorflow/contrib/lite/c:c_api_internal"], ) cc_library( @@ -121,12 +122,12 @@ cc_library( name = "framework", srcs = [ "allocation.cc", - "error_reporter.cc", "graph_info.cc", "interpreter.cc", "model.cc", - "op_resolver.cc", + "mutable_op_resolver.cc", "optional_debug_tools.cc", + "stderr_reporter.cc", ] + select({ "//tensorflow:android": [ "nnapi_delegate.cc", @@ -149,9 +150,11 @@ cc_library( "graph_info.h", "interpreter.h", "model.h", + "mutable_op_resolver.h", "nnapi_delegate.h", "op_resolver.h", "optional_debug_tools.h", + "stderr_reporter.h", ], copts = tflite_copts(), linkopts = [ @@ -164,14 +167,14 @@ cc_library( }), deps = [ ":arena_planner", - ":builtin_op_data", - ":context", ":graph_info", ":memory_planner", ":schema_fbs_version", ":simple_memory_arena", ":string", ":util", + "//tensorflow/contrib/lite/c:c_api_internal", + "//tensorflow/contrib/lite/core/api", "//tensorflow/contrib/lite/kernels:eigen_support", "//tensorflow/contrib/lite/kernels:gemm_support", "//tensorflow/contrib/lite/nnapi:nnapi_lib", @@ -210,6 +213,8 @@ cc_test( deps = [ ":framework", ":string_util", + "//tensorflow/contrib/lite/c:c_api_internal", + "//tensorflow/contrib/lite/core/api", "//tensorflow/contrib/lite/kernels:builtin_ops", "//tensorflow/contrib/lite/kernels:kernel_util", "//tensorflow/contrib/lite/kernels/internal:tensor_utils", @@ -259,6 +264,8 @@ cc_test( ], deps = [ ":framework", + "//tensorflow/contrib/lite/c:c_api_internal", + "//tensorflow/contrib/lite/core/api", "//tensorflow/contrib/lite/testing:util", "@com_google_googletest//:gtest", ], @@ -266,9 +273,9 @@ cc_test( # Test OpResolver. cc_test( - name = "op_resolver_test", + name = "mutable_op_resolver_test", size = "small", - srcs = ["op_resolver_test.cc"], + srcs = ["mutable_op_resolver_test.cc"], tags = ["no_oss"], deps = [ ":framework", @@ -277,24 +284,12 @@ cc_test( ], ) -# Test the C extension API code. -cc_test( - name = "context_test", - size = "small", - srcs = ["context_test.cc"], - deps = [ - ":framework", - "//tensorflow/contrib/lite/testing:util", - "@com_google_googletest//:gtest", - ], -) - cc_library( name = "util", srcs = ["util.cc"], hdrs = ["util.h"], deps = [ - ":context", + "//tensorflow/contrib/lite/c:c_api_internal", ], ) @@ -304,7 +299,6 @@ cc_test( srcs = ["util_test.cc"], tags = ["no_oss"], deps = [ - ":context", ":util", "//tensorflow/contrib/lite/testing:util", "@com_google_googletest//:gtest", diff --git a/tensorflow/contrib/lite/allocation.cc b/tensorflow/contrib/lite/allocation.cc index 8946261814..21cb1832a7 100644 --- a/tensorflow/contrib/lite/allocation.cc +++ b/tensorflow/contrib/lite/allocation.cc @@ -23,8 +23,8 @@ limitations under the License. #include <cstring> #include <utility> -#include "tensorflow/contrib/lite/context.h" -#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/contrib/lite/core/api/error_reporter.h" namespace tflite { diff --git a/tensorflow/contrib/lite/allocation.h b/tensorflow/contrib/lite/allocation.h index 121f3d2646..182bc0977f 100644 --- a/tensorflow/contrib/lite/allocation.h +++ b/tensorflow/contrib/lite/allocation.h @@ -20,8 +20,8 @@ limitations under the License. #include <cstdio> #include <cstdlib> #include <vector> -#include "tensorflow/contrib/lite/context.h" -#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/contrib/lite/core/api/error_reporter.h" #include "tensorflow/contrib/lite/simple_memory_arena.h" #include "tensorflow/contrib/lite/string.h" diff --git a/tensorflow/contrib/lite/arena_planner.h b/tensorflow/contrib/lite/arena_planner.h index 55003cf4e9..382577045b 100644 --- a/tensorflow/contrib/lite/arena_planner.h +++ b/tensorflow/contrib/lite/arena_planner.h @@ -18,7 +18,7 @@ limitations under the License. #include <memory> #include <vector> -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/graph_info.h" #include "tensorflow/contrib/lite/memory_planner.h" #include "tensorflow/contrib/lite/simple_memory_arena.h" @@ -37,8 +37,8 @@ struct AllocationInfo; // each tensor needs to be allocated and deallocated, and preallocates all the // necessary memory (the PlanAllocations phase). It then assigns portions of // this memory buffer to each tensor (the ExecuteAllocations phase). Tensors may -// share some of the buffer if a tensor B is to be allocated after another tensor -// A has been deallocated. +// share some of the buffer if a tensor B is to be allocated after another +// tensor A has been deallocated. // // If dynamic tensors are used the planning steps can be repeated during model // execution. Since dynamic tensors don't have sizes until after the diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl index 0246e7fa30..9317e2bb6e 100644 --- a/tensorflow/contrib/lite/build_def.bzl +++ b/tensorflow/contrib/lite/build_def.bzl @@ -49,6 +49,9 @@ def tflite_linkopts_unstripped(): Returns: a select object with proper linkopts """ + + # In case you wonder why there's no --icf is because the gains were + # negligible, and created potential compatibility problems. return select({ "//tensorflow:android": [ "-Wl,--no-export-dynamic", # Only inc syms referenced by dynamic obj. @@ -56,13 +59,7 @@ def tflite_linkopts_unstripped(): "-Wl,--gc-sections", # Eliminate unused code and data. "-Wl,--as-needed", # Don't link unused libs. ], - "//tensorflow:darwin": [], - "//tensorflow:ios": [], - "//tensorflow/contrib/lite:mips": [], - "//tensorflow/contrib/lite:mips64": [], - "//conditions:default": [ - "-Wl,--icf=all", # Identical code folding. - ], + "//conditions:default": [], }) def tflite_jni_linkopts_unstripped(): @@ -74,17 +71,15 @@ def tflite_jni_linkopts_unstripped(): Returns: a select object with proper linkopts """ + + # In case you wonder why there's no --icf is because the gains were + # negligible, and created potential compatibility problems. return select({ "//tensorflow:android": [ "-Wl,--gc-sections", # Eliminate unused code and data. "-Wl,--as-needed", # Don't link unused libs. ], - "//tensorflow:darwin": [], - "//tensorflow/contrib/lite:mips": [], - "//tensorflow/contrib/lite:mips64": [], - "//conditions:default": [ - "-Wl,--icf=all", # Identical code folding. - ], + "//conditions:default": [], }) def tflite_linkopts(): diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h index aecd71910c..30901bd0fa 100644 --- a/tensorflow/contrib/lite/builtin_op_data.h +++ b/tensorflow/contrib/lite/builtin_op_data.h @@ -12,297 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +// Compatibility shim for new location of interface definitions. + #ifndef TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_ #define TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_ -#include <stdint.h> - -#include "tensorflow/contrib/lite/context.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -// TODO(aselle): Consider using "if this then that" for testing. - -// Useful placeholder to put in otherwise empty structs to avoid size warnings. -typedef struct { - char dummy_; -} EmptyStructPlaceholder; - -// Possible padding types (for convolutions) -typedef enum { - kTfLitePaddingUnknown = 0, - kTfLitePaddingSame, - kTfLitePaddingValid, -} TfLitePadding; - -typedef struct { - int width; - int height; -} TfLitePaddingValues; - -// Possible fused activation functions. -// TODO(aselle): rename to TfLiteActivation -typedef enum { - kTfLiteActNone = 0, - kTfLiteActRelu, - kTfLiteActRelu1, - kTfLiteActRelu6, - kTfLiteActTanh, - kTfLiteActSignBit, - kTfLiteActSigmoid, -} TfLiteFusedActivation; - -typedef struct { - TfLitePadding padding; - int stride_width; - int stride_height; - int dilation_width_factor; - int dilation_height_factor; - TfLiteFusedActivation activation; -} TfLiteConvParams; - -typedef struct { - TfLitePadding padding; - int stride_width; - int stride_height; - int filter_width; - int filter_height; - TfLiteFusedActivation activation; - struct { - TfLitePaddingValues padding; - } computed; -} TfLitePoolParams; - -typedef struct { - TfLitePadding padding; - int stride_width; - int stride_height; - int depth_multiplier; - TfLiteFusedActivation activation; -} TfLiteDepthwiseConvParams; - -typedef struct { - int rank; - TfLiteFusedActivation activation; -} TfLiteSVDFParams; - -typedef struct { - TfLiteFusedActivation activation; -} TfLiteRNNParams; - -typedef struct { - bool time_major; - TfLiteFusedActivation activation; -} TfLiteSequenceRNNParams; - -typedef enum { - kTfLiteFullyConnectedWeightsFormatDefault = 0, - kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8 = 1, -} TfLiteFullyConnectedWeightsFormat; - -typedef struct { - // Parameters for FullyConnected version 1 or above. - TfLiteFusedActivation activation; - - // Parameters for FullyConnected version 2 or above. - TfLiteFullyConnectedWeightsFormat weights_format; -} TfLiteFullyConnectedParams; - -typedef enum { - kTfLiteLshProjectionUnknown = 0, - kTfLiteLshProjectionSparse = 1, - kTfLiteLshProjectionDense = 2, -} TfLiteLSHProjectionType; - -typedef struct { - TfLiteLSHProjectionType type; -} TfLiteLSHProjectionParams; - -typedef struct { - float beta; -} TfLiteSoftmaxParams; - -typedef struct { - int axis; - TfLiteFusedActivation activation; -} TfLiteConcatenationParams; - -typedef struct { - TfLiteFusedActivation activation; -} TfLiteAddParams; - -typedef struct { - EmptyStructPlaceholder placeholder_; -} TfLiteSpaceToBatchNDParams; - -typedef struct { - EmptyStructPlaceholder placeholder_; -} TfLiteBatchToSpaceNDParams; - -typedef struct { - TfLiteFusedActivation activation; -} TfLiteMulParams; - -typedef struct { - TfLiteFusedActivation activation; -} TfLiteSubParams; - -typedef struct { - TfLiteFusedActivation activation; -} TfLiteDivParams; - -typedef struct { - TfLiteFusedActivation activation; -} TfLiteL2NormParams; - -typedef struct { - int radius; - float bias; - float alpha; - float beta; -} TfLiteLocalResponseNormParams; - -typedef enum { - kTfLiteLSTMFullKernel = 0, - kTfLiteLSTMBasicKernel -} TfLiteLSTMKernelType; - -typedef struct { - // Parameters for LSTM version 1. - TfLiteFusedActivation activation; - float cell_clip; - float proj_clip; - - // Parameters for LSTM version 2. - // kTfLiteLSTMBasicKernel is only supported in version 2 or above. - TfLiteLSTMKernelType kernel_type; -} TfLiteLSTMParams; - -typedef struct { - bool align_corners; -} TfLiteResizeBilinearParams; - -typedef struct { - EmptyStructPlaceholder placeholder_; -} TfLitePadParams; - -typedef struct { - EmptyStructPlaceholder placeholder_; -} TfLitePadV2Params; - -typedef struct { - // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. - // For now we will fix the maximum possible number of dimensions. - int shape[8]; - int num_dimensions; -} TfLiteReshapeParams; - -typedef struct { - int ngram_size; - int max_skip_size; - bool include_all_ngrams; -} TfLiteSkipGramParams; - -typedef struct { - int block_size; -} TfLiteSpaceToDepthParams; - -typedef struct { - TfLiteType in_data_type; - TfLiteType out_data_type; -} TfLiteCastParams; - -typedef enum { - kTfLiteCombinerTypeSum = 0, - kTfLiteCombinerTypeMean = 1, - kTfLiteCombinerTypeSqrtn = 2, -} TfLiteCombinerType; - -typedef struct { - TfLiteCombinerType combiner; -} TfLiteEmbeddingLookupSparseParams; - -typedef struct { - int axis; -} TfLiteGatherParams; - -typedef struct { - EmptyStructPlaceholder placeholder_; -} TfLiteTransposeParams; - -typedef struct { - bool keep_dims; -} TfLiteReducerParams; - -typedef struct { - int num_splits; -} TfLiteSplitParams; - -typedef struct { - // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. - // For now we will fix the maximum possible number of dimensions. - int squeeze_dims[8]; - int num_squeeze_dims; -} TfLiteSqueezeParams; - -typedef struct { - int begin_mask; - int end_mask; - int ellipsis_mask; - int new_axis_mask; - int shrink_axis_mask; -} TfLiteStridedSliceParams; - -typedef struct { - TfLiteType output_type; -} TfLiteArgMaxParams; - -typedef struct { - TfLiteType output_type; -} TfLiteArgMinParams; - -typedef struct { - TfLitePadding padding; - int stride_width; - int stride_height; -} TfLiteTransposeConvParams; - -typedef struct { - bool validate_indices; -} TfLiteSparseToDenseParams; - -typedef struct { - TfLiteType out_type; -} TfLiteShapeParams; - -typedef struct { - // Parameters supported by version 1: - float min; - float max; - int num_bits; - - // Parameters supported by version 2: - bool narrow_range; -} TfLiteFakeQuantParams; - -typedef struct { - int values_count; - int axis; -} TfLitePackParams; - -typedef struct { - int axis; -} TfLiteOneHotParams; - -typedef struct { - int num; - int axis; -} TfLiteUnpackParams; - -#ifdef __cplusplus -} // extern "C" -#endif // __cplusplus +#include "tensorflow/contrib/lite/c/builtin_op_data.h" #endif // TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_ diff --git a/tensorflow/contrib/lite/c/BUILD b/tensorflow/contrib/lite/c/BUILD new file mode 100644 index 0000000000..663eb63cad --- /dev/null +++ b/tensorflow/contrib/lite/c/BUILD @@ -0,0 +1,39 @@ +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # Apache 2.0 + +cc_library( + name = "c_api_internal", + srcs = ["c_api_internal.c"], + hdrs = [ + "builtin_op_data.h", + "c_api_internal.h", + ], + visibility = [ + "//tensorflow/contrib/lite:__subpackages__", + ], +) + +# Test the C extension API code. +cc_test( + name = "c_api_internal_test", + size = "small", + srcs = ["c_api_internal_test.cc"], + deps = [ + ":c_api_internal", + "@com_google_googletest//:gtest", + ], +) + +cc_test( + name = "builtin_op_data_test", + size = "small", + srcs = ["builtin_op_data_test.cc"], + copts = ["-Wno-unused-variable"], + deps = [ + ":c_api_internal", + "@com_google_googletest//:gtest", + ], +) diff --git a/tensorflow/contrib/lite/c/builtin_op_data.h b/tensorflow/contrib/lite/c/builtin_op_data.h new file mode 100644 index 0000000000..fa43e6a024 --- /dev/null +++ b/tensorflow/contrib/lite/c/builtin_op_data.h @@ -0,0 +1,298 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_C_BUILTIN_OP_DATA_H_ +#define TENSORFLOW_CONTRIB_LITE_C_BUILTIN_OP_DATA_H_ + +#include <stdint.h> + +#include "tensorflow/contrib/lite/c/c_api_internal.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// TODO(aselle): Consider using "if this then that" for testing. + +// Possible padding types (for convolutions) +typedef enum { + kTfLitePaddingUnknown = 0, + kTfLitePaddingSame, + kTfLitePaddingValid, +} TfLitePadding; + +typedef struct { + int width; + int height; +} TfLitePaddingValues; + +// Possible fused activation functions. +// TODO(aselle): rename to TfLiteActivation +typedef enum { + kTfLiteActNone = 0, + kTfLiteActRelu, + kTfLiteActRelu1, + kTfLiteActRelu6, + kTfLiteActTanh, + kTfLiteActSignBit, + kTfLiteActSigmoid, +} TfLiteFusedActivation; + +typedef struct { + TfLitePadding padding; + int stride_width; + int stride_height; + int dilation_width_factor; + int dilation_height_factor; + TfLiteFusedActivation activation; +} TfLiteConvParams; + +typedef struct { + TfLitePadding padding; + int stride_width; + int stride_height; + int filter_width; + int filter_height; + TfLiteFusedActivation activation; + struct { + TfLitePaddingValues padding; + } computed; +} TfLitePoolParams; + +typedef struct { + TfLitePadding padding; + int stride_width; + int stride_height; + int depth_multiplier; + TfLiteFusedActivation activation; +} TfLiteDepthwiseConvParams; + +typedef struct { + int rank; + TfLiteFusedActivation activation; +} TfLiteSVDFParams; + +typedef struct { + TfLiteFusedActivation activation; +} TfLiteRNNParams; + +typedef struct { + bool time_major; + TfLiteFusedActivation activation; +} TfLiteSequenceRNNParams; + +typedef enum { + kTfLiteFullyConnectedWeightsFormatDefault = 0, + kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8 = 1, +} TfLiteFullyConnectedWeightsFormat; + +typedef struct { + // Parameters for FullyConnected version 1 or above. + TfLiteFusedActivation activation; + + // Parameters for FullyConnected version 2 or above. + TfLiteFullyConnectedWeightsFormat weights_format; +} TfLiteFullyConnectedParams; + +typedef enum { + kTfLiteLshProjectionUnknown = 0, + kTfLiteLshProjectionSparse = 1, + kTfLiteLshProjectionDense = 2, +} TfLiteLSHProjectionType; + +typedef struct { + TfLiteLSHProjectionType type; +} TfLiteLSHProjectionParams; + +typedef struct { + float beta; +} TfLiteSoftmaxParams; + +typedef struct { + int axis; + TfLiteFusedActivation activation; +} TfLiteConcatenationParams; + +typedef struct { + TfLiteFusedActivation activation; +} TfLiteAddParams; + +typedef struct { +} TfLiteSpaceToBatchNDParams; + +typedef struct { +} TfLiteBatchToSpaceNDParams; + +typedef struct { + TfLiteFusedActivation activation; +} TfLiteMulParams; + +typedef struct { + TfLiteFusedActivation activation; +} TfLiteSubParams; + +typedef struct { + TfLiteFusedActivation activation; +} TfLiteDivParams; + +typedef struct { + TfLiteFusedActivation activation; +} TfLiteL2NormParams; + +typedef struct { + int radius; + float bias; + float alpha; + float beta; +} TfLiteLocalResponseNormParams; + +typedef enum { + kTfLiteLSTMFullKernel = 0, + kTfLiteLSTMBasicKernel +} TfLiteLSTMKernelType; + +typedef struct { + // Parameters for LSTM version 1. + TfLiteFusedActivation activation; + float cell_clip; + float proj_clip; + + // Parameters for LSTM version 2. + // kTfLiteLSTMBasicKernel is only supported in version 2 or above. + TfLiteLSTMKernelType kernel_type; +} TfLiteLSTMParams; + +typedef struct { + bool align_corners; +} TfLiteResizeBilinearParams; + +typedef struct { +} TfLitePadParams; + +typedef struct { +} TfLitePadV2Params; + +typedef struct { + // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. + // For now we will fix the maximum possible number of dimensions. + int shape[8]; + int num_dimensions; +} TfLiteReshapeParams; + +typedef struct { + int ngram_size; + int max_skip_size; + bool include_all_ngrams; +} TfLiteSkipGramParams; + +typedef struct { + int block_size; +} TfLiteSpaceToDepthParams; + +typedef struct { + TfLiteType in_data_type; + TfLiteType out_data_type; +} TfLiteCastParams; + +typedef enum { + kTfLiteCombinerTypeSum = 0, + kTfLiteCombinerTypeMean = 1, + kTfLiteCombinerTypeSqrtn = 2, +} TfLiteCombinerType; + +typedef struct { + TfLiteCombinerType combiner; +} TfLiteEmbeddingLookupSparseParams; + +typedef struct { + int axis; +} TfLiteGatherParams; + +typedef struct { +} TfLiteTransposeParams; + +typedef struct { + bool keep_dims; +} TfLiteReducerParams; + +typedef struct { + int num_splits; +} TfLiteSplitParams; + +typedef struct { + // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. + // For now we will fix the maximum possible number of dimensions. + int squeeze_dims[8]; + int num_squeeze_dims; +} TfLiteSqueezeParams; + +typedef struct { + int begin_mask; + int end_mask; + int ellipsis_mask; + int new_axis_mask; + int shrink_axis_mask; +} TfLiteStridedSliceParams; + +typedef struct { + TfLiteType output_type; +} TfLiteArgMaxParams; + +typedef struct { + TfLiteType output_type; +} TfLiteArgMinParams; + +typedef struct { + TfLitePadding padding; + int stride_width; + int stride_height; +} TfLiteTransposeConvParams; + +typedef struct { + bool validate_indices; +} TfLiteSparseToDenseParams; + +typedef struct { + TfLiteType out_type; +} TfLiteShapeParams; + +typedef struct { + // Parameters supported by version 1: + float min; + float max; + int num_bits; + + // Parameters supported by version 2: + bool narrow_range; +} TfLiteFakeQuantParams; + +typedef struct { + int values_count; + int axis; +} TfLitePackParams; + +typedef struct { + int axis; +} TfLiteOneHotParams; + +typedef struct { + int num; + int axis; +} TfLiteUnpackParams; + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_CONTRIB_LITE_C_BUILTIN_OP_DATA_H_ diff --git a/tensorflow/contrib/lite/c/builtin_op_data_test.cc b/tensorflow/contrib/lite/c/builtin_op_data_test.cc new file mode 100644 index 0000000000..4d0ba75e68 --- /dev/null +++ b/tensorflow/contrib/lite/c/builtin_op_data_test.cc @@ -0,0 +1,83 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include <gtest/gtest.h> + +namespace tflite { + +// Builtin op data is just a set of data definitions, so the only meaningful +// test we can run is whether we can create the structs we expect to find. +// Testing each struct's members might be possible, but it seems unnecessary +// until we've locked down the API. The build rule has copts set to ignore the +// unused variable warning, since this is just a compilation test. +TEST(IntArray, CanCompileStructs) { + TfLitePadding padding = kTfLitePaddingSame; + TfLitePaddingValues padding_values; + TfLiteFusedActivation fused_activation = kTfLiteActRelu; + TfLiteConvParams conv_params; + TfLitePoolParams pool_params; + TfLiteDepthwiseConvParams depthwise_conv_params; + TfLiteSVDFParams svdf_params; + TfLiteRNNParams rnn_params; + TfLiteSequenceRNNParams sequence_rnn_params; + TfLiteFullyConnectedWeightsFormat fully_connected_weights_format = + kTfLiteFullyConnectedWeightsFormatDefault; + TfLiteFullyConnectedParams fully_connected_params; + TfLiteLSHProjectionType projection_type = kTfLiteLshProjectionDense; + TfLiteLSHProjectionParams projection_params; + TfLiteSoftmaxParams softmax_params; + TfLiteConcatenationParams concatenation_params; + TfLiteAddParams add_params; + TfLiteSpaceToBatchNDParams space_to_batch_nd_params; + TfLiteBatchToSpaceNDParams batch_to_space_nd_params; + TfLiteMulParams mul_params; + TfLiteSubParams sub_params; + TfLiteDivParams div_params; + TfLiteL2NormParams l2_norm_params; + TfLiteLocalResponseNormParams local_response_norm_params; + TfLiteLSTMKernelType lstm_kernel_type = kTfLiteLSTMBasicKernel; + TfLiteLSTMParams lstm_params; + TfLiteResizeBilinearParams resize_bilinear_params; + TfLitePadParams pad_params; + TfLitePadV2Params pad_v2_params; + TfLiteReshapeParams reshape_params; + TfLiteSkipGramParams skip_gram_params; + TfLiteSpaceToDepthParams space_to_depth_params; + TfLiteCastParams cast_params; + TfLiteCombinerType combiner_type = kTfLiteCombinerTypeSqrtn; + TfLiteEmbeddingLookupSparseParams lookup_sparse_params; + TfLiteGatherParams gather_params; + TfLiteTransposeParams transpose_params; + TfLiteReducerParams reducer_params; + TfLiteSplitParams split_params; + TfLiteSqueezeParams squeeze_params; + TfLiteStridedSliceParams strided_slice_params; + TfLiteArgMaxParams arg_max_params; + TfLiteArgMinParams arg_min_params; + TfLiteTransposeConvParams transpose_conv_params; + TfLiteSparseToDenseParams sparse_to_dense_params; + TfLiteShapeParams shape_params; + TfLiteFakeQuantParams fake_quant_params; + TfLitePackParams pack_params; + TfLiteOneHotParams one_hot_params; +} + +} // namespace tflite + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/context.c b/tensorflow/contrib/lite/c/c_api_internal.c index 7f2aa316f4..1846bad4b7 100644 --- a/tensorflow/contrib/lite/context.c +++ b/tensorflow/contrib/lite/c/c_api_internal.c @@ -13,8 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include <stdio.h> +#include <stdlib.h> #include <string.h> int TfLiteIntArrayGetSizeInBytes(int size) { @@ -76,7 +77,8 @@ void TfLiteTensorFree(TfLiteTensor* t) { void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims, TfLiteQuantizationParams quantization, char* buffer, size_t size, TfLiteAllocationType allocation_type, - const void* allocation, bool is_variable, TfLiteTensor* tensor) { + const void* allocation, bool is_variable, + TfLiteTensor* tensor) { TfLiteTensorFree(tensor); tensor->type = type; tensor->name = name; diff --git a/tensorflow/contrib/lite/c/c_api_internal.h b/tensorflow/contrib/lite/c/c_api_internal.h new file mode 100644 index 0000000000..48df68a654 --- /dev/null +++ b/tensorflow/contrib/lite/c/c_api_internal.h @@ -0,0 +1,491 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This file defines a C API for implementing operations in tflite. +// These operations can be defined using c++ but the interface between +// the interpreter and the operations are C. +// +// Summary of abstractions +// TF_LITE_ENSURE - Self-sufficient error checking +// TfLiteStatus - Status reporting +// TfLiteIntArray - stores tensor shapes (dims), +// TfLiteContext - allows an op to access the tensors +// TfLiteTensor - tensor (a multidimensional array) +// TfLiteNode - a single node or operation +// TfLiteRegistration - the implementation of a conceptual operation. +// +// Some abstractions in this file are created and managed by Interpreter. +#ifndef TENSORFLOW_CONTRIB_LITE_C_C_API_INTERNAL_H_ +#define TENSORFLOW_CONTRIB_LITE_C_C_API_INTERNAL_H_ + +#include <stdbool.h> +#include <stddef.h> +#include <stdint.h> + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef enum { kTfLiteOk = 0, kTfLiteError = 1 } TfLiteStatus; + +// The list of external context types known to TF Lite. This list exists solely +// to avoid conflicts and to ensure ops can share the external contexts they +// need. Access to the external contexts is controled by one of the +// corresponding support files. +typedef enum { + kTfLiteEigenContext = 0, // include eigen_support.h to use. + kTfLiteGemmLowpContext = 1, // include gemm_support.h to use. + kTfLiteEdgeTpuContext = 2, // Placeholder for Edge TPU support. + kTfLiteMaxExternalContexts = 3 +} TfLiteExternalContextType; + +// An external context is a collection of information unrelated to the TF Lite +// framework, but useful to a subset of the ops. TF Lite knows very little +// about about the actual contexts, but it keeps a list of them, and is able to +// refresh them if configurations like the number of recommended threads +// change. +typedef struct { + TfLiteExternalContextType type; + TfLiteStatus (*Refresh)(struct TfLiteContext* context); +} TfLiteExternalContext; + +// Forward declare so GetNode can use this is in Context. +typedef struct _TfLiteRegistration TfLiteRegistration; +typedef struct _TfLiteDelegate TfLiteDelegate; + +#define kOptionalTensor (-1) + +// Fixed size list of integers. Used for dimensions and inputs/outputs tensor +// indices +typedef struct { + int size; +// gcc 6.1+ have a bug where flexible members aren't properly handled +// https://github.com/google/re2/commit/b94b7cd42e9f02673cd748c1ac1d16db4052514c +#if !defined(__clang__) && defined(__GNUC__) && __GNUC__ == 6 && \ + __GNUC_MINOR__ >= 1 + int data[0]; +#else + int data[]; +#endif +} TfLiteIntArray; + +// Given the size (number of elements) in a TfLiteIntArray, calculate its size +// in bytes. +int TfLiteIntArrayGetSizeInBytes(int size); + +// Create a array of a given `size` (uninitialized entries). +// This returns a pointer, that you must free using TfLiteIntArrayFree(). +TfLiteIntArray* TfLiteIntArrayCreate(int size); + +// Check if two tensors are equal. Returns 1 if they are equal, 0 otherwise. +int TfLiteIntArrayEqual(TfLiteIntArray* a, TfLiteIntArray* b); + +// Create a copy of an array passed as `src`. +// You are expected to free memory with TfLiteIntArrayFree +TfLiteIntArray* TfLiteIntArrayCopy(TfLiteIntArray* src); + +// Free memory of array `v`. +void TfLiteIntArrayFree(TfLiteIntArray* v); + +// Since we must not depend on any libraries, define a minimal subset of +// error macros while avoiding names that have pre-conceived meanings like +// assert and check. + +// Check whether value is true, and if not return kTfLiteError from +// the current function (and report the error string msg). +#define TF_LITE_ENSURE_MSG(context, value, msg) \ + do { \ + if (!(value)) { \ + (context)->ReportError((context), __FILE__ " " msg); \ + return kTfLiteError; \ + } \ + } while (0) + +// Check whether the value `a` is true, and if not return kTfLiteError from +// the current function, while also reporting the location of the error. +#define TF_LITE_ENSURE(context, a) \ + do { \ + if (!(a)) { \ + (context)->ReportError((context), "%s:%d %s was not true.", __FILE__, \ + __LINE__, #a); \ + return kTfLiteError; \ + } \ + } while (0) + +#define TF_LITE_ENSURE_STATUS(a) \ + do { \ + if ((a) != kTfLiteOk) { \ + return kTfLiteError; \ + } \ + } while (0) + +// Check whether the value `a == b` is true, and if not return kTfLiteError from +// the current function, while also reporting the location of the error. +// `a` and `b` may be evaluated more than once, so no side effects or +// extremely expensive computations should be done. +#define TF_LITE_ENSURE_EQ(context, a, b) \ + do { \ + if ((a) != (b)) { \ + (context)->ReportError((context), "%s:%d %s != %s (%d != %d)", __FILE__, \ + __LINE__, #a, #b, (a), (b)); \ + return kTfLiteError; \ + } \ + } while (0) + +#define TF_LITE_ENSURE_OK(context, status) \ + do { \ + if ((status) != kTfLiteOk) { \ + return status; \ + } \ + } while (0) + +// Single-precision complex data type compatible with the C99 definition. +typedef struct { + float re, im; // real and imaginary parts, respectively. +} TfLiteComplex64; + +// Types supported by tensor +typedef enum { + kTfLiteNoType = 0, + kTfLiteFloat32 = 1, + kTfLiteInt32 = 2, + kTfLiteUInt8 = 3, + kTfLiteInt64 = 4, + kTfLiteString = 5, + kTfLiteBool = 6, + kTfLiteInt16 = 7, + kTfLiteComplex64 = 8, +} TfLiteType; + +// Parameters for asymmetric quantization. Quantized values can be converted +// back to float using: +// real_value = scale * (quantized_value - zero_point); +typedef struct { + float scale; + int32_t zero_point; +} TfLiteQuantizationParams; + +// A union of pointers that points to memory for a given tensor. +typedef union { + int* i32; + int64_t* i64; + float* f; + char* raw; + const char* raw_const; + uint8_t* uint8; + bool* b; + int16_t* i16; + TfLiteComplex64* c64; +} TfLitePtrUnion; + +// Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped +// data (or data externally allocated). kTfLiteArenaRw is arena allocated +// data. kTfLiteDynamic is for tensors that are allocated during evaluation. +typedef enum { + kTfLiteMemNone = 0, + kTfLiteMmapRo, + kTfLiteArenaRw, + kTfLiteArenaRwPersistent, + kTfLiteDynamic, +} TfLiteAllocationType; + +// The delegates should use zero or positive integers to represent handles. +// -1 is reserved from unallocated status. +typedef int TfLiteBufferHandle; +const TfLiteBufferHandle kTfLiteNullBufferHandle = -1; + +// An tensor in the interpreter system which is a wrapper around a buffer of +// data including a dimensionality (or NULL if not currently defined). +typedef struct { + // The data type specification for data stored in `data`. This affects + // what member of `data` union should be used. + TfLiteType type; + // A union of data pointers. The appropriate type should be used for a typed + // tensor based on `type`. + TfLitePtrUnion data; + // A pointer to a structure representing the dimensionality interpretation + // that the buffer should have. NOTE: the product of elements of `dims` + // and the element datatype size should be equal to `bytes` below. + TfLiteIntArray* dims; + // Quantization information. + TfLiteQuantizationParams params; + // How memory is mapped + // kTfLiteMmapRo: Memory mapped read only. + // i.e. weights + // kTfLiteArenaRw: Arena allocated read write memory + // (i.e. temporaries, outputs). + TfLiteAllocationType allocation_type; + // The number of bytes required to store the data of this Tensor. I.e. + // (bytes of each element) * dims[0] * ... * dims[n-1]. For example, if + // type is kTfLiteFloat32 and dims = {3, 2} then + // bytes = sizeof(float) * 3 * 2 = 4 * 3 * 2 = 24. + size_t bytes; + + // An opaque pointer to a tflite::MMapAllocation + const void* allocation; + + // Null-terminated name of this tensor. + const char* name; + + // The delegate which knows how to handle `buffer_handle`. + // WARNING: This is an experimental interface that is subject to change. + TfLiteDelegate* delegate; + + // An integer buffer handle that can be handled by `delegate`. + // The value is valid only when delegate is not null. + // WARNING: This is an experimental interface that is subject to change. + TfLiteBufferHandle buffer_handle; + + // If the delegate uses its own buffer (e.g. GPU memory), the delegate is + // responsible to set data_is_stale to true. + // `delegate->CopyFromBufferHandle` can be called to copy the data from + // delegate buffer. + // WARNING: This is an // experimental interface that is subject to change. + bool data_is_stale; + + // True if the tensor is a variable. + bool is_variable; +} TfLiteTensor; + +// Free data memory of tensor `t`; +void TfLiteTensorDataFree(TfLiteTensor* t); + +// Free memory of tensor `t`; +void TfLiteTensorFree(TfLiteTensor* t); + +// Set all of a tensor's fields (and free any previously allocated data). +void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims, + TfLiteQuantizationParams quantization, char* buffer, + size_t size, TfLiteAllocationType allocation_type, + const void* allocation, bool is_variable, + TfLiteTensor* tensor); + +// Resize the allocated data of a (dynamic) tensor. Tensors with allocation +// types other than kTfLiteDynamic will be ignored. +void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor); + +// A structure representing an instance of a node. +// This structure only exhibits the inputs, outputs and user defined data, not +// other features like the type. +typedef struct { + // Inputs to this node expressed as indices into the simulator's tensors. + TfLiteIntArray* inputs; + + // Outputs to this node expressed as indices into the simulator's tensors. + TfLiteIntArray* outputs; + + // Temporary tensors uses during the computations. This usually contains no + // tensors, but ops are allowed to change that if they need scratch space of + // any sort. + TfLiteIntArray* temporaries; + + // Opaque data provided by the node implementer through `Registration.init`. + void* user_data; + + // Opaque data provided to the node if the node is a builtin. This is usually + // a structure defined in builtin_op_data.h + void* builtin_data; + + // Custom initial data. This is the opaque data provided in the flatbuffer. + // WARNING: This is an experimental interface that is subject to change. + const void* custom_initial_data; + int custom_initial_data_size; + + // The pointer to the delegate. This is non-null only when the node is + // created by calling `interpreter.ModifyGraphWithDelegate`. + // WARNING: This is an experimental interface that is subject to change. + TfLiteDelegate* delegate; +} TfLiteNode; + +typedef struct TfLiteContext { + // Number of tensors in the context. + size_t tensors_size; + + // The execution plan contains a list of the node indices in execution + // order. execution_plan->size is the current number of nodes. And, + // execution_plan->data[0] is the first node that needs to be run. + // TfLiteDelegates can traverse the current execution plan by iterating + // through each member of this array and using GetNodeAndRegistration() to + // access details about a node. i.e. + // TfLiteIntArray* execution_plan; + // TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &execution_plan)); + // for (int exec_index = 0; exec_index < execution_plan->size; exec_index++) { + // int node_index = execution_plan->data[exec_index]; + // TfLiteNode* node; + // TfLiteRegistration* reg; + // context->GetNodeAndRegistration(context, node_index, &node, ®); + // } + // WARNING: This is an experimental interface that is subject to change. + TfLiteStatus (*GetExecutionPlan)(struct TfLiteContext* context, + TfLiteIntArray** execution_plan); + + // An array of tensors in the interpreter context (of length `tensors_size`) + TfLiteTensor* tensors; + + // opaque full context ptr (an opaque c++ data structure) + void* impl_; + + // Request memory pointer be resized. Updates dimensions on the tensor. + // NOTE: ResizeTensor takes ownership of newSize. + TfLiteStatus (*ResizeTensor)(struct TfLiteContext*, TfLiteTensor* tensor, + TfLiteIntArray* new_size); + // Request that a error be reported with format string msg. + void (*ReportError)(struct TfLiteContext*, const char* msg, ...); + + // Add `tensors_to_add` tensors, preserving pre-existing Tensor entries. If + // non-null, the value pointed to by `first_new_tensor_index` will be set to + // the index of the first new tensor. + TfLiteStatus (*AddTensors)(struct TfLiteContext*, int tensors_to_add, + int* first_new_tensor_index); + + // Get a Tensor node by node_index. + // WARNING: This is an experimental interface that is subject to change. + TfLiteStatus (*GetNodeAndRegistration)(struct TfLiteContext*, int node_index, + TfLiteNode** node, + TfLiteRegistration** registration); + + // Replace ops with one or more stub delegate operations. This function + // does not take ownership of `nodes_to_replace`. + TfLiteStatus (*ReplaceSubgraphsWithDelegateKernels)( + struct TfLiteContext*, TfLiteRegistration registration, + const TfLiteIntArray* nodes_to_replace, TfLiteDelegate* delegate); + + // Number of threads that are recommended to subsystems like gemmlowp and + // eigen. + int recommended_num_threads; + + // Access external contexts by type. + // WARNING: This is an experimental interface that is subject to change. + TfLiteExternalContext* (*GetExternalContext)(struct TfLiteContext*, + TfLiteExternalContextType); + // Set the value of a external context. Does not take ownership of the + // pointer. + // WARNING: This is an experimental interface that is subject to change. + void (*SetExternalContext)(struct TfLiteContext*, TfLiteExternalContextType, + TfLiteExternalContext*); +} TfLiteContext; + +typedef struct _TfLiteRegistration { + // Initializes the op from serialized data. + // If a built-in op: + // `buffer` is the op's params data (TfLiteLSTMParams*). + // `length` is zero. + // If custom op: + // `buffer` is the op's `custom_options`. + // `length` is the size of the buffer. + // + // Returns a type-punned (i.e. void*) opaque data (e.g. a primitive pointer + // or an instance of a struct). + // + // The returned pointer will be stored with the node in the `user_data` field, + // accessible within prepare and invoke functions below. + // NOTE: if the data is already in the desired format, simply implement this + // function to return `nullptr` and implement the free function to be a no-op. + void* (*init)(TfLiteContext* context, const char* buffer, size_t length); + + // The pointer `buffer` is the data previously returned by an init invocation. + void (*free)(TfLiteContext* context, void* buffer); + + // prepare is called when the inputs this node depends on have been resized. + // context->ResizeTensor() can be called to request output tensors to be + // resized. + // + // Returns kTfLiteOk on success. + TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node); + + // Execute the node (should read node->inputs and output to node->outputs). + // Returns kTfLiteOk on success. + TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node); + + // profiling_string is called during summarization of profiling information + // in order to group executions together. Providing a value here will cause a + // given op to appear multiple times is the profiling report. This is + // particularly useful for custom ops that can perform significantly + // different calculations depending on their `user-data`. + const char* (*profiling_string)(const TfLiteContext* context, + const TfLiteNode* node); + + // Builtin codes. If this kernel refers to a builtin this is the code + // of the builtin. This is so we can do marshaling to other frameworks like + // NN API. + // Note: It is the responsibility of the registration binder to set this + // properly. + int32_t builtin_code; + + // Custom op name. If the op is a builtin, this will be null. + // Note: It is the responsibility of the registration binder to set this + // properly. + // WARNING: This is an experimental interface that is subject to change. + const char* custom_name; + + // The version of the op. + // Note: It is the responsibility of the registration binder to set this + // properly. + int version; +} TfLiteRegistration; + +// WARNING: This is an experimental interface that is subject to change. +typedef struct _TfLiteDelegate { + // Data that delegate needs to identify itself. This data is owned by the + // delegate. The delegate is owned in the user code, so the delegate is + // responsible for doing this when it is destroyed. + void* data_; + + // Invoked by ModifyGraphWithDelegate. This prepare is called, giving the + // delegate a view of the current graph through TfLiteContext*. It typically + // will look at the nodes and call ReplaceSubgraphsWithDelegateKernels() + // to ask the TensorFlow lite runtime to create macro-nodes to represent + // delegated subgraphs of the original graph. + TfLiteStatus (*Prepare)(TfLiteContext* context, TfLiteDelegate* delegate); + + // Copy the data from delegate buffer handle to raw memory. + // This can be null if the delegate doesn't use its own buffer. + TfLiteStatus (*CopyFromBufferHandle)(TfLiteContext* context, + TfLiteDelegate* delegate, + TfLiteBufferHandle buffer_handle, + void* data, size_t size); + + // Copy the data from raw memory to delegate buffer handle. + // This can be null if the delegate doesn't use its own buffer. + TfLiteStatus (*CopyToBufferHandle)(TfLiteContext* context, + TfLiteDelegate* delegate, + TfLiteBufferHandle buffer_handle, + void* data, size_t size); + + // Free the Delegate Buffer Handle. Note: This only frees the handle, but + // this doesn't release the underlying resource (e.g. textures). The + // resources are either owned by application layer or the delegate. + // This can be null if the delegate doesn't use its own buffer. + void (*FreeBufferHandle)(TfLiteContext* context, TfLiteDelegate* delegate, + TfLiteBufferHandle* handle); +} TfLiteDelegate; + +// WARNING: This is an experimental interface that is subject to change. +// +// Currently, TfLiteDelegateParams has to be allocated in a way that it's +// trivially destructable. It will be stored as `builtin_data` field in +// `TfLiteNode` of the delegate node. +// +// See also the `CreateDelegateParams` function in `interpreter.cc` details. +typedef struct { + TfLiteDelegate* delegate; + TfLiteIntArray* nodes_to_replace; + TfLiteIntArray* input_tensors; + TfLiteIntArray* output_tensors; +} TfLiteDelegateParams; + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus +#endif // TENSORFLOW_CONTRIB_LITE_C_C_API_INTERNAL_H_ diff --git a/tensorflow/contrib/lite/context_test.cc b/tensorflow/contrib/lite/c/c_api_internal_test.cc index 20d6f69a25..af398f3207 100644 --- a/tensorflow/contrib/lite/context_test.cc +++ b/tensorflow/contrib/lite/c/c_api_internal_test.cc @@ -13,16 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include <gtest/gtest.h> -#include "tensorflow/contrib/lite/testing/util.h" namespace tflite { // NOTE: this tests only the TfLiteIntArray part of context. -// most of context.h is provided in the context of using it with interpreter.h -// and interpreter.cc, so interpreter_test.cc tests context structures more -// thoroughly. +// most of c_api_internal.h is provided in the context of using it with +// interpreter.h and interpreter.cc, so interpreter_test.cc tests context +// structures more thoroughly. TEST(IntArray, TestIntArrayCreate) { TfLiteIntArray* a = TfLiteIntArrayCreate(0); @@ -69,7 +68,6 @@ TEST(IntArray, TestIntArrayEqual) { } // namespace tflite int main(int argc, char** argv) { - ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h index b23183b743..b86c2819b8 100644 --- a/tensorflow/contrib/lite/context.h +++ b/tensorflow/contrib/lite/context.h @@ -12,484 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This file defines a C API for implementing operations in tflite. -// These operations can be defined using c++ but the interface between -// the interpreter and the operations are C. -// -// Summary of abstractions -// TF_LITE_ENSURE - Self-sufficient error checking -// TfLiteStatus - Status reporting -// TfLiteIntArray - stores tensor shapes (dims), -// TfLiteContext - allows an op to access the tensors -// TfLiteTensor - tensor (a multidimensional array) -// TfLiteNode - a single node or operation -// TfLiteRegistration - the implementation of a conceptual operation. -// -// Some abstractions in this file are created and managed by Interpreter. +// Compatibility shim for moved header location. #ifndef TENSORFLOW_CONTRIB_LITE_CONTEXT_H_ #define TENSORFLOW_CONTRIB_LITE_CONTEXT_H_ -#include <stdbool.h> -#include <stdint.h> -#include <stdlib.h> +#include "tensorflow/contrib/lite/c/c_api_internal.h" -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -typedef enum { kTfLiteOk = 0, kTfLiteError = 1 } TfLiteStatus; - -// Forward declarations for use with dependent types. -struct TfLiteContext; -struct TfLiteNode; -struct _TfLiteRegistration; -struct _TfLiteDelegate; - -// The list of external context types known to TF Lite. This list exists solely -// to avoid conflicts and to ensure ops can share the external contexts they -// need. Access to the external contexts is controled by one of the -// corresponding support files. -typedef enum { - kTfLiteEigenContext = 0, // include eigen_support.h to use. - kTfLiteGemmLowpContext = 1, // include gemm_support.h to use. - kTfLiteEdgeTpuContext = 2, // Placeholder for Edge TPU support. - kTfLiteMaxExternalContexts = 3 -} TfLiteExternalContextType; - -// An external context is a collection of information unrelated to the TF Lite -// framework, but useful to a subset of the ops. TF Lite knows very little -// about about the actual contexts, but it keeps a list of them, and is able to -// refresh them if configurations like the number of recommended threads -// change. -typedef struct { - TfLiteExternalContextType type; - TfLiteStatus (*Refresh)(struct TfLiteContext* context); -} TfLiteExternalContext; - -#define kOptionalTensor (-1) - -// Fixed size list of integers. Used for dimensions and inputs/outputs tensor -// indices -typedef struct { - int size; -// gcc 6.1+ have a bug where flexible members aren't properly handled -// https://github.com/google/re2/commit/b94b7cd42e9f02673cd748c1ac1d16db4052514c -#if !defined(__clang__) && defined(__GNUC__) && __GNUC__ == 6 && \ - __GNUC_MINOR__ >= 1 - int data[0]; -#else - int data[]; -#endif -} TfLiteIntArray; - -// Given the size (number of elements) in a TfLiteIntArray, calculate its size -// in bytes. -int TfLiteIntArrayGetSizeInBytes(int size); - -// Create a array of a given `size` (uninitialized entries). -// This returns a pointer, that you must free using TfLiteIntArrayFree(). -TfLiteIntArray* TfLiteIntArrayCreate(int size); - -// Check if two tensors are equal. Returns 1 if they are equal, 0 otherwise. -int TfLiteIntArrayEqual(TfLiteIntArray* a, TfLiteIntArray* b); - -// Create a copy of an array passed as `src`. -// You are expected to free memory with TfLiteIntArrayFree -TfLiteIntArray* TfLiteIntArrayCopy(TfLiteIntArray* src); - -// Free memory of array `v`. -void TfLiteIntArrayFree(TfLiteIntArray* v); - -// Since we must not depend on any libraries, define a minimal subset of -// error macros while avoiding names that have pre-conceived meanings like -// assert and check. - -// Check whether value is true, and if not return kTfLiteError from -// the current function (and report the error string msg). -#define TF_LITE_ENSURE_MSG(context, value, msg) \ - do { \ - if (!(value)) { \ - (context)->ReportError((context), __FILE__ " " msg); \ - return kTfLiteError; \ - } \ - } while (0) - -// Check whether the value `a` is true, and if not return kTfLiteError from -// the current function, while also reporting the location of the error. -#define TF_LITE_ENSURE(context, a) \ - do { \ - if (!(a)) { \ - (context)->ReportError((context), "%s:%d %s was not true.", __FILE__, \ - __LINE__, #a); \ - return kTfLiteError; \ - } \ - } while (0) - -#define TF_LITE_ENSURE_STATUS(a) \ - do { \ - if ((a) != kTfLiteOk) { \ - return kTfLiteError; \ - } \ - } while (0) - -// Check whether the value `a == b` is true, and if not return kTfLiteError from -// the current function, while also reporting the location of the error. -// `a` and `b` may be evaluated more than once, so no side effects or -// extremely expensive computations should be done. -#define TF_LITE_ENSURE_EQ(context, a, b) \ - do { \ - if ((a) != (b)) { \ - (context)->ReportError((context), "%s:%d %s != %s (%d != %d)", __FILE__, \ - __LINE__, #a, #b, (a), (b)); \ - return kTfLiteError; \ - } \ - } while (0) - -#define TF_LITE_ENSURE_OK(context, status) \ - do { \ - if ((status) != kTfLiteOk) { \ - return status; \ - } \ - } while (0) - -// Single-precision complex data type compatible with the C99 definition. -typedef struct { - float re, im; // real and imaginary parts, respectively. -} TfLiteComplex64; - -// Types supported by tensor -typedef enum { - kTfLiteNoType = 0, - kTfLiteFloat32 = 1, - kTfLiteInt32 = 2, - kTfLiteUInt8 = 3, - kTfLiteInt64 = 4, - kTfLiteString = 5, - kTfLiteBool = 6, - kTfLiteInt16 = 7, - kTfLiteComplex64 = 8, -} TfLiteType; - -// Parameters for asymmetric quantization. Quantized values can be converted -// back to float using: -// real_value = scale * (quantized_value - zero_point); -typedef struct { - float scale; - int32_t zero_point; -} TfLiteQuantizationParams; - -// A union of pointers that points to memory for a given tensor. -typedef union { - int* i32; - int64_t* i64; - float* f; - char* raw; - const char* raw_const; - uint8_t* uint8; - bool* b; - int16_t* i16; - TfLiteComplex64* c64; -} TfLitePtrUnion; - -// Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped -// data (or data externally allocated). kTfLiteArenaRw is arena allocated -// data. kTfLiteDynamic is for tensors that are allocated during evaluation. -typedef enum { - kTfLiteMemNone = 0, - kTfLiteMmapRo, - kTfLiteArenaRw, - kTfLiteArenaRwPersistent, - kTfLiteDynamic, -} TfLiteAllocationType; - -// The delegates should use zero or positive integers to represent handles. -// -1 is reserved from unallocated status. -typedef int TfLiteBufferHandle; -const TfLiteBufferHandle kTfLiteNullBufferHandle = -1; - -// An tensor in the interpreter system which is a wrapper around a buffer of -// data including a dimensionality (or NULL if not currently defined). -typedef struct { - // The data type specification for data stored in `data`. This affects - // what member of `data` union should be used. - TfLiteType type; - // A union of data pointers. The appropriate type should be used for a typed - // tensor based on `type`. - TfLitePtrUnion data; - // A pointer to a structure representing the dimensionality interpretation - // that the buffer should have. NOTE: the product of elements of `dims` - // and the element datatype size should be equal to `bytes` below. - TfLiteIntArray* dims; - // Quantization information. - TfLiteQuantizationParams params; - // How memory is mapped - // kTfLiteMmapRo: Memory mapped read only. - // i.e. weights - // kTfLiteArenaRw: Arena allocated read write memory - // (i.e. temporaries, outputs). - TfLiteAllocationType allocation_type; - // The number of bytes required to store the data of this Tensor. I.e. - // (bytes of each element) * dims[0] * ... * dims[n-1]. For example, if - // type is kTfLiteFloat32 and dims = {3, 2} then - // bytes = sizeof(float) * 3 * 2 = 4 * 3 * 2 = 24. - size_t bytes; - - // An opaque pointer to a tflite::MMapAllocation - const void* allocation; - - // Null-terminated name of this tensor. - const char* name; - - // The delegate which knows how to handle `buffer_handle`. - // WARNING: This is an experimental interface that is subject to change. - struct _TfLiteDelegate* delegate; - - // An integer buffer handle that can be handled by `delegate`. - // The value is valid only when delegate is not null. - // WARNING: This is an experimental interface that is subject to change. - TfLiteBufferHandle buffer_handle; - - // If the delegate uses its own buffer (e.g. GPU memory), the delegate is - // responsible to set data_is_stale to true. - // `delegate->CopyFromBufferHandle` can be called to copy the data from - // delegate buffer. - // WARNING: This is an // experimental interface that is subject to change. - bool data_is_stale; - - // True if the tensor is a variable. - bool is_variable; -} TfLiteTensor; - -// Free data memory of tensor `t`; -void TfLiteTensorDataFree(TfLiteTensor* t); - -// Free memory of tensor `t`; -void TfLiteTensorFree(TfLiteTensor* t); - -// Set all of a tensor's fields (and free any previously allocated data). -void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims, - TfLiteQuantizationParams quantization, char* buffer, - size_t size, TfLiteAllocationType allocation_type, - const void* allocation, bool is_variable, - TfLiteTensor* tensor); - -// Resize the allocated data of a (dynamic) tensor. Tensors with allocation -// types other than kTfLiteDynamic will be ignored. -void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor); - -// A structure representing an instance of a node. -// This structure only exhibits the inputs, outputs and user defined data, not -// other features like the type. -typedef struct TfLiteNode { - // Inputs to this node expressed as indices into the simulator's tensors. - TfLiteIntArray* inputs; - - // Outputs to this node expressed as indices into the simulator's tensors. - TfLiteIntArray* outputs; - - // Temporary tensors uses during the computations. This usually contains no - // tensors, but ops are allowed to change that if they need scratch space of - // any sort. - TfLiteIntArray* temporaries; - - // Opaque data provided by the node implementer through `Registration.init`. - void* user_data; - - // Opaque data provided to the node if the node is a builtin. This is usually - // a structure defined in builtin_op_data.h - void* builtin_data; - - // Custom initial data. This is the opaque data provided in the flatbuffer. - // WARNING: This is an experimental interface that is subject to change. - const void* custom_initial_data; - int custom_initial_data_size; - - // The pointer to the delegate. This is non-null only when the node is - // created by calling `interpreter.ModifyGraphWithDelegate`. - // WARNING: This is an experimental interface that is subject to change. - struct _TfLiteDelegate* delegate; -} TfLiteNode; - -typedef struct TfLiteContext { - // Number of tensors in the context. - size_t tensors_size; - - // The execution plan contains a list of the node indices in execution - // order. execution_plan->size is the current number of nodes. And, - // execution_plan->data[0] is the first node that needs to be run. - // TfLiteDelegates can traverse the current execution plan by iterating - // through each member of this array and using GetNodeAndRegistration() to - // access details about a node. i.e. - // TfLiteIntArray* execution_plan; - // TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &execution_plan)); - // for (int exec_index = 0; exec_index < execution_plan->size; exec_index++) { - // int node_index = execution_plan->data[exec_index]; - // TfLiteNode* node; - // TfLiteRegistration* reg; - // context->GetNodeAndRegistration(context, node_index, &node, ®); - // } - // WARNING: This is an experimental interface that is subject to change. - TfLiteStatus (*GetExecutionPlan)(struct TfLiteContext* context, - TfLiteIntArray** execution_plan); - - // An array of tensors in the interpreter context (of length `tensors_size`) - TfLiteTensor* tensors; - - // opaque full context ptr (an opaque c++ data structure) - void* impl_; - - // Request memory pointer be resized. Updates dimensions on the tensor. - // NOTE: ResizeTensor takes ownership of newSize. - TfLiteStatus (*ResizeTensor)(struct TfLiteContext*, TfLiteTensor* tensor, - TfLiteIntArray* new_size); - // Request that a error be reported with format string msg. - void (*ReportError)(struct TfLiteContext*, const char* msg, ...); - - // Add `tensors_to_add` tensors, preserving pre-existing Tensor entries. If - // non-null, the value pointed to by `first_new_tensor_index` will be set to - // the index of the first new tensor. - TfLiteStatus (*AddTensors)(struct TfLiteContext*, int tensors_to_add, - int* first_new_tensor_index); - - // Get a Tensor node by node_index. - // WARNING: This is an experimental interface that is subject to change. - TfLiteStatus (*GetNodeAndRegistration)( - struct TfLiteContext*, int node_index, struct TfLiteNode** node, - struct _TfLiteRegistration** registration); - - // Replace ops with one or more stub delegate operations. This function - // does not take ownership of `nodes_to_replace`. - TfLiteStatus (*ReplaceSubgraphsWithDelegateKernels)( - struct TfLiteContext*, struct _TfLiteRegistration registration, - const TfLiteIntArray* nodes_to_replace, struct _TfLiteDelegate* delegate); - - // Number of threads that are recommended to subsystems like gemmlowp and - // eigen. - int recommended_num_threads; - - // Access external contexts by type. - // WARNING: This is an experimental interface that is subject to change. - TfLiteExternalContext* (*GetExternalContext)(struct TfLiteContext*, - TfLiteExternalContextType); - // Set the value of a external context. Does not take ownership of the - // pointer. - // WARNING: This is an experimental interface that is subject to change. - void (*SetExternalContext)(struct TfLiteContext*, TfLiteExternalContextType, - TfLiteExternalContext*); -} TfLiteContext; - -typedef struct _TfLiteRegistration { - // Initializes the op from serialized data. - // If a built-in op: - // `buffer` is the op's params data (TfLiteLSTMParams*). - // `length` is zero. - // If custom op: - // `buffer` is the op's `custom_options`. - // `length` is the size of the buffer. - // - // Returns a type-punned (i.e. void*) opaque data (e.g. a primitive pointer - // or an instance of a struct). - // - // The returned pointer will be stored with the node in the `user_data` field, - // accessible within prepare and invoke functions below. - // NOTE: if the data is already in the desired format, simply implement this - // function to return `nullptr` and implement the free function to be a no-op. - void* (*init)(TfLiteContext* context, const char* buffer, size_t length); - - // The pointer `buffer` is the data previously returned by an init invocation. - void (*free)(TfLiteContext* context, void* buffer); - - // prepare is called when the inputs this node depends on have been resized. - // context->ResizeTensor() can be called to request output tensors to be - // resized. - // - // Returns kTfLiteOk on success. - TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node); - - // Execute the node (should read node->inputs and output to node->outputs). - // Returns kTfLiteOk on success. - TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node); - - // profiling_string is called during summarization of profiling information - // in order to group executions together. Providing a value here will cause a - // given op to appear multiple times is the profiling report. This is - // particularly useful for custom ops that can perform significantly - // different calculations depending on their `user-data`. - const char* (*profiling_string)(const TfLiteContext* context, - const TfLiteNode* node); - - // Builtin codes. If this kernel refers to a builtin this is the code - // of the builtin. This is so we can do marshaling to other frameworks like - // NN API. - // Note: It is the responsibility of the registration binder to set this - // properly. - int32_t builtin_code; - - // Custom op name. If the op is a builtin, this will be null. - // Note: It is the responsibility of the registration binder to set this - // properly. - // WARNING: This is an experimental interface that is subject to change. - const char* custom_name; - - // The version of the op. - // Note: It is the responsibility of the registration binder to set this - // properly. - int version; -} TfLiteRegistration; - -// WARNING: This is an experimental interface that is subject to change. -typedef struct _TfLiteDelegate { - // Data that delegate needs to identify itself. This data is owned by the - // delegate. The delegate is owned in the user code, so the delegate is - // responsible for doing this when it is destroyed. - void* data_; - - // Invoked by ModifyGraphWithDelegate. This prepare is called, giving the - // delegate a view of the current graph through TfLiteContext*. It typically - // will look at the nodes and call ReplaceSubgraphsWithDelegateKernels() - // to ask the TensorFlow lite runtime to create macro-nodes to represent - // delegated subgraphs of the original graph. - TfLiteStatus (*Prepare)(struct TfLiteContext* context, - struct _TfLiteDelegate* delegate); - - // Copy the data from delegate buffer handle to raw memory. - // This can be null if the delegate doesn't use its own buffer. - TfLiteStatus (*CopyFromBufferHandle)(struct TfLiteContext* context, - struct _TfLiteDelegate* delegate, - TfLiteBufferHandle buffer_handle, - void* data, size_t size); - - // Copy the data from raw memory to delegate buffer handle. - // This can be null if the delegate doesn't use its own buffer. - TfLiteStatus (*CopyToBufferHandle)(struct TfLiteContext* context, - struct _TfLiteDelegate* delegate, - TfLiteBufferHandle buffer_handle, - void* data, size_t size); - - // Free the Delegate Buffer Handle. Note: This only frees the handle, but - // this doesn't release the underlying resource (e.g. textures). The - // resources are either owned by application layer or the delegate. - // This can be null if the delegate doesn't use its own buffer. - void (*FreeBufferHandle)(struct TfLiteContext* context, - struct _TfLiteDelegate* delegate, - TfLiteBufferHandle* handle); -} TfLiteDelegate; - -// WARNING: This is an experimental interface that is subject to change. -// -// Currently, TfLiteDelegateParams has to be allocated in a way that it's -// trivially destructable. It will be stored as `builtin_data` field in -// `TfLiteNode` of the delegate node. -// -// See also the `CreateDelegateParams` function in `interpreter.cc` details. -typedef struct { - TfLiteDelegate* delegate; - TfLiteIntArray* nodes_to_replace; - TfLiteIntArray* input_tensors; - TfLiteIntArray* output_tensors; -} TfLiteDelegateParams; - -#ifdef __cplusplus -} // extern "C" -#endif // __cplusplus #endif // TENSORFLOW_CONTRIB_LITE_CONTEXT_H_ diff --git a/tensorflow/contrib/lite/context_util.h b/tensorflow/contrib/lite/context_util.h index abe802e342..ccda4c7393 100644 --- a/tensorflow/contrib/lite/context_util.h +++ b/tensorflow/contrib/lite/context_util.h @@ -17,7 +17,7 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_CONTEXT_UTIL_H_ #define TENSORFLOW_CONTRIB_LITE_CONTEXT_UTIL_H_ -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" namespace tflite { diff --git a/tensorflow/contrib/lite/core/api/BUILD b/tensorflow/contrib/lite/core/api/BUILD new file mode 100644 index 0000000000..e4500534f3 --- /dev/null +++ b/tensorflow/contrib/lite/core/api/BUILD @@ -0,0 +1,57 @@ +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") + +cc_library( + name = "api", + srcs = [ + "error_reporter.cc", + "flatbuffer_conversions.cc", + "op_resolver.cc", + ], + hdrs = [ + "error_reporter.h", + "flatbuffer_conversions.h", + "op_resolver.h", + ], + copts = tflite_copts(), + deps = [ + "//tensorflow/contrib/lite/c:c_api_internal", + "//tensorflow/contrib/lite/schema:schema_fbs", + ], +) + +cc_test( + name = "error_reporter_test", + size = "small", + srcs = ["error_reporter_test.cc"], + deps = [ + ":api", + "@com_google_googletest//:gtest", + ], +) + +cc_test( + name = "op_resolver_test", + size = "small", + srcs = ["op_resolver_test.cc"], + deps = [ + ":api", + "@com_google_googletest//:gtest", + ], +) + +cc_test( + name = "flatbuffer_conversions_test", + size = "small", + srcs = ["flatbuffer_conversions_test.cc"], + deps = [ + ":api", + "//tensorflow/contrib/lite/c:c_api_internal", + "@com_google_googletest//:gtest", + ], +) diff --git a/tensorflow/contrib/lite/core/api/error_reporter.cc b/tensorflow/contrib/lite/core/api/error_reporter.cc new file mode 100644 index 0000000000..423f83b1a9 --- /dev/null +++ b/tensorflow/contrib/lite/core/api/error_reporter.cc @@ -0,0 +1,38 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/core/api/error_reporter.h" +#include <cstdarg> + +namespace tflite { + +int ErrorReporter::Report(const char* format, ...) { + va_list args; + va_start(args, format); + int code = Report(format, args); + va_end(args); + return code; +} + +// TODO(aselle): Make the name of ReportError on context the same, so +// we can use the ensure functions w/o a context and w/ a reporter. +int ErrorReporter::ReportError(void*, const char* format, ...) { + va_list args; + va_start(args, format); + int code = Report(format, args); + va_end(args); + return code; +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/core/api/error_reporter.h b/tensorflow/contrib/lite/core/api/error_reporter.h new file mode 100644 index 0000000000..a2f780b003 --- /dev/null +++ b/tensorflow/contrib/lite/core/api/error_reporter.h @@ -0,0 +1,45 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_CORE_API_ERROR_REPORTER_H_ +#define TENSORFLOW_CONTRIB_LITE_CORE_API_ERROR_REPORTER_H_ + +#include <cstdarg> + +namespace tflite { + +// A functor that reports error to supporting system. Invoked similar to +// printf. +// +// Usage: +// ErrorReporter foo; +// foo.Report("test %d", 5); +// or +// va_list args; +// foo.Report("test %d", args); // where args is va_list +// +// Subclass ErrorReporter to provide another reporting destination. +// For example, if you have a GUI program, you might redirect to a buffer +// that drives a GUI error log box. +class ErrorReporter { + public: + virtual ~ErrorReporter() {} + virtual int Report(const char* format, va_list args) = 0; + int Report(const char* format, ...); + int ReportError(void*, const char* format, ...); +}; + +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_CORE_API_ERROR_REPORTER_H_ diff --git a/tensorflow/contrib/lite/core/api/error_reporter_test.cc b/tensorflow/contrib/lite/core/api/error_reporter_test.cc new file mode 100644 index 0000000000..0463eee6be --- /dev/null +++ b/tensorflow/contrib/lite/core/api/error_reporter_test.cc @@ -0,0 +1,49 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/core/api/error_reporter.h" + +#include <cstdio> + +#include <gtest/gtest.h> + +namespace tflite { + +class MockErrorReporter : public ErrorReporter { + public: + int Report(const char* format, va_list args) override { + vsnprintf(buffer_, kBufferSize, format, args); + return 0; + } + char* GetBuffer() { return buffer_; } + + private: + static constexpr int kBufferSize = 256; + char buffer_[kBufferSize]; +}; + +TEST(ErrorReporter, TestReport) { + MockErrorReporter mock_reporter; + ErrorReporter* reporter = &mock_reporter; + reporter->Report("Error: %d", 23); + EXPECT_EQ(0, strcmp(mock_reporter.GetBuffer(), "Error: 23")); +} + +} // namespace tflite + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc new file mode 100644 index 0000000000..1420fbcdc6 --- /dev/null +++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc @@ -0,0 +1,622 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/core/api/flatbuffer_conversions.h" + +#include <cstdlib> + +#include "tensorflow/contrib/lite/c/builtin_op_data.h" + +namespace tflite { + +namespace { + +// Copies the contents from the flatbuffer int vector `flatbuffer` into the +// int array `buffer`. `flat_vector` and `buffer` represent the same +// configuration operation for a given operation. +void FlatBufferIntVectorToArray(int max_size_of_buffer, + const flatbuffers::Vector<int32_t>* flat_vector, + int* buffer, ErrorReporter* error_reporter) { + if (!flat_vector) { + error_reporter->Report("Input array not provided for operation.\n"); + } else { + int num_dimensions = flat_vector->Length(); + if (num_dimensions > max_size_of_buffer / sizeof(int)) { + error_reporter->Report( + "Found too many dimensions in the operation's input array.\n"); + } else { + for (int i = 0; i < num_dimensions; ++i) { + buffer[i] = flat_vector->Get(i); + } + } + } +} + +// Allocate a structure using malloc, but make sure the structure is a POD +// structure that doesn't require constructors to run. The reason we do this, +// is that Interpreter's C extension part will take ownership so destructors +// will not be run during deallocation. +template <class T> +T* MallocPOD() { + static_assert(std::is_pod<T>::value, "Builtin data structure must be POD."); + return static_cast<T*>(malloc(sizeof(T))); +} + +} // namespace + +TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type, + ErrorReporter* error_reporter) { + switch (tensor_type) { + case TensorType_FLOAT32: + *type = kTfLiteFloat32; + break; + case TensorType_INT16: + *type = kTfLiteInt16; + break; + case TensorType_INT32: + *type = kTfLiteInt32; + break; + case TensorType_UINT8: + *type = kTfLiteUInt8; + break; + case TensorType_INT64: + *type = kTfLiteInt64; + break; + case TensorType_STRING: + *type = kTfLiteString; + break; + case TensorType_BOOL: + *type = kTfLiteBool; + break; + case TensorType_COMPLEX64: + *type = kTfLiteComplex64; + break; + default: + error_reporter->Report("Unimplemented data type %s (%d) in tensor\n", + EnumNameTensorType(tensor_type), tensor_type); + return kTfLiteError; + } + return kTfLiteOk; +} + +// Parse the appropriate data out of the op. +// +// This handles builtin data explicitly as there are flatbuffer schemas. +// If it returns kTfLiteOk, it passes the data out with `builtin_data`, which +// need to be released by calling `free`.` +// If it returns kTfLiteError, `builtin_data` will be `nullptr`. +TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, + ErrorReporter* error_reporter, void** builtin_data) { + auto parse_padding = [](Padding padding) { + switch (padding) { + case Padding_SAME: + return kTfLitePaddingSame; + case Padding_VALID: + return kTfLitePaddingValid; + } + return kTfLitePaddingUnknown; + }; + auto parse_activation = [](ActivationFunctionType activation) { + switch (activation) { + case ActivationFunctionType_NONE: + return kTfLiteActNone; + case ActivationFunctionType_RELU: + return kTfLiteActRelu; + case ActivationFunctionType_RELU_N1_TO_1: + return kTfLiteActRelu1; + case ActivationFunctionType_RELU6: + return kTfLiteActRelu6; + case ActivationFunctionType_TANH: + return kTfLiteActTanh; + case ActivationFunctionType_SIGN_BIT: + return kTfLiteActSignBit; + } + return kTfLiteActNone; + }; + auto parseLSHProjectionType = [](LSHProjectionType type) { + switch (type) { + case LSHProjectionType_SPARSE: + return kTfLiteLshProjectionSparse; + case LSHProjectionType_DENSE: + return kTfLiteLshProjectionDense; + default: + return kTfLiteLshProjectionUnknown; + } + }; + auto parseCombinerType = [](CombinerType type) { + switch (type) { + case CombinerType_MEAN: + return kTfLiteCombinerTypeMean; + case CombinerType_SQRTN: + return kTfLiteCombinerTypeSqrtn; + case CombinerType_SUM: + default: + return kTfLiteCombinerTypeSum; + } + }; + + *builtin_data = nullptr; + switch (op_type) { + case BuiltinOperator_CONV_2D: { + TfLiteConvParams* params = MallocPOD<TfLiteConvParams>(); + if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) { + params->padding = parse_padding(conv_params->padding()); + params->stride_width = conv_params->stride_w(); + params->stride_height = conv_params->stride_h(); + params->activation = + parse_activation(conv_params->fused_activation_function()); + + params->dilation_width_factor = conv_params->dilation_w_factor(); + params->dilation_height_factor = conv_params->dilation_h_factor(); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_CAST: { + TfLiteCastParams* params = MallocPOD<TfLiteCastParams>(); + if (auto* schema_params = op->builtin_options_as_CastOptions()) { + auto in_status = + ConvertTensorType(schema_params->in_data_type(), + ¶ms->in_data_type, error_reporter); + auto out_status = + ConvertTensorType(schema_params->out_data_type(), + ¶ms->out_data_type, error_reporter); + if (in_status != kTfLiteOk || out_status != kTfLiteOk) { + free(params); + return kTfLiteError; + } + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_LSH_PROJECTION: { + TfLiteLSHProjectionParams* params = + MallocPOD<TfLiteLSHProjectionParams>(); + if (auto* lshParams = op->builtin_options_as_LSHProjectionOptions()) { + params->type = parseLSHProjectionType(lshParams->type()); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_AVERAGE_POOL_2D: + case BuiltinOperator_MAX_POOL_2D: + case BuiltinOperator_L2_POOL_2D: { + TfLitePoolParams* params = MallocPOD<TfLitePoolParams>(); + if (auto* pool_params = op->builtin_options_as_Pool2DOptions()) { + params->padding = parse_padding(pool_params->padding()); + params->stride_width = pool_params->stride_w(); + params->stride_height = pool_params->stride_h(); + params->filter_width = pool_params->filter_width(); + params->filter_height = pool_params->filter_height(); + params->activation = + parse_activation(pool_params->fused_activation_function()); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_DEPTHWISE_CONV_2D: { + TfLiteDepthwiseConvParams* params = + MallocPOD<TfLiteDepthwiseConvParams>(); + if (auto* conv_params = op->builtin_options_as_DepthwiseConv2DOptions()) { + params->padding = parse_padding(conv_params->padding()); + params->stride_width = conv_params->stride_w(); + params->stride_height = conv_params->stride_h(); + params->depth_multiplier = conv_params->depth_multiplier(); + params->activation = + parse_activation(conv_params->fused_activation_function()); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_SVDF: { + TfLiteSVDFParams* params = MallocPOD<TfLiteSVDFParams>(); + if (auto* svdf_params = op->builtin_options_as_SVDFOptions()) { + params->rank = svdf_params->rank(); + params->activation = + parse_activation(svdf_params->fused_activation_function()); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN: + case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: { + TfLiteSequenceRNNParams* params = MallocPOD<TfLiteSequenceRNNParams>(); + if (auto* sequence_rnn_params = + op->builtin_options_as_SequenceRNNOptions()) { + params->activation = + parse_activation(sequence_rnn_params->fused_activation_function()); + params->time_major = sequence_rnn_params->time_major(); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_RNN: { + TfLiteRNNParams* params = MallocPOD<TfLiteRNNParams>(); + if (auto* rnn_params = op->builtin_options_as_RNNOptions()) { + params->activation = + parse_activation(rnn_params->fused_activation_function()); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: { + TfLiteEmbeddingLookupSparseParams* params = + MallocPOD<TfLiteEmbeddingLookupSparseParams>(); + if (auto* embedding_params = + op->builtin_options_as_EmbeddingLookupSparseOptions()) { + params->combiner = parseCombinerType(embedding_params->combiner()); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_FULLY_CONNECTED: { + TfLiteFullyConnectedParams* params = + MallocPOD<TfLiteFullyConnectedParams>(); + if (auto* fully_connected_params = + op->builtin_options_as_FullyConnectedOptions()) { + params->activation = parse_activation( + fully_connected_params->fused_activation_function()); + switch (fully_connected_params->weights_format()) { + case FullyConnectedOptionsWeightsFormat_DEFAULT: + params->weights_format = kTfLiteFullyConnectedWeightsFormatDefault; + break; + case FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8: + params->weights_format = + kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8; + break; + default: + error_reporter->Report("Unhandled fully-connected weights format."); + return kTfLiteError; + } + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_HASHTABLE_LOOKUP: + // no-op. + break; + case BuiltinOperator_SOFTMAX: { + TfLiteSoftmaxParams* params = MallocPOD<TfLiteSoftmaxParams>(); + if (auto* softmax_params = op->builtin_options_as_SoftmaxOptions()) { + params->beta = softmax_params->beta(); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_CONCATENATION: { + TfLiteConcatenationParams* params = + MallocPOD<TfLiteConcatenationParams>(); + if (auto* concatenation_params = + op->builtin_options_as_ConcatenationOptions()) { + params->activation = + parse_activation(concatenation_params->fused_activation_function()); + params->axis = concatenation_params->axis(); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_MUL: { + auto* params = MallocPOD<TfLiteMulParams>(); + if (auto* schema_params = op->builtin_options_as_MulOptions()) { + params->activation = + parse_activation(schema_params->fused_activation_function()); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_ADD: { + auto* params = MallocPOD<TfLiteAddParams>(); + if (auto* schema_params = op->builtin_options_as_AddOptions()) { + params->activation = + parse_activation(schema_params->fused_activation_function()); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_DIV: { + auto* params = MallocPOD<TfLiteDivParams>(); + if (auto* schema_params = op->builtin_options_as_DivOptions()) { + params->activation = + parse_activation(schema_params->fused_activation_function()); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_SUB: { + auto* params = MallocPOD<TfLiteSubParams>(); + if (auto* schema_params = op->builtin_options_as_SubOptions()) { + params->activation = + parse_activation(schema_params->fused_activation_function()); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_L2_NORMALIZATION: { + auto* params = MallocPOD<TfLiteL2NormParams>(); + if (auto* schema_params = op->builtin_options_as_L2NormOptions()) { + params->activation = + parse_activation(schema_params->fused_activation_function()); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: { + auto* params = MallocPOD<TfLiteLocalResponseNormParams>(); + if (auto* schema_params = + op->builtin_options_as_LocalResponseNormalizationOptions()) { + params->radius = schema_params->radius(); + params->bias = schema_params->bias(); + params->alpha = schema_params->alpha(); + params->beta = schema_params->beta(); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: + case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: + case BuiltinOperator_LSTM: { + TfLiteLSTMParams* params = MallocPOD<TfLiteLSTMParams>(); + if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) { + params->activation = + parse_activation(lstm_params->fused_activation_function()); + params->cell_clip = lstm_params->cell_clip(); + params->proj_clip = lstm_params->proj_clip(); + switch (lstm_params->kernel_type()) { + case LSTMKernelType_FULL: + params->kernel_type = kTfLiteLSTMFullKernel; + break; + case LSTMKernelType_BASIC: + params->kernel_type = kTfLiteLSTMBasicKernel; + break; + } + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_RESIZE_BILINEAR: { + auto* params = MallocPOD<TfLiteResizeBilinearParams>(); + if (auto* schema_params = + op->builtin_options_as_ResizeBilinearOptions()) { + params->align_corners = schema_params->align_corners(); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_RESHAPE: { + auto* params = MallocPOD<TfLiteReshapeParams>(); + if (auto* schema_params = op->builtin_options_as_ReshapeOptions()) { + auto* new_shape = schema_params->new_shape(); + FlatBufferIntVectorToArray(sizeof(params->shape), new_shape, + params->shape, error_reporter); + params->num_dimensions = new_shape->Length(); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_SKIP_GRAM: { + TfLiteSkipGramParams* params = MallocPOD<TfLiteSkipGramParams>(); + if (auto* skip_gram_params = op->builtin_options_as_SkipGramOptions()) { + params->ngram_size = skip_gram_params->ngram_size(); + params->max_skip_size = skip_gram_params->max_skip_size(); + params->include_all_ngrams = skip_gram_params->include_all_ngrams(); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_SPACE_TO_DEPTH: { + auto* params = MallocPOD<TfLiteSpaceToDepthParams>(); + if (auto* schema_params = op->builtin_options_as_SpaceToDepthOptions()) { + params->block_size = schema_params->block_size(); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_GATHER: { + TfLiteGatherParams* params = MallocPOD<TfLiteGatherParams>(); + params->axis = 0; + if (auto* gather_params = op->builtin_options_as_GatherOptions()) { + params->axis = gather_params->axis(); + } + + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_MEAN: + case BuiltinOperator_REDUCE_MAX: + case BuiltinOperator_REDUCE_MIN: + case BuiltinOperator_REDUCE_PROD: + case BuiltinOperator_REDUCE_ANY: + case BuiltinOperator_SUM: { + auto* params = MallocPOD<TfLiteReducerParams>(); + if (auto* schema_params = op->builtin_options_as_ReducerOptions()) { + params->keep_dims = schema_params->keep_dims(); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_SPLIT: { + auto* params = MallocPOD<TfLiteSplitParams>(); + if (auto* schema_params = op->builtin_options_as_SplitOptions()) { + params->num_splits = schema_params->num_splits(); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_SQUEEZE: { + auto* params = MallocPOD<TfLiteSqueezeParams>(); + if (auto* schema_params = op->builtin_options_as_SqueezeOptions()) { + const auto& squeeze_dims = schema_params->squeeze_dims(); + FlatBufferIntVectorToArray(sizeof(params->squeeze_dims), squeeze_dims, + params->squeeze_dims, error_reporter); + params->num_squeeze_dims = squeeze_dims->Length(); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_STRIDED_SLICE: { + auto* params = MallocPOD<TfLiteStridedSliceParams>(); + if (auto* schema_params = op->builtin_options_as_StridedSliceOptions()) { + params->begin_mask = schema_params->begin_mask(); + params->end_mask = schema_params->end_mask(); + params->ellipsis_mask = schema_params->ellipsis_mask(); + params->new_axis_mask = schema_params->new_axis_mask(); + params->shrink_axis_mask = schema_params->shrink_axis_mask(); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_ARG_MAX: { + auto* params = MallocPOD<TfLiteArgMaxParams>(); + if (auto* schema_params = op->builtin_options_as_ArgMaxOptions()) { + ConvertTensorType(schema_params->output_type(), ¶ms->output_type, + error_reporter); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_ARG_MIN: { + auto* params = MallocPOD<TfLiteArgMinParams>(); + if (const auto* schema_params = op->builtin_options_as_ArgMinOptions()) { + ConvertTensorType(schema_params->output_type(), ¶ms->output_type, + error_reporter); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_TRANSPOSE_CONV: { + TfLiteTransposeConvParams* params = + MallocPOD<TfLiteTransposeConvParams>(); + if (auto* transpose_conv_params = + op->builtin_options_as_TransposeConvOptions()) { + params->padding = parse_padding(transpose_conv_params->padding()); + params->stride_width = transpose_conv_params->stride_w(); + params->stride_height = transpose_conv_params->stride_h(); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_SPARSE_TO_DENSE: { + TfLiteSparseToDenseParams* params = + MallocPOD<TfLiteSparseToDenseParams>(); + if (auto* sparse_to_dense_params = + op->builtin_options_as_SparseToDenseOptions()) { + params->validate_indices = sparse_to_dense_params->validate_indices(); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_SHAPE: { + auto* params = MallocPOD<TfLiteShapeParams>(); + if (auto* schema_params = op->builtin_options_as_ShapeOptions()) { + ConvertTensorType(schema_params->out_type(), ¶ms->out_type, + error_reporter); + } + *builtin_data = static_cast<void*>(params); + break; + } + case BuiltinOperator_PACK: { + TfLitePackParams* params = MallocPOD<TfLitePackParams>(); + if (auto* pack_params = op->builtin_options_as_PackOptions()) { + params->values_count = pack_params->values_count(); + params->axis = pack_params->axis(); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_DELEGATE: { + // TODO(ycling): Revisit when supporting saving delegated models. + error_reporter->Report("DELEGATE op shouldn't exist in model."); + return kTfLiteError; + } + case BuiltinOperator_FAKE_QUANT: { + auto* params = MallocPOD<TfLiteFakeQuantParams>(); + if (auto* schema_params = op->builtin_options_as_FakeQuantOptions()) { + params->min = schema_params->min(); + params->max = schema_params->max(); + params->num_bits = schema_params->num_bits(); + params->narrow_range = schema_params->narrow_range(); + } + *builtin_data = static_cast<void*>(params); + break; + } + case BuiltinOperator_ONE_HOT: { + auto* params = MallocPOD<TfLiteOneHotParams>(); + if (auto* schema_params = op->builtin_options_as_OneHotOptions()) { + params->axis = schema_params->axis(); + } + *builtin_data = static_cast<void*>(params); + break; + } + case BuiltinOperator_UNPACK: { + TfLiteUnpackParams* params = MallocPOD<TfLiteUnpackParams>(); + if (auto* unpack_params = op->builtin_options_as_UnpackOptions()) { + params->num = unpack_params->num(); + params->axis = unpack_params->axis(); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + + // Below are the ops with no builtin_data strcture. + case BuiltinOperator_BATCH_TO_SPACE_ND: + // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are + // ok for now, since there is no call implementation either. + case BuiltinOperator_CALL: + case BuiltinOperator_CONCAT_EMBEDDINGS: + case BuiltinOperator_CUSTOM: + case BuiltinOperator_DEQUANTIZE: + case BuiltinOperator_EMBEDDING_LOOKUP: + case BuiltinOperator_EQUAL: + case BuiltinOperator_EXP: + case BuiltinOperator_EXPAND_DIMS: + case BuiltinOperator_FLOOR: + case BuiltinOperator_GREATER: + case BuiltinOperator_GREATER_EQUAL: + case BuiltinOperator_LESS: + case BuiltinOperator_LESS_EQUAL: + case BuiltinOperator_LOG: + case BuiltinOperator_LOGISTIC: + case BuiltinOperator_LOG_SOFTMAX: + case BuiltinOperator_MAXIMUM: + case BuiltinOperator_MINIMUM: + case BuiltinOperator_NEG: + case BuiltinOperator_NOT_EQUAL: + case BuiltinOperator_PAD: + case BuiltinOperator_PADV2: + case BuiltinOperator_PRELU: + case BuiltinOperator_RELU: + case BuiltinOperator_RELU6: + case BuiltinOperator_RELU_N1_TO_1: + case BuiltinOperator_RSQRT: + case BuiltinOperator_SELECT: + case BuiltinOperator_SIN: + case BuiltinOperator_SLICE: + case BuiltinOperator_SPACE_TO_BATCH_ND: + case BuiltinOperator_SQRT: + case BuiltinOperator_TANH: + case BuiltinOperator_TILE: + case BuiltinOperator_TOPK_V2: + case BuiltinOperator_TRANSPOSE: + case BuiltinOperator_POW: + case BuiltinOperator_LOGICAL_OR: + case BuiltinOperator_LOGICAL_AND: + case BuiltinOperator_LOGICAL_NOT: + case BuiltinOperator_FLOOR_DIV: + break; + } + return kTfLiteOk; +} // NOLINT[readability/fn_size] + +} // namespace tflite diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.h b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.h new file mode 100644 index 0000000000..4dec6f9cfc --- /dev/null +++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.h @@ -0,0 +1,48 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_ +#define TENSORFLOW_CONTRIB_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_ + +// These functions transform codes and data structures that are defined in the +// flatbuffer serialization format into in-memory values that are used by the +// runtime API and interpreter. + +#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/contrib/lite/core/api/error_reporter.h" +#include "tensorflow/contrib/lite/core/api/op_resolver.h" +#include "tensorflow/contrib/lite/schema/schema_generated.h" + +namespace tflite { + +// Parse the appropriate data out of the op. +// +// This handles builtin data explicitly as there are flatbuffer schemas. +// If it returns kTfLiteOk, it passes the data out with `builtin_data`. The +// calling function has to pass in an allocator object, and this allocator +// will be called to reserve space for the output data. If the calling +// function's allocator reserves memory on the heap, then it's the calling +// function's responsibility to free it. +// If it returns kTfLiteError, `builtin_data` will be `nullptr`. +TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, + ErrorReporter* error_reporter, void** builtin_data); + +// Converts the tensor data type used in the flat buffer to the representation +// used by the runtime. +TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type, + ErrorReporter* error_reporter); + +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_ diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions_test.cc b/tensorflow/contrib/lite/core/api/flatbuffer_conversions_test.cc new file mode 100644 index 0000000000..b12bdf43b2 --- /dev/null +++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions_test.cc @@ -0,0 +1,104 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/core/api/flatbuffer_conversions.h" + +#include <cstring> + +#include <gtest/gtest.h> +#include "tensorflow/contrib/lite/c/builtin_op_data.h" + +namespace tflite { +namespace { + +class MockErrorReporter : public ErrorReporter { + public: + MockErrorReporter() : buffer_size_(0) {} + int Report(const char* format, va_list args) override { + buffer_size_ = vsnprintf(buffer_, kBufferSize, format, args); + return buffer_size_; + } + char* GetBuffer() { return buffer_; } + int GetBufferSize() { return buffer_size_; } + + private: + static constexpr int kBufferSize = 256; + char buffer_[kBufferSize]; + int buffer_size_; +}; + +} // namespace + +TEST(FlatbufferConversions, TestParseOpDataConv) { + MockErrorReporter mock_reporter; + ErrorReporter* reporter = &mock_reporter; + + flatbuffers::FlatBufferBuilder builder; + flatbuffers::Offset<void> conv_options = + CreateConv2DOptions(builder, Padding_SAME, 1, 2, + ActivationFunctionType_RELU, 3, 4) + .Union(); + flatbuffers::Offset<Operator> conv_offset = CreateOperatorDirect( + builder, 0, nullptr, nullptr, BuiltinOptions_Conv2DOptions, conv_options, + nullptr, CustomOptionsFormat_FLEXBUFFERS, nullptr); + builder.Finish(conv_offset); + void* conv_pointer = builder.GetBufferPointer(); + const Operator* conv_op = flatbuffers::GetRoot<Operator>(conv_pointer); + void* output_data = nullptr; + EXPECT_EQ(kTfLiteOk, ParseOpData(conv_op, BuiltinOperator_CONV_2D, reporter, + &output_data)); + EXPECT_NE(nullptr, output_data); + TfLiteConvParams* params = reinterpret_cast<TfLiteConvParams*>(output_data); + EXPECT_EQ(kTfLitePaddingSame, params->padding); + EXPECT_EQ(1, params->stride_width); + EXPECT_EQ(2, params->stride_height); + EXPECT_EQ(kTfLiteActRelu, params->activation); + EXPECT_EQ(3, params->dilation_width_factor); + EXPECT_EQ(4, params->dilation_height_factor); + free(output_data); +} + +TEST(FlatbufferConversions, TestParseOpDataCustom) { + MockErrorReporter mock_reporter; + ErrorReporter* reporter = &mock_reporter; + + flatbuffers::FlatBufferBuilder builder; + flatbuffers::Offset<void> null_options; + flatbuffers::Offset<Operator> custom_offset = CreateOperatorDirect( + builder, 0, nullptr, nullptr, BuiltinOptions_NONE, null_options, nullptr, + CustomOptionsFormat_FLEXBUFFERS, nullptr); + builder.Finish(custom_offset); + void* custom_pointer = builder.GetBufferPointer(); + const Operator* custom_op = flatbuffers::GetRoot<Operator>(custom_pointer); + void* output_data = nullptr; + EXPECT_EQ(kTfLiteOk, ParseOpData(custom_op, BuiltinOperator_CUSTOM, reporter, + &output_data)); + EXPECT_EQ(nullptr, output_data); +} + +TEST(FlatbufferConversions, TestConvertTensorType) { + MockErrorReporter mock_reporter; + ErrorReporter* reporter = &mock_reporter; + TfLiteType type; + EXPECT_EQ(kTfLiteOk, ConvertTensorType(TensorType_FLOAT32, &type, reporter)); + EXPECT_EQ(kTfLiteFloat32, type); +} + +} // namespace tflite + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/core/api/op_resolver.cc b/tensorflow/contrib/lite/core/api/op_resolver.cc new file mode 100644 index 0000000000..55ee924843 --- /dev/null +++ b/tensorflow/contrib/lite/core/api/op_resolver.cc @@ -0,0 +1,60 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/core/api/op_resolver.h" + +namespace tflite { + +TfLiteStatus GetRegistrationFromOpCode( + const OperatorCode* opcode, const OpResolver& op_resolver, + ErrorReporter* error_reporter, const TfLiteRegistration** registration) { + TfLiteStatus status = kTfLiteOk; + *registration = nullptr; + auto builtin_code = opcode->builtin_code(); + int version = opcode->version(); + + if (builtin_code > BuiltinOperator_MAX || + builtin_code < BuiltinOperator_MIN) { + error_reporter->Report( + "Op builtin_code out of range: %d. Are you using old TFLite binary " + "with newer model?", + builtin_code); + status = kTfLiteError; + } else if (builtin_code != BuiltinOperator_CUSTOM) { + *registration = op_resolver.FindOp(builtin_code, version); + if (*registration == nullptr) { + error_reporter->Report( + "Didn't find op for builtin opcode '%s' version '%d'\n", + EnumNameBuiltinOperator(builtin_code), version); + status = kTfLiteError; + } + } else if (!opcode->custom_code()) { + error_reporter->Report( + "Operator with CUSTOM builtin_code has no custom_code.\n"); + status = kTfLiteError; + } else { + const char* name = opcode->custom_code()->c_str(); + *registration = op_resolver.FindOp(name, version); + if (*registration == nullptr) { + error_reporter->Report( + "Didn't find custom op for name '%s' with version %d\n", name, + version); + status = kTfLiteError; + } + } + return status; +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/core/api/op_resolver.h b/tensorflow/contrib/lite/core/api/op_resolver.h new file mode 100644 index 0000000000..5f5e6b2736 --- /dev/null +++ b/tensorflow/contrib/lite/core/api/op_resolver.h @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_CORE_API_OP_RESOLVER_H_ +#define TENSORFLOW_CONTRIB_LITE_CORE_API_OP_RESOLVER_H_ + +#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/contrib/lite/core/api/error_reporter.h" +#include "tensorflow/contrib/lite/schema/schema_generated.h" + +namespace tflite { + +// Abstract interface that returns TfLiteRegistrations given op codes or custom +// op names. This is the mechanism that ops being referenced in the flatbuffer +// model are mapped to executable function pointers (TfLiteRegistrations). +class OpResolver { + public: + // Finds the op registration for a builtin operator by enum code. + virtual const TfLiteRegistration* FindOp(tflite::BuiltinOperator op, + int version) const = 0; + // Finds the op registration of a custom operator by op name. + virtual const TfLiteRegistration* FindOp(const char* op, + int version) const = 0; + virtual ~OpResolver() {} +}; + +// Handles the logic for converting between an OperatorCode structure extracted +// from a flatbuffer and information about a registered operator implementation. +TfLiteStatus GetRegistrationFromOpCode(const OperatorCode* opcode, + const OpResolver& op_resolver, + ErrorReporter* error_reporter, + const TfLiteRegistration** registration); + +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_CORE_API_OP_RESOLVER_H_ diff --git a/tensorflow/contrib/lite/core/api/op_resolver_test.cc b/tensorflow/contrib/lite/core/api/op_resolver_test.cc new file mode 100644 index 0000000000..167463110e --- /dev/null +++ b/tensorflow/contrib/lite/core/api/op_resolver_test.cc @@ -0,0 +1,197 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/core/api/op_resolver.h" + +#include <cstring> + +#include <gtest/gtest.h> + +namespace tflite { +namespace { +void* MockInit(TfLiteContext* context, const char* buffer, size_t length) { + // Do nothing. + return nullptr; +} + +void MockFree(TfLiteContext* context, void* buffer) { + // Do nothing. +} + +TfLiteStatus MockPrepare(TfLiteContext* context, TfLiteNode* node) { + return kTfLiteOk; +} + +TfLiteStatus MockInvoke(TfLiteContext* context, TfLiteNode* node) { + return kTfLiteOk; +} + +class MockOpResolver : public OpResolver { + public: + const TfLiteRegistration* FindOp(BuiltinOperator op, + int version) const override { + if (op == BuiltinOperator_CONV_2D) { + static TfLiteRegistration r = {MockInit, MockFree, MockPrepare, + MockInvoke}; + return &r; + } else { + return nullptr; + } + } + const TfLiteRegistration* FindOp(const char* op, int version) const override { + if (strcmp(op, "mock_custom") == 0) { + static TfLiteRegistration r = {MockInit, MockFree, MockPrepare, + MockInvoke}; + return &r; + } else { + return nullptr; + } + } +}; + +class MockErrorReporter : public ErrorReporter { + public: + MockErrorReporter() : buffer_size_(0) {} + int Report(const char* format, va_list args) override { + buffer_size_ = vsnprintf(buffer_, kBufferSize, format, args); + return buffer_size_; + } + char* GetBuffer() { return buffer_; } + int GetBufferSize() { return buffer_size_; } + + private: + static constexpr int kBufferSize = 256; + char buffer_[kBufferSize]; + int buffer_size_; +}; + +} // namespace + +TEST(OpResolver, TestResolver) { + MockOpResolver mock_resolver; + OpResolver* resolver = &mock_resolver; + + const TfLiteRegistration* registration = + resolver->FindOp(BuiltinOperator_CONV_2D, 0); + EXPECT_NE(nullptr, registration); + EXPECT_EQ(nullptr, registration->init(nullptr, nullptr, 0)); + EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr)); + EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr)); + + registration = resolver->FindOp(BuiltinOperator_CAST, 0); + EXPECT_EQ(nullptr, registration); + + registration = resolver->FindOp("mock_custom", 0); + EXPECT_NE(nullptr, registration); + EXPECT_EQ(nullptr, registration->init(nullptr, nullptr, 0)); + EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr)); + EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr)); + + registration = resolver->FindOp("nonexistent_custom", 0); + EXPECT_EQ(nullptr, registration); +} + +TEST(OpResolver, TestGetRegistrationFromOpCodeConv) { + MockOpResolver mock_resolver; + OpResolver* resolver = &mock_resolver; + MockErrorReporter mock_reporter; + ErrorReporter* reporter = &mock_reporter; + + flatbuffers::FlatBufferBuilder builder; + flatbuffers::Offset<OperatorCode> conv_offset = + CreateOperatorCodeDirect(builder, BuiltinOperator_CONV_2D, nullptr, 0); + builder.Finish(conv_offset); + void* conv_pointer = builder.GetBufferPointer(); + const OperatorCode* conv_code = + flatbuffers::GetRoot<OperatorCode>(conv_pointer); + const TfLiteRegistration* registration = nullptr; + EXPECT_EQ(kTfLiteOk, GetRegistrationFromOpCode(conv_code, *resolver, reporter, + ®istration)); + EXPECT_NE(nullptr, registration); + EXPECT_EQ(nullptr, registration->init(nullptr, nullptr, 0)); + EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr)); + EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr)); + EXPECT_EQ(0, mock_reporter.GetBufferSize()); +} + +TEST(OpResolver, TestGetRegistrationFromOpCodeCast) { + MockOpResolver mock_resolver; + OpResolver* resolver = &mock_resolver; + MockErrorReporter mock_reporter; + ErrorReporter* reporter = &mock_reporter; + + flatbuffers::FlatBufferBuilder builder; + flatbuffers::Offset<OperatorCode> conv_offset = + CreateOperatorCodeDirect(builder, BuiltinOperator_CAST, nullptr, 0); + builder.Finish(conv_offset); + void* conv_pointer = builder.GetBufferPointer(); + const OperatorCode* conv_code = + flatbuffers::GetRoot<OperatorCode>(conv_pointer); + const TfLiteRegistration* registration = nullptr; + EXPECT_EQ(kTfLiteError, GetRegistrationFromOpCode(conv_code, *resolver, + reporter, ®istration)); + EXPECT_EQ(nullptr, registration); + EXPECT_NE(0, mock_reporter.GetBufferSize()); +} + +TEST(OpResolver, TestGetRegistrationFromOpCodeCustom) { + MockOpResolver mock_resolver; + OpResolver* resolver = &mock_resolver; + MockErrorReporter mock_reporter; + ErrorReporter* reporter = &mock_reporter; + + flatbuffers::FlatBufferBuilder builder; + flatbuffers::Offset<OperatorCode> conv_offset = CreateOperatorCodeDirect( + builder, BuiltinOperator_CUSTOM, "mock_custom", 0); + builder.Finish(conv_offset); + void* conv_pointer = builder.GetBufferPointer(); + const OperatorCode* conv_code = + flatbuffers::GetRoot<OperatorCode>(conv_pointer); + const TfLiteRegistration* registration = nullptr; + EXPECT_EQ(kTfLiteOk, GetRegistrationFromOpCode(conv_code, *resolver, reporter, + ®istration)); + EXPECT_NE(nullptr, registration); + EXPECT_EQ(nullptr, registration->init(nullptr, nullptr, 0)); + EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr)); + EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr)); + EXPECT_EQ(0, mock_reporter.GetBufferSize()); +} + +TEST(OpResolver, TestGetRegistrationFromOpCodeNonexistentCustom) { + MockOpResolver mock_resolver; + OpResolver* resolver = &mock_resolver; + MockErrorReporter mock_reporter; + ErrorReporter* reporter = &mock_reporter; + + flatbuffers::FlatBufferBuilder builder; + flatbuffers::Offset<OperatorCode> conv_offset = CreateOperatorCodeDirect( + builder, BuiltinOperator_CUSTOM, "nonexistent_custom", 0); + builder.Finish(conv_offset); + void* conv_pointer = builder.GetBufferPointer(); + const OperatorCode* conv_code = + flatbuffers::GetRoot<OperatorCode>(conv_pointer); + const TfLiteRegistration* registration = nullptr; + EXPECT_EQ(kTfLiteError, GetRegistrationFromOpCode(conv_code, *resolver, + reporter, ®istration)); + EXPECT_EQ(nullptr, registration); + EXPECT_NE(0, mock_reporter.GetBufferSize()); +} + +} // namespace tflite + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/delegates/eager/BUILD b/tensorflow/contrib/lite/delegates/eager/BUILD index b6b2357873..bf5d91899c 100644 --- a/tensorflow/contrib/lite/delegates/eager/BUILD +++ b/tensorflow/contrib/lite/delegates/eager/BUILD @@ -16,6 +16,7 @@ cc_library( deps = [ ":util", "//tensorflow/c:c_api_internal", + "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite:kernel_api", ] + select({ "//tensorflow:android": [ @@ -54,6 +55,7 @@ cc_library( ":delegate_data", ":kernel", ":util", + "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite:kernel_api", "//tensorflow/contrib/lite:util", ] + select({ @@ -104,6 +106,7 @@ tf_cc_test( ":delegate_data", "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:util", + "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite/testing:util", "@com_google_googletest//:gtest", ], @@ -117,6 +120,7 @@ cc_library( ":delegate_data", ":util", "@flatbuffers", + "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite:kernel_api", "//tensorflow/contrib/lite:string", "//tensorflow/contrib/lite/kernels:kernel_util", @@ -170,6 +174,7 @@ cc_library( hdrs = ["util.h"], deps = [ "//tensorflow/c:c_api_internal", + "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite:kernel_api", ] + select({ "//tensorflow:android": [ diff --git a/tensorflow/contrib/lite/delegates/eager/buffer_map.h b/tensorflow/contrib/lite/delegates/eager/buffer_map.h index a28329ae7d..aaaa045840 100644 --- a/tensorflow/contrib/lite/delegates/eager/buffer_map.h +++ b/tensorflow/contrib/lite/delegates/eager/buffer_map.h @@ -17,7 +17,7 @@ limitations under the License. #include <map> -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/core/framework/tensor.h" namespace tflite { diff --git a/tensorflow/contrib/lite/delegates/eager/delegate.h b/tensorflow/contrib/lite/delegates/eager/delegate.h index 6d15ba47dc..70f3c15af4 100644 --- a/tensorflow/contrib/lite/delegates/eager/delegate.h +++ b/tensorflow/contrib/lite/delegates/eager/delegate.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_H_ #define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_H_ -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/delegates/eager/delegate_data.h" namespace tflite { diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc b/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc index b3a0ffcec1..def063309f 100644 --- a/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc +++ b/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include <gmock/gmock.h> #include <gtest/gtest.h> -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/testing/util.h" namespace tflite { diff --git a/tensorflow/contrib/lite/delegates/eager/kernel.cc b/tensorflow/contrib/lite/delegates/eager/kernel.cc index 0ee4db1ffb..274c3c082a 100644 --- a/tensorflow/contrib/lite/delegates/eager/kernel.cc +++ b/tensorflow/contrib/lite/delegates/eager/kernel.cc @@ -16,7 +16,7 @@ limitations under the License. #include "flatbuffers/flexbuffers.h" // flatbuffers #include "tensorflow/contrib/lite/builtin_ops.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/context_util.h" #include "tensorflow/contrib/lite/delegates/eager/delegate_data.h" #include "tensorflow/contrib/lite/delegates/eager/util.h" diff --git a/tensorflow/contrib/lite/delegates/eager/kernel.h b/tensorflow/contrib/lite/delegates/eager/kernel.h index 100672c82d..2478abccaa 100644 --- a/tensorflow/contrib/lite/delegates/eager/kernel.h +++ b/tensorflow/contrib/lite/delegates/eager/kernel.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_KERNEL_H_ #define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_KERNEL_H_ -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" namespace tflite { namespace eager { diff --git a/tensorflow/contrib/lite/delegates/eager/util.h b/tensorflow/contrib/lite/delegates/eager/util.h index ff500d18f3..930cb99cb9 100644 --- a/tensorflow/contrib/lite/delegates/eager/util.h +++ b/tensorflow/contrib/lite/delegates/eager/util.h @@ -16,7 +16,7 @@ limitations under the License. #define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_UTIL_H_ #include "tensorflow/c/c_api_internal.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" diff --git a/tensorflow/contrib/lite/delegates/nnapi/BUILD b/tensorflow/contrib/lite/delegates/nnapi/BUILD index 954955f24b..4e7b2948fb 100644 --- a/tensorflow/contrib/lite/delegates/nnapi/BUILD +++ b/tensorflow/contrib/lite/delegates/nnapi/BUILD @@ -13,6 +13,7 @@ cc_library( deps = [ "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:kernel_api", + "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite/kernels:kernel_util", "//tensorflow/contrib/lite/nnapi:nnapi_lib", ], @@ -29,6 +30,7 @@ tf_cc_test( deps = [ ":nnapi_delegate", "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite/kernels:test_util", "@com_google_googletest//:gtest", ], diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc index 980a1cb4a0..e3eebac4da 100644 --- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/contrib/lite/allocation.h" #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/builtin_ops.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/context_util.h" #include "tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h index 44cca2fd28..4852b76974 100644 --- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h +++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_H_ #define TENSORFLOW_CONTRIB_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_H_ -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" namespace tflite { diff --git a/tensorflow/contrib/lite/error_reporter.h b/tensorflow/contrib/lite/error_reporter.h index 3c5f805f12..5c20eedc25 100644 --- a/tensorflow/contrib/lite/error_reporter.h +++ b/tensorflow/contrib/lite/error_reporter.h @@ -12,43 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +// Compatibility shim for moved header location. #ifndef TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_ #define TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_ -#include <cstdarg> -#include "tensorflow/contrib/lite/context.h" - -namespace tflite { - -// A functor that reports error to supporting system. Invoked similar to -// printf. -// -// Usage: -// ErrorReporter foo; -// foo.Report("test %d", 5); -// or -// va_list args; -// foo.Report("test %d", args); // where args is va_list -// -// Subclass ErrorReporter to provide another reporting destination. -// For example, if you have a GUI program, you might redirect to a buffer -// that drives a GUI error log box. -class ErrorReporter { - public: - virtual ~ErrorReporter(); - virtual int Report(const char* format, va_list args) = 0; - int Report(const char* format, ...); - int ReportError(void*, const char* format, ...); -}; - -// An error reporter that simplify writes the message to stderr. -struct StderrReporter : public ErrorReporter { - int Report(const char* format, va_list args) override; -}; - -// Return the default error reporter (output to stderr). -ErrorReporter* DefaultErrorReporter(); - -} // namespace tflite +#include "tensorflow/contrib/lite/core/api/error_reporter.h" +#include "tensorflow/contrib/lite/stderr_reporter.h" #endif // TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_ diff --git a/tensorflow/contrib/lite/experimental/c/BUILD b/tensorflow/contrib/lite/experimental/c/BUILD index 8fc07e8eb7..ea4a543252 100644 --- a/tensorflow/contrib/lite/experimental/c/BUILD +++ b/tensorflow/contrib/lite/experimental/c/BUILD @@ -78,6 +78,7 @@ cc_test( data = ["//tensorflow/contrib/lite:testdata/add.bin"], deps = [ ":c_api", + "//tensorflow/contrib/lite:context", "//tensorflow/contrib/lite:kernel_api", "//tensorflow/contrib/lite/testing:util", "@com_google_googletest//:gtest", diff --git a/tensorflow/contrib/lite/experimental/c/c_api.cc b/tensorflow/contrib/lite/experimental/c/c_api.cc index a4ab0e8c30..c589cf71ea 100644 --- a/tensorflow/contrib/lite/experimental/c/c_api.cc +++ b/tensorflow/contrib/lite/experimental/c/c_api.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/experimental/c/c_api.h" +#include <memory> + #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/experimental/c/c_api_internal.h" #include "tensorflow/contrib/lite/interpreter.h" @@ -29,12 +31,14 @@ extern "C" { TFL_Model* TFL_NewModel(const void* model_data, size_t model_size) { auto model = tflite::FlatBufferModel::BuildFromBuffer( static_cast<const char*>(model_data), model_size); - return model ? new TFL_Model{std::move(model)} : nullptr; + std::shared_ptr<const tflite::FlatBufferModel> shared_model(model.release()); + return shared_model ? new TFL_Model{std::move(shared_model)} : nullptr; } TFL_Model* TFL_NewModelFromFile(const char* model_path) { auto model = tflite::FlatBufferModel::BuildFromFile(model_path); - return model ? new TFL_Model{std::move(model)} : nullptr; + std::shared_ptr<const tflite::FlatBufferModel> shared_model(model.release()); + return shared_model ? new TFL_Model{std::move(shared_model)} : nullptr; } void TFL_DeleteModel(TFL_Model* model) { delete model; } @@ -72,7 +76,7 @@ TFL_Interpreter* TFL_NewInterpreter( } } - return new TFL_Interpreter{std::move(interpreter)}; + return new TFL_Interpreter{model->impl, std::move(interpreter)}; } void TFL_DeleteInterpreter(TFL_Interpreter* interpreter) { delete interpreter; } @@ -129,6 +133,8 @@ void* TFL_TensorData(const TFL_Tensor* tensor) { return static_cast<void*>(tensor->data.raw); } +const char* TFL_TensorName(const TFL_Tensor* tensor) { return tensor->name; } + TFL_Status TFL_TensorCopyFromBuffer(TFL_Tensor* tensor, const void* input_data, size_t input_data_size) { if (tensor->bytes != input_data_size) { diff --git a/tensorflow/contrib/lite/experimental/c/c_api.h b/tensorflow/contrib/lite/experimental/c/c_api.h index 3757349b55..b429e76870 100644 --- a/tensorflow/contrib/lite/experimental/c/c_api.h +++ b/tensorflow/contrib/lite/experimental/c/c_api.h @@ -93,7 +93,8 @@ typedef struct TFL_Interpreter TFL_Interpreter; // failure. // // * `model` must be a valid model instance. The caller retains ownership of the -// object, and can destroy it immediately after creating the interpreter. +// object, and can destroy it immediately after creating the interpreter; the +// interpreter will maintain its own reference to the underlying model data. // * `optional_options` may be null. The caller retains ownership of the object, // and can safely destroy it immediately after creating the interpreter. // @@ -145,6 +146,11 @@ TFL_CAPI_EXPORT extern int32_t TFL_InterpreterGetOutputTensorCount( // Returns the tensor associated with the output index. // REQUIRES: 0 <= input_index < TFL_InterpreterGetOutputTensorCount(tensor) +// +// NOTE: The shape and underlying data buffer for output tensors may be not +// be available until after the output tensor has been both sized and allocated. +// In general, best practice is to interact with the output tensor *after* +// calling TFL_InterpreterInvoke(). TFL_CAPI_EXPORT extern const TFL_Tensor* TFL_InterpreterGetOutputTensor( const TFL_Interpreter* interpreter, int32_t output_index); @@ -172,12 +178,15 @@ TFL_CAPI_EXPORT extern size_t TFL_TensorByteSize(const TFL_Tensor* tensor); // Returns a pointer to the underlying data buffer. // -// Note: The result may be null if tensors have not yet been allocated, e.g., +// NOTE: The result may be null if tensors have not yet been allocated, e.g., // if the Tensor has just been created or resized and `TFL_AllocateTensors()` // has yet to be called, or if the output tensor is dynamically sized and the // interpreter hasn't been invoked. TFL_CAPI_EXPORT extern void* TFL_TensorData(const TFL_Tensor* tensor); +// Returns the (null-terminated) name of the tensor. +TFL_CAPI_EXPORT extern const char* TFL_TensorName(const TFL_Tensor* tensor); + // Copies from the provided input buffer into the tensor's buffer. // REQUIRES: input_data_size == TFL_TensorByteSize(tensor) TFL_CAPI_EXPORT extern TFL_Status TFL_TensorCopyFromBuffer( diff --git a/tensorflow/contrib/lite/experimental/c/c_api_internal.h b/tensorflow/contrib/lite/experimental/c/c_api_internal.h index c5c612a4c6..60c2e4e2cd 100644 --- a/tensorflow/contrib/lite/experimental/c/c_api_internal.h +++ b/tensorflow/contrib/lite/experimental/c/c_api_internal.h @@ -24,7 +24,8 @@ limitations under the License. // not be depended on. struct TFL_Model { - std::unique_ptr<tflite::FlatBufferModel> impl; + // Sharing is safe as FlatBufferModel is const. + std::shared_ptr<const tflite::FlatBufferModel> impl; }; struct TFL_InterpreterOptions { @@ -35,6 +36,9 @@ struct TFL_InterpreterOptions { }; struct TFL_Interpreter { + // Taking a reference to the (const) model data avoids lifetime-related issues + // and complexity with the TFL_Model's existence. + std::shared_ptr<const tflite::FlatBufferModel> model; std::unique_ptr<tflite::Interpreter> impl; }; diff --git a/tensorflow/contrib/lite/experimental/c/c_api_test.cc b/tensorflow/contrib/lite/experimental/c/c_api_test.cc index a631dae890..649dac8d1a 100644 --- a/tensorflow/contrib/lite/experimental/c/c_api_test.cc +++ b/tensorflow/contrib/lite/experimental/c/c_api_test.cc @@ -55,6 +55,8 @@ TEST(CApiSimple, Smoke) { EXPECT_EQ(TFL_TensorNumDims(input_tensor), 1); EXPECT_EQ(TFL_TensorDim(input_tensor, 0), 2); EXPECT_EQ(TFL_TensorByteSize(input_tensor), sizeof(float) * 2); + EXPECT_NE(TFL_TensorData(input_tensor), nullptr); + EXPECT_STREQ(TFL_TensorName(input_tensor), "input"); std::array<float, 2> input = {1.f, 3.f}; ASSERT_EQ(TFL_TensorCopyFromBuffer(input_tensor, input.data(), @@ -70,6 +72,8 @@ TEST(CApiSimple, Smoke) { EXPECT_EQ(TFL_TensorNumDims(output_tensor), 1); EXPECT_EQ(TFL_TensorDim(output_tensor, 0), 2); EXPECT_EQ(TFL_TensorByteSize(output_tensor), sizeof(float) * 2); + EXPECT_NE(TFL_TensorData(output_tensor), nullptr); + EXPECT_STREQ(TFL_TensorName(output_tensor), "output"); std::array<float, 2> output; ASSERT_EQ(TFL_TensorCopyToBuffer(output_tensor, output.data(), diff --git a/tensorflow/contrib/lite/experimental/kernels/BUILD b/tensorflow/contrib/lite/experimental/kernels/BUILD index 9c06c4ebd9..4786cc62f9 100644 --- a/tensorflow/contrib/lite/experimental/kernels/BUILD +++ b/tensorflow/contrib/lite/experimental/kernels/BUILD @@ -53,6 +53,7 @@ cc_library( "//tensorflow/contrib/lite:builtin_op_data", "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite/kernels:builtin_ops", "//tensorflow/contrib/lite/kernels:gemm_support", "//tensorflow/contrib/lite/kernels:kernel_util", @@ -61,8 +62,8 @@ cc_library( "//tensorflow/contrib/lite/kernels/internal:optimized", "//tensorflow/contrib/lite/kernels/internal:optimized_base", "//tensorflow/contrib/lite/kernels/internal:quantization_util", - "//tensorflow/contrib/lite/kernels/internal:reference", "//tensorflow/contrib/lite/kernels/internal:reference_base", + "//tensorflow/contrib/lite/kernels/internal:tensor", "//tensorflow/contrib/lite/kernels/internal:tensor_utils", "@flatbuffers", ], diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc index 121997dcb2..8442c4d46c 100644 --- a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc +++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include <vector> #include "flatbuffers/flexbuffers.h" // flatbuffers -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" diff --git a/tensorflow/contrib/lite/graph_info.h b/tensorflow/contrib/lite/graph_info.h index 77268d7aeb..8ee83827bb 100644 --- a/tensorflow/contrib/lite/graph_info.h +++ b/tensorflow/contrib/lite/graph_info.h @@ -17,7 +17,7 @@ limitations under the License. #include <vector> -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" namespace tflite { diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc index 5ab53f4c1d..3f8f4d198f 100644 --- a/tensorflow/contrib/lite/interpreter.cc +++ b/tensorflow/contrib/lite/interpreter.cc @@ -21,9 +21,9 @@ limitations under the License. #include <cstring> #include "tensorflow/contrib/lite/arena_planner.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/context_util.h" -#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/core/api/error_reporter.h" #include "tensorflow/contrib/lite/graph_info.h" #include "tensorflow/contrib/lite/memory_planner.h" #include "tensorflow/contrib/lite/nnapi_delegate.h" diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h index 2b1f1819b9..f0cd178c19 100644 --- a/tensorflow/contrib/lite/interpreter.h +++ b/tensorflow/contrib/lite/interpreter.h @@ -23,10 +23,11 @@ limitations under the License. #include <vector> #include "tensorflow/contrib/lite/allocation.h" -#include "tensorflow/contrib/lite/context.h" -#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/contrib/lite/core/api/error_reporter.h" #include "tensorflow/contrib/lite/memory_planner.h" #include "tensorflow/contrib/lite/profiling/profiler.h" +#include "tensorflow/contrib/lite/stderr_reporter.h" namespace tflite { diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc index 5bcf0927d8..cdede430e2 100644 --- a/tensorflow/contrib/lite/interpreter_test.cc +++ b/tensorflow/contrib/lite/interpreter_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/contrib/lite/interpreter.h" #include <gtest/gtest.h> -#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/core/api/error_reporter.h" #include "tensorflow/contrib/lite/kernels/internal/compatibility.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" #include "tensorflow/contrib/lite/schema/schema_generated.h" diff --git a/tensorflow/contrib/lite/java/ovic/BUILD b/tensorflow/contrib/lite/java/ovic/BUILD index 06f46fb923..781289ceb2 100644 --- a/tensorflow/contrib/lite/java/ovic/BUILD +++ b/tensorflow/contrib/lite/java/ovic/BUILD @@ -35,6 +35,7 @@ java_binary( "//tensorflow/contrib/lite/java/ovic/src/testdata:labels.txt", ], main_class = "org.tensorflow.ovic.OvicValidator", + tags = ["no_oss"], deps = [ "//tensorflow/contrib/lite/java/ovic:ovicbenchmarkerlib_java", ], @@ -47,6 +48,7 @@ android_library( "src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java", ], manifest = "//tensorflow/contrib/lite/java:AndroidManifest.xml", + tags = ["no_oss"], deps = [ "//tensorflow/contrib/lite/java:tensorflowlite", "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper", @@ -61,6 +63,7 @@ java_library( "src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java", ], javacopts = JAVACOPTS, + tags = ["no_oss"], deps = [ "//tensorflow/contrib/lite/java:libtensorflowlite_jni.so", "//tensorflow/contrib/lite/java:tensorflowlite_java", diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h index 55ca47fed7..06b35d77c8 100644 --- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h +++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h @@ -20,7 +20,7 @@ limitations under the License. #include <stdio.h> #include <time.h> #include <vector> -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/interpreter.h" #include "tensorflow/contrib/lite/java/src/main/native/exception_jni.h" #include "tensorflow/contrib/lite/java/src/main/native/tensor_jni.h" @@ -124,9 +124,9 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_useNNAPI(JNIEnv* env, */ JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_numThreads(JNIEnv* env, - jclass clazz, - jlong handle, - jint num_threads); + jclass clazz, + jlong handle, + jint num_threads); /* * Class: org_tensorflow_lite_NativeInterpreterWrapper * Method: diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h index c020f13d9c..2f73128bdf 100644 --- a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h +++ b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_CONTRIB_LITE_JAVA_SRC_MAIN_NATIVE_TENSOR_JNI_H_ #include <jni.h> -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #ifdef __cplusplus extern "C" { diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD index b7c5cbf207..40f28aeab4 100644 --- a/tensorflow/contrib/lite/kernels/BUILD +++ b/tensorflow/contrib/lite/kernels/BUILD @@ -66,7 +66,7 @@ cc_library( deps = [ ":op_macros", "//tensorflow/contrib/lite:arena_planner", - "//tensorflow/contrib/lite:context", + "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite/kernels/internal:optimized", ], ) @@ -82,7 +82,7 @@ cc_library( copts = tflite_copts(), deps = [ ":op_macros", - "//tensorflow/contrib/lite:context", + "//tensorflow/contrib/lite/c:c_api_internal", "@gemmlowp", ], ) @@ -93,7 +93,7 @@ cc_library( "activation_functor.h", ], deps = [ - "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite/c:c_api_internal", ], ) @@ -113,9 +113,9 @@ cc_library( "kernel_util.h", ], deps = [ - "//tensorflow/contrib/lite:builtin_op_data", - "//tensorflow/contrib/lite:context", + "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite/kernels/internal:round", + "//tensorflow/contrib/lite/kernels/internal:types", ], ) @@ -147,6 +147,15 @@ tf_cc_test( ) cc_library( + name = "padding", + srcs = [], + hdrs = ["padding.h"], + deps = [ + "//tensorflow/contrib/lite/c:c_api_internal", + ], +) + +cc_library( name = "builtin_op_kernels", srcs = [ "activations.cc", @@ -216,7 +225,6 @@ cc_library( "unpack.cc", ], hdrs = [ - "padding.h", ], copts = tflite_copts() + tf_opts_nortti_if_android() + EXTRA_EIGEN_COPTS, visibility = ["//visibility:private"], @@ -225,18 +233,19 @@ cc_library( ":eigen_support", ":kernel_util", ":op_macros", - "//tensorflow/contrib/lite:builtin_op_data", + ":padding", "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:string_util", "//tensorflow/contrib/lite:util", + "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite/kernels:gemm_support", "//tensorflow/contrib/lite/kernels/internal:audio_utils", "//tensorflow/contrib/lite/kernels/internal:kernel_utils", "//tensorflow/contrib/lite/kernels/internal:optimized", "//tensorflow/contrib/lite/kernels/internal:optimized_base", "//tensorflow/contrib/lite/kernels/internal:quantization_util", - "//tensorflow/contrib/lite/kernels/internal:reference", "//tensorflow/contrib/lite/kernels/internal:reference_base", + "//tensorflow/contrib/lite/kernels/internal:tensor", "//tensorflow/contrib/lite/kernels/internal:tensor_utils", "@farmhash_archive//:farmhash", "@flatbuffers", @@ -251,6 +260,7 @@ cc_library( ":builtin_op_kernels", "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:util", + "//tensorflow/contrib/lite/c:c_api_internal", ], ) @@ -757,8 +767,8 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:builtin_op_data", "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite/kernels:test_util", "@com_google_googletest//:gtest", ], @@ -774,8 +784,8 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:builtin_op_data", "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite/kernels:test_util", "@com_google_googletest//:gtest", ], @@ -1044,8 +1054,8 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:builtin_op_data", "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite/kernels:test_util", "@com_google_googletest//:gtest", ], @@ -1147,8 +1157,8 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:builtin_op_data", "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite/kernels:test_util", "@com_google_googletest//:gtest", ], @@ -1164,8 +1174,8 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:builtin_op_data", "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite/kernels:test_util", "@com_google_googletest//:gtest", ], @@ -1181,8 +1191,8 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:builtin_op_data", "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite/kernels:test_util", "@com_google_googletest//:gtest", ], @@ -1198,8 +1208,8 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:builtin_op_data", "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite/kernels:test_util", "@com_google_googletest//:gtest", ], @@ -1212,8 +1222,8 @@ tf_cc_test( tags = ["tflite_not_portable_ios"], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:builtin_op_data", "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite/kernels:test_util", "@com_google_googletest//:gtest", ], @@ -1239,8 +1249,8 @@ tf_cc_test( tags = ["tflite_not_portable_ios"], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:builtin_op_data", "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite/kernels:test_util", "@com_google_googletest//:gtest", ], diff --git a/tensorflow/contrib/lite/kernels/activation_functor.h b/tensorflow/contrib/lite/kernels/activation_functor.h index 41ec3cca33..e075dc7054 100644 --- a/tensorflow/contrib/lite/kernels/activation_functor.h +++ b/tensorflow/contrib/lite/kernels/activation_functor.h @@ -19,7 +19,7 @@ limitations under the License. #include <cmath> #include <cstdlib> -#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" namespace tflite { diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/contrib/lite/kernels/activations.cc index 5cdd9fc94f..b2d9b84979 100644 --- a/tensorflow/contrib/lite/kernels/activations.cc +++ b/tensorflow/contrib/lite/kernels/activations.cc @@ -19,8 +19,8 @@ limitations under the License. #include <iostream> #include <limits> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" diff --git a/tensorflow/contrib/lite/kernels/add.cc b/tensorflow/contrib/lite/kernels/add.cc index af9b5c7013..b4393e8097 100644 --- a/tensorflow/contrib/lite/kernels/add.cc +++ b/tensorflow/contrib/lite/kernels/add.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" diff --git a/tensorflow/contrib/lite/kernels/arg_min_max.cc b/tensorflow/contrib/lite/kernels/arg_min_max.cc index 6e05f5a9b2..b91e348c27 100644 --- a/tensorflow/contrib/lite/kernels/arg_min_max.cc +++ b/tensorflow/contrib/lite/kernels/arg_min_max.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" diff --git a/tensorflow/contrib/lite/kernels/audio_spectrogram.cc b/tensorflow/contrib/lite/kernels/audio_spectrogram.cc index 1170d84553..44ef587244 100644 --- a/tensorflow/contrib/lite/kernels/audio_spectrogram.cc +++ b/tensorflow/contrib/lite/kernels/audio_spectrogram.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/spectrogram.h" diff --git a/tensorflow/contrib/lite/kernels/basic_rnn.cc b/tensorflow/contrib/lite/kernels/basic_rnn.cc index c5a5c0182f..1aa27602e5 100644 --- a/tensorflow/contrib/lite/kernels/basic_rnn.cc +++ b/tensorflow/contrib/lite/kernels/basic_rnn.cc @@ -15,8 +15,8 @@ limitations under the License. #include <stddef.h> #include <stdint.h> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" #include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc index 4efa9d596d..fe2865dfb9 100644 --- a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc +++ b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include <string.h> #include <vector> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc index 6b8ecdd5c3..541f320138 100644 --- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc +++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc @@ -20,8 +20,8 @@ limitations under the License. #include <iostream> #include <limits> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" #include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" #include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc index d988ef8b33..2f896c5289 100644 --- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc +++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc @@ -19,8 +19,8 @@ limitations under the License. #include <iostream> #include <limits> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" #include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/cast.cc b/tensorflow/contrib/lite/kernels/cast.cc index 8dd48af57f..a7972140ac 100644 --- a/tensorflow/contrib/lite/kernels/cast.cc +++ b/tensorflow/contrib/lite/kernels/cast.cc @@ -15,8 +15,8 @@ limitations under the License. #include <string.h> #include <algorithm> #include <complex> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/comparisons.cc b/tensorflow/contrib/lite/kernels/comparisons.cc index 8b4d778332..4cd96348a2 100644 --- a/tensorflow/contrib/lite/kernels/comparisons.cc +++ b/tensorflow/contrib/lite/kernels/comparisons.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/concatenation.cc b/tensorflow/contrib/lite/kernels/concatenation.cc index 605a20ac3e..25ea556d5a 100644 --- a/tensorflow/contrib/lite/kernels/concatenation.cc +++ b/tensorflow/contrib/lite/kernels/concatenation.cc @@ -19,8 +19,8 @@ limitations under the License. #include <iostream> #include <limits> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc index 3ed0cdb131..ab6bdaecaa 100644 --- a/tensorflow/contrib/lite/kernels/conv.cc +++ b/tensorflow/contrib/lite/kernels/conv.cc @@ -20,8 +20,8 @@ limitations under the License. #include <iostream> #include <limits> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/eigen_support.h" #include "tensorflow/contrib/lite/kernels/gemm_support.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h" diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv.cc b/tensorflow/contrib/lite/kernels/depthwise_conv.cc index 21518156b8..347515f289 100644 --- a/tensorflow/contrib/lite/kernels/depthwise_conv.cc +++ b/tensorflow/contrib/lite/kernels/depthwise_conv.cc @@ -19,8 +19,8 @@ limitations under the License. #include <iostream> #include <limits> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h" #include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" diff --git a/tensorflow/contrib/lite/kernels/dequantize.cc b/tensorflow/contrib/lite/kernels/dequantize.cc index 2b0f04489a..3a08f48b00 100644 --- a/tensorflow/contrib/lite/kernels/dequantize.cc +++ b/tensorflow/contrib/lite/kernels/dequantize.cc @@ -15,8 +15,8 @@ limitations under the License. #include <string.h> #include <vector> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/detection_postprocess.cc b/tensorflow/contrib/lite/kernels/detection_postprocess.cc index 136697f945..d2906632d7 100644 --- a/tensorflow/contrib/lite/kernels/detection_postprocess.cc +++ b/tensorflow/contrib/lite/kernels/detection_postprocess.cc @@ -16,8 +16,8 @@ limitations under the License. #include <numeric> #include <vector> #include "flatbuffers/flexbuffers.h" // flatbuffers -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" diff --git a/tensorflow/contrib/lite/kernels/div.cc b/tensorflow/contrib/lite/kernels/div.cc index d7420ddd8e..7945c095b1 100644 --- a/tensorflow/contrib/lite/kernels/div.cc +++ b/tensorflow/contrib/lite/kernels/div.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" diff --git a/tensorflow/contrib/lite/kernels/eigen_support.h b/tensorflow/contrib/lite/kernels/eigen_support.h index b235829642..feb1543f7b 100644 --- a/tensorflow/contrib/lite/kernels/eigen_support.h +++ b/tensorflow/contrib/lite/kernels/eigen_support.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_EIGEN_SUPPORT_H_ #define TENSORFLOW_CONTRIB_LITE_KERNELS_EIGEN_SUPPORT_H_ -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" namespace EigenForTFLite { struct ThreadPoolDevice; diff --git a/tensorflow/contrib/lite/kernels/elementwise.cc b/tensorflow/contrib/lite/kernels/elementwise.cc index e19779ea59..04995d70dd 100644 --- a/tensorflow/contrib/lite/kernels/elementwise.cc +++ b/tensorflow/contrib/lite/kernels/elementwise.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include <cmath> -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup.cc b/tensorflow/contrib/lite/kernels/embedding_lookup.cc index b2dff87e62..fe33f98eb0 100644 --- a/tensorflow/contrib/lite/kernels/embedding_lookup.cc +++ b/tensorflow/contrib/lite/kernels/embedding_lookup.cc @@ -37,8 +37,8 @@ limitations under the License. #include <iostream> #include <limits> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc index d3be36993c..aa75b03990 100644 --- a/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc +++ b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc @@ -65,8 +65,8 @@ limitations under the License. #include <algorithm> #include <cmath> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" diff --git a/tensorflow/contrib/lite/kernels/exp.cc b/tensorflow/contrib/lite/kernels/exp.cc index ce03cdfe26..673e7be90a 100644 --- a/tensorflow/contrib/lite/kernels/exp.cc +++ b/tensorflow/contrib/lite/kernels/exp.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include <string.h> #include <vector> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/expand_dims.cc b/tensorflow/contrib/lite/kernels/expand_dims.cc index ed33012864..fa1140b19c 100644 --- a/tensorflow/contrib/lite/kernels/expand_dims.cc +++ b/tensorflow/contrib/lite/kernels/expand_dims.cc @@ -15,8 +15,8 @@ limitations under the License. ==============================================================================*/ #include <string.h> #include <vector> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/expand_dims_test.cc b/tensorflow/contrib/lite/kernels/expand_dims_test.cc index 50dc860e5a..a3bc1813db 100644 --- a/tensorflow/contrib/lite/kernels/expand_dims_test.cc +++ b/tensorflow/contrib/lite/kernels/expand_dims_test.cc @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include <gtest/gtest.h> -#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" #include "tensorflow/contrib/lite/interpreter.h" #include "tensorflow/contrib/lite/kernels/register.h" #include "tensorflow/contrib/lite/kernels/test_util.h" diff --git a/tensorflow/contrib/lite/kernels/fake_quant.cc b/tensorflow/contrib/lite/kernels/fake_quant.cc index 0ef1a50b30..f9bc3747cb 100644 --- a/tensorflow/contrib/lite/kernels/fake_quant.cc +++ b/tensorflow/contrib/lite/kernels/fake_quant.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include <string.h> #include <vector> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/floor.cc b/tensorflow/contrib/lite/kernels/floor.cc index f7d5f5146d..59ff77f35b 100644 --- a/tensorflow/contrib/lite/kernels/floor.cc +++ b/tensorflow/contrib/lite/kernels/floor.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/floor_div.cc b/tensorflow/contrib/lite/kernels/floor_div.cc index 75cf19a5a7..5d62cd2755 100644 --- a/tensorflow/contrib/lite/kernels/floor_div.cc +++ b/tensorflow/contrib/lite/kernels/floor_div.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/fully_connected.cc b/tensorflow/contrib/lite/kernels/fully_connected.cc index eaf5a67d67..7a71fcc219 100644 --- a/tensorflow/contrib/lite/kernels/fully_connected.cc +++ b/tensorflow/contrib/lite/kernels/fully_connected.cc @@ -20,8 +20,8 @@ limitations under the License. #include <iostream> #include <limits> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" #include "tensorflow/contrib/lite/kernels/gemm_support.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" diff --git a/tensorflow/contrib/lite/kernels/gather.cc b/tensorflow/contrib/lite/kernels/gather.cc index 2b2a9e6620..badd2de11a 100644 --- a/tensorflow/contrib/lite/kernels/gather.cc +++ b/tensorflow/contrib/lite/kernels/gather.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include <string.h> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/gather_test.cc b/tensorflow/contrib/lite/kernels/gather_test.cc index 1d4292955c..1b48884e09 100644 --- a/tensorflow/contrib/lite/kernels/gather_test.cc +++ b/tensorflow/contrib/lite/kernels/gather_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include <gtest/gtest.h> -#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" #include "tensorflow/contrib/lite/interpreter.h" #include "tensorflow/contrib/lite/kernels/register.h" #include "tensorflow/contrib/lite/kernels/test_util.h" diff --git a/tensorflow/contrib/lite/kernels/gemm_support.h b/tensorflow/contrib/lite/kernels/gemm_support.h index 37af772c68..43cd2b3055 100644 --- a/tensorflow/contrib/lite/kernels/gemm_support.h +++ b/tensorflow/contrib/lite/kernels/gemm_support.h @@ -16,7 +16,7 @@ limitations under the License. #define TENSORFLOW_CONTRIB_LITE_KERNELS_GEMM_SUPPORT_H_ #include "public/gemmlowp.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" namespace tflite { namespace gemm_support { diff --git a/tensorflow/contrib/lite/kernels/hashtable_lookup.cc b/tensorflow/contrib/lite/kernels/hashtable_lookup.cc index f37c66acb3..c0b3c3c0c5 100644 --- a/tensorflow/contrib/lite/kernels/hashtable_lookup.cc +++ b/tensorflow/contrib/lite/kernels/hashtable_lookup.cc @@ -39,8 +39,8 @@ limitations under the License. #include <iostream> #include <limits> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" #include "tensorflow/contrib/lite/string_util.h" diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD index 464163bd78..a6fd4ac2dd 100644 --- a/tensorflow/contrib/lite/kernels/internal/BUILD +++ b/tensorflow/contrib/lite/kernels/internal/BUILD @@ -163,7 +163,7 @@ cc_library( ":tensor_utils", "//third_party/eigen3", "@gemmlowp", - "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite/c:c_api_internal", ] + select({ ":haswell": tflite_deps_intel, ":ios_x86_64": tflite_deps_intel, @@ -198,7 +198,7 @@ cc_library( ":round", "//third_party/eigen3", "@gemmlowp", - "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite/c:c_api_internal", ] + select({ ":haswell": tflite_deps_intel, ":ios_x86_64": tflite_deps_intel, @@ -220,13 +220,15 @@ cc_library( "optimized/eigen_spatial_convolutions.h", "optimized/eigen_tensor_reduced_instantiations_oss.h", "optimized/multithreaded_conv.h", + # FIXME(petewarden) - This should be removed, since it's a header from the + # :tensor dependency below. "tensor.h", ], deps = [ ":optimized_base", + ":tensor", ":types", - "//tensorflow/contrib/lite:builtin_op_data", - "//tensorflow/contrib/lite:context", + "//tensorflow/contrib/lite/c:c_api_internal", "//third_party/eigen3", ], ) @@ -236,7 +238,7 @@ cc_test( srcs = ["tensor_test.cc"], tags = ["no_oss"], deps = [ - ":reference", + ":tensor", "@com_google_googletest//:gtest", ], ) @@ -296,7 +298,7 @@ cc_library( ":strided_slice_logic", ":types", "@gemmlowp", - "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite/c:c_api_internal", ] + select({ ":haswell": tflite_deps_intel, ":ios_x86_64": tflite_deps_intel, @@ -326,7 +328,7 @@ cc_library( ":strided_slice_logic", ":types", "@gemmlowp", - "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite/c:c_api_internal", ] + select({ ":haswell": tflite_deps_intel, ":ios_x86_64": tflite_deps_intel, @@ -341,11 +343,27 @@ cc_library( ) cc_library( + name = "tensor", + hdrs = [ + "tensor.h", + "tensor_ctypes.h", + ], + deps = [ + ":types", + "//tensorflow/contrib/lite/c:c_api_internal", + ], +) + +# Deprecated version of :tensor, kept for backwards compatibility. +cc_library( name = "reference", - hdrs = ["tensor.h"], + hdrs = [ + "tensor.h", + "tensor_ctypes.h", + ], deps = [ ":types", - "//tensorflow/contrib/lite:context", + "//tensorflow/contrib/lite/c:c_api_internal", ], ) @@ -359,7 +377,7 @@ cc_library( ], deps = [ ":round", - "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite/kernels:activation_functor", "//tensorflow/contrib/lite/kernels:op_macros", ], @@ -384,7 +402,7 @@ cc_library( ":cpu_check", ":round", ":types", - "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite/kernels:activation_functor", "//tensorflow/contrib/lite/kernels:op_macros", "@arm_neon_2_x86_sse", @@ -398,7 +416,7 @@ cc_library( hdrs = ["kernel_utils.h"], deps = [ ":tensor_utils", - "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite/c:c_api_internal", ], ) @@ -441,7 +459,7 @@ cc_library( copts = NEON_FLAGS_IF_APPLICABLE, deps = [ "//tensorflow/contrib/lite/kernels:activation_functor", - "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite/c:c_api_internal", "@arm_neon_2_x86_sse", "@gemmlowp", ] + select({ @@ -517,7 +535,7 @@ cc_test( ], deps = [ ":tensor_utils", - "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite/kernels:test_util", "@com_google_googletest//:gtest_main", ], diff --git a/tensorflow/contrib/lite/kernels/internal/common.h b/tensorflow/contrib/lite/kernels/internal/common.h index eb4d0108bd..e67fee11b8 100644 --- a/tensorflow/contrib/lite/kernels/internal/common.h +++ b/tensorflow/contrib/lite/kernels/internal/common.h @@ -45,7 +45,7 @@ limitations under the License. #endif #endif -#include "public/gemmlowp.h" +#include "fixedpoint/fixedpoint.h" #include "tensorflow/contrib/lite/kernels/internal/types.h" namespace tflite { diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc index b9dd40ddf9..56e9367878 100644 --- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc +++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc @@ -14,8 +14,6 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" -#include <algorithm> - #include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" namespace tflite { diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h index 215ad04add..b5558cce55 100644 --- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_ #define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_ -#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" namespace tflite { namespace kernel_utils { diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h index 921aae1303..5fb31889fe 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h @@ -26,7 +26,7 @@ limitations under the License. #include <tuple> #include <type_traits> -#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" #include "tensorflow/contrib/lite/kernels/internal/common.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc index 70b6994a2b..27418178fd 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc +++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc @@ -15,7 +15,7 @@ limitations under the License. #include <stdlib.h> #include <string.h> -#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" #include "tensorflow/contrib/lite/kernels/internal/common.h" #include "tensorflow/contrib/lite/kernels/internal/compatibility.h" diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h index 5ca1b4b76f..630a6bbf29 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h @@ -17,7 +17,7 @@ limitations under the License. // TODO(ghodrat): Remove this header file and the dependency to internal data // structure. -#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h" diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h index 7e53dc2fa2..f87760a6c3 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h @@ -17,7 +17,7 @@ limitations under the License. // TODO(ghodrat): Remove this header file and the dependency to internal data // structure. -#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" #if defined(_MSC_VER) #define __restrict__ __restrict diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc index 2a30910c3f..77e60adc18 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc +++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc @@ -16,7 +16,7 @@ limitations under the License. #include <string.h> #include <algorithm> -#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" #include "tensorflow/contrib/lite/kernels/internal/round.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h index f5b3a84f07..714b1164ee 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h @@ -17,7 +17,7 @@ limitations under the License. // TODO(ghodrat): Remove this header file and the dependency to internal data // structure. -#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" #if defined(_MSC_VER) #define __restrict__ __restrict diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index a027a47726..0abacf85e1 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -3488,8 +3488,7 @@ inline void Gather(const tflite::GatherParams& op_params, const RuntimeShape& input_shape, const T* input_data, const RuntimeShape& coords_shape, const int32* coords_data, const RuntimeShape& output_shape, T* output_data) { - // TODO(b/80418076): Enable these checks when moving legacy ops to - // legacy_reference_ops. + // Enable these checks when moving legacy ops to legacy_reference_ops. // // TFLITE_DCHECK_EQ(coords_shape.DimensionsCount(), 1); const int input_rank = op_params.input_rank; @@ -3808,58 +3807,110 @@ inline void Pad(const tflite::PadParams& op_params, } template <typename T> -inline void StridedSlice(const T* input_data, const Dims<4>& input_dims, - int begin_mask, int end_mask, int shrink_axis_mask, - const std::vector<int>& start_indices, - const std::vector<int>& stop_indices, - const std::vector<int>& strides, T* output_data, - const Dims<4>& output_dims) { - // Note that the axis orders are reversed for runtime ops, so the indices, - // strides and masks must be as well too. - TFLITE_DCHECK_EQ(start_indices.size(), 4); - TFLITE_DCHECK_EQ(stop_indices.size(), 4); - TFLITE_DCHECK_EQ(strides.size(), 4); - const int start_b = strided_slice::StartForAxis(begin_mask, start_indices, - strides, input_dims.sizes, 3); +inline void StridedSlice(const tflite::StridedSliceParams& op_params, + const RuntimeShape& unextended_input_shape, + const T* input_data, + const RuntimeShape& unextended_output_shape, + T* output_data) { + // Note that the output_shape is not used herein. + tflite::StridedSliceParams params_copy = op_params; + + TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); + RuntimeShape input_shape = + RuntimeShape::ExtendedShape(4, unextended_input_shape); + RuntimeShape output_shape = + RuntimeShape::ExtendedShape(4, unextended_output_shape); + + // Reverse and pad to 4 dimensions because that is what the runtime code + // requires (ie. all shapes must be 4D and are given backwards). + strided_slice::StridedSlicePadIndices(¶ms_copy, 4); + + const int start_b = strided_slice::StartForAxis(params_copy, input_shape, 0); const int stop_b = - strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices, - strides, input_dims.sizes, 3, start_b); - const int start_h = strided_slice::StartForAxis(begin_mask, start_indices, - strides, input_dims.sizes, 2); + strided_slice::StopForAxis(params_copy, input_shape, 0, start_b); + const int start_h = strided_slice::StartForAxis(params_copy, input_shape, 1); const int stop_h = - strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices, - strides, input_dims.sizes, 2, start_h); - const int start_w = strided_slice::StartForAxis(begin_mask, start_indices, - strides, input_dims.sizes, 1); + strided_slice::StopForAxis(params_copy, input_shape, 1, start_h); + const int start_w = strided_slice::StartForAxis(params_copy, input_shape, 2); const int stop_w = - strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices, - strides, input_dims.sizes, 1, start_w); - const int start_d = strided_slice::StartForAxis(begin_mask, start_indices, - strides, input_dims.sizes, 0); + strided_slice::StopForAxis(params_copy, input_shape, 2, start_w); + const int start_d = strided_slice::StartForAxis(params_copy, input_shape, 3); const int stop_d = - strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices, - strides, input_dims.sizes, 0, start_d); + strided_slice::StopForAxis(params_copy, input_shape, 3, start_d); T* out_ptr = output_data; for (int in_b = start_b; - !strided_slice::LoopCondition(in_b, stop_b, strides[3]); - in_b += strides[3]) { + !strided_slice::LoopCondition(in_b, stop_b, params_copy.strides[0]); + in_b += params_copy.strides[0]) { for (int in_h = start_h; - !strided_slice::LoopCondition(in_h, stop_h, strides[2]); - in_h += strides[2]) { + !strided_slice::LoopCondition(in_h, stop_h, params_copy.strides[1]); + in_h += params_copy.strides[1]) { for (int in_w = start_w; - !strided_slice::LoopCondition(in_w, stop_w, strides[1]); - in_w += strides[1]) { - for (int in_d = start_d; - !strided_slice::LoopCondition(in_d, stop_d, strides[0]); - in_d += strides[0]) { - *out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)]; + !strided_slice::LoopCondition(in_w, stop_w, params_copy.strides[2]); + in_w += params_copy.strides[2]) { + for (int in_d = start_d; !strided_slice::LoopCondition( + in_d, stop_d, params_copy.strides[3]); + in_d += params_copy.strides[3]) { + *out_ptr++ = input_data[Offset(input_shape, in_b, in_h, in_w, in_d)]; } } } } } +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy. +inline uint32 LegacyReverseBits32(uint32 n) { + n = ((n >> 1) & 0x55555555) | ((n & 0x55555555) << 1); + n = ((n >> 2) & 0x33333333) | ((n & 0x33333333) << 2); + n = ((n >> 4) & 0x0F0F0F0F) | ((n & 0x0F0F0F0F) << 4); + return (((n & 0xFF) << 24) | ((n & 0xFF00) << 8) | ((n & 0xFF0000) >> 8) | + ((n & 0xFF000000) >> 24)); +} + +inline void StridedSliceReverseIndices(tflite::StridedSliceParams* p) { + TFLITE_CHECK_EQ(p->start_indices_count, p->stop_indices_count); + TFLITE_CHECK_EQ(p->stop_indices_count, p->strides_count); + + std::reverse(p->start_indices, p->start_indices + p->start_indices_count); + std::reverse(p->stop_indices, p->stop_indices + p->stop_indices_count); + std::reverse(p->strides, p->strides + p->strides_count); + + p->begin_mask = LegacyReverseBits32(static_cast<uint32>(p->begin_mask)) >> + (32 - p->start_indices_count); + p->ellipsis_mask = + LegacyReverseBits32(static_cast<uint32>(p->ellipsis_mask)) >> + (32 - p->start_indices_count); + p->end_mask = LegacyReverseBits32(static_cast<uint32>(p->end_mask)) >> + (32 - p->start_indices_count); + p->new_axis_mask = + LegacyReverseBits32(static_cast<uint32>(p->new_axis_mask)) >> + (32 - p->start_indices_count); + p->shrink_axis_mask = + LegacyReverseBits32(static_cast<uint32>(p->shrink_axis_mask)) >> + (32 - p->start_indices_count); +} + +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy. +template <typename T> +inline void StridedSlice(const T* input_data, const Dims<4>& input_dims, + int begin_mask, int end_mask, int shrink_axis_mask, + const std::vector<int>& start_indices, + const std::vector<int>& stop_indices, + const std::vector<int>& strides, T* output_data, + const Dims<4>& output_dims) { + TFLITE_DCHECK_EQ(start_indices.size(), 4); + auto op_params = strided_slice::BuildStridedSliceParams( + begin_mask, end_mask, shrink_axis_mask, start_indices, stop_indices, + strides); + StridedSliceReverseIndices(&op_params); + + StridedSlice(op_params, DimsToShape(input_dims), input_data, + DimsToShape(output_dims), output_data); +} + template <typename T> inline void Slice(const tflite::SliceParams& op_params, const RuntimeShape& input_shape, const T* input_data, diff --git a/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h b/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h index 5994fad5c7..af5db1064c 100644 --- a/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h +++ b/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h @@ -19,9 +19,9 @@ limitations under the License. #include <limits> #include <vector> #include "tensorflow/contrib/lite/kernels/internal/compatibility.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" namespace tflite { - namespace strided_slice { // Use until std::clamp() is available from C++17. @@ -32,15 +32,51 @@ inline int Clamp(const int v, const int lo, const int hi) { return v; } +inline void StridedSlicePadIndices(tflite::StridedSliceParams* p, + int dim_count) { + // Add indices and mask bits to fully include extra dimensions + TFLITE_CHECK_LE(dim_count, 4); + TFLITE_CHECK_GE(dim_count, p->start_indices_count); + TFLITE_CHECK_EQ(p->start_indices_count, p->stop_indices_count); + TFLITE_CHECK_EQ(p->stop_indices_count, p->strides_count); + + const int pad_count = dim_count - p->start_indices_count; + + // Pad indices at start, so move arrays by pad_count. + for (int i = p->start_indices_count - 1; i > 0; --i) { + p->strides[i + pad_count] = p->strides[i]; + p->start_indices[i + pad_count] = p->start_indices[i]; + p->stop_indices[i + pad_count] = p->stop_indices[i]; + } + for (int i = 0; i < pad_count; ++i) { + p->start_indices[i] = 0; + p->stop_indices[i] = 0; + p->strides[i] = 1; + } + + // Pad masks with 0s or 1s as required. + p->shrink_axis_mask <<= pad_count; + p->ellipsis_mask <<= pad_count; + p->new_axis_mask <<= pad_count; + p->begin_mask <<= pad_count; + p->end_mask <<= pad_count; + p->begin_mask |= (1 << pad_count) - 1; + p->end_mask |= (1 << pad_count) - 1; + + p->start_indices_count = dim_count; + p->stop_indices_count = dim_count; + p->strides_count = dim_count; +} + // Return the index for the first element along that axis. This index will be a // positive integer between [0, axis_size - 1] that can be used to index // directly into the data. -template <typename IntType> -inline int StartForAxis(int begin_mask, - std::vector<IntType> const& start_indices, - std::vector<IntType> const& strides, - int const* input_shape, int axis) { - // Begin with the specified index +inline int StartForAxis(const tflite::StridedSliceParams& params, + const RuntimeShape& input_shape, int axis) { + const auto begin_mask = params.begin_mask; + const auto* start_indices = params.start_indices; + const auto* strides = params.strides; + // Begin with the specified index. int start = start_indices[axis]; // begin_mask override @@ -57,7 +93,7 @@ inline int StartForAxis(int begin_mask, } // Handle negative indices - int axis_size = input_shape[axis]; + int axis_size = input_shape.Dims(axis); if (start < 0) { start += axis_size; } @@ -73,11 +109,14 @@ inline int StartForAxis(int begin_mask, // element. ie. So if you were iterating through all elements of a 1D array of // size 4, this function would return 4 as the stop, because it is one past the // "real" indices of 0, 1, 2 & 3. -template <typename IntType> -inline int StopForAxis(int end_mask, int shrink_axis_mask, - std::vector<IntType> const& stop_indices, - std::vector<IntType> const& strides, - int const* input_shape, int axis, int start_for_axis) { +inline int StopForAxis(const tflite::StridedSliceParams& params, + const RuntimeShape& input_shape, int axis, + int start_for_axis) { + const auto end_mask = params.end_mask; + const auto shrink_axis_mask = params.shrink_axis_mask; + const auto* stop_indices = params.stop_indices; + const auto* strides = params.strides; + // Begin with the specified index const bool shrink_axis = shrink_axis_mask & (1 << axis); int stop = stop_indices[axis]; @@ -103,7 +142,7 @@ inline int StopForAxis(int end_mask, int shrink_axis_mask, } // Handle negative indices - const int axis_size = input_shape[axis]; + const int axis_size = input_shape.Dims(axis); if (stop < 0) { stop += axis_size; } @@ -127,6 +166,31 @@ inline bool LoopCondition(int index, int stop, int stride) { return stride > 0 ? index >= stop : index <= stop; } +inline tflite::StridedSliceParams BuildStridedSliceParams( + int begin_mask, int end_mask, int shrink_axis_mask, + const std::vector<int>& start_indices, const std::vector<int>& stop_indices, + const std::vector<int>& strides) { + tflite::StridedSliceParams op_params; + const int dims_count = start_indices.size(); + + op_params.start_indices_count = dims_count; + op_params.stop_indices_count = dims_count; + op_params.strides_count = dims_count; + for (int i = 0; i < dims_count; ++i) { + op_params.start_indices[i] = start_indices[i]; + op_params.stop_indices[i] = stop_indices[i]; + op_params.strides[i] = strides[i]; + } + + op_params.begin_mask = begin_mask; + op_params.ellipsis_mask = 0; + op_params.end_mask = end_mask; + op_params.new_axis_mask = 0; + op_params.shrink_axis_mask = shrink_axis_mask; + + return op_params; +} + } // namespace strided_slice } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/tensor.h b/tensorflow/contrib/lite/kernels/internal/tensor.h index ee2af5b460..13106456df 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor.h +++ b/tensorflow/contrib/lite/kernels/internal/tensor.h @@ -17,44 +17,12 @@ limitations under the License. #include <complex> #include <vector> -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/contrib/lite/kernels/internal/types.h" namespace tflite { -template <typename T> -inline T* GetTensorData(TfLiteTensor* tensor); - -template <> -inline float* GetTensorData(TfLiteTensor* tensor) { - return tensor != nullptr ? tensor->data.f : nullptr; -} - -template <> -inline uint8_t* GetTensorData(TfLiteTensor* tensor) { - return tensor != nullptr ? tensor->data.uint8 : nullptr; -} - -template <> -inline int16_t* GetTensorData(TfLiteTensor* tensor) { - return tensor != nullptr ? tensor->data.i16 : nullptr; -} - -template <> -inline int32_t* GetTensorData(TfLiteTensor* tensor) { - return tensor != nullptr ? tensor->data.i32 : nullptr; -} - -template <> -inline int64_t* GetTensorData(TfLiteTensor* tensor) { - return tensor != nullptr ? tensor->data.i64 : nullptr; -} - -template <> -inline bool* GetTensorData(TfLiteTensor* tensor) { - return tensor != nullptr ? tensor->data.b : nullptr; -} - template <> inline std::complex<float>* GetTensorData(TfLiteTensor* tensor) { return tensor != nullptr @@ -62,39 +30,6 @@ inline std::complex<float>* GetTensorData(TfLiteTensor* tensor) { : nullptr; } -template <typename T> -inline const T* GetTensorData(const TfLiteTensor* tensor); - -template <> -inline const float* GetTensorData(const TfLiteTensor* tensor) { - return tensor != nullptr ? tensor->data.f : nullptr; -} - -template <> -inline const uint8_t* GetTensorData(const TfLiteTensor* tensor) { - return tensor != nullptr ? tensor->data.uint8 : nullptr; -} - -template <> -inline const int16_t* GetTensorData(const TfLiteTensor* tensor) { - return tensor != nullptr ? tensor->data.i16 : nullptr; -} - -template <> -inline const int32_t* GetTensorData(const TfLiteTensor* tensor) { - return tensor != nullptr ? tensor->data.i32 : nullptr; -} - -template <> -inline const int64_t* GetTensorData(const TfLiteTensor* tensor) { - return tensor != nullptr ? tensor->data.i64 : nullptr; -} - -template <> -inline const bool* GetTensorData(const TfLiteTensor* tensor) { - return tensor != nullptr ? tensor->data.b : nullptr; -} - template <> inline const std::complex<float>* GetTensorData(const TfLiteTensor* tensor) { return tensor != nullptr @@ -102,56 +37,14 @@ inline const std::complex<float>* GetTensorData(const TfLiteTensor* tensor) { : nullptr; } -inline int RemapDim(int max_dimensions, int d) { - return max_dimensions - d - 1; -} - -// TODO(ahentz): the implementations in kernels/internal/ take a Dims<4> object -// even if the original tensors were not 4D. We should consider rewriting them -// to take a more generic 'shape' object. -inline Dims<4> GetTensorDims(const int data[], const int size) { - Dims<4> d; - for (int i = 0; i < 4; ++i) { - int src = size - i - 1; - if (src >= 0) { - d.sizes[i] = data[src]; - } else { - d.sizes[i] = 1; - } - } - d.strides[0] = 1; - for (int i = 1; i < 4; i++) { - d.strides[i] = d.strides[i - 1] * d.sizes[i - 1]; - } - return d; -} - inline Dims<4> GetTensorDims(std::vector<int32_t> data) { return GetTensorDims(data.data(), data.size()); } -inline Dims<4> GetTensorDims(const TfLiteTensor* tensor) { - if (tensor == nullptr) { - return Dims<4>(); - } - - auto* dims = tensor->dims; - return GetTensorDims(dims->data, dims->size); -} - inline RuntimeShape GetTensorShape(std::vector<int32_t> data) { return RuntimeShape(data.size(), data.data()); } -inline RuntimeShape GetTensorShape(const TfLiteTensor* tensor) { - if (tensor == nullptr) { - return RuntimeShape(); - } - - auto* dims = tensor->dims; - return RuntimeShape(dims->size, dims->data); -} - // A list of tensors in a format that can be used by kernels like split and // concatenation. template <typename T> diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h b/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h new file mode 100644 index 0000000000..77e22a08b4 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h @@ -0,0 +1,135 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_CTYPES_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_CTYPES_H_ + +#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { + +template <typename T> +inline T* GetTensorData(TfLiteTensor* tensor); + +template <> +inline float* GetTensorData(TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.f : nullptr; +} + +template <> +inline uint8_t* GetTensorData(TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.uint8 : nullptr; +} + +template <> +inline int16_t* GetTensorData(TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.i16 : nullptr; +} + +template <> +inline int32_t* GetTensorData(TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.i32 : nullptr; +} + +template <> +inline int64_t* GetTensorData(TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.i64 : nullptr; +} + +template <> +inline bool* GetTensorData(TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.b : nullptr; +} + +template <typename T> +inline const T* GetTensorData(const TfLiteTensor* tensor); + +template <> +inline const float* GetTensorData(const TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.f : nullptr; +} + +template <> +inline const uint8_t* GetTensorData(const TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.uint8 : nullptr; +} + +template <> +inline const int16_t* GetTensorData(const TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.i16 : nullptr; +} + +template <> +inline const int32_t* GetTensorData(const TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.i32 : nullptr; +} + +template <> +inline const int64_t* GetTensorData(const TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.i64 : nullptr; +} + +template <> +inline const bool* GetTensorData(const TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.b : nullptr; +} + +inline int RemapDim(int max_dimensions, int d) { + return max_dimensions - d - 1; +} + +// TODO(ahentz): the implementations in kernels/internal/ take a Dims<4> object +// even if the original tensors were not 4D. We should consider rewriting them +// to take a more generic 'shape' object. +inline Dims<4> GetTensorDims(const int data[], const int size) { + Dims<4> d; + for (int i = 0; i < 4; ++i) { + int src = size - i - 1; + if (src >= 0) { + d.sizes[i] = data[src]; + } else { + d.sizes[i] = 1; + } + } + d.strides[0] = 1; + for (int i = 1; i < 4; i++) { + d.strides[i] = d.strides[i - 1] * d.sizes[i - 1]; + } + return d; +} + +inline Dims<4> GetTensorDims(const TfLiteTensor* tensor) { + if (tensor == nullptr) { + return Dims<4>(); + } + + auto* dims = tensor->dims; + return GetTensorDims(dims->data, dims->size); +} + +inline RuntimeShape GetTensorShape(const TfLiteTensor* tensor) { + if (tensor == nullptr) { + return RuntimeShape(); + } + + TfLiteIntArray* dims = tensor->dims; + const int dims_size = dims->size; + const int32_t* dims_data = dims->data; + return RuntimeShape(dims_size, dims_data); +} + +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_CTYPES_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h index 1439bf8c37..b0fe5adf65 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_ #define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_ -#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" #if defined(_MSC_VER) #define __restrict__ __restrict diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc index dad924fc28..6458af714b 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc +++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" #include <gmock/gmock.h> -#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" #include "tensorflow/contrib/lite/kernels/test_util.h" namespace tflite { diff --git a/tensorflow/contrib/lite/kernels/kernel_util.h b/tensorflow/contrib/lite/kernels/kernel_util.h index ed46cd984f..e9a5fd7a40 100644 --- a/tensorflow/contrib/lite/kernels/kernel_util.h +++ b/tensorflow/contrib/lite/kernels/kernel_util.h @@ -16,9 +16,10 @@ limitations under the License. #define TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_ #include <algorithm> +#include <limits> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" namespace tflite { diff --git a/tensorflow/contrib/lite/kernels/l2norm.cc b/tensorflow/contrib/lite/kernels/l2norm.cc index 5b3536de0c..e02d7df9ef 100644 --- a/tensorflow/contrib/lite/kernels/l2norm.cc +++ b/tensorflow/contrib/lite/kernels/l2norm.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" diff --git a/tensorflow/contrib/lite/kernels/local_response_norm.cc b/tensorflow/contrib/lite/kernels/local_response_norm.cc index 799c1528bd..334d2a2788 100644 --- a/tensorflow/contrib/lite/kernels/local_response_norm.cc +++ b/tensorflow/contrib/lite/kernels/local_response_norm.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" diff --git a/tensorflow/contrib/lite/kernels/logical.cc b/tensorflow/contrib/lite/kernels/logical.cc index c71f3b4701..f770cb35d1 100644 --- a/tensorflow/contrib/lite/kernels/logical.cc +++ b/tensorflow/contrib/lite/kernels/logical.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/lsh_projection.cc b/tensorflow/contrib/lite/kernels/lsh_projection.cc index 69523b02cc..9fa1c5f100 100644 --- a/tensorflow/contrib/lite/kernels/lsh_projection.cc +++ b/tensorflow/contrib/lite/kernels/lsh_projection.cc @@ -59,8 +59,8 @@ limitations under the License. #include <limits> #include <memory> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" #include <farmhash.h> diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc index 74dc3f25f9..aaa3ce966e 100644 --- a/tensorflow/contrib/lite/kernels/lstm.cc +++ b/tensorflow/contrib/lite/kernels/lstm.cc @@ -20,8 +20,8 @@ limitations under the License. #include <iostream> #include <limits> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" #include "tensorflow/contrib/lite/kernels/gemm_support.h" #include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" diff --git a/tensorflow/contrib/lite/kernels/maximum_minimum.cc b/tensorflow/contrib/lite/kernels/maximum_minimum.cc index 0308a3976a..7cb01465ee 100644 --- a/tensorflow/contrib/lite/kernels/maximum_minimum.cc +++ b/tensorflow/contrib/lite/kernels/maximum_minimum.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include <string.h> #include <vector> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/mfcc.cc b/tensorflow/contrib/lite/kernels/mfcc.cc index 306f676619..66cf147d75 100644 --- a/tensorflow/contrib/lite/kernels/mfcc.cc +++ b/tensorflow/contrib/lite/kernels/mfcc.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/kernels/internal/mfcc.h" #include "flatbuffers/flexbuffers.h" // flatbuffers -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/mfcc_dct.h" #include "tensorflow/contrib/lite/kernels/internal/mfcc_mel_filterbank.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" diff --git a/tensorflow/contrib/lite/kernels/mul.cc b/tensorflow/contrib/lite/kernels/mul.cc index 92d8bc8b67..e0aac8a842 100644 --- a/tensorflow/contrib/lite/kernels/mul.cc +++ b/tensorflow/contrib/lite/kernels/mul.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" diff --git a/tensorflow/contrib/lite/kernels/neg.cc b/tensorflow/contrib/lite/kernels/neg.cc index 4124c05388..0ddd0644f5 100644 --- a/tensorflow/contrib/lite/kernels/neg.cc +++ b/tensorflow/contrib/lite/kernels/neg.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" namespace tflite { diff --git a/tensorflow/contrib/lite/kernels/one_hot.cc b/tensorflow/contrib/lite/kernels/one_hot.cc index 9ff3dca932..910aed6f14 100644 --- a/tensorflow/contrib/lite/kernels/one_hot.cc +++ b/tensorflow/contrib/lite/kernels/one_hot.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" diff --git a/tensorflow/contrib/lite/kernels/pack.cc b/tensorflow/contrib/lite/kernels/pack.cc index cc326a7d51..4cb98fdd19 100644 --- a/tensorflow/contrib/lite/kernels/pack.cc +++ b/tensorflow/contrib/lite/kernels/pack.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/pad.cc b/tensorflow/contrib/lite/kernels/pad.cc index 3bce05353d..0d939405f6 100644 --- a/tensorflow/contrib/lite/kernels/pad.cc +++ b/tensorflow/contrib/lite/kernels/pad.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include <string.h> #include <vector> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" diff --git a/tensorflow/contrib/lite/kernels/padding.h b/tensorflow/contrib/lite/kernels/padding.h index 3cb55f19a9..42b6b45d3b 100644 --- a/tensorflow/contrib/lite/kernels/padding.h +++ b/tensorflow/contrib/lite/kernels/padding.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_ #define TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_ -#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" namespace tflite { diff --git a/tensorflow/contrib/lite/kernels/pooling.cc b/tensorflow/contrib/lite/kernels/pooling.cc index 29a5be0683..6451142391 100644 --- a/tensorflow/contrib/lite/kernels/pooling.cc +++ b/tensorflow/contrib/lite/kernels/pooling.cc @@ -19,8 +19,8 @@ limitations under the License. #include <iostream> #include <limits> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" diff --git a/tensorflow/contrib/lite/kernels/pow.cc b/tensorflow/contrib/lite/kernels/pow.cc index d676de5b1d..1e96cc80b1 100644 --- a/tensorflow/contrib/lite/kernels/pow.cc +++ b/tensorflow/contrib/lite/kernels/pow.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/reduce.cc b/tensorflow/contrib/lite/kernels/reduce.cc index ca83797936..d94d821e87 100644 --- a/tensorflow/contrib/lite/kernels/reduce.cc +++ b/tensorflow/contrib/lite/kernels/reduce.cc @@ -15,8 +15,8 @@ limitations under the License. #include <string.h> #include <limits> #include <vector> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" diff --git a/tensorflow/contrib/lite/kernels/register.h b/tensorflow/contrib/lite/kernels/register.h index 0296152d68..61856ab9de 100644 --- a/tensorflow/contrib/lite/kernels/register.h +++ b/tensorflow/contrib/lite/kernels/register.h @@ -16,8 +16,9 @@ limitations under the License. #define TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_ #include <unordered_map> -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/mutable_op_resolver.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/reshape.cc b/tensorflow/contrib/lite/kernels/reshape.cc index 49ba0571e2..f41147b2d6 100644 --- a/tensorflow/contrib/lite/kernels/reshape.cc +++ b/tensorflow/contrib/lite/kernels/reshape.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include <string.h> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear.cc b/tensorflow/contrib/lite/kernels/resize_bilinear.cc index dafa3aebab..fb045d15f3 100644 --- a/tensorflow/contrib/lite/kernels/resize_bilinear.cc +++ b/tensorflow/contrib/lite/kernels/resize_bilinear.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" diff --git a/tensorflow/contrib/lite/kernels/select.cc b/tensorflow/contrib/lite/kernels/select.cc index 3cdb5db209..3959502d91 100644 --- a/tensorflow/contrib/lite/kernels/select.cc +++ b/tensorflow/contrib/lite/kernels/select.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/shape.cc b/tensorflow/contrib/lite/kernels/shape.cc index dbcd2ef004..66d4c9e5c1 100644 --- a/tensorflow/contrib/lite/kernels/shape.cc +++ b/tensorflow/contrib/lite/kernels/shape.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" diff --git a/tensorflow/contrib/lite/kernels/skip_gram.cc b/tensorflow/contrib/lite/kernels/skip_gram.cc index c90a15b3a2..de80a4016e 100644 --- a/tensorflow/contrib/lite/kernels/skip_gram.cc +++ b/tensorflow/contrib/lite/kernels/skip_gram.cc @@ -33,8 +33,8 @@ limitations under the License. #include <string> #include <vector> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" #include "tensorflow/contrib/lite/string_util.h" diff --git a/tensorflow/contrib/lite/kernels/slice.cc b/tensorflow/contrib/lite/kernels/slice.cc index 55e16506df..ccfee41b9c 100644 --- a/tensorflow/contrib/lite/kernels/slice.cc +++ b/tensorflow/contrib/lite/kernels/slice.cc @@ -16,8 +16,8 @@ limitations under the License. #include <string.h> #include <cmath> #include <vector> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc index 8332ae32cf..3a10d2e60c 100644 --- a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc +++ b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include <string.h> #include <vector> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" diff --git a/tensorflow/contrib/lite/kernels/space_to_depth.cc b/tensorflow/contrib/lite/kernels/space_to_depth.cc index 9238e879f8..64c56c017b 100644 --- a/tensorflow/contrib/lite/kernels/space_to_depth.cc +++ b/tensorflow/contrib/lite/kernels/space_to_depth.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" diff --git a/tensorflow/contrib/lite/kernels/sparse_to_dense.cc b/tensorflow/contrib/lite/kernels/sparse_to_dense.cc index fec2a6f0d9..178568e07c 100644 --- a/tensorflow/contrib/lite/kernels/sparse_to_dense.cc +++ b/tensorflow/contrib/lite/kernels/sparse_to_dense.cc @@ -19,8 +19,8 @@ limitations under the License. #include <iostream> #include <limits> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/split.cc b/tensorflow/contrib/lite/kernels/split.cc index b144486041..719e2dc606 100644 --- a/tensorflow/contrib/lite/kernels/split.cc +++ b/tensorflow/contrib/lite/kernels/split.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include <string.h> #include <vector> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" diff --git a/tensorflow/contrib/lite/kernels/squeeze.cc b/tensorflow/contrib/lite/kernels/squeeze.cc index 09a5662fd9..080c51cd18 100644 --- a/tensorflow/contrib/lite/kernels/squeeze.cc +++ b/tensorflow/contrib/lite/kernels/squeeze.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include <string.h> #include <vector> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" diff --git a/tensorflow/contrib/lite/kernels/strided_slice.cc b/tensorflow/contrib/lite/kernels/strided_slice.cc index bed2117f9a..87ffcc4110 100644 --- a/tensorflow/contrib/lite/kernels/strided_slice.cc +++ b/tensorflow/contrib/lite/kernels/strided_slice.cc @@ -15,8 +15,8 @@ limitations under the License. #include <string.h> #include <cmath> #include <vector> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/sub.cc b/tensorflow/contrib/lite/kernels/sub.cc index 77a1f59689..1be0c83f17 100644 --- a/tensorflow/contrib/lite/kernels/sub.cc +++ b/tensorflow/contrib/lite/kernels/sub.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" diff --git a/tensorflow/contrib/lite/kernels/svdf.cc b/tensorflow/contrib/lite/kernels/svdf.cc index 6ba7959752..9903fd5c35 100644 --- a/tensorflow/contrib/lite/kernels/svdf.cc +++ b/tensorflow/contrib/lite/kernels/svdf.cc @@ -23,8 +23,8 @@ limitations under the License. #include <iostream> #include <limits> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" #include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/tile.cc b/tensorflow/contrib/lite/kernels/tile.cc index 5181a8f89a..49421eb870 100644 --- a/tensorflow/contrib/lite/kernels/tile.cc +++ b/tensorflow/contrib/lite/kernels/tile.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include <string.h> #include <vector> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/tile_test.cc b/tensorflow/contrib/lite/kernels/tile_test.cc index 4f78c224e5..e73ca7b750 100644 --- a/tensorflow/contrib/lite/kernels/tile_test.cc +++ b/tensorflow/contrib/lite/kernels/tile_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include <gtest/gtest.h> -#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" #include "tensorflow/contrib/lite/interpreter.h" #include "tensorflow/contrib/lite/kernels/register.h" #include "tensorflow/contrib/lite/kernels/test_util.h" diff --git a/tensorflow/contrib/lite/kernels/topk_v2.cc b/tensorflow/contrib/lite/kernels/topk_v2.cc index 2dd760bbfe..6c38b6739e 100644 --- a/tensorflow/contrib/lite/kernels/topk_v2.cc +++ b/tensorflow/contrib/lite/kernels/topk_v2.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include <algorithm> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" diff --git a/tensorflow/contrib/lite/kernels/topk_v2_test.cc b/tensorflow/contrib/lite/kernels/topk_v2_test.cc index 2abb89b617..16106fdafe 100644 --- a/tensorflow/contrib/lite/kernels/topk_v2_test.cc +++ b/tensorflow/contrib/lite/kernels/topk_v2_test.cc @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include <gtest/gtest.h> -#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" #include "tensorflow/contrib/lite/interpreter.h" #include "tensorflow/contrib/lite/kernels/register.h" #include "tensorflow/contrib/lite/kernels/test_util.h" diff --git a/tensorflow/contrib/lite/kernels/transpose.cc b/tensorflow/contrib/lite/kernels/transpose.cc index 800b0563d7..95359962e0 100644 --- a/tensorflow/contrib/lite/kernels/transpose.cc +++ b/tensorflow/contrib/lite/kernels/transpose.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include <string.h> #include <vector> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/transpose_conv.cc b/tensorflow/contrib/lite/kernels/transpose_conv.cc index a9baa5c698..6f2d98ede8 100644 --- a/tensorflow/contrib/lite/kernels/transpose_conv.cc +++ b/tensorflow/contrib/lite/kernels/transpose_conv.cc @@ -19,8 +19,8 @@ limitations under the License. #include <iostream> #include <limits> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc index c678f14930..63817bd886 100644 --- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc @@ -20,8 +20,8 @@ limitations under the License. #include <iostream> #include <limits> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" #include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" #include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc index 0180c2c498..744ee7c109 100644 --- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc @@ -19,8 +19,8 @@ limitations under the License. #include <iostream> #include <limits> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" #include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/unpack.cc b/tensorflow/contrib/lite/kernels/unpack.cc index 4998f88b41..9ff06f8331 100644 --- a/tensorflow/contrib/lite/kernels/unpack.cc +++ b/tensorflow/contrib/lite/kernels/unpack.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/memory_planner.h b/tensorflow/contrib/lite/memory_planner.h index 0294ec815c..2d4707f849 100644 --- a/tensorflow/contrib/lite/memory_planner.h +++ b/tensorflow/contrib/lite/memory_planner.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_MEMORY_PLANNER_H_ #define TENSORFLOW_CONTRIB_LITE_MEMORY_PLANNER_H_ -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" namespace tflite { diff --git a/tensorflow/contrib/lite/mmap_allocation.cc b/tensorflow/contrib/lite/mmap_allocation.cc index fa9a3cd1d8..92934d1fd1 100644 --- a/tensorflow/contrib/lite/mmap_allocation.cc +++ b/tensorflow/contrib/lite/mmap_allocation.cc @@ -20,7 +20,7 @@ limitations under the License. #include <unistd.h> #include "tensorflow/contrib/lite/allocation.h" -#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/core/api/error_reporter.h" namespace tflite { diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index aa410ab002..241865b3d8 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -20,8 +20,9 @@ limitations under the License. #include <sys/types.h> #include "tensorflow/contrib/lite/allocation.h" -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/core/api/error_reporter.h" +#include "tensorflow/contrib/lite/core/api/flatbuffer_conversions.h" #include "tensorflow/contrib/lite/model.h" #ifndef TFLITE_MCU #include "tensorflow/contrib/lite/nnapi_delegate.h" @@ -42,41 +43,6 @@ ErrorReporter* ValidateErrorReporter(ErrorReporter* e) { const char* kEmptyTensorName = ""; -TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type, - ErrorReporter* error_reporter) { - switch (tensor_type) { - case TensorType_FLOAT32: - *type = kTfLiteFloat32; - break; - case TensorType_INT16: - *type = kTfLiteInt16; - break; - case TensorType_INT32: - *type = kTfLiteInt32; - break; - case TensorType_UINT8: - *type = kTfLiteUInt8; - break; - case TensorType_INT64: - *type = kTfLiteInt64; - break; - case TensorType_STRING: - *type = kTfLiteString; - break; - case TensorType_BOOL: - *type = kTfLiteBool; - break; - case TensorType_COMPLEX64: - *type = kTfLiteComplex64; - break; - default: - error_reporter->Report("Unimplemented data type %s (%d) in tensor\n", - EnumNameTensorType(tensor_type), tensor_type); - return kTfLiteError; - } - return kTfLiteOk; -} - #ifndef TFLITE_MCU // Loads a model from `filename`. If `mmap_file` is true then use mmap, // otherwise make a copy of the model in a buffer. @@ -198,39 +164,10 @@ TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() { auto opcodes = model_->operator_codes(); for (const OperatorCode* opcode : *opcodes) { const TfLiteRegistration* registration = nullptr; - auto builtin_code = opcode->builtin_code(); - int version = opcode->version(); - - if (builtin_code > BuiltinOperator_MAX || - builtin_code < BuiltinOperator_MIN) { - error_reporter_->Report( - "Op builtin_code out or range: %d. Are you using old TFLite binary " - "with newer model?", - builtin_code); - status = kTfLiteError; - } else if (builtin_code != BuiltinOperator_CUSTOM) { - registration = op_resolver_.FindOp(builtin_code, version); - if (registration == nullptr) { - error_reporter_->Report( - "Didn't find op for builtin opcode '%s' version '%d'\n", - EnumNameBuiltinOperator(builtin_code), version); - status = kTfLiteError; - } - } else if (!opcode->custom_code()) { - error_reporter_->Report( - "Operator with CUSTOM builtin_code has no custom_code.\n"); - status = kTfLiteError; - } else { - const char* name = opcode->custom_code()->c_str(); - registration = op_resolver_.FindOp(name, version); - flatbuffer_op_index_to_registration_types_.push_back( - BuiltinOperator_CUSTOM); - if (registration == nullptr) { - error_reporter_->Report( - "Didn't find custom op for name '%s' with version %d\n", name, - version); - status = kTfLiteError; - } + status = GetRegistrationFromOpCode(opcode, op_resolver_, error_reporter_, + ®istration); + if (status != kTfLiteOk) { + return status; } flatbuffer_op_index_to_registration_.push_back(registration); } @@ -247,565 +184,6 @@ std::vector<int> FlatBufferIntArrayToVector(T* flat_array) { return ret; } -// Copies the contents from the flatbuffer int vector `flatbuffer` into the -// int array `buffer`. `flat_vector` and `buffer` represent the same -// configuration operation for a given operation. -void FlatBufferIntVectorToArray(int max_size_of_buffer, - const flatbuffers::Vector<int32_t>* flat_vector, - int* buffer, ErrorReporter* error_reporter) { - if (!flat_vector) { - error_reporter->Report("Input array not provided for operation.\n"); - } else { - int num_dimensions = flat_vector->Length(); - if (num_dimensions > max_size_of_buffer / sizeof(int)) { - error_reporter->Report( - "Found too many dimensions in the operation's input array.\n"); - } else { - for (int i = 0; i < num_dimensions; ++i) { - buffer[i] = flat_vector->Get(i); - } - } - } -} - -// Allocate a structure using C malloc, but make sure the structure is a -// POD structure that doesn't require constructors to run. The reason we do -// this, is that Interpreter's C extension part will take ownership and wants -// to use malloc() and free(). -template <class T> -T* MallocPOD() { - static_assert(std::is_pod<T>::value, "Builtin data structure must be POD."); - return static_cast<T*>(malloc(sizeof(T))); -} - -// Parse the appropriate data out of the op. -// -// This handles builtin data explicitly as there are flatbuffer schemas. -// If it returns kTfLiteOk, it passes the data out with `builtin_data`, which -// need to be released by calling `free`.` -// If it returns kTfLiteError, `builtin_data` will be `nullptr`. -TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, - ErrorReporter* error_reporter, void** builtin_data) { - auto parse_padding = [](Padding padding) { - switch (padding) { - case Padding_SAME: - return kTfLitePaddingSame; - case Padding_VALID: - return kTfLitePaddingValid; - } - return kTfLitePaddingUnknown; - }; - auto parse_activation = [](ActivationFunctionType activation) { - switch (activation) { - case ActivationFunctionType_NONE: - return kTfLiteActNone; - case ActivationFunctionType_RELU: - return kTfLiteActRelu; - case ActivationFunctionType_RELU_N1_TO_1: - return kTfLiteActRelu1; - case ActivationFunctionType_RELU6: - return kTfLiteActRelu6; - case ActivationFunctionType_TANH: - return kTfLiteActTanh; - case ActivationFunctionType_SIGN_BIT: - return kTfLiteActSignBit; - } - return kTfLiteActNone; - }; - auto parseLSHProjectionType = [](LSHProjectionType type) { - switch (type) { - case LSHProjectionType_SPARSE: - return kTfLiteLshProjectionSparse; - case LSHProjectionType_DENSE: - return kTfLiteLshProjectionDense; - default: - return kTfLiteLshProjectionUnknown; - } - }; - auto parseCombinerType = [](CombinerType type) { - switch (type) { - case CombinerType_MEAN: - return kTfLiteCombinerTypeMean; - case CombinerType_SQRTN: - return kTfLiteCombinerTypeSqrtn; - case CombinerType_SUM: - default: - return kTfLiteCombinerTypeSum; - } - }; - - *builtin_data = nullptr; - switch (op_type) { - case BuiltinOperator_CONV_2D: { - TfLiteConvParams* params = MallocPOD<TfLiteConvParams>(); - if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) { - params->padding = parse_padding(conv_params->padding()); - params->stride_width = conv_params->stride_w(); - params->stride_height = conv_params->stride_h(); - params->activation = - parse_activation(conv_params->fused_activation_function()); - - params->dilation_width_factor = conv_params->dilation_w_factor(); - params->dilation_height_factor = conv_params->dilation_h_factor(); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_CAST: { - TfLiteCastParams* params = MallocPOD<TfLiteCastParams>(); - if (auto* schema_params = op->builtin_options_as_CastOptions()) { - auto in_status = - ConvertTensorType(schema_params->in_data_type(), - ¶ms->in_data_type, error_reporter); - auto out_status = - ConvertTensorType(schema_params->out_data_type(), - ¶ms->out_data_type, error_reporter); - if (in_status != kTfLiteOk || out_status != kTfLiteOk) { - free(params); - return kTfLiteError; - } - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_LSH_PROJECTION: { - TfLiteLSHProjectionParams* params = - MallocPOD<TfLiteLSHProjectionParams>(); - if (auto* lshParams = op->builtin_options_as_LSHProjectionOptions()) { - params->type = parseLSHProjectionType(lshParams->type()); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_AVERAGE_POOL_2D: - case BuiltinOperator_MAX_POOL_2D: - case BuiltinOperator_L2_POOL_2D: { - TfLitePoolParams* params = MallocPOD<TfLitePoolParams>(); - if (auto* pool_params = op->builtin_options_as_Pool2DOptions()) { - params->padding = parse_padding(pool_params->padding()); - params->stride_width = pool_params->stride_w(); - params->stride_height = pool_params->stride_h(); - params->filter_width = pool_params->filter_width(); - params->filter_height = pool_params->filter_height(); - params->activation = - parse_activation(pool_params->fused_activation_function()); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_DEPTHWISE_CONV_2D: { - TfLiteDepthwiseConvParams* params = - MallocPOD<TfLiteDepthwiseConvParams>(); - if (auto* conv_params = op->builtin_options_as_DepthwiseConv2DOptions()) { - params->padding = parse_padding(conv_params->padding()); - params->stride_width = conv_params->stride_w(); - params->stride_height = conv_params->stride_h(); - params->depth_multiplier = conv_params->depth_multiplier(); - params->activation = - parse_activation(conv_params->fused_activation_function()); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_SVDF: { - TfLiteSVDFParams* params = MallocPOD<TfLiteSVDFParams>(); - if (auto* svdf_params = op->builtin_options_as_SVDFOptions()) { - params->rank = svdf_params->rank(); - params->activation = - parse_activation(svdf_params->fused_activation_function()); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN: - case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: { - TfLiteSequenceRNNParams* params = MallocPOD<TfLiteSequenceRNNParams>(); - if (auto* sequence_rnn_params = - op->builtin_options_as_SequenceRNNOptions()) { - params->activation = - parse_activation(sequence_rnn_params->fused_activation_function()); - params->time_major = sequence_rnn_params->time_major(); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_RNN: { - TfLiteRNNParams* params = MallocPOD<TfLiteRNNParams>(); - if (auto* rnn_params = op->builtin_options_as_RNNOptions()) { - params->activation = - parse_activation(rnn_params->fused_activation_function()); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: { - TfLiteEmbeddingLookupSparseParams* params = - MallocPOD<TfLiteEmbeddingLookupSparseParams>(); - if (auto* embedding_params = - op->builtin_options_as_EmbeddingLookupSparseOptions()) { - params->combiner = parseCombinerType(embedding_params->combiner()); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_FULLY_CONNECTED: { - TfLiteFullyConnectedParams* params = - MallocPOD<TfLiteFullyConnectedParams>(); - if (auto* fully_connected_params = - op->builtin_options_as_FullyConnectedOptions()) { - params->activation = parse_activation( - fully_connected_params->fused_activation_function()); - switch (fully_connected_params->weights_format()) { - case FullyConnectedOptionsWeightsFormat_DEFAULT: - params->weights_format = kTfLiteFullyConnectedWeightsFormatDefault; - break; - case FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8: - params->weights_format = - kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8; - break; - default: - error_reporter->Report("Unhandled fully-connected weights format."); - return kTfLiteError; - } - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_HASHTABLE_LOOKUP: - // no-op. - break; - case BuiltinOperator_SOFTMAX: { - TfLiteSoftmaxParams* params = MallocPOD<TfLiteSoftmaxParams>(); - if (auto* softmax_params = op->builtin_options_as_SoftmaxOptions()) { - params->beta = softmax_params->beta(); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_CONCATENATION: { - TfLiteConcatenationParams* params = - MallocPOD<TfLiteConcatenationParams>(); - if (auto* concatenation_params = - op->builtin_options_as_ConcatenationOptions()) { - params->activation = - parse_activation(concatenation_params->fused_activation_function()); - params->axis = concatenation_params->axis(); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_MUL: { - auto* params = MallocPOD<TfLiteMulParams>(); - if (auto* schema_params = op->builtin_options_as_MulOptions()) { - params->activation = - parse_activation(schema_params->fused_activation_function()); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_ADD: { - auto* params = MallocPOD<TfLiteAddParams>(); - if (auto* schema_params = op->builtin_options_as_AddOptions()) { - params->activation = - parse_activation(schema_params->fused_activation_function()); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_DIV: { - auto* params = MallocPOD<TfLiteDivParams>(); - if (auto* schema_params = op->builtin_options_as_DivOptions()) { - params->activation = - parse_activation(schema_params->fused_activation_function()); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_SUB: { - auto* params = MallocPOD<TfLiteSubParams>(); - if (auto* schema_params = op->builtin_options_as_SubOptions()) { - params->activation = - parse_activation(schema_params->fused_activation_function()); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_L2_NORMALIZATION: { - auto* params = MallocPOD<TfLiteL2NormParams>(); - if (auto* schema_params = op->builtin_options_as_L2NormOptions()) { - params->activation = - parse_activation(schema_params->fused_activation_function()); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: { - auto* params = MallocPOD<TfLiteLocalResponseNormParams>(); - if (auto* schema_params = - op->builtin_options_as_LocalResponseNormalizationOptions()) { - params->radius = schema_params->radius(); - params->bias = schema_params->bias(); - params->alpha = schema_params->alpha(); - params->beta = schema_params->beta(); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: - case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: - case BuiltinOperator_LSTM: { - TfLiteLSTMParams* params = MallocPOD<TfLiteLSTMParams>(); - if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) { - params->activation = - parse_activation(lstm_params->fused_activation_function()); - params->cell_clip = lstm_params->cell_clip(); - params->proj_clip = lstm_params->proj_clip(); - switch (lstm_params->kernel_type()) { - case LSTMKernelType_FULL: - params->kernel_type = kTfLiteLSTMFullKernel; - break; - case LSTMKernelType_BASIC: - params->kernel_type = kTfLiteLSTMBasicKernel; - break; - } - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_RESIZE_BILINEAR: { - auto* params = MallocPOD<TfLiteResizeBilinearParams>(); - if (auto* schema_params = - op->builtin_options_as_ResizeBilinearOptions()) { - params->align_corners = schema_params->align_corners(); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_RESHAPE: { - auto* params = MallocPOD<TfLiteReshapeParams>(); - if (auto* schema_params = op->builtin_options_as_ReshapeOptions()) { - auto* new_shape = schema_params->new_shape(); - FlatBufferIntVectorToArray(sizeof(params->shape), new_shape, - params->shape, error_reporter); - params->num_dimensions = new_shape->Length(); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_SKIP_GRAM: { - TfLiteSkipGramParams* params = MallocPOD<TfLiteSkipGramParams>(); - if (auto* skip_gram_params = op->builtin_options_as_SkipGramOptions()) { - params->ngram_size = skip_gram_params->ngram_size(); - params->max_skip_size = skip_gram_params->max_skip_size(); - params->include_all_ngrams = skip_gram_params->include_all_ngrams(); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_SPACE_TO_DEPTH: { - auto* params = MallocPOD<TfLiteSpaceToDepthParams>(); - if (auto* schema_params = op->builtin_options_as_SpaceToDepthOptions()) { - params->block_size = schema_params->block_size(); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_GATHER: { - TfLiteGatherParams* params = MallocPOD<TfLiteGatherParams>(); - params->axis = 0; - if (auto* gather_params = op->builtin_options_as_GatherOptions()) { - params->axis = gather_params->axis(); - } - - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_MEAN: - case BuiltinOperator_REDUCE_MAX: - case BuiltinOperator_REDUCE_MIN: - case BuiltinOperator_REDUCE_PROD: - case BuiltinOperator_SUM: - case BuiltinOperator_REDUCE_ANY: { - auto* params = MallocPOD<TfLiteReducerParams>(); - if (auto* schema_params = op->builtin_options_as_ReducerOptions()) { - params->keep_dims = schema_params->keep_dims(); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_SPLIT: { - auto* params = MallocPOD<TfLiteSplitParams>(); - if (auto* schema_params = op->builtin_options_as_SplitOptions()) { - params->num_splits = schema_params->num_splits(); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_SQUEEZE: { - auto* params = MallocPOD<TfLiteSqueezeParams>(); - if (auto* schema_params = op->builtin_options_as_SqueezeOptions()) { - const auto& squeeze_dims = schema_params->squeeze_dims(); - FlatBufferIntVectorToArray(sizeof(params->squeeze_dims), squeeze_dims, - params->squeeze_dims, error_reporter); - params->num_squeeze_dims = squeeze_dims->Length(); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_STRIDED_SLICE: { - auto* params = MallocPOD<TfLiteStridedSliceParams>(); - if (auto* schema_params = op->builtin_options_as_StridedSliceOptions()) { - params->begin_mask = schema_params->begin_mask(); - params->end_mask = schema_params->end_mask(); - params->ellipsis_mask = schema_params->ellipsis_mask(); - params->new_axis_mask = schema_params->new_axis_mask(); - params->shrink_axis_mask = schema_params->shrink_axis_mask(); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_ARG_MAX: { - auto* params = MallocPOD<TfLiteArgMaxParams>(); - if (auto* schema_params = op->builtin_options_as_ArgMaxOptions()) { - ConvertTensorType(schema_params->output_type(), ¶ms->output_type, - error_reporter); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_ARG_MIN: { - auto* params = MallocPOD<TfLiteArgMinParams>(); - if (const auto* schema_params = op->builtin_options_as_ArgMinOptions()) { - ConvertTensorType(schema_params->output_type(), ¶ms->output_type, - error_reporter); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_TRANSPOSE_CONV: { - TfLiteTransposeConvParams* params = - MallocPOD<TfLiteTransposeConvParams>(); - if (auto* transpose_conv_params = - op->builtin_options_as_TransposeConvOptions()) { - params->padding = parse_padding(transpose_conv_params->padding()); - params->stride_width = transpose_conv_params->stride_w(); - params->stride_height = transpose_conv_params->stride_h(); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_SPARSE_TO_DENSE: { - TfLiteSparseToDenseParams* params = - MallocPOD<TfLiteSparseToDenseParams>(); - if (auto* sparse_to_dense_params = - op->builtin_options_as_SparseToDenseOptions()) { - params->validate_indices = sparse_to_dense_params->validate_indices(); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_SHAPE: { - auto* params = MallocPOD<TfLiteShapeParams>(); - if (auto* schema_params = op->builtin_options_as_ShapeOptions()) { - ConvertTensorType(schema_params->out_type(), ¶ms->out_type, - error_reporter); - } - *builtin_data = static_cast<void*>(params); - break; - } - case BuiltinOperator_PACK: { - TfLitePackParams* params = MallocPOD<TfLitePackParams>(); - if (auto* pack_params = op->builtin_options_as_PackOptions()) { - params->values_count = pack_params->values_count(); - params->axis = pack_params->axis(); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_DELEGATE: { - // TODO(ycling): Revisit when supporting saving delegated models. - error_reporter->Report("DELEGATE op shouldn't exist in model."); - return kTfLiteError; - } - case BuiltinOperator_FAKE_QUANT: { - auto* params = MallocPOD<TfLiteFakeQuantParams>(); - if (auto* schema_params = op->builtin_options_as_FakeQuantOptions()) { - params->min = schema_params->min(); - params->max = schema_params->max(); - params->num_bits = schema_params->num_bits(); - params->narrow_range = schema_params->narrow_range(); - } - *builtin_data = static_cast<void*>(params); - break; - } - case BuiltinOperator_ONE_HOT: { - auto* params = MallocPOD<TfLiteOneHotParams>(); - if (auto* schema_params = op->builtin_options_as_OneHotOptions()) { - params->axis = schema_params->axis(); - } - *builtin_data = static_cast<void*>(params); - break; - } - case BuiltinOperator_UNPACK: { - TfLiteUnpackParams* params = MallocPOD<TfLiteUnpackParams>(); - if (auto* unpack_params = op->builtin_options_as_UnpackOptions()) { - params->num = unpack_params->num(); - params->axis = unpack_params->axis(); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - - // Below are the ops with no builtin_data strcture. - case BuiltinOperator_BATCH_TO_SPACE_ND: - // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are - // ok for now, since there is no call implementation either. - case BuiltinOperator_CALL: - case BuiltinOperator_CONCAT_EMBEDDINGS: - case BuiltinOperator_CUSTOM: - case BuiltinOperator_DEQUANTIZE: - case BuiltinOperator_EMBEDDING_LOOKUP: - case BuiltinOperator_EQUAL: - case BuiltinOperator_EXP: - case BuiltinOperator_EXPAND_DIMS: - case BuiltinOperator_FLOOR: - case BuiltinOperator_GREATER: - case BuiltinOperator_GREATER_EQUAL: - case BuiltinOperator_LESS: - case BuiltinOperator_LESS_EQUAL: - case BuiltinOperator_LOG: - case BuiltinOperator_LOGISTIC: - case BuiltinOperator_LOG_SOFTMAX: - case BuiltinOperator_MAXIMUM: - case BuiltinOperator_MINIMUM: - case BuiltinOperator_NEG: - case BuiltinOperator_NOT_EQUAL: - case BuiltinOperator_PAD: - case BuiltinOperator_PADV2: - case BuiltinOperator_PRELU: - case BuiltinOperator_RELU: - case BuiltinOperator_RELU6: - case BuiltinOperator_RELU_N1_TO_1: - case BuiltinOperator_RSQRT: - case BuiltinOperator_SELECT: - case BuiltinOperator_SIN: - case BuiltinOperator_SLICE: - case BuiltinOperator_SPACE_TO_BATCH_ND: - case BuiltinOperator_SQRT: - case BuiltinOperator_TANH: - case BuiltinOperator_TILE: - case BuiltinOperator_TOPK_V2: - case BuiltinOperator_TRANSPOSE: - case BuiltinOperator_POW: - case BuiltinOperator_LOGICAL_OR: - case BuiltinOperator_LOGICAL_AND: - case BuiltinOperator_LOGICAL_NOT: - case BuiltinOperator_FLOOR_DIV: - break; - } - return kTfLiteOk; -} - } // namespace TfLiteStatus InterpreterBuilder::ParseNodes( diff --git a/tensorflow/contrib/lite/model.h b/tensorflow/contrib/lite/model.h index 8bc9ecd7ce..6abdfcd079 100644 --- a/tensorflow/contrib/lite/model.h +++ b/tensorflow/contrib/lite/model.h @@ -35,9 +35,10 @@ limitations under the License. #define TENSORFLOW_CONTRIB_LITE_MODEL_H_ #include <memory> -#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/core/api/error_reporter.h" +#include "tensorflow/contrib/lite/core/api/op_resolver.h" #include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/op_resolver.h" +#include "tensorflow/contrib/lite/mutable_op_resolver.h" #include "tensorflow/contrib/lite/schema/schema_generated.h" namespace tflite { diff --git a/tensorflow/contrib/lite/model_test.cc b/tensorflow/contrib/lite/model_test.cc index df4f60d4ad..ec7d46af7c 100644 --- a/tensorflow/contrib/lite/model_test.cc +++ b/tensorflow/contrib/lite/model_test.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/contrib/lite/model.h" #include <gtest/gtest.h> -#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/core/api/error_reporter.h" #include "tensorflow/contrib/lite/testing/util.h" // Comparison for TfLiteRegistration. Since TfLiteRegistration is a C object, diff --git a/tensorflow/contrib/lite/op_resolver.cc b/tensorflow/contrib/lite/mutable_op_resolver.cc index f6e435e982..8ee63d2a02 100644 --- a/tensorflow/contrib/lite/op_resolver.cc +++ b/tensorflow/contrib/lite/mutable_op_resolver.cc @@ -13,8 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/op_resolver.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/mutable_op_resolver.h" namespace tflite { diff --git a/tensorflow/contrib/lite/mutable_op_resolver.h b/tensorflow/contrib/lite/mutable_op_resolver.h new file mode 100644 index 0000000000..c319041e9b --- /dev/null +++ b/tensorflow/contrib/lite/mutable_op_resolver.h @@ -0,0 +1,79 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_MUTABLE_OP_RESOLVER_H_ +#define TENSORFLOW_CONTRIB_LITE_MUTABLE_OP_RESOLVER_H_ + +#include <unordered_map> +#include "tensorflow/contrib/lite/core/api/op_resolver.h" +#include "tensorflow/contrib/lite/util.h" + +namespace tflite { + +// Some versions of gcc doesn't support partial specialization in class scope, +// so these are defined in a namescope. +namespace op_resolver_hasher { +template <typename V> +struct ValueHasher { + size_t operator()(const V& v) const { return std::hash<V>()(v); } +}; + +template <> +struct ValueHasher<tflite::BuiltinOperator> { + size_t operator()(const tflite::BuiltinOperator& v) const { + return std::hash<int>()(static_cast<int>(v)); + } +}; + +template <typename T> +struct OperatorKeyHasher { + size_t operator()(const T& x) const { + size_t a = ValueHasher<typename T::first_type>()(x.first); + size_t b = ValueHasher<typename T::second_type>()(x.second); + return CombineHashes({a, b}); + } +}; +} // namespace op_resolver_hasher + +// An OpResolver that is mutable, also used as the op in gen_op_registration. +// A typical usage: +// MutableOpResolver resolver; +// resolver.AddBuiltin(BuiltinOperator_ADD, Register_ADD()); +// resolver.AddCustom("CustomOp", Register_CUSTOM_OP()); +// InterpreterBuilder(model, resolver)(&interpreter); +class MutableOpResolver : public OpResolver { + public: + const TfLiteRegistration* FindOp(tflite::BuiltinOperator op, + int version) const override; + const TfLiteRegistration* FindOp(const char* op, int version) const override; + void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration, + int min_version = 1, int max_version = 1); + void AddCustom(const char* name, TfLiteRegistration* registration, + int min_version = 1, int max_version = 1); + + private: + typedef std::pair<tflite::BuiltinOperator, int> BuiltinOperatorKey; + typedef std::pair<std::string, int> CustomOperatorKey; + + std::unordered_map<BuiltinOperatorKey, TfLiteRegistration, + op_resolver_hasher::OperatorKeyHasher<BuiltinOperatorKey> > + builtins_; + std::unordered_map<CustomOperatorKey, TfLiteRegistration, + op_resolver_hasher::OperatorKeyHasher<CustomOperatorKey> > + custom_ops_; +}; + +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_MUTABLE_OP_RESOLVER_H_ diff --git a/tensorflow/contrib/lite/op_resolver_test.cc b/tensorflow/contrib/lite/mutable_op_resolver_test.cc index 10b7e31972..db690eaab9 100644 --- a/tensorflow/contrib/lite/op_resolver_test.cc +++ b/tensorflow/contrib/lite/mutable_op_resolver_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/op_resolver.h" +#include "tensorflow/contrib/lite/mutable_op_resolver.h" #include <gtest/gtest.h> #include "tensorflow/contrib/lite/testing/util.h" diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc index 484842713d..817486e898 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/nnapi_delegate.cc @@ -18,8 +18,8 @@ limitations under the License. #include <sys/mman.h> #include <sys/stat.h> #include <sys/types.h> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/core/api/error_reporter.h" #include "tensorflow/contrib/lite/model.h" #include "tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h" diff --git a/tensorflow/contrib/lite/nnapi_delegate.h b/tensorflow/contrib/lite/nnapi_delegate.h index 2bdb2cc5c8..22359d557e 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.h +++ b/tensorflow/contrib/lite/nnapi_delegate.h @@ -16,8 +16,8 @@ limitations under the License. #define TENSORFLOW_CONTRIB_LITE_NNAPI_DELEGATE_H_ #include "tensorflow/contrib/lite/allocation.h" -#include "tensorflow/contrib/lite/context.h" -#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/contrib/lite/core/api/error_reporter.h" #include "tensorflow/contrib/lite/interpreter.h" class ANeuralNetworksModel; diff --git a/tensorflow/contrib/lite/op_resolver.h b/tensorflow/contrib/lite/op_resolver.h index 9d7e3f2085..e93134cbde 100644 --- a/tensorflow/contrib/lite/op_resolver.h +++ b/tensorflow/contrib/lite/op_resolver.h @@ -12,83 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +// Compatibility shim for moved header location. #ifndef TENSORFLOW_CONTRIB_LITE_OP_RESOLVER_H_ #define TENSORFLOW_CONTRIB_LITE_OP_RESOLVER_H_ -#include <unordered_map> -#include "tensorflow/contrib/lite/context.h" -#include "tensorflow/contrib/lite/schema/schema_generated.h" -#include "tensorflow/contrib/lite/util.h" - -namespace tflite { - -// Abstract interface that returns TfLiteRegistrations given op codes or custom -// op names. This is the mechanism that ops being referenced in the flatbuffer -// model are mapped to executable function pointers (TfLiteRegistrations). -class OpResolver { - public: - // Finds the op registration for a builtin operator by enum code. - virtual const TfLiteRegistration* FindOp(tflite::BuiltinOperator op, - int version) const = 0; - // Finds the op registration of a custom operator by op name. - virtual const TfLiteRegistration* FindOp(const char* op, - int version) const = 0; - virtual ~OpResolver() {} -}; - -// Some versions of gcc doesn't support partial specialization in class scope, -// so these are defined in a namescope. -namespace op_resolver_hasher { -template <typename V> -struct ValueHasher { - size_t operator()(const V& v) const { return std::hash<V>()(v); } -}; - -template <> -struct ValueHasher<tflite::BuiltinOperator> { - size_t operator()(const tflite::BuiltinOperator& v) const { - return std::hash<int>()(static_cast<int>(v)); - } -}; - -template <typename T> -struct OperatorKeyHasher { - size_t operator()(const T& x) const { - size_t a = ValueHasher<typename T::first_type>()(x.first); - size_t b = ValueHasher<typename T::second_type>()(x.second); - return CombineHashes({a, b}); - } -}; -} // namespace op_resolver_hasher - -// An OpResolver that is mutable, also used as the op in gen_op_registration. -// A typical usage: -// MutableOpResolver resolver; -// resolver.AddBuiltin(BuiltinOperator_ADD, Register_ADD()); -// resolver.AddCustom("CustomOp", Register_CUSTOM_OP()); -// InterpreterBuilder(model, resolver)(&interpreter); -class MutableOpResolver : public OpResolver { - public: - const TfLiteRegistration* FindOp(tflite::BuiltinOperator op, - int version) const override; - const TfLiteRegistration* FindOp(const char* op, int version) const override; - void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration, - int min_version = 1, int max_version = 1); - void AddCustom(const char* name, TfLiteRegistration* registration, - int min_version = 1, int max_version = 1); - - private: - typedef std::pair<tflite::BuiltinOperator, int> BuiltinOperatorKey; - typedef std::pair<std::string, int> CustomOperatorKey; - - std::unordered_map<BuiltinOperatorKey, TfLiteRegistration, - op_resolver_hasher::OperatorKeyHasher<BuiltinOperatorKey> > - builtins_; - std::unordered_map<CustomOperatorKey, TfLiteRegistration, - op_resolver_hasher::OperatorKeyHasher<CustomOperatorKey> > - custom_ops_; -}; - -} // namespace tflite +#include "tensorflow/contrib/lite/core/api/op_resolver.h" +#include "tensorflow/contrib/lite/mutable_op_resolver.h" #endif // TENSORFLOW_CONTRIB_LITE_OP_RESOLVER_H_ diff --git a/tensorflow/contrib/lite/simple_memory_arena.h b/tensorflow/contrib/lite/simple_memory_arena.h index f738315cf2..45d0d8735e 100644 --- a/tensorflow/contrib/lite/simple_memory_arena.h +++ b/tensorflow/contrib/lite/simple_memory_arena.h @@ -17,7 +17,7 @@ limitations under the License. #include <list> #include <memory> -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" namespace tflite { diff --git a/tensorflow/contrib/lite/error_reporter.cc b/tensorflow/contrib/lite/stderr_reporter.cc index 646913c026..e29a6345fd 100644 --- a/tensorflow/contrib/lite/error_reporter.cc +++ b/tensorflow/contrib/lite/stderr_reporter.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/stderr_reporter.h" #include <cstdarg> #include <cstdio> @@ -22,26 +22,6 @@ limitations under the License. namespace tflite { -ErrorReporter::~ErrorReporter() {} - -int ErrorReporter::Report(const char* format, ...) { - va_list args; - va_start(args, format); - int code = Report(format, args); - va_end(args); - return code; -} - -// TODO(aselle): Make the name of ReportError on context the same, so -// we can use the ensure functions w/o a context and w/ a reporter. -int ErrorReporter::ReportError(void*, const char* format, ...) { - va_list args; - va_start(args, format); - int code = Report(format, args); - va_end(args); - return code; -} - int StderrReporter::Report(const char* format, va_list args) { #ifdef __ANDROID__ // On Android stderr is not captured for applications, only for code run from diff --git a/tensorflow/contrib/lite/stderr_reporter.h b/tensorflow/contrib/lite/stderr_reporter.h new file mode 100644 index 0000000000..c6f4ffbdff --- /dev/null +++ b/tensorflow/contrib/lite/stderr_reporter.h @@ -0,0 +1,34 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_STDERR_REPORTER_H_ +#define TENSORFLOW_CONTRIB_LITE_STDERR_REPORTER_H_ + +#include <cstdarg> +#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/contrib/lite/core/api/error_reporter.h" + +namespace tflite { + +// An error reporter that simplify writes the message to stderr. +struct StderrReporter : public ErrorReporter { + int Report(const char* format, va_list args) override; +}; + +// Return the default error reporter (output to stderr). +ErrorReporter* DefaultErrorReporter(); + +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_STDERR_REPORTER_H_ diff --git a/tensorflow/contrib/lite/string_util.cc b/tensorflow/contrib/lite/string_util.cc index a316a40b62..b991e999b6 100644 --- a/tensorflow/contrib/lite/string_util.cc +++ b/tensorflow/contrib/lite/string_util.cc @@ -17,7 +17,7 @@ limitations under the License. #include <string.h> #include <vector> -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/interpreter.h" namespace tflite { diff --git a/tensorflow/contrib/lite/string_util.h b/tensorflow/contrib/lite/string_util.h index 57f129bf5e..d24627b509 100644 --- a/tensorflow/contrib/lite/string_util.h +++ b/tensorflow/contrib/lite/string_util.h @@ -42,7 +42,7 @@ limitations under the License. #include <vector> -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/string.h" namespace tflite { diff --git a/tensorflow/contrib/lite/string_util_test.cc b/tensorflow/contrib/lite/string_util_test.cc index d53fec7512..a583a9184b 100644 --- a/tensorflow/contrib/lite/string_util_test.cc +++ b/tensorflow/contrib/lite/string_util_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/contrib/lite/string_util.h" #include <gtest/gtest.h> -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/interpreter.h" #include "tensorflow/contrib/lite/testing/util.h" diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD index 0b3a97d4f5..aad1ecaeb6 100644 --- a/tensorflow/contrib/lite/testing/BUILD +++ b/tensorflow/contrib/lite/testing/BUILD @@ -173,7 +173,6 @@ tf_cc_test( srcs = ["tflite_driver_test.cc"], data = ["//tensorflow/contrib/lite:testdata/multi_add.bin"], tags = [ - "no_oss", # b/112769036 "tflite_not_portable_android", "tflite_not_portable_ios", ], @@ -215,6 +214,7 @@ cc_library( deps = [ "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:string", + "//tensorflow/contrib/lite/core/api", ], ) diff --git a/tensorflow/contrib/lite/testing/util.h b/tensorflow/contrib/lite/testing/util.h index 8aa639157b..925791d390 100644 --- a/tensorflow/contrib/lite/testing/util.h +++ b/tensorflow/contrib/lite/testing/util.h @@ -17,7 +17,7 @@ limitations under the License. #include <cstdio> -#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/core/api/error_reporter.h" #include "tensorflow/contrib/lite/string.h" namespace tflite { diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD index a75553db84..bea90f1ce8 100644 --- a/tensorflow/contrib/lite/toco/BUILD +++ b/tensorflow/contrib/lite/toco/BUILD @@ -372,6 +372,7 @@ cc_library( ":toco_graphviz_dump_options", ":toco_port", ":types_proto_cc", + "//tensorflow/contrib/lite/kernels/internal:types", "//tensorflow/core:lib", "@com_google_absl//absl/strings", "@com_googlesource_code_re2//:re2", diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index c25be078ff..f103bb94ae 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -1314,12 +1314,16 @@ void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) { // Compute output shape for (int axis = 0; axis < num_input_axes; ++axis) { + const auto strided_slice_params = + tflite::strided_slice::BuildStridedSliceParams( + op->begin_mask, op->end_mask, op->shrink_axis_mask, + op->start_indices, op->stop_indices, op->strides); int start_index = tflite::strided_slice::StartForAxis( - op->begin_mask, op->start_indices, op->strides, - input_array.shape().dims().data(), axis); + strided_slice_params, ToRuntimeShape(input_array.shape()), axis); int stop_index = tflite::strided_slice::StopForAxis( - op->end_mask, op->shrink_axis_mask, op->stop_indices, op->strides, - input_array.shape().dims().data(), axis, start_index); + strided_slice_params, ToRuntimeShape(input_array.shape()), axis, + start_index); + int dim_size = ceil(static_cast<float>(stop_index - start_index) / op->strides[axis]); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc index 9d8bd4fc39..8853ed87e6 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc @@ -52,14 +52,18 @@ void StridedSlice(StridedSliceOperator const& op, Array const& input_array, Buffer<Type> const& input_buffer = input_array.GetBuffer<Type>(); std::vector<int> src_coord(num_input_axes); std::vector<int> stop_for_axis(num_input_axes); + const auto strided_slice_params = + tflite::strided_slice::BuildStridedSliceParams( + op.begin_mask, op.end_mask, op.shrink_axis_mask, op.start_indices, + op.stop_indices, op.strides); + for (int axis = 0; axis < num_input_axes; axis++) { - int start = tflite::strided_slice::StartForAxis( - op.begin_mask, op.start_indices, op.strides, input_shape.dims().data(), - axis); - src_coord[axis] = start; + int start_index = tflite::strided_slice::StartForAxis( + strided_slice_params, ToRuntimeShape(input_array.shape()), axis); + src_coord[axis] = start_index; stop_for_axis[axis] = tflite::strided_slice::StopForAxis( - op.end_mask, op.shrink_axis_mask, op.stop_indices, op.strides, - input_shape.dims().data(), axis, start); + strided_slice_params, ToRuntimeShape(input_array.shape()), axis, + start_index); } // In order to handle any number (N) of dimensions, we copy elements one by @@ -86,8 +90,7 @@ void StridedSlice(StridedSliceOperator const& op, Array const& input_array, if (tflite::strided_slice::LoopCondition(src_coord[axis], stop, stride)) { // Reset axis and set carry src_coord[axis] = tflite::strided_slice::StartForAxis( - op.begin_mask, op.start_indices, op.strides, - input_shape.dims().data(), axis); + strided_slice_params, ToRuntimeShape(input_shape), axis); carry = true; } else { carry = false; diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h index bdeb203024..5f4b8cb66a 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.h +++ b/tensorflow/contrib/lite/toco/tooling_util.h @@ -28,6 +28,7 @@ limitations under the License. #if TOCO_SUPPORT_PORTABLE_PROTOS #include "third_party/protobuf/include/google/protobuf/text_format.h" #endif // TOCO_SUPPORT_PORTABLE_PROTOS +#include "tensorflow/contrib/lite/kernels/internal/types.h" #include "tensorflow/contrib/lite/toco/model.h" #include "tensorflow/contrib/lite/toco/model_flags.pb.h" #include "tensorflow/contrib/lite/toco/runtime/types.h" @@ -139,6 +140,10 @@ bool ShapesAgreeUpToBroadcasting(const Shape& shape0, const Shape& shape1); // - For the remaining indices [0..i0), d0[i0] == 1. bool ShapesAgreeUpToExtending(const Shape& shape0, const Shape& shape1); +inline ::tflite::RuntimeShape ToRuntimeShape(const Shape& shape) { + return ::tflite::RuntimeShape(shape.dimensions_count(), shape.dims().data()); +} + bool IsArrayFullyConnectedWeights(const Model& model, const string& name); // If there is a wildcard dimension (-1), this may return a negative value. diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/BUILD b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/BUILD index a66812fe87..98e2835b2e 100644 --- a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/BUILD +++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/BUILD @@ -54,6 +54,7 @@ tf_cc_test( linkopts = common_linkopts, linkstatic = 1, tags = [ + "no_oss", # b/114307765 "tflite_not_portable_android", "tflite_not_portable_ios", ], diff --git a/tensorflow/contrib/lite/tools/make/Makefile b/tensorflow/contrib/lite/tools/make/Makefile index e30cc1d70e..59bdb10811 100644 --- a/tensorflow/contrib/lite/tools/make/Makefile +++ b/tensorflow/contrib/lite/tools/make/Makefile @@ -24,6 +24,21 @@ HOST_ARCH := $(shell if [[ $(shell uname -m) =~ i[345678]86 ]]; then echo x86_32 TARGET := $(HOST_OS) TARGET_ARCH := $(HOST_ARCH) +INCLUDES := \ +-I. \ +-I$(MAKEFILE_DIR)/../../../../../ \ +-I$(MAKEFILE_DIR)/../../../../../../ \ +-I$(MAKEFILE_DIR)/downloads/ \ +-I$(MAKEFILE_DIR)/downloads/eigen \ +-I$(MAKEFILE_DIR)/downloads/gemmlowp \ +-I$(MAKEFILE_DIR)/downloads/neon_2_sse \ +-I$(MAKEFILE_DIR)/downloads/farmhash/src \ +-I$(MAKEFILE_DIR)/downloads/flatbuffers/include \ +-I$(OBJDIR) +# This is at the end so any globally-installed frameworks like protobuf don't +# override local versions in the source tree. +INCLUDES += -I/usr/local/include + # These are the default libraries needed, but they can be added to or # overridden by the platform-specific settings in target makefiles. LIBS := \ @@ -44,55 +59,17 @@ ARFLAGS := -r TARGET_TOOLCHAIN_PREFIX := CC_PREFIX := -# These target-specific makefiles should modify or replace options like -# CXXFLAGS or LIBS to work for a specific targetted architecture. All logic -# based on platforms or architectures should happen within these files, to -# keep this main makefile focused on the sources and dependencies. -include $(wildcard $(MAKEFILE_DIR)/targets/*_makefile.inc) - -# Where compiled objects are stored. -GENDIR := $(MAKEFILE_DIR)/gen/$(TARGET)_$(TARGET_ARCH)/ -OBJDIR := $(GENDIR)obj/ -BINDIR := $(GENDIR)bin/ -LIBDIR := $(GENDIR)lib/ - -INCLUDES := \ --I. \ --I$(MAKEFILE_DIR)/../../../../../ \ --I$(MAKEFILE_DIR)/../../../../../../ \ --I$(MAKEFILE_DIR)/downloads/ \ --I$(MAKEFILE_DIR)/downloads/eigen \ --I$(MAKEFILE_DIR)/downloads/gemmlowp \ --I$(MAKEFILE_DIR)/downloads/neon_2_sse \ --I$(MAKEFILE_DIR)/downloads/farmhash/src \ --I$(MAKEFILE_DIR)/downloads/flatbuffers/include \ --I$(OBJDIR) -# This is at the end so any globally-installed frameworks like protobuf don't -# override local versions in the source tree. -INCLUDES += -I/usr/local/include - -CXX := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}g++ -CC := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}gcc -AR := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}ar - # This library is the main target for this makefile. It will contain a minimal # runtime that can be linked in to other programs. LIB_NAME := libtensorflow-lite.a -LIB_PATH := $(LIBDIR)$(LIB_NAME) - -# A small example program that shows how to link against the library. -MINIMAL_PATH := $(BINDIR)minimal # Benchmark static library and binary BENCHMARK_LIB_NAME := benchmark-lib.a BENCHMARK_BINARY_NAME := benchmark_model -BENCHMARK_LIB := $(LIBDIR)$(BENCHMARK_LIB_NAME) -BENCHMARK_BINARY := $(BINDIR)$(BENCHMARK_BINARY_NAME) +# A small example program that shows how to link against the library. MINIMAL_SRCS := \ tensorflow/contrib/lite/examples/minimal/minimal.cc -MINIMAL_OBJS := $(addprefix $(OBJDIR), \ -$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(MINIMAL_SRCS)))) # What sources we want to compile, must be kept in sync with the main Bazel # build files. @@ -105,7 +82,9 @@ PROFILE_SUMMARIZER_SRCS := \ CORE_CC_ALL_SRCS := \ $(wildcard tensorflow/contrib/lite/*.cc) \ -$(wildcard tensorflow/contrib/lite/*.c) +$(wildcard tensorflow/contrib/lite/*.c) \ +$(wildcard tensorflow/contrib/lite/c/*.c) \ +$(wildcard tensorflow/contrib/lite/core/api/*.cc) ifneq ($(BUILD_TYPE),micro) CORE_CC_ALL_SRCS += \ $(wildcard tensorflow/contrib/lite/kernels/*.cc) \ @@ -136,10 +115,6 @@ tensorflow/contrib/lite/nnapi_delegate.cc endif # Filter out all the excluded files. TF_LITE_CC_SRCS := $(filter-out $(CORE_CC_EXCLUDE_SRCS), $(CORE_CC_ALL_SRCS)) -# File names of the intermediate files target compilation generates. -TF_LITE_CC_OBJS := $(addprefix $(OBJDIR), \ -$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(TF_LITE_CC_SRCS)))) -LIB_OBJS := $(TF_LITE_CC_OBJS) # Benchmark sources BENCHMARK_SRCS_DIR := tensorflow/contrib/lite/tools/benchmark @@ -151,6 +126,40 @@ BENCHMARK_SRCS := $(filter-out \ $(wildcard $(BENCHMARK_SRCS_DIR)/*_test.cc), \ $(BENCHMARK_ALL_SRCS)) +# These target-specific makefiles should modify or replace options like +# CXXFLAGS or LIBS to work for a specific targetted architecture. All logic +# based on platforms or architectures should happen within these files, to +# keep this main makefile focused on the sources and dependencies. +include $(wildcard $(MAKEFILE_DIR)/targets/*_makefile.inc) + +ALL_SRCS := \ + $(MINIMAL_SRCS) \ + $(PROFILER_SRCS) \ + $(PROFILER_SUMMARY_SRCS) \ + $(TF_LITE_CC_SRCS) \ + $(BENCHMARK_SRCS) + +# Where compiled objects are stored. +GENDIR := $(MAKEFILE_DIR)/gen/$(TARGET)_$(TARGET_ARCH)/ +OBJDIR := $(GENDIR)obj/ +BINDIR := $(GENDIR)bin/ +LIBDIR := $(GENDIR)lib/ + +LIB_PATH := $(LIBDIR)$(LIB_NAME) +BENCHMARK_LIB := $(LIBDIR)$(BENCHMARK_LIB_NAME) +BENCHMARK_BINARY := $(BINDIR)$(BENCHMARK_BINARY_NAME) +MINIMAL_BINARY := $(BINDIR)minimal + +CXX := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}g++ +CC := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}gcc +AR := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}ar + +MINIMAL_OBJS := $(addprefix $(OBJDIR), \ +$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(MINIMAL_SRCS)))) + +LIB_OBJS := $(addprefix $(OBJDIR), \ +$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(TF_LITE_CC_SRCS)))) + BENCHMARK_OBJS := $(addprefix $(OBJDIR), \ $(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(BENCHMARK_SRCS)))) @@ -164,7 +173,7 @@ $(OBJDIR)%.o: %.c $(CC) $(CCFLAGS) $(INCLUDES) -c $< -o $@ # The target that's compiled if there's no command-line arguments. -all: $(LIB_PATH) $(MINIMAL_PATH) $(BENCHMARK_BINARY) +all: $(LIB_PATH) $(MINIMAL_BINARY) $(BENCHMARK_BINARY) # The target that's compiled for micro-controllers micro: $(LIB_PATH) @@ -178,19 +187,18 @@ $(LIB_PATH): tensorflow/contrib/lite/schema/schema_generated.h $(LIB_OBJS) @mkdir -p $(dir $@) $(AR) $(ARFLAGS) $(LIB_PATH) $(LIB_OBJS) -$(MINIMAL_PATH): $(MINIMAL_OBJS) $(LIB_PATH) +$(MINIMAL_BINARY): $(MINIMAL_OBJS) $(LIB_PATH) @mkdir -p $(dir $@) $(CXX) $(CXXFLAGS) $(INCLUDES) \ - -o $(MINIMAL_PATH) $(MINIMAL_OBJS) \ + -o $(MINIMAL_BINARY) $(MINIMAL_OBJS) \ $(LIBFLAGS) $(LIB_PATH) $(LDFLAGS) $(LIBS) - $(BENCHMARK_LIB) : $(LIB_PATH) $(BENCHMARK_OBJS) @mkdir -p $(dir $@) $(AR) $(ARFLAGS) $(BENCHMARK_LIB) $(LIB_OBJS) $(BENCHMARK_OBJS) benchmark_lib: $(BENCHMARK_LIB) -$(info $(BENCHMARK_BINARY)) + $(BENCHMARK_BINARY) : $(BENCHMARK_LIB) @mkdir -p $(dir $@) $(CXX) $(CXXFLAGS) $(INCLUDES) \ @@ -213,4 +221,4 @@ cleantarget: $(DEPDIR)/%.d: ; .PRECIOUS: $(DEPDIR)/%.d --include $(patsubst %,$(DEPDIR)/%.d,$(basename $(TF_CC_SRCS))) +-include $(patsubst %,$(DEPDIR)/%.d,$(basename $(ALL_SRCS))) diff --git a/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc b/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc index 692efb9029..b863108aa4 100644 --- a/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc +++ b/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc @@ -141,6 +141,7 @@ bool IsHybridEvaluationOp(const OperatorT* op, const BuiltinOperator& op_code) { op_code == BuiltinOperator_CONV_2D || op_code == BuiltinOperator_SVDF || op_code == BuiltinOperator_EMBEDDING_LOOKUP || op_code == BuiltinOperator_RNN || + op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM || op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN || op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM || op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN) { diff --git a/tensorflow/contrib/lite/tutorials/BUILD b/tensorflow/contrib/lite/tutorials/BUILD new file mode 100644 index 0000000000..67ff1ea124 --- /dev/null +++ b/tensorflow/contrib/lite/tutorials/BUILD @@ -0,0 +1,20 @@ +# Example Estimator model + +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +py_binary( + name = "mnist_tflite", + srcs = [ + "dataset.py", + "mnist_tflite.py", + ], + deps = [ + "//tensorflow:tensorflow_py", + ], +) diff --git a/tensorflow/contrib/lite/tutorials/dataset.py b/tensorflow/contrib/lite/tutorials/dataset.py new file mode 100644 index 0000000000..ba49dfcc9b --- /dev/null +++ b/tensorflow/contrib/lite/tutorials/dataset.py @@ -0,0 +1,122 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""tf.data.Dataset interface to the MNIST dataset. + + This is cloned from + https://github.com/tensorflow/models/blob/master/official/mnist/dataset.py +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gzip +import os +import shutil +import tempfile + +import numpy as np +from six.moves import urllib +import tensorflow as tf + + +def read32(bytestream): + """Read 4 bytes from bytestream as an unsigned 32-bit integer.""" + dt = np.dtype(np.uint32).newbyteorder('>') + return np.frombuffer(bytestream.read(4), dtype=dt)[0] + + +def check_image_file_header(filename): + """Validate that filename corresponds to images for the MNIST dataset.""" + with tf.gfile.Open(filename, 'rb') as f: + magic = read32(f) + read32(f) # num_images, unused + rows = read32(f) + cols = read32(f) + if magic != 2051: + raise ValueError('Invalid magic number %d in MNIST file %s' % (magic, + f.name)) + if rows != 28 or cols != 28: + raise ValueError( + 'Invalid MNIST file %s: Expected 28x28 images, found %dx%d' % + (f.name, rows, cols)) + + +def check_labels_file_header(filename): + """Validate that filename corresponds to labels for the MNIST dataset.""" + with tf.gfile.Open(filename, 'rb') as f: + magic = read32(f) + read32(f) # num_items, unused + if magic != 2049: + raise ValueError('Invalid magic number %d in MNIST file %s' % (magic, + f.name)) + + +def download(directory, filename): + """Download (and unzip) a file from the MNIST dataset if not already done.""" + filepath = os.path.join(directory, filename) + if tf.gfile.Exists(filepath): + return filepath + if not tf.gfile.Exists(directory): + tf.gfile.MakeDirs(directory) + # CVDF mirror of http://yann.lecun.com/exdb/mnist/ + url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz' + _, zipped_filepath = tempfile.mkstemp(suffix='.gz') + print('Downloading %s to %s' % (url, zipped_filepath)) + urllib.request.urlretrieve(url, zipped_filepath) + with gzip.open(zipped_filepath, 'rb') as f_in, \ + tf.gfile.Open(filepath, 'wb') as f_out: + shutil.copyfileobj(f_in, f_out) + os.remove(zipped_filepath) + return filepath + + +def dataset(directory, images_file, labels_file): + """Download and parse MNIST dataset.""" + + images_file = download(directory, images_file) + labels_file = download(directory, labels_file) + + check_image_file_header(images_file) + check_labels_file_header(labels_file) + + def decode_image(image): + # Normalize from [0, 255] to [0.0, 1.0] + image = tf.decode_raw(image, tf.uint8) + image = tf.cast(image, tf.float32) + image = tf.reshape(image, [784]) + return image / 255.0 + + def decode_label(label): + label = tf.decode_raw(label, tf.uint8) # tf.string -> [tf.uint8] + label = tf.reshape(label, []) # label is a scalar + return tf.to_int32(label) + + images = tf.data.FixedLengthRecordDataset( + images_file, 28 * 28, header_bytes=16).map(decode_image) + labels = tf.data.FixedLengthRecordDataset( + labels_file, 1, header_bytes=8).map(decode_label) + return tf.data.Dataset.zip((images, labels)) + + +def train(directory): + """tf.data.Dataset object for MNIST training data.""" + return dataset(directory, 'train-images-idx3-ubyte', + 'train-labels-idx1-ubyte') + + +def test(directory): + """tf.data.Dataset object for MNIST test data.""" + return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte') diff --git a/tensorflow/contrib/lite/tutorials/mnist_tflite.py b/tensorflow/contrib/lite/tutorials/mnist_tflite.py new file mode 100644 index 0000000000..7b8bf5b5db --- /dev/null +++ b/tensorflow/contrib/lite/tutorials/mnist_tflite.py @@ -0,0 +1,87 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Script to evaluate accuracy of TFLite flatbuffer model on mnist dataset.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import numpy as np +import tensorflow as tf # pylint: disable=g-bad-import-order +from tensorflow.contrib.lite.tutorials import dataset +flags = tf.app.flags + +flags.DEFINE_string('data_dir', '/tmp/data_dir', + 'Directory where data is stored.') +flags.DEFINE_string('model_file', '', + 'The path to the TFLite flatbuffer model file.') + + +flags = flags.FLAGS + + +def test_image_generator(): + # Generates an iterator over images + with tf.Session() as sess: + input_data = dataset.test( + flags.data_dir).make_one_shot_iterator().get_next() + try: + while True: + yield sess.run(input_data) + except tf.errors.OutOfRangeError: + pass + + +def run_eval(interpreter, input_image): + """Performs evaluation for input image over specified model. + + Args: + interpreter: TFLite interpreter initialized with model to execute. + input_image: Image input to the model. + + Returns: + output: output tensor of model being executed. + """ + + # Get input and output tensors. + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + # Test model on the input images. + input_image = np.reshape(input_image, input_details[0]['shape']) + interpreter.set_tensor(input_details[0]['index'], input_image) + + interpreter.invoke() + output_data = interpreter.get_tensor(output_details[0]['index']) + output = np.squeeze(output_data) + return output + + +def main(_): + interpreter = tf.contrib.lite.Interpreter(model_path=flags.model_file) + interpreter.allocate_tensors() + num_correct, total = 0, 0 + for input_data in test_image_generator(): + output = run_eval(interpreter, input_data[0]) + total += 1 + if output == input_data[1]: + num_correct += 1 + if total % 500 == 0: + print('Accuracy after %i images: %f' % + (total, float(num_correct) / float(total))) + + +if __name__ == '__main__': + tf.logging.set_verbosity(tf.logging.INFO) + tf.app.run(main) diff --git a/tensorflow/contrib/lite/util.h b/tensorflow/contrib/lite/util.h index f5b208afbb..6d81f844f8 100644 --- a/tensorflow/contrib/lite/util.h +++ b/tensorflow/contrib/lite/util.h @@ -22,7 +22,7 @@ limitations under the License. #define TENSORFLOW_CONTRIB_LITE_UTIL_H_ #include <vector> -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" namespace tflite { diff --git a/tensorflow/contrib/lite/util_test.cc b/tensorflow/contrib/lite/util_test.cc index 32bf917a59..c5c1709f1d 100644 --- a/tensorflow/contrib/lite/util_test.cc +++ b/tensorflow/contrib/lite/util_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include <gmock/gmock.h> #include <gtest/gtest.h> -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/util.h" namespace tflite { diff --git a/tensorflow/contrib/makefile/proto_text_cc_files.txt b/tensorflow/contrib/makefile/proto_text_cc_files.txt index 22b11f1c57..7d26429f9c 100644 --- a/tensorflow/contrib/makefile/proto_text_cc_files.txt +++ b/tensorflow/contrib/makefile/proto_text_cc_files.txt @@ -56,6 +56,7 @@ tensorflow/core/lib/hash/hash.cc tensorflow/core/lib/hash/crc32c.cc tensorflow/core/lib/hash/crc32c_accelerate.cc tensorflow/core/lib/core/threadpool.cc +tensorflow/core/lib/core/stringpiece.cc tensorflow/core/lib/core/status.cc tensorflow/core/lib/core/coding.cc tensorflow/core/lib/core/arena.cc diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD index 93e589907e..2e4d61d931 100644 --- a/tensorflow/contrib/opt/BUILD +++ b/tensorflow/contrib/opt/BUILD @@ -159,8 +159,10 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", + "//tensorflow/python:resource_variable_ops", "//tensorflow/python:variables", "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) diff --git a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py index f026f437dc..f55209ec49 100644 --- a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py +++ b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py @@ -25,7 +25,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops @@ -48,12 +47,7 @@ class LazyAdamOptimizer(adam.AdamOptimizer): may lead to different empirical results. """ - def _apply_sparse_shared(self, - grad, - var, - indices, - scatter_update, - scatter_sub): + def _apply_sparse(self, grad, var): beta1_power, beta2_power = self._get_beta_accumulators() beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) @@ -65,51 +59,56 @@ class LazyAdamOptimizer(adam.AdamOptimizer): # \\(m := beta1 * m + (1 - beta1) * g_t\\) m = self.get_slot(var, "m") - m_t = scatter_update(m, indices, - beta1_t * array_ops.gather(m, indices) + - (1 - beta1_t) * grad) + m_t = state_ops.scatter_update(m, grad.indices, + beta1_t * array_ops.gather(m, grad.indices) + + (1 - beta1_t) * grad.values, + use_locking=self._use_locking) # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\) v = self.get_slot(var, "v") - v_t = scatter_update(v, indices, - beta2_t * array_ops.gather(v, indices) + - (1 - beta2_t) * math_ops.square(grad)) + v_t = state_ops.scatter_update(v, grad.indices, + beta2_t * array_ops.gather(v, grad.indices) + + (1 - beta2_t) * math_ops.square(grad.values), + use_locking=self._use_locking) # \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\) - m_t_slice = array_ops.gather(m_t, indices) - v_t_slice = array_ops.gather(v_t, indices) + m_t_slice = array_ops.gather(m_t, grad.indices) + v_t_slice = array_ops.gather(v_t, grad.indices) denominator_slice = math_ops.sqrt(v_t_slice) + epsilon_t - var_update = scatter_sub(var, indices, - lr * m_t_slice / denominator_slice) + var_update = state_ops.scatter_sub(var, grad.indices, + lr * m_t_slice / denominator_slice, + use_locking=self._use_locking) return control_flow_ops.group(var_update, m_t, v_t) - def _apply_sparse(self, grad, var): - return self._apply_sparse_shared( - grad.values, var, grad.indices, - self._scatter_update, - self._scatter_sub) - def _resource_apply_sparse(self, grad, var, indices): - return self._apply_sparse_shared( - grad, var, indices, - self._resource_scatter_update, - self._resource_scatter_sub) - - # Utility functions for updating resource or non-resource variables. - def _scatter_update(self, x, i, v): - return state_ops.scatter_update( - x, i, v, use_locking=self._use_locking) - - def _scatter_sub(self, x, i, v): - return state_ops.scatter_sub( - x, i, v, use_locking=self._use_locking) - - def _resource_scatter_update(self, x, i, v): - update_op = resource_variable_ops.resource_scatter_update(x.handle, i, v) - with ops.control_dependencies([update_op]): - return x.value() - - def _resource_scatter_sub(self, x, i, v): - sub_op = resource_variable_ops.resource_scatter_sub(x.handle, i, v) - with ops.control_dependencies([sub_op]): - return x.value() + beta1_power, beta2_power = self._get_beta_accumulators() + beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) + beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) + lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) + beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) + beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) + epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) + lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) + + # \\(m := beta1 * m + (1 - beta1) * g_t\\) + m = self.get_slot(var, "m") + m_t_slice = beta1_t * array_ops.gather(m, indices) + (1 - beta1_t) * grad + m_update_op = resource_variable_ops.resource_scatter_update(m.handle, + indices, + m_t_slice) + + # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\) + v = self.get_slot(var, "v") + v_t_slice = (beta2_t * array_ops.gather(v, indices) + + (1 - beta2_t) * math_ops.square(grad)) + v_update_op = resource_variable_ops.resource_scatter_update(v.handle, + indices, + v_t_slice) + + # \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\) + var_slice = lr * m_t_slice / (math_ops.sqrt(v_t_slice) + epsilon_t) + var_update_op = resource_variable_ops.resource_scatter_sub(var.handle, + indices, + var_slice) + + return control_flow_ops.group(var_update_op, m_update_op, v_update_op) diff --git a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py index d3e9e89502..f08ffaa36f 100644 --- a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py +++ b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py @@ -19,12 +19,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np from tensorflow.contrib.opt.python.training import lazy_adam_optimizer +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops @@ -50,9 +53,10 @@ def adam_update_numpy(param, return param_t, m_t, v_t -class AdamOptimizerTest(test.TestCase): +class AdamOptimizerTest(test.TestCase, parameterized.TestCase): - def doTestSparse(self, use_resource=False): + @parameterized.parameters([False, True]) + def testSparse(self, use_resource): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: with self.cached_session(): # Initialize variables for numpy implementation. @@ -68,6 +72,7 @@ class AdamOptimizerTest(test.TestCase): else: var0 = variables.Variable(var0_np) var1 = variables.Variable(var1_np) + grads0_np_indices = np.array([0, 1], dtype=np.int32) grads0 = ops.IndexedSlices( constant_op.constant(grads0_np), @@ -99,18 +104,17 @@ class AdamOptimizerTest(test.TestCase): self.assertAllCloseAccordingToType(var0_np, var0.eval()) self.assertAllCloseAccordingToType(var1_np, var1.eval()) - def testSparse(self): - self.doTestSparse(use_resource=False) - - def testResourceSparse(self): - self.doTestSparse(use_resource=True) - - def testSparseDevicePlacement(self): + @parameterized.parameters([False, True]) + def testSparseDevicePlacement(self, use_resource): for index_dtype in [dtypes.int32, dtypes.int64]: with self.test_session(force_gpu=test.is_gpu_available()): # If a GPU is available, tests that all optimizer ops can be placed on # it (i.e. they have GPU kernels). - var = variables.Variable([[1.0], [2.0]]) + if use_resource: + var = resource_variable_ops.ResourceVariable([[1.0], [2.0]]) + else: + var = variables.Variable([[1.0], [2.0]]) + indices = constant_op.constant([0, 1], dtype=index_dtype) gathered_sum = math_ops.reduce_sum(array_ops.gather(var, indices)) optimizer = lazy_adam_optimizer.LazyAdamOptimizer(3.0) @@ -118,13 +122,21 @@ class AdamOptimizerTest(test.TestCase): variables.global_variables_initializer().run() minimize_op.run() - def testSparseRepeatedIndices(self): + @parameterized.parameters([False, True]) + def testSparseRepeatedIndices(self, use_resource): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: with self.cached_session(): - repeated_index_update_var = variables.Variable( - [[1.0], [2.0]], dtype=dtype) - aggregated_update_var = variables.Variable( - [[1.0], [2.0]], dtype=dtype) + if use_resource: + repeated_index_update_var = resource_variable_ops.ResourceVariable( + [[1.0], [2.0]], dtype=dtype) + aggregated_update_var = resource_variable_ops.ResourceVariable( + [[1.0], [2.0]], dtype=dtype) + else: + repeated_index_update_var = variables.Variable( + [[1.0], [2.0]], dtype=dtype) + aggregated_update_var = variables.Variable( + [[1.0], [2.0]], dtype=dtype) + grad_repeated_index = ops.IndexedSlices( constant_op.constant( [0.1, 0.1], shape=[2, 1], dtype=dtype), @@ -150,6 +162,204 @@ class AdamOptimizerTest(test.TestCase): self.assertAllClose(aggregated_update_var.eval(), repeated_index_update_var.eval()) + def doTestBasic(self, use_resource=False, use_callable_params=False): + for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): + with self.session(graph=ops.Graph()): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + if use_resource: + var0 = resource_variable_ops.ResourceVariable( + var0_np, name="var0_%d" % i) + var1 = resource_variable_ops.ResourceVariable( + var1_np, name="var1_%d" % i) + else: + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + + learning_rate = lambda: 0.001 + beta1 = lambda: 0.9 + beta2 = lambda: 0.999 + epsilon = lambda: 1e-8 + if not use_callable_params: + learning_rate = learning_rate() + beta1 = beta1() + beta2 = beta2() + epsilon = epsilon() + + opt = lazy_adam_optimizer.LazyAdamOptimizer(learning_rate=learning_rate) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + opt_variables = opt.variables() + beta1_power, beta2_power = opt._get_beta_accumulators() + self.assertIsNotNone(beta1_power) + self.assertIsNotNone(beta2_power is not None) + self.assertIn(beta1_power, opt_variables) + self.assertIn(beta2_power, opt_variables) + + if not context.executing_eagerly(): + with ops.Graph().as_default(): + # Shouldn't return non-slot variables from other graphs. + self.assertEqual(0, len(opt.variables())) + self.evaluate(variables.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Run 3 steps of Adam + for t in range(1, 4): + if not context.executing_eagerly(): + self.evaluate(update) + elif t > 1: + opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + + self.assertAllCloseAccordingToType(0.9**(t + 1), + self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType(0.999**(t + 1), + self.evaluate(beta2_power)) + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + if use_resource: + self.assertEqual("var0_%d/Adam:0" % (i,), + opt.get_slot(var=var0, name="m").name) + + def testBasic(self): + with self.test_session(): + self.doTestBasic(use_resource=False) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testResourceBasic(self): + self.doTestBasic(use_resource=True) + + def testBasicCallableParams(self): + with context.eager_mode(): + self.doTestBasic(use_resource=True, use_callable_params=True) + + def testTensorLearningRate(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + opt = lazy_adam_optimizer.LazyAdamOptimizer(constant_op.constant(0.001)) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Run 3 steps of Adam + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) + self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + update.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + def testSharing(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + opt = lazy_adam_optimizer.LazyAdamOptimizer() + update1 = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + update2 = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + # Run 3 steps of intertwined Adam1 and Adam2. + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) + self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + if t % 2 == 0: + update1.run() + else: + update2.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + def testTwoSessions(self): + optimizer = lazy_adam_optimizer.LazyAdamOptimizer() + + with context.eager_mode(): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + optimizer.apply_gradients([(grads0, var0)]) + + g = ops.Graph() + with g.as_default(): + with self.session(graph=g): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + optimizer.apply_gradients([(grads0, var0)]) + + gg = ops.Graph() + with gg.as_default(): + with self.session(graph=gg): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + + # If the optimizer saves any state not keyed by graph the following line + # fails. + optimizer.apply_gradients([(grads0, var0)]) + + def testSlotsUniqueEager(self): + with context.eager_mode(): + v1 = resource_variable_ops.ResourceVariable(1.) + v2 = resource_variable_ops.ResourceVariable(1.) + opt = lazy_adam_optimizer.LazyAdamOptimizer(1.) + opt.minimize(lambda: v1 + v2) + # There should be two non-slot variables, and two unique slot variables + # for v1 and v2 respectively. + self.assertEqual(6, len(set(opt.variables()))) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/quantize/BUILD b/tensorflow/contrib/quantize/BUILD index 499fec4ffa..c59f667f6a 100644 --- a/tensorflow/contrib/quantize/BUILD +++ b/tensorflow/contrib/quantize/BUILD @@ -22,6 +22,7 @@ py_test( ":common", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:session", "//tensorflow/python:variable_scope", @@ -89,7 +90,6 @@ py_library( ":common", ":graph_matcher", ":input_to_ops", - "//tensorflow/contrib/graph_editor:graph_editor_py", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:dtypes", @@ -171,7 +171,6 @@ py_library( ":graph_matcher", ":input_to_ops", ":quant_ops", - "//tensorflow/contrib/graph_editor:graph_editor_py", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", diff --git a/tensorflow/contrib/quantize/python/common.py b/tensorflow/contrib/quantize/python/common.py index bf648e158e..b27117dd48 100644 --- a/tensorflow/contrib/quantize/python/common.py +++ b/tensorflow/contrib/quantize/python/common.py @@ -131,3 +131,29 @@ def DropStringPrefix(s, prefix): return s[len(prefix):] else: return s + + +def RerouteTensor(t0, t1, can_modify=None): + """Reroute the end of the tensor t0 to the ends of the tensor t1. + + Args: + t0: a tf.Tensor. + t1: a tf.Tensor. + can_modify: iterable of operations which can be modified. Any operation + outside within_ops will be left untouched by this function. + + Returns: + The number of individual modifications made by the function. + """ + nb_update_inputs = 0 + consumers = t1.consumers() + if can_modify is not None: + consumers = [c for c in consumers if c in can_modify] + consumers_indices = {} + for c in consumers: + consumers_indices[c] = [i for i, t in enumerate(c.inputs) if t is t1] + for c in consumers: + for i in consumers_indices[c]: + c._update_input(i, t0) # pylint: disable=protected-access + nb_update_inputs += 1 + return nb_update_inputs diff --git a/tensorflow/contrib/quantize/python/common_test.py b/tensorflow/contrib/quantize/python/common_test.py index 06c62f2d26..2b26302f8a 100644 --- a/tensorflow/contrib/quantize/python/common_test.py +++ b/tensorflow/contrib/quantize/python/common_test.py @@ -20,8 +20,10 @@ from __future__ import print_function from tensorflow.contrib.quantize.python import common from tensorflow.python.client import session +from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import googletest @@ -62,6 +64,29 @@ class CommonTest(test_util.TensorFlowTestCase): _, step_val = sess.run([b, quantization_step_tensor]) self.assertEqual(step_val, 2) + def testRerouteTensor(self): + a = constant_op.constant(1, name='a') + b = constant_op.constant(2, name='b') + c = constant_op.constant(3, name='c') + d = constant_op.constant(4, name='d') + + add_ac = math_ops.add(a, c) + add_ad = math_ops.add(a, d) + + # Ensure that before rerouting the inputs are what we think. + self._CheckOpHasInputs(add_ac.op, [a, c]) + self._CheckOpHasInputs(add_ad.op, [a, d]) + + # references to tensor a should be replaced with b for all ops in + # can_modify. This means add_ac will be changed but add_ad will not. + common.RerouteTensor(b, a, can_modify=[add_ac.op]) + self._CheckOpHasInputs(add_ac.op, [b, c]) + self._CheckOpHasInputs(add_ad.op, [a, d]) + + def _CheckOpHasInputs(self, op, inputs): + for i in inputs: + self.assertIn(i, op.inputs) + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py index d9f179bee4..2971b28f45 100644 --- a/tensorflow/contrib/quantize/python/fold_batch_norms.py +++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function import re -from tensorflow.contrib import graph_editor from tensorflow.contrib.quantize.python import common from tensorflow.contrib.quantize.python import graph_matcher from tensorflow.contrib.quantize.python import input_to_ops @@ -134,8 +133,8 @@ def _FoldFusedBatchNorms(graph, is_training, freeze_batch_norm_delay): bias_add_tensor = math_ops.add( new_layer_tensor, bias_tensor, name='add_fold') - nodes_modified_count = graph_editor.reroute_ts(bias_add_tensor, - match.output_tensor) + nodes_modified_count = common.RerouteTensor(bias_add_tensor, + match.output_tensor) if nodes_modified_count == 0: raise ValueError('Folding batch norms failed, %s had no outputs.' % match.output_tensor.name) @@ -370,8 +369,9 @@ def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay, lambda: match.bn_decay_mean_tensor, name='freeze_moving_mean') - graph_editor.reroute_ts( - [bn_decay_mean_out], [match.bn_decay_mean_tensor], + common.RerouteTensor( + bn_decay_mean_out, + match.bn_decay_mean_tensor, can_modify=bn_decay_mean_consumers) bn_decay_var_consumers = list(match.bn_decay_var_tensor.consumers()) @@ -380,8 +380,9 @@ def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay, lambda: bn_decay_zero, lambda: match.bn_decay_var_tensor, name='freeze_moving_var') - graph_editor.reroute_ts( - [bn_decay_var_out], [match.bn_decay_var_tensor], + common.RerouteTensor( + bn_decay_var_out, + match.bn_decay_var_tensor, can_modify=bn_decay_var_consumers) correction_recip = utils.smart_cond( @@ -486,9 +487,8 @@ def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay): activation = common.GetEndpointActivationOp(graph, bn) if activation: - nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]], - [original_op.outputs[0]], - can_modify=[activation]) + nodes_modified_count = common.RerouteTensor( + folded_op.outputs[0], original_op.outputs[0], can_modify=[activation]) if nodes_modified_count != 1: raise ValueError('Unexpected inputs to op: %s' % activation.name) continue @@ -497,9 +497,8 @@ def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay): # operations instead of Relu* above. add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1) add_bypass = graph.get_operation_by_name(add_bypass_ctx + '/Add') - nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]], - [original_op.outputs[0]], - can_modify=[add_bypass]) + nodes_modified_count = common.RerouteTensor( + folded_op.outputs[0], original_op.outputs[0], can_modify=[add_bypass]) if nodes_modified_count != 1: raise ValueError('Unexpected inputs to op: %s' % add_bypass.name) diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py index 2ddbd73ea6..e88db0acd5 100644 --- a/tensorflow/contrib/quantize/python/quantize.py +++ b/tensorflow/contrib/quantize/python/quantize.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function import re -from tensorflow.contrib import graph_editor from tensorflow.contrib.quantize.python import common from tensorflow.contrib.quantize.python import graph_matcher from tensorflow.contrib.quantize.python import input_to_ops @@ -592,8 +591,8 @@ def _InsertQuantOp(context, name=name_prefix + '/delayed_quant') if consumers: - tensors_modified_count = graph_editor.reroute_ts( - [quant], [inputs], can_modify=consumers) + tensors_modified_count = common.RerouteTensor( + quant, inputs, can_modify=consumers) # Some operations can have multiple output tensors going to the same # consumer. Since consumers is a set, we need to ensure that # tensors_modified_count is greater than or equal to the length of the set diff --git a/tensorflow/contrib/rnn/BUILD b/tensorflow/contrib/rnn/BUILD index 5874245d58..4e67d80558 100644 --- a/tensorflow/contrib/rnn/BUILD +++ b/tensorflow/contrib/rnn/BUILD @@ -212,6 +212,7 @@ cuda_py_tests( "//tensorflow/python:variable_scope", "//tensorflow/python:variables", ], + tags = ["noasan"], ) tf_custom_op_library( @@ -279,7 +280,10 @@ cuda_py_tests( "//tensorflow/python:variable_scope", "//tensorflow/python:variables", ], - tags = ["no_oss"], + tags = [ + "no_oss", + "noasan", + ], ) tf_cc_test( @@ -287,6 +291,7 @@ tf_cc_test( size = "small", srcs = ["ops/gru_ops_test.cc"], data = [":python/ops/_gru_ops.so"], + tags = ["noasan"], # We must ensure that the dependencies can be dynamically linked since # the shared library must be able to use core:framework. # linkstatic = tf_kernel_tests_linkstatic(), @@ -306,6 +311,7 @@ tf_cc_test( size = "small", srcs = ["ops/lstm_ops_test.cc"], data = [":python/ops/_lstm_ops.so"], + tags = ["noasan"], # We must ensure that the dependencies can be dynamically linked since # the shared library must be able to use core:framework. # linkstatic = tf_kernel_tests_linkstatic(), diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index f74c95f962..06c481672c 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -97,10 +97,10 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell): The default non-peephole implementation is based on: - http://www.bioinf.jku.at/publications/older/2604.pdf + https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf - S. Hochreiter and J. Schmidhuber. - "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997. + Felix Gers, Jurgen Schmidhuber, and Fred Cummins. + "Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999. The peephole implementation is based on: @@ -2448,10 +2448,10 @@ class LayerNormLSTMCell(rnn_cell_impl.RNNCell): The default non-peephole implementation is based on: - http://www.bioinf.jku.at/publications/older/2604.pdf + https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf - S. Hochreiter and J. Schmidhuber. - "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997. + Felix Gers, Jurgen Schmidhuber, and Fred Cummins. + "Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999. The peephole implementation is based on: @@ -2802,9 +2802,11 @@ class WeightNormLSTMCell(rnn_cell_impl.RNNCell): Training of Deep Neural Networks The default LSTM implementation based on: - http://www.bioinf.jku.at/publications/older/2604.pdf - S. Hochreiter and J. Schmidhuber. - "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997. + + https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf + + Felix Gers, Jurgen Schmidhuber, and Fred Cummins. + "Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999. The class uses optional peephole connections, optional cell clipping and an optional projection layer. diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py index db970deff5..0042d37acd 100644 --- a/tensorflow/contrib/tensor_forest/client/random_forest.py +++ b/tensorflow/contrib/tensor_forest/client/random_forest.py @@ -134,19 +134,19 @@ def _get_default_head(params, weights_name, output_type, name=None): weight_column=weights_name, label_dimension=params.num_outputs, name=name, - loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE) else: if params.num_classes == 2: return core_head_lib.binary_classification_head( weight_column=weights_name, name=name, - loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE) else: return core_head_lib.multi_class_head( n_classes=params.num_classes, weight_column=weights_name, name=name, - loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE) def get_model_fn(params, graph_builder_class, diff --git a/tensorflow/contrib/tpu/__init__.py b/tensorflow/contrib/tpu/__init__.py index 537d94b797..3c0456dc2f 100644 --- a/tensorflow/contrib/tpu/__init__.py +++ b/tensorflow/contrib/tpu/__init__.py @@ -33,6 +33,7 @@ @@shard @@batch_parallel @@rewrite +@@outside_compilation @@CrossShardOptimizer diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py index 08e0465b71..d8c3872363 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_support.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py @@ -258,6 +258,8 @@ class KerasCrossShardOptimizer(keras_optimizers.Optimizer): return [tpu_ops.cross_replica_sum(grad) / num_shards for grad in grads] def set_weights(self, weights): + # TODO(power): Figure out whether we really need this given there is no + # caller for this API yet. self._opt.set_weights() def get_weights(self): @@ -282,9 +284,9 @@ def _valid_name(tensor_name): def _replicated_optimizer(opt): """Wrap the optimizer `opt` with CrossShardOptimizer if applicable.""" - if tpu_function.get_tpu_context().number_of_shards == 1: - return opt - + # Always wrap `opt` with CrossShardOptimizer, even if we are running on a + # single core. This ensures Keras properly tracks and initializes optimizer + # variables. if isinstance(opt, keras_optimizers.TFOptimizer): return tpu_optimizer.CrossShardOptimizer(opt.optimizer) else: @@ -1420,7 +1422,7 @@ class KerasTPUModel(models.Model): y, sample_weights, batch_size) - self._pipeline_fit_loop( + return self._pipeline_fit_loop( x, y, sample_weights=sample_weights, diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index 1e21cc5252..c1f90c3963 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -652,13 +652,28 @@ def split_compile_and_replicate(computation, # TODO(phawkins): consider removing this code. It will # be less confusing to clients if they knowingly choose to use resource # variables. + # Partitioned variables is not supported (b/112311320). + def custom_getter(getter, name, *args, **kwargs): + partitioner = kwargs["partitioner"] + if partitioner is None: + return getter(name, *args, **kwargs) + else: + raise ValueError( + "Partitioned variables are not supported on TPU. Got " + "`partitioner` that is {}.".format(partitioner)) + vscope = variable_scope.get_variable_scope() + saved_use_resource = vscope.use_resource + saved_custom_getter = vscope.custom_getter + vscope.set_use_resource(True) + vscope.set_custom_getter(custom_getter) outputs = computation(*computation_inputs) vscope.set_use_resource(saved_use_resource) + vscope.set_custom_getter(saved_custom_getter) # If the computation returns `None`, make it an empty tuple. if outputs is None: diff --git a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc index ad3dce1784..d4951b156c 100644 --- a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc +++ b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc @@ -63,7 +63,7 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync( } CHECK(dst_name.compare(rdma_mgr_->local_worker()) == 0); RdmaChannel* rc = rdma_mgr_->FindChannel(src_name); - string key(std::move(parsed.FullKey().ToString())); + string key(parsed.FullKey()); string key_with_step_id = VerbsUtil::AppendStepidToKey(key, step_id_); Device* dst_dev; diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index c06fea130f..79ad3b8e54 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -702,6 +702,21 @@ cc_library( ) cc_library( + name = "feature_util", + srcs = ["example/feature_util.cc"], + hdrs = [ + "example/feature_util.h", + "platform/types.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":core_stringpiece", + ":platform_protobuf", + ":protos_all_cc", + ], +) + +cc_library( name = "abi", srcs = ["platform/abi.cc"], hdrs = ["platform/abi.h"], @@ -1339,6 +1354,7 @@ cc_library( "//tensorflow/core/kernels:mkl_relu_op", "//tensorflow/core/kernels:mkl_reshape_op", "//tensorflow/core/kernels:mkl_softmax_op", + "//tensorflow/core/kernels:mkl_transpose_op", "//tensorflow/core/kernels:mkl_tfconv_op", "//tensorflow/core/kernels:mkl_aggregate_ops", ]) + if_cuda([ @@ -3712,6 +3728,7 @@ tf_cc_test_mkl( ":core_cpu_internal", ":framework", ":framework_internal", + ":lib", ":test", ":test_main", ":testlib", diff --git a/tensorflow/core/api_def/base_api/api_def_ParallelInterleaveDatasetV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_ParallelInterleaveDatasetV2.pbtxt new file mode 100644 index 0000000000..27bc4013c3 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ParallelInterleaveDatasetV2.pbtxt @@ -0,0 +1,13 @@ +op { + graph_op_name: "ParallelInterleaveDatasetV2" + visibility: HIDDEN + attr { + name: "f" + description: <<END +A function mapping elements of `input_dataset`, concatenated with +`other_arguments`, to a Dataset variant that contains elements matching +`output_types` and `output_shapes`. +END + } + summary: "Creates a dataset that applies `f` to the outputs of `input_dataset`." +} diff --git a/tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt b/tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt index 8cef243aee..30fd97a0d7 100644 --- a/tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt @@ -9,7 +9,7 @@ END in_arg { name: "pattern" description: <<END -A 1-D string tensor of the regular expression to match the input. +A scalar string tensor containing the regular expression to match the input. END } out_arg { diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt index 35f55fe106..d33a36ce06 100644 --- a/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt @@ -3,7 +3,7 @@ op { in_arg { name: "segment_ids" description: <<END -A 1-D tensor whose rank is equal to the rank of `data`'s +A 1-D tensor whose size is equal to the size of `data`'s first dimension. Values should be sorted and can be repeated. END } diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentMean.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentMean.pbtxt index 70a07d9b4c..afdc39da96 100644 --- a/tensorflow/core/api_def/base_api/api_def_SegmentMean.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_SegmentMean.pbtxt @@ -3,7 +3,7 @@ op { in_arg { name: "segment_ids" description: <<END -A 1-D tensor whose rank is equal to the rank of `data`'s +A 1-D tensor whose size is equal to the size of `data`'s first dimension. Values should be sorted and can be repeated. END } diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentMin.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentMin.pbtxt index b2e3eece38..026b5b3991 100644 --- a/tensorflow/core/api_def/base_api/api_def_SegmentMin.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_SegmentMin.pbtxt @@ -3,7 +3,7 @@ op { in_arg { name: "segment_ids" description: <<END -A 1-D tensor whose rank is equal to the rank of `data`'s +A 1-D tensor whose size is equal to the size of `data`'s first dimension. Values should be sorted and can be repeated. END } diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentProd.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentProd.pbtxt index 7bac02e23d..a168eed87f 100644 --- a/tensorflow/core/api_def/base_api/api_def_SegmentProd.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_SegmentProd.pbtxt @@ -3,7 +3,7 @@ op { in_arg { name: "segment_ids" description: <<END -A 1-D tensor whose rank is equal to the rank of `data`'s +A 1-D tensor whose size is equal to the size of `data`'s first dimension. Values should be sorted and can be repeated. END } diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentSum.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentSum.pbtxt index a73306a892..876b860824 100644 --- a/tensorflow/core/api_def/base_api/api_def_SegmentSum.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_SegmentSum.pbtxt @@ -3,7 +3,7 @@ op { in_arg { name: "segment_ids" description: <<END -A 1-D tensor whose rank is equal to the rank of `data`'s +A 1-D tensor whose size is equal to the size of `data`'s first dimension. Values should be sorted and can be repeated. END } diff --git a/tensorflow/core/api_def/base_api/api_def_StaticRegexFullMatch.pbtxt b/tensorflow/core/api_def/base_api/api_def_StaticRegexFullMatch.pbtxt new file mode 100644 index 0000000000..6d9d9908ca --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_StaticRegexFullMatch.pbtxt @@ -0,0 +1,29 @@ +op { + graph_op_name: "StaticRegexFullMatch" + in_arg { + name: "input" + description: <<END +A string tensor of the text to be processed. +END + } + out_arg { + name: "output" + description: <<END +A bool tensor with the same shape as `input`. +END + } + attr { + name: "pattern" + description: "The regular expression to match the input." + } + summary: "Check if the input matches the regex pattern." + description: <<END +The input is a string tensor of any shape. The pattern is the +regular expression to be matched with every element of the input tensor. +The boolean values (True or False) of the output tensor indicate +if the input matches the regex pattern provided. + +The pattern follows the re2 syntax (https://github.com/google/re2/wiki/Syntax) +END + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt index 907c6d2022..7a60e4387a 100644 --- a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt @@ -3,15 +3,14 @@ op { in_arg { name: "segment_ids" description: <<END -A 1-D tensor whose rank is equal to the rank of `data`'s -first dimension. -END +A tensor whose shape is a prefix of `data.shape`.END } out_arg { name: "output" description: <<END -Has same shape as data, except for dimension 0 which -has size `num_segments`. +Has same shape as data, except for the first `segment_ids.rank` +dimensions, which are replaced with a single dimension which has size +`num_segments`. END } summary: "Computes the maximum along segments of a tensor." @@ -24,13 +23,16 @@ This operator is similar to the unsorted segment sum operator found [(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum). Instead of computing the sum over segments, it computes the maximum such that: -\\(output_i = \max_j data_j\\) where max is over `j` such -that `segment_ids[j] == i`. +\\(output_i = \max_{j...} data[j...]\\) where max is over tuples `j...` such +that `segment_ids[j...] == i`. If the maximum is empty for a given segment ID `i`, it outputs the smallest possible value for the specific numeric type, `output[i] = numeric_limits<T>::lowest()`. +If the given segment ID `i` is negative, then the corresponding value is +dropped, and will not be included in the result. + <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> <img style="width:100%" src="https://www.tensorflow.org/images/UnsortedSegmentMax.png" alt> </div> diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt index 37dd973b23..7e139ddf4d 100644 --- a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt @@ -3,15 +3,15 @@ op { in_arg { name: "segment_ids" description: <<END -A 1-D tensor whose rank is equal to the rank of `data`'s -first dimension. +A tensor whose shape is a prefix of `data.shape`. END } out_arg { name: "output" description: <<END -Has same shape as data, except for dimension 0 which -has size `num_segments`. +Has same shape as data, except for the first `segment_ids.rank` +dimensions, which are replaced with a single dimension which has size +`num_segments`. END } summary: "Computes the minimum along segments of a tensor." @@ -24,11 +24,14 @@ This operator is similar to the unsorted segment sum operator found [(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum). Instead of computing the sum over segments, it computes the minimum such that: -\\(output_i = \min_j data_j\\) where min is over `j` such -that `segment_ids[j] == i`. +\\(output_i = \min_{j...} data_[j...]\\) where min is over tuples `j...` such +that `segment_ids[j...] == i`. If the minimum is empty for a given segment ID `i`, it outputs the largest possible value for the specific numeric type, `output[i] = numeric_limits<T>::max()`. + +If the given segment ID `i` is negative, then the corresponding value is +dropped, and will not be included in the result. END } diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt index efbc023705..9c8ea3b620 100644 --- a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt @@ -3,15 +3,15 @@ op { in_arg { name: "segment_ids" description: <<END -A 1-D tensor whose rank is equal to the rank of `data`'s -first dimension. +A tensor whose shape is a prefix of `data.shape`. END } out_arg { name: "output" description: <<END -Has same shape as data, except for dimension 0 which -has size `num_segments`. +Has same shape as data, except for the first `segment_ids.rank` +dimensions, which are replaced with a single dimension which has size +`num_segments`. END } summary: "Computes the product along segments of a tensor." @@ -25,9 +25,12 @@ This operator is similar to the unsorted segment sum operator found Instead of computing the sum over segments, it computes the product of all entries belonging to a segment such that: -\\(output_i = \prod_j data_j\\) where the product is over `j` such -that `segment_ids[j] == i`. +\\(output_i = \prod_{j...} data[j...]\\) where the product is over tuples +`j...` such that `segment_ids[j...] == i`. If there is no entry for a given segment ID `i`, it outputs 1. + +If the given segment ID `i` is negative, then the corresponding value is +dropped, and will not be included in the result. END } diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt index a8874950eb..7e5d9265c2 100644 --- a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt @@ -21,7 +21,7 @@ Read for an explanation of segments. Computes a tensor such that -\\(output[i] = sum_{j...} data[j...]\\) where the sum is over tuples `j...` such +\\(output[i] = \sum_{j...} data[j...]\\) where the sum is over tuples `j...` such that `segment_ids[j...] == i`. Unlike `SegmentSum`, `segment_ids` need not be sorted and need not cover all values in the full range of valid values. diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 46bb8d92f8..1c9b69721d 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -615,11 +615,14 @@ void PruneFunctionBody(Graph* g) { std::unordered_set<const Node*> nodes; for (auto n : g->nodes()) { // NOTE(mrry): "_Retval" nodes are stateful, and so will be added - // to the seed set of `nodes`. + // to the seed set of `nodes`. "_Arg" nodes are also stateful, but we + // specifically exclude them as seeds, to avoid unconditionally executing + // unused argument nodes (e.g. in a function like `lambda x, y: y`). // TODO(mrry): Investigate whether the `n->IsControlFlow()` test is // still needed. It would be preferable to prune entire loops and/or // conditionals if they are not used in the graph. - if (n->IsControlFlow() || n->op_def().is_stateful()) { + if (n->IsControlFlow() || + (n->op_def().is_stateful() && n->type_string() != kArgOp)) { nodes.insert(n); } } @@ -925,29 +928,18 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle, } DCHECK(run_opts.runner != nullptr); - Executor::Args* exec_args = new Executor::Args; + Executor::Args exec_args; // Inherit the step_id from the caller. - exec_args->step_id = run_opts.step_id; - exec_args->rendezvous = run_opts.rendezvous; - exec_args->stats_collector = run_opts.stats_collector; - exec_args->cancellation_manager = run_opts.cancellation_manager; - exec_args->collective_executor = run_opts.collective_executor; - exec_args->step_container = run_opts.step_container; - exec_args->runner = *run_opts.runner; - exec_args->call_frame = frame; - - item->exec->RunAsync( - // Executor args - *exec_args, - // Done callback. - std::bind( - [item, frame, exec_args](DoneCallback done, - // Start unbound arguments. - const Status& status) { - delete exec_args; - done(status); - }, - std::move(done), std::placeholders::_1)); + exec_args.step_id = run_opts.step_id; + exec_args.rendezvous = run_opts.rendezvous; + exec_args.stats_collector = run_opts.stats_collector; + exec_args.cancellation_manager = run_opts.cancellation_manager; + exec_args.collective_executor = run_opts.collective_executor; + exec_args.step_container = run_opts.step_container; + exec_args.runner = *run_opts.runner; + exec_args.call_frame = frame; + + item->exec->RunAsync(exec_args, std::move(done)); } bool FunctionLibraryRuntimeImpl::IsStateful(const string& func) { diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index 120f480198..7bab9be9a6 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -802,9 +802,9 @@ TEST_F(FunctionLibraryRuntimeTest, PruneBody) { // Name "SquareAndAddOneWithStatefulNodes", // Args - {"x: int32"}, + {"x: int32", "y: float32"}, // Return values - {"y: int32"}, + {"z: int32"}, // Attrs {}, // Nodes @@ -822,12 +822,13 @@ TEST_F(FunctionLibraryRuntimeTest, PruneBody) { "RandomUniform", {"shape"}, {{"T", T}, {"dtype", DT_FLOAT}}}, - // y = Add<T>(a, o) - {{"y"}, "Add", {"a", "o"}, {{"T", T}}}}); + // z = Add<T>(a, o) + {{"z"}, "Add", {"a", "o"}, {{"T", T}}}}); Init({stateful_func}); auto x = test::AsTensor<int32>({1, 2, 3, 4}); - Tensor y; + auto y = test::AsTensor<float>({1.0, 2.0, 3.0, 4.0}); + Tensor z; FunctionLibraryRuntime::Handle handle; TF_CHECK_OK( @@ -837,18 +838,19 @@ TEST_F(FunctionLibraryRuntimeTest, PruneBody) { StepStatsCollector stats_collector(&stats); FunctionLibraryRuntime::Options opts; opts.stats_collector = &stats_collector; - TF_CHECK_OK(Run(flr0_, handle, opts, {x}, {&y})); + TF_CHECK_OK(Run(flr0_, handle, opts, {x, y}, {&z})); TF_CHECK_OK(flr0_->ReleaseHandle(handle)); TF_CHECK_OK(InstantiateAndRun(flr0_, "SquareAndAddOneWithStatefulNodes", {}, - {x}, {&y})); - test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({2, 5, 10, 17})); + {x, y}, {&z})); + test::ExpectTensorEqual<int>(z, test::AsTensor<int32>({2, 5, 10, 17})); stats_collector.FinalizeAndSwap(&stats); - // Note that we do not expect the nodes named "x1", "x2", or "x3" to execute. + // Note that we do not expect the nodes named "y", "x1", "x2", or "x3" to + // execute. std::set<string> expected_node_names( - {"_SOURCE", "shape", "x", "o", "a", "keep_me", "y", "y_RetVal"}); + {"_SOURCE", "shape", "x", "o", "a", "keep_me", "z", "z_RetVal"}); std::set<string> executed_node_names; for (const auto& node_stats : stats.dev_stats()[0].node_stats()) { executed_node_names.insert(node_stats.node_name()); diff --git a/tensorflow/core/common_runtime/tracing_device.h b/tensorflow/core/common_runtime/tracing_device.h index 39215efa35..e1b163074f 100644 --- a/tensorflow/core/common_runtime/tracing_device.h +++ b/tensorflow/core/common_runtime/tracing_device.h @@ -35,8 +35,11 @@ class TracingDevice : public Device { : Device(env, attributes) {} void Compute(OpKernel* op_kernel, OpKernelContext* context) override { + const tracing::TraceCollector* trace_collector = + tracing::GetTraceCollector(); if (TF_PREDICT_FALSE( - tracing::GetTraceCollector() || + (trace_collector && + trace_collector->IsEnabled(op_kernel->IsExpensive())) || tracing::GetEventCollector(tracing::EventCategory::kCompute))) { const string& op_name = op_kernel->name(); tracing::ScopedActivity activity(op_name, op_kernel->type_string(), diff --git a/tensorflow/core/debug/debug_io_utils.cc b/tensorflow/core/debug/debug_io_utils.cc index 38863db1cc..6994dec3b5 100644 --- a/tensorflow/core/debug/debug_io_utils.cc +++ b/tensorflow/core/debug/debug_io_utils.cc @@ -693,6 +693,7 @@ uint64 DebugFileIO::diskBytesUsed = 0; mutex DebugFileIO::bytes_mu(LINKER_INITIALIZED); bool DebugFileIO::requestDiskByteUsage(uint64 bytes) { + mutex_lock l(bytes_mu); if (globalDiskBytesLimit == 0) { const char* env_tfdbg_disk_bytes_limit = getenv("TFDBG_DISK_BYTES_LIMIT"); if (env_tfdbg_disk_bytes_limit == nullptr || @@ -707,7 +708,6 @@ bool DebugFileIO::requestDiskByteUsage(uint64 bytes) { if (bytes == 0) { return true; } - mutex_lock l(bytes_mu); if (diskBytesUsed + bytes < globalDiskBytesLimit) { diskBytesUsed += bytes; return true; diff --git a/tensorflow/core/framework/dataset_stateful_op_whitelist.h b/tensorflow/core/framework/dataset_stateful_op_whitelist.h index 21c21723d0..74bd39cb61 100644 --- a/tensorflow/core/framework/dataset_stateful_op_whitelist.h +++ b/tensorflow/core/framework/dataset_stateful_op_whitelist.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_WHITELIST_H_ #define TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_WHITELIST_H_ +#include <unordered_set> #include "tensorflow/core/lib/core/status.h" namespace tensorflow { @@ -24,27 +25,26 @@ namespace data { // See below macro for usage details. class WhitelistedStatefulOpRegistry { public: - Status Add(StringPiece op_name) { - op_names_.insert(op_name); + Status Add(string op_name) { + op_names_.insert(std::move(op_name)); return Status::OK(); } - bool Contains(StringPiece op_name) { - return op_names_.find(op_name) != op_names_.end(); - } + bool Contains(const string& op_name) { return op_names_.count(op_name); } static WhitelistedStatefulOpRegistry* Global() { - static WhitelistedStatefulOpRegistry* reg = - new WhitelistedStatefulOpRegistry; + static auto* reg = new WhitelistedStatefulOpRegistry; return reg; } private: - WhitelistedStatefulOpRegistry() {} - WhitelistedStatefulOpRegistry(WhitelistedStatefulOpRegistry const& copy); + WhitelistedStatefulOpRegistry() = default; + WhitelistedStatefulOpRegistry(WhitelistedStatefulOpRegistry const& copy) = + delete; WhitelistedStatefulOpRegistry operator=( - WhitelistedStatefulOpRegistry const& copy); - std::set<StringPiece> op_names_; + WhitelistedStatefulOpRegistry const& copy) = delete; + + std::unordered_set<string> op_names_; }; } // namespace data diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index 6710ff9df3..d24e7e8ee4 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -429,18 +429,22 @@ class SymbolicShapeRefiner { // perform shape inference on the function body. // // Propagate shape information of final function body node - // to function node `node`. + // to function node `function_node`. // - // In the event of an error, UpdateNode will simply set `node`'s + // In the event of an error, UpdateNode will simply set `function_node`'s // output shape to be Unknown. - Status UpdateFunction(const NodeDef* node) { - auto it = fun_to_grappler_function_item_.find(node->op()); + Status UpdateFunction(const NodeDef* function_node) { + auto it = fun_to_grappler_function_item_.find(function_node->op()); if (it == fun_to_grappler_function_item_.end()) { return errors::InvalidArgument( - node->op(), " was not previously added to SymbolicShapeRefiner."); + function_node->op(), + " was not previously added to SymbolicShapeRefiner."); } - GrapplerFunctionItem& grappler_function_item = it->second; + // Copy (not reference) so that changes we make here (e.g., replacing + // Placeholder with Const) don't affect one in + // fun_to_grappler_function_item_. + GrapplerFunctionItem grappler_function_item = it->second; GraphView gv(&grappler_function_item.graph); // Forward shapes from function input nodes to argument nodes. @@ -453,7 +457,7 @@ class SymbolicShapeRefiner { "supported."); } NodeDef* fun_node = gv.GetNode(fun_input.input_name); - const string& input = node->input(i); + const string& input = function_node->input(i); const string& node_name = NodeName(input); if (IsControlInput(input)) { @@ -478,16 +482,35 @@ class SymbolicShapeRefiner { TensorShapeProto proto; const auto& handle = input_inference_context->output(output_port_num); input_inference_context->ShapeHandleToProto(handle, &proto); + // There may be dim.size < -1 in SymbolicShapeRefiner. Change those to -1. + for (int i = 0; i < proto.dim_size(); i++) { + if (proto.dim(i).size() < -1) { + proto.mutable_dim(i)->set_size(-1); + } + } *attr_output_shape.mutable_shape() = proto; (*fun_node->mutable_attr())["shape"] = attr_output_shape; } + // Replace input Placeholders with Consts, if values are known. Note that + // we don't check exceptions here as it's done in the above loop. + for (int i = grappler_function_item.inputs().size() - 1; i >= 0; --i) { + const string& input = function_node->input(i); + const string& node_name = NodeName(input); + NodeDef* input_node = graph_.GetNode(node_name); + // TODO(dyoon): also use Const when output_tensors_as_shape is available. + if (IsConstant(*input_node)) { + TF_CHECK_OK( + ReplaceInputWithConst(*input_node, i, &grappler_function_item)); + } + } + // Perform inference on function body. GraphProperties gp(grappler_function_item); TF_RETURN_IF_ERROR(gp.InferStatically(true)); // Add return nodes for output shapes. - auto ic = GetContext(node); + auto ic = GetContext(function_node); int output = 0; for (auto const& out_arg : grappler_function_item.outputs()) { if (out_arg.output_tensors.size() > 1) { @@ -505,8 +528,9 @@ class SymbolicShapeRefiner { const NodeDef* retnode = gv.GetNode(node_name); if (retnode == nullptr) { - return errors::FailedPrecondition("Unable to find return node ", - node_name, " for ", node->name()); + return errors::FailedPrecondition( + "Unable to find return function_node ", node_name, " for ", + function_node->name()); } auto output_properties = gp.GetOutputProperties(retnode->name()); @@ -671,11 +695,13 @@ class SymbolicShapeRefiner { // true, as the updates to the call node will have changed, even if it's // the same function being called twice with the same input shapes. // Example: simple_function.pbtxt - if (UpdateFunction(node).ok()) { + auto s = UpdateFunction(node); + if (s.ok()) { return Status::OK(); } else { VLOG(1) << "UpdateFunction failed for " << node->op() - << ". Defaulting to ShapeUnknown."; + << ". Defaulting to ShapeUnknown.\n" + << s.ToString(); } } diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc index 8938b7c32e..3ec68a4e59 100644 --- a/tensorflow/core/grappler/costs/graph_properties_test.cc +++ b/tensorflow/core/grappler/costs/graph_properties_test.cc @@ -785,7 +785,58 @@ TEST_F(GraphPropertiesTest, InferRestoreOpShape_WithTwoNodesShareSameOutput) { EXPECT_EQ("float: [128,256]", PropToString(prop)); } -TEST_F(GraphPropertiesTest, FunctionWithScalarInputTest) { +TEST_F(GraphPropertiesTest, FunctionWithConstInput) { + FunctionDefLibrary library; + // This function is simply + // out = Fill(shape, value), but + // Fill requires values in the shape input, not just shape of it, to infer + // output shape; hence, func + *library.add_function() = FunctionDefHelper::Create( + // Name + "MyFillFunc", + // Inputs + {"shape: int32", "value: float"}, + // Outputs + {"out: float"}, + // Attrs + {}, + // Nodes + { + {{"a"}, + "Fill", + {"shape", "value"}, + {{"T", DataType::DT_FLOAT}, {"index_type", DataType::DT_INT32}}}, + }, + // Returns + {{"out", "a:output:0"}}); + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + TF_CHECK_OK(s.graph()->AddFunctionLibrary(library)); + Output shape = ops::Const(s.WithOpName("shape"), {1, 2, 3, 4}); + Output value = ops::Const(s.WithOpName("value"), 0.1f, {}); + auto builder = tensorflow::NodeBuilder("MyFillFunc", "MyFillFunc", + s.graph()->op_registry()); + tensorflow::Node* func_op; + auto _shape = tensorflow::ops::AsNodeOut(s, shape); + auto _value = tensorflow::ops::AsNodeOut(s, value); + TF_CHECK_OK( + builder.Input(_shape).Input(_value).Finalize(s.graph(), &func_op)); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + GraphProperties properties(item); + TF_CHECK_OK(properties.InferStatically(false)); + const auto out_props = properties.GetOutputProperties("MyFillFunc"); + const OpInfo::TensorProperties out_prop0 = out_props[0]; + EXPECT_EQ(DT_FLOAT, out_prop0.dtype()); + EXPECT_FALSE(out_prop0.shape().unknown_rank()); + EXPECT_EQ(4, out_prop0.shape().dim_size()); + EXPECT_EQ(1, out_prop0.shape().dim(0).size()); + EXPECT_EQ(2, out_prop0.shape().dim(1).size()); + EXPECT_EQ(3, out_prop0.shape().dim(2).size()); + EXPECT_EQ(4, out_prop0.shape().dim(3).size()); +} + +TEST_F(GraphPropertiesTest, FunctionWithScalarInput) { // Create graph with a function that takes a scalar value so that we use // Placeholder with scalar as for input to the function shape inference. // Placeholder -> Identity -> MyFunc, where MyFunc simply takes Identity of @@ -818,7 +869,7 @@ TEST_F(GraphPropertiesTest, FunctionWithScalarInputTest) { // MyFunc output shouldn't be unknown rank. GraphProperties properties(item); - TF_CHECK_OK(properties.InferStatically(false)); + TF_CHECK_OK(properties.InferStatically(true)); const auto out_props = properties.GetOutputProperties("MyFunc"); const OpInfo::TensorProperties out_prop0 = out_props[0]; EXPECT_EQ(DT_FLOAT, out_prop0.dtype()); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 65947ddce5..11ce121cba 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -1121,11 +1121,8 @@ class RemoveIdentityTranspose : public ArithmeticOptimizerStage { Status TrySimplify(NodeDef* node, string* simplified_node_name) override { TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node)); NodeDef* tail = node; - // TODO(rmlarsen): Enable after debugging breakage in Bayesflow. - if (ctx().opt_level == RewriterConfig::AGGRESSIVE) { - tail = GetTailOfIdempotentChain(*tail, *ctx().node_map, - *ctx().nodes_to_preserve); - } + tail = GetTailOfIdempotentChain(*tail, *ctx().node_map, + *ctx().nodes_to_preserve); NodeDef* first_transpose; TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &first_transpose)); diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index 5fd34efeb1..a5fd33d28b 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -156,7 +156,7 @@ Status MetaOptimizer::InitializeOptimizers( optimizers->push_back(MakeUnique<ScopedAllocatorOptimizer>( cfg_.scoped_allocator_optimization(), cfg_.scoped_allocator_opts())); } - return Status::OK(); + return InitializeCustomGraphOptimizers(optimizers); } Status MetaOptimizer::InitializeOptimizersByName( @@ -180,6 +180,11 @@ Status MetaOptimizer::InitializeOptimizersByName( VLOG(2) << "Can't register an optimizer by name: " << optimizer_name; } } + return InitializeCustomGraphOptimizers(optimizers); +} + +Status MetaOptimizer::InitializeCustomGraphOptimizers( + std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const { for (const auto& optimizer_config : cfg_.custom_optimizers()) { auto custom_optimizer = CustomGraphOptimizerRegistry::CreateByNameOrNull( optimizer_config.name()); @@ -208,7 +213,7 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item, } std::vector<std::unique_ptr<GraphOptimizer>> optimizers; - if (cfg_.optimizers().empty() && cfg_.custom_optimizers().empty()) { + if (cfg_.optimizers().empty()) { TF_RETURN_IF_ERROR(InitializeOptimizers(&optimizers)); } else { TF_RETURN_IF_ERROR(InitializeOptimizersByName(&optimizers)); diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.h b/tensorflow/core/grappler/optimizers/meta_optimizer.h index 151a54cbdf..831c5e37c0 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.h +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.h @@ -52,6 +52,9 @@ class MetaOptimizer : public GraphOptimizer { // Initialize active optimizers from RewriterConfig optimizer names. Status InitializeOptimizersByName( std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const; + // Initialize active optimizers from RewriterConfig.custom_optimizers. + Status InitializeCustomGraphOptimizers( + std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const; // Run optimization pass over a single GrapplerItem. Meta optimizer might run // multiple such passes: 1) for the main graph 2) for the function library diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc index 9a03c7dfef..e74e0f7501 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc @@ -64,6 +64,13 @@ bool TestOptimizer::optimized_; REGISTER_GRAPH_OPTIMIZER(TestOptimizer); +class TestGraphOptimizer : public TestOptimizer { + public: + string name() const override { return "test_graph_optimizer"; } +}; + +REGISTER_GRAPH_OPTIMIZER(TestGraphOptimizer); + class MetaOptimizerTest : public GrapplerTest {}; TEST_F(MetaOptimizerTest, RunsCustomOptimizer) { @@ -83,6 +90,27 @@ TEST_F(MetaOptimizerTest, RunsCustomOptimizer) { EXPECT_TRUE(TestOptimizer::IsOptimized()); } +TEST_F(MetaOptimizerTest, RunsCustomOptimizerAndCustomGraphOptimizer) { + TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"}); + GrapplerItem item; + CHECK(fake_input.NextItem(&item)); + + TestOptimizer::SetOptimized(false); + TestGraphOptimizer::SetOptimized(false); + RewriterConfig rewriter_config; + rewriter_config.add_optimizers("TestOptimizer"); + auto customGraphOptimizer = rewriter_config.add_custom_optimizers(); + customGraphOptimizer->set_name("TestGraphOptimizer"); + rewriter_config.set_min_graph_nodes(-1); + + MetaOptimizer optimizer(nullptr, rewriter_config); + GraphDef output; + const Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + EXPECT_TRUE(TestOptimizer::IsOptimized()); + EXPECT_TRUE(TestGraphOptimizer::IsOptimized()); +} + TEST_F(MetaOptimizerTest, RunOptimizersTwice) { TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"}); GrapplerItem item; @@ -98,6 +126,24 @@ TEST_F(MetaOptimizerTest, RunOptimizersTwice) { TF_EXPECT_OK(status); } +TEST_F(MetaOptimizerTest, RunToggleOptimizersAndCustomGraphOptimizerTwice) { + TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"}); + GrapplerItem item; + CHECK(fake_input.NextItem(&item)); + + RewriterConfig rewriter_config; + auto customGraphOptimizer = rewriter_config.add_custom_optimizers(); + customGraphOptimizer->set_name("TestGraphOptimizer"); + rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO); + rewriter_config.set_min_graph_nodes(-1); + + MetaOptimizer optimizer(nullptr, rewriter_config); + GraphDef output; + const Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + EXPECT_TRUE(TestGraphOptimizer::IsOptimized()); +} + TEST_F(MetaOptimizerTest, OptimizeFunctionLibrary) { using test::function::NDef; diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 25063ac823..972fb9efa9 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -643,14 +643,7 @@ cc_library( ":split_v_op", ":strided_slice_op", ":tile_ops", - ] + if_mkl( - [ - ":mkl_transpose_op", - ], - [ - ":transpose_op", - ], - ) + [ + ":transpose_op", ":unique_op", ":unpack_op", ":unravel_index_op", @@ -893,24 +886,13 @@ tf_kernel_library( deps = ARRAY_DEPS, ) -if_mkl( - [tf_mkl_kernel_library( - name = "mkl_transpose_op", - srcs = [ - "mkl_transpose_op.cc", - "transpose_op.cc", - ], - hdrs = ["transpose_op.h"], - deps = ARRAY_DEPS + mkl_deps(), - )], - [tf_kernel_library( - name = "transpose_op", - srcs = [ - "transpose_op.cc", - ], - hdrs = ["transpose_op.h"], - deps = ARRAY_DEPS, - )], +tf_kernel_library( + name = "transpose_op", + srcs = [ + "transpose_op.cc", + ], + hdrs = ["transpose_op.h"], + deps = ARRAY_DEPS + if_mkl([":mkl_transpose_op"]), ) tf_kernel_library( @@ -6351,6 +6333,15 @@ tf_mkl_kernel_library( deps = NN_DEPS + mkl_deps() + [":cwise_op"], ) +tf_mkl_kernel_library( + name = "mkl_transpose_op", + srcs = [ + "mkl_transpose_op.cc", + ], + hdrs = ["transpose_op.h"], + deps = ARRAY_DEPS + mkl_deps(), +) + # NOTE(lespeholt): This rule is deprecated, please use: # tensorflow/core/util/batch_util.h cc_library( diff --git a/tensorflow/core/kernels/conditional_accumulator.h b/tensorflow/core/kernels/conditional_accumulator.h index a7836896c7..390db8fe5a 100644 --- a/tensorflow/core/kernels/conditional_accumulator.h +++ b/tensorflow/core/kernels/conditional_accumulator.h @@ -51,9 +51,11 @@ class ConditionalAccumulator // dtype: The datatype of the gradients to be accumulated. // shape: The shape of the accumulated gradients. // name: A name to use for the ConditionalAccumulator. + // reduction_type: The reduction type, i.e., MEAN or SUM ConditionalAccumulator(const DataType& dtype, const PartialTensorShape& shape, - const string& name) - : TypedConditionalAccumulatorBase<const Tensor>(dtype, shape, name) {} + const string& name, const string& reduction_type) + : TypedConditionalAccumulatorBase<const Tensor>(dtype, shape, name, + reduction_type) {} ~ConditionalAccumulator() override{}; protected: diff --git a/tensorflow/core/kernels/conditional_accumulator_base.cc b/tensorflow/core/kernels/conditional_accumulator_base.cc index 90593c56b8..292cf0cd64 100644 --- a/tensorflow/core/kernels/conditional_accumulator_base.cc +++ b/tensorflow/core/kernels/conditional_accumulator_base.cc @@ -14,12 +14,17 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/kernels/conditional_accumulator_base.h" +#include "tensorflow/core/lib/core/errors.h" namespace tensorflow { ConditionalAccumulatorBase::ConditionalAccumulatorBase( - const DataType& dtype, const PartialTensorShape& shape, const string& name) - : dtype_(dtype), shape_(shape), name_(name) { + const DataType& dtype, const PartialTensorShape& shape, const string& name, + const string& reduction_type) + : dtype_(dtype), + shape_(shape), + name_(name), + reduction_type_(reduction_type) { counter_ = 0; current_global_step_ = 0; } @@ -190,7 +195,9 @@ bool ConditionalAccumulatorBase::TakeGradLockedHelper(OpKernelContext* ctx, current_global_step_++; // Average the accumulated gradient - DivideAccumGradByCounter(ctx); + if (reduction_type_ == "MEAN") { + DivideAccumGradByCounter(ctx); + } // Set output for accumulated gradient tensor bool successful_set_output = SetOutput(ctx); diff --git a/tensorflow/core/kernels/conditional_accumulator_base.h b/tensorflow/core/kernels/conditional_accumulator_base.h index b7b7482a00..4a5ec6f0fb 100644 --- a/tensorflow/core/kernels/conditional_accumulator_base.h +++ b/tensorflow/core/kernels/conditional_accumulator_base.h @@ -52,7 +52,7 @@ class ConditionalAccumulatorBase : public ResourceBase { // name: A name to use for the ConditionalAccumulator. ConditionalAccumulatorBase(const DataType& dtype, const PartialTensorShape& shape, - const string& name); + const string& name, const string& reduction_type); typedef AsyncOpKernel::DoneCallback DoneCallback; @@ -125,6 +125,7 @@ class ConditionalAccumulatorBase : public ResourceBase { const DataType dtype_; const PartialTensorShape shape_; const string name_; + const string reduction_type_; mutex mu_; int counter_ GUARDED_BY(mu_); int64 current_global_step_ GUARDED_BY(mu_); diff --git a/tensorflow/core/kernels/conditional_accumulator_base_op.h b/tensorflow/core/kernels/conditional_accumulator_base_op.h index 012a0dcc12..ca24d690f8 100644 --- a/tensorflow/core/kernels/conditional_accumulator_base_op.h +++ b/tensorflow/core/kernels/conditional_accumulator_base_op.h @@ -51,6 +51,8 @@ class ConditionalAccumulatorBaseOp : public OpKernel { &accumulator_handle_, nullptr)); OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_)); OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_)); + OP_REQUIRES_OK(context, + context->GetAttr("reduction_type", &reduction_type_)); } void Compute(OpKernelContext* ctx) override { @@ -81,6 +83,7 @@ class ConditionalAccumulatorBaseOp : public OpKernel { DataType dtype_; PartialTensorShape shape_; ContainerInfo cinfo_; + string reduction_type_; private: Status SetAccumulatorHandle(OpKernelContext* ctx) diff --git a/tensorflow/core/kernels/conditional_accumulator_op.cc b/tensorflow/core/kernels/conditional_accumulator_op.cc index e13bf8a4c6..52ac51a9b6 100644 --- a/tensorflow/core/kernels/conditional_accumulator_op.cc +++ b/tensorflow/core/kernels/conditional_accumulator_op.cc @@ -34,7 +34,8 @@ class ConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp { Creator GetCreator() const override { return [this](ConditionalAccumulatorBase** ret) { ConditionalAccumulator<Device, T>* accumulator = - new ConditionalAccumulator<Device, T>(dtype_, shape_, cinfo_.name()); + new ConditionalAccumulator<Device, T>(dtype_, shape_, cinfo_.name(), + reduction_type_); *ret = accumulator; return Status::OK(); }; diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc index 306486b96a..af301e2b42 100644 --- a/tensorflow/core/kernels/data/map_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_dataset_op.cc @@ -28,9 +28,7 @@ namespace { class MapDatasetOp : public UnaryDatasetOpKernel { public: - explicit MapDatasetOp(OpKernelConstruction* ctx) - : UnaryDatasetOpKernel(ctx), - graph_def_version_(ctx->graph_def_version()) { + explicit MapDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); @@ -186,7 +184,6 @@ class MapDatasetOp : public UnaryDatasetOpKernel { const std::vector<PartialTensorShape> output_shapes_; }; - const int graph_def_version_; DataTypeVector output_types_; std::vector<PartialTensorShape> output_shapes_; NameAttrList func_; diff --git a/tensorflow/core/kernels/data/map_defun_op.cc b/tensorflow/core/kernels/data/map_defun_op.cc index 3c562fc7f3..b87d61ee44 100644 --- a/tensorflow/core/kernels/data/map_defun_op.cc +++ b/tensorflow/core/kernels/data/map_defun_op.cc @@ -18,7 +18,9 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_util.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/util/batch_util.h" #include "tensorflow/core/util/reffed_status_callback.h" @@ -60,26 +62,43 @@ class MapDefunOp : public AsyncOpKernel { ~MapDefunOp() override {} + Status GetInputBatchSize(OpKernelContext* ctx, int64* batch_size) { + // Validates inputs and gets the size of their leading dimension. + *batch_size = ctx->input(0).dims() > 0 ? ctx->input(0).dim_size(0) : -1; + for (size_t i = 0; i < ctx->num_inputs(); ++i) { + if (ctx->input(i).dims() == 0) { + return errors::InvalidArgument( + "All inputs must have rank at least 1. Input ", i, + " has a rank of 0."); + } else if (ctx->input(i).dim_size(0) != *batch_size) { + return errors::InvalidArgument( + "All inputs must have the same dimension 0. Input ", i, + " has leading dimension ", ctx->input(i).dim_size(0), + ", while all previous inputs have leading dimension ", batch_size); + } + } + return Status::OK(); + } + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { - int64 batch_size = ctx->input(0).dim_size(0); + int64 batch_size; + OP_REQUIRES_OK_ASYNC(ctx, GetInputBatchSize(ctx, &batch_size), done); + // Inputs auto* args = new std::vector<Tensor>; auto* arg_shapes = new std::vector<TensorShape>; + + // Create a copy because every `Compute` may have different output shapes. + auto* output_shapes = new std::vector<PartialTensorShape>(output_shapes_); arg_shapes->reserve(ctx->num_inputs()); args->reserve(ctx->num_inputs()); + auto* mu = new mutex; + for (size_t i = 0; i < ctx->num_inputs(); ++i) { args->push_back(ctx->input(i)); arg_shapes->push_back(ctx->input(i).shape()); arg_shapes->at(i).RemoveDim(0); // Remove the first batch dimension - OP_REQUIRES_ASYNC( - ctx, batch_size == ctx->input(i).dim_size(0), - errors::InvalidArgument( - "All inputs must have the same dimension 0. Input ", i, - " has leading dimension ", ctx->input(i).dim_size(0), - ", while all previous inputs have leading dimension ", batch_size, - "."), - done); } // Outputs @@ -87,10 +106,14 @@ class MapDefunOp : public AsyncOpKernel { OP_REQUIRES_OK_ASYNC(ctx, ctx->output_list("output", output), done); for (size_t i = 0; i < output_types().size(); ++i) { - Tensor* out = nullptr; - TensorShape output_shape = output_shapes_.at(i); - output_shape.InsertDim(0, batch_size); - OP_REQUIRES_OK_ASYNC(ctx, output->allocate(i, output_shape, &out), done); + if (output_shapes_.at(i).IsFullyDefined()) { + Tensor* out = nullptr; + TensorShape output_shape; + output_shapes_.at(i).AsTensorShape(&output_shape); + output_shape.InsertDim(0, batch_size); + OP_REQUIRES_OK_ASYNC(ctx, output->allocate(i, output_shape, &out), + done); + } } SetRunOptions(ctx, &opts_, false); @@ -98,15 +121,19 @@ class MapDefunOp : public AsyncOpKernel { // Run loop StatusCallback callback = std::bind( [](OpKernelContext* ctx, std::vector<Tensor>* args, - std::vector<TensorShape>* arg_shapes, OpOutputList* output, - DoneCallback& done, const Status& status) { + std::vector<TensorShape>* arg_shapes, + std::vector<PartialTensorShape>* output_shapes, OpOutputList* output, + mutex* mu, DoneCallback& done, const Status& status) { delete args; delete arg_shapes; delete output; + delete output_shapes; + delete mu; ctx->SetStatus(status); done(); }, - ctx, args, arg_shapes, output, std::move(done), std::placeholders::_1); + ctx, args, arg_shapes, output_shapes, output, mu, std::move(done), + std::placeholders::_1); auto* refcounted = new ReffedStatusCallback(std::move(callback)); @@ -114,9 +141,11 @@ class MapDefunOp : public AsyncOpKernel { // Start from i = 1 because refcounted is initialized with refcount = 1 refcounted->Ref(); } + for (size_t i = 0; i < static_cast<size_t>(batch_size); ++i) { - auto* call_frame = - new MapFunctionCallFrame(*args, *arg_shapes, output, this, i); + auto* call_frame = new MapFunctionCallFrame( + *args, *arg_shapes, output_shapes, mu, output, this, i, + static_cast<size_t>(batch_size)); CancellationManager* c_mgr = new CancellationManager; opts_.cancellation_manager = c_mgr; ctx->function_library()->Run( @@ -133,18 +162,23 @@ class MapDefunOp : public AsyncOpKernel { private: FunctionLibraryRuntime::Handle func_handle_; FunctionLibraryRuntime::Options opts_; - std::vector<TensorShape> output_shapes_; + std::vector<PartialTensorShape> output_shapes_; class MapFunctionCallFrame : public CallFrameInterface { public: MapFunctionCallFrame(const std::vector<Tensor>& args, const std::vector<TensorShape>& arg_shapes, - OpOutputList* output, OpKernel* kernel, size_t iter) + std::vector<PartialTensorShape>* output_shapes, + mutex* output_shapes_mutex, OpOutputList* output, + OpKernel* kernel, size_t iter, size_t batch_size) : args_(args), arg_shapes_(arg_shapes), + output_shapes_(output_shapes), + output_shapes_mutex_(output_shapes_mutex), output_(output), kernel_(kernel), - iter_(iter) {} + iter_(iter), + batch_size_(batch_size) {} ~MapFunctionCallFrame() override {} @@ -182,15 +216,37 @@ class MapDefunOp : public AsyncOpKernel { "output: ", index); } + { // Locking scope + mutex_lock l(*output_shapes_mutex_); + if (!output_shapes_->at(index).IsCompatibleWith(val.shape())) { + return errors::InvalidArgument( + "Mismatch in function retval shape, ", val.shape(), + ", and expected output shape,", + output_shapes_->at(index).DebugString(), "."); + } + if (!output_shapes_->at(index).IsFullyDefined()) { + // Given val, we have new information about the output shape at + // this index. Store the shape and allocate the output accordingly. + output_shapes_->at(index) = val.shape(); + + Tensor* out = nullptr; + TensorShape actual_shape = val.shape(); + actual_shape.InsertDim(0, batch_size_); + TF_RETURN_IF_ERROR(output_->allocate(index, actual_shape, &out)); + } + } return batch_util::CopyElementToSlice(val, (*output_)[index], iter_); } private: const std::vector<Tensor>& args_; const std::vector<TensorShape>& arg_shapes_; + std::vector<PartialTensorShape>* output_shapes_; + mutex* output_shapes_mutex_; OpOutputList* output_; const OpKernel* kernel_; const size_t iter_; + const size_t batch_size_; }; }; diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc index f8287cf0e3..640f1565b7 100644 --- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include <deque> +#include <utility> #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/partial_tensor_shape.h" @@ -21,6 +22,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset.h" #include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/lib/core/error_codes.pb.h" +#include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/random/random.h" @@ -34,8 +36,7 @@ namespace { class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { public: explicit ParallelInterleaveDatasetOp(OpKernelConstruction* ctx) - : UnaryDatasetOpKernel(ctx), - graph_def_version_(ctx->graph_def_version()) { + : UnaryDatasetOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &interleave_func_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); @@ -125,6 +126,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { const DataTypeVector& output_dtypes() const override { return output_types_; } + const std::vector<PartialTensorShape>& output_shapes() const override { return output_shapes_; } @@ -1058,7 +1060,6 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { const std::vector<PartialTensorShape> output_shapes_; }; - const int graph_def_version_; DataTypeVector output_types_; std::vector<PartialTensorShape> output_shapes_; NameAttrList interleave_func_; @@ -1067,6 +1068,593 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDataset").Device(DEVICE_CPU), ParallelInterleaveDatasetOp); +// The motivation for creating an alternative implementation of parallel +// interleave is to decouple the degree of parallelism from the cycle length. +// This makes it possible to change the degree of parallelism (e.g. through +// auto-tuning) without changing the cycle length (which would change the order +// in which elements are produced). +// +// Furthermore, this class favors modularity over extended functionality. In +// particular, it refrains from implementing configurable buffering of output +// elements and prefetching of input iterators, relying on other parts of +// tf.data to provide this functionality if necessary. +// +// The above design choices were made with automated optimizations in mind, +// isolating the degree of parallelism as the single tunable knob of this +// implementation. +class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { + public: + explicit ParallelInterleaveDatasetV2Op(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &interleave_func_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + } + + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + OpInputList inputs; + OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs)); + + int64 cycle_length = 0; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "cycle_length", &cycle_length)); + OP_REQUIRES(ctx, cycle_length > 0, + errors::InvalidArgument("`cycle_length` must be > 0")); + + int64 block_length = 0; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "block_length", &block_length)); + OP_REQUIRES(ctx, block_length > 0, + errors::InvalidArgument("`block_length` must be > 0")); + + int64 num_parallel_calls; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls", + &num_parallel_calls)); + OP_REQUIRES(ctx, num_parallel_calls > 0, + errors::InvalidArgument( + "num_parallel_calls must be greater than zero.")); + OP_REQUIRES( + ctx, num_parallel_calls <= cycle_length, + errors::InvalidArgument( + "num_parallel_calls must less than or equal to cycle_length.")); + + // TODO(b/114267189): Use `other_arguments(inputs.begin(), inputs.end());`. + std::vector<Tensor> other_arguments; + other_arguments.reserve(inputs.size()); + for (const Tensor& t : inputs) { + other_arguments.push_back(t); + } + std::unique_ptr<CapturedFunction> captured_func; + OP_REQUIRES_OK( + ctx, CapturedFunction::Create( + interleave_func_, std::move(other_arguments), &captured_func)); + + *output = new Dataset(ctx, input, interleave_func_, + std::move(captured_func), cycle_length, block_length, + num_parallel_calls, output_types_, output_shapes_); + } + + private: + class Dataset : public DatasetBase { + public: + Dataset(OpKernelContext* ctx, const DatasetBase* input, + const NameAttrList& func, + std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length, + int64 block_length, int64 num_parallel_calls, + const DataTypeVector& output_types, + const std::vector<PartialTensorShape>& output_shapes) + : DatasetBase(DatasetContext(ctx)), + input_(input), + interleave_func_(func), + captured_func_(std::move(captured_func)), + cycle_length_(cycle_length), + block_length_(block_length), + num_parallel_calls_(num_parallel_calls), + output_types_(output_types), + output_shapes_(output_shapes) { + input_->Ref(); + } + + ~Dataset() override { input_->Unref(); } + + std::unique_ptr<IteratorBase> MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr<IteratorBase>(new Iterator( + {this, strings::StrCat(prefix, "::ParallelInterleaveV2")})); + } + + const DataTypeVector& output_dtypes() const override { + return output_types_; + } + + const std::vector<PartialTensorShape>& output_shapes() const override { + return output_shapes_; + } + + string DebugString() const override { + return "ParallelInterleaveDatasetV2Op::Dataset"; + } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + TF_RETURN_IF_ERROR(b->AddFunction(ctx, interleave_func_.name())); + Node* input_node; + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node)); + Node* cycle_length_node; + TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node)); + Node* block_length_node; + TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node)); + Node* num_parallel_calls_node; + TF_RETURN_IF_ERROR( + b->AddScalar(num_parallel_calls_, &num_parallel_calls_node)); + DataTypeVector other_arguments_types; + other_arguments_types.reserve(captured_func_->captured_inputs().size()); + std::vector<Node*> other_arguments; + other_arguments.reserve(captured_func_->captured_inputs().size()); + for (const Tensor& t : captured_func_->captured_inputs()) { + Node* node; + TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + other_arguments.emplace_back(node); + other_arguments_types.emplace_back(t.dtype()); + } + AttrValue f; + b->BuildAttrValue(interleave_func_, &f); + AttrValue other_arguments_types_attr; + b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr); + + TF_RETURN_IF_ERROR(b->AddDataset( + this, + {{0, input_node}, + {2, cycle_length_node}, + {3, block_length_node}, + {4, num_parallel_calls_node}}, + {{1, other_arguments}}, + {{"f", f}, {"Targuments", other_arguments_types_attr}}, output)); + return Status::OK(); + } + + private: + class Iterator : public DatasetIterator<Dataset> { + public: + explicit Iterator(const Params& params) + : DatasetIterator<Dataset>(params), + args_list_(params.dataset->cycle_length_), + current_elements_(params.dataset->cycle_length_), + element_in_use_(params.dataset->cycle_length_, false), + thread_pool_(new thread::ThreadPool( + Env::Default(), ThreadOptions(), "parallel_interleave", + dataset()->cycle_length_ /* num_threads */, + false /* low_latency_hint */)) {} + + ~Iterator() override { + mutex_lock l(mu_); + // Cancel the runner thread. + cancelled_ = true; + cond_var_.notify_all(); + // Wait for all in-flight calls to complete. + while (num_calls_ > 0) { + cond_var_.wait(l); + } + } + + Status Initialize(IteratorContext* ctx) override { + TF_RETURN_IF_ERROR( + dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); + return dataset()->captured_func_->Instantiate(ctx); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector<Tensor>* out_tensors, + bool* end_of_sequence) override { + std::shared_ptr<InvocationResult> result; + do { + { + mutex_lock l(mu_); + EnsureRunnerThreadStarted(ctx); + while (invocation_results_.empty() && + (!end_of_input_ || num_open_ > 0)) { + cond_var_.wait(l); + } + if (!invocation_results_.empty()) { + std::swap(result, invocation_results_.front()); + invocation_results_.pop_front(); + } else { + *end_of_sequence = true; + return Status::OK(); + } + } + cond_var_.notify_all(); + result->notification.WaitForNotification(); + } while (result->skip); + + if (result->status.ok()) { + *out_tensors = std::move(result->return_values); + } + *end_of_sequence = false; + return result->status; + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + // Wait for all in-flight calls to complete. + while (num_calls_ > 0) { + cond_var_.wait(l); + } + CHECK_EQ(num_calls_, 0); + TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name("invocation_results.size"), invocation_results_.size())); + for (size_t i = 0; i < invocation_results_.size(); i++) { + std::shared_ptr<InvocationResult> result = invocation_results_[i]; + TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result->status)); + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name(strings::StrCat("invocation_results[", i, "].size")), + result->return_values.size())); + for (size_t j = 0; j < result->return_values.size(); j++) { + TF_RETURN_IF_ERROR(writer->WriteTensor( + full_name( + strings::StrCat("invocation_results[", i, "][", j, "]")), + result->return_values[j])); + } + if (result->skip) { + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name(strings::StrCat("invocation_results[", i, "].skip")), + "")); + } + } + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("cycle_index"), cycle_index_)); + if (end_of_input_) { + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("end_of_input"), "")); + } + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("num_open"), num_open_)); + TF_RETURN_IF_ERROR(WriteCurrentElements(writer)); + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); + int64 invocation_results_size; + TF_RETURN_IF_ERROR(reader->ReadScalar( + full_name("invocation_results.size"), &invocation_results_size)); + for (size_t i = 0; i < invocation_results_size; i++) { + std::shared_ptr<InvocationResult> result(new InvocationResult()); + invocation_results_.push_back(result); + TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result->status)); + size_t num_return_values; + { + int64 size; + TF_RETURN_IF_ERROR(reader->ReadScalar( + full_name(strings::StrCat("invocation_results[", i, "].size")), + &size)); + num_return_values = static_cast<size_t>(size); + if (num_return_values != size) { + return errors::InvalidArgument(strings::StrCat( + full_name( + strings::StrCat("invocation_results[", i, "].size")), + ": ", size, " is not a valid value of type size_t.")); + } + } + result->return_values.reserve(num_return_values); + for (size_t j = 0; j < num_return_values; j++) { + result->return_values.emplace_back(); + TF_RETURN_IF_ERROR( + reader->ReadTensor(full_name(strings::StrCat( + "invocation_results[", i, "][", j, "]")), + &result->return_values.back())); + } + result->skip = reader->Contains( + full_name(strings::StrCat("invocation_results[", i, "].skip"))); + result->notification.Notify(); + } + TF_RETURN_IF_ERROR( + reader->ReadScalar(full_name("cycle_index"), &cycle_index_)); + if (reader->Contains(full_name("end_of_input"))) end_of_input_ = true; + TF_RETURN_IF_ERROR( + reader->ReadScalar(full_name("num_open"), &num_open_)); + TF_RETURN_IF_ERROR(ReadCurrentElements(ctx, reader)); + return Status::OK(); + } + + private: + struct InvocationResult { + Notification notification; // used for coordination with the consumer + Status status; // the invocation status + std::vector<Tensor> return_values; // the invocation result values + bool skip; // if set the result should be skipped + }; + + void EnsureRunnerThreadStarted(IteratorContext* ctx) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (!runner_thread_) { + std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx)); + runner_thread_.reset(ctx->env()->StartThread( + {}, "runner_thread", + [this, new_ctx]() { RunnerThread(new_ctx); })); + } + } + + // Fetches up to `results.size()` outputs from the cycle element at + // position `cycle_index`. + // + // If end of input is encountered, the `skip` field of the invocation + // result is used to identify results that should be skipped. + void FetchOutputs( + const std::shared_ptr<IteratorContext>& ctx, int64 cycle_index, + const std::vector<std::shared_ptr<InvocationResult>>& results) + LOCKS_EXCLUDED(mu_) { + bool end_of_input = false; + for (auto& result : results) { + if (!end_of_input) { + result->status = current_elements_[cycle_index]->GetNext( + ctx.get(), &result->return_values, &end_of_input); + } + if (end_of_input) { + result->skip = true; + } + result->notification.Notify(); + if (!result->status.ok()) { + break; + } + } + + // Release the ownership of the cycle element iterator, closing the + // iterator if end of input was encountered. + { + if (end_of_input) { + current_elements_[cycle_index].reset(); + } + mutex_lock l(mu_); + element_in_use_[cycle_index] = false; + num_calls_--; + if (end_of_input) { + args_list_[cycle_index].clear(); + num_open_--; + } + } + cond_var_.notify_all(); + } + + int64 MaxInvocationResults() { + return dataset()->cycle_length_ * dataset()->block_length_; + } + + // Method responsible for 1) creating iterators out of input elements, 2) + // determining the order in which elements are fetched from the iterators, + // and 3) scheduling the fetching of the elements to a threadpool. + // + // This method runs in the `runner_thread` background thread. + void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) { + while (true) { + { + mutex_lock l(mu_); + // Wait until this thread is cancelled, the end of input has been + // reached, or the cycle element at the `cycle_index_` position is + // not in use and there is space in the `invocation_results_` queue. + while (!cancelled_ && (!end_of_input_ || num_open_ > 0) && + (element_in_use_[cycle_index_] || + num_calls_ >= dataset()->num_parallel_calls_ || + invocation_results_.size() >= MaxInvocationResults())) { + cond_var_.wait(l); + } + + if (cancelled_ || (end_of_input_ && num_open_ == 0)) { + return; + } + + while (!element_in_use_[cycle_index_] && + (!end_of_input_ || num_open_ > 0) && + num_calls_ < dataset()->num_parallel_calls_ && + invocation_results_.size() < MaxInvocationResults()) { + if (!current_elements_[cycle_index_]) { + // Try to create a new iterator from the next input element. + Status status = input_impl_->GetNext( + ctx.get(), &args_list_[cycle_index_], &end_of_input_); + if (!status.ok()) { + invocation_results_.emplace_back(new InvocationResult()); + std::shared_ptr<InvocationResult>& result = + invocation_results_.back(); + result->status.Update(status); + result->notification.Notify(); + break; + } + if (!end_of_input_) { + Status status = MakeIteratorFromInputElement( + ctx.get(), args_list_[cycle_index_], cycle_index_, + dataset()->captured_func_.get(), prefix(), + ¤t_elements_[cycle_index_]); + if (!status.ok()) { + invocation_results_.emplace_back(new InvocationResult()); + std::shared_ptr<InvocationResult>& result = + invocation_results_.back(); + result->status.Update(status); + result->notification.Notify(); + break; + } + ++num_open_; + } + } + if (current_elements_[cycle_index_]) { + // Pre-allocate invocation results for outputs to be fetched + // and then fetch the outputs asynchronously. + std::vector<std::shared_ptr<InvocationResult>> results; + results.reserve(dataset()->block_length_); + for (int i = 0; i < dataset()->block_length_; ++i) { + invocation_results_.emplace_back(new InvocationResult()); + results.push_back(invocation_results_.back()); + } + num_calls_++; + element_in_use_[cycle_index_] = true; + thread_pool_->Schedule(std::bind(&Iterator::FetchOutputs, this, + ctx, cycle_index_, + std::move(results))); + } + cycle_index_ = (cycle_index_ + 1) % dataset()->cycle_length_; + } + } + cond_var_.notify_all(); + } + } + + Status WriteStatusLocked(IteratorStateWriter* writer, size_t index, + const Status& status) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + TF_RETURN_IF_ERROR(writer->WriteScalar( + CodeKey(index), static_cast<int64>(status.code()))); + if (!status.ok()) { + TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(index), + status.error_message())); + } + return Status::OK(); + } + + Status ReadStatusLocked(IteratorStateReader* reader, size_t index, + Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + int64 code_int; + TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int)); + error::Code code = static_cast<error::Code>(code_int); + + if (code != error::Code::OK) { + string error_message; + TF_RETURN_IF_ERROR( + reader->ReadScalar(ErrorMessageKey(index), &error_message)); + *status = Status(code, error_message); + } else { + *status = Status::OK(); + } + return Status::OK(); + } + + string CodeKey(size_t index) { + return full_name( + strings::StrCat("invocation_results[", index, "].code")); + } + + string ErrorMessageKey(size_t index) { + return full_name( + strings::StrCat("invocation_results[", index, "].error_message")); + } + + Status WriteCurrentElements(IteratorStateWriter* writer) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + for (int idx = 0; idx < current_elements_.size(); idx++) { + if (current_elements_[idx]) { + TF_RETURN_IF_ERROR(SaveInput(writer, current_elements_[idx])); + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name(strings::StrCat("args_size[", idx, "]")), + args_list_[idx].size())); + for (int i = 0; i < args_list_[idx].size(); i++) { + TF_RETURN_IF_ERROR(writer->WriteTensor( + full_name(strings::StrCat("args_list_[", idx, "][", i, "]")), + args_list_[idx][i])); + } + } + } + return Status::OK(); + } + + Status ReadCurrentElements(IteratorContext* ctx, + IteratorStateReader* reader) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + for (int idx = 0; idx < current_elements_.size(); idx++) { + if (reader->Contains( + full_name(strings::StrCat("args_size[", idx, "]")))) { + int64 args_size; + TF_RETURN_IF_ERROR(reader->ReadScalar( + full_name(strings::StrCat("args_size[", idx, "]")), + &args_size)); + args_list_[idx].resize(args_size); + for (int i = 0; i < args_size; i++) { + TF_RETURN_IF_ERROR(reader->ReadTensor( + full_name(strings::StrCat("args_list_[", idx, "][", i, "]")), + &args_list_[idx][i])); + } + TF_RETURN_IF_ERROR(MakeIteratorFromInputElement( + ctx, args_list_[idx], idx, dataset()->captured_func_.get(), + prefix(), ¤t_elements_[idx])); + TF_RETURN_IF_ERROR( + RestoreInput(ctx, reader, current_elements_[idx])); + } else { + current_elements_[idx].reset(); + } + } + return Status::OK(); + } + + // Used for coordination between the main thread, the runner thread, and + // the worker threads. + mutex mu_; + + // Used for coordination between the main thread, the runner thread, and + // the worker threads. In particular, the runner thread should only + // schedule new calls when the number of in-flight calls is less than the + // user specified level of parallelism, there are slots available in the + // `invocation_results_` buffer, the current cycle element is not in use, + // and there are elements left to be fetched. + condition_variable cond_var_; + + // Iterator for input elements. + std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); + + // Identifies current cycle element. + int64 cycle_index_ = 0; + + // Arguments for creating an iterator for cycle elements. + std::vector<std::vector<Tensor>> args_list_ GUARDED_BY(mu_); + + // Iterators for the current cycle elements. Concurrent access is + // protected by `element_in_use_`. + std::vector<std::unique_ptr<IteratorBase>> current_elements_; + + // Identifies cycle elements that are in use by worker threads. + std::vector<bool> element_in_use_ GUARDED_BY(mu_); + + // Buffer for storing the invocation results. + std::deque<std::shared_ptr<InvocationResult>> invocation_results_ + GUARDED_BY(mu_); + + // Identifies whether end of input has been reached. + bool end_of_input_ GUARDED_BY(mu_) = false; + + // Identifies the number of open iterators. + int64 num_open_ GUARDED_BY(mu_) = 0; + + // Identifies the number of outstanding calls. + int64 num_calls_ GUARDED_BY(mu_) = 0; + + std::unique_ptr<thread::ThreadPool> thread_pool_; + std::unique_ptr<Thread> runner_thread_ GUARDED_BY(mu_); + + // Identifies whether background activity should be cancelled. + bool cancelled_ GUARDED_BY(mu_) = false; + }; + + const DatasetBase* const input_; + const NameAttrList interleave_func_; + const std::unique_ptr<CapturedFunction> captured_func_; + const int64 cycle_length_; + const int64 block_length_; + const int64 num_parallel_calls_; + const DataTypeVector output_types_; + const std::vector<PartialTensorShape> output_shapes_; + }; + + DataTypeVector output_types_; + std::vector<PartialTensorShape> output_shapes_; + NameAttrList interleave_func_; +}; + +REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDatasetV2").Device(DEVICE_CPU), + ParallelInterleaveDatasetV2Op); + } // namespace } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc index ac5ed286ee..a0cb179eb8 100644 --- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc @@ -33,11 +33,12 @@ namespace { class ParallelMapDatasetOp : public UnaryDatasetOpKernel { public: explicit ParallelMapDatasetOp(OpKernelConstruction* ctx) - : UnaryDatasetOpKernel(ctx), - graph_def_version_(ctx->graph_def_version()) { + : UnaryDatasetOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("use_inter_op_parallelism", + &use_inter_op_parallelism_)); } protected: @@ -60,10 +61,12 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr<CapturedFunction> captured_func; OP_REQUIRES_OK(ctx, CapturedFunction::Create( - func_, std::move(other_arguments), &captured_func)); + func_, std::move(other_arguments), + use_inter_op_parallelism_, &captured_func)); *output = new Dataset(ctx, input, func_, num_parallel_calls, output_types_, - output_shapes_, std::move(captured_func)); + output_shapes_, use_inter_op_parallelism_, + std::move(captured_func)); } private: @@ -73,6 +76,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { const NameAttrList& func, int32 num_parallel_calls, const DataTypeVector& output_types, const std::vector<PartialTensorShape>& output_shapes, + bool use_inter_op_parallelism, std::unique_ptr<CapturedFunction> captured_func) : DatasetBase(DatasetContext(ctx)), input_(input), @@ -80,6 +84,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { num_parallel_calls_(num_parallel_calls), output_types_(output_types), output_shapes_(output_shapes), + use_inter_op_parallelism_(use_inter_op_parallelism), captured_func_(std::move(captured_func)) { input_->Ref(); } @@ -92,12 +97,27 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { return captured_func_->Instantiate(ctx); }; - auto map_func = [this](IteratorContext* ctx, - std::vector<Tensor> input_element, - std::vector<Tensor>* result, StatusCallback done) { - captured_func_->RunAsync(ctx, std::move(input_element), result, - std::move(done)); - }; + ParallelMapIteratorFunction map_func; + if (use_inter_op_parallelism_) { + map_func = [this](IteratorContext* ctx, + std::vector<Tensor> input_element, + std::vector<Tensor>* result, StatusCallback done) { + captured_func_->RunAsync(ctx, std::move(input_element), result, + std::move(done)); + }; + } else { + map_func = [this](IteratorContext* ctx, + std::vector<Tensor> input_element, + std::vector<Tensor>* result, StatusCallback done) { + (*ctx->runner())(std::bind( + [this, ctx, result](std::vector<Tensor>& input_element, + StatusCallback& done) { + captured_func_->RunAsync(ctx, std::move(input_element), result, + std::move(done)); + }, + std::move(input_element), std::move(done))); + }; + } return NewParallelMapIterator( {this, strings::StrCat(prefix, "::ParallelMap")}, input_, @@ -167,12 +187,13 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { const int32 num_parallel_calls_; const DataTypeVector output_types_; const std::vector<PartialTensorShape> output_shapes_; + const bool use_inter_op_parallelism_; const std::unique_ptr<CapturedFunction> captured_func_; }; - const int graph_def_version_; DataTypeVector output_types_; std::vector<PartialTensorShape> output_shapes_; + bool use_inter_op_parallelism_; NameAttrList func_; }; diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc index a7a2935195..baf448e572 100644 --- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc +++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc @@ -209,6 +209,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { if (s.ok()) { *out_tensors = std::move(buffer_.front().value); } + auto_tuner_.RecordConsumption(buffer_.size()); buffer_.pop_front(); *end_of_sequence = false; diff --git a/tensorflow/core/kernels/dynamic_stitch_op.cc b/tensorflow/core/kernels/dynamic_stitch_op.cc index b01db91720..fb2a4cc8ef 100644 --- a/tensorflow/core/kernels/dynamic_stitch_op.cc +++ b/tensorflow/core/kernels/dynamic_stitch_op.cc @@ -247,8 +247,8 @@ class DynamicStitchOpImplCPU : public DynamicStitchOpImplBase<T> { data.shaped<T, 2>({indices_vec.dimension(0), slice_size}); if (DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) { - T* merged_base = &merged_flat(0, 0); - const T* data_base = &data_flat(0, 0); + T* merged_base = merged_flat.data(); + const T* data_base = data_flat.data(); for (int i = 0; i < indices_vec.size(); i++) { int32 index = internal::SubtleMustCopy(indices_vec(i)); OP_REQUIRES( diff --git a/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc b/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc index 3b34f650b6..ec949ddc84 100644 --- a/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc +++ b/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc @@ -48,8 +48,10 @@ void SpatialConvolution(int iters, int num_threads, benchmark.SpatialConvolution(input_dims, filter_dims); - auto output_size = input_dims.TotalSize(); - auto flops = output_size * (input_depth * filter_height * filter_width); + auto num_computed_elements = + (input_dims.TotalSize() / input_depth) * filter_count; + auto flops = + num_computed_elements * (input_depth * filter_height * filter_width); ::tensorflow::testing::ItemsProcessed(flops * iters); } @@ -75,8 +77,9 @@ void SpatialConvolutionBackwardInput(int iters, int num_threads, benchmark.SpatialConvolutionBackwardInput(input_dims, filter_dims); - auto output_size = input_dims.TotalSize(); - auto flops = output_size * (input_depth * filter_height * filter_width); + auto num_computed_elements = input_dims.TotalSize(); + auto flops = + num_computed_elements * (input_depth * filter_height * filter_width); ::tensorflow::testing::ItemsProcessed(flops * iters); } @@ -102,8 +105,9 @@ void SpatialConvolutionBackwardKernel(int iters, int num_threads, benchmark.SpatialConvolutionBackwardKernel(input_dims, filter_dims); - auto filter_size = filter_dims.TotalSize(); - auto flops = filter_size * (input_batches * input_height * input_width); + auto num_computed_elements = filter_dims.TotalSize(); + auto flops = + num_computed_elements * (input_batches * input_height * input_width); ::tensorflow::testing::ItemsProcessed(flops * iters); } @@ -266,8 +270,9 @@ void CuboidConvolution(int iters, int num_threads, benchmark.CuboidConvolution(input_dims, filter_dims); - auto output_size = input_dims.TotalSize(); - auto flops = output_size * + auto num_computed_elements = + (input_dims.TotalSize() / input_depth) * filter_count; + auto flops = num_computed_elements * (input_depth * filter_height * filter_width * filter_planes); ::tensorflow::testing::ItemsProcessed(flops * iters); } @@ -295,8 +300,8 @@ void CuboidConvolutionBackwardInput(int iters, int num_threads, benchmark.CuboidConvolutionBackwardInput(input_dims, filter_dims); - auto output_size = input_dims.TotalSize(); - auto flops = output_size * + auto num_computed_elements = input_dims.TotalSize(); + auto flops = num_computed_elements * (input_depth * filter_height * filter_width * filter_planes); ::tensorflow::testing::ItemsProcessed(flops * iters); } @@ -324,9 +329,9 @@ void CuboidConvolutionBackwardKernel(int iters, int num_threads, benchmark.CuboidConvolutionBackwardKernel(input_dims, filter_dims); - auto filter_size = filter_dims.TotalSize(); - auto flops = - filter_size * (input_batches * input_height * input_width * input_planes); + auto num_computed_elements = filter_dims.TotalSize(); + auto flops = num_computed_elements * + (input_batches * input_height * input_width * input_planes); ::tensorflow::testing::ItemsProcessed(flops * iters); } diff --git a/tensorflow/core/kernels/lookup_table_op.cc b/tensorflow/core/kernels/lookup_table_op.cc index 2e8d9c623c..a495758861 100644 --- a/tensorflow/core/kernels/lookup_table_op.cc +++ b/tensorflow/core/kernels/lookup_table_op.cc @@ -50,7 +50,7 @@ class MutableHashTableOfScalars final : public LookupInterface { MutableHashTableOfScalars(OpKernelContext* ctx, OpKernel* kernel) {} size_t size() const override { - mutex_lock l(mu_); + tf_shared_lock l(mu_); return table_.size(); } @@ -60,7 +60,7 @@ class MutableHashTableOfScalars final : public LookupInterface { const auto key_values = key.flat<K>(); auto value_values = value->flat<V>(); - mutex_lock l(mu_); + tf_shared_lock l(mu_); for (int64 i = 0; i < key_values.size(); ++i) { value_values(i) = gtl::FindWithDefault( table_, SubtleMustCopyIfIntegral(key_values(i)), default_val); @@ -95,7 +95,7 @@ class MutableHashTableOfScalars final : public LookupInterface { } Status ExportValues(OpKernelContext* ctx) override { - mutex_lock l(mu_); + tf_shared_lock l(mu_); int64 size = table_.size(); Tensor* keys; @@ -125,7 +125,7 @@ class MutableHashTableOfScalars final : public LookupInterface { int64 MemoryUsed() const override { int64 ret = 0; - mutex_lock l(mu_); + tf_shared_lock l(mu_); for (unsigned i = 0; i < table_.bucket_count(); ++i) { size_t bucket_size = table_.bucket_size(i); if (bucket_size == 0) { @@ -138,7 +138,6 @@ class MutableHashTableOfScalars final : public LookupInterface { } private: - // TODO(andreasst): consider using a read/write lock or a concurrent map mutable mutex mu_; std::unordered_map<K, V> table_ GUARDED_BY(mu_); }; @@ -158,7 +157,7 @@ class MutableHashTableOfTensors final : public LookupInterface { } size_t size() const override { - mutex_lock l(mu_); + tf_shared_lock l(mu_); return table_.size(); } @@ -169,7 +168,7 @@ class MutableHashTableOfTensors final : public LookupInterface { auto value_values = value->flat_inner_dims<V, 2>(); int64 value_dim = value_shape_.dim_size(0); - mutex_lock l(mu_); + tf_shared_lock l(mu_); for (int64 i = 0; i < key_values.size(); ++i) { ValueArray* value_vec = gtl::FindOrNull(table_, SubtleMustCopyIfIntegral(key_values(i))); @@ -219,7 +218,7 @@ class MutableHashTableOfTensors final : public LookupInterface { } Status ExportValues(OpKernelContext* ctx) override { - mutex_lock l(mu_); + tf_shared_lock l(mu_); int64 size = table_.size(); int64 value_dim = value_shape_.dim_size(0); @@ -254,7 +253,7 @@ class MutableHashTableOfTensors final : public LookupInterface { int64 MemoryUsed() const override { int64 ret = 0; - mutex_lock l(mu_); + tf_shared_lock l(mu_); for (unsigned i = 0; i < table_.bucket_count(); ++i) { size_t bucket_size = table_.bucket_size(i); if (bucket_size == 0) { @@ -268,7 +267,6 @@ class MutableHashTableOfTensors final : public LookupInterface { private: TensorShape value_shape_; - // TODO(andreasst): consider using a read/write lock or a concurrent map mutable mutex mu_; typedef gtl::InlinedVector<V, 4> ValueArray; std::unordered_map<K, ValueArray> table_ GUARDED_BY(mu_); @@ -335,7 +333,7 @@ class MutableDenseHashTable final : public LookupInterface { } size_t size() const override LOCKS_EXCLUDED(mu_) { - mutex_lock l(mu_); + tf_shared_lock l(mu_); return num_entries_; } @@ -355,7 +353,7 @@ class MutableDenseHashTable final : public LookupInterface { auto value_matrix = value->shaped<V, 2>({num_elements, value_size}); const auto default_flat = default_value.flat<V>(); - mutex_lock l(mu_); + tf_shared_lock l(mu_); const auto key_buckets_matrix = key_buckets_.AccessTensor(ctx)->template matrix<K>(); const auto value_buckets_matrix = @@ -451,7 +449,7 @@ class MutableDenseHashTable final : public LookupInterface { } Status ExportValues(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) { - mutex_lock l(mu_); + tf_shared_lock l(mu_); Tensor key_buckets_tensor = *key_buckets_.AccessTensor(ctx); Tensor value_buckets_tensor = *value_buckets_.AccessTensor(ctx); TF_RETURN_IF_ERROR(ctx->set_output("keys", key_buckets_tensor)); @@ -493,7 +491,7 @@ class MutableDenseHashTable final : public LookupInterface { TensorShape value_shape() const override { return value_shape_; } int64 MemoryUsed() const override { - mutex_lock l(mu_); + tf_shared_lock l(mu_); return sizeof(MutableDenseHashTable) + key_buckets_.AllocatedBytes() + value_buckets_.AllocatedBytes() + empty_key_.AllocatedBytes(); } diff --git a/tensorflow/core/kernels/map_stage_op.cc b/tensorflow/core/kernels/map_stage_op.cc index bdc3b5778f..dd89597369 100644 --- a/tensorflow/core/kernels/map_stage_op.cc +++ b/tensorflow/core/kernels/map_stage_op.cc @@ -410,8 +410,9 @@ class StagingMap : public ResourceBase { copy_or_move_tensors(&it->second, *key, *indices, tuple)); // Remove entry if all the values have been consumed - if (!std::any_of(it->second.begin(), it->second.end(), - std::mem_fn(&OptionalTensor::has_value))) { + if (!std::any_of( + it->second.begin(), it->second.end(), + [](const OptionalTensor& tensor) { return tensor.has_value(); })) { map_.erase(it); } @@ -444,8 +445,9 @@ class StagingMap : public ResourceBase { *key = it->first; // Remove entry if all the values have been consumed - if (!std::any_of(it->second.begin(), it->second.end(), - std::mem_fn(&OptionalTensor::has_value))) { + if (!std::any_of( + it->second.begin(), it->second.end(), + [](const OptionalTensor& tensor) { return tensor.has_value(); })) { map_.erase(it); } diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc index 9b10c3f3d6..184e0cb003 100644 --- a/tensorflow/core/kernels/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_ops.cc @@ -1083,7 +1083,7 @@ class MklConvOp : public OpKernel { #endif // Register 2D operations -#define REGISTER_MKL_CPU(T) \ +#define REGISTER_MKL_CPU_2D(T) \ REGISTER_KERNEL_BUILDER(Name("_MklConv2D") \ .Device(DEVICE_CPU) \ .TypeConstraint<T>("T") \ @@ -1100,16 +1100,16 @@ class MklConvOp : public OpKernel { .Label(mkl_op_registry::kMklOpLabel), \ MklDummyOp<CPUDevice, T>); -TF_CALL_float(REGISTER_MKL_CPU); +TF_CALL_float(REGISTER_MKL_CPU_2D); // Register 3D operations -#define REGISTER_MKL_CPU(T) \ +#define REGISTER_MKL_CPU_3D(T) \ REGISTER_KERNEL_BUILDER(Name("_MklConv3D") \ .Device(DEVICE_CPU) \ .TypeConstraint<T>("T") \ .Label(mkl_op_registry::kMklOpLabel), \ MklConvOp<CPUDevice, T, false>); -TF_CALL_float(REGISTER_MKL_CPU); +TF_CALL_float(REGISTER_MKL_CPU_3D); } // namespace tensorflow #endif // INTEL_MKL diff --git a/tensorflow/core/kernels/mkl_pooling_ops_common.cc b/tensorflow/core/kernels/mkl_pooling_ops_common.cc index ec6d241e17..5398e6113f 100644 --- a/tensorflow/core/kernels/mkl_pooling_ops_common.cc +++ b/tensorflow/core/kernels/mkl_pooling_ops_common.cc @@ -34,11 +34,11 @@ using mkldnn::prop_kind; template <typename T> void MklPoolingFwdPrimitive<T>::Setup(const MklPoolingParams& fwdParams) { - if (fwdParams.alg_kind != pooling_max && fwdParams.alg_kind != pooling_avg && - fwdParams.alg_kind != pooling_avg_include_padding && - fwdParams.alg_kind != pooling_avg_exclude_padding) { - assert("Pooling algorithm kind is not supported\n"); - } + DCHECK(fwdParams.alg_kind == pooling_max || + fwdParams.alg_kind == pooling_avg || + fwdParams.alg_kind == pooling_avg_include_padding || + fwdParams.alg_kind == pooling_avg_exclude_padding) + << "Pooling algorithm kind is not supported"; context_.alg_kind = fwdParams.alg_kind; // create memory desc @@ -102,7 +102,7 @@ void MklPoolingFwdPrimitive<T>::Execute(const T* src_data, T* dst_data, static_cast<void*>(const_cast<T*>(src_data))); context_.dst_mem->set_data_handle(static_cast<void*>(dst_data)); if (context_.alg_kind == pooling_max) { // max pooling must have ws - assert(ws_data != nullptr); + DCHECK(ws_data != nullptr); context_.ws_mem->set_data_handle(ws_data); } context_.fwd_stream->submit(context_.fwd_primitives); @@ -111,7 +111,7 @@ void MklPoolingFwdPrimitive<T>::Execute(const T* src_data, T* dst_data, context_.src_mem->set_data_handle(DummyData); context_.dst_mem->set_data_handle(DummyData); if (context_.alg_kind == pooling_max) { // max pooling must have ws - assert(ws_data != nullptr); + DCHECK(ws_data != nullptr); context_.ws_mem->set_data_handle(DummyData); } } @@ -120,11 +120,11 @@ template class MklPoolingFwdPrimitive<float>; template <typename T> void MklPoolingBwdPrimitive<T>::Setup(const MklPoolingParams& bwdParams) { - if (bwdParams.alg_kind != pooling_max && bwdParams.alg_kind != pooling_avg && - bwdParams.alg_kind != pooling_avg_include_padding && - bwdParams.alg_kind != pooling_avg_exclude_padding) { - assert("Pooling algorithm kind is not supported\n"); - } + DCHECK(bwdParams.alg_kind == pooling_max || + bwdParams.alg_kind == pooling_avg || + bwdParams.alg_kind == pooling_avg_include_padding || + bwdParams.alg_kind == pooling_avg_exclude_padding) + << "Pooling algorithm kind is not supported"; context_.alg_kind = bwdParams.alg_kind; // check whether it is 2d or 3d @@ -190,7 +190,7 @@ void MklPoolingBwdPrimitive<T>::Execute(const T* diff_dst_data, static_cast<void*>(const_cast<T*>(diff_dst_data))); context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data)); if (context_.alg_kind == pooling_max) { - assert(ws_data != nullptr); + DCHECK(ws_data != nullptr); context_.ws_mem->set_data_handle(const_cast<void*>(ws_data)); } @@ -199,7 +199,7 @@ void MklPoolingBwdPrimitive<T>::Execute(const T* diff_dst_data, context_.diff_dst_mem->set_data_handle(DummyData); context_.diff_src_mem->set_data_handle(DummyData); if (context_.alg_kind == pooling_max) { - assert(ws_data != nullptr); + DCHECK(ws_data != nullptr); context_.ws_mem->set_data_handle(DummyData); } } diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc index f4cfc48af5..84385356e1 100644 --- a/tensorflow/core/kernels/mkl_relu_op.cc +++ b/tensorflow/core/kernels/mkl_relu_op.cc @@ -40,7 +40,6 @@ using mkldnn::memory; #include "mkl_dnn.h" #include "mkl_dnn_types.h" #endif -#include "tensorflow/core/platform/default/logging.h" #include "tensorflow/core/util/mkl_util.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mkl_softmax_op.cc b/tensorflow/core/kernels/mkl_softmax_op.cc index 04d8a1bdeb..cfab529662 100644 --- a/tensorflow/core/kernels/mkl_softmax_op.cc +++ b/tensorflow/core/kernels/mkl_softmax_op.cc @@ -88,6 +88,7 @@ class MklSoftmaxOp : public OpKernel { break; default: OP_REQUIRES_OK(context, errors::Aborted("Input dims must be <= 5 and >=1")); + return; } // Create softmax memory for src, dst: both are defined in mkl_util.h, // they are wrapper diff --git a/tensorflow/core/kernels/non_max_suppression_op.cc b/tensorflow/core/kernels/non_max_suppression_op.cc index 5d9257e20b..81ce6d6e95 100644 --- a/tensorflow/core/kernels/non_max_suppression_op.cc +++ b/tensorflow/core/kernels/non_max_suppression_op.cc @@ -75,28 +75,28 @@ static inline void ParseAndCheckBoxSizes(OpKernelContext* context, } // Return intersection-over-union overlap between boxes i and j -static inline float IOUGreaterThanThreshold( - typename TTypes<float, 2>::ConstTensor boxes, int i, int j, - float iou_threshold) { - const float ymin_i = std::min<float>(boxes(i, 0), boxes(i, 2)); - const float xmin_i = std::min<float>(boxes(i, 1), boxes(i, 3)); - const float ymax_i = std::max<float>(boxes(i, 0), boxes(i, 2)); - const float xmax_i = std::max<float>(boxes(i, 1), boxes(i, 3)); - const float ymin_j = std::min<float>(boxes(j, 0), boxes(j, 2)); - const float xmin_j = std::min<float>(boxes(j, 1), boxes(j, 3)); - const float ymax_j = std::max<float>(boxes(j, 0), boxes(j, 2)); - const float xmax_j = std::max<float>(boxes(j, 1), boxes(j, 3)); - const float area_i = (ymax_i - ymin_i) * (xmax_i - xmin_i); - const float area_j = (ymax_j - ymin_j) * (xmax_j - xmin_j); - if (area_i <= 0 || area_j <= 0) return 0.0; - const float intersection_ymin = std::max<float>(ymin_i, ymin_j); - const float intersection_xmin = std::max<float>(xmin_i, xmin_j); - const float intersection_ymax = std::min<float>(ymax_i, ymax_j); - const float intersection_xmax = std::min<float>(xmax_i, xmax_j); - const float intersection_area = - std::max<float>(intersection_ymax - intersection_ymin, 0.0) * - std::max<float>(intersection_xmax - intersection_xmin, 0.0); - const float iou = intersection_area / (area_i + area_j - intersection_area); +template <typename T> +static inline bool IOUGreaterThanThreshold( + typename TTypes<T, 2>::ConstTensor boxes, int i, int j, T iou_threshold) { + const T ymin_i = std::min<T>(boxes(i, 0), boxes(i, 2)); + const T xmin_i = std::min<T>(boxes(i, 1), boxes(i, 3)); + const T ymax_i = std::max<T>(boxes(i, 0), boxes(i, 2)); + const T xmax_i = std::max<T>(boxes(i, 1), boxes(i, 3)); + const T ymin_j = std::min<T>(boxes(j, 0), boxes(j, 2)); + const T xmin_j = std::min<T>(boxes(j, 1), boxes(j, 3)); + const T ymax_j = std::max<T>(boxes(j, 0), boxes(j, 2)); + const T xmax_j = std::max<T>(boxes(j, 1), boxes(j, 3)); + const T area_i = (ymax_i - ymin_i) * (xmax_i - xmin_i); + const T area_j = (ymax_j - ymin_j) * (xmax_j - xmin_j); + if (area_i <= static_cast<T>(0) || area_j <= static_cast<T>(0)) return 0; + const T intersection_ymin = std::max<T>(ymin_i, ymin_j); + const T intersection_xmin = std::max<T>(xmin_i, xmin_j); + const T intersection_ymax = std::min<T>(ymax_i, ymax_j); + const T intersection_xmax = std::min<T>(xmax_i, xmax_j); + const T intersection_area = + std::max<T>(intersection_ymax - intersection_ymin, static_cast<T>(0.0)) * + std::max<T>(intersection_xmax - intersection_xmin, static_cast<T>(0.0)); + const T iou = intersection_area / (area_i + area_j - intersection_area); return iou > iou_threshold; } @@ -106,11 +106,13 @@ static inline bool OverlapsGreaterThanThreshold( return overlaps(i, j) > overlap_threshold; } +template <typename T> static inline std::function<bool(int, int)> CreateIOUSuppressCheckFn( const Tensor& boxes, float threshold) { - typename TTypes<float, 2>::ConstTensor boxes_data = boxes.tensor<float, 2>(); - return std::bind(&IOUGreaterThanThreshold, boxes_data, std::placeholders::_1, - std::placeholders::_2, threshold); + typename TTypes<T, 2>::ConstTensor boxes_data = boxes.tensor<T, 2>(); + return std::bind(&IOUGreaterThanThreshold<T>, boxes_data, + std::placeholders::_1, std::placeholders::_2, + static_cast<T>(threshold)); } static inline std::function<bool(int, int)> CreateOverlapsSuppressCheckFn( @@ -121,6 +123,7 @@ static inline std::function<bool(int, int)> CreateOverlapsSuppressCheckFn( std::placeholders::_1, std::placeholders::_2, threshold); } +template <typename T> void DoNonMaxSuppressionOp( OpKernelContext* context, const Tensor& scores, int num_boxes, const Tensor& max_output_size, const float score_threshold, @@ -128,13 +131,13 @@ void DoNonMaxSuppressionOp( bool pad_to_max_output_size = false, int* ptr_num_valid_outputs = nullptr) { const int output_size = max_output_size.scalar<int>()(); - std::vector<float> scores_data(num_boxes); - std::copy_n(scores.flat<float>().data(), num_boxes, scores_data.begin()); + std::vector<T> scores_data(num_boxes); + std::copy_n(scores.flat<T>().data(), num_boxes, scores_data.begin()); // Data structure for selection candidate in NMS. struct Candidate { int box_index; - float score; + T score; }; auto cmp = [](const Candidate bs_i, const Candidate bs_j) { @@ -143,13 +146,13 @@ void DoNonMaxSuppressionOp( std::priority_queue<Candidate, std::deque<Candidate>, decltype(cmp)> candidate_priority_queue(cmp); for (int i = 0; i < scores_data.size(); ++i) { - if (scores_data[i] > score_threshold) { + if (static_cast<float>(scores_data[i]) > score_threshold) { candidate_priority_queue.emplace(Candidate({i, scores_data[i]})); } } std::vector<int> selected; - std::vector<float> selected_scores; + std::vector<T> selected_scores; Candidate next_candidate; while (selected.size() < output_size && !candidate_priority_queue.empty()) { @@ -176,7 +179,7 @@ void DoNonMaxSuppressionOp( int num_valid_outputs = selected.size(); if (pad_to_max_output_size) { selected.resize(output_size, 0); - selected_scores.resize(output_size, 0); + selected_scores.resize(output_size, static_cast<T>(0)); } if (ptr_num_valid_outputs) { *ptr_num_valid_outputs = num_valid_outputs; @@ -221,18 +224,19 @@ class NonMaxSuppressionOp : public OpKernel { if (!context->status().ok()) { return; } - auto suppress_check_fn = CreateIOUSuppressCheckFn(boxes, iou_threshold_); + auto suppress_check_fn = + CreateIOUSuppressCheckFn<float>(boxes, iou_threshold_); const float score_threshold_val = std::numeric_limits<float>::lowest(); - DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size, - score_threshold_val, suppress_check_fn); + DoNonMaxSuppressionOp<float>(context, scores, num_boxes, max_output_size, + score_threshold_val, suppress_check_fn); } private: float iou_threshold_; }; -template <typename Device> +template <typename Device, typename T> class NonMaxSuppressionV2Op : public OpKernel { public: explicit NonMaxSuppressionV2Op(OpKernelConstruction* context) @@ -264,11 +268,12 @@ class NonMaxSuppressionV2Op : public OpKernel { if (!context->status().ok()) { return; } - auto suppress_check_fn = CreateIOUSuppressCheckFn(boxes, iou_threshold_val); + auto suppress_check_fn = + CreateIOUSuppressCheckFn<T>(boxes, iou_threshold_val); const float score_threshold_val = std::numeric_limits<float>::lowest(); - DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size, - score_threshold_val, suppress_check_fn); + DoNonMaxSuppressionOp<T>(context, scores, num_boxes, max_output_size, + score_threshold_val, suppress_check_fn); } }; @@ -325,7 +330,7 @@ class NonMaxSuppressionV3V4Base : public OpKernel { float score_threshold_val_; }; -template <typename Device> +template <typename Device, typename T> class NonMaxSuppressionV3Op : public NonMaxSuppressionV3V4Base { public: explicit NonMaxSuppressionV3Op(OpKernelConstruction* context) @@ -334,14 +339,14 @@ class NonMaxSuppressionV3Op : public NonMaxSuppressionV3V4Base { protected: void DoComputeAndPostProcess(OpKernelContext* context) override { auto suppress_check_fn = - CreateIOUSuppressCheckFn(boxes_, iou_threshold_val_); + CreateIOUSuppressCheckFn<T>(boxes_, iou_threshold_val_); - DoNonMaxSuppressionOp(context, scores_, num_boxes_, max_output_size_, - score_threshold_val_, suppress_check_fn); + DoNonMaxSuppressionOp<T>(context, scores_, num_boxes_, max_output_size_, + score_threshold_val_, suppress_check_fn); } }; -template <typename Device> +template <typename Device, typename T> class NonMaxSuppressionV4Op : public NonMaxSuppressionV3V4Base { public: explicit NonMaxSuppressionV4Op(OpKernelConstruction* context) @@ -353,12 +358,12 @@ class NonMaxSuppressionV4Op : public NonMaxSuppressionV3V4Base { protected: void DoComputeAndPostProcess(OpKernelContext* context) override { auto suppress_check_fn = - CreateIOUSuppressCheckFn(boxes_, iou_threshold_val_); + CreateIOUSuppressCheckFn<T>(boxes_, iou_threshold_val_); int num_valid_outputs; - DoNonMaxSuppressionOp(context, scores_, num_boxes_, max_output_size_, - score_threshold_val_, suppress_check_fn, - pad_to_max_output_size_, &num_valid_outputs); + DoNonMaxSuppressionOp<T>(context, scores_, num_boxes_, max_output_size_, + score_threshold_val_, suppress_check_fn, + pad_to_max_output_size_, &num_valid_outputs); // Allocate scalar output tensor for number of indices computed. Tensor* num_outputs_t = nullptr; @@ -413,22 +418,37 @@ class NonMaxSuppressionWithOverlapsOp : public OpKernel { auto suppress_check_fn = CreateOverlapsSuppressCheckFn(overlaps, overlap_threshold_val); - DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size, - score_threshold_val, suppress_check_fn); + DoNonMaxSuppressionOp<float>(context, scores, num_boxes, max_output_size, + score_threshold_val, suppress_check_fn); } }; REGISTER_KERNEL_BUILDER(Name("NonMaxSuppression").Device(DEVICE_CPU), NonMaxSuppressionOp<CPUDevice>); -REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2").Device(DEVICE_CPU), - NonMaxSuppressionV2Op<CPUDevice>); +REGISTER_KERNEL_BUILDER( + Name("NonMaxSuppressionV2").TypeConstraint<float>("T").Device(DEVICE_CPU), + NonMaxSuppressionV2Op<CPUDevice, float>); +REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2") + .TypeConstraint<Eigen::half>("T") + .Device(DEVICE_CPU), + NonMaxSuppressionV2Op<CPUDevice, Eigen::half>); -REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV3").Device(DEVICE_CPU), - NonMaxSuppressionV3Op<CPUDevice>); +REGISTER_KERNEL_BUILDER( + Name("NonMaxSuppressionV3").TypeConstraint<float>("T").Device(DEVICE_CPU), + NonMaxSuppressionV3Op<CPUDevice, float>); +REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV3") + .TypeConstraint<Eigen::half>("T") + .Device(DEVICE_CPU), + NonMaxSuppressionV3Op<CPUDevice, Eigen::half>); -REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV4").Device(DEVICE_CPU), - NonMaxSuppressionV4Op<CPUDevice>); +REGISTER_KERNEL_BUILDER( + Name("NonMaxSuppressionV4").TypeConstraint<float>("T").Device(DEVICE_CPU), + NonMaxSuppressionV4Op<CPUDevice, float>); +REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV4") + .TypeConstraint<Eigen::half>("T") + .Device(DEVICE_CPU), + NonMaxSuppressionV4Op<CPUDevice, Eigen::half>); REGISTER_KERNEL_BUILDER( Name("NonMaxSuppressionWithOverlaps").Device(DEVICE_CPU), diff --git a/tensorflow/core/kernels/partitioned_function_ops.cc b/tensorflow/core/kernels/partitioned_function_ops.cc index 876a1704c7..7bb403290d 100644 --- a/tensorflow/core/kernels/partitioned_function_ops.cc +++ b/tensorflow/core/kernels/partitioned_function_ops.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/common_runtime/placer.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/framework/function.h" @@ -104,13 +105,6 @@ class PartitionedCallOp : public AsyncOpKernel { for (auto d : lib->device_mgr()->ListDevices()) { device_set.AddDevice(d); } - Placer placer(graph.get(), &device_set); - OP_REQUIRES_OK_ASYNC(ctx, placer.Run(), done); - - std::unordered_map<string, std::unique_ptr<Graph>> subgraphs; - OP_REQUIRES_OK_ASYNC( - ctx, PartitionHelper(device_set, std::move(graph), &subgraphs), - done); // The FunctionLibraryRuntime's library cannot be mutated from within // an OpKernel, so functions are instantiated in an overlay library. @@ -124,6 +118,47 @@ class PartitionedCallOp : public AsyncOpKernel { new FunctionLibraryDefinition(*lib->GetFunctionLibraryDefinition()); overlay_libs_.emplace(lib, overlay_lib); + GraphOptimizationPassOptions optimization_options; + // TODO(akshayka): Thread SessionOptions (if any) into this kernel, or + // make it possible to specify the relevant options via attributes. + SessionOptions session_options; + session_options.env = ctx->env(); + optimization_options.session_options = &session_options; + optimization_options.graph = &graph; + optimization_options.flib_def = overlay_lib; + optimization_options.device_set = &device_set; + Placer placer(graph.get(), &device_set); + OP_REQUIRES_OK_ASYNC( + ctx, + OptimizationPassRegistry::Global()->RunGrouping( + OptimizationPassRegistry::PRE_PLACEMENT, optimization_options), + done); + OP_REQUIRES_OK_ASYNC(ctx, placer.Run(), done); + OP_REQUIRES_OK_ASYNC( + ctx, + OptimizationPassRegistry::Global()->RunGrouping( + OptimizationPassRegistry::POST_PLACEMENT, optimization_options), + done); + OP_REQUIRES_OK_ASYNC( + ctx, + OptimizationPassRegistry::Global()->RunGrouping( + OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, + optimization_options), + done); + + std::unordered_map<string, std::unique_ptr<Graph>> subgraphs; + OP_REQUIRES_OK_ASYNC( + ctx, PartitionHelper(device_set, std::move(graph), &subgraphs), + done); + optimization_options.graph = nullptr; + optimization_options.device_set = nullptr; + optimization_options.partition_graphs = &subgraphs; + OP_REQUIRES_OK_ASYNC(ctx, + OptimizationPassRegistry::Global()->RunGrouping( + OptimizationPassRegistry::POST_PARTITIONING, + optimization_options), + done); + auto handles = tensorflow::MakeUnique<gtl::FlatMap<string, FHandle>>(); for (const auto& pair : subgraphs) { // TODO(akshayka): Fail gracefully if the set of devices corresponds diff --git a/tensorflow/core/kernels/regex_full_match_op.cc b/tensorflow/core/kernels/regex_full_match_op.cc index 5863a2c8e4..7edaaad8f7 100644 --- a/tensorflow/core/kernels/regex_full_match_op.cc +++ b/tensorflow/core/kernels/regex_full_match_op.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { @@ -56,4 +57,36 @@ class RegexFullMatchOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("RegexFullMatch").Device(DEVICE_CPU), RegexFullMatchOp); +class StaticRegexFullMatchOp : public OpKernel { + public: + explicit StaticRegexFullMatchOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + string pattern; + OP_REQUIRES_OK(ctx, ctx->GetAttr("pattern", &pattern)); + re_ = MakeUnique<RE2>(pattern); + OP_REQUIRES(ctx, re_->ok(), + errors::InvalidArgument("Invalid pattern: ", pattern, + ", error: ", re_->error())); + } + + void Compute(OpKernelContext* ctx) override { + const Tensor* input_tensor; + OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor)); + const auto& input_flat = input_tensor->flat<string>(); + + Tensor* output_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output("output", input_tensor->shape(), + &output_tensor)); + auto output_flat = output_tensor->flat<bool>(); + for (size_t i = 0; i < input_flat.size(); ++i) { + output_flat(i) = RE2::FullMatch(input_flat(i), *re_); + } + } + + private: + std::unique_ptr<RE2> re_; +}; + +REGISTER_KERNEL_BUILDER(Name("StaticRegexFullMatch").Device(DEVICE_CPU), + StaticRegexFullMatchOp); + } // namespace tensorflow diff --git a/tensorflow/core/kernels/sparse_conditional_accumulator.h b/tensorflow/core/kernels/sparse_conditional_accumulator.h index 11149c4d16..a4453bd7ab 100644 --- a/tensorflow/core/kernels/sparse_conditional_accumulator.h +++ b/tensorflow/core/kernels/sparse_conditional_accumulator.h @@ -50,10 +50,10 @@ class SparseConditionalAccumulator public: SparseConditionalAccumulator(const DataType& dtype, const PartialTensorShape& shape, - const string& name) + const string& name, const string& reduction_type) : TypedConditionalAccumulatorBase< std::tuple<const Tensor*, const Tensor*, const Tensor*>>( - dtype, shape, name) { + dtype, shape, name, reduction_type) { accum_idx_vec_ = nullptr; count_element_ = nullptr; accum_val_ = nullptr; diff --git a/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc b/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc index 80bc1f1934..1e542a26a7 100644 --- a/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc +++ b/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc @@ -34,8 +34,8 @@ class SparseConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp { Creator GetCreator() const override { return [this](ConditionalAccumulatorBase** ret) { SparseConditionalAccumulator<Device, T>* accumulator = - new SparseConditionalAccumulator<Device, T>(dtype_, shape_, - cinfo_.name()); + new SparseConditionalAccumulator<Device, T>( + dtype_, shape_, cinfo_.name(), reduction_type_); *ret = accumulator; return Status::OK(); }; diff --git a/tensorflow/core/kernels/typed_conditional_accumulator_base.h b/tensorflow/core/kernels/typed_conditional_accumulator_base.h index 9dedb618f9..ca341e511e 100644 --- a/tensorflow/core/kernels/typed_conditional_accumulator_base.h +++ b/tensorflow/core/kernels/typed_conditional_accumulator_base.h @@ -35,8 +35,9 @@ class TypedConditionalAccumulatorBase : public ConditionalAccumulatorBase { public: TypedConditionalAccumulatorBase(const DataType& dtype, const PartialTensorShape& shape, - const string& name) - : ConditionalAccumulatorBase(dtype, shape, name) {} + const string& name, + const string& reduction_type) + : ConditionalAccumulatorBase(dtype, shape, name, reduction_type) {} /** * Attempts to add a gradient to the accumulator. An ApplyGrad attempt is diff --git a/tensorflow/core/lib/core/stringpiece.cc b/tensorflow/core/lib/core/stringpiece.cc new file mode 100644 index 0000000000..4c488066e4 --- /dev/null +++ b/tensorflow/core/lib/core/stringpiece.cc @@ -0,0 +1,54 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/lib/core/stringpiece.h" + +#include <algorithm> +#include <iostream> + +namespace tensorflow { + +std::ostream& operator<<(std::ostream& o, StringPiece piece) { + o.write(piece.data(), piece.size()); + return o; +} + +size_t StringPiece::find(char c, size_t pos) const { + if (pos >= size_) { + return npos; + } + const char* result = + reinterpret_cast<const char*>(memchr(data_ + pos, c, size_ - pos)); + return result != nullptr ? result - data_ : npos; +} + +// Search range is [0..pos] inclusive. If pos == npos, search everything. +size_t StringPiece::rfind(char c, size_t pos) const { + if (size_ == 0) return npos; + for (const char* p = data_ + std::min(pos, size_ - 1); p >= data_; p--) { + if (*p == c) { + return p - data_; + } + } + return npos; +} + +StringPiece StringPiece::substr(size_t pos, size_t n) const { + if (pos > size_) pos = size_; + if (n > size_ - pos) n = size_ - pos; + return StringPiece(data_ + pos, n); +} + +} // namespace tensorflow diff --git a/tensorflow/core/lib/core/stringpiece.h b/tensorflow/core/lib/core/stringpiece.h index e7b17c9b36..02dded42c1 100644 --- a/tensorflow/core/lib/core/stringpiece.h +++ b/tensorflow/core/lib/core/stringpiece.h @@ -31,13 +31,124 @@ limitations under the License. #include <string.h> #include <iosfwd> #include <string> -#include "absl/strings/string_view.h" +#include <type_traits> #include "tensorflow/core/platform/types.h" namespace tensorflow { -// Deprecated: please use absl::string_view directly. -using StringPiece = absl::string_view; +class StringPiece { + public: + typedef size_t size_type; + + // Create an empty slice. + StringPiece() : data_(nullptr), size_(0) {} + + // Create a slice that refers to d[0,n-1]. + StringPiece(const char* d, size_t n) : data_(d), size_(n) {} + + // Create a slice that refers to the contents of "s" + StringPiece(const string& s) : data_(s.data()), size_(s.size()) {} + + // Create a slice that refers to s[0,strlen(s)-1] + StringPiece(const char* s) : data_(s), size_(strlen(s)) {} + + // Return a pointer to the beginning of the referenced data + const char* data() const { return data_; } + + // Return the length (in bytes) of the referenced data + size_t size() const { return size_; } + + // Return true iff the length of the referenced data is zero + bool empty() const { return size_ == 0; } + + typedef const char* const_iterator; + typedef const char* iterator; + iterator begin() const { return data_; } + iterator end() const { return data_ + size_; } + + static const size_t npos = size_type(-1); + + // Return the ith byte in the referenced data. + // REQUIRES: n < size() + char operator[](size_t n) const { + assert(n < size()); + return data_[n]; + } + + // Drop the first "n" bytes from this slice. + void remove_prefix(size_t n) { + assert(n <= size()); + data_ += n; + size_ -= n; + } + + void remove_suffix(size_t n) { + assert(size_ >= n); + size_ -= n; + } + + size_t find(char c, size_t pos = 0) const; + size_t rfind(char c, size_t pos = npos) const; + + StringPiece substr(size_t pos, size_t n = npos) const; + + // Three-way comparison. Returns value: + // < 0 iff "*this" < "b", + // == 0 iff "*this" == "b", + // > 0 iff "*this" > "b" + int compare(StringPiece b) const; + + // Converts to various kinds of strings, including `std::basic_string`. + template <typename S> + explicit operator S() const { + static_assert( + std::is_same<char, typename S::value_type>::value, + "Type mismatch: S must be a string with character type char."); + static_assert( + std::is_same<std::char_traits<char>, typename S::traits_type>::value, + "Type mismatch: S must be a string with traits type " + "std::char_traits<char>."); + if (!data()) return {}; + return S(data(), size()); + } + + private: + const char* data_; + size_t size_; + + // Intentionally copyable +}; + +inline bool operator==(StringPiece x, StringPiece y) { + return ((x.size() == y.size()) && + (memcmp(x.data(), y.data(), x.size()) == 0)); +} + +inline bool operator!=(StringPiece x, StringPiece y) { return !(x == y); } + +inline bool operator<(StringPiece x, StringPiece y) { return x.compare(y) < 0; } +inline bool operator>(StringPiece x, StringPiece y) { return x.compare(y) > 0; } +inline bool operator<=(StringPiece x, StringPiece y) { + return x.compare(y) <= 0; +} +inline bool operator>=(StringPiece x, StringPiece y) { + return x.compare(y) >= 0; +} + +inline int StringPiece::compare(StringPiece b) const { + const size_t min_len = (size_ < b.size_) ? size_ : b.size_; + int r = memcmp(data_, b.data_, min_len); + if (r == 0) { + if (size_ < b.size_) + r = -1; + else if (size_ > b.size_) + r = +1; + } + return r; +} + +// allow StringPiece to be logged +extern std::ostream& operator<<(std::ostream& o, tensorflow::StringPiece piece); } // namespace tensorflow diff --git a/tensorflow/core/lib/io/record_writer.h b/tensorflow/core/lib/io/record_writer.h index 2f6afa5487..6a2bf66d12 100644 --- a/tensorflow/core/lib/io/record_writer.h +++ b/tensorflow/core/lib/io/record_writer.h @@ -41,7 +41,7 @@ class RecordWriterOptions { // Options specific to zlib compression. #if !defined(IS_SLIM_BUILD) - ZlibCompressionOptions zlib_options; + tensorflow::io::ZlibCompressionOptions zlib_options; #endif // IS_SLIM_BUILD }; diff --git a/tensorflow/core/lib/strings/strcat.h b/tensorflow/core/lib/strings/strcat.h index a620f59447..351b6f5de3 100644 --- a/tensorflow/core/lib/strings/strcat.h +++ b/tensorflow/core/lib/strings/strcat.h @@ -124,9 +124,6 @@ class AlphaNum { AlphaNum(const StringPiece &pc) : piece_(pc) {} // NOLINT(runtime/explicit) AlphaNum(const tensorflow::string &str) // NOLINT(runtime/explicit) : piece_(str) {} - template <typename A> - AlphaNum(const std::basic_string<char, std::char_traits<char>, A> &str) - : piece_(str) {} // NOLINT(runtime/explicit) StringPiece::size_type size() const { return piece_.size(); } const char *data() const { return piece_.data(); } diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt index 9836f784ab..c32d6f84f5 100644 --- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt @@ -13070,6 +13070,71 @@ op { is_stateful: true } op { + name: "ConditionalAccumulator" + output_arg { + name: "handle" + type: DT_STRING + is_ref: true + } + attr { + name: "dtype" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_UINT8 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_INT64 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_BFLOAT16 + type: DT_UINT16 + type: DT_COMPLEX128 + type: DT_HALF + type: DT_UINT32 + type: DT_UINT64 + } + } + } + attr { + name: "shape" + type: "shape" + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + } + attr { + name: "reduction_type" + type: "string" + default_value { + s: "MEAN" + } + allowed_values { + list { + s: "MEAN" + s: "SUM" + } + } + } + is_stateful: true +} +op { name: "Conj" input_arg { name: "input" @@ -37080,6 +37145,54 @@ op { } } op { + name: "ParallelInterleaveDatasetV2" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "other_arguments" + type_list_attr: "Targuments" + } + input_arg { + name: "cycle_length" + type: DT_INT64 + } + input_arg { + name: "block_length" + type: DT_INT64 + } + input_arg { + name: "num_parallel_calls" + type: DT_INT64 + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "f" + type: "func" + } + attr { + name: "Targuments" + type: "list(type)" + has_minimum: true + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } +} +op { name: "ParallelMapDataset" input_arg { name: "input_dataset" @@ -37161,6 +37274,53 @@ op { } } op { + name: "ParallelMapDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "other_arguments" + type_list_attr: "Targuments" + } + input_arg { + name: "num_parallel_calls" + type: DT_INT32 + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "f" + type: "func" + } + attr { + name: "Targuments" + type: "list(type)" + has_minimum: true + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + attr { + name: "use_inter_op_parallelism" + type: "bool" + default_value { + b: true + } + } +} +op { name: "ParameterizedTruncatedNormal" input_arg { name: "shape" @@ -64543,6 +64703,71 @@ op { is_stateful: true } op { + name: "SparseConditionalAccumulator" + output_arg { + name: "handle" + type: DT_STRING + is_ref: true + } + attr { + name: "dtype" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_UINT8 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_INT64 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_BFLOAT16 + type: DT_UINT16 + type: DT_COMPLEX128 + type: DT_HALF + type: DT_UINT32 + type: DT_UINT64 + } + } + } + attr { + name: "shape" + type: "shape" + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + } + attr { + name: "reduction_type" + type: "string" + default_value { + s: "MEAN" + } + allowed_values { + list { + s: "MEAN" + s: "SUM" + } + } + } + is_stateful: true +} +op { name: "SparseCross" input_arg { name: "indices" @@ -69336,6 +69561,21 @@ op { } } op { + name: "StaticRegexFullMatch" + input_arg { + name: "input" + type: DT_STRING + } + output_arg { + name: "output" + type: DT_BOOL + } + attr { + name: "pattern" + type: "string" + } +} +op { name: "StaticRegexReplace" input_arg { name: "input" diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc index eed0bce174..ffab8ad661 100644 --- a/tensorflow/core/ops/data_flow_ops.cc +++ b/tensorflow/core/ops/data_flow_ops.cc @@ -419,6 +419,7 @@ REGISTER_OP("ConditionalAccumulator") .Attr("shape: shape") .Attr("container: string = ''") .Attr("shared_name: string = ''") + .Attr("reduction_type: { 'MEAN', 'SUM' } = 'MEAN' ") .SetIsStateful() .SetShapeFn([](InferenceContext* c) { c->set_output(0, c->Vector(2)); @@ -456,6 +457,7 @@ REGISTER_OP("SparseConditionalAccumulator") .Attr("shape: shape") .Attr("container: string = ''") .Attr("shared_name: string = ''") + .Attr("reduction_type: { 'MEAN', 'SUM' } = 'MEAN' ") .SetIsStateful() .SetShapeFn([](InferenceContext* c) { c->set_output(0, c->Vector(2)); diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index 1a5ad8f421..9d2b3af51d 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -210,6 +210,7 @@ REGISTER_OP("ParallelMapDataset") .Attr("Targuments: list(type) >= 0") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") + .Attr("use_inter_op_parallelism: bool = true") .SetShapeFn(shape_inference::ScalarShape); REGISTER_OP("MapAndBatchDataset") @@ -326,6 +327,19 @@ REGISTER_OP("ParallelInterleaveDataset") .Attr("output_shapes: list(shape) >= 1") .SetShapeFn(shape_inference::ScalarShape); +REGISTER_OP("ParallelInterleaveDatasetV2") + .Input("input_dataset: variant") + .Input("other_arguments: Targuments") + .Input("cycle_length: int64") + .Input("block_length: int64") + .Input("num_parallel_calls: int64") + .Output("handle: variant") + .Attr("f: func") + .Attr("Targuments: list(type) >= 0") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape); + REGISTER_OP("GroupByReducerDataset") .Input("input_dataset: variant") .Input("key_func_other_arguments: Tkey_func_other_arguments") @@ -867,7 +881,7 @@ REGISTER_OP("MapDefun") .Attr("output_shapes: list(shape) >= 1") .Attr("f: func") .SetShapeFn([](shape_inference::InferenceContext* c) { - std::vector<TensorShape> output_shapes; + std::vector<PartialTensorShape> output_shapes; TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); if (output_shapes.size() != c->num_outputs()) { return errors::InvalidArgument( @@ -877,6 +891,10 @@ REGISTER_OP("MapDefun") int64 dim_zero = -1; for (size_t i = 0; i < static_cast<size_t>(c->num_inputs()); ++i) { + if (c->Rank(c->input(i)) == 0) { + return errors::InvalidArgument( + "Inputs must have rank at least 1. Input ", i, " has rank of 0"); + } auto dim_handle = c->Dim(c->input(i), 0); if (c->ValueKnown(dim_handle)) { if (dim_zero == -1) { diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc index 11ca0bd259..5427275284 100644 --- a/tensorflow/core/ops/image_ops.cc +++ b/tensorflow/core/ops/image_ops.cc @@ -683,11 +683,12 @@ REGISTER_OP("NonMaxSuppression") }); REGISTER_OP("NonMaxSuppressionV2") - .Input("boxes: float") - .Input("scores: float") + .Input("boxes: T") + .Input("scores: T") .Input("max_output_size: int32") .Input("iou_threshold: float") .Output("selected_indices: int32") + .Attr("T: {half, float} = DT_FLOAT") .SetShapeFn([](InferenceContext* c) { // Get inputs and validate ranks. ShapeHandle boxes; @@ -711,22 +712,24 @@ REGISTER_OP("NonMaxSuppressionV2") }); REGISTER_OP("NonMaxSuppressionV3") - .Input("boxes: float") - .Input("scores: float") + .Input("boxes: T") + .Input("scores: T") .Input("max_output_size: int32") .Input("iou_threshold: float") .Input("score_threshold: float") .Output("selected_indices: int32") + .Attr("T: {half, float} = DT_FLOAT") .SetShapeFn(NMSShapeFn); REGISTER_OP("NonMaxSuppressionV4") - .Input("boxes: float") - .Input("scores: float") + .Input("boxes: T") + .Input("scores: T") .Input("max_output_size: int32") .Input("iou_threshold: float") .Input("score_threshold: float") .Output("selected_indices: int32") .Output("valid_outputs: int32") + .Attr("T: {half, float} = DT_FLOAT") .Attr("pad_to_max_output_size: bool = false") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(NMSShapeFn(c)); diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 28b25fdeae..aeb03c5952 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -5592,6 +5592,19 @@ op { s: "" } } + attr { + name: "reduction_type" + type: "string" + default_value { + s: "MEAN" + } + allowed_values { + list { + s: "MEAN" + s: "SUM" + } + } + } is_stateful: true } op { @@ -18199,6 +18212,54 @@ op { } } op { + name: "ParallelInterleaveDatasetV2" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "other_arguments" + type_list_attr: "Targuments" + } + input_arg { + name: "cycle_length" + type: DT_INT64 + } + input_arg { + name: "block_length" + type: DT_INT64 + } + input_arg { + name: "num_parallel_calls" + type: DT_INT64 + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "f" + type: "func" + } + attr { + name: "Targuments" + type: "list(type)" + has_minimum: true + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } +} +op { name: "ParallelMapDataset" input_arg { name: "input_dataset" @@ -18237,6 +18298,13 @@ op { has_minimum: true minimum: 1 } + attr { + name: "use_inter_op_parallelism" + type: "bool" + default_value { + b: true + } + } } op { name: "ParameterizedTruncatedNormal" @@ -29617,6 +29685,19 @@ op { s: "" } } + attr { + name: "reduction_type" + type: "string" + default_value { + s: "MEAN" + } + allowed_values { + list { + s: "MEAN" + s: "SUM" + } + } + } is_stateful: true } op { @@ -32115,6 +32196,21 @@ op { } } op { + name: "StaticRegexFullMatch" + input_arg { + name: "input" + type: DT_STRING + } + output_arg { + name: "output" + type: DT_BOOL + } + attr { + name: "pattern" + type: "string" + } +} +op { name: "StaticRegexReplace" input_arg { name: "input" diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc index 7aa1e71809..ef8b15dc8a 100644 --- a/tensorflow/core/ops/string_ops.cc +++ b/tensorflow/core/ops/string_ops.cc @@ -56,6 +56,12 @@ REGISTER_OP("RegexFullMatch") return Status::OK(); }); +REGISTER_OP("StaticRegexFullMatch") + .Input("input: string") + .Attr("pattern: string") + .Output("output: bool") + .SetShapeFn(shape_inference::UnchangedShape); + REGISTER_OP("StringToHashBucketFast") .Input("input: string") .Output("output: int64") diff --git a/tensorflow/core/platform/default/device_tracer.cc b/tensorflow/core/platform/default/device_tracer.cc index ccddf1eafc..0389149469 100644 --- a/tensorflow/core/platform/default/device_tracer.cc +++ b/tensorflow/core/platform/default/device_tracer.cc @@ -321,6 +321,11 @@ class DeviceTracerImpl : public DeviceTracer, return nullptr; } + bool IsEnabled(bool is_expensive) const override { + // We don't do anything with 'Activities' so we are never 'enabled'. + return false; + } + protected: // This callback is used exclusively by CUPTIManager. friend class CUPTIManager; diff --git a/tensorflow/core/platform/tracing.h b/tensorflow/core/platform/tracing.h index e5851f1dfe..9974bbbb4e 100644 --- a/tensorflow/core/platform/tracing.h +++ b/tensorflow/core/platform/tracing.h @@ -155,6 +155,10 @@ class TraceCollector { StringPiece name_part1, StringPiece name_part2, bool is_expensive) const = 0; + // Returns true if this activity handle tracking is enabled for an op of the + // given expensiveness. + virtual bool IsEnabled(bool is_expensive) const = 0; + protected: static string ConcatenateNames(StringPiece first, StringPiece second); diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 5ebd409b15..e755c37039 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -3401,56 +3401,59 @@ func BoostedTreesCenterBias(scope *Scope, tree_ensemble_handle tf.Output, mean_g return op.Output(0) } -// Computes the mean along sparse segments of a tensor. -// -// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of -// segments. +// Runs multiple additive regression ensemble predictors on input instances and // -// Like `SegmentMean`, but `segment_ids` can have rank less than `data`'s first -// dimension, selecting a subset of dimension 0, specified by `indices`. +// computes the update to cached logits. It is designed to be used during training. +// It traverses the trees starting from cached tree id and cached node id and +// calculates the updates to be pushed to the cache. // // Arguments: // -// indices: A 1-D tensor. Has same rank as `segment_ids`. -// segment_ids: A 1-D tensor. Values should be sorted and can be repeated. +// cached_tree_ids: Rank 1 Tensor containing cached tree ids which is the starting +// tree of prediction. +// cached_node_ids: Rank 1 Tensor containing cached node id which is the starting +// node of prediction. +// bucketized_features: A list of rank 1 Tensors containing bucket id for each +// feature. +// logits_dimension: scalar, dimension of the logits, to be used for partial logits +// shape. // -// Returns Has same shape as data, except for dimension 0 which -// has size `k`, the number of segments. -func SparseSegmentMean(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output) (output tf.Output) { +// Returns Rank 2 Tensor containing logits update (with respect to cached +// values stored) for each example.Rank 1 Tensor containing new tree ids for each example.Rank 1 Tensor containing new node ids in the new tree_ids. +func BoostedTreesTrainingPredict(scope *Scope, tree_ensemble_handle tf.Output, cached_tree_ids tf.Output, cached_node_ids tf.Output, bucketized_features []tf.Output, logits_dimension int64) (partial_logits tf.Output, tree_ids tf.Output, node_ids tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"logits_dimension": logits_dimension} opspec := tf.OpSpec{ - Type: "SparseSegmentMean", + Type: "BoostedTreesTrainingPredict", Input: []tf.Input{ - data, indices, segment_ids, + tree_ensemble_handle, cached_tree_ids, cached_node_ids, tf.OutputList(bucketized_features), }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// Pop the element at the top of the stack. +// Serializes the tree ensemble to a proto. // // Arguments: -// handle: The handle to a stack. -// elem_type: The type of the elem that is popped. +// tree_ensemble_handle: Handle to the tree ensemble. // -// Returns The tensor that is popped from the top of the stack. -func StackPopV2(scope *Scope, handle tf.Output, elem_type tf.DataType) (elem tf.Output) { +// Returns Stamp token of the tree ensemble resource.Serialized proto of the ensemble. +func BoostedTreesSerializeEnsemble(scope *Scope, tree_ensemble_handle tf.Output) (stamp_token tf.Output, tree_ensemble_serialized tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"elem_type": elem_type} opspec := tf.OpSpec{ - Type: "StackPopV2", + Type: "BoostedTreesSerializeEnsemble", Input: []tf.Input{ - handle, + tree_ensemble_handle, }, - Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1) } // Computes the sum along sparse segments of a tensor. @@ -8159,47 +8162,6 @@ func DecodeRaw(scope *Scope, bytes tf.Output, out_type tf.DataType, optional ... return op.Output(0) } -// RandomPoissonAttr is an optional argument to RandomPoisson. -type RandomPoissonAttr func(optionalAttr) - -// RandomPoissonSeed sets the optional seed attribute to value. -// If not specified, defaults to 0 -func RandomPoissonSeed(value int64) RandomPoissonAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// RandomPoissonSeed2 sets the optional seed2 attribute to value. -// If not specified, defaults to 0 -func RandomPoissonSeed2(value int64) RandomPoissonAttr { - return func(m optionalAttr) { - m["seed2"] = value - } -} - -// Use RandomPoissonV2 instead. -// -// DEPRECATED at GraphDef version 25: Replaced by RandomPoissonV2 -func RandomPoisson(scope *Scope, shape tf.Output, rate tf.Output, optional ...RandomPoissonAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "RandomPoisson", - Input: []tf.Input{ - shape, rate, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Returns the element-wise sum of a list of tensors. // // `tf.accumulate_n_v2` performs the same operation as `tf.add_n`, but does not @@ -8348,6 +8310,377 @@ func OrderedMapIncompleteSize(scope *Scope, dtypes []tf.DataType, optional ...Or return op.Output(0) } +// Returns the truth value of (x > y) element-wise. +// +// *NOTE*: `Greater` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func Greater(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Greater", + Input: []tf.Input{ + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ResourceSparseApplyRMSPropAttr is an optional argument to ResourceSparseApplyRMSProp. +type ResourceSparseApplyRMSPropAttr func(optionalAttr) + +// ResourceSparseApplyRMSPropUseLocking sets the optional use_locking attribute to value. +// +// value: If `True`, updating of the var, ms, and mom tensors is protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceSparseApplyRMSPropUseLocking(value bool) ResourceSparseApplyRMSPropAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Update '*var' according to the RMSProp algorithm. +// +// Note that in dense implementation of this algorithm, ms and mom will +// update even if the grad is zero, but in this sparse implementation, ms +// and mom will not update in iterations during which the grad is zero. +// +// mean_square = decay * mean_square + (1-decay) * gradient ** 2 +// Delta = learning_rate * gradient / sqrt(mean_square + epsilon) +// +// ms <- rho * ms_{t-1} + (1-rho) * grad * grad +// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) +// var <- var - mom +// +// Arguments: +// var_: Should be from a Variable(). +// ms: Should be from a Variable(). +// mom: Should be from a Variable(). +// lr: Scaling factor. Must be a scalar. +// rho: Decay rate. Must be a scalar. +// +// epsilon: Ridge term. Must be a scalar. +// grad: The gradient. +// indices: A vector of indices into the first dimension of var, ms and mom. +// +// Returns the created operation. +func ResourceSparseApplyRMSProp(scope *Scope, var_ tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyRMSPropAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceSparseApplyRMSProp", + Input: []tf.Input{ + var_, ms, mom, lr, rho, momentum, epsilon, grad, indices, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// SampleDistortedBoundingBoxAttr is an optional argument to SampleDistortedBoundingBox. +type SampleDistortedBoundingBoxAttr func(optionalAttr) + +// SampleDistortedBoundingBoxSeed sets the optional seed attribute to value. +// +// value: If either `seed` or `seed2` are set to non-zero, the random number +// generator is seeded by the given `seed`. Otherwise, it is seeded by a random +// seed. +// If not specified, defaults to 0 +func SampleDistortedBoundingBoxSeed(value int64) SampleDistortedBoundingBoxAttr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// SampleDistortedBoundingBoxSeed2 sets the optional seed2 attribute to value. +// +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func SampleDistortedBoundingBoxSeed2(value int64) SampleDistortedBoundingBoxAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// SampleDistortedBoundingBoxMinObjectCovered sets the optional min_object_covered attribute to value. +// +// value: The cropped area of the image must contain at least this +// fraction of any bounding box supplied. The value of this parameter should be +// non-negative. In the case of 0, the cropped area does not need to overlap +// any of the bounding boxes supplied. +// If not specified, defaults to 0.1 +func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBoundingBoxAttr { + return func(m optionalAttr) { + m["min_object_covered"] = value + } +} + +// SampleDistortedBoundingBoxAspectRatioRange sets the optional aspect_ratio_range attribute to value. +// +// value: The cropped area of the image must have an aspect ratio = +// width / height within this range. +// If not specified, defaults to <f:0.75 f:1.33 > +func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { + return func(m optionalAttr) { + m["aspect_ratio_range"] = value + } +} + +// SampleDistortedBoundingBoxAreaRange sets the optional area_range attribute to value. +// +// value: The cropped area of the image must contain a fraction of the +// supplied image within this range. +// If not specified, defaults to <f:0.05 f:1 > +func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { + return func(m optionalAttr) { + m["area_range"] = value + } +} + +// SampleDistortedBoundingBoxMaxAttempts sets the optional max_attempts attribute to value. +// +// value: Number of attempts at generating a cropped region of the image +// of the specified constraints. After `max_attempts` failures, return the entire +// image. +// If not specified, defaults to 100 +func SampleDistortedBoundingBoxMaxAttempts(value int64) SampleDistortedBoundingBoxAttr { + return func(m optionalAttr) { + m["max_attempts"] = value + } +} + +// SampleDistortedBoundingBoxUseImageIfNoBoundingBoxes sets the optional use_image_if_no_bounding_boxes attribute to value. +// +// value: Controls behavior if no bounding boxes supplied. +// If true, assume an implicit bounding box covering the whole input. If false, +// raise an error. +// If not specified, defaults to false +func SampleDistortedBoundingBoxUseImageIfNoBoundingBoxes(value bool) SampleDistortedBoundingBoxAttr { + return func(m optionalAttr) { + m["use_image_if_no_bounding_boxes"] = value + } +} + +// Generate a single randomly distorted bounding box for an image. +// +// Bounding box annotations are often supplied in addition to ground-truth labels +// in image recognition or object localization tasks. A common technique for +// training such a system is to randomly distort an image while preserving +// its content, i.e. *data augmentation*. This Op outputs a randomly distorted +// localization of an object, i.e. bounding box, given an `image_size`, +// `bounding_boxes` and a series of constraints. +// +// The output of this Op is a single bounding box that may be used to crop the +// original image. The output is returned as 3 tensors: `begin`, `size` and +// `bboxes`. The first 2 tensors can be fed directly into `tf.slice` to crop the +// image. The latter may be supplied to `tf.image.draw_bounding_boxes` to visualize +// what the bounding box looks like. +// +// Bounding boxes are supplied and returned as `[y_min, x_min, y_max, x_max]`. The +// bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and +// height of the underlying image. +// +// For example, +// +// ```python +// # Generate a single distorted bounding box. +// begin, size, bbox_for_draw = tf.image.sample_distorted_bounding_box( +// tf.shape(image), +// bounding_boxes=bounding_boxes) +// +// # Draw the bounding box in an image summary. +// image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0), +// bbox_for_draw) +// tf.summary.image('images_with_box', image_with_box) +// +// # Employ the bounding box to distort the image. +// distorted_image = tf.slice(image, begin, size) +// ``` +// +// Note that if no bounding box information is available, setting +// `use_image_if_no_bounding_boxes = true` will assume there is a single implicit +// bounding box covering the whole image. If `use_image_if_no_bounding_boxes` is +// false and no bounding boxes are supplied, an error is raised. +// +// Arguments: +// image_size: 1-D, containing `[height, width, channels]`. +// bounding_boxes: 3-D with shape `[batch, N, 4]` describing the N bounding boxes +// associated with the image. +// +// Returns 1-D, containing `[offset_height, offset_width, 0]`. Provide as input to +// `tf.slice`.1-D, containing `[target_height, target_width, -1]`. Provide as input to +// `tf.slice`.3-D with shape `[1, 1, 4]` containing the distorted bounding box. +// Provide as input to `tf.image.draw_bounding_boxes`. +func SampleDistortedBoundingBox(scope *Scope, image_size tf.Output, bounding_boxes tf.Output, optional ...SampleDistortedBoundingBoxAttr) (begin tf.Output, size tf.Output, bboxes tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "SampleDistortedBoundingBox", + Input: []tf.Input{ + image_size, bounding_boxes, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// Computes sigmoid of `x` element-wise. +// +// Specifically, `y = 1 / (1 + exp(-x))`. +func Sigmoid(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Sigmoid", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// FusedBatchNormAttr is an optional argument to FusedBatchNorm. +type FusedBatchNormAttr func(optionalAttr) + +// FusedBatchNormEpsilon sets the optional epsilon attribute to value. +// +// value: A small float number added to the variance of x. +// If not specified, defaults to 0.0001 +func FusedBatchNormEpsilon(value float32) FusedBatchNormAttr { + return func(m optionalAttr) { + m["epsilon"] = value + } +} + +// FusedBatchNormDataFormat sets the optional data_format attribute to value. +// +// value: The data format for x and y. Either "NHWC" (default) or "NCHW". +// If not specified, defaults to "NHWC" +func FusedBatchNormDataFormat(value string) FusedBatchNormAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// FusedBatchNormIsTraining sets the optional is_training attribute to value. +// +// value: A bool value to indicate the operation is for training (default) +// or inference. +// If not specified, defaults to true +func FusedBatchNormIsTraining(value bool) FusedBatchNormAttr { + return func(m optionalAttr) { + m["is_training"] = value + } +} + +// Batch normalization. +// +// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". +// The size of 1D Tensors matches the dimension C of the 4D Tensors. +// +// Arguments: +// x: A 4D Tensor for input data. +// scale: A 1D Tensor for scaling factor, to scale the normalized x. +// offset: A 1D Tensor for offset, to shift to the normalized x. +// mean: A 1D Tensor for population mean. Used for inference only; +// must be empty for training. +// variance: A 1D Tensor for population variance. Used for inference only; +// must be empty for training. +// +// Returns A 4D Tensor for output data.A 1D Tensor for the computed batch mean, to be used by TensorFlow +// to compute the running mean.A 1D Tensor for the computed batch variance, to be used by +// TensorFlow to compute the running variance.A 1D Tensor for the computed batch mean, to be reused +// in the gradient computation.A 1D Tensor for the computed batch variance (inverted variance +// in the cuDNN case), to be reused in the gradient computation. +func FusedBatchNorm(scope *Scope, x tf.Output, scale tf.Output, offset tf.Output, mean tf.Output, variance tf.Output, optional ...FusedBatchNormAttr) (y tf.Output, batch_mean tf.Output, batch_variance tf.Output, reserve_space_1 tf.Output, reserve_space_2 tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "FusedBatchNorm", + Input: []tf.Input{ + x, scale, offset, mean, variance, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) +} + +// RandomStandardNormalAttr is an optional argument to RandomStandardNormal. +type RandomStandardNormalAttr func(optionalAttr) + +// RandomStandardNormalSeed sets the optional seed attribute to value. +// +// value: If either `seed` or `seed2` are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func RandomStandardNormalSeed(value int64) RandomStandardNormalAttr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// RandomStandardNormalSeed2 sets the optional seed2 attribute to value. +// +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func RandomStandardNormalSeed2(value int64) RandomStandardNormalAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// Outputs random values from a normal distribution. +// +// The generated values will have mean 0 and standard deviation 1. +// +// Arguments: +// shape: The shape of the output tensor. +// dtype: The type of the output. +// +// Returns A tensor of the specified shape filled with random normal values. +func RandomStandardNormal(scope *Scope, shape tf.Output, dtype tf.DataType, optional ...RandomStandardNormalAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtype": dtype} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RandomStandardNormal", + Input: []tf.Input{ + shape, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // ResourceApplyFtrlAttr is an optional argument to ResourceApplyFtrl. type ResourceApplyFtrlAttr func(optionalAttr) @@ -12357,235 +12690,6 @@ func OrderedMapPeek(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf. return values } -// ResourceSparseApplyRMSPropAttr is an optional argument to ResourceSparseApplyRMSProp. -type ResourceSparseApplyRMSPropAttr func(optionalAttr) - -// ResourceSparseApplyRMSPropUseLocking sets the optional use_locking attribute to value. -// -// value: If `True`, updating of the var, ms, and mom tensors is protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceSparseApplyRMSPropUseLocking(value bool) ResourceSparseApplyRMSPropAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// Update '*var' according to the RMSProp algorithm. -// -// Note that in dense implementation of this algorithm, ms and mom will -// update even if the grad is zero, but in this sparse implementation, ms -// and mom will not update in iterations during which the grad is zero. -// -// mean_square = decay * mean_square + (1-decay) * gradient ** 2 -// Delta = learning_rate * gradient / sqrt(mean_square + epsilon) -// -// ms <- rho * ms_{t-1} + (1-rho) * grad * grad -// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) -// var <- var - mom -// -// Arguments: -// var_: Should be from a Variable(). -// ms: Should be from a Variable(). -// mom: Should be from a Variable(). -// lr: Scaling factor. Must be a scalar. -// rho: Decay rate. Must be a scalar. -// -// epsilon: Ridge term. Must be a scalar. -// grad: The gradient. -// indices: A vector of indices into the first dimension of var, ms and mom. -// -// Returns the created operation. -func ResourceSparseApplyRMSProp(scope *Scope, var_ tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyRMSPropAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResourceSparseApplyRMSProp", - Input: []tf.Input{ - var_, ms, mom, lr, rho, momentum, epsilon, grad, indices, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// Returns the truth value of (x > y) element-wise. -// -// *NOTE*: `Greater` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func Greater(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Greater", - Input: []tf.Input{ - x, y, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// SampleDistortedBoundingBoxAttr is an optional argument to SampleDistortedBoundingBox. -type SampleDistortedBoundingBoxAttr func(optionalAttr) - -// SampleDistortedBoundingBoxSeed sets the optional seed attribute to value. -// -// value: If either `seed` or `seed2` are set to non-zero, the random number -// generator is seeded by the given `seed`. Otherwise, it is seeded by a random -// seed. -// If not specified, defaults to 0 -func SampleDistortedBoundingBoxSeed(value int64) SampleDistortedBoundingBoxAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// SampleDistortedBoundingBoxSeed2 sets the optional seed2 attribute to value. -// -// value: A second seed to avoid seed collision. -// If not specified, defaults to 0 -func SampleDistortedBoundingBoxSeed2(value int64) SampleDistortedBoundingBoxAttr { - return func(m optionalAttr) { - m["seed2"] = value - } -} - -// SampleDistortedBoundingBoxMinObjectCovered sets the optional min_object_covered attribute to value. -// -// value: The cropped area of the image must contain at least this -// fraction of any bounding box supplied. The value of this parameter should be -// non-negative. In the case of 0, the cropped area does not need to overlap -// any of the bounding boxes supplied. -// If not specified, defaults to 0.1 -func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBoundingBoxAttr { - return func(m optionalAttr) { - m["min_object_covered"] = value - } -} - -// SampleDistortedBoundingBoxAspectRatioRange sets the optional aspect_ratio_range attribute to value. -// -// value: The cropped area of the image must have an aspect ratio = -// width / height within this range. -// If not specified, defaults to <f:0.75 f:1.33 > -func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { - return func(m optionalAttr) { - m["aspect_ratio_range"] = value - } -} - -// SampleDistortedBoundingBoxAreaRange sets the optional area_range attribute to value. -// -// value: The cropped area of the image must contain a fraction of the -// supplied image within this range. -// If not specified, defaults to <f:0.05 f:1 > -func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { - return func(m optionalAttr) { - m["area_range"] = value - } -} - -// SampleDistortedBoundingBoxMaxAttempts sets the optional max_attempts attribute to value. -// -// value: Number of attempts at generating a cropped region of the image -// of the specified constraints. After `max_attempts` failures, return the entire -// image. -// If not specified, defaults to 100 -func SampleDistortedBoundingBoxMaxAttempts(value int64) SampleDistortedBoundingBoxAttr { - return func(m optionalAttr) { - m["max_attempts"] = value - } -} - -// SampleDistortedBoundingBoxUseImageIfNoBoundingBoxes sets the optional use_image_if_no_bounding_boxes attribute to value. -// -// value: Controls behavior if no bounding boxes supplied. -// If true, assume an implicit bounding box covering the whole input. If false, -// raise an error. -// If not specified, defaults to false -func SampleDistortedBoundingBoxUseImageIfNoBoundingBoxes(value bool) SampleDistortedBoundingBoxAttr { - return func(m optionalAttr) { - m["use_image_if_no_bounding_boxes"] = value - } -} - -// Generate a single randomly distorted bounding box for an image. -// -// Bounding box annotations are often supplied in addition to ground-truth labels -// in image recognition or object localization tasks. A common technique for -// training such a system is to randomly distort an image while preserving -// its content, i.e. *data augmentation*. This Op outputs a randomly distorted -// localization of an object, i.e. bounding box, given an `image_size`, -// `bounding_boxes` and a series of constraints. -// -// The output of this Op is a single bounding box that may be used to crop the -// original image. The output is returned as 3 tensors: `begin`, `size` and -// `bboxes`. The first 2 tensors can be fed directly into `tf.slice` to crop the -// image. The latter may be supplied to `tf.image.draw_bounding_boxes` to visualize -// what the bounding box looks like. -// -// Bounding boxes are supplied and returned as `[y_min, x_min, y_max, x_max]`. The -// bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and -// height of the underlying image. -// -// For example, -// -// ```python -// # Generate a single distorted bounding box. -// begin, size, bbox_for_draw = tf.image.sample_distorted_bounding_box( -// tf.shape(image), -// bounding_boxes=bounding_boxes) -// -// # Draw the bounding box in an image summary. -// image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0), -// bbox_for_draw) -// tf.summary.image('images_with_box', image_with_box) -// -// # Employ the bounding box to distort the image. -// distorted_image = tf.slice(image, begin, size) -// ``` -// -// Note that if no bounding box information is available, setting -// `use_image_if_no_bounding_boxes = true` will assume there is a single implicit -// bounding box covering the whole image. If `use_image_if_no_bounding_boxes` is -// false and no bounding boxes are supplied, an error is raised. -// -// Arguments: -// image_size: 1-D, containing `[height, width, channels]`. -// bounding_boxes: 3-D with shape `[batch, N, 4]` describing the N bounding boxes -// associated with the image. -// -// Returns 1-D, containing `[offset_height, offset_width, 0]`. Provide as input to -// `tf.slice`.1-D, containing `[target_height, target_width, -1]`. Provide as input to -// `tf.slice`.3-D with shape `[1, 1, 4]` containing the distorted bounding box. -// Provide as input to `tf.image.draw_bounding_boxes`. -func SampleDistortedBoundingBox(scope *Scope, image_size tf.Output, bounding_boxes tf.Output, optional ...SampleDistortedBoundingBoxAttr) (begin tf.Output, size tf.Output, bboxes tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "SampleDistortedBoundingBox", - Input: []tf.Input{ - image_size, bounding_boxes, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - // LRNAttr is an optional argument to LRN. type LRNAttr func(optionalAttr) @@ -14396,6 +14500,47 @@ func Sub(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { return op.Output(0) } +// RandomPoissonAttr is an optional argument to RandomPoisson. +type RandomPoissonAttr func(optionalAttr) + +// RandomPoissonSeed sets the optional seed attribute to value. +// If not specified, defaults to 0 +func RandomPoissonSeed(value int64) RandomPoissonAttr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// RandomPoissonSeed2 sets the optional seed2 attribute to value. +// If not specified, defaults to 0 +func RandomPoissonSeed2(value int64) RandomPoissonAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// Use RandomPoissonV2 instead. +// +// DEPRECATED at GraphDef version 25: Replaced by RandomPoissonV2 +func RandomPoisson(scope *Scope, shape tf.Output, rate tf.Output, optional ...RandomPoissonAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RandomPoisson", + Input: []tf.Input{ + shape, rate, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // LogUniformCandidateSamplerAttr is an optional argument to LogUniformCandidateSampler. type LogUniformCandidateSamplerAttr func(optionalAttr) @@ -16136,148 +16281,6 @@ func ResourceScatterMul(scope *Scope, resource tf.Output, indices tf.Output, upd return scope.AddOperation(opspec) } -// Computes sigmoid of `x` element-wise. -// -// Specifically, `y = 1 / (1 + exp(-x))`. -func Sigmoid(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Sigmoid", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// FusedBatchNormAttr is an optional argument to FusedBatchNorm. -type FusedBatchNormAttr func(optionalAttr) - -// FusedBatchNormEpsilon sets the optional epsilon attribute to value. -// -// value: A small float number added to the variance of x. -// If not specified, defaults to 0.0001 -func FusedBatchNormEpsilon(value float32) FusedBatchNormAttr { - return func(m optionalAttr) { - m["epsilon"] = value - } -} - -// FusedBatchNormDataFormat sets the optional data_format attribute to value. -// -// value: The data format for x and y. Either "NHWC" (default) or "NCHW". -// If not specified, defaults to "NHWC" -func FusedBatchNormDataFormat(value string) FusedBatchNormAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// FusedBatchNormIsTraining sets the optional is_training attribute to value. -// -// value: A bool value to indicate the operation is for training (default) -// or inference. -// If not specified, defaults to true -func FusedBatchNormIsTraining(value bool) FusedBatchNormAttr { - return func(m optionalAttr) { - m["is_training"] = value - } -} - -// Batch normalization. -// -// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". -// The size of 1D Tensors matches the dimension C of the 4D Tensors. -// -// Arguments: -// x: A 4D Tensor for input data. -// scale: A 1D Tensor for scaling factor, to scale the normalized x. -// offset: A 1D Tensor for offset, to shift to the normalized x. -// mean: A 1D Tensor for population mean. Used for inference only; -// must be empty for training. -// variance: A 1D Tensor for population variance. Used for inference only; -// must be empty for training. -// -// Returns A 4D Tensor for output data.A 1D Tensor for the computed batch mean, to be used by TensorFlow -// to compute the running mean.A 1D Tensor for the computed batch variance, to be used by -// TensorFlow to compute the running variance.A 1D Tensor for the computed batch mean, to be reused -// in the gradient computation.A 1D Tensor for the computed batch variance (inverted variance -// in the cuDNN case), to be reused in the gradient computation. -func FusedBatchNorm(scope *Scope, x tf.Output, scale tf.Output, offset tf.Output, mean tf.Output, variance tf.Output, optional ...FusedBatchNormAttr) (y tf.Output, batch_mean tf.Output, batch_variance tf.Output, reserve_space_1 tf.Output, reserve_space_2 tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "FusedBatchNorm", - Input: []tf.Input{ - x, scale, offset, mean, variance, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) -} - -// RandomStandardNormalAttr is an optional argument to RandomStandardNormal. -type RandomStandardNormalAttr func(optionalAttr) - -// RandomStandardNormalSeed sets the optional seed attribute to value. -// -// value: If either `seed` or `seed2` are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func RandomStandardNormalSeed(value int64) RandomStandardNormalAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// RandomStandardNormalSeed2 sets the optional seed2 attribute to value. -// -// value: A second seed to avoid seed collision. -// If not specified, defaults to 0 -func RandomStandardNormalSeed2(value int64) RandomStandardNormalAttr { - return func(m optionalAttr) { - m["seed2"] = value - } -} - -// Outputs random values from a normal distribution. -// -// The generated values will have mean 0 and standard deviation 1. -// -// Arguments: -// shape: The shape of the output tensor. -// dtype: The type of the output. -// -// Returns A tensor of the specified shape filled with random normal values. -func RandomStandardNormal(scope *Scope, shape tf.Output, dtype tf.DataType, optional ...RandomStandardNormalAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtype": dtype} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "RandomStandardNormal", - Input: []tf.Input{ - shape, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Component-wise divides a SparseTensor by a dense Tensor. // // *Limitation*: this Op only broadcasts the dense side to the sparse side, but not @@ -17427,26 +17430,6 @@ func DecodeJpeg(scope *Scope, contents tf.Output, optional ...DecodeJpegAttr) (i return op.Output(0) } -// Serializes the tree ensemble to a proto. -// -// Arguments: -// tree_ensemble_handle: Handle to the tree ensemble. -// -// Returns Stamp token of the tree ensemble resource.Serialized proto of the ensemble. -func BoostedTreesSerializeEnsemble(scope *Scope, tree_ensemble_handle tf.Output) (stamp_token tf.Output, tree_ensemble_serialized tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "BoostedTreesSerializeEnsemble", - Input: []tf.Input{ - tree_ensemble_handle, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - // StageSizeAttr is an optional argument to StageSize. type StageSizeAttr func(optionalAttr) @@ -20376,6 +20359,58 @@ func RandomUniformInt(scope *Scope, shape tf.Output, minval tf.Output, maxval tf return op.Output(0) } +// Computes the mean along sparse segments of a tensor. +// +// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of +// segments. +// +// Like `SegmentMean`, but `segment_ids` can have rank less than `data`'s first +// dimension, selecting a subset of dimension 0, specified by `indices`. +// +// Arguments: +// +// indices: A 1-D tensor. Has same rank as `segment_ids`. +// segment_ids: A 1-D tensor. Values should be sorted and can be repeated. +// +// Returns Has same shape as data, except for dimension 0 which +// has size `k`, the number of segments. +func SparseSegmentMean(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseSegmentMean", + Input: []tf.Input{ + data, indices, segment_ids, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Pop the element at the top of the stack. +// +// Arguments: +// handle: The handle to a stack. +// elem_type: The type of the elem that is popped. +// +// Returns The tensor that is popped from the top of the stack. +func StackPopV2(scope *Scope, handle tf.Output, elem_type tf.DataType) (elem tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"elem_type": elem_type} + opspec := tf.OpSpec{ + Type: "StackPopV2", + Input: []tf.Input{ + handle, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Computes hyperbolic cosine of x element-wise. func Cosh(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { @@ -31743,54 +31778,6 @@ func FixedUnigramCandidateSampler(scope *Scope, true_classes tf.Output, num_true return op.Output(0), op.Output(1), op.Output(2) } -// WholeFileReaderV2Attr is an optional argument to WholeFileReaderV2. -type WholeFileReaderV2Attr func(optionalAttr) - -// WholeFileReaderV2Container sets the optional container attribute to value. -// -// value: If non-empty, this reader is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func WholeFileReaderV2Container(value string) WholeFileReaderV2Attr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// WholeFileReaderV2SharedName sets the optional shared_name attribute to value. -// -// value: If non-empty, this reader is named in the given bucket -// with this shared_name. Otherwise, the node name is used instead. -// If not specified, defaults to "" -func WholeFileReaderV2SharedName(value string) WholeFileReaderV2Attr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// A Reader that outputs the entire contents of a file as a value. -// -// To use, enqueue filenames in a Queue. The output of ReaderRead will -// be a filename (key) and the contents of that file (value). -// -// Returns The handle to reference the Reader. -func WholeFileReaderV2(scope *Scope, optional ...WholeFileReaderV2Attr) (reader_handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "WholeFileReaderV2", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Transforms a tf.Example proto (as a string) into typed tensors. // // Arguments: @@ -31861,60 +31848,73 @@ func ParseSingleExample(scope *Scope, serialized tf.Output, dense_defaults []tf. return sparse_indices, sparse_values, sparse_shapes, dense_values } -// Deserializes a serialized tree ensemble config and replaces current tree +// WholeFileReaderV2Attr is an optional argument to WholeFileReaderV2. +type WholeFileReaderV2Attr func(optionalAttr) + +// WholeFileReaderV2Container sets the optional container attribute to value. // -// ensemble. +// value: If non-empty, this reader is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func WholeFileReaderV2Container(value string) WholeFileReaderV2Attr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// WholeFileReaderV2SharedName sets the optional shared_name attribute to value. // -// Arguments: -// tree_ensemble_handle: Handle to the tree ensemble. -// stamp_token: Token to use as the new value of the resource stamp. -// tree_ensemble_serialized: Serialized proto of the ensemble. +// value: If non-empty, this reader is named in the given bucket +// with this shared_name. Otherwise, the node name is used instead. +// If not specified, defaults to "" +func WholeFileReaderV2SharedName(value string) WholeFileReaderV2Attr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// A Reader that outputs the entire contents of a file as a value. // -// Returns the created operation. -func BoostedTreesDeserializeEnsemble(scope *Scope, tree_ensemble_handle tf.Output, stamp_token tf.Output, tree_ensemble_serialized tf.Output) (o *tf.Operation) { +// To use, enqueue filenames in a Queue. The output of ReaderRead will +// be a filename (key) and the contents of that file (value). +// +// Returns The handle to reference the Reader. +func WholeFileReaderV2(scope *Scope, optional ...WholeFileReaderV2Attr) (reader_handle tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "BoostedTreesDeserializeEnsemble", - Input: []tf.Input{ - tree_ensemble_handle, stamp_token, tree_ensemble_serialized, - }, + Type: "WholeFileReaderV2", + + Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Runs multiple additive regression ensemble predictors on input instances and +// Deserializes a serialized tree ensemble config and replaces current tree // -// computes the update to cached logits. It is designed to be used during training. -// It traverses the trees starting from cached tree id and cached node id and -// calculates the updates to be pushed to the cache. +// ensemble. // // Arguments: +// tree_ensemble_handle: Handle to the tree ensemble. +// stamp_token: Token to use as the new value of the resource stamp. +// tree_ensemble_serialized: Serialized proto of the ensemble. // -// cached_tree_ids: Rank 1 Tensor containing cached tree ids which is the starting -// tree of prediction. -// cached_node_ids: Rank 1 Tensor containing cached node id which is the starting -// node of prediction. -// bucketized_features: A list of rank 1 Tensors containing bucket id for each -// feature. -// logits_dimension: scalar, dimension of the logits, to be used for partial logits -// shape. -// -// Returns Rank 2 Tensor containing logits update (with respect to cached -// values stored) for each example.Rank 1 Tensor containing new tree ids for each example.Rank 1 Tensor containing new node ids in the new tree_ids. -func BoostedTreesTrainingPredict(scope *Scope, tree_ensemble_handle tf.Output, cached_tree_ids tf.Output, cached_node_ids tf.Output, bucketized_features []tf.Output, logits_dimension int64) (partial_logits tf.Output, tree_ids tf.Output, node_ids tf.Output) { +// Returns the created operation. +func BoostedTreesDeserializeEnsemble(scope *Scope, tree_ensemble_handle tf.Output, stamp_token tf.Output, tree_ensemble_serialized tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"logits_dimension": logits_dimension} opspec := tf.OpSpec{ - Type: "BoostedTreesTrainingPredict", + Type: "BoostedTreesDeserializeEnsemble", Input: []tf.Input{ - tree_ensemble_handle, cached_tree_ids, cached_node_ids, tf.OutputList(bucketized_features), + tree_ensemble_handle, stamp_token, tree_ensemble_serialized, }, - Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return scope.AddOperation(opspec) } diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index ba9c6a2320..19729813a1 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -78,6 +78,7 @@ py_library( "//tensorflow:__pkg__", "//tensorflow/python/tools:__pkg__", "//tensorflow/python/tools/api/generator:__pkg__", + "//tensorflow/tools/api/tests:__pkg__", ], deps = [ ":array_ops", diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 586f4c6936..7a3fc27592 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -26,7 +26,7 @@ import datetime from tensorflow.python.util import tf_contextlib from tensorflow.python.util.tf_export import tf_export -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 9, 5) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 9, 7) @tf_export("compat.forward_compatible") diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD index 23c98247bf..631b87a718 100644 --- a/tensorflow/python/data/kernel_tests/BUILD +++ b/tensorflow/python/data/kernel_tests/BUILD @@ -137,6 +137,8 @@ tf_py_test( size = "small", srcs = ["interleave_dataset_op_test.py"], additional_deps = [ + "@absl_py//absl/testing:parameterized", + "//third_party/py/numpy", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", @@ -154,6 +156,7 @@ tf_py_test( size = "small", srcs = ["map_dataset_op_test.py"], additional_deps = [ + "@absl_py//absl/testing:parameterized", "//third_party/py/numpy", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py b/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py index 7dbf7268d7..a35cee594a 100644 --- a/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py @@ -19,8 +19,10 @@ from __future__ import print_function import itertools +from absl.testing import parameterized +import numpy as np + from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops @@ -28,7 +30,7 @@ from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import test -class InterleaveDatasetTest(test.TestCase): +class InterleaveDatasetTest(test.TestCase, parameterized.TestCase): def _interleave(self, lists, cycle_length, block_length): num_open = 0 @@ -97,84 +99,85 @@ class InterleaveDatasetTest(test.TestCase): expected_elements, self._interleave(input_lists, 7, 2)): self.assertEqual(expected, produced) - def testInterleaveDataset(self): - input_values = array_ops.placeholder(dtypes.int64, shape=[None]) - cycle_length = array_ops.placeholder(dtypes.int64, shape=[]) - block_length = array_ops.placeholder(dtypes.int64, shape=[]) - - repeat_count = 2 - - dataset = ( - dataset_ops.Dataset.from_tensor_slices(input_values) - .repeat(repeat_count) - .interleave(lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x), - cycle_length, block_length)) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - next_element = iterator.get_next() + @parameterized.named_parameters( + ("1", np.int64([4, 5, 6]), 1, 3, None), + ("2", np.int64([4, 5, 6]), 1, 3, 1), + ("3", np.int64([4, 5, 6]), 2, 1, None), + ("4", np.int64([4, 5, 6]), 2, 1, 1), + ("5", np.int64([4, 5, 6]), 2, 1, 2), + ("6", np.int64([4, 5, 6]), 2, 3, None), + ("7", np.int64([4, 5, 6]), 2, 3, 1), + ("8", np.int64([4, 5, 6]), 2, 3, 2), + ("9", np.int64([4, 5, 6]), 7, 2, None), + ("10", np.int64([4, 5, 6]), 7, 2, 1), + ("11", np.int64([4, 5, 6]), 7, 2, 3), + ("12", np.int64([4, 5, 6]), 7, 2, 5), + ("13", np.int64([4, 5, 6]), 7, 2, 7), + ("14", np.int64([]), 2, 3, None), + ("15", np.int64([0, 0, 0]), 2, 3, None), + ("16", np.int64([4, 0, 6]), 2, 3, None), + ("17", np.int64([4, 0, 6]), 2, 3, 1), + ("18", np.int64([4, 0, 6]), 2, 3, 2), + ) + def testInterleaveDataset(self, input_values, cycle_length, block_length, + num_parallel_calls): + count = 2 + dataset = dataset_ops.Dataset.from_tensor_slices(input_values).repeat( + count).interleave( + lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x), + cycle_length, block_length, num_parallel_calls) + get_next = dataset.make_one_shot_iterator().get_next() + + def repeat(values, count): + result = [] + for value in values: + result.append([value] * value) + return result * count with self.test_session() as sess: - # Cycle length 1 acts like `Dataset.flat_map()`. - sess.run(init_op, feed_dict={input_values: [4, 5, 6], - cycle_length: 1, block_length: 3}) - - for expected_element in self._interleave( - [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 1, 3): - self.assertEqual(expected_element, sess.run(next_element)) - - # Cycle length > 1. - # expected: [4, 5, 4, 5, 4, 5, 4, 5, 5, 6, 6, 4, 6, 4, 6, 4, 6, 4, 6, 5, - # 6, 5, 6, 5, 6, 5, 6, 5] - sess.run(init_op, feed_dict={input_values: [4, 5, 6], - cycle_length: 2, block_length: 1}) for expected_element in self._interleave( - [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 2, 1): - self.assertEqual(expected_element, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - # Cycle length > 1 and block length > 1. - # expected: [4, 4, 4, 5, 5, 5, 4, 5, 5, 6, 6, 6, 4, 4, 4, 6, 6, 6, 4, 5, - # 5, 5, 6, 6, 6, 5, 5, 6, 6, 6] - sess.run(init_op, feed_dict={input_values: [4, 5, 6], - cycle_length: 2, block_length: 3}) - for expected_element in self._interleave( - [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 2, 3): - self.assertEqual(expected_element, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - # Cycle length > len(input_values) * repeat_count. - # expected: [4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, - # 5, 5, 6, 6, 5, 6, 6, 5, 6, 6] - sess.run(init_op, feed_dict={input_values: [4, 5, 6], - cycle_length: 7, block_length: 2}) - for expected_element in self._interleave( - [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 7, 2): - self.assertEqual(expected_element, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - # Empty input. - sess.run(init_op, feed_dict={input_values: [], - cycle_length: 2, block_length: 3}) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) + repeat(input_values, count), cycle_length, block_length): + self.assertEqual(expected_element, sess.run(get_next)) + + for _ in range(2): + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + @parameterized.named_parameters( + ("1", np.float32([1., np.nan, 2., np.nan, 3.]), 1, 3, None), + ("2", np.float32([1., np.nan, 2., np.nan, 3.]), 1, 3, 1), + ("3", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 1, None), + ("4", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 1, 1), + ("5", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 1, 2), + ("6", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 3, None), + ("7", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 3, 1), + ("8", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 3, 2), + ("9", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, None), + ("10", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, 1), + ("11", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, 3), + ("12", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, 5), + ("13", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, 7), + ) + def testInterleaveErrorDataset(self, + input_values, + cycle_length, + block_length, + num_parallel_calls): + dataset = dataset_ops.Dataset.from_tensor_slices(input_values).map( + lambda x: array_ops.check_numerics(x, "message")).interleave( + dataset_ops.Dataset.from_tensors, cycle_length, block_length, + num_parallel_calls) + get_next = dataset.make_one_shot_iterator().get_next() - # Non-empty input leading to empty output. - sess.run(init_op, feed_dict={input_values: [0, 0, 0], - cycle_length: 2, block_length: 3}) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - # Mixture of non-empty and empty interleaved datasets. - sess.run(init_op, feed_dict={input_values: [4, 0, 6], - cycle_length: 2, block_length: 3}) - for expected_element in self._interleave( - [[4] * 4, [], [6] * 6] * repeat_count, 2, 3): - self.assertEqual(expected_element, sess.run(next_element)) + with self.test_session() as sess: + for value in input_values: + if np.isnan(value): + with self.assertRaises(errors.InvalidArgumentError): + sess.run(get_next) + else: + self.assertEqual(value, sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) + sess.run(get_next) def testSparse(self): @@ -201,20 +204,6 @@ class InterleaveDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - def testEmptyInput(self): - iterator = ( - dataset_ops.Dataset.from_tensor_slices([]) - .repeat(None) - .interleave(dataset_ops.Dataset.from_tensors, cycle_length=2) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py index df2c9b170a..fde785be6e 100644 --- a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py @@ -22,6 +22,7 @@ import threading import time import warnings +from absl.testing import parameterized import numpy as np from tensorflow.core.framework import attr_value_pb2 @@ -46,7 +47,7 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test -class MapDatasetTest(test.TestCase): +class MapDatasetTest(test.TestCase, parameterized.TestCase): def _buildMapDataset(self, components, count): def _map_fn(x, y, z): @@ -705,6 +706,35 @@ class MapDatasetTest(test.TestCase): with self.assertRaisesRegexp(errors.InvalidArgumentError, "BrokenConst"): sess.run(iterator.initializer) +# pylint: disable=g-long-lambda + @parameterized.named_parameters( + ("Map", lambda dataset, func: + dataset_ops.MapDataset(dataset, func, use_inter_op_parallelism=False)), + ("ParallelMap", lambda dataset, func: + dataset_ops.ParallelMapDataset(dataset, func, num_parallel_calls=1, + use_inter_op_parallelism=False)), + ) + def testNoInterOpParallelism(self, make_dataset_fn): + dataset = dataset_ops.Dataset.from_tensors(0) + + def _get_tid(): + return np.int64(threading.current_thread().ident) + + def _map_fn(_): + tids = [] + for _ in range(10): + tids.append(script_ops.py_func(_get_tid, [], dtypes.int64)) + return tids + + dataset = make_dataset_fn(dataset, _map_fn) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + tids = sess.run(get_next) + self.assertTrue(all(tids[0] == tid for tid in tids)) +# pylint: enable=g-long-lambda + class MapDatasetBenchmark(test.Benchmark): diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 6205ee392e..c985e00dd1 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -1019,7 +1019,11 @@ class Dataset(object): """ return FlatMapDataset(self, map_func) - def interleave(self, map_func, cycle_length, block_length=1): + def interleave(self, + map_func, + cycle_length, + block_length=1, + num_parallel_calls=None): """Maps `map_func` across this dataset, and interleaves the results. For example, you can use `Dataset.interleave()` to process many input files @@ -1082,11 +1086,19 @@ class Dataset(object): processed concurrently. block_length: The number of consecutive elements to produce from each input element before cycling to another input element. + num_parallel_calls: (Optional.) If specified, the implementation creates + a threadpool, which is used to fetch inputs from cycle elements + asynchronously and in parallel. The default behavior is to fetch inputs + from cycle elements synchronously with no parallelism. Returns: Dataset: A `Dataset`. """ - return InterleaveDataset(self, map_func, cycle_length, block_length) + if num_parallel_calls is None: + return InterleaveDataset(self, map_func, cycle_length, block_length) + else: + return ParallelInterleaveDataset(self, map_func, cycle_length, + block_length, num_parallel_calls) def filter(self, predicate): """Filters this dataset according to `predicate`. @@ -2245,9 +2257,14 @@ class MapDataset(Dataset): class ParallelMapDataset(MapDataset): """A `Dataset` that maps a function over elements in its input in parallel.""" - def __init__(self, input_dataset, map_func, num_parallel_calls): + def __init__(self, + input_dataset, + map_func, + num_parallel_calls, + use_inter_op_parallelism=True): """See `Dataset.map()` for details.""" - super(ParallelMapDataset, self).__init__(input_dataset, map_func) + super(ParallelMapDataset, self).__init__(input_dataset, map_func, + use_inter_op_parallelism) self._num_parallel_calls = ops.convert_to_tensor( num_parallel_calls, dtype=dtypes.int32, name="num_parallel_calls") @@ -2260,6 +2277,7 @@ class ParallelMapDataset(MapDataset): self._map_func.captured_inputs, f=self._map_func, num_parallel_calls=self._num_parallel_calls, + use_inter_op_parallelism=self._use_inter_op_parallelism, **flat_structure(self)) # pylint: enable=protected-access @@ -2330,6 +2348,36 @@ class InterleaveDataset(FlatMapDataset): return "Dataset.interleave()" +class ParallelInterleaveDataset(FlatMapDataset): + """A `Dataset` that maps a function over its input and interleaves the result. + + """ + + def __init__(self, input_dataset, map_func, cycle_length, block_length, + num_parallel_calls): + """See `Dataset.interleave()` for details.""" + super(ParallelInterleaveDataset, self).__init__(input_dataset, map_func) + self._cycle_length = ops.convert_to_tensor( + cycle_length, dtype=dtypes.int64, name="cycle_length") + self._block_length = ops.convert_to_tensor( + block_length, dtype=dtypes.int64, name="block_length") + self._num_parallel_calls = ops.convert_to_tensor( + num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls") + + def _as_variant_tensor(self): + return gen_dataset_ops.parallel_interleave_dataset_v2( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + self._map_func.captured_inputs, # pylint: disable=protected-access + self._cycle_length, + self._block_length, + self._num_parallel_calls, + f=self._map_func, # pylint: disable=protected-access + **flat_structure(self)) + + def _transformation_name(self): + return "Dataset.interleave()" + + class FilterDataset(Dataset): """A `Dataset` that filters its input according to a predicate function.""" diff --git a/tensorflow/python/data/util/nest.py b/tensorflow/python/data/util/nest.py index 9d621fcd30..3a5d1f0adf 100644 --- a/tensorflow/python/data/util/nest.py +++ b/tensorflow/python/data/util/nest.py @@ -96,37 +96,12 @@ def _yield_value(iterable): yield value -def is_sequence(seq): - """Returns a true if `seq` is a Sequence or dict (except strings/lists). +# See the swig file (../../util/util.i) for documentation. +is_sequence = _pywrap_tensorflow.IsSequenceForData - NOTE(mrry): This differs from `tensorflow.python.util.nest.is_sequence()`, - which *does* treat a Python list as a sequence. For ergonomic - reasons, `tf.data` users would prefer to treat lists as - implicit `tf.Tensor` objects, and dicts as (nested) sequences. - Args: - seq: an input sequence. - - Returns: - True if the sequence is a not a string or list and is a - collections.Sequence. - """ - return _pywrap_tensorflow.IsSequenceForData(seq) - - -def flatten(nest): - """Returns a flat sequence from a given nested structure. - - If `nest` is not a sequence, this returns a single-element list: `[nest]`. - - Args: - nest: an arbitrarily nested structure or a scalar object. - Note, numpy arrays are considered scalars. - - Returns: - A Python list, the flattened version of the input. - """ - return _pywrap_tensorflow.FlattenForData(nest) +# See the swig file (../../util/util.i) for documentation. +flatten = _pywrap_tensorflow.FlattenForData def assert_same_structure(nest1, nest2, check_types=True): diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index 9891068056..be392c7a0f 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -216,9 +216,7 @@ def implicit_val_and_grad(f): "function was being computed.") sources = [v.handle for v in variables] - grad = imperative_grad.imperative_grad(_default_vspace, - this_tape, - nest.flatten(end_node), + grad = imperative_grad.imperative_grad(this_tape, nest.flatten(end_node), sources) return end_node, list(zip(grad, variables)) @@ -537,8 +535,8 @@ def make_vjp(f, params=None, persistent=True): if dy is not None: dy = [ops.convert_to_tensor(x) for x in nest.flatten(dy)] return imperative_grad.imperative_grad( - _default_vspace, this_tape, nest.flatten(result), sources, - output_gradients=dy) + this_tape, nest.flatten(result), sources, output_gradients=dy) + return result, vjp return decorated @@ -631,9 +629,9 @@ def _ones(shape, dtype): _default_vspace = imperative_grad.VSpace( num_elements_fn=_num_elements, aggregate_fn=_aggregate_grads, - tensor_id=ops.tensor_id, zeros=_zeros, ones=_ones) +pywrap_tensorflow.TFE_Py_RegisterVSpace(_default_vspace) def _handle_or_self(x): @@ -695,19 +693,57 @@ class GradientTape(object): del g # Drop the reference to the tape ``` + By default GradientTape will automatically watch any trainable variables that + are accessed inside the context. If you want fine grained control over which + variables are watched you can disable automatic tracking by passing + `watch_accessed_variables=False` to the tape constructor: + + ```python + with tf.GradientTape(watch_accessed_variables=False) as tape: + tape.watch(variable_a) + y = variable_a ** 2 # Gradients will be available for `variable_a`. + z = variable_b ** 3 # No gradients will be avaialble since `variable_b` is + # not being watched. + ``` + + Note that when using models you should ensure that your variables exist when + using `watch_accessed_variables=False`. Otherwise it's quite easy to make your + first iteration not have any gradients: + + ```python + a = tf.keras.layers.Dense(32) + b = tf.keras.layers.Dense(32) + + with tf.GradientTape(watch_accessed_variables=False) as tape: + tape.watch(a.variables) # Since `a.build` has not been called at this point + # `a.variables` will return an empty list and the + # tape will not be watching anything. + result = b(a(inputs)) + tape.gradient(result, a.variables) # The result of this computation will be + # a list of `None`s since a's variables + # are not being watched. + ``` + Note that only tensors with real or complex dtypes are differentiable. """ - def __init__(self, persistent=False): + def __init__(self, persistent=False, watch_accessed_variables=True): """Creates a new GradientTape. Args: persistent: Boolean controlling whether a persistent gradient tape is created. False by default, which means at most one call can be made to the gradient() method on this object. + watch_accessed_variables: Boolean controlling whether the tape will + automatically `watch` any (trainable) variables accessed while the tape + is active. Defaults to True meaning gradients can be requested from any + result computed in the tape derived from reading a trainable `Variable`. + If False users must explicitly `watch` any `Variable`s they want to + request gradients from. """ self._tape = None self._persistent = persistent + self._watch_accessed_variables = watch_accessed_variables self._recording = False context.context().start_step() @@ -721,15 +757,15 @@ class GradientTape(object): if self._recording: self._pop_tape() - def _push_tape(self, existing_tape=False): + def _push_tape(self): if self._recording: raise ValueError("Tape is already recording.") - if existing_tape: - if self._tape is None: - raise ValueError("There is no existing tape.") - tape.push_tape(self._tape) + if self._tape is None: + self._tape = tape.push_new_tape( + persistent=self._persistent, + watch_accessed_variables=self._watch_accessed_variables) else: - self._tape = tape.push_new_tape(persistent=self._persistent) + tape.push_tape(self._tape) self._recording = True def _pop_tape(self): @@ -748,7 +784,13 @@ class GradientTape(object): tensor: a Tensor or list of Tensors. """ for t in nest.flatten(tensor): - tape.watch(self._tape, _handle_or_self(t)) + if hasattr(t, "handle"): + # There are many variable-like objects, all of them currently have + # `handle` attribute that points to a tensor. If this changes, internals + # of watch_variable need to change as well. + tape.watch_variable(self._tape, t) + else: + tape.watch(self._tape, t) @tf_contextlib.contextmanager def stop_recording(self): @@ -780,7 +822,7 @@ class GradientTape(object): try: yield finally: - self._push_tape(existing_tape=True) + self._push_tape() def reset(self): """Clears all information stored in this tape. @@ -814,6 +856,7 @@ class GradientTape(object): ``` """ self._pop_tape() + self._tape = None self._push_tape() def watched_variables(self): @@ -865,7 +908,9 @@ class GradientTape(object): for x in nest.flatten(output_gradients)] flat_grad = imperative_grad.imperative_grad( - _default_vspace, self._tape, nest.flatten(target), flat_sources, + self._tape, + nest.flatten(target), + flat_sources, output_gradients=output_gradients) if not self._persistent: diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index 6673178ee7..f938ed5df8 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -474,6 +474,18 @@ class BackpropTest(test.TestCase): self.assertEqual(backprop.implicit_grad(f)()[0][0], None) @test_util.assert_no_new_tensors + def testGradientTapeReEnterContext(self): + g = backprop.GradientTape() + with g: + x = constant_op.constant(3.0) + g.watch(x) + y = 2*x + with g: + z = 2*y + grad = g.gradient(target=z, sources=[x]) + self.assertEqual(self.evaluate(grad), [4.0]) + + @test_util.assert_no_new_tensors @test_util.run_in_graph_and_eager_modes def testGradientTapeRepeatedSource(self): with backprop.GradientTape(persistent=False) as g: @@ -956,6 +968,60 @@ class BackpropTest(test.TestCase): self.assertAllEqual(grad1, grad2) + @test_util.run_in_graph_and_eager_modes + def testSelectivelyWatchVariables(self): + x1 = resource_variable_ops.ResourceVariable(1.0) + x2 = resource_variable_ops.ResourceVariable(1.0) + with backprop.GradientTape(watch_accessed_variables=False) as tape: + tape.watch(x2) + y = x1**2 + z = x2**3 + self.assertTupleEqual(tape.watched_variables(), (x2,)) + dy, dz = tape.gradient([y, z], [x1, x2]) + self.evaluate([x1.initializer, x2.initializer]) + self.assertIsNone(dy) + self.assertEqual(self.evaluate(dz), 3.0) + + + @test_util.run_in_graph_and_eager_modes + def testDifferentiatingScalarCache(self): + # In the following test, if x2 = x1 (i.e the objects are the exact same), + # then y is essentially, 2*x1, and dy/dx1 = 2. + # When we had a pure scalar cache in eager, this would be the case. This + # test prevents us from going back to that case. + with backprop.GradientTape(persistent=False) as g: + x1 = constant_op.constant(3.0) + x2 = constant_op.constant(3.0) + g.watch(x1) + g.watch(x2) + y = x1 + x2 + grad = g.gradient(target=y, sources=[x1]) + self.assertEqual(self.evaluate(grad), [1.0]) + + def testVariablesAndConstantsProduceTheSameGradients(self): + + # In the following test, differentiating [y, z] against [a, b] gives: + # (dy/da + dz/da, dy/db + dz/db). + # If a and b are the same constant, dz/da will not be 0 (which it should + # be). + # This is solved by using variable since doing a read_value on a tensor will + # produce a new tensor and corresponding TensorHandle, and not reuse the + # same tensor (which would happen if we are using a cache and reusing + # EagerTensor objects). + def get_grads(a, b): + with backprop.GradientTape() as tape: + tape.watch([a, b]) + y = a**3 + z = b**2 + return tape.gradient([y, z], [a, b]) + + gradients_constants = get_grads( + constant_op.constant(2.0), constant_op.constant(2.0)) + gradients_variables = get_grads( + resource_variable_ops.ResourceVariable(2.0), + resource_variable_ops.ResourceVariable(2.0)) + self.assertAllEqual(gradients_constants, gradients_variables) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py index a2e8422671..3bdaf0b214 100644 --- a/tensorflow/python/eager/benchmarks_test.py +++ b/tensorflow/python/eager/benchmarks_test.py @@ -175,6 +175,11 @@ class MicroBenchmarks(test.Benchmark): self._run(func, 30000) + def benchmark_create_constant(self): + func = lambda: constant_op.constant(3.0) + + self._run(func, 30000) + def benchmark_create_float_tensor_from_list_CPU(self): self._benchmark_create_tensor([[3.0]], dtypes.float32.as_datatype_enum, CPU) diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index d56c1457e0..03f12139f6 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -519,7 +519,7 @@ class Function(object): for v in self._func_graph.variables: if v.trainable: - tape.watch_variable(v) + tape.variable_accessed(v) captures = self._resolve_captured_inputs() tensor_inputs = [x for x in nest.flatten(args) if isinstance(x, ops.Tensor)] diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 3c79099d87..37a9957cea 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -27,7 +27,6 @@ from tensorflow.python.data.ops import iterator_ops from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import function -from tensorflow.python.eager import tape from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -616,7 +615,6 @@ class FunctionTest(test.TestCase): @function.defun def g(x): - tape.watch_variable(x) y = math_ops.add(x, three) f(y) @@ -630,7 +628,6 @@ class FunctionTest(test.TestCase): return math_ops.add(x, three) def g(x): - tape.watch_variable(three) return f(x) g = backprop.implicit_grad(g)(constant_op.constant(1.0))[0][0] @@ -1427,14 +1424,14 @@ class FunctionTest(test.TestCase): grad_t, = backprop.gradients_function(sq, [0])(t) self.assertAllEqual(grad_t, [[6, 6], [14, 14]]) - with backprop.GradientTape(persistent=True) as gtape: - gtape.watch(t) + with backprop.GradientTape(persistent=True) as tape: + tape.watch(t) one = matmul(t, b=t, transpose_a=True) two = matmul(b=t, a=t, transpose_a=True) three = matmul(a=t, b=t, transpose_a=True) for output in [one, two, three]: - self.assertAllEqual(gtape.gradient(output, t), [[6, 6], [14, 14]]) + self.assertAllEqual(tape.gradient(output, t), [[6, 6], [14, 14]]) def testGradientInFunctionWithKeywordArguments(self): diff --git a/tensorflow/python/eager/imperative_grad.py b/tensorflow/python/eager/imperative_grad.py index 000152855d..5f027d107c 100644 --- a/tensorflow/python/eager/imperative_grad.py +++ b/tensorflow/python/eager/imperative_grad.py @@ -24,12 +24,10 @@ from tensorflow.python import pywrap_tensorflow VSpace = collections.namedtuple( - "VSpace", - ["aggregate_fn", "num_elements_fn", "tensor_id", "zeros", "ones"]) + "VSpace", ["aggregate_fn", "num_elements_fn", "zeros", "ones"]) def imperative_grad( - vspace, tape, target, sources, @@ -41,7 +39,6 @@ def imperative_grad( gradients for all sources. Args: - vspace: the vector space in which to differentiate. tape: the gradient tape which stores the trace. target: either a Tensor or list of Tensors to be differentiated. sources: list of Tensors for which we want gradients @@ -60,4 +57,7 @@ def imperative_grad( computation of target. """ return pywrap_tensorflow.TFE_Py_TapeGradient( - tape._tape, vspace, target, sources, output_gradients) # pylint: disable=protected-access + tape._tape, # pylint: disable=protected-access + target, + sources, + output_gradients) diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc index 86fbd24d68..f34ce6af79 100644 --- a/tensorflow/python/eager/pywrap_tensor.cc +++ b/tensorflow/python/eager/pywrap_tensor.cc @@ -27,6 +27,8 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/python/lib/core/ndarray_tensor.h" +#include "structmember.h" // NOLINT // For PyMemberDef + // forward declare struct EagerTensor; @@ -325,12 +327,36 @@ int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) { PyObject* context = nullptr; PyObject* device = nullptr; PyObject* dtype = Py_None; - const char* kwlist[] = {"value", "context", "device", "dtype", nullptr}; - if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO|O", + PyObject* other_value = nullptr; + const char* kwlist[] = {"value", "context", "device", + "dtype", "other_value", nullptr}; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO|OO", const_cast<char**>(kwlist), &value, &context, - &device, &dtype)) { + &device, &dtype, &other_value)) { return -1; } + + if (other_value != nullptr) { + if (!EagerTensor_CheckExact(other_value)) { + PyErr_SetString(PyExc_TypeError, + tensorflow::strings::StrCat( + "Expecting an EagerTensor for other_value, got ", + Py_TYPE(other_value)->tp_name) + .c_str()); + + return -1; + } + EagerTensor* other = reinterpret_cast<EagerTensor*>(other_value); + self->handle = + TFE_TensorHandleCopySharingTensor(other->handle, self->status); + + if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) { + return -1; + } + + return 0; + } + // Extract dtype int desired_dtype = -1; if (dtype != Py_None) { @@ -619,6 +645,15 @@ static PyGetSetDef EagerTensor_getseters[] = { {nullptr} /* Sentinel */ }; +#if PY_MAJOR_VERSION < 3 +// Only used for Python2 since Python3 seems to set the __dict__ correctly. +static PyMemberDef EagerTensor_members[] = { + {const_cast<char*>("__dict__"), T_OBJECT, offsetof(EagerTensor, dict), + READONLY}, + {nullptr}, +}; +#endif + static PyMethodDef EagerTensor_methods[] = { {"_numpy", (PyCFunction)EagerTensor_numpy, METH_NOARGS, PyDoc_STR("_numpy")}, @@ -693,7 +728,7 @@ static PyTypeObject _EagerTensorType = { nullptr, /* tp_iter */ nullptr, /* tp_iternext */ EagerTensor_methods, /* tp_methods */ - nullptr, /* tp_members */ + EagerTensor_members, /* tp_members */ EagerTensor_getseters, /* tp_getset */ nullptr, /* tp_base */ nullptr, /* tp_dict */ @@ -829,7 +864,7 @@ PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) { } EagerTensorType->tp_dictoffset = offsetof(EagerTensor, dict); #else - _EagerTensorType.tp_base = reinterpret_cast<PyTypeObject*>(base_class); + _EagerTensorType.tp_base = base_class_type; if (PyType_Ready(&_EagerTensorType) < 0) { if (PyErr_Occurred()) return nullptr; diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h index 16f8c3c917..f1b4042ec9 100755 --- a/tensorflow/python/eager/pywrap_tfe.h +++ b/tensorflow/python/eager/pywrap_tfe.h @@ -59,6 +59,10 @@ PyObject* TFE_Py_RegisterExceptionClass(PyObject* e); // This function is not thread-safe. PyObject* TFE_Py_RegisterResourceVariableType(PyObject* e); +// Registers e as the VSpace to use. +// `vspace` must be a imperative_grad.py:VSpace named tuple. +PyObject* TFE_Py_RegisterVSpace(PyObject* e); + // Registers e as the Exception to be raised when the conditions of // TFE_Py_FastPathExecute_C have not been met. When this exception is set, it // is a signal to the calling code that it should fall back to the safer (and @@ -124,9 +128,10 @@ PyObject* TFE_Py_InitEagerTensor(PyObject* base_class); // To unset the profiler, pass Py_None as the value of `profiler`. PyObject* TFE_Py_SetEagerTensorProfiler(PyObject* profiler); -// Creates a new tape and adds it to the active set. `persistent` must be a -// PyBool_Type, i.e either Py_True or Py_False -PyObject* TFE_Py_TapeSetNew(PyObject* persistent); +// Creates a new tape and adds it to the active set. `persistent` and +// `watch_accessed_variables` must be `PyBool_Type` (`Py_True` or `Py_False`). +PyObject* TFE_Py_TapeSetNew(PyObject* persistent, + PyObject* watch_accessed_variables); // Removes the passed tape from the set of active tapes. void TFE_Py_TapeSetRemove(PyObject* tape); @@ -158,18 +163,20 @@ void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors, PyObject* input_tensor_ids, PyObject* backward_function); +// Notifies all tapes that a variable has been accessed. +void TFE_Py_TapeVariableAccessed(PyObject* variable); + // Watches the given variable object on the given tape. -void TFE_Py_TapeSetWatchVariable(PyObject* variable); +void TFE_Py_TapeWatchVariable(PyObject* tape, PyObject* variable); // Computes a gradient based on information recorded on the tape.`tape` must -// have been produced by TFE_Py_NewTape. `vspace` must be a -// imperative_grad.py:VSpace named tuple. `target` and `sources` must be python +// have been produced by TFE_Py_NewTape. `target` and `sources` must be python // lists of Tensor objects. `output_gradients` is either None or a python list // of either Tensor or None, and if not None should have the same length as // target. -PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace, - PyObject* target, PyObject* sources, - PyObject* output_gradients, TF_Status* status); +PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* target, + PyObject* sources, PyObject* output_gradients, + TF_Status* status); // Execute a tensorflow operation assuming that all provided inputs are // correctly formatted (i.e. EagerTensors). If it doesn't find EagerTensors, diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 0a33a04dcb..1ed814258b 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -892,9 +892,10 @@ static tensorflow::DataType FastTensorDtype(PyObject* tensor) { class GradientTape : public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction> { public: - explicit GradientTape(bool persistent) + explicit GradientTape(bool persistent, bool watch_accessed_variables) : tensorflow::eager::GradientTape<PyObject, PyBackwardFunction>( - persistent) {} + persistent), + watch_accessed_variables_(watch_accessed_variables) {} virtual ~GradientTape() { for (const IdAndVariable& v : watched_variables_) { @@ -902,6 +903,12 @@ class GradientTape } } + void VariableAccessed(PyObject* v) { + if (watch_accessed_variables_) { + WatchVariable(v); + } + } + void WatchVariable(PyObject* v) { tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(v, "handle")); if (handle == nullptr) { @@ -951,6 +958,7 @@ class GradientTape } }; + bool watch_accessed_variables_; tensorflow::mutex watched_variables_mu_; std::set<IdAndVariable, CompareById> watched_variables_ GUARDED_BY(watched_variables_mu_); @@ -1056,11 +1064,13 @@ void TFE_Py_TapeSetStopOnThread() { *ThreadTapeIsStopped() = true; } void TFE_Py_TapeSetRestartOnThread() { *ThreadTapeIsStopped() = false; } -PyObject* TFE_Py_TapeSetNew(PyObject* persistent) { +PyObject* TFE_Py_TapeSetNew(PyObject* persistent, + PyObject* watch_accessed_variables) { TFE_Py_Tape_Type.tp_new = PyType_GenericNew; if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return nullptr; TFE_Py_Tape* tape = PyObject_NEW(TFE_Py_Tape, &TFE_Py_Tape_Type); - tape->tape = new GradientTape(persistent == Py_True); + tape->tape = new GradientTape(persistent == Py_True, + watch_accessed_variables == Py_True); Py_INCREF(tape); GetTapeSet()->insert(reinterpret_cast<TFE_Py_Tape*>(tape)); return reinterpret_cast<PyObject*>(tape); @@ -1233,13 +1243,20 @@ std::vector<tensorflow::int64> MakeTensorIDList(PyObject* tensors) { return list; } -void TFE_Py_TapeSetWatchVariable(PyObject* variable) { +void TFE_Py_TapeVariableAccessed(PyObject* variable) { if (*ThreadTapeIsStopped()) { return; } for (TFE_Py_Tape* tape : SafeTapeSet()) { - tape->tape->WatchVariable(variable); + tape->tape->VariableAccessed(variable); + } +} + +void TFE_Py_TapeWatchVariable(PyObject* tape, PyObject* variable) { + if (*ThreadTapeIsStopped()) { + return; } + reinterpret_cast<TFE_Py_Tape*>(tape)->tape->WatchVariable(variable); } PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) { @@ -1348,7 +1365,9 @@ void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) { class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction> { public: - explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {} + explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) { + Py_INCREF(py_vspace_); + } tensorflow::Status Initialize() { num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn"); @@ -1376,6 +1395,8 @@ class PyVSpace Py_XDECREF(aggregate_fn_); Py_XDECREF(zeros_); Py_XDECREF(ones_); + + Py_DECREF(py_vspace_); } tensorflow::int64 NumElements(PyObject* tensor) const final { @@ -1491,6 +1512,22 @@ class PyVSpace PyObject* zeros_; PyObject* ones_; }; +PyVSpace* py_vspace = nullptr; + +PyObject* TFE_Py_RegisterVSpace(PyObject* e) { + if (py_vspace != nullptr) { + delete py_vspace; + } + + py_vspace = new PyVSpace(e); + auto status = py_vspace->Initialize(); + if (MaybeRaiseExceptionFromStatus(status, nullptr)) { + delete py_vspace; + return nullptr; + } + + Py_RETURN_NONE; +} std::vector<PyObject*> MakeTensorList(PyObject* tensors) { PyObject* seq = PySequence_Fast(tensors, "expected a sequence"); @@ -1507,9 +1544,9 @@ std::vector<PyObject*> MakeTensorList(PyObject* tensors) { return list; } -PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace, - PyObject* target, PyObject* sources, - PyObject* output_gradients, TF_Status* status) { +PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* target, + PyObject* sources, PyObject* output_gradients, + TF_Status* status) { TFE_Py_Tape* tape_obj = reinterpret_cast<TFE_Py_Tape*>(tape); if (!tape_obj->tape->IsPersistent()) { auto* tape_set = GetTapeSet(); @@ -1524,10 +1561,6 @@ PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace, return nullptr; } } - PyVSpace c_vspace(vspace); - if (!c_vspace.Initialize().ok()) { - return nullptr; - } std::vector<tensorflow::int64> target_vec = MakeTensorIDList(target); if (PyErr_Occurred()) { @@ -1551,7 +1584,7 @@ PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace, } std::vector<PyObject*> result; status->status = tape_obj->tape->ComputeGradient( - c_vspace, target_vec, sources_vec, outgrad_vec, &result); + *py_vspace, target_vec, sources_vec, outgrad_vec, &result); if (!status->status.ok()) { if (PyErr_Occurred()) { // Do not propagate the erroneous status as that would swallow the @@ -1893,14 +1926,14 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs, Py_RETURN_NONE; } -void MaybeWatchVariable(PyObject* input) { +void MaybeNotifyVariableAccessed(PyObject* input) { DCHECK(CheckResourceVariable(input)); DCHECK(PyObject_HasAttrString(input, "_trainable")); tensorflow::Safe_PyObjectPtr trainable( PyObject_GetAttrString(input, "_trainable")); if (trainable.get() == Py_False) return; - TFE_Py_TapeSetWatchVariable(input); + TFE_Py_TapeVariableAccessed(input); } bool CastTensor(const FastPathOpExecInfo& op_exec_info, @@ -1931,7 +1964,7 @@ bool CastTensor(const FastPathOpExecInfo& op_exec_info, bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info, PyObject* input, tensorflow::Safe_PyObjectPtr* output, TF_Status* status) { - MaybeWatchVariable(input); + MaybeNotifyVariableAccessed(input); TFE_Op* op = TFE_NewOp(parent_op_exec_info.ctx, "ReadVariableOp", status); auto cleaner = tensorflow::gtl::MakeCleanup([op] { TFE_DeleteOp(op); }); diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py index 6eb62afec4..399d90223c 100644 --- a/tensorflow/python/eager/tape.py +++ b/tensorflow/python/eager/tape.py @@ -33,9 +33,10 @@ class Tape(object): return pywrap_tensorflow.TFE_Py_TapeWatchedVariables(self._tape) -def push_new_tape(persistent=False): +def push_new_tape(persistent=False, watch_accessed_variables=True): """Pushes a new tape onto the tape stack.""" - tape = pywrap_tensorflow.TFE_Py_TapeSetNew(persistent) + tape = pywrap_tensorflow.TFE_Py_TapeSetNew(persistent, + watch_accessed_variables) return Tape(tape) @@ -49,13 +50,14 @@ def watch(tape, tensor): pywrap_tensorflow.TFE_Py_TapeWatch(tape._tape, tensor) # pylint: disable=protected-access -def watch_variable(variable): - """Marks this variable to be watched by all tapes in the stack. +def watch_variable(tape, variable): + """Marks this variable to be watched by the given tape.""" + pywrap_tensorflow.TFE_Py_TapeWatchVariable(tape._tape, variable) # pylint: disable=protected-access - Args: - variable: variable to be watched. - """ - pywrap_tensorflow.TFE_Py_TapeSetWatchVariable(variable) + +def variable_accessed(variable): + """Notifies all tapes in the stack that a variable has been accessed.""" + pywrap_tensorflow.TFE_Py_TapeVariableAccessed(variable) def pop_tape(tape): diff --git a/tensorflow/python/eager/tensor_test.py b/tensorflow/python/eager/tensor_test.py index 32742a9b96..344a9b25bd 100644 --- a/tensorflow/python/eager/tensor_test.py +++ b/tensorflow/python/eager/tensor_test.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops def _create_tensor(value, device=None, dtype=None): @@ -333,6 +334,19 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase): "but tensor at index 2 has rank 0"): pywrap_tensorflow.TFE_Py_TensorShapeSlice([t2, t1, t3], 0) + @test_util.assert_no_new_pyobjects_executing_eagerly + def testTensorDir(self): + t = array_ops.zeros(1) + t.test_attr = "Test" + + instance_dir = dir(t) + type_dir = dir(ops.EagerTensor) + + # Monkey patched attributes should show up in dir(t) + self.assertIn("test_attr", instance_dir) + instance_dir.remove("test_attr") + self.assertEqual(instance_dir, type_dir) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD index cf8e18b216..00da335fef 100644 --- a/tensorflow/python/estimator/BUILD +++ b/tensorflow/python/estimator/BUILD @@ -687,6 +687,7 @@ py_test( "manual", # b/112769036, b/113907597 "no_oss", # b/112769036, b/113907597 "no_windows", + "noasan", # b/114304340 "nomsan", "notsan", # b/67510291 ], diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py index d104c961d3..19f18015e4 100644 --- a/tensorflow/python/estimator/canned/boosted_trees.py +++ b/tensorflow/python/estimator/canned/boosted_trees.py @@ -1000,8 +1000,11 @@ class BoostedTreesClassifier(estimator.Estimator): bucketized_feature_2 = bucketized_column( numeric_column('feature_2'), BUCKET_BOUNDARIES_2) + # Need to see a large portion of the data before we can build a layer, for + # example half of data n_batches_per_layer = 0.5 * NUM_EXAMPLES / BATCH_SIZE classifier = estimator.BoostedTreesClassifier( feature_columns=[bucketized_feature_1, bucketized_feature_2], + n_batches_per_layer=n_batches_per_layer, n_trees=100, ... <some other params> ) @@ -1024,7 +1027,8 @@ class BoostedTreesClassifier(estimator.Estimator): the model. All items in the set should be instances of classes derived from `FeatureColumn`. n_batches_per_layer: the number of batches to collect statistics per - layer. + layer. The total number of batches is total number of data divided by + batch size. model_dir: Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to continue training a previously saved model. @@ -1138,8 +1142,11 @@ class BoostedTreesRegressor(estimator.Estimator): bucketized_feature_2 = bucketized_column( numeric_column('feature_2'), BUCKET_BOUNDARIES_2) + # Need to see a large portion of the data before we can build a layer, for + # example half of data n_batches_per_layer = 0.5 * NUM_EXAMPLES / BATCH_SIZE regressor = estimator.BoostedTreesRegressor( feature_columns=[bucketized_feature_1, bucketized_feature_2], + n_batches_per_layer=n_batches_per_layer, n_trees=100, ... <some other params> ) @@ -1162,7 +1169,8 @@ class BoostedTreesRegressor(estimator.Estimator): the model. All items in the set should be instances of classes derived from `FeatureColumn`. n_batches_per_layer: the number of batches to collect statistics per - layer. + layer. The total number of batches is total number of data divided by + batch size. model_dir: Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to continue training a previously saved model. diff --git a/tensorflow/python/estimator/keras_test.py b/tensorflow/python/estimator/keras_test.py index 290c4604ce..7e5a0c80a7 100644 --- a/tensorflow/python/estimator/keras_test.py +++ b/tensorflow/python/estimator/keras_test.py @@ -26,20 +26,23 @@ import numpy as np from tensorflow.core.protobuf import config_pb2 from tensorflow.python import keras +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import keras as keras_lib +from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator import run_config as run_config_lib -from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.keras import testing_utils from tensorflow.python.keras.optimizers import SGD from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables from tensorflow.python.ops.parsing_ops import gen_parsing_ops from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache from tensorflow.python.training import rmsprop from tensorflow.python.training import session_run_hook +from tensorflow.python.training import training_util try: @@ -90,6 +93,15 @@ def simple_subclassed_model(): return SimpleModel() +def gen_input_fn(x, y=None, batch_size=128, num_epochs=1, shuffle=False): + def input_fn(): + ds = dataset_ops.Dataset.from_tensor_slices((x, y) if y is not None else x) + if shuffle: + ds = ds.shuffle(1000) + return ds.repeat(num_epochs).batch(batch_size) + return input_fn + + def get_resource_for_simple_model(model_type='sequential', is_evaluate=False,): if model_type == 'sequential': @@ -117,19 +129,19 @@ def get_resource_for_simple_model(model_type='sequential', y_train = keras.utils.to_categorical(y_train) y_test = keras.utils.to_categorical(y_test) - train_input_fn = numpy_io.numpy_input_fn( + train_input_fn = gen_input_fn( x=randomize_io_type(x_train, input_name), y=randomize_io_type(y_train, output_name), shuffle=False, num_epochs=None, batch_size=16) - evaluate_input_fn = numpy_io.numpy_input_fn( + evaluate_input_fn = gen_input_fn( x=randomize_io_type(x_test, input_name), y=randomize_io_type(y_test, output_name), num_epochs=1, shuffle=False) - predict_input_fn = numpy_io.numpy_input_fn( + predict_input_fn = gen_input_fn( x=randomize_io_type(x_test, input_name), num_epochs=1, shuffle=False) inference_input_fn = evaluate_input_fn if is_evaluate else predict_input_fn @@ -203,7 +215,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): optimizer='rmsprop', metrics=['mse', keras.metrics.categorical_accuracy]) - with self.test_session(): + with self.cached_session(): est_keras = keras_lib.model_to_estimator( keras_model=keras_model, config=self._config) before_eval_results = est_keras.evaluate( @@ -228,7 +240,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): metrics=['mse', keras.metrics.categorical_accuracy]) my_hook = MyHook() - with self.test_session(): + with self.cached_session(): est_keras = keras_lib.model_to_estimator( keras_model=keras_model, config=self._config) before_eval_results = est_keras.evaluate( @@ -252,7 +264,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): optimizer=rmsprop.RMSPropOptimizer(1e-3), metrics=['mse', keras.metrics.categorical_accuracy]) my_hook = MyHook() - with self.test_session(): + with self.cached_session(): keras_model.fit(x_train, y_train, epochs=1) keras_est = keras_lib.model_to_estimator( @@ -274,7 +286,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): optimizer=rmsprop.RMSPropOptimizer(1e-3), metrics=['mse', keras.metrics.categorical_accuracy]) - with self.test_session(): + with self.cached_session(): est_keras = keras_lib.model_to_estimator( keras_model=keras_model, config=self._config) @@ -297,7 +309,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): optimizer=rmsprop.RMSPropOptimizer(1e-3), metrics=['mse', keras.metrics.categorical_accuracy]) - with self.test_session(): + with self.cached_session(): est_keras = keras_lib.model_to_estimator( keras_model=keras_model, config=self._config) est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16) @@ -316,7 +328,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): optimizer=rmsprop.RMSPropOptimizer(1e-3), metrics=['mse', keras.metrics.categorical_accuracy]) - with self.test_session(): + with self.cached_session(): # Create state keras_model.train_on_batch(np.random.random((10,) + _INPUT_SIZE), np.random.random((10, _NUM_CLASS))) @@ -343,7 +355,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): x_test, y_test), _, eval_input_fn = get_resource_for_simple_model( model_type='functional', is_evaluate=True) - with self.test_session(): + with self.cached_session(): metrics = [ 'binary_accuracy', 'binary_crossentropy', 'categorical_accuracy', 'categorical_crossentropy', 'cosine_proximity', 'hinge', @@ -357,7 +369,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): keras_model.fit(x_train, y_train, epochs=1) keras_eval = keras_model.evaluate(x_test, y_test, batch_size=32) - with self.test_session(): + with self.cached_session(): keras_est = keras_lib.model_to_estimator( keras_model=keras_model, config=self._config) est_eval = keras_est.evaluate(input_fn=eval_input_fn) @@ -385,7 +397,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): x_test, _), _, pred_input_fn = get_resource_for_simple_model( model_type='sequential', is_evaluate=False) - with self.test_session(): + with self.cached_session(): keras_model.compile( loss='categorical_crossentropy', optimizer='adam', @@ -393,7 +405,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): keras_model.fit(x_train, y_train, epochs=1) keras_pred = [np.argmax(y) for y in keras_model.predict(x_test)] - with self.test_session(): + with self.cached_session(): keras_est = keras_lib.model_to_estimator( keras_model=keras_model, config=self._config) est_pred = [ @@ -439,7 +451,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): output_dict = {'dense_2': c_test, 'dense_3': d_test} return input_dict, output_dict - with self.test_session(): + with self.cached_session(): model = multi_inputs_multi_outputs_model() est_keras = keras_lib.model_to_estimator( keras_model=model, config=self._config) @@ -456,7 +468,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): x_test, _), _, pred_input_fn = get_resource_for_simple_model( model_type='functional', is_evaluate=False) - with self.test_session(): + with self.cached_session(): keras_model.compile( loss='categorical_crossentropy', optimizer='rmsprop', @@ -466,7 +478,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): fname = os.path.join(self._base_dir, 'keras_model.h5') keras.models.save_model(keras_model, fname) - with self.test_session(): + with self.cached_session(): keras_est = keras_lib.model_to_estimator( keras_model_path=fname, config=self._config) est_pred = [ @@ -479,19 +491,19 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): with self.assertRaisesRegexp(ValueError, 'Either'): keras_lib.model_to_estimator() - with self.test_session(): + with self.cached_session(): keras_model = simple_sequential_model() with self.assertRaisesRegexp(ValueError, 'not both'): keras_lib.model_to_estimator( keras_model=keras_model, keras_model_path=tempfile.mkdtemp(dir=self._base_dir)) - with self.test_session(): + with self.cached_session(): keras_model = simple_sequential_model() with self.assertRaisesRegexp(ValueError, 'compiled'): keras_lib.model_to_estimator(keras_model=keras_model) - with self.test_session(): + with self.cached_session(): keras_model = simple_sequential_model() with self.assertRaisesRegexp(ValueError, 'not a local path'): keras_lib.model_to_estimator( @@ -516,10 +528,10 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): model = simple_functional_model() model.compile( loss='categorical_crossentropy', optimizer='adam', metrics=['acc']) - with self.test_session(): + with self.cached_session(): est_keras = keras_lib.model_to_estimator( keras_model=model, config=self._config) - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp(KeyError, 'Difference: .*invalid_input_name'): est_keras.train(input_fn=invald_input_name_input_fn, steps=100) @@ -547,20 +559,20 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): y_train = keras.utils.to_categorical(y_train, 2) input_name = keras_model.input_names[0] output_name = keras_model.output_names[0] - train_input_fn = numpy_io.numpy_input_fn( + train_input_fn = gen_input_fn( x=randomize_io_type(x_train, input_name), y=randomize_io_type(y_train, output_name), shuffle=False, num_epochs=None, batch_size=16) with self.assertRaisesRegexp(ValueError, 'relu6'): - with self.test_session(): + with self.cached_session(): est = keras_lib.model_to_estimator( keras_model=keras_model, model_dir=tempfile.mkdtemp(dir=self._base_dir)) est.train(input_fn=train_input_fn, steps=1) - with self.test_session(): + with self.cached_session(): est = keras_lib.model_to_estimator( keras_model=keras_model, model_dir=tempfile.mkdtemp(dir=self._base_dir), @@ -586,7 +598,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): } }) with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}): - with self.test_session(): + with self.cached_session(): keras_lib.model_to_estimator( keras_model=keras_model, model_dir=tempfile.mkdtemp(dir=self._base_dir)) @@ -602,7 +614,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): gpu_options = config_pb2.GPUOptions(per_process_gpu_memory_fraction=0.3) sess_config = config_pb2.ConfigProto(gpu_options=gpu_options) self._config._session_config = sess_config - with self.test_session(): + with self.cached_session(): keras_lib.model_to_estimator( keras_model=keras_model, config=self._config) self.assertEqual( @@ -618,7 +630,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): optimizer='rmsprop', metrics=['mse', keras.metrics.categorical_accuracy]) - with self.test_session(): + with self.cached_session(): est_keras = keras_lib.model_to_estimator( keras_model=keras_model, model_dir=self._base_dir, config=run_config_lib.RunConfig()) @@ -629,7 +641,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): self.assertEqual(self._base_dir, est_keras._config.model_dir) self.assertEqual(self._base_dir, est_keras._model_dir) - with self.test_session(): + with self.cached_session(): est_keras = keras_lib.model_to_estimator( keras_model=keras_model, model_dir=self._base_dir, config=None) @@ -648,7 +660,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): optimizer='rmsprop', metrics=['mse', keras.metrics.categorical_accuracy]) - with self.test_session(): + with self.cached_session(): with test.mock.patch.object(tempfile, 'mkdtemp', return_value=_TMP_DIR): est_keras = keras_lib.model_to_estimator( keras_model=keras_model, @@ -663,7 +675,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): optimizer='rmsprop', metrics=['mse', keras.metrics.categorical_accuracy]) - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp(ValueError, '`model_dir` are set both in ' 'constructor and `RunConfig`'): keras_lib.model_to_estimator( @@ -676,7 +688,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): loss='categorical_crossentropy', optimizer=rmsprop.RMSPropOptimizer(1e-3), metrics=['mse', keras.metrics.categorical_accuracy]) - with self.test_session(): + with self.cached_session(): keras_model.train_on_batch( np.random.random((10,) + _INPUT_SIZE), np.random.random((10, _NUM_CLASS))) @@ -690,6 +702,32 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): keras_lib.model_to_estimator( keras_model=keras_model, config=self._config) + def assert_increasing_global_step(self, optimizer): + keras_model, _, _, train_input_fn, _ = get_resource_for_simple_model( + model_type='sequential', is_evaluate=True) + keras_model.compile( + loss='categorical_crossentropy', + optimizer=optimizer, + metrics=['mse', keras.metrics.categorical_accuracy]) + with self.cached_session() as sess: + keras_model_fn = keras_lib._create_keras_model_fn(keras_model) + global_step = training_util.create_global_step() + features, labels = train_input_fn().make_one_shot_iterator().get_next() + spec = keras_model_fn(features, labels, mode=model_fn_lib.ModeKeys.TRAIN) + + sess.run(variables.global_variables_initializer()) + sess.run(variables.local_variables_initializer()) + + self.assertEqual(global_step.eval(), 0) # Sanity check + sess.run(spec.train_op) + self.assertEqual(global_step.eval(), 1) + + def test_model_fn_increments_global_step_tf_optimizer(self): + self.assert_increasing_global_step(rmsprop.RMSPropOptimizer(1e-3)) + + def test_model_fn_increments_global_step_keras_optimizer(self): + self.assert_increasing_global_step('rmsprop') + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/feature_column/BUILD b/tensorflow/python/feature_column/BUILD index 1017d4ba47..ac53a84eef 100644 --- a/tensorflow/python/feature_column/BUILD +++ b/tensorflow/python/feature_column/BUILD @@ -12,6 +12,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":feature_column", + ":feature_column_v2", "//tensorflow/python:util", ], ) diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py index aa66ed77e9..28c5c82d2c 100644 --- a/tensorflow/python/feature_column/feature_column_v2.py +++ b/tensorflow/python/feature_column/feature_column_v2.py @@ -385,6 +385,10 @@ class FeatureLayer(Layer): 'You can wrap a categorical column with an ' 'embedding_column or indicator_column. Given: {}'.format(column)) + @property + def _is_feature_layer(self): + return True + def build(self, _): for column in sorted(self._feature_columns, key=lambda x: x.name): if isinstance(column, SharedEmbeddingColumn): @@ -409,7 +413,13 @@ class FeatureLayer(Layer): A `Tensor` which represents input layer of a model. Its shape is (batch_size, first_layer_dimension) and its dtype is `float32`. first_layer_dimension is determined based on given `feature_columns`. + + Raises: + ValueError: If features are not a dictionary. """ + if not isinstance(features, dict): + raise ValueError('We expected a dictionary here. Instead we got: ', + features) transformation_cache = FeatureTransformationCache(features) output_tensors = [] ordered_columns = [] @@ -431,6 +441,12 @@ class FeatureLayer(Layer): _verify_static_batch_size_equality(output_tensors, ordered_columns) return array_ops.concat(output_tensors, 1) + def compute_output_shape(self, input_shape): + total_elements = 0 + for column in sorted(self._feature_columns, key=lambda x: x.name): + total_elements += column.variable_shape.num_elements() + return (input_shape[0], total_elements) + def linear_model(features, feature_columns, diff --git a/tensorflow/python/feature_column/feature_column_v2_test.py b/tensorflow/python/feature_column/feature_column_v2_test.py index 6b343ecf3e..58168e0f9e 100644 --- a/tensorflow/python/feature_column/feature_column_v2_test.py +++ b/tensorflow/python/feature_column/feature_column_v2_test.py @@ -2786,6 +2786,21 @@ class FeatureLayerTest(test.TestCase): with _initialized_session(): self.assertAllClose([[1., 2.], [5., 6.]], net.eval()) + def test_compute_output_shape(self): + price1 = fc.numeric_column('price1', shape=2) + price2 = fc.numeric_column('price2', shape=4) + with ops.Graph().as_default(): + features = { + 'price1': [[1., 2.], [5., 6.]], + 'price2': [[3., 4., 5., 6.], [7., 8., 9., 10.]] + } + feature_layer = FeatureLayer([price1, price2]) + self.assertEqual((None, 6), feature_layer.compute_output_shape((None,))) + net = feature_layer(features) + with _initialized_session(): + self.assertAllClose( + [[1., 2., 3., 4., 5., 6.], [5., 6., 7., 8., 9., 10.]], net.eval()) + def test_raises_if_shape_mismatch(self): price = fc.numeric_column('price', shape=2) with ops.Graph().as_default(): diff --git a/tensorflow/python/framework/constant_op.py b/tensorflow/python/framework/constant_op.py index eca34ac26e..4b2706d4cf 100644 --- a/tensorflow/python/framework/constant_op.py +++ b/tensorflow/python/framework/constant_op.py @@ -105,7 +105,8 @@ def convert_to_eager_tensor(value, ctx, dtype=None): scalar_cache = ctx.scalar_cache() tensor = scalar_cache.get(cache_key, None) if tensor is not None: - return tensor + return ops.EagerTensor( + value, context=handle, device=device, dtype=dtype, other_value=tensor) t = ops.EagerTensor(value, context=handle, device=device, dtype=dtype) scalar_cache[cache_key] = t return t diff --git a/tensorflow/python/framework/tensor_shape.py b/tensorflow/python/framework/tensor_shape.py index 11b681d544..3c2a736fb9 100644 --- a/tensorflow/python/framework/tensor_shape.py +++ b/tensorflow/python/framework/tensor_shape.py @@ -606,8 +606,8 @@ class TensorShape(object): slice. Raises: - ValueError: If `key` is a slice, and any of its elements are negative, or - if `self` is completely unknown and the step is set. + ValueError: If `key` is a slice and `self` is completely unknown and + the step is set. """ if self._dims is not None: if isinstance(key, slice): diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 0925598e33..4bece9e25e 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -465,29 +465,31 @@ def assert_no_new_pyobjects_executing_eagerly(f): f(self, **kwargs) gc.collect() previous_count = len(gc.get_objects()) - collection_sizes_before = { - collection: len(ops.get_collection(collection)) - for collection in ops.get_default_graph().collections - } + if ops.has_default_graph(): + collection_sizes_before = { + collection: len(ops.get_collection(collection)) + for collection in ops.get_default_graph().collections + } for _ in range(3): f(self, **kwargs) # Note that gc.get_objects misses anything that isn't subject to garbage # collection (C types). Collections are a common source of leaks, so we # test for collection sizes explicitly. - for collection_key in ops.get_default_graph().collections: - collection = ops.get_collection(collection_key) - size_before = collection_sizes_before.get(collection_key, 0) - if len(collection) > size_before: - raise AssertionError( - ("Collection %s increased in size from " - "%d to %d (current items %s).") % (collection_key, size_before, - len(collection), collection)) - # Make sure our collection checks don't show up as leaked memory by - # removing references to temporary variables. - del collection - del collection_key - del size_before - del collection_sizes_before + if ops.has_default_graph(): + for collection_key in ops.get_default_graph().collections: + collection = ops.get_collection(collection_key) + size_before = collection_sizes_before.get(collection_key, 0) + if len(collection) > size_before: + raise AssertionError( + ("Collection %s increased in size from " + "%d to %d (current items %s).") % + (collection_key, size_before, len(collection), collection)) + # Make sure our collection checks don't show up as leaked memory by + # removing references to temporary variables. + del collection + del collection_key + del size_before + del collection_sizes_before gc.collect() # There should be no new Python objects hanging around. new_count = len(gc.get_objects()) diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index 7246341519..290e182a79 100755 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -700,6 +700,20 @@ py_test( ) py_test( + name = "feature_columns_integration_test", + size = "small", + srcs = ["engine/feature_columns_integration_test.py"], + srcs_version = "PY2AND3", + tags = ["notsan"], + deps = [ + ":keras", + "//tensorflow/python:client_testlib", + "//tensorflow/python/feature_column:feature_column_py", + "//third_party/py/numpy", + ], +) + +py_test( name = "training_eager_test", size = "medium", srcs = ["engine/training_eager_test.py"], diff --git a/tensorflow/python/keras/engine/feature_columns_integration_test.py b/tensorflow/python/keras/engine/feature_columns_integration_test.py new file mode 100644 index 0000000000..e0478ee357 --- /dev/null +++ b/tensorflow/python/keras/engine/feature_columns_integration_test.py @@ -0,0 +1,237 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests specific to Feature Columns integration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python import keras +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.feature_column import feature_column_v2 as fc +from tensorflow.python.framework import test_util as tf_test_util +from tensorflow.python.keras import metrics as metrics_module +from tensorflow.python.platform import test +from tensorflow.python.training import rmsprop + + +class TestDNNModel(keras.models.Model): + + def __init__(self, feature_columns, units, name=None, **kwargs): + super(TestDNNModel, self).__init__(name=name, **kwargs) + self._input_layer = fc.FeatureLayer(feature_columns, name='input_layer') + self._dense_layer = keras.layers.Dense(units, name='dense_layer') + + def call(self, features): + net = self._input_layer(features) + net = self._dense_layer(net) + return net + + +class FeatureColumnsIntegrationTest(test.TestCase): + """Most Sequential model API tests are covered in `training_test.py`. + + """ + + @tf_test_util.run_in_graph_and_eager_modes + def test_sequential_model(self): + columns = [fc.numeric_column('a')] + model = keras.models.Sequential([ + fc.FeatureLayer(columns), + keras.layers.Dense(64, activation='relu'), + keras.layers.Dense(20, activation='softmax') + ]) + model.compile( + optimizer=rmsprop.RMSPropOptimizer(1e-3), + loss='categorical_crossentropy', + metrics=['accuracy']) + + x = {'a': np.random.random((10, 1))} + y = np.random.randint(20, size=(10, 1)) + y = keras.utils.to_categorical(y, num_classes=20) + model.fit(x, y, epochs=1, batch_size=5) + model.fit(x, y, epochs=1, batch_size=5) + model.evaluate(x, y, batch_size=5) + model.predict(x, batch_size=5) + + @tf_test_util.run_in_graph_and_eager_modes + def test_sequential_model_with_ds_input(self): + columns = [fc.numeric_column('a')] + model = keras.models.Sequential([ + fc.FeatureLayer(columns), + keras.layers.Dense(64, activation='relu'), + keras.layers.Dense(20, activation='softmax') + ]) + model.compile( + optimizer=rmsprop.RMSPropOptimizer(1e-3), + loss='categorical_crossentropy', + metrics=['accuracy']) + + y = np.random.randint(20, size=(100, 1)) + y = keras.utils.to_categorical(y, num_classes=20) + x = {'a': np.random.random((100, 1))} + ds1 = dataset_ops.Dataset.from_tensor_slices(x) + ds2 = dataset_ops.Dataset.from_tensor_slices(y) + ds = dataset_ops.Dataset.zip((ds1, ds2)).batch(5) + model.fit(ds, steps_per_epoch=1) + model.fit(ds, steps_per_epoch=1) + model.evaluate(ds, steps=1) + model.predict(ds, steps=1) + + @tf_test_util.run_in_graph_and_eager_modes + def test_subclassed_model_with_feature_columns(self): + col_a = fc.numeric_column('a') + col_b = fc.numeric_column('b') + + dnn_model = TestDNNModel([col_a, col_b], 20) + + dnn_model.compile( + optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.001), + loss='categorical_crossentropy', + metrics=['accuracy']) + + x = {'a': np.random.random((10, 1)), 'b': np.random.random((10, 1))} + y = np.random.randint(20, size=(10, 1)) + y = keras.utils.to_categorical(y, num_classes=20) + dnn_model.fit(x=x, y=y, epochs=1, batch_size=5) + dnn_model.fit(x=x, y=y, epochs=1, batch_size=5) + dnn_model.evaluate(x=x, y=y, batch_size=5) + dnn_model.predict(x=x, batch_size=5) + + @tf_test_util.run_in_graph_and_eager_modes + def test_subclassed_model_with_feature_columns_with_ds_input(self): + col_a = fc.numeric_column('a') + col_b = fc.numeric_column('b') + + dnn_model = TestDNNModel([col_a, col_b], 20) + + dnn_model.compile( + optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.001), + loss='categorical_crossentropy', + metrics=['accuracy']) + + y = np.random.randint(20, size=(100, 1)) + y = keras.utils.to_categorical(y, num_classes=20) + x = {'a': np.random.random((100, 1)), 'b': np.random.random((100, 1))} + ds1 = dataset_ops.Dataset.from_tensor_slices(x) + ds2 = dataset_ops.Dataset.from_tensor_slices(y) + ds = dataset_ops.Dataset.zip((ds1, ds2)).batch(5) + dnn_model.fit(ds, steps_per_epoch=1) + dnn_model.fit(ds, steps_per_epoch=1) + dnn_model.evaluate(ds, steps=1) + dnn_model.predict(ds, steps=1) + + @tf_test_util.run_in_graph_and_eager_modes + def DISABLED_test_function_model_feature_layer_input(self): + col_a = fc.numeric_column('a') + col_b = fc.numeric_column('b') + + feature_layer = fc.FeatureLayer([col_a, col_b], name='fc') + dense = keras.layers.Dense(4) + + # This seems problematic.... We probably need something for FeatureLayer + # the way Input is for InputLayer. + output = dense(feature_layer) + + model = keras.models.Model([feature_layer], [output]) + + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + loss_weights = [1., 0.5] + model.compile( + optimizer, + loss, + metrics=[metrics_module.CategoricalAccuracy(), 'mae'], + loss_weights=loss_weights) + + data = ({'a': np.arange(10), 'b': np.arange(10)}, np.arange(10, 20)) + print(model.fit(*data, epochs=1)) + + @tf_test_util.run_in_graph_and_eager_modes + def DISABLED_test_function_model_multiple_feature_layer_inputs(self): + col_a = fc.numeric_column('a') + col_b = fc.numeric_column('b') + col_c = fc.numeric_column('c') + + fc1 = fc.FeatureLayer([col_a, col_b], name='fc1') + fc2 = fc.FeatureLayer([col_b, col_c], name='fc2') + dense = keras.layers.Dense(4) + + # This seems problematic.... We probably need something for FeatureLayer + # the way Input is for InputLayer. + output = dense(fc1) + dense(fc2) + + model = keras.models.Model([fc1, fc2], [output]) + + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + loss_weights = [1., 0.5] + model.compile( + optimizer, + loss, + metrics=[metrics_module.CategoricalAccuracy(), 'mae'], + loss_weights=loss_weights) + + data_list = ([{ + 'a': np.arange(10), + 'b': np.arange(10) + }, { + 'b': np.arange(10), + 'c': np.arange(10) + }], np.arange(10, 100)) + print(model.fit(*data_list, epochs=1)) + + data_bloated_list = ([{ + 'a': np.arange(10), + 'b': np.arange(10), + 'c': np.arange(10) + }, { + 'a': np.arange(10), + 'b': np.arange(10), + 'c': np.arange(10) + }], np.arange(10, 100)) + print(model.fit(*data_bloated_list, epochs=1)) + + data_dict = ({ + 'fc1': { + 'a': np.arange(10), + 'b': np.arange(10) + }, + 'fc2': { + 'b': np.arange(10), + 'c': np.arange(10) + } + }, np.arange(10, 100)) + print(model.fit(*data_dict, epochs=1)) + + data_bloated_dict = ({ + 'fc1': { + 'a': np.arange(10), + 'b': np.arange(10), + 'c': np.arange(10) + }, + 'fc2': { + 'a': np.arange(10), + 'b': np.arange(10), + 'c': np.arange(10) + } + }, np.arange(10, 100)) + print(model.fit(*data_bloated_dict, epochs=1)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 966b446f22..d224dfffdd 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -45,6 +45,7 @@ from tensorflow.python.ops import weights_broadcast_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import optimizer as tf_optimizer_module from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.util import nest from tensorflow.python.util.tf_export import tf_export @@ -862,7 +863,8 @@ class Model(Network): Fraction of the training data to be used as validation data. Returns: - A tuple of 3 lists: input arrays, target arrays, sample-weight arrays. + A tuple of 3: inputs (arrays or dicts, depending on whether `x` was a dict + or not), target arrays, sample-weight arrays. If the model's input and targets are symbolic, these lists are empty (since the model takes no user-provided data, instead the data comes from the symbolic inputs/targets). @@ -928,11 +930,16 @@ class Model(Network): 'Make sure that your dataset can generate ' 'required number of samples.') - if not isinstance(next_element, (list, tuple)) or len(next_element) != 2: - raise ValueError('Please provide model inputs as a list or tuple of 2 ' - 'elements: input and target pair. ' - 'Received %s' % next_element) - x, y = next_element + if (not isinstance(next_element, (list, tuple)) or + len(next_element) not in [2, 3]): + raise ValueError( + 'Please provide model inputs as a list or tuple of 2 or 3' + 'elements: (input, target) or (input, target, sample_weights)' + 'Received %s' % next_element) + if len(next_element) == 2: + x, y = next_element + else: + x, y, sample_weight = next_element x, y, sample_weights = self._standardize_weights(x, y, sample_weight, class_weight, batch_size) return x, y, sample_weights @@ -948,6 +955,7 @@ class Model(Network): all_inputs = [] is_build_called = False is_compile_called = False + dict_inputs = False if not self.inputs: # We need to use `x` to set the model inputs. # We type-check that `x` and `y` are either single arrays @@ -959,7 +967,9 @@ class Model(Network): 'array or a list of arrays. You passed: x=' + str(x)) all_inputs += list(x) elif isinstance(x, dict): - raise ValueError('Please do not pass a dictionary as model inputs.') + dict_inputs = True + keys = sorted(x.keys()) + all_inputs = [x[k] for k in keys] else: if not isinstance(x, np.ndarray) and not tensor_util.is_tensor(x): raise ValueError('Please provide as model inputs either a single ' @@ -972,6 +982,8 @@ class Model(Network): if not self.inputs: is_build_called = True self._set_inputs(x) + else: + dict_inputs = isinstance(self.inputs, dict) if y is not None: if not self.optimizer: @@ -1124,6 +1136,10 @@ class Model(Network): 'a number of samples that can be ' 'divided by the batch size. Found: ' + str(x[0].shape[0]) + ' samples') + + # If dictionary inputs were provided, we return a dictionary as well. + if dict_inputs: + x = dict(zip(feed_input_names, x)) return x, y, sample_weights @checkpointable.no_automatic_dependency_tracking @@ -1146,6 +1162,9 @@ class Model(Network): training: Boolean or None. Only relevant in symbolic mode. Specifies whether to build the model's graph in inference mode (False), training mode (True), or using the Keras learning phase (None). + Raises: + ValueError: If dict inputs are passed to a Sequential Model where the + first layer isn't FeatureLayer. """ call_convention = getattr( self, @@ -1162,6 +1181,14 @@ class Model(Network): if tensor_util.is_tensor(inputs): input_shape = (None,) + tuple(inputs.get_shape().as_list()[1:]) self.build(input_shape=input_shape) + elif isinstance(inputs, dict): + # We assert that the first layer is a FeatureLayer. + if not training_utils.is_feature_layer(self.layers[0]): + raise ValueError('Passing a dictionary input to a Sequential Model ' + 'which doesnt have FeatureLayer as the first layer ' + 'is an error') + input_shape = (None,) + self.build(input_shape=input_shape) else: input_shape = (None,) + inputs.shape[1:] self.build(input_shape=input_shape) @@ -1189,36 +1216,22 @@ class Model(Network): assert context.executing_eagerly() if self.inputs: raise ValueError('Model inputs are already set.') + # On-the-fly setting of model inputs/outputs as DeferredTensors, # to keep track of number of inputs and outputs and their ndim. - if isinstance(inputs, (list, tuple)): - if tensor_util.is_tensor(inputs[0]): - dummy_output_values = self.call( - training_utils.cast_if_floating_dtype(inputs)) - else: - dummy_output_values = self.call( - [ops.convert_to_tensor(v, dtype=K.floatx()) for v in inputs]) - dummy_input_values = list(inputs) - else: - if tensor_util.is_tensor(inputs): - dummy_output_values = self.call( - training_utils.cast_if_floating_dtype(inputs)) - else: - dummy_output_values = self.call( - ops.convert_to_tensor(inputs, dtype=K.floatx())) - dummy_input_values = [inputs] - if isinstance(dummy_output_values, (list, tuple)): - dummy_output_values = list(dummy_output_values) - else: - dummy_output_values = [dummy_output_values] + model_inputs = training_utils.ModelInputs(inputs) + dummy_input_values = model_inputs.get_input_values() + dummy_output_values = self.call(dummy_input_values) + + self.inputs = model_inputs.get_symbolic_inputs(return_single_as_list=True) + self.input_names = model_inputs.get_input_names() + + dummy_output_values = nest.flatten(dummy_output_values) self.outputs = [ - base_layer.DeferredTensor(shape=(None for _ in v.shape), - dtype=v.dtype) for v in dummy_output_values] - self.inputs = [ - base_layer.DeferredTensor(shape=(None for _ in v.shape), - dtype=v.dtype) for v in dummy_input_values] - self.input_names = [ - 'input_%d' % (i + 1) for i in range(len(dummy_input_values))] + base_layer.DeferredTensor(shape=(None + for _ in v.shape), dtype=v.dtype) + for v in dummy_output_values + ] self.output_names = [ 'output_%d' % (i + 1) for i in range(len(dummy_output_values))] self.built = True @@ -1248,58 +1261,29 @@ class Model(Network): # On-the-fly setting of symbolic model inputs (either by using the tensor # provided, or by creating a placeholder if Numpy data was provided). - self.inputs = [] - self.input_names = [] + model_inputs = training_utils.ModelInputs(inputs) + dummy_input_values = model_inputs.get_symbolic_inputs() + self.inputs = model_inputs.get_symbolic_inputs(return_single_as_list=True) + self.input_names = model_inputs.get_input_names() + self._feed_inputs = [] self._feed_input_names = [] self._feed_input_shapes = [] - if isinstance(inputs, (list, tuple)): - inputs = list(inputs) - else: - inputs = [inputs] - - for i, v in enumerate(inputs): - name = 'input_%d' % (i + 1) - self.input_names.append(name) - if isinstance(v, list): - v = np.asarray(v) - if v.ndim == 1: - v = np.expand_dims(v, 1) - if isinstance(v, (np.ndarray)): - # We fix the placeholder shape except the batch size. - # This is suboptimal, but it is the best we can do with the info - # we have. The user should call `model._set_inputs(placeholders)` - # to specify custom placeholders if the need arises. - shape = (None,) + v.shape[1:] - placeholder = K.placeholder(shape=shape, name=name) - self.inputs.append(placeholder) - self._feed_inputs.append(placeholder) - self._feed_input_names.append(name) - self._feed_input_shapes.append(shape) - else: - # Assumed tensor - TODO(fchollet) additional type check? - self.inputs.append(v) - if K.is_placeholder(v): - self._feed_inputs.append(v) - self._feed_input_names.append(name) - self._feed_input_shapes.append(K.int_shape(v)) + + for k, v in model_inputs.as_dict(): + if K.is_placeholder(v): + self._feed_inputs.append(v) + self._feed_input_names.append(k) + self._feed_input_shapes.append(K.int_shape(v)) if outputs is None: # Obtain symbolic outputs by calling the model. - if len(self.inputs) == 1: - if self._expects_training_arg: - outputs = self.call(self.inputs[0], training=training) - else: - outputs = self.call(self.inputs[0]) + if self._expects_training_arg: + outputs = self.call(dummy_input_values, training=training) else: - if self._expects_training_arg: - outputs = self.call(self.inputs, training=training) - else: - outputs = self.call(self.inputs) - if isinstance(outputs, (list, tuple)): - outputs = list(outputs) - else: - outputs = [outputs] + outputs = self.call(dummy_input_values) + + outputs = nest.flatten(outputs) self.outputs = outputs self.output_names = [ 'output_%d' % (i + 1) for i in range(len(self.outputs))] @@ -1331,7 +1315,8 @@ class Model(Network): (in case the model has multiple inputs). - A dict mapping input names to the corresponding array/tensors, if the model has named inputs. - - A `tf.data` dataset or a dataset iterator. + - A `tf.data` dataset or a dataset iterator. Should return a tuple + of either (inputs, targets) or (inputs, targets, sample_weights). y: Target data. Like the input data `x`, it could be either Numpy array(s) or TensorFlow tensor(s). It should be consistent with `x` (you cannot have Numpy inputs and @@ -1396,7 +1381,8 @@ class Model(Network): to apply a different weight to every timestep of every sample. In this case you should make sure to specify `sample_weight_mode="temporal"` in `compile()`. This argument is not - supported when `x` is a dataset or a dataset iterator. + supported when `x` is a dataset or a dataset iterator, instead + provide the sample_weights as the third element of `x`. initial_epoch: Integer. Epoch at which to start training (useful for resuming a previous training run). diff --git a/tensorflow/python/keras/engine/training_arrays.py b/tensorflow/python/keras/engine/training_arrays.py index e2c458c65f..95b864bef0 100644 --- a/tensorflow/python/keras/engine/training_arrays.py +++ b/tensorflow/python/keras/engine/training_arrays.py @@ -55,7 +55,7 @@ def fit_loop(model, Arguments: model: Keras Model instance. - inputs: List of input arrays. + inputs: Either a list of arrays or a dictionary. targets: List of target arrays. sample_weights: Optional list of sample weight arrays. batch_size: Integer batch size or None if unknown. @@ -88,6 +88,7 @@ def fit_loop(model, sample_weights = sample_weights or [] val_sample_weights = val_sample_weights or [] + inputs = training_utils.ModelInputs(inputs).as_list() if model.uses_learning_phase and not isinstance(K.learning_phase(), int): ins = inputs + targets + sample_weights + [1] else: @@ -262,6 +263,7 @@ def predict_loop(model, inputs, batch_size=32, verbose=0, steps=None): model._make_predict_function() f = model.predict_function + inputs = training_utils.ModelInputs(inputs).as_list() if model.uses_learning_phase and not isinstance(K.learning_phase(), int): ins = inputs + [0] else: @@ -368,6 +370,7 @@ def test_loop(model, f = model.test_function sample_weights = sample_weights or [] + inputs = training_utils.ModelInputs(inputs).as_list() if model.uses_learning_phase and not isinstance(K.learning_phase(), int): ins = inputs + targets + sample_weights + [0] else: diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py index e440e02bfb..939732cd67 100644 --- a/tensorflow/python/keras/engine/training_distributed.py +++ b/tensorflow/python/keras/engine/training_distributed.py @@ -70,7 +70,8 @@ def fit_loop( # TODO(priyag, sourabhbajaj): Remove this when the codepaths are merged. if current_strategy.__class__.__name__ == 'TPUStrategy': return _experimental_fit_loop( - model, iterator, epochs, initial_epoch, steps_per_epoch) + model, iterator, epochs, verbose, callbacks, initial_epoch, + steps_per_epoch) clone_model_on_towers( model, current_strategy, make_callback_model=True) @@ -201,6 +202,8 @@ def _experimental_fit_loop( model, iterator, epochs=100, + verbose=1, + callbacks=None, initial_epoch=0, steps_per_epoch=None): """fit function when using TPU DistributionStrategy for training. @@ -209,6 +212,8 @@ def _experimental_fit_loop( model: Keras Model instance. iterator: Iterator that returns inputs and targets epochs: Number of times to iterate over the data + verbose: Verbosity mode, 0, 1 or 2 + callbacks: List of callbacks to be called during training initial_epoch: Epoch at which to start training (useful for resuming a previous training run) steps_per_epoch: Total number of steps (batches of samples) @@ -225,7 +230,6 @@ def _experimental_fit_loop( # TODO(priyag): Add validation that shapes are fully defined for TPU case. - # TODO(priyag, sourabhbajaj): This should be moved into a callback instead. K.get_session().run(current_strategy.initialize()) def _per_device_train_function(model): @@ -298,19 +302,35 @@ def _experimental_fit_loop( assert steps_per_epoch is not None - # TODO(priyag, sourabhbajaj): Add callbacks support. + # TODO(sourabhbajaj): Convert this into a proper validation function + if callbacks: + raise NotImplementedError( + 'Callbacks are not supported with TPUStrategy right now.') + + callbacks = cbks.configure_callbacks( + callbacks, + model, + do_validation=False, + val_inputs=None, + val_targets=None, + epochs=epochs, + steps_per_epoch=steps_per_epoch, + verbose=verbose) + # TODO(priyag, sourabhbajaj): Add callbacks support for per step callback + # TODO(priyag, sourabhbajaj): Fix the number of steps run with steps_per_run # TODO(priyag, sourabhbajaj): Add validation. + callbacks.on_train_begin() for epoch in range(initial_epoch, epochs): - for step_index in range( - 0, steps_per_epoch, current_strategy.steps_per_run): + callbacks.on_epoch_begin(epoch) + epoch_logs = {} + for step_index in range(0, steps_per_epoch, current_strategy.steps_per_run): + # TODO(sourabhbajaj): Add the size parameter in batch_logs once callbacks + # are fixed as we need to replace size with a combination of steps_per_run + # and batch_size + batch_logs = {'batch': step_index} + callbacks.on_batch_begin(step_index, batch_logs) try: - _, outs = K.get_session().run([train_op, output_tensors]) - # TODO(priyag, sourabhbajaj): Remove this logging in favor of proper - # summaries through callbacks. - print('Epoch: {}, step_index: {}, loss: {}'.format( - epoch, step_index, outs['loss'])) - for label, out in outs.items(): - print(label, ': ', out) + _, outputs = K.get_session().run([train_op, output_tensors]) except errors.OutOfRangeError: logging.warning('Your dataset iterator ran out of data; ' 'interrupting training. Make sure that your dataset ' @@ -319,6 +339,16 @@ def _experimental_fit_loop( steps_per_epoch * epochs) break + batch_logs.update(outputs) + callbacks.on_batch_end(step_index, batch_logs) + if callbacks.model.stop_training: + break + + callbacks.on_epoch_end(epoch, epoch_logs) + if callbacks.model.stop_training: + break + callbacks.on_train_end() + # Copy the weights back from the replicated model to the original model. with current_strategy.scope(): updated_weights = current_strategy.unwrap( @@ -326,8 +356,7 @@ def _experimental_fit_loop( model.set_weights(updated_weights) K.get_session().run(current_strategy.finalize()) - - # TODO(priyag, sourabhbajaj): Return history. + return model.history def test_loop(model, iterator, verbose=0, steps=None): diff --git a/tensorflow/python/keras/engine/training_eager.py b/tensorflow/python/keras/engine/training_eager.py index 1e377149b6..939a7f2356 100644 --- a/tensorflow/python/keras/engine/training_eager.py +++ b/tensorflow/python/keras/engine/training_eager.py @@ -67,7 +67,8 @@ def _model_loss(model, inputs, targets, sample_weights=None, training=False): Arguments: model: The model on which metrics are being calculated. - inputs: List of input arrays. + inputs: Either a dictionary of inputs to the model or a list of input + arrays. targets: List of target arrays. sample_weights: Optional list of sample weight arrays. training: Whether the model should be run in inference or training mode. @@ -82,7 +83,7 @@ def _model_loss(model, inputs, targets, sample_weights=None, training=False): kwargs = {} if model._expects_training_arg: kwargs['training'] = training - if len(inputs) == 1: + if len(inputs) == 1 and not isinstance(inputs, dict): inputs = inputs[0] if model._compute_output_and_mask_jointly: @@ -369,6 +370,8 @@ def iterator_test_loop(model, inputs, steps, verbose=0): # Get current step size. if isinstance(x, list): step_size = x[0].get_shape().as_list()[0] + elif isinstance(x, dict): + step_size = list(x.values())[0].get_shape().as_list()[0] else: step_size = x.get_shape().as_list()[0] @@ -417,11 +420,12 @@ def iterator_predict_loop(model, inputs, steps, verbose=0): """ assert isinstance(inputs, iterator_ops.EagerIterator) if not isinstance(inputs.output_shapes, - (list, tuple)) or len(inputs.output_shapes) > 2: + (list, tuple)) or len(inputs.output_shapes) > 3: raise ValueError( - 'Please provide data as a list or tuple of 1 or 2 elements ' - ' - input or input and target pair. Received %s. We do not use the ' - '`target` value here.' % inputs.output_shapes) + 'Please provide data as a list or tuple of 1, 2, or 3 elements ' + ' - `(input)`, or `(input, target)`, or `(input, target,' + 'sample_weights)`. Received %s. We do not use the `target` or' + '`sample_weights` value here.' % inputs.output_shapes) outs = [] if verbose == 1: progbar = generic_utils.Progbar(target=steps) @@ -444,10 +448,13 @@ def iterator_predict_loop(model, inputs, steps, verbose=0): x, _, _ = model._standardize_user_data(x) x = training_utils.cast_if_floating_dtype(x) + if isinstance(x, list) and len(x) == 1: + x = x[0] + if model._expects_training_arg: - batch_outs = model.call(x[0] if len(x) == 1 else x, training=False) + batch_outs = model.call(x, training=False) else: - batch_outs = model.call(x[0] if len(x) == 1 else x) + batch_outs = model.call(x) if not isinstance(batch_outs, list): batch_outs = [batch_outs] diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py index bf5c7fd7f8..1d0d113e40 100644 --- a/tensorflow/python/keras/engine/training_test.py +++ b/tensorflow/python/keras/engine/training_test.py @@ -481,8 +481,8 @@ class LossWeightingTest(test.TestCase): num_hidden=10, num_classes=num_classes, input_dim=input_dim) model.compile( loss='categorical_crossentropy', - metrics=['acc'], - weighted_metrics=['mae'], + metrics=['acc', metrics_module.CategoricalAccuracy()], + weighted_metrics=['mae', metrics_module.CategoricalAccuracy()], optimizer=RMSPropOptimizer(learning_rate=learning_rate)) np.random.seed(1337) @@ -536,6 +536,25 @@ class LossWeightingTest(test.TestCase): self.assertLess(score[0], ref_score[0]) @tf_test_util.run_in_graph_and_eager_modes + def test_sequential_model_fails_with_dict_inputs(self): + num_classes = 5 + model = testing_utils.get_small_sequential_mlp( + num_hidden=10, num_classes=num_classes) + model.compile( + RMSPropOptimizer(learning_rate=0.001), + metrics=['acc'], + weighted_metrics=['mae'], + loss='categorical_crossentropy') + + x = {'dense_input': np.random.random((10, 1))} + y = np.random.randint(num_classes, size=(10, 1)) + + with self.assertRaisesRegexp( + ValueError, 'Passing a dictionary input to a Sequential Model which ' + 'doesnt have FeatureLayer as the first layer is an error'): + model.fit(x, y, batch_size=5, epochs=1) + + @tf_test_util.run_in_graph_and_eager_modes def test_sample_weights(self): num_classes = 5 batch_size = 5 @@ -550,8 +569,8 @@ class LossWeightingTest(test.TestCase): num_hidden=10, num_classes=num_classes, input_dim=input_dim) model.compile( RMSPropOptimizer(learning_rate=learning_rate), - metrics=['acc'], - weighted_metrics=['mae'], + metrics=['acc', metrics_module.CategoricalAccuracy()], + weighted_metrics=['mae', metrics_module.CategoricalAccuracy()], loss='categorical_crossentropy') np.random.seed(43) @@ -679,8 +698,8 @@ class LossWeightingTest(test.TestCase): model.compile( RMSPropOptimizer(learning_rate=learning_rate), loss='binary_crossentropy', - metrics=['acc'], - weighted_metrics=['mae'], + metrics=['acc', metrics_module.CategoricalAccuracy()], + weighted_metrics=['mae', metrics_module.CategoricalAccuracy()], sample_weight_mode='temporal') model.fit( @@ -2097,6 +2116,43 @@ class TestTrainingWithDataset(test.TestCase): 'you should specify the `steps` argument'): model.predict(dataset, verbose=0) + @tf_test_util.run_in_graph_and_eager_modes + def test_dataset_with_sample_weights(self): + model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3) + optimizer = RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + metrics = ['mae', metrics_module.CategoricalAccuracy()] + model.compile(optimizer, loss, metrics=metrics) + + inputs = np.zeros((10, 3), np.float32) + targets = np.zeros((10, 4), np.float32) + sample_weights = np.ones((10), np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets, + sample_weights)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) + model.evaluate(dataset, steps=2, verbose=1) + model.predict(dataset, steps=2) + model.train_on_batch(dataset) + model.predict_on_batch(dataset) + + @tf_test_util.run_in_graph_and_eager_modes + def test_dataset_with_sparse_labels(self): + model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3) + optimizer = RMSPropOptimizer(learning_rate=0.001) + loss = 'sparse_categorical_crossentropy' + model.compile(optimizer, loss) + + inputs = np.zeros((10, 3)) + targets = np.random.randint(0, 4, size=10, dtype=np.int32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) + def test_dataset_input_shape_validation(self): with self.test_session(): model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3) @@ -2108,8 +2164,10 @@ class TestTrainingWithDataset(test.TestCase): dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) dataset = dataset.repeat(100) - with self.assertRaisesRegexp(ValueError, - r'expected (.*?) to have 2 dimensions'): + with self.assertRaisesRegexp( + ValueError, + r'expected (.*?) to have shape \(3,\) but got array with shape \(1,\)' + ): model.train_on_batch(dataset) # Wrong input shape diff --git a/tensorflow/python/keras/engine/training_utils.py b/tensorflow/python/keras/engine/training_utils.py index f94697c913..898e9223cb 100644 --- a/tensorflow/python/keras/engine/training_utils.py +++ b/tensorflow/python/keras/engine/training_utils.py @@ -22,18 +22,22 @@ import copy import math import numpy as np +import six from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.eager import context from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.keras import backend as K from tensorflow.python.keras import losses from tensorflow.python.keras import metrics as metrics_module +from tensorflow.python.keras.engine import base_layer from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import weights_broadcast_ops +from tensorflow.python.util import nest def _map_nested(data, func): @@ -210,10 +214,11 @@ def check_num_samples(ins, def standardize_single_array(x): if x is None: return None - elif tensor_util.is_tensor(x): - return x - elif x.ndim == 1: - x = np.expand_dims(x, 1) + if x.shape is not None and len(x.shape) == 1: + if tensor_util.is_tensor(x): + return array_ops.expand_dims(x, axis=1) + else: + return np.expand_dims(x, 1) return x @@ -245,7 +250,8 @@ def standardize_input_data(data, ValueError: in case of improperly formatted user-provided data. """ if not names: - if data is not None and hasattr(data, '__len__') and len(data): + if (data is not None and hasattr(data, '__len__') and len(data) and + not isinstance(data, dict)): raise ValueError('Error when checking model ' + exception_prefix + ': ' 'expected no data, but got:', data) return [] @@ -341,7 +347,7 @@ def standardize_sample_or_class_weights(x_weight, output_names, weight_type): Raises: ValueError: In case of invalid user-provided argument. """ - if x_weight is None or len(x_weight) == 0: # pylint: disable=g-explicit-length-test + if x_weight is None or (isinstance(x_weight, list) and len(x_weight) == 0): # pylint: disable=g-explicit-length-test return [None for _ in output_names] if len(output_names) == 1: if isinstance(x_weight, list) and len(x_weight) == 1: @@ -675,7 +681,8 @@ def standardize_weights(y, 'Expected sample_weight with rank ' 'less than or equal to ' + str(len(y.shape))) - if y.shape[:sample_weight.ndim] != sample_weight.shape: + if (not tensor_util.is_tensor(sample_weight) and + y.shape[:sample_weight.ndim] != sample_weight.shape): raise ValueError( 'Found a sample_weight array with shape ' + str(sample_weight.shape) + ' for an input with shape ' + str(y.shape) + '. ' @@ -717,6 +724,8 @@ def has_symbolic_tensors(ls): def has_tensors(ls): if isinstance(ls, (list, tuple)): return any(tensor_util.is_tensor(v) for v in ls) + if isinstance(ls, dict): + return any(tensor_util.is_tensor(v) for _, v in six.iteritems(ls)) return tensor_util.is_tensor(ls) @@ -777,7 +786,9 @@ def validate_iterator_input(x, y, sample_weight, validation_split=None): 'Received: %s' % (x, y)) if sample_weight is not None: raise ValueError('`sample_weight` argument is not supported when input ' - '`x` is a dataset or a dataset iterator. ' + '`x` is a dataset or a dataset iterator. Instead, you' + 'can provide sample_weight as the third element of your' + 'dataset, i.e. (inputs, targets, sample_weight). ' 'Received: x=%s, sample_weight=%s' % (x, sample_weight)) if validation_split is not None and validation_split != 0.0: raise ValueError( @@ -825,6 +836,12 @@ def check_steps_argument(input_data, steps, steps_name): return False +def cast_single_tensor(x): + if tensor_util.is_tensor(x) and x.dtype.is_floating: + return math_ops.cast(x, dtype=K.floatx()) + return x + + def cast_if_floating_dtype(x): """Casts the given data tensors to the default floating point type. @@ -842,13 +859,7 @@ def cast_if_floating_dtype(x): raise RuntimeError( 'Please provide tensors for casting, got: {x}'.format(x=x)) - if isinstance(x, (list, tuple)): - return [ - math_ops.cast(val, dtype=K.floatx()) - if tensor_util.is_tensor(val) and val.dtype.is_floating else val - for val in x - ] - return math_ops.cast(x, dtype=K.floatx()) if x.dtype.is_floating else x + return nest.map_structure(cast_single_tensor, x) def get_output_sample_weight_and_mode(skip_target_weighing_indices, @@ -929,3 +940,103 @@ def prepare_sample_weights(output_names, sample_weight_mode, sample_weights.append(weight) sample_weight_modes.append(mode) return sample_weights, sample_weight_modes + + +# TODO(rohanj): This is a hack to get around not depending on feature_column and +# create a cyclical dependency. Figure out a cleaner solution +def is_feature_layer(layer): + """Returns whether `layer` is a FeatureLayer or not.""" + return getattr(layer, '_is_feature_layer', False) + + +class ModelInputs(object): + """Encapsulates model inputs. + + Allows for transforming model inputs while keeping the same structure. + """ + + def __init__(self, inputs): + self._inputs = inputs + self._is_dict = isinstance(self._inputs, dict) + self._is_single_input = not isinstance(self._inputs, (list, tuple, dict)) + self._flattened_inputs = [] + self._input_names = [] + if isinstance(self._inputs, dict): + for k in sorted(self._inputs.keys()): + self._flattened_inputs.append(self._inputs[k]) + self._input_names.append(k) + else: + self._flattened_inputs = nest.flatten(self._inputs) + self._input_names = [ + 'input_%d' % (i + 1) for i in range(len(self._flattened_inputs)) + ] + assert len(self._input_names) == len(self._flattened_inputs) + + def get_input_names(self): + """Returns keys to name inputs by. + + In case inputs provided were a list, tuple or single entry, we make up a + key 'input_%d'. For dictionary case, we return a sorted list of keys. + """ + return self._input_names + + def _get(self, return_single_as_list=False): + """Returns provided inputs, potentially transformed. + + Inputs are returned in the same format they were provided i.e. lists + are returned as lists, single entries as single entries (unless + `return_single_as_list` is true), dictionaries as dictionaries. + + Args: + return_single_as_list: Returns a list of size 1 for single entry case. + """ + if self._is_dict: + return dict(zip(self._input_names, self._flattened_inputs)) + if self._is_single_input and not return_single_as_list: + return self._flattened_inputs[0] + return self._flattened_inputs + + def get_input_values(self): + """Returns input values passed in.""" + if context.executing_eagerly(): + for i in range(len(self._flattened_inputs)): + v = self._flattened_inputs[i] + if tensor_util.is_tensor(v): + v = cast_single_tensor(v) + else: + v = ops.convert_to_tensor(v, dtype=K.floatx()) + self._flattened_inputs[i] = v + return self._get(return_single_as_list=False) + + def get_symbolic_inputs(self, return_single_as_list=False): + """Returns inputs to be set as self.inputs for a model.""" + for i in range(len(self._flattened_inputs)): + k = self._input_names[i] + v = self._flattened_inputs[i] + if context.executing_eagerly(): + v = base_layer.DeferredTensor( + shape=(None for _ in v.shape), dtype=v.dtype) + else: + if isinstance(v, list): + v = np.asarray(v) + if v.ndim == 1: + v = np.expand_dims(v, 1) + if isinstance(v, (np.ndarray)): + # We fix the placeholder shape except the batch size. + # This is suboptimal, but it is the best we can do with the info + # we have. The user should call `model._set_inputs(placeholders)` + # to specify custom placeholders if the need arises. + shape = (None,) + v.shape[1:] + v = K.placeholder(shape=shape, name=k) + self._flattened_inputs[i] = v + + return self._get(return_single_as_list) + + def as_dict(self): + """An iterable over a dictionary version of inputs.""" + for i in range(len(self._flattened_inputs)): + yield self._input_names[i], self._flattened_inputs[i] + + def as_list(self): + """Returning the inputs as a list.""" + return self._flattened_inputs diff --git a/tensorflow/python/keras/engine/training_utils_test.py b/tensorflow/python/keras/engine/training_utils_test.py index 297a1ae494..e777cb6db3 100644 --- a/tensorflow/python/keras/engine/training_utils_test.py +++ b/tensorflow/python/keras/engine/training_utils_test.py @@ -20,8 +20,11 @@ from __future__ import print_function import numpy as np +from tensorflow.python.eager import context from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.framework import test_util +from tensorflow.python.keras.engine import base_layer from tensorflow.python.keras.engine import training_utils from tensorflow.python.platform import test @@ -146,5 +149,91 @@ class TrainingUtilTest(test.TestCase): self.assertEquals(any_true, False) +class ModelInputsTest(test.TestCase): + + def test_single_thing(self): + a = np.ones(10) + model_inputs = training_utils.ModelInputs(a) + self.assertEquals(['input_1'], model_inputs.get_input_names()) + vals = model_inputs.get_input_values() + self.assertAllEqual(np.ones(10), vals) + self.assertFalse(tensor_util.is_tensor(vals)) + vals = model_inputs.get_symbolic_inputs() + self.assertTrue(tensor_util.is_tensor(vals)) + vals = model_inputs.get_symbolic_inputs(return_single_as_list=True) + self.assertEquals(1, len(vals)) + self.assertTrue(tensor_util.is_tensor(vals[0])) + + def test_single_thing_eager(self): + with context.eager_mode(): + a = np.ones(10) + model_inputs = training_utils.ModelInputs(a) + self.assertEquals(['input_1'], model_inputs.get_input_names()) + vals = model_inputs.get_input_values() + self.assertAllEqual(np.ones(10), vals) + self.assertTrue(tensor_util.is_tensor(vals)) + vals = model_inputs.get_symbolic_inputs() + self.assertTrue(isinstance(vals, base_layer.DeferredTensor)) + vals = model_inputs.get_symbolic_inputs(return_single_as_list=True) + self.assertEquals(1, len(vals)) + self.assertTrue(isinstance(vals[0], base_layer.DeferredTensor)) + + def test_list(self): + a = [np.ones(10), np.ones(20)] + model_inputs = training_utils.ModelInputs(a) + self.assertEquals(['input_1', 'input_2'], model_inputs.get_input_names()) + vals = model_inputs.get_input_values() + self.assertEqual(2, len(vals)) + self.assertAllEqual(np.ones(10), vals[0]) + self.assertAllEqual(np.ones(20), vals[1]) + self.assertFalse(tensor_util.is_tensor(vals[0])) + self.assertFalse(tensor_util.is_tensor(vals[1])) + vals = model_inputs.get_symbolic_inputs() + self.assertTrue(tensor_util.is_tensor(vals[0])) + self.assertTrue(tensor_util.is_tensor(vals[1])) + + def test_list_eager(self): + with context.eager_mode(): + a = [np.ones(10), np.ones(20)] + model_inputs = training_utils.ModelInputs(a) + self.assertEquals(['input_1', 'input_2'], model_inputs.get_input_names()) + vals = model_inputs.get_input_values() + self.assertEqual(2, len(vals)) + self.assertAllEqual(np.ones(10), vals[0]) + self.assertAllEqual(np.ones(20), vals[1]) + self.assertTrue(tensor_util.is_tensor(vals[0])) + self.assertTrue(tensor_util.is_tensor(vals[1])) + vals = model_inputs.get_symbolic_inputs() + self.assertTrue(isinstance(vals[0], base_layer.DeferredTensor)) + self.assertTrue(isinstance(vals[1], base_layer.DeferredTensor)) + + def test_dict(self): + a = {'b': np.ones(10), 'a': np.ones(20)} + model_inputs = training_utils.ModelInputs(a) + self.assertEquals(['a', 'b'], model_inputs.get_input_names()) + vals = model_inputs.get_input_values() + self.assertAllEqual(np.ones(20), vals['a']) + self.assertAllEqual(np.ones(10), vals['b']) + self.assertFalse(tensor_util.is_tensor(vals['a'])) + self.assertFalse(tensor_util.is_tensor(vals['b'])) + vals = model_inputs.get_symbolic_inputs() + self.assertTrue(tensor_util.is_tensor(vals['a'])) + self.assertTrue(tensor_util.is_tensor(vals['b'])) + + def test_dict_eager(self): + with context.eager_mode(): + a = {'b': np.ones(10), 'a': np.ones(20)} + model_inputs = training_utils.ModelInputs(a) + self.assertEquals(['a', 'b'], model_inputs.get_input_names()) + vals = model_inputs.get_input_values() + self.assertAllEqual(np.ones(20), vals['a']) + self.assertAllEqual(np.ones(10), vals['b']) + self.assertTrue(tensor_util.is_tensor(vals['a'])) + self.assertTrue(tensor_util.is_tensor(vals['b'])) + vals = model_inputs.get_symbolic_inputs() + self.assertTrue(isinstance(vals['a'], base_layer.DeferredTensor)) + self.assertTrue(isinstance(vals['b'], base_layer.DeferredTensor)) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py index 81c760b1f6..473d8cd95b 100644 --- a/tensorflow/python/keras/metrics.py +++ b/tensorflow/python/keras/metrics.py @@ -22,7 +22,10 @@ from __future__ import print_function from abc import ABCMeta from abc import abstractmethod +import functools +import sys import types +import weakref import six from tensorflow.python.eager import context @@ -137,6 +140,21 @@ def result_wrapper(result_fn): return tf_decorator.make_decorator(result_fn, decorated) +def weakmethod(method): + """Creates a weak reference to the bound method.""" + + cls = method.im_class + func = method.im_func + instance_ref = weakref.ref(method.im_self) + + @functools.wraps(method) + def inner(*args, **kwargs): + return func.__get__(instance_ref(), cls)(*args, **kwargs) + + del method + return inner + + def safe_div(numerator, denominator): """Divides two tensors element-wise, returning 0 if the denominator is <= 0. @@ -318,14 +336,27 @@ class Metric(Layer): def __new__(cls, *args, **kwargs): obj = super(Metric, cls).__new__(cls) - # TODO(psv): Fix reference cycle issue here. - - # Converting update_state_fn() into a graph function, so that - # we can return a single op that performs all of the variable updates. - defuned_update_state_fn = function.defun(obj.update_state) - obj.update_state = types.MethodType( - update_state_wrapper(defuned_update_state_fn), obj) - obj.result = types.MethodType(result_wrapper(obj.result), obj) + + if sys.version_info < (3,): + # Wrap methods in `weakmethod` function to remove binding and create a + # weak reference. This is to remove reference cycle that is created here. + # This is not an issue in python versions > 3. + if context.executing_eagerly(): + update_state = weakmethod(obj.update_state) + else: + update_state = function.defun(obj.update_state) + obj.update_state = weakmethod( + types.MethodType(update_state_wrapper(update_state), obj)) + result = weakmethod(obj.result) + obj.result = weakmethod(types.MethodType(result_wrapper(result), obj)) + else: + # Converting update_state_fn() into a graph function, so that + # we can return a single op that performs all of the variable updates. + defuned_update_state_fn = function.defun(obj.update_state) + obj.update_state = types.MethodType( + update_state_wrapper(defuned_update_state_fn), obj) + obj.result = types.MethodType(result_wrapper(obj.result), obj) + return obj def __call__(self, *args, **kwargs): diff --git a/tensorflow/python/keras/metrics_test.py b/tensorflow/python/keras/metrics_test.py index 779c08c42d..4195ea18ad 100644 --- a/tensorflow/python/keras/metrics_test.py +++ b/tensorflow/python/keras/metrics_test.py @@ -212,7 +212,7 @@ class KerasMetricsTest(test.TestCase): self.assertAllClose( val_outs[2], history.history['val_true_positives'][-1], atol=1e-5) - @test_util.run_in_graph_and_eager_modes + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def test_mean(self): m = metrics.Mean(name='my_mean') @@ -394,7 +394,7 @@ class KerasMetricsTest(test.TestCase): self.assertTrue(acc_obj.stateful) self.assertEqual(len(acc_obj.variables), 2) self.assertEqual(acc_obj.dtype, dtypes.float32) - self.evaluate(variables.global_variables_initializer()) + self.evaluate(variables.variables_initializer(acc_obj.variables)) # verify that correct value is returned update_op = acc_obj.update_state([[0, 0, 1], [0, 1, 0]], diff --git a/tensorflow/python/keras/models.py b/tensorflow/python/keras/models.py index c3b7301eba..f0733a9105 100644 --- a/tensorflow/python/keras/models.py +++ b/tensorflow/python/keras/models.py @@ -414,10 +414,10 @@ def clone_and_build_model( this argument must be set to `True` (default `False`). To restore the original model, use the function `in_place_subclassed_model_state_restoration(model)`. - optimizer_iterations: An iterations variable to pass to the optimizer if - the model uses a TFOptimizer, and if the clone is compiled. This is used - when a Keras model is cloned into an Estimator model function, because - Estimators create their own global step variable. + optimizer_iterations: An iterations variable that will be incremented by the + optimizer if the clone is compiled. This argument is used when a Keras + model is cloned into an Estimator model function, because Estimators + create their own global step variable. Returns: Clone of the model. @@ -458,6 +458,8 @@ def clone_and_build_model( else: optimizer_config = model.optimizer.get_config() optimizer = model.optimizer.__class__.from_config(optimizer_config) + if optimizer_iterations is not None: + optimizer.iterations = optimizer_iterations clone.compile( optimizer, diff --git a/tensorflow/python/keras/models_test.py b/tensorflow/python/keras/models_test.py index 1d0f56f3c8..c550caeb80 100644 --- a/tensorflow/python/keras/models_test.py +++ b/tensorflow/python/keras/models_test.py @@ -25,7 +25,9 @@ import numpy as np from tensorflow.python import keras from tensorflow.python.eager import context +from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util +from tensorflow.python.keras import backend as K from tensorflow.python.keras import metrics from tensorflow.python.keras import models from tensorflow.python.ops import random_ops @@ -51,7 +53,7 @@ class TestModel(keras.Model): class TestModelCloning(test.TestCase): def test_clone_sequential_model(self): - with self.test_session(): + with self.cached_session(): val_a = np.random.random((10, 4)) val_out = np.random.random((10, 4)) @@ -64,7 +66,7 @@ class TestModelCloning(test.TestCase): # Everything should work in a new session. keras.backend.clear_session() - with self.test_session(): + with self.cached_session(): # With placeholder creation new_model = keras.models.clone_model(model) # update ops from batch norm needs to be included @@ -89,7 +91,7 @@ class TestModelCloning(test.TestCase): new_model.train_on_batch(None, val_out) def test_clone_functional_model(self): - with self.test_session(): + with self.cached_session(): val_a = np.random.random((10, 4)) val_b = np.random.random((10, 4)) val_out = np.random.random((10, 4)) @@ -110,7 +112,7 @@ class TestModelCloning(test.TestCase): # Everything should work in a new session. keras.backend.clear_session() - with self.test_session(): + with self.cached_session(): # With placeholder creation new_model = keras.models.clone_model(model) self.assertEquals(len(new_model.get_updates_for(new_model.inputs)), 2) @@ -137,7 +139,7 @@ class TestModelCloning(test.TestCase): @test_util.run_in_graph_and_eager_modes def test_clone_functional_model_with_masking(self): - with self.test_session(): + with self.cached_session(): x = np.array([[[1], [1]], [[0], [0]]]) inputs = keras.Input((2, 1)) outputs = keras.layers.Masking(mask_value=0)(inputs) @@ -238,7 +240,7 @@ class TestModelDeepCopy(test.TestCase): class TestCloneAndBuildModel(test.TestCase): def test_clone_and_build_non_compiled_model(self): - with self.test_session(): + with self.cached_session(): inp = np.random.random((10, 4)) out = np.random.random((10, 4)) @@ -251,7 +253,7 @@ class TestCloneAndBuildModel(test.TestCase): # Everything should work in a new session. keras.backend.clear_session() - with self.test_session(): + with self.cached_session(): # With placeholder creation new_model = models.clone_and_build_model(model, compile_clone=True) with self.assertRaisesRegexp(RuntimeError, 'must compile'): @@ -289,7 +291,7 @@ class TestCloneAndBuildModel(test.TestCase): # Everything should work in a new session. keras.backend.clear_session() - with self.test_session(): + with self.cached_session(): # With placeholder creation new_model = models.clone_and_build_model( model, compile_clone=True, in_place_reset=is_subclassed) @@ -316,7 +318,7 @@ class TestCloneAndBuildModel(test.TestCase): new_model.evaluate(inp, out) def test_clone_and_build_compiled_sequential_model(self): - with self.test_session(): + with self.cached_session(): model = keras.models.Sequential() model.add(keras.layers.Dense(4, input_shape=(4,))) model.add(keras.layers.BatchNormalization()) @@ -328,7 +330,7 @@ class TestCloneAndBuildModel(test.TestCase): self._clone_and_build_test_helper(model) def test_clone_and_build_functional_model(self): - with self.test_session(): + with self.cached_session(): input_a = keras.Input(shape=(4,)) dense_1 = keras.layers.Dense(4,) dense_2 = keras.layers.Dense(4,) @@ -358,12 +360,42 @@ class TestCloneAndBuildModel(test.TestCase): out = self.layer2(out) return out - with self.test_session(): + with self.cached_session(): model = SubclassedModel() model.compile('rmsprop', 'mse', metrics=['acc', metrics.categorical_accuracy]) self._clone_and_build_test_helper(model, True) + def assert_optimizer_iterations_increases(self, optimizer): + with self.cached_session(): + input_a = keras.Input(shape=(4,)) + dense_1 = keras.layers.Dense(4,) + dense_2 = keras.layers.Dense(4,) + + x_a = dense_1(input_a) + x_a = keras.layers.Dropout(0.5)(x_a) + x_a = keras.layers.BatchNormalization()(x_a) + x_a = dense_2(x_a) + model = keras.models.Model(input_a, x_a) + model.compile(optimizer, 'mse', + metrics=['acc', metrics.categorical_accuracy]) + + global_step = keras.backend.variable(123, dtype=dtypes.int64) + clone_model = models.clone_and_build_model( + model, compile_clone=True, optimizer_iterations=global_step) + + inp = np.random.random((10, 4)) + out = np.random.random((10, 4)) + clone_model.train_on_batch(inp, out) + + self.assertEqual(K.eval(global_step), 124) + + def test_replace_tf_optimizer_iterations_variable(self): + self.assert_optimizer_iterations_increases(adam.AdamOptimizer(0.01)) + + def test_replace_keras_optimizer_iterations_variable(self): + self.assert_optimizer_iterations_increases('adam') + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 3026c7755a..0403211d92 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -622,6 +622,7 @@ cuda_py_test( "//tensorflow/python:linalg_ops", "//tensorflow/python:math_ops", ], + tags = ["notap"], ) cuda_py_test( @@ -779,6 +780,7 @@ tf_py_test( size = "small", srcs = ["regex_full_match_op_test.py"], additional_deps = [ + "@absl_py//absl/testing:parameterized", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", @@ -1634,6 +1636,7 @@ cuda_py_test( srcs = ["functional_ops_test.py"], additional_deps = [ "//third_party/py/numpy", + "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:framework", diff --git a/tensorflow/python/kernel_tests/clip_ops_test.py b/tensorflow/python/kernel_tests/clip_ops_test.py index 400d38b936..de52a70cc0 100644 --- a/tensorflow/python/kernel_tests/clip_ops_test.py +++ b/tensorflow/python/kernel_tests/clip_ops_test.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops from tensorflow.python.ops import gradient_checker +from tensorflow.python.ops import gradients_impl from tensorflow.python.platform import test @@ -158,13 +159,19 @@ class ClipTest(test.TestCase): ans = clip_ops.clip_by_norm(x, clip_norm) tf_ans = ans.eval() - clip_tensor = constant_op.constant(4.0) ans = clip_ops.clip_by_norm(x, clip_norm) tf_ans_tensor = ans.eval() self.assertAllClose(np_ans, tf_ans) self.assertAllClose(np_ans, tf_ans_tensor) + def testClipByNormGradientZeros(self): + with self.test_session(use_gpu=True): + x = array_ops.zeros([3]) + b = clip_ops.clip_by_norm(x, 1.) + grad, = gradients_impl.gradients(b, x) + self.assertAllEqual(grad.eval(), [1., 1., 1.]) + def testClipByNormBadShape(self): with self.test_session(use_gpu=True): x = constant_op.constant([-3.0, 0.0, 0.0, 4.0, 0.0, 0.0], shape=[2, 3, 1]) diff --git a/tensorflow/python/kernel_tests/conditional_accumulator_test.py b/tensorflow/python/kernel_tests/conditional_accumulator_test.py index 7570523495..86802664d1 100644 --- a/tensorflow/python/kernel_tests/conditional_accumulator_test.py +++ b/tensorflow/python/kernel_tests/conditional_accumulator_test.py @@ -42,14 +42,22 @@ class ConditionalAccumulatorTest(test.TestCase): with ops.Graph().as_default(): q = data_flow_ops.ConditionalAccumulator(dtypes_lib.float32, name="Q") self.assertTrue(isinstance(q.accumulator_ref, ops.Tensor)) - self.assertProtoEquals(""" + self.assertProtoEquals( + """ name:'Q' op:'ConditionalAccumulator' attr { key: 'dtype' value { type: DT_FLOAT } } attr { key: 'shape' value { shape { unknown_rank: true} } } attr { key: 'container' value { s: '' } } attr { key: 'shared_name' value { s: '' } } + attr { key: 'reduction_type' value {s: 'MEAN'} } """, q.accumulator_ref.op.node_def) + def testConstructorWithInvalidArg(self): + with ops.Graph().as_default(): + with self.assertRaises(ValueError): + data_flow_ops.ConditionalAccumulator( + dtypes_lib.float32, name="Q", reduction_type="Invalid") + def testConstructorWithShape(self): with ops.Graph().as_default(): q = data_flow_ops.ConditionalAccumulator( @@ -57,7 +65,8 @@ class ConditionalAccumulatorTest(test.TestCase): name="Q", shape=tensor_shape.TensorShape([1, 5, 2, 8])) self.assertTrue(isinstance(q.accumulator_ref, ops.Tensor)) - self.assertProtoEquals(""" + self.assertProtoEquals( + """ name:'Q' op:'ConditionalAccumulator' attr { key: 'dtype' value { type: DT_FLOAT } } attr { key: 'shape' value { shape { dim {size: 1 } @@ -67,6 +76,7 @@ class ConditionalAccumulatorTest(test.TestCase): } } } attr { key: 'container' value { s: '' } } attr { key: 'shared_name' value { s: '' } } + attr { key: 'reduction_type' value {s: 'MEAN'} } """, q.accumulator_ref.op.node_def) def testAccumulatorSizeEmpty(self): @@ -237,12 +247,11 @@ class ConditionalAccumulatorTest(test.TestCase): extract_t.op.run() self.assertEqual(q.num_accumulated().eval(), 0) - def testAccumulatorTakeGrad(self): + def testAccumulatorTakeGradMean(self): with self.test_session(): q = data_flow_ops.ConditionalAccumulator( dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1])) elems = [10.0, 20.0] - elems_ave = sum(elems) / len(elems) accum_ops = [q.apply_grad((x,), local_step=0) for x in elems] takeg_t = q.take_grad(1) @@ -251,7 +260,7 @@ class ConditionalAccumulatorTest(test.TestCase): accum_op.run() val = takeg_t.eval() - self.assertEqual(elems_ave, val) + self.assertEqual(15.0, val) accum_ops = [q.apply_grad((x,), local_step=1) for x in elems] takeg_t = q.take_grad(constant_op.constant(1)) @@ -260,7 +269,42 @@ class ConditionalAccumulatorTest(test.TestCase): accum_op.run() val = takeg_t.eval() - self.assertEqual(elems_ave, val) + self.assertEqual(15.0, val) + + def testAccumulatorTakeGradSum(self): + with self.test_session(): + q = data_flow_ops.ConditionalAccumulator( + dtypes_lib.float32, + name="Q", + shape=tensor_shape.TensorShape([1]), + reduction_type="SUM") + elems = [10.0, 20.0] + + accum_ops = [q.apply_grad((x,), local_step=0) for x in elems] + takeg_t = q.take_grad(1) + + for accum_op in accum_ops: + accum_op.run() + + val = takeg_t.eval() + self.assertEqual(30.0, val) + + accum_ops = [q.apply_grad((x,), local_step=1) for x in elems] + takeg_t = q.take_grad(constant_op.constant(1)) + + for accum_op in accum_ops: + accum_op.run() + + val = takeg_t.eval() + self.assertEqual(30.0, val) + + def testAccumulatorTakeGradInvalidReductionType(self): + with self.assertRaises(ValueError): + data_flow_ops.ConditionalAccumulator( + dtypes_lib.float32, + name="Q", + shape=tensor_shape.TensorShape([1]), + reduction_type="Invalid") def testAccumulatorInvalidTakeGrad(self): with self.test_session(): @@ -277,7 +321,7 @@ class ConditionalAccumulatorTest(test.TestCase): with self.assertRaises(errors_impl.InvalidArgumentError): takeg_t.eval() - def testAccumulatorRepeatedTakeGrad(self): + def testAccumulatorRepeatedTakeGradMean(self): with self.test_session(): q = data_flow_ops.ConditionalAccumulator( dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1])) @@ -304,6 +348,36 @@ class ConditionalAccumulatorTest(test.TestCase): val = takeg_t.eval() self.assertEqual(elems_ave + 0.0, val) + def testAccumulatorRepeatedTakeGradSum(self): + with self.test_session(): + q = data_flow_ops.ConditionalAccumulator( + dtypes_lib.float32, + name="Q", + shape=tensor_shape.TensorShape([1]), + reduction_type="SUM") + + elems = [10.0, 20.0] + elems_sum = 30.0 + accum_ops = [q.apply_grad((x,), local_step=0) for x in elems] + takeg_t = q.take_grad(1) + + for accum_op in accum_ops: + accum_op.run() + + val = takeg_t.eval() + self.assertEqual(elems_sum, val) + + elems = [20.0, 30.0] + elems_sum = 50.0 + accum_ops = [q.apply_grad((x,), local_step=1) for x in elems] + takeg_t = q.take_grad(1) + + for accum_op in accum_ops: + accum_op.run() + + val = takeg_t.eval() + self.assertEqual(elems_sum, val) + def testAccumulatorIncrementGlobalStep(self): with self.test_session(): q = data_flow_ops.ConditionalAccumulator( diff --git a/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py index c4d4ce780b..49b9569e2b 100644 --- a/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py +++ b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py @@ -104,6 +104,27 @@ class DynamicStitchTestBase(object): # Dimension 0 is max(flatten(indices))+1. self.assertEqual([8, 2], stitched_t.get_shape().as_list()) + def testZeroSizeTensor(self): + with self.test_session(use_gpu=True): + indices = [ + constant_op.constant([0, 4, 7]), + constant_op.constant([1, 6]), + constant_op.constant([2, 3, 5]), + array_ops.zeros([0], dtype=dtypes.int32) + ] + data = [ + constant_op.constant([[0, 1], [40, 41], [70, 71]]), + constant_op.constant([[10, 11], [60, 61]]), + constant_op.constant([[20, 21], [30, 31], [50, 51]]), + array_ops.zeros([0, 2], dtype=dtypes.int32) + ] + stitched_t = self.stitch_op(indices, data) + stitched_val = stitched_t.eval() + self.assertAllEqual([[0, 1], [10, 11], [20, 21], [30, 31], [40, 41], + [50, 51], [60, 61], [70, 71]], stitched_val) + # Dimension 0 is max(flatten(indices))+1. + self.assertEqual([8, 2], stitched_t.get_shape().as_list()) + def testHigherRank(self): with self.test_session(use_gpu=True) as sess: indices = [ diff --git a/tensorflow/python/kernel_tests/functional_ops_test.py b/tensorflow/python/kernel_tests/functional_ops_test.py index 3ddb5e06c9..e39daf1371 100644 --- a/tensorflow/python/kernel_tests/functional_ops_test.py +++ b/tensorflow/python/kernel_tests/functional_ops_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np +from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.data.ops import iterator_ops @@ -738,6 +739,40 @@ class FunctionalOpsTest(test.TestCase): self.assertAllEqual(Run(sess, 20.), 210.) self.assertAllEqual(Run(sess, 100.), 5050.) + def testWhileLowering(self): + + def Run(n, fetch_by_name): + for use_gpu in (True, False): + with ops.Graph().as_default() as g: + + @function.Defun(*[dtypes.float32] * 2) + def Cond(n, unused_x): + return n > 0 + + @function.Defun(*[dtypes.float32] * 2) + def Body(n, x): + return n - 1, x + n + + # outputs: [0, n*(n+1)/2] + outputs = functional_ops.While([n, 0.], Cond, Body, name="my_while") + + # `outputs` is the list of output tensors of the While op. We + # arbitrarily choose the 0th tensor to get the While op and set the + # lowering attribute on it. + outputs[0].op._set_attr("_lower_using_switch_merge", + attr_value_pb2.AttrValue(b=True)) + if not fetch_by_name: + fetch = outputs[1] + else: + fetch = "my_while:1" + with self.test_session(graph=g, use_gpu=use_gpu) as sess: + return sess.run(fetch) + + self.assertAllEqual(Run(20., False), 210.) + self.assertAllEqual(Run(20., True), 210.) + self.assertAllEqual(Run(100., False), 5050.) + self.assertAllEqual(Run(100., True), 5050.) + def testWhileError(self): for use_gpu in (True, False): with ops.Graph().as_default() as g: diff --git a/tensorflow/python/kernel_tests/regex_full_match_op_test.py b/tensorflow/python/kernel_tests/regex_full_match_op_test.py index 5daae1b79b..7bd8c3ca27 100644 --- a/tensorflow/python/kernel_tests/regex_full_match_op_test.py +++ b/tensorflow/python/kernel_tests/regex_full_match_op_test.py @@ -18,37 +18,77 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized + +from tensorflow.python.compat import compat from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.ops import gen_string_ops from tensorflow.python.ops import string_ops from tensorflow.python.platform import test -class RegexFullMatchOpTest(test.TestCase): +@parameterized.parameters( + (gen_string_ops.regex_full_match), + (gen_string_ops.static_regex_full_match)) +class RegexFullMatchOpVariantsTest(test.TestCase, parameterized.TestCase): - def testRegexFullMatch(self): + def testRegexFullMatch(self, op): values = ["abaaba", "abcdabcde"] with self.test_session(): - input_vector = constant_op.constant(values, dtypes.string) - matched = string_ops.regex_full_match(input_vector, "a.*a").eval() + input_tensor = constant_op.constant(values, dtypes.string) + matched = op(input_tensor, "a.*a").eval() self.assertAllEqual([True, False], matched) - def testEmptyMatch(self): + def testRegexFullMatchTwoDims(self, op): + values = [["abaaba", "abcdabcde"], ["acdcba", "ebcda"]] + with self.test_session(): + input_tensor = constant_op.constant(values, dtypes.string) + matched = op(input_tensor, "a.*a").eval() + self.assertAllEqual([[True, False], [True, False]], matched) + + def testEmptyMatch(self, op): values = ["abc", "1"] with self.test_session(): - input_vector = constant_op.constant(values, dtypes.string) - matched = string_ops.regex_full_match(input_vector, "").eval() + input_tensor = constant_op.constant(values, dtypes.string) + matched = op(input_tensor, "").eval() self.assertAllEqual([False, False], matched) - def testInvalidPattern(self): + def testInvalidPattern(self, op): values = ["abc", "1"] with self.test_session(): - input_vector = constant_op.constant(values, dtypes.string) + input_tensor = constant_op.constant(values, dtypes.string) invalid_pattern = "A[" - matched = string_ops.regex_full_match(input_vector, invalid_pattern) + matched = op(input_tensor, invalid_pattern) with self.assertRaisesOpError("Invalid pattern"): matched.eval() +class RegexFullMatchOpTest(test.TestCase): + + def testRegexFullMatchDelegation(self): + with compat.forward_compatibility_horizon(2018, 11, 1): + with self.test_session(): + input_tensor = constant_op.constant("foo", dtypes.string) + pattern = "[a-z]" + op = string_ops.regex_full_match(input_tensor, pattern) + self.assertTrue(op.name.startswith("RegexFullMatch"), op.name) + + pattern_tensor = constant_op.constant("[a-z]*", dtypes.string) + op_tensor = string_ops.regex_full_match(input_tensor, pattern_tensor) + self.assertTrue(op_tensor.name.startswith("RegexFullMatch"), op.name) + + def testStaticRegexFullMatchDelegation(self): + with compat.forward_compatibility_horizon(2018, 11, 20): + with self.test_session(): + input_tensor = constant_op.constant("foo", dtypes.string) + pattern = "[a-z]*" + op = string_ops.regex_full_match(input_tensor, pattern) + self.assertTrue(op.name.startswith("StaticRegexFullMatch"), op.name) + + pattern_tensor = constant_op.constant("[a-z]*", dtypes.string) + op_vec = string_ops.regex_full_match(input_tensor, pattern_tensor) + self.assertTrue(op_vec.name.startswith("RegexFullMatch"), op.name) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py b/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py index d749843410..3bb5e899fe 100644 --- a/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py +++ b/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py @@ -61,14 +61,22 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase): q = data_flow_ops.SparseConditionalAccumulator( dtypes_lib.float32, name="Q") self.assertTrue(isinstance(q.accumulator_ref, ops.Tensor)) - self.assertProtoEquals(""" + self.assertProtoEquals( + """ name:'Q' op:'SparseConditionalAccumulator' attr { key: 'dtype' value { type: DT_FLOAT } } attr { key: 'shape' value { shape { unknown_rank: true} } } attr { key: 'container' value { s: '' } } attr { key: 'shared_name' value { s: '' } } + attr { key: 'reduction_type' value {s: 'MEAN'} } """, q.accumulator_ref.op.node_def) + def testConstructorWithInvalidArg(self): + with ops.Graph().as_default(): + with self.assertRaises(ValueError): + data_flow_ops.SparseConditionalAccumulator( + dtypes_lib.float32, name="Q", reduction_type="Invalid") + def testConstructorWithShape(self): with ops.Graph().as_default(): q = data_flow_ops.SparseConditionalAccumulator( @@ -76,7 +84,8 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase): name="Q", shape=tensor_shape.TensorShape([1, 5, 2, 8])) self.assertTrue(isinstance(q.accumulator_ref, ops.Tensor)) - self.assertProtoEquals(""" + self.assertProtoEquals( + """ name:'Q' op:'SparseConditionalAccumulator' attr { key: 'dtype' value { type: DT_FLOAT } } attr { key: 'shape' value { shape { dim {size: 1 } @@ -86,6 +95,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase): } } } attr { key: 'container' value { s: '' } } attr { key: 'shared_name' value { s: '' } } + attr { key: 'reduction_type' value {s: 'MEAN'} } """, q.accumulator_ref.op.node_def) def testAccumulatorSizeEmpty(self): @@ -164,7 +174,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase): result = sess.run(accums[i].take_indexed_slices_grad(1)) self._assertEqual_indexedslices(expected_tensors[i], result) - def testAccumulatorTakeGrad(self): + def testAccumulatorTakeGradMean(self): with self.test_session() as sess: q = data_flow_ops.SparseConditionalAccumulator( dtypes_lib.float32, name="Q", shape=()) @@ -180,9 +190,34 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase): takeg_t = q.take_indexed_slices_grad(1) val = sess.run(takeg_t) - self.assertAllEqual(val.indices, [0, 1, 2]) - self.assertAllEqual(val.values, [[0.5, 0.5], [0, 2], [3, 0]]) - self.assertAllEqual(val.dense_shape, [-1, 2]) + self.assertAllEqual([0, 1, 2], val.indices) + self.assertAllEqual([[0.5, 0.5], [0, 2], [3, 0]], val.values) + self.assertAllEqual([-1, 2], val.dense_shape) + + def testAccumulatorTakeGradSum(self): + with self.test_session() as sess: + q = data_flow_ops.SparseConditionalAccumulator( + dtypes_lib.float32, name="Q", shape=(), reduction_type="SUM") + + grad_indexed_slices = ops.IndexedSlices( + indices=[0, 1], values=np.array([[1, 0], [0, 2]]).astype(np.float32)) + accum_op = q.apply_indexed_slices_grad(grad_indexed_slices) + accum_op.run() + accum_op = q.apply_grad([0, 2], + np.array([[0, 1], [3, 0]]).astype(np.float32), + [3, 2]) + accum_op.run() + + takeg_t = q.take_indexed_slices_grad(1) + val = sess.run(takeg_t) + self.assertAllEqual([0, 1, 2], val.indices) + self.assertAllEqual([[1, 1], [0, 2], [3, 0]], val.values) + self.assertAllEqual([-1, 2], val.dense_shape) + + def testAccumulatorTakeGradInvalidReductionType(self): + with self.assertRaises(ValueError): + data_flow_ops.SparseConditionalAccumulator( + dtypes_lib.float32, name="Q", shape=(), reduction_type="Invalid") def testAccumulatorRepeatedTakeGrad(self): with self.test_session() as sess: @@ -222,7 +257,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase): self.assertAllEqual(val.values, [[5, 5], [0, 20], [30, 0]]) self.assertAllEqual(val.dense_shape, [-1, 2]) - def testParallelApplyGrad(self): + def testParallelApplyGradMean(self): with self.test_session() as sess: q = data_flow_ops.SparseConditionalAccumulator( dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([2, 2])) @@ -253,6 +288,40 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase): np.array([[expected_val, 0], [0, expected_val]]).astype(np.float32), val, sess) + def testParallelApplyGradSum(self): + with self.test_session() as sess: + q = data_flow_ops.SparseConditionalAccumulator( + dtypes_lib.float32, + name="Q", + shape=tensor_shape.TensorShape([2, 2]), + reduction_type="SUM") + elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0] + accum_ops = [] + for x in elems: + x = _indexedslice(np.array([[x, 0], [0, x]]).astype(np.float32)) + accum_ops.append(q.apply_indexed_slices_grad(x, local_step=0)) + takeg_t = q.take_indexed_slices_grad(1) + + def apply_indexed_slices_grad(accum_op): + sess.run(accum_op) + + threads = [ + self.checkedThread(target=apply_indexed_slices_grad, args=(o,)) + for o in accum_ops + ] + + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + val = sess.run(takeg_t) + + expected_val = 550.0 + self._assertEqual_nparray( + np.array([[expected_val, 0], [0, expected_val]]).astype(np.float32), + val, sess) + def testParallelTakeGrad(self): with self.test_session() as sess: q = data_flow_ops.SparseConditionalAccumulator( diff --git a/tensorflow/python/lib/io/py_record_reader.cc b/tensorflow/python/lib/io/py_record_reader.cc index 9500fc6a7c..07ce071845 100644 --- a/tensorflow/python/lib/io/py_record_reader.cc +++ b/tensorflow/python/lib/io/py_record_reader.cc @@ -30,6 +30,8 @@ namespace io { PyRecordReader::PyRecordReader() {} +// NOTE(sethtroisi): At this time PyRecordReader doesn't benefit from taking +// RecordReaderOptions, if this changes the API can be updated at that time. PyRecordReader* PyRecordReader::New(const string& filename, uint64 start_offset, const string& compression_type_string, TF_Status* out_status) { diff --git a/tensorflow/python/lib/io/py_record_writer.cc b/tensorflow/python/lib/io/py_record_writer.cc index e4e5268b0f..faf20df868 100644 --- a/tensorflow/python/lib/io/py_record_writer.cc +++ b/tensorflow/python/lib/io/py_record_writer.cc @@ -28,7 +28,7 @@ namespace io { PyRecordWriter::PyRecordWriter() {} PyRecordWriter* PyRecordWriter::New(const string& filename, - const string& compression_type_string, + const io::RecordWriterOptions& options, TF_Status* out_status) { std::unique_ptr<WritableFile> file; Status s = Env::Default()->NewWritableFile(filename, &file); @@ -38,10 +38,6 @@ PyRecordWriter* PyRecordWriter::New(const string& filename, } PyRecordWriter* writer = new PyRecordWriter; writer->file_ = std::move(file); - - RecordWriterOptions options = - RecordWriterOptions::CreateRecordWriterOptions(compression_type_string); - writer->writer_.reset(new RecordWriter(writer->file_.get(), options)); return writer; } diff --git a/tensorflow/python/lib/io/py_record_writer.h b/tensorflow/python/lib/io/py_record_writer.h index 61a4960ee6..9b0792c6db 100644 --- a/tensorflow/python/lib/io/py_record_writer.h +++ b/tensorflow/python/lib/io/py_record_writer.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/c/c_api.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/io/record_writer.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -36,10 +37,8 @@ class RecordWriter; // by multiple threads. class PyRecordWriter { public: - // TODO(vrv): make this take a shared proto to configure - // the compression options. static PyRecordWriter* New(const string& filename, - const string& compression_type_string, + const io::RecordWriterOptions& compression_options, TF_Status* out_status); ~PyRecordWriter(); diff --git a/tensorflow/python/lib/io/py_record_writer.i b/tensorflow/python/lib/io/py_record_writer.i index 3181c9afce..b2c2bda5dd 100644 --- a/tensorflow/python/lib/io/py_record_writer.i +++ b/tensorflow/python/lib/io/py_record_writer.i @@ -18,6 +18,11 @@ limitations under the License. %include "tensorflow/python/platform/base.i" %include "tensorflow/python/lib/core/strings.i" +// Define int8_t explicitly instead of including "stdint.i", since "stdint.h" +// and "stdint.i" disagree on the definition of int64_t. +typedef signed char int8; +%{ typedef signed char int8; %} + %feature("except") tensorflow::io::PyRecordWriter::New { // Let other threads run while we write Py_BEGIN_ALLOW_THREADS @@ -26,6 +31,7 @@ limitations under the License. } %newobject tensorflow::io::PyRecordWriter::New; +%newobject tensorflow::io::RecordWriterOptions::CreateRecordWriterOptions; %feature("except") tensorflow::io::PyRecordWriter::WriteRecord { // Let other threads run while we write @@ -35,6 +41,8 @@ limitations under the License. } %{ +#include "tensorflow/core/lib/io/record_writer.h" +#include "tensorflow/core/lib/io/zlib_compression_options.h" #include "tensorflow/python/lib/io/py_record_writer.h" %} @@ -48,7 +56,21 @@ limitations under the License. %unignore tensorflow::io::PyRecordWriter::Flush; %unignore tensorflow::io::PyRecordWriter::Close; %unignore tensorflow::io::PyRecordWriter::New; +%unignore tensorflow::io::ZlibCompressionOptions; +%unignore tensorflow::io::ZlibCompressionOptions::flush_mode; +%unignore tensorflow::io::ZlibCompressionOptions::input_buffer_size; +%unignore tensorflow::io::ZlibCompressionOptions::output_buffer_size; +%unignore tensorflow::io::ZlibCompressionOptions::window_bits; +%unignore tensorflow::io::ZlibCompressionOptions::compression_level; +%unignore tensorflow::io::ZlibCompressionOptions::compression_method; +%unignore tensorflow::io::ZlibCompressionOptions::mem_level; +%unignore tensorflow::io::ZlibCompressionOptions::compression_strategy; +%unignore tensorflow::io::RecordWriterOptions; +%unignore tensorflow::io::RecordWriterOptions::CreateRecordWriterOptions; +%unignore tensorflow::io::RecordWriterOptions::zlib_options; +%include "tensorflow/core/lib/io/record_writer.h" +%include "tensorflow/core/lib/io/zlib_compression_options.h" %include "tensorflow/python/lib/io/py_record_writer.h" %unignoreall diff --git a/tensorflow/python/lib/io/tf_record.py b/tensorflow/python/lib/io/tf_record.py index 2b3e986f6b..cce71a2bab 100644 --- a/tensorflow/python/lib/io/tf_record.py +++ b/tensorflow/python/lib/io/tf_record.py @@ -33,8 +33,6 @@ class TFRecordCompressionType(object): GZIP = 2 -# NOTE(vrv): This will eventually be converted into a proto. to match -# the interface used by the C++ RecordWriter. @tf_export("python_io.TFRecordOptions") class TFRecordOptions(object): """Options used for manipulating TFRecord files.""" @@ -44,14 +42,105 @@ class TFRecordOptions(object): TFRecordCompressionType.NONE: "" } - def __init__(self, compression_type): + def __init__(self, + compression_type=None, + flush_mode=None, + input_buffer_size=None, + output_buffer_size=None, + window_bits=None, + compression_level=None, + compression_method=None, + mem_level=None, + compression_strategy=None): + # pylint: disable=line-too-long + """Creates a `TFRecordOptions` instance. + + Options only effect TFRecordWriter when compression_type is not `None`. + Documentation, details, and defaults can be found in + [`zlib_compression_options.h`](https://www.tensorflow.org/code/tensorflow/core/lib/io/zlib_compression_options.h) + and in the [zlib manual](http://www.zlib.net/manual.html). + Leaving an option as `None` allows C++ to set a reasonable default. + + Args: + compression_type: `TFRecordCompressionType` or `None`. + flush_mode: flush mode or `None`, Default: Z_NO_FLUSH. + input_buffer_size: int or `None`. + output_buffer_size: int or `None`. + window_bits: int or `None`. + compression_level: 0 to 9, or `None`. + compression_method: compression method or `None`. + mem_level: 1 to 9, or `None`. + compression_strategy: strategy or `None`. Default: Z_DEFAULT_STRATEGY. + + Returns: + A `TFRecordOptions` object. + + Raises: + ValueError: If compression_type is invalid. + """ + # pylint: enable=line-too-long + # Check compression_type is valid, but for backwards compatibility don't + # immediately convert to a string. + self.get_compression_type_string(compression_type) self.compression_type = compression_type + self.flush_mode = flush_mode + self.input_buffer_size = input_buffer_size + self.output_buffer_size = output_buffer_size + self.window_bits = window_bits + self.compression_level = compression_level + self.compression_method = compression_method + self.mem_level = mem_level + self.compression_strategy = compression_strategy @classmethod def get_compression_type_string(cls, options): + """Convert various option types to a unified string. + + Args: + options: `TFRecordOption`, `TFRecordCompressionType`, or string. + + Returns: + Compression type as string (e.g. `'ZLIB'`, `'GZIP'`, or `''`). + + Raises: + ValueError: If compression_type is invalid. + """ if not options: return "" - return cls.compression_type_map[options.compression_type] + elif isinstance(options, TFRecordOptions): + return cls.get_compression_type_string(options.compression_type) + elif isinstance(options, TFRecordCompressionType): + return cls.compression_type_map[options] + elif options in TFRecordOptions.compression_type_map: + return cls.compression_type_map[options] + elif options in TFRecordOptions.compression_type_map.values(): + return options + else: + raise ValueError('Not a valid compression_type: "{}"'.format(options)) + + def _as_record_writer_options(self): + """Convert to RecordWriterOptions for use with PyRecordWriter.""" + options = pywrap_tensorflow.RecordWriterOptions_CreateRecordWriterOptions( + compat.as_bytes( + self.get_compression_type_string(self.compression_type))) + + if self.flush_mode is not None: + options.zlib_options.flush_mode = self.flush_mode + if self.input_buffer_size is not None: + options.zlib_options.input_buffer_size = self.input_buffer_size + if self.output_buffer_size is not None: + options.zlib_options.output_buffer_size = self.output_buffer_size + if self.window_bits is not None: + options.zlib_options.window_bits = self.window_bits + if self.compression_level is not None: + options.zlib_options.compression_level = self.compression_level + if self.compression_method is not None: + options.zlib_options.compression_method = self.compression_method + if self.mem_level is not None: + options.zlib_options.mem_level = self.mem_level + if self.compression_strategy is not None: + options.zlib_options.compression_strategy = self.compression_strategy + return options @tf_export("python_io.tf_record_iterator") @@ -100,16 +189,21 @@ class TFRecordWriter(object): Args: path: The path to the TFRecords file. - options: (optional) A TFRecordOptions object. + options: (optional) String specifying compression type, + `TFRecordCompressionType`, or `TFRecordOptions` object. Raises: IOError: If `path` cannot be opened for writing. + ValueError: If valid compression_type can't be determined from `options`. """ - compression_type = TFRecordOptions.get_compression_type_string(options) + if not isinstance(options, TFRecordOptions): + options = TFRecordOptions(compression_type=options) with errors.raise_exception_on_not_ok_status() as status: + # pylint: disable=protected-access self._writer = pywrap_tensorflow.PyRecordWriter_New( - compat.as_bytes(path), compat.as_bytes(compression_type), status) + compat.as_bytes(path), options._as_record_writer_options(), status) + # pylint: enable=protected-access def __enter__(self): """Enter a `with` block.""" diff --git a/tensorflow/python/lib/io/tf_record_test.py b/tensorflow/python/lib/io/tf_record_test.py index b853b64ae4..def8fe23e5 100644 --- a/tensorflow/python/lib/io/tf_record_test.py +++ b/tensorflow/python/lib/io/tf_record_test.py @@ -20,6 +20,8 @@ from __future__ import print_function import gzip import os +import random +import string import zlib import six @@ -131,9 +133,6 @@ class TFCompressionTestCase(test.TestCase): class TFRecordWriterTest(TFCompressionTestCase): - def setUp(self): - super(TFRecordWriterTest, self).setUp() - def _AssertFilesEqual(self, a, b, equal): for an, bn in zip(a, b): with open(an, "rb") as af, open(bn, "rb") as bf: @@ -142,6 +141,37 @@ class TFRecordWriterTest(TFCompressionTestCase): else: self.assertNotEqual(af.read(), bf.read()) + def _CompressionSizeDelta(self, records, options_a, options_b): + """Validate compression with options_a and options_b and return size delta. + + Compress records with options_a and options_b. Uncompress both compressed + files and assert that the contents match the original records. Finally + calculate how much smaller the file compressed with options_a was than the + file compressed with options_b. + + Args: + records: The records to compress + options_a: First set of options to compress with, the baseline for size. + options_b: Second set of options to compress with. + + Returns: + The difference in file size when using options_a vs options_b. A positive + value means options_a was a better compression than options_b. A negative + value means options_b had better compression than options_a. + + """ + + fn_a = self._WriteRecordsToFile(records, "tfrecord_a", options=options_a) + test_a = list(tf_record.tf_record_iterator(fn_a, options=options_a)) + self.assertEqual(records, test_a, options_a) + + fn_b = self._WriteRecordsToFile(records, "tfrecord_b", options=options_b) + test_b = list(tf_record.tf_record_iterator(fn_b, options=options_b)) + self.assertEqual(records, test_b, options_b) + + # Negative number => better compression. + return os.path.getsize(fn_a) - os.path.getsize(fn_b) + def testWriteReadZLibFiles(self): # Write uncompressed then compress manually. options = tf_record.TFRecordOptions(TFRecordCompressionType.NONE) @@ -188,6 +218,76 @@ class TFRecordWriterTest(TFCompressionTestCase): ] self._AssertFilesEqual(uncompressed_files, files, True) + def testNoCompressionType(self): + self.assertEqual( + "", + tf_record.TFRecordOptions.get_compression_type_string( + tf_record.TFRecordOptions())) + + self.assertEqual( + "", + tf_record.TFRecordOptions.get_compression_type_string( + tf_record.TFRecordOptions(""))) + + with self.assertRaises(ValueError): + tf_record.TFRecordOptions(5) + + with self.assertRaises(ValueError): + tf_record.TFRecordOptions("BZ2") + + def testZlibCompressionType(self): + zlib_t = tf_record.TFRecordCompressionType.ZLIB + + self.assertEqual( + "ZLIB", + tf_record.TFRecordOptions.get_compression_type_string( + tf_record.TFRecordOptions("ZLIB"))) + + self.assertEqual( + "ZLIB", + tf_record.TFRecordOptions.get_compression_type_string( + tf_record.TFRecordOptions(zlib_t))) + + self.assertEqual( + "ZLIB", + tf_record.TFRecordOptions.get_compression_type_string( + tf_record.TFRecordOptions(tf_record.TFRecordOptions(zlib_t)))) + + def testCompressionOptions(self): + # Create record with mix of random and repeated data to test compression on. + rnd = random.Random(123) + random_record = compat.as_bytes( + "".join(rnd.choice(string.digits) for _ in range(10000))) + repeated_record = compat.as_bytes(_TEXT) + for _ in range(10000): + start_i = rnd.randint(0, len(_TEXT)) + length = rnd.randint(10, 200) + repeated_record += _TEXT[start_i:start_i + length] + records = [random_record, repeated_record, random_record] + + tests = [ + ("compression_level", 2, -1), # Lower compression is worse. + ("compression_level", 6, 0), # Default compression_level is equal. + ("flush_mode", zlib.Z_FULL_FLUSH, 1), # A few less bytes. + ("flush_mode", zlib.Z_NO_FLUSH, 0), # NO_FLUSH is the default. + ("input_buffer_size", 4096, 0), # Increases time not size. + ("output_buffer_size", 4096, 0), # Increases time not size. + ("window_bits", 8, -1), # Smaller than default window increases size. + ("compression_strategy", zlib.Z_HUFFMAN_ONLY, -1), # Worse. + ("compression_strategy", zlib.Z_FILTERED, -1), # Worse. + ] + + compression_type = tf_record.TFRecordCompressionType.ZLIB + options_a = tf_record.TFRecordOptions(compression_type) + for prop, value, delta_sign in tests: + options_b = tf_record.TFRecordOptions( + compression_type=compression_type, **{prop: value}) + delta = self._CompressionSizeDelta(records, options_a, options_b) + self.assertTrue( + delta == 0 if delta_sign == 0 else delta // delta_sign > 0, + "Setting {} = {}, file was {} smaller didn't match sign of {}".format( + prop, value, delta, delta_sign)) + class TFRecordWriterZlibTest(TFCompressionTestCase): @@ -318,6 +418,7 @@ class TFRecordIteratorTest(TFCompressionTestCase): for _ in tf_record.tf_record_iterator(fn_truncated): pass + class TFRecordWriterCloseAndFlushTests(test.TestCase): def setUp(self, compression_type=TFRecordCompressionType.NONE): diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 21ccbc6c33..c8b883350d 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -1275,7 +1275,7 @@ unique_with_counts.__doc__ = gen_array_ops.unique_with_counts.__doc__ def split(value, num_or_size_splits, axis=0, num=None, name="split"): """Splits a tensor into sub tensors. - If `num_or_size_splits` is an integer type, `num_split`, then splits `value` + If `num_or_size_splits` is an integer type, then `value` is split along dimension `axis` into `num_split` smaller tensors. Requires that `num_split` evenly divides `value.shape[axis]`. diff --git a/tensorflow/python/ops/clip_ops.py b/tensorflow/python/ops/clip_ops.py index 78b395a6c1..29468431b3 100644 --- a/tensorflow/python/ops/clip_ops.py +++ b/tensorflow/python/ops/clip_ops.py @@ -144,7 +144,11 @@ def clip_by_norm(t, clip_norm, axes=None, name=None): t = ops.convert_to_tensor(t, name="t") # Calculate L2-norm, clip elements by ratio of clip_norm to L2-norm - l2norm = math_ops.sqrt(math_ops.reduce_sum(t * t, axes, keepdims=True)) + l2sum = math_ops.reduce_sum(t * t, axes, keepdims=True) + pred = l2sum > 0 + # Two-tap tf.where trick to bypass NaN gradients + l2sum_safe = array_ops.where(pred, l2sum, array_ops.ones_like(l2sum)) + l2norm = array_ops.where(pred, math_ops.sqrt(l2sum_safe), l2sum) intermediate = t * clip_norm # Assert that the shape is compatible with the initial shape, # to prevent unintentional broadcasting. diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py index 7af2ca56be..69c0fcbbee 100644 --- a/tensorflow/python/ops/data_flow_ops.py +++ b/tensorflow/python/ops/data_flow_ops.py @@ -1229,7 +1229,8 @@ class ConditionalAccumulator(ConditionalAccumulatorBase): dtype, shape=None, shared_name=None, - name="conditional_accumulator"): + name="conditional_accumulator", + reduction_type="MEAN"): """Creates a new ConditionalAccumulator. Args: @@ -1238,9 +1239,14 @@ class ConditionalAccumulator(ConditionalAccumulatorBase): shared_name: Optional. If non-empty, this accumulator will be shared under the given name across multiple sessions. name: Optional name for the accumulator. + reduction_type: Reduction type to use when taking the gradient. """ accumulator_ref = gen_data_flow_ops.conditional_accumulator( - dtype=dtype, shape=shape, shared_name=shared_name, name=name) + dtype=dtype, + shape=shape, + shared_name=shared_name, + name=name, + reduction_type=reduction_type) super(ConditionalAccumulator, self).__init__(dtype, shape, accumulator_ref) def apply_grad(self, grad, local_step=0, name=None): @@ -1312,15 +1318,21 @@ class SparseConditionalAccumulator(ConditionalAccumulatorBase): shared_name: Optional. If non-empty, this accumulator will be shared under the given name across multiple sessions. name: Optional name for the accumulator. + reduction_type: Reduction type to use when taking the gradient. """ def __init__(self, dtype, shape=None, shared_name=None, - name="sparse_conditional_accumulator"): + name="sparse_conditional_accumulator", + reduction_type="MEAN"): accumulator_ref = gen_data_flow_ops.sparse_conditional_accumulator( - dtype=dtype, shape=shape, shared_name=shared_name, name=name) + dtype=dtype, + shape=shape, + shared_name=shared_name, + name=name, + reduction_type=reduction_type) super(SparseConditionalAccumulator, self).__init__(dtype, shape, accumulator_ref) diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index 12356944f8..de260f3140 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -330,6 +330,8 @@ def _random_flip(image, flip_index, seed, scope_name): lambda: image, name=scope ) + if isinstance(result, tuple): + result = result[0] # TODO(b/111124878) remove this logic (CondV2). return fix_image_flip_shape(image, result) elif shape.ndims == 4: uniform_random = random_ops.random_uniform( diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py index f7502c4018..795e6bbc3e 100644 --- a/tensorflow/python/ops/image_ops_test.py +++ b/tensorflow/python/ops/image_ops_test.py @@ -3657,6 +3657,47 @@ class NonMaxSuppressionTest(test_util.TensorFlowTestCase): scores = constant_op.constant([0.9]) image_ops.non_max_suppression(boxes, scores, 3, [[0.5]]) + def testDataTypes(self): + # Test case for GitHub issue 20199. + boxes_np = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], + [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]] + scores_np = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3] + max_output_size_np = 3 + iou_threshold_np = 0.5 + # Note: There are multiple versions of non_max_suppression v2, v3, v4. + # gen_image_ops.non_max_suppression_v2: + for dtype in [np.float16, np.float32]: + with self.test_session(): + boxes = constant_op.constant(boxes_np, dtype=dtype) + scores = constant_op.constant(scores_np, dtype=dtype) + max_output_size = constant_op.constant(max_output_size_np) + iou_threshold = constant_op.constant(iou_threshold_np) + selected_indices = gen_image_ops.non_max_suppression_v2( + boxes, scores, max_output_size, iou_threshold).eval() + self.assertAllClose(selected_indices, [3, 0, 5]) + # image_ops.non_max_suppression = gen_image_ops.non_max_suppression_v3. + for dtype in [np.float16, np.float32]: + with self.test_session(): + boxes = constant_op.constant(boxes_np, dtype=dtype) + scores = constant_op.constant(scores_np, dtype=dtype) + max_output_size = constant_op.constant(max_output_size_np) + iou_threshold = constant_op.constant(iou_threshold_np) + selected_indices = image_ops.non_max_suppression( + boxes, scores, max_output_size, iou_threshold).eval() + self.assertAllClose(selected_indices, [3, 0, 5]) + # gen_image_ops.non_max_suppression_v4. + score_threshold = float('-inf') + for dtype in [np.float16, np.float32]: + with self.test_session(): + boxes = constant_op.constant(boxes_np, dtype=dtype) + scores = constant_op.constant(scores_np, dtype=dtype) + max_output_size = constant_op.constant(max_output_size_np) + iou_threshold = constant_op.constant(iou_threshold_np) + selected_indices, _ = gen_image_ops.non_max_suppression_v4( + boxes, scores, max_output_size, iou_threshold, score_threshold) + selected_indices = selected_indices.eval() + self.assertAllClose(selected_indices, [3, 0, 5]) + class NonMaxSuppressionPaddedTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 9b0ab00c7a..33e7a5533b 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -2571,7 +2571,7 @@ def _unsorted_segment_N(data, segment_ids, num_segments): @tf_export("unsorted_segment_mean") def unsorted_segment_mean(data, segment_ids, num_segments, name=None): - r""" Computes the mean along segments of a tensor. + r"""Computes the mean along segments of a tensor. Read [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#segmentation) @@ -2582,17 +2582,26 @@ def unsorted_segment_mean(data, segment_ids, num_segments, name=None): Instead of computing the sum over segments, it computes the mean of all entries belonging to a segment such that: - \\(output_i = 1/N_i \sum data_j\\) where the sum is over `j` such - that `segment_ids[j] == i` with \\N_i\\ being the number of occurrences - of id \\i\\. + \\(output_i = 1/N_i \sum_{j...} data[j...]\\) where the sum is over tuples + `j...` such that `segment_ids[j...] == i` with \\N_i\\ being the number of + occurrences of id \\i\\. If there is no entry for a given segment ID `i`, it outputs 0. - segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s - first dimension. + If the given segment ID `i` is negative, the value is dropped and will not + be added to the sum of the segment. - output: Has same shape as data, except for dimension 0 which - has size `num_segments`. + Args: + data: A `Tensor` with floating point or complex dtype. + segment_ids: An integer tensor whose shape is a prefix of `data.shape`. + num_segments: An integer scalar `Tensor`. The number of distinct + segment IDs. + name: A name for the operation (optional). + + Returns: + A `Tensor`. Has same shape as data, except for the first `segment_ids.rank` + dimensions, which are replaced with a single dimension which has size + `num_segments`. """ with ops.name_scope(name, "UnsortedSegmentMean"): data = ops.convert_to_tensor(data) @@ -2615,20 +2624,29 @@ def unsorted_segment_sqrt_n(data, segment_ids, num_segments, name=None): Additionally to computing the sum over segments, it divides the results by sqrt(N). - \\(output_i = 1/sqrt(N_i) \sum data_j\\) where the sum is over `j` such - that `segment_ids[j] == i` with \\N_i\\ being the number of occurrences - of id \\i\\. + \\(output_i = 1/sqrt(N_i) \sum_{j...} data[j...]\\) where the sum is over + tuples `j...` such that `segment_ids[j...] == i` with \\N_i\\ being the + number of occurrences of id \\i\\. If there is no entry for a given segment ID `i`, it outputs 0. Note that this op only supports floating point and complex dtypes, due to tf.sqrt only supporting these types. - segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s - first dimension. + If the given segment ID `i` is negative, the value is dropped and will not + be added to the sum of the segment. - output: Has same shape as data, except for dimension 0 which - has size `num_segments`. + Args: + data: A `Tensor` with floating point or complex dtype. + segment_ids: An integer tensor whose shape is a prefix of `data.shape`. + num_segments: An integer scalar `Tensor`. The number of distinct + segment IDs. + name: A name for the operation (optional). + + Returns: + A `Tensor`. Has same shape as data, except for the first `segment_ids.rank` + dimensions, which are replaced with a single dimension which has size + `num_segments`. """ with ops.name_scope(name, "UnsortedSegmentSqrtN"): data = ops.convert_to_tensor(data) diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index 4800352ac2..55c2eb5fa4 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -750,7 +750,7 @@ class ResourceVariable(variables.RefVariable): def _read_variable_op(self): if self.trainable: - tape.watch_variable(self) + tape.variable_accessed(self) result = gen_resource_variable_ops.read_variable_op(self._handle, self._dtype) if not context.executing_eagerly(): @@ -781,7 +781,7 @@ class ResourceVariable(variables.RefVariable): """Reads the value of this variable sparsely, using `gather`.""" with ops.name_scope("Gather" if name is None else name) as name: if self.trainable: - tape.watch_variable(self) + tape.variable_accessed(self) value = gen_resource_variable_ops.resource_gather( self._handle, indices, dtype=self._dtype, name=name) return array_ops.identity(value) @@ -949,12 +949,12 @@ class ResourceVariable(variables.RefVariable): def _lazy_read(self, op): if self.trainable: - tape.watch_variable(self) + tape.variable_accessed(self) return _UnreadVariable( handle=self._handle, dtype=self.dtype, shape=self._shape, in_graph_mode=self._in_graph_mode, deleter=self._handle_deleter if not self._in_graph_mode else None, - parent_op=op, parent_name=self._handle_name, unique_id=self._unique_id) + parent_op=op, unique_id=self._unique_id) def assign(self, value, use_locking=None, name=None, read_value=True): """Assigns a new value to this variable. @@ -1293,8 +1293,7 @@ class _UnreadVariable(ResourceVariable): """ def __init__(self, handle, dtype, # pylint: disable=super-init-not-called - shape, in_graph_mode, deleter, parent_op, parent_name, - unique_id): + shape, in_graph_mode, deleter, parent_op, unique_id): # We do not call super init on purpose. self._trainable = False self._save_slice_info = None diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py index fa13568596..c11c9ccaae 100644 --- a/tensorflow/python/ops/rnn_cell_impl.py +++ b/tensorflow/python/ops/rnn_cell_impl.py @@ -428,7 +428,7 @@ class BasicRNNCell(LayerRNNCell): def build(self, inputs_shape): if inputs_shape[-1] is None: raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" - % inputs_shape) + % str(input_shape)) input_depth = inputs_shape[-1] self._kernel = self.add_variable( @@ -525,7 +525,7 @@ class GRUCell(LayerRNNCell): def build(self, inputs_shape): if inputs_shape[-1] is None: raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" - % inputs_shape) + % str(input_shape)) input_depth = inputs_shape[-1] self._gate_kernel = self.add_variable( @@ -705,7 +705,7 @@ class BasicLSTMCell(LayerRNNCell): def build(self, inputs_shape): if inputs_shape[-1] is None: raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" - % inputs_shape) + % str(input_shape)) input_depth = inputs_shape[-1] h_depth = self._num_units @@ -783,10 +783,10 @@ class LSTMCell(LayerRNNCell): The default non-peephole implementation is based on: - http://www.bioinf.jku.at/publications/older/2604.pdf + https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf - S. Hochreiter and J. Schmidhuber. - "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997. + Felix Gers, Jurgen Schmidhuber, and Fred Cummins. + "Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999. The peephole implementation is based on: @@ -908,7 +908,7 @@ class LSTMCell(LayerRNNCell): def build(self, inputs_shape): if inputs_shape[-1] is None: raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" - % inputs_shape) + % str(input_shape)) input_depth = inputs_shape[-1] h_depth = self._num_units if self._num_proj is None else self._num_proj diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py index c832ba4e2a..29fefbe3a5 100644 --- a/tensorflow/python/ops/string_ops.py +++ b/tensorflow/python/ops/string_ops.py @@ -41,12 +41,41 @@ from tensorflow.python.util import deprecation from tensorflow.python.util.tf_export import tf_export # pylint: enable=wildcard-import + +# pylint: disable=redefined-builtin +def regex_full_match(input, pattern, name=None): + r"""Match elements of `input` with regex `pattern`. + + Args: + input: string `Tensor`, the source strings to process. + pattern: string or scalar string `Tensor`, regular expression to use, + see more details at https://github.com/google/re2/wiki/Syntax + name: Name of the op. + + Returns: + bool `Tensor` of the same shape as `input` with match results. + """ + # TODO(b/112455102): Remove compat.forward_compatible once past the horizon. + if not compat.forward_compatible(2018, 11, 10): + return gen_string_ops.regex_full_match( + input=input, pattern=pattern, name=name) + if isinstance(pattern, util_compat.bytes_or_text_types): + # When `pattern` is static through the life of the op we can + # use a version which performs the expensive regex compilation once at + # creation time. + return gen_string_ops.static_regex_full_match( + input=input, pattern=pattern, name=name) + return gen_string_ops.regex_full_match( + input=input, pattern=pattern, name=name) + +regex_full_match.__doc__ = gen_string_ops.regex_full_match.__doc__ + # Expose regex_full_match in strings namespace tf_export("strings.regex_full_match")(regex_full_match) def regex_replace(source, pattern, rewrite, replace_global=True): - r"""Replace elements of `source` matching regex `pattern with `rewrite`. + r"""Replace elements of `source` matching regex `pattern` with `rewrite`. Args: source: string `Tensor`, the source strings to process. @@ -128,6 +157,7 @@ def string_split(source, delimiter=" ", skip_empty=True): # pylint: disable=inv shape.set_shape([2]) return sparse_tensor.SparseTensor(indices, values, shape) + @tf_export("strings.split") def string_split_v2(source, sep=None, maxsplit=-1): """Split elements of `source` based on `sep` into a `SparseTensor`. @@ -170,7 +200,7 @@ def string_split_v2(source, sep=None, maxsplit=-1): second column corresponds to the index of the split component in this row. """ if sep is None: - sep = '' + sep = "" sep = ops.convert_to_tensor(sep, dtype=dtypes.string) source = ops.convert_to_tensor(source, dtype=dtypes.string) diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i index a31861ae40..be8f425481 100755 --- a/tensorflow/python/pywrap_tfe.i +++ b/tensorflow/python/pywrap_tfe.i @@ -52,9 +52,10 @@ limitations under the License. %rename("%s") TFE_Py_TapeSetShouldRecord; %rename("%s") TFE_Py_TapeSetDeleteTrace; %rename("%s") TFE_Py_TapeSetRecordOperation; -%rename("%s") TFE_Py_TapeSetWatchVariable; %rename("%s") TFE_Py_TapeGradient; +%rename("%s") TFE_Py_TapeVariableAccessed; %rename("%s") TFE_Py_TapeWatch; +%rename("%s") TFE_Py_TapeWatchVariable; %rename("%s") TFE_Py_TapeWatchedVariables; %rename("%s") TFE_NewContextOptions; %rename("%s") TFE_ContextOptionsSetConfig; @@ -65,6 +66,7 @@ limitations under the License. %rename("%s") TFE_Py_TensorShapeOnDevice; %rename("%s") TFE_ContextStartStep; %rename("%s") TFE_ContextEndStep; +%rename("%s") TFE_Py_RegisterVSpace; %{ #include "tensorflow/python/eager/pywrap_tfe.h" diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD index 7a37eda5ea..c9bc33e218 100644 --- a/tensorflow/python/saved_model/BUILD +++ b/tensorflow/python/saved_model/BUILD @@ -225,6 +225,7 @@ py_library( ":signature_constants", ":utils", "//tensorflow/core:protos_all_py", + "//tensorflow/python:framework_ops", "//tensorflow/python:util", ], ) diff --git a/tensorflow/python/saved_model/signature_def_utils_impl.py b/tensorflow/python/saved_model/signature_def_utils_impl.py index f8ad788f77..37f927f381 100644 --- a/tensorflow/python/saved_model/signature_def_utils_impl.py +++ b/tensorflow/python/saved_model/signature_def_utils_impl.py @@ -21,9 +21,7 @@ from __future__ import print_function from tensorflow.core.framework import types_pb2 from tensorflow.core.protobuf import meta_graph_pb2 -from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import utils from tensorflow.python.util.tf_export import tf_export @@ -316,80 +314,3 @@ def _is_valid_classification_signature(signature_def): return True - -def _get_shapes_from_tensor_info_dict(tensor_info_dict): - """Returns a map of keys to TensorShape objects. - - Args: - tensor_info_dict: map with TensorInfo proto as values. - - Returns: - Map with corresponding TensorShape objects as values. - """ - return { - key: tensor_shape.TensorShape(tensor_info.tensor_shape) - for key, tensor_info in tensor_info_dict.items() - } - - -def _get_types_from_tensor_info_dict(tensor_info_dict): - """Returns a map of keys to DType objects. - - Args: - tensor_info_dict: map with TensorInfo proto as values. - - Returns: - Map with corresponding DType objects as values. - """ - return { - key: dtypes.DType(tensor_info.dtype) - for key, tensor_info in tensor_info_dict.items() - } - - -def get_signature_def_input_shapes(signature): - """Returns map of parameter names to their shapes. - - Args: - signature: SignatureDef proto. - - Returns: - Map from string to TensorShape objects. - """ - return _get_shapes_from_tensor_info_dict(signature.inputs) - - -def get_signature_def_input_types(signature): - """Returns map of output names to their types. - - Args: - signature: SignatureDef proto. - - Returns: - Map from string to DType objects. - """ - return _get_types_from_tensor_info_dict(signature.inputs) - - -def get_signature_def_output_shapes(signature): - """Returns map of output names to their shapes. - - Args: - signature: SignatureDef proto. - - Returns: - Map from string to TensorShape objects. - """ - return _get_shapes_from_tensor_info_dict(signature.outputs) - - -def get_signature_def_output_types(signature): - """Returns map of output names to their types. - - Args: - signature: SignatureDef proto. - - Returns: - Map from string to DType objects. - """ - return _get_types_from_tensor_info_dict(signature.outputs) diff --git a/tensorflow/python/saved_model/signature_def_utils_test.py b/tensorflow/python/saved_model/signature_def_utils_test.py index ebc5450633..18c55d8d33 100644 --- a/tensorflow/python/saved_model/signature_def_utils_test.py +++ b/tensorflow/python/saved_model/signature_def_utils_test.py @@ -275,44 +275,6 @@ class SignatureDefUtilsTest(test.TestCase): self.assertEqual(method_name, signature_def.method_name) self.assertEqual(3, len(signature_def.outputs)) - def testGetShapeAndTypes(self): - inputs = { - "input-1": constant_op.constant(["a", "b"]), - "input-2": array_ops.placeholder(dtypes.float32, [10, 11]), - } - outputs = { - "output-1": array_ops.placeholder(dtypes.float32, [10, 32]), - "output-2": constant_op.constant([["b"]]), - } - signature_def = _make_signature(inputs, outputs) - self.assertEqual( - signature_def_utils_impl.get_signature_def_input_shapes(signature_def), - {"input-1": [2], "input-2": [10, 11]}) - self.assertEqual( - signature_def_utils_impl.get_signature_def_output_shapes(signature_def), - {"output-1": [10, 32], "output-2": [1, 1]}) - self.assertEqual( - signature_def_utils_impl.get_signature_def_input_types(signature_def), - {"input-1": dtypes.string, "input-2": dtypes.float32}) - self.assertEqual( - signature_def_utils_impl.get_signature_def_output_types(signature_def), - {"output-1": dtypes.float32, "output-2": dtypes.string}) - - def testGetNonFullySpecifiedShapes(self): - outputs = { - "output-1": array_ops.placeholder(dtypes.float32, [None, 10, None]), - "output-2": array_ops.sparse_placeholder(dtypes.float32), - } - signature_def = _make_signature({}, outputs) - shapes = signature_def_utils_impl.get_signature_def_output_shapes( - signature_def) - self.assertEqual(len(shapes), 2) - # Must compare shapes with as_list() since 2 equivalent non-fully defined - # shapes are not equal to each other. - self.assertEqual(shapes["output-1"].as_list(), [None, 10, None]) - # Must compare `dims` since its an unknown shape. - self.assertEqual(shapes["output-2"].dims, None) - def _assertValidSignature(self, inputs, outputs, method_name): signature_def = signature_def_utils_impl.build_signature_def( inputs, outputs, method_name) diff --git a/tensorflow/python/tools/BUILD b/tensorflow/python/tools/BUILD index 01d43e09d1..1c1a1a54cd 100644 --- a/tensorflow/python/tools/BUILD +++ b/tensorflow/python/tools/BUILD @@ -137,6 +137,7 @@ py_test( size = "small", srcs = ["strip_unused_test.py"], srcs_version = "PY2AND3", + tags = ["notap"], deps = [ ":strip_unused_lib", "//tensorflow/core:protos_all_py", diff --git a/tensorflow/python/tools/api/generator/api_gen.bzl b/tensorflow/python/tools/api/generator/api_gen.bzl index 2810d83bd2..271cf2afaf 100644 --- a/tensorflow/python/tools/api/generator/api_gen.bzl +++ b/tensorflow/python/tools/api/generator/api_gen.bzl @@ -12,10 +12,15 @@ ESTIMATOR_API_INIT_FILES = [ # END GENERATED ESTIMATOR FILES ] +def get_compat_files( + file_paths, + compat_api_version): + """Prepends compat/v<compat_api_version> to file_paths.""" + return ["compat/v%d/%s" % (compat_api_version, f) for f in file_paths] + def gen_api_init_files( name, output_files = TENSORFLOW_API_INIT_FILES, - compat_output_files = {}, root_init_template = None, srcs = [], api_name = "tensorflow", @@ -23,7 +28,8 @@ def gen_api_init_files( compat_api_versions = [], package = "tensorflow.python", package_dep = "//tensorflow/python:no_contrib", - output_package = "tensorflow"): + output_package = "tensorflow", + output_dir = ""): """Creates API directory structure and __init__.py files. Creates a genrule that generates a directory structure with __init__.py @@ -37,8 +43,6 @@ def gen_api_init_files( tf_export. For e.g. if an op is decorated with @tf_export('module1.module2', 'module3'). Then, output_files should include module1/module2/__init__.py and module3/__init__.py. - compat_output_files: Dictionary mapping each compat_api_version to the - set of __init__.py file paths that should be generated for that version. root_init_template: Python init file that should be used as template for root __init__.py file. "# API IMPORTS PLACEHOLDER" comment inside this template will be replaced with root imports collected by this genrule. @@ -53,14 +57,16 @@ def gen_api_init_files( process package_dep: Python library target containing your package. output_package: Package where generated API will be added to. + output_dir: Subdirectory to output API to. + If non-empty, must end with '/'. """ root_init_template_flag = "" if root_init_template: root_init_template_flag = "--root_init_template=$(location " + root_init_template + ")" - api_gen_binary_target = "create_" + package + "_api" + api_gen_binary_target = ("create_" + package + "_api_%d") % api_version native.py_binary( - name = "create_" + package + "_api", + name = api_gen_binary_target, srcs = ["//tensorflow/python/tools/api/generator:create_python_api.py"], main = "//tensorflow/python/tools/api/generator:create_python_api.py", srcs_version = "PY2AND3", @@ -72,14 +78,9 @@ def gen_api_init_files( ], ) - all_output_files = list(output_files) + all_output_files = ["%s%s" % (output_dir, f) for f in output_files] compat_api_version_flags = "" for compat_api_version in compat_api_versions: - compat_files = compat_output_files.get(compat_api_version, []) - all_output_files.extend([ - "compat/v%d/%s" % (compat_api_version, f) - for f in compat_files - ]) compat_api_version_flags += " --compat_apiversion=%d" % compat_api_version native.genrule( @@ -87,12 +88,15 @@ def gen_api_init_files( outs = all_output_files, cmd = ( "$(location :" + api_gen_binary_target + ") " + - root_init_template_flag + " --apidir=$(@D) --apiname=" + - api_name + " --apiversion=" + str(api_version) + + root_init_template_flag + " --apidir=$(@D)" + output_dir + + " --apiname=" + api_name + " --apiversion=" + str(api_version) + compat_api_version_flags + " --package=" + package + " --output_package=" + output_package + " $(OUTS)" ), srcs = srcs, tools = [":" + api_gen_binary_target], - visibility = ["//tensorflow:__pkg__"], + visibility = [ + "//tensorflow:__pkg__", + "//tensorflow/tools/api/tests:__pkg__", + ], ) diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py index 6716c79f87..c5289564fe 100644 --- a/tensorflow/python/tools/saved_model_cli.py +++ b/tensorflow/python/tools/saved_model_cli.py @@ -546,7 +546,7 @@ def load_inputs_from_input_arg_string(inputs_str, input_exprs_str, input_examples = preprocess_input_examples_arg_string(input_examples_str) for input_tensor_key, (filename, variable_name) in inputs.items(): - data = np.load(file_io.FileIO(filename, mode='r')) + data = np.load(file_io.FileIO(filename, mode='rb')) # When a variable_name key is specified for the input file if variable_name: diff --git a/tensorflow/python/training/basic_session_run_hooks.py b/tensorflow/python/training/basic_session_run_hooks.py index 76625624e4..3bd4bd75bd 100644 --- a/tensorflow/python/training/basic_session_run_hooks.py +++ b/tensorflow/python/training/basic_session_run_hooks.py @@ -1025,7 +1025,7 @@ class ProfilerHook(session_run_hook.SessionRunHook): def before_run(self, run_context): self._request_summary = ( - self._next_step is None or + self._next_step is not None and self._timer.should_trigger_for_step(self._next_step)) requests = {"global_step": self._global_step_tensor} opts = (config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE) @@ -1035,6 +1035,10 @@ class ProfilerHook(session_run_hook.SessionRunHook): def after_run(self, run_context, run_values): stale_global_step = run_values.results["global_step"] + if self._next_step is None: + # Update the timer so that it does not activate until N steps or seconds + # have passed. + self._timer.update_last_triggered_step(stale_global_step) global_step = stale_global_step + 1 if self._request_summary: global_step = run_context.session.run(self._global_step_tensor) diff --git a/tensorflow/python/training/basic_session_run_hooks_test.py b/tensorflow/python/training/basic_session_run_hooks_test.py index b49a871a56..fe8a3e9062 100644 --- a/tensorflow/python/training/basic_session_run_hooks_test.py +++ b/tensorflow/python/training/basic_session_run_hooks_test.py @@ -1454,52 +1454,50 @@ class ProfilerHookTest(test.TestCase): with self.assertRaises(ValueError): basic_session_run_hooks.ProfilerHook(save_secs=None, save_steps=None) - def test_save_secs_saves_in_first_step(self): + def test_save_secs_does_not_save_in_first_step(self): with self.graph.as_default(): hook = basic_session_run_hooks.ProfilerHook( save_secs=2, output_dir=self.output_dir) with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess: sess.run(self.train_op) - self.assertEqual(1, self._count_timeline_files()) + self.assertEqual(0, self._count_timeline_files()) @test.mock.patch.object(time, 'time') def test_save_secs_saves_periodically(self, mock_time): # Pick a fixed start time. - current_time = 1484863632.320497 + current_time = 1484863632. with self.graph.as_default(): mock_time.return_value = current_time hook = basic_session_run_hooks.ProfilerHook( save_secs=2, output_dir=self.output_dir) with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess: - sess.run(self.train_op) # Saved. - self.assertEqual(1, self._count_timeline_files()) sess.run(self.train_op) # Not saved. - self.assertEqual(1, self._count_timeline_files()) + self.assertEqual(0, self._count_timeline_files()) # Simulate 2.5 seconds of sleep. mock_time.return_value = current_time + 2.5 sess.run(self.train_op) # Saved. + self.assertEqual(1, self._count_timeline_files()) # Pretend some small amount of time has passed. - mock_time.return_value = current_time + 0.1 + mock_time.return_value = current_time + 2.6 sess.run(self.train_op) # Not saved. # Edge test just before we should save the timeline. - mock_time.return_value = current_time + 1.9 + mock_time.return_value = current_time + 4.4 sess.run(self.train_op) # Not saved. - self.assertEqual(2, self._count_timeline_files()) + self.assertEqual(1, self._count_timeline_files()) mock_time.return_value = current_time + 4.5 sess.run(self.train_op) # Saved. - self.assertEqual(3, self._count_timeline_files()) + self.assertEqual(2, self._count_timeline_files()) - def test_save_steps_saves_in_first_step(self): + def test_save_steps_does_not_save_in_first_step(self): with self.graph.as_default(): hook = basic_session_run_hooks.ProfilerHook( - save_secs=2, output_dir=self.output_dir) + save_steps=1, output_dir=self.output_dir) with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess: - sess.run(self.train_op) # Saved. sess.run(self.train_op) # Not saved. - self.assertEqual(1, self._count_timeline_files()) + self.assertEqual(0, self._count_timeline_files()) def test_save_steps_saves_periodically(self): with self.graph.as_default(): @@ -1507,6 +1505,8 @@ class ProfilerHookTest(test.TestCase): save_steps=2, output_dir=self.output_dir) with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess: self.assertEqual(0, self._count_timeline_files()) + sess.run(self.train_op) # Not saved. + self.assertEqual(0, self._count_timeline_files()) sess.run(self.train_op) # Saved. self.assertEqual(1, self._count_timeline_files()) sess.run(self.train_op) # Not saved. @@ -1515,20 +1515,19 @@ class ProfilerHookTest(test.TestCase): self.assertEqual(2, self._count_timeline_files()) sess.run(self.train_op) # Not saved. self.assertEqual(2, self._count_timeline_files()) - sess.run(self.train_op) # Saved. - self.assertEqual(3, self._count_timeline_files()) - def test_run_metadata_saves_in_first_step(self): + def test_run_metadata_saves(self): writer_cache.FileWriterCache.clear() fake_summary_writer.FakeSummaryWriter.install() fake_writer = writer_cache.FileWriterCache.get(self.output_dir) with self.graph.as_default(): hook = basic_session_run_hooks.ProfilerHook( - save_secs=2, output_dir=self.output_dir) + save_steps=1, output_dir=self.output_dir) with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess: + sess.run(self.train_op) # Not saved. sess.run(self.train_op) # Saved. self.assertEqual( - list(fake_writer._added_run_metadata.keys()), ['step_1']) + list(fake_writer._added_run_metadata.keys()), ['step_2']) fake_summary_writer.FakeSummaryWriter.uninstall() diff --git a/tensorflow/python/training/checkpointable/base.py b/tensorflow/python/training/checkpointable/base.py index 9189d8f3e8..095a90ddd4 100644 --- a/tensorflow/python/training/checkpointable/base.py +++ b/tensorflow/python/training/checkpointable/base.py @@ -17,11 +17,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import abc import collections import functools import json import weakref +import six + from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import context from tensorflow.python.framework import constant_op @@ -91,7 +94,45 @@ class CheckpointInitialValue(ops.Tensor): return self._checkpoint_position -class PythonStringStateSaveable(saveable_object.SaveableObject): +class NoRestoreSaveable(saveable_object.SaveableObject): + """Embeds a tensor in a checkpoint with no restore ops.""" + + def __init__(self, tensor, name, dtype=None): + spec = saveable_object.SaveSpec(tensor, "", name, dtype=dtype) + super(NoRestoreSaveable, self).__init__(tensor, [spec], name) + + def restore(self, restored_tensors, restored_shapes): + return control_flow_ops.no_op() + + +@six.add_metaclass(abc.ABCMeta) +class PythonStateSaveable(saveable_object.SaveableObject): + """An interface for saving/restoring volatile Python state.""" + + @abc.abstractmethod + def feed_dict_additions(self): + """When running a graph, indicates fresh state to feed. + + Returns: + A dictionary mapping `Tensor`s to current Python state. + """ + pass + + @abc.abstractmethod + def freeze(self): + """Create a new `SaveableObject` which freezes current state as a constant. + + Used when executing eagerly to embed the current state as a constant, or + when creating a static tf.train.Saver with the frozen current Python state. + + Returns: + A `SaveableObject` which is not a `PythonStateSaveable` instance (i.e. has + no Python state associated with it). + """ + pass + + +class PythonStringStateSaveable(PythonStateSaveable): """Saves Python state in a checkpoint.""" def __init__(self, name, state_callback, restore_callback=None): @@ -104,19 +145,26 @@ class PythonStringStateSaveable(saveable_object.SaveableObject): restore_callback: A function taking a Python string, used to restore state. Optional; defaults to doing nothing. """ + self._state_callback = state_callback self._restore_callback = restore_callback - if context.executing_eagerly(): - self._save_string = ( - lambda: constant_op.constant(state_callback(), dtype=dtypes.string)) - else: + with ops.device("/cpu:0"): self._save_string = constant_op.constant("", dtype=dtypes.string) - self.feed_dict_additions = ( - lambda: {self._save_string: state_callback()}) spec = saveable_object.SaveSpec( self._save_string, "", name, dtype=dtypes.string) super(PythonStringStateSaveable, self).__init__( self._save_string, [spec], name) + def feed_dict_additions(self): + """When running a graph, indicates fresh state to feed.""" + return {self._save_string: self._state_callback()} + + def freeze(self): + """Create a frozen `SaveableObject` which saves the current state.""" + return NoRestoreSaveable( + tensor=self._state_callback, + dtype=dtypes.string, + name=self.name) + def python_restore(self, restored_strings): """Called to restore Python state.""" if self._restore_callback: @@ -309,7 +357,7 @@ class _CheckpointPosition(object): if self._checkpoint.saveable_object_cache is not None: self._checkpoint.saveable_object_cache.setdefault( self.checkpointable, {})[serialized_tensor.name] = [saveable] - if isinstance(saveable, PythonStringStateSaveable): + if isinstance(saveable, PythonStateSaveable): python_saveables.append(saveable) else: named_saveables[serialized_tensor.checkpoint_key] = saveable @@ -819,7 +867,7 @@ class CheckpointableBase(object): def _state_callback(): dereferenced_self = weak_self() if dereferenced_self: - return json.dumps(self, + return json.dumps(dereferenced_self, default=serialization.get_json_type, sort_keys=True).encode("utf8") else: diff --git a/tensorflow/python/training/checkpointable/util.py b/tensorflow/python/training/checkpointable/util.py index 13dddd37ac..56c4043d9d 100644 --- a/tensorflow/python/training/checkpointable/util.py +++ b/tensorflow/python/training/checkpointable/util.py @@ -32,7 +32,6 @@ from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_io_ops as io_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import variable_scope @@ -557,7 +556,14 @@ def _serialize_checkpointables( object_graph_proto = ( checkpointable_object_graph_pb2.CheckpointableObjectGraph()) named_saveables = [] - feed_additions = {} + if saveables_cache is None: + # No SaveableObject caching. Either we're executing eagerly, or building a + # static save which is specialized to the current Python state. + feed_additions = None + else: + # If we are caching SaveableObjects, we need to build up a feed_dict with + # functions computing volatile Python state to be saved with the checkpoint. + feed_additions = {} for checkpoint_id, checkpointable in enumerate(checkpointable_objects): assert node_ids[checkpointable] == checkpoint_id object_proto = object_graph_proto.nodes.add() @@ -616,18 +622,25 @@ def _serialize_checkpointables( for saveable in saveables: if hasattr(saveable, "full_name"): attribute.full_name = saveable.full_name - saveable_feed_dict_fn = getattr(saveable, "feed_dict_additions", None) - if saveable_feed_dict_fn is not None: - saveable_feed_dict = saveable_feed_dict_fn() # pylint: disable=not-callable - for new_feed_key in saveable_feed_dict.keys(): - if new_feed_key in feed_additions: - raise AssertionError( - ("The object %s tried to feed a value for the Tensor %s " - "when saving, but another object is already feeding a " - "value.") - % (checkpointable, new_feed_key)) - feed_additions.update(saveable_feed_dict) - named_saveables.extend(saveables) + if isinstance(saveable, base.PythonStateSaveable): + if feed_additions is None: + assert saveables_cache is None + # If we're not caching saveables, then we're either executing + # eagerly or building a static save/restore (e.g. for a + # SavedModel). In either case, we should embed the current Python + # state in the graph rather than relying on a feed dict. + saveable = saveable.freeze() + else: + saveable_feed_dict = saveable.feed_dict_additions() + for new_feed_key in saveable_feed_dict.keys(): + if new_feed_key in feed_additions: + raise AssertionError( + ("The object %s tried to feed a value for the Tensor %s " + "when saving, but another object is already feeding a " + "value.") + % (checkpointable, new_feed_key)) + feed_additions.update(saveable_feed_dict) + named_saveables.append(saveable) for child in checkpointable._checkpoint_dependencies: # pylint: disable=protected-access child_proto = object_proto.children.add() @@ -827,16 +840,6 @@ def capture_dependencies(template): yield -class _NoRestoreSaveable(saver_lib.BaseSaverBuilder.SaveableObject): - - def __init__(self, tensor, name): - spec = saver_lib.BaseSaverBuilder.SaveSpec(tensor, "", name) - super(_NoRestoreSaveable, self).__init__(tensor, [spec], name) - - def restore(self, restored_tensors, restored_shapes): - return control_flow_ops.no_op() - - class _LoadStatus(object): """Abstract base for load status callbacks.""" @@ -1241,6 +1244,78 @@ class CheckpointableSaver(object): else: return self._root_checkpointable_ref + def _gather_saveables( + self, object_graph_tensor=None, saveable_object_cache=None): + """Wraps _serialize_object_graph to include the object graph proto.""" + assert ((object_graph_tensor is None and saveable_object_cache is None) + or (object_graph_tensor is not None + and saveable_object_cache is not None)) + (named_saveable_objects, graph_proto, + feed_additions) = _serialize_object_graph( + self._root_checkpointable, + saveables_cache=saveable_object_cache) + if object_graph_tensor is None: + with ops.device("/cpu:0"): + object_graph_tensor = constant_op.constant( + graph_proto.SerializeToString(), dtype=dtypes.string) + else: + feed_additions.update( + {object_graph_tensor: graph_proto.SerializeToString()}) + assert base.OBJECT_GRAPH_PROTO_KEY not in named_saveable_objects + named_saveable_objects.append( + base.NoRestoreSaveable( + tensor=object_graph_tensor, + name=base.OBJECT_GRAPH_PROTO_KEY)) + return named_saveable_objects, graph_proto, feed_additions + + def freeze(self): + """Creates a `tf.train.Saver` with the current object graph frozen.""" + named_saveable_objects, _, _ = self._gather_saveables( + object_graph_tensor=None, saveable_object_cache=None) + return saver_lib.Saver( + var_list=named_saveable_objects, max_to_keep=None) + + def _prepare_save(self, + object_graph_tensor=None, + saveable_object_cache=None): + """Create or retrieve save ops. + + When graph building, `saveable_object_cache` will typically be non-`None`, + meaning that existing `SaveableObject`s are re-used across calls to + `_prepare_save` even if the object graph has grown. This avoids + unnecessarily re-creating save ops. + + Args: + object_graph_tensor: A `Tensor` to which the current object graph will be + fed. + saveable_object_cache: A dictionary; if specified, used to cache + `SaveableObject`s. + + Returns: + A two-element tuple with a `tf.train.Saver` and a feed_dict of `Tensor`s + to feed when running save ops. The feed dict contains the current object + graph and any Python state to be saved in the checkpoint. + """ + (named_saveable_objects, graph_proto, + feed_additions) = self._gather_saveables( + object_graph_tensor=object_graph_tensor, + saveable_object_cache=saveable_object_cache) + if (self._last_save_object_graph != graph_proto + # When executing eagerly, we need to re-create SaveableObjects each time + # save() is called so they pick up new Tensors passed to their + # constructors. That means the Saver needs to be copied with a new + # var_list. + or context.executing_eagerly()): + if self._last_save_object_graph is not None: + self._last_save_saver = _copy_saver_with_new_var_list( + old_saver=self._last_save_saver, + new_var_list=named_saveable_objects) + else: + self._last_save_saver = saver_lib.Saver( + var_list=named_saveable_objects, max_to_keep=None) + self._last_save_object_graph = graph_proto + return self._last_save_saver, feed_additions + def save(self, file_prefix, checkpoint_number=None, session=None): """Save a training checkpoint. @@ -1263,44 +1338,29 @@ class CheckpointableSaver(object): Returns: The full path to the checkpoint. """ - named_variables, graph_proto, feed_additions = _serialize_object_graph( - self._root_checkpointable, - saveables_cache=self._saveable_object_cache) - if not context.executing_eagerly(): - if session is None: - session = ops.get_default_session() + feed_additions = {} + graph_building = not context.executing_eagerly() + if graph_building: if self._object_graph_feed_tensor is None: with ops.device("/cpu:0"): self._object_graph_feed_tensor = constant_op.constant( "", dtype=dtypes.string) object_graph_tensor = self._object_graph_feed_tensor - feed_additions.update( - {object_graph_tensor: graph_proto.SerializeToString()}) else: + object_graph_tensor = None + + saver, new_feed_additions = self._prepare_save( + object_graph_tensor=object_graph_tensor, + saveable_object_cache=self._saveable_object_cache) + if new_feed_additions: + feed_additions.update(new_feed_additions) + if not graph_building: session = None - with ops.device("/cpu:0"): - object_graph_tensor = constant_op.constant( - graph_proto.SerializeToString(), dtype=dtypes.string) - assert base.OBJECT_GRAPH_PROTO_KEY not in named_variables - named_variables.append( - _NoRestoreSaveable( - tensor=object_graph_tensor, - name=base.OBJECT_GRAPH_PROTO_KEY)) - if (self._last_save_object_graph != graph_proto - # When executing eagerly, we need to re-create SaveableObjects each time - # save() is called so they pick up new Tensors passed to their - # constructors. That means the Saver needs to be copied with a new - # var_list. - or context.executing_eagerly()): - if self._last_save_object_graph is not None: - self._last_save_saver = _copy_saver_with_new_var_list( - old_saver=self._last_save_saver, new_var_list=named_variables) - else: - self._last_save_saver = saver_lib.Saver( - var_list=named_variables, max_to_keep=None) - self._last_save_object_graph = graph_proto + elif session is None: + session = ops.get_default_session() + with ops.device("/cpu:0"): - save_path = self._last_save_saver.save( + save_path = saver.save( sess=_SessionWithFeedDictAdditions( session=session, feed_additions=feed_additions), save_path=file_prefix, @@ -1422,6 +1482,30 @@ class CheckpointableSaver(object): return load_status +def frozen_saver(root_checkpointable): + """Creates a static `tf.train.Saver` from a checkpointable object. + + The returned `Saver` saves object-based checkpoints, but these checkpoints + will no longer reflect structural changes to the object graph, only changes to + the values of `Variable`s added as dependencies of the root object before + `freeze` was called. + + `restore` works on the returned `Saver`, but requires that the object graph of + the checkpoint being loaded exactly matches the object graph when `freeze` was + called. This is in contrast the object-based restore performed by + `tf.train.Checkpoint` which attempts a fuzzy matching between a checkpoint's + object graph and the current Python object graph. + + Args: + root_checkpointable: A checkpointable object to save. + + Returns: + A `tf.train.Saver` which saves object-based checkpoints for the object graph + frozen at the time `frozen_saver` was called. + """ + return CheckpointableSaver(root_checkpointable).freeze() + + @tf_export("train.Checkpoint") class Checkpoint(tracking.Checkpointable): """Groups checkpointable objects, saving and restoring them. diff --git a/tensorflow/python/training/checkpointable/util_test.py b/tensorflow/python/training/checkpointable/util_test.py index bef4bf2a16..0d32d21426 100644 --- a/tensorflow/python/training/checkpointable/util_test.py +++ b/tensorflow/python/training/checkpointable/util_test.py @@ -560,6 +560,46 @@ class CheckpointingTests(test.TestCase): self.evaluate(root.save_counter)) @test_util.run_in_graph_and_eager_modes + def testFreezing(self): + with self.cached_session(use_gpu=True) as session: + # Save an object-based checkpoint using a frozen saver + directory = self.get_temp_dir() + prefix = os.path.join(directory, "ckpt") + v = resource_variable_ops.ResourceVariable(0, dtype=dtypes.int64) + checkpoint = checkpointable_utils.Checkpoint(v=v) + self.evaluate(v.assign(3)) + # Create the save counter so assert_consumed doesn't complain about it not + # existing in the checkpoint on restore. + self.evaluate(checkpoint.save_counter.assign(12)) + saver = checkpointable_utils.frozen_saver(checkpoint) + save_path = saver.save(session, prefix) + self.evaluate(v.assign(10)) + # Use the frozen saver to restore the same object graph + saver.restore(session, save_path) + self.assertEqual(3, self.evaluate(v)) + + # Restore using another frozen saver on an identical object graph + del v, checkpoint, saver + v = resource_variable_ops.ResourceVariable(0, dtype=dtypes.int64) + checkpoint = checkpointable_utils.Checkpoint(v=v) + saver = checkpointable_utils.frozen_saver(checkpoint) + saver.restore(session, save_path) + self.assertEqual(3, self.evaluate(v)) + + # Restore as an object-based checkpoint + del v, checkpoint, saver + checkpoint = checkpointable_utils.Checkpoint() + status = checkpoint.restore(save_path) + v = resource_variable_ops.ResourceVariable(0, dtype=dtypes.int64) + if context.executing_eagerly(): + self.assertEqual(12, self.evaluate(checkpoint.save_counter)) + self.assertEqual(0, self.evaluate(v)) + checkpoint.v = v + status.assert_consumed().run_restore_ops() + self.assertEqual(3, self.evaluate(v)) + self.assertEqual(12, self.evaluate(checkpoint.save_counter)) + + @test_util.run_in_graph_and_eager_modes def testCustomNumbering(self): directory = self.get_temp_dir() prefix = os.path.join(directory, "ckpt") diff --git a/tensorflow/python/util/util.i b/tensorflow/python/util/util.i index 6d336ac39d..104a615636 100644 --- a/tensorflow/python/util/util.i +++ b/tensorflow/python/util/util.i @@ -104,9 +104,36 @@ Raises: %unignore tensorflow::swig::Flatten; %noexception tensorflow::swig::Flatten; +%feature("docstring") tensorflow::swig::IsSequenceForData +"""Returns a true if `seq` is a Sequence or dict (except strings/lists). + +NOTE(mrry): This differs from `tensorflow.python.util.nest.is_sequence()`, +which *does* treat a Python list as a sequence. For ergonomic +reasons, `tf.data` users would prefer to treat lists as +implicit `tf.Tensor` objects, and dicts as (nested) sequences. + +Args: + seq: an input sequence. + +Returns: + True if the sequence is a not a string or list and is a + collections.Sequence. +""" %unignore tensorflow::swig::IsSequenceForData; %noexception tensorflow::swig::IsSequenceForData; +%feature("docstring") tensorflow::swig::FlattenForData +"""Returns a flat sequence from a given nested structure. + +If `nest` is not a sequence, this returns a single-element list: `[nest]`. + +Args: + nest: an arbitrarily nested structure or a scalar object. + Note, numpy arrays are considered scalars. + +Returns: + A Python list, the flattened version of the input. +""" %unignore tensorflow::swig::FlattenForData; %noexception tensorflow::swig::FlattenForData; diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index 207f22c931..3c533c7f99 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -3275,6 +3275,26 @@ port::Status CudnnSupport::DoConvolveBackwardFilterImpl( "This configuration potentially produces incorrect results."); }()); + // Zero out the result buffer for strided conv backward filter for NHWC + // layouts. cuDNN 7.1.4 and 7.2 has non-determinisic bug if the buffer is not + // zeroed. + // + // This wrong result caused by the bug is very flaky. It needs to be run for + // up to 20 times to produce a mismatch. + // + // TODO(timshen): add a nvbugs link. + if (CUDNN_VERSION >= 7100 && + algorithm_config.algorithm().algo_id() == + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1 && + cudnn_type == CUDNN_DATA_HALF && + input_descriptor.layout() == dnn::DataLayout::kBatchYXDepth && + filter_descriptor.layout() == dnn::FilterLayout::kOutputYXInput && + output_descriptor.layout() == dnn::DataLayout::kBatchYXDepth && + (convolution_descriptor.vertical_filter_stride() > 1 || + convolution_descriptor.horizontal_filter_stride() > 1)) { + stream->ThenMemZero(backward_filter_data, backward_filter_data->size()); + } + RETURN_IF_CUDNN_ERROR(cudnnConvolutionBackwardFilter( cudnn.handle(), /*alpha=*/alpha, diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-conditional-accumulator.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-conditional-accumulator.pbtxt index d23b3bd0ca..15e0ab76b6 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-conditional-accumulator.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-conditional-accumulator.pbtxt @@ -17,7 +17,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'conditional_accumulator\'], " + argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\', \'reduction_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'conditional_accumulator\', \'MEAN\'], " } member_method { name: "apply_grad" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-gradient-tape.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-gradient-tape.pbtxt index cbf655498c..2f4257a66a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-gradient-tape.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-gradient-tape.pbtxt @@ -4,7 +4,7 @@ tf_class { is_instance: "<type \'object\'>" member_method { name: "__init__" - argspec: "args=[\'self\', \'persistent\'], varargs=None, keywords=None, defaults=[\'False\'], " + argspec: "args=[\'self\', \'persistent\', \'watch_accessed_variables\'], varargs=None, keywords=None, defaults=[\'False\', \'True\'], " } member_method { name: "gradient" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-sparse-conditional-accumulator.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-sparse-conditional-accumulator.pbtxt index 2260279ad2..39ff336c4f 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-sparse-conditional-accumulator.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-sparse-conditional-accumulator.pbtxt @@ -17,7 +17,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'sparse_conditional_accumulator\'], " + argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\', \'reduction_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'sparse_conditional_accumulator\', \'MEAN\'], " } member_method { name: "apply_grad" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt index 834f0954d5..87745420ee 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt @@ -60,7 +60,7 @@ tf_class { } member_method { name: "interleave" - argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], " + argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], " } member_method { name: "list_files" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt index 4d854a4cee..6dd46365b0 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt @@ -61,7 +61,7 @@ tf_class { } member_method { name: "interleave" - argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], " + argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], " } member_method { name: "list_files" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt index 601f095a60..35b7105eba 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt @@ -61,7 +61,7 @@ tf_class { } member_method { name: "interleave" - argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], " + argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], " } member_method { name: "list_files" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt index 587829a4c0..8ae370af98 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt @@ -61,7 +61,7 @@ tf_class { } member_method { name: "interleave" - argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], " + argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], " } member_method { name: "list_files" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.python_io.-t-f-record-options.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.python_io.-t-f-record-options.pbtxt index 0853716023..614ba42d3e 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.python_io.-t-f-record-options.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.python_io.-t-f-record-options.pbtxt @@ -8,7 +8,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'compression_type\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'compression_type\', \'flush_mode\', \'input_buffer_size\', \'output_buffer_size\', \'window_bits\', \'compression_level\', \'compression_method\', \'mem_level\', \'compression_strategy\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "get_compression_type_string" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-conditional-accumulator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-conditional-accumulator.pbtxt index d23b3bd0ca..15e0ab76b6 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.-conditional-accumulator.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.-conditional-accumulator.pbtxt @@ -17,7 +17,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'conditional_accumulator\'], " + argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\', \'reduction_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'conditional_accumulator\', \'MEAN\'], " } member_method { name: "apply_grad" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-gradient-tape.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-gradient-tape.pbtxt index cbf655498c..2f4257a66a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.-gradient-tape.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.-gradient-tape.pbtxt @@ -4,7 +4,7 @@ tf_class { is_instance: "<type \'object\'>" member_method { name: "__init__" - argspec: "args=[\'self\', \'persistent\'], varargs=None, keywords=None, defaults=[\'False\'], " + argspec: "args=[\'self\', \'persistent\', \'watch_accessed_variables\'], varargs=None, keywords=None, defaults=[\'False\', \'True\'], " } member_method { name: "gradient" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-sparse-conditional-accumulator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-sparse-conditional-accumulator.pbtxt index 2260279ad2..39ff336c4f 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.-sparse-conditional-accumulator.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.-sparse-conditional-accumulator.pbtxt @@ -17,7 +17,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'sparse_conditional_accumulator\'], " + argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\', \'reduction_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'sparse_conditional_accumulator\', \'MEAN\'], " } member_method { name: "apply_grad" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt index 834f0954d5..87745420ee 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt @@ -60,7 +60,7 @@ tf_class { } member_method { name: "interleave" - argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], " + argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], " } member_method { name: "list_files" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt index 4d854a4cee..6dd46365b0 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt @@ -61,7 +61,7 @@ tf_class { } member_method { name: "interleave" - argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], " + argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], " } member_method { name: "list_files" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt index 601f095a60..35b7105eba 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt @@ -61,7 +61,7 @@ tf_class { } member_method { name: "interleave" - argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], " + argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], " } member_method { name: "list_files" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt index 587829a4c0..8ae370af98 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt @@ -61,7 +61,7 @@ tf_class { } member_method { name: "interleave" - argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], " + argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], " } member_method { name: "list_files" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.python_io.-t-f-record-options.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.python_io.-t-f-record-options.pbtxt index 0853716023..614ba42d3e 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.python_io.-t-f-record-options.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.python_io.-t-f-record-options.pbtxt @@ -8,7 +8,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'compression_type\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'compression_type\', \'flush_mode\', \'input_buffer_size\', \'output_buffer_size\', \'window_bits\', \'compression_level\', \'compression_method\', \'mem_level\', \'compression_strategy\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "get_compression_type_string" diff --git a/tensorflow/tools/api/tests/BUILD b/tensorflow/tools/api/tests/BUILD index 8764409e4d..4efa4a9651 100644 --- a/tensorflow/tools/api/tests/BUILD +++ b/tensorflow/tools/api/tests/BUILD @@ -15,7 +15,10 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_binary") py_test( name = "api_compatibility_test", - srcs = ["api_compatibility_test.py"], + srcs = [ + "api_compatibility_test.py", + "//tensorflow:tf_python_api_gen_v2", + ], data = [ "//tensorflow/tools/api/golden:api_golden_v1", "//tensorflow/tools/api/golden:api_golden_v2", diff --git a/tensorflow/tools/api/tests/api_compatibility_test.py b/tensorflow/tools/api/tests/api_compatibility_test.py index 43d19bc99c..99bed5714f 100644 --- a/tensorflow/tools/api/tests/api_compatibility_test.py +++ b/tensorflow/tools/api/tests/api_compatibility_test.py @@ -34,6 +34,7 @@ import sys import unittest import tensorflow as tf +from tensorflow._api import v2 as tf_v2 from google.protobuf import message from google.protobuf import text_format @@ -232,14 +233,14 @@ class ApiCompatibilityTest(test.TestCase): return visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor) visitor.do_not_descend_map['tf'].append('contrib') - traverse.traverse(tf.compat.v1, visitor) + traverse.traverse(tf_v2.compat.v1, visitor) def testNoSubclassOfMessageV2(self): if not hasattr(tf.compat, 'v2'): return visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor) visitor.do_not_descend_map['tf'].append('contrib') - traverse.traverse(tf.compat.v2, visitor) + traverse.traverse(tf_v2, visitor) def _checkBackwardsCompatibility( self, root, golden_file_pattern, api_version, @@ -300,27 +301,24 @@ class ApiCompatibilityTest(test.TestCase): sys.version_info.major == 2, 'API compabitility test goldens are generated using python2.') def testAPIBackwardsCompatibilityV1(self): - if not hasattr(tf.compat, 'v1'): - return api_version = 1 golden_file_pattern = os.path.join( resource_loader.get_root_dir_with_all_resources(), _KeyToFilePath('*', api_version)) self._checkBackwardsCompatibility( - tf.compat.v1, golden_file_pattern, api_version) + tf_v2.compat.v1, golden_file_pattern, api_version) @unittest.skipUnless( sys.version_info.major == 2, 'API compabitility test goldens are generated using python2.') def testAPIBackwardsCompatibilityV2(self): - if not hasattr(tf.compat, 'v2'): - return api_version = 2 golden_file_pattern = os.path.join( resource_loader.get_root_dir_with_all_resources(), _KeyToFilePath('*', api_version)) self._checkBackwardsCompatibility( - tf.compat.v2, golden_file_pattern, api_version) + tf_v2, golden_file_pattern, api_version, + additional_private_map={'tf.compat': ['v1']}) if __name__ == '__main__': diff --git a/tensorflow/tools/ci_build/Dockerfile.rbe.gcc.gpu b/tensorflow/tools/ci_build/Dockerfile.rbe.gcc.gpu new file mode 100644 index 0000000000..08dc026328 --- /dev/null +++ b/tensorflow/tools/ci_build/Dockerfile.rbe.gcc.gpu @@ -0,0 +1,43 @@ +# To push a new version, run: +# $ docker build -f Dockerfile.rbe.gcc.gpu \ +# --tag "gcr.io/asci-toolchain/nosla-nvidia-gcc" . +# $ docker push gcr.io/asci-toolchain/nosla-nvidia-gcc +FROM nvidia/cuda:9.0-cudnn7-devel-ubuntu16.04 + +LABEL maintainer="Manuel Klimek <klimek@google.com>" + +# TODO(b/110903506): Fix the nvidia docker image by providing a link to the +# SONAME of libcuda.so. Alternatively, consider using gold or lld which do not +# run into the same problem - that will only work once the tensorflow build does +# not link to libcuda from generators anymore. +# https://github.com/NVIDIA/nvidia-docker/issues/775 +RUN ln -s libcuda.so /usr/local/cuda/lib64/stubs/libcuda.so.1 + +# TODO(klimek): Once the TODO in tensorflow's configure.py to correctly find +# libnccl is resolved, delete this block. +RUN ln -s /usr/lib/x86_64-linux-gnu/libnccl.so /usr/lib/libnccl.so \ + && ln -s /usr/lib/x86_64-linux-gnu/libnccl.so /usr/lib/libnccl.so.2 + +# TODO(b/110903506): Fix tensorflow to not require the use of LD_LIBRARY_PATH. +# The stubs/libcuda.so is not meant to used at runtime. The correct way to +# pass the path to bfd-ld is to pass -Wl,-rpath-link=/usr/local/cuda/lib64/stubs +# to all binaries transitively depending on libcuda. Optimally the tensorflow +# build would do that internally. +ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64/stubs + +# Copy and run the install scripts. +COPY install/*.sh /install/ +ARG DEBIAN_FRONTEND=noninteractive +RUN /install/install_bootstrap_deb_packages.sh +RUN add-apt-repository -y ppa:openjdk-r/ppa && \ + add-apt-repository -y ppa:george-edison55/cmake-3.x +RUN /install/install_deb_packages.sh +RUN /install/install_pip_packages.sh +RUN /install/install_golang.sh + +# Install nccl2. +RUN apt-get update && apt-get install -y \ + libnccl2 \ + libnccl-dev \ + && rm -rf /var/lib/apt-lists/* + diff --git a/tensorflow/tools/ci_build/ci_parameterized_build.sh b/tensorflow/tools/ci_build/ci_parameterized_build.sh index 1d7d9df72f..c8472102cb 100755 --- a/tensorflow/tools/ci_build/ci_parameterized_build.sh +++ b/tensorflow/tools/ci_build/ci_parameterized_build.sh @@ -86,7 +86,7 @@ # When set, overrides TF_BUILD_IS_OPT and TF_BUILD_MAVX # options, as this will replace the two. # TF_SKIP_CONTRIB_TESTS: -# If set to any non-empty or non-0 value, will skipp running +# If set to any non-empty or non-0 value, will skip running # contrib tests. # TF_NIGHTLY: # If this run is being used to build the tf_nightly pip @@ -131,7 +131,13 @@ BAZEL_CMD="bazel test" BAZEL_BUILD_ONLY_CMD="bazel build" BAZEL_CLEAN_CMD="bazel clean" -DEFAULT_BAZEL_CONFIGS="" +# Default flags: +# --test_summary=detailed: Tell us more about which targets are being built +# --keep_going: Don't stop at the first failure; tell us all the failures +# --build_tests_only: Don't build targets depended on by tests if the test is +# disabled. Also saves some compilation time. Otherwise, +# tries to build everything. +DEFAULT_BAZEL_CONFIGS="--test_summary=detailed --build_tests_only --keep_going" PIP_CMD="${CI_BUILD_DIR}/builds/pip.sh" PIP_TEST_TUTORIALS_FLAG="--test_tutorials" @@ -148,9 +154,7 @@ EXTRA_PARAMS="" BAZEL_TARGET="//tensorflow/... -//tensorflow/compiler/..." if [[ -n "$TF_SKIP_CONTRIB_TESTS" ]]; then - BAZEL_TARGET="$BAZEL_TARGET -//tensorflow/contrib/..." -else - BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/..." + BAZEL_TARGET="${BAZEL_TARGET} -//tensorflow/contrib/..." fi TUT_TEST_DATA_DIR="/tmp/tf_tutorial_test_data" diff --git a/tensorflow/tools/ci_build/install/install_pip_packages.sh b/tensorflow/tools/ci_build/install/install_pip_packages.sh index af478eded4..a9ae715c6a 100755 --- a/tensorflow/tools/ci_build/install/install_pip_packages.sh +++ b/tensorflow/tools/ci_build/install/install_pip_packages.sh @@ -119,6 +119,8 @@ pip2 install keras_applications==1.0.5 --no-deps pip3 install keras_applications==1.0.5 --no-deps pip2 install keras_preprocessing==1.0.3 --no-deps pip3 install keras_preprocessing==1.0.3 --no-deps +pip2 install --upgrade h5py==2.8.0 +pip3 install --upgrade h5py==2.8.0 # Install last working version of setuptools. pip2 install --upgrade setuptools==39.1.0 diff --git a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh index 93ea0c3db6..37e6b51f66 100755 --- a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh +++ b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh @@ -87,6 +87,7 @@ pip3.5 install --upgrade setuptools==39.1.0 # Keras pip3.5 install keras_applications==1.0.5 pip3.5 install keras_preprocessing==1.0.3 +pip3.5 install --upgrade h5py==2.8.0 # Install last working version of setuptools. pip3.5 install --upgrade setuptools==39.1.0 diff --git a/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh index 7a9eef7c64..7520ff74cb 100755 --- a/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh +++ b/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh @@ -99,6 +99,7 @@ pip3 install --upgrade termcolor # Install last working version of setuptools. pip3 install --upgrade setuptools==39.1.0 +pip3 install --upgrade h5py==2.8.0 # Keras pip3 install keras_applications==1.0.5 diff --git a/tensorflow/tools/ci_build/windows/bazel/common_env.sh b/tensorflow/tools/ci_build/windows/bazel/common_env.sh index 333a89d3f5..c18f0d6e69 100644 --- a/tensorflow/tools/ci_build/windows/bazel/common_env.sh +++ b/tensorflow/tools/ci_build/windows/bazel/common_env.sh @@ -53,7 +53,7 @@ export PATH="/c/${PYTHON_BASE_PATH}/Scripts:$PATH" # Setting default values to CUDA related environment variables export TF_CUDA_VERSION=${TF_CUDA_VERSION:-9.0} -export TF_CUDNN_VERSION=${TF_CUDNN_VERSION:-7.0} +export TF_CUDNN_VERSION=${TF_CUDNN_VERSION:-7} export TF_CUDA_COMPUTE_CAPABILITIES=${TF_CUDA_COMPUTE_CAPABILITIES:-3.7} export CUDA_TOOLKIT_PATH=${CUDA_TOOLKIT_PATH:-"C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v${TF_CUDA_VERSION}"} export CUDNN_INSTALL_PATH=${CUDNN_INSTALL_PATH:-"C:/tools/cuda"} diff --git a/tensorflow/tools/dockerfiles/README.md b/tensorflow/tools/dockerfiles/README.md index c484c162cb..d64db35afb 100644 --- a/tensorflow/tools/dockerfiles/README.md +++ b/tensorflow/tools/dockerfiles/README.md @@ -2,8 +2,8 @@ This directory houses TensorFlow's Dockerfiles. **DO NOT EDIT THE DOCKERFILES MANUALLY!** They are maintained by `assembler.py`, which builds Dockerfiles from -the files in `partials/` and the rules in `spec.yml`. See [the Maintaining -section](#maintaining) for more information. +the files in `partials/` and the rules in `spec.yml`. See [the Contributing +section](#contributing) for more information. ## Building diff --git a/tensorflow/tools/docs/generate_lib.py b/tensorflow/tools/docs/generate_lib.py index 483921fc2f..1cd9cb7ca9 100644 --- a/tensorflow/tools/docs/generate_lib.py +++ b/tensorflow/tools/docs/generate_lib.py @@ -36,23 +36,6 @@ from tensorflow.tools.docs import pretty_docs from tensorflow.tools.docs import py_guide_parser -def _is_free_function(py_object, full_name, index): - """Check if input is a free function (and not a class- or static method).""" - if not tf_inspect.isfunction(py_object): - return False - - # Static methods are functions to tf_inspect (in 2.7), so check if the parent - # is a class. If there is no parent, it's not a function. - if '.' not in full_name: - return False - - parent_name = full_name.rsplit('.', 1)[0] - if tf_inspect.isclass(index[parent_name]): - return False - - return True - - def write_docs(output_dir, parser_config, yaml_toc, @@ -109,7 +92,7 @@ def write_docs(output_dir, # Methods and some routines are documented only as part of their class. if not (tf_inspect.ismodule(py_object) or tf_inspect.isclass(py_object) or - _is_free_function(py_object, full_name, parser_config.index)): + parser.is_free_function(py_object, full_name, parser_config.index)): continue sitepath = os.path.join('api_docs/python', @@ -548,6 +531,13 @@ class DocGenerator(object): help='The path from the site-root to api_docs' 'directory for this project') + self.argument_parser.add_argument( + '--api_cache_out_path', + type=str, + default=None, + help='Path to store a json-serialized api-index, so links can be ' + 'inserted into docs without rebuilding the api_docs') + def add_output_dir_argument(self): self.argument_parser.add_argument( '--output_dir', @@ -648,6 +638,9 @@ class DocGenerator(object): visitor = self.run_extraction() reference_resolver = self.make_reference_resolver(visitor, doc_index) + if getattr(flags, 'api_cache_out_path', None): + reference_resolver.to_json_file(flags.api_cache_out_path) + # Build the guide_index for the api_docs back links. root_title = getattr(flags, 'root_title', 'TensorFlow') guide_index = _build_guide_index( diff --git a/tensorflow/tools/docs/parser.py b/tensorflow/tools/docs/parser.py index 549056c6c4..a6159fa692 100644 --- a/tensorflow/tools/docs/parser.py +++ b/tensorflow/tools/docs/parser.py @@ -35,6 +35,28 @@ from tensorflow.python.util import tf_inspect from tensorflow.tools.docs import doc_controls +def is_free_function(py_object, full_name, index): + """Check if input is a free function (and not a class- or static method). + + Args: + py_object: The the object in question. + full_name: The full name of the object, like `tf.module.symbol`. + index: The {full_name:py_object} dictionary for the public API. + + Returns: + True if the obeject is a stand-alone function, and not part of a class + definition. + """ + if not tf_inspect.isfunction(py_object): + return False + + parent_name = full_name.rsplit('.', 1)[0] + if tf_inspect.isclass(index[parent_name]): + return False + + return True + + # A regular expression capturing a python identifier. IDENTIFIER_RE = r'[a-zA-Z_]\w*' @@ -74,7 +96,7 @@ class _Errors(object): return self._errors == other._errors # pylint: disable=protected-access -def documentation_path(full_name): +def documentation_path(full_name, is_fragment=False): """Returns the file path for the documentation for the given API symbol. Given the fully qualified name of a library symbol, compute the path to which @@ -84,12 +106,22 @@ def documentation_path(full_name): Args: full_name: Fully qualified name of a library symbol. - + is_fragment: If `False` produce a direct markdown link (`tf.a.b.c` --> + `tf/a/b/c.md`). If `True` produce fragment link, `tf.a.b.c` --> + `tf/a/b.md#c` Returns: The file path to which to write the documentation for `full_name`. """ - dirs = full_name.split('.') - return os.path.join(*dirs) + '.md' + parts = full_name.split('.') + if is_fragment: + parts, fragment = parts[:-1], parts[-1] + + result = os.path.join(*parts) + '.md' + + if is_fragment: + result = result + '#' + fragment + + return result def _get_raw_docstring(py_object): @@ -136,8 +168,7 @@ class ReferenceResolver(object): doc. """ - def __init__(self, duplicate_of, doc_index, is_class, is_module, - py_module_names): + def __init__(self, duplicate_of, doc_index, is_fragment, py_module_names): """Initializes a Reference Resolver. Args: @@ -145,15 +176,15 @@ class ReferenceResolver(object): symbols. doc_index: A `dict` mapping symbol name strings to objects with `url` and `title` fields. Used to resolve @{$doc} references in docstrings. - is_class: A map from full names to bool for each symbol. - is_module: A map from full names to bool for each symbol. + is_fragment: A map from full names to bool for each symbol. If True the + object lives at a page fragment `tf.a.b.c` --> `tf/a/b#c`. If False + object has a page to itself: `tf.a.b.c` --> `tf/a/b/c`. py_module_names: A list of string names of Python modules. """ self._duplicate_of = duplicate_of self._doc_index = doc_index - self._is_class = is_class - self._is_module = is_module - self._all_names = set(is_class.keys()) + self._is_fragment = is_fragment + self._all_names = set(is_fragment.keys()) self._py_module_names = py_module_names self.current_doc_full_name = None @@ -180,21 +211,18 @@ class ReferenceResolver(object): Returns: an instance of `ReferenceResolver` () """ - is_class = { - name: tf_inspect.isclass(visitor.index[name]) - for name, obj in visitor.index.items() - } + is_fragment = {} + for name, obj in visitor.index.items(): + has_page = ( + tf_inspect.isclass(obj) or tf_inspect.ismodule(obj) or + is_free_function(obj, name, visitor.index)) - is_module = { - name: tf_inspect.ismodule(visitor.index[name]) - for name, obj in visitor.index.items() - } + is_fragment[name] = not has_page return cls( duplicate_of=visitor.duplicate_of, doc_index=doc_index, - is_class=is_class, - is_module=is_module, + is_fragment=is_fragment, **kwargs) @classmethod @@ -210,6 +238,10 @@ class ReferenceResolver(object): Args: filepath: The file path to write the json to. """ + try: + os.makedirs(os.path.dirname(filepath)) + except OSError: + pass json_dict = {} for key, value in self.__dict__.items(): # Drop these two fields. `_doc_index` is not serializable. `_all_names` is @@ -223,7 +255,7 @@ class ReferenceResolver(object): json_dict[key.lstrip('_')] = value with open(filepath, 'w') as f: - json.dump(json_dict, f) + json.dump(json_dict, f, indent=2, sort_keys=True) def replace_references(self, string, relative_path_to_root): """Replace "@{symbol}" references with links to symbol's documentation page. @@ -339,19 +371,7 @@ class ReferenceResolver(object): raise TFDocsError( 'Cannot make link to "%s": Not in index.' % master_name) - # If this is a member of a class, link to the class page with an anchor. - ref_path = None - if not (self._is_class[master_name] or self._is_module[master_name]): - idents = master_name.split('.') - if len(idents) > 1: - class_name = '.'.join(idents[:-1]) - assert class_name in self._all_names - if self._is_class[class_name]: - ref_path = documentation_path(class_name) + '#%s' % idents[-1] - - if not ref_path: - ref_path = documentation_path(master_name) - + ref_path = documentation_path(master_name, self._is_fragment[master_name]) return os.path.join(relative_path_to_root, ref_path) def _one_ref(self, match, relative_path_to_root): diff --git a/tensorflow/tools/docs/parser_test.py b/tensorflow/tools/docs/parser_test.py index 71e96afa10..8a41796fb9 100644 --- a/tensorflow/tools/docs/parser_test.py +++ b/tensorflow/tools/docs/parser_test.py @@ -28,6 +28,12 @@ from tensorflow.python.util import tf_inspect from tensorflow.tools.docs import doc_controls from tensorflow.tools.docs import parser +# The test needs a real module. `types.ModuleType()` doesn't work, as the result +# is a `builtin` module. Using "parser" here is arbitraty. The tests don't +# depend on the module contents. At this point in the process the public api +# has already been extracted. +test_module = parser + def test_function(unused_arg, unused_kwarg='default'): """Docstring for test function.""" @@ -334,15 +340,16 @@ class ParserTest(googletest.TestCase): self.assertEqual('my_method', page_info.methods[0].short_name) def test_docs_for_module(self): - # Get the current module. - module = sys.modules[__name__] index = { - 'TestModule': module, - 'TestModule.test_function': test_function, + 'TestModule': + test_module, + 'TestModule.test_function': + test_function, 'TestModule.test_function_with_args_kwargs': - test_function_with_args_kwargs, - 'TestModule.TestClass': TestClass, + test_function_with_args_kwargs, + 'TestModule.TestClass': + TestClass, } visitor = DummyVisitor(index=index, duplicate_of={}) @@ -365,11 +372,13 @@ class ParserTest(googletest.TestCase): base_dir='/') page_info = parser.docs_for_object( - full_name='TestModule', py_object=module, parser_config=parser_config) + full_name='TestModule', + py_object=test_module, + parser_config=parser_config) # Make sure the brief docstring is present - self.assertEqual(tf_inspect.getdoc(module).split('\n')[0], - page_info.doc.brief) + self.assertEqual( + tf_inspect.getdoc(test_module).split('\n')[0], page_info.doc.brief) # Make sure that the members are there funcs = {f_info.obj for f_info in page_info.functions} @@ -378,8 +387,9 @@ class ParserTest(googletest.TestCase): classes = {cls_info.obj for cls_info in page_info.classes} self.assertEqual({TestClass}, classes) - # Make sure this file is contained as the definition location. - self.assertEqual(os.path.relpath(__file__, '/'), page_info.defined_in.path) + # Make sure the module's file is contained as the definition location. + self.assertEqual( + os.path.relpath(test_module.__file__, '/'), page_info.defined_in.path) def test_docs_for_function(self): index = { @@ -495,6 +505,7 @@ class ParserTest(googletest.TestCase): duplicate_of = {'tf.third': 'tf.fourth'} index = { + 'tf': test_module, 'tf.fancy': test_function_with_fancy_docstring, 'tf.reference': HasOneMember, 'tf.reference.foo': HasOneMember.foo, @@ -521,20 +532,18 @@ class ParserTest(googletest.TestCase): 'NumPy has nothing as awesome as this function.\n') def test_generate_index(self): - module = sys.modules[__name__] index = { - 'TestModule': module, - 'test_function': test_function, - 'TestModule.test_function': test_function, - 'TestModule.TestClass': TestClass, - 'TestModule.TestClass.a_method': TestClass.a_method, - 'TestModule.TestClass.a_property': TestClass.a_property, - 'TestModule.TestClass.ChildClass': TestClass.ChildClass, - } - duplicate_of = { - 'TestModule.test_function': 'test_function' + 'tf': test_module, + 'tf.TestModule': test_module, + 'tf.test_function': test_function, + 'tf.TestModule.test_function': test_function, + 'tf.TestModule.TestClass': TestClass, + 'tf.TestModule.TestClass.a_method': TestClass.a_method, + 'tf.TestModule.TestClass.a_property': TestClass.a_property, + 'tf.TestModule.TestClass.ChildClass': TestClass.ChildClass, } + duplicate_of = {'tf.TestModule.test_function': 'tf.test_function'} visitor = DummyVisitor(index=index, duplicate_of=duplicate_of) @@ -553,7 +562,7 @@ class ParserTest(googletest.TestCase): self.assertIn('TestModule.test_function', docs) # Leading backtick to make sure it's included top-level. # This depends on formatting, but should be stable. - self.assertIn('<code>test_function', docs) + self.assertIn('<code>tf.test_function', docs) def test_argspec_for_functools_partial(self): # pylint: disable=unused-argument @@ -665,22 +674,18 @@ class ParserTest(googletest.TestCase): duplicate_of = {'AClass': ['AClass2']} doc_index = {'doc': you_cant_serialize_this} - is_class = { + is_fragment = { 'tf': False, - 'tf.AClass': True, - 'tf.AClass2': True, - 'tf.function': False - } - is_module = { - 'tf': True, + 'tf.VERSION': True, 'tf.AClass': False, + 'tf.AClass.method': True, 'tf.AClass2': False, 'tf.function': False } py_module_names = ['tf', 'tfdbg'] - resolver = parser.ReferenceResolver(duplicate_of, doc_index, is_class, - is_module, py_module_names) + resolver = parser.ReferenceResolver(duplicate_of, doc_index, is_fragment, + py_module_names) outdir = googletest.GetTempDir() @@ -692,6 +697,23 @@ class ParserTest(googletest.TestCase): # There are no __slots__, so all fields are visible in __dict__. self.assertEqual(resolver.__dict__, resolver2.__dict__) + def testIsFreeFunction(self): + + result = parser.is_free_function(test_function, 'test_module.test_function', + {'test_module': test_module}) + self.assertTrue(result) + + result = parser.is_free_function(test_function, 'TestClass.test_function', + {'TestClass': TestClass}) + self.assertFalse(result) + + result = parser.is_free_function(TestClass, 'TestClass', {}) + self.assertFalse(result) + + result = parser.is_free_function(test_module, 'test_module', {}) + self.assertFalse(result) + + RELU_DOC = """Computes rectified linear: `max(features, 0)` Args: diff --git a/tensorflow/tools/docs/pretty_docs.py b/tensorflow/tools/docs/pretty_docs.py index 448f246e0e..1a3e79621f 100644 --- a/tensorflow/tools/docs/pretty_docs.py +++ b/tensorflow/tools/docs/pretty_docs.py @@ -255,8 +255,9 @@ def _build_module_page(page_info): # at least for basic types. parts.append('## Other Members\n\n') + h3 = '<h3 id="{short_name}"><code>{short_name}</code></h3>\n\n' for item in page_info.other_members: - parts.append('`{short_name}`\n\n'.format(**item._asdict())) + parts.append(h3.format(**item._asdict())) return ''.join(parts) diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py index 61419f25ae..3102239a19 100644 --- a/tensorflow/tools/pip_package/setup.py +++ b/tensorflow/tools/pip_package/setup.py @@ -167,17 +167,21 @@ class InstallHeaders(Command): # directories for -I install_dir = re.sub('/google/protobuf_archive/src', '', install_dir) - # Copy eigen code into tensorflow/include. + # Copy external code headers into tensorflow/include. # A symlink would do, but the wheel file that gets created ignores # symlink within the directory hierarchy. # NOTE(keveman): Figure out how to customize bdist_wheel package so # we can do the symlink. - if 'tensorflow/include/external/eigen_archive/' in install_dir: - extra_dir = install_dir.replace( - 'tensorflow/include/external/eigen_archive', '') - if not os.path.exists(extra_dir): - self.mkpath(extra_dir) - self.copy_file(header, extra_dir) + external_header_locations = [ + 'tensorflow/include/external/eigen_archive/', + 'tensorflow/include/external/com_google_absl/', + ] + for location in external_header_locations: + if location in install_dir: + extra_dir = install_dir.replace(location, '') + if not os.path.exists(extra_dir): + self.mkpath(extra_dir) + self.copy_file(header, extra_dir) if not os.path.exists(install_dir): self.mkpath(install_dir) @@ -227,6 +231,8 @@ headers = (list(find_files('*.h', 'tensorflow/core')) + list(find_files('*.h', 'tensorflow/stream_executor')) + list(find_files('*.h', 'google/protobuf_archive/src')) + list(find_files('*', 'third_party/eigen3')) + + list(find_files('*.h', + 'tensorflow/include/external/com_google_absl')) + list(find_files('*', 'tensorflow/include/external/eigen_archive'))) setup( diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 2bf867c7e1..0ff695d9f8 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -106,11 +106,11 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""): tf_http_archive( name = "com_google_absl", urls = [ - "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/c075ad321696fa5072e097f0a51e4fe76a6fe13e.tar.gz", - "https://github.com/abseil/abseil-cpp/archive/c075ad321696fa5072e097f0a51e4fe76a6fe13e.tar.gz", + "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/fb462224c058487763f263b7995d70efd0242c17.tar.gz", + "https://github.com/abseil/abseil-cpp/archive/fb462224c058487763f263b7995d70efd0242c17.tar.gz", ], - sha256 = "cb4e11259742954f88802be6f33c1007c16502d90d68e8898b5e5084264ca8a9", - strip_prefix = "abseil-cpp-c075ad321696fa5072e097f0a51e4fe76a6fe13e", + sha256 = "f4f34f90083d5259f9a1a4067749d842599748d8ca03c1d9fe723124a7045c63", + strip_prefix = "abseil-cpp-fb462224c058487763f263b7995d70efd0242c17", build_file = clean_dep("//third_party:com_google_absl.BUILD"), ) @@ -491,11 +491,11 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""): tf_http_archive( name = "llvm", urls = [ - "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/dc6d9ec3646865125d057b6f515b4543df79920a.tar.gz", - "https://github.com/llvm-mirror/llvm/archive/dc6d9ec3646865125d057b6f515b4543df79920a.tar.gz", + "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/738b5f5028ef39cbb023967f80fa2e5dd568556b.tar.gz", + "https://github.com/llvm-mirror/llvm/archive/738b5f5028ef39cbb023967f80fa2e5dd568556b.tar.gz", ], - sha256 = "c7252290a113f694cccbb4b325c67b56f3aa6f5b3044524302c0e79db2da7e2a", - strip_prefix = "llvm-dc6d9ec3646865125d057b6f515b4543df79920a", + sha256 = "2bda8dd724ab432c162fb6eace259ccf8a97f13cb627336611bff68da2f33ec2", + strip_prefix = "llvm-738b5f5028ef39cbb023967f80fa2e5dd568556b", build_file = clean_dep("//third_party/llvm:llvm.autogenerated.BUILD"), ) diff --git a/third_party/gpus/cuda/remote.BUILD.tpl b/third_party/gpus/cuda/remote.BUILD.tpl index f774def5e6..100c7bb7c4 100644 --- a/third_party/gpus/cuda/remote.BUILD.tpl +++ b/third_party/gpus/cuda/remote.BUILD.tpl @@ -75,6 +75,11 @@ alias( ) alias( + name = "cudnn_header", + actual = "%{remote_cuda_repo}/cuda:cudnn_header", +) + +alias( name = "cufft", actual = "%{remote_cuda_repo}/cuda:cufft", ) diff --git a/third_party/llvm/llvm.autogenerated.BUILD b/third_party/llvm/llvm.autogenerated.BUILD index 0ac27e26a4..776935739a 100644 --- a/third_party/llvm/llvm.autogenerated.BUILD +++ b/third_party/llvm/llvm.autogenerated.BUILD @@ -109,16 +109,23 @@ template_rule( ) # A common library that all LLVM targets depend on. +# TODO(b/113996071): We need to glob all potentially #included files and stage +# them here because LLVM's build files are not strict headers clean, and remote +# build execution requires all inputs to be depended upon. cc_library( name = "config", - hdrs = [ + hdrs = glob([ + "**/*.h", + "**/*.def", + "**/*.inc.cpp", + ]) + [ "include/llvm/Config/AsmParsers.def", "include/llvm/Config/AsmPrinters.def", "include/llvm/Config/Disassemblers.def", "include/llvm/Config/Targets.def", - "include/llvm/Config/abi-breaking.h", "include/llvm/Config/config.h", "include/llvm/Config/llvm-config.h", + "include/llvm/Config/abi-breaking.h", ], defines = llvm_defines, includes = ["include"], diff --git a/third_party/nccl/BUILD b/third_party/nccl/BUILD new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/third_party/nccl/BUILD diff --git a/third_party/nccl/nccl_configure.bzl b/third_party/nccl/nccl_configure.bzl index 5d1ebf0686..ce9447096e 100644 --- a/third_party/nccl/nccl_configure.bzl +++ b/third_party/nccl/nccl_configure.bzl @@ -16,6 +16,7 @@ load( _NCCL_INSTALL_PATH = "NCCL_INSTALL_PATH" _TF_NCCL_VERSION = "TF_NCCL_VERSION" +_TF_NCCL_CONFIG_REPO = "TF_NCCL_CONFIG_REPO" _DEFINE_NCCL_MAJOR = "#define NCCL_MAJOR" _DEFINE_NCCL_MINOR = "#define NCCL_MINOR" @@ -48,25 +49,8 @@ alias( """ # Local build results in dynamic link and the license should not be included. -_NCCL_LOCAL_BUILD_TEMPLATE = """ -filegroup( - name = "LICENSE", - visibility = ["//visibility:public"], -) - -cc_library( - name = "nccl", - srcs = ["nccl/lib/libnccl.so.%s"], - hdrs = ["nccl/include/nccl.h"], - include_prefix = "third_party/nccl", - strip_include_prefix = "nccl/include", - deps = [ - "@local_config_cuda//cuda:cuda_headers", - ], - visibility = ["//visibility:public"], -) -""" - +_NCCL_REMOTE_BUILD_TEMPLATE = Label("//third_party/nccl:remote.BUILD.tpl") +_NCCL_LOCAL_BUILD_TEMPLATE = Label("//third_party/nccl:system.BUILD.tpl") def _find_nccl_header(repository_ctx, nccl_install_path): """Finds the NCCL header on the system. @@ -137,6 +121,13 @@ def _nccl_configure_impl(repository_ctx): repository_ctx.file("BUILD", _NCCL_DUMMY_BUILD_CONTENT) return + if _TF_NCCL_CONFIG_REPO in repository_ctx.os.environ: + # Forward to the pre-configured remote repository. + repository_ctx.template("BUILD", _NCCL_REMOTE_BUILD_TEMPLATE, { + "%{target}": repository_ctx.os.environ[_TF_NCCL_CONFIG_REPO], + }) + return + nccl_version = repository_ctx.os.environ[_TF_NCCL_VERSION].strip() if matches_version("1", nccl_version): # Alias to GitHub target from @nccl_archive. @@ -148,8 +139,10 @@ def _nccl_configure_impl(repository_ctx): # Create target for locally installed NCCL. nccl_install_path = repository_ctx.os.environ[_NCCL_INSTALL_PATH].strip() _check_nccl_version(repository_ctx, nccl_install_path, nccl_version) - repository_ctx.symlink(nccl_install_path, "nccl") - repository_ctx.file("BUILD", _NCCL_LOCAL_BUILD_TEMPLATE % nccl_version) + repository_ctx.template("BUILD", _NCCL_LOCAL_BUILD_TEMPLATE, { + "%{version}": nccl_version, + "%{install_path}": nccl_install_path, + }) nccl_configure = repository_rule( diff --git a/third_party/nccl/remote.BUILD.tpl b/third_party/nccl/remote.BUILD.tpl new file mode 100644 index 0000000000..d66fc5563d --- /dev/null +++ b/third_party/nccl/remote.BUILD.tpl @@ -0,0 +1,6 @@ +licenses(["restricted"]) + +package(default_visibility = ["//visibility:public"]) + +alias(name="LICENSE", actual = "%{target}:LICENSE") +alias(name = "nccl", actual = "%{target}:nccl") diff --git a/third_party/nccl/system.BUILD.tpl b/third_party/nccl/system.BUILD.tpl new file mode 100644 index 0000000000..7ca835dedf --- /dev/null +++ b/third_party/nccl/system.BUILD.tpl @@ -0,0 +1,26 @@ +filegroup( + name = "LICENSE", + visibility = ["//visibility:public"], +) + +cc_library( + name = "nccl", + srcs = ["libnccl.so.%{version}"], + hdrs = ["nccl.h"], + include_prefix = "third_party/nccl", + deps = [ + "@local_config_cuda//cuda:cuda_headers", + ], + visibility = ["//visibility:public"], +) + +genrule( + name = "nccl-files", + outs = [ + "libnccl.so.%{version}", + "nccl.h", + ], + cmd = """cp "%{install_path}/include/nccl.h" "$(@D)/nccl.h" && + cp "%{install_path}/lib/libnccl.so.%{version}" "$(@D)/libnccl.so.%{version}" """, +) + |