aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--RELEASE.md18
-rw-r--r--WORKSPACE2
-rw-r--r--configure.py61
-rw-r--r--tensorflow/c/c_api.cc1
-rw-r--r--tensorflow/c/c_api.h7
-rw-r--r--tensorflow/c/c_api_function_test.cc3
-rw-r--r--tensorflow/cc/saved_model/BUILD13
-rw-r--r--tensorflow/compiler/jit/deadness_analysis.cc158
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass.cc15
-rw-r--r--tensorflow/compiler/jit/xla_fusion_optimizer.cc12
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.h6
-rw-r--r--tensorflow/compiler/tests/binary_ops_test.py12
-rw-r--r--tensorflow/compiler/tf2xla/kernels/diag_op.cc105
-rw-r--r--tensorflow/compiler/tf2xla/lib/scatter.cc2
-rw-r--r--tensorflow/compiler/tf2xla/lib/triangular_solve.cc4
-rw-r--r--tensorflow/compiler/xla/client/lib/arithmetic.cc12
-rw-r--r--tensorflow/compiler/xla/client/lib/arithmetic.h6
-rw-r--r--tensorflow/compiler/xla/client/lib/math.cc6
-rw-r--r--tensorflow/compiler/xla/client/lib/numeric.cc21
-rw-r--r--tensorflow/compiler/xla/client/lib/numeric.h6
-rw-r--r--tensorflow/compiler/xla/client/lib/numeric_test.cc26
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.cc24
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.h24
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.i24
-rw-r--r--tensorflow/compiler/xla/python/xla_client.py18
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.cc203
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.h11
-rw-r--r--tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc40
-rw-r--r--tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc72
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/BUILD14
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc30
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h11
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/sort_util.cc201
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/sort_util.h34
-rw-r--r--tensorflow/compiler/xla/service/while_loop_constant_sinking.cc6
-rw-r--r--tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc45
-rw-r--r--tensorflow/compiler/xla/shape_tree.h140
-rw-r--r--tensorflow/compiler/xla/shape_tree_test.cc21
-rw-r--r--tensorflow/compiler/xla/shape_util.cc51
-rw-r--r--tensorflow/compiler/xla/shape_util.h13
-rw-r--r--tensorflow/compiler/xla/tests/reduce_test.cc43
-rw-r--r--tensorflow/contrib/autograph/README.md17
-rw-r--r--tensorflow/contrib/autograph/pyct/compiler.py14
-rw-r--r--tensorflow/contrib/data/__init__.py2
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD4
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py82
-rw-r--r--tensorflow/contrib/data/python/ops/BUILD3
-rw-r--r--tensorflow/contrib/data/python/ops/get_single_element.py30
-rw-r--r--tensorflow/contrib/data/python/ops/prefetching_ops.py5
-rw-r--r--tensorflow/contrib/eager/python/datasets.py64
-rw-r--r--tensorflow/contrib/eager/python/examples/gan/mnist.py5
-rw-r--r--tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb2
-rw-r--r--tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py2
-rw-r--r--tensorflow/contrib/estimator/BUILD204
-rw-r--r--tensorflow/contrib/estimator/python/estimator/head.py16
-rw-r--r--tensorflow/contrib/estimator/python/estimator/head_test.py14
-rw-r--r--tensorflow/contrib/lite/Makefile3
-rwxr-xr-xtensorflow/contrib/lite/build_ios_universal_lib.sh3
-rw-r--r--tensorflow/contrib/lite/builtin_ops.h1
-rw-r--r--tensorflow/contrib/lite/delegates/eager/BUILD5
-rwxr-xr-xtensorflow/contrib/lite/download_dependencies.sh2
-rw-r--r--tensorflow/contrib/lite/java/AndroidManifest.xml12
-rw-r--r--tensorflow/contrib/lite/kernels/add.cc85
-rw-r--r--tensorflow/contrib/lite/kernels/div.cc62
-rw-r--r--tensorflow/contrib/lite/kernels/div_test.cc61
-rw-r--r--tensorflow/contrib/lite/kernels/internal/common.h133
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h239
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h802
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h234
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h916
-rw-r--r--tensorflow/contrib/lite/kernels/internal/types.h112
-rw-r--r--tensorflow/contrib/lite/kernels/sub.cc68
-rw-r--r--tensorflow/contrib/lite/model.cc1
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.cc1
-rw-r--r--tensorflow/contrib/lite/python/tflite_convert.py2
-rw-r--r--tensorflow/contrib/lite/schema/schema.fbs7
-rwxr-xr-xtensorflow/contrib/lite/schema/schema_generated.h156
-rw-r--r--tensorflow/contrib/lite/testing/generated_examples_zip_test.cc2
-rw-r--r--tensorflow/contrib/lookup/lookup_ops_test.py15
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops.py18
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops_test.py5
-rw-r--r--tensorflow/contrib/model_pruning/README.md11
-rw-r--r--tensorflow/contrib/model_pruning/python/pruning.py8
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py31
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py72
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_context.py9
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py11
-rw-r--r--tensorflow/core/BUILD1
-rw-r--r--tensorflow/core/common_runtime/bfc_allocator.cc48
-rw-r--r--tensorflow/core/common_runtime/bfc_allocator.h10
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device.cc1
-rw-r--r--tensorflow/core/framework/register_types.h6
-rw-r--r--tensorflow/core/kernels/BUILD2
-rw-r--r--tensorflow/core/kernels/argmax_op.cc2
-rw-r--r--tensorflow/core/kernels/decode_proto_op.cc367
-rw-r--r--tensorflow/core/kernels/encode_proto_op.cc284
-rw-r--r--tensorflow/core/kernels/identity_op.cc1
-rw-r--r--tensorflow/core/kernels/lookup_table_op.cc1
-rw-r--r--tensorflow/core/kernels/scatter_nd_op.cc4
-rw-r--r--tensorflow/core/kernels/scatter_nd_op_cpu_impl.h7
-rw-r--r--tensorflow/core/lib/core/refcount.h11
-rw-r--r--tensorflow/core/lib/io/record_reader_writer_test.cc84
-rw-r--r--tensorflow/core/ops/array_ops.cc14
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt120
-rw-r--r--tensorflow/core/ops/ops.pbtxt5
-rw-r--r--tensorflow/core/protobuf/config.proto3
-rw-r--r--tensorflow/core/public/version.h2
-rw-r--r--tensorflow/core/util/proto/BUILD10
-rw-r--r--tensorflow/core/util/proto/decode.h298
-rw-r--r--tensorflow/core/util/proto/proto_utils.cc70
-rw-r--r--tensorflow/core/util/proto/proto_utils.h33
-rw-r--r--tensorflow/docs_src/guide/eager.md16
-rw-r--r--tensorflow/docs_src/install/install_c.md2
-rw-r--r--tensorflow/docs_src/install/install_go.md2
-rw-r--r--tensorflow/docs_src/install/install_java.md22
-rw-r--r--tensorflow/docs_src/install/install_linux.md38
-rw-r--r--tensorflow/docs_src/install/install_mac.md10
-rw-r--r--tensorflow/docs_src/install/install_sources.md4
-rw-r--r--tensorflow/python/client/session.py14
-rw-r--r--tensorflow/python/client/session_list_devices_test.py8
-rw-r--r--tensorflow/python/client/session_test.py21
-rw-r--r--tensorflow/python/client/tf_session.i5
-rw-r--r--tensorflow/python/data/ops/iterator_ops.py26
-rwxr-xr-xtensorflow/python/debug/examples/examples_test.sh2
-rw-r--r--tensorflow/python/estimator/BUILD8
-rw-r--r--tensorflow/python/estimator/canned/metric_keys.py5
-rw-r--r--tensorflow/python/grappler/layout_optimizer_test.py2
-rw-r--r--tensorflow/python/kernel_tests/bitcast_op_test.py8
-rw-r--r--tensorflow/python/kernel_tests/scatter_nd_ops_test.py23
-rw-r--r--tensorflow/python/ops/array_ops.py8
-rw-r--r--tensorflow/python/ops/math_ops.py15
-rw-r--r--tensorflow/python/ops/parallel_for/__init__.py10
-rw-r--r--tensorflow/python/ops/resource_variable_ops.py13
-rw-r--r--tensorflow/python/ops/variable_scope.py54
-rw-r--r--tensorflow/python/ops/variables.py57
-rw-r--r--tensorflow/python/training/checkpointable/data_structures_test.py25
-rw-r--r--tensorflow/python/training/checkpointable/layer_utils.py13
-rw-r--r--tensorflow/python/training/checkpointable/util.py2
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.cc29
-rwxr-xr-xtensorflow/tools/ci_build/ci_parameterized_build.sh31
-rwxr-xr-xtensorflow/tools/ci_build/ci_sanity.sh2
-rwxr-xr-xtensorflow/tools/ci_build/install/install_bazel.sh2
-rwxr-xr-xtensorflow/tools/ci_build/install/install_bazel_from_source.sh2
-rwxr-xr-xtensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh33
-rw-r--r--tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh8
-rw-r--r--tensorflow/tools/ci_build/windows/bazel/common_env.sh3
-rw-r--r--tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh10
-rw-r--r--tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh10
-rw-r--r--tensorflow/tools/docker/Dockerfile.devel2
-rw-r--r--tensorflow/tools/docker/Dockerfile.devel-gpu2
-rw-r--r--tensorflow/tools/docker/Dockerfile.devel-gpu-cuda9-cudnn72
-rw-r--r--tensorflow/tools/docker/notebooks/1_hello_tensorflow.ipynb2
-rw-r--r--tensorflow/tools/docs/generate.py5
-rw-r--r--tensorflow/tools/docs/generate_lib.py30
-rw-r--r--tensorflow/tools/docs/generate_lib_test.py13
-rwxr-xr-xtensorflow/tools/pip_package/build_pip_package.sh6
-rw-r--r--tensorflow/workspace.bzl8
-rw-r--r--third_party/examples/eager/spinn/spinn.py2
-rw-r--r--third_party/toolchains/cpus/py/BUILD242
-rw-r--r--third_party/toolchains/cpus/py3/BUILD234
161 files changed, 4439 insertions, 3109 deletions
diff --git a/RELEASE.md b/RELEASE.md
index 7bb1e3e1c8..6b67072f8e 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -34,18 +34,22 @@
* Using `tf.layers` in a subclassed `tf.keras.Model` class. See
[here](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/layers) for more details
* `tf.data`:
- * The `DatasetBase::DebugString()` method is now `const`.
- * Added the `tf.contrib.data.sample_from_datasets()` API for randomly sampling from multiple datasets.
+ * `Dataset.from_generator()` now accepts an `args` list, in order to create nested generators.
+ * `Dataset.list_files()` now produces determinstic results when `shuffle=False` or a `seed` is passed.
+ * `tf.contrib.data.sample_from_datasets()` and `tf.contrib.data.choose_from_datasets()` make it easier to sample or deterministically choose elements from multiple datasets.
+ * `tf.contrib.data.make_csv_dataset()` now supports line breaks in quoted strings, and two infrequently used arguments removed.
+ * (C++) `DatasetBase::DebugString()` is now `const`.
+ * (C++) `DatasetBase::MakeIterator()` has been renamed to `DatasetBase::MakeIteratorInternal()`.
+ * (C++) `IteratorBase::Initialize()` method was added to support raising errors during iterator construction.
* Eager Execution:
+ * Added the ability to pause recording operations for gradient computation via `tf.GradientTape.stop_recording`.
+ * Updated documentation, introductory notebooks.
* `tf.keras`:
* Move Keras code out of _impl folder and remove API files.
* `tf.keras.Model.save_weights` now saves in TensorFlow format by default.
* Enable dataset iterators to be passed to `tf.keras.Model` training/eval methods.
-* Accelerated Linear Algebra (XLA):
-* TensorFlow Debugger (tfdbg): fix an issue in which the TensorBoard Debugger Plugin could not handle total source file size exceeding gRPC message size limit (4 MB).
+* TensorFlow Debugger (tfdbg) CLI: fix an issue in which the TensorBoard Debugger Plugin could not handle total source file size exceeding gRPC message size limit (4 MB).
* `tf.contrib`:
- * Add `tf.contrib.data.choose_from_datasets()`.
- * `tf.contrib.data.make_csv_dataset()` now supports line breaks in quoted strings. Two arguments were removed from `make_csv_dataset`.
* `tf.contrib.framework.zero_initializer` supports ResourceVariable.
* Adding "constrained_optimization" to tensorflow/contrib.
* Other:
@@ -55,7 +59,6 @@
* More consistent GcsFileSystem behavior for certain reads past EOF.
* Update benchmark for tf.scan to match ranges across eager and graph modes.
* Fixed bug in `tf.reduce_prod gradient` for complex dtypes.
- * Add optional `args` argument to `Dataset.from_generator()`.
* Allow the use of '.' in variables (e.g. "hparams.parse('a.b=1.0')"), which would previously raise an error. This will correspond to an attribute name with an embedded '.' symbol (e.g. 'a.b'), which can only be accessed indirectly (e.g. through getattr and setattr). To set this up the user will first need to explicitly add the variable to the hparam object (e.g. "hparams.add_hparam(name='a.b', value=0.0)").
* Benchmark for tf.scan in graph and eager modes.
* Added complex128 support to FFT, FFT2D, FFT3D, IFFT, IFFT2D, and IFFT3D.
@@ -65,7 +68,6 @@
* LinearOperator[1D,2D,3D]Circulant added to `tensorflow.linalg`.
* Conv3D, Conv3DBackpropInput, Conv3DBackpropFilter now supports arbitrary.
* Added `tf.train.Checkpoint` for reading/writing object-based checkpoints.
- * `Dataset.list_files()` now produces determinstic results when `shuffle=False` or a `seed` is passed.
* Added LinearOperatorKronecker, a dense-free implementation of the Kronecker Product.
* Allow LinearOperator to broadcast.
* SavedModelBuilder will now deduplicate asset names that point to files with the same basename and the same contents. Note that this may result in new asset files included in SavedModels in cases where assets with the same name but different contents were previously overwriting each other.
diff --git a/WORKSPACE b/WORKSPACE
index 17961829a6..fd7570a80a 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -18,7 +18,7 @@ closure_repositories()
# files, in case the parsing of those build files depends on the bazel
# version we require here.
load("//tensorflow:version_check.bzl", "check_bazel_version_at_least")
-check_bazel_version_at_least("0.15.0")
+check_bazel_version_at_least("0.10.0")
load("//tensorflow:workspace.bzl", "tf_workspace")
diff --git a/configure.py b/configure.py
index 25729adf36..60fe54b2f6 100644
--- a/configure.py
+++ b/configure.py
@@ -882,7 +882,7 @@ def set_tf_cudnn_version(environ_cp):
default_cudnn_path = environ_cp.get('CUDA_TOOLKIT_PATH')
ask_cudnn_path = (r'Please specify the location where cuDNN %s library is '
'installed. Refer to README.md for more details. [Default'
- ' is %s]:') % (tf_cudnn_version, default_cudnn_path)
+ ' is %s]: ') % (tf_cudnn_version, default_cudnn_path)
cudnn_install_path = get_from_env_or_user_or_default(
environ_cp, 'CUDNN_INSTALL_PATH', ask_cudnn_path, default_cudnn_path)
@@ -1201,7 +1201,7 @@ def set_tf_cuda_compute_capabilities(environ_cp):
'https://developer.nvidia.com/cuda-gpus.\nPlease'
' note that each additional compute '
'capability significantly increases your '
- 'build time and binary size. [Default is: %s]' %
+ 'build time and binary size. [Default is: %s]: ' %
default_cuda_compute_capabilities)
tf_cuda_compute_capabilities = get_from_env_or_user_or_default(
environ_cp, 'TF_CUDA_COMPUTE_CAPABILITIES',
@@ -1402,14 +1402,36 @@ def set_build_strip_flag():
write_to_bazelrc('build --strip=always')
-def set_windows_build_flags():
- if is_windows():
- # The non-monolithic build is not supported yet
- write_to_bazelrc('build --config monolithic')
- # Suppress warning messages
- write_to_bazelrc('build --copt=-w --host_copt=-w')
- # Output more verbose information when something goes wrong
- write_to_bazelrc('build --verbose_failures')
+def set_windows_build_flags(environ_cp):
+ """Set Windows specific build options."""
+ # The non-monolithic build is not supported yet
+ write_to_bazelrc('build --config monolithic')
+ # Suppress warning messages
+ write_to_bazelrc('build --copt=-w --host_copt=-w')
+ # Output more verbose information when something goes wrong
+ write_to_bazelrc('build --verbose_failures')
+ # The host and target platforms are the same in Windows build. So we don't
+ # have to distinct them. This avoids building the same targets twice.
+ write_to_bazelrc('build --distinct_host_configuration=false')
+ # Enable short object file path to avoid long path issue on Windows.
+ # TODO(pcloudy): Remove this flag when upgrading Bazel to 0.16.0
+ # Short object file path will be enabled by default.
+ write_to_bazelrc('build --experimental_shortened_obj_file_path=true')
+
+ if get_var(
+ environ_cp, 'TF_OVERRIDE_EIGEN_STRONG_INLINE', 'Eigen strong inline',
+ True,
+ ('Would you like to override eigen strong inline for some C++ '
+ 'compilation to reduce the compiling time?'),
+ 'Eigen strong inline overridden.',
+ 'Not overriding eigen strong inline, '
+ 'some compilations could take more than 20 mins.'):
+ # Due to a known MSVC compiler issue
+ # https://github.com/tensorflow/tensorflow/issues/10521
+ # Overriding eigen strong inline speeds up the compiling of
+ # conv_grad_ops_3d.cc and conv_ops_3d.cc by 20 minutes,
+ # but this also hurts the performance. Let users decide what they want.
+ write_to_bazelrc('build --define=override_eigen_strong_inline=true')
def config_info_line(name, help_text):
@@ -1429,7 +1451,7 @@ def main():
# environment variables.
environ_cp = dict(os.environ)
- check_bazel_version('0.15.0')
+ check_bazel_version('0.10.0')
reset_tf_configure_bazelrc(args.workspace)
cleanup_makefile()
@@ -1537,7 +1559,8 @@ def main():
set_grpc_build_flags()
set_cc_opt_flags(environ_cp)
set_build_strip_flag()
- set_windows_build_flags()
+ if is_windows():
+ set_windows_build_flags(environ_cp)
if get_var(
environ_cp, 'TF_SET_ANDROID_WORKSPACE', 'android workspace',
@@ -1549,11 +1572,15 @@ def main():
create_android_ndk_rule(environ_cp)
create_android_sdk_rule(environ_cp)
- print('Preconfigured Bazel build configs. You can use any of the below by '
- 'adding "--config=<>" to your build command. See tools/bazel.rc for '
- 'more details.')
- config_info_line('mkl', 'Build with MKL support.')
- config_info_line('monolithic', 'Config for mostly static monolithic build.')
+ # On Windows, we don't have MKL support and the build is always monolithic.
+ # So no need to print the following message.
+ # TODO(pcloudy): remove the following if check when they make sense on Windows
+ if not is_windows():
+ print('Preconfigured Bazel build configs. You can use any of the below by '
+ 'adding "--config=<>" to your build command. See tools/bazel.rc for '
+ 'more details.')
+ config_info_line('mkl', 'Build with MKL support.')
+ config_info_line('monolithic', 'Config for mostly static monolithic build.')
if __name__ == '__main__':
main()
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index 5c218d3f25..a3003953a3 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -963,6 +963,7 @@ TF_DEVICELIST_METHOD(const char*, TF_DeviceListName, name().c_str(), nullptr);
TF_DEVICELIST_METHOD(const char*, TF_DeviceListType, device_type().c_str(),
nullptr);
TF_DEVICELIST_METHOD(int64_t, TF_DeviceListMemoryBytes, memory_limit(), -1);
+TF_DEVICELIST_METHOD(uint64_t, TF_DeviceListIncarnation, incarnation(), 0);
#undef TF_DEVICELIST_METHOD
diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h
index 1eb75ef11f..fddc09d45e 100644
--- a/tensorflow/c/c_api.h
+++ b/tensorflow/c/c_api.h
@@ -1521,6 +1521,13 @@ TF_CAPI_EXPORT extern const char* TF_DeviceListType(const TF_DeviceList* list,
TF_CAPI_EXPORT extern int64_t TF_DeviceListMemoryBytes(
const TF_DeviceList* list, int index, TF_Status* status);
+// Retrieve the incarnation number of a given device.
+//
+// If index is out of bounds, an error code will be set in the status object,
+// and 0 will be returned.
+TF_CAPI_EXPORT extern uint64_t TF_DeviceListIncarnation(
+ const TF_DeviceList* list, int index, TF_Status* status);
+
// --------------------------------------------------------------------------
// Load plugins containing custom ops and kernels
diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc
index 610274696f..f7ca219c89 100644
--- a/tensorflow/c/c_api_function_test.cc
+++ b/tensorflow/c/c_api_function_test.cc
@@ -1516,7 +1516,8 @@ void DefineStatefulFunction(const char* name, TF_Function** func) {
TF_Output inputs[] = {};
TF_Output outputs[] = {{random, 0}};
- *func = TF_GraphToFunction(func_graph.get(), name, /*append_hash=*/false, -1,
+ *func = TF_GraphToFunction(func_graph.get(), name,
+ /*append_hash_to_fn_name=*/false, -1,
/*opers=*/nullptr, 0, inputs, 1, outputs,
/*output_names=*/nullptr,
/*opts=*/nullptr, "", s.get());
diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD
index 730b1b669b..3d3895c8fa 100644
--- a/tensorflow/cc/saved_model/BUILD
+++ b/tensorflow/cc/saved_model/BUILD
@@ -39,9 +39,20 @@ cc_library(
hdrs = ["reader.h"],
deps = [
":constants",
+ ] + if_not_mobile([
+ # TODO(b/111634734): :lib and :protos_all contain dependencies that
+ # cannot be built on mobile platforms. Instead, include the appropriate
+ # tf_lib depending on the build platform.
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
- ],
+ ]) + if_mobile([
+ # Mobile-friendly SavedModel proto. See go/portable-proto for more info.
+ "//tensorflow/core:saved_model_portable_proto",
+ ]) + if_android([
+ "//tensorflow/core:android_tensorflow_lib",
+ ]) + if_ios([
+ "//tensorflow/core:ios_tensorflow_lib",
+ ]),
)
tf_cc_test(
diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc
index b2d119029a..d81e5fe900 100644
--- a/tensorflow/compiler/jit/deadness_analysis.cc
+++ b/tensorflow/compiler/jit/deadness_analysis.cc
@@ -44,10 +44,6 @@ class Predicate {
enum class Kind { kAnd, kOr, kNot, kSymbol };
virtual string ToString() const = 0;
- virtual bool operator==(const Predicate& other) const = 0;
- virtual bool operator!=(const Predicate& other) const {
- return !(*this == other);
- }
int64 hash() const { return hash_; }
virtual Kind kind() const = 0;
@@ -58,6 +54,8 @@ class Predicate {
private:
const int64 hash_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Predicate);
};
int64 HashPredicateSequence(Predicate::Kind kind,
@@ -69,19 +67,6 @@ int64 HashPredicateSequence(Predicate::Kind kind,
return hash;
}
-bool PredicateSequenceEqual(gtl::ArraySlice<Predicate*> lhs,
- gtl::ArraySlice<Predicate*> rhs) {
- if (lhs.size() != rhs.size()) {
- return false;
- }
- for (int64 i = 0; i < lhs.size(); i++) {
- if (*lhs[i] != *rhs[i]) {
- return false;
- }
- }
- return true;
-}
-
// Represents a logical conjunction of a set of predicates.
class AndPredicate : public Predicate {
public:
@@ -102,17 +87,9 @@ class AndPredicate : public Predicate {
return strings::StrCat("(", str_util::Join(operands_str, " & "), ")");
}
- bool operator==(const Predicate& other) const override {
- return other.kind() == Kind::kAnd &&
- PredicateSequenceEqual(
- dynamic_cast<const AndPredicate&>(other).operands(), operands());
- }
-
Kind kind() const override { return Kind::kAnd; }
- const tensorflow::gtl::ArraySlice<Predicate*> operands() const {
- return operands_;
- }
+ const gtl::ArraySlice<Predicate*> operands() const { return operands_; }
private:
std::vector<Predicate*> operands_;
@@ -138,16 +115,8 @@ class OrPredicate : public Predicate {
return strings::StrCat("(", str_util::Join(operands_str, " | "), ")");
}
- bool operator==(const Predicate& other) const override {
- return other.kind() == Kind::kOr &&
- PredicateSequenceEqual(
- dynamic_cast<const OrPredicate&>(other).operands(), operands());
- }
-
Kind kind() const override { return Kind::kOr; }
- const tensorflow::gtl::ArraySlice<Predicate*> operands() const {
- return operands_;
- }
+ const gtl::ArraySlice<Predicate*> operands() const { return operands_; }
private:
std::vector<Predicate*> operands_;
@@ -164,11 +133,6 @@ class NotPredicate : public Predicate {
return strings::StrCat("~", operand()->ToString());
}
- bool operator==(const Predicate& other) const override {
- return other.kind() == Kind::kNot &&
- *dynamic_cast<const NotPredicate&>(other).operand() == *operand();
- }
-
Kind kind() const override { return Kind::kNot; }
Predicate* operand() const { return operand_; }
@@ -188,14 +152,6 @@ class SymbolPredicate : public Predicate {
must_be_true_(must_be_true) {}
string ToString() const override { return tensor_id_.ToString(); }
- bool operator==(const Predicate& other) const override {
- return other.kind() == Kind::kSymbol &&
- must_be_true() ==
- dynamic_cast<const SymbolPredicate&>(other).must_be_true() &&
- dynamic_cast<const SymbolPredicate&>(other).tensor_id() ==
- tensor_id();
- }
-
Kind kind() const override { return Kind::kSymbol; }
// If `must_be_true()` is true this SymbolPredicate represents the proposition
@@ -225,16 +181,37 @@ class PredicateFactory {
Predicate* MakeAndPredicate(gtl::ArraySlice<Predicate*> operands) {
return MakeAndOrImpl(operands, /*is_and=*/true);
}
+
Predicate* MakeOrPredicate(gtl::ArraySlice<Predicate*> operands) {
return MakeAndOrImpl(operands, /*is_and=*/false);
}
Predicate* MakeNotPredicate(Predicate* pred) {
- return Make<NotPredicate>(pred);
+ SignatureForNot signature = pred;
+ auto it = interned_not_instances_.find(signature);
+ if (it == interned_not_instances_.end()) {
+ std::unique_ptr<Predicate> new_pred = Make<NotPredicate>(pred);
+ Predicate* new_pred_ptr = new_pred.get();
+ interned_not_instances_.emplace(signature, std::move(new_pred));
+ return new_pred_ptr;
+ } else {
+ return it->second.get();
+ }
}
Predicate* MakeSymbolPredicate(TensorId tensor_id, bool must_be_true) {
- return Make<SymbolPredicate>(tensor_id, must_be_true);
+ SignatureForSymbol signature = {tensor_id, must_be_true};
+ auto it = interned_symbol_instances_.find(signature);
+ if (it == interned_symbol_instances_.end()) {
+ std::unique_ptr<Predicate> new_pred =
+ Make<SymbolPredicate>(tensor_id, must_be_true);
+ Predicate* new_pred_ptr = new_pred.get();
+ interned_symbol_instances_.emplace(std::move(signature),
+ std::move(new_pred));
+ return new_pred_ptr;
+ } else {
+ return it->second.get();
+ }
}
Predicate* MakeTrue() { return MakeAndPredicate({}); }
@@ -242,29 +219,53 @@ class PredicateFactory {
private:
template <typename PredicateT, typename... Args>
- Predicate* Make(Args... args) {
- std::unique_ptr<PredicateT> pred(
+ std::unique_ptr<Predicate> Make(Args&&... args) {
+ return std::unique_ptr<PredicateT>(
new PredicateT(std::forward<Args>(args)...));
- predicate_storage_.emplace_back(std::move(pred));
- return predicate_storage_.back().get();
}
Predicate* MakeAndOrImpl(gtl::ArraySlice<Predicate*> operands, bool is_and);
- struct PredicatePtrHash {
- size_t operator()(const Predicate* pred) const { return pred->hash(); }
+ // Predicate instances are interned, meaning that there is only a single
+ // instance of a Predicate object with a given content. This makes checking
+ // for structural equality super-cheap -- we can just compare pointers.
+ //
+ // We intern predicates by maintaining a map from the content of a Predicate
+ // to the only instance of said predicate we allow to exist in the
+ // interned_and_or_instances_, interned_not_instances_ and
+ // interned_symbol_instances_ fields. These maps also double up as storage
+ // for the owning pointers to predicate instances.
+
+ using SignatureForAndOr =
+ std::pair<Predicate::Kind, gtl::ArraySlice<Predicate*>>;
+ using SignatureForNot = Predicate*;
+ using SignatureForSymbol = std::pair<SafeTensorId, bool>;
+
+ struct HashSignatureForAndOr {
+ size_t operator()(const SignatureForAndOr& signature) const {
+ size_t hash = ::tensorflow::hash<Predicate::Kind>()(signature.first);
+ for (Predicate* p : signature.second) {
+ hash = Hash64Combine(hash, ::tensorflow::hash<Predicate*>()(p));
+ }
+ return hash;
+ }
};
- struct PredicatePtrEq {
- size_t operator()(const Predicate* a, const Predicate* b) const {
- return *a == *b;
+ struct HashSignatureForSymbol {
+ size_t operator()(const SignatureForSymbol& signature) const {
+ return Hash64Combine(SafeTensorId::Hasher()(signature.first),
+ ::tensorflow::hash<bool>()(signature.second));
}
};
- using PredicateSet =
- gtl::FlatSet<Predicate*, PredicatePtrHash, PredicatePtrEq>;
-
- std::vector<std::unique_ptr<Predicate>> predicate_storage_;
+ gtl::FlatMap<SignatureForAndOr, std::unique_ptr<Predicate>,
+ HashSignatureForAndOr>
+ interned_and_or_instances_;
+ gtl::FlatMap<SignatureForNot, std::unique_ptr<Predicate>>
+ interned_not_instances_;
+ gtl::FlatMap<SignatureForSymbol, std::unique_ptr<Predicate>,
+ HashSignatureForSymbol>
+ interned_symbol_instances_;
};
// Common code to create AndPredicate or OrPredicate instances.
@@ -272,7 +273,7 @@ Predicate* PredicateFactory::MakeAndOrImpl(gtl::ArraySlice<Predicate*> operands,
bool is_and) {
Predicate::Kind pred_kind =
is_and ? Predicate::Kind::kAnd : Predicate::Kind::kOr;
- PredicateSet simplified_ops_set;
+ gtl::FlatSet<Predicate*> simplified_ops_set;
std::vector<Predicate*> simplified_ops;
for (Predicate* op : operands) {
// Simplify A&A => A and A|A => A.
@@ -300,7 +301,7 @@ Predicate* PredicateFactory::MakeAndOrImpl(gtl::ArraySlice<Predicate*> operands,
}
// Simplify "A&~A=>False" and "A|~A=>True".
- PredicateSet negated_ops;
+ gtl::FlatSet<Predicate*> negated_ops;
for (Predicate* op : simplified_ops) {
if (op->kind() == Predicate::Kind::kNot) {
negated_ops.insert(dynamic_cast<NotPredicate&>(*op).operand());
@@ -317,8 +318,26 @@ Predicate* PredicateFactory::MakeAndOrImpl(gtl::ArraySlice<Predicate*> operands,
simplified_ops.begin(), simplified_ops.end(),
[](Predicate* a, Predicate* b) { return a->hash() < b->hash(); });
- return is_and ? Make<AndPredicate>(std::move(simplified_ops))
- : Make<OrPredicate>(std::move(simplified_ops));
+ auto it = interned_and_or_instances_.find({pred_kind, simplified_ops});
+ if (it == interned_and_or_instances_.end()) {
+ simplified_ops.shrink_to_fit();
+ // NB! Because we'll use a non-owning reference to simplified_ops in the
+ // key for interned_and_or_instances_ we need to be careful to std::move()
+ // it all the way through.
+ gtl::ArraySlice<Predicate*> operands_slice = simplified_ops;
+ std::unique_ptr<Predicate> new_pred =
+ is_and ? Make<AndPredicate>(std::move(simplified_ops))
+ : Make<OrPredicate>(std::move(simplified_ops));
+
+ Predicate* new_pred_ptr = new_pred.get();
+ CHECK(interned_and_or_instances_
+ .emplace(SignatureForAndOr(pred_kind, operands_slice),
+ std::move(new_pred))
+ .second);
+ return new_pred_ptr;
+ } else {
+ return it->second.get();
+ }
}
class DeadnessAnalysisImpl : public DeadnessAnalysis {
@@ -491,8 +510,9 @@ bool DeadnessAnalysisImpl::HasInputsWithMismatchingDeadness(const Node& node) {
// Today we just compare the predicates for equality (with some
// canonicalization/simplification happening before) but we could be more
- // sophisticated here if need be.
- if (pred != nullptr && *pred != *it->second) {
+ // sophisticated here if need be. Comparing pointers is sufficient because
+ // we intern Predicate instances by their content.
+ if (pred != nullptr && pred != it->second) {
if (vlog_) {
VLOG(2) << "HasInputsWithMismatchingDeadness(" << node.name()
<< ") -> true";
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
index 6558f14dd6..73db0d5952 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
@@ -21,7 +21,6 @@ limitations under the License.
#include <unordered_map>
#include <unordered_set>
-#include "tensorflow/compiler/jit/deadness_analysis.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
@@ -464,12 +463,6 @@ Status MarkForCompilationPass::Run(
VLOG(1) << "flags->tf_xla_fusion_only = " << flags->tf_xla_fusion_only;
const FunctionLibraryDefinition* fld = options.flib_def;
- std::unique_ptr<DeadnessAnalysis> deadness;
- {
- XLA_SCOPED_LOGGING_TIMER_LEVEL("DeadnessAnalysis", 0);
- TF_RETURN_IF_ERROR(DeadnessAnalysis::Run(**options.graph, &deadness));
- }
-
auto is_compilable = [&](const Node* node, const DeviceType& device_type) {
const XlaOpRegistry::DeviceRegistration* registration;
if (!XlaOpRegistry::GetCompilationDevice(device_type.type(),
@@ -497,14 +490,6 @@ Status MarkForCompilationPass::Run(
status = fld->GetAttr(*node, kXlaCompileAttr, &compile);
if (status.ok()) return compile;
- // If inputs to `node` can have conflicting deadness (i.e. some are alive
- // and some are dead) then don't compile it. XLA cannot represent the
- // deadness semantics of these nodes correctly and auto-clustering these
- // nodes can cause deadness propagate to nodes that should be live.
- if (node->IsMerge() || deadness->HasInputsWithMismatchingDeadness(*node)) {
- return false;
- }
-
// Check for fusable ops only if requested.
if (global_jit_level > 0 && fusion_only && !IsXlaFusable(node->def())) {
return false;
diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer.cc b/tensorflow/compiler/jit/xla_fusion_optimizer.cc
index b70e1cf52b..74257b09a8 100644
--- a/tensorflow/compiler/jit/xla_fusion_optimizer.cc
+++ b/tensorflow/compiler/jit/xla_fusion_optimizer.cc
@@ -20,7 +20,6 @@ limitations under the License.
#include <unordered_map>
#include <unordered_set>
-#include "tensorflow/compiler/jit/deadness_analysis.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/union_find.h"
@@ -147,9 +146,6 @@ Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster,
TF_RETURN_IF_ERROR(
ImportGraphDef(options, item.graph, &graph, &shape_refiner));
- std::unique_ptr<DeadnessAnalysis> deadness;
- TF_RETURN_IF_ERROR(DeadnessAnalysis::Run(graph, &deadness));
-
// Collect nodes that can be fused via XLA, while ignoring those that
// explicitly ask for XLA: (*) nodes that are marked to be compiled
// explicitly. (*) nodes assigned to XLA device.
@@ -189,14 +185,6 @@ Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster,
continue;
}
- // If inputs to `node` can have conflicting deadness (i.e. some are alive
- // and some are dead) then don't compile it. XLA cannot represent the
- // deadness semantics of these nodes correctly and auto-clustering these
- // nodes can cause deadness propagate to nodes that should be live.
- if (node->IsMerge() || deadness->HasInputsWithMismatchingDeadness(*node)) {
- continue;
- }
-
compilation_candidates.insert(node);
}
diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h
index 90531174ff..1ea3fa4cf2 100644
--- a/tensorflow/compiler/jit/xla_launch_util.h
+++ b/tensorflow/compiler/jit/xla_launch_util.h
@@ -122,7 +122,11 @@ class XlaTensorBuffer : public TensorBuffer {
data_ = const_cast<void*>(ptr);
}
- ~XlaTensorBuffer() override { allocator_->DeallocateRaw(data_); }
+ ~XlaTensorBuffer() override {
+ if (data_) {
+ allocator_->DeallocateRaw(data_);
+ }
+ }
void* data() const override { return data_; }
size_t size() const override { return expected_size_; }
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py
index 9cb3d04546..0aafda7fb4 100644
--- a/tensorflow/compiler/tests/binary_ops_test.py
+++ b/tensorflow/compiler/tests/binary_ops_test.py
@@ -691,11 +691,13 @@ class BinaryOpsTest(xla_test.XLATestCase):
np.array([[10], [7], [2]], dtype=np.float32),
np.float32(7),
expected=np.array([[False], [False], [True]], dtype=np.bool))
- self._testBinary(
- less_op,
- np.array([[10], [7], [2], [-1]], dtype=np.int64),
- np.int64(7),
- expected=np.array([[False], [False], [True], [True]], dtype=np.bool))
+ if np.int64 in self.numeric_types:
+ self._testBinary(
+ less_op,
+ np.array([[10], [7], [2], [-1]], dtype=np.int64),
+ np.int64(7),
+ expected=np.array(
+ [[False], [False], [True], [True]], dtype=np.bool))
for less_equal_op in [math_ops.less_equal, (lambda x, y: x <= y)]:
self._testBinary(
diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc
index 6dec414c53..22cda27567 100644
--- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc
@@ -123,8 +123,6 @@ class DiagPartOp : public XlaOpKernel {
explicit DiagPartOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- xla::XlaBuilder* builder = ctx->builder();
-
const TensorShape input_shape = ctx->InputShape(0);
auto dims = input_shape.dim_sizes();
@@ -150,37 +148,13 @@ class DiagPartOp : public XlaOpKernel {
new_dims.push_back(dims[i]);
}
- xla::XlaOp diag = ctx->Input(0);
-
- // TODO(b/30878775): use Slice with strides when supported, in place of
- // the Pad -> Reshape -> Slice.
-
- // Picture:
- // [[1, 0, 0, 0] pad and reshape to [[1, 0, 0, 0, 0],
- // [0, 2, 0, 0] =================> [2, 0, 0, 0, 0],
- // [0, 0, 3, 0] [3, 0, 0, 0, 0],
- // [0, 0, 0, 4]] [4, 0, 0, 0, 0]]
- // and then slice out the first column.
-
- // Flattens the input to 1D.
- int64 size = input_shape.num_elements();
- diag = xla::Reshape(diag, {size});
-
- // Adds padding after the last element of 'new_size'.
- xla::PaddingConfig config;
- auto* dim = config.add_dimensions();
- dim->set_edge_padding_high(new_size);
- auto zero = XlaHelpers::Zero(builder, input_type(0));
- diag = xla::Pad(diag, zero, config);
-
- // Reshapes so the diagonal is now in the first column.
- diag = xla::Reshape(diag, {new_size, new_size + 1});
+ xla::XlaOp input = ctx->Input(0);
- // Slices out the first column and reshapes to the final shape.
- diag = xla::Slice(diag, {0, 0}, {new_size, 1}, {1, 1});
- diag = xla::Reshape(diag, new_dims);
+ xla::XlaOp output = xla::Reshape(
+ xla::GetMatrixDiagonal(xla::Reshape(input, {new_size, new_size})),
+ new_dims);
- ctx->SetOutput(0, diag);
+ ctx->SetOutput(0, output);
}
};
@@ -220,8 +194,6 @@ class MatrixDiagPartOp : public XlaOpKernel {
explicit MatrixDiagPartOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- xla::XlaBuilder* builder = ctx->builder();
-
const TensorShape input_shape = ctx->InputShape(0);
auto dims = input_shape.dim_sizes();
@@ -229,71 +201,8 @@ class MatrixDiagPartOp : public XlaOpKernel {
errors::InvalidArgument("Expected 2 <= dims, got shape ",
input_shape.DebugString()));
- xla::XlaOp diag = ctx->Input(0);
-
- int last_dim = dims.size() - 1;
- int64 last_dim_size = dims[last_dim];
-
- // The smaller of the last two dimension sizes.
- int64 smaller_dim_size = std::min(dims[last_dim - 1], dims[last_dim]);
-
- // TODO(b/30878775): use Slice with strides when supported, in place of
- // the Pad -> Reshape -> Slice.
-
- // Picture: for each 2D matrix in the tensor's last two dimensions:
- // [[1, 0, 0, 0] pad and reshape to [[1, 0, 0, 0, 0],
- // [0, 2, 0, 0] =================> [2, 0, 0, 0, 0],
- // [0, 0, 3, 0]] [3, 0, 0, 0, 0],
- // and then slice out the first column.
- //
- // Another example, with tall and narrow input.
- // [[1, 0] pad and reshape to [[1, 0, 0],
- // [0, 2] =================> [2, 0, 0]]
- // [0, 0]
- // [0, 0]]
-
- // Collapses the last two dimensions.
- std::vector<int64> flattened_dims(dims.begin(), dims.end() - 1);
- flattened_dims.back() *= dims.back();
- diag = xla::Reshape(diag, flattened_dims);
-
- // Slices or pads the last dimension to 'target_size'.
- int64 actual_size = flattened_dims.back();
- int64 target_size = smaller_dim_size * (last_dim_size + 1);
- if (actual_size < target_size) {
- xla::PaddingConfig config =
- xla::MakeNoPaddingConfig(flattened_dims.size());
- auto* dim = config.mutable_dimensions(flattened_dims.size() - 1);
- dim->set_edge_padding_high(target_size - actual_size);
- auto zero = XlaHelpers::Zero(builder, input_type(0));
- diag = xla::Pad(diag, zero, config);
- } else if (actual_size > target_size) {
- std::vector<int64> start(flattened_dims.size(), 0);
- std::vector<int64> limits(flattened_dims.begin(), flattened_dims.end());
- std::vector<int64> strides(flattened_dims.size(), 1);
- limits[flattened_dims.size() - 1] = target_size;
- diag = xla::Slice(diag, start, limits, strides);
- }
-
- // Reshape so the target values are in the first position of the last
- // dimension.
- std::vector<int64> unflattened_dims(dims.begin(), dims.end());
- dims[last_dim - 1] = smaller_dim_size;
- dims[last_dim] = last_dim_size + 1;
- diag = xla::Reshape(diag, dims);
-
- // Slices out the first column and reshapes to the final shape.
- std::vector<int64> start(dims.size(), 0);
- std::vector<int64> limits(dims.begin(), dims.end());
- std::vector<int64> strides(dims.size(), 1);
- limits[last_dim] = 1;
- diag = xla::Slice(diag, start, limits, strides);
-
- // Collapses away the last dimension.
- dims.pop_back();
- diag = xla::Reshape(diag, dims);
-
- ctx->SetOutput(0, diag);
+ xla::XlaOp input = ctx->Input(0);
+ ctx->SetOutput(0, xla::GetMatrixDiagonal(input));
}
};
diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc
index 6a5be1c2be..739032fef7 100644
--- a/tensorflow/compiler/tf2xla/lib/scatter.cc
+++ b/tensorflow/compiler/tf2xla/lib/scatter.cc
@@ -132,7 +132,7 @@ xla::StatusOr<xla::XlaOp> XlaScatter(
// Discard updates with negative indices, since some users expect this.
auto index_in_range = xla::ReduceAll(
xla::Le(zero_index, index), xla::ConstantR0<bool>(body_builder, true),
- xla::CreateScalarAndComputation(body_builder));
+ xla::CreateScalarAndComputation(xla::PRED, body_builder));
// Make the index in bounds to prevent implementation defined behavior.
index = xla::Max(index, zero_index);
diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
index e405f8dfaa..a2dd5a0d57 100644
--- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
+++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
@@ -325,7 +325,7 @@ xla::XlaOp TriangularSolveLeftLooking(xla::XlaOp a, xla::XlaOp b,
}
// Rescale the input to be unit triangular
- auto diag = Diagonal(a);
+ auto diag = xla::GetMatrixDiagonal(a);
xla::XlaOp scaled_a;
std::vector<int64> broadcast_dimensions(ndims - 1);
std::iota(broadcast_dimensions.begin(), broadcast_dimensions.end(), 0);
@@ -490,7 +490,7 @@ xla::XlaOp TriangularSolveRightLooking(xla::XlaOp a, xla::XlaOp b,
}
// Rescale the input to be unit triangular
- auto diag = Diagonal(a);
+ auto diag = xla::GetMatrixDiagonal(a);
xla::XlaOp scaled_a;
std::vector<int64> broadcast_dimensions(ndims - 1);
std::iota(broadcast_dimensions.begin(), broadcast_dimensions.end(), 0);
diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc
index 978fc40f34..de1d785e19 100644
--- a/tensorflow/compiler/xla/client/lib/arithmetic.cc
+++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc
@@ -94,16 +94,18 @@ XlaComputation CreateScalarMinComputation(PrimitiveType type,
});
}
-XlaComputation CreateScalarAndComputation(XlaBuilder* builder) {
+XlaComputation CreateScalarAndComputation(PrimitiveType type,
+ XlaBuilder* builder) {
return CreateScalarComputation(
- "and", PRED, builder,
+ "and", type, builder,
[](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
return And(lhs, rhs);
});
}
-XlaComputation CreateScalarOrComputation(XlaBuilder* builder) {
- return CreateScalarComputation("or", PRED, builder,
+XlaComputation CreateScalarOrComputation(PrimitiveType type,
+ XlaBuilder* builder) {
+ return CreateScalarComputation("or", type, builder,
[](XlaBuilder* b, const XlaOp& lhs,
const XlaOp& rhs) { return Or(lhs, rhs); });
}
@@ -112,7 +114,7 @@ XlaOp Any(XlaOp predicates) {
XlaBuilder* builder = predicates.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
auto f = ConstantR0<bool>(builder, false);
- XlaComputation logical_or = CreateScalarOrComputation(builder);
+ XlaComputation logical_or = CreateScalarOrComputation(PRED, builder);
TF_ASSIGN_OR_RETURN(const Shape& predicates_shape,
builder->GetShape(predicates));
std::vector<int64> all_dimensions(ShapeUtil::Rank(predicates_shape));
diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h
index d0b916e8c8..8367e09450 100644
--- a/tensorflow/compiler/xla/client/lib/arithmetic.h
+++ b/tensorflow/compiler/xla/client/lib/arithmetic.h
@@ -45,10 +45,12 @@ XlaComputation CreateScalarMinComputation(PrimitiveType type,
XlaBuilder* builder);
// Creates a scalar logical AND computation and returns it.
-XlaComputation CreateScalarAndComputation(XlaBuilder* builder);
+XlaComputation CreateScalarAndComputation(PrimitiveType type,
+ XlaBuilder* builder);
// Creates a scalar logical OR computation and returns it.
-XlaComputation CreateScalarOrComputation(XlaBuilder* builder);
+XlaComputation CreateScalarOrComputation(PrimitiveType type,
+ XlaBuilder* builder);
// Returns whether any predicate in "predicates" is set.
//
diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc
index a6d606f944..0221de7672 100644
--- a/tensorflow/compiler/xla/client/lib/math.cc
+++ b/tensorflow/compiler/xla/client/lib/math.cc
@@ -25,11 +25,9 @@ XlaOp Sqrt(XlaOp operand) { return Pow(operand, ScalarLike(operand, 0.5)); }
XlaOp Rsqrt(XlaOp operand) { return Pow(operand, ScalarLike(operand, -0.5)); }
-XlaOp Square(XlaOp operand) { return Pow(operand, ScalarLike(operand, 2.0)); }
+XlaOp Square(XlaOp operand) { return operand * operand; }
-XlaOp Reciprocal(XlaOp operand) {
- return Pow(operand, ScalarLike(operand, -1.0));
-}
+XlaOp Reciprocal(XlaOp operand) { return ScalarLike(operand, 1.0) / operand; }
namespace {
diff --git a/tensorflow/compiler/xla/client/lib/numeric.cc b/tensorflow/compiler/xla/client/lib/numeric.cc
index cdbeb189f4..a6e460aa75 100644
--- a/tensorflow/compiler/xla/client/lib/numeric.cc
+++ b/tensorflow/compiler/xla/client/lib/numeric.cc
@@ -79,25 +79,30 @@ XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m,
return ConvertElementType(indicator, type);
}
-XlaOp Diagonal(XlaOp x) {
+XlaOp GetMatrixDiagonal(XlaOp x) {
XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
const int64 n_dims = ShapeUtil::Rank(shape);
TF_RET_CHECK(n_dims >= 2);
- const int64 n = shape.dimensions(n_dims - 1);
const int64 m = shape.dimensions(n_dims - 2);
+ const int64 n = shape.dimensions(n_dims - 1);
tensorflow::gtl::ArraySlice<int64> major_dims(
AsInt64Slice(shape.dimensions()), /*pos=*/0, /*len=*/n_dims - 2);
auto a = Iota(builder, U32, n);
auto b = Iota(builder, U32, m);
- auto indicator = Eq(a, Broadcast(b, {n}), /*broadcast_dimensions=*/{0});
+ auto indicator = Eq(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0});
auto mask = Broadcast(indicator, major_dims);
- XlaComputation add =
- CreateScalarAddComputation(shape.element_type(), builder);
- auto diag = Reduce(Select(mask, x, Zeros(builder, shape)), ScalarLike(x, 0),
- add, {n_dims - 1});
- return diag;
+
+ // TPUs don't support S64 add reduction at the moment. But fortunately
+ // OR-reductions work just as well for integers.
+ XlaComputation reducer =
+ primitive_util::IsIntegralType(shape.element_type())
+ ? CreateScalarOrComputation(shape.element_type(), builder)
+ : CreateScalarAddComputation(shape.element_type(), builder);
+
+ return Reduce(Select(mask, x, Zeros(builder, shape)), ScalarLike(x, 0),
+ reducer, {m >= n ? n_dims - 2 : n_dims - 1});
});
}
diff --git a/tensorflow/compiler/xla/client/lib/numeric.h b/tensorflow/compiler/xla/client/lib/numeric.h
index 3ec084636b..e9037b722c 100644
--- a/tensorflow/compiler/xla/client/lib/numeric.h
+++ b/tensorflow/compiler/xla/client/lib/numeric.h
@@ -29,8 +29,10 @@ XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size);
// else.
XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, int64 n);
-// Get the diagonals of the last two dimensions.
-XlaOp Diagonal(XlaOp x);
+// Get the diagonals of the last two dimensions. If 'x' has shape
+// [..., M, N], then the output has shape [..., min(M, N)], containing the
+// diagonal elements (i.e., with indices [..., i, i]).
+XlaOp GetMatrixDiagonal(XlaOp x);
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/lib/numeric_test.cc b/tensorflow/compiler/xla/client/lib/numeric_test.cc
index bc8a73e9d7..bfea3f539d 100644
--- a/tensorflow/compiler/xla/client/lib/numeric_test.cc
+++ b/tensorflow/compiler/xla/client/lib/numeric_test.cc
@@ -24,7 +24,11 @@ limitations under the License.
namespace xla {
namespace {
-using NumericTest = ClientLibraryTestBase;
+class NumericTest : public ClientLibraryTestBase {
+ protected:
+ template <typename T>
+ void TestMatrixDiagonal();
+};
XLA_TEST_F(NumericTest, Iota) {
XlaBuilder builder(TestName());
@@ -33,5 +37,25 @@ XLA_TEST_F(NumericTest, Iota) {
ComputeAndCompareR1<int32>(&builder, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, {});
}
+template <typename T>
+void NumericTest::TestMatrixDiagonal() {
+ XlaBuilder builder("GetMatrixDiagonal");
+ Array3D<T> input(2, 3, 4);
+ input.FillIota(0);
+
+ XlaOp a;
+ auto a_data = CreateR3Parameter<T>(input, 0, "a", &builder, &a);
+ GetMatrixDiagonal(a);
+ Array2D<T> expected({{0, 5, 10}, {12, 17, 22}});
+
+ ComputeAndCompareR2<T>(&builder, expected, {a_data.get()});
+}
+
+XLA_TEST_F(NumericTest, GetMatrixDiagonal_S32) { TestMatrixDiagonal<int32>(); }
+
+XLA_TEST_F(NumericTest, GetMatrixDiagonal_S64) { TestMatrixDiagonal<int64>(); }
+
+XLA_TEST_F(NumericTest, GetMatrixDiagonal_F32) { TestMatrixDiagonal<float>(); }
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc
index 66b1c08a39..f25348e735 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.cc
+++ b/tensorflow/compiler/xla/python/local_computation_builder.cc
@@ -617,6 +617,8 @@ _FORWARD_BINOP(Xor)
_FORWARD_BINOP(ShiftLeft)
_FORWARD_BINOP(ShiftRightArithmetic)
_FORWARD_BINOP(ShiftRightLogical)
+_FORWARD_BINOP(Atan2)
+_FORWARD_BINOP(Pow)
_FORWARD_UNOP(Not)
_FORWARD_UNOP(Abs)
_FORWARD_UNOP(Exp)
@@ -630,13 +632,27 @@ _FORWARD_UNOP(Sign)
_FORWARD_UNOP(Cos)
_FORWARD_UNOP(Sin)
_FORWARD_UNOP(Tanh)
-_FORWARD_UNOP(Sqrt)
-_FORWARD_UNOP(Square)
-_FORWARD_BINOP(Pow)
_FORWARD_UNOP(IsFinite)
-_FORWARD_UNOP(Reciprocal)
_FORWARD_UNOP(Neg)
_FORWARD_UNOP(Sort)
+_FORWARD_UNOP(Sqrt)
+_FORWARD_UNOP(Rsqrt)
+_FORWARD_UNOP(Square)
+_FORWARD_UNOP(Reciprocal)
+_FORWARD_UNOP(Erfc)
+_FORWARD_UNOP(Erf)
+_FORWARD_UNOP(ErfInv)
+_FORWARD_UNOP(Lgamma)
+_FORWARD_UNOP(Digamma)
+_FORWARD_UNOP(Acos)
+_FORWARD_UNOP(Asin)
+_FORWARD_UNOP(Atan)
+_FORWARD_UNOP(Tan)
+_FORWARD_UNOP(Acosh)
+_FORWARD_UNOP(Asinh)
+_FORWARD_UNOP(Atanh)
+_FORWARD_UNOP(Cosh)
+_FORWARD_UNOP(Sinh)
#undef _FORWARD
#undef _FORWARD_UNOP
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h
index 17ad044578..0e0d8ac29a 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.h
+++ b/tensorflow/compiler/xla/python/local_computation_builder.h
@@ -336,6 +336,8 @@ class LocalComputationBuilder {
_FORWARD_BINOP(ShiftLeft)
_FORWARD_BINOP(ShiftRightArithmetic)
_FORWARD_BINOP(ShiftRightLogical)
+ _FORWARD_BINOP(Atan2)
+ _FORWARD_BINOP(Pow)
_FORWARD_UNOP(Not)
_FORWARD_UNOP(Abs)
_FORWARD_UNOP(Exp)
@@ -349,13 +351,27 @@ class LocalComputationBuilder {
_FORWARD_UNOP(Cos)
_FORWARD_UNOP(Sin)
_FORWARD_UNOP(Tanh)
- _FORWARD_UNOP(Sqrt)
- _FORWARD_UNOP(Square)
- _FORWARD_BINOP(Pow)
_FORWARD_UNOP(IsFinite)
- _FORWARD_UNOP(Reciprocal)
_FORWARD_UNOP(Neg)
_FORWARD_UNOP(Sort)
+ _FORWARD_UNOP(Sqrt)
+ _FORWARD_UNOP(Rsqrt)
+ _FORWARD_UNOP(Square)
+ _FORWARD_UNOP(Reciprocal)
+ _FORWARD_UNOP(Erfc)
+ _FORWARD_UNOP(Erf)
+ _FORWARD_UNOP(ErfInv)
+ _FORWARD_UNOP(Lgamma)
+ _FORWARD_UNOP(Digamma)
+ _FORWARD_UNOP(Acos)
+ _FORWARD_UNOP(Asin)
+ _FORWARD_UNOP(Atan)
+ _FORWARD_UNOP(Tan)
+ _FORWARD_UNOP(Acosh)
+ _FORWARD_UNOP(Asinh)
+ _FORWARD_UNOP(Atanh)
+ _FORWARD_UNOP(Cosh)
+ _FORWARD_UNOP(Sinh)
#undef _FORWARD
#undef _FORWARD_UNOP
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i
index 42bf76e5d8..eeccbd7cfa 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.i
+++ b/tensorflow/compiler/xla/python/local_computation_builder.i
@@ -1005,13 +1005,29 @@ tensorflow::ImportNumpy();
%unignore xla::swig::LocalComputationBuilder::Cos;
%unignore xla::swig::LocalComputationBuilder::Sin;
%unignore xla::swig::LocalComputationBuilder::Tanh;
-%unignore xla::swig::LocalComputationBuilder::Sqrt;
-%unignore xla::swig::LocalComputationBuilder::Square;
-%unignore xla::swig::LocalComputationBuilder::Pow;
+%unignore xla::swig::LocalComputationBuilder::Atan2;
%unignore xla::swig::LocalComputationBuilder::IsFinite;
-%unignore xla::swig::LocalComputationBuilder::Reciprocal;
+%unignore xla::swig::LocalComputationBuilder::Pow;
%unignore xla::swig::LocalComputationBuilder::Neg;
%unignore xla::swig::LocalComputationBuilder::Sort;
+%unignore xla::swig::LocalComputationBuilder::Sqrt;
+%unignore xla::swig::LocalComputationBuilder::Rsqrt;
+%unignore xla::swig::LocalComputationBuilder::Square;
+%unignore xla::swig::LocalComputationBuilder::Reciprocal;
+%unignore xla::swig::LocalComputationBuilder::Erfc;
+%unignore xla::swig::LocalComputationBuilder::Erf;
+%unignore xla::swig::LocalComputationBuilder::ErfInv;
+%unignore xla::swig::LocalComputationBuilder::Lgamma;
+%unignore xla::swig::LocalComputationBuilder::Digamma;
+%unignore xla::swig::LocalComputationBuilder::Acos;
+%unignore xla::swig::LocalComputationBuilder::Asin;
+%unignore xla::swig::LocalComputationBuilder::Atan;
+%unignore xla::swig::LocalComputationBuilder::Tan;
+%unignore xla::swig::LocalComputationBuilder::Acosh;
+%unignore xla::swig::LocalComputationBuilder::Asinh;
+%unignore xla::swig::LocalComputationBuilder::Atanh;
+%unignore xla::swig::LocalComputationBuilder::Cosh;
+%unignore xla::swig::LocalComputationBuilder::Sinh;
%unignore xla::swig::DestructureLocalShapedBufferTuple;
%unignore xla::swig::DeleteLocalShapedBuffer;
%unignore xla::swig::DeleteLocalComputation;
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py
index f93d7bda2d..ef043e4ca0 100644
--- a/tensorflow/compiler/xla/python/xla_client.py
+++ b/tensorflow/compiler/xla/python/xla_client.py
@@ -99,12 +99,27 @@ _UNARY_OPS = [
'Cos',
'Sin',
'Tanh',
+ 'IsFinite',
'Sqrt',
+ 'Rsqrt',
'Square',
- 'IsFinite',
'Reciprocal',
'Neg',
'Sort',
+ 'Erf',
+ 'Erfc',
+ 'ErfInv',
+ 'Lgamma',
+ 'Digamma',
+ 'Acos',
+ 'Asin',
+ 'Atan',
+ 'Tan',
+ 'Acosh',
+ 'Asinh',
+ 'Atanh',
+ 'Cosh',
+ 'Sinh',
]
_BINARY_OPS = [
@@ -128,6 +143,7 @@ _BINARY_OPS = [
'ShiftLeft',
'ShiftRightArithmetic',
'ShiftRightLogical',
+ 'Atan2',
]
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index a043795a21..ca39797e81 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -170,6 +170,7 @@ cc_library(
"//tensorflow/compiler/xla/service/llvm_ir:llvm_loop",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/compiler/xla/service/llvm_ir:loop_emitter",
+ "//tensorflow/compiler/xla/service/llvm_ir:sort_util",
"//tensorflow/compiler/xla/service/llvm_ir:tuple_ops",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
index a08b72e3af..449a18e710 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
@@ -37,6 +37,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
#include "tensorflow/compiler/xla/service/name_uniquer.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -44,7 +45,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/window_util.h"
-#include "tensorflow/core/lib/core/bits.h"
#include "tensorflow/core/lib/core/errors.h"
namespace xla {
@@ -125,135 +125,14 @@ Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) {
}
Status IrEmitter::HandleSort(HloInstruction* sort) {
- auto keys = sort->operand(0);
auto values = sort->operand_count() > 1 ? sort->operand(1) : nullptr;
if (values != nullptr) {
// TODO(b/26783907): Also sort the values by their corresponding key.
return Unimplemented("Key/Value Sort is not implemented on GPU");
}
int dimension_to_sort = sort->dimensions(0);
- const llvm_ir::IrArray& keys_array = GetIrArray(*keys, *sort);
- const llvm_ir::IrArray& target_array = GetIrArray(*sort, *sort);
-
- const Shape& keys_shape = keys->shape();
-
- // TODO(b/26783907): This case can probably be avoided with the Algebraic
- // Simplifier.
- if (ShapeUtil::IsScalar(keys_shape)) {
- return Status::OK();
- }
-
- // Create loop nests which loop through the operand dimensions. The sort
- // dimension is handled in three separate innermost loops which perform the
- // sorting.
- llvm_ir::ForLoopNest loop_nest(IrName(sort), &ir_builder_);
- llvm_ir::IrArray::Index keys_index = EmitOperandArrayLoopNest(
- keys_array, dimension_to_sort, "keys", &loop_nest);
-
- // 'compare_keys_index' is the index of the element that 'keys_index' should
- // be compared to.
- llvm_ir::IrArray::Index compare_keys_index(keys_index.GetType());
- for (size_t dimension = 0; dimension < keys_index.size(); ++dimension) {
- if (dimension != dimension_to_sort) {
- compare_keys_index.push_back(keys_index[dimension]);
- } else {
- compare_keys_index.push_back(nullptr);
- }
- }
-
- // Create the sorting loops which do the sorting.
- int64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort);
- std::unique_ptr<llvm_ir::ForLoop> stages_loop = loop_nest.AddLoop(
- /*start_index=*/0,
- /*end_index=*/
- tensorflow::Log2Ceiling64(dimension_to_sort_bound),
- /*suffix=*/"sort_stages");
- std::unique_ptr<llvm_ir::ForLoop> mask_loop = loop_nest.AddLoop(
- /*suffix=*/"mask",
- /*start_index=*/keys_index.GetConstantWithIndexType(0),
- /*end_index=*/stages_loop->GetIndVarValue());
- std::unique_ptr<llvm_ir::ForLoop> compare_loop = loop_nest.AddLoop(
- /*start_index=*/0,
- /*end_index=*/dimension_to_sort_bound,
- /*suffix=*/"compare");
-
- // Naive C++ code for the inner loops (without parallelization):
- //
- // for (int64 stage = 0; stage < Log2Ceiling(dimension_to_sort_bound);
- // ++stage) {
- // int64 first_xor_mask = (1LL << (stage + 1)) - 1;
- // for (int64 i = 0; i < dimension_to_sort_bound; ++i) {
- // int64 j = i ^ first_xor_mask;
- // if (i < j && j < dimension_to_sort_bound) {
- // int64 min_key = std::min(keys[i], keys[j]);
- // keys[j] = std::max(keys[i], keys[j]);
- // keys[i] = min_key;
- // }
- // }
- // for (int64 mask = 0; mask < stage; ++mask) {
- // int64 later_xor_mask = (1LL << (stage - (mask + 1));
- // for (int64 i = 0; i < dimension_to_sort_bound; ++i) {
- // int64 j = i ^ later_xor_mask;
- // if (i < j && j < dimension_to_sort_bound) {
- // int64 min_key = std::min(keys[i], keys[j]);
- // keys[j] = std::max(keys[i], keys[j]);
- // keys[i] = min_key;
- // }
- // }
- // }
- // }
- //
- // This follows the algorithm described on Wikipedia:
- // https://en.wikipedia.org/wiki/Bitonic_sorter
-
- SetToFirstInsertPoint(stages_loop->GetBodyBasicBlock(), &ir_builder_);
- // The first xor mask of a stage is 2^(stage + 1) - 1.
- auto first_xor_mask = ir_builder_.CreateSub(
- ir_builder_.CreateShl(
- keys_index.GetConstantWithIndexType(1),
- ir_builder_.CreateAdd(stages_loop->GetIndVarValue(),
- keys_index.GetConstantWithIndexType(1))),
- keys_index.GetConstantWithIndexType(1));
- std::unique_ptr<llvm_ir::ForLoop> first_compare_loop =
- llvm_ir::ForLoop::EmitForLoop(
- /*prefix=*/"first_compare",
- /*start_index=*/keys_index.GetConstantWithIndexType(0),
- /*end_index=*/
- keys_index.GetConstantWithIndexType(
- keys_shape.dimensions(dimension_to_sort)),
- /*step=*/keys_index.GetConstantWithIndexType(1),
- /*ir_builder=*/&ir_builder_);
-
- SetToFirstInsertPoint(first_compare_loop->GetBodyBasicBlock(), &ir_builder_);
- // 'first_compare_loop' iterates through the 'dimension_to_sort'.
- keys_index[dimension_to_sort] = first_compare_loop->GetIndVarValue();
- compare_keys_index[dimension_to_sort] = ir_builder_.CreateXor(
- first_compare_loop->GetIndVarValue(), first_xor_mask);
- EmitCompareLoop(dimension_to_sort, keys_index, compare_keys_index,
- target_array);
-
- SetToFirstInsertPoint(compare_loop->GetPreheaderBasicBlock(), &ir_builder_);
- // The later masks of a stage are 2^(stage - (mask_loop_ind_var + 1)).
- auto later_xor_mask = ir_builder_.CreateShl(
- keys_index.GetConstantWithIndexType(1),
- ir_builder_.CreateSub(
- stages_loop->GetIndVarValue(),
- ir_builder_.CreateAdd(mask_loop->GetIndVarValue(),
- keys_index.GetConstantWithIndexType(1))));
-
- SetToFirstInsertPoint(compare_loop->GetBodyBasicBlock(), &ir_builder_);
- // 'compare_loop' iterates through the 'dimension_to_sort'.
- keys_index[dimension_to_sort] = compare_loop->GetIndVarValue();
- compare_keys_index[dimension_to_sort] =
- ir_builder_.CreateXor(compare_loop->GetIndVarValue(), later_xor_mask);
- EmitCompareLoop(dimension_to_sort, keys_index, compare_keys_index,
- target_array);
-
- // Set the IR builder insert point to the exit basic block of the outer most
- // loop. This ensures later instructions are inserted after this loop nest.
- ir_builder_.SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock());
-
- return Status::OK();
+ return llvm_ir::EmitSortInPlace(dimension_to_sort, GetIrArray(*sort, *sort),
+ IrName(sort), &ir_builder_);
}
Status IrEmitter::HandleSend(HloInstruction*) {
@@ -527,44 +406,6 @@ Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation,
return Status::OK();
}
-void IrEmitter::EmitCompareLoop(
- int64 dimension_to_sort, const llvm_ir::IrArray::Index& keys_index,
- const llvm_ir::IrArray::Index& compare_keys_index,
- const llvm_ir::IrArray& keys_array) {
- // TODO(b/26783907): parallelize this loop.
-
- // if (is_smaller_index &&
- // compare_keys[dimension_to_sort] < dimension_to_sort_bound)
- llvm::Value* is_smaller_index = ir_builder_.CreateICmpSLT(
- keys_index[dimension_to_sort], compare_keys_index[dimension_to_sort]);
- int64 dimension_to_sort_bound =
- keys_array.GetShape().dimensions(dimension_to_sort);
- auto if_data = llvm_ir::EmitIfThenElse(
- ir_builder_.CreateAnd(
- is_smaller_index,
- ir_builder_.CreateICmpSLT(
- compare_keys_index[dimension_to_sort],
- keys_index.GetConstantWithIndexType(dimension_to_sort_bound))),
- "smaller_comparison_index", &ir_builder_, /*emit_else=*/false);
- SetToFirstInsertPoint(if_data.true_block, &ir_builder_);
- auto key1 = keys_array.EmitReadArrayElement(keys_index, &ir_builder_);
- auto key2 = keys_array.EmitReadArrayElement(compare_keys_index, &ir_builder_);
- auto key_type = keys_array.GetShape().element_type();
- auto comparison =
- primitive_util::IsFloatingPointType(key_type)
- // TODO(b/26783907): Figure out how to handle NaNs.
- ? ir_builder_.CreateFCmp(llvm::FCmpInst::FCMP_ULT, key1, key2)
- : ir_builder_.CreateICmp(
- primitive_util::IsSignedIntegralType(key_type)
- ? llvm::ICmpInst::ICMP_SLT
- : llvm::ICmpInst::ICMP_ULT,
- key1, key2);
- auto min_key = ir_builder_.CreateSelect(comparison, key1, key2);
- auto max_key = ir_builder_.CreateSelect(comparison, key2, key1);
- keys_array.EmitWriteArrayElement(keys_index, min_key, &ir_builder_);
- keys_array.EmitWriteArrayElement(compare_keys_index, max_key, &ir_builder_);
-}
-
Status IrEmitter::EmitAtomicOperationForNestedComputation(
const HloComputation& computation, llvm::Value* output_address,
llvm::Value* source_address) {
@@ -691,10 +532,10 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
// operand dimensions. The reduction dimension of the LHS and RHS are handled
// in a separate innermost loop which performs the sum of products.
llvm_ir::ForLoopNest loop_nest(IrName(dot), &ir_builder_);
- llvm_ir::IrArray::Index lhs_index = EmitOperandArrayLoopNest(
- lhs_array, lhs_reduction_dimension, "lhs", &loop_nest);
- llvm_ir::IrArray::Index rhs_index = EmitOperandArrayLoopNest(
- rhs_array, rhs_reduction_dimension, "rhs", &loop_nest);
+ llvm_ir::IrArray::Index lhs_index = loop_nest.EmitOperandArrayLoopNest(
+ lhs_array, /*dimension_to_skip=*/lhs_reduction_dimension, "lhs");
+ llvm_ir::IrArray::Index rhs_index = loop_nest.EmitOperandArrayLoopNest(
+ rhs_array, /*dimension_to_skip=*/rhs_reduction_dimension, "rhs");
// Create the reduction loop which does the sum of products reduction.
std::unique_ptr<llvm_ir::ForLoop> reduction_loop = loop_nest.AddLoop(
@@ -943,36 +784,6 @@ Status IrEmitter::HandleBatchNormGrad(HloInstruction*) {
"to a cudnn CustomCall using CudnnBatchNormRewriter.");
}
-llvm_ir::IrArray::Index IrEmitter::EmitOperandArrayLoopNest(
- const llvm_ir::IrArray& operand_array, int64 reduction_dimension,
- tensorflow::StringPiece name_suffix, llvm_ir::ForLoopNest* loop_nest) {
- // Prepares the dimension list we will use to emit the loop nest. Outermost
- // loops are added first. Add loops in major-to-minor order, and skip the
- // reduction dimension.
- std::vector<int64> dimensions;
- const Shape& shape = operand_array.GetShape();
- for (int i = 0; i < LayoutUtil::MinorToMajor(shape).size(); ++i) {
- int64 dimension = LayoutUtil::Major(shape.layout(), i);
- if (dimension != reduction_dimension) {
- dimensions.push_back(dimension);
- }
- }
-
- // Create loop nest with one for-loop for each dimension of the
- // output.
- llvm_ir::IrArray::Index index =
- loop_nest->AddLoopsForShapeOnDimensions(shape, dimensions, name_suffix);
- // Verify every dimension except the reduction dimension was set in the index.
- for (size_t dimension = 0; dimension < index.size(); ++dimension) {
- if (dimension == reduction_dimension) {
- DCHECK_EQ(nullptr, index[dimension]);
- } else {
- DCHECK_NE(nullptr, index[dimension]);
- }
- }
- return index;
-}
-
StatusOr<llvm::Value*> IrEmitter::ComputeNestedElement(
const HloComputation& computation,
tensorflow::gtl::ArraySlice<llvm::Value*> parameter_elements) {
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
index e9ad4a752b..77e48d729c 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
@@ -171,17 +171,6 @@ class IrEmitter : public DfsHloVisitorWithDefault {
const HloModuleConfig& hlo_module_config_;
private:
- // Emits a series of nested loops for iterating over an operand array in the
- // dot operation. Loops are constructed in major to minor dimension layout
- // order. No loop is emitted for the given reduction_dimension. The function
- // returns an IrArray index for the given operand_array containing the indvars
- // of the loops. All dimensions of the index are filled except for the
- // reduction dimension. name_suffix is the string to append to the names of
- // LLVM constructs (eg, basic blocks) constructed by this method.
- llvm_ir::IrArray::Index EmitOperandArrayLoopNest(
- const llvm_ir::IrArray& operand_array, int64 reduction_dimension,
- tensorflow::StringPiece name_suffix, llvm_ir::ForLoopNest* loop_nest);
-
// A helper method for EmitAtomicOperationForNestedComputation. Certain
// computations, such as floating-point addition and integer maximization, can
// be simply implemented using an LLVM atomic instruction. If "computation" is
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
index ea661b3c2c..f95fbb01f9 100644
--- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
+++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include <string>
#include <utility>
+#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -71,7 +72,6 @@ bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1,
// In that case, the operand of the reduce needs to have the same shape
// as the other tuple operands, but also we need to compare the output
// shapes of the reduces.
- // TODO(tjoerg): Allow differences in fp precision.
auto* element_instr_1 = get_element_instr(instr1);
auto* element_instr_2 = get_element_instr(instr2);
if (element_instr_1->opcode() == HloOpcode::kReduce &&
@@ -80,8 +80,8 @@ bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1,
return false;
}
// The elementwise output shapes must be the same (including layout).
- return ShapeUtil::Equal(get_element_shape(element_instr_1),
- get_element_shape(element_instr_2));
+ return ShapeUtil::EqualIgnoringFpPrecision(
+ get_element_shape(element_instr_1), get_element_shape(element_instr_2));
}
namespace {
@@ -107,6 +107,27 @@ bool IsInputFusibleReduction(HloInstruction* instr) {
return IsReductionToVector(*instr);
}
}
+
+// The code emitted for reduction suffers from poor data locality if the layouts
+// of input parameters differ. In such situtations it is beneficial not to fuse.
+// We consider input params with maximum rank only. Params with smaller ranks
+// will be broadcasted and have not been observed to cause data locality issues.
+// TODO(b/110927656): Improve reduce emitters to remove this limitation.
+bool ReduceFriendlyInputLayouts(HloInstruction* instr) {
+ int64 max_rank = 0;
+ const Layout* max_rank_layout;
+ for (HloInstruction* param : instr->fused_parameters()) {
+ if (ShapeUtil::Rank(param->shape()) > max_rank) {
+ max_rank = ShapeUtil::Rank(param->shape());
+ max_rank_layout = &param->shape().layout();
+ }
+ }
+ return c_all_of(instr->fused_parameters(), [&](HloInstruction* param) {
+ return (ShapeUtil::Rank(param->shape()) < max_rank) ||
+ (LayoutUtil::Equal(param->shape().layout(), *max_rank_layout));
+ });
+}
+
} // namespace
bool GpuMultiOutputFusion::IsFusible(HloInstruction* instr) {
@@ -173,29 +194,41 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() {
// fusions operands.
for (HloInstruction* consumer : computation()->MakeInstructionPostOrder()) {
if (consumer->user_count() == 0) {
+ VLOG(3) << consumer->name() << " has no users.";
continue;
}
if (!IsInputFusibleReduction(consumer)) {
+ VLOG(3) << consumer->name() << " is not an input-fusable reduction.";
continue;
}
+ VLOG(3) << consumer->name()
+ << " is a fusion candidate. Looking for fuseable operands.";
auto consumer_operands = consumer->operands();
for (size_t i = 0; i < consumer_operands.size(); ++i) {
HloInstruction* producer = consumer_operands[i];
if (!producer->IsFusable()) {
+ VLOG(3) << producer->name() << " is not fusable.";
continue;
}
const bool is_loop_fusion =
producer->opcode() == HloOpcode::kFusion &&
producer->fusion_kind() == HloInstruction::FusionKind::kLoop;
if (!is_loop_fusion) {
+ VLOG(3) << producer->name() << " is not a loop fusion.";
continue;
}
if (!ShapesCompatibleForFusion(producer, consumer)) {
+ VLOG(3) << producer->name() << " has an incompatible shape.";
+ continue;
+ }
+ if (!ReduceFriendlyInputLayouts(producer)) {
+ VLOG(3) << producer->name() << " has inputs with mixed layouts.";
continue;
}
// If we have already decided to fuse this producer, skip it.
if (ContainsKey(to_fuse, producer)) {
+ VLOG(3) << producer->name() << " will be fused with another consumer.";
continue;
}
// Do not fuse a producer if the other operands of the fusion are
@@ -204,6 +237,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() {
return producer != operand &&
reachability()->IsReachable(producer, operand);
})) {
+ VLOG(3) << producer->name() << " would introduce a cycle when fused.";
break;
}
to_fuse.insert(producer);
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc
index a6dc635b52..451e49f23a 100644
--- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc
@@ -40,7 +40,7 @@ const char kModulePrefix[] = R"(
scalar_mul_computation {
scalar_lhs.1 = f32[] parameter(0)
scalar_rhs.1 = f32[] parameter(1)
- ROOT mul.1 = f32[] add(scalar_lhs.1, scalar_rhs.1)
+ ROOT mul.1 = f32[] multiply(scalar_lhs.1, scalar_rhs.1)
})";
TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) {
@@ -349,5 +349,75 @@ TEST_F(MultiOutputFusionTest, ProducerConsumerFusionDoNotFuseLoopReduceFusion) {
ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
}
+TEST_F(MultiOutputFusionTest,
+ ProducerConsumerFusionFp16LoopFusionAndReduceFusion) {
+ auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ fused_select {
+ p1.1 = f16[2,2,2]{2,1,0} parameter(1)
+ c0 = f16[] constant(0)
+ broadcast = f16[2,2,2]{2,1,0} broadcast(f16[] c0), dimensions={}
+ greater-than = pred[2,2,2]{2,1,0} greater-than(f32[2,2,2]{2,1,0} p1.1, f32[2,2,2]{2,1,0} broadcast)
+ p0.1 = f16[2,2,2]{2,1,0} parameter(0)
+ ROOT select = f16[2,2,2]{2,1,0} select(pred[2,2,2]{2,1,0} greater-than, f16[2,2,2]{2,1,0} p0.1, f16[2,2,2]{2,1,0} broadcast)
+ }
+ fused_reduce {
+ p0.2 = f16[2,2,2]{2,1,0} parameter(0)
+ convert = f32[2,2,2]{2,1,0} convert(p0.2)
+ c1 = f32[] constant(0)
+ r1 = f32[2,2]{1,0} reduce(convert, c1), dimensions={2}, to_apply=scalar_add_computation
+ mul = f32[2,2,2]{2,1,0} multiply(convert, convert)
+ r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=scalar_add_computation
+ ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2)
+ }
+ ENTRY reduce {
+ p0 = f16[2,2,2]{2,1,0} parameter(0)
+ p1 = f16[2,2,2]{2,1,0} parameter(1)
+ select = f16[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_select
+ fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(select), kind=kInput, calls=fused_reduce
+ gte0 = f32[2,2]{1,0} get-tuple-element(fusion), index=0
+ gte1 = f32[2,2]{1,0} get-tuple-element(fusion), index=1
+ ROOT root = (f32[2,2]{1,0}, f32[2,2]{1,0}, f16[2,2,2]{2,1,0}) tuple(gte1, gte1, select)
+ })"))
+ .ValueOrDie();
+ ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Tuple(op::GetTupleElement(), op::GetTupleElement(),
+ op::GetTupleElement()));
+ const HloInstruction* fusion = root->operand(0)->operand(0);
+ ASSERT_TRUE(fusion->IsMultiOutputFusion());
+ EXPECT_THAT(fusion->fused_expression_root(),
+ op::Tuple(op::Reduce(), op::Reduce(), op::Select()));
+}
+
+TEST_F(MultiOutputFusionTest,
+ ProducerConsumerFusionReduceUnfriendlyLoopFusion) {
+ auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ mixed_input_layouts_computation {
+ p0.1 = f16[128,1024,32,32]{1,3,2,0} parameter(0)
+ p1.1 = f16[128,1024,32,32]{3,2,1,0} parameter(1)
+ copy = f16[128,1024,32,32]{1,3,2,0} copy(p1.1)
+ c0 = f16[] constant(0)
+ broadcast = f16[128,1024,32,32]{1,3,2,0} broadcast(c0), dimensions={}
+ greater-than = pred[128,1024,32,32]{1,3,2,0} greater-than(copy, broadcast)
+ ROOT root = f16[128,1024,32,32]{1,3,2,0} select(greater-than, p0.1, broadcast)
+ }
+ fused_reduce {
+ p0.2 = f16[128,1024,32,32]{1,3,2,0} parameter(0)
+ convert = f32[128,1024,32,32]{1,3,2,0} convert(p0.2)
+ c0.2 = f32[] constant(0)
+ ROOT reduce = f32[1024]{0} reduce(convert, c0.2), dimensions={0,2,3}, to_apply=scalar_add_computation
+ }
+ ENTRY reduce {
+ p0 = f16[128,1024,32,32]{3,2,1,0} parameter(0)
+ p1 = f16[128,1024,32,32]{1,3,2,0} parameter(1)
+ loop_fusion = f16[128,1024,32,32]{1,3,2,0} fusion(p0, p1), kind=kLoop, calls=mixed_input_layouts_computation
+ reduce_fusion = f32[1024]{0} fusion(loop_fusion), kind=kInput, calls=fused_reduce
+ ROOT root = (f32[1024]{0}, f16[128,1024,32,32]{1,3,2,0}) tuple(reduce_fusion, loop_fusion)
+ })"))
+ .ValueOrDie();
+ ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
+}
+
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD
index c14a5bfb53..462be543bc 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/BUILD
+++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD
@@ -181,6 +181,20 @@ cc_library(
)
cc_library(
+ name = "sort_util",
+ srcs = ["sort_util.cc"],
+ hdrs = ["sort_util.h"],
+ deps = [
+ ":ir_array",
+ ":llvm_loop",
+ ":llvm_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/core:lib",
+ "@llvm//:core",
+ ],
+)
+
+cc_library(
name = "tuple_ops",
srcs = ["tuple_ops.cc"],
hdrs = ["tuple_ops.h"],
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc
index c9ae7d3afd..1227534779 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc
@@ -262,5 +262,35 @@ IrArray::Index ForLoopNest::AddLoopsForShapeOnDimensions(
return index;
}
+IrArray::Index ForLoopNest::EmitOperandArrayLoopNest(
+ const llvm_ir::IrArray& operand_array, int64 dimension_to_skip,
+ tensorflow::StringPiece name_suffix) {
+ // Prepares the dimension list we will use to emit the loop nest. Outermost
+ // loops are added first. Add loops in major-to-minor order, and skip the
+ // 'dimension_to_skip' dimension.
+ std::vector<int64> dimensions;
+ const Shape& shape = operand_array.GetShape();
+ for (int64 dimension : LayoutUtil::MinorToMajor(shape)) {
+ if (dimension != dimension_to_skip) {
+ dimensions.push_back(dimension);
+ }
+ }
+
+ // Create loop nest with one for-loop for each dimension of the
+ // output.
+ llvm_ir::IrArray::Index index =
+ AddLoopsForShapeOnDimensions(shape, dimensions, name_suffix);
+ // Verify every dimension except the 'dimension_to_skip' dimension was set in
+ // the index.
+ for (size_t dimension = 0; dimension < index.size(); ++dimension) {
+ if (dimension == dimension_to_skip) {
+ DCHECK_EQ(nullptr, index[dimension]);
+ } else {
+ DCHECK_NE(nullptr, index[dimension]);
+ }
+ }
+ return index;
+}
+
} // namespace llvm_ir
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h
index 0dd5b9d3b2..b3266022db 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h
@@ -248,6 +248,17 @@ class ForLoopNest {
const Shape& shape, tensorflow::gtl::ArraySlice<int64> dimensions,
tensorflow::StringPiece suffix);
+ // Emits a series of nested loops for iterating over an operand array. Loops
+ // are constructed in major to minor dimension layout order. No loop is
+ // emitted for the given 'dimension_to_skip'. The function returns an IrArray
+ // index for the given operand_array containing the indvars of the loops. All
+ // dimensions of the index are filled except for 'dimension_to_skip'.
+ // name_suffix is the string to append to the names of LLVM constructs (eg,
+ // basic blocks) constructed by this method.
+ IrArray::Index EmitOperandArrayLoopNest(const llvm_ir::IrArray& operand_array,
+ int64 dimension_to_skip,
+ tensorflow::StringPiece name_suffix);
+
// Convenience methods which return particular basic blocks of the outermost
// or innermost loops. These methods return nullptr if no loops have been
// added yet.
diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc
new file mode 100644
index 0000000000..16a9a5aaeb
--- /dev/null
+++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc
@@ -0,0 +1,201 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h"
+
+// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/Instructions.h"
+#include "tensorflow/compiler/xla/primitive_util.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/core/lib/core/bits.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace llvm_ir {
+
+namespace {
+// Adds the inner comparison loop where we compare elements pointed to by
+// 'keys_index' and 'compare_keys_index'.
+void EmitCompareLoop(int64 dimension_to_sort,
+ const llvm_ir::IrArray::Index& keys_index,
+ const llvm_ir::IrArray::Index& compare_keys_index,
+ const llvm_ir::IrArray& keys_array,
+ llvm::IRBuilder<>* ir_builder) {
+ // TODO(b/26783907): parallelize this loop.
+
+ // if (is_smaller_index &&
+ // compare_keys[dimension_to_sort] < dimension_to_sort_bound)
+ llvm::Value* is_smaller_index = ir_builder->CreateICmpSLT(
+ keys_index[dimension_to_sort], compare_keys_index[dimension_to_sort]);
+ int64 dimension_to_sort_bound =
+ keys_array.GetShape().dimensions(dimension_to_sort);
+ auto if_data = llvm_ir::EmitIfThenElse(
+ ir_builder->CreateAnd(
+ is_smaller_index,
+ ir_builder->CreateICmpSLT(
+ compare_keys_index[dimension_to_sort],
+ keys_index.GetConstantWithIndexType(dimension_to_sort_bound))),
+ "smaller_comparison_index", ir_builder, /*emit_else=*/false);
+ SetToFirstInsertPoint(if_data.true_block, ir_builder);
+ auto key1 = keys_array.EmitReadArrayElement(keys_index, ir_builder);
+ auto key2 = keys_array.EmitReadArrayElement(compare_keys_index, ir_builder);
+ auto key_type = keys_array.GetShape().element_type();
+ auto comparison =
+ primitive_util::IsFloatingPointType(key_type)
+ // TODO(b/26783907): Figure out how to handle NaNs.
+ ? ir_builder->CreateFCmp(llvm::FCmpInst::FCMP_ULT, key1, key2)
+ : ir_builder->CreateICmp(
+ primitive_util::IsSignedIntegralType(key_type)
+ ? llvm::ICmpInst::ICMP_SLT
+ : llvm::ICmpInst::ICMP_ULT,
+ key1, key2);
+ auto min_key = ir_builder->CreateSelect(comparison, key1, key2);
+ auto max_key = ir_builder->CreateSelect(comparison, key2, key1);
+ keys_array.EmitWriteArrayElement(keys_index, min_key, ir_builder);
+ keys_array.EmitWriteArrayElement(compare_keys_index, max_key, ir_builder);
+}
+} // namespace
+
+Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array,
+ tensorflow::StringPiece name,
+ llvm::IRBuilder<>* ir_builder) {
+ const Shape& keys_shape = keys_array.GetShape();
+
+ // TODO(b/26783907): This case can probably be avoided with the Algebraic
+ // Simplifier.
+ if (ShapeUtil::IsScalar(keys_shape)) {
+ return Status::OK();
+ }
+
+ // Create loop nests which loop through the operand dimensions. The sort
+ // dimension is handled in three separate innermost loops which perform the
+ // sorting.
+ ForLoopNest loop_nest(name, ir_builder);
+ IrArray::Index keys_index =
+ loop_nest.EmitOperandArrayLoopNest(keys_array, dimension_to_sort, "keys");
+
+ // 'compare_keys_index' is the index of the element that 'keys_index' should
+ // be compared to.
+ IrArray::Index compare_keys_index(keys_index.GetType());
+ for (size_t dimension = 0; dimension < keys_index.size(); ++dimension) {
+ if (dimension != dimension_to_sort) {
+ compare_keys_index.push_back(keys_index[dimension]);
+ } else {
+ compare_keys_index.push_back(nullptr);
+ }
+ }
+
+ // Create the sorting loops which do the sorting.
+ int64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort);
+ std::unique_ptr<ForLoop> stages_loop = loop_nest.AddLoop(
+ /*start_index=*/0,
+ /*end_index=*/
+ tensorflow::Log2Ceiling64(dimension_to_sort_bound),
+ /*suffix=*/"sort_stages");
+ std::unique_ptr<ForLoop> mask_loop = loop_nest.AddLoop(
+ /*suffix=*/"mask",
+ /*start_index=*/keys_index.GetConstantWithIndexType(0),
+ /*end_index=*/stages_loop->GetIndVarValue());
+ std::unique_ptr<ForLoop> compare_loop = loop_nest.AddLoop(
+ /*start_index=*/0,
+ /*end_index=*/dimension_to_sort_bound,
+ /*suffix=*/"compare");
+
+ // Naive C++ code for the inner loops (without parallelization):
+ //
+ // for (int64 stage = 0; stage < Log2Ceiling(dimension_to_sort_bound);
+ // ++stage) {
+ // int64 first_xor_mask = (1LL << (stage + 1)) - 1;
+ // for (int64 i = 0; i < dimension_to_sort_bound; ++i) {
+ // int64 j = i ^ first_xor_mask;
+ // if (i < j && j < dimension_to_sort_bound) {
+ // int64 min_key = std::min(keys[i], keys[j]);
+ // keys[j] = std::max(keys[i], keys[j]);
+ // keys[i] = min_key;
+ // }
+ // }
+ // for (int64 mask = 0; mask < stage; ++mask) {
+ // int64 later_xor_mask = (1LL << (stage - (mask + 1));
+ // for (int64 i = 0; i < dimension_to_sort_bound; ++i) {
+ // int64 j = i ^ later_xor_mask;
+ // if (i < j && j < dimension_to_sort_bound) {
+ // int64 min_key = std::min(keys[i], keys[j]);
+ // keys[j] = std::max(keys[i], keys[j]);
+ // keys[i] = min_key;
+ // }
+ // }
+ // }
+ // }
+ //
+ // This follows the algorithm described on Wikipedia:
+ // https://en.wikipedia.org/wiki/Bitonic_sorter
+
+ SetToFirstInsertPoint(stages_loop->GetBodyBasicBlock(), ir_builder);
+ // The first xor mask of a stage is 2^(stage + 1) - 1.
+ auto first_xor_mask = ir_builder->CreateSub(
+ ir_builder->CreateShl(
+ keys_index.GetConstantWithIndexType(1),
+ ir_builder->CreateAdd(stages_loop->GetIndVarValue(),
+ keys_index.GetConstantWithIndexType(1))),
+ keys_index.GetConstantWithIndexType(1));
+ std::unique_ptr<ForLoop> first_compare_loop = ForLoop::EmitForLoop(
+ /*prefix=*/"first_compare",
+ /*start_index=*/keys_index.GetConstantWithIndexType(0),
+ /*end_index=*/
+ keys_index.GetConstantWithIndexType(dimension_to_sort_bound),
+ /*step=*/keys_index.GetConstantWithIndexType(1),
+ /*ir_builder=*/ir_builder);
+
+ SetToFirstInsertPoint(first_compare_loop->GetBodyBasicBlock(), ir_builder);
+ // 'first_compare_loop' iterates through the 'dimension_to_sort'.
+ keys_index[dimension_to_sort] = first_compare_loop->GetIndVarValue();
+ compare_keys_index[dimension_to_sort] = ir_builder->CreateXor(
+ first_compare_loop->GetIndVarValue(), first_xor_mask);
+ EmitCompareLoop(dimension_to_sort, keys_index, compare_keys_index, keys_array,
+ ir_builder);
+
+ SetToFirstInsertPoint(compare_loop->GetPreheaderBasicBlock(), ir_builder);
+ // The later masks of a stage are 2^(stage - (mask_loop_ind_var + 1)).
+ auto later_xor_mask = ir_builder->CreateShl(
+ keys_index.GetConstantWithIndexType(1),
+ ir_builder->CreateSub(
+ stages_loop->GetIndVarValue(),
+ ir_builder->CreateAdd(mask_loop->GetIndVarValue(),
+ keys_index.GetConstantWithIndexType(1))));
+
+ SetToFirstInsertPoint(compare_loop->GetBodyBasicBlock(), ir_builder);
+ // 'compare_loop' iterates through the 'dimension_to_sort'.
+ keys_index[dimension_to_sort] = compare_loop->GetIndVarValue();
+ compare_keys_index[dimension_to_sort] =
+ ir_builder->CreateXor(compare_loop->GetIndVarValue(), later_xor_mask);
+ EmitCompareLoop(dimension_to_sort, keys_index, compare_keys_index, keys_array,
+ ir_builder);
+
+ // Set the IR builder insert point to the exit basic block of the outer most
+ // loop. This ensures later instructions are inserted after this loop nest.
+ ir_builder->SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock());
+
+ return Status::OK();
+}
+
+} // namespace llvm_ir
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h
new file mode 100644
index 0000000000..fc45bfab12
--- /dev/null
+++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h
@@ -0,0 +1,34 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_SORT_UTIL_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_SORT_UTIL_H_
+
+#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace llvm_ir {
+// Emits llvm IR to sort the 'dimension_to_sort' dimension of 'keys_array' into
+// ascending order.
+Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array,
+ tensorflow::StringPiece name,
+ llvm::IRBuilder<>* ir_builder);
+} // namespace llvm_ir
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_SORT_UTIL_H_
diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc
index 10fc4958fa..62af45128a 100644
--- a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc
+++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc
@@ -61,6 +61,12 @@ StatusOr<bool> WhileLoopConstantSinking::TrySinkingConstantsIntoWhileBody(
WhileUtil::GetInvariantGTEsForWhileBody(*while_body)) {
int64 index = invariant_gte->tuple_index();
const HloInstruction& invariant_value = *init_value.operand(index);
+
+ // Should have at least one user that's not while_body_root.
+ if (invariant_gte->user_count() <= 1) {
+ continue;
+ }
+
if (invariant_value.opcode() == HloOpcode::kConstant) {
auto* constant_instr =
while_body->AddInstruction(invariant_value.Clone(/*suffix=*/".sunk"));
diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc
index 393e758038..266039d2ff 100644
--- a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc
+++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc
@@ -196,5 +196,50 @@ ENTRY entry {
op::GetTupleElement(op::Parameter(0)),
op::GetTupleElement(op::Parameter(0))));
}
+
+TEST_F(WhileLoopConstantSinkingTest, DontCreateDeadConstant) {
+ const char* const hlo_string = R"(
+HloModule ModuleWithWhile
+
+body {
+ p_body = (f32[2],f32[2]) parameter(0)
+ p_body.0 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=0
+ p_body.1 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=1
+
+ outfeed = token[] outfeed(p_body.0)
+ ROOT root = (f32[2],f32[2],f32[2]) tuple(p_body.0, p_body.1, p_body.1)
+}
+
+condition {
+ p_cond = (f32[2],f32[2]) parameter(0)
+ ROOT result = pred[] constant(true)
+}
+
+ENTRY entry {
+ const_0 = f32[2] constant({1, 2})
+ const_1 = f32[2] constant({2, 1})
+ while_init = (f32[2],f32[2]) tuple(const_0, const_1)
+ ROOT while = (f32[2],f32[2],f32[2]) while(while_init), condition=condition,
+ body=body
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(hlo_string));
+
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ WhileLoopConstantSinking{}.Run(module.get()));
+ ASSERT_TRUE(changed);
+
+ auto* while_body = module->GetComputationWithName("body");
+ EXPECT_THAT(while_body->root_instruction(),
+ op::Tuple(op::GetTupleElement(), op::GetTupleElement(),
+ op::GetTupleElement()));
+ for (const HloInstruction* inst : while_body->instructions()) {
+ if (inst->opcode() == HloOpcode::kConstant) {
+ EXPECT_GT(inst->user_count(), 0);
+ }
+ }
+}
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h
index 4aacc87b78..c74dd648ad 100644
--- a/tensorflow/compiler/xla/shape_tree.h
+++ b/tensorflow/compiler/xla/shape_tree.h
@@ -44,10 +44,6 @@ struct ShapeTreeNode {
// Data corresponding to this node.
std::pair<ShapeIndex, T> data;
- // Children of this node, as indices into the container's nodes_ array.
- std::vector<size_t> children;
-
- // Tells whether this is a leaf node.
bool is_leaf = true;
explicit ShapeTreeNode(ShapeIndex index)
@@ -56,6 +52,20 @@ struct ShapeTreeNode {
: data(std::move(index), std::move(data)) {}
};
+// Internal representation of an index table entry.
+struct IndexTableEntry {
+ // Index of the node in the ShapeTreeNode vector.
+ uint32 index;
+ // Index of the first child in a IndexTableEntry vector. In the index
+ // table all children entries for a given node will be placed next to each
+ // other. This allows us to use a single field to index them.
+ uint32 children_start;
+#ifndef NDEBUG
+ // Number of children, used for bounds checking.
+ uint32 children_count;
+#endif
+};
+
} // namespace internal
template <typename ContainerType, typename IteratorType, typename ValueType>
@@ -84,6 +94,7 @@ template <typename T>
class ShapeTree {
public:
using Node = internal::ShapeTreeNode<T>;
+ using Index = internal::IndexTableEntry;
// Default constructor creates a tree with a nil shape (i.e. an empty tuple).
ShapeTree() : ShapeTree(ShapeUtil::MakeNil()) {}
@@ -267,11 +278,12 @@ class ShapeTree {
private:
// Initialize node->children based on 'shape'. All children are assigned the
// the given 'init_value'.
- void InitChildren(const Shape& shape, const T& init_value, Node* node);
+ void InitChildren(const Shape& shape, const T& init_value, Node* node,
+ Index* index);
// Initialize node->children based on 'shape'. All children have
// default-constructed data values.
- void InitChildren(const Shape& shape, Node* node);
+ void InitChildren(const Shape& shape, Node* node, Index* index);
// Returns the number of subshapes, including interior nodes, in shape.
int64 CountSubshapes(const Shape& shape);
@@ -291,6 +303,9 @@ class ShapeTree {
// The nodes in this shape tree.
std::vector<Node> nodes_;
+ // Index table for node lookups.
+ std::vector<Index> index_table_;
+
// If we own our Shape, this field contains it, and shape_ is a pointer into
// here. Otherwise if we don't own our shape, this is nullptr.
std::shared_ptr<Shape> shape_storage_;
@@ -373,36 +388,74 @@ int64 ShapeTree<T>::CountSubshapes(const Shape& shape) {
template <typename T>
void ShapeTree<T>::InitChildren(const Shape& shape, const T& init_value,
- Node* node) {
+ Node* node, Index* index) {
if (ShapeUtil::IsTuple(shape)) {
const int64 size = ShapeUtil::TupleElementCount(shape);
- node->children.reserve(size);
+#ifndef NDEBUG
+ index->children_count = size;
+#endif
node->is_leaf = false;
ShapeIndex shape_index = node->data.first;
shape_index.push_back(0);
+
+ // At the end of the index_table, reserve a continuous space to hold the
+ // children of current node. In order to enforce the invariant that all
+ // children of a given node are placed together, we need to do the
+ // reservation before we recurse into any of its children.
+ int64 children_start_position = index_table_.size();
+ index_table_.resize(index_table_.size() + size);
+
for (int i = 0; i < size; ++i) {
shape_index[shape_index.size() - 1] = i;
- node->children.push_back(nodes_.size());
+ index_table_[children_start_position + i].index = nodes_.size();
+ // The first child of the node in the index table is placed at the end of
+ // the table.
+ index_table_[children_start_position + i].children_start =
+ index_table_.size();
nodes_.emplace_back(shape_index, init_value);
- InitChildren(shape.tuple_shapes(i), init_value, &nodes_.back());
+ InitChildren(shape.tuple_shapes(i), init_value, &nodes_.back(),
+ &index_table_[children_start_position + i]);
}
+ } else {
+#ifndef NDEBUG
+ index->children_count = 0;
+#endif
}
}
template <typename T>
-void ShapeTree<T>::InitChildren(const Shape& shape, Node* node) {
+void ShapeTree<T>::InitChildren(const Shape& shape, Node* node, Index* index) {
if (ShapeUtil::IsTuple(shape)) {
const int64 size = ShapeUtil::TupleElementCount(shape);
- node->children.reserve(size);
+#ifndef NDEBUG
+ index->children_count = size;
+#endif
node->is_leaf = false;
ShapeIndex shape_index = node->data.first;
shape_index.push_back(0);
+
+ // At the end of the index_table, reserve a continuous space to hold the
+ // children of current node. In order to enforce the invariant that all
+ // children of a given node are placed together, we need to do the
+ // reservation before we recurse into any of its children.
+ int64 children_start_position = index_table_.size();
+ index_table_.resize(index_table_.size() + size);
+
for (int i = 0; i < size; ++i) {
shape_index[shape_index.size() - 1] = i;
- node->children.push_back(nodes_.size());
+ index_table_[children_start_position + i].index = nodes_.size();
+ // The first child of the node in the index table is placed at the end of
+ // the table.
+ index_table_[children_start_position + i].children_start =
+ index_table_.size();
nodes_.emplace_back(shape_index);
- InitChildren(shape.tuple_shapes(i), &nodes_.back());
+ InitChildren(shape.tuple_shapes(i), &nodes_.back(),
+ &index_table_[children_start_position + i]);
}
+ } else {
+#ifndef NDEBUG
+ index->children_count = 0;
+#endif
}
}
@@ -413,24 +466,36 @@ ShapeTree<T>::ShapeTree(Shape shape)
// The shape_ field is just used to hold the structure of the shape.
// It should not be relied upon to store layout information.
LayoutUtil::ClearLayout(shape_storage_.get());
- nodes_.reserve(CountSubshapes(*shape_));
+ const int64 count = CountSubshapes(*shape_);
+ nodes_.reserve(count);
nodes_.emplace_back(ShapeIndex{});
- InitChildren(*shape_, &nodes_[0]);
+
+ index_table_.reserve(count);
+ index_table_.emplace_back(Index{0, 1});
+ InitChildren(*shape_, &nodes_[0], &index_table_[0]);
}
template <typename T>
ShapeTree<T>::ShapeTree(const Shape* shape) : shape_(shape) {
- nodes_.reserve(CountSubshapes(*shape_));
+ const int64 count = CountSubshapes(*shape_);
+ nodes_.reserve(count);
nodes_.emplace_back(ShapeIndex{});
- InitChildren(*shape_, &nodes_[0]);
+
+ index_table_.reserve(count);
+ index_table_.emplace_back(Index{0, 1});
+ InitChildren(*shape_, &nodes_[0], &index_table_[0]);
}
template <typename T>
ShapeTree<T>::ShapeTree(const std::shared_ptr<Shape>& shape)
: shape_storage_(shape), shape_(shape_storage_.get()) {
- nodes_.reserve(CountSubshapes(*shape_));
+ const int64 count = CountSubshapes(*shape_);
+ nodes_.reserve(count);
nodes_.emplace_back(ShapeIndex{});
- InitChildren(*shape_, &nodes_[0]);
+
+ index_table_.reserve(count);
+ index_table_.emplace_back(Index{0, 1});
+ InitChildren(*shape_, &nodes_[0], &index_table_[0]);
}
template <typename T>
@@ -440,26 +505,38 @@ ShapeTree<T>::ShapeTree(Shape shape, const T& init_value)
// The shape_ field is just used to hold the structure of the shape.
// It should not be relied upon to store layout information.
LayoutUtil::ClearLayout(shape_storage_.get());
- nodes_.reserve(CountSubshapes(*shape_));
+ const int64 count = CountSubshapes(*shape_);
+ nodes_.reserve(count);
nodes_.emplace_back(ShapeIndex{}, init_value);
- InitChildren(*shape_, init_value, &nodes_[0]);
+
+ index_table_.reserve(count);
+ index_table_.emplace_back(Index{0, 1});
+ InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]);
}
template <typename T>
ShapeTree<T>::ShapeTree(const Shape* shape, const T& init_value)
: shape_(shape) {
- nodes_.reserve(CountSubshapes(*shape_));
+ const int64 count = CountSubshapes(*shape_);
+ nodes_.reserve(count);
nodes_.emplace_back(ShapeIndex{}, init_value);
- InitChildren(*shape_, init_value, &nodes_[0]);
+
+ index_table_.reserve(count);
+ index_table_.emplace_back(Index{0, 1});
+ InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]);
}
template <typename T>
ShapeTree<T>::ShapeTree(const std::shared_ptr<Shape>& shape,
const T& init_value)
: shape_storage_(shape), shape_(shape_storage_.get()) {
- nodes_.reserve(CountSubshapes(*shape_));
+ const int64 count = CountSubshapes(*shape_);
+ nodes_.reserve(count);
nodes_.emplace_back(ShapeIndex{}, init_value);
- InitChildren(*shape_, init_value, &nodes_[0]);
+
+ index_table_.reserve(count);
+ index_table_.emplace_back(Index{0, 1});
+ InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]);
}
template <typename T>
@@ -474,13 +551,16 @@ T* ShapeTree<T>::mutable_element(ShapeIndexView index) {
template <typename T>
internal::ShapeTreeNode<T>* ShapeTree<T>::Lookup(ShapeIndexView index) {
- Node* node = &nodes_[0];
+ Index* iter = &index_table_[0];
for (const int64 i : index) {
CHECK_GE(i, 0);
- CHECK_LT(i, node->children.size());
- node = &nodes_[node->children[i]];
+#ifndef NDEBUG
+ CHECK_LT(i, iter->children_count);
+#endif
+ iter = &index_table_[iter->children_start + i];
}
- return node;
+
+ return &nodes_[iter->index];
}
template <typename T>
diff --git a/tensorflow/compiler/xla/shape_tree_test.cc b/tensorflow/compiler/xla/shape_tree_test.cc
index 51de82e957..4391078b64 100644
--- a/tensorflow/compiler/xla/shape_tree_test.cc
+++ b/tensorflow/compiler/xla/shape_tree_test.cc
@@ -227,14 +227,16 @@ TEST_F(ShapeTreeTest, NestedTupleShape) {
TEST_F(ShapeTreeTest, InvalidIndexingTuple) {
ShapeTree<int> shape_tree{tuple_shape_};
-
+#ifndef NDEBUG
EXPECT_DEATH(shape_tree.element({4}), "");
+#endif
}
TEST_F(ShapeTreeTest, InvalidIndexingNestedTuple) {
ShapeTree<int> shape_tree{nested_tuple_shape_};
-
+#ifndef NDEBUG
EXPECT_DEATH(shape_tree.element({0, 0}), "");
+#endif
}
TEST_F(ShapeTreeTest, ShapeTreeOfNonCopyableType) {
@@ -602,12 +604,15 @@ void BM_Iterate(int iters, int depth, int fan_out) {
}
}
-BENCHMARK(BM_Construct)->ArgPair(2, 8);
-BENCHMARK(BM_ConstructUnowned)->ArgPair(2, 8);
-BENCHMARK(BM_Copy)->ArgPair(2, 8);
-BENCHMARK(BM_Move)->ArgPair(2, 8);
-BENCHMARK(BM_ForEach)->ArgPair(2, 8);
-BENCHMARK(BM_Iterate)->ArgPair(2, 8);
+#define BENCHMARK_WITH_ARGS(name) \
+ BENCHMARK(name)->ArgPair(2, 8)->ArgPair(1, 1000)
+
+BENCHMARK_WITH_ARGS(BM_Construct);
+BENCHMARK_WITH_ARGS(BM_ConstructUnowned);
+BENCHMARK_WITH_ARGS(BM_Copy);
+BENCHMARK_WITH_ARGS(BM_Move);
+BENCHMARK_WITH_ARGS(BM_ForEach);
+BENCHMARK_WITH_ARGS(BM_Iterate);
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index f4668c0f55..6480148336 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -883,40 +883,51 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
}
int64 shape_size = [&shape]() {
- int64 shape_size;
if (LayoutUtil::IsSparseArray(shape)) {
- shape_size = LayoutUtil::MaxSparseElements(shape.layout());
- if (shape_size < 0) {
- return shape_size;
+ int64 max_sparse_elements = LayoutUtil::MaxSparseElements(shape.layout());
+ if (max_sparse_elements < 0) {
+ return max_sparse_elements;
}
- shape_size = MultiplyWithoutOverflow(shape_size, ShapeUtil::Rank(shape));
- if (shape_size < 0) {
- return shape_size;
+ int64 sparse_elements_size = MultiplyWithoutOverflow(
+ max_sparse_elements, ByteSizeOfPrimitiveType(shape.element_type()));
+ if (sparse_elements_size < 0) {
+ return sparse_elements_size;
}
- shape_size = MultiplyWithoutOverflow(shape_size, sizeof(int64));
- if (shape_size < 0) {
- return shape_size;
+ int64 sparse_indices_size =
+ MultiplyWithoutOverflow(max_sparse_elements, ShapeUtil::Rank(shape));
+ if (sparse_indices_size < 0) {
+ return sparse_indices_size;
+ }
+ sparse_indices_size =
+ MultiplyWithoutOverflow(sparse_indices_size, sizeof(int64));
+ if (sparse_indices_size < 0) {
+ return sparse_indices_size;
+ }
+ // At this point, both sparse_indices_size and sparse_elements_size are
+ // non-negative, so we can easily check if adding them wraps.
+ if (static_cast<uint64>(sparse_elements_size) +
+ static_cast<uint64>(sparse_indices_size) >
+ INT64_MAX) {
+ return static_cast<int64>(-1);
}
}
- shape_size = 1;
-
// This is intentionally unconditional: even if the shape is sparse, we want
// to verify the densified version has a reasonable size.
+ int64 dense_shape_size = 1;
if (shape.dimensions().empty()) {
- return shape_size;
+ return dense_shape_size;
}
for (int64 dim : shape.dimensions()) {
- shape_size = MultiplyWithoutOverflow(shape_size, dim);
- if (shape_size < 0) {
- return shape_size;
+ dense_shape_size = MultiplyWithoutOverflow(dense_shape_size, dim);
+ if (dense_shape_size < 0) {
+ return dense_shape_size;
}
}
- shape_size = MultiplyWithoutOverflow(
- shape_size, ByteSizeOfPrimitiveType(shape.element_type()));
-
- return shape_size;
+ dense_shape_size = MultiplyWithoutOverflow(
+ dense_shape_size, ByteSizeOfPrimitiveType(shape.element_type()));
+ return dense_shape_size;
}();
if (shape_size < 0) {
diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h
index 17c1d7b10a..d6f17fc965 100644
--- a/tensorflow/compiler/xla/shape_util.h
+++ b/tensorflow/compiler/xla/shape_util.h
@@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/env.h"
@@ -73,10 +74,12 @@ class ShapeIndex {
// push_front is O(n^2), but shapes don't usually have a ton of dimensions.
void push_front(int64 value) { indices_.insert(indices_.begin(), value); }
- std::vector<int64>::const_iterator begin() const { return indices_.begin(); }
- std::vector<int64>::const_iterator end() const { return indices_.end(); }
- std::vector<int64>::iterator begin() { return indices_.begin(); }
- std::vector<int64>::iterator end() { return indices_.end(); }
+ using container_type = tensorflow::gtl::InlinedVector<int64, 2>;
+
+ container_type::const_iterator begin() const { return indices_.begin(); }
+ container_type::const_iterator end() const { return indices_.end(); }
+ container_type::iterator begin() { return indices_.begin(); }
+ container_type::iterator end() { return indices_.end(); }
const int64* data() const { return indices_.data(); }
@@ -97,7 +100,7 @@ class ShapeIndex {
string ToString() const;
private:
- std::vector<int64> indices_;
+ container_type indices_;
};
// A view into a ShapeIndex as above, with the cheap/easy ability to consume the
diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc
index 1407fca72f..e4a8ddf86a 100644
--- a/tensorflow/compiler/xla/tests/reduce_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_test.cc
@@ -125,10 +125,10 @@ class ReduceTest : public ClientLibraryTestBase {
XlaComputation reduce;
if (and_reduce) {
init_value = ConstantR0<bool>(&builder, true);
- reduce = CreateScalarAndComputation(&builder);
+ reduce = CreateScalarAndComputation(PRED, &builder);
} else {
init_value = ConstantR0<bool>(&builder, false);
- reduce = CreateScalarOrComputation(&builder);
+ reduce = CreateScalarOrComputation(PRED, &builder);
}
Reduce(pred_values, init_value, reduce,
/*dimensions_to_reduce=*/{0});
@@ -163,10 +163,10 @@ class ReduceTest : public ClientLibraryTestBase {
XlaComputation reduce_op;
if (and_reduce) {
init_value = ConstantR0<bool>(&builder, true);
- reduce_op = CreateScalarAndComputation(&builder);
+ reduce_op = CreateScalarAndComputation(PRED, &builder);
} else {
init_value = ConstantR0<bool>(&builder, false);
- reduce_op = CreateScalarOrComputation(&builder);
+ reduce_op = CreateScalarOrComputation(PRED, &builder);
}
Reduce(input_pred, init_value, reduce_op,
@@ -798,13 +798,17 @@ XLA_TEST_F(ReduceTest, VectorizedReduce_Min) {
XLA_TEST_F(ReduceTest, VectorizedReduce_BooleanAnd) {
RunVectorizedReduceTestForType<bool>(
- static_cast<FuncGenerator>(CreateScalarAndComputation),
+ static_cast<FuncGenerator>([](XlaBuilder* builder) {
+ return CreateScalarAndComputation(PRED, builder);
+ }),
[](bool a, bool b) { return a && b; }, true);
}
XLA_TEST_F(ReduceTest, VectorizedReduce_BooleanOr) {
RunVectorizedReduceTestForType<bool>(
- static_cast<FuncGenerator>(CreateScalarOrComputation),
+ static_cast<FuncGenerator>([](XlaBuilder* builder) {
+ return CreateScalarOrComputation(PRED, builder);
+ }),
[](bool a, bool b) { return a || b; }, false);
}
@@ -963,5 +967,32 @@ XLA_TEST_F(ReduceTest, ReduceIdentity) {
ErrorSpec(0.0001));
}
+XLA_TEST_F(ReduceTest, AndReduceU64) {
+ XlaBuilder builder(TestName());
+ Array2D<uint64> initializer = {{0x123456789ABCDEF0LL, 0x3BCDEF12A4567890LL},
+ {0XFFFFFFFFFFFFFFD6LL, 101},
+ {1, 0XFFFFFFFFFFFFFFFFLL}};
+ auto reducer = CreateScalarAndComputation(U64, &builder);
+ auto m = ConstantR2FromArray2D(&builder, initializer);
+ Reduce(m, ConstantR0<uint64>(&builder, 0xFFFFFFFFFFFFFFFFLL), reducer, {1});
+
+ std::vector<uint64> expected = {0x1204461080145890LL, 68, 1};
+ ComputeAndCompareR1<uint64>(&builder, expected, {});
+}
+
+XLA_TEST_F(ReduceTest, OrReduceU64) {
+ XlaBuilder builder(TestName());
+ Array2D<uint64> initializer = {{0x123456789ABCDEF0LL, 0x3BCDEF12A4567890LL},
+ {0xFFFFFFFFFFFFFFD6LL, 101},
+ {1, 0xCAFEBEEFABABABABLL}};
+ auto reducer = CreateScalarOrComputation(U64, &builder);
+ auto m = ConstantR2FromArray2D(&builder, initializer);
+ Reduce(m, ConstantR0<uint64>(&builder, 0), reducer, {1});
+
+ std::vector<uint64> expected = {0X3BFDFF7ABEFEFEF0LL, 0XFFFFFFFFFFFFFFF7LL,
+ 0xCAFEBEEFABABABABLL};
+ ComputeAndCompareR1<uint64>(&builder, expected, {});
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/contrib/autograph/README.md b/tensorflow/contrib/autograph/README.md
index 679ab48e5c..cc54da4daa 100644
--- a/tensorflow/contrib/autograph/README.md
+++ b/tensorflow/contrib/autograph/README.md
@@ -1,6 +1,6 @@
# AutoGraph
-IMPORTANT: AutoGraph is alpha software, and under active development. Expect rough edges and bugs, but if you try it, we appreciate early feedback! We'd also love contributions ([please see our contributing guidelines](CONTRIBUTING.md) and our [style guide](STYLE_GUIDE.md)).
+IMPORTANT: AutoGraph is beta software, and under active development. Expect rough edges and bugs, but if you try it, we appreciate early feedback! We'd also love contributions ([please see our contributing guidelines](CONTRIBUTING.md) and our [style guide](STYLE_GUIDE.md)).
AutoGraph is a Python to TensorFlow compiler.
@@ -68,12 +68,21 @@ Then import the `autograph` module from `tf.contrib`:
from tensorflow.contrib import autograph as ag
```
-### Interactive demo notebooks
+### Related links
-For more extensive examples, check out these interactive notebooks:
+Articles:
- * [RNN trained using Keras and Estimators](https://colab.sandbox.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb)
+ * [TensorFlow blog post](https://medium.com/tensorflow/autograph-converts-python-into-tensorflow-graphs-b2a871f87ec7)
+
+Interactive notebooks:
+
+ * [Quick guide](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/guide/autograph.ipynb)
+ * [RNN trained using Keras and Estimators](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb)
* [Demo from the TF Dev Summit 2018](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb)
+ * [Basic control flow speed test](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/ag_vs_eager_collatz_speed_test.ipynb)
+ * [MNIST training speed test](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/ag_vs_eager_mnist_speed_test.ipynb)
+ * [Basic algorithm samples](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/algorithms.ipynb)
+ * [Introductory workshop support notebook](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/workshop.ipynb)
## Using with annotations
diff --git a/tensorflow/contrib/autograph/pyct/compiler.py b/tensorflow/contrib/autograph/pyct/compiler.py
index c172ab21f6..c90a5e89c2 100644
--- a/tensorflow/contrib/autograph/pyct/compiler.py
+++ b/tensorflow/contrib/autograph/pyct/compiler.py
@@ -71,7 +71,16 @@ def _build_source_map(node, code):
def ast_to_source(node, indentation=' '):
- """Return the source code of given AST."""
+ """Return the source code of given AST.
+
+ Args:
+ node: The code to compile, as an AST object.
+ indentation: The string to use for indentation.
+
+ Returns:
+ code: The source code generated from the AST object
+ source_mapping: A mapping between the user and AutoGraph generated code.
+ """
original_node = node
if isinstance(node, gast.AST):
node = gast.gast_to_ast(node)
@@ -105,7 +114,8 @@ def ast_to_object(node,
exit.
Returns:
- A module object containing the compiled source code.
+ compiled_node: A module object containing the compiled source code.
+ source: The source code of the compiled object
Raises:
ValueError: If ag_source_map__ is already in the namespace of the compiled
node.
diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py
index 675330716b..7878e46e88 100644
--- a/tensorflow/contrib/data/__init__.py
+++ b/tensorflow/contrib/data/__init__.py
@@ -52,6 +52,7 @@ See @{$guide/datasets$Importing Data} for an overview.
@@prefetch_to_device
@@read_batch_features
@@rejection_resample
+@@reduce_dataset
@@sample_from_datasets
@@scan
@@shuffle_and_repeat
@@ -77,6 +78,7 @@ from tensorflow.contrib.data.python.ops.counter import Counter
from tensorflow.contrib.data.python.ops.enumerate_ops import enumerate_dataset
from tensorflow.contrib.data.python.ops.error_ops import ignore_errors
from tensorflow.contrib.data.python.ops.get_single_element import get_single_element
+from tensorflow.contrib.data.python.ops.get_single_element import reduce_dataset
from tensorflow.contrib.data.python.ops.grouping import bucket_by_sequence_length
from tensorflow.contrib.data.python.ops.grouping import group_by_reducer
from tensorflow.contrib.data.python.ops.grouping import group_by_window
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index d372bed479..036dc795bb 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -60,7 +60,7 @@ py_test(
py_test(
name = "csv_dataset_op_test",
- size = "small",
+ size = "medium",
srcs = ["csv_dataset_op_test.py"],
srcs_version = "PY2AND3",
tags = ["no_pip"],
@@ -121,6 +121,7 @@ py_test(
srcs = ["get_single_element_test.py"],
deps = [
"//tensorflow/contrib/data/python/ops:get_single_element",
+ "//tensorflow/contrib/data/python/ops:grouping",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
@@ -128,6 +129,7 @@ py_test(
"//tensorflow/python:errors",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python/data/ops:dataset_ops",
+ "@absl_py//absl/testing:parameterized",
],
)
diff --git a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py
index 87b7c6ddb7..e6883d53e0 100644
--- a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py
@@ -17,9 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
+import numpy as np
+
from tensorflow.contrib.data.python.ops import get_single_element
+from tensorflow.contrib.data.python.ops import grouping
from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
@@ -27,40 +30,69 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class GetSingleElementTest(test.TestCase):
+class GetSingleElementTest(test.TestCase, parameterized.TestCase):
- def testGetSingleElement(self):
- skip_value = array_ops.placeholder(dtypes.int64, shape=[])
- take_value = array_ops.placeholder_with_default(
- constant_op.constant(1, dtype=dtypes.int64), shape=[])
+ @parameterized.named_parameters(
+ ("Zero", 0, 1),
+ ("Five", 5, 1),
+ ("Ten", 10, 1),
+ ("Empty", 100, 1, errors.InvalidArgumentError, "Dataset was empty."),
+ ("MoreThanOne", 0, 2, errors.InvalidArgumentError,
+ "Dataset had more than one element."),
+ )
+ def testGetSingleElement(self, skip, take, error=None, error_msg=None):
+ skip_t = array_ops.placeholder(dtypes.int64, shape=[])
+ take_t = array_ops.placeholder(dtypes.int64, shape=[])
def make_sparse(x):
x_1d = array_ops.reshape(x, [1])
x_2d = array_ops.reshape(x, [1, 1])
return sparse_tensor.SparseTensor(x_2d, x_1d, x_1d)
- dataset = (dataset_ops.Dataset.range(100)
- .skip(skip_value)
- .map(lambda x: (x * x, make_sparse(x)))
- .take(take_value))
-
+ dataset = dataset_ops.Dataset.range(100).skip(skip_t).map(
+ lambda x: (x * x, make_sparse(x))).take(take_t)
element = get_single_element.get_single_element(dataset)
with self.test_session() as sess:
- for x in [0, 5, 10]:
- dense_val, sparse_val = sess.run(element, feed_dict={skip_value: x})
- self.assertEqual(x * x, dense_val)
- self.assertAllEqual([[x]], sparse_val.indices)
- self.assertAllEqual([x], sparse_val.values)
- self.assertAllEqual([x], sparse_val.dense_shape)
-
- with self.assertRaisesRegexp(errors.InvalidArgumentError,
- "Dataset was empty."):
- sess.run(element, feed_dict={skip_value: 100})
-
- with self.assertRaisesRegexp(errors.InvalidArgumentError,
- "Dataset had more than one element."):
- sess.run(element, feed_dict={skip_value: 0, take_value: 2})
+ if error is None:
+ dense_val, sparse_val = sess.run(
+ element, feed_dict={
+ skip_t: skip,
+ take_t: take
+ })
+ self.assertEqual(skip * skip, dense_val)
+ self.assertAllEqual([[skip]], sparse_val.indices)
+ self.assertAllEqual([skip], sparse_val.values)
+ self.assertAllEqual([skip], sparse_val.dense_shape)
+ else:
+ with self.assertRaisesRegexp(error, error_msg):
+ sess.run(element, feed_dict={skip_t: skip, take_t: take})
+
+ @parameterized.named_parameters(
+ ("SumZero", 0),
+ ("SumOne", 1),
+ ("SumFive", 5),
+ ("SumTen", 10),
+ )
+ def testReduceDataset(self, stop):
+ def init_fn(_):
+ return np.int64(0)
+
+ def reduce_fn(state, value):
+ return state + value
+
+ def finalize_fn(state):
+ return state
+
+ sum_reducer = grouping.Reducer(init_fn, reduce_fn, finalize_fn)
+
+ stop_t = array_ops.placeholder(dtypes.int64, shape=[])
+ dataset = dataset_ops.Dataset.range(stop_t)
+ element = get_single_element.reduce_dataset(dataset, sum_reducer)
+
+ with self.test_session() as sess:
+ value = sess.run(element, feed_dict={stop_t: stop})
+ self.assertEqual(stop * (stop - 1) / 2, value)
if __name__ == "__main__":
diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD
index 160d7fe22a..1ad021ea03 100644
--- a/tensorflow/contrib/data/python/ops/BUILD
+++ b/tensorflow/contrib/data/python/ops/BUILD
@@ -28,10 +28,12 @@ py_library(
srcs = ["get_single_element.py"],
srcs_version = "PY2AND3",
deps = [
+ ":grouping",
"//tensorflow/python:dataset_ops_gen",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:nest",
"//tensorflow/python/data/util:sparse",
+ "//third_party/py/numpy",
],
)
@@ -129,6 +131,7 @@ py_library(
"//tensorflow/python/data/util:convert",
"//tensorflow/python/data/util:nest",
"//tensorflow/python/data/util:sparse",
+ "//third_party/py/numpy",
],
)
diff --git a/tensorflow/contrib/data/python/ops/get_single_element.py b/tensorflow/contrib/data/python/ops/get_single_element.py
index 0f4cd8e20c..ef9284456e 100644
--- a/tensorflow/contrib/data/python/ops/get_single_element.py
+++ b/tensorflow/contrib/data/python/ops/get_single_element.py
@@ -17,6 +17,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy as np
+
+from tensorflow.contrib.data.python.ops import grouping
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import sparse
@@ -68,3 +71,30 @@ def get_single_element(dataset):
return sparse.deserialize_sparse_tensors(
nested_ret, dataset.output_types, dataset.output_shapes,
dataset.output_classes)
+
+
+def reduce_dataset(dataset, reducer):
+ """Returns the result of reducing the `dataset` using `reducer`.
+
+ Args:
+ dataset: A @{tf.data.Dataset} object.
+ reducer: A @{tf.contrib.data.Reducer} object representing the reduce logic.
+
+ Returns:
+ A nested structure of @{tf.Tensor} objects, corresponding to the result
+ of reducing `dataset` using `reducer`.
+
+ Raises:
+ TypeError: if `dataset` is not a `tf.data.Dataset` object.
+ """
+ if not isinstance(dataset, dataset_ops.Dataset):
+ raise TypeError("`dataset` must be a `tf.data.Dataset` object.")
+
+ # The sentinel dataset is used in case the reduced dataset is empty.
+ sentinel_dataset = dataset_ops.Dataset.from_tensors(
+ reducer.finalize_func(reducer.init_func(np.int64(0))))
+ reduced_dataset = dataset.apply(
+ grouping.group_by_reducer(lambda x: np.int64(0), reducer))
+
+ return get_single_element(
+ reduced_dataset.concatenate(sentinel_dataset).take(1))
diff --git a/tensorflow/contrib/data/python/ops/prefetching_ops.py b/tensorflow/contrib/data/python/ops/prefetching_ops.py
index 50212d3b52..45abd6376c 100644
--- a/tensorflow/contrib/data/python/ops/prefetching_ops.py
+++ b/tensorflow/contrib/data/python/ops/prefetching_ops.py
@@ -480,6 +480,11 @@ class _CopyToDeviceDataset(dataset_ops.Dataset):
self._finalize_func = _remote_finalize_func
self._finalize_captured_args = _remote_finalize_func.captured_inputs
+
+ g = ops.get_default_graph()
+ _remote_init_func.add_to_graph(g)
+ _remote_next_func.add_to_graph(g)
+ _remote_finalize_func.add_to_graph(g)
# pylint: enable=protected-scope
# The one_shot_iterator implementation needs a 0 arg _make_dataset function
diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py
index 58c548d798..e31dbbe80f 100644
--- a/tensorflow/contrib/eager/python/datasets.py
+++ b/tensorflow/contrib/eager/python/datasets.py
@@ -18,33 +18,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import threading
-
from tensorflow.contrib.data.python.ops import prefetching_ops
from tensorflow.python.data.ops import iterator_ops
-from tensorflow.python.data.util import nest
-from tensorflow.python.data.util import sparse
from tensorflow.python.eager import context
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_dataset_ops
-from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.training.saver import BaseSaverBuilder
-_uid_counter = 0
-_uid_lock = threading.Lock()
-
-
-def _generate_shared_name(prefix):
- with _uid_lock:
- global _uid_counter
- uid = _uid_counter
- _uid_counter += 1
- return "{}{}".format(prefix, uid)
-
class Iterator(iterator_ops.EagerIterator, checkpointable.CheckpointableBase):
"""An iterator producing tf.Tensor objects from a tf.data.Dataset.
@@ -80,38 +61,18 @@ class Iterator(iterator_ops.EagerIterator, checkpointable.CheckpointableBase):
"`tf.contrib.eager.Iterator`. Use `for ... in dataset:` to iterate "
"over the dataset instead.")
- super(Iterator, self).__init__(dataset)
if not context.context().device_spec.device_type:
is_remote_device = False
else:
is_remote_device = context.context().device_spec.device_type != "CPU"
- self._buffer_resource_handle = None
if is_remote_device:
- with ops.device("/device:CPU:0"):
- iter_string_handle = gen_dataset_ops.iterator_to_string_handle(
- self._resource)
-
- @function.Defun(dtypes.string)
- def remote_fn(h):
- remote_iterator = iterator_ops.Iterator.from_string_handle(
- h, self.output_types, self.output_shapes, self.output_classes)
- return remote_iterator.get_next()
-
- remote_fn.add_to_graph(None)
- target = constant_op.constant("/device:CPU:0")
- with ops.device(self._device):
- self._buffer_resource_handle = prefetching_ops.function_buffering_resource( # pylint: disable=line-too-long
- string_arg=iter_string_handle,
- output_types=self._flat_output_types,
- f=remote_fn,
- target_device=target,
- buffer_size=10,
- container="",
- shared_name=_generate_shared_name(
- "contrib_eager_iterator_function_buffer_resource"))
- self._buffer_resource_deleter = resource_variable_ops.EagerResourceDeleter( # pylint: disable=line-too-long
- handle=self._buffer_resource_handle,
- handle_device=self._device)
+ with ops.device(None):
+ # Let the placer figure out where to place the various functions etc.
+ # created by the CopyToDeviceDataset.
+ dataset = dataset.apply(prefetching_ops.copy_to_device(
+ context.context().device_name))
+ dataset = dataset.prefetch(1)
+ super(Iterator, self).__init__(dataset)
def _next_internal(self):
"""Returns a nested structure of `tf.Tensor`s containing the next element.
@@ -120,16 +81,7 @@ class Iterator(iterator_ops.EagerIterator, checkpointable.CheckpointableBase):
# that there is no more data to iterate over.
# TODO(b/77291417): Fix
with context.execution_mode(context.SYNC):
- if self._buffer_resource_handle is not None:
- with ops.device(self._device):
- ret = prefetching_ops.function_buffering_resource_get_next(
- function_buffer_resource=self._buffer_resource_handle,
- output_types=self._flat_output_types)
- return sparse.deserialize_sparse_tensors(
- nest.pack_sequence_as(self._output_types, ret), self._output_types,
- self._output_shapes, self._output_classes)
- else:
- return super(Iterator, self)._next_internal()
+ return super(Iterator, self)._next_internal()
# TODO(shivaniagrawal): Expose checkpointable stateful objects from dataset
# attributes(potential).
diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist.py b/tensorflow/contrib/eager/python/examples/gan/mnist.py
index b33243021b..9a42179299 100644
--- a/tensorflow/contrib/eager/python/examples/gan/mnist.py
+++ b/tensorflow/contrib/eager/python/examples/gan/mnist.py
@@ -29,7 +29,6 @@ import time
import tensorflow as tf
-import tensorflow.contrib.eager as tfe
from tensorflow.examples.tutorials.mnist import input_data
layers = tf.keras.layers
@@ -265,7 +264,7 @@ def train_one_epoch(generator, discriminator, generator_optimizer,
def main(_):
(device, data_format) = ('/gpu:0', 'channels_first')
- if FLAGS.no_gpu or tfe.num_gpus() <= 0:
+ if FLAGS.no_gpu or tf.contrib.eager.num_gpus() <= 0:
(device, data_format) = ('/cpu:0', 'channels_last')
print('Using device %s, and data format %s.' % (device, data_format))
@@ -291,7 +290,7 @@ def main(_):
latest_cpkt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
if latest_cpkt:
print('Using latest checkpoint at ' + latest_cpkt)
- checkpoint = tfe.Checkpoint(**model_objects)
+ checkpoint = tf.train.Checkpoint(**model_objects)
# Restore variables on creation if a checkpoint exists.
checkpoint.restore(latest_cpkt)
diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb
index 43c8c355dc..232f9a8ef0 100644
--- a/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb
+++ b/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb
@@ -31,7 +31,7 @@
"\n",
"On a colab GPU(Tesla K80), the model takes around 40 seconds per epoch to train.\n",
"\n",
- "Below is the output generated after training the generator and discriminator models for 100 epochs.\n",
+ "Below is the output generated after training the generator and discriminator models for 150 epochs.\n",
"\n",
"![sample output](https://tensorflow.org/images/gan/dcgan.gif)"
]
diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py
index d64bf5354e..15776c694e 100644
--- a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py
+++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py
@@ -315,7 +315,7 @@ def main(_):
FLAGS.hidden_dim, FLAGS.num_layers, FLAGS.dropout,
use_cudnn_rnn)
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
- checkpoint = tfe.Checkpoint(
+ checkpoint = tf.train.Checkpoint(
learning_rate=learning_rate, model=model,
# GradientDescentOptimizer has no state to checkpoint, but noting it
# here lets us swap in an optimizer that does.
diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD
index 11d40f5982..1aa3df8d8d 100644
--- a/tensorflow/contrib/estimator/BUILD
+++ b/tensorflow/contrib/estimator/BUILD
@@ -28,7 +28,7 @@ py_library(
":multi_head",
":replicate_model_fn",
":rnn",
- "//tensorflow/python:util",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -54,22 +54,10 @@ py_test(
deps = [
":baseline",
":head",
- "//tensorflow/python:check_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:session",
- "//tensorflow/python:summary",
- "//tensorflow/python:training",
- "//tensorflow/python:variables",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/estimator:export_export",
"//tensorflow/python/estimator:metric_keys",
"//tensorflow/python/estimator:numpy_io",
- "//tensorflow/python/feature_column",
- "//tensorflow/python/ops/losses",
"//third_party/py/numpy",
"@six_archive//:six",
],
@@ -96,11 +84,8 @@ py_test(
],
deps = [
":boosted_trees",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:training",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/estimator:numpy_io",
- "//tensorflow/python/feature_column",
"//third_party/py/numpy",
],
)
@@ -110,7 +95,7 @@ py_library(
srcs = ["python/estimator/dnn.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:nn",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/estimator",
"//tensorflow/python/estimator:dnn",
],
@@ -129,16 +114,11 @@ py_test(
deps = [
":dnn",
":head",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:summary",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/estimator:dnn_testing_utils",
"//tensorflow/python/estimator:export_export",
"//tensorflow/python/estimator:numpy_io",
"//tensorflow/python/estimator:prediction_keys",
- "//tensorflow/python/feature_column",
- "//tensorflow/python/ops/losses",
"//third_party/py/numpy",
"@six_archive//:six",
],
@@ -149,7 +129,7 @@ py_library(
srcs = ["python/estimator/dnn_linear_combined.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:nn",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/estimator",
"//tensorflow/python/estimator:dnn_linear_combined",
],
@@ -168,18 +148,12 @@ py_test(
deps = [
":dnn_linear_combined",
":head",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:nn",
- "//tensorflow/python:platform",
- "//tensorflow/python:summary",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/estimator:dnn_testing_utils",
"//tensorflow/python/estimator:export_export",
"//tensorflow/python/estimator:linear_testing_utils",
"//tensorflow/python/estimator:numpy_io",
"//tensorflow/python/estimator:prediction_keys",
- "//tensorflow/python/feature_column",
- "//tensorflow/python/ops/losses",
"//third_party/py/numpy",
"@six_archive//:six",
],
@@ -192,10 +166,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:clip_ops",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python:training",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/estimator",
"//tensorflow/python/estimator:model_fn",
"//tensorflow/python/estimator:util",
@@ -211,18 +182,11 @@ py_test(
tags = ["notsan"], # b/62863147
deps = [
":extenders",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/contrib/data/python/ops:dataset_ops",
"//tensorflow/contrib/predictor",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:metrics",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python:training",
- "//tensorflow/python:variables",
"//tensorflow/python/estimator:estimator_py",
"//tensorflow/python/estimator:linear",
- "//tensorflow/python/feature_column",
"//third_party/py/numpy",
],
)
@@ -246,21 +210,11 @@ py_test(
tags = ["notsan"], # b/62863147
deps = [
":export",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:metrics",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:session",
- "//tensorflow/python:state_ops",
- "//tensorflow/python:training",
- "//tensorflow/python:util",
- "//tensorflow/python:variables",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/estimator",
"//tensorflow/python/estimator:export_export",
"//tensorflow/python/estimator:export_output",
"//tensorflow/python/estimator:model_fn",
- "//tensorflow/python/saved_model:loader",
- "//tensorflow/python/saved_model:tag_constants",
],
)
@@ -271,25 +225,12 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:array_ops",
- "//tensorflow/python:check_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:lookup_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:metrics",
- "//tensorflow/python:nn",
- "//tensorflow/python:sparse_ops",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python:summary",
- "//tensorflow/python:training",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/estimator:export_output",
"//tensorflow/python/estimator:head",
"//tensorflow/python/estimator:metric_keys",
"//tensorflow/python/estimator:model_fn",
"//tensorflow/python/estimator:prediction_keys",
- "//tensorflow/python/ops/losses",
- "//tensorflow/python/saved_model:signature_constants",
],
)
@@ -300,25 +241,10 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":head",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:check_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python:string_ops",
- "//tensorflow/python:training",
- "//tensorflow/python:variables",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/estimator:metric_keys",
"//tensorflow/python/estimator:model_fn",
"//tensorflow/python/estimator:prediction_keys",
- "//tensorflow/python/ops/losses",
- "//tensorflow/python/saved_model:signature_constants",
"//third_party/py/numpy",
"@six_archive//:six",
],
@@ -331,8 +257,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:training",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/estimator:estimator_py",
],
)
@@ -345,10 +270,7 @@ py_test(
tags = ["notsan"],
deps = [
":hooks",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:training",
- "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/estimator:estimator_py",
"//third_party/py/numpy",
"@six_archive//:six",
@@ -377,16 +299,11 @@ py_test(
deps = [
":head",
":linear",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:summary",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/estimator:export_export",
"//tensorflow/python/estimator:linear_testing_utils",
"//tensorflow/python/estimator:numpy_io",
"//tensorflow/python/estimator:prediction_keys",
- "//tensorflow/python/feature_column",
- "//tensorflow/python/ops/losses",
"//third_party/py/numpy",
"@six_archive//:six",
],
@@ -399,8 +316,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:util",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/estimator:dnn",
"//tensorflow/python/estimator:linear",
],
@@ -413,9 +329,7 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":logit_fns",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:session",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/estimator:model_fn",
],
)
@@ -427,18 +341,11 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:array_ops",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:metrics",
- "//tensorflow/python:summary",
- "//tensorflow/python:training",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/estimator:export_output",
"//tensorflow/python/estimator:head",
"//tensorflow/python/estimator:metric_keys",
"//tensorflow/python/estimator:model_fn",
- "//tensorflow/python/saved_model:signature_constants",
"@six_archive//:six",
],
)
@@ -451,15 +358,10 @@ py_test(
deps = [
":head",
":multi_head",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:string_ops",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/estimator:metric_keys",
"//tensorflow/python/estimator:model_fn",
"//tensorflow/python/estimator:prediction_keys",
- "//tensorflow/python/saved_model:signature_constants",
"//third_party/py/numpy",
"@six_archive//:six",
],
@@ -472,24 +374,10 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:device",
- "//tensorflow/python:device_lib",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:sparse_ops",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python:state_ops",
- "//tensorflow/python:training",
- "//tensorflow/python:util",
- "//tensorflow/python:variable_scope",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/estimator:export_output",
"//tensorflow/python/estimator:model_fn",
"//tensorflow/python/estimator:util",
- "//tensorflow/python/ops/losses",
"@six_archive//:six",
],
)
@@ -500,6 +388,7 @@ cuda_py_test(
srcs = ["python/estimator/replicate_model_fn_test.py"],
additional_deps = [
"@absl_py//absl/testing:parameterized",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/estimator",
"//tensorflow/python/estimator:dnn",
"//tensorflow/python/estimator:export_export",
@@ -508,21 +397,6 @@ cuda_py_test(
"//tensorflow/python/estimator:numpy_io",
"//tensorflow/python/estimator:optimizers",
"//tensorflow/python/estimator:prediction_keys",
- "//tensorflow/python/feature_column",
- "//tensorflow/python/ops/losses",
- "//tensorflow/python/saved_model:signature_constants",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:metrics",
- "//tensorflow/python:platform",
- "//tensorflow/python:summary",
- "//tensorflow/python:training",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python:variables",
":replicate_model_fn",
],
tags = [
@@ -538,22 +412,11 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":extenders",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/contrib/feature_column:feature_column_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:check_ops",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:init_ops",
- "//tensorflow/python:layers",
- "//tensorflow/python:partitioned_variables",
- "//tensorflow/python:rnn",
- "//tensorflow/python:rnn_cell",
- "//tensorflow/python:summary",
- "//tensorflow/python:training",
- "//tensorflow/python:variable_scope",
"//tensorflow/python/estimator",
"//tensorflow/python/estimator:head",
"//tensorflow/python/estimator:optimizers",
- "//tensorflow/python/feature_column",
"@six_archive//:six",
],
)
@@ -572,21 +435,10 @@ py_test(
deps = [
":head",
":rnn",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/contrib/data",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:check_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:lib",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:state_ops",
- "//tensorflow/python:summary",
- "//tensorflow/python:training",
- "//tensorflow/python:variables",
"//tensorflow/python/estimator:numpy_io",
"//tensorflow/python/estimator:parsing_utils",
- "//tensorflow/python/feature_column",
"//third_party/py/numpy",
"@six_archive//:six",
],
@@ -597,13 +449,7 @@ py_library(
srcs = ["python/estimator/early_stopping.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:init_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:state_ops",
- "//tensorflow/python:summary",
- "//tensorflow/python:training",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/estimator",
],
)
@@ -614,7 +460,7 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":early_stopping",
- "//tensorflow/python:client_testlib",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/estimator",
"@absl_py//absl/testing:parameterized",
],
diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py
index c9d86ef4ab..34f765d565 100644
--- a/tensorflow/contrib/estimator/python/estimator/head.py
+++ b/tensorflow/contrib/estimator/python/estimator/head.py
@@ -943,20 +943,30 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
class_probabilities = array_ops.slice(
probabilities, begin=begin, size=size)
class_labels = array_ops.slice(labels, begin=begin, size=size)
- prob_key = keys.PROBABILITY_MEAN_AT_CLASS % class_id
+ if self._label_vocabulary is None:
+ prob_key = keys.PROBABILITY_MEAN_AT_CLASS % class_id
+ else:
+ prob_key = (
+ keys.PROBABILITY_MEAN_AT_NAME % self._label_vocabulary[class_id])
metric_ops[head_lib._summary_key(self._name, prob_key)] = ( # pylint:disable=protected-access
head_lib._predictions_mean( # pylint:disable=protected-access
predictions=class_probabilities,
weights=weights,
name=prob_key))
- auc_key = keys.AUC_AT_CLASS % class_id
+ if self._label_vocabulary is None:
+ auc_key = keys.AUC_AT_CLASS % class_id
+ else:
+ auc_key = keys.AUC_AT_NAME % self._label_vocabulary[class_id]
metric_ops[head_lib._summary_key(self._name, auc_key)] = ( # pylint:disable=protected-access
head_lib._auc( # pylint:disable=protected-access
labels=class_labels,
predictions=class_probabilities,
weights=weights,
name=auc_key))
- auc_pr_key = keys.AUC_PR_AT_CLASS % class_id
+ if self._label_vocabulary is None:
+ auc_pr_key = keys.AUC_PR_AT_CLASS % class_id
+ else:
+ auc_pr_key = keys.AUC_PR_AT_NAME % self._label_vocabulary[class_id]
metric_ops[head_lib._summary_key(self._name, auc_pr_key)] = ( # pylint:disable=protected-access
head_lib._auc( # pylint:disable=protected-access
labels=class_labels,
diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py
index 7b884402d4..2d367adb47 100644
--- a/tensorflow/contrib/estimator/python/estimator/head_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/head_test.py
@@ -694,12 +694,14 @@ class MultiLabelHead(test.TestCase):
# this assert tests that the algorithm remains consistent.
keys.AUC: 0.3333,
keys.AUC_PR: 0.7639,
- keys.PROBABILITY_MEAN_AT_CLASS % 0: np.sum(_sigmoid(logits[:, 0])) / 2.,
- keys.AUC_AT_CLASS % 0: 0.,
- keys.AUC_PR_AT_CLASS % 0: 1.,
- keys.PROBABILITY_MEAN_AT_CLASS % 1: np.sum(_sigmoid(logits[:, 1])) / 2.,
- keys.AUC_AT_CLASS % 1: 1.,
- keys.AUC_PR_AT_CLASS % 1: 1.,
+ keys.PROBABILITY_MEAN_AT_NAME % 'a':
+ np.sum(_sigmoid(logits[:, 0])) / 2.,
+ keys.AUC_AT_NAME % 'a': 0.,
+ keys.AUC_PR_AT_NAME % 'a': 1.,
+ keys.PROBABILITY_MEAN_AT_NAME % 'b':
+ np.sum(_sigmoid(logits[:, 1])) / 2.,
+ keys.AUC_AT_NAME % 'b': 1.,
+ keys.AUC_PR_AT_NAME % 'b': 1.,
}
self._test_eval(
diff --git a/tensorflow/contrib/lite/Makefile b/tensorflow/contrib/lite/Makefile
index a616138d33..df5954744a 100644
--- a/tensorflow/contrib/lite/Makefile
+++ b/tensorflow/contrib/lite/Makefile
@@ -82,8 +82,9 @@ endif
# Settings for the host compiler.
CXX := $(CC_PREFIX) ${TARGET_TOOLCHAIN_PREFIX}g++
-CXXFLAGS += --std=c++11 -O3 -DNDEBUG
+CXXFLAGS += -O3 -DNDEBUG
CCFLAGS := ${CXXFLAGS}
+CXXFLAGS += --std=c++11
CC := $(CC_PREFIX) ${TARGET_TOOLCHAIN_PREFIX}gcc
AR := $(CC_PREFIX) ${TARGET_TOOLCHAIN_PREFIX}ar
CFLAGS :=
diff --git a/tensorflow/contrib/lite/build_ios_universal_lib.sh b/tensorflow/contrib/lite/build_ios_universal_lib.sh
index e9531aef19..31df43a175 100755
--- a/tensorflow/contrib/lite/build_ios_universal_lib.sh
+++ b/tensorflow/contrib/lite/build_ios_universal_lib.sh
@@ -21,7 +21,7 @@ cd "$SCRIPT_DIR/../../.."
# Build library for supported architectures and packs them in a fat binary.
make_library() {
- for arch in x86_64 i386 armv7 armv7s arm64
+ for arch in x86_64 armv7 armv7s arm64
do
make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=${arch} \
-j 8 \
@@ -29,7 +29,6 @@ make_library() {
done
lipo \
tensorflow/contrib/lite/gen/lib/ios_x86_64/${1} \
- tensorflow/contrib/lite/gen/lib/ios_i386/${1} \
tensorflow/contrib/lite/gen/lib/ios_armv7/${1} \
tensorflow/contrib/lite/gen/lib/ios_armv7s/${1} \
tensorflow/contrib/lite/gen/lib/ios_arm64/${1} \
diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h
index 4c7b27c4e0..558e547121 100644
--- a/tensorflow/contrib/lite/builtin_ops.h
+++ b/tensorflow/contrib/lite/builtin_ops.h
@@ -108,6 +108,7 @@ typedef enum {
kTfLiteBuiltinFakeQuant = 80,
kTfLiteBuiltinReduceProd = 81,
kTfLiteBuiltinReduceMax = 82,
+ kTfLiteBuiltinPack = 83,
} TfLiteBuiltinOperator;
#ifdef __cplusplus
diff --git a/tensorflow/contrib/lite/delegates/eager/BUILD b/tensorflow/contrib/lite/delegates/eager/BUILD
index 9f31ffdf67..03a4b7bf1d 100644
--- a/tensorflow/contrib/lite/delegates/eager/BUILD
+++ b/tensorflow/contrib/lite/delegates/eager/BUILD
@@ -42,10 +42,6 @@ cc_library(
name = "delegate_data",
srcs = ["delegate_data.cc"],
hdrs = ["delegate_data.h"],
- tags = [
- "no_oss",
- "tflite_not_portable",
- ],
deps = [
":buffer_map",
"//tensorflow/core:core_cpu",
@@ -59,6 +55,7 @@ cc_test(
size = "small",
srcs = ["delegate_data_test.cc"],
tags = [
+ "no_oss",
"tflite_not_portable",
],
deps = [
diff --git a/tensorflow/contrib/lite/download_dependencies.sh b/tensorflow/contrib/lite/download_dependencies.sh
index 840015a7fa..8c7df474d5 100755
--- a/tensorflow/contrib/lite/download_dependencies.sh
+++ b/tensorflow/contrib/lite/download_dependencies.sh
@@ -35,7 +35,7 @@ GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.g
ABSL_URL="$(grep -o 'https://github.com/abseil/abseil-cpp/.*tar.gz' "${BZL_FILE_PATH}" | head -n1)"
NEON_2_SSE_URL="https://github.com/intel/ARM_NEON_2_x86_SSE/archive/master.zip"
FARMHASH_URL="https://mirror.bazel.build/github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz"
-FLATBUFFERS_URL="https://github.com/google/flatbuffers/archive/master.zip"
+FLATBUFFERS_URL="https://github.com/google/flatbuffers/archive/v1.8.0.zip"
FFT2D_URL="https://mirror.bazel.build/www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz"
# TODO(petewarden): Some new code in Eigen triggers a clang bug with iOS arm64,
diff --git a/tensorflow/contrib/lite/java/AndroidManifest.xml b/tensorflow/contrib/lite/java/AndroidManifest.xml
index f705feacbe..c3849e6868 100644
--- a/tensorflow/contrib/lite/java/AndroidManifest.xml
+++ b/tensorflow/contrib/lite/java/AndroidManifest.xml
@@ -1,7 +1,11 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
- package="org.tensorflow.lite">
- <application>
- </application>
-</manifest>
+ package="org.tensorflow.lite">
+
+ <uses-sdk
+ android:minSdkVersion="4"
+ android:targetSdkVersion="19" />
+ <application />
+
+</manifest>
diff --git a/tensorflow/contrib/lite/kernels/add.cc b/tensorflow/contrib/lite/kernels/add.cc
index f44d531cbf..af9b5c7013 100644
--- a/tensorflow/contrib/lite/kernels/add.cc
+++ b/tensorflow/contrib/lite/kernels/add.cc
@@ -110,15 +110,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
QuantizeMultiplierSmallerThanOneExp(
real_input1_multiplier, &data->input1_multiplier, &data->input1_shift);
- data->input1_shift *= -1;
QuantizeMultiplierSmallerThanOneExp(
real_input2_multiplier, &data->input2_multiplier, &data->input2_shift);
- data->input2_shift *= -1;
QuantizeMultiplierSmallerThanOneExp(
real_output_multiplier, &data->output_multiplier, &data->output_shift);
- data->output_shift *= -1;
CalculateActivationRangeUint8(params->activation, output,
&data->output_activation_min,
@@ -152,14 +149,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
CheckedLog2(output->params.scale, &output_scale_log2_rounded);
TF_LITE_ENSURE(context, output_scale_is_pot);
- data->input1_shift = output_scale_log2_rounded - input1_scale_log2_rounded;
- data->input2_shift = output_scale_log2_rounded - input2_scale_log2_rounded;
+ data->input1_shift = input1_scale_log2_rounded - output_scale_log2_rounded;
+ data->input2_shift = input2_scale_log2_rounded - output_scale_log2_rounded;
// Shifting of one input is supported. The graph quantization should ensure
// that the other input matches the output.
TF_LITE_ENSURE(context, data->input1_shift == 0 || data->input2_shift == 0);
- TF_LITE_ENSURE(context, data->input1_shift >= 0);
- TF_LITE_ENSURE(context, data->input2_shift >= 0);
+ TF_LITE_ENSURE(context, data->input1_shift <= 0);
+ TF_LITE_ENSURE(context, data->input2_shift <= 0);
CalculateActivationRangeQuantized(context, params->activation, output,
&data->output_activation_min,
@@ -173,24 +170,27 @@ template <KernelType kernel_type>
void EvalAdd(TfLiteContext* context, TfLiteNode* node, TfLiteAddParams* params,
const OpData* data, const TfLiteTensor* input1,
const TfLiteTensor* input2, TfLiteTensor* output) {
-#define TF_LITE_ADD(type, opname, data_type) \
- data_type output_activation_min, output_activation_max; \
- CalculateActivationRange(params->activation, &output_activation_min, \
- &output_activation_max); \
- type::opname(GetTensorData<data_type>(input1), GetTensorDims(input1), \
- GetTensorData<data_type>(input2), GetTensorDims(input2), \
- output_activation_min, output_activation_max, \
- GetTensorData<data_type>(output), GetTensorDims(output))
+#define TF_LITE_ADD(type, opname, data_type) \
+ data_type output_activation_min, output_activation_max; \
+ CalculateActivationRange(params->activation, &output_activation_min, \
+ &output_activation_max); \
+ tflite::ArithmeticParams op_params; \
+ SetActivationParams(output_activation_min, output_activation_max, \
+ &op_params); \
+ type::opname(op_params, GetTensorShape(input1), \
+ GetTensorData<data_type>(input1), GetTensorShape(input2), \
+ GetTensorData<data_type>(input2), GetTensorShape(output), \
+ GetTensorData<data_type>(output))
if (output->type == kTfLiteInt32) {
if (kernel_type == kReference) {
if (data->requires_broadcast) {
- TF_LITE_ADD(reference_ops, BroadcastAdd, int32_t);
+ TF_LITE_ADD(reference_ops, BroadcastAdd4DSlow, int32_t);
} else {
TF_LITE_ADD(reference_ops, Add, int32_t);
}
} else {
if (data->requires_broadcast) {
- TF_LITE_ADD(optimized_ops, BroadcastAdd, int32_t);
+ TF_LITE_ADD(optimized_ops, BroadcastAdd4DSlow, int32_t);
} else {
TF_LITE_ADD(optimized_ops, Add, int32_t);
}
@@ -198,13 +198,13 @@ void EvalAdd(TfLiteContext* context, TfLiteNode* node, TfLiteAddParams* params,
} else if (output->type == kTfLiteFloat32) {
if (kernel_type == kReference) {
if (data->requires_broadcast) {
- TF_LITE_ADD(reference_ops, BroadcastAdd, float);
+ TF_LITE_ADD(reference_ops, BroadcastAdd4DSlow, float);
} else {
TF_LITE_ADD(reference_ops, Add, float);
}
} else {
if (data->requires_broadcast) {
- TF_LITE_ADD(optimized_ops, BroadcastAdd, float);
+ TF_LITE_ADD(optimized_ops, BroadcastAdd4DSlow, float);
} else {
TF_LITE_ADD(optimized_ops, Add, float);
}
@@ -220,30 +220,43 @@ TfLiteStatus EvalAddQuantized(TfLiteContext* context, TfLiteNode* node,
const TfLiteTensor* input2,
TfLiteTensor* output) {
if (output->type == kTfLiteUInt8) {
-#define TF_LITE_ADD(type, opname) \
- type::opname( \
- data->left_shift, GetTensorData<uint8_t>(input1), GetTensorDims(input1), \
- data->input1_offset, data->input1_multiplier, data->input1_shift, \
- GetTensorData<uint8_t>(input2), GetTensorDims(input2), \
- data->input2_offset, data->input2_multiplier, data->input2_shift, \
- data->output_offset, data->output_multiplier, data->output_shift, \
- data->output_activation_min, data->output_activation_max, \
- GetTensorData<uint8_t>(output), GetTensorDims(output));
+#define TF_LITE_ADD(type, opname) \
+ tflite::ArithmeticParams op_params; \
+ op_params.left_shift = data->left_shift; \
+ op_params.input1_offset = data->input1_offset; \
+ op_params.input1_multiplier = data->input1_multiplier; \
+ op_params.input1_shift = data->input1_shift; \
+ op_params.input2_offset = data->input2_offset; \
+ op_params.input2_multiplier = data->input2_multiplier; \
+ op_params.input2_shift = data->input2_shift; \
+ op_params.output_offset = data->output_offset; \
+ op_params.output_multiplier = data->output_multiplier; \
+ op_params.output_shift = data->output_shift; \
+ SetActivationParams(data->output_activation_min, \
+ data->output_activation_max, &op_params); \
+ type::opname(op_params, GetTensorShape(input1), \
+ GetTensorData<uint8_t>(input1), GetTensorShape(input2), \
+ GetTensorData<uint8_t>(input2), GetTensorShape(output), \
+ GetTensorData<uint8_t>(output))
// The quantized version of Add doesn't support activations, so we
// always use BroadcastAdd.
if (kernel_type == kReference) {
- TF_LITE_ADD(reference_ops, BroadcastAdd);
+ TF_LITE_ADD(reference_ops, BroadcastAdd4DSlow);
} else {
- TF_LITE_ADD(optimized_ops, BroadcastAdd);
+ TF_LITE_ADD(optimized_ops, BroadcastAdd4DSlow);
}
#undef TF_LITE_ADD
} else if (output->type == kTfLiteInt16) {
-#define TF_LITE_ADD(type, opname) \
- type::opname(GetTensorData<int16_t>(input1), GetTensorDims(input1), \
- data->input1_shift, GetTensorData<int16_t>(input2), \
- GetTensorDims(input2), data->input2_shift, \
- data->output_activation_min, data->output_activation_max, \
- GetTensorData<int16_t>(output), GetTensorDims(output));
+#define TF_LITE_ADD(type, opname) \
+ tflite::ArithmeticParams op_params; \
+ op_params.input1_shift = data->input1_shift; \
+ op_params.input2_shift = data->input2_shift; \
+ SetActivationParams(data->output_activation_min, \
+ data->output_activation_max, &op_params); \
+ type::opname(op_params, GetTensorShape(input1), \
+ GetTensorData<int16_t>(input1), GetTensorShape(input2), \
+ GetTensorData<int16_t>(input2), GetTensorShape(output), \
+ GetTensorData<int16_t>(output))
// The quantized version of Add doesn't support activations, so we
// always use BroadcastAdd.
if (kernel_type == kReference) {
diff --git a/tensorflow/contrib/lite/kernels/div.cc b/tensorflow/contrib/lite/kernels/div.cc
index bc5c3783fd..d7420ddd8e 100644
--- a/tensorflow/contrib/lite/kernels/div.cc
+++ b/tensorflow/contrib/lite/kernels/div.cc
@@ -78,29 +78,44 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
template <KernelType kernel_type>
-void EvalFloat(TfLiteContext* context, TfLiteNode* node,
- TfLiteDivParams* params, const OpData* data,
- const TfLiteTensor* input1, const TfLiteTensor* input2,
- TfLiteTensor* output) {
- float output_activation_min, output_activation_max;
- CalculateActivationRange(params->activation, &output_activation_min,
- &output_activation_max);
-#define TF_LITE_DIV(type, opname) \
- type::opname(GetTensorData<float>(input1), GetTensorDims(input1), \
- GetTensorData<float>(input2), GetTensorDims(input2), \
- output_activation_min, output_activation_max, \
- GetTensorData<float>(output), GetTensorDims(output))
- if (kernel_type == kReference) {
- if (data->requires_broadcast) {
- TF_LITE_DIV(reference_ops, BroadcastDiv);
+void EvalDiv(TfLiteContext* context, TfLiteNode* node, TfLiteDivParams* params,
+ const OpData* data, const TfLiteTensor* input1,
+ const TfLiteTensor* input2, TfLiteTensor* output) {
+#define TF_LITE_DIV(type, opname, data_type) \
+ data_type output_activation_min, output_activation_max; \
+ CalculateActivationRange(params->activation, &output_activation_min, \
+ &output_activation_max); \
+ type::opname(GetTensorData<data_type>(input1), GetTensorDims(input1), \
+ GetTensorData<data_type>(input2), GetTensorDims(input2), \
+ output_activation_min, output_activation_max, \
+ GetTensorData<data_type>(output), GetTensorDims(output))
+ if (output->type == kTfLiteInt32) {
+ if (kernel_type == kReference) {
+ if (data->requires_broadcast) {
+ TF_LITE_DIV(reference_ops, BroadcastDiv, int32_t);
+ } else {
+ TF_LITE_DIV(reference_ops, Div, int32_t);
+ }
} else {
- TF_LITE_DIV(reference_ops, Div);
+ if (data->requires_broadcast) {
+ TF_LITE_DIV(optimized_ops, BroadcastDiv, int32_t);
+ } else {
+ TF_LITE_DIV(optimized_ops, Div, int32_t);
+ }
}
- } else {
- if (data->requires_broadcast) {
- TF_LITE_DIV(optimized_ops, BroadcastDiv);
+ } else if (output->type == kTfLiteFloat32) {
+ if (kernel_type == kReference) {
+ if (data->requires_broadcast) {
+ TF_LITE_DIV(reference_ops, BroadcastDiv, float);
+ } else {
+ TF_LITE_DIV(reference_ops, Div, float);
+ }
} else {
- TF_LITE_DIV(optimized_ops, Div);
+ if (data->requires_broadcast) {
+ TF_LITE_DIV(optimized_ops, BroadcastDiv, float);
+ } else {
+ TF_LITE_DIV(optimized_ops, Div, float);
+ }
}
}
#undef TF_LITE_DIV
@@ -115,11 +130,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- if (output->type == kTfLiteFloat32) {
- EvalFloat<kernel_type>(context, node, params, data, input1, input2, output);
+ if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) {
+ EvalDiv<kernel_type>(context, node, params, data, input1, input2, output);
} else {
context->ReportError(
- context, "Div only supports FLOAT32 and quantized UINT8 now, got %d.",
+ context,
+ "Div only supports FLOAT32, INT32 and quantized UINT8 now, got %d.",
output->type);
return kTfLiteError;
}
diff --git a/tensorflow/contrib/lite/kernels/div_test.cc b/tensorflow/contrib/lite/kernels/div_test.cc
index 276b8289fb..97aa2fe04e 100644
--- a/tensorflow/contrib/lite/kernels/div_test.cc
+++ b/tensorflow/contrib/lite/kernels/div_test.cc
@@ -52,6 +52,13 @@ class FloatDivOpModel : public BaseDivOpModel {
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
};
+class IntegerDivOpModel : public BaseDivOpModel {
+ public:
+ using BaseDivOpModel::BaseDivOpModel;
+
+ std::vector<int32_t> GetOutput() { return ExtractVector<int32_t>(output_); }
+};
+
TEST(FloatDivOpTest, NoActivation) {
FloatDivOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
{TensorType_FLOAT32, {1, 2, 2, 1}},
@@ -75,7 +82,7 @@ TEST(FloatDivOpTest, ActivationRELU_N1_TO_1) {
}
TEST(FloatDivOpTest, VariousInputShapes) {
- std::vector<std::initializer_list<int>> test_shapes = {
+ std::vector<std::vector<int>> test_shapes = {
{6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
for (int i = 0; i < test_shapes.size(); ++i) {
FloatDivOpModel m({TensorType_FLOAT32, test_shapes[i]},
@@ -92,7 +99,7 @@ TEST(FloatDivOpTest, VariousInputShapes) {
}
TEST(FloatDivOpTest, WithBroadcast) {
- std::vector<std::initializer_list<int>> test_shapes = {
+ std::vector<std::vector<int>> test_shapes = {
{6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
for (int i = 0; i < test_shapes.size(); ++i) {
FloatDivOpModel m({TensorType_FLOAT32, test_shapes[i]},
@@ -108,6 +115,56 @@ TEST(FloatDivOpTest, WithBroadcast) {
}
}
+TEST(IntegerDivOpTest, NoActivation) {
+ IntegerDivOpModel m({TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {1, 2, 2, 1}}, {TensorType_INT32, {}},
+ ActivationFunctionType_NONE);
+ m.PopulateTensor<int32_t>(m.input1(), {-2, 2, -15, 8});
+ m.PopulateTensor<int32_t>(m.input2(), {5, -2, -3, 5});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, -1, 5, 1}));
+}
+
+TEST(IntegerDivOpTest, ActivationRELU_N1_TO_1) {
+ IntegerDivOpModel m({TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {1, 2, 2, 1}}, {TensorType_INT32, {}},
+ ActivationFunctionType_RELU_N1_TO_1);
+ m.PopulateTensor<int32_t>(m.input1(), {-2, 2, -12, 8});
+ m.PopulateTensor<int32_t>(m.input2(), {1, 2, -15, 5});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1, 1, 0, 1}));
+}
+
+TEST(IntegerDivOpTest, VariousInputShapes) {
+ std::vector<std::vector<int>> test_shapes = {
+ {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
+ for (int i = 0; i < test_shapes.size(); ++i) {
+ IntegerDivOpModel m({TensorType_INT32, test_shapes[i]},
+ {TensorType_INT32, test_shapes[i]},
+ {TensorType_INT32, {}}, ActivationFunctionType_NONE);
+ m.PopulateTensor<int32_t>(m.input1(), {-20, 2, 3, 8, 11, -20});
+ m.PopulateTensor<int32_t>(m.input2(), {1, 2, 6, 5, -11, -1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({-20, 1, 0, 1, -1, 20}))
+ << "With shape number " << i;
+ }
+}
+
+TEST(IntegerDivOpTest, WithBroadcast) {
+ std::vector<std::vector<int>> test_shapes = {
+ {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
+ for (int i = 0; i < test_shapes.size(); ++i) {
+ IntegerDivOpModel m({TensorType_INT32, test_shapes[i]},
+ {TensorType_INT32, {}}, // always a scalar
+ {TensorType_INT32, {}}, ActivationFunctionType_NONE);
+ m.PopulateTensor<int32_t>(m.input1(), {-20, 21, 7, 8, 11, -123});
+ m.PopulateTensor<int32_t>(m.input2(), {3});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({-6, 7, 2, 2, 3, -41}))
+ << "With shape number " << i;
+ }
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/common.h b/tensorflow/contrib/lite/kernels/internal/common.h
index b86ca49c11..310a8980e6 100644
--- a/tensorflow/contrib/lite/kernels/internal/common.h
+++ b/tensorflow/contrib/lite/kernels/internal/common.h
@@ -127,6 +127,139 @@ int CountLeadingZeros(T integer_input) {
return leading_zeros;
}
+// DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING
+// BROADCASTING.
+//
+// NdArrayDesc<N> describes the shape and memory layout of an N-dimensional
+// rectangular array of numbers.
+//
+// NdArrayDesc<N> is basically identical to Dims<N> defined in types.h.
+// However, as Dims<N> is to be deprecated, this class exists as an adaptor
+// to enable simple unoptimized implementations of element-wise broadcasting
+// operations.
+template <int N>
+struct NdArrayDesc {
+ // The "extent" of each dimension. Indices along dimension d must be in the
+ // half-open interval [0, extents[d]).
+ int extents[N];
+
+ // The number of *elements* (not bytes) between consecutive indices of each
+ // dimension.
+ int strides[N];
+};
+
+// DO NOT USE THIS FUNCTION FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING
+// BROADCASTING.
+//
+// Same as Offset(), except takes as NdArrayDesc<N> instead of Dims<N>.
+inline int SubscriptToIndex(const NdArrayDesc<4>& desc, int i0, int i1, int i2,
+ int i3) {
+ TFLITE_DCHECK(i0 >= 0 && i0 < desc.extents[0]);
+ TFLITE_DCHECK(i1 >= 0 && i1 < desc.extents[1]);
+ TFLITE_DCHECK(i2 >= 0 && i2 < desc.extents[2]);
+ TFLITE_DCHECK(i3 >= 0 && i3 < desc.extents[3]);
+ return i0 * desc.strides[0] + i1 * desc.strides[1] + i2 * desc.strides[2] +
+ i3 * desc.strides[3];
+}
+
+// Given the dimensions of the operands for an element-wise binary broadcast,
+// adjusts them so that they can be directly iterated over with simple loops.
+// Returns the adjusted dims as instances of NdArrayDesc in 'desc0_out' and
+// 'desc1_out'. 'desc0_out' and 'desc1_out' cannot be nullptr.
+//
+// This function assumes that the two input shapes are compatible up to
+// broadcasting and the shorter one has already been prepended with 1s to be the
+// same length. E.g., if shape0 is (1, 16, 16, 64) and shape1 is (1, 64),
+// shape1 must already have been prepended to be (1, 1, 1, 64). Recall that
+// Dims<N> refer to shapes in reverse order. In this case, input0_dims will be
+// (64, 16, 16, 1) and input1_dims will be (64, 1, 1, 1).
+//
+// When two shapes are compatible up to broadcasting, for each dimension d,
+// the input extents are either equal, or one of them is 1.
+//
+// This function performs the following for each dimension d:
+// - If the extents are equal, then do nothing since the loop that walks over
+// both of the input arrays is correct.
+// - Otherwise, one (and only one) of the extents must be 1. Say extent0 is 1
+// and extent1 is e1. Then set extent0 to e1 and stride0 *to 0*. This allows
+// array0 to be referenced *at any index* in dimension d and still access the
+// same slice.
+template <int N>
+inline void NdArrayDescsForElementwiseBroadcast(const Dims<N>& input0_dims,
+ const Dims<N>& input1_dims,
+ NdArrayDesc<N>* desc0_out,
+ NdArrayDesc<N>* desc1_out) {
+ TFLITE_DCHECK(desc0_out != nullptr);
+ TFLITE_DCHECK(desc1_out != nullptr);
+
+ // Copy dims to desc.
+ for (int i = 0; i < N; ++i) {
+ desc0_out->extents[i] = input0_dims.sizes[i];
+ desc0_out->strides[i] = input0_dims.strides[i];
+ desc1_out->extents[i] = input1_dims.sizes[i];
+ desc1_out->strides[i] = input1_dims.strides[i];
+ }
+
+ // Walk over each dimension. If the extents are equal do nothing.
+ // Otherwise, set the desc with extent 1 to have extent equal to the other and
+ // stride 0.
+ for (int i = 0; i < N; ++i) {
+ const int extent0 = ArraySize(input0_dims, i);
+ const int extent1 = ArraySize(input1_dims, i);
+ if (extent0 != extent1) {
+ if (extent0 == 1) {
+ desc0_out->strides[i] = 0;
+ desc0_out->extents[i] = extent1;
+ } else {
+ TFLITE_DCHECK_EQ(extent1, 1);
+ desc1_out->strides[i] = 0;
+ desc1_out->extents[i] = extent0;
+ }
+ }
+ }
+}
+
+template <int N>
+inline void NdArrayDescsForElementwiseBroadcast(
+ const RuntimeShape& input0_shape, const RuntimeShape& input1_shape,
+ NdArrayDesc<N>* desc0_out, NdArrayDesc<N>* desc1_out) {
+ TFLITE_DCHECK(desc0_out != nullptr);
+ TFLITE_DCHECK(desc1_out != nullptr);
+
+ auto extended_input0_shape = RuntimeShape::ExtendedShape(N, input0_shape);
+ auto extended_input1_shape = RuntimeShape::ExtendedShape(N, input1_shape);
+
+ // Copy dims to desc, calculating strides.
+ int desc0_stride = 1;
+ int desc1_stride = 1;
+ for (int i = N - 1; i >= 0; --i) {
+ desc0_out->extents[i] = extended_input0_shape.Dims(i);
+ desc0_out->strides[i] = desc0_stride;
+ desc0_stride *= extended_input0_shape.Dims(i);
+ desc1_out->extents[i] = extended_input1_shape.Dims(i);
+ desc1_out->strides[i] = desc1_stride;
+ desc1_stride *= extended_input1_shape.Dims(i);
+ }
+
+ // Walk over each dimension. If the extents are equal do nothing.
+ // Otherwise, set the desc with extent 1 to have extent equal to the other and
+ // stride 0.
+ for (int i = 0; i < N; ++i) {
+ const int extent0 = extended_input0_shape.Dims(i);
+ const int extent1 = extended_input1_shape.Dims(i);
+ if (extent0 != extent1) {
+ if (extent0 == 1) {
+ desc0_out->strides[i] = 0;
+ desc0_out->extents[i] = extent1;
+ } else {
+ TFLITE_DCHECK_EQ(extent1, 1);
+ desc1_out->strides[i] = 0;
+ desc1_out->extents[i] = extent0;
+ }
+ }
+ }
+}
+
} // namespace tflite
#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMMON_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h
index 6db41d7961..d5503073a7 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h
@@ -55,6 +55,245 @@ inline void Relu(const float* input_data, const Dims<4>& input_dims,
DimsToShape(output_dims));
}
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void Add(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ tflite::ArithmeticParams op_params;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+ Add(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <FusedActivationFunctionType Ac>
+inline void Add(int left_shift, const uint8* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift,
+ const uint8* input2_data, const Dims<4>& input2_dims,
+ int32 input2_offset, int32 input2_multiplier, int input2_shift,
+ int32 output_offset, int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ constexpr int kReverseShift = -1;
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+
+ tflite::ArithmeticParams op_params;
+ op_params.left_shift = left_shift;
+ op_params.input1_offset = input1_offset;
+ op_params.input1_multiplier = input1_multiplier;
+ op_params.input1_shift = kReverseShift * input1_shift;
+ op_params.input2_offset = input2_offset;
+ op_params.input2_multiplier = input2_multiplier;
+ op_params.input2_shift = kReverseShift * input2_shift;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+ Add(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <FusedActivationFunctionType Ac>
+void Add(const int32* input1_data, const Dims<4>& input1_dims,
+ const int32* input2_data, const Dims<4>& input2_dims,
+ int32* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Add/int32");
+ TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
+
+ tflite::ArithmeticParams op_params;
+ op_params.quantized_activation_min = std::numeric_limits<int32>::min();
+ op_params.quantized_activation_max = std::numeric_limits<int32>::max();
+ Add(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <typename T>
+void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T output_activation_min, T output_activation_max,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+ BroadcastAdd4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <FusedActivationFunctionType Ac>
+inline void BroadcastAdd(int left_shift, const uint8* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift,
+ const uint8* input2_data, const Dims<4>& input2_dims,
+ int32 input2_offset, int32 input2_multiplier,
+ int input2_shift, int32 output_offset,
+ int32 output_multiplier, int output_shift,
+ int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ constexpr int kReverseShift = -1;
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+
+ tflite::ArithmeticParams op_params;
+ op_params.left_shift = left_shift;
+ op_params.input1_offset = input1_offset;
+ op_params.input1_multiplier = input1_multiplier;
+ op_params.input1_shift = kReverseShift * input1_shift;
+ op_params.input2_offset = input2_offset;
+ op_params.input2_multiplier = input2_multiplier;
+ op_params.input2_shift = kReverseShift * input2_shift;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+ BroadcastAdd4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <FusedActivationFunctionType Ac>
+inline void BroadcastAddFivefold(
+ int y0, int y1, int y2, int y3, int y4, int left_shift,
+ const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift, const uint8* input2_data,
+ const Dims<4>& input2_dims, int32 input2_offset, int32 input2_multiplier,
+ int input2_shift, int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ constexpr int kReverseShift = -1;
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ tflite::ArithmeticParams op_params;
+ op_params.broadcast_category =
+ tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
+ op_params.left_shift = left_shift;
+ op_params.input1_offset = input1_offset;
+ op_params.input1_multiplier = input1_multiplier;
+ op_params.input1_shift = kReverseShift * input1_shift;
+ op_params.input2_offset = input2_offset;
+ op_params.input2_multiplier = input2_multiplier;
+ op_params.input2_shift = kReverseShift * input2_shift;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+ op_params.broadcast_shape[4] = y0;
+ op_params.broadcast_shape[3] = y1;
+ op_params.broadcast_shape[2] = y2;
+ op_params.broadcast_shape[1] = y3;
+ op_params.broadcast_shape[0] = y4;
+ BroadcastAddFivefold(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac, typename T>
+void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ T output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ BroadcastAdd(input1_data, input1_dims, input2_data, input2_dims,
+ output_activation_min, output_activation_max, output_data,
+ output_dims);
+}
+
+template <FusedActivationFunctionType Ac>
+inline void Add(const int16* input1_data, const Dims<4>& input1_dims,
+ int input1_shift, const int16* input2_data,
+ const Dims<4>& input2_dims, int input2_shift,
+ int16 output_activation_min, int16 output_activation_max,
+ int16* output_data, const Dims<4>& output_dims) {
+ constexpr int kReverseShift = -1;
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, -32768);
+ TFLITE_DCHECK_EQ(output_activation_max, 32767);
+ }
+
+ tflite::ArithmeticParams op_params;
+ op_params.input1_shift = kReverseShift * input1_shift;
+ op_params.input2_shift = kReverseShift * input2_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+ Add(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+inline void Sub(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(FusedActivationFunctionType::kNone,
+ &output_activation_min, &output_activation_max);
+ tflite::ArithmeticParams op_params;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+ Sub(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <typename T>
+void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data,
+ const Dims<4>& input2_dims, T* output_data,
+ const Dims<4>& output_dims) {
+ T output_activation_min, output_activation_max;
+ GetActivationMinMax(FusedActivationFunctionType::kNone,
+ &output_activation_min, &output_activation_max);
+ tflite::ArithmeticParams op_params;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+ Sub(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
inline void AveragePool(const float* input_data, const Dims<4>& input_dims,
int stride_width, int stride_height, int pad_width,
int pad_height, int kwidth, int kheight,
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
index 2f73036e03..78567d52ea 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
@@ -42,10 +42,12 @@ namespace optimized_ops {
// Unoptimized reference ops:
using reference_ops::ArgMax;
using reference_ops::ArgMinMax;
+using reference_ops::BroadcastAdd4DSlow;
using reference_ops::BroadcastGreater;
using reference_ops::BroadcastGreaterEqual;
using reference_ops::BroadcastLess;
using reference_ops::BroadcastLessEqual;
+using reference_ops::BroadcastSub4DSlow;
using reference_ops::Concatenation;
using reference_ops::DepthConcatenation;
using reference_ops::Dequantize;
@@ -217,98 +219,6 @@ SaturatingRoundingMultiplyByPOTParam(
SaturatingRoundingMultiplyByPOTParam(a.raw(), exponent));
}
-// DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING ELEMENT-WISE
-// BROADCASTING.
-//
-// NdArrayDesc<N> describes the shape and memory layout of an N-dimensional
-// rectangular array of numbers.
-//
-// NdArrayDesc<N> is basically identical to Dims<N> defined in types.h.
-// However, as Dims<N> is to be deprecated, this class exists as an adaptor
-// to enable simple unoptimized implementations of element-wise broadcasting
-// operations.
-template <int N>
-struct NdArrayDesc {
- // The "extent" of each dimension. Indices along dimension d must be in the
- // half-open interval [0, extents[d]).
- int extents[N];
-
- // The number of *elements* (not bytes) between consecutive indices of each
- // dimension.
- int strides[N];
-};
-
-// DO NOT USE THIS FUNCTION FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING
-// ELEMENT-WISE BROADCASTING.
-//
-// Same as Offset(), except takes as NdArrayDesc<N> instead of Dims<N>.
-inline int SubscriptToIndex(const NdArrayDesc<4>& desc, int i0, int i1, int i2,
- int i3) {
- TFLITE_DCHECK(i0 >= 0 && i0 < desc.extents[0]);
- TFLITE_DCHECK(i1 >= 0 && i1 < desc.extents[1]);
- TFLITE_DCHECK(i2 >= 0 && i2 < desc.extents[2]);
- TFLITE_DCHECK(i3 >= 0 && i3 < desc.extents[3]);
- return i0 * desc.strides[0] + i1 * desc.strides[1] + i2 * desc.strides[2] +
- i3 * desc.strides[3];
-}
-
-// Given the dimensions of the operands for an element-wise binary broadcast,
-// adjusts them so that they can be directly iterated over with simple loops.
-// Returns the adjusted dims as instances of NdArrayDesc in 'desc0_out' and
-// 'desc1_out'. 'desc0_out' and 'desc1_out' cannot be nullptr.
-//
-// This function assumes that the two input shapes are compatible up to
-// broadcasting and the shorter one has already been prepended with 1s to be the
-// same length. E.g., if shape0 is (1, 16, 16, 64) and shape1 is (1, 64),
-// shape1 must already have been prepended to be (1, 1, 1, 64). Recall that
-// Dims<N> refer to shapes in reverse order. In this case, input0_dims will be
-// (64, 16, 16, 1) and input1_dims will be (64, 1, 1, 1).
-//
-// When two shapes are compatible up to broadcasting, for each dimension d,
-// the input extents are either equal, or one of them is 1.
-//
-// This function performs the following for each dimension d:
-// - If the extents are equal, then do nothing since the loop that walks over
-// both of the input arrays is correct.
-// - Otherwise, one (and only one) of the extents must be 1. Say extent0 is 1
-// and extent1 is e1. Then set extent0 to e1 and stride0 *to 0*. This allows
-// array0 to be referenced *at any index* in dimension d and still access the
-// same slice.
-template <int N>
-inline void NdArrayDescsForElementwiseBroadcast(const Dims<N>& input0_dims,
- const Dims<N>& input1_dims,
- NdArrayDesc<N>* desc0_out,
- NdArrayDesc<N>* desc1_out) {
- TFLITE_DCHECK(desc0_out != nullptr);
- TFLITE_DCHECK(desc1_out != nullptr);
-
- // Copy dims to desc.
- for (int i = 0; i < N; ++i) {
- desc0_out->extents[i] = input0_dims.sizes[i];
- desc0_out->strides[i] = input0_dims.strides[i];
- desc1_out->extents[i] = input1_dims.sizes[i];
- desc1_out->strides[i] = input1_dims.strides[i];
- }
-
- // Walk over each dimension. If the extents are equal do nothing.
- // Otherwise, set the desc with extent 1 to have extent equal to the other and
- // stride 0.
- for (int i = 0; i < N; ++i) {
- const int extent0 = ArraySize(input0_dims, i);
- const int extent1 = ArraySize(input1_dims, i);
- if (extent0 != extent1) {
- if (extent0 == 1) {
- desc0_out->strides[i] = 0;
- desc0_out->extents[i] = extent1;
- } else {
- TFLITE_DCHECK_EQ(extent1, 1);
- desc1_out->strides[i] = 0;
- desc1_out->extents[i] = extent0;
- }
- }
- }
-}
-
inline bool AreSameDims(const Dims<4>& dims1, const Dims<4>& dims2) {
for (int i = 0; i < 4; i++) {
if (dims1.sizes[i] != dims2.sizes[i]) {
@@ -2478,20 +2388,17 @@ inline void L2Normalization(const uint8* input_data,
}
}
-inline void Add(const float* input1_data, const Dims<4>& input1_dims,
- const float* input2_data, const Dims<4>& input2_dims,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims) {
+inline void Add(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const float* input1_data,
+ const RuntimeShape& input2_shape, const float* input2_data,
+ const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("Add");
- TFLITE_DCHECK(IsPackedWithoutStrides(input1_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(input2_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
int i = 0;
- const int size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
+ const int size = MatchingFlatSize(input1_shape, input2_shape, output_shape);
#ifdef USE_NEON
- const auto activation_min = vdupq_n_f32(output_activation_min);
- const auto activation_max = vdupq_n_f32(output_activation_max);
+ const auto activation_min = vdupq_n_f32(params.float_activation_min);
+ const auto activation_max = vdupq_n_f32(params.float_activation_max);
for (; i <= size - 16; i += 16) {
auto a10 = vld1q_f32(input1_data + i);
auto a11 = vld1q_f32(input1_data + i + 4);
@@ -2530,29 +2437,26 @@ inline void Add(const float* input1_data, const Dims<4>& input1_dims,
for (; i < size; i++) {
auto x = input1_data[i] + input2_data[i];
- output_data[i] = ActivationFunctionWithMinMax(x, output_activation_min,
- output_activation_max);
+ output_data[i] = ActivationFunctionWithMinMax(
+ x, params.float_activation_min, params.float_activation_max);
}
}
// Element-wise add that can often be used for inner loop of broadcast add as
// well as the non-broadcast add.
-inline void AddElementwise(int size, int left_shift, const uint8* input1_data,
- int32 input1_offset, int32 input1_multiplier,
- int input1_shift, const uint8* input2_data,
- int32 input2_offset, int32 input2_multiplier,
- int input2_shift, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_data) {
+inline void AddElementwise(int size, const ArithmeticParams& params,
+ const uint8* input1_data, const uint8* input2_data,
+ uint8* output_data) {
int i = 0;
- TFLITE_DCHECK_GT(input1_offset, -256);
- TFLITE_DCHECK_GT(input2_offset, -256);
- TFLITE_DCHECK_LT(input1_offset, 256);
- TFLITE_DCHECK_LT(input2_offset, 256);
+ TFLITE_DCHECK_GT(params.input1_offset, -256);
+ TFLITE_DCHECK_GT(params.input2_offset, -256);
+ TFLITE_DCHECK_LT(params.input1_offset, 256);
+ TFLITE_DCHECK_LT(params.input2_offset, 256);
#ifdef USE_NEON
- const auto output_activation_min_vector = vdup_n_u8(output_activation_min);
- const auto output_activation_max_vector = vdup_n_u8(output_activation_max);
+ const auto output_activation_min_vector =
+ vdup_n_u8(params.quantized_activation_min);
+ const auto output_activation_max_vector =
+ vdup_n_u8(params.quantized_activation_max);
for (; i <= size - 8; i += 8) {
const auto input1_val_original = vld1_u8(input1_data + i);
const auto input2_val_original = vld1_u8(input2_data + i);
@@ -2561,9 +2465,9 @@ inline void AddElementwise(int size, int left_shift, const uint8* input1_data,
const auto input2_val_s16 =
vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
const auto input1_val =
- vaddq_s16(input1_val_s16, vdupq_n_s16(input1_offset));
+ vaddq_s16(input1_val_s16, vdupq_n_s16(params.input1_offset));
const auto input2_val =
- vaddq_s16(input2_val_s16, vdupq_n_s16(input2_offset));
+ vaddq_s16(input2_val_s16, vdupq_n_s16(params.input2_offset));
const auto input1_val_high = vget_high_s16(input1_val);
const auto input1_val_low = vget_low_s16(input1_val);
const auto input2_val_high = vget_high_s16(input2_val);
@@ -2572,32 +2476,32 @@ inline void AddElementwise(int size, int left_shift, const uint8* input1_data,
auto x12 = vmovl_s16(input1_val_high);
auto x21 = vmovl_s16(input2_val_low);
auto x22 = vmovl_s16(input2_val_high);
- const auto left_shift_dup = vdupq_n_s32(left_shift);
+ const auto left_shift_dup = vdupq_n_s32(params.left_shift);
x11 = vshlq_s32(x11, left_shift_dup);
x12 = vshlq_s32(x12, left_shift_dup);
x21 = vshlq_s32(x21, left_shift_dup);
x22 = vshlq_s32(x22, left_shift_dup);
- x11 = vqrdmulhq_n_s32(x11, input1_multiplier);
- x12 = vqrdmulhq_n_s32(x12, input1_multiplier);
- x21 = vqrdmulhq_n_s32(x21, input2_multiplier);
- x22 = vqrdmulhq_n_s32(x22, input2_multiplier);
- const auto input1_shift_dup = vdupq_n_s32(-input1_shift);
- const auto input2_shift_dup = vdupq_n_s32(-input2_shift);
+ x11 = vqrdmulhq_n_s32(x11, params.input1_multiplier);
+ x12 = vqrdmulhq_n_s32(x12, params.input1_multiplier);
+ x21 = vqrdmulhq_n_s32(x21, params.input2_multiplier);
+ x22 = vqrdmulhq_n_s32(x22, params.input2_multiplier);
+ const auto input1_shift_dup = vdupq_n_s32(params.input1_shift);
+ const auto input2_shift_dup = vdupq_n_s32(params.input2_shift);
x11 = vshlq_s32(x11, input1_shift_dup);
x12 = vshlq_s32(x12, input1_shift_dup);
x21 = vshlq_s32(x21, input2_shift_dup);
x22 = vshlq_s32(x22, input2_shift_dup);
auto s1 = vaddq_s32(x11, x21);
auto s2 = vaddq_s32(x12, x22);
- s1 = vqrdmulhq_n_s32(s1, output_multiplier);
- s2 = vqrdmulhq_n_s32(s2, output_multiplier);
+ s1 = vqrdmulhq_n_s32(s1, params.output_multiplier);
+ s2 = vqrdmulhq_n_s32(s2, params.output_multiplier);
using gemmlowp::RoundingDivideByPOT;
- s1 = RoundingDivideByPOT(s1, output_shift);
- s2 = RoundingDivideByPOT(s2, output_shift);
+ s1 = RoundingDivideByPOT(s1, -params.output_shift);
+ s2 = RoundingDivideByPOT(s2, -params.output_shift);
const auto s1_narrowed = vmovn_s32(s1);
const auto s2_narrowed = vmovn_s32(s2);
const auto s = vaddq_s16(vcombine_s16(s1_narrowed, s2_narrowed),
- vdupq_n_s16(output_offset));
+ vdupq_n_s16(params.output_offset));
const auto clamped =
vmax_u8(output_activation_min_vector,
vmin_u8(output_activation_max_vector, vqmovun_s16(s)));
@@ -2606,101 +2510,74 @@ inline void AddElementwise(int size, int left_shift, const uint8* input1_data,
#endif // NEON
for (; i < size; ++i) {
- const int32 input1_val = input1_offset + input1_data[i];
- const int32 input2_val = input2_offset + input2_data[i];
- const int32 shifted_input1_val = input1_val * (1 << left_shift);
- const int32 shifted_input2_val = input2_val * (1 << left_shift);
+ const int32 input1_val = params.input1_offset + input1_data[i];
+ const int32 input2_val = params.input2_offset + input2_data[i];
+ const int32 shifted_input1_val = input1_val * (1 << params.left_shift);
+ const int32 shifted_input2_val = input2_val * (1 << params.left_shift);
const int32 scaled_input1_val =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
- shifted_input1_val, input1_multiplier,
- kReverseShift * input1_shift);
+ shifted_input1_val, params.input1_multiplier, params.input1_shift);
const int32 scaled_input2_val =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
- shifted_input2_val, input2_multiplier,
- kReverseShift * input2_shift);
+ shifted_input2_val, params.input2_multiplier, params.input2_shift);
const int32 raw_sum = scaled_input1_val + scaled_input2_val;
const int32 raw_output =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
- raw_sum, output_multiplier, kReverseShift * output_shift) +
- output_offset;
- const int32 clamped_output = std::min(
- output_activation_max, std::max(output_activation_min, raw_output));
+ raw_sum, params.output_multiplier, params.output_shift) +
+ params.output_offset;
+ const int32 clamped_output =
+ std::min(params.quantized_activation_max,
+ std::max(params.quantized_activation_min, raw_output));
output_data[i] = static_cast<uint8>(clamped_output);
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void Add(const float* input1_data, const Dims<4>& input1_dims,
- const float* input2_data, const Dims<4>& input2_dims,
- float* output_data, const Dims<4>& output_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
-
- Add(input1_data, input1_dims, input2_data, input2_dims, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-template <FusedActivationFunctionType Ac>
-inline void Add(int left_shift, const uint8* input1_data,
- const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift,
- const uint8* input2_data, const Dims<4>& input2_dims,
- int32 input2_offset, int32 input2_multiplier, int input2_shift,
- int32 output_offset, int32 output_multiplier, int output_shift,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, 0);
- TFLITE_DCHECK_EQ(output_activation_max, 255);
- }
+inline void Add(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const uint8* input1_data,
+ const RuntimeShape& input2_shape, const uint8* input2_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
+ TFLITE_DCHECK_LE(params.quantized_activation_min,
+ params.quantized_activation_max);
gemmlowp::ScopedProfilingLabel label("Add/8bit");
- const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
- TFLITE_DCHECK(IsPackedWithoutStrides(input1_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(input2_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
-
- TFLITE_DCHECK_GT(input1_offset, -256);
- TFLITE_DCHECK_GT(input2_offset, -256);
- TFLITE_DCHECK_LT(input1_offset, 256);
- TFLITE_DCHECK_LT(input2_offset, 256);
- AddElementwise(flat_size, left_shift, input1_data, input1_offset,
- input1_multiplier, input1_shift, input2_data, input2_offset,
- input2_multiplier, input2_shift, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_data);
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
+
+ TFLITE_DCHECK_GT(params.input1_offset, -256);
+ TFLITE_DCHECK_GT(params.input2_offset, -256);
+ TFLITE_DCHECK_LT(params.input1_offset, 256);
+ TFLITE_DCHECK_LT(params.input2_offset, 256);
+ AddElementwise(flat_size, params, input1_data, input2_data, output_data);
}
-inline void Add(const int16* input1_data, const Dims<4>& input1_dims,
- int input1_shift, const int16* input2_data,
- const Dims<4>& input2_dims, int input2_shift,
- int16 output_activation_min, int16 output_activation_max,
- int16* output_data, const Dims<4>& output_dims) {
+inline void Add(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const int16* input1_data,
+ const RuntimeShape& input2_shape, const int16* input2_data,
+ const RuntimeShape& output_shape, int16* output_data) {
gemmlowp::ScopedProfilingLabel label("Add/Int16");
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ TFLITE_DCHECK_LE(params.quantized_activation_min,
+ params.quantized_activation_max);
- const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims);
+ const int input1_shift = params.input1_shift;
+ const int flat_size =
+ MatchingFlatSize(output_shape, input1_shape, input2_shape);
+ const int16 output_activation_min = params.quantized_activation_min;
+ const int16 output_activation_max = params.quantized_activation_max;
- TFLITE_DCHECK(input1_shift == 0 || input2_shift == 0);
- TFLITE_DCHECK_GE(input1_shift, 0);
- TFLITE_DCHECK_GE(input2_shift, 0);
+ TFLITE_DCHECK(input1_shift == 0 || params.input2_shift == 0);
+ TFLITE_DCHECK_LE(input1_shift, 0);
+ TFLITE_DCHECK_LE(params.input2_shift, 0);
const int16* not_shift_input = input1_shift == 0 ? input1_data : input2_data;
const int16* shift_input = input1_shift == 0 ? input2_data : input1_data;
- const int input_shift = input1_shift == 0 ? input2_shift : input1_shift;
+ const int input_right_shift =
+ input1_shift == 0 ? -params.input2_shift : -input1_shift;
for (int i = 0; i < flat_size; i++) {
// F0 uses 0 integer bits, range [-1, 1].
using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
F0 input_ready_scaled = F0::FromRaw(not_shift_input[i]);
- F0 scaled_input =
- F0::FromRaw(gemmlowp::RoundingDivideByPOT(shift_input[i], input_shift));
+ F0 scaled_input = F0::FromRaw(
+ gemmlowp::RoundingDivideByPOT(shift_input[i], input_right_shift));
F0 result = gemmlowp::SaturatingAdd(scaled_input, input_ready_scaled);
const int16 raw_output = result.raw();
const int16 clamped_output = std::min(
@@ -2709,195 +2586,59 @@ inline void Add(const int16* input1_data, const Dims<4>& input1_dims,
}
}
-inline void Add(const int32* input1_data, const Dims<4>& input1_dims,
- const int32* input2_data, const Dims<4>& input2_dims,
- int32 output_activation_min, int32 output_activation_max,
- int32* output_data, const Dims<4>& output_dims) {
+inline void Add(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const int32* input1_data,
+ const RuntimeShape& input2_shape, const int32* input2_data,
+ const RuntimeShape& output_shape, int32* output_data) {
gemmlowp::ScopedProfilingLabel label("Add/int32");
- const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
- for (int i = 0; i < flat_size; ++i) {
- output_data[i] = ActivationFunctionWithMinMax(
- input1_data[i] + input2_data[i], output_activation_min,
- output_activation_max);
- }
-}
-
-template <FusedActivationFunctionType Ac>
-inline void Add(const int16* input1_data, const Dims<4>& input1_dims,
- int input1_shift, const int16* input2_data,
- const Dims<4>& input2_dims, int input2_shift,
- int16 output_activation_min, int16 output_activation_max,
- int16* output_data, const Dims<4>& output_dims) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, -32768);
- TFLITE_DCHECK_EQ(output_activation_max, 32767);
- }
-
- Add(input1_data, input1_dims, input1_shift, input2_data, input2_dims,
- input2_shift, output_activation_min, output_activation_max, output_data,
- output_dims);
-}
-
-template <FusedActivationFunctionType Ac>
-void Add(const int32* input1_data, const Dims<4>& input1_dims,
- const int32* input2_data, const Dims<4>& input2_dims,
- int32* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("Add/int32");
- TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
-
- auto input1_map = MapAsVector(input1_data, input1_dims);
- auto input2_map = MapAsVector(input2_data, input2_dims);
- auto output_map = MapAsVector(output_data, output_dims);
- if (AreSameDims(input1_dims, input2_dims)) {
+ auto input1_map = MapAsVector(input1_data, input1_shape);
+ auto input2_map = MapAsVector(input2_data, input2_shape);
+ auto output_map = MapAsVector(output_data, output_shape);
+ if (input1_shape == input2_shape) {
output_map.array() = input1_map.array() + input2_map.array();
- } else if (FlatSize(input2_dims) == 1) {
+ } else if (input2_shape.FlatSize() == 1) {
auto scalar = input2_data[0];
output_map.array() = input1_map.array() + scalar;
- } else if (FlatSize(input1_dims) == 1) {
+ } else if (input1_shape.FlatSize() == 1) {
auto scalar = input1_data[0];
output_map.array() = scalar + input2_map.array();
} else {
// Should not come here.
TFLITE_DCHECK(false);
}
+ output_map = output_map.cwiseMax(params.quantized_activation_min);
+ output_map = output_map.cwiseMin(params.quantized_activation_max);
}
-// TODO(jiawen): We can implement BroadcastAdd on buffers of arbitrary
-// dimensionality if the runtime code does a single loop over one dimension
-// that handles broadcasting as the base case. The code generator would then
-// generate max(D1, D2) nested for loops.
-// TODO(benoitjacob): BroadcastAdd is intentionally duplicated from
-// reference_ops.h. Once an optimized version is implemented and NdArrayDesc<T>
-// is no longer referenced in this file, move NdArrayDesc<T> from types.h to
-// reference_ops.h.
-template <typename T>
-void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T output_activation_min, T output_activation_max,
- T* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastAdd");
-
- NdArrayDesc<4> desc1;
- NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
-
- // In Tensorflow, the dimensions are canonically named (batch_number, row,
- // col, channel), with extents (batches, height, width, depth), with the
- // trailing dimension changing most rapidly (channels has the smallest stride,
- // typically 1 element).
- //
- // In generated C code, we store arrays with the dimensions reversed. The
- // first dimension has smallest stride.
- //
- // We name our variables by their Tensorflow convention, but generate C code
- // nesting loops such that the innermost loop has the smallest stride for the
- // best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
- ActivationFunctionWithMinMax(
- input1_data[SubscriptToIndex(desc1, c, x, y, b)] +
- input2_data[SubscriptToIndex(desc2, c, x, y, b)],
- output_activation_min, output_activation_max);
- }
- }
- }
- }
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac, typename T>
-void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T* output_data, const Dims<4>& output_dims) {
- T output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
-
- BroadcastAdd(input1_data, input1_dims, input2_data, input2_dims,
- output_activation_min, output_activation_max, output_data,
- output_dims);
-}
-
-inline void BroadcastAdd(int left_shift, const uint8* input1_data,
- const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift,
- const uint8* input2_data, const Dims<4>& input2_dims,
- int32 input2_offset, int32 input2_multiplier,
- int input2_shift, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastAddGeneric/8bit");
-
- NdArrayDesc<4> desc1;
- NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
-
- // In Tensorflow, the dimensions are canonically named (batch_number, row,
- // col, channel), with extents (batches, height, width, depth), with the
- // trailing dimension changing most rapidly (channels has the smallest stride,
- // typically 1 element).
- //
- // In generated C code, we store arrays with the dimensions reversed. The
- // first dimension has smallest stride.
- //
- // We name our variables by their Tensorflow convention, but generate C code
- // nesting loops such that the innermost loop has the smallest stride for the
- // best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- const int32 input1_val =
- input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)];
- const int32 input2_val =
- input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)];
- const int32 shifted_input1_val = input1_val * (1 << left_shift);
- const int32 shifted_input2_val = input2_val * (1 << left_shift);
- const int32 scaled_input1_val =
- MultiplyByQuantizedMultiplierSmallerThanOneExp(
- shifted_input1_val, input1_multiplier,
- kReverseShift * input1_shift);
- const int32 scaled_input2_val =
- MultiplyByQuantizedMultiplierSmallerThanOneExp(
- shifted_input2_val, input2_multiplier,
- kReverseShift * input2_shift);
- const int32 raw_sum = scaled_input1_val + scaled_input2_val;
- const int32 raw_output =
- MultiplyByQuantizedMultiplierSmallerThanOneExp(
- raw_sum, output_multiplier, kReverseShift * output_shift) +
- output_offset;
- const int32 clamped_output =
- std::min(output_activation_max,
- std::max(output_activation_min, raw_output));
- output_data[Offset(output_dims, c, x, y, b)] =
- static_cast<uint8>(clamped_output);
- }
- }
- }
- }
-}
-
-inline void BroadcastAddFivefold(
- int y0, int y1, int y2, int y3, int y4, int left_shift,
- const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift, const uint8* input2_data,
- const Dims<4>& input2_dims, int32 input2_offset, int32 input2_multiplier,
- int input2_shift, int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims) {
+inline void BroadcastAddFivefold(const ArithmeticParams& unswitched_params,
+ const RuntimeShape& unswitched_input1_shape,
+ const uint8* unswitched_input1_data,
+ const RuntimeShape& unswitched_input2_shape,
+ const uint8* unswitched_input2_data,
+ const RuntimeShape& output_shape,
+ uint8* output_data) {
gemmlowp::ScopedProfilingLabel label("BroadcastAddFivefold/8bit");
+ ArithmeticParams switched_params = unswitched_params;
+ switched_params.input1_offset = unswitched_params.input2_offset;
+ switched_params.input1_multiplier = unswitched_params.input2_multiplier;
+ switched_params.input1_shift = unswitched_params.input2_shift;
+ switched_params.input2_offset = unswitched_params.input1_offset;
+ switched_params.input2_multiplier = unswitched_params.input1_multiplier;
+ switched_params.input2_shift = unswitched_params.input1_shift;
+
+ const bool use_unswitched =
+ unswitched_params.broadcast_category ==
+ tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
+
+ const ArithmeticParams& params =
+ use_unswitched ? unswitched_params : switched_params;
+ const uint8* input1_data =
+ use_unswitched ? unswitched_input1_data : unswitched_input2_data;
+ const uint8* input2_data =
+ use_unswitched ? unswitched_input2_data : unswitched_input1_data;
+
// Fivefold nested loops. The second input resets its position for each
// iteration of the second loop. The first input resets its position at the
// beginning of the fourth loop. The innermost loop is an elementwise add of
@@ -2905,82 +2646,29 @@ inline void BroadcastAddFivefold(
uint8* output_data_ptr = output_data;
const uint8* input1_data_ptr = input1_data;
const uint8* input2_data_reset = input2_data;
- for (int i4 = 0; i4 < y4; ++i4) {
+ int y0 = params.broadcast_shape[0];
+ int y1 = params.broadcast_shape[1];
+ int y2 = params.broadcast_shape[2];
+ int y3 = params.broadcast_shape[3];
+ int y4 = params.broadcast_shape[4];
+ for (int i0 = 0; i0 < y0; ++i0) {
const uint8* input2_data_ptr;
- for (int i3 = 0; i3 < y3; ++i3) {
+ for (int i1 = 0; i1 < y1; ++i1) {
input2_data_ptr = input2_data_reset;
for (int i2 = 0; i2 < y2; ++i2) {
- for (int i1 = 0; i1 < y1; ++i1) {
- AddElementwise(
- y0, left_shift, input1_data_ptr, input1_offset, input1_multiplier,
- input1_shift, input2_data_ptr, input2_offset, input2_multiplier,
- input2_shift, output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_data_ptr);
- input2_data_ptr += y0;
- output_data_ptr += y0;
+ for (int i3 = 0; i3 < y3; ++i3) {
+ AddElementwise(y4, params, input1_data_ptr, input2_data_ptr,
+ output_data_ptr);
+ input2_data_ptr += y4;
+ output_data_ptr += y4;
}
- input1_data_ptr += y0;
+ input1_data_ptr += y4;
}
}
input2_data_reset = input2_data_ptr;
}
}
-template <FusedActivationFunctionType Ac>
-inline void BroadcastAdd(int left_shift, const uint8* input1_data,
- const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift,
- const uint8* input2_data, const Dims<4>& input2_dims,
- int32 input2_offset, int32 input2_multiplier,
- int input2_shift, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, 0);
- TFLITE_DCHECK_EQ(output_activation_max, 255);
- }
- BroadcastAdd(left_shift, input1_data, input1_dims, input1_offset,
- input1_multiplier, input1_shift, input2_data, input2_dims,
- input2_offset, input2_multiplier, input2_shift, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-template <FusedActivationFunctionType Ac>
-inline void BroadcastAddFivefold(
- int y0, int y1, int y2, int y3, int y4, int left_shift,
- const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift, const uint8* input2_data,
- const Dims<4>& input2_dims, int32 input2_offset, int32 input2_multiplier,
- int input2_shift, int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, 0);
- TFLITE_DCHECK_EQ(output_activation_max, 255);
- }
- BroadcastAddFivefold(y0, y1, y2, y3, y4, left_shift, input1_data, input1_dims,
- input1_offset, input1_multiplier, input1_shift,
- input2_data, input2_dims, input2_offset,
- input2_multiplier, input2_shift, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
inline void Mul(const float* input1_data, const Dims<4>& input1_dims,
const float* input2_data, const Dims<4>& input2_dims,
float output_activation_min, float output_activation_max,
@@ -3305,135 +2993,78 @@ void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
}
// TODO(aselle): This is not actually optimized yet.
-inline void Sub(const float* input1_data, const Dims<4>& input1_dims,
- const float* input2_data, const Dims<4>& input2_dims,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("Sub");
- const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
+inline void SubNonBroadcast(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape,
+ const float* input1_data,
+ const RuntimeShape& input2_shape,
+ const float* input2_data,
+ const RuntimeShape& output_shape,
+ float* output_data) {
+ gemmlowp::ScopedProfilingLabel label("SubNonBroadcast");
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
output_data[i] = ActivationFunctionWithMinMax(
- input1_data[i] - input2_data[i], output_activation_min,
- output_activation_max);
+ input1_data[i] - input2_data[i], params.float_activation_min,
+ params.float_activation_max);
}
}
-inline void Sub(const int32* input1_data, const Dims<4>& input1_dims,
- const int32* input2_data, const Dims<4>& input2_dims,
- int32 output_activation_min, int32 output_activation_max,
- int32* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("Sub/int32");
- const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
+inline void SubWithActivation(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape,
+ const int32* input1_data,
+ const RuntimeShape& input2_shape,
+ const int32* input2_data,
+ const RuntimeShape& output_shape,
+ int32* output_data) {
+ gemmlowp::ScopedProfilingLabel label("SubWithActivation/int32");
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, input2_shape);
for (int i = 0; i < flat_size; ++i) {
output_data[i] = ActivationFunctionWithMinMax(
- input1_data[i] - input2_data[i], output_activation_min,
- output_activation_max);
+ input1_data[i] - input2_data[i], params.quantized_activation_min,
+ params.quantized_activation_max);
}
}
-// TODO(jiawen): We can implement BroadcastSub on buffers of arbitrary
-// dimensionality if the runtime code does a single loop over one dimension
-// that handles broadcasting as the base case. The code generator would then
-// generate max(D1, D2) nested for loops.
-// TODO(benoitjacob): BroadcastSub is intentionally duplicated from
-// reference_ops.h. Once an optimized version is implemented and NdArrayDesc<T>
-// is no longer referenced in this file, move NdArrayDesc<T> from types.h to
-// reference_ops.h.
-template <typename T>
-void BroadcastSub(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T output_activation_min, T output_activation_max,
- T* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastSub");
-
- NdArrayDesc<4> desc1;
- NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
-
- // In Tensorflow, the dimensions are canonically named (batch_number, row,
- // col, channel), with extents (batches, height, width, depth), with the
- // trailing dimension changing most rapidly (channels has the smallest stride,
- // typically 1 element).
- //
- // In generated C code, we store arrays with the dimensions reversed. The
- // first dimension has smallest stride.
- //
- // We name our variables by their Tensorflow convention, but generate C code
- // nesting loops such that the innermost loop has the smallest stride for the
- // best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
- ActivationFunctionWithMinMax(
- input1_data[SubscriptToIndex(desc1, c, x, y, b)] -
- input2_data[SubscriptToIndex(desc2, c, x, y, b)],
- output_activation_min, output_activation_max);
- }
- }
- }
+inline void SubWithActivation(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape,
+ const float* input1_data,
+ const RuntimeShape& input2_shape,
+ const float* input2_data,
+ const RuntimeShape& output_shape,
+ float* output_data) {
+ gemmlowp::ScopedProfilingLabel label("SubWithActivation/float");
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, input2_shape);
+ for (int i = 0; i < flat_size; ++i) {
+ output_data[i] = ActivationFunctionWithMinMax(
+ input1_data[i] - input2_data[i], params.float_activation_min,
+ params.float_activation_max);
}
}
-inline void BroadcastSub(int left_shift, const uint8* input1_data,
- const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift,
- const uint8* input2_data, const Dims<4>& input2_dims,
- int32 input2_offset, int32 input2_multiplier,
- int input2_shift, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastSub/8bit");
-
- NdArrayDesc<4> desc1;
- NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+template <typename T>
+void Sub(const ArithmeticParams& params, const RuntimeShape& input1_shape,
+ const T* input1_data, const RuntimeShape& input2_shape,
+ const T* input2_data, const RuntimeShape& output_shape,
+ T* output_data) {
+ gemmlowp::ScopedProfilingLabel label("Sub");
- // In Tensorflow, the dimensions are canonically named (batch_number, row,
- // col, channel), with extents (batches, height, width, depth), with the
- // trailing dimension changing most rapidly (channels has the smallest stride,
- // typically 1 element).
- //
- // In generated C code, we store arrays with the dimensions reversed. The
- // first dimension has smallest stride.
- //
- // We name our variables by their Tensorflow convention, but generate C code
- // nesting loops such that the innermost loop has the smallest stride for the
- // best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- const int32 input1_val =
- input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)];
- const int32 input2_val =
- input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)];
- const int32 shifted_input1_val = input1_val * (1 << left_shift);
- const int32 shifted_input2_val = input2_val * (1 << left_shift);
- const int32 scaled_input1_val =
- MultiplyByQuantizedMultiplierSmallerThanOneExp(
- shifted_input1_val, input1_multiplier,
- kReverseShift * input1_shift);
- const int32 scaled_input2_val =
- MultiplyByQuantizedMultiplierSmallerThanOneExp(
- shifted_input2_val, input2_multiplier,
- kReverseShift * input2_shift);
- const int32 raw_sub = scaled_input1_val - scaled_input2_val;
- const int32 raw_output =
- MultiplyByQuantizedMultiplierSmallerThanOneExp(
- raw_sub, output_multiplier, kReverseShift * output_shift) +
- output_offset;
- const int32 clamped_output =
- std::min(output_activation_max,
- std::max(output_activation_min, raw_output));
- output_data[Offset(output_dims, c, x, y, b)] =
- static_cast<uint8>(clamped_output);
- }
- }
- }
+ auto input1_map = MapAsVector(input1_data, input1_shape);
+ auto input2_map = MapAsVector(input2_data, input2_shape);
+ auto output_map = MapAsVector(output_data, output_shape);
+ if (input1_shape == input2_shape) {
+ output_map.array() = input1_map.array() - input2_map.array();
+ } else if (input1_shape.FlatSize() == 1) {
+ auto scalar = input1_data[0];
+ output_map.array() = scalar - input2_map.array();
+ } else if (input2_shape.FlatSize() == 1) {
+ auto scalar = input2_data[0];
+ output_map.array() = input1_map.array() - scalar;
+ } else {
+ BroadcastSub4DSlow(params, input1_shape, input1_data, input2_shape,
+ input2_data, output_shape, output_data);
}
}
@@ -5876,63 +5507,6 @@ inline void Slice(const T* input_data, const Dims<4>& input_dims,
}
template <typename T>
-void GenericBroadcastSub(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("GenericBroadcastSub");
-
- NdArrayDesc<4> desc1;
- NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
-
- // In Tensorflow, the dimensions are canonically named (batch_number, row,
- // col, channel), with extents (batches, height, width, depth), with the
- // trailing dimension changing most rapidly (channels has the smallest stride,
- // typically 1 element).
- //
- // In generated C code, we store arrays with the dimensions reversed. The
- // first dimension has smallest stride.
- //
- // We name our variables by their Tensorflow convention, but generate C code
- // nesting loops such that the innermost loop has the smallest stride for the
- // best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
- input1_data[SubscriptToIndex(desc1, c, x, y, b)] -
- input2_data[SubscriptToIndex(desc2, c, x, y, b)];
- }
- }
- }
- }
-}
-
-template <typename T>
-void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data,
- const Dims<4>& input2_dims, T* output_data,
- const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("Sub");
-
- auto input1_map = MapAsVector(input1_data, input1_dims);
- auto input2_map = MapAsVector(input2_data, input2_dims);
- auto output_map = MapAsVector(output_data, output_dims);
- if (AreSameDims(input1_dims, input2_dims)) {
- output_map.array() = input1_map.array() - input2_map.array();
- } else if (FlatSize(input1_dims) == 1) {
- auto scalar = input1_data[0];
- output_map.array() = scalar - input2_map.array();
- } else if (FlatSize(input2_dims) == 1) {
- auto scalar = input2_data[0];
- output_map.array() = input1_map.array() - scalar;
- } else {
- GenericBroadcastSub(input1_data, input1_dims, input2_data, input2_dims,
- output_data, output_dims);
- }
-}
-
-template <typename T>
void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims,
const T* input2_data, T* output_data,
const Dims<4>& output_dims) {
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
index f715d34bc1..bcf5e4e4f6 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
@@ -63,6 +63,240 @@ inline void Relu6(const float* input_data, const Dims<4>& input_dims,
DimsToShape(output_dims));
}
+template <FusedActivationFunctionType Ac>
+inline void Add(int left_shift, const uint8* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift,
+ const uint8* input2_data, const Dims<4>& input2_dims,
+ int32 input2_offset, int32 input2_multiplier, int input2_shift,
+ int32 output_offset, int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ constexpr int kReverseShift = -1;
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+
+ tflite::ArithmeticParams op_params;
+ op_params.left_shift = left_shift;
+ op_params.input1_offset = input1_offset;
+ op_params.input1_multiplier = input1_multiplier;
+ op_params.input1_shift = kReverseShift * input1_shift;
+ op_params.input2_offset = input2_offset;
+ op_params.input2_multiplier = input2_multiplier;
+ op_params.input2_shift = kReverseShift * input2_shift;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+ Add(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <FusedActivationFunctionType Ac>
+void Add(const int32* input1_data, const Dims<4>& input1_dims,
+ const int32* input2_data, const Dims<4>& input2_dims,
+ int32* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Add/int32");
+ TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
+
+ tflite::ArithmeticParams op_params;
+ op_params.quantized_activation_min = std::numeric_limits<int32>::min();
+ op_params.quantized_activation_max = std::numeric_limits<int32>::max();
+ Add(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <FusedActivationFunctionType Ac>
+inline void BroadcastAdd(int left_shift, const uint8* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift,
+ const uint8* input2_data, const Dims<4>& input2_dims,
+ int32 input2_offset, int32 input2_multiplier,
+ int input2_shift, int32 output_offset,
+ int32 output_multiplier, int output_shift,
+ int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ constexpr int kReverseShift = -1;
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+
+ tflite::ArithmeticParams op_params;
+ op_params.left_shift = left_shift;
+ op_params.input1_offset = input1_offset;
+ op_params.input1_multiplier = input1_multiplier;
+ op_params.input1_shift = kReverseShift * input1_shift;
+ op_params.input2_offset = input2_offset;
+ op_params.input2_multiplier = input2_multiplier;
+ op_params.input2_shift = kReverseShift * input2_shift;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+ BroadcastAdd4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <FusedActivationFunctionType Ac>
+void Add(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ tflite::ArithmeticParams op_params;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+ Add(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <typename T>
+void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T output_activation_min, T output_activation_max,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+ BroadcastAdd4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <FusedActivationFunctionType Ac>
+inline void BroadcastAddFivefold(
+ int y0, int y1, int y2, int y3, int y4, int left_shift,
+ const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift, const uint8* input2_data,
+ const Dims<4>& input2_dims, int32 input2_offset, int32 input2_multiplier,
+ int input2_shift, int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ constexpr int kReverseShift = -1;
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ tflite::ArithmeticParams op_params;
+ op_params.broadcast_category =
+ tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
+ op_params.left_shift = left_shift;
+ op_params.input1_offset = input1_offset;
+ op_params.input1_multiplier = input1_multiplier;
+ op_params.input1_shift = kReverseShift * input1_shift;
+ op_params.input2_offset = input2_offset;
+ op_params.input2_multiplier = input2_multiplier;
+ op_params.input2_shift = kReverseShift * input2_shift;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+ op_params.broadcast_shape[4] = y0;
+ op_params.broadcast_shape[3] = y1;
+ op_params.broadcast_shape[2] = y2;
+ op_params.broadcast_shape[1] = y3;
+ op_params.broadcast_shape[0] = y4;
+ BroadcastAddFivefold(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac, typename T>
+void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ T output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ BroadcastAdd(input1_data, input1_dims, input2_data, input2_dims,
+ output_activation_min, output_activation_max, output_data,
+ output_dims);
+}
+
+template <FusedActivationFunctionType Ac>
+inline void Add(const int16* input1_data, const Dims<4>& input1_dims,
+ int input1_shift, const int16* input2_data,
+ const Dims<4>& input2_dims, int input2_shift,
+ int16 output_activation_min, int16 output_activation_max,
+ int16* output_data, const Dims<4>& output_dims) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, -32768);
+ TFLITE_DCHECK_EQ(output_activation_max, 32767);
+ }
+
+ tflite::ArithmeticParams op_params;
+ op_params.input1_shift = kReverseShift * input1_shift;
+ op_params.input2_shift = kReverseShift * input2_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+ Add(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+inline void Sub(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(FusedActivationFunctionType::kNone,
+ &output_activation_min, &output_activation_max);
+ tflite::ArithmeticParams op_params;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+ Sub(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <typename T>
+void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data,
+ const Dims<4>& input2_dims, T* output_data,
+ const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ op_params.quantized_activation_min = std::numeric_limits<T>::min();
+ op_params.quantized_activation_max = std::numeric_limits<T>::max();
+ Sub(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
inline void AveragePool(const float* input_data, const Dims<4>& input_dims,
int stride_width, int stride_height, int pad_width,
int pad_height, int kwidth, int kheight,
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index 6fabb9c268..10e23f0b41 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -158,98 +158,6 @@ SaturatingRoundingMultiplyByPOTParam(
SaturatingRoundingMultiplyByPOTParam(a.raw(), exponent));
}
-// DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING ELEMENT-WISE
-// BROADCASTING.
-//
-// NdArrayDesc<N> describes the shape and memory layout of an N-dimensional
-// rectangular array of numbers.
-//
-// NdArrayDesc<N> is basically identical to Dims<N> defined in types.h.
-// However, as Dims<N> is to be deprecated, this class exists as an adaptor
-// to enable simple unoptimized implementations of element-wise broadcasting
-// operations.
-template <int N>
-struct NdArrayDesc {
- // The "extent" of each dimension. Indices along dimension d must be in the
- // half-open interval [0, extents[d]).
- int extents[N];
-
- // The number of *elements* (not bytes) between consecutive indices of each
- // dimension.
- int strides[N];
-};
-
-// DO NOT USE THIS FUNCTION FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING
-// ELEMENT-WISE BROADCASTING.
-//
-// Same as Offset(), except takes as NdArrayDesc<N> instead of Dims<N>.
-inline int SubscriptToIndex(const NdArrayDesc<4>& desc, int i0, int i1, int i2,
- int i3) {
- TFLITE_DCHECK(i0 >= 0 && i0 < desc.extents[0]);
- TFLITE_DCHECK(i1 >= 0 && i1 < desc.extents[1]);
- TFLITE_DCHECK(i2 >= 0 && i2 < desc.extents[2]);
- TFLITE_DCHECK(i3 >= 0 && i3 < desc.extents[3]);
- return i0 * desc.strides[0] + i1 * desc.strides[1] + i2 * desc.strides[2] +
- i3 * desc.strides[3];
-}
-
-// Given the dimensions of the operands for an element-wise binary broadcast,
-// adjusts them so that they can be directly iterated over with simple loops.
-// Returns the adjusted dims as instances of NdArrayDesc in 'desc0_out' and
-// 'desc1_out'. 'desc0_out' and 'desc1_out' cannot be nullptr.
-//
-// This function assumes that the two input shapes are compatible up to
-// broadcasting and the shorter one has already been prepended with 1s to be the
-// same length. E.g., if shape0 is (1, 16, 16, 64) and shape1 is (1, 64),
-// shape1 must already have been prepended to be (1, 1, 1, 64). Recall that
-// Dims<N> refer to shapes in reverse order. In this case, input0_dims will be
-// (64, 16, 16, 1) and input1_dims will be (64, 1, 1, 1).
-//
-// When two shapes are compatible up to broadcasting, for each dimension d,
-// the input extents are either equal, or one of them is 1.
-//
-// This function performs the following for each dimension d:
-// - If the extents are equal, then do nothing since the loop that walks over
-// both of the input arrays is correct.
-// - Otherwise, one (and only one) of the extents must be 1. Say extent0 is 1
-// and extent1 is e1. Then set extent0 to e1 and stride0 *to 0*. This allows
-// array0 to be referenced *at any index* in dimension d and still access the
-// same slice.
-template <int N>
-inline void NdArrayDescsForElementwiseBroadcast(const Dims<N>& input0_dims,
- const Dims<N>& input1_dims,
- NdArrayDesc<N>* desc0_out,
- NdArrayDesc<N>* desc1_out) {
- TFLITE_DCHECK(desc0_out != nullptr);
- TFLITE_DCHECK(desc1_out != nullptr);
-
- // Copy dims to desc.
- for (int i = 0; i < N; ++i) {
- desc0_out->extents[i] = input0_dims.sizes[i];
- desc0_out->strides[i] = input0_dims.strides[i];
- desc1_out->extents[i] = input1_dims.sizes[i];
- desc1_out->strides[i] = input1_dims.strides[i];
- }
-
- // Walk over each dimension. If the extents are equal do nothing.
- // Otherwise, set the desc with extent 1 to have extent equal to the other and
- // stride 0.
- for (int i = 0; i < N; ++i) {
- const int extent0 = ArraySize(input0_dims, i);
- const int extent1 = ArraySize(input1_dims, i);
- if (extent0 != extent1) {
- if (extent0 == 1) {
- desc0_out->strides[i] = 0;
- desc0_out->extents[i] = extent1;
- } else {
- TFLITE_DCHECK_EQ(extent1, 1);
- desc1_out->strides[i] = 0;
- desc1_out->extents[i] = extent0;
- }
- }
- }
-}
-
inline void Conv(const float* input_data, const Dims<4>& input_dims,
const float* filter_data, const Dims<4>& filter_dims,
const float* bias_data, const Dims<4>& bias_dims,
@@ -1065,114 +973,108 @@ inline void L2Normalization(const uint8* input_data,
}
template <typename T>
-inline void Add(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T output_activation_min, T output_activation_max,
- T* output_data, const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
+inline void Add(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const T* input1_data,
+ const RuntimeShape& input2_shape, const T* input2_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
output_data[i] = ActivationFunctionWithMinMax(
- input1_data[i] + input2_data[i], output_activation_min,
- output_activation_max);
+ input1_data[i] + input2_data[i], params.quantized_activation_min,
+ params.quantized_activation_max);
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void Add(const float* input1_data, const Dims<4>& input1_dims,
- const float* input2_data, const Dims<4>& input2_dims,
- float* output_data, const Dims<4>& output_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
-
- Add(input1_data, input1_dims, input2_data, input2_dims, output_activation_min,
- output_activation_max, output_data, output_dims);
+inline void Add(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const float* input1_data,
+ const RuntimeShape& input2_shape, const float* input2_data,
+ const RuntimeShape& output_shape, float* output_data) {
+ const int size = MatchingFlatSize(input1_shape, input2_shape, output_shape);
+ for (int i = 0; i < size; i++) {
+ auto x = input1_data[i] + input2_data[i];
+ output_data[i] = ActivationFunctionWithMinMax(
+ x, params.float_activation_min, params.float_activation_max);
+ }
}
-template <FusedActivationFunctionType Ac>
-inline void Add(int left_shift, const uint8* input1_data,
- const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift,
- const uint8* input2_data, const Dims<4>& input2_dims,
- int32 input2_offset, int32 input2_multiplier, int input2_shift,
- int32 output_offset, int32 output_multiplier, int output_shift,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, 0);
- TFLITE_DCHECK_EQ(output_activation_max, 255);
- }
- const int batches =
- MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3);
- const int height =
- MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2);
- const int width =
- MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1);
- const int depth =
- MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0);
- for (int b = 0; b < batches; ++b) {
- for (int y = 0; y < height; ++y) {
- for (int x = 0; x < width; ++x) {
- for (int c = 0; c < depth; ++c) {
- const int32 input1_val =
- input1_offset + input1_data[Offset(input1_dims, c, x, y, b)];
- const int32 input2_val =
- input2_offset + input2_data[Offset(input2_dims, c, x, y, b)];
- const int32 shifted_input1_val = input1_val * (1 << left_shift);
- const int32 shifted_input2_val = input2_val * (1 << left_shift);
- const int32 scaled_input1_val =
- MultiplyByQuantizedMultiplierSmallerThanOneExp(
- shifted_input1_val, input1_multiplier,
- kReverseShift * input1_shift);
- const int32 scaled_input2_val =
- MultiplyByQuantizedMultiplierSmallerThanOneExp(
- shifted_input2_val, input2_multiplier,
- kReverseShift * input2_shift);
- const int32 raw_sum = scaled_input1_val + scaled_input2_val;
- const int32 raw_output =
- MultiplyByQuantizedMultiplierSmallerThanOneExp(
- raw_sum, output_multiplier, kReverseShift * output_shift) +
- output_offset;
- const int32 clamped_output =
- std::min(output_activation_max,
- std::max(output_activation_min, raw_output));
- output_data[Offset(output_dims, c, x, y, b)] =
- static_cast<uint8>(clamped_output);
- }
- }
- }
+// Element-wise add that can often be used for inner loop of broadcast add as
+// well as the non-broadcast add.
+inline void AddElementwise(int size, const ArithmeticParams& params,
+ const uint8* input1_data, const uint8* input2_data,
+ uint8* output_data) {
+ TFLITE_DCHECK_GT(params.input1_offset, -256);
+ TFLITE_DCHECK_GT(params.input2_offset, -256);
+ TFLITE_DCHECK_LT(params.input1_offset, 256);
+ TFLITE_DCHECK_LT(params.input2_offset, 256);
+
+ for (int i = 0; i < size; ++i) {
+ const int32 input1_val = params.input1_offset + input1_data[i];
+ const int32 input2_val = params.input2_offset + input2_data[i];
+ const int32 shifted_input1_val = input1_val * (1 << params.left_shift);
+ const int32 shifted_input2_val = input2_val * (1 << params.left_shift);
+ const int32 scaled_input1_val =
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ shifted_input1_val, params.input1_multiplier, params.input1_shift);
+ const int32 scaled_input2_val =
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ shifted_input2_val, params.input2_multiplier, params.input2_shift);
+ const int32 raw_sum = scaled_input1_val + scaled_input2_val;
+ const int32 raw_output =
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ raw_sum, params.output_multiplier, params.output_shift) +
+ params.output_offset;
+ const int32 clamped_output =
+ std::min(params.quantized_activation_max,
+ std::max(params.quantized_activation_min, raw_output));
+ output_data[i] = static_cast<uint8>(clamped_output);
}
}
-inline void Add(const int16* input1_data, const Dims<4>& input1_dims,
- int input1_shift, const int16* input2_data,
- const Dims<4>& input2_dims, int input2_shift,
- int16 output_activation_min, int16 output_activation_max,
- int16* output_data, const Dims<4>& output_dims) {
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+inline void Add(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const uint8* input1_data,
+ const RuntimeShape& input2_shape, const uint8* input2_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
+ TFLITE_DCHECK_LE(params.quantized_activation_min,
+ params.quantized_activation_max);
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
- const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims);
+ TFLITE_DCHECK_GT(params.input1_offset, -256);
+ TFLITE_DCHECK_GT(params.input2_offset, -256);
+ TFLITE_DCHECK_LT(params.input1_offset, 256);
+ TFLITE_DCHECK_LT(params.input2_offset, 256);
+ AddElementwise(flat_size, params, input1_data, input2_data, output_data);
+}
- TFLITE_DCHECK(input1_shift == 0 || input2_shift == 0);
- TFLITE_DCHECK_GE(input1_shift, 0);
- TFLITE_DCHECK_GE(input2_shift, 0);
+inline void Add(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const int16* input1_data,
+ const RuntimeShape& input2_shape, const int16* input2_data,
+ const RuntimeShape& output_shape, int16* output_data) {
+ TFLITE_DCHECK_LE(params.quantized_activation_min,
+ params.quantized_activation_max);
+
+ const int input1_shift = params.input1_shift;
+ const int flat_size =
+ MatchingFlatSize(output_shape, input1_shape, input2_shape);
+ const int16 output_activation_min = params.quantized_activation_min;
+ const int16 output_activation_max = params.quantized_activation_max;
+
+ TFLITE_DCHECK(input1_shift == 0 || params.input2_shift == 0);
+ TFLITE_DCHECK_LE(input1_shift, 0);
+ TFLITE_DCHECK_LE(params.input2_shift, 0);
const int16* not_shift_input = input1_shift == 0 ? input1_data : input2_data;
const int16* shift_input = input1_shift == 0 ? input2_data : input1_data;
- const int input_shift = input1_shift == 0 ? input2_shift : input1_shift;
+ const int input_right_shift =
+ input1_shift == 0 ? -params.input2_shift : -input1_shift;
for (int i = 0; i < flat_size; i++) {
// F0 uses 0 integer bits, range [-1, 1].
using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
F0 input_ready_scaled = F0::FromRaw(not_shift_input[i]);
- F0 scaled_input =
- F0::FromRaw(gemmlowp::RoundingDivideByPOT(shift_input[i], input_shift));
+ F0 scaled_input = F0::FromRaw(
+ gemmlowp::RoundingDivideByPOT(shift_input[i], input_right_shift));
F0 result = gemmlowp::SaturatingAdd(scaled_input, input_ready_scaled);
const int16 raw_output = result.raw();
const int16 clamped_output = std::min(
@@ -1181,42 +1083,28 @@ inline void Add(const int16* input1_data, const Dims<4>& input1_dims,
}
}
-template <FusedActivationFunctionType Ac>
-inline void Add(const int16* input1_data, const Dims<4>& input1_dims,
- int input1_shift, const int16* input2_data,
- const Dims<4>& input2_dims, int input2_shift,
- int16 output_activation_min, int16 output_activation_max,
- int16* output_data, const Dims<4>& output_dims) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, -32768);
- TFLITE_DCHECK_EQ(output_activation_max, 32767);
- }
-
- Add(input1_data, input1_dims, input1_shift, input2_data, input2_dims,
- input2_shift, output_activation_min, output_activation_max, output_data,
- output_dims);
-}
-
// TODO(jiawen): We can implement BroadcastAdd on buffers of arbitrary
// dimensionality if the runtime code does a single loop over one dimension
// that handles broadcasting as the base case. The code generator would then
// generate max(D1, D2) nested for loops.
-template <typename T>
-void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T output_activation_min, T output_activation_max,
- T* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastAdd");
-
+// TODO(benoitjacob): BroadcastAdd is intentionally duplicated from
+// reference_ops.h. Once an optimized version is implemented and NdArrayDesc<T>
+// is no longer referenced in this file, move NdArrayDesc<T> from types.h to
+// reference_ops.h.
+inline void BroadcastAdd4DSlow(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape,
+ const float* input1_data,
+ const RuntimeShape& input2_shape,
+ const float* input2_data,
+ const RuntimeShape& output_shape,
+ float* output_data) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/float");
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+ NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
+ &desc2);
+ RuntimeShape extended_output_shape =
+ RuntimeShape::ExtendedShape(4, output_shape);
// In Tensorflow, the dimensions are canonically named (batch_number, row,
// col, channel), with extents (batches, height, width, depth), with the
@@ -1229,49 +1117,77 @@ void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
// We name our variables by their Tensorflow convention, but generate C code
// nesting loops such that the innermost loop has the smallest stride for the
// best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
+ for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
+ for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
+ for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
+ for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
+ output_data[Offset(extended_output_shape, b, y, x, c)] =
ActivationFunctionWithMinMax(
- input1_data[SubscriptToIndex(desc1, c, x, y, b)] +
- input2_data[SubscriptToIndex(desc2, c, x, y, b)],
- output_activation_min, output_activation_max);
+ input1_data[SubscriptToIndex(desc1, b, y, x, c)] +
+ input2_data[SubscriptToIndex(desc2, b, y, x, c)],
+ params.float_activation_min, params.float_activation_max);
}
}
}
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac, typename T>
-void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T* output_data, const Dims<4>& output_dims) {
- T output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+inline void BroadcastAdd4DSlow(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape,
+ const int32* input1_data,
+ const RuntimeShape& input2_shape,
+ const int32* input2_data,
+ const RuntimeShape& output_shape,
+ int32* output_data) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/int32");
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
+ &desc2);
+ RuntimeShape extended_output_shape =
+ RuntimeShape::ExtendedShape(4, output_shape);
- BroadcastAdd(input1_data, input1_dims, input2_data, input2_dims,
- output_activation_min, output_activation_max, output_data,
- output_dims);
+ // In Tensorflow, the dimensions are canonically named (batch_number, row,
+ // col, channel), with extents (batches, height, width, depth), with the
+ // trailing dimension changing most rapidly (channels has the smallest stride,
+ // typically 1 element).
+ //
+ // In generated C code, we store arrays with the dimensions reversed. The
+ // first dimension has smallest stride.
+ //
+ // We name our variables by their Tensorflow convention, but generate C code
+ // nesting loops such that the innermost loop has the smallest stride for the
+ // best cache behavior.
+ for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
+ for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
+ for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
+ for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
+ output_data[Offset(extended_output_shape, b, y, x, c)] =
+ ActivationFunctionWithMinMax(
+ input1_data[SubscriptToIndex(desc1, b, y, x, c)] +
+ input2_data[SubscriptToIndex(desc2, b, y, x, c)],
+ params.quantized_activation_min,
+ params.quantized_activation_max);
+ }
+ }
+ }
+ }
}
-inline void BroadcastAdd(int left_shift, const uint8* input1_data,
- const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift,
- const uint8* input2_data, const Dims<4>& input2_dims,
- int32 input2_offset, int32 input2_multiplier,
- int input2_shift, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastAdd/8bit");
-
+inline void BroadcastAdd4DSlow(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape,
+ const uint8* input1_data,
+ const RuntimeShape& input2_shape,
+ const uint8* input2_data,
+ const RuntimeShape& output_shape,
+ uint8* output_data) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/uint8");
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+ NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
+ &desc2);
+ RuntimeShape extended_output_shape =
+ RuntimeShape::ExtendedShape(4, output_shape);
// In Tensorflow, the dimensions are canonically named (batch_number, row,
// col, channel), with extents (batches, height, width, depth), with the
@@ -1284,33 +1200,37 @@ inline void BroadcastAdd(int left_shift, const uint8* input1_data,
// We name our variables by their Tensorflow convention, but generate C code
// nesting loops such that the innermost loop has the smallest stride for the
// best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
+ for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
+ for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
+ for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
+ for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
const int32 input1_val =
- input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)];
+ params.input1_offset +
+ input1_data[SubscriptToIndex(desc1, b, y, x, c)];
const int32 input2_val =
- input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)];
- const int32 shifted_input1_val = input1_val * (1 << left_shift);
- const int32 shifted_input2_val = input2_val * (1 << left_shift);
+ params.input2_offset +
+ input2_data[SubscriptToIndex(desc2, b, y, x, c)];
+ const int32 shifted_input1_val =
+ input1_val * (1 << params.left_shift);
+ const int32 shifted_input2_val =
+ input2_val * (1 << params.left_shift);
const int32 scaled_input1_val =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
- shifted_input1_val, input1_multiplier,
- kReverseShift * input1_shift);
+ shifted_input1_val, params.input1_multiplier,
+ params.input1_shift);
const int32 scaled_input2_val =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
- shifted_input2_val, input2_multiplier,
- kReverseShift * input2_shift);
+ shifted_input2_val, params.input2_multiplier,
+ params.input2_shift);
const int32 raw_sum = scaled_input1_val + scaled_input2_val;
const int32 raw_output =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
- raw_sum, output_multiplier, kReverseShift * output_shift) +
- output_offset;
+ raw_sum, params.output_multiplier, params.output_shift) +
+ params.output_offset;
const int32 clamped_output =
- std::min(output_activation_max,
- std::max(output_activation_min, raw_output));
- output_data[Offset(output_dims, c, x, y, b)] =
+ std::min(params.quantized_activation_max,
+ std::max(params.quantized_activation_min, raw_output));
+ output_data[Offset(extended_output_shape, b, y, x, c)] =
static_cast<uint8>(clamped_output);
}
}
@@ -1318,117 +1238,62 @@ inline void BroadcastAdd(int left_shift, const uint8* input1_data,
}
}
-inline void BroadcastAddFivefold(
- int y0, int y1, int y2, int y3, int y4, int left_shift,
- const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift, const uint8* input2_data,
- const Dims<4>& input2_dims, int32 input2_offset, int32 input2_multiplier,
- int input2_shift, int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastAddFivefold/8bit");
-
- int sb1 = y0;
- int sa2 = y0;
- int sb2 = y0 * y1;
- int sa3 = y0 * y2;
- int sa4 = y0 * y2 * y3;
- int sb4 = y0 * y1 * y2;
-
+inline void BroadcastAddFivefold(const ArithmeticParams& unswitched_params,
+ const RuntimeShape& unswitched_input1_shape,
+ const uint8* unswitched_input1_data,
+ const RuntimeShape& unswitched_input2_shape,
+ const uint8* unswitched_input2_data,
+ const RuntimeShape& output_shape,
+ uint8* output_data) {
+ ArithmeticParams switched_params = unswitched_params;
+ switched_params.input1_offset = unswitched_params.input2_offset;
+ switched_params.input1_multiplier = unswitched_params.input2_multiplier;
+ switched_params.input1_shift = unswitched_params.input2_shift;
+ switched_params.input2_offset = unswitched_params.input1_offset;
+ switched_params.input2_multiplier = unswitched_params.input1_multiplier;
+ switched_params.input2_shift = unswitched_params.input1_shift;
+
+ const bool use_unswitched =
+ unswitched_params.broadcast_category ==
+ tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
+
+ const ArithmeticParams& params =
+ use_unswitched ? unswitched_params : switched_params;
+ const uint8* input1_data =
+ use_unswitched ? unswitched_input1_data : unswitched_input2_data;
+ const uint8* input2_data =
+ use_unswitched ? unswitched_input2_data : unswitched_input1_data;
+
+ // Fivefold nested loops. The second input resets its position for each
+ // iteration of the second loop. The first input resets its position at the
+ // beginning of the fourth loop. The innermost loop is an elementwise add of
+ // sections of the arrays.
uint8* output_data_ptr = output_data;
- for (int i4 = 0; i4 < y4; ++i4) {
- for (int i3 = 0; i3 < y3; ++i3) {
+ const uint8* input1_data_ptr = input1_data;
+ const uint8* input2_data_reset = input2_data;
+ int y0 = params.broadcast_shape[0];
+ int y1 = params.broadcast_shape[1];
+ int y2 = params.broadcast_shape[2];
+ int y3 = params.broadcast_shape[3];
+ int y4 = params.broadcast_shape[4];
+ for (int i0 = 0; i0 < y0; ++i0) {
+ const uint8* input2_data_ptr;
+ for (int i1 = 0; i1 < y1; ++i1) {
+ input2_data_ptr = input2_data_reset;
for (int i2 = 0; i2 < y2; ++i2) {
- for (int i1 = 0; i1 < y1; ++i1) {
- for (int i0 = 0; i0 < y0; ++i0) {
- const int32 input1_val =
- input1_offset +
- input1_data[i4 * sa4 + i3 * sa3 + i2 * sa2 + i0];
- const int32 input2_val =
- input2_offset +
- input2_data[i4 * sb4 + i2 * sb2 + i1 * sb1 + i0];
- const int32 shifted_input1_val = input1_val * (1 << left_shift);
- const int32 shifted_input2_val = input2_val * (1 << left_shift);
- const int32 scaled_input1_val =
- MultiplyByQuantizedMultiplierSmallerThanOneExp(
- shifted_input1_val, input1_multiplier,
- kReverseShift * input1_shift);
- const int32 scaled_input2_val =
- MultiplyByQuantizedMultiplierSmallerThanOneExp(
- shifted_input2_val, input2_multiplier,
- kReverseShift * input2_shift);
- const int32 raw_sum = scaled_input1_val + scaled_input2_val;
- const int32 raw_output =
- MultiplyByQuantizedMultiplierSmallerThanOneExp(
- raw_sum, output_multiplier, kReverseShift * output_shift) +
- output_offset;
- const int32 clamped_output =
- std::min(output_activation_max,
- std::max(output_activation_min, raw_output));
- *output_data_ptr = static_cast<uint8>(clamped_output);
- ++output_data_ptr;
- }
+ for (int i3 = 0; i3 < y3; ++i3) {
+ AddElementwise(y4, params, input1_data_ptr, input2_data_ptr,
+ output_data_ptr);
+ input2_data_ptr += y4;
+ output_data_ptr += y4;
}
+ input1_data_ptr += y4;
}
}
+ input2_data_reset = input2_data_ptr;
}
}
-template <FusedActivationFunctionType Ac>
-inline void BroadcastAdd(int left_shift, const uint8* input1_data,
- const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift,
- const uint8* input2_data, const Dims<4>& input2_dims,
- int32 input2_offset, int32 input2_multiplier,
- int input2_shift, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, 0);
- TFLITE_DCHECK_EQ(output_activation_max, 255);
- }
- BroadcastAdd(left_shift, input1_data, input1_dims, input1_offset,
- input1_multiplier, input1_shift, input2_data, input2_dims,
- input2_offset, input2_multiplier, input2_shift, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-template <FusedActivationFunctionType Ac>
-inline void BroadcastAddFivefold(
- int y0, int y1, int y2, int y3, int y4, int left_shift,
- const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift, const uint8* input2_data,
- const Dims<4>& input2_dims, int32 input2_offset, int32 input2_multiplier,
- int input2_shift, int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, 0);
- TFLITE_DCHECK_EQ(output_activation_max, 255);
- }
- BroadcastAddFivefold(y0, y1, y2, y3, y4, left_shift, input1_data, input1_dims,
- input1_offset, input1_multiplier, input1_shift,
- input2_data, input2_dims, input2_offset,
- input2_multiplier, input2_shift, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
template <typename T>
inline void Mul(const T* input1_data, const Dims<4>& input1_dims,
const T* input2_data, const Dims<4>& input2_dims,
@@ -1654,10 +1519,11 @@ void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
}
}
-inline void Div(const float* input1_data, const Dims<4>& input1_dims,
- const float* input2_data, const Dims<4>& input2_dims,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims) {
+template <typename T>
+inline void Div(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T output_activation_min, T output_activation_max,
+ T* output_data, const Dims<4>& output_dims) {
const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
for (int i = 0; i < flat_size; ++i) {
output_data[i] = ActivationFunctionWithMinMax(
@@ -1666,16 +1532,35 @@ inline void Div(const float* input1_data, const Dims<4>& input1_dims,
}
}
-template <typename T>
-inline void Sub(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T output_activation_min, T output_activation_max,
- T* output_data, const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
+inline void SubNonBroadcast(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape,
+ const float* input1_data,
+ const RuntimeShape& input2_shape,
+ const float* input2_data,
+ const RuntimeShape& output_shape,
+ float* output_data) {
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
output_data[i] = ActivationFunctionWithMinMax(
- input1_data[i] - input2_data[i], output_activation_min,
- output_activation_max);
+ input1_data[i] - input2_data[i], params.float_activation_min,
+ params.float_activation_max);
+ }
+}
+
+inline void SubNonBroadcast(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape,
+ const int32* input1_data,
+ const RuntimeShape& input2_shape,
+ const int32* input2_data,
+ const RuntimeShape& output_shape,
+ int32* output_data) {
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
+ for (int i = 0; i < flat_size; ++i) {
+ output_data[i] = ActivationFunctionWithMinMax(
+ input1_data[i] - input2_data[i], params.quantized_activation_min,
+ params.quantized_activation_max);
}
}
@@ -1683,16 +1568,24 @@ inline void Sub(const T* input1_data, const Dims<4>& input1_dims,
// dimensionality if the runtime code does a single loop over one dimension
// that handles broadcasting as the base case. The code generator would then
// generate max(D1, D2) nested for loops.
-template <typename T>
-void BroadcastSub(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T output_activation_min, T output_activation_max,
- T* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastSub");
-
+// TODO(benoitjacob): BroadcastSub is intentionally duplicated from
+// reference_ops.h. Once an optimized version is implemented and NdArrayDesc<T>
+// is no longer referenced in this file, move NdArrayDesc<T> from types.h to
+// reference_ops.h.
+inline void BroadcastSub4DSlow(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape,
+ const float* input1_data,
+ const RuntimeShape& input2_shape,
+ const float* input2_data,
+ const RuntimeShape& output_shape,
+ float* output_data) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/float");
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+ NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
+ &desc2);
+ RuntimeShape extended_output_shape =
+ RuntimeShape::ExtendedShape(4, output_shape);
// In Tensorflow, the dimensions are canonically named (batch_number, row,
// col, channel), with extents (batches, height, width, depth), with the
@@ -1705,36 +1598,35 @@ void BroadcastSub(const T* input1_data, const Dims<4>& input1_dims,
// We name our variables by their Tensorflow convention, but generate C code
// nesting loops such that the innermost loop has the smallest stride for the
// best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
+ for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
+ for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
+ for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
+ for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
+ output_data[Offset(extended_output_shape, b, y, x, c)] =
ActivationFunctionWithMinMax(
- input1_data[SubscriptToIndex(desc1, c, x, y, b)] -
- input2_data[SubscriptToIndex(desc2, c, x, y, b)],
- output_activation_min, output_activation_max);
+ input1_data[SubscriptToIndex(desc1, b, y, x, c)] -
+ input2_data[SubscriptToIndex(desc2, b, y, x, c)],
+ params.float_activation_min, params.float_activation_max);
}
}
}
}
}
-inline void BroadcastSub(int left_shift, const uint8* input1_data,
- const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift,
- const uint8* input2_data, const Dims<4>& input2_dims,
- int32 input2_offset, int32 input2_multiplier,
- int input2_shift, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastSub/8bit");
-
+inline void BroadcastSub4DSlow(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape,
+ const uint8* input1_data,
+ const RuntimeShape& input2_shape,
+ const uint8* input2_data,
+ const RuntimeShape& output_shape,
+ uint8* output_data) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/uint8");
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+ NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
+ &desc2);
+ RuntimeShape extended_output_shape =
+ RuntimeShape::ExtendedShape(4, output_shape);
// In Tensorflow, the dimensions are canonically named (batch_number, row,
// col, channel), with extents (batches, height, width, depth), with the
@@ -1747,33 +1639,37 @@ inline void BroadcastSub(int left_shift, const uint8* input1_data,
// We name our variables by their Tensorflow convention, but generate C code
// nesting loops such that the innermost loop has the smallest stride for the
// best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
+ for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
+ for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
+ for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
+ for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
const int32 input1_val =
- input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)];
+ params.input1_offset +
+ input1_data[SubscriptToIndex(desc1, b, y, x, c)];
const int32 input2_val =
- input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)];
- const int32 shifted_input1_val = input1_val * (1 << left_shift);
- const int32 shifted_input2_val = input2_val * (1 << left_shift);
+ params.input2_offset +
+ input2_data[SubscriptToIndex(desc2, b, y, x, c)];
+ const int32 shifted_input1_val =
+ input1_val * (1 << params.left_shift);
+ const int32 shifted_input2_val =
+ input2_val * (1 << params.left_shift);
const int32 scaled_input1_val =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
- shifted_input1_val, input1_multiplier,
- kReverseShift * input1_shift);
+ shifted_input1_val, params.input1_multiplier,
+ params.input1_shift);
const int32 scaled_input2_val =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
- shifted_input2_val, input2_multiplier,
- kReverseShift * input2_shift);
+ shifted_input2_val, params.input2_multiplier,
+ params.input2_shift);
const int32 raw_sub = scaled_input1_val - scaled_input2_val;
const int32 raw_output =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
- raw_sub, output_multiplier, kReverseShift * output_shift) +
- output_offset;
+ raw_sub, params.output_multiplier, params.output_shift) +
+ params.output_offset;
const int32 clamped_output =
- std::min(output_activation_max,
- std::max(output_activation_min, raw_output));
- output_data[Offset(output_dims, c, x, y, b)] =
+ std::min(params.quantized_activation_max,
+ std::max(params.quantized_activation_min, raw_output));
+ output_data[Offset(extended_output_shape, b, y, x, c)] =
static_cast<uint8>(clamped_output);
}
}
@@ -1781,6 +1677,156 @@ inline void BroadcastSub(int left_shift, const uint8* input1_data,
}
}
+inline void BroadcastSub4DSlow(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape,
+ const int32* input1_data,
+ const RuntimeShape& input2_shape,
+ const int32* input2_data,
+ const RuntimeShape& output_shape,
+ int32* output_data) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/int32");
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
+ &desc2);
+ RuntimeShape extended_output_shape =
+ RuntimeShape::ExtendedShape(4, output_shape);
+
+ // In Tensorflow, the dimensions are canonically named (batch_number, row,
+ // col, channel), with extents (batches, height, width, depth), with the
+ // trailing dimension changing most rapidly (channels has the smallest stride,
+ // typically 1 element).
+ //
+ // In generated C code, we store arrays with the dimensions reversed. The
+ // first dimension has smallest stride.
+ //
+ // We name our variables by their Tensorflow convention, but generate C code
+ // nesting loops such that the innermost loop has the smallest stride for the
+ // best cache behavior.
+ for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
+ for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
+ for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
+ for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
+ output_data[Offset(extended_output_shape, b, y, x, c)] =
+ ActivationFunctionWithMinMax(
+ input1_data[SubscriptToIndex(desc1, b, y, x, c)] -
+ input2_data[SubscriptToIndex(desc2, b, y, x, c)],
+ params.quantized_activation_min,
+ params.quantized_activation_max);
+ }
+ }
+ }
+ }
+}
+
+template <typename T>
+void BroadcastSub4DSlow(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const T* input1_data,
+ const RuntimeShape& input2_shape, const T* input2_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/templated");
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
+ &desc2);
+ RuntimeShape extended_output_shape =
+ RuntimeShape::ExtendedShape(4, output_shape);
+
+ // In Tensorflow, the dimensions are canonically named (batch_number, row,
+ // col, channel), with extents (batches, height, width, depth), with the
+ // trailing dimension changing most rapidly (channels has the smallest stride,
+ // typically 1 element).
+ //
+ // In generated C code, we store arrays with the dimensions reversed. The
+ // first dimension has smallest stride.
+ //
+ // We name our variables by their Tensorflow convention, but generate C code
+ // nesting loops such that the innermost loop has the smallest stride for the
+ // best cache behavior.
+ for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
+ for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
+ for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
+ for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
+ output_data[Offset(extended_output_shape, b, y, x, c)] =
+ ActivationFunctionWithMinMax(
+ input1_data[SubscriptToIndex(desc1, b, y, x, c)] -
+ input2_data[SubscriptToIndex(desc2, b, y, x, c)],
+ params.quantized_activation_min,
+ params.quantized_activation_max);
+ }
+ }
+ }
+ }
+}
+
+template <typename T>
+void Sub(const ArithmeticParams& params, const RuntimeShape& input1_shape,
+ const T* input1_data, const RuntimeShape& input2_shape,
+ const T* input2_data, const RuntimeShape& output_shape,
+ T* output_data) {
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
+ &desc2);
+ RuntimeShape extended_output_shape =
+ RuntimeShape::ExtendedShape(4, output_shape);
+
+ // In Tensorflow, the dimensions are canonically named (batch_number, row,
+ // col, channel), with extents (batches, height, width, depth), with the
+ // trailing dimension changing most rapidly (channels has the smallest stride,
+ // typically 1 element).
+ //
+ // In generated C code, we store arrays with the dimensions reversed. The
+ // first dimension has smallest stride.
+ //
+ // We name our variables by their Tensorflow convention, but generate C code
+ // nesting loops such that the innermost loop has the smallest stride for the
+ // best cache behavior.
+ for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
+ for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
+ for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
+ for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
+ output_data[Offset(extended_output_shape, b, y, x, c)] =
+ input1_data[SubscriptToIndex(desc1, b, y, x, c)] -
+ input2_data[SubscriptToIndex(desc2, b, y, x, c)];
+ }
+ }
+ }
+ }
+}
+
+inline void SubWithActivation(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape,
+ const int32* input1_data,
+ const RuntimeShape& input2_shape,
+ const int32* input2_data,
+ const RuntimeShape& output_shape,
+ int32* output_data) {
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, input2_shape);
+ for (int i = 0; i < flat_size; ++i) {
+ output_data[i] = ActivationFunctionWithMinMax(
+ input1_data[i] - input2_data[i], params.quantized_activation_min,
+ params.quantized_activation_max);
+ }
+}
+
+inline void SubWithActivation(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape,
+ const float* input1_data,
+ const RuntimeShape& input2_shape,
+ const float* input2_data,
+ const RuntimeShape& output_shape,
+ float* output_data) {
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, input2_shape);
+ for (int i = 0; i < flat_size; ++i) {
+ output_data[i] = ActivationFunctionWithMinMax(
+ input1_data[i] - input2_data[i], params.float_activation_min,
+ params.float_activation_max);
+ }
+}
+
template <FusedActivationFunctionType Ac, typename Scalar>
void Concatenation(int concat_dim, const Scalar* const* input_data,
const Dims<4>* const* input_dims, int inputs_count,
@@ -3717,38 +3763,6 @@ inline void Mean(const T* input_data, const Dims<4>& input_dims,
}
template <typename T>
-void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data,
- const Dims<4>& input2_dims, T* output_data,
- const Dims<4>& output_dims) {
- NdArrayDesc<4> desc1;
- NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
-
- // In Tensorflow, the dimensions are canonically named (batch_number, row,
- // col, channel), with extents (batches, height, width, depth), with the
- // trailing dimension changing most rapidly (channels has the smallest stride,
- // typically 1 element).
- //
- // In generated C code, we store arrays with the dimensions reversed. The
- // first dimension has smallest stride.
- //
- // We name our variables by their Tensorflow convention, but generate C code
- // nesting loops such that the innermost loop has the smallest stride for the
- // best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
- input1_data[SubscriptToIndex(desc1, c, x, y, b)] -
- input2_data[SubscriptToIndex(desc2, c, x, y, b)];
- }
- }
- }
- }
-}
-
-template <typename T>
void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims,
const T* input2_data, T* output_data,
const Dims<4>& output_dims) {
diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h
index 737cfb69c9..fe113dfdd3 100644
--- a/tensorflow/contrib/lite/kernels/internal/types.h
+++ b/tensorflow/contrib/lite/kernels/internal/types.h
@@ -119,6 +119,8 @@ class RuntimeShape {
// larger shapes are separately allocated.
static constexpr int kMaxSmallSize = 4;
+ RuntimeShape& operator=(RuntimeShape const&) = delete;
+
RuntimeShape() : size_(0) {}
explicit RuntimeShape(int dimensions_count) : size_(dimensions_count) {
@@ -135,6 +137,20 @@ class RuntimeShape {
BuildFrom(init_list);
}
+ // Avoid using this constructor. We should be able to delete it when C++17
+ // rolls out.
+ RuntimeShape(RuntimeShape const& other) : size_(other.DimensionsCount()) {
+ if (size_ > kMaxSmallSize) {
+ dims_pointer_ = new int32[size_];
+ }
+ std::memcpy(DimsData(), other.DimsData(), sizeof(int32) * size_);
+ }
+
+ bool operator==(const RuntimeShape& comp) const {
+ return this->size_ == comp.size_ &&
+ std::memcmp(DimsData(), comp.DimsData(), size_ * sizeof(int32)) == 0;
+ }
+
~RuntimeShape() {
if (size_ > kMaxSmallSize) {
delete[] dims_pointer_;
@@ -191,6 +207,16 @@ class RuntimeShape {
}
}
+ // This will probably be factored out. Old code made substantial use of 4-D
+ // shapes, and so this function is used to extend smaller shapes. Note that
+ // (a) as Dims<4>-dependent code is eliminated, the reliance on this should be
+ // reduced, and (b) some kernels are stricly 4-D, but then the shapes of their
+ // inputs should already be 4-D, so this function should not be needed.
+ inline static RuntimeShape ExtendedShape(int new_shape_size,
+ const RuntimeShape& shape) {
+ return RuntimeShape(new_shape_size, shape, 1);
+ }
+
inline void BuildFrom(const std::initializer_list<int> init_list) {
BuildFrom<const std::initializer_list<int>>(init_list);
}
@@ -208,7 +234,25 @@ class RuntimeShape {
return buffer_size;
}
+ bool operator!=(const RuntimeShape& comp) const { return !((*this) == comp); }
+
private:
+ // For use only by ExtendFrom(), written to guarantee (return-value) copy
+ // elision in C++17.
+ // This creates a shape padded to the desired size with the specified value.
+ RuntimeShape(int new_shape_size, const RuntimeShape& shape, int pad_value)
+ : size_(0) {
+ TFLITE_CHECK_GE(new_shape_size, shape.DimensionsCount());
+ TFLITE_CHECK_LE(new_shape_size, kMaxSmallSize);
+ Resize(new_shape_size);
+ const int size_increase = new_shape_size - shape.DimensionsCount();
+ for (int i = 0; i < size_increase; ++i) {
+ SetDim(i, pad_value);
+ }
+ std::memcpy(DimsData() + size_increase, shape.DimsData(),
+ sizeof(int32) * shape.DimensionsCount());
+ }
+
int32 size_;
union {
int32 dims_[kMaxSmallSize];
@@ -364,6 +408,7 @@ inline int RequiredBufferSizeForDims(const Dims<4>& dims) {
// arrays.
inline int MatchingFlatSize(const RuntimeShape& shape,
const RuntimeShape& check_shape_0) {
+ TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
const int dims_count = shape.DimensionsCount();
for (int i = 0; i < dims_count; ++i) {
TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
@@ -374,6 +419,7 @@ inline int MatchingFlatSize(const RuntimeShape& shape,
inline int MatchingFlatSize(const RuntimeShape& shape,
const RuntimeShape& check_shape_0,
const RuntimeShape& check_shape_1) {
+ TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
const int dims_count = shape.DimensionsCount();
for (int i = 0; i < dims_count; ++i) {
TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
@@ -385,6 +431,7 @@ inline int MatchingFlatSize(const RuntimeShape& shape,
const RuntimeShape& check_shape_0,
const RuntimeShape& check_shape_1,
const RuntimeShape& check_shape_2) {
+ TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
const int dims_count = shape.DimensionsCount();
for (int i = 0; i < dims_count; ++i) {
TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
@@ -397,6 +444,7 @@ inline int MatchingFlatSize(const RuntimeShape& shape,
const RuntimeShape& check_shape_1,
const RuntimeShape& check_shape_2,
const RuntimeShape& check_shape_3) {
+ TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
const int dims_count = shape.DimensionsCount();
for (int i = 0; i < dims_count; ++i) {
TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
@@ -601,14 +649,74 @@ struct PoolParams {
int stride_width;
int filter_height;
int filter_width;
- // uint8, etc, inference params.
+ // uint8, etc, activation params.
int32 quantized_activation_min;
int32 quantized_activation_max;
- // float inference params.
+ // float activation params.
float float_activation_min;
float float_activation_max;
};
+enum class BroadcastableOpCategory : uint8 {
+ kNone,
+ kNonBroadcast, // Matching input shapes.
+ kFirstInputBroadcastsFast, // Fivefold nested loops.
+ kSecondInputBroadcastsFast, // Fivefold nested loops.
+ kGenericBroadcast, // Fall-back.
+};
+
+// For Add, Sub, Mul ops.
+struct ArithmeticParams {
+ // Shape dependent / common to data / op types.
+ BroadcastableOpCategory broadcast_category;
+ // uint8 inference params.
+ int32 input1_offset;
+ int32 input2_offset;
+ int32 output_offset;
+ int32 output_multiplier;
+ int output_shift;
+ // Add / Sub, not Mul, uint8 inference params.
+ int left_shift;
+ int32 input1_multiplier;
+ int input1_shift;
+ int32 input2_multiplier;
+ int input2_shift;
+ // uint8, etc, activation params.
+ int32 quantized_activation_min;
+ int32 quantized_activation_max;
+ // float activation params.
+ float float_activation_min;
+ float float_activation_max;
+
+ // Processed output dimensions.
+ // Let input "a" be the one that broadcasts in the faster-changing dimension.
+ // Then, after coalescing, for shapes {a0, a1, a2, a3, a4} and
+ // {b0, b1, b2, b3, b4},
+ // broadcast_shape[4] = b0 = a0.
+ // broadcast_shape[3] = b1; a1 = 1.
+ // broadcast_shape[2] = b2 = a2.
+ // broadcast_shape[1] = a3; b3 = 1.
+ // broadcast_shape[0] = b4 = a4.
+ int broadcast_shape[5];
+};
+
+template <typename T>
+inline void SetActivationParams(T min, T max, ArithmeticParams* params);
+
+template <>
+inline void SetActivationParams(float min, float max,
+ ArithmeticParams* params) {
+ params->float_activation_min = min;
+ params->float_activation_max = max;
+}
+
+template <>
+inline void SetActivationParams(int32 min, int32 max,
+ ArithmeticParams* params) {
+ params->quantized_activation_min = min;
+ params->quantized_activation_max = max;
+}
+
} // namespace tflite
#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_
diff --git a/tensorflow/contrib/lite/kernels/sub.cc b/tensorflow/contrib/lite/kernels/sub.cc
index 541c85f756..77a1f59689 100644
--- a/tensorflow/contrib/lite/kernels/sub.cc
+++ b/tensorflow/contrib/lite/kernels/sub.cc
@@ -81,40 +81,43 @@ template <KernelType kernel_type>
void EvalSub(TfLiteContext* context, TfLiteNode* node, TfLiteSubParams* params,
const OpData* data, const TfLiteTensor* input1,
const TfLiteTensor* input2, TfLiteTensor* output) {
-#define TF_LITE_SUB(type, opname, data_type) \
- data_type output_activation_min, output_activation_max; \
- CalculateActivationRange(params->activation, &output_activation_min, \
- &output_activation_max); \
- type::opname(GetTensorData<data_type>(input1), GetTensorDims(input1), \
- GetTensorData<data_type>(input2), GetTensorDims(input2), \
- output_activation_min, output_activation_max, \
- GetTensorData<data_type>(output), GetTensorDims(output))
+#define TF_LITE_SUB(type, opname, data_type) \
+ data_type output_activation_min, output_activation_max; \
+ CalculateActivationRange(params->activation, &output_activation_min, \
+ &output_activation_max); \
+ tflite::ArithmeticParams op_params; \
+ SetActivationParams(output_activation_min, output_activation_max, \
+ &op_params); \
+ type::opname(op_params, GetTensorShape(input1), \
+ GetTensorData<data_type>(input1), GetTensorShape(input2), \
+ GetTensorData<data_type>(input2), GetTensorShape(output), \
+ GetTensorData<data_type>(output))
if (output->type == kTfLiteInt32) {
if (kernel_type == kReference) {
if (data->requires_broadcast) {
- TF_LITE_SUB(reference_ops, BroadcastSub, int32_t);
+ TF_LITE_SUB(reference_ops, BroadcastSub4DSlow, int32_t);
} else {
- TF_LITE_SUB(reference_ops, Sub, int32_t);
+ TF_LITE_SUB(reference_ops, SubWithActivation, int32_t);
}
} else {
if (data->requires_broadcast) {
- TF_LITE_SUB(optimized_ops, BroadcastSub, int32_t);
+ TF_LITE_SUB(optimized_ops, BroadcastSub4DSlow, int32_t);
} else {
- TF_LITE_SUB(optimized_ops, Sub, int32_t);
+ TF_LITE_SUB(optimized_ops, SubWithActivation, int32_t);
}
}
} else if (output->type == kTfLiteFloat32) {
if (kernel_type == kReference) {
if (data->requires_broadcast) {
- TF_LITE_SUB(reference_ops, BroadcastSub, float);
+ TF_LITE_SUB(reference_ops, BroadcastSub4DSlow, float);
} else {
- TF_LITE_SUB(reference_ops, Sub, float);
+ TF_LITE_SUB(reference_ops, SubWithActivation, float);
}
} else {
if (data->requires_broadcast) {
- TF_LITE_SUB(optimized_ops, BroadcastSub, float);
+ TF_LITE_SUB(optimized_ops, BroadcastSub4DSlow, float);
} else {
- TF_LITE_SUB(optimized_ops, Sub, float);
+ TF_LITE_SUB(optimized_ops, SubWithActivation, float);
}
}
}
@@ -143,36 +146,43 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
int input1_shift;
QuantizeMultiplierSmallerThanOneExp(real_input1_multiplier,
&input1_multiplier, &input1_shift);
- input1_shift *= -1;
int32 input2_multiplier;
int input2_shift;
QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier,
&input2_multiplier, &input2_shift);
- input2_shift *= -1;
int32 output_multiplier;
int output_shift;
QuantizeMultiplierSmallerThanOneExp(real_output_multiplier,
&output_multiplier, &output_shift);
- output_shift *= -1;
int32 output_activation_min, output_activation_max;
CalculateActivationRangeUint8(params->activation, output,
&output_activation_min, &output_activation_max);
-#define TF_LITE_SUB(type, opname) \
- type::opname(left_shift, GetTensorData<uint8_t>(input1), \
- GetTensorDims(input1), input1_offset, input1_multiplier, \
- input1_shift, GetTensorData<uint8_t>(input2), \
- GetTensorDims(input2), input2_offset, input2_multiplier, \
- input2_shift, output_offset, output_multiplier, output_shift, \
- output_activation_min, output_activation_max, \
- GetTensorData<uint8_t>(output), GetTensorDims(output));
+#define TF_LITE_SUB(type, opname) \
+ tflite::ArithmeticParams op_params; \
+ op_params.left_shift = left_shift; \
+ op_params.input1_offset = input1_offset; \
+ op_params.input1_multiplier = input1_multiplier; \
+ op_params.input1_shift = input1_shift; \
+ op_params.input2_offset = input2_offset; \
+ op_params.input2_multiplier = input2_multiplier; \
+ op_params.input2_shift = input2_shift; \
+ op_params.output_offset = output_offset; \
+ op_params.output_multiplier = output_multiplier; \
+ op_params.output_shift = output_shift; \
+ SetActivationParams(output_activation_min, output_activation_max, \
+ &op_params); \
+ type::opname(op_params, GetTensorShape(input1), \
+ GetTensorData<uint8_t>(input1), GetTensorShape(input2), \
+ GetTensorData<uint8_t>(input2), GetTensorShape(output), \
+ GetTensorData<uint8_t>(output))
// The quantized version of Sub doesn't support activations, so we
// always use BroadcastSub.
if (kernel_type == kReference) {
- TF_LITE_SUB(reference_ops, BroadcastSub);
+ TF_LITE_SUB(reference_ops, BroadcastSub4DSlow);
} else {
- TF_LITE_SUB(optimized_ops, BroadcastSub);
+ TF_LITE_SUB(optimized_ops, BroadcastSub4DSlow);
}
#undef TF_LITE_SUB
}
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index 6c1ba3694a..5e6106a87e 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -764,6 +764,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_TOPK_V2:
case BuiltinOperator_TRANSPOSE:
case BuiltinOperator_POW:
+ case BuiltinOperator_PACK:
break;
}
return kTfLiteOk;
diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc
index 5950840e8a..710ce1632e 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/nnapi_delegate.cc
@@ -614,6 +614,7 @@ TfLiteStatus AddOpsAndParams(
case tflite::BuiltinOperator_SHAPE:
case tflite::BuiltinOperator_POW:
case tflite::BuiltinOperator_FAKE_QUANT:
+ case tflite::BuiltinOperator_PACK:
logError("Op code %d is currently not delegated to NNAPI", builtin);
return kTfLiteError;
break;
diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py
index 9bd1f4f76e..d17482e601 100644
--- a/tensorflow/contrib/lite/python/tflite_convert.py
+++ b/tensorflow/contrib/lite/python/tflite_convert.py
@@ -257,7 +257,7 @@ def run_main(_):
parser.add_argument(
"--input_arrays",
type=str,
- help="Names of the output arrays, comma-separated.")
+ help="Names of the input arrays, comma-separated.")
parser.add_argument(
"--input_shapes",
type=str,
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs
index 6c3189a884..0434199a08 100644
--- a/tensorflow/contrib/lite/schema/schema.fbs
+++ b/tensorflow/contrib/lite/schema/schema.fbs
@@ -164,6 +164,7 @@ enum BuiltinOperator : byte {
FAKE_QUANT = 80,
REDUCE_PROD = 81,
REDUCE_MAX = 82,
+ PACK = 83,
}
// Options for the builtin operators.
@@ -226,6 +227,7 @@ union BuiltinOptions {
PowOptions,
ArgMinOptions,
FakeQuantOptions,
+ PackOptions,
}
enum Padding : byte { SAME, VALID }
@@ -537,6 +539,11 @@ table FakeQuantOptions {
narrow_range:bool;
}
+table PackOptions {
+ values_count:int;
+ axis:int;
+}
+
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom.
table OperatorCode {
diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h
index 8052404319..9b84030938 100755
--- a/tensorflow/contrib/lite/schema/schema_generated.h
+++ b/tensorflow/contrib/lite/schema/schema_generated.h
@@ -205,6 +205,9 @@ struct PowOptionsT;
struct FakeQuantOptions;
struct FakeQuantOptionsT;
+struct PackOptions;
+struct PackOptionsT;
+
struct OperatorCode;
struct OperatorCodeT;
@@ -353,11 +356,12 @@ enum BuiltinOperator {
BuiltinOperator_FAKE_QUANT = 80,
BuiltinOperator_REDUCE_PROD = 81,
BuiltinOperator_REDUCE_MAX = 82,
+ BuiltinOperator_PACK = 83,
BuiltinOperator_MIN = BuiltinOperator_ADD,
- BuiltinOperator_MAX = BuiltinOperator_REDUCE_MAX
+ BuiltinOperator_MAX = BuiltinOperator_PACK
};
-inline BuiltinOperator (&EnumValuesBuiltinOperator())[82] {
+inline BuiltinOperator (&EnumValuesBuiltinOperator())[83] {
static BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
@@ -440,7 +444,8 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[82] {
BuiltinOperator_ARG_MIN,
BuiltinOperator_FAKE_QUANT,
BuiltinOperator_REDUCE_PROD,
- BuiltinOperator_REDUCE_MAX
+ BuiltinOperator_REDUCE_MAX,
+ BuiltinOperator_PACK
};
return values;
}
@@ -530,6 +535,7 @@ inline const char **EnumNamesBuiltinOperator() {
"FAKE_QUANT",
"REDUCE_PROD",
"REDUCE_MAX",
+ "PACK",
nullptr
};
return names;
@@ -600,11 +606,12 @@ enum BuiltinOptions {
BuiltinOptions_PowOptions = 56,
BuiltinOptions_ArgMinOptions = 57,
BuiltinOptions_FakeQuantOptions = 58,
+ BuiltinOptions_PackOptions = 59,
BuiltinOptions_MIN = BuiltinOptions_NONE,
- BuiltinOptions_MAX = BuiltinOptions_FakeQuantOptions
+ BuiltinOptions_MAX = BuiltinOptions_PackOptions
};
-inline BuiltinOptions (&EnumValuesBuiltinOptions())[59] {
+inline BuiltinOptions (&EnumValuesBuiltinOptions())[60] {
static BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@@ -664,7 +671,8 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[59] {
BuiltinOptions_ShapeOptions,
BuiltinOptions_PowOptions,
BuiltinOptions_ArgMinOptions,
- BuiltinOptions_FakeQuantOptions
+ BuiltinOptions_FakeQuantOptions,
+ BuiltinOptions_PackOptions
};
return values;
}
@@ -730,6 +738,7 @@ inline const char **EnumNamesBuiltinOptions() {
"PowOptions",
"ArgMinOptions",
"FakeQuantOptions",
+ "PackOptions",
nullptr
};
return names;
@@ -976,6 +985,10 @@ template<> struct BuiltinOptionsTraits<FakeQuantOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_FakeQuantOptions;
};
+template<> struct BuiltinOptionsTraits<PackOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_PackOptions;
+};
+
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@@ -1471,6 +1484,14 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_FakeQuantOptions ?
reinterpret_cast<const FakeQuantOptionsT *>(value) : nullptr;
}
+ PackOptionsT *AsPackOptions() {
+ return type == BuiltinOptions_PackOptions ?
+ reinterpret_cast<PackOptionsT *>(value) : nullptr;
+ }
+ const PackOptionsT *AsPackOptions() const {
+ return type == BuiltinOptions_PackOptions ?
+ reinterpret_cast<const PackOptionsT *>(value) : nullptr;
+ }
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@@ -5304,6 +5325,72 @@ inline flatbuffers::Offset<FakeQuantOptions> CreateFakeQuantOptions(
flatbuffers::Offset<FakeQuantOptions> CreateFakeQuantOptions(flatbuffers::FlatBufferBuilder &_fbb, const FakeQuantOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct PackOptionsT : public flatbuffers::NativeTable {
+ typedef PackOptions TableType;
+ int32_t values_count;
+ int32_t axis;
+ PackOptionsT()
+ : values_count(0),
+ axis(0) {
+ }
+};
+
+struct PackOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef PackOptionsT NativeTableType;
+ enum {
+ VT_VALUES_COUNT = 4,
+ VT_AXIS = 6
+ };
+ int32_t values_count() const {
+ return GetField<int32_t>(VT_VALUES_COUNT, 0);
+ }
+ int32_t axis() const {
+ return GetField<int32_t>(VT_AXIS, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int32_t>(verifier, VT_VALUES_COUNT) &&
+ VerifyField<int32_t>(verifier, VT_AXIS) &&
+ verifier.EndTable();
+ }
+ PackOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(PackOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<PackOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const PackOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct PackOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_values_count(int32_t values_count) {
+ fbb_.AddElement<int32_t>(PackOptions::VT_VALUES_COUNT, values_count, 0);
+ }
+ void add_axis(int32_t axis) {
+ fbb_.AddElement<int32_t>(PackOptions::VT_AXIS, axis, 0);
+ }
+ explicit PackOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ PackOptionsBuilder &operator=(const PackOptionsBuilder &);
+ flatbuffers::Offset<PackOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<PackOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<PackOptions> CreatePackOptions(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int32_t values_count = 0,
+ int32_t axis = 0) {
+ PackOptionsBuilder builder_(_fbb);
+ builder_.add_axis(axis);
+ builder_.add_values_count(values_count);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<PackOptions> CreatePackOptions(flatbuffers::FlatBufferBuilder &_fbb, const PackOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct OperatorCodeT : public flatbuffers::NativeTable {
typedef OperatorCode TableType;
BuiltinOperator builtin_code;
@@ -5611,6 +5698,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const FakeQuantOptions *builtin_options_as_FakeQuantOptions() const {
return builtin_options_type() == BuiltinOptions_FakeQuantOptions ? static_cast<const FakeQuantOptions *>(builtin_options()) : nullptr;
}
+ const PackOptions *builtin_options_as_PackOptions() const {
+ return builtin_options_type() == BuiltinOptions_PackOptions ? static_cast<const PackOptions *>(builtin_options()) : nullptr;
+ }
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@@ -5874,6 +5964,10 @@ template<> inline const FakeQuantOptions *Operator::builtin_options_as<FakeQuant
return builtin_options_as_FakeQuantOptions();
}
+template<> inline const PackOptions *Operator::builtin_options_as<PackOptions>() const {
+ return builtin_options_as_PackOptions();
+}
+
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@@ -7937,6 +8031,35 @@ inline flatbuffers::Offset<FakeQuantOptions> CreateFakeQuantOptions(flatbuffers:
_narrow_range);
}
+inline PackOptionsT *PackOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new PackOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void PackOptions::UnPackTo(PackOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = values_count(); _o->values_count = _e; };
+ { auto _e = axis(); _o->axis = _e; };
+}
+
+inline flatbuffers::Offset<PackOptions> PackOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PackOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreatePackOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<PackOptions> CreatePackOptions(flatbuffers::FlatBufferBuilder &_fbb, const PackOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PackOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _values_count = _o->values_count;
+ auto _axis = _o->axis;
+ return tflite::CreatePackOptions(
+ _fbb,
+ _values_count,
+ _axis);
+}
+
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new OperatorCodeT();
UnPackTo(_o, _resolver);
@@ -8358,6 +8481,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
auto ptr = reinterpret_cast<const FakeQuantOptions *>(obj);
return verifier.VerifyTable(ptr);
}
+ case BuiltinOptions_PackOptions: {
+ auto ptr = reinterpret_cast<const PackOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
default: return false;
}
}
@@ -8608,6 +8735,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const FakeQuantOptions *>(obj);
return ptr->UnPack(resolver);
}
+ case BuiltinOptions_PackOptions: {
+ auto ptr = reinterpret_cast<const PackOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
default: return nullptr;
}
}
@@ -8846,6 +8977,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
auto ptr = reinterpret_cast<const FakeQuantOptionsT *>(value);
return CreateFakeQuantOptions(_fbb, ptr, _rehasher).Union();
}
+ case BuiltinOptions_PackOptions: {
+ auto ptr = reinterpret_cast<const PackOptionsT *>(value);
+ return CreatePackOptions(_fbb, ptr, _rehasher).Union();
+ }
default: return 0;
}
}
@@ -9084,6 +9219,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
value = new FakeQuantOptionsT(*reinterpret_cast<FakeQuantOptionsT *>(u.value));
break;
}
+ case BuiltinOptions_PackOptions: {
+ value = new PackOptionsT(*reinterpret_cast<PackOptionsT *>(u.value));
+ break;
+ }
default:
break;
}
@@ -9381,6 +9520,11 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
+ case BuiltinOptions_PackOptions: {
+ auto ptr = reinterpret_cast<PackOptionsT *>(value);
+ delete ptr;
+ break;
+ }
default: break;
}
value = nullptr;
diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
index ba36017baf..770092e12c 100644
--- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
+++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
@@ -53,8 +53,6 @@ tensorflow::Env* env = tensorflow::Env::Default();
// Key is a substring of the test name and value is a bug number.
// TODO(ahentz): make sure we clean this list up frequently.
std::map<string, string> kBrokenTests = {
- {R"(^\/div.*int32)", "68808744"},
-
// Pad and PadV2 only supports 4D tensors.
{R"(^\/pad.*,input_shape=\[.,.\],paddings=\[\[.,.\],\[.,.\]\])",
"70527055"},
diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py
index 889accdd5a..8d510ede58 100644
--- a/tensorflow/contrib/lookup/lookup_ops_test.py
+++ b/tensorflow/contrib/lookup/lookup_ops_test.py
@@ -280,6 +280,21 @@ class HashTableOpTest(test.TestCase):
table.init.run()
self.assertAllEqual(3, table.size().eval())
+ def testHashTableInt32String(self):
+ with self.test_session():
+ default_val = "n/a"
+ keys = constant_op.constant([0, 1, 2], dtypes.int32)
+ values = constant_op.constant(["brain", "salad", "surgery"])
+ table = lookup.HashTable(
+ lookup.KeyValueTensorInitializer(keys, values), default_val)
+ table.init.run()
+
+ input_tensor = constant_op.constant([0, 1, -1])
+ output = table.lookup(input_tensor)
+
+ result = output.eval()
+ self.assertAllEqual([b"brain", b"salad", b"n/a"], result)
+
class MutableHashTableOpTest(test.TestCase):
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py
index b14202ff9e..a328670526 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py
@@ -3715,6 +3715,7 @@ def count(values,
name=None):
"""Computes the number of examples, or sum of `weights`.
+ This metric keeps track of the denominator in `tf.metrics.mean`.
When evaluating some metric (e.g. mean) on one or more subsets of the data,
this auxiliary metric is useful for keeping track of how many examples there
are in each subset.
@@ -3741,15 +3742,21 @@ def count(values,
ValueError: If `weights` is not `None` and its shape doesn't match `values`,
or if either `metrics_collections` or `updates_collections` are not a list
or tuple.
+ RuntimeError: If eager execution is enabled.
"""
+ if context.executing_eagerly():
+ raise RuntimeError('tf.contrib.metrics.count is not supported when eager '
+ 'execution is enabled.')
with variable_scope.variable_scope(name, 'count', (values, weights)):
+
count_ = metrics_impl.metric_variable([], dtypes.float32, name='count')
if weights is None:
num_values = math_ops.to_float(array_ops.size(values))
else:
- _, _, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access
+ values = math_ops.to_float(values)
+ values, _, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access
predictions=values,
labels=None,
weights=weights)
@@ -3758,15 +3765,14 @@ def count(values,
num_values = math_ops.reduce_sum(weights)
with ops.control_dependencies([values]):
- update_op = state_ops.assign_add(count_, num_values)
+ update_count_op = state_ops.assign_add(count_, num_values)
- if metrics_collections:
- ops.add_to_collections(metrics_collections, count_)
+ count_ = metrics_impl._aggregate_variable(count_, metrics_collections) # pylint: disable=protected-access
if updates_collections:
- ops.add_to_collections(updates_collections, update_op)
+ ops.add_to_collections(updates_collections, update_count_op)
- return count_, update_op
+ return count_, update_count_op
def cohen_kappa(labels,
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
index a09fc4abd4..401fedcbed 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
@@ -6854,6 +6854,11 @@ class CountTest(test.TestCase):
array_ops.ones([4, 3]), updates_collections=[my_collection_name])
self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
+ def testReturnType(self):
+ c, op = metrics.count(array_ops.ones([4, 3]))
+ self.assertTrue(isinstance(c, ops.Tensor))
+ self.assertTrue(isinstance(op, ops.Operation) or isinstance(op, ops.Tensor))
+
def testBasic(self):
with self.test_session() as sess:
values_queue = data_flow_ops.FIFOQueue(
diff --git a/tensorflow/contrib/model_pruning/README.md b/tensorflow/contrib/model_pruning/README.md
index 86f4fd6adf..9143d082bf 100644
--- a/tensorflow/contrib/model_pruning/README.md
+++ b/tensorflow/contrib/model_pruning/README.md
@@ -66,10 +66,10 @@ is the sparsity_function_begin_step. In this equation, the
sparsity_function_exponent is set to 3.
### Adding pruning ops to the training graph
-The final step involves adding ops to the training graph that monitors the
-distribution of the layer's weight magnitudes and determines the layer threshold
-such masking all the weights below this threshold achieves the sparsity level
-desired for the current training step. This can be achieved as follows:
+The final step involves adding ops to the training graph that monitor the
+distribution of the layer's weight magnitudes and determine the layer threshold,
+such that masking all the weights below this threshold achieves the sparsity
+level desired for the current training step. This can be achieved as follows:
```python
tf.app.flags.DEFINE_string(
@@ -79,7 +79,7 @@ tf.app.flags.DEFINE_string(
with tf.graph.as_default():
# Create global step variable
- global_step = tf.train.get_global_step()
+ global_step = tf.train.get_or_create_global_step()
# Parse pruning hyperparameters
pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams)
@@ -103,6 +103,7 @@ with tf.graph.as_default():
mon_sess.run(mask_update_op)
```
+Ensure that `global_step` is being [incremented](https://www.tensorflow.org/api_docs/python/tf/train/Optimizer#minimize), otherwise pruning will not work!
## Example: Pruning and training deep CNNs on the cifar10 dataset
diff --git a/tensorflow/contrib/model_pruning/python/pruning.py b/tensorflow/contrib/model_pruning/python/pruning.py
index 4b7af18b33..da9d398cbc 100644
--- a/tensorflow/contrib/model_pruning/python/pruning.py
+++ b/tensorflow/contrib/model_pruning/python/pruning.py
@@ -518,11 +518,11 @@ class Pruning(object):
summary.scalar('last_mask_update_step', self._last_update_step)
masks = get_masks()
thresholds = get_thresholds()
- for index, mask in enumerate(masks):
+ for mask, threshold in zip(masks, thresholds):
if not self._exists_in_do_not_prune_list(mask.name):
- summary.scalar(mask.name + '/sparsity', nn_impl.zero_fraction(mask))
- summary.scalar(thresholds[index].op.name + '/threshold',
- thresholds[index])
+ summary.scalar(mask.op.name + '/sparsity',
+ nn_impl.zero_fraction(mask))
+ summary.scalar(threshold.op.name + '/threshold', threshold)
def print_hparams(self):
logging.info(self._spec.to_json())
diff --git a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py
index 5f7f510352..e3570e38a3 100644
--- a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py
+++ b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py
@@ -106,34 +106,27 @@ class DecodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase):
self.assertEqual(v, ev)
continue
- # This can be a little confusing. For testing we are using TestValue in
- # two ways: it's the proto that we decode for testing, and it's used in
- # the expected value as a union type.
- #
- # The two cases are slightly different: this is the second case. We may be
- # fetching the uint64_value from the test proto, but in the expected proto
- # we store it in the int64_value field because TensorFlow doesn't support
- # unsigned int64.
tf_type_to_primitive_value_field = {
+ dtypes.bool:
+ 'bool_value',
dtypes.float32:
'float_value',
dtypes.float64:
'double_value',
- dtypes.int32:
- 'int32_value',
- dtypes.uint8:
- 'uint8_value',
dtypes.int8:
'int8_value',
- dtypes.string:
- 'string_value',
+ dtypes.int32:
+ 'int32_value',
dtypes.int64:
'int64_value',
- dtypes.bool:
- 'bool_value',
- # Unhandled TensorFlow types:
- # DT_INT16 DT_COMPLEX64 DT_QINT8 DT_QUINT8 DT_QINT32
- # DT_BFLOAT16 DT_QINT16 DT_QUINT16 DT_UINT16
+ dtypes.string:
+ 'string_value',
+ dtypes.uint8:
+ 'uint8_value',
+ dtypes.uint32:
+ 'uint32_value',
+ dtypes.uint64:
+ 'uint64_value',
}
tf_field_name = tf_type_to_primitive_value_field.get(field.dtype)
if tf_field_name is None:
diff --git a/tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py b/tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py
index cbc7b3d3f8..2950c7dfdc 100644
--- a/tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py
+++ b/tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py
@@ -44,7 +44,7 @@ class ProtoOpTestBase(test.TestCase):
("minmax", ProtoOpTestBase.minmax_test_case()),
("nested", ProtoOpTestBase.nested_test_case()),
("optional", ProtoOpTestBase.optional_test_case()),
- ("promote_unsigned", ProtoOpTestBase.promote_unsigned_test_case()),
+ ("promote", ProtoOpTestBase.promote_test_case()),
("ragged", ProtoOpTestBase.ragged_test_case()),
("shaped_batch", ProtoOpTestBase.shaped_batch_test_case()),
("simple", ProtoOpTestBase.simple_test_case()),
@@ -83,13 +83,13 @@ class ProtoOpTestBase(test.TestCase):
test_case.sizes.append(0)
field = test_case.fields.add()
field.name = "uint64_value_with_default"
- field.dtype = types_pb2.DT_INT64
- field.value.int64_value.append(4)
+ field.dtype = types_pb2.DT_UINT64
+ field.value.uint64_value.append(4)
test_case.sizes.append(0)
field = test_case.fields.add()
field.name = "fixed64_value_with_default"
- field.dtype = types_pb2.DT_INT64
- field.value.int64_value.append(6)
+ field.dtype = types_pb2.DT_UINT64
+ field.value.uint64_value.append(6)
test_case.sizes.append(0)
field = test_case.fields.add()
field.name = "int32_value_with_default"
@@ -108,13 +108,13 @@ class ProtoOpTestBase(test.TestCase):
test_case.sizes.append(0)
field = test_case.fields.add()
field.name = "uint32_value_with_default"
- field.dtype = types_pb2.DT_INT32
- field.value.int32_value.append(9)
+ field.dtype = types_pb2.DT_UINT32
+ field.value.uint32_value.append(9)
test_case.sizes.append(0)
field = test_case.fields.add()
field.name = "fixed32_value_with_default"
- field.dtype = types_pb2.DT_INT32
- field.value.int32_value.append(7)
+ field.dtype = types_pb2.DT_UINT32
+ field.value.uint32_value.append(7)
test_case.sizes.append(0)
field = test_case.fields.add()
field.name = "bool_value_with_default"
@@ -202,15 +202,15 @@ class ProtoOpTestBase(test.TestCase):
test_case.sizes.append(2)
field = test_case.fields.add()
field.name = "uint64_value"
- field.dtype = types_pb2.DT_INT64
- field.value.int64_value.append(0)
- field.value.int64_value.append(-1)
+ field.dtype = types_pb2.DT_UINT64
+ field.value.uint64_value.append(0)
+ field.value.uint64_value.append(18446744073709551615)
test_case.sizes.append(2)
field = test_case.fields.add()
field.name = "fixed64_value"
- field.dtype = types_pb2.DT_INT64
- field.value.int64_value.append(0)
- field.value.int64_value.append(-1)
+ field.dtype = types_pb2.DT_UINT64
+ field.value.uint64_value.append(0)
+ field.value.uint64_value.append(18446744073709551615)
test_case.sizes.append(2)
field = test_case.fields.add()
field.name = "int32_value"
@@ -232,15 +232,15 @@ class ProtoOpTestBase(test.TestCase):
test_case.sizes.append(2)
field = test_case.fields.add()
field.name = "uint32_value"
- field.dtype = types_pb2.DT_INT32
- field.value.int32_value.append(0)
- field.value.int32_value.append(-1)
+ field.dtype = types_pb2.DT_UINT32
+ field.value.uint32_value.append(0)
+ field.value.uint32_value.append(4294967295)
test_case.sizes.append(2)
field = test_case.fields.add()
field.name = "fixed32_value"
- field.dtype = types_pb2.DT_INT32
- field.value.int32_value.append(0)
- field.value.int32_value.append(-1)
+ field.dtype = types_pb2.DT_UINT32
+ field.value.uint32_value.append(0)
+ field.value.uint32_value.append(4294967295)
test_case.sizes.append(2)
field = test_case.fields.add()
field.name = "bool_value"
@@ -289,28 +289,40 @@ class ProtoOpTestBase(test.TestCase):
return test_case
@staticmethod
- def promote_unsigned_test_case():
+ def promote_test_case():
test_case = test_example_pb2.TestCase()
value = test_case.values.add()
+ value.sint32_value.append(2147483647)
+ value.sfixed32_value.append(2147483647)
+ value.int32_value.append(2147483647)
value.fixed32_value.append(4294967295)
value.uint32_value.append(4294967295)
test_case.shapes.append(1)
test_case.sizes.append(1)
field = test_case.fields.add()
- field.name = "fixed32_value"
+ field.name = "sint32_value"
field.dtype = types_pb2.DT_INT64
- field.value.int64_value.append(4294967295)
+ field.value.int64_value.append(2147483647)
test_case.sizes.append(1)
field = test_case.fields.add()
- field.name = "uint32_value"
+ field.name = "sfixed32_value"
field.dtype = types_pb2.DT_INT64
- field.value.int64_value.append(4294967295)
- # Comes from an explicitly-specified default
- test_case.sizes.append(0)
+ field.value.int64_value.append(2147483647)
+ test_case.sizes.append(1)
field = test_case.fields.add()
- field.name = "uint32_value_with_default"
+ field.name = "int32_value"
field.dtype = types_pb2.DT_INT64
- field.value.int64_value.append(9)
+ field.value.int64_value.append(2147483647)
+ test_case.sizes.append(1)
+ field = test_case.fields.add()
+ field.name = "fixed32_value"
+ field.dtype = types_pb2.DT_UINT64
+ field.value.uint64_value.append(4294967295)
+ test_case.sizes.append(1)
+ field = test_case.fields.add()
+ field.name = "uint32_value"
+ field.dtype = types_pb2.DT_UINT64
+ field.value.uint64_value.append(4294967295)
return test_case
@staticmethod
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py
index 211c59cb90..750e677263 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_context.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py
@@ -595,7 +595,8 @@ class _InternalTPUContext(object):
raise ValueError(message)
if mode == model_fn_lib.ModeKeys.TRAIN:
- if self._train_batch_size % num_replicas != 0:
+ if (self._train_batch_size % num_replicas != 0 and
+ not self.is_input_broadcast_with_iterators()):
raise ValueError(
'train batch size {} must be divisible by number of replicas {}'
.format(self._train_batch_size, num_replicas))
@@ -605,7 +606,8 @@ class _InternalTPUContext(object):
raise ValueError(
'eval_batch_size in TPUEstimator constructor cannot be `None`'
'if .evaluate is running on TPU.')
- if self._eval_batch_size % num_replicas != 0:
+ if (self._eval_batch_size % num_replicas != 0 and
+ not self.is_input_broadcast_with_iterators()):
raise ValueError(
'eval batch size {} must be divisible by number of replicas {}'
.format(self._eval_batch_size, num_replicas))
@@ -619,7 +621,8 @@ class _InternalTPUContext(object):
raise ValueError(
'predict_batch_size in TPUEstimator constructor should not be '
'`None` if .predict is running on TPU.')
- if self._predict_batch_size % num_replicas != 0:
+ if (self._predict_batch_size % num_replicas != 0 and
+ not self.is_input_broadcast_with_iterators()):
raise ValueError(
'predict batch size {} must be divisible by number of replicas {}'
.format(self._predict_batch_size, num_replicas))
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index 78b79b111e..73dfefd19c 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -858,7 +858,8 @@ def generate_broadcast_enqueue_ops_fn(ctx, input_fn, inputs_structure_recorder,
if ctx.mode == model_fn_lib.ModeKeys.PREDICT:
raise TypeError('Mode PREDICT not yet supported in BROADCAST mode.')
- hooks.append(inputs.dataset_initializer_hook())
+ if is_dataset:
+ hooks.append(inputs.dataset_initializer_hook())
num_replicas_per_host = ctx.num_of_replicas_per_host
def tpu_ordinal_function_impl(replica_id):
@@ -1336,7 +1337,8 @@ class _ModelFnWrapper(object):
loss = tpu_estimator_spec.loss
captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn)
to_record = {}
- to_record['eval_metrics'] = tpu_estimator_spec.eval_metrics
+ if tpu_estimator_spec.eval_metrics:
+ to_record['eval_metrics'] = tpu_estimator_spec.eval_metrics
if tpu_estimator_spec.host_call is not None:
# We assume that evaluate won't update global step, so we don't wrap
# this host_call.
@@ -1639,7 +1641,7 @@ class _OutfeedHostCall(object):
RuntimeError: If outfeed tensor is scalar.
"""
if not self._names:
- return []
+ return {}
ret = {}
# For each i, dequeue_ops[i] is a list containing the tensors from all
@@ -2514,7 +2516,8 @@ class TPUEstimator(estimator_lib.Estimator):
host_call_ret = host_calls.create_tpu_hostcall()
eval_metric_ops = {}
eval_update_ops = []
- for k, v in host_call_ret['eval_metrics'].items():
+
+ for k, v in host_call_ret.get('eval_metrics', {}).items():
eval_metric_ops[k] = (v[0], dummy_update_op)
eval_update_ops.append(v[1])
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 514713bb96..fc12027291 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -3226,6 +3226,7 @@ tf_cc_tests(
":test",
":test_main",
"//third_party/eigen3",
+ "@zlib_archive//:zlib",
],
)
diff --git a/tensorflow/core/common_runtime/bfc_allocator.cc b/tensorflow/core/common_runtime/bfc_allocator.cc
index 9cda17867b..f8ca039d15 100644
--- a/tensorflow/core/common_runtime/bfc_allocator.cc
+++ b/tensorflow/core/common_runtime/bfc_allocator.cc
@@ -465,49 +465,33 @@ void BFCAllocator::FreeAndMaybeCoalesce(BFCAllocator::ChunkHandle h) {
Chunk* c = ChunkFromHandle(h);
CHECK(c->in_use() && (c->bin_num == kInvalidBinNum));
- // Mark the chunk as no longer in use
+ // Mark the chunk as no longer in use.
c->allocation_id = -1;
// Updates the stats.
stats_.bytes_in_use -= c->size;
- // This chunk is no longer in-use, consider coalescing the chunk
- // with adjacent chunks.
- ChunkHandle chunk_to_reassign = h;
+ ChunkHandle coalesced_chunk = h;
- // If the next chunk is free, coalesce the two
- if (c->next != kInvalidChunkHandle) {
- Chunk* cnext = ChunkFromHandle(c->next);
- if (!cnext->in_use()) {
- // VLOG(8) << "Chunk at " << cnext->ptr << " merging with c " <<
- // c->ptr;
-
- chunk_to_reassign = h;
-
- // Deletes c->next
- RemoveFreeChunkFromBin(c->next);
- Merge(h, ChunkFromHandle(h)->next);
- }
+ // If the next chunk is free, merge it into c and delete it.
+ if (c->next != kInvalidChunkHandle && !ChunkFromHandle(c->next)->in_use()) {
+ // VLOG(8) << "Merging c->next " << ChunkFromHandle(c->next)->ptr
+ // << " with c " << c->ptr;
+ RemoveFreeChunkFromBin(c->next);
+ Merge(h, c->next);
}
- // If the previous chunk is free, coalesce the two
- c = ChunkFromHandle(h);
- if (c->prev != kInvalidChunkHandle) {
- Chunk* cprev = ChunkFromHandle(c->prev);
- if (!cprev->in_use()) {
- // VLOG(8) << "Chunk at " << c->ptr << " merging into c->prev "
- // << cprev->ptr;
+ // If the previous chunk is free, merge c into it and delete c.
+ if (c->prev != kInvalidChunkHandle && !ChunkFromHandle(c->prev)->in_use()) {
+ // VLOG(8) << "Merging c " << c->ptr << " into c->prev "
+ // << ChunkFromHandle(c->prev)->ptr;
- chunk_to_reassign = c->prev;
-
- // Deletes c
- RemoveFreeChunkFromBin(c->prev);
- Merge(ChunkFromHandle(h)->prev, h);
- c = ChunkFromHandle(h);
- }
+ coalesced_chunk = c->prev;
+ RemoveFreeChunkFromBin(c->prev);
+ Merge(c->prev, h);
}
- InsertFreeChunkIntoBin(chunk_to_reassign);
+ InsertFreeChunkIntoBin(coalesced_chunk);
}
void BFCAllocator::AddAllocVisitor(Visitor visitor) {
diff --git a/tensorflow/core/common_runtime/bfc_allocator.h b/tensorflow/core/common_runtime/bfc_allocator.h
index 52aedb1e9c..cd8ff6e5c0 100644
--- a/tensorflow/core/common_runtime/bfc_allocator.h
+++ b/tensorflow/core/common_runtime/bfc_allocator.h
@@ -191,18 +191,14 @@ class BFCAllocator : public VisitableAllocator {
DCHECK_EQ(0, memory_size % kMinAllocationSize);
const size_t n_handles =
(memory_size + kMinAllocationSize - 1) / kMinAllocationSize;
- handles_ = new ChunkHandle[n_handles];
+ handles_.reset(new ChunkHandle[n_handles]);
for (size_t i = 0; i < n_handles; i++) {
handles_[i] = kInvalidChunkHandle;
}
}
- AllocationRegion() {}
-
- ~AllocationRegion() { delete[] handles_; }
-
+ AllocationRegion() = default;
AllocationRegion(AllocationRegion&& other) { Swap(other); }
-
AllocationRegion& operator=(AllocationRegion&& other) {
Swap(other);
return *this;
@@ -241,7 +237,7 @@ class BFCAllocator : public VisitableAllocator {
// Array of size "memory_size / kMinAllocationSize". It is
// indexed by (p-base) / kMinAllocationSize, contains ChunkHandle
// for the memory allocation represented by "p"
- ChunkHandle* handles_ = nullptr;
+ std::unique_ptr<ChunkHandle[]> handles_;
TF_DISALLOW_COPY_AND_ASSIGN(AllocationRegion);
};
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc
index 7110ffd40c..3292ef2f62 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc
@@ -225,6 +225,7 @@ class BaseGPUDevice::StreamGroupFactory {
int num_d2d_streams =
options.experimental().num_dev_to_dev_copy_streams();
+ if (num_d2d_streams == 0) num_d2d_streams = 1;
if (num_d2d_streams < 1 || num_d2d_streams > 4) {
LOG(ERROR)
<< "Illegal GPUOptions.experimental.num_dev_to_dev_copy_streams="
diff --git a/tensorflow/core/framework/register_types.h b/tensorflow/core/framework/register_types.h
index e90596980f..f1cd37ecda 100644
--- a/tensorflow/core/framework/register_types.h
+++ b/tensorflow/core/framework/register_types.h
@@ -151,6 +151,12 @@ limitations under the License.
// Defines for sets of types.
+// TODO(b/111604096): Add uint32 and uint64 to TF_CALL_INTEGRAL_TYPES.
+//
+// The uint32 and uint64 types were introduced in 10/2017 to be used via XLA and
+// thus were not included in TF_CALL_INTEGRAL_TYPES. Including them in
+// TF_CALL_INTEGRAL_TYPES should only happen after evaluating the effect on the
+// TF binary size and performance.
#define TF_CALL_INTEGRAL_TYPES(m) \
TF_CALL_int64(m) TF_CALL_int32(m) TF_CALL_uint16(m) TF_CALL_int16(m) \
TF_CALL_uint8(m) TF_CALL_int8(m)
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 99e5e3cfca..10cbcdecc8 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -6320,6 +6320,7 @@ tf_kernel_library(
"//tensorflow/core:lib",
"//tensorflow/core/util/proto:decode",
"//tensorflow/core/util/proto:descriptors",
+ "//tensorflow/core/util/proto:proto_utils",
"//third_party/eigen3",
],
)
@@ -6332,6 +6333,7 @@ tf_kernel_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/util/proto:descriptors",
+ "//tensorflow/core/util/proto:proto_utils",
"//third_party/eigen3",
],
)
diff --git a/tensorflow/core/kernels/argmax_op.cc b/tensorflow/core/kernels/argmax_op.cc
index 49cd997fed..adc573e40c 100644
--- a/tensorflow/core/kernels/argmax_op.cc
+++ b/tensorflow/core/kernels/argmax_op.cc
@@ -59,7 +59,7 @@ class ArgOp : public OpKernel {
int axis = dim < 0 ? dim + input_dims : dim;
- OP_REQUIRES(context, axis >= 0 && axis < input_dims,
+ OP_REQUIRES(context, FastBoundsCheck(axis, input_dims),
errors::InvalidArgument("Expected dimension in the range [",
-input_dims, ", ", input_dims,
"), but got ", dim));
diff --git a/tensorflow/core/kernels/decode_proto_op.cc b/tensorflow/core/kernels/decode_proto_op.cc
index 6d3dcc1c59..b54e1ea8ac 100644
--- a/tensorflow/core/kernels/decode_proto_op.cc
+++ b/tensorflow/core/kernels/decode_proto_op.cc
@@ -13,21 +13,19 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-// DecodeProto is a TensorFlow Op which extracts arbitrary fields
-// from protos serialized as strings.
+// DecodeProto is a TensorFlow op which extracts arbitrary fields from protos
+// serialized as strings.
//
// See docs in ../ops/decode_proto_op.cc.
//
-// This implementation reads the serialized format using a handful of
-// calls from the WireFormatLite API used by generated proto code.
-// WireFormatLite is marked as an "internal" proto API but is widely
-// used in practice and highly unlikely to change.
-// This will be much faster than the previous implementation based on
-// constructing a temporary dynamic message in memory and using the
-// proto reflection api to read it.
-// It can be used with any proto whose descriptors are available at
-// runtime but should be competitive in speed with approaches that
-// compile in the proto definitions.
+// This implementation reads the serialized format using a handful of calls from
+// the WireFormatLite API used by generated proto code. WireFormatLite is marked
+// as an "internal" proto API but is widely used in practice and highly unlikely
+// to change. This will be much faster than the previous implementation based on
+// constructing a temporary dynamic message in memory and using the proto
+// reflection api to read it. It can be used with any proto whose descriptors
+// are available at runtime but should be competitive in speed with approaches
+// that compile in the proto definitions.
#include <memory>
#include <string>
@@ -36,11 +34,13 @@ limitations under the License.
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/util/proto/decode.h"
#include "tensorflow/core/util/proto/descriptors.h"
+#include "tensorflow/core/util/proto/proto_utils.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
@@ -58,53 +58,6 @@ using ::tensorflow::protobuf::io::CodedInputStream;
const bool kFailOnDecodeError = true;
-// Returns true if the proto field type can be converted to the
-// tensorflow::DataType.
-bool CheckOutputType(FieldDescriptor::Type field_type, DataType output_type) {
- switch (field_type) {
- case WireFormatLite::TYPE_DOUBLE:
- return output_type == tensorflow::DT_DOUBLE;
- case WireFormatLite::TYPE_FLOAT:
- return output_type == tensorflow::DT_FLOAT ||
- output_type == tensorflow::DT_DOUBLE;
- case WireFormatLite::TYPE_INT64:
- return output_type == tensorflow::DT_INT64;
- case WireFormatLite::TYPE_UINT64:
- return output_type == tensorflow::DT_INT64;
- case WireFormatLite::TYPE_INT32:
- return output_type == tensorflow::DT_INT32;
- case WireFormatLite::TYPE_FIXED64:
- return output_type == tensorflow::DT_INT64;
- case WireFormatLite::TYPE_FIXED32:
- return output_type == tensorflow::DT_INT32 ||
- output_type == tensorflow::DT_INT64;
- case WireFormatLite::TYPE_BOOL:
- return output_type == tensorflow::DT_BOOL;
- case WireFormatLite::TYPE_STRING:
- return output_type == tensorflow::DT_STRING;
- case WireFormatLite::TYPE_GROUP:
- return output_type == tensorflow::DT_STRING;
- case WireFormatLite::TYPE_MESSAGE:
- return output_type == tensorflow::DT_STRING;
- case WireFormatLite::TYPE_BYTES:
- return output_type == tensorflow::DT_STRING;
- case WireFormatLite::TYPE_UINT32:
- return output_type == tensorflow::DT_INT32 ||
- output_type == tensorflow::DT_INT64;
- case WireFormatLite::TYPE_ENUM:
- return output_type == tensorflow::DT_INT32;
- case WireFormatLite::TYPE_SFIXED32:
- return output_type == tensorflow::DT_INT32;
- case WireFormatLite::TYPE_SFIXED64:
- return output_type == tensorflow::DT_INT64;
- case WireFormatLite::TYPE_SINT32:
- return output_type == tensorflow::DT_INT32;
- case WireFormatLite::TYPE_SINT64:
- return output_type == tensorflow::DT_INT64;
- // default: intentionally omitted in order to enable static checking.
- }
-}
-
// Used to store the default value of a protocol message field, casted to the
// type of the output tensor.
//
@@ -113,13 +66,15 @@ struct DefaultValue {
DataType dtype = DataType::DT_INVALID;
union Value {
bool v_bool; // DT_BOOL
- uint8 v_uint8; // DT_UINT8
+ double v_double; // DT_DOUBLE
+ float v_float; // DT_FLOAT
int8 v_int8; // DT_INT8
int32 v_int32; // DT_INT32
int64 v_int64; // DT_INT64
- float v_float; // DT_FLOAT
- double v_double; // DT_DOUBLE
const char* v_string; // DT_STRING
+ uint8 v_uint8; // DT_UINT8
+ uint8 v_uint32; // DT_UINT32
+ uint8 v_uint64; // DT_UINT64
};
Value value;
};
@@ -138,23 +93,29 @@ Status InitDefaultValue(DataType dtype, const T value, DefaultValue* result) {
case DT_BOOL:
result->value.v_bool = static_cast<bool>(value);
break;
- case DT_INT32:
- result->value.v_int32 = static_cast<int32>(value);
+ case DT_DOUBLE:
+ result->value.v_double = static_cast<double>(value);
+ break;
+ case DT_FLOAT:
+ result->value.v_float = static_cast<float>(value);
break;
case DT_INT8:
result->value.v_int8 = static_cast<int8>(value);
break;
- case DT_UINT8:
- result->value.v_uint8 = static_cast<uint8>(value);
+ case DT_INT32:
+ result->value.v_int32 = static_cast<int32>(value);
break;
case DT_INT64:
result->value.v_int64 = static_cast<int64>(value);
break;
- case DT_FLOAT:
- result->value.v_float = static_cast<float>(value);
+ case DT_UINT8:
+ result->value.v_uint8 = static_cast<uint8>(value);
break;
- case DT_DOUBLE:
- result->value.v_double = static_cast<double>(value);
+ case DT_UINT32:
+ result->value.v_uint32 = static_cast<uint32>(value);
+ break;
+ case DT_UINT64:
+ result->value.v_uint64 = static_cast<uint64>(value);
break;
default:
// We should never get here, given the type checking that occurs earlier.
@@ -241,13 +202,11 @@ struct FieldInfo {
number = field_desc->number();
// The wire format library defines the same constants used in
- // descriptor.proto. This static_cast is safe because they
- // are guaranteed to stay in sync.
- // We need the field type from the FieldDescriptor here
- // because the wire format doesn't tell us anything about
- // what happens inside a packed repeated field: there is
- // enough information in the wire format to skip the
- // whole field but not enough to know how to parse what's
+ // descriptor.proto. This static_cast is safe because they are guaranteed to
+ // stay in sync. We need the field type from the FieldDescriptor here
+ // because the wire format doesn't tell us anything about what happens
+ // inside a packed repeated field: there is enough information in the wire
+ // format to skip the whole field but not enough to know how to parse what's
// inside. For that we go to the schema.
type = static_cast<WireFormatLite::FieldType>(field_desc->type());
is_repeated = field_desc->is_repeated();
@@ -257,16 +216,15 @@ struct FieldInfo {
FieldInfo(const FieldInfo&) = delete;
FieldInfo& operator=(const FieldInfo&) = delete;
- // Internally we sort field descriptors by wire number for
- // fast lookup. In general this is different from the order
- // given by the user. Output_index gives the index into
- // the field_names and output_types attributes and into
+ // Internally we sort field descriptors by wire number for fast lookup. In
+ // general this is different from the order given by the user. Output_index
+ // gives the index into the field_names and output_types attributes and into
// the output tensor list.
int output_index = -1;
- // This is a cache of the relevant fields from `FieldDescriptorProto`.
- // This was added after noticing that FieldDescriptor->type() was
- // using 6% of the cpu profile.
+ // This is a cache of the relevant fields from `FieldDescriptorProto`. This
+ // was added after noticing that FieldDescriptor->type() was using 6% of the
+ // cpu profile.
WireFormatLite::FieldType type;
int number;
bool is_repeated;
@@ -275,16 +233,16 @@ struct FieldInfo {
// A CountCollector counts sizes of repeated and optional fields in a proto.
//
-// Each field is tracked by a single CountCollector instance. The
-// instance manages a single count, which is stored as a pointer (it
-// is intended to be a reference to the `sizes` output which is being
-// filled in). The pointer is passed in at initialization.
+// Each field is tracked by a single CountCollector instance. The instance
+// manages a single count, which is stored as a pointer (it is intended to be a
+// reference to the `sizes` output which is being filled in). The pointer is
+// passed in at initialization.
//
-// Counting is done as a separate pass in order to allocate output tensors
-// all at once. This allows the TensorFlow runtime to optimize allocation
-// for the consumer, while removing the need for copying inside this op.
-// After this pass, the DenseCollector class (below) gathers the data:
-// It is more complex and provides better motivation for the API here.
+// Counting is done as a separate pass in order to allocate output tensors all
+// at once. This allows the TensorFlow runtime to optimize allocation for the
+// consumer, while removing the need for copying inside this op. After this
+// pass, the DenseCollector class (below) gathers the data: it is more complex
+// and provides better motivation for the API here.
class CountCollector {
public:
CountCollector() = delete;
@@ -298,8 +256,8 @@ class CountCollector {
if (*count_ptr_ == 0 || field.is_repeated) {
(*count_ptr_)++;
}
- // We expect a wire type based on the schema field_type, to allow
- // a little more checking.
+ // We expect a wire type based on the schema field_type, to allow a little
+ // more checking.
if (!SkipValue(input, field)) {
return errors::DataLoss("ReadValue: Failed skipping field when counting");
}
@@ -329,8 +287,8 @@ class CountCollector {
return errors::DataLoss("ReadPackedValues: Skipping packed field failed");
}
- // Dispatch to the appropriately typed field reader based on the
- // schema type.
+ // Dispatch to the appropriately typed field reader based on the schema
+ // type.
Status st;
switch (field.type) {
case WireFormatLite::TYPE_DOUBLE:
@@ -409,18 +367,17 @@ class CountCollector {
return input->Skip(length);
}
- // Counts the number of packed varints in an array.
- // The end of a varint is signaled by a value < 0x80,
- // so counting them requires parsing the bytestream.
- // It is the caller's responsibility to ensure that len > 0.
+ // Counts the number of packed varints in an array. The end of a varint is
+ // signaled by a value < 0x80, so counting them requires parsing the
+ // bytestream. It is the caller's responsibility to ensure that len > 0.
Status CountPackedVarint(const uint8* buf, size_t len) {
const uint8* bound = buf + len;
int count;
- // The last byte in a valid encoded varint is guaranteed to have
- // the high bit unset. We rely on this property to prevent
- // ReadVarint64FromArray from going out of bounds, so validate
- // the end of the buf before scanning anything.
+ // The last byte in a valid encoded varint is guaranteed to have the high
+ // bit unset. We rely on this property to prevent ReadVarint64FromArray from
+ // going out of bounds, so validate the end of the buf before scanning
+ // anything.
if (bound[-1] & 0x80) {
return errors::DataLoss("Corrupt packed varint");
}
@@ -439,8 +396,8 @@ class CountCollector {
return Status::OK();
}
- // Counts the number of fixed-size values in a packed field.
- // This can be done without actually parsing anything.
+ // Counts the number of fixed-size values in a packed field. This can be done
+ // without actually parsing anything.
template <typename T>
Status CountPackedFixed(const uint8* unused_buf, size_t len) {
int count = len / sizeof(T);
@@ -452,10 +409,9 @@ class CountCollector {
return Status::OK();
}
- // Skips a single value in the input stream.
- // Dispatches to the appropriately typed field skipper based on the
- // schema type tag.
- // This is not as permissive as just handling the wire type.
+ // Skips a single value in the input stream. Dispatches to the appropriately
+ // typed field skipper based on the schema type tag. This is not as permissive
+ // as just handling the wire type.
static bool SkipValue(CodedInputStream* input, const FieldInfo& field) {
uint32 tmp32;
protobuf_uint64 tmp64;
@@ -507,13 +463,13 @@ class CountCollector {
// A DenseCollector accumulates values from a proto into a tensor.
//
-// There is an instance of DenseCollector for each field of each
-// proto. The DenseCollector deserializes the value from the wire
-// directly into the preallocated output Tensor.
+// There is an instance of DenseCollector for each field of each proto. The
+// DenseCollector deserializes the value from the wire directly into the
+// preallocated output Tensor.
//
-// This class is named DenseCollector because in the future there should
-// be a SparseCollector that accumulates field data into sparse tensors if
-// the user requests it.
+// This class is named DenseCollector because in the future there should be a
+// SparseCollector that accumulates field data into sparse tensors if the user
+// requests it.
class DenseCollector {
public:
DenseCollector() = delete;
@@ -578,40 +534,43 @@ class DenseCollector {
}
}
- // Fills in any missing values in the output array with defaults.
- // Dispatches to the appropriately typed field default based on the
- // runtime type tag.
+ // Fills in any missing values in the output array with defaults. Dispatches
+ // to the appropriately typed field default based on the runtime type tag.
Status FillWithDefaults() {
switch (default_value_.dtype) {
+ case DataType::DT_BOOL:
+ return FillDefault<bool>(default_value_.value.v_bool);
case DataType::DT_FLOAT:
return FillDefault<float>(default_value_.value.v_float);
case DataType::DT_DOUBLE:
return FillDefault<double>(default_value_.value.v_double);
- case DataType::DT_INT32:
- return FillDefault<int32>(default_value_.value.v_int32);
- case DataType::DT_UINT8:
- return FillDefault<uint8>(default_value_.value.v_uint8);
case DataType::DT_INT8:
return FillDefault<int8>(default_value_.value.v_int8);
- case DataType::DT_STRING:
- return FillDefault<string>(default_value_.value.v_string);
+ case DataType::DT_INT32:
+ return FillDefault<int32>(default_value_.value.v_int32);
case DataType::DT_INT64:
return FillDefault<int64>(default_value_.value.v_int64);
- case DataType::DT_BOOL:
- return FillDefault<bool>(default_value_.value.v_bool);
+ case DataType::DT_STRING:
+ return FillDefault<string>(default_value_.value.v_string);
+ case DataType::DT_UINT8:
+ return FillDefault<uint8>(default_value_.value.v_uint8);
+ case DataType::DT_UINT32:
+ return FillDefault<uint32>(default_value_.value.v_uint32);
+ case DataType::DT_UINT64:
+ return FillDefault<uint64>(default_value_.value.v_uint64);
default:
// There are many tensorflow dtypes not handled here, but they
// should not come up unless type casting is added to the Op.
// Chaining with tf.cast() should do the right thing until then.
- return errors::DataLoss(
- "Failed filling defaults in unknown tf::DataType");
+ return errors::DataLoss("Failed filling defaults for ",
+ DataTypeString(default_value_.dtype));
}
}
private:
- // Fills empty values in the dense representation with a
- // default value. This uses next_repeat_index_ which counts the number
- // of parsed values for the field.
+ // Fills empty values in the dense representation with a default value. This
+ // uses next_repeat_index_ which counts the number of parsed values for the
+ // field.
template <class T>
Status FillDefault(const T& default_value) {
for (int i = next_repeat_index_; i < max_repeat_count_; i++) {
@@ -622,11 +581,10 @@ class DenseCollector {
int32 next_repeat_index_ = 0;
- // This is a pointer to data_[message_index_].
- // There is no bounds checking at this level: we computed the max
- // repeat size for each field in CountCollector and use the same
- // code to traverse it here, so we are guaranteed not to be called
- // for more items than we have allocated space.
+ // This is a pointer to data_[message_index_]. There is no bounds checking at
+ // this level: we computed the max repeat size for each field in
+ // CountCollector and use the same code to traverse it here, so we are
+ // guaranteed not to be called for more items than we have allocated space.
void* const datap_ = nullptr;
const DefaultValue default_value_;
@@ -665,7 +623,6 @@ class DecodeProtoOp : public OpKernel {
"have the same length"));
// Gather the field descriptors and check that requested output types match.
-
int field_index = 0;
std::vector<const FieldDescriptor*> field_descs;
for (const string& name : field_names) {
@@ -673,18 +630,16 @@ class DecodeProtoOp : public OpKernel {
OP_REQUIRES(context, fd != nullptr,
errors::InvalidArgument("Unknown field: ", name,
" in message type ", message_type));
- OP_REQUIRES(context,
- CheckOutputType(fd->type(), output_types[field_index]),
- // Many TensorFlow types don't have corresponding proto types
- // and the user will get an error if they are requested. It
- // would be nice to allow conversions here, but tf.cast
- // already exists so we don't duplicate the functionality.
- // Known unhandled types:
- // DT_INT16 DT_COMPLEX64 DT_QINT8 DT_QUINT8 DT_QINT32
- // DT_BFLOAT16 DT_QINT16 DT_QUINT16 DT_UINT16
- errors::InvalidArgument("Unexpected output type for ",
- fd->full_name(), ": ", fd->cpp_type(),
- " to ", output_types[field_index]));
+ OP_REQUIRES(
+ context,
+ proto_utils::IsCompatibleType(fd->type(), output_types[field_index]),
+ // Many TensorFlow types don't have corresponding proto types and the
+ // user will get an error if they are requested. It would be nice to
+ // allow conversions here, but tf.cast already exists so we don't
+ // duplicate the functionality.
+ errors::InvalidArgument("Unexpected output type for ",
+ fd->full_name(), ": ", fd->cpp_type(), " to ",
+ output_types[field_index]));
field_index++;
field_descs.push_back(fd);
@@ -726,10 +681,9 @@ class DecodeProtoOp : public OpKernel {
errors::InvalidArgument("format must be one of binary or text"));
is_binary_ = format == "binary";
- // Enable the initial protobuf sanitizer, which is much
- // more expensive than the decoder.
- // TODO(nix): Remove this once the fast decoder
- // has passed security review.
+ // Enable the initial protobuf sanitizer, which is much more expensive than
+ // the decoder.
+ // TODO(nix): Remove this once the fast decoder has passed security review.
OP_REQUIRES_OK(context, context->GetAttr("sanitize", &sanitize_));
}
@@ -742,9 +696,9 @@ class DecodeProtoOp : public OpKernel {
int field_count = fields_.size();
- // Save the argument shape for later, then flatten the input
- // Tensor since we are working componentwise. We will restore
- // the same shape in the returned Tensor.
+ // Save the argument shape for later, then flatten the input Tensor since we
+ // are working componentwise. We will restore the same shape in the returned
+ // Tensor.
const TensorShape& shape_prefix = buf_tensor.shape();
TensorShape sizes_shape = shape_prefix;
@@ -752,8 +706,8 @@ class DecodeProtoOp : public OpKernel {
Tensor* sizes_tensor = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, sizes_shape, &sizes_tensor));
- // This is used to allocate binary bufs if used. It serves only
- // to define memory ownership.
+ // This is used to allocate binary bufs if used. It serves only to define
+ // memory ownership.
std::vector<string> tmp_binary_bufs(message_count);
// These are the actual buffers to use, which may be in tmp_binary_bufs
@@ -768,8 +722,8 @@ class DecodeProtoOp : public OpKernel {
bufs.push_back(buf);
}
} else {
- // We will have to allocate a copy, either to convert from text to
- // binary or to sanitize a binary proto.
+ // We will have to allocate a copy, either to convert from text to binary
+ // or to sanitize a binary proto.
for (int mi = 0; mi < message_count; ++mi) {
ReserializeMessage(ctx, buf_tensor.flat<string>()(mi),
&tmp_binary_bufs[mi]);
@@ -780,16 +734,14 @@ class DecodeProtoOp : public OpKernel {
}
}
- // Walk through all the strings in the input tensor, counting
- // the number of fields in each.
- // We can't allocate our actual output Tensor until we know the
- // maximum repeat count, so we do a first pass through the serialized
- // proto just counting fields.
- // We always allocate at least one value so that optional fields
- // are populated with default values - this avoids a TF
- // conditional when handling the output data.
- // The caller can distinguish between real data and defaults
- // using the repeat count matrix that is returned by decode_proto.
+ // Walk through all the strings in the input tensor, counting the number of
+ // fields in each. We can't allocate our actual output Tensor until we know
+ // the maximum repeat count, so we do a first pass through the serialized
+ // proto just counting fields. We always allocate at least one value so that
+ // optional fields are populated with default values - this avoids a TF
+ // conditional when handling the output data. The caller can distinguish
+ // between real data and defaults using the repeat count matrix that is
+ // returned by decode_proto.
std::vector<int32> max_sizes(field_count, 1);
for (int mi = 0; mi < message_count; ++mi) {
CountFields(ctx, mi, *bufs[mi], sizes_tensor, &max_sizes);
@@ -814,14 +766,12 @@ class DecodeProtoOp : public OpKernel {
// REGISTER_OP(...)
// .Attr("output_types: list(type) >= 0")
// .Output("values: output_types")
- OP_REQUIRES_OK(ctx,
- // ctx->allocate_output(output_indices_[fi] + 1,
- ctx->allocate_output(fields_[fi]->output_index + 1,
- out_shape, &outputs[fi]));
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(fields_[fi]->output_index + 1,
+ out_shape, &outputs[fi]));
}
- // Make the second pass through the serialized proto, decoding
- // into preallocated tensors.
+ // Make the second pass through the serialized proto, decoding into
+ // preallocated tensors.
AccumulateFields(ctx, bufs, outputs);
}
@@ -976,6 +926,7 @@ class DecodeProtoOp : public OpKernel {
// Look up the FieldDescriptor for a particular field number.
bool LookupField(int field_number, int* field_index) {
// Look up the FieldDescriptor using linear search.
+ //
// TODO(nix): this could be sped up with binary search, but we are
// already way off the fastpath at this point. If you see a hotspot
// here, somebody is sending you very inefficient protos.
@@ -1010,6 +961,7 @@ class DecodeProtoOp : public OpKernel {
// This takes advantage of the sorted field numbers in most serialized
// protos: it tries the next expected field first rather than doing
// a lookup by field number.
+ //
// TODO(nix): haberman@ suggests a hybrid approach with a lookup table
// for small field numbers and a hash table for larger ones. This would
// be a simpler approach that should offer comparable speed in most
@@ -1029,9 +981,9 @@ class DecodeProtoOp : public OpKernel {
last_good_field_index = field_index;
}
} else {
- // If we see a field that is past the next field we want,
- // it was empty. Look for the one after that.
- // Repeat until we run out of fields that we care about.
+ // If we see a field that is past the next field we want, it was
+ // empty. Look for the one after that. Repeat until we run out of
+ // fields that we care about.
while (field_number >= next_good_field_number) {
if (field_number == next_good_field_number) {
last_good_field_number = field_number;
@@ -1044,10 +996,9 @@ class DecodeProtoOp : public OpKernel {
next_good_field_number =
fields_[last_good_field_index + 1]->number;
} else {
- // Saw something past the last field we care about.
- // Continue parsing the message just in case there
- // are disordered fields later, but any remaining
- // ordered fields will have no effect.
+ // Saw something past the last field we care about. Continue
+ // parsing the message just in case there are disordered fields
+ // later, but any remaining ordered fields will have no effect.
next_good_field_number = INT_MAX;
}
}
@@ -1077,20 +1028,20 @@ class DecodeProtoOp : public OpKernel {
WireFormatLite::WireType wire_type,
CodedInputStream* input, CollectorClass* collector) {
// The wire format library defines the same constants used in
- // descriptor.proto. This static_cast is safe because they
- // are guaranteed to stay in sync.
- // We need the field type from the FieldDescriptor here
- // because the wire format doesn't tell us anything about
- // what happens inside a packed repeated field: there is
- // enough information in the wire format to skip the
- // whole field but not enough to know how to parse what's
- // inside. For that we go to the schema.
+ // descriptor.proto. This static_cast is safe because they are guaranteed to
+ // stay in sync.
+ //
+ // We need the field type from the FieldDescriptor here because the wire
+ // format doesn't tell us anything about what happens inside a packed
+ // repeated field: there is enough information in the wire format to skip
+ // the whole field but not enough to know how to parse what's inside. For
+ // that we go to the schema.
WireFormatLite::WireType schema_wire_type =
WireFormatLite::WireTypeForFieldType(field.type);
- // Handle packed repeated fields. SkipField would skip the
- // whole length-delimited blob without letting us count the
- // values, so we have to scan them ourselves.
+ // Handle packed repeated fields. SkipField would skip the whole
+ // length-delimited blob without letting us count the values, so we have to
+ // scan them ourselves.
if (wire_type == WireFormatLite::WIRETYPE_LENGTH_DELIMITED &&
schema_wire_type != WireFormatLite::WIRETYPE_LENGTH_DELIMITED) {
// Handle packed repeated primitives.
@@ -1098,11 +1049,7 @@ class DecodeProtoOp : public OpKernel {
if (!input->ReadVarintSizeAsInt(&length)) {
return errors::DataLoss("CollectField: Failed reading packed size");
}
- Status st = collector->ReadPackedValues(input, field, length);
- if (!st.ok()) {
- return st;
- }
- return Status::OK();
+ return collector->ReadPackedValues(input, field, length);
}
// Read ordinary values, including strings, bytes, and messages.
@@ -1118,9 +1065,9 @@ class DecodeProtoOp : public OpKernel {
}
string message_type_;
- // Note that fields are sorted by increasing field number,
- // which is not in general the order given by the user-specified
- // field_names and output_types Op attributes.
+ // Note that fields are sorted by increasing field number, which is not in
+ // general the order given by the user-specified field_names and output_types
+ // Op attributes.
std::vector<std::unique_ptr<const FieldInfo>> fields_;
// Owned_desc_pool_ is null when using descriptor_source=local.
@@ -1131,12 +1078,12 @@ class DecodeProtoOp : public OpKernel {
// True if decoding binary format, false if decoding text format.
bool is_binary_;
- // True if the protos should be sanitized before parsing.
- // Enables the initial protobuf sanitizer, which is much
- // more expensive than the decoder. The flag defaults to true
- // but can be set to false for trusted sources.
- // TODO(nix): flip the default to false when the fast decoder
- // has passed security review.
+ // True if the protos should be sanitized before parsing. Enables the initial
+ // protobuf sanitizer, which is much more expensive than the decoder. The flag
+ // defaults to true but can be set to false for trusted sources.
+ //
+ // TODO(nix): Flip the default to false when the fast decoder has passed
+ // security review.
bool sanitize_;
TF_DISALLOW_COPY_AND_ASSIGN(DecodeProtoOp);
diff --git a/tensorflow/core/kernels/encode_proto_op.cc b/tensorflow/core/kernels/encode_proto_op.cc
index 3b02ae52a2..4a0c1943e5 100644
--- a/tensorflow/core/kernels/encode_proto_op.cc
+++ b/tensorflow/core/kernels/encode_proto_op.cc
@@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/util/proto/descriptors.h"
+#include "tensorflow/core/util/proto/proto_utils.h"
namespace tensorflow {
namespace {
@@ -42,9 +43,9 @@ using ::tensorflow::protobuf::internal::WireFormatLite;
using ::tensorflow::protobuf::io::CodedOutputStream;
using ::tensorflow::protobuf::io::StringOutputStream;
-// Computes the total serialized size for a packed repeated field.
-// For fixed-size types this can just multiply, but for variable-sized
-// types it has to iterate through the values in the tensor.
+// Computes the total serialized size for a packed repeated field. For
+// fixed-size types this can just multiply, but for variable-sized types it has
+// to iterate through the values in the tensor.
template <WireFormatLite::FieldType FieldType, typename TensorT>
size_t TotalPackedSize(const Tensor& input, int message_index, int size);
@@ -83,11 +84,11 @@ size_t TotalPackedSize<WireFormatLite::TYPE_INT64, int64>(const Tensor& input,
}
template <>
-size_t TotalPackedSize<WireFormatLite::TYPE_UINT64, int64>(const Tensor& input,
- int message_index,
- int size) {
+size_t TotalPackedSize<WireFormatLite::TYPE_UINT64, uint64>(const Tensor& input,
+ int message_index,
+ int size) {
size_t data_size = 0;
- auto input_t = input.flat_inner_dims<int64>();
+ auto input_t = input.flat_inner_dims<uint64>();
for (int64 i = 0; i < size; i++) {
data_size += WireFormatLite::UInt64Size(
input_t(static_cast<int64>(message_index), i));
@@ -96,6 +97,19 @@ size_t TotalPackedSize<WireFormatLite::TYPE_UINT64, int64>(const Tensor& input,
}
template <>
+size_t TotalPackedSize<WireFormatLite::TYPE_INT32, int64>(const Tensor& input,
+ int message_index,
+ int size) {
+ size_t data_size = 0;
+ auto input_t = input.flat_inner_dims<int64>();
+ for (int64 i = 0; i < size; i++) {
+ data_size += WireFormatLite::Int32Size(
+ input_t(static_cast<int64>(message_index), i));
+ }
+ return data_size;
+}
+
+template <>
size_t TotalPackedSize<WireFormatLite::TYPE_INT32, int32>(const Tensor& input,
int message_index,
int size) {
@@ -109,23 +123,20 @@ size_t TotalPackedSize<WireFormatLite::TYPE_INT32, int32>(const Tensor& input,
}
template <>
-size_t TotalPackedSize<WireFormatLite::TYPE_FIXED64, int64>(const Tensor& input,
- int message_index,
- int size) {
+size_t TotalPackedSize<WireFormatLite::TYPE_FIXED64, uint64>(
+ const Tensor& input, int message_index, int size) {
return size * WireFormatLite::kFixed64Size;
}
template <>
-size_t TotalPackedSize<WireFormatLite::TYPE_FIXED32, int64>(const Tensor& input,
- int message_index,
- int size) {
+size_t TotalPackedSize<WireFormatLite::TYPE_FIXED32, uint64>(
+ const Tensor& input, int message_index, int size) {
return size * WireFormatLite::kFixed32Size;
}
template <>
-size_t TotalPackedSize<WireFormatLite::TYPE_FIXED32, int32>(const Tensor& input,
- int message_index,
- int size) {
+size_t TotalPackedSize<WireFormatLite::TYPE_FIXED32, uint32>(
+ const Tensor& input, int message_index, int size) {
return size * WireFormatLite::kFixed32Size;
}
@@ -137,11 +148,11 @@ size_t TotalPackedSize<WireFormatLite::TYPE_BOOL, bool>(const Tensor& input,
}
template <>
-size_t TotalPackedSize<WireFormatLite::TYPE_UINT32, int64>(const Tensor& input,
- int message_index,
- int size) {
+size_t TotalPackedSize<WireFormatLite::TYPE_UINT32, uint64>(const Tensor& input,
+ int message_index,
+ int size) {
size_t data_size = 0;
- auto input_t = input.flat_inner_dims<int64>();
+ auto input_t = input.flat_inner_dims<uint64>();
for (int64 i = 0; i < size; i++) {
data_size += WireFormatLite::UInt32Size(
input_t(static_cast<int64>(message_index), i));
@@ -150,11 +161,11 @@ size_t TotalPackedSize<WireFormatLite::TYPE_UINT32, int64>(const Tensor& input,
}
template <>
-size_t TotalPackedSize<WireFormatLite::TYPE_UINT32, int32>(const Tensor& input,
- int message_index,
- int size) {
+size_t TotalPackedSize<WireFormatLite::TYPE_UINT32, uint32>(const Tensor& input,
+ int message_index,
+ int size) {
size_t data_size = 0;
- auto input_t = input.flat_inner_dims<int32>();
+ auto input_t = input.flat_inner_dims<uint32>();
for (int64 i = 0; i < size; i++) {
data_size += WireFormatLite::UInt32Size(
input_t(static_cast<int64>(message_index), i));
@@ -182,6 +193,12 @@ size_t TotalPackedSize<WireFormatLite::TYPE_SFIXED32, int32>(
}
template <>
+size_t TotalPackedSize<WireFormatLite::TYPE_SFIXED32, int64>(
+ const Tensor& input, int message_index, int size) {
+ return size * WireFormatLite::kSFixed32Size;
+}
+
+template <>
size_t TotalPackedSize<WireFormatLite::TYPE_SFIXED64, int64>(
const Tensor& input, int message_index, int size) {
return size * WireFormatLite::kSFixed64Size;
@@ -201,6 +218,19 @@ size_t TotalPackedSize<WireFormatLite::TYPE_SINT32, int32>(const Tensor& input,
}
template <>
+size_t TotalPackedSize<WireFormatLite::TYPE_SINT32, int64>(const Tensor& input,
+ int message_index,
+ int size) {
+ size_t data_size = 0;
+ auto input_t = input.flat_inner_dims<int64>();
+ for (int64 i = 0; i < size; i++) {
+ data_size += WireFormatLite::SInt32Size(
+ input_t(static_cast<int64>(message_index), i));
+ }
+ return data_size;
+}
+
+template <>
size_t TotalPackedSize<WireFormatLite::TYPE_SINT64, int64>(const Tensor& input,
int message_index,
int size) {
@@ -213,14 +243,13 @@ size_t TotalPackedSize<WireFormatLite::TYPE_SINT64, int64>(const Tensor& input,
return data_size;
}
-// Writes a possibly repeated primitive field.
-// TensorFlow does not have unsigned types, so we decode them to signed and
-// encode them back to unsigned.
+// Writes a possibly repeated primitive field. TensorFlow does not have unsigned
+// types, so we decode them to signed and encode them back to unsigned.
template <typename TensorT, typename ProtoT,
WireFormatLite::FieldType FieldType,
void Writer(ProtoT, CodedOutputStream*)>
-void WriteField(const FieldDescriptor& field_desc, const Tensor& input,
- int message_index, int size, CodedOutputStream* output) {
+Status WriteField(const FieldDescriptor& field_desc, const Tensor& input,
+ int message_index, int size, CodedOutputStream* output) {
auto wire_type = WireFormatLite::WireTypeForFieldType(
WireFormatLite::FieldType(field_desc.type()));
@@ -250,12 +279,14 @@ void WriteField(const FieldDescriptor& field_desc, const Tensor& input,
Writer(value, output);
}
}
+ return Status::OK();
}
// Writes a possibly repeated string, bytes, or message field.
template <typename T, void Writer(int, const T&, CodedOutputStream*)>
-void WriteVarLenField(const FieldDescriptor& field_desc, const Tensor& input,
- int message_index, int size, CodedOutputStream* output) {
+Status WriteVarLenField(const FieldDescriptor& field_desc, const Tensor& input,
+ int message_index, int size,
+ CodedOutputStream* output) {
auto input_t = input.flat_inner_dims<T>();
for (int64 i = 0; i < size; i++) {
const T& value = input_t(static_cast<int64>(message_index), i);
@@ -264,14 +295,14 @@ void WriteVarLenField(const FieldDescriptor& field_desc, const Tensor& input,
// small speedup.
Writer(field_desc.number(), value, output);
}
+ return Status::OK();
}
-// Writes a group field.
-// Groups are treated like submessages, but tag-delimited
-// instead of length-delimited. WireFormatLite handles this
-// differently so we code it ourselves.
-void WriteGroup(const FieldDescriptor& field_desc, const Tensor& input,
- int message_index, int size, CodedOutputStream* output) {
+// Writes a group field. Groups are treated like submessages, but tag-delimited
+// instead of length-delimited. WireFormatLite handles this differently so we
+// code it ourselves.
+Status WriteGroup(const FieldDescriptor& field_desc, const Tensor& input,
+ int message_index, int size, CodedOutputStream* output) {
auto input_t = input.flat_inner_dims<string>();
for (int64 i = 0; i < size; i++) {
const string& value = input_t(static_cast<int64>(message_index), i);
@@ -282,16 +313,16 @@ void WriteGroup(const FieldDescriptor& field_desc, const Tensor& input,
WireFormatLite::WriteTag(field_desc.number(),
WireFormatLite::WIRETYPE_END_GROUP, output);
}
+ return Status::OK();
}
-// Writes a (possibly repeated) field into an output stream.
-// It is the caller's responsibility to ensure that the type of
-// the input tensor is compatible with the type of the proto
-// field descriptor, and that (message_index, size-1) is within
-// bounds.
-void WriteField(const FieldDescriptor& field_desc, const Tensor& input,
- int message_index, int size, CodedOutputStream* output) {
- DataType tf_type = input.dtype();
+// Writes a (possibly repeated) field into an output stream. It is the caller's
+// responsibility to ensure that the type of the input tensor is compatible with
+// the type of the proto field descriptor, and that (message_index, size-1) is
+// within bounds.
+Status WriteField(const FieldDescriptor& field_desc, const Tensor& input,
+ int message_index, int size, CodedOutputStream* output) {
+ DataType dtype = input.dtype();
switch (field_desc.type()) {
case WireFormatLite::TYPE_DOUBLE:
@@ -299,7 +330,7 @@ void WriteField(const FieldDescriptor& field_desc, const Tensor& input,
WireFormatLite::WriteDoubleNoTag>(
field_desc, input, message_index, size, output);
case WireFormatLite::TYPE_FLOAT:
- switch (tf_type) {
+ switch (dtype) {
case DataType::DT_FLOAT:
return WriteField<float, float, WireFormatLite::TYPE_FLOAT,
WireFormatLite::WriteFloatNoTag>(
@@ -309,36 +340,48 @@ void WriteField(const FieldDescriptor& field_desc, const Tensor& input,
WireFormatLite::WriteFloatNoTag>(
field_desc, input, message_index, size, output);
default:
- return;
+ return errors::DataLoss("Failed writing TYPE_FLOAT for ",
+ DataTypeString(dtype));
}
case WireFormatLite::TYPE_INT64:
return WriteField<int64, protobuf_int64, WireFormatLite::TYPE_INT64,
WireFormatLite::WriteInt64NoTag>(
field_desc, input, message_index, size, output);
case WireFormatLite::TYPE_UINT64:
- return WriteField<int64, protobuf_uint64, WireFormatLite::TYPE_UINT64,
+ return WriteField<uint64, protobuf_uint64, WireFormatLite::TYPE_UINT64,
WireFormatLite::WriteUInt64NoTag>(
field_desc, input, message_index, size, output);
case WireFormatLite::TYPE_INT32:
- return WriteField<int32, int32, WireFormatLite::TYPE_INT32,
- WireFormatLite::WriteInt32NoTag>(
- field_desc, input, message_index, size, output);
+ switch (dtype) {
+ case DataType::DT_INT64:
+ return WriteField<int64, int32, WireFormatLite::TYPE_INT32,
+ WireFormatLite::WriteInt32NoTag>(
+ field_desc, input, message_index, size, output);
+ case DataType::DT_INT32:
+ return WriteField<int32, int32, WireFormatLite::TYPE_INT32,
+ WireFormatLite::WriteInt32NoTag>(
+ field_desc, input, message_index, size, output);
+ default:
+ return errors::DataLoss("Failed writing TYPE_INT32 for ",
+ DataTypeString(dtype));
+ }
case WireFormatLite::TYPE_FIXED64:
- return WriteField<int64, protobuf_uint64, WireFormatLite::TYPE_FIXED64,
+ return WriteField<uint64, protobuf_uint64, WireFormatLite::TYPE_FIXED64,
WireFormatLite::WriteFixed64NoTag>(
field_desc, input, message_index, size, output);
case WireFormatLite::TYPE_FIXED32:
- switch (tf_type) {
- case DataType::DT_INT64:
- return WriteField<int64, uint32, WireFormatLite::TYPE_FIXED32,
+ switch (dtype) {
+ case DataType::DT_UINT64:
+ return WriteField<uint64, uint32, WireFormatLite::TYPE_FIXED32,
WireFormatLite::WriteFixed32NoTag>(
field_desc, input, message_index, size, output);
- case DataType::DT_INT32:
- return WriteField<int32, uint32, WireFormatLite::TYPE_FIXED32,
+ case DataType::DT_UINT32:
+ return WriteField<uint32, uint32, WireFormatLite::TYPE_FIXED32,
WireFormatLite::WriteFixed32NoTag>(
field_desc, input, message_index, size, output);
default:
- return;
+ return errors::DataLoss("Failed writing TYPE_FIXED32 for ",
+ DataTypeString(dtype));
}
case WireFormatLite::TYPE_BOOL:
return WriteField<bool, bool, WireFormatLite::TYPE_BOOL,
@@ -356,34 +399,55 @@ void WriteField(const FieldDescriptor& field_desc, const Tensor& input,
return WriteVarLenField<string, WireFormatLite::WriteBytes>(
field_desc, input, message_index, size, output);
case WireFormatLite::TYPE_UINT32:
- switch (tf_type) {
- case DataType::DT_INT64:
- return WriteField<int64, uint32, WireFormatLite::TYPE_UINT32,
+ switch (dtype) {
+ case DataType::DT_UINT64:
+ return WriteField<uint64, uint32, WireFormatLite::TYPE_UINT32,
WireFormatLite::WriteUInt32NoTag>(
field_desc, input, message_index, size, output);
- case DataType::DT_INT32:
- return WriteField<int32, uint32, WireFormatLite::TYPE_UINT32,
+ case DataType::DT_UINT32:
+ return WriteField<uint32, uint32, WireFormatLite::TYPE_UINT32,
WireFormatLite::WriteUInt32NoTag>(
field_desc, input, message_index, size, output);
default:
- return;
+ return errors::DataLoss("Failed writing TYPE_UINT32 for ",
+ DataTypeString(dtype));
}
case WireFormatLite::TYPE_ENUM:
return WriteField<int32, int32, WireFormatLite::TYPE_ENUM,
WireFormatLite::WriteEnumNoTag>(
field_desc, input, message_index, size, output);
case WireFormatLite::TYPE_SFIXED32:
- return WriteField<int32, int32, WireFormatLite::TYPE_SFIXED32,
- WireFormatLite::WriteSFixed32NoTag>(
- field_desc, input, message_index, size, output);
+ switch (dtype) {
+ case DataType::DT_INT64:
+ return WriteField<int64, int32, WireFormatLite::TYPE_SFIXED32,
+ WireFormatLite::WriteSFixed32NoTag>(
+ field_desc, input, message_index, size, output);
+ case DataType::DT_INT32:
+ return WriteField<int32, int32, WireFormatLite::TYPE_SFIXED32,
+ WireFormatLite::WriteSFixed32NoTag>(
+ field_desc, input, message_index, size, output);
+ default:
+ return errors::DataLoss("Failed writing TYPE_SFIXED32 for ",
+ DataTypeString(dtype));
+ }
case WireFormatLite::TYPE_SFIXED64:
return WriteField<int64, protobuf_int64, WireFormatLite::TYPE_SFIXED64,
WireFormatLite::WriteSFixed64NoTag>(
field_desc, input, message_index, size, output);
case WireFormatLite::TYPE_SINT32:
- return WriteField<int32, int32, WireFormatLite::TYPE_SINT32,
- WireFormatLite::WriteSInt32NoTag>(
- field_desc, input, message_index, size, output);
+ switch (dtype) {
+ case DataType::DT_INT64:
+ return WriteField<int64, int32, WireFormatLite::TYPE_SINT32,
+ WireFormatLite::WriteSInt32NoTag>(
+ field_desc, input, message_index, size, output);
+ case DataType::DT_INT32:
+ return WriteField<int32, int32, WireFormatLite::TYPE_SINT32,
+ WireFormatLite::WriteSInt32NoTag>(
+ field_desc, input, message_index, size, output);
+ default:
+ return errors::DataLoss("Failed writing TYPE_SINT32 for ",
+ DataTypeString(dtype));
+ }
case WireFormatLite::TYPE_SINT64:
return WriteField<int64, protobuf_int64, WireFormatLite::TYPE_SINT64,
WireFormatLite::WriteSInt64NoTag>(
@@ -392,42 +456,6 @@ void WriteField(const FieldDescriptor& field_desc, const Tensor& input,
}
}
-// Checks that a Protobuf field is compatible with a TensorFlow datatype.
-// This is separated from WriteField to lift it out of the inner loop.
-bool IsCompatibleType(const FieldDescriptor& field_desc, DataType tf_type) {
- switch (field_desc.type()) {
- case WireFormatLite::TYPE_DOUBLE:
- return tf_type == DataType::DT_DOUBLE;
- case WireFormatLite::TYPE_FLOAT:
- return tf_type == DataType::DT_FLOAT || tf_type == DataType::DT_DOUBLE;
- case WireFormatLite::TYPE_INT64:
- case WireFormatLite::TYPE_SFIXED64:
- case WireFormatLite::TYPE_SINT64:
- return tf_type == DataType::DT_INT64;
- case WireFormatLite::TYPE_UINT64:
- return tf_type == DataType::DT_INT64;
- case WireFormatLite::TYPE_INT32:
- case WireFormatLite::TYPE_ENUM:
- case WireFormatLite::TYPE_SFIXED32:
- case WireFormatLite::TYPE_SINT32:
- return tf_type == DataType::DT_INT32;
- case WireFormatLite::TYPE_FIXED64:
- return tf_type == DataType::DT_INT64;
- case WireFormatLite::TYPE_FIXED32:
- case WireFormatLite::TYPE_UINT32:
- return tf_type == DataType::DT_INT64 || tf_type == DataType::DT_INT32;
- case WireFormatLite::TYPE_BOOL:
- return tf_type == DataType::DT_BOOL;
- case WireFormatLite::TYPE_STRING:
- case WireFormatLite::TYPE_GROUP:
- case WireFormatLite::TYPE_MESSAGE:
- case WireFormatLite::TYPE_BYTES:
- return tf_type == DataType::DT_STRING;
- // default: intentionally omitted in order to enable static checking.
- }
- return false;
-}
-
class EncodeProtoOp : public OpKernel {
public:
explicit EncodeProtoOp(OpKernelConstruction* context) : OpKernel(context) {
@@ -475,14 +503,14 @@ class EncodeProtoOp : public OpKernel {
});
}
- void Compute(OpKernelContext* cx) override {
+ void Compute(OpKernelContext* ctx) override {
const Tensor* sizes_tensor;
- OP_REQUIRES_OK(cx, cx->input("sizes", &sizes_tensor));
+ OP_REQUIRES_OK(ctx, ctx->input("sizes", &sizes_tensor));
OpInputList values;
- OP_REQUIRES_OK(cx, cx->input_list("values", &values));
+ OP_REQUIRES_OK(ctx, ctx->input_list("values", &values));
- OP_REQUIRES(cx, field_descs_.size() == values.size(),
+ OP_REQUIRES(ctx, field_descs_.size() == values.size(),
errors::InvalidArgument(
"Length of inputs list must match field_names"));
@@ -493,12 +521,14 @@ class EncodeProtoOp : public OpKernel {
const Tensor& v = values[i];
// The type of each value tensor must match the corresponding field.
- OP_REQUIRES(cx, IsCompatibleType(*field_descs_[i], v.dtype()),
- errors::InvalidArgument(
- "Incompatible type for field " + field_names_[i] +
- ". Saw dtype: ",
- DataTypeString(v.dtype()),
- " but field type is: ", field_descs_[i]->type_name()));
+ OP_REQUIRES(
+ ctx,
+ proto_utils::IsCompatibleType(field_descs_[i]->type(), v.dtype()),
+ errors::InvalidArgument(
+ "Incompatible type for field " + field_names_[i] +
+ ". Saw dtype: ",
+ DataTypeString(v.dtype()),
+ " but field type is: ", field_descs_[i]->type_name()));
// All value tensors must have the same shape prefix (i.e. batch size).
TensorShape shape_prefix = v.shape();
@@ -507,14 +537,14 @@ class EncodeProtoOp : public OpKernel {
// Do some initialization on the first input value. The rest will
// have to match this one.
if (i == 0) {
- OP_REQUIRES(cx, v.dims() >= 1,
+ OP_REQUIRES(ctx, v.dims() >= 1,
errors::InvalidArgument(
"Expected value to be at least a vector, saw shape: ",
v.shape().DebugString()));
common_prefix = shape_prefix;
message_count = common_prefix.num_elements();
} else {
- OP_REQUIRES(cx, shape_prefix == common_prefix,
+ OP_REQUIRES(ctx, shape_prefix == common_prefix,
errors::InvalidArgument(
"Values must match up to the last dimension"));
}
@@ -523,7 +553,7 @@ class EncodeProtoOp : public OpKernel {
TensorShape expected_sizes_shape = common_prefix;
expected_sizes_shape.AddDim(field_descs_.size());
- OP_REQUIRES(cx, sizes_tensor->shape() == expected_sizes_shape,
+ OP_REQUIRES(ctx, sizes_tensor->shape() == expected_sizes_shape,
errors::InvalidArgument(
"sizes should be batch_size + [len(field_names)]. Saw: ",
sizes_tensor->shape().DebugString(),
@@ -536,12 +566,11 @@ class EncodeProtoOp : public OpKernel {
int max_size = v.dim_size(v.dims() - 1);
// The last dimension of a value tensor must be greater than the
- // corresponding
- // size in the sizes tensor.
+ // corresponding size in the sizes tensor.
for (int message_index = 0; message_index < message_count;
message_index++) {
OP_REQUIRES(
- cx, sizes(message_index, i) <= max_size,
+ ctx, sizes(message_index, i) <= max_size,
errors::InvalidArgument(
"Size to write must not be larger than value tensor; but saw: ",
sizes(message_index, i), " > ", max_size, " at message ",
@@ -551,13 +580,13 @@ class EncodeProtoOp : public OpKernel {
// This pointer is owned by the context.
Tensor* output_tensor;
- OP_REQUIRES_OK(cx, cx->allocate_output(0, common_prefix, &output_tensor));
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, common_prefix, &output_tensor));
auto bufs = output_tensor->flat<string>();
for (int message_index = 0; message_index < message_count;
message_index++) {
// TODO(nix): possibly optimize allocation here by calling
- // bufs(message_index).reserve(DEFAULT_BUF_SIZE);
+ // `bufs(message_index).reserve(DEFAULT_BUF_SIZE)`.
StringOutputStream output_string(&bufs(message_index));
CodedOutputStream out(&output_string);
// Write fields in ascending field_number order.
@@ -566,7 +595,8 @@ class EncodeProtoOp : public OpKernel {
const Tensor& v = values[i];
int size = sizes(message_index, i);
if (!size) continue;
- WriteField(field_desc, v, message_index, size, &out);
+ OP_REQUIRES_OK(ctx,
+ WriteField(field_desc, v, message_index, size, &out));
}
}
}
@@ -578,8 +608,8 @@ class EncodeProtoOp : public OpKernel {
// Owned_desc_pool_ is null when using descriptor_source=local.
std::unique_ptr<DescriptorPool> owned_desc_pool_;
- // Contains indices into field_names_, sorted by field number since
- // that's the order of writing.
+ // Contains indices into field_names_, sorted by field number since that's the
+ // order of writing.
std::vector<int> sorted_field_index_;
TF_DISALLOW_COPY_AND_ASSIGN(EncodeProtoOp);
diff --git a/tensorflow/core/kernels/identity_op.cc b/tensorflow/core/kernels/identity_op.cc
index dffb4d7171..6f79729883 100644
--- a/tensorflow/core/kernels/identity_op.cc
+++ b/tensorflow/core/kernels/identity_op.cc
@@ -145,6 +145,7 @@ REGISTER_GPU_KERNEL(Variant);
REGISTER_GPU_HOST_KERNEL(int32);
REGISTER_GPU_HOST_KERNEL(bool);
REGISTER_GPU_HOST_KERNEL(string);
+REGISTER_GPU_HOST_KERNEL(ResourceHandle);
#undef REGISTER_GPU_HOST_KERNEL
diff --git a/tensorflow/core/kernels/lookup_table_op.cc b/tensorflow/core/kernels/lookup_table_op.cc
index 57b7798ba0..07e754a6ef 100644
--- a/tensorflow/core/kernels/lookup_table_op.cc
+++ b/tensorflow/core/kernels/lookup_table_op.cc
@@ -822,6 +822,7 @@ REGISTER_KERNEL(int64, float);
REGISTER_KERNEL(string, string);
REGISTER_KERNEL(string, bool);
REGISTER_KERNEL(int32, int32);
+REGISTER_KERNEL(int32, string);
#undef REGISTER_KERNEL
diff --git a/tensorflow/core/kernels/scatter_nd_op.cc b/tensorflow/core/kernels/scatter_nd_op.cc
index e1fc2ea128..c44753e25e 100644
--- a/tensorflow/core/kernels/scatter_nd_op.cc
+++ b/tensorflow/core/kernels/scatter_nd_op.cc
@@ -277,6 +277,9 @@ TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_ADD_SUB_CPU);
TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_UPDATE_CPU);
TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_CPU);
TF_CALL_string(REGISTER_SCATTER_ND_CPU);
+TF_CALL_bool(REGISTER_SCATTER_ND_ADD_SUB_CPU);
+TF_CALL_bool(REGISTER_SCATTER_ND_UPDATE_CPU);
+TF_CALL_bool(REGISTER_SCATTER_ND_CPU);
// Registers GPU kernels.
#if GOOGLE_CUDA
@@ -309,6 +312,7 @@ TF_CALL_complex128(REGISTER_SCATTER_ND_ALL_GPU);
TF_CALL_int32(REGISTER_SCATTER_ND_ADD_SUB_SYCL);
TF_CALL_int32(REGISTER_SCATTER_ND_UPDATE_SYCL);
+TF_CALL_bool(REGISTER_SCATTER_ND_UPDATE_SYCL);
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_ADD_SUB_SYCL);
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_UPDATE_SYCL);
#undef REGISTER_SCATTER_ND_ADD_SUB_SYCL
diff --git a/tensorflow/core/kernels/scatter_nd_op_cpu_impl.h b/tensorflow/core/kernels/scatter_nd_op_cpu_impl.h
index 7cfffa20c5..472f5a3547 100644
--- a/tensorflow/core/kernels/scatter_nd_op_cpu_impl.h
+++ b/tensorflow/core/kernels/scatter_nd_op_cpu_impl.h
@@ -161,15 +161,16 @@ struct ScatterNdFunctor<CPUDevice, T, Index, OP, IXDIM> {
TF_CALL_ALL_TYPES(REGISTER_SCATTER_ND_UPDATE);
REGISTER_SCATTER_ND_INDEX(string, scatter_nd_op::UpdateOp::ADD);
-TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_MATH)
-
+TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_MATH);
+TF_CALL_bool(REGISTER_SCATTER_ND_MATH);
#undef REGISTER_SCATTER_ND_MATH
#undef REGISTER_SCATTER_ND_UPDATE
#undef REGISTER_SCATTER_ND_INDEX
#undef REGISTER_SCATTER_ND_FULL
-#ifdef TENSORFLOW_USE_SYCL
// Implementation of update functor for SYCL.
+#ifdef TENSORFLOW_USE_SYCL
+
template <typename T, typename Index, scatter_nd_op::UpdateOp OP, int IXDIM>
struct ScatterNdFunctor<SYCLDevice, T, Index, OP, IXDIM> {
Index operator()(
diff --git a/tensorflow/core/lib/core/refcount.h b/tensorflow/core/lib/core/refcount.h
index eb41f9ff36..87bcfec411 100644
--- a/tensorflow/core/lib/core/refcount.h
+++ b/tensorflow/core/lib/core/refcount.h
@@ -17,6 +17,8 @@ limitations under the License.
#define TENSORFLOW_LIB_CORE_REFCOUNT_H_
#include <atomic>
+#include <memory>
+
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
@@ -58,6 +60,15 @@ class RefCounted {
void operator=(const RefCounted&) = delete;
};
+// A deleter class to form a std::unique_ptr that unrefs objects.
+struct RefCountDeleter {
+ void operator()(tensorflow::core::RefCounted* o) const { o->Unref(); }
+};
+
+// A unique_ptr that unrefs the owned object on destruction.
+template <typename T>
+using RefCountPtr = std::unique_ptr<T, RefCountDeleter>;
+
// Helper class to unref an object when out-of-scope.
class ScopedUnref {
public:
diff --git a/tensorflow/core/lib/io/record_reader_writer_test.cc b/tensorflow/core/lib/io/record_reader_writer_test.cc
index 95ac040602..c36c909399 100644
--- a/tensorflow/core/lib/io/record_reader_writer_test.cc
+++ b/tensorflow/core/lib/io/record_reader_writer_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/lib/io/record_reader.h"
#include "tensorflow/core/lib/io/record_writer.h"
+#include <zlib.h>
#include <vector>
#include "tensorflow/core/platform/env.h"
@@ -33,6 +34,89 @@ static std::vector<int> BufferSizes() {
12, 13, 14, 15, 16, 17, 18, 19, 20, 65536};
}
+namespace {
+
+io::RecordReaderOptions GetMatchingReaderOptions(
+ const io::RecordWriterOptions& options) {
+ if (options.compression_type == io::RecordWriterOptions::ZLIB_COMPRESSION) {
+ return io::RecordReaderOptions::CreateRecordReaderOptions("ZLIB");
+ }
+ return io::RecordReaderOptions::CreateRecordReaderOptions("");
+}
+
+uint64 GetFileSize(const string& fname) {
+ Env* env = Env::Default();
+ uint64 fsize;
+ TF_CHECK_OK(env->GetFileSize(fname, &fsize));
+ return fsize;
+}
+
+void VerifyFlush(const io::RecordWriterOptions& options) {
+ std::vector<string> records = {
+ "abcdefghijklmnopqrstuvwxyz",
+ "ZYXWVUTSRQPONMLKJIHGFEDCBA0123456789!@#$%^&*()",
+ "G5SyohOL9UmXofSOOwWDrv9hoLLMYPJbG9r38t3uBRcHxHj2PdKcPDuZmKW62RIY",
+ "aaaaaaaaaaaaaaaaaaaaaaaaaa",
+ };
+
+ Env* env = Env::Default();
+ string fname = testing::TmpDir() + "/record_reader_writer_flush_test";
+
+ std::unique_ptr<WritableFile> file;
+ TF_CHECK_OK(env->NewWritableFile(fname, &file));
+ io::RecordWriter writer(file.get(), options);
+
+ std::unique_ptr<RandomAccessFile> read_file;
+ TF_CHECK_OK(env->NewRandomAccessFile(fname, &read_file));
+ io::RecordReaderOptions read_options = GetMatchingReaderOptions(options);
+ io::RecordReader reader(read_file.get(), read_options);
+
+ EXPECT_EQ(GetFileSize(fname), 0);
+ for (size_t i = 0; i < records.size(); i++) {
+ uint64 start_size = GetFileSize(fname);
+
+ // Write a new record.
+ TF_EXPECT_OK(writer.WriteRecord(records[i]));
+ TF_CHECK_OK(writer.Flush());
+ TF_CHECK_OK(file->Flush());
+
+ // Verify that file size has changed after file flush.
+ uint64 new_size = GetFileSize(fname);
+ EXPECT_GT(new_size, start_size);
+
+ // Verify that file has all records written so far and no more.
+ uint64 offset = 0;
+ string record;
+ for (size_t j = 0; j <= i; j++) {
+ // Check that j'th record is written correctly.
+ TF_CHECK_OK(reader.ReadRecord(&offset, &record));
+ EXPECT_EQ(record, records[j]);
+ }
+
+ // Verify that file has no more records.
+ CHECK_EQ(reader.ReadRecord(&offset, &record).code(), error::OUT_OF_RANGE);
+ }
+}
+
+} // namespace
+
+TEST(RecordReaderWriterTest, TestFlush) {
+ io::RecordWriterOptions options;
+ VerifyFlush(options);
+}
+
+TEST(RecordReaderWriterTest, TestZlibSyncFlush) {
+ io::RecordWriterOptions options;
+ options.compression_type = io::RecordWriterOptions::ZLIB_COMPRESSION;
+ // The default flush_mode is Z_NO_FLUSH and only writes to the file when the
+ // buffer is full or the file is closed, which makes testing harder.
+ // By using Z_SYNC_FLUSH the test can verify Flush does write out records of
+ // approximately the right size at the right times.
+ options.zlib_options.flush_mode = Z_SYNC_FLUSH;
+
+ VerifyFlush(options);
+}
+
TEST(RecordReaderWriterTest, TestBasics) {
Env* env = Env::Default();
string fname = testing::TmpDir() + "/record_reader_writer_test";
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index fce0b93cd7..d6ae75473f 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -2549,14 +2549,16 @@ REGISTER_OP("ExtractImagePatches")
REGISTER_OP("Bitcast")
.Input("input: T")
.Output("output: type")
- // All supported dtypes are listed here to include qint16 and quint16.
+ // All supported dtypes are listed here to include qint16, quint16, uint32,
+ // and uint64.
.Attr(
- "T: {bfloat16, half, float, double, int64, int32, uint8, uint16, int8, "
- "int16, complex64, complex128, qint8, quint8, qint16, quint16, qint32}")
+ "T: {bfloat16, half, float, double, int64, int32, uint8, uint16, "
+ "uint32, uint64, int8, int16, complex64, complex128, qint8, quint8, "
+ "qint16, quint16, qint32}")
.Attr(
"type: {bfloat16, half, float, double, int64, int32, uint8, uint16, "
- "int8, int16, complex64, complex128, qint8, quint8, qint16, quint16, "
- "qint32}")
+ "uint32, uint64, int8, int16, complex64, complex128, qint8, quint8, "
+ "qint16, quint16, qint32}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle input = c->input(0);
if (!c->RankKnown(input)) {
@@ -2879,7 +2881,7 @@ REGISTER_OP("ScatterNdNonAliasingAdd")
.Input("indices: Tindices")
.Input("updates: T")
.Output("output: T")
- .Attr("T: numbertype")
+ .Attr("T: {numbertype, bool}")
.Attr("Tindices: {int32, int64}")
.SetShapeFn(shape_inference::ScatterNdUpdateShape);
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index d94fa2cad7..69351cd392 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -11046,6 +11046,71 @@ op {
}
}
op {
+ name: "Bitcast"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "type"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT64
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_UINT16
+ type: DT_UINT32
+ type: DT_UINT64
+ type: DT_INT8
+ type: DT_INT16
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT16
+ type: DT_QUINT16
+ type: DT_QINT32
+ }
+ }
+ }
+ attr {
+ name: "type"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT64
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_UINT16
+ type: DT_UINT32
+ type: DT_UINT64
+ type: DT_INT8
+ type: DT_INT16
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT16
+ type: DT_QUINT16
+ type: DT_QINT32
+ }
+ }
+ }
+}
+op {
name: "BitwiseAnd"
input_arg {
name: "x"
@@ -55178,6 +55243,61 @@ op {
}
}
op {
+ name: "ScatterNdNonAliasingAdd"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "indices"
+ type_attr: "Tindices"
+ }
+ input_arg {
+ name: "updates"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_INT64
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_COMPLEX128
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ type: DT_BOOL
+ }
+ }
+ }
+ attr {
+ name: "Tindices"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "ScatterNdSub"
input_arg {
name: "ref"
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 4f24ab480f..978bb0bbf4 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -4132,6 +4132,8 @@ op {
type: DT_INT32
type: DT_UINT8
type: DT_UINT16
+ type: DT_UINT32
+ type: DT_UINT64
type: DT_INT8
type: DT_INT16
type: DT_COMPLEX64
@@ -4157,6 +4159,8 @@ op {
type: DT_INT32
type: DT_UINT8
type: DT_UINT16
+ type: DT_UINT32
+ type: DT_UINT64
type: DT_INT8
type: DT_INT16
type: DT_COMPLEX64
@@ -26178,6 +26182,7 @@ op {
type: DT_HALF
type: DT_UINT32
type: DT_UINT64
+ type: DT_BOOL
}
}
}
diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto
index 77639461d9..22a2691dcc 100644
--- a/tensorflow/core/protobuf/config.proto
+++ b/tensorflow/core/protobuf/config.proto
@@ -145,7 +145,8 @@ message GPUOptions {
bool use_unified_memory = 2;
// If > 1, the number of device-to-device copy streams to create
- // for each GPUDevice.
+ // for each GPUDevice. Default value is 0, which is automatically
+ // converted to 1.
int32 num_dev_to_dev_copy_streams = 3;
}
diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h
index cb1fd09dbb..cea5e8ffb0 100644
--- a/tensorflow/core/public/version.h
+++ b/tensorflow/core/public/version.h
@@ -24,7 +24,7 @@ limitations under the License.
// TF_VERSION_SUFFIX is non-empty for pre-releases (e.g. "-alpha", "-alpha.1",
// "-beta", "-rc", "-rc.1")
-#define TF_VERSION_SUFFIX "-rc0"
+#define TF_VERSION_SUFFIX ""
#define TF_STR_HELPER(x) #x
#define TF_STR(x) TF_STR_HELPER(x)
diff --git a/tensorflow/core/util/proto/BUILD b/tensorflow/core/util/proto/BUILD
index ade14ed162..7e549c7764 100644
--- a/tensorflow/core/util/proto/BUILD
+++ b/tensorflow/core/util/proto/BUILD
@@ -60,3 +60,13 @@ cc_library(
],
alwayslink = 1,
)
+
+cc_library(
+ name = "proto_utils",
+ srcs = ["proto_utils.cc"],
+ hdrs = ["proto_utils.h"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ ],
+)
diff --git a/tensorflow/core/util/proto/decode.h b/tensorflow/core/util/proto/decode.h
index 74634a356a..cbcb203ee7 100644
--- a/tensorflow/core/util/proto/decode.h
+++ b/tensorflow/core/util/proto/decode.h
@@ -27,6 +27,7 @@ limitations under the License.
#define TENSORFLOW_CORE_UTIL_PROTO_DECODE_H_
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h"
@@ -103,6 +104,16 @@ template <class TensorType, enum WireFormatLite::FieldType DeclaredType>
const uint8* ReadFromArray(const uint8* buf, TensorType* value);
template <>
+inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_INT32>(
+ const uint8* buf, int64* value) {
+ uint32 temp;
+ bool unused_ok; // The Counting pass would have failed if this were corrupt.
+ buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
+ *value = static_cast<int64>(temp);
+ return buf;
+}
+
+template <>
inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_INT32>(
const uint8* buf, int32* value) {
uint32 temp;
@@ -123,8 +134,8 @@ inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_INT64>(
}
template <>
-inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_UINT32>(
- const uint8* buf, int64* value) {
+inline const uint8* ReadFromArray<uint64, WireFormatLite::TYPE_UINT32>(
+ const uint8* buf, uint64* value) {
uint32 temp;
bool unused_ok; // The Counting pass would have failed if this were corrupt.
buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
@@ -133,22 +144,26 @@ inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_UINT32>(
}
template <>
-inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_UINT32>(
- const uint8* buf, int32* value) {
- uint32 temp;
+inline const uint8* ReadFromArray<uint32, WireFormatLite::TYPE_UINT32>(
+ const uint8* buf, uint32* value) {
bool unused_ok; // The Counting pass would have failed if this were corrupt.
- buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
- *value = WrapUnsignedAsSigned32(temp);
- return buf;
+ return ReadVarint32FromArray(buf, &unused_ok, value);
+}
+
+template <>
+inline const uint8* ReadFromArray<uint64, WireFormatLite::TYPE_UINT64>(
+ const uint8* buf, uint64* value) {
+ bool unused_ok; // The Counting pass would have failed if this were corrupt.
+ return ReadVarint64FromArray(buf, &unused_ok, value);
}
template <>
-inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_UINT64>(
+inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_SINT32>(
const uint8* buf, int64* value) {
uint64 temp;
bool unused_ok; // The Counting pass would have failed if this were corrupt.
buf = ReadVarint64FromArray(buf, &unused_ok, &temp);
- *value = static_cast<int64>(temp);
+ *value = WireFormatLite::ZigZagDecode32(temp);
return buf;
}
@@ -173,8 +188,8 @@ inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_SINT64>(
}
template <>
-inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_FIXED32>(
- const uint8* buf, int64* value) {
+inline const uint8* ReadFromArray<uint64, WireFormatLite::TYPE_FIXED32>(
+ const uint8* buf, uint64* value) {
uint32 temp;
buf = WireFormatLite::ReadPrimitiveFromArray<uint32,
WireFormatLite::TYPE_FIXED32>(
@@ -184,8 +199,8 @@ inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_FIXED32>(
}
template <>
-inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_FIXED32>(
- const uint8* buf, int32* value) {
+inline const uint8* ReadFromArray<uint32, WireFormatLite::TYPE_FIXED32>(
+ const uint8* buf, uint32* value) {
uint32 temp;
buf = WireFormatLite::ReadPrimitiveFromArray<uint32,
WireFormatLite::TYPE_FIXED32>(
@@ -195,8 +210,8 @@ inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_FIXED32>(
}
template <>
-inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_FIXED64>(
- const uint8* buf, int64* value) {
+inline const uint8* ReadFromArray<uint64, WireFormatLite::TYPE_FIXED64>(
+ const uint8* buf, uint64* value) {
protobuf_uint64 temp;
buf = WireFormatLite::ReadPrimitiveFromArray<protobuf_uint64,
WireFormatLite::TYPE_FIXED64>(
@@ -206,6 +221,17 @@ inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_FIXED64>(
}
template <>
+inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_SFIXED32>(
+ const uint8* buf, int64* value) {
+ int32 temp;
+ buf = WireFormatLite::ReadPrimitiveFromArray<int32,
+ WireFormatLite::TYPE_SFIXED32>(
+ buf, &temp);
+ *value = temp;
+ return buf;
+}
+
+template <>
inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_SFIXED32>(
const uint8* buf, int32* value) {
return WireFormatLite::ReadPrimitiveFromArray<int32,
@@ -233,6 +259,17 @@ inline const uint8* ReadFromArray<float, WireFormatLite::TYPE_FLOAT>(
}
template <>
+inline const uint8* ReadFromArray<double, WireFormatLite::TYPE_FLOAT>(
+ const uint8* buf, double* value) {
+ float temp;
+ buf =
+ WireFormatLite::ReadPrimitiveFromArray<float, WireFormatLite::TYPE_FLOAT>(
+ buf, &temp);
+ *value = temp;
+ return buf;
+}
+
+template <>
inline const uint8* ReadFromArray<double, WireFormatLite::TYPE_DOUBLE>(
const uint8* buf, double* value) {
return WireFormatLite::ReadPrimitiveFromArray<double,
@@ -334,48 +371,56 @@ inline Status ReadGroupBytes(CodedInputStream* input, int field_number,
inline Status ReadValue(CodedInputStream* input,
WireFormatLite::FieldType field_type, int field_number,
DataType dtype, int index, void* datap) {
- // Dispatch to the appropriately typed field reader based on the
- // schema type.
+ // Dispatch to the appropriately typed field reader based on the schema type.
switch (field_type) {
case WireFormatLite::TYPE_DOUBLE:
return ReadPrimitive<double, double, WireFormatLite::TYPE_DOUBLE>(
input, index, datap);
case WireFormatLite::TYPE_FLOAT:
- if (dtype == DataType::DT_FLOAT) {
- return ReadPrimitive<float, float, WireFormatLite::TYPE_FLOAT>(
- input, index, datap);
- }
- if (dtype == DataType::DT_DOUBLE) {
- return ReadPrimitive<float, double, WireFormatLite::TYPE_FLOAT>(
- input, index, datap);
+ switch (dtype) {
+ case DataType::DT_DOUBLE:
+ return ReadPrimitive<float, double, WireFormatLite::TYPE_FLOAT>(
+ input, index, datap);
+ case DataType::DT_FLOAT:
+ return ReadPrimitive<float, float, WireFormatLite::TYPE_FLOAT>(
+ input, index, datap);
+ default:
+ return errors::DataLoss("Failed reading TYPE_FLOAT for ",
+ DataTypeString(dtype));
}
- // Any case that reaches this point should have triggered an error
- // already.
- return errors::DataLoss("Failed reading TYPE_FLOAT");
case WireFormatLite::TYPE_INT64:
return ReadPrimitive<protobuf_int64, int64, WireFormatLite::TYPE_INT64>(
input, index, datap);
case WireFormatLite::TYPE_UINT64:
- return ReadPrimitive<protobuf_uint64, int64, WireFormatLite::TYPE_UINT64>(
- input, index, datap);
+ return ReadPrimitive<protobuf_uint64, uint64,
+ WireFormatLite::TYPE_UINT64>(input, index, datap);
case WireFormatLite::TYPE_INT32:
- return ReadPrimitive<int32, int32, WireFormatLite::TYPE_INT32>(
- input, index, datap);
+ switch (dtype) {
+ case DataType::DT_INT64:
+ return ReadPrimitive<int32, int64, WireFormatLite::TYPE_INT32>(
+ input, index, datap);
+ case DataType::DT_INT32:
+ return ReadPrimitive<int32, int32, WireFormatLite::TYPE_INT32>(
+ input, index, datap);
+ default:
+ return errors::DataLoss("Failed reading TYPE_INT32 for ",
+ DataTypeString(dtype));
+ }
case WireFormatLite::TYPE_FIXED64:
- return ReadPrimitive<protobuf_uint64, int64,
+ return ReadPrimitive<protobuf_uint64, uint64,
WireFormatLite::TYPE_FIXED64>(input, index, datap);
case WireFormatLite::TYPE_FIXED32:
- if (dtype == DataType::DT_INT64) {
- return ReadPrimitive<uint32, int64, WireFormatLite::TYPE_FIXED32>(
- input, index, datap);
- }
- if (dtype == DataType::DT_INT32) {
- return ReadPrimitive<uint32, int32, WireFormatLite::TYPE_FIXED32>(
- input, index, datap);
+ switch (dtype) {
+ case DataType::DT_UINT64:
+ return ReadPrimitive<uint32, uint64, WireFormatLite::TYPE_FIXED32>(
+ input, index, datap);
+ case DataType::DT_UINT32:
+ return ReadPrimitive<uint32, uint32, WireFormatLite::TYPE_FIXED32>(
+ input, index, datap);
+ default:
+ return errors::DataLoss("Failed reading TYPE_FIXED32 for ",
+ DataTypeString(dtype));
}
- // Any case that reaches this point should have triggered an error
- // already.
- return errors::DataLoss("Failed reading TYPE_FIXED32");
case WireFormatLite::TYPE_BOOL:
return ReadPrimitive<bool, bool, WireFormatLite::TYPE_BOOL>(input, index,
datap);
@@ -388,29 +433,47 @@ inline Status ReadValue(CodedInputStream* input,
case WireFormatLite::TYPE_BYTES:
return ReadBytes(input, index, datap);
case WireFormatLite::TYPE_UINT32:
- if (dtype == DataType::DT_INT64) {
- return ReadPrimitive<uint32, int64, WireFormatLite::TYPE_UINT32>(
- input, index, datap);
+ switch (dtype) {
+ case DataType::DT_UINT64:
+ return ReadPrimitive<uint32, uint64, WireFormatLite::TYPE_UINT32>(
+ input, index, datap);
+ case DataType::DT_UINT32:
+ return ReadPrimitive<uint32, uint32, WireFormatLite::TYPE_UINT32>(
+ input, index, datap);
+ default:
+ return errors::DataLoss("Failed reading TYPE_UINT32 for ",
+ DataTypeString(dtype));
}
- if (dtype == DataType::DT_INT32) {
- return ReadPrimitive<uint32, int32, WireFormatLite::TYPE_UINT32>(
- input, index, datap);
- }
- // Any case that reaches this point should have triggered an error
- // already.
- return errors::DataLoss("Failed reading TYPE_UINT32");
case WireFormatLite::TYPE_ENUM:
return ReadPrimitive<int32, int32, WireFormatLite::TYPE_ENUM>(
input, index, datap);
case WireFormatLite::TYPE_SFIXED32:
- return ReadPrimitive<int32, int32, WireFormatLite::TYPE_SFIXED32>(
- input, index, datap);
+ switch (dtype) {
+ case DataType::DT_INT64:
+ return ReadPrimitive<int32, int64, WireFormatLite::TYPE_SFIXED32>(
+ input, index, datap);
+ case DataType::DT_INT32:
+ return ReadPrimitive<int32, int32, WireFormatLite::TYPE_SFIXED32>(
+ input, index, datap);
+ default:
+ return errors::DataLoss("Failed reading TYPE_SFIXED32 for ",
+ DataTypeString(dtype));
+ }
case WireFormatLite::TYPE_SFIXED64:
return ReadPrimitive<protobuf_int64, int64,
WireFormatLite::TYPE_SFIXED64>(input, index, datap);
case WireFormatLite::TYPE_SINT32:
- return ReadPrimitive<int32, int32, WireFormatLite::TYPE_SINT32>(
- input, index, datap);
+ switch (dtype) {
+ case DataType::DT_INT64:
+ return ReadPrimitive<int32, int64, WireFormatLite::TYPE_SINT32>(
+ input, index, datap);
+ case DataType::DT_INT32:
+ return ReadPrimitive<int32, int32, WireFormatLite::TYPE_SINT32>(
+ input, index, datap);
+ default:
+ return errors::DataLoss("Failed reading TYPE_SINT32 for ",
+ DataTypeString(dtype));
+ }
case WireFormatLite::TYPE_SINT64:
return ReadPrimitive<protobuf_int64, int64, WireFormatLite::TYPE_SINT64>(
input, index, datap);
@@ -425,47 +488,66 @@ inline Status ReadPackedFromArray(const void* buf, size_t buf_size,
const WireFormatLite::FieldType field_type,
const int field_number, const DataType dtype,
const int stride, int* index, void* data) {
- // Dispatch to the appropriately typed field reader based on the
- // schema type.
+ // Dispatch to the appropriately typed field reader based on the schema type.
switch (field_type) {
case WireFormatLite::TYPE_DOUBLE:
*index += ReadPackedPrimitives<double, WireFormatLite::TYPE_DOUBLE>(
buf, buf_size, *index, stride, data);
return Status::OK();
case WireFormatLite::TYPE_FLOAT:
- *index += ReadPackedPrimitives<float, WireFormatLite::TYPE_FLOAT>(
- buf, buf_size, *index, stride, data);
- return Status::OK();
+ switch (dtype) {
+ case DataType::DT_DOUBLE:
+ *index += ReadPackedPrimitives<double, WireFormatLite::TYPE_FLOAT>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ case DataType::DT_FLOAT:
+ *index += ReadPackedPrimitives<float, WireFormatLite::TYPE_FLOAT>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ default:
+ return errors::DataLoss("Failed reading TYPE_FLOAT for ",
+ DataTypeString(dtype));
+ }
case WireFormatLite::TYPE_INT64:
*index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_INT64>(
buf, buf_size, *index, stride, data);
return Status::OK();
case WireFormatLite::TYPE_UINT64:
- *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_UINT64>(
+ *index += ReadPackedPrimitives<uint64, WireFormatLite::TYPE_UINT64>(
buf, buf_size, *index, stride, data);
return Status::OK();
case WireFormatLite::TYPE_INT32:
- *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_INT32>(
- buf, buf_size, *index, stride, data);
- return Status::OK();
+ switch (dtype) {
+ case DataType::DT_INT64:
+ *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_INT32>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ case DataType::DT_INT32:
+ *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_INT32>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ default:
+ return errors::DataLoss("Failed reading TYPE_INT32 for ",
+ DataTypeString(dtype));
+ }
case WireFormatLite::TYPE_FIXED64:
- *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_FIXED64>(
+ *index += ReadPackedPrimitives<uint64, WireFormatLite::TYPE_FIXED64>(
buf, buf_size, *index, stride, data);
return Status::OK();
case WireFormatLite::TYPE_FIXED32:
- if (dtype == DataType::DT_INT64) {
- *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_FIXED32>(
- buf, buf_size, *index, stride, data);
- return Status::OK();
- }
- if (dtype == DataType::DT_INT32) {
- *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_FIXED32>(
- buf, buf_size, *index, stride, data);
- return Status::OK();
+ switch (dtype) {
+ case DataType::DT_UINT64:
+ *index += ReadPackedPrimitives<uint64, WireFormatLite::TYPE_FIXED32>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ case DataType::DT_UINT32:
+ *index += ReadPackedPrimitives<uint32, WireFormatLite::TYPE_FIXED32>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ default:
+ return errors::DataLoss("Failed reading TYPE_FIXED32 for ",
+ DataTypeString(dtype));
}
- // Any case that reaches this point should have triggered an error
- // already.
- return errors::DataLoss("Failed reading TYPE_FIXED32");
case WireFormatLite::TYPE_BOOL:
*index += ReadPackedPrimitives<bool, WireFormatLite::TYPE_BOOL>(
buf, buf_size, *index, stride, data);
@@ -476,38 +558,56 @@ inline Status ReadPackedFromArray(const void* buf, size_t buf_size,
case WireFormatLite::TYPE_BYTES:
return errors::DataLoss("Non-primitive type encountered as packed");
case WireFormatLite::TYPE_UINT32:
- if (dtype == DataType::DT_INT64) {
- *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_UINT32>(
- buf, buf_size, *index, stride, data);
- return Status::OK();
+ switch (dtype) {
+ case DataType::DT_UINT64:
+ *index += ReadPackedPrimitives<uint64, WireFormatLite::TYPE_UINT32>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ case DataType::DT_UINT32:
+ *index += ReadPackedPrimitives<uint32, WireFormatLite::TYPE_UINT32>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ default:
+ return errors::DataLoss("Failed reading TYPE_UINT32 for ",
+ DataTypeString(dtype));
}
- if (dtype == DataType::DT_INT32) {
- *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_UINT32>(
- buf, buf_size, *index, stride, data);
- return Status::OK();
- }
- // Any case that reaches this point should have triggered an error
- // already.
- return errors::DataLoss("Failed reading TYPE_UINT32");
case WireFormatLite::TYPE_ENUM:
*index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_ENUM>(
buf, buf_size, *index, stride, data);
return Status::OK();
case WireFormatLite::TYPE_SFIXED32:
- *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_SFIXED32>(
- buf, buf_size, *index, stride, data);
- return Status::OK();
-
+ switch (dtype) {
+ case DataType::DT_INT64:
+ *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_SFIXED32>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ case DataType::DT_INT32:
+ *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_SFIXED32>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ default:
+ return errors::DataLoss("Failed reading TYPE_INT32 for ",
+ DataTypeString(dtype));
+ }
case WireFormatLite::TYPE_SFIXED64:
*index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_SFIXED64>(
buf, buf_size, *index, stride, data);
return Status::OK();
case WireFormatLite::TYPE_SINT32:
- *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_SINT32>(
- buf, buf_size, *index, stride, data);
- return Status::OK();
-
+ switch (dtype) {
+ case DataType::DT_INT64:
+ *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_SINT32>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ case DataType::DT_INT32:
+ *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_SINT32>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ default:
+ return errors::DataLoss("Failed reading TYPE_SINT32 for ",
+ DataTypeString(dtype));
+ }
case WireFormatLite::TYPE_SINT64:
*index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_SINT64>(
buf, buf_size, *index, stride, data);
diff --git a/tensorflow/core/util/proto/proto_utils.cc b/tensorflow/core/util/proto/proto_utils.cc
new file mode 100644
index 0000000000..201f05a129
--- /dev/null
+++ b/tensorflow/core/util/proto/proto_utils.cc
@@ -0,0 +1,70 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+#include "tensorflow/core/util/proto/proto_utils.h"
+
+namespace tensorflow {
+namespace proto_utils {
+
+using tensorflow::protobuf::FieldDescriptor;
+using tensorflow::protobuf::internal::WireFormatLite;
+
+bool IsCompatibleType(FieldDescriptor::Type field_type, DataType dtype) {
+ switch (field_type) {
+ case WireFormatLite::TYPE_DOUBLE:
+ return dtype == tensorflow::DT_DOUBLE;
+ case WireFormatLite::TYPE_FLOAT:
+ return dtype == tensorflow::DT_FLOAT || dtype == tensorflow::DT_DOUBLE;
+ case WireFormatLite::TYPE_INT64:
+ return dtype == tensorflow::DT_INT64;
+ case WireFormatLite::TYPE_UINT64:
+ return dtype == tensorflow::DT_UINT64;
+ case WireFormatLite::TYPE_INT32:
+ return dtype == tensorflow::DT_INT32 || dtype == tensorflow::DT_INT64;
+ case WireFormatLite::TYPE_FIXED64:
+ return dtype == tensorflow::DT_UINT64;
+ case WireFormatLite::TYPE_FIXED32:
+ return dtype == tensorflow::DT_UINT32 || dtype == tensorflow::DT_UINT64;
+ case WireFormatLite::TYPE_BOOL:
+ return dtype == tensorflow::DT_BOOL;
+ case WireFormatLite::TYPE_STRING:
+ return dtype == tensorflow::DT_STRING;
+ case WireFormatLite::TYPE_GROUP:
+ return dtype == tensorflow::DT_STRING;
+ case WireFormatLite::TYPE_MESSAGE:
+ return dtype == tensorflow::DT_STRING;
+ case WireFormatLite::TYPE_BYTES:
+ return dtype == tensorflow::DT_STRING;
+ case WireFormatLite::TYPE_UINT32:
+ return dtype == tensorflow::DT_UINT32 || dtype == tensorflow::DT_UINT64;
+ case WireFormatLite::TYPE_ENUM:
+ return dtype == tensorflow::DT_INT32;
+ case WireFormatLite::TYPE_SFIXED32:
+ return dtype == tensorflow::DT_INT32 || dtype == tensorflow::DT_INT64;
+ case WireFormatLite::TYPE_SFIXED64:
+ return dtype == tensorflow::DT_INT64;
+ case WireFormatLite::TYPE_SINT32:
+ return dtype == tensorflow::DT_INT32 || dtype == tensorflow::DT_INT64;
+ case WireFormatLite::TYPE_SINT64:
+ return dtype == tensorflow::DT_INT64;
+ // default: intentionally omitted in order to enable static checking.
+ }
+}
+
+} // namespace proto_utils
+} // namespace tensorflow
diff --git a/tensorflow/core/util/proto/proto_utils.h b/tensorflow/core/util/proto/proto_utils.h
new file mode 100644
index 0000000000..d5e0b9006c
--- /dev/null
+++ b/tensorflow/core/util/proto/proto_utils.h
@@ -0,0 +1,33 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_UTIL_PROTO_PROTO_UTILS_H_
+#define TENSORFLOW_CORE_UTIL_PROTO_PROTO_UTILS_H_
+
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+namespace proto_utils {
+
+using tensorflow::protobuf::FieldDescriptor;
+
+// Returns true if the proto field type can be converted to the tensor dtype.
+bool IsCompatibleType(FieldDescriptor::Type field_type, DataType dtype);
+
+} // namespace proto_utils
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_UTIL_PROTO_PROTO_UTILS_H_
diff --git a/tensorflow/docs_src/guide/eager.md b/tensorflow/docs_src/guide/eager.md
index 42ad9652f8..3b54d6d2bb 100644
--- a/tensorflow/docs_src/guide/eager.md
+++ b/tensorflow/docs_src/guide/eager.md
@@ -504,13 +504,13 @@ with tf.device("gpu:0"):
### Object-based saving
-`tfe.Checkpoint` can save and restore `tf.Variable`s to and from
+`tf.train.Checkpoint` can save and restore `tf.Variable`s to and from
checkpoints:
```py
x = tf.Variable(10.)
-checkpoint = tfe.Checkpoint(x=x) # save as "x"
+checkpoint = tf.train.Checkpoint(x=x) # save as "x"
x.assign(2.) # Assign a new value to the variables and save.
save_path = checkpoint.save('./ckpt/')
@@ -523,18 +523,18 @@ checkpoint.restore(save_path)
print(x) # => 2.0
```
-To save and load models, `tfe.Checkpoint` stores the internal state of objects,
+To save and load models, `tf.train.Checkpoint` stores the internal state of objects,
without requiring hidden variables. To record the state of a `model`,
-an `optimizer`, and a global step, pass them to a `tfe.Checkpoint`:
+an `optimizer`, and a global step, pass them to a `tf.train.Checkpoint`:
```py
model = MyModel()
optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
checkpoint_dir = ‘/path/to/model_dir’
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
-root = tfe.Checkpoint(optimizer=optimizer,
- model=model,
- optimizer_step=tf.train.get_or_create_global_step())
+root = tf.train.Checkpoint(optimizer=optimizer,
+ model=model,
+ optimizer_step=tf.train.get_or_create_global_step())
root.save(file_prefix=checkpoint_prefix)
# or
@@ -824,7 +824,7 @@ gives you eager's interactive experimentation and debuggability with the
distributed performance benefits of graph execution.
Write, debug, and iterate in eager execution, then import the model graph for
-production deployment. Use `tfe.Checkpoint` to save and restore model
+production deployment. Use `tf.train.Checkpoint` to save and restore model
variables, this allows movement between eager and graph execution environments.
See the examples in:
[tensorflow/contrib/eager/python/examples](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples).
diff --git a/tensorflow/docs_src/install/install_c.md b/tensorflow/docs_src/install/install_c.md
index 4e1c32f972..cf869e8655 100644
--- a/tensorflow/docs_src/install/install_c.md
+++ b/tensorflow/docs_src/install/install_c.md
@@ -38,7 +38,7 @@ enable TensorFlow for C:
OS="linux" # Change to "darwin" for macOS
TARGET_DIRECTORY="/usr/local"
curl -L \
- "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-${OS}-x86_64-1.9.0-rc0.tar.gz" |
+ "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-${OS}-x86_64-1.9.0.tar.gz" |
sudo tar -C $TARGET_DIRECTORY -xz
The `tar` command extracts the TensorFlow C library into the `lib`
diff --git a/tensorflow/docs_src/install/install_go.md b/tensorflow/docs_src/install/install_go.md
index 162a820f22..4ec7e42773 100644
--- a/tensorflow/docs_src/install/install_go.md
+++ b/tensorflow/docs_src/install/install_go.md
@@ -38,7 +38,7 @@ steps to install this library and enable TensorFlow for Go:
TF_TYPE="cpu" # Change to "gpu" for GPU support
TARGET_DIRECTORY='/usr/local'
curl -L \
- "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-$(go env GOOS)-x86_64-1.9.0-rc0.tar.gz" |
+ "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-$(go env GOOS)-x86_64-1.9.0.tar.gz" |
sudo tar -C $TARGET_DIRECTORY -xz
The `tar` command extracts the TensorFlow C library into the `lib`
diff --git a/tensorflow/docs_src/install/install_java.md b/tensorflow/docs_src/install/install_java.md
index c196bb9b31..c5f760d254 100644
--- a/tensorflow/docs_src/install/install_java.md
+++ b/tensorflow/docs_src/install/install_java.md
@@ -36,7 +36,7 @@ following to the project's `pom.xml` to use the TensorFlow Java APIs:
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
- <version>1.9.0-rc0</version>
+ <version>1.9.0</version>
</dependency>
```
@@ -65,7 +65,7 @@ As an example, these steps will create a Maven project that uses TensorFlow:
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
- <version>1.9.0-rc0</version>
+ <version>1.9.0</version>
</dependency>
</dependencies>
</project>
@@ -124,12 +124,12 @@ instead:
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>libtensorflow</artifactId>
- <version>1.9.0-rc0</version>
+ <version>1.9.0</version>
</dependency>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>libtensorflow_jni_gpu</artifactId>
- <version>1.9.0-rc0</version>
+ <version>1.9.0</version>
</dependency>
```
@@ -148,7 +148,7 @@ refer to the simpler instructions above instead.
Take the following steps to install TensorFlow for Java on Linux or macOS:
1. Download
- [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.9.0-rc0.jar),
+ [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.9.0.jar),
which is the TensorFlow Java Archive (JAR).
2. Decide whether you will run TensorFlow for Java on CPU(s) only or with
@@ -167,7 +167,7 @@ Take the following steps to install TensorFlow for Java on Linux or macOS:
OS=$(uname -s | tr '[:upper:]' '[:lower:]')
mkdir -p ./jni
curl -L \
- "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-${TF_TYPE}-${OS}-x86_64-1.9.0-rc0.tar.gz" |
+ "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-${TF_TYPE}-${OS}-x86_64-1.9.0.tar.gz" |
tar -xz -C ./jni
### Install on Windows
@@ -175,10 +175,10 @@ Take the following steps to install TensorFlow for Java on Linux or macOS:
Take the following steps to install TensorFlow for Java on Windows:
1. Download
- [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.9.0-rc0.jar),
+ [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.9.0.jar),
which is the TensorFlow Java Archive (JAR).
2. Download the following Java Native Interface (JNI) file appropriate for
- [TensorFlow for Java on Windows](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.9.0-rc0.zip).
+ [TensorFlow for Java on Windows](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.9.0.zip).
3. Extract this .zip file.
__Note__: The native library (`tensorflow_jni.dll`) requires `msvcp140.dll` at runtime, which is included in the [Visual C++ 2015 Redistributable](https://www.microsoft.com/en-us/download/details.aspx?id=48145) package.
@@ -227,7 +227,7 @@ must be part of your `classpath`. For example, you can include the
downloaded `.jar` in your `classpath` by using the `-cp` compilation flag
as follows:
-<pre><b>javac -cp libtensorflow-1.9.0-rc0.jar HelloTF.java</b></pre>
+<pre><b>javac -cp libtensorflow-1.9.0.jar HelloTF.java</b></pre>
### Running
@@ -241,11 +241,11 @@ two files are available to the JVM:
For example, the following command line executes the `HelloTF` program on Linux
and macOS X:
-<pre><b>java -cp libtensorflow-1.9.0-rc0.jar:. -Djava.library.path=./jni HelloTF</b></pre>
+<pre><b>java -cp libtensorflow-1.9.0.jar:. -Djava.library.path=./jni HelloTF</b></pre>
And the following command line executes the `HelloTF` program on Windows:
-<pre><b>java -cp libtensorflow-1.9.0-rc0.jar;. -Djava.library.path=jni HelloTF</b></pre>
+<pre><b>java -cp libtensorflow-1.9.0.jar;. -Djava.library.path=jni HelloTF</b></pre>
If the program prints <tt>Hello from <i>version</i></tt>, you've successfully
installed TensorFlow for Java and are ready to use the API. If the program
diff --git a/tensorflow/docs_src/install/install_linux.md b/tensorflow/docs_src/install/install_linux.md
index 7534d0fac1..3a9a01c57e 100644
--- a/tensorflow/docs_src/install/install_linux.md
+++ b/tensorflow/docs_src/install/install_linux.md
@@ -65,7 +65,7 @@ We *recommend* using `pip` version 8.1 or higher. If using a release before
version 8.1, upgrade `pip`:
<pre class="prettyprint lang-bsh">
- <code class="devsite-terminal">sudo pip install -U pip</code>
+ <code class="devsite-terminal">pip install --upgrade pip</code>
</pre>
If not using Ubuntu and [setuptools](https://pypi.org/project/setuptools/) is
@@ -102,7 +102,7 @@ When the Virtualenv is activated, the shell prompt displays as `(venv) $`.
Within the active virtual environment, upgrade `pip`:
<pre class="prettyprint lang-bsh">
-(venv)$ pip install -U pip
+(venv)$ pip install --upgrade pip
</pre>
You can install other Python packages within the virtual environment without
@@ -120,7 +120,7 @@ Choose one of the available TensorFlow packages for installation:
Within an active Virtualenv environment, use `pip` to install the package:
<pre class="prettyprint lang-bsh">
- <code class="devsite-terminal">pip install -U tensorflow</code>
+ <code class="devsite-terminal">pip install --upgrade tensorflow</code>
</pre>
Use `pip list` to show the packages installed in the virtual environment.
@@ -198,7 +198,7 @@ We *recommend* using `pip` version 8.1 or higher. If using a release before
version 8.1, upgrade `pip`:
<pre class="prettyprint lang-bsh">
- <code class="devsite-terminal">sudo pip install -U pip</code>
+ <code class="devsite-terminal">pip install --upgrade pip</code>
</pre>
If not using Ubuntu and [setuptools](https://pypi.org/project/setuptools/) is
@@ -220,8 +220,8 @@ Choose one of the available TensorFlow packages for installation:
And use `pip` to install the package for Python 2 or 3:
<pre class="prettyprint lang-bsh">
- <code class="devsite-terminal">sudo pip install -U tensorflow # Python 2.7</code>
- <code class="devsite-terminal">sudo pip3 install -U tensorflow # Python 3.n</code>
+ <code class="devsite-terminal">pip install --upgrade --user tensorflow # Python 2.7</code>
+ <code class="devsite-terminal">pip3 install --upgrade --user tensorflow # Python 3.n</code>
</pre>
Use `pip list` to show the packages installed on the system.
@@ -239,8 +239,8 @@ If the above steps failed, try installing the TensorFlow binary using the remote
URL of the `pip` package:
<pre class="prettyprint lang-bsh">
- <code class="devsite-terminal">sudo pip install --upgrade <var>remote-pkg-URL</var> # Python 2.7</code>
- <code class="devsite-terminal">sudo pip3 install --upgrade <var>remote-pkg-URL</var> # Python 3.n</code>
+ <code class="devsite-terminal">pip install --user --upgrade <var>remote-pkg-URL</var> # Python 2.7</code>
+ <code class="devsite-terminal">pip3 install --user --upgrade <var>remote-pkg-URL</var> # Python 3.n</code>
</pre>
The <var>remote-pkg-URL</var> depends on the operating system, Python version,
@@ -255,8 +255,8 @@ encounter problems.
To uninstall TensorFlow on your system, use one of following commands:
<pre class="prettyprint lang-bsh">
- <code class="devsite-terminal">sudo pip uninstall tensorflow # for Python 2.7</code>
- <code class="devsite-terminal">sudo pip3 uninstall tensorflow # for Python 3.n</code>
+ <code class="devsite-terminal">pip uninstall tensorflow # for Python 2.7</code>
+ <code class="devsite-terminal">pip3 uninstall tensorflow # for Python 3.n</code>
</pre>
<a name="InstallingDocker"></a>
@@ -436,7 +436,7 @@ Take the following steps to install TensorFlow in an Anaconda environment:
<pre>
(tensorflow)$ <b>pip install --ignore-installed --upgrade \
- https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.9.0rc0-cp34-cp34m-linux_x86_64.whl</b></pre>
+ https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.9.0-cp34-cp34m-linux_x86_64.whl</b></pre>
<a name="ValidateYourInstallation"></a>
@@ -650,13 +650,13 @@ This section documents the relevant values for Linux installations.
CPU only:
<pre>
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.9.0rc0-cp27-none-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.9.0-cp27-none-linux_x86_64.whl
</pre>
GPU support:
<pre>
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.9.0rc0-cp27-none-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.9.0-cp27-none-linux_x86_64.whl
</pre>
Note that GPU support requires the NVIDIA hardware and software described in
@@ -667,13 +667,13 @@ Note that GPU support requires the NVIDIA hardware and software described in
CPU only:
<pre>
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.9.0rc0-cp34-cp34m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.9.0-cp34-cp34m-linux_x86_64.whl
</pre>
GPU support:
<pre>
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.9.0rc0-cp34-cp34m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.9.0-cp34-cp34m-linux_x86_64.whl
</pre>
Note that GPU support requires the NVIDIA hardware and software described in
@@ -684,13 +684,13 @@ Note that GPU support requires the NVIDIA hardware and software described in
CPU only:
<pre>
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.9.0rc0-cp35-cp35m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.9.0-cp35-cp35m-linux_x86_64.whl
</pre>
GPU support:
<pre>
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.9.0rc0-cp35-cp35m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.9.0-cp35-cp35m-linux_x86_64.whl
</pre>
Note that GPU support requires the NVIDIA hardware and software described in
@@ -701,13 +701,13 @@ Note that GPU support requires the NVIDIA hardware and software described in
CPU only:
<pre>
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.9.0rc0-cp36-cp36m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.9.0-cp36-cp36m-linux_x86_64.whl
</pre>
GPU support:
<pre>
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.9.0rc0-cp36-cp36m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.9.0-cp36-cp36m-linux_x86_64.whl
</pre>
Note that GPU support requires the NVIDIA hardware and software described in
diff --git a/tensorflow/docs_src/install/install_mac.md b/tensorflow/docs_src/install/install_mac.md
index 3372e9e1e0..1a7b2b815d 100644
--- a/tensorflow/docs_src/install/install_mac.md
+++ b/tensorflow/docs_src/install/install_mac.md
@@ -119,7 +119,7 @@ Take the following steps to install TensorFlow with Virtualenv:
TensorFlow in the active Virtualenv is as follows:
<pre> $ <b>pip3 install --upgrade \
- https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.9.0rc0-py3-none-any.whl</b></pre>
+ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.9.0-py3-none-any.whl</b></pre>
If you encounter installation problems, see
[Common Installation Problems](#common-installation-problems).
@@ -242,7 +242,7 @@ take the following steps:
issue the following command:
<pre> $ <b>sudo pip3 install --upgrade \
- https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.9.0rc0-py3-none-any.whl</b> </pre>
+ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.9.0-py3-none-any.whl</b> </pre>
If the preceding command fails, see
[installation problems](#common-installation-problems).
@@ -350,7 +350,7 @@ Take the following steps to install TensorFlow in an Anaconda environment:
TensorFlow for Python 2.7:
<pre> (<i>targetDirectory</i>)$ <b>pip install --ignore-installed --upgrade \
- https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.9.0rc0-py2-none-any.whl</b></pre>
+ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.9.0-py2-none-any.whl</b></pre>
<a name="ValidateYourInstallation"></a>
@@ -517,7 +517,7 @@ The value you specify depends on your Python version.
<pre>
-https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.9.0rc0-py2-none-any.whl
+https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.9.0-py2-none-any.whl
</pre>
@@ -525,5 +525,5 @@ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.9.0rc0-py2-none-a
<pre>
-https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.9.0rc0-py3-none-any.whl
+https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.9.0-py3-none-any.whl
</pre>
diff --git a/tensorflow/docs_src/install/install_sources.md b/tensorflow/docs_src/install/install_sources.md
index 502f4de7a6..4c09ba1a8b 100644
--- a/tensorflow/docs_src/install/install_sources.md
+++ b/tensorflow/docs_src/install/install_sources.md
@@ -330,10 +330,10 @@ Invoke `pip install` to install that pip package. The filename of the `.whl`
file depends on your platform. For example, the following command will install
the pip package
-for TensorFlow 1.9.0rc0 on Linux:
+for TensorFlow 1.9.0 on Linux:
<pre>
-$ <b>sudo pip install /tmp/tensorflow_pkg/tensorflow-1.9.0rc0-py2-none-any.whl</b>
+$ <b>sudo pip install /tmp/tensorflow_pkg/tensorflow-1.9.0-py2-none-any.whl</b>
</pre>
## Validate your installation
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index e037925961..8ede6ab54c 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -540,10 +540,11 @@ class _DeviceAttributes(object):
(in bytes).
"""
- def __init__(self, name, device_type, memory_limit_bytes):
+ def __init__(self, name, device_type, memory_limit_bytes, incarnation):
self._name = device.canonical_name(name)
self._device_type = device_type
self._memory_limit_bytes = memory_limit_bytes
+ self._incarnation = incarnation
@property
def name(self):
@@ -557,11 +558,16 @@ class _DeviceAttributes(object):
def memory_limit_bytes(self):
return self._memory_limit_bytes
+ @property
+ def incarnation(self):
+ return self._incarnation
+
def __repr__(self):
- return '_DeviceAttributes(%s, %s, %d)' % (
+ return '_DeviceAttributes(%s, %s, %d, %d)' % (
self.name,
self.device_type,
self.memory_limit_bytes,
+ self.incarnation,
)
@@ -658,7 +664,9 @@ class BaseSession(SessionInterface):
name = tf_session.TF_DeviceListName(raw_device_list, i)
device_type = tf_session.TF_DeviceListType(raw_device_list, i)
memory = tf_session.TF_DeviceListMemoryBytes(raw_device_list, i)
- device_list.append(_DeviceAttributes(name, device_type, memory))
+ incarnation = tf_session.TF_DeviceListIncarnation(raw_device_list, i)
+ device_list.append(
+ _DeviceAttributes(name, device_type, memory, incarnation))
tf_session.TF_DeleteDeviceList(raw_device_list)
return device_list
diff --git a/tensorflow/python/client/session_list_devices_test.py b/tensorflow/python/client/session_list_devices_test.py
index c5d82c213a..dd381c689f 100644
--- a/tensorflow/python/client/session_list_devices_test.py
+++ b/tensorflow/python/client/session_list_devices_test.py
@@ -37,6 +37,8 @@ class SessionListDevicesTest(test_util.TensorFlowTestCase):
devices = sess.list_devices()
self.assertTrue('/job:localhost/replica:0/task:0/device:CPU:0' in set(
[d.name for d in devices]), devices)
+ # All valid device incarnations must be non-zero.
+ self.assertTrue(all(d.incarnation != 0 for d in devices))
def testInvalidDeviceNumber(self):
opts = tf_session.TF_NewSessionOptions()
@@ -54,6 +56,8 @@ class SessionListDevicesTest(test_util.TensorFlowTestCase):
devices = sess.list_devices()
self.assertTrue('/job:local/replica:0/task:0/device:CPU:0' in set(
[d.name for d in devices]), devices)
+ # All valid device incarnations must be non-zero.
+ self.assertTrue(all(d.incarnation != 0 for d in devices))
def testListDevicesClusterSpecPropagation(self):
server1 = server_lib.Server.create_local_server()
@@ -67,11 +71,13 @@ class SessionListDevicesTest(test_util.TensorFlowTestCase):
config = config_pb2.ConfigProto(cluster_def=cluster_def)
with session.Session(server1.target, config=config) as sess:
devices = sess.list_devices()
- device_names = set([d.name for d in devices])
+ device_names = set(d.name for d in devices)
self.assertTrue(
'/job:worker/replica:0/task:0/device:CPU:0' in device_names)
self.assertTrue(
'/job:worker/replica:0/task:1/device:CPU:0' in device_names)
+ # All valid device incarnations must be non-zero.
+ self.assertTrue(all(d.incarnation != 0 for d in devices))
if __name__ == '__main__':
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index b72e029d1c..052be68385 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -35,6 +35,7 @@ from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import device as framework_device_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import function
@@ -104,18 +105,20 @@ class SessionTest(test_util.TensorFlowTestCase):
copy_val)
def testManyCPUs(self):
- # TODO(keveman): Implement ListDevices and test for the number of
- # devices returned by ListDevices.
with session.Session(
config=config_pb2.ConfigProto(device_count={
- 'CPU': 2
- })):
+ 'CPU': 2, 'GPU': 0
+ })) as sess:
inp = constant_op.constant(10.0, name='W1')
self.assertAllEqual(inp.eval(), 10.0)
+ devices = sess.list_devices()
+ self.assertEqual(2, len(devices))
+ for device in devices:
+ self.assertEqual('CPU', framework_device_lib.DeviceSpec.from_string(
+ device.name).device_type)
+
def testPerSessionThreads(self):
- # TODO(keveman): Implement ListDevices and test for the number of
- # devices returned by ListDevices.
with session.Session(
config=config_pb2.ConfigProto(use_per_session_threads=True)):
inp = constant_op.constant(10.0, name='W1')
@@ -1868,19 +1871,21 @@ class SessionTest(test_util.TensorFlowTestCase):
def testDeviceAttributes(self):
attrs = session._DeviceAttributes(
- '/job:worker/replica:0/task:3/device:CPU:2', 'TYPE', 1337)
+ '/job:worker/replica:0/task:3/device:CPU:2', 'TYPE', 1337, 1000000)
self.assertEqual(1337, attrs.memory_limit_bytes)
self.assertEqual('/job:worker/replica:0/task:3/device:CPU:2', attrs.name)
self.assertEqual('TYPE', attrs.device_type)
+ self.assertEqual(1000000, attrs.incarnation)
str_repr = '%s' % attrs
self.assertTrue(str_repr.startswith('_DeviceAttributes'), str_repr)
def testDeviceAttributesCanonicalization(self):
attrs = session._DeviceAttributes('/job:worker/replica:0/task:3/cpu:1',
- 'TYPE', 1337)
+ 'TYPE', 1337, 1000000)
self.assertEqual(1337, attrs.memory_limit_bytes)
self.assertEqual('/job:worker/replica:0/task:3/device:CPU:1', attrs.name)
self.assertEqual('TYPE', attrs.device_type)
+ self.assertEqual(1000000, attrs.incarnation)
str_repr = '%s' % attrs
self.assertTrue(str_repr.startswith('_DeviceAttributes'), str_repr)
diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i
index 985cb90436..1cdd8e0b6a 100644
--- a/tensorflow/python/client/tf_session.i
+++ b/tensorflow/python/client/tf_session.i
@@ -138,6 +138,11 @@ tensorflow::ImportNumpy();
$result = PyLong_FromLongLong($1);
}
+// Convert TF_DeviceListIncarnation uint64_t output to Python integer
+%typemap(out) uint64_t {
+ $result = PyLong_FromUnsignedLongLong($1);
+}
+
// We use TF_OperationGetControlInputs_wrapper instead of
// TF_OperationGetControlInputs
%ignore TF_OperationGetControlInputs;
diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py
index 35de2f2841..f0784ed3d0 100644
--- a/tensorflow/python/data/ops/iterator_ops.py
+++ b/tensorflow/python/data/ops/iterator_ops.py
@@ -499,23 +499,23 @@ class EagerIterator(object):
"tf.data.Dataset.make_initializable_iterator or "
"tf.data.Dataset.make_one_shot_iterator for graph construction".
format(type(self)))
- with ops.device("/device:CPU:0"):
- ds_variant = dataset._as_variant_tensor() # pylint: disable=protected-access
- self._output_classes = dataset.output_classes
- self._output_types = dataset.output_types
- self._output_shapes = dataset.output_shapes
- self._flat_output_types = nest.flatten(
- sparse.as_dense_types(self._output_types, self._output_classes))
- self._flat_output_shapes = nest.flatten(
- sparse.as_dense_shapes(self._output_shapes, self._output_classes))
+ self._device = context.context().device_name
+ ds_variant = dataset._as_variant_tensor() # pylint: disable=protected-access
+ self._output_classes = dataset.output_classes
+ self._output_types = dataset.output_types
+ self._output_shapes = dataset.output_shapes
+ self._flat_output_types = nest.flatten(
+ sparse.as_dense_types(self._output_types, self._output_classes))
+ self._flat_output_shapes = nest.flatten(
+ sparse.as_dense_shapes(self._output_shapes, self._output_classes))
+ with ops.colocate_with(ds_variant):
self._resource = gen_dataset_ops.anonymous_iterator(
output_types=self._flat_output_types,
output_shapes=self._flat_output_shapes)
gen_dataset_ops.make_iterator(ds_variant, self._resource)
- # Delete the resource when this object is deleted
- self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
- handle=self._resource, handle_device="/device:CPU:0")
- self._device = context.context().device_name
+ # Delete the resource when this object is deleted
+ self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
+ handle=self._resource, handle_device=self._device)
def __iter__(self):
return self
diff --git a/tensorflow/python/debug/examples/examples_test.sh b/tensorflow/python/debug/examples/examples_test.sh
index 2d35b2d8bb..f7d597c8c0 100755
--- a/tensorflow/python/debug/examples/examples_test.sh
+++ b/tensorflow/python/debug/examples/examples_test.sh
@@ -99,7 +99,7 @@ if [[ -d "${CUSTOM_DUMP_ROOT}" ]]; then
fi
# Test debugging of tf.keras.
-cat << EOF | "${DEBUG_KERAS_BIN}" --debug --ui_type=readline
+cat << EOF | ${DEBUG_KERAS_BIN} --debug --ui_type=readline
run -f has_inf_or_nan
EOF
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD
index 6c415b1bf2..fd46163050 100644
--- a/tensorflow/python/estimator/BUILD
+++ b/tensorflow/python/estimator/BUILD
@@ -40,9 +40,9 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":gc",
+ ":metric_keys",
+ ":util",
"//tensorflow:tensorflow_py_no_contrib",
- "//tensorflow/python/estimator:metric_keys",
- "//tensorflow/python/estimator:util",
],
)
@@ -683,9 +683,9 @@ py_test(
],
deps = [
":keras",
+ ":numpy_io",
+ ":run_config",
"//tensorflow:tensorflow_py_no_contrib",
- "//tensorflow/python/estimator:numpy_io",
- "//tensorflow/python/estimator:run_config",
"//third_party/py/numpy",
],
)
diff --git a/tensorflow/python/estimator/canned/metric_keys.py b/tensorflow/python/estimator/canned/metric_keys.py
index 4f7c849ba4..9d49240fea 100644
--- a/tensorflow/python/estimator/canned/metric_keys.py
+++ b/tensorflow/python/estimator/canned/metric_keys.py
@@ -47,3 +47,8 @@ class MetricKeys(object):
PROBABILITY_MEAN_AT_CLASS = 'probability_mean/class%d'
AUC_AT_CLASS = 'auc/class%d'
AUC_PR_AT_CLASS = 'auc_precision_recall/class%d'
+
+ # The following require a class name applied.
+ PROBABILITY_MEAN_AT_NAME = 'probability_mean/%s'
+ AUC_AT_NAME = 'auc/%s'
+ AUC_PR_AT_NAME = 'auc_precision_recall/%s'
diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py
index 7d07c77c79..8cc971c61d 100644
--- a/tensorflow/python/grappler/layout_optimizer_test.py
+++ b/tensorflow/python/grappler/layout_optimizer_test.py
@@ -1340,7 +1340,7 @@ class LayoutOptimizerTest(test.TestCase):
expected_num_transposes = 2
self.assertEqual(expected_num_transposes, num_transposes)
self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
- self.assertAllEqual(output_val_ref, output_val)
+ self.assertAllClose(output_val_ref, output_val, atol=1e-3)
def testLoop(self):
if test.is_gpu_available(cuda_only=True):
diff --git a/tensorflow/python/kernel_tests/bitcast_op_test.py b/tensorflow/python/kernel_tests/bitcast_op_test.py
index a535468b05..a2c6b54273 100644
--- a/tensorflow/python/kernel_tests/bitcast_op_test.py
+++ b/tensorflow/python/kernel_tests/bitcast_op_test.py
@@ -76,12 +76,18 @@ class BitcastTest(test.TestCase):
datatype = dtypes.int8
array_ops.bitcast(x, datatype, None)
- def testQuantizeType(self):
+ def testQuantizedType(self):
shape = [3, 4]
x = np.zeros(shape, np.uint16)
datatype = dtypes.quint16
self._testBitcast(x, datatype, shape)
+ def testUnsignedType(self):
+ shape = [3, 4]
+ x = np.zeros(shape, np.int64)
+ datatype = dtypes.uint64
+ self._testBitcast(x, datatype, shape)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
index f9b9c77bbf..080319f6e8 100644
--- a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
+++ b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
@@ -370,6 +370,29 @@ class ScatterNdTest(test.TestCase):
return array_ops.scatter_nd(indices, updates, shape)
@test_util.run_in_graph_and_eager_modes
+ def testBool(self):
+ indices = constant_op.constant(
+ [[4], [3], [1], [7]], dtype=dtypes.int32)
+ updates = constant_op.constant(
+ [False, True, False, True], dtype=dtypes.bool)
+ expected = np.array(
+ [False, False, False, True, False, False, False, True])
+ scatter = self.scatter_nd(indices, updates, shape=(8,))
+ result = self.evaluate(scatter)
+ self.assertAllEqual(expected, result)
+
+ # Same indice is updated twice by same value.
+ indices = constant_op.constant(
+ [[4], [3], [3], [7]], dtype=dtypes.int32)
+ updates = constant_op.constant(
+ [False, True, True, True], dtype=dtypes.bool)
+ expected = np.array([
+ False, False, False, True, False, False, False, True])
+ scatter = self.scatter_nd(indices, updates, shape=(8,))
+ result = self.evaluate(scatter)
+ self.assertAllEqual(expected, result)
+
+ @test_util.run_in_graph_and_eager_modes
def testInvalidShape(self):
# TODO(apassos) figure out how to unify these errors
with self.assertRaises(errors.InvalidArgumentError
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 361667ec49..ec6488ea63 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -636,10 +636,10 @@ def strided_slice(input_,
`foo[:4, tf.newaxis, :2]` would produce a shape `(4, 1, 2)` tensor.
If the ith bit of `shrink_axis_mask` is set, it implies that the ith
- specification shrinks the dimensionality by 1. `begin[i]`, `end[i]` and
- `strides[i]` must imply a slice of size 1 in the dimension. For example in
- Python one might do `foo[:, 3, :]` which would result in
- `shrink_axis_mask` equal to 2.
+ specification shrinks the dimensionality by 1, taking on the value at index
+ `begin[i]`. `end[i]` and `strides[i]` are ignored in this case. For example in
+ Python one might do `foo[:, 3, :]` which would result in `shrink_axis_mask`
+ equal to 2.
NOTE: `begin` and `end` are zero-indexed.
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index c28dca5137..fbe6b62302 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -628,16 +628,17 @@ def cast(x, dtype, name=None):
```
The operation supports data types (for `x` and `dtype`) of
- `uint8`, `int8`, `uint16`, `int16`, `int32`, `int64`, `float16`, `float32`,
- `float64`, `complex64`, `complex128`, `bfloat16`. In case of casting from
- complex types (`complex64`, `complex128`) to real types, only the real part
- of `x` is returned. In case of casting from real types to complex types
- (`complex64`, `complex128`), the imaginary part of the returned value is set
- to `0`. The handling of complex types here matches the behavior of numpy.
+ `uint8`, `uint16`, `uint32`, `uint64`, `int8`, `int16`, `int32`, `int64`,
+ `float16`, `float32`, `float64`, `complex64`, `complex128`, `bfloat16`.
+ In case of casting from complex types (`complex64`, `complex128`) to real
+ types, only the real part of `x` is returned. In case of casting from real
+ types to complex types (`complex64`, `complex128`), the imaginary part of the
+ returned value is set to `0`. The handling of complex types here matches the
+ behavior of numpy.
Args:
x: A `Tensor` or `SparseTensor` of numeric type. It could be
- `uint8`, `int8`, `uint16`, `int16`, `int32`, `int64`,
+ `uint8`, `uint16`, `uint32`, `uint64`, `int8`, `int16`, `int32`, `int64`,
`float16`, `float32`, `float64`, `complex64`, `complex128`, `bfloat16`.
dtype: The destination type. The list of supported dtypes is the same
as `x`.
diff --git a/tensorflow/python/ops/parallel_for/__init__.py b/tensorflow/python/ops/parallel_for/__init__.py
index b49d865968..dd8bc6d487 100644
--- a/tensorflow/python/ops/parallel_for/__init__.py
+++ b/tensorflow/python/ops/parallel_for/__init__.py
@@ -23,13 +23,3 @@ from tensorflow.python.ops.parallel_for.control_flow_ops import for_loop
from tensorflow.python.ops.parallel_for.control_flow_ops import pfor
from tensorflow.python.ops.parallel_for.gradients import batch_jacobian
from tensorflow.python.ops.parallel_for.gradients import jacobian
-from tensorflow.python.util.all_util import remove_undocumented
-
-_allowed_symbols = [
- 'pfor',
- 'for_loop',
- 'jacobian',
- 'batch_jacobian',
-]
-
-remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index 5979b76ff2..1f56ad25bf 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -1294,16 +1294,3 @@ def is_resource_variable(var):
""""Returns True if `var` is to be considered a ResourceVariable."""
return isinstance(var, ResourceVariable) or hasattr(
var, "_should_act_as_resource_variable")
-
-
-_DEFAULT_USE_RESOURCE = False
-
-
-def _default_variable_creator(_, *args, **kwds):
- use_resource = kwds.pop("use_resource", _DEFAULT_USE_RESOURCE)
- use_resource = use_resource or context.executing_eagerly()
- if use_resource:
- return ResourceVariable(*args, **kwds)
- return variables.RefVariable(*args, **kwds)
-
-variables.default_variable_creator = _default_variable_creator
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index 0f37dcc027..aca44bcd44 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -2349,7 +2349,10 @@ def default_variable_creator(next_creator=None, **kwargs):
validate_shape = kwargs.get("validate_shape", True)
caching_device = kwargs.get("caching_device", None)
name = kwargs.get("name", None)
+ variable_def = kwargs.get("variable_def", None)
dtype = kwargs.get("dtype", None)
+ expected_shape = kwargs.get("expected_shape", None)
+ import_scope = kwargs.get("import_scope", None)
constraint = kwargs.get("constraint", None)
use_resource = kwargs.get("use_resource", None)
@@ -2360,23 +2363,24 @@ def default_variable_creator(next_creator=None, **kwargs):
if use_resource is None:
use_resource = get_variable_scope().use_resource
- if use_resource or (use_resource is None and context.executing_eagerly()):
+ use_resource = use_resource or context.executing_eagerly()
+ if use_resource:
return resource_variable_ops.ResourceVariable(
initial_value=initial_value, trainable=trainable,
collections=collections, validate_shape=validate_shape,
caching_device=caching_device, name=name, dtype=dtype,
- constraint=constraint)
- elif not use_resource and context.executing_eagerly():
- raise RuntimeError(
- "VariableScope should use resource variable when eager execution is"
- " enabled, but use_resource is False."
- )
+ constraint=constraint, variable_def=variable_def,
+ import_scope=import_scope)
else:
- return variables.Variable(
+ return variables.RefVariable(
initial_value=initial_value, trainable=trainable,
collections=collections, validate_shape=validate_shape,
caching_device=caching_device, name=name, dtype=dtype,
- constraint=constraint)
+ constraint=constraint, variable_def=variable_def,
+ expected_shape=expected_shape, import_scope=import_scope)
+
+
+variables.default_variable_creator = default_variable_creator
def _make_getter(captured_getter, captured_previous):
@@ -2384,36 +2388,8 @@ def _make_getter(captured_getter, captured_previous):
return lambda **kwargs: captured_getter(captured_previous, **kwargs)
-def variable(initial_value=None,
- trainable=None,
- collections=None,
- validate_shape=True,
- caching_device=None,
- name=None,
- dtype=None,
- constraint=None,
- use_resource=None,
- synchronization=VariableSynchronization.AUTO,
- aggregation=VariableAggregation.NONE):
- previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)
- for getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access
- previous_getter = _make_getter(getter, previous_getter)
-
- # Reset `aggregation` that is explicitly set as `None` to the enum None value.
- if aggregation is None:
- aggregation = VariableAggregation.NONE
- return previous_getter(
- initial_value=initial_value,
- trainable=trainable,
- collections=collections,
- validate_shape=validate_shape,
- caching_device=caching_device,
- name=name,
- dtype=dtype,
- constraint=constraint,
- use_resource=use_resource,
- synchronization=synchronization,
- aggregation=aggregation)
+# TODO(apassos) remove forwarding symbol
+variable = variables.Variable
@tf_contextlib.contextmanager
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index 6bb2d6f669..d03d93beeb 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -40,15 +40,15 @@ from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.tf_export import tf_export
-def default_variable_creator(_, *args, **kwds):
- del args, kwds
- raise NotImplementedError("resource_variable_ops needs to be imported")
+def default_variable_creator(_, **kwds):
+ del kwds
+ raise NotImplementedError("variable_scope needs to be imported")
def _make_getter(captured_getter, captured_previous):
"""To avoid capturing loop variables."""
- def getter(*args, **kwargs):
- return captured_getter(captured_previous, *args, **kwargs)
+ def getter(**kwargs):
+ return captured_getter(captured_previous, **kwargs)
return getter
@@ -86,11 +86,48 @@ class VariableAggregation(enum.Enum):
class VariableMetaclass(type):
"""Metaclass to allow construction of tf.Variable to be overridden."""
+ def _variable_call(cls,
+ initial_value=None,
+ trainable=None,
+ collections=None,
+ validate_shape=True,
+ caching_device=None,
+ name=None,
+ variable_def=None,
+ dtype=None,
+ expected_shape=None,
+ import_scope=None,
+ constraint=None,
+ use_resource=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
+ """Call on Variable class. Useful to force the signature."""
+ previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)
+ for getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access
+ previous_getter = _make_getter(getter, previous_getter)
+
+ # Reset `aggregation` that is explicitly set as `None` to the enum NONE.
+ if aggregation is None:
+ aggregation = VariableAggregation.NONE
+ return previous_getter(
+ initial_value=initial_value,
+ trainable=trainable,
+ collections=collections,
+ validate_shape=validate_shape,
+ caching_device=caching_device,
+ name=name,
+ variable_def=variable_def,
+ dtype=dtype,
+ expected_shape=expected_shape,
+ import_scope=import_scope,
+ constraint=constraint,
+ use_resource=use_resource,
+ synchronization=synchronization,
+ aggregation=aggregation)
+
def __call__(cls, *args, **kwargs):
if cls is Variable:
- previous_getter = lambda *a, **k: default_variable_creator(None, *a, **k)
- # TODO(apassos) use a stack of getters here
- return previous_getter(*args, **kwargs)
+ return cls._variable_call(*args, **kwargs)
else:
return super(VariableMetaclass, cls).__call__(*args, **kwargs)
@@ -650,8 +687,8 @@ class Variable(six.with_metaclass(VariableMetaclass,
@staticmethod
def from_proto(variable_def, import_scope=None):
"""Returns a `Variable` object created from `variable_def`."""
- return Variable(variable_def=variable_def,
- import_scope=import_scope)
+ return RefVariable(variable_def=variable_def,
+ import_scope=import_scope)
class SaveSliceInfo(object):
"""Information on how to save this Variable as a slice.
diff --git a/tensorflow/python/training/checkpointable/data_structures_test.py b/tensorflow/python/training/checkpointable/data_structures_test.py
index ec8c9da809..7bee00a927 100644
--- a/tensorflow/python/training/checkpointable/data_structures_test.py
+++ b/tensorflow/python/training/checkpointable/data_structures_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import os
import numpy
+import six
from tensorflow.python.eager import context
from tensorflow.python.eager import test
@@ -72,11 +73,14 @@ class ListTests(test.TestCase):
model = HasList()
output = model(array_ops.ones([32, 2]))
self.assertAllEqual([32, 12], output.shape)
- self.assertEqual(2, len(model.layers))
- self.assertIs(model.layer_list, model.layers[0])
- self.assertEqual(10, len(model.layers[0].layers))
+ self.assertEqual(11, len(model.layers))
+ self.assertEqual(10, len(model.layer_list.layers))
+ six.assertCountEqual(
+ self,
+ model.layers,
+ model.layer_list.layers + model.layers_with_updates)
for index in range(10):
- self.assertEqual(3 + index, model.layers[0].layers[index].units)
+ self.assertEqual(3 + index, model.layer_list.layers[index].units)
self.assertEqual(2, len(model._checkpoint_dependencies))
self.assertIs(model.layer_list, model._checkpoint_dependencies[0].ref)
self.assertIs(model.layers_with_updates,
@@ -123,9 +127,11 @@ class ListTests(test.TestCase):
self.l2 = []
model = HasEqualContainers()
- model.l1.append(HasEqualContainers())
- model.l2.append(HasEqualContainers())
- self.assertEqual([model.l1, model.l2], model.layers)
+ first_layer = HasEqualContainers()
+ model.l1.append(first_layer)
+ second_layer = HasEqualContainers()
+ model.l2.append(second_layer)
+ self.assertEqual([first_layer, second_layer], model.layers)
def testNotCheckpointable(self):
class NotCheckpointable(object):
@@ -260,9 +266,8 @@ class MappingTests(test.TestCase):
model = HasMapping()
output = model(array_ops.ones([32, 2]))
self.assertAllEqual([32, 7], output.shape)
- self.assertEqual(1, len(model.layers))
- self.assertIs(model.layer_dict, model.layers[0])
- self.assertEqual(3, len(model.layers[0].layers))
+ self.assertEqual(5, len(model.layers))
+ six.assertCountEqual(self, model.layers, model.layer_dict.layers)
self.assertEqual(1, len(model._checkpoint_dependencies))
self.assertIs(model.layer_dict, model._checkpoint_dependencies[0].ref)
self.evaluate([v.initializer for v in model.variables])
diff --git a/tensorflow/python/training/checkpointable/layer_utils.py b/tensorflow/python/training/checkpointable/layer_utils.py
index 978fcb2252..d65b631fe9 100644
--- a/tensorflow/python/training/checkpointable/layer_utils.py
+++ b/tensorflow/python/training/checkpointable/layer_utils.py
@@ -32,10 +32,15 @@ def is_layer(obj):
def filter_empty_layer_containers(layer_list):
"""Filter out empty Layer-like containers."""
- return [layer for layer in layer_list
- # Filter out only empty Checkpointable data structures. Empty Networks
- # will still show up in Model.layers.
- if is_layer(layer) or getattr(layer, "layers", True)]
+ filtered = []
+ for obj in layer_list:
+ if is_layer(obj):
+ filtered.append(obj)
+ else:
+ # Checkpointable data structures will not show up in ".layers" lists, but
+ # the layers they contain will.
+ filtered.extend(obj.layers)
+ return filtered
def gather_trainable_weights(trainable, sub_layers, extra_variables):
diff --git a/tensorflow/python/training/checkpointable/util.py b/tensorflow/python/training/checkpointable/util.py
index 6ae5765b13..686232fe27 100644
--- a/tensorflow/python/training/checkpointable/util.py
+++ b/tensorflow/python/training/checkpointable/util.py
@@ -747,7 +747,7 @@ def capture_dependencies(template):
initial_value=initializer,
name=name,
**inner_kwargs)
- if name.startswith(name_prefix):
+ if name is not None and name.startswith(name_prefix):
scope_stripped_name = name[len(name_prefix) + 1:]
if not checkpointable_parent:
return template._add_variable_with_custom_getter( # pylint: disable=protected-access
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index e85b6db511..766a0dafb5 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -2411,19 +2411,6 @@ port::Status CudnnSupport::DoConvolveImpl(
stream, cudnn, algorithm_config, input_nd, filter,
conv, output_nd, scratch_allocator, &scratch));
- if (cudnn_type == CUDNN_DATA_HALF &&
- filter_descriptor.layout() == dnn::FilterLayout::kOutputYXInput &&
- (algo_desc.algo_id() != CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM ||
- input_descriptor.layout() != dnn::DataLayout::kBatchYXDepth ||
- output_descriptor.layout() != dnn::DataLayout::kBatchYXDepth)) {
- // TODO(timshen): Attach a nvbugs number.
- return port::Status(
- port::error::INTERNAL,
- "Cudnn doesn't return an error code on this documented unsupported "
- "layout combination. Instead, it accesses out-of-bounds memory. "
- "Being nice and returning an error instead.");
- }
-
std::unique_ptr<CUDATimer, TimerDeleter> timer;
if (is_profiling) {
timer.reset(new CUDATimer(parent_)); // NOLINT
@@ -3093,21 +3080,9 @@ port::Status CudnnSupport::DoConvolveBackwardDataImpl(
}
}
- if (cudnn_type == CUDNN_DATA_HALF &&
- filter_descriptor.layout() == dnn::FilterLayout::kOutputYXInput &&
- ((algo_desc.algo_id() != CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0 &&
- algo_desc.algo_id() != CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1) ||
- input_descriptor.layout() != dnn::DataLayout::kBatchYXDepth ||
- output_descriptor.layout() != dnn::DataLayout::kBatchYXDepth)) {
- return port::Status(
- port::error::INTERNAL,
- "Cudnn doesn't return an error code on this documented unsupported "
- "layout combination. Instead, it crashes. Being nice and returning an "
- "error instead. See nvbugs/2260917");
- }
-
// Cudnn 7.1.4 has a bug if the workspace of the following convolution is not
- // zero-initialized. See nvbugs/2254619.
+ // zero-initialized.
+ // TODO(timshen): Add an nvbugs/ link.
if (CUDNN_VERSION >= 7000 &&
algorithm_config.algorithm().algo_id() ==
CUDNN_CONVOLUTION_BWD_DATA_ALGO_1 &&
diff --git a/tensorflow/tools/ci_build/ci_parameterized_build.sh b/tensorflow/tools/ci_build/ci_parameterized_build.sh
index 08e2c3edd2..5115be8c6d 100755
--- a/tensorflow/tools/ci_build/ci_parameterized_build.sh
+++ b/tensorflow/tools/ci_build/ci_parameterized_build.sh
@@ -150,36 +150,7 @@ BAZEL_TARGET="//tensorflow/... -//tensorflow/compiler/..."
if [[ -n "$TF_SKIP_CONTRIB_TESTS" ]]; then
BAZEL_TARGET="$BAZEL_TARGET -//tensorflow/contrib/..."
else
- BAZEL_TARGET="${BAZEL_TARGET} -//tensorflow/contrib/lite/..."
- BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite:context_test"
- BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite:framework"
- BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite:interpreter_test"
- BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite:model_test"
- BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/toco:toco"
- BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite:simple_memory_arena_test"
- BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite:string_util_test"
- BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:activations_test"
- BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:add_test"
- BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:basic_rnn_test"
- BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:concatenation_test"
- BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:conv_test"
- BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:depthwise_conv_test"
- BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:embedding_lookup_test"
- BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:embedding_lookup_sparse_test"
- BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:fully_connected_test"
- BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:hashtable_lookup_test"
- BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:local_response_norm_test"
- BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:lsh_projection_test"
- BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:lstm_test"
- BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:l2norm_test"
- BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:mul_test"
- BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:pooling_test"
- BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:reshape_test"
- BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:resize_bilinear_test"
- BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:skip_gram_test"
- BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:softmax_test"
- BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:space_to_depth_test"
- BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:svdf_test"
+ BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/..."
fi
TUT_TEST_DATA_DIR="/tmp/tf_tutorial_test_data"
diff --git a/tensorflow/tools/ci_build/ci_sanity.sh b/tensorflow/tools/ci_build/ci_sanity.sh
index 866fe95d2b..db37edf809 100755
--- a/tensorflow/tools/ci_build/ci_sanity.sh
+++ b/tensorflow/tools/ci_build/ci_sanity.sh
@@ -354,7 +354,7 @@ do_external_licenses_check(){
# Whitelist
echo ${EXTRA_LICENSE_FILE}
- grep -e "@bazel_tools//src" -e "@bazel_tools//tools/" -e "@com_google_absl//" -e "//external" -e "@local" -e "@com_github_googlecloudplatform_google_cloud_cpp//" -e "@embedded_jdk//" -v ${EXTRA_LICENSES_FILE} > temp.txt
+ grep -e "@bazel_tools//src" -e "@bazel_tools//tools/" -e "@com_google_absl//" -e "//external" -e "@local" -e "@com_github_googlecloudplatform_google_cloud_cpp//" -v ${EXTRA_LICENSES_FILE} > temp.txt
mv temp.txt ${EXTRA_LICENSES_FILE}
diff --git a/tensorflow/tools/ci_build/install/install_bazel.sh b/tensorflow/tools/ci_build/install/install_bazel.sh
index e284401b8a..adbff8f6ef 100755
--- a/tensorflow/tools/ci_build/install/install_bazel.sh
+++ b/tensorflow/tools/ci_build/install/install_bazel.sh
@@ -15,7 +15,7 @@
# ==============================================================================
# Select bazel version.
-BAZEL_VERSION="0.15.0"
+BAZEL_VERSION="0.14.1"
set +e
local_bazel_ver=$(bazel version 2>&1 | grep -i label | awk '{print $3}')
diff --git a/tensorflow/tools/ci_build/install/install_bazel_from_source.sh b/tensorflow/tools/ci_build/install/install_bazel_from_source.sh
index 87be81577d..9d24b3e421 100755
--- a/tensorflow/tools/ci_build/install/install_bazel_from_source.sh
+++ b/tensorflow/tools/ci_build/install/install_bazel_from_source.sh
@@ -18,7 +18,7 @@
# It will compile bazel from source and install it in /usr/local/bin
# Select bazel version.
-BAZEL_VERSION="0.15.0"
+BAZEL_VERSION="0.14.1"
set +e
local_bazel_ver=$(bazel version 2>&1 | grep -i label | awk '{print $3}')
diff --git a/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh b/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh
index 2b68de3c5b..f6fa9251d4 100755
--- a/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh
+++ b/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh
@@ -34,35 +34,4 @@ yes "" | $PYTHON_BIN_PATH configure.py
bazel test --test_tag_filters=-no_oss,-oss_serial,-gpu,-benchmark-test -k \
--jobs=${N_JOBS} --test_timeout 300,450,1200,3600 --config=opt \
--test_size_filters=small,medium --test_output=errors -- \
- //tensorflow/contrib/... \
- -//tensorflow/contrib/lite/... \
- //tensorflow/contrib/lite:context_test \
- //tensorflow/contrib/lite:framework \
- //tensorflow/contrib/lite:interpreter_test \
- //tensorflow/contrib/lite:model_test \
- //tensorflow/contrib/lite/toco:toco \
- //tensorflow/contrib/lite:simple_memory_arena_test \
- //tensorflow/contrib/lite:string_util_test \
- //tensorflow/contrib/lite/kernels:activations_test \
- //tensorflow/contrib/lite/kernels:add_test \
- //tensorflow/contrib/lite/kernels:basic_rnn_test \
- //tensorflow/contrib/lite/kernels:concatenation_test \
- //tensorflow/contrib/lite/kernels:conv_test \
- //tensorflow/contrib/lite/kernels:depthwise_conv_test \
- //tensorflow/contrib/lite/kernels:embedding_lookup_test \
- //tensorflow/contrib/lite/kernels:embedding_lookup_sparse_test \
- //tensorflow/contrib/lite/kernels:fully_connected_test \
- //tensorflow/contrib/lite/testing:generated_zip_tests \
- //tensorflow/contrib/lite/kernels:hashtable_lookup_test \
- //tensorflow/contrib/lite/kernels:local_response_norm_test \
- //tensorflow/contrib/lite/kernels:lsh_projection_test \
- //tensorflow/contrib/lite/kernels:lstm_test \
- //tensorflow/contrib/lite/kernels:l2norm_test \
- //tensorflow/contrib/lite/kernels:mul_test \
- //tensorflow/contrib/lite/kernels:pooling_test \
- //tensorflow/contrib/lite/kernels:reshape_test \
- //tensorflow/contrib/lite/kernels:resize_bilinear_test \
- //tensorflow/contrib/lite/kernels:skip_gram_test \
- //tensorflow/contrib/lite/kernels:softmax_test \
- //tensorflow/contrib/lite/kernels:space_to_depth_test \
- //tensorflow/contrib/lite/kernels:svdf_test
+ //tensorflow/contrib/...
diff --git a/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh b/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh
index 0482cf619a..c03cbd9c66 100644
--- a/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh
+++ b/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh
@@ -33,10 +33,10 @@ function set_remote_cache_options {
echo "build --tls_enabled=true" >> "${TMP_BAZELRC}"
echo "build --remote_timeout=3600" >> "${TMP_BAZELRC}"
echo "build --auth_enabled=true" >> "${TMP_BAZELRC}"
- echo "build --spawn_strategy=standalone" >> "${TMP_BAZELRC}"
- echo "build --strategy=Javac=standalone" >> "${TMP_BAZELRC}"
- echo "build --strategy=Closure=standalone" >> "${TMP_BAZELRC}"
- echo "build --genrule_strategy=standalone" >> "${TMP_BAZELRC}"
+ echo "build --spawn_strategy=remote" >> "${TMP_BAZELRC}"
+ echo "build --strategy=Javac=remote" >> "${TMP_BAZELRC}"
+ echo "build --strategy=Closure=remote" >> "${TMP_BAZELRC}"
+ echo "build --genrule_strategy=remote" >> "${TMP_BAZELRC}"
echo "build --google_credentials=$GOOGLE_CLOUD_CREDENTIAL" >> "${TMP_BAZELRC}"
}
diff --git a/tensorflow/tools/ci_build/windows/bazel/common_env.sh b/tensorflow/tools/ci_build/windows/bazel/common_env.sh
index 3af132217e..333a89d3f5 100644
--- a/tensorflow/tools/ci_build/windows/bazel/common_env.sh
+++ b/tensorflow/tools/ci_build/windows/bazel/common_env.sh
@@ -26,7 +26,8 @@
# * Bazel windows executable copied as "bazel.exe" and included in PATH.
# Use a temporary directory with a short name.
-export TMPDIR="C:/tmp"
+export TMPDIR=${TMPDIR:-"C:/tmp"}
+export TMPDIR=$(cygpath -m "$TMPDIR")
mkdir -p "$TMPDIR"
# Set bash path
diff --git a/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh b/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh
index 61dec249f3..dc7ea1dc57 100644
--- a/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh
+++ b/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh
@@ -67,16 +67,12 @@ for ARG in "$@"; do
done
if [[ "$release_build" != 1 ]]; then
- # --define=override_eigen_strong_inline=true speeds up the compiling of conv_grad_ops_3d.cc and conv_ops_3d.cc
+ # Overriding eigen strong inline speeds up the compiling of conv_grad_ops_3d.cc and conv_ops_3d.cc
# by 20 minutes. See https://github.com/tensorflow/tensorflow/issues/10521
- # Because this hurts the performance of TF, we don't enable it in release build.
- echo "build --define=override_eigen_strong_inline=true" >> "${TMP_BAZELRC}"
+ # Because this hurts the performance of TF, we don't override it in release build.
+ export TF_OVERRIDE_EIGEN_STRONG_INLINE=0
fi
-# The host and target platforms are the same in Windows build. So we don't have
-# to distinct them. This helps avoid building the same targets twice.
-echo "build --distinct_host_configuration=false" >> "${TMP_BAZELRC}"
-
# Enable short object file path to avoid long path issue on Windows.
echo "startup --output_user_root=${TMPDIR}" >> "${TMP_BAZELRC}"
diff --git a/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh b/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh
index e232306653..a4175a0e81 100644
--- a/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh
+++ b/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh
@@ -67,16 +67,12 @@ for ARG in "$@"; do
done
if [[ "$release_build" != 1 ]]; then
- # --define=override_eigen_strong_inline=true speeds up the compiling of conv_grad_ops_3d.cc and conv_ops_3d.cc
+ # Overriding eigen strong inline speeds up the compiling of conv_grad_ops_3d.cc and conv_ops_3d.cc
# by 20 minutes. See https://github.com/tensorflow/tensorflow/issues/10521
- # Because this hurts the performance of TF, we don't enable it in release build.
- echo "build --define=override_eigen_strong_inline=true" >> "${TMP_BAZELRC}"
+ # Because this hurts the performance of TF, we don't override it in release build.
+ export TF_OVERRIDE_EIGEN_STRONG_INLINE=0
fi
-# The host and target platforms are the same in Windows build. So we don't have
-# to distinct them. This helps avoid building the same targets twice.
-echo "build --distinct_host_configuration=false" >> "${TMP_BAZELRC}"
-
# Enable short object file path to avoid long path issue on Windows.
echo "startup --output_user_root=${TMPDIR}" >> "${TMP_BAZELRC}"
diff --git a/tensorflow/tools/docker/Dockerfile.devel b/tensorflow/tools/docker/Dockerfile.devel
index f7fe4119da..fd94d64268 100644
--- a/tensorflow/tools/docker/Dockerfile.devel
+++ b/tensorflow/tools/docker/Dockerfile.devel
@@ -63,7 +63,7 @@ RUN echo "startup --batch" >>/etc/bazel.bazelrc
RUN echo "build --spawn_strategy=standalone --genrule_strategy=standalone" \
>>/etc/bazel.bazelrc
# Install the most recent bazel release.
-ENV BAZEL_VERSION 0.15.0
+ENV BAZEL_VERSION 0.14.1
WORKDIR /
RUN mkdir /bazel && \
cd /bazel && \
diff --git a/tensorflow/tools/docker/Dockerfile.devel-gpu b/tensorflow/tools/docker/Dockerfile.devel-gpu
index 957a7ed799..44120bf274 100644
--- a/tensorflow/tools/docker/Dockerfile.devel-gpu
+++ b/tensorflow/tools/docker/Dockerfile.devel-gpu
@@ -83,7 +83,7 @@ RUN echo "startup --batch" >>/etc/bazel.bazelrc
RUN echo "build --spawn_strategy=standalone --genrule_strategy=standalone" \
>>/etc/bazel.bazelrc
# Install the most recent bazel release.
-ENV BAZEL_VERSION 0.15.0
+ENV BAZEL_VERSION 0.14.1
WORKDIR /
RUN mkdir /bazel && \
cd /bazel && \
diff --git a/tensorflow/tools/docker/Dockerfile.devel-gpu-cuda9-cudnn7 b/tensorflow/tools/docker/Dockerfile.devel-gpu-cuda9-cudnn7
index 30bc2d2806..3bedc8cf34 100644
--- a/tensorflow/tools/docker/Dockerfile.devel-gpu-cuda9-cudnn7
+++ b/tensorflow/tools/docker/Dockerfile.devel-gpu-cuda9-cudnn7
@@ -4,7 +4,7 @@ LABEL maintainer="Gunhan Gulsoy <gunan@google.com>"
# It is possible to override these for releases.
ARG TF_BRANCH=master
-ARG BAZEL_VERSION=0.15.0
+ARG BAZEL_VERSION=0.5.4
ARG TF_AVAILABLE_CPUS=32
RUN apt-get update && apt-get install -y --no-install-recommends \
diff --git a/tensorflow/tools/docker/notebooks/1_hello_tensorflow.ipynb b/tensorflow/tools/docker/notebooks/1_hello_tensorflow.ipynb
index 0633b03259..8fa871ef77 100644
--- a/tensorflow/tools/docker/notebooks/1_hello_tensorflow.ipynb
+++ b/tensorflow/tools/docker/notebooks/1_hello_tensorflow.ipynb
@@ -665,7 +665,7 @@
"source": [
"## What's next?\n",
"\n",
- "This has been a gentle introduction to TensorFlow, focused on what TensorFlow is and the very basics of doing anything in TensorFlow. If you'd like more, the next tutorial in the series is Getting Started with TensorFlow, also available in the [notebooks directory](..)."
+ "This has been a gentle introduction to TensorFlow, focused on what TensorFlow is and the very basics of doing anything in TensorFlow. If you'd like more, the next tutorial in the series is Getting Started with TensorFlow, also available in the [notebooks directory](../notebooks)."
]
}
],
diff --git a/tensorflow/tools/docs/generate.py b/tensorflow/tools/docs/generate.py
index fc93085e3e..f96887e4c7 100644
--- a/tensorflow/tools/docs/generate.py
+++ b/tensorflow/tools/docs/generate.py
@@ -31,6 +31,11 @@ if __name__ == '__main__':
doc_generator = generate_lib.DocGenerator()
doc_generator.add_output_dir_argument()
doc_generator.add_src_dir_argument()
+ doc_generator.argument_parser.add_argument(
+ '--site_api_path',
+ type=str, default='api_docs/python',
+ help='The path from the site-root to api_docs'
+ 'directory for this project')
# This doc generator works on the TensorFlow codebase. Since this script lives
# at tensorflow/tools/docs, and all code is defined somewhere inside
diff --git a/tensorflow/tools/docs/generate_lib.py b/tensorflow/tools/docs/generate_lib.py
index e7634cd5dc..4f70a69364 100644
--- a/tensorflow/tools/docs/generate_lib.py
+++ b/tensorflow/tools/docs/generate_lib.py
@@ -55,7 +55,8 @@ def write_docs(output_dir,
parser_config,
yaml_toc,
root_title='TensorFlow',
- search_hints=True):
+ search_hints=True,
+ site_api_path=None):
"""Write previously extracted docs to disk.
Write a docs page for each symbol included in the indices of parser_config to
@@ -73,6 +74,8 @@ def write_docs(output_dir,
root_title: The title name for the root level index.md.
search_hints: (bool) include meta-data search hints at the top of each
output file.
+ site_api_path: Used to write the api-duplicates _redirects.yaml file. if
+ None (the default) the file is not generated.
Raises:
ValueError: if `output_dir` is not an absolute path
@@ -92,6 +95,9 @@ def write_docs(output_dir,
# - symbol name(string):pathname (string)
symbol_to_file = {}
+ # Collect redirects for an api _redirects.yaml file.
+ redirects = ['redirects:\n']
+
# Parse and write Markdown pages, resolving cross-links (@{symbol}).
for full_name, py_object in six.iteritems(parser_config.index):
parser_config.reference_resolver.current_doc_full_name = full_name
@@ -150,6 +156,25 @@ def write_docs(output_dir,
raise OSError(
'Cannot write documentation for %s to %s' % (full_name, directory))
+ if site_api_path:
+ duplicates = parser_config.duplicates.get(full_name, [])
+ if not duplicates:
+ continue
+
+ duplicates = [item for item in duplicates if item != full_name]
+ template = ('- from: /{}\n'
+ ' to: /{}\n')
+ for dup in duplicates:
+ from_path = os.path.join(site_api_path, dup.replace('.', '/'))
+ to_path = os.path.join(site_api_path, full_name.replace('.', '/'))
+ redirects.append(
+ template.format(from_path, to_path))
+
+ if site_api_path:
+ api_redirects_path = os.path.join(output_dir, '_redirects.yaml')
+ with open(api_redirects_path, 'w') as redirect_file:
+ redirect_file.write(''.join(redirects))
+
if yaml_toc:
# Generate table of contents
@@ -608,7 +633,8 @@ class DocGenerator(object):
parser_config,
yaml_toc=self.yaml_toc,
root_title=root_title,
- search_hints=getattr(flags, 'search_hints', True))
+ search_hints=getattr(flags, 'search_hints', True),
+ site_api_path=getattr(flags, 'site_api_path', None))
# Replace all the @{} references in files under `FLAGS.src_dir`
replace_refs(flags.src_dir, flags.output_dir, reference_resolver, '*.md')
diff --git a/tensorflow/tools/docs/generate_lib_test.py b/tensorflow/tools/docs/generate_lib_test.py
index 7a6f9fd9f7..de18b13254 100644
--- a/tensorflow/tools/docs/generate_lib_test.py
+++ b/tensorflow/tools/docs/generate_lib_test.py
@@ -107,7 +107,18 @@ class GenerateTest(googletest.TestCase):
output_dir = googletest.GetTempDir()
- generate_lib.write_docs(output_dir, parser_config, yaml_toc=True)
+ generate_lib.write_docs(output_dir, parser_config, yaml_toc=True,
+ site_api_path='api_docs/python')
+
+ # Check redirects
+ redirects_file = os.path.join(output_dir, '_redirects.yaml')
+ self.assertTrue(os.path.exists(redirects_file))
+ with open(redirects_file) as f:
+ redirects = f.read()
+ self.assertEqual(redirects.split(), [
+ 'redirects:', '-', 'from:', '/api_docs/python/tf/test_function', 'to:',
+ '/api_docs/python/tf/TestModule/test_function'
+ ])
# Make sure that the right files are written to disk.
self.assertTrue(os.path.exists(os.path.join(output_dir, 'index.md')))
diff --git a/tensorflow/tools/pip_package/build_pip_package.sh b/tensorflow/tools/pip_package/build_pip_package.sh
index 4101b34a11..ca40f2eaa8 100755
--- a/tensorflow/tools/pip_package/build_pip_package.sh
+++ b/tensorflow/tools/pip_package/build_pip_package.sh
@@ -17,8 +17,12 @@
set -e
+function is_absolute {
+ [[ "$1" = /* ]] || [[ "$1" =~ ^[a-zA-Z]:[/\\].* ]]
+}
+
function real_path() {
- [[ $1 = /* ]] && echo "$1" || echo "$PWD/${1#./}"
+ is_absolute "$1" && echo "$1" || echo "$PWD/${1#./}"
}
function cp_external() {
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 4b4f31813c..2c8658fc59 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -487,11 +487,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "llvm",
urls = [
- "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/10c3b3d15ed6a788ac12221b784caf81fb8248b5.tar.gz",
- "https://github.com/llvm-mirror/llvm/archive/10c3b3d15ed6a788ac12221b784caf81fb8248b5.tar.gz",
+ "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/62b518b75a780a3bc75982cbe54b0e7bc262aa6e.tar.gz",
+ "https://github.com/llvm-mirror/llvm/archive/62b518b75a780a3bc75982cbe54b0e7bc262aa6e.tar.gz",
],
- sha256 = "a9feb6b47267c30fd7c19ebfdf4dbde6757054f716fa77c09bcb1106799c3253",
- strip_prefix = "llvm-10c3b3d15ed6a788ac12221b784caf81fb8248b5",
+ sha256 = "51ab0edcf7dde0207f5cf141aec16b14fcac5290112cdf1ea671a2757f719f8b",
+ strip_prefix = "llvm-62b518b75a780a3bc75982cbe54b0e7bc262aa6e",
build_file = clean_dep("//third_party/llvm:llvm.autogenerated.BUILD"),
)
diff --git a/third_party/examples/eager/spinn/spinn.py b/third_party/examples/eager/spinn/spinn.py
index c242ef3fdd..de63ebe9e6 100644
--- a/third_party/examples/eager/spinn/spinn.py
+++ b/third_party/examples/eager/spinn/spinn.py
@@ -626,7 +626,7 @@ def train_or_infer_spinn(embed,
model = SNLIClassifier(config, embed)
global_step = tf.train.get_or_create_global_step()
trainer = SNLIClassifierTrainer(model, config.lr)
- checkpoint = tfe.Checkpoint(trainer=trainer, global_step=global_step)
+ checkpoint = tf.train.Checkpoint(trainer=trainer, global_step=global_step)
checkpoint.restore(tf.train.latest_checkpoint(config.logdir))
if inference_sentence_pair:
diff --git a/third_party/toolchains/cpus/py/BUILD b/third_party/toolchains/cpus/py/BUILD
index c175742cbf..1235988abb 100644
--- a/third_party/toolchains/cpus/py/BUILD
+++ b/third_party/toolchains/cpus/py/BUILD
@@ -6,18 +6,24 @@ licenses(["restricted"])
package(default_visibility = ["//visibility:public"])
+# To build Python C/C++ extension on Windows, we need to link to python import library pythonXY.lib
+# See https://docs.python.org/3/extending/windows.html
+cc_import(
+ name = "python_lib",
+ interface_library = select({
+ ":windows": ":python_import_lib",
+ # A placeholder for Unix platforms which makes --no_build happy.
+ "//conditions:default": "not-existing.lib",
+ }),
+ system_provided = 1,
+)
+
cc_library(
name = "python_headers",
hdrs = [":python_include"],
- data = select({
- ":windows": [":python_import_lib"],
- "//conditions:default": [],
- }),
includes = ["python_include"],
- linkopts = select({
- # TODO(pcloudy): Ideally, this should just go into deps after resolving
- # https://github.com/bazelbuild/bazel/issues/3237,
- ":windows": ["$(locations :python_import_lib)"],
+ deps = select({
+ ":windows": [":python_lib"],
"//conditions:default": [],
}),
)
@@ -37,161 +43,135 @@ config_setting(
genrule(
name = "python_include",
outs = [
+ "python_include/Python-ast.h",
+ "python_include/Python.h",
+ "python_include/abstract.h",
+ "python_include/asdl.h",
+ "python_include/ast.h",
+ "python_include/bitset.h",
+ "python_include/boolobject.h",
+ "python_include/bufferobject.h",
+ "python_include/bytearrayobject.h",
+ "python_include/bytes_methods.h",
+ "python_include/bytesobject.h",
+ "python_include/cStringIO.h",
+ "python_include/cellobject.h",
+ "python_include/ceval.h",
+ "python_include/classobject.h",
+ "python_include/cobject.h",
"python_include/code.h",
+ "python_include/codecs.h",
+ "python_include/compile.h",
+ "python_include/complexobject.h",
+ "python_include/datetime.h",
+ "python_include/descrobject.h",
+ "python_include/dictobject.h",
"python_include/dtoa.h",
- "python_include/tupleobject.h",
- "python_include/object.h",
- "python_include/ast.h",
- "python_include/pymacconfig.h",
+ "python_include/enumobject.h",
"python_include/errcode.h",
+ "python_include/eval.h",
+ "python_include/fileobject.h",
+ "python_include/floatobject.h",
"python_include/frameobject.h",
- "python_include/pgenheaders.h",
- "python_include/cellobject.h",
+ "python_include/funcobject.h",
+ "python_include/genobject.h",
+ "python_include/graminit.h",
+ "python_include/grammar.h",
+ "python_include/import.h",
"python_include/intobject.h",
- "python_include/pythread.h",
- "python_include/cStringIO.h",
- "python_include/boolobject.h",
+ "python_include/intrcheck.h",
+ "python_include/iterobject.h",
+ "python_include/listobject.h",
+ "python_include/longintrepr.h",
+ "python_include/longobject.h",
+ "python_include/marshal.h",
+ "python_include/memoryobject.h",
+ "python_include/metagrammar.h",
+ "python_include/methodobject.h",
"python_include/modsupport.h",
- "python_include/import.h",
- "python_include/pymath.h",
+ "python_include/moduleobject.h",
"python_include/node.h",
- "python_include/funcobject.h",
- "python_include/eval.h",
- "python_include/longintrepr.h",
- "python_include/floatobject.h",
- "python_include/rangeobject.h",
- "python_include/pyfpe.h",
- "python_include/pystrcmp.h",
- "python_include/dictobject.h",
- "python_include/pyarena.h",
+ "python_include/object.h",
"python_include/objimpl.h",
- "python_include/bitset.h",
- "python_include/memoryobject.h",
- "python_include/bytearrayobject.h",
+ "python_include/opcode.h",
+ "python_include/osdefs.h",
+ "python_include/parsetok.h",
+ "python_include/patchlevel.h",
+ "python_include/pgen.h",
+ "python_include/pgenheaders.h",
+ "python_include/py_curses.h",
+ "python_include/pyarena.h",
+ "python_include/pycapsule.h",
+ "python_include/pyconfig.h",
+ "python_include/pyctype.h",
"python_include/pydebug.h",
"python_include/pyerrors.h",
- "python_include/weakrefobject.h",
- "python_include/grammar.h",
- "python_include/symtable.h",
- "python_include/longobject.h",
- "python_include/structmember.h",
- "python_include/enumobject.h",
- "python_include/classobject.h",
- "python_include/unicodeobject.h",
- "python_include/sliceobject.h",
- "python_include/pystrtod.h",
- "python_include/genobject.h",
- "python_include/pymactoolbox.h",
- "python_include/compile.h",
"python_include/pyexpat.h",
- "python_include/asdl.h",
- "python_include/codecs.h",
- "python_include/pyctype.h",
- "python_include/sysmodule.h",
- "python_include/methodobject.h",
- "python_include/graminit.h",
- "python_include/cobject.h",
- "python_include/intrcheck.h",
- "python_include/pyport.h",
- "python_include/warnings.h",
- "python_include/osdefs.h",
- "python_include/fileobject.h",
- "python_include/stringobject.h",
- "python_include/timefuncs.h",
- "python_include/traceback.h",
- "python_include/ceval.h",
- "python_include/bytes_methods.h",
- "python_include/pyconfig.h",
- "python_include/Python.h",
- "python_include/moduleobject.h",
- "python_include/pystate.h",
- "python_include/descrobject.h",
- "python_include/ucnhash.h",
+ "python_include/pyfpe.h",
"python_include/pygetopt.h",
+ "python_include/pymacconfig.h",
+ "python_include/pymactoolbox.h",
+ "python_include/pymath.h",
"python_include/pymem.h",
- "python_include/complexobject.h",
- "python_include/structseq.h",
- "python_include/datetime.h",
+ "python_include/pyport.h",
+ "python_include/pystate.h",
+ "python_include/pystrcmp.h",
+ "python_include/pystrtod.h",
"python_include/pythonrun.h",
- "python_include/numpy/oldnumeric.h",
- "python_include/numpy/npy_1_7_deprecated_api.h",
- "python_include/numpy/ufunc_api.txt",
- "python_include/numpy/multiarray_api.txt",
- "python_include/numpy/halffloat.h",
- "python_include/numpy/npy_common.h",
- "python_include/numpy/utils.h",
- "python_include/numpy/npy_interrupt.h",
- "python_include/numpy/npy_endian.h",
- "python_include/numpy/__ufunc_api.h",
- "python_include/numpy/_neighborhood_iterator_imp.h",
- "python_include/numpy/ufuncobject.h",
- "python_include/numpy/ndarraytypes.h",
- "python_include/numpy/npy_math.h",
- "python_include/numpy/noprefix.h",
- "python_include/numpy/npy_3kcompat.h",
- "python_include/numpy/arrayscalars.h",
- "python_include/numpy/npy_os.h",
- "python_include/numpy/ndarrayobject.h",
- "python_include/numpy/npy_no_deprecated_api.h",
- "python_include/numpy/arrayobject.h",
- "python_include/numpy/_numpyconfig.h",
- "python_include/numpy/__multiarray_api.h",
- "python_include/numpy/npy_cpu.h",
- "python_include/numpy/old_defines.h",
- "python_include/numpy/numpyconfig.h",
- "python_include/pycapsule.h",
+ "python_include/pythread.h",
+ "python_include/rangeobject.h",
"python_include/setobject.h",
- "python_include/listobject.h",
- "python_include/bytesobject.h",
- "python_include/pgen.h",
- "python_include/patchlevel.h",
- "python_include/opcode.h",
- "python_include/parsetok.h",
- "python_include/marshal.h",
+ "python_include/sliceobject.h",
+ "python_include/stringobject.h",
+ "python_include/structmember.h",
+ "python_include/structseq.h",
+ "python_include/symtable.h",
+ "python_include/sysmodule.h",
+ "python_include/timefuncs.h",
"python_include/token.h",
- "python_include/iterobject.h",
- "python_include/abstract.h",
- "python_include/py_curses.h",
- "python_include/metagrammar.h",
- "python_include/bufferobject.h",
- "python_include/Python-ast.h",
+ "python_include/traceback.h",
+ "python_include/tupleobject.h",
+ "python_include/ucnhash.h",
+ "python_include/unicodeobject.h",
+ "python_include/warnings.h",
+ "python_include/weakrefobject.h",
],
cmd = """
-cp "/usr/include/python2.7/code.h" "$(@D)/python_include/code.h" && cp "/usr/include/python2.7/dtoa.h" "$(@D)/python_include/dtoa.h" && cp "/usr/include/python2.7/tupleobject.h" "$(@D)/python_include/tupleobject.h" && cp "/usr/include/python2.7/object.h" "$(@D)/python_include/object.h" && cp "/usr/include/python2.7/ast.h" "$(@D)/python_include/ast.h" && cp "/usr/include/python2.7/pymacconfig.h" "$(@D)/python_include/pymacconfig.h" && cp "/usr/include/python2.7/errcode.h" "$(@D)/python_include/errcode.h" && cp "/usr/include/python2.7/frameobject.h" "$(@D)/python_include/frameobject.h" && cp "/usr/include/python2.7/pgenheaders.h" "$(@D)/python_include/pgenheaders.h" && cp "/usr/include/python2.7/cellobject.h" "$(@D)/python_include/cellobject.h" && cp "/usr/include/python2.7/intobject.h" "$(@D)/python_include/intobject.h" && cp "/usr/include/python2.7/pythread.h" "$(@D)/python_include/pythread.h" && cp "/usr/include/python2.7/cStringIO.h" "$(@D)/python_include/cStringIO.h" && cp "/usr/include/python2.7/boolobject.h" "$(@D)/python_include/boolobject.h" && cp "/usr/include/python2.7/modsupport.h" "$(@D)/python_include/modsupport.h" && cp "/usr/include/python2.7/import.h" "$(@D)/python_include/import.h" && cp "/usr/include/python2.7/pymath.h" "$(@D)/python_include/pymath.h" && cp "/usr/include/python2.7/node.h" "$(@D)/python_include/node.h" && cp "/usr/include/python2.7/funcobject.h" "$(@D)/python_include/funcobject.h" && cp "/usr/include/python2.7/eval.h" "$(@D)/python_include/eval.h" && cp "/usr/include/python2.7/longintrepr.h" "$(@D)/python_include/longintrepr.h" && cp "/usr/include/python2.7/floatobject.h" "$(@D)/python_include/floatobject.h" && cp "/usr/include/python2.7/rangeobject.h" "$(@D)/python_include/rangeobject.h" && cp "/usr/include/python2.7/pyfpe.h" "$(@D)/python_include/pyfpe.h" && cp "/usr/include/python2.7/pystrcmp.h" "$(@D)/python_include/pystrcmp.h" && cp "/usr/include/python2.7/dictobject.h" "$(@D)/python_include/dictobject.h" && cp "/usr/include/python2.7/pyarena.h" "$(@D)/python_include/pyarena.h" && cp "/usr/include/python2.7/objimpl.h" "$(@D)/python_include/objimpl.h" && cp "/usr/include/python2.7/bitset.h" "$(@D)/python_include/bitset.h" && cp "/usr/include/python2.7/memoryobject.h" "$(@D)/python_include/memoryobject.h" && cp "/usr/include/python2.7/bytearrayobject.h" "$(@D)/python_include/bytearrayobject.h" && cp "/usr/include/python2.7/pydebug.h" "$(@D)/python_include/pydebug.h" && cp "/usr/include/python2.7/pyerrors.h" "$(@D)/python_include/pyerrors.h" && cp "/usr/include/python2.7/weakrefobject.h" "$(@D)/python_include/weakrefobject.h" && cp "/usr/include/python2.7/grammar.h" "$(@D)/python_include/grammar.h" && cp "/usr/include/python2.7/symtable.h" "$(@D)/python_include/symtable.h" && cp "/usr/include/python2.7/longobject.h" "$(@D)/python_include/longobject.h" && cp "/usr/include/python2.7/structmember.h" "$(@D)/python_include/structmember.h" && cp "/usr/include/python2.7/enumobject.h" "$(@D)/python_include/enumobject.h" && cp "/usr/include/python2.7/classobject.h" "$(@D)/python_include/classobject.h" && cp "/usr/include/python2.7/unicodeobject.h" "$(@D)/python_include/unicodeobject.h" && cp "/usr/include/python2.7/sliceobject.h" "$(@D)/python_include/sliceobject.h" && cp "/usr/include/python2.7/pystrtod.h" "$(@D)/python_include/pystrtod.h" && cp "/usr/include/python2.7/genobject.h" "$(@D)/python_include/genobject.h" && cp "/usr/include/python2.7/pymactoolbox.h" "$(@D)/python_include/pymactoolbox.h" && cp "/usr/include/python2.7/compile.h" "$(@D)/python_include/compile.h" && cp "/usr/include/python2.7/pyexpat.h" "$(@D)/python_include/pyexpat.h" && cp "/usr/include/python2.7/asdl.h" "$(@D)/python_include/asdl.h" && cp "/usr/include/python2.7/codecs.h" "$(@D)/python_include/codecs.h" && cp "/usr/include/python2.7/pyctype.h" "$(@D)/python_include/pyctype.h" && cp "/usr/include/python2.7/sysmodule.h" "$(@D)/python_include/sysmodule.h" && cp "/usr/include/python2.7/methodobject.h" "$(@D)/python_include/methodobject.h" && cp "/usr/include/python2.7/graminit.h" "$(@D)/python_include/graminit.h" && cp "/usr/include/python2.7/cobject.h" "$(@D)/python_include/cobject.h" && cp "/usr/include/python2.7/intrcheck.h" "$(@D)/python_include/intrcheck.h" && cp "/usr/include/python2.7/pyport.h" "$(@D)/python_include/pyport.h" && cp "/usr/include/python2.7/warnings.h" "$(@D)/python_include/warnings.h" && cp "/usr/include/python2.7/osdefs.h" "$(@D)/python_include/osdefs.h" && cp "/usr/include/python2.7/fileobject.h" "$(@D)/python_include/fileobject.h" && cp "/usr/include/python2.7/stringobject.h" "$(@D)/python_include/stringobject.h" && cp "/usr/include/python2.7/timefuncs.h" "$(@D)/python_include/timefuncs.h" && cp "/usr/include/python2.7/traceback.h" "$(@D)/python_include/traceback.h" && cp "/usr/include/python2.7/ceval.h" "$(@D)/python_include/ceval.h" && cp "/usr/include/python2.7/bytes_methods.h" "$(@D)/python_include/bytes_methods.h" && cp "/usr/include/python2.7/pyconfig.h" "$(@D)/python_include/pyconfig.h" && cp "/usr/include/python2.7/Python.h" "$(@D)/python_include/Python.h" && cp "/usr/include/python2.7/moduleobject.h" "$(@D)/python_include/moduleobject.h" && cp "/usr/include/python2.7/pystate.h" "$(@D)/python_include/pystate.h" && cp "/usr/include/python2.7/descrobject.h" "$(@D)/python_include/descrobject.h" && cp "/usr/include/python2.7/ucnhash.h" "$(@D)/python_include/ucnhash.h" && cp "/usr/include/python2.7/pygetopt.h" "$(@D)/python_include/pygetopt.h" && cp "/usr/include/python2.7/pymem.h" "$(@D)/python_include/pymem.h" && cp "/usr/include/python2.7/complexobject.h" "$(@D)/python_include/complexobject.h" && cp "/usr/include/python2.7/structseq.h" "$(@D)/python_include/structseq.h" && cp "/usr/include/python2.7/datetime.h" "$(@D)/python_include/datetime.h" && cp "/usr/include/python2.7/pythonrun.h" "$(@D)/python_include/pythonrun.h" && cp "/usr/include/python2.7/numpy/oldnumeric.h" "$(@D)/python_include/numpy/oldnumeric.h" && cp "/usr/include/python2.7/numpy/npy_1_7_deprecated_api.h" "$(@D)/python_include/numpy/npy_1_7_deprecated_api.h" && cp "/usr/include/python2.7/numpy/ufunc_api.txt" "$(@D)/python_include/numpy/ufunc_api.txt" && cp "/usr/include/python2.7/numpy/multiarray_api.txt" "$(@D)/python_include/numpy/multiarray_api.txt" && cp "/usr/include/python2.7/numpy/halffloat.h" "$(@D)/python_include/numpy/halffloat.h" && cp "/usr/include/python2.7/numpy/npy_common.h" "$(@D)/python_include/numpy/npy_common.h" && cp "/usr/include/python2.7/numpy/utils.h" "$(@D)/python_include/numpy/utils.h" && cp "/usr/include/python2.7/numpy/npy_interrupt.h" "$(@D)/python_include/numpy/npy_interrupt.h" && cp "/usr/include/python2.7/numpy/npy_endian.h" "$(@D)/python_include/numpy/npy_endian.h" && cp "/usr/include/python2.7/numpy/__ufunc_api.h" "$(@D)/python_include/numpy/__ufunc_api.h" && cp "/usr/include/python2.7/numpy/_neighborhood_iterator_imp.h" "$(@D)/python_include/numpy/_neighborhood_iterator_imp.h" && cp "/usr/include/python2.7/numpy/ufuncobject.h" "$(@D)/python_include/numpy/ufuncobject.h" && cp "/usr/include/python2.7/numpy/ndarraytypes.h" "$(@D)/python_include/numpy/ndarraytypes.h" && cp "/usr/include/python2.7/numpy/npy_math.h" "$(@D)/python_include/numpy/npy_math.h" && cp "/usr/include/python2.7/numpy/noprefix.h" "$(@D)/python_include/numpy/noprefix.h" && cp "/usr/include/python2.7/numpy/npy_3kcompat.h" "$(@D)/python_include/numpy/npy_3kcompat.h" && cp "/usr/include/python2.7/numpy/arrayscalars.h" "$(@D)/python_include/numpy/arrayscalars.h" && cp "/usr/include/python2.7/numpy/npy_os.h" "$(@D)/python_include/numpy/npy_os.h" && cp "/usr/include/python2.7/numpy/ndarrayobject.h" "$(@D)/python_include/numpy/ndarrayobject.h" && cp "/usr/include/python2.7/numpy/npy_no_deprecated_api.h" "$(@D)/python_include/numpy/npy_no_deprecated_api.h" && cp "/usr/include/python2.7/numpy/arrayobject.h" "$(@D)/python_include/numpy/arrayobject.h" && cp "/usr/include/python2.7/numpy/_numpyconfig.h" "$(@D)/python_include/numpy/_numpyconfig.h" && cp "/usr/include/python2.7/numpy/__multiarray_api.h" "$(@D)/python_include/numpy/__multiarray_api.h" && cp "/usr/include/python2.7/numpy/npy_cpu.h" "$(@D)/python_include/numpy/npy_cpu.h" && cp "/usr/include/python2.7/numpy/old_defines.h" "$(@D)/python_include/numpy/old_defines.h" && cp "/usr/include/python2.7/numpy/numpyconfig.h" "$(@D)/python_include/numpy/numpyconfig.h" && cp "/usr/include/python2.7/pycapsule.h" "$(@D)/python_include/pycapsule.h" && cp "/usr/include/python2.7/setobject.h" "$(@D)/python_include/setobject.h" && cp "/usr/include/python2.7/listobject.h" "$(@D)/python_include/listobject.h" && cp "/usr/include/python2.7/bytesobject.h" "$(@D)/python_include/bytesobject.h" && cp "/usr/include/python2.7/pgen.h" "$(@D)/python_include/pgen.h" && cp "/usr/include/python2.7/patchlevel.h" "$(@D)/python_include/patchlevel.h" && cp "/usr/include/python2.7/opcode.h" "$(@D)/python_include/opcode.h" && cp "/usr/include/python2.7/parsetok.h" "$(@D)/python_include/parsetok.h" && cp "/usr/include/python2.7/marshal.h" "$(@D)/python_include/marshal.h" && cp "/usr/include/python2.7/token.h" "$(@D)/python_include/token.h" && cp "/usr/include/python2.7/iterobject.h" "$(@D)/python_include/iterobject.h" && cp "/usr/include/python2.7/abstract.h" "$(@D)/python_include/abstract.h" && cp "/usr/include/python2.7/py_curses.h" "$(@D)/python_include/py_curses.h" && cp "/usr/include/python2.7/metagrammar.h" "$(@D)/python_include/metagrammar.h" && cp "/usr/include/python2.7/bufferobject.h" "$(@D)/python_include/bufferobject.h" && cp "/usr/include/python2.7/Python-ast.h" "$(@D)/python_include/Python-ast.h"
+cp "/usr/include/python2.7/Python-ast.h" "$(@D)/python_include/Python-ast.h" && cp "/usr/include/python2.7/Python.h" "$(@D)/python_include/Python.h" && cp "/usr/include/python2.7/abstract.h" "$(@D)/python_include/abstract.h" && cp "/usr/include/python2.7/asdl.h" "$(@D)/python_include/asdl.h" && cp "/usr/include/python2.7/ast.h" "$(@D)/python_include/ast.h" && cp "/usr/include/python2.7/bitset.h" "$(@D)/python_include/bitset.h" && cp "/usr/include/python2.7/boolobject.h" "$(@D)/python_include/boolobject.h" && cp "/usr/include/python2.7/bufferobject.h" "$(@D)/python_include/bufferobject.h" && cp "/usr/include/python2.7/bytearrayobject.h" "$(@D)/python_include/bytearrayobject.h" && cp "/usr/include/python2.7/bytes_methods.h" "$(@D)/python_include/bytes_methods.h" && cp "/usr/include/python2.7/bytesobject.h" "$(@D)/python_include/bytesobject.h" && cp "/usr/include/python2.7/cStringIO.h" "$(@D)/python_include/cStringIO.h" && cp "/usr/include/python2.7/cellobject.h" "$(@D)/python_include/cellobject.h" && cp "/usr/include/python2.7/ceval.h" "$(@D)/python_include/ceval.h" && cp "/usr/include/python2.7/classobject.h" "$(@D)/python_include/classobject.h" && cp "/usr/include/python2.7/cobject.h" "$(@D)/python_include/cobject.h" && cp "/usr/include/python2.7/code.h" "$(@D)/python_include/code.h" && cp "/usr/include/python2.7/codecs.h" "$(@D)/python_include/codecs.h" && cp "/usr/include/python2.7/compile.h" "$(@D)/python_include/compile.h" && cp "/usr/include/python2.7/complexobject.h" "$(@D)/python_include/complexobject.h" && cp "/usr/include/python2.7/datetime.h" "$(@D)/python_include/datetime.h" && cp "/usr/include/python2.7/descrobject.h" "$(@D)/python_include/descrobject.h" && cp "/usr/include/python2.7/dictobject.h" "$(@D)/python_include/dictobject.h" && cp "/usr/include/python2.7/dtoa.h" "$(@D)/python_include/dtoa.h" && cp "/usr/include/python2.7/enumobject.h" "$(@D)/python_include/enumobject.h" && cp "/usr/include/python2.7/errcode.h" "$(@D)/python_include/errcode.h" && cp "/usr/include/python2.7/eval.h" "$(@D)/python_include/eval.h" && cp "/usr/include/python2.7/fileobject.h" "$(@D)/python_include/fileobject.h" && cp "/usr/include/python2.7/floatobject.h" "$(@D)/python_include/floatobject.h" && cp "/usr/include/python2.7/frameobject.h" "$(@D)/python_include/frameobject.h" && cp "/usr/include/python2.7/funcobject.h" "$(@D)/python_include/funcobject.h" && cp "/usr/include/python2.7/genobject.h" "$(@D)/python_include/genobject.h" && cp "/usr/include/python2.7/graminit.h" "$(@D)/python_include/graminit.h" && cp "/usr/include/python2.7/grammar.h" "$(@D)/python_include/grammar.h" && cp "/usr/include/python2.7/import.h" "$(@D)/python_include/import.h" && cp "/usr/include/python2.7/intobject.h" "$(@D)/python_include/intobject.h" && cp "/usr/include/python2.7/intrcheck.h" "$(@D)/python_include/intrcheck.h" && cp "/usr/include/python2.7/iterobject.h" "$(@D)/python_include/iterobject.h" && cp "/usr/include/python2.7/listobject.h" "$(@D)/python_include/listobject.h" && cp "/usr/include/python2.7/longintrepr.h" "$(@D)/python_include/longintrepr.h" && cp "/usr/include/python2.7/longobject.h" "$(@D)/python_include/longobject.h" && cp "/usr/include/python2.7/marshal.h" "$(@D)/python_include/marshal.h" && cp "/usr/include/python2.7/memoryobject.h" "$(@D)/python_include/memoryobject.h" && cp "/usr/include/python2.7/metagrammar.h" "$(@D)/python_include/metagrammar.h" && cp "/usr/include/python2.7/methodobject.h" "$(@D)/python_include/methodobject.h" && cp "/usr/include/python2.7/modsupport.h" "$(@D)/python_include/modsupport.h" && cp "/usr/include/python2.7/moduleobject.h" "$(@D)/python_include/moduleobject.h" && cp "/usr/include/python2.7/node.h" "$(@D)/python_include/node.h" && cp "/usr/include/python2.7/object.h" "$(@D)/python_include/object.h" && cp "/usr/include/python2.7/objimpl.h" "$(@D)/python_include/objimpl.h" && cp "/usr/include/python2.7/opcode.h" "$(@D)/python_include/opcode.h" && cp "/usr/include/python2.7/osdefs.h" "$(@D)/python_include/osdefs.h" && cp "/usr/include/python2.7/parsetok.h" "$(@D)/python_include/parsetok.h" && cp "/usr/include/python2.7/patchlevel.h" "$(@D)/python_include/patchlevel.h" && cp "/usr/include/python2.7/pgen.h" "$(@D)/python_include/pgen.h" && cp "/usr/include/python2.7/pgenheaders.h" "$(@D)/python_include/pgenheaders.h" && cp "/usr/include/python2.7/py_curses.h" "$(@D)/python_include/py_curses.h" && cp "/usr/include/python2.7/pyarena.h" "$(@D)/python_include/pyarena.h" && cp "/usr/include/python2.7/pycapsule.h" "$(@D)/python_include/pycapsule.h" && cp "/usr/include/python2.7/pyconfig.h" "$(@D)/python_include/pyconfig.h" && cp "/usr/include/python2.7/pyctype.h" "$(@D)/python_include/pyctype.h" && cp "/usr/include/python2.7/pydebug.h" "$(@D)/python_include/pydebug.h" && cp "/usr/include/python2.7/pyerrors.h" "$(@D)/python_include/pyerrors.h" && cp "/usr/include/python2.7/pyexpat.h" "$(@D)/python_include/pyexpat.h" && cp "/usr/include/python2.7/pyfpe.h" "$(@D)/python_include/pyfpe.h" && cp "/usr/include/python2.7/pygetopt.h" "$(@D)/python_include/pygetopt.h" && cp "/usr/include/python2.7/pymacconfig.h" "$(@D)/python_include/pymacconfig.h" && cp "/usr/include/python2.7/pymactoolbox.h" "$(@D)/python_include/pymactoolbox.h" && cp "/usr/include/python2.7/pymath.h" "$(@D)/python_include/pymath.h" && cp "/usr/include/python2.7/pymem.h" "$(@D)/python_include/pymem.h" && cp "/usr/include/python2.7/pyport.h" "$(@D)/python_include/pyport.h" && cp "/usr/include/python2.7/pystate.h" "$(@D)/python_include/pystate.h" && cp "/usr/include/python2.7/pystrcmp.h" "$(@D)/python_include/pystrcmp.h" && cp "/usr/include/python2.7/pystrtod.h" "$(@D)/python_include/pystrtod.h" && cp "/usr/include/python2.7/pythonrun.h" "$(@D)/python_include/pythonrun.h" && cp "/usr/include/python2.7/pythread.h" "$(@D)/python_include/pythread.h" && cp "/usr/include/python2.7/rangeobject.h" "$(@D)/python_include/rangeobject.h" && cp "/usr/include/python2.7/setobject.h" "$(@D)/python_include/setobject.h" && cp "/usr/include/python2.7/sliceobject.h" "$(@D)/python_include/sliceobject.h" && cp "/usr/include/python2.7/stringobject.h" "$(@D)/python_include/stringobject.h" && cp "/usr/include/python2.7/structmember.h" "$(@D)/python_include/structmember.h" && cp "/usr/include/python2.7/structseq.h" "$(@D)/python_include/structseq.h" && cp "/usr/include/python2.7/symtable.h" "$(@D)/python_include/symtable.h" && cp "/usr/include/python2.7/sysmodule.h" "$(@D)/python_include/sysmodule.h" && cp "/usr/include/python2.7/timefuncs.h" "$(@D)/python_include/timefuncs.h" && cp "/usr/include/python2.7/token.h" "$(@D)/python_include/token.h" && cp "/usr/include/python2.7/traceback.h" "$(@D)/python_include/traceback.h" && cp "/usr/include/python2.7/tupleobject.h" "$(@D)/python_include/tupleobject.h" && cp "/usr/include/python2.7/ucnhash.h" "$(@D)/python_include/ucnhash.h" && cp "/usr/include/python2.7/unicodeobject.h" "$(@D)/python_include/unicodeobject.h" && cp "/usr/include/python2.7/warnings.h" "$(@D)/python_include/warnings.h" && cp "/usr/include/python2.7/weakrefobject.h" "$(@D)/python_include/weakrefobject.h"
""",
)
genrule(
name = "numpy_include",
outs = [
- "numpy_include/numpy/oldnumeric.h",
- "numpy_include/numpy/npy_1_7_deprecated_api.h",
- "numpy_include/numpy/ufunc_api.txt",
- "numpy_include/numpy/multiarray_api.txt",
- "numpy_include/numpy/halffloat.h",
- "numpy_include/numpy/npy_common.h",
- "numpy_include/numpy/utils.h",
- "numpy_include/numpy/npy_interrupt.h",
- "numpy_include/numpy/npy_endian.h",
+ "numpy_include/numpy/__multiarray_api.h",
"numpy_include/numpy/__ufunc_api.h",
"numpy_include/numpy/_neighborhood_iterator_imp.h",
- "numpy_include/numpy/ufuncobject.h",
+ "numpy_include/numpy/_numpyconfig.h",
+ "numpy_include/numpy/arrayobject.h",
+ "numpy_include/numpy/arrayscalars.h",
+ "numpy_include/numpy/halffloat.h",
+ "numpy_include/numpy/multiarray_api.txt",
+ "numpy_include/numpy/ndarrayobject.h",
"numpy_include/numpy/ndarraytypes.h",
- "numpy_include/numpy/npy_math.h",
"numpy_include/numpy/noprefix.h",
+ "numpy_include/numpy/npy_1_7_deprecated_api.h",
"numpy_include/numpy/npy_3kcompat.h",
- "numpy_include/numpy/arrayscalars.h",
- "numpy_include/numpy/npy_os.h",
- "numpy_include/numpy/ndarrayobject.h",
- "numpy_include/numpy/npy_no_deprecated_api.h",
- "numpy_include/numpy/arrayobject.h",
- "numpy_include/numpy/_numpyconfig.h",
- "numpy_include/numpy/__multiarray_api.h",
+ "numpy_include/numpy/npy_common.h",
"numpy_include/numpy/npy_cpu.h",
- "numpy_include/numpy/old_defines.h",
+ "numpy_include/numpy/npy_endian.h",
+ "numpy_include/numpy/npy_interrupt.h",
+ "numpy_include/numpy/npy_math.h",
+ "numpy_include/numpy/npy_no_deprecated_api.h",
+ "numpy_include/numpy/npy_os.h",
"numpy_include/numpy/numpyconfig.h",
+ "numpy_include/numpy/old_defines.h",
+ "numpy_include/numpy/oldnumeric.h",
+ "numpy_include/numpy/ufunc_api.txt",
+ "numpy_include/numpy/ufuncobject.h",
+ "numpy_include/numpy/utils.h",
],
cmd = """
-cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/oldnumeric.h" "$(@D)/numpy_include/numpy/oldnumeric.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_1_7_deprecated_api.h" "$(@D)/numpy_include/numpy/npy_1_7_deprecated_api.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/ufunc_api.txt" "$(@D)/numpy_include/numpy/ufunc_api.txt" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/multiarray_api.txt" "$(@D)/numpy_include/numpy/multiarray_api.txt" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/halffloat.h" "$(@D)/numpy_include/numpy/halffloat.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_common.h" "$(@D)/numpy_include/numpy/npy_common.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/utils.h" "$(@D)/numpy_include/numpy/utils.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_interrupt.h" "$(@D)/numpy_include/numpy/npy_interrupt.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_endian.h" "$(@D)/numpy_include/numpy/npy_endian.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/__ufunc_api.h" "$(@D)/numpy_include/numpy/__ufunc_api.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/_neighborhood_iterator_imp.h" "$(@D)/numpy_include/numpy/_neighborhood_iterator_imp.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/ufuncobject.h" "$(@D)/numpy_include/numpy/ufuncobject.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/ndarraytypes.h" "$(@D)/numpy_include/numpy/ndarraytypes.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_math.h" "$(@D)/numpy_include/numpy/npy_math.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/noprefix.h" "$(@D)/numpy_include/numpy/noprefix.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_3kcompat.h" "$(@D)/numpy_include/numpy/npy_3kcompat.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/arrayscalars.h" "$(@D)/numpy_include/numpy/arrayscalars.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_os.h" "$(@D)/numpy_include/numpy/npy_os.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/ndarrayobject.h" "$(@D)/numpy_include/numpy/ndarrayobject.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_no_deprecated_api.h" "$(@D)/numpy_include/numpy/npy_no_deprecated_api.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/arrayobject.h" "$(@D)/numpy_include/numpy/arrayobject.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/_numpyconfig.h" "$(@D)/numpy_include/numpy/_numpyconfig.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/__multiarray_api.h" "$(@D)/numpy_include/numpy/__multiarray_api.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_cpu.h" "$(@D)/numpy_include/numpy/npy_cpu.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/old_defines.h" "$(@D)/numpy_include/numpy/old_defines.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/numpyconfig.h" "$(@D)/numpy_include/numpy/numpyconfig.h"
+cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/__multiarray_api.h" "$(@D)/numpy_include/numpy/__multiarray_api.h" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/__ufunc_api.h" "$(@D)/numpy_include/numpy/__ufunc_api.h" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/_neighborhood_iterator_imp.h" "$(@D)/numpy_include/numpy/_neighborhood_iterator_imp.h" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/_numpyconfig.h" "$(@D)/numpy_include/numpy/_numpyconfig.h" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/arrayobject.h" "$(@D)/numpy_include/numpy/arrayobject.h" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/arrayscalars.h" "$(@D)/numpy_include/numpy/arrayscalars.h" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/halffloat.h" "$(@D)/numpy_include/numpy/halffloat.h" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/multiarray_api.txt" "$(@D)/numpy_include/numpy/multiarray_api.txt" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/ndarrayobject.h" "$(@D)/numpy_include/numpy/ndarrayobject.h" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/ndarraytypes.h" "$(@D)/numpy_include/numpy/ndarraytypes.h" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/noprefix.h" "$(@D)/numpy_include/numpy/noprefix.h" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_1_7_deprecated_api.h" "$(@D)/numpy_include/numpy/npy_1_7_deprecated_api.h" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_3kcompat.h" "$(@D)/numpy_include/numpy/npy_3kcompat.h" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_common.h" "$(@D)/numpy_include/numpy/npy_common.h" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_cpu.h" "$(@D)/numpy_include/numpy/npy_cpu.h" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_endian.h" "$(@D)/numpy_include/numpy/npy_endian.h" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_interrupt.h" "$(@D)/numpy_include/numpy/npy_interrupt.h" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_math.h" "$(@D)/numpy_include/numpy/npy_math.h" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_no_deprecated_api.h" "$(@D)/numpy_include/numpy/npy_no_deprecated_api.h" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_os.h" "$(@D)/numpy_include/numpy/npy_os.h" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/numpyconfig.h" "$(@D)/numpy_include/numpy/numpyconfig.h" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/old_defines.h" "$(@D)/numpy_include/numpy/old_defines.h" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/oldnumeric.h" "$(@D)/numpy_include/numpy/oldnumeric.h" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/ufunc_api.txt" "$(@D)/numpy_include/numpy/ufunc_api.txt" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/ufuncobject.h" "$(@D)/numpy_include/numpy/ufuncobject.h" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/utils.h" "$(@D)/numpy_include/numpy/utils.h"
""",
)
diff --git a/third_party/toolchains/cpus/py3/BUILD b/third_party/toolchains/cpus/py3/BUILD
index 932a25239f..d47256ebef 100644
--- a/third_party/toolchains/cpus/py3/BUILD
+++ b/third_party/toolchains/cpus/py3/BUILD
@@ -6,18 +6,24 @@ licenses(["restricted"])
package(default_visibility = ["//visibility:public"])
+# To build Python C/C++ extension on Windows, we need to link to python import library pythonXY.lib
+# See https://docs.python.org/3/extending/windows.html
+cc_import(
+ name = "python_lib",
+ interface_library = select({
+ ":windows": ":python_import_lib",
+ # A placeholder for Unix platforms which makes --no_build happy.
+ "//conditions:default": "not-existing.lib",
+ }),
+ system_provided = 1,
+)
+
cc_library(
name = "python_headers",
hdrs = [":python_include"],
- data = select({
- ":windows": [":python_import_lib"],
- "//conditions:default": [],
- }),
includes = ["python_include"],
- linkopts = select({
- # TODO(pcloudy): Ideally, this should just go into deps after resolving
- # https://github.com/bazelbuild/bazel/issues/3237,
- ":windows": ["$(locations :python_import_lib)"],
+ deps = select({
+ ":windows": [":python_lib"],
"//conditions:default": [],
}),
)
@@ -37,143 +43,143 @@ config_setting(
genrule(
name = "python_include",
outs = [
- "python_include/code.h",
- "python_include/dtoa.h",
- "python_include/tupleobject.h",
- "python_include/object.h",
- "python_include/ast.h",
- "python_include/pymacconfig.h",
- "python_include/errcode.h",
- "python_include/frameobject.h",
- "python_include/typeslots.h",
- "python_include/pgenheaders.h",
- "python_include/cellobject.h",
- "python_include/pythread.h",
- "python_include/boolobject.h",
+ "python_include/Python-ast.h",
+ "python_include/Python.h",
+ "python_include/abstract.h",
"python_include/accu.h",
- "python_include/modsupport.h",
- "python_include/import.h",
- "python_include/pymath.h",
- "python_include/node.h",
- "python_include/funcobject.h",
- "python_include/eval.h",
- "python_include/pyatomic.h",
- "python_include/longintrepr.h",
- "python_include/floatobject.h",
- "python_include/rangeobject.h",
- "python_include/pyfpe.h",
- "python_include/pystrcmp.h",
- "python_include/fileutils.h",
- "python_include/dictobject.h",
- "python_include/pyarena.h",
- "python_include/osmodule.h",
- "python_include/objimpl.h",
+ "python_include/asdl.h",
+ "python_include/ast.h",
"python_include/bitset.h",
- "python_include/memoryobject.h",
+ "python_include/bltinmodule.h",
+ "python_include/boolobject.h",
"python_include/bytearrayobject.h",
- "python_include/pydebug.h",
- "python_include/pyerrors.h",
- "python_include/weakrefobject.h",
- "python_include/grammar.h",
- "python_include/symtable.h",
- "python_include/longobject.h",
- "python_include/structmember.h",
- "python_include/enumobject.h",
- "python_include/pymacro.h",
+ "python_include/bytes_methods.h",
+ "python_include/bytesobject.h",
+ "python_include/cellobject.h",
+ "python_include/ceval.h",
"python_include/classobject.h",
- "python_include/unicodeobject.h",
- "python_include/sliceobject.h",
- "python_include/pystrtod.h",
- "python_include/genobject.h",
- "python_include/compile.h",
- "python_include/pyexpat.h",
- "python_include/asdl.h",
+ "python_include/code.h",
"python_include/codecs.h",
+ "python_include/compile.h",
+ "python_include/complexobject.h",
+ "python_include/datetime.h",
+ "python_include/descrobject.h",
+ "python_include/dictobject.h",
+ "python_include/dtoa.h",
"python_include/dynamic_annotations.h",
- "python_include/pyctype.h",
- "python_include/sysmodule.h",
- "python_include/methodobject.h",
+ "python_include/enumobject.h",
+ "python_include/errcode.h",
+ "python_include/eval.h",
+ "python_include/fileobject.h",
+ "python_include/fileutils.h",
+ "python_include/floatobject.h",
+ "python_include/frameobject.h",
+ "python_include/funcobject.h",
+ "python_include/genobject.h",
"python_include/graminit.h",
- "python_include/bltinmodule.h",
+ "python_include/grammar.h",
+ "python_include/import.h",
"python_include/intrcheck.h",
- "python_include/pyport.h",
- "python_include/warnings.h",
- "python_include/osdefs.h",
- "python_include/pydtrace.h",
- "python_include/pylifecycle.h",
- "python_include/fileobject.h",
- "python_include/pytime.h",
- "python_include/traceback.h",
- "python_include/ceval.h",
- "python_include/bytes_methods.h",
- "python_include/namespaceobject.h",
- "python_include/pyconfig.h",
- "python_include/Python.h",
+ "python_include/iterobject.h",
+ "python_include/listobject.h",
+ "python_include/longintrepr.h",
+ "python_include/longobject.h",
+ "python_include/marshal.h",
+ "python_include/memoryobject.h",
+ "python_include/metagrammar.h",
+ "python_include/methodobject.h",
+ "python_include/modsupport.h",
"python_include/moduleobject.h",
- "python_include/pystate.h",
- "python_include/descrobject.h",
+ "python_include/namespaceobject.h",
+ "python_include/node.h",
+ "python_include/object.h",
+ "python_include/objimpl.h",
"python_include/odictobject.h",
- "python_include/ucnhash.h",
+ "python_include/opcode.h",
+ "python_include/osdefs.h",
+ "python_include/osmodule.h",
+ "python_include/parsetok.h",
+ "python_include/patchlevel.h",
+ "python_include/pgen.h",
+ "python_include/pgenheaders.h",
+ "python_include/py_curses.h",
+ "python_include/pyarena.h",
+ "python_include/pyatomic.h",
+ "python_include/pycapsule.h",
+ "python_include/pyconfig.h",
+ "python_include/pyctype.h",
+ "python_include/pydebug.h",
+ "python_include/pydtrace.h",
+ "python_include/pyerrors.h",
+ "python_include/pyexpat.h",
+ "python_include/pyfpe.h",
"python_include/pygetopt.h",
+ "python_include/pyhash.h",
+ "python_include/pylifecycle.h",
+ "python_include/pymacconfig.h",
+ "python_include/pymacro.h",
+ "python_include/pymath.h",
"python_include/pymem.h",
- "python_include/complexobject.h",
- "python_include/structseq.h",
- "python_include/datetime.h",
+ "python_include/pyport.h",
+ "python_include/pystate.h",
+ "python_include/pystrcmp.h",
+ "python_include/pystrhex.h",
+ "python_include/pystrtod.h",
"python_include/pythonrun.h",
- "python_include/pyhash.h",
- "python_include/pycapsule.h",
+ "python_include/pythread.h",
+ "python_include/pytime.h",
+ "python_include/rangeobject.h",
"python_include/setobject.h",
- "python_include/listobject.h",
- "python_include/bytesobject.h",
- "python_include/pgen.h",
- "python_include/patchlevel.h",
- "python_include/opcode.h",
- "python_include/parsetok.h",
- "python_include/pystrhex.h",
- "python_include/marshal.h",
+ "python_include/sliceobject.h",
+ "python_include/structmember.h",
+ "python_include/structseq.h",
+ "python_include/symtable.h",
+ "python_include/sysmodule.h",
"python_include/token.h",
- "python_include/iterobject.h",
- "python_include/abstract.h",
- "python_include/py_curses.h",
- "python_include/metagrammar.h",
- "python_include/Python-ast.h",
+ "python_include/traceback.h",
+ "python_include/tupleobject.h",
+ "python_include/typeslots.h",
+ "python_include/ucnhash.h",
+ "python_include/unicodeobject.h",
+ "python_include/warnings.h",
+ "python_include/weakrefobject.h",
],
cmd = """
-cp "/opt/python3.6/include/python3.6m/code.h" "$(@D)/python_include/code.h" && cp "/opt/python3.6/include/python3.6m/dtoa.h" "$(@D)/python_include/dtoa.h" && cp "/opt/python3.6/include/python3.6m/tupleobject.h" "$(@D)/python_include/tupleobject.h" && cp "/opt/python3.6/include/python3.6m/object.h" "$(@D)/python_include/object.h" && cp "/opt/python3.6/include/python3.6m/ast.h" "$(@D)/python_include/ast.h" && cp "/opt/python3.6/include/python3.6m/pymacconfig.h" "$(@D)/python_include/pymacconfig.h" && cp "/opt/python3.6/include/python3.6m/errcode.h" "$(@D)/python_include/errcode.h" && cp "/opt/python3.6/include/python3.6m/frameobject.h" "$(@D)/python_include/frameobject.h" && cp "/opt/python3.6/include/python3.6m/typeslots.h" "$(@D)/python_include/typeslots.h" && cp "/opt/python3.6/include/python3.6m/pgenheaders.h" "$(@D)/python_include/pgenheaders.h" && cp "/opt/python3.6/include/python3.6m/cellobject.h" "$(@D)/python_include/cellobject.h" && cp "/opt/python3.6/include/python3.6m/pythread.h" "$(@D)/python_include/pythread.h" && cp "/opt/python3.6/include/python3.6m/boolobject.h" "$(@D)/python_include/boolobject.h" && cp "/opt/python3.6/include/python3.6m/accu.h" "$(@D)/python_include/accu.h" && cp "/opt/python3.6/include/python3.6m/modsupport.h" "$(@D)/python_include/modsupport.h" && cp "/opt/python3.6/include/python3.6m/import.h" "$(@D)/python_include/import.h" && cp "/opt/python3.6/include/python3.6m/pymath.h" "$(@D)/python_include/pymath.h" && cp "/opt/python3.6/include/python3.6m/node.h" "$(@D)/python_include/node.h" && cp "/opt/python3.6/include/python3.6m/funcobject.h" "$(@D)/python_include/funcobject.h" && cp "/opt/python3.6/include/python3.6m/eval.h" "$(@D)/python_include/eval.h" && cp "/opt/python3.6/include/python3.6m/pyatomic.h" "$(@D)/python_include/pyatomic.h" && cp "/opt/python3.6/include/python3.6m/longintrepr.h" "$(@D)/python_include/longintrepr.h" && cp "/opt/python3.6/include/python3.6m/floatobject.h" "$(@D)/python_include/floatobject.h" && cp "/opt/python3.6/include/python3.6m/rangeobject.h" "$(@D)/python_include/rangeobject.h" && cp "/opt/python3.6/include/python3.6m/pyfpe.h" "$(@D)/python_include/pyfpe.h" && cp "/opt/python3.6/include/python3.6m/pystrcmp.h" "$(@D)/python_include/pystrcmp.h" && cp "/opt/python3.6/include/python3.6m/fileutils.h" "$(@D)/python_include/fileutils.h" && cp "/opt/python3.6/include/python3.6m/dictobject.h" "$(@D)/python_include/dictobject.h" && cp "/opt/python3.6/include/python3.6m/pyarena.h" "$(@D)/python_include/pyarena.h" && cp "/opt/python3.6/include/python3.6m/osmodule.h" "$(@D)/python_include/osmodule.h" && cp "/opt/python3.6/include/python3.6m/objimpl.h" "$(@D)/python_include/objimpl.h" && cp "/opt/python3.6/include/python3.6m/bitset.h" "$(@D)/python_include/bitset.h" && cp "/opt/python3.6/include/python3.6m/memoryobject.h" "$(@D)/python_include/memoryobject.h" && cp "/opt/python3.6/include/python3.6m/bytearrayobject.h" "$(@D)/python_include/bytearrayobject.h" && cp "/opt/python3.6/include/python3.6m/pydebug.h" "$(@D)/python_include/pydebug.h" && cp "/opt/python3.6/include/python3.6m/pyerrors.h" "$(@D)/python_include/pyerrors.h" && cp "/opt/python3.6/include/python3.6m/weakrefobject.h" "$(@D)/python_include/weakrefobject.h" && cp "/opt/python3.6/include/python3.6m/grammar.h" "$(@D)/python_include/grammar.h" && cp "/opt/python3.6/include/python3.6m/symtable.h" "$(@D)/python_include/symtable.h" && cp "/opt/python3.6/include/python3.6m/longobject.h" "$(@D)/python_include/longobject.h" && cp "/opt/python3.6/include/python3.6m/structmember.h" "$(@D)/python_include/structmember.h" && cp "/opt/python3.6/include/python3.6m/enumobject.h" "$(@D)/python_include/enumobject.h" && cp "/opt/python3.6/include/python3.6m/pymacro.h" "$(@D)/python_include/pymacro.h" && cp "/opt/python3.6/include/python3.6m/classobject.h" "$(@D)/python_include/classobject.h" && cp "/opt/python3.6/include/python3.6m/unicodeobject.h" "$(@D)/python_include/unicodeobject.h" && cp "/opt/python3.6/include/python3.6m/sliceobject.h" "$(@D)/python_include/sliceobject.h" && cp "/opt/python3.6/include/python3.6m/pystrtod.h" "$(@D)/python_include/pystrtod.h" && cp "/opt/python3.6/include/python3.6m/genobject.h" "$(@D)/python_include/genobject.h" && cp "/opt/python3.6/include/python3.6m/compile.h" "$(@D)/python_include/compile.h" && cp "/opt/python3.6/include/python3.6m/pyexpat.h" "$(@D)/python_include/pyexpat.h" && cp "/opt/python3.6/include/python3.6m/asdl.h" "$(@D)/python_include/asdl.h" && cp "/opt/python3.6/include/python3.6m/codecs.h" "$(@D)/python_include/codecs.h" && cp "/opt/python3.6/include/python3.6m/dynamic_annotations.h" "$(@D)/python_include/dynamic_annotations.h" && cp "/opt/python3.6/include/python3.6m/pyctype.h" "$(@D)/python_include/pyctype.h" && cp "/opt/python3.6/include/python3.6m/sysmodule.h" "$(@D)/python_include/sysmodule.h" && cp "/opt/python3.6/include/python3.6m/methodobject.h" "$(@D)/python_include/methodobject.h" && cp "/opt/python3.6/include/python3.6m/graminit.h" "$(@D)/python_include/graminit.h" && cp "/opt/python3.6/include/python3.6m/bltinmodule.h" "$(@D)/python_include/bltinmodule.h" && cp "/opt/python3.6/include/python3.6m/intrcheck.h" "$(@D)/python_include/intrcheck.h" && cp "/opt/python3.6/include/python3.6m/pyport.h" "$(@D)/python_include/pyport.h" && cp "/opt/python3.6/include/python3.6m/warnings.h" "$(@D)/python_include/warnings.h" && cp "/opt/python3.6/include/python3.6m/osdefs.h" "$(@D)/python_include/osdefs.h" && cp "/opt/python3.6/include/python3.6m/pydtrace.h" "$(@D)/python_include/pydtrace.h" && cp "/opt/python3.6/include/python3.6m/pylifecycle.h" "$(@D)/python_include/pylifecycle.h" && cp "/opt/python3.6/include/python3.6m/fileobject.h" "$(@D)/python_include/fileobject.h" && cp "/opt/python3.6/include/python3.6m/pytime.h" "$(@D)/python_include/pytime.h" && cp "/opt/python3.6/include/python3.6m/traceback.h" "$(@D)/python_include/traceback.h" && cp "/opt/python3.6/include/python3.6m/ceval.h" "$(@D)/python_include/ceval.h" && cp "/opt/python3.6/include/python3.6m/bytes_methods.h" "$(@D)/python_include/bytes_methods.h" && cp "/opt/python3.6/include/python3.6m/namespaceobject.h" "$(@D)/python_include/namespaceobject.h" && cp "/opt/python3.6/include/python3.6m/pyconfig.h" "$(@D)/python_include/pyconfig.h" && cp "/opt/python3.6/include/python3.6m/Python.h" "$(@D)/python_include/Python.h" && cp "/opt/python3.6/include/python3.6m/moduleobject.h" "$(@D)/python_include/moduleobject.h" && cp "/opt/python3.6/include/python3.6m/pystate.h" "$(@D)/python_include/pystate.h" && cp "/opt/python3.6/include/python3.6m/descrobject.h" "$(@D)/python_include/descrobject.h" && cp "/opt/python3.6/include/python3.6m/odictobject.h" "$(@D)/python_include/odictobject.h" && cp "/opt/python3.6/include/python3.6m/ucnhash.h" "$(@D)/python_include/ucnhash.h" && cp "/opt/python3.6/include/python3.6m/pygetopt.h" "$(@D)/python_include/pygetopt.h" && cp "/opt/python3.6/include/python3.6m/pymem.h" "$(@D)/python_include/pymem.h" && cp "/opt/python3.6/include/python3.6m/complexobject.h" "$(@D)/python_include/complexobject.h" && cp "/opt/python3.6/include/python3.6m/structseq.h" "$(@D)/python_include/structseq.h" && cp "/opt/python3.6/include/python3.6m/datetime.h" "$(@D)/python_include/datetime.h" && cp "/opt/python3.6/include/python3.6m/pythonrun.h" "$(@D)/python_include/pythonrun.h" && cp "/opt/python3.6/include/python3.6m/pyhash.h" "$(@D)/python_include/pyhash.h" && cp "/opt/python3.6/include/python3.6m/pycapsule.h" "$(@D)/python_include/pycapsule.h" && cp "/opt/python3.6/include/python3.6m/setobject.h" "$(@D)/python_include/setobject.h" && cp "/opt/python3.6/include/python3.6m/listobject.h" "$(@D)/python_include/listobject.h" && cp "/opt/python3.6/include/python3.6m/bytesobject.h" "$(@D)/python_include/bytesobject.h" && cp "/opt/python3.6/include/python3.6m/pgen.h" "$(@D)/python_include/pgen.h" && cp "/opt/python3.6/include/python3.6m/patchlevel.h" "$(@D)/python_include/patchlevel.h" && cp "/opt/python3.6/include/python3.6m/opcode.h" "$(@D)/python_include/opcode.h" && cp "/opt/python3.6/include/python3.6m/parsetok.h" "$(@D)/python_include/parsetok.h" && cp "/opt/python3.6/include/python3.6m/pystrhex.h" "$(@D)/python_include/pystrhex.h" && cp "/opt/python3.6/include/python3.6m/marshal.h" "$(@D)/python_include/marshal.h" && cp "/opt/python3.6/include/python3.6m/token.h" "$(@D)/python_include/token.h" && cp "/opt/python3.6/include/python3.6m/iterobject.h" "$(@D)/python_include/iterobject.h" && cp "/opt/python3.6/include/python3.6m/abstract.h" "$(@D)/python_include/abstract.h" && cp "/opt/python3.6/include/python3.6m/py_curses.h" "$(@D)/python_include/py_curses.h" && cp "/opt/python3.6/include/python3.6m/metagrammar.h" "$(@D)/python_include/metagrammar.h" && cp "/opt/python3.6/include/python3.6m/Python-ast.h" "$(@D)/python_include/Python-ast.h"
+cp "/opt/python3.6/include/python3.6m/Python-ast.h" "$(@D)/python_include/Python-ast.h" && cp "/opt/python3.6/include/python3.6m/Python.h" "$(@D)/python_include/Python.h" && cp "/opt/python3.6/include/python3.6m/abstract.h" "$(@D)/python_include/abstract.h" && cp "/opt/python3.6/include/python3.6m/accu.h" "$(@D)/python_include/accu.h" && cp "/opt/python3.6/include/python3.6m/asdl.h" "$(@D)/python_include/asdl.h" && cp "/opt/python3.6/include/python3.6m/ast.h" "$(@D)/python_include/ast.h" && cp "/opt/python3.6/include/python3.6m/bitset.h" "$(@D)/python_include/bitset.h" && cp "/opt/python3.6/include/python3.6m/bltinmodule.h" "$(@D)/python_include/bltinmodule.h" && cp "/opt/python3.6/include/python3.6m/boolobject.h" "$(@D)/python_include/boolobject.h" && cp "/opt/python3.6/include/python3.6m/bytearrayobject.h" "$(@D)/python_include/bytearrayobject.h" && cp "/opt/python3.6/include/python3.6m/bytes_methods.h" "$(@D)/python_include/bytes_methods.h" && cp "/opt/python3.6/include/python3.6m/bytesobject.h" "$(@D)/python_include/bytesobject.h" && cp "/opt/python3.6/include/python3.6m/cellobject.h" "$(@D)/python_include/cellobject.h" && cp "/opt/python3.6/include/python3.6m/ceval.h" "$(@D)/python_include/ceval.h" && cp "/opt/python3.6/include/python3.6m/classobject.h" "$(@D)/python_include/classobject.h" && cp "/opt/python3.6/include/python3.6m/code.h" "$(@D)/python_include/code.h" && cp "/opt/python3.6/include/python3.6m/codecs.h" "$(@D)/python_include/codecs.h" && cp "/opt/python3.6/include/python3.6m/compile.h" "$(@D)/python_include/compile.h" && cp "/opt/python3.6/include/python3.6m/complexobject.h" "$(@D)/python_include/complexobject.h" && cp "/opt/python3.6/include/python3.6m/datetime.h" "$(@D)/python_include/datetime.h" && cp "/opt/python3.6/include/python3.6m/descrobject.h" "$(@D)/python_include/descrobject.h" && cp "/opt/python3.6/include/python3.6m/dictobject.h" "$(@D)/python_include/dictobject.h" && cp "/opt/python3.6/include/python3.6m/dtoa.h" "$(@D)/python_include/dtoa.h" && cp "/opt/python3.6/include/python3.6m/dynamic_annotations.h" "$(@D)/python_include/dynamic_annotations.h" && cp "/opt/python3.6/include/python3.6m/enumobject.h" "$(@D)/python_include/enumobject.h" && cp "/opt/python3.6/include/python3.6m/errcode.h" "$(@D)/python_include/errcode.h" && cp "/opt/python3.6/include/python3.6m/eval.h" "$(@D)/python_include/eval.h" && cp "/opt/python3.6/include/python3.6m/fileobject.h" "$(@D)/python_include/fileobject.h" && cp "/opt/python3.6/include/python3.6m/fileutils.h" "$(@D)/python_include/fileutils.h" && cp "/opt/python3.6/include/python3.6m/floatobject.h" "$(@D)/python_include/floatobject.h" && cp "/opt/python3.6/include/python3.6m/frameobject.h" "$(@D)/python_include/frameobject.h" && cp "/opt/python3.6/include/python3.6m/funcobject.h" "$(@D)/python_include/funcobject.h" && cp "/opt/python3.6/include/python3.6m/genobject.h" "$(@D)/python_include/genobject.h" && cp "/opt/python3.6/include/python3.6m/graminit.h" "$(@D)/python_include/graminit.h" && cp "/opt/python3.6/include/python3.6m/grammar.h" "$(@D)/python_include/grammar.h" && cp "/opt/python3.6/include/python3.6m/import.h" "$(@D)/python_include/import.h" && cp "/opt/python3.6/include/python3.6m/intrcheck.h" "$(@D)/python_include/intrcheck.h" && cp "/opt/python3.6/include/python3.6m/iterobject.h" "$(@D)/python_include/iterobject.h" && cp "/opt/python3.6/include/python3.6m/listobject.h" "$(@D)/python_include/listobject.h" && cp "/opt/python3.6/include/python3.6m/longintrepr.h" "$(@D)/python_include/longintrepr.h" && cp "/opt/python3.6/include/python3.6m/longobject.h" "$(@D)/python_include/longobject.h" && cp "/opt/python3.6/include/python3.6m/marshal.h" "$(@D)/python_include/marshal.h" && cp "/opt/python3.6/include/python3.6m/memoryobject.h" "$(@D)/python_include/memoryobject.h" && cp "/opt/python3.6/include/python3.6m/metagrammar.h" "$(@D)/python_include/metagrammar.h" && cp "/opt/python3.6/include/python3.6m/methodobject.h" "$(@D)/python_include/methodobject.h" && cp "/opt/python3.6/include/python3.6m/modsupport.h" "$(@D)/python_include/modsupport.h" && cp "/opt/python3.6/include/python3.6m/moduleobject.h" "$(@D)/python_include/moduleobject.h" && cp "/opt/python3.6/include/python3.6m/namespaceobject.h" "$(@D)/python_include/namespaceobject.h" && cp "/opt/python3.6/include/python3.6m/node.h" "$(@D)/python_include/node.h" && cp "/opt/python3.6/include/python3.6m/object.h" "$(@D)/python_include/object.h" && cp "/opt/python3.6/include/python3.6m/objimpl.h" "$(@D)/python_include/objimpl.h" && cp "/opt/python3.6/include/python3.6m/odictobject.h" "$(@D)/python_include/odictobject.h" && cp "/opt/python3.6/include/python3.6m/opcode.h" "$(@D)/python_include/opcode.h" && cp "/opt/python3.6/include/python3.6m/osdefs.h" "$(@D)/python_include/osdefs.h" && cp "/opt/python3.6/include/python3.6m/osmodule.h" "$(@D)/python_include/osmodule.h" && cp "/opt/python3.6/include/python3.6m/parsetok.h" "$(@D)/python_include/parsetok.h" && cp "/opt/python3.6/include/python3.6m/patchlevel.h" "$(@D)/python_include/patchlevel.h" && cp "/opt/python3.6/include/python3.6m/pgen.h" "$(@D)/python_include/pgen.h" && cp "/opt/python3.6/include/python3.6m/pgenheaders.h" "$(@D)/python_include/pgenheaders.h" && cp "/opt/python3.6/include/python3.6m/py_curses.h" "$(@D)/python_include/py_curses.h" && cp "/opt/python3.6/include/python3.6m/pyarena.h" "$(@D)/python_include/pyarena.h" && cp "/opt/python3.6/include/python3.6m/pyatomic.h" "$(@D)/python_include/pyatomic.h" && cp "/opt/python3.6/include/python3.6m/pycapsule.h" "$(@D)/python_include/pycapsule.h" && cp "/opt/python3.6/include/python3.6m/pyconfig.h" "$(@D)/python_include/pyconfig.h" && cp "/opt/python3.6/include/python3.6m/pyctype.h" "$(@D)/python_include/pyctype.h" && cp "/opt/python3.6/include/python3.6m/pydebug.h" "$(@D)/python_include/pydebug.h" && cp "/opt/python3.6/include/python3.6m/pydtrace.h" "$(@D)/python_include/pydtrace.h" && cp "/opt/python3.6/include/python3.6m/pyerrors.h" "$(@D)/python_include/pyerrors.h" && cp "/opt/python3.6/include/python3.6m/pyexpat.h" "$(@D)/python_include/pyexpat.h" && cp "/opt/python3.6/include/python3.6m/pyfpe.h" "$(@D)/python_include/pyfpe.h" && cp "/opt/python3.6/include/python3.6m/pygetopt.h" "$(@D)/python_include/pygetopt.h" && cp "/opt/python3.6/include/python3.6m/pyhash.h" "$(@D)/python_include/pyhash.h" && cp "/opt/python3.6/include/python3.6m/pylifecycle.h" "$(@D)/python_include/pylifecycle.h" && cp "/opt/python3.6/include/python3.6m/pymacconfig.h" "$(@D)/python_include/pymacconfig.h" && cp "/opt/python3.6/include/python3.6m/pymacro.h" "$(@D)/python_include/pymacro.h" && cp "/opt/python3.6/include/python3.6m/pymath.h" "$(@D)/python_include/pymath.h" && cp "/opt/python3.6/include/python3.6m/pymem.h" "$(@D)/python_include/pymem.h" && cp "/opt/python3.6/include/python3.6m/pyport.h" "$(@D)/python_include/pyport.h" && cp "/opt/python3.6/include/python3.6m/pystate.h" "$(@D)/python_include/pystate.h" && cp "/opt/python3.6/include/python3.6m/pystrcmp.h" "$(@D)/python_include/pystrcmp.h" && cp "/opt/python3.6/include/python3.6m/pystrhex.h" "$(@D)/python_include/pystrhex.h" && cp "/opt/python3.6/include/python3.6m/pystrtod.h" "$(@D)/python_include/pystrtod.h" && cp "/opt/python3.6/include/python3.6m/pythonrun.h" "$(@D)/python_include/pythonrun.h" && cp "/opt/python3.6/include/python3.6m/pythread.h" "$(@D)/python_include/pythread.h" && cp "/opt/python3.6/include/python3.6m/pytime.h" "$(@D)/python_include/pytime.h" && cp "/opt/python3.6/include/python3.6m/rangeobject.h" "$(@D)/python_include/rangeobject.h" && cp "/opt/python3.6/include/python3.6m/setobject.h" "$(@D)/python_include/setobject.h" && cp "/opt/python3.6/include/python3.6m/sliceobject.h" "$(@D)/python_include/sliceobject.h" && cp "/opt/python3.6/include/python3.6m/structmember.h" "$(@D)/python_include/structmember.h" && cp "/opt/python3.6/include/python3.6m/structseq.h" "$(@D)/python_include/structseq.h" && cp "/opt/python3.6/include/python3.6m/symtable.h" "$(@D)/python_include/symtable.h" && cp "/opt/python3.6/include/python3.6m/sysmodule.h" "$(@D)/python_include/sysmodule.h" && cp "/opt/python3.6/include/python3.6m/token.h" "$(@D)/python_include/token.h" && cp "/opt/python3.6/include/python3.6m/traceback.h" "$(@D)/python_include/traceback.h" && cp "/opt/python3.6/include/python3.6m/tupleobject.h" "$(@D)/python_include/tupleobject.h" && cp "/opt/python3.6/include/python3.6m/typeslots.h" "$(@D)/python_include/typeslots.h" && cp "/opt/python3.6/include/python3.6m/ucnhash.h" "$(@D)/python_include/ucnhash.h" && cp "/opt/python3.6/include/python3.6m/unicodeobject.h" "$(@D)/python_include/unicodeobject.h" && cp "/opt/python3.6/include/python3.6m/warnings.h" "$(@D)/python_include/warnings.h" && cp "/opt/python3.6/include/python3.6m/weakrefobject.h" "$(@D)/python_include/weakrefobject.h"
""",
)
genrule(
name = "numpy_include",
outs = [
- "numpy_include/numpy/oldnumeric.h",
- "numpy_include/numpy/npy_1_7_deprecated_api.h",
- "numpy_include/numpy/ufunc_api.txt",
- "numpy_include/numpy/multiarray_api.txt",
- "numpy_include/numpy/halffloat.h",
- "numpy_include/numpy/npy_common.h",
- "numpy_include/numpy/utils.h",
- "numpy_include/numpy/npy_interrupt.h",
- "numpy_include/numpy/npy_endian.h",
+ "numpy_include/numpy/__multiarray_api.h",
"numpy_include/numpy/__ufunc_api.h",
"numpy_include/numpy/_neighborhood_iterator_imp.h",
- "numpy_include/numpy/ufuncobject.h",
+ "numpy_include/numpy/_numpyconfig.h",
+ "numpy_include/numpy/arrayobject.h",
+ "numpy_include/numpy/arrayscalars.h",
+ "numpy_include/numpy/halffloat.h",
+ "numpy_include/numpy/multiarray_api.txt",
+ "numpy_include/numpy/ndarrayobject.h",
"numpy_include/numpy/ndarraytypes.h",
- "numpy_include/numpy/npy_math.h",
"numpy_include/numpy/noprefix.h",
+ "numpy_include/numpy/npy_1_7_deprecated_api.h",
"numpy_include/numpy/npy_3kcompat.h",
- "numpy_include/numpy/arrayscalars.h",
- "numpy_include/numpy/npy_os.h",
- "numpy_include/numpy/ndarrayobject.h",
- "numpy_include/numpy/npy_no_deprecated_api.h",
- "numpy_include/numpy/arrayobject.h",
- "numpy_include/numpy/_numpyconfig.h",
- "numpy_include/numpy/__multiarray_api.h",
+ "numpy_include/numpy/npy_common.h",
"numpy_include/numpy/npy_cpu.h",
- "numpy_include/numpy/old_defines.h",
+ "numpy_include/numpy/npy_endian.h",
+ "numpy_include/numpy/npy_interrupt.h",
+ "numpy_include/numpy/npy_math.h",
+ "numpy_include/numpy/npy_no_deprecated_api.h",
+ "numpy_include/numpy/npy_os.h",
"numpy_include/numpy/numpyconfig.h",
+ "numpy_include/numpy/old_defines.h",
+ "numpy_include/numpy/oldnumeric.h",
+ "numpy_include/numpy/ufunc_api.txt",
+ "numpy_include/numpy/ufuncobject.h",
+ "numpy_include/numpy/utils.h",
],
cmd = """
-cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/oldnumeric.h" "$(@D)/numpy_include/numpy/oldnumeric.h" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/npy_1_7_deprecated_api.h" "$(@D)/numpy_include/numpy/npy_1_7_deprecated_api.h" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/ufunc_api.txt" "$(@D)/numpy_include/numpy/ufunc_api.txt" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/multiarray_api.txt" "$(@D)/numpy_include/numpy/multiarray_api.txt" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/halffloat.h" "$(@D)/numpy_include/numpy/halffloat.h" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/npy_common.h" "$(@D)/numpy_include/numpy/npy_common.h" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/utils.h" "$(@D)/numpy_include/numpy/utils.h" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/npy_interrupt.h" "$(@D)/numpy_include/numpy/npy_interrupt.h" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/npy_endian.h" "$(@D)/numpy_include/numpy/npy_endian.h" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/__ufunc_api.h" "$(@D)/numpy_include/numpy/__ufunc_api.h" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/_neighborhood_iterator_imp.h" "$(@D)/numpy_include/numpy/_neighborhood_iterator_imp.h" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/ufuncobject.h" "$(@D)/numpy_include/numpy/ufuncobject.h" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/ndarraytypes.h" "$(@D)/numpy_include/numpy/ndarraytypes.h" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/npy_math.h" "$(@D)/numpy_include/numpy/npy_math.h" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/noprefix.h" "$(@D)/numpy_include/numpy/noprefix.h" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/npy_3kcompat.h" "$(@D)/numpy_include/numpy/npy_3kcompat.h" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/arrayscalars.h" "$(@D)/numpy_include/numpy/arrayscalars.h" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/npy_os.h" "$(@D)/numpy_include/numpy/npy_os.h" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/ndarrayobject.h" "$(@D)/numpy_include/numpy/ndarrayobject.h" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/npy_no_deprecated_api.h" "$(@D)/numpy_include/numpy/npy_no_deprecated_api.h" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/arrayobject.h" "$(@D)/numpy_include/numpy/arrayobject.h" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/_numpyconfig.h" "$(@D)/numpy_include/numpy/_numpyconfig.h" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/__multiarray_api.h" "$(@D)/numpy_include/numpy/__multiarray_api.h" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/npy_cpu.h" "$(@D)/numpy_include/numpy/npy_cpu.h" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/old_defines.h" "$(@D)/numpy_include/numpy/old_defines.h" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/numpyconfig.h" "$(@D)/numpy_include/numpy/numpyconfig.h"
+cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/__multiarray_api.h" "$(@D)/numpy_include/numpy/__multiarray_api.h" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/__ufunc_api.h" "$(@D)/numpy_include/numpy/__ufunc_api.h" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/_neighborhood_iterator_imp.h" "$(@D)/numpy_include/numpy/_neighborhood_iterator_imp.h" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/_numpyconfig.h" "$(@D)/numpy_include/numpy/_numpyconfig.h" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/arrayobject.h" "$(@D)/numpy_include/numpy/arrayobject.h" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/arrayscalars.h" "$(@D)/numpy_include/numpy/arrayscalars.h" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/halffloat.h" "$(@D)/numpy_include/numpy/halffloat.h" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/multiarray_api.txt" "$(@D)/numpy_include/numpy/multiarray_api.txt" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/ndarrayobject.h" "$(@D)/numpy_include/numpy/ndarrayobject.h" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/ndarraytypes.h" "$(@D)/numpy_include/numpy/ndarraytypes.h" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/noprefix.h" "$(@D)/numpy_include/numpy/noprefix.h" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/npy_1_7_deprecated_api.h" "$(@D)/numpy_include/numpy/npy_1_7_deprecated_api.h" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/npy_3kcompat.h" "$(@D)/numpy_include/numpy/npy_3kcompat.h" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/npy_common.h" "$(@D)/numpy_include/numpy/npy_common.h" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/npy_cpu.h" "$(@D)/numpy_include/numpy/npy_cpu.h" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/npy_endian.h" "$(@D)/numpy_include/numpy/npy_endian.h" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/npy_interrupt.h" "$(@D)/numpy_include/numpy/npy_interrupt.h" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/npy_math.h" "$(@D)/numpy_include/numpy/npy_math.h" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/npy_no_deprecated_api.h" "$(@D)/numpy_include/numpy/npy_no_deprecated_api.h" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/npy_os.h" "$(@D)/numpy_include/numpy/npy_os.h" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/numpyconfig.h" "$(@D)/numpy_include/numpy/numpyconfig.h" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/old_defines.h" "$(@D)/numpy_include/numpy/old_defines.h" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/oldnumeric.h" "$(@D)/numpy_include/numpy/oldnumeric.h" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/ufunc_api.txt" "$(@D)/numpy_include/numpy/ufunc_api.txt" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/ufuncobject.h" "$(@D)/numpy_include/numpy/ufuncobject.h" && cp "/opt/python3.6/lib/python3.6/site-packages/numpy/core/include/numpy/utils.h" "$(@D)/numpy_include/numpy/utils.h"
""",
)