aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Gunhan Gulsoy <gunan@google.com>2018-07-02 10:09:16 -0700
committerGravatar Gunhan Gulsoy <gunan@google.com>2018-07-02 10:09:16 -0700
commite35d9ae50c5bb9ebc6e8e52ab937410fba2030fd (patch)
tree8fbadbee72dba3cdf4b94fa16d6878fb4e5b3fde /tensorflow
parenta7b7aa856f34bf2e44fbeb91d817742c61483618 (diff)
parent28b8525b417d5b0a1d0a4905e5e3237ef5b502ef (diff)
Merge commit for internal changes
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/BUILD16
-rw-r--r--tensorflow/compiler/jit/xla_device_context.cc20
-rw-r--r--tensorflow/compiler/tests/BUILD2
-rw-r--r--tensorflow/compiler/tests/sort_ops_test.py25
-rw-r--r--tensorflow/compiler/tf2xla/BUILD3
-rw-r--r--tensorflow/compiler/tf2xla/kernels/BUILD2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/categorical_op.cc18
-rw-r--r--tensorflow/compiler/tf2xla/kernels/diag_op.cc32
-rw-r--r--tensorflow/compiler/tf2xla/kernels/index_ops.cc12
-rw-r--r--tensorflow/compiler/tf2xla/kernels/pooling_ops.cc11
-rw-r--r--tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc9
-rw-r--r--tensorflow/compiler/tf2xla/kernels/random_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reduction_ops.cc13
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reduction_ops.h1
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/retval_op.cc19
-rw-r--r--tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc39
-rw-r--r--tensorflow/compiler/tf2xla/kernels/softmax_op.cc32
-rw-r--r--tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc11
-rw-r--r--tensorflow/compiler/tf2xla/kernels/topk_op.cc108
-rw-r--r--tensorflow/compiler/tf2xla/kernels/training_ops.cc72
-rw-r--r--tensorflow/compiler/tf2xla/kernels/unary_ops.cc146
-rw-r--r--tensorflow/compiler/tf2xla/lib/BUILD5
-rw-r--r--tensorflow/compiler/tf2xla/lib/cholesky.cc9
-rw-r--r--tensorflow/compiler/tf2xla/lib/random.cc23
-rw-r--r--tensorflow/compiler/tf2xla/lib/random.h2
-rw-r--r--tensorflow/compiler/tf2xla/lib/triangular_solve.cc7
-rw-r--r--tensorflow/compiler/tf2xla/lib/util.cc6
-rw-r--r--tensorflow/compiler/tf2xla/lib/util.h3
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.cc61
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.h4
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler_test.cc3
-rw-r--r--tensorflow/compiler/tf2xla/xla_context.cc8
-rw-r--r--tensorflow/compiler/tf2xla/xla_context.h11
-rw-r--r--tensorflow/compiler/tf2xla/xla_helpers.cc190
-rw-r--r--tensorflow/compiler/tf2xla/xla_helpers.h44
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.cc10
-rw-r--r--tensorflow/compiler/xla/client/lib/BUILD57
-rw-r--r--tensorflow/compiler/xla/client/lib/arithmetic.cc131
-rw-r--r--tensorflow/compiler/xla/client/lib/arithmetic.h15
-rw-r--r--tensorflow/compiler/xla/client/lib/constants.cc103
-rw-r--r--tensorflow/compiler/xla/client/lib/constants.h124
-rw-r--r--tensorflow/compiler/xla/client/lib/constants_test.cc159
-rw-r--r--tensorflow/compiler/xla/client/lib/math.cc152
-rw-r--r--tensorflow/compiler/xla/client/lib/math.h51
-rw-r--r--tensorflow/compiler/xla/client/lib/math_test.cc85
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder.cc42
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder.h85
-rw-r--r--tensorflow/compiler/xla/python/BUILD2
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.cc7
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.h6
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.i6
-rw-r--r--tensorflow/compiler/xla/python/xla_client.py6
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/gpu/conditional_thunk.cc14
-rw-r--r--tensorflow/compiler/xla/service/gpu/conditional_thunk.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_thunk.cc5
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_thunk.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/copy_thunk.cc9
-rw-r--r--tensorflow/compiler/xla/service/gpu/copy_thunk.h7
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc14
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h10
-rw-r--r--tensorflow/compiler/xla/service/gpu/fft_thunk.cc5
-rw-r--r--tensorflow/compiler/xla/service/gpu/fft_thunk.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/for_thunk.cc11
-rw-r--r--tensorflow/compiler/xla/service/gpu/for_thunk.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/gemm_thunk.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/gemm_thunk.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_executable.cc5
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc65
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h48
-rw-r--r--tensorflow/compiler/xla/service/gpu/infeed_thunk.cc7
-rw-r--r--tensorflow/compiler/xla/service/gpu/infeed_thunk.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/kernel_thunk.cc5
-rw-r--r--tensorflow/compiler/xla/service/gpu/kernel_thunk.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/memset_thunk.cc10
-rw-r--r--tensorflow/compiler/xla/service/gpu/memset_thunk.h7
-rw-r--r--tensorflow/compiler/xla/service/gpu/sequential_thunk.cc13
-rw-r--r--tensorflow/compiler/xla/service/gpu/sequential_thunk.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/thunk.h6
-rw-r--r--tensorflow/compiler/xla/service/gpu/tuple_thunk.cc5
-rw-r--r--tensorflow/compiler/xla/service/gpu/tuple_thunk.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/while_thunk.cc22
-rw-r--r--tensorflow/compiler/xla/service/gpu/while_thunk.h4
-rw-r--r--tensorflow/compiler/xla/shape_util.cc67
-rw-r--r--tensorflow/compiler/xla/tests/BUILD1
-rw-r--r--tensorflow/compiler/xla/tests/batch_normalization_test.cc13
-rw-r--r--tensorflow/compiler/xla/tests/broadcast_simple_test.cc80
-rw-r--r--tensorflow/compiler/xla/tests/scalar_computations_test.cc12
-rw-r--r--tensorflow/compiler/xla/tests/vector_ops_simple_test.cc40
-rw-r--r--tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc11
-rw-r--r--tensorflow/contrib/BUILD2
-rw-r--r--tensorflow/contrib/checkpoint/__init__.py2
-rw-r--r--tensorflow/contrib/checkpoint/python/containers_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/bucketing_test.py39
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py6
-rw-r--r--tensorflow/contrib/data/python/ops/grouping.py10
-rw-r--r--tensorflow/contrib/distribute/python/cross_tower_ops.py71
-rw-r--r--tensorflow/contrib/distribute/python/cross_tower_ops_test.py82
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py47
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py180
-rw-r--r--tensorflow/contrib/distribute/python/one_device_strategy.py12
-rw-r--r--tensorflow/contrib/distribute/python/strategy_test_lib.py9
-rw-r--r--tensorflow/contrib/distribute/python/tpu_strategy.py5
-rw-r--r--tensorflow/contrib/distribute/python/values.py35
-rw-r--r--tensorflow/contrib/distribute/python/values_test.py39
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/BUILD2
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/blocks.py147
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/blocks_test.py264
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/cifar_input.py2
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/cifar_tfrecords.py89
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/config.py29
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/main.py204
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/revnet.py65
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/revnet_test.py67
-rw-r--r--tensorflow/contrib/eager/python/examples/workshop/1_basic.ipynb282
-rw-r--r--tensorflow/contrib/eager/python/examples/workshop/2_models.ipynb1018
-rw-r--r--tensorflow/contrib/eager/python/examples/workshop/3_inspecting.ipynb443
-rw-r--r--tensorflow/contrib/lite/Makefile83
-rw-r--r--tensorflow/contrib/lite/allocation.cc6
-rw-r--r--tensorflow/contrib/lite/arena_planner.cc13
-rw-r--r--tensorflow/contrib/lite/arena_planner.h9
-rw-r--r--tensorflow/contrib/lite/arena_planner_test.cc29
-rw-r--r--tensorflow/contrib/lite/examples/android/BUILD1
-rw-r--r--tensorflow/contrib/lite/examples/android/app/build.gradle4
-rw-r--r--tensorflow/contrib/lite/examples/android/app/download-models.gradle5
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/DetectorActivity.java13
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java210
-rw-r--r--tensorflow/contrib/lite/interpreter.cc16
-rw-r--r--tensorflow/contrib/lite/interpreter_test.cc30
-rw-r--r--tensorflow/contrib/lite/java/demo/app/build.gradle4
-rw-r--r--tensorflow/contrib/lite/java/ovic/demo/app/build.gradle4
-rw-r--r--tensorflow/contrib/lite/kernels/BUILD1
-rw-r--r--tensorflow/contrib/lite/kernels/embedding_lookup.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/internal/kernel_utils.cc6
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h1
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc11
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc9
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc5
-rw-r--r--tensorflow/contrib/lite/kernels/svdf.cc5
-rw-r--r--tensorflow/contrib/lite/kernels/svdf_test.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/transpose_conv.cc111
-rw-r--r--tensorflow/contrib/lite/kernels/transpose_conv_test.cc121
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc468
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc1767
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.cc244
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.h6
-rw-r--r--tensorflow/contrib/lite/schema/BUILD1
-rw-r--r--tensorflow/contrib/lite/testing/BUILD1
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py10
-rw-r--r--tensorflow/contrib/lite/testing/generate_testspec.cc85
-rw-r--r--tensorflow/contrib/lite/testing/generate_testspec.h4
-rw-r--r--tensorflow/contrib/lite/testing/generated_examples_zip_test.cc19
-rw-r--r--tensorflow/contrib/lite/testing/tflite_diff_example_test.cc23
-rw-r--r--tensorflow/contrib/lite/testing/tflite_diff_flags.h6
-rw-r--r--tensorflow/contrib/lite/testing/tflite_diff_util.cc7
-rw-r--r--tensorflow/contrib/lite/testing/tflite_diff_util.h6
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc10
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc108
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc38
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc13
-rw-r--r--tensorflow/contrib/lite/toco/model.h2
-rw-r--r--tensorflow/contrib/lite/tools/BUILD1
-rw-r--r--tensorflow/contrib/optimizer_v2/optimizer_v2.py9
-rw-r--r--tensorflow/core/BUILD4
-rw-r--r--tensorflow/core/common_runtime/eager/execute.cc2
-rw-r--r--tensorflow/core/framework/graph_to_functiondef.cc2
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc203
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.h1
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc66
-rw-r--r--tensorflow/core/kernels/BUILD19
-rw-r--r--tensorflow/core/kernels/deserialize_sparse_string_op.cc293
-rw-r--r--tensorflow/core/kernels/quantize_and_dequantize_op.h8
-rw-r--r--tensorflow/core/kernels/serialize_sparse_op.cc255
-rw-r--r--tensorflow/python/client/session.py2
-rw-r--r--tensorflow/python/data/kernel_tests/batch_dataset_op_test.py4
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py78
-rw-r--r--tensorflow/python/eager/function.py4
-rw-r--r--tensorflow/python/eager/function_test.py19
-rw-r--r--tensorflow/python/eager/graph_callable.py47
-rw-r--r--tensorflow/python/estimator/keras.py13
-rw-r--r--tensorflow/python/keras/backend.py20
-rw-r--r--tensorflow/python/keras/backend_test.py30
-rw-r--r--tensorflow/python/keras/callbacks.py77
-rw-r--r--tensorflow/python/keras/callbacks_test.py75
-rw-r--r--tensorflow/python/keras/engine/base_layer.py6
-rw-r--r--tensorflow/python/keras/engine/network.py56
-rw-r--r--tensorflow/python/keras/engine/sequential.py4
-rw-r--r--tensorflow/python/keras/engine/training.py10
-rw-r--r--tensorflow/python/keras/engine/training_arrays.py6
-rw-r--r--tensorflow/python/keras/layers/normalization.py4
-rw-r--r--tensorflow/python/keras/model_subclassing_test.py6
-rw-r--r--tensorflow/python/keras/optimizers.py5
-rw-r--r--tensorflow/python/kernel_tests/variable_scope_test.py42
-rw-r--r--tensorflow/python/ops/distributions/distribution.py2
-rw-r--r--tensorflow/python/ops/gradients_impl.py129
-rw-r--r--tensorflow/python/ops/gradients_test.py90
-rw-r--r--tensorflow/python/ops/metrics_impl.py3
-rw-r--r--tensorflow/python/ops/variable_scope.py292
-rw-r--r--tensorflow/python/ops/variables.py6
-rw-r--r--tensorflow/python/training/checkpointable/BUILD1
-rw-r--r--tensorflow/python/training/checkpointable/base.py63
-rw-r--r--tensorflow/python/training/checkpointable/data_structures.py283
-rw-r--r--tensorflow/python/training/checkpointable/data_structures_test.py65
-rw-r--r--tensorflow/python/training/checkpointable/layer_utils.py8
-rw-r--r--tensorflow/python/training/checkpointable/tracking.py47
-rw-r--r--tensorflow/python/training/checkpointable/tracking_test.py123
-rw-r--r--tensorflow/python/training/checkpointable/util.py151
-rw-r--r--tensorflow/python/training/distribute.py73
-rw-r--r--tensorflow/python/training/distribute_test.py39
-rw-r--r--tensorflow/python/training/optimizer.py9
-rw-r--r--tensorflow/python/util/nest.py11
-rw-r--r--tensorflow/python/util/util.cc6
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-variable-aggregation.pbtxt16
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-variable-scope.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-variable-synchronization.pbtxt20
-rw-r--r--tensorflow/tools/api/golden/tensorflow.pbtxt12
-rwxr-xr-xtensorflow/tools/ci_build/ci_parameterized_build.sh2
-rwxr-xr-xtensorflow/tools/ci_build/ci_sanity.sh2
-rw-r--r--tensorflow/workspace.bzl8
221 files changed, 8670 insertions, 3388 deletions
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 51eea94847..fb96738e33 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -603,3 +603,19 @@ py_library(
visibility = ["//visibility:public"],
deps = ["//tensorflow/python:no_contrib"],
)
+
+cc_library(
+ name = "grpc",
+ deps = select({
+ ":linux_s390x": ["@grpc//:grpc_unsecure"],
+ "//conditions:default": ["@grpc"],
+ }),
+)
+
+cc_library(
+ name = "grpc++",
+ deps = select({
+ ":linux_s390x": ["@grpc//:grpc++_unsecure"],
+ "//conditions:default": ["@grpc//:grpc++"],
+ }),
+)
diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc
index e20f5aa837..3bbf97afad 100644
--- a/tensorflow/compiler/jit/xla_device_context.cc
+++ b/tensorflow/compiler/jit/xla_device_context.cc
@@ -56,9 +56,9 @@ XlaTransferManager::XlaTransferManager(
transfer_as_literal_(transfer_as_literal),
shape_representation_fn_(std::move(shape_representation_fn)) {
if (!shape_representation_fn_) {
- shape_representation_fn_ = [](const TensorShape& shape, DataType dtype) {
- return shape;
- };
+ shape_representation_fn_ =
+ [](const TensorShape& shape,
+ DataType dtype) -> xla::StatusOr<TensorShape> { return shape; };
}
}
@@ -136,9 +136,14 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
CHECK(xla_tensor);
- TensorShape shape = shape_representation_fn_(device_tensor->shape(),
- device_tensor->dtype());
Status status;
+ xla::StatusOr<TensorShape> shape_or_status = shape_representation_fn_(
+ device_tensor->shape(), device_tensor->dtype());
+ if (!shape_or_status.ok()) {
+ done(shape_or_status.status());
+ return;
+ }
+ TensorShape shape = shape_or_status.ValueOrDie();
if (!xla_tensor->has_shaped_buffer()) {
status = xla_tensor->AllocateShapedBuffer(
device_tensor->dtype(), shape, client_,
@@ -233,8 +238,9 @@ void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor,
CHECK(xla_src && xla_dst)
<< "Missing destination tensor for device-to-device copy";
if (!xla_dst->has_shaped_buffer()) {
- TensorShape shape =
- shape_representation_fn_(src_tensor.shape(), src_tensor.dtype());
+ TF_ASSIGN_OR_RETURN(
+ TensorShape shape,
+ shape_representation_fn_(src_tensor.shape(), src_tensor.dtype()));
TF_RETURN_IF_ERROR(
xla_dst->AllocateShapedBuffer(src_tensor.dtype(), shape, client_,
stream_->parent()->device_ordinal()));
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 366822f0b7..95fda489a1 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -887,6 +887,8 @@ tf_xla_py_test(
name = "sort_ops_test",
size = "small",
srcs = ["sort_ops_test.py"],
+ # Times out in fastbuild mode.
+ tags = ["optonly"],
deps = [
"//tensorflow/compiler/tests:xla_test",
"//tensorflow/compiler/tf2xla/python:xla",
diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py
index 8ae579abda..9e2ef964a1 100644
--- a/tensorflow/compiler/tests/sort_ops_test.py
+++ b/tensorflow/compiler/tests/sort_ops_test.py
@@ -64,20 +64,29 @@ class XlaSortOpTest(xla_test.XLATestCase):
if self.device in ["XLA_CPU", "XLA_GPU"]:
return
- # Only bfloat16 is implemented.
- bfloat16 = dtypes.bfloat16.as_numpy_dtype
- if bfloat16 in self.numeric_types:
- for x in [np.arange(20)]:
+ supported_types = set(
+ [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32])
+ for dtype in supported_types.intersection(self.numeric_types):
+ # Use small input size for bfloat16. Otherwise, we'll get duplicate values
+ # after conversion to bfloat16, so the possible resulting index array is
+ # no longer unique.
+ if dtype == dtypes.bfloat16.as_numpy_dtype:
+ array_size = 20
+ k_options = [0, 1, 2, 10, 20]
+ else:
+ array_size = 200 * 1000
+ k_options = [0, 1, 2, 10, 20, 100, 1000, 200 * 1000]
+ for x in [np.arange(array_size)]:
np.random.shuffle(x)
- for k in [0, 1, 2, 10, 20]:
+ for k in k_options:
indices = x.argsort()[::-1][:k]
def topk(v, k=k):
return nn_ops.top_k(v, k=k, sorted=True)
self._assertOpOutputMatchesExpected(
- topk, [x.astype(bfloat16)],
- expected=[x[indices].astype(bfloat16), indices])
+ topk, [x.astype(dtype)],
+ expected=[x[indices].astype(dtype), indices])
def testTopKZeros(self):
"""Tests that positive and negative zeros sort correctly."""
@@ -99,7 +108,7 @@ class XlaSortOpTest(xla_test.XLATestCase):
{p: np.array([0., -0., 0., 3., -0., -4., 0., -0.], dtype=bfloat16)})
self.assertAllEqual(
np.array([3., 0., 0., 0.], dtype=bfloat16), results[0])
- self.assertEqual(list([3, 0, 1, 2]), list(results[1]))
+ self.assertEqual(list([3, 0, 2, 6]), list(results[1]))
def testTopKInfinities(self):
"""Tests that positive and negative infinity sort correctly."""
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index aa9c0596d1..40e32f2e75 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -164,11 +164,14 @@ cc_library(
"//tensorflow/compiler/tf2xla/lib:util",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/client/lib:constants",
"//tensorflow/compiler/xla/client/lib:numeric",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index e6cbf2349d..a8eb7d942d 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -121,6 +121,8 @@ tf_kernel_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/client/lib:constants",
+ "//tensorflow/compiler/xla/client/lib:math",
"//tensorflow/compiler/xla/client/lib:numeric",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/core:framework",
diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc
index c137d026bd..1784e712b5 100644
--- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc
@@ -74,16 +74,14 @@ class CategoricalOp : public XlaOpKernel {
// See:
// https://hips.seas.harvard.edu/blog/2013/04/06/the-gumbel-max-trick-for-discrete-distributions/
// TODO(b/68769470): Switch to using a cumulative sum approach.
- auto softmax_entries =
- xla::Sub(logits, xla::Log(xla::Neg(xla::Log(uniforms))),
- /*broadcast_dimensions=*/{0, 2});
-
- TensorShape softmax_shape(uniform_shape_array);
- xla::XlaOp argmax;
- OP_REQUIRES_OK(
- ctx,
- XlaHelpers::ArgMax(builder, ctx, softmax_entries, softmax_shape,
- input_type(0), output_type(0), /*axis=*/2, &argmax));
+ auto softmax_entries = xla::Sub(logits, xla::Log(-xla::Log(uniforms)),
+ /*broadcast_dimensions=*/{0, 2});
+
+ xla::PrimitiveType xla_output_type;
+ OP_REQUIRES_OK(ctx,
+ DataTypeToPrimitiveType(output_type(0), &xla_output_type));
+ xla::XlaOp argmax =
+ XlaHelpers::ArgMax(softmax_entries, xla_output_type, /*axis=*/2);
ctx->SetOutput(0, argmax);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc
index 378b62c0d6..6dec414c53 100644
--- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/numeric.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/util.h"
@@ -27,10 +28,10 @@ namespace tensorflow {
namespace {
// Create a diagonal / batch diagonal matrix with 'input' on the diagonal.
-xla::StatusOr<xla::XlaOp> CreateDiagonal(
- const xla::XlaOp& input, int64 last_dim_size,
- tensorflow::gtl::ArraySlice<int64> other_dims, XlaOpKernelContext* ctx,
- xla::XlaBuilder* builder) {
+xla::XlaOp CreateDiagonal(xla::XlaOp input, int64 last_dim_size,
+ gtl::ArraySlice<int64> other_dims,
+ xla::PrimitiveType element_type) {
+ xla::XlaBuilder* builder = input.builder();
// Create two matrices that have the following forms, and compare them:
//
// [[0, 0, 0, 0] [[0, 1, 2, 3]
@@ -67,12 +68,9 @@ xla::StatusOr<xla::XlaOp> CreateDiagonal(
xla::XlaOp input_broadcast = xla::Reshape(input, broadcast_dims);
broadcast_dims[broadcast_dims.size() - 2] = last_dim_size;
- xla::PrimitiveType element_type;
- TF_RETURN_IF_ERROR(
- DataTypeToPrimitiveType(ctx->input_type(0), &element_type));
auto broadcast_shape =
xla::ShapeUtil::MakeShape(element_type, broadcast_dims);
- xla::XlaOp zeros = Zeros(builder, broadcast_shape);
+ xla::XlaOp zeros = xla::Zeros(builder, broadcast_shape);
input_broadcast = xla::Add(input_broadcast, zeros);
return xla::Select(mask, input_broadcast, zeros);
@@ -83,8 +81,6 @@ class DiagOp : public XlaOpKernel {
explicit DiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- xla::XlaBuilder* builder = ctx->builder();
-
OP_REQUIRES(ctx, ctx->num_inputs() >= 1,
errors::InvalidArgument("Diag op must have at an input"));
const TensorShape input_shape = ctx->InputShape(0);
@@ -107,10 +103,8 @@ class DiagOp : public XlaOpKernel {
input = xla::Reshape(input, {size});
// Create an R2 with the R1 diagonal.
- auto diag_or_status =
- CreateDiagonal(input, size, /*other_dims=*/{}, ctx, builder);
- OP_REQUIRES_OK(ctx, diag_or_status.status());
- xla::XlaOp diag = diag_or_status.ValueOrDie();
+ xla::XlaOp diag =
+ CreateDiagonal(input, size, /*other_dims=*/{}, ctx->input_xla_type(0));
// Reshapes to the final shape.
std::vector<int64> new_dims(dims.size() * 2);
@@ -197,8 +191,6 @@ class MatrixDiagOp : public XlaOpKernel {
explicit MatrixDiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- xla::XlaBuilder* builder = ctx->builder();
-
OP_REQUIRES(ctx, ctx->num_inputs() >= 1,
errors::InvalidArgument("MatrixDiag op must have at an input"));
const TensorShape input_shape = ctx->InputShape(0);
@@ -208,17 +200,15 @@ class MatrixDiagOp : public XlaOpKernel {
errors::InvalidArgument("Expected 1 <= dims, got shape ",
input_shape.DebugString()));
- xla::XlaOp diag = ctx->Input(0);
int last_dim = dims.size() - 1;
int64 last_dim_size = input_shape.dim_size(last_dim);
tensorflow::gtl::ArraySlice<int64> other_dims(dims);
other_dims.pop_back();
- auto diag_or_status =
- CreateDiagonal(diag, last_dim_size, other_dims, ctx, builder);
- OP_REQUIRES_OK(ctx, diag_or_status.status());
- diag = diag_or_status.ValueOrDie();
+ xla::XlaOp input = ctx->Input(0);
+ xla::XlaOp diag = CreateDiagonal(input, last_dim_size, other_dims,
+ ctx->input_xla_type(0));
ctx->SetOutput(0, diag);
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops.cc b/tensorflow/compiler/tf2xla/kernels/index_ops.cc
index 36eb4c7545..f396474858 100644
--- a/tensorflow/compiler/tf2xla/kernels/index_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/index_ops.cc
@@ -60,19 +60,15 @@ void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) {
input_shape.DebugString()));
DataType index_type = output_type(0);
+ xla::PrimitiveType index_xla_type;
+ OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(index_type, &index_xla_type));
- xla::XlaBuilder* b = ctx->builder();
xla::XlaOp input = ctx->Input(0);
-
xla::XlaOp output;
if (is_min_) {
- OP_REQUIRES_OK(ctx,
- XlaHelpers::ArgMin(b, ctx, input, input_shape, input_type(0),
- index_type, axis, &output));
+ output = XlaHelpers::ArgMin(input, index_xla_type, axis);
} else {
- OP_REQUIRES_OK(ctx,
- XlaHelpers::ArgMax(b, ctx, input, input_shape, input_type(0),
- index_type, axis, &output));
+ output = XlaHelpers::ArgMax(input, index_xla_type, axis);
}
ctx->SetOutput(0, output);
diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
index 771dcbab21..a81f5fddf6 100644
--- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/util.h"
@@ -62,6 +63,9 @@ class PoolingOp : public XlaOpKernel {
Padding padding;
OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding));
padding_ = (padding == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
+
+ OP_REQUIRES_OK(
+ ctx, DataTypeToPrimitiveType(reduction_type_, &xla_reduction_type_));
}
int num_dims() const { return num_spatial_dims_ + 2; }
@@ -128,6 +132,7 @@ class PoolingOp : public XlaOpKernel {
xla::Padding padding_;
TensorFormat data_format_ = FORMAT_NHWC;
DataType reduction_type_;
+ xla::PrimitiveType xla_reduction_type_;
};
class MaxPoolOp : public PoolingOp {
@@ -137,7 +142,7 @@ class MaxPoolOp : public PoolingOp {
/*reduction_type=*/ctx->input_type(0)) {}
xla::XlaOp InitValue(xla::XlaBuilder* b) override {
- return XlaHelpers::MinValue(b, reduction_type_);
+ return xla::MinValue(b, xla_reduction_type_);
}
const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) override {
@@ -236,7 +241,7 @@ class AvgPoolOp : public PoolingOp {
XlaHelpers::SumAccumulationType(ctx->input_type(0))) {}
xla::XlaOp InitValue(xla::XlaBuilder* b) override {
- return XlaHelpers::Zero(b, reduction_type_);
+ return xla::Zero(b, xla_reduction_type_);
}
const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) override {
@@ -628,7 +633,7 @@ class MaxPoolGradGradOp : public XlaOpKernel {
auto in_hi_bp_hi = xla::Add(in_hi, bp_hi); // Want an unsigned add.
auto in_hi_bp_lo = xla::Add(in_hi, bp_lo); // Want an unsigned add.
- auto init_value = XlaHelpers::MinValue(b, DT_FLOAT);
+ auto init_value = xla::MinValue(b, xla::F32);
// We will reduce by taking the maximal value up to 16 bits (ignoring the lo
// 16 bits of packed-in hi/lo backprop value).
auto rb = b->CreateSubBuilder("GreaterOrEqOf_ByFirst16Bits");
diff --git a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc
index 02293796e4..e88221e4f4 100644
--- a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/platform/macros.h"
@@ -50,8 +51,8 @@ class QuantizeAndDequantizeOp : public XlaOpKernel {
} else {
const xla::XlaComputation* fmax = ctx->GetOrCreateMax(data_type);
const xla::XlaComputation* fmin = ctx->GetOrCreateMin(data_type);
- min_range = ReduceAll(input, XlaHelpers::MaxValue(b, data_type), *fmin);
- max_range = ReduceAll(input, XlaHelpers::MinValue(b, data_type), *fmax);
+ min_range = ReduceAll(input, xla::MaxValue(b, xla_type), *fmin);
+ max_range = ReduceAll(input, xla::MinValue(b, xla_type), *fmax);
}
xla::XlaOp num_bits;
@@ -93,10 +94,10 @@ class QuantizeAndDequantizeOp : public XlaOpKernel {
// while keeping 0 unchanged.
xla::XlaOp scale_from_min_side =
Select(Gt(min_quantized * min_range, zero), min_quantized / min_range,
- XlaHelpers::MaxFiniteValue(b, data_type));
+ xla::MaxFiniteValue(b, xla_type));
xla::XlaOp scale_from_max_side =
Select(Gt(max_quantized * max_range, zero), max_quantized / max_range,
- XlaHelpers::MaxFiniteValue(b, data_type));
+ xla::MaxFiniteValue(b, xla_type));
// Note: Avoids changing the side of the range that determines scale.
xla::XlaOp cond = Lt(scale_from_min_side, scale_from_max_side);
diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc
index d5b645d70a..9a0a7f9b90 100644
--- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc
@@ -211,7 +211,7 @@ class TruncatedNormalOp : public XlaOpKernel {
xla::XlaOp min_positive =
XlaHelpers::FloatLiteral(b, dtype, std::numeric_limits<float>::min());
auto uniform = xla::RngUniform(min_positive, one, xla_shape);
- ctx->SetOutput(0, TruncatedNormal(dtype, uniform));
+ ctx->SetOutput(0, TruncatedNormal(uniform));
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc
index d3573bac3d..46fae59ad4 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
@@ -32,7 +33,7 @@ class SumOp : public XlaReductionOp {
: XlaReductionOp(ctx,
XlaHelpers::SumAccumulationType(ctx->input_type(0))) {}
xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
- return XlaHelpers::Zero(builder, reduction_type_);
+ return xla::Zero(builder, xla_reduction_type_);
}
void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
const xla::XlaOp& scalar_rhs) override {
@@ -49,7 +50,7 @@ class ProdOp : public XlaReductionOp {
XlaHelpers::SumAccumulationType(ctx->input_type(0))) {}
xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
- return XlaHelpers::One(builder, reduction_type_);
+ return xla::One(builder, xla_reduction_type_);
}
void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
@@ -67,7 +68,7 @@ class MinOp : public XlaReductionOp {
: XlaReductionOp(ctx, ctx->input_type(0)) {}
xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
- return XlaHelpers::MaxValue(builder, reduction_type_);
+ return xla::MaxValue(builder, xla_reduction_type_);
}
void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
@@ -84,7 +85,7 @@ class MaxOp : public XlaReductionOp {
: XlaReductionOp(ctx, ctx->input_type(0)) {}
xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
- return XlaHelpers::MinValue(builder, reduction_type_);
+ return xla::MinValue(builder, xla_reduction_type_);
}
void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
@@ -102,7 +103,7 @@ class MeanOp : public XlaReductionOp {
XlaHelpers::SumAccumulationType(ctx->input_type(0))) {}
xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
- return XlaHelpers::Zero(builder, reduction_type_);
+ return xla::Zero(builder, xla_reduction_type_);
}
void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
const xla::XlaOp& scalar_rhs) override {
@@ -114,7 +115,7 @@ class MeanOp : public XlaReductionOp {
int64 num_elements_reduced) override {
auto divisor = XlaHelpers::IntegerLiteral(builder, input_type(0),
num_elements_reduced);
- return xla::Div(reduce_output, divisor);
+ return reduce_output / divisor;
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h
index 2ecfb854a1..8333f9b288 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h
+++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h
@@ -64,6 +64,7 @@ class XlaReductionOp : public XlaOpKernel {
protected:
DataType reduction_type_;
+ xla::PrimitiveType xla_reduction_type_;
};
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
index 14506d65c4..909783ecb3 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
@@ -32,6 +32,8 @@ XlaReductionOp::XlaReductionOp(OpKernelConstruction* ctx,
OP_REQUIRES_OK(ctx, ctx->MatchSignature({dt, DT_INT32}, {dt}));
OP_REQUIRES_OK(ctx, ctx->GetAttr("keep_dims", &keep_dims_));
+ OP_REQUIRES_OK(
+ ctx, DataTypeToPrimitiveType(reduction_type_, &xla_reduction_type_));
}
// Unless BuildFinalizer is overridden the reduction has no
diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc
index db7ea775e2..5be70a4ded 100644
--- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -62,10 +63,20 @@ class RetvalOp : public XlaOpKernel {
OP_REQUIRES_OK(ctx, tc.AddConstRetval(index_, dtype_, literal));
} else {
TensorShape shape = ctx->InputShape(0);
- TensorShape representation_shape =
- tc.is_entry_computation()
- ? tc.RepresentationShape(shape, ctx->input_type(0))
- : shape;
+ ctx->SetStatus(is_constant.status());
+ TensorShape representation_shape;
+ if (tc.is_entry_computation()) {
+ xla::StatusOr<TensorShape> shape_or_status =
+ tc.RepresentationShape(shape, ctx->input_type(0));
+ if (!shape_or_status.ok()) {
+ ctx->SetStatus(shape_or_status.status());
+ return;
+ } else {
+ representation_shape = shape_or_status.ValueOrDie();
+ }
+ } else {
+ representation_shape = shape;
+ }
xla::XlaOp output = input;
if (tc.is_entry_computation()) {
diff --git a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc
index db7e559420..e2ac7da2c2 100644
--- a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc
@@ -14,9 +14,11 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/lib/scatter.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
namespace tensorflow {
@@ -25,15 +27,16 @@ namespace {
class UnsortedSegmentReduce : public XlaOpKernel {
public:
explicit UnsortedSegmentReduce(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
- OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ DataType dtype;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype));
+ OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dtype, &type_));
}
// The initial value to initialize elements of the output to.
virtual xla::XlaOp InitialValue(xla::XlaBuilder* builder) = 0;
// A function to combine two scalars with the same index (e.g., sum).
- virtual xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b,
- xla::XlaBuilder* builder) = 0;
+ virtual xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) = 0;
void Compile(XlaOpKernelContext* ctx) override {
// output = unsorted_segment_sum(data, indices, num_segments)
@@ -78,9 +81,7 @@ class UnsortedSegmentReduce : public XlaOpKernel {
xla::Broadcast(InitialValue(builder), buffer_shape.dim_sizes());
auto combiner = [this](xla::XlaOp a, xla::XlaOp b,
- xla::XlaBuilder* builder) {
- return Combine(a, b, builder);
- };
+ xla::XlaBuilder* builder) { return Combine(a, b); };
auto result = XlaScatter(buffer, /*updates=*/data, indices,
/*indices_are_vectors=*/false, combiner, builder);
@@ -89,7 +90,7 @@ class UnsortedSegmentReduce : public XlaOpKernel {
}
protected:
- DataType dtype_;
+ xla::PrimitiveType type_;
};
class UnsortedSegmentSum : public UnsortedSegmentReduce {
@@ -98,12 +99,9 @@ class UnsortedSegmentSum : public UnsortedSegmentReduce {
: UnsortedSegmentReduce(ctx) {}
xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
- return XlaHelpers::Zero(builder, dtype_);
- };
- xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b,
- xla::XlaBuilder* builder) override {
- return xla::Add(a, b);
+ return xla::Zero(builder, type_);
};
+ xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) override { return a + b; };
};
REGISTER_XLA_OP(
@@ -116,12 +114,9 @@ class UnsortedSegmentProd : public UnsortedSegmentReduce {
: UnsortedSegmentReduce(ctx) {}
xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
- return XlaHelpers::One(builder, dtype_);
- };
- xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b,
- xla::XlaBuilder* builder) override {
- return xla::Mul(a, b);
+ return xla::One(builder, type_);
};
+ xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) override { return a * b; };
};
REGISTER_XLA_OP(
@@ -134,10 +129,9 @@ class UnsortedSegmentMin : public UnsortedSegmentReduce {
: UnsortedSegmentReduce(ctx) {}
xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
- return XlaHelpers::MaxFiniteValue(builder, dtype_);
+ return xla::MaxFiniteValue(builder, type_);
};
- xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b,
- xla::XlaBuilder* builder) override {
+ xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) override {
return xla::Min(a, b);
};
};
@@ -152,10 +146,9 @@ class UnsortedSegmentMax : public UnsortedSegmentReduce {
: UnsortedSegmentReduce(ctx) {}
xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
- return XlaHelpers::MinFiniteValue(builder, dtype_);
+ return xla::MinFiniteValue(builder, type_);
};
- xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b,
- xla::XlaBuilder* builder) override {
+ xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) override {
return xla::Max(a, b);
};
};
diff --git a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc
index d1c69f08b0..a71fbcd901 100644
--- a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc
@@ -15,9 +15,11 @@ limitations under the License.
// XLA-specific Ops for softmax.
+#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
@@ -42,23 +44,27 @@ class SoftmaxOp : public XlaOpKernel {
const int kClassDim = 1;
const DataType type = input_type(0);
+ const xla::PrimitiveType xla_type = ctx->input_xla_type(0);
auto logits = ctx->Input(0);
xla::XlaBuilder* const b = ctx->builder();
const xla::XlaComputation& max_func = *ctx->GetOrCreateMax(type);
// Find the max in each batch, resulting in a tensor of shape [batch]
- auto logits_max = xla::Reduce(logits, XlaHelpers::MinValue(b, type),
- max_func, {kClassDim});
+ auto logits_max =
+ xla::Reduce(logits, xla::MinValue(b, xla_type), max_func, {kClassDim});
// Subtract the max in batch b from every element in batch b. Broadcasts
// along the batch dimension.
auto shifted_logits = xla::Sub(logits, logits_max, {kBatchDim});
auto exp_shifted = xla::Exp(shifted_logits);
const DataType accumulation_type = XlaHelpers::SumAccumulationType(type);
+ xla::PrimitiveType xla_accumulation_type;
+ OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(accumulation_type,
+ &xla_accumulation_type));
auto converted =
- XlaHelpers::ConvertElementType(b, exp_shifted, accumulation_type);
+ xla::ConvertElementType(exp_shifted, xla_accumulation_type);
auto reduce =
- xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type),
+ xla::Reduce(converted, xla::Zero(b, xla_accumulation_type),
*ctx->GetOrCreateAdd(accumulation_type), {kClassDim});
auto sum = XlaHelpers::ConvertElementType(b, reduce, type);
auto softmax =
@@ -78,8 +84,8 @@ REGISTER_XLA_OP(Name("Softmax"), SoftmaxOp);
REGISTER_XLA_OP(Name("LogSoftmax"), SoftmaxOp);
std::pair<xla::XlaOp, xla::XlaOp> CrossEntropyWithLogits(
- XlaOpKernelContext* ctx, DataType type, const xla::XlaOp& logits,
- const xla::XlaOp& labels) {
+ XlaOpKernelContext* ctx, DataType type, xla::PrimitiveType xla_type,
+ xla::XlaOp logits, xla::XlaOp labels) {
const xla::XlaComputation& max_func = *ctx->GetOrCreateMax(type);
const int kBatchDim = 0;
@@ -88,7 +94,7 @@ std::pair<xla::XlaOp, xla::XlaOp> CrossEntropyWithLogits(
xla::XlaBuilder* b = ctx->builder();
// Find the max in each batch, resulting in a tensor of shape [batch]
auto logits_max =
- xla::Reduce(logits, XlaHelpers::MinValue(b, type), max_func, {kClassDim});
+ xla::Reduce(logits, xla::MinValue(b, xla_type), max_func, {kClassDim});
// Subtract the max in batch b from every element in batch b.
// Broadcasts along the batch dimension.
@@ -148,12 +154,13 @@ class SoftmaxXentWithLogitsOp : public XlaOpKernel {
// check that "labels" is a matrix too.
const DataType type = input_type(0);
+ const xla::PrimitiveType xla_type = ctx->input_xla_type(0);
auto logits = ctx->Input(0);
auto labels = ctx->Input(1);
xla::XlaOp loss, backprop;
std::tie(loss, backprop) =
- CrossEntropyWithLogits(ctx, type, logits, labels);
+ CrossEntropyWithLogits(ctx, type, xla_type, logits, labels);
ctx->SetOutput(0, loss);
ctx->SetOutput(1, backprop);
}
@@ -189,8 +196,9 @@ class SparseSoftmaxXentWithLogitsOp : public XlaOpKernel {
int64 batch_size = logits_shape.dim_size(0);
int64 depth = logits_shape.dim_size(1);
- DataType logits_type = input_type(0);
- DataType indices_type = input_type(1);
+ const DataType logits_type = input_type(0);
+ const xla::PrimitiveType xla_logits_type = ctx->input_xla_type(0);
+ const DataType indices_type = input_type(1);
xla::XlaOp indices = ctx->Input(1);
@@ -218,8 +226,8 @@ class SparseSoftmaxXentWithLogitsOp : public XlaOpKernel {
labels = xla::Add(labels, nan_or_zero, {0});
xla::XlaOp loss, backprop;
- std::tie(loss, backprop) =
- CrossEntropyWithLogits(ctx, logits_type, ctx->Input(0), labels);
+ std::tie(loss, backprop) = CrossEntropyWithLogits(
+ ctx, logits_type, xla_logits_type, ctx->Input(0), labels);
ctx->SetOutput(0, loss);
ctx->SetOutput(1, backprop);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc
index 50a455b520..a6f5769e7b 100644
--- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc
@@ -20,7 +20,8 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/lib/numeric.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -209,8 +210,8 @@ class StatelessRandomNormalOp : public XlaOpKernel {
RandomUniform(builder, seed, shape, std::nextafter(-1.0f, 0.0f), 1.0);
// Convert uniform distribution to normal distribution by computing
// sqrt(2) * erfinv(x)
- auto normal = xla::Mul(xla::ConstantR0<float>(builder, std::sqrt(2.0)),
- ErfInv(uniform));
+ auto normal =
+ xla::ScalarLike(uniform, std::sqrt(2.0)) * xla::ErfInv(uniform);
ctx->SetOutput(0, normal);
}
@@ -231,8 +232,6 @@ class StatelessTruncatedNormalOp : public XlaOpKernel {
: XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- const DataType dtype = output_type(0);
-
TensorShape shape;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
@@ -245,7 +244,7 @@ class StatelessTruncatedNormalOp : public XlaOpKernel {
auto uniform =
RandomUniform(b, seed, shape, std::numeric_limits<float>::min(), 1.0);
- ctx->SetOutput(0, TruncatedNormal(dtype, uniform));
+ ctx->SetOutput(0, TruncatedNormal(uniform));
}
private:
diff --git a/tensorflow/compiler/tf2xla/kernels/topk_op.cc b/tensorflow/compiler/tf2xla/kernels/topk_op.cc
index 8a1377fc38..9962f1207d 100644
--- a/tensorflow/compiler/tf2xla/kernels/topk_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/topk_op.cc
@@ -52,107 +52,33 @@ class TopKOp : public XlaOpKernel {
errors::Unimplemented("TopK is implemented for 1-D inputs, got shape ",
input_shape.DebugString()));
- const int64 n = input_shape.dim_size(0);
- OP_REQUIRES(context, n < (1 << 16),
- errors::Unimplemented(
- "TopK is implemented for sizes up to 2**16, got shape ",
- input_shape.DebugString()));
-
xla::XlaBuilder* const b = context->builder();
if (input_shape.dim_size(0) < k) {
k = input_shape.dim_size(0);
}
- const xla::XlaOp input_bf16 = context->Input(0);
- xla::XlaOp iota_s32 = xla::Iota(b, xla::S32, n);
-
- // TODO(b/73891930): add a key-value sort to HLO, rather than using
- // bit-packing tricks here.
-
- xla::XlaOp zero = xla::ConstantR0<int32>(b, 0);
-
- // max can either be 0x7FFFFFFF or 0x8000000. Neither choice is totally
- // ideal. The implications of the choice are:
- //
- // 0x7FFFFFFF
- // 1. +0.0 > -0.0
- // 2. The elements of the inputs and outputs are bitwise identical.
- // 3. The sort is unstable since a later +0.0 will appear before an earlier
- // -0.0.
- //
- // 0x8000000
- // 1. +0.0 == -0.0
- // 2. All -0.0 in the input are replaced with +0.0 in the output.
- // 3. The sort is stable.
- xla::XlaOp max = xla::ConstantR0<int32>(b, 0x80000000);
- xla::XlaOp index_mask = xla::ConstantR0<int32>(b, 0x0000FFFF);
- xla::XlaOp value_mask = xla::ConstantR0<int32>(b, 0xFFFF0000);
-
- // Convert to from bf16 to f32. The lower 16-bits are zero due to the
- // definition of bf16.
- xla::XlaOp input_f32 = xla::ConvertElementType(input_bf16, xla::F32);
-
- // Negate the input to reverse sort it. The lower 16-bits are zero, because
- // negating a float is just inverting the high-bit.
- xla::XlaOp negative_input_f32 = xla::Neg(input_f32);
-
- // Convert to a sign magnitude integer. The lower 16-bits are zero, since
- // bitcast convert doesn't change any bits.
- xla::XlaOp negative_input_sm32 =
- xla::BitcastConvertType(negative_input_f32, xla::S32);
-
- // Convert from sign magnitude integer to two's complement integer. The
- // lower 16-bits are zero on both sides of the select. On the false side,
- // the value is unchanged, and on the true side, the lower 16-bits of max
- // are all zero, so the lower 16-bits of the result of the subtraction will
- // also be zero.
- xla::XlaOp negative_input_s32 =
- xla::Select(xla::Lt(negative_input_sm32, zero),
- xla::Sub(max, negative_input_sm32), negative_input_sm32);
-
- // In order for the Or with iota_s32 to to work properly, the lower 16-bits
- // of negative_input_32 must be zero.
-
- // Pack elements as:
- // * upper 16 bits are the value
- // * lower 16 bits are the index.
- xla::XlaOp packed_s32 = xla::Or(negative_input_s32, iota_s32);
-
- // TODO(phawkins): use a more efficient algorithm that does not require a
- // full sort.
- xla::XlaOp sorted_s32 = xla::Slice(xla::Sort(packed_s32),
- /*start_indices=*/{0},
- /*limit_indices=*/{k},
- /*strides=*/{1});
-
- // Unpack the value/index.
- xla::XlaOp indices_s32 = xla::And(sorted_s32, index_mask);
- xla::XlaOp negative_values_s32 = xla::And(sorted_s32, value_mask);
-
- // Convert from two's complement integer to sign magnitude integer.
- xla::XlaOp negative_values_sm32 =
- xla::Select(xla::Lt(negative_values_s32, zero),
- xla::Sub(max, negative_values_s32), negative_values_s32);
-
- xla::XlaOp negative_values_f32 =
- xla::BitcastConvertType(negative_values_sm32, xla::F32);
-
- // Negate the values to get back the original inputs.
- xla::XlaOp values_f32 = xla::Neg(negative_values_f32);
-
- // Convert from f32 to bf16.
- xla::XlaOp values_bf16 = xla::ConvertElementType(values_f32, xla::BF16);
-
- context->SetOutput(0, values_bf16);
- context->SetOutput(1, indices_s32);
+ const xla::XlaOp input = context->Input(0);
+ xla::XlaOp iota_s32 = xla::Iota(b, xla::S32, input_shape.dim_size(0));
+ xla::XlaOp sort_result = xla::Sort(xla::Neg(input), iota_s32);
+ xla::XlaOp values =
+ xla::Neg(xla::Slice(xla::GetTupleElement(sort_result, 0),
+ /*start_indices=*/{0},
+ /*limit_indices=*/{k},
+ /*strides=*/{1}));
+ xla::XlaOp indices = xla::Slice(xla::GetTupleElement(sort_result, 1),
+ /*start_indices=*/{0},
+ /*limit_indices=*/{k},
+ /*strides=*/{1});
+ context->SetOutput(0, values);
+ context->SetOutput(1, indices);
}
private:
bool sorted_;
};
-REGISTER_XLA_OP(
- Name("TopKV2").CompileTimeConstInput("k").TypeConstraint("T", DT_BFLOAT16),
- TopKOp);
+REGISTER_XLA_OP(Name("TopKV2").CompileTimeConstInput("k").TypeConstraint(
+ "T", {DT_UINT32, DT_INT32, DT_FLOAT, DT_BFLOAT16}),
+ TopKOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc
index 2e5d61e111..f3e112c7b3 100644
--- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc
@@ -16,6 +16,8 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
@@ -47,7 +49,7 @@ class ResourceApplyGradientDescent : public XlaOpKernel {
var_shape.DebugString(), " vs ",
delta_shape.DebugString()));
- handle = xla::Sub(handle, xla::Mul(ctx->Input(1), ctx->Input(2)));
+ handle = handle - ctx->Input(1) * ctx->Input(2);
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle));
}
};
@@ -94,14 +96,13 @@ class ResourceApplyMomentum : public XlaOpKernel {
xla::XlaOp grad = ctx->Input(3);
xla::XlaOp momentum = ctx->Input(4);
- accum = xla::Add(xla::Mul(accum, momentum), grad);
+ accum = accum * momentum + grad;
if (use_nesterov_) {
// See https://github.com/tensorflow/tensorflow/pull/2798 for an
// explanation of the reparameterization used here.
- var = xla::Sub(var, xla::Add(xla::Mul(grad, lr),
- xla::Mul(xla::Mul(accum, momentum), lr)));
+ var = var - (grad * lr + accum * momentum * lr);
} else {
- var = xla::Sub(var, xla::Mul(accum, lr));
+ var = var - accum * lr;
}
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var));
OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum));
@@ -118,8 +119,6 @@ class ResourceApplyAdagrad : public XlaOpKernel {
explicit ResourceApplyAdagrad(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- xla::XlaBuilder* b = ctx->builder();
-
DataType type = ctx->input_type(2);
TensorShape var_shape, accum_shape;
@@ -146,12 +145,8 @@ class ResourceApplyAdagrad : public XlaOpKernel {
xla::XlaOp lr = ctx->Input(2);
xla::XlaOp grad = ctx->Input(3);
- accum =
- xla::Add(accum, xla::Pow(grad, XlaHelpers::FloatLiteral(b, type, 2.0)));
- var = xla::Sub(
- var,
- xla::Mul(xla::Mul(grad, lr),
- xla::Pow(accum, XlaHelpers::FloatLiteral(b, type, -0.5))));
+ accum = accum + xla::Square(grad);
+ var = var - grad * lr * xla::Rsqrt(accum);
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var));
OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum));
}
@@ -226,18 +221,12 @@ class ResourceApplyAdam : public XlaOpKernel {
// variable <- variable - alpha * m_t / (sqrt(v_t) + epsilon)
xla::XlaBuilder* b = ctx->builder();
- xla::XlaOp half = XlaHelpers::FloatLiteral(b, dtype_, 0.5);
xla::XlaOp one = XlaHelpers::FloatLiteral(b, dtype_, 1.0);
- xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype_, 2.0);
- xla::XlaOp alpha =
- xla::Div(xla::Mul(lr, xla::Pow(xla::Sub(one, beta2_power), half)),
- xla::Sub(one, beta1_power));
- m = xla::Add(m, xla::Mul(xla::Sub(grad, m), xla::Sub(one, beta1)));
- v = xla::Add(
- v, xla::Mul(xla::Sub(xla::Pow(grad, two), v), xla::Sub(one, beta2)));
- var = xla::Sub(var, xla::Div(xla::Mul(m, alpha),
- xla::Add(xla::Pow(v, half), epsilon)));
+ xla::XlaOp alpha = lr * xla::Sqrt(one - beta2_power) / (one - beta1_power);
+ m = m + (grad - m) * (one - beta1);
+ v = v + (xla::Square(grad) - v) * (one - beta2);
+ var = var - m * alpha / (xla::Sqrt(v) + epsilon);
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var));
OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, m));
@@ -255,8 +244,6 @@ class ResourceApplyRMSProp : public XlaOpKernel {
explicit ResourceApplyRMSProp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- xla::XlaBuilder* b = ctx->builder();
-
DataType type = ctx->input_type(3);
TensorShape var_shape, ms_shape, mom_shape;
@@ -320,17 +307,11 @@ class ResourceApplyRMSProp : public XlaOpKernel {
// ms <- grad**2 (1 - rho) + ms * rho
//
// Which is the equation listed above.
- xla::XlaOp new_ms = xla::Add(
- ms, xla::Mul(
- xla::Sub(xla::Pow(grad, XlaHelpers::FloatLiteral(b, type, 2.0)),
- ms),
- xla::Sub(XlaHelpers::FloatLiteral(b, type, 1.0), rho)));
+ xla::XlaOp new_ms =
+ ms + (xla::Square(grad) - ms) * (xla::ScalarLike(ms, 1.0) - rho);
xla::XlaOp new_mom =
- xla::Add(xla::Mul(mom, momentum),
- xla::Mul(xla::Mul(grad, lr),
- xla::Pow(xla::Add(new_ms, epsilon),
- XlaHelpers::FloatLiteral(b, type, -0.5))));
- xla::XlaOp new_var = xla::Sub(var, new_mom);
+ mom * momentum + grad * lr * xla::Rsqrt(new_ms + epsilon);
+ xla::XlaOp new_var = var - new_mom;
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, new_var));
OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, new_ms));
@@ -425,23 +406,18 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype,
xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype, 2.0);
xla::XlaOp grad_to_use;
if (has_l2_shrinkage) {
- grad_to_use = xla::Add(grad, xla::Mul(two, xla::Mul(l2_shrinkage, var)));
+ grad_to_use = grad + two * l2_shrinkage * var;
} else {
grad_to_use = grad;
}
- xla::XlaOp new_accum = xla::Add(accum, xla::Pow(grad_to_use, two));
- xla::XlaOp new_accum_lr_pow = xla::Pow(new_accum, xla::Neg(lr_power));
- xla::XlaOp accum_lr_pow = xla::Pow(accum, xla::Neg(lr_power));
- linear = xla::Add(
- linear,
- xla::Sub(grad_to_use,
- xla::Mul(xla::Div(xla::Sub(new_accum_lr_pow, accum_lr_pow), lr),
- var)));
- xla::XlaOp linear_clipped = xla::Clamp(xla::Neg(l1), linear, l1);
- xla::XlaOp quadratic =
- xla::Add(xla::Div(new_accum_lr_pow, lr), xla::Mul(two, l2));
- var = xla::Div(xla::Sub(linear_clipped, linear), quadratic);
+ xla::XlaOp new_accum = accum + xla::Square(grad_to_use);
+ xla::XlaOp new_accum_lr_pow = xla::Pow(new_accum, -lr_power);
+ xla::XlaOp accum_lr_pow = xla::Pow(accum, -lr_power);
+ linear = linear + grad_to_use - (new_accum_lr_pow - accum_lr_pow) / lr * var;
+ xla::XlaOp linear_clipped = xla::Clamp(-l1, linear, l1);
+ xla::XlaOp quadratic = new_accum_lr_pow / lr + two * l2;
+ var = (linear_clipped - linear) / quadratic;
accum = new_accum;
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype, var));
diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
index e996916461..116a020437 100644
--- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
@@ -21,6 +21,8 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
@@ -51,56 +53,36 @@ XLAJIT_MAKE_UNARY(Conj, xla::Conj(x));
XLAJIT_MAKE_UNARY(Abs, xla::Abs(x));
// acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x))
-XLAJIT_MAKE_UNARY(
- Acos,
- xla::Mul(XlaHelpers::FloatLiteral(b, input_type(0), 2.0),
- xla::Atan2(xla::Pow(xla::Sub(XlaHelpers::One(b, input_type(0)),
- xla::Mul(x, x)),
- XlaHelpers::FloatLiteral(b, input_type(0),
- 0.5)),
- xla::Add(XlaHelpers::One(b, input_type(0)), x))));
+XLAJIT_MAKE_UNARY(Acos,
+ xla::ScalarLike(x, 2.0) *
+ xla::Atan2(xla::Sqrt(xla::ScalarLike(x, 1.0) - x * x),
+ xla::ScalarLike(x, 1.0) + x));
// acosh(x) = log(x + sqrt(x^2 - 1))
// = log(x + sqrt((x+1)*(x-1)))
-XLAJIT_MAKE_UNARY(
- Acosh,
- xla::Log(xla::Add(
- x, xla::Pow(xla::Mul(xla::Add(x, XlaHelpers::One(b, input_type(0))),
- xla::Sub(x, XlaHelpers::One(b, input_type(0)))),
- XlaHelpers::FloatLiteral(b, input_type(0), 0.5)))));
+XLAJIT_MAKE_UNARY(Acosh,
+ xla::Log(x + xla::Sqrt((x + xla::ScalarLike(x, 1.0)) *
+ (x - xla::ScalarLike(x, 1.0)))));
// asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2)))
XLAJIT_MAKE_UNARY(
- Asin,
- xla::Mul(
- XlaHelpers::FloatLiteral(b, input_type(0), 2.0),
- xla::Atan2(x,
- xla::Add(XlaHelpers::One(b, input_type(0)),
- xla::Pow(xla::Sub(XlaHelpers::One(b, input_type(0)),
- xla::Mul(x, x)),
- XlaHelpers::FloatLiteral(b, input_type(0),
- 0.5))))));
+ Asin, xla::ScalarLike(x, 2.0) *
+ xla::Atan2(x, xla::ScalarLike(x, 1.0) +
+ xla::Sqrt(xla::ScalarLike(x, 1.0) - x * x)));
// asinh(x) = log(x + sqrt(x^2 + 1))
-XLAJIT_MAKE_UNARY(
- Asinh,
- xla::Log(xla::Add(
- x, xla::Pow(xla::Add(xla::Mul(x, x), XlaHelpers::One(b, input_type(0))),
- XlaHelpers::FloatLiteral(b, input_type(0), 0.5)))));
+XLAJIT_MAKE_UNARY(Asinh,
+ xla::Log(x + xla::Sqrt(x * x + xla::ScalarLike(x, 1.0))));
-XLAJIT_MAKE_UNARY(Atan, xla::Atan2(x, XlaHelpers::One(b, input_type(0))));
+XLAJIT_MAKE_UNARY(Atan, xla::Atan2(x, xla::ScalarLike(x, 1.0)));
// atanh(x) = 0.5 * log((1 + x) / (1 - x))
-XLAJIT_MAKE_UNARY(
- Atanh,
- xla::Mul(xla::Log(xla::Div(xla::Add(XlaHelpers::One(b, input_type(0)), x),
- xla::Sub(XlaHelpers::One(b, input_type(0)), x))),
- XlaHelpers::FloatLiteral(b, input_type(0), 0.5)));
+XLAJIT_MAKE_UNARY(Atanh, xla::Log((xla::ScalarLike(x, 1.0) + x) /
+ (xla::ScalarLike(x, 1.0) - x)) *
+ xla::ScalarLike(x, 0.5));
XLAJIT_MAKE_UNARY(Ceil, xla::Ceil(x));
XLAJIT_MAKE_UNARY(Cos, xla::Cos(x));
-XLAJIT_MAKE_UNARY(Cosh,
- xla::Mul(xla::Add(xla::Exp(x), xla::Exp(xla::Neg(x))),
- XlaHelpers::FloatLiteral(b, input_type(0), 0.5)));
+XLAJIT_MAKE_UNARY(Cosh, (xla::Exp(x) + xla::Exp(-x)) * xla::ScalarLike(x, 0.5));
XLAJIT_MAKE_UNARY(Sin, xla::Sin(x));
XLAJIT_MAKE_UNARY(Exp, xla::Exp(x));
@@ -108,59 +90,53 @@ XLAJIT_MAKE_UNARY(Expm1, xla::Expm1(x));
XLAJIT_MAKE_UNARY(Floor, xla::Floor(x));
XLAJIT_MAKE_UNARY(IsFinite, xla::IsFinite(x));
-XLAJIT_MAKE_UNARY(IsInf, xla::Eq(xla::Abs(x),
- XlaHelpers::FloatLiteral(
- b, input_type(0),
- std::numeric_limits<double>::infinity())));
+XLAJIT_MAKE_UNARY(
+ IsInf,
+ xla::Eq(xla::Abs(x),
+ xla::ScalarLike(x, std::numeric_limits<double>::infinity())));
XLAJIT_MAKE_UNARY(IsNan, xla::Ne(x, x));
// Return 1/x
-XLAJIT_MAKE_UNARY(Inv, xla::Div(XlaHelpers::One(b, input_type(0)), x));
-XLAJIT_MAKE_UNARY(Reciprocal, xla::Div(XlaHelpers::One(b, input_type(0)), x));
+XLAJIT_MAKE_UNARY(Inv, xla::ScalarLike(x, 1.0) / x);
+XLAJIT_MAKE_UNARY(Reciprocal, xla::ScalarLike(x, 1.0) / x);
XLAJIT_MAKE_UNARY(Log, xla::Log(x));
XLAJIT_MAKE_UNARY(Log1p, xla::Log1p(x));
XLAJIT_MAKE_UNARY(Invert, xla::Not(x));
XLAJIT_MAKE_UNARY(LogicalNot, xla::Not(x));
-XLAJIT_MAKE_UNARY(Neg, xla::Neg(x));
+XLAJIT_MAKE_UNARY(Neg, -x);
// Implements Banker's rounding: numbers that are equidistant between two
// integers are rounded towards even.
-static xla::XlaOp Round(xla::XlaBuilder* b, DataType dtype,
- const xla::XlaOp& x) {
- auto half = XlaHelpers::FloatLiteral(b, dtype, 0.5);
- auto one = XlaHelpers::FloatLiteral(b, dtype, 1.0);
- auto two = XlaHelpers::FloatLiteral(b, dtype, 2.0);
+xla::XlaOp RoundToEven(xla::XlaOp x) {
+ auto half = xla::ScalarLike(x, 0.5);
+ auto one = xla::ScalarLike(x, 1.0);
+ auto two = xla::ScalarLike(x, 2.0);
auto round_val = xla::Floor(x);
- auto fraction = xla::Sub(x, round_val);
- auto nearest_even_int =
- xla::Sub(round_val, xla::Mul(two, xla::Floor(xla::Mul(half, x))));
+ auto fraction = x - round_val;
+ auto nearest_even_int = round_val - two * xla::Floor(half * x);
auto is_odd = xla::Eq(nearest_even_int, one);
return xla::Select(xla::Or(xla::Gt(fraction, half),
xla::And(xla::Eq(fraction, half), is_odd)),
- xla::Add(round_val, one), round_val);
+ round_val + one, round_val);
}
-XLAJIT_MAKE_UNARY(Rint, Round(b, input_type(0), x));
-XLAJIT_MAKE_UNARY(Round, Round(b, input_type(0), x));
+XLAJIT_MAKE_UNARY(Rint, RoundToEven(x));
+XLAJIT_MAKE_UNARY(Round, RoundToEven(x));
-XLAJIT_MAKE_UNARY(Rsqrt, xla::Pow(x, XlaHelpers::FloatLiteral(b, input_type(0),
- -0.5)));
+XLAJIT_MAKE_UNARY(Rsqrt, xla::Rsqrt(x));
// Expresses sigmoid as a rescaled tanh: sigmoid(x) == (tanh(x/2) + 1) / 2.
-static xla::XlaOp Sigmoid(xla::XlaBuilder* b, DataType dtype,
- const xla::XlaOp& x) {
- auto half = XlaHelpers::FloatLiteral(b, dtype, 0.5);
- return xla::Add(half, xla::Mul(half, xla::Tanh(xla::Mul(half, x))));
+xla::XlaOp Sigmoid(xla::XlaOp x) {
+ auto half = xla::ScalarLike(x, 0.5);
+ return half + half * xla::Tanh(half * x);
}
-XLAJIT_MAKE_UNARY(Sigmoid, Sigmoid(b, input_type(0), x));
+XLAJIT_MAKE_UNARY(Sigmoid, Sigmoid(x));
// Returns 0 if x is 0, -1 if x < 0 and 1 if x > 0.
XLAJIT_MAKE_UNARY(Sign, xla::Sign(x));
-XLAJIT_MAKE_UNARY(Sinh,
- xla::Mul(xla::Sub(xla::Exp(x), xla::Exp(xla::Neg(x))),
- XlaHelpers::FloatLiteral(b, input_type(0), 0.5)));
+XLAJIT_MAKE_UNARY(Sinh, (xla::Exp(x) - xla::Exp(-x)) * xla::ScalarLike(x, 0.5));
// softplus(x) = log(1 + exp(x))
//
@@ -170,18 +146,14 @@ XLAJIT_MAKE_UNARY(Sinh,
//
// This is equivalent to:
// max(x, 0) + log1p(exp(-abs(x)))
-XLAJIT_MAKE_UNARY(Softplus,
- xla::Add(xla::Max(x, XlaHelpers::Zero(b, input_type(0))),
- xla::Log1p(xla::Exp(xla::Neg(xla::Abs(x))))));
+XLAJIT_MAKE_UNARY(Softplus, xla::Max(x, xla::ScalarLike(x, 0.0)) +
+ xla::Log1p(xla::Exp(-xla::Abs(x))));
// softsign(x) = x / (abs(x) + 1)
-XLAJIT_MAKE_UNARY(Softsign,
- xla::Div(x, xla::Add(xla::Abs(x),
- XlaHelpers::One(b, input_type(0)))));
-XLAJIT_MAKE_UNARY(Sqrt,
- xla::Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), 0.5)));
-XLAJIT_MAKE_UNARY(Square, xla::Mul(x, x));
-XLAJIT_MAKE_UNARY(Tan, xla::Div(xla::Sin(x), xla::Cos(x)));
+XLAJIT_MAKE_UNARY(Softsign, x / (xla::Abs(x) + xla::ScalarLike(x, 1.0)));
+XLAJIT_MAKE_UNARY(Sqrt, xla::Sqrt(x));
+XLAJIT_MAKE_UNARY(Square, x* x);
+XLAJIT_MAKE_UNARY(Tan, xla::Sin(x) / xla::Cos(x));
XLAJIT_MAKE_UNARY(Tanh, xla::Tanh(x));
XLAJIT_MAKE_UNARY(Real, xla::Real(x));
@@ -195,18 +167,10 @@ class ErfOp : public XlaOpKernel {
public:
explicit ErfOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- xla::XlaBuilder* b = ctx->builder();
- xla::PrimitiveType primitive_type;
- xla::XlaOp one = XlaHelpers::One(b, input_type(0));
xla::XlaOp x = ctx->Input(0);
- xla::XlaOp abs_x = xla::Abs(x);
-
- OP_REQUIRES_OK(ctx,
- DataTypeToPrimitiveType(input_type(0), &primitive_type));
-
+ xla::XlaOp one = xla::ScalarLike(x, 1.0);
auto y =
- xla::Select(xla::Gt(abs_x, one), xla::Sub(one, Erfc(x, primitive_type)),
- Erf(x, primitive_type));
+ xla::Select(xla::Gt(xla::Abs(x), one), one - xla::Erfc(x), xla::Erf(x));
ctx->SetOutput(0, y);
}
};
@@ -216,18 +180,10 @@ class ErfcOp : public XlaOpKernel {
public:
explicit ErfcOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- xla::XlaBuilder* b = ctx->builder();
- xla::XlaOp one = XlaHelpers::One(b, input_type(0));
xla::XlaOp x = ctx->Input(0);
- xla::XlaOp abs_x = xla::Abs(x);
-
- xla::PrimitiveType primitive_type;
- OP_REQUIRES_OK(ctx,
- DataTypeToPrimitiveType(input_type(0), &primitive_type));
-
+ xla::XlaOp one = xla::ScalarLike(x, 1.0);
auto y =
- xla::Select(xla::Lt(abs_x, one), xla::Sub(one, Erf(x, primitive_type)),
- Erfc(x, primitive_type));
+ xla::Select(xla::Lt(xla::Abs(x), one), one - xla::Erf(x), xla::Erfc(x));
ctx->SetOutput(0, y);
}
};
diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD
index 04c600698c..dfa3c0595a 100644
--- a/tensorflow/compiler/tf2xla/lib/BUILD
+++ b/tensorflow/compiler/tf2xla/lib/BUILD
@@ -44,6 +44,7 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla/client/lib:constants",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/core:lib",
@@ -58,7 +59,8 @@ cc_library(
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/client/lib:constants",
+ "//tensorflow/compiler/xla/client/lib:math",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/core:protos_all_cc",
],
@@ -95,6 +97,7 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/client/lib:constants",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/core:lib",
diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc
index a90178c7d9..cc840de393 100644
--- a/tensorflow/compiler/tf2xla/lib/cholesky.cc
+++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/lib/triangular_solve.h"
#include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -58,7 +59,7 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a) {
/*pos=*/0,
/*len=*/n_dims - 2);
- xla::XlaOp l = Zeros(builder, a_shape);
+ xla::XlaOp l = xla::ZerosLike(a);
// Construct the for loop body to iterate over rows.
auto body_fn = [&](xla::XlaOp i, gtl::ArraySlice<xla::XlaOp> loop_vars,
@@ -73,12 +74,12 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a) {
row_shape.add_dimensions(1);
row_shape.add_dimensions(n);
row_shape.set_element_type(a_shape.element_type());
- auto mask_zeros_row = Zeros(body_builder, row_shape);
+ auto mask_zeros_row = xla::Zeros(body_builder, row_shape);
col_shape.add_dimensions(n);
col_shape.add_dimensions(1);
col_shape.set_element_type(a_shape.element_type());
- auto mask_zeros_col = Zeros(body_builder, col_shape);
+ auto mask_zeros_col = xla::Zeros(body_builder, col_shape);
std::vector<int32> mask_vector(n);
std::iota(mask_vector.begin(), mask_vector.end(), 0);
@@ -170,7 +171,7 @@ xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size) {
// Algorithm 1 from
// Haidar, Azzam, et al. "High-performance Cholesky factorization for
// GPU-only execution." Proceedings of General Purpose GPUs. ACM, 2017.
- xla::XlaOp l = Zeros(builder, a_shape);
+ xla::XlaOp l = xla::ZerosLike(a);
for (int64 i = 0; i < n; i += block_size) {
int64 k = std::min(block_size, n - i);
if (i > 0) {
diff --git a/tensorflow/compiler/tf2xla/lib/random.cc b/tensorflow/compiler/tf2xla/lib/random.cc
index 3dfa66029c..8ff10fbd3f 100644
--- a/tensorflow/compiler/tf2xla/lib/random.cc
+++ b/tensorflow/compiler/tf2xla/lib/random.cc
@@ -19,14 +19,14 @@ limitations under the License.
#include <limits>
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
-#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/status_macros.h"
namespace tensorflow {
-xla::XlaOp TruncatedNormal(const DataType dtype, xla::XlaOp uniform) {
- xla::XlaBuilder* builder = uniform.builder();
+xla::XlaOp TruncatedNormal(xla::XlaOp uniform) {
auto normal_cdf = [](double x) {
return (1.0 + std::erf(x / std::sqrt(2.0))) / 2.0;
};
@@ -41,18 +41,15 @@ xla::XlaOp TruncatedNormal(const DataType dtype, xla::XlaOp uniform) {
const double kBetaNormalCdf = normal_cdf(kBeta);
const double kZ = kBetaNormalCdf - kAlphaNormalCdf;
- xla::XlaOp one = XlaHelpers::FloatLiteral(builder, dtype, 1.0);
- xla::XlaOp two = XlaHelpers::FloatLiteral(builder, dtype, 2.0);
- xla::XlaOp sqrt_2 = XlaHelpers::FloatLiteral(builder, dtype, std::sqrt(2.0));
-
- xla::XlaOp z = XlaHelpers::FloatLiteral(builder, dtype, kZ);
- xla::XlaOp alpha_normal_cdf =
- XlaHelpers::FloatLiteral(builder, dtype, kAlphaNormalCdf);
+ xla::XlaOp one = xla::ScalarLike(uniform, 1.0);
+ xla::XlaOp two = xla::ScalarLike(uniform, 2.0);
+ xla::XlaOp sqrt_2 = xla::ScalarLike(uniform, std::sqrt(2.0));
+ xla::XlaOp z = xla::ScalarLike(uniform, kZ);
+ xla::XlaOp alpha_normal_cdf = xla::ScalarLike(uniform, kAlphaNormalCdf);
+ auto p = alpha_normal_cdf + z * uniform;
// probit(p) = sqrt(2) * erfinv(2*p-1)
- auto p = xla::Add(alpha_normal_cdf, xla::Mul(z, uniform));
- auto erfinv_input = xla::Sub(xla::Mul(p, two), one);
- return xla::Mul(sqrt_2, ErfInv(erfinv_input));
+ return sqrt_2 * xla::ErfInv(two * p - one);
}
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/random.h b/tensorflow/compiler/tf2xla/lib/random.h
index 39cbcf9c5e..2c573fd85b 100644
--- a/tensorflow/compiler/tf2xla/lib/random.h
+++ b/tensorflow/compiler/tf2xla/lib/random.h
@@ -28,7 +28,7 @@ namespace tensorflow {
//
// The "uniform" parameter must be an array of random numbers distributed in
// (0,1).
-xla::XlaOp TruncatedNormal(DataType dtype, xla::XlaOp uniform);
+xla::XlaOp TruncatedNormal(xla::XlaOp uniform);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
index 0d3ce129c7..4f97d1277c 100644
--- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
+++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/lib/batch_dot.h"
#include "tensorflow/compiler/tf2xla/lib/util.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -131,7 +132,7 @@ xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side,
return &computation;
};
- xla::XlaOp output = Zeros(builder, b_shape);
+ xla::XlaOp output = xla::ZerosLike(b);
// Right-looking blocked triangular solve.
// For an explanation of the algorithm, see the TRSM discussion in:
@@ -342,7 +343,7 @@ xla::XlaOp TriangularSolveLeftLooking(xla::XlaOp a, xla::XlaOp b,
// output[..., m-1:, :] = b[..., m-1:, :] / a[..., m-1:, m-1:]
// else:
// output[..., :1, :] = b[..., :1, :] / a[..., :1, :1]
- xla::XlaOp output = Zeros(builder, b_shape);
+ xla::XlaOp output = xla::ZerosLike(b);
{
auto i = transpose_a ? m - 1 : 0;
auto a_slice = SliceInMinorDims(a, {i, i}, {i + 1, i + 1});
@@ -484,7 +485,7 @@ xla::XlaOp TriangularSolveRightLooking(xla::XlaOp a, xla::XlaOp b,
}
// The main computation is performed in a While loop.
- xla::XlaOp output = Zeros(builder, b_shape);
+ xla::XlaOp output = xla::ZerosLike(b);
// Construct the initial loop carry tuple,
// if transpose_a:
diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc
index 6694729495..fdc8bfca49 100644
--- a/tensorflow/compiler/tf2xla/lib/util.cc
+++ b/tensorflow/compiler/tf2xla/lib/util.cc
@@ -28,12 +28,6 @@ limitations under the License.
namespace tensorflow {
-xla::XlaOp Zeros(xla::XlaBuilder* builder, const xla::Shape& shape) {
- return xla::Broadcast(
- xla::ConstantLiteral(builder, xla::Literal::Zero(shape.element_type())),
- xla::AsInt64Slice(shape.dimensions()));
-}
-
xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
double value) {
switch (type) {
diff --git a/tensorflow/compiler/tf2xla/lib/util.h b/tensorflow/compiler/tf2xla/lib/util.h
index ac5d2940ff..6cb6c088e9 100644
--- a/tensorflow/compiler/tf2xla/lib/util.h
+++ b/tensorflow/compiler/tf2xla/lib/util.h
@@ -23,9 +23,6 @@ limitations under the License.
namespace tensorflow {
-// Returns a zero-filled tensor with shape `shape`.
-xla::XlaOp Zeros(xla::XlaBuilder* builder, const xla::Shape& shape);
-
// Returns a floating point scalar constant of 'type' with 'value'.
// If 'type' is complex, returns a real value with zero imaginary component.
xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index 0c98c20805..319cbc74e9 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -231,10 +231,13 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg,
case XlaCompiler::Argument::kConstant:
LOG(FATAL) << "Unreachable case";
case XlaCompiler::Argument::kParameter: {
- TensorShape shape =
- is_entry_computation
- ? options_.shape_representation_fn(arg.shape, arg.type)
- : arg.shape;
+ TensorShape shape;
+ if (is_entry_computation) {
+ TF_ASSIGN_OR_RETURN(
+ shape, options_.shape_representation_fn(arg.shape, arg.type));
+ } else {
+ shape = arg.shape;
+ }
return TensorShapeToXLAShape(arg.type, shape, xla_shape);
}
case XlaCompiler::Argument::kResource: {
@@ -242,8 +245,9 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg,
switch (arg.resource_kind) {
case XlaResource::kVariable: {
- TensorShape representation_shape =
- options_.shape_representation_fn(arg.shape, arg.type);
+ TF_ASSIGN_OR_RETURN(
+ TensorShape representation_shape,
+ options_.shape_representation_fn(arg.shape, arg.type));
return TensorShapeToXLAShape(arg.type, representation_shape,
xla_shape);
}
@@ -664,20 +668,17 @@ Status XlaCompiler::CompileSingleOp(
namespace {
// Check that the ops of all non-functional nodes have been registered.
-string ValidateFunctionDef(const FunctionDef* fdef,
+Status ValidateFunctionDef(const FunctionDef* fdef,
const FunctionLibraryDefinition& flib_def) {
- std::vector<string> invalid_ops;
for (const NodeDef& node : fdef->node_def()) {
const string& op = node.op();
if (op == FunctionLibraryDefinition::kGradientOp || flib_def.Find(op)) {
continue;
}
const OpDef* op_def;
- if (!OpRegistry::Global()->LookUpOpDef(op, &op_def).ok()) {
- invalid_ops.push_back(op);
- }
+ TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(op, &op_def));
}
- return tensorflow::str_util::Join(invalid_ops, ", ");
+ return Status::OK();
}
// Check that the graph doesn't have any invalid nodes (e.g. incompatible with
@@ -685,35 +686,33 @@ string ValidateFunctionDef(const FunctionDef* fdef,
Status ValidateGraph(const Graph* graph,
const FunctionLibraryDefinition& flib_def,
const DeviceType& device_type, const string& name) {
- std::set<string> invalid_ops;
+ auto maybe_error = [&](const string& op, const Status& s) -> Status {
+ if (!s.ok()) {
+ return errors::InvalidArgument(strings::StrCat(
+ "Detected unsupported operations when trying to compile graph ", name,
+ " on ", device_type.type_string(), ": ", op, " (", s.error_message(),
+ ")"));
+ }
+ return Status::OK();
+ };
+
for (const Node* node : graph->nodes()) {
if (node->type_string() == FunctionLibraryDefinition::kGradientOp) {
continue;
}
const FunctionDef* fdef = flib_def.Find(node->def().op());
+ Status s;
if (fdef) {
- string error_msg = ValidateFunctionDef(fdef, flib_def);
- if (!error_msg.empty()) {
- invalid_ops.insert(
- strings::StrCat(node->def().op(), ":{", error_msg, "}"));
- }
+ s = ValidateFunctionDef(fdef, flib_def);
+ TF_RETURN_IF_ERROR(maybe_error(node->def().op(), s));
continue;
}
const OpDef* op_def;
- if (!OpRegistry::Global()->LookUpOpDef(node->def().op(), &op_def).ok()) {
- invalid_ops.insert(node->def().op());
- continue;
- }
+ s = OpRegistry::Global()->LookUpOpDef(node->def().op(), &op_def);
+ TF_RETURN_IF_ERROR(maybe_error(node->def().op(), s));
TF_RETURN_IF_ERROR(ValidateNodeDef(node->def(), *op_def));
- if (!FindKernelDef(device_type, node->def(), nullptr, nullptr).ok()) {
- invalid_ops.insert(node->def().op());
- }
- }
- if (!invalid_ops.empty()) {
- return errors::InvalidArgument(strings::StrCat(
- "Detected unsupported operations when trying to compile graph ", name,
- " on ", device_type.type_string(), ":",
- tensorflow::str_util::Join(invalid_ops, ", ")));
+ s = FindKernelDef(device_type, node->def(), nullptr, nullptr);
+ TF_RETURN_IF_ERROR(maybe_error(node->def().op(), s));
}
return Status::OK();
}
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h
index 80593eaca5..079c99797e 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.h
+++ b/tensorflow/compiler/tf2xla/xla_compiler.h
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/function.h"
@@ -242,7 +243,8 @@ class XlaCompiler {
std::shared_ptr<xla::XlaComputation> computation;
};
- typedef std::function<TensorShape(const TensorShape&, DataType)>
+ typedef std::function<xla::StatusOr<TensorShape>(const TensorShape&,
+ DataType)>
ShapeRepresentationFn;
struct Options {
// Name of the compilation device to use. It must be set by the caller.
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
index 613230452b..07af8ef54b 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
@@ -1021,8 +1021,7 @@ TEST_F(XlaCompilerTest, FunctionWithInvalidOp) {
status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill",
std::move(graph), args, &result);
ASSERT_FALSE(status.ok());
- EXPECT_TRUE(
- str_util::StrContains(status.error_message(), "FillFn:{InvalidOp}"))
+ EXPECT_TRUE(str_util::StrContains(status.error_message(), "InvalidOp"))
<< status.error_message();
}
diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc
index d0b5606907..fd39a58ce6 100644
--- a/tensorflow/compiler/tf2xla/xla_context.cc
+++ b/tensorflow/compiler/tf2xla/xla_context.cc
@@ -66,8 +66,8 @@ XlaContext::XlaContext(
XlaCompiler* compiler, xla::XlaBuilder* builder,
bool allow_cpu_custom_calls, bool resolve_compile_time_constants,
bool is_entry_computation,
- const std::function<TensorShape(const TensorShape&, DataType)>*
- shape_representation_fn)
+ const std::function<xla::StatusOr<TensorShape>(
+ const TensorShape&, DataType)>* shape_representation_fn)
: compiler_(compiler),
builder_(builder),
allow_cpu_custom_calls_(allow_cpu_custom_calls),
@@ -119,8 +119,8 @@ Status XlaContext::CreateResource(
return Status::OK();
}
-TensorShape XlaContext::RepresentationShape(const TensorShape& shape,
- DataType type) const {
+xla::StatusOr<TensorShape> XlaContext::RepresentationShape(
+ const TensorShape& shape, DataType type) const {
return (*shape_representation_fn_)(shape, type);
}
diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h
index 5960daaefd..38d8cd653c 100644
--- a/tensorflow/compiler/tf2xla/xla_context.h
+++ b/tensorflow/compiler/tf2xla/xla_context.h
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
+#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
@@ -47,8 +48,8 @@ class XlaContext : public ResourceBase {
XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder,
bool allow_cpu_custom_calls, bool resolve_compile_time_constants,
bool is_entry_computation,
- const std::function<TensorShape(const TensorShape&, DataType)>*
- shape_representation_fn);
+ const std::function<xla::StatusOr<TensorShape>(
+ const TensorShape&, DataType)>* shape_representation_fn);
// Virtual method defined by ResourceBase.
string DebugString() override;
@@ -101,8 +102,8 @@ class XlaContext : public ResourceBase {
// Returns the XLA shape to be used to represent a variable of TF `shape`
// and `type`, or of an argument or return value of a top-level computation.
- TensorShape RepresentationShape(const TensorShape& shape,
- DataType type) const;
+ xla::StatusOr<TensorShape> RepresentationShape(const TensorShape& shape,
+ DataType type) const;
// Get an XLA lambda to compute Max. This is cached in the
// XlaContext since it may be used by multiple Ops. There is a
@@ -160,7 +161,7 @@ class XlaContext : public ResourceBase {
// should be represented in XLA. Parameters/return values will be shaped
// according to this function, and reshaped back to/from their declared shapes
// for computations. Must be non-null.
- const std::function<TensorShape(const TensorShape&, DataType)>*
+ const std::function<xla::StatusOr<TensorShape>(const TensorShape&, DataType)>*
shape_representation_fn_;
// Cache of prebuilt computations indexed by their type.
diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc
index 81bdf139f5..edbc5e95a8 100644
--- a/tensorflow/compiler/tf2xla/xla_helpers.cc
+++ b/tensorflow/compiler/tf2xla/xla_helpers.cc
@@ -23,6 +23,8 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/numeric.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/types.h"
@@ -34,111 +36,61 @@ namespace tensorflow {
namespace {
-Status ArgMinMax(xla::XlaBuilder* builder, XlaOpKernelContext* ctx,
- const xla::XlaOp& input, const TensorShape& input_shape,
- DataType input_type, DataType output_type, int axis,
- bool is_min, xla::XlaOp* argminmax) {
- xla::XlaOp init_value;
- const xla::XlaComputation* reducer;
- if (is_min) {
- init_value = XlaHelpers::MaxValue(builder, input_type);
- reducer = ctx->GetOrCreateMin(input_type);
- } else {
- init_value = XlaHelpers::MinValue(builder, input_type);
- reducer = ctx->GetOrCreateMax(input_type);
- }
-
- xla::PrimitiveType xla_output_type;
- TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(output_type, &xla_output_type));
-
- xla::XlaOp input_max = xla::Reduce(input, init_value, *reducer,
- /*dimensions_to_reduce=*/{axis});
- std::vector<int64> broadcast_dims(input_shape.dims() - 1);
- std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
- std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
- // Compute a mask that has 1s for elements equal to the maximum.
- xla::XlaOp partial_mask = xla::ConvertElementType(
- xla::Eq(input, input_max, broadcast_dims), xla_output_type);
-
- // In order to make identity elements for a bitwise And, we:
- // Left shift the 1 to the leftmost bit, yielding 0x10...0
- // Arithmetic right shift the 1 back to the rightmost bit, yielding
- // 0xFF...F
- int32 bits_in_type =
- xla::ShapeUtil::ByteSizeOfPrimitiveType(xla_output_type) * 8 - 1;
- xla::XlaOp shift_amount =
- XlaHelpers::IntegerLiteral(builder, output_type, bits_in_type);
- xla::XlaOp full_mask = xla::ShiftRightArithmetic(
- xla::ShiftLeft(partial_mask, shift_amount), shift_amount);
-
- // And with the vector [0, 1, 2, ...] to convert each 0xFF...F into its
- // index.
-
- const int64 axis_size = input_shape.dim_size(axis);
- xla::XlaOp iota = xla::Iota(builder, xla_output_type, axis_size);
- xla::XlaOp product =
- xla::And(full_mask, iota, /*broadcast_dimensions=*/{axis});
-
- // If there are multiple maximum elements, choose the one with the highest
- // index.
- xla::XlaOp output =
- xla::Reduce(product, XlaHelpers::MinValue(builder, output_type),
- *ctx->GetOrCreateMax(output_type),
- /*dimensions_to_reduce=*/{axis});
- *argminmax = output;
- return Status::OK();
+xla::XlaOp ArgMinMax(xla::XlaOp input, xla::PrimitiveType output_type, int axis,
+ bool is_min) {
+ xla::XlaBuilder* builder = input.builder();
+ return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
+ TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(input));
+ xla::XlaOp init_value;
+ xla::XlaComputation reducer;
+ if (is_min) {
+ init_value = xla::MaxValue(builder, input_shape.element_type());
+ reducer =
+ xla::CreateScalarMinComputation(input_shape.element_type(), builder);
+ } else {
+ init_value = xla::MinValue(builder, input_shape.element_type());
+ reducer =
+ xla::CreateScalarMaxComputation(input_shape.element_type(), builder);
+ }
+
+ xla::XlaOp input_max = xla::Reduce(input, init_value, reducer,
+ /*dimensions_to_reduce=*/{axis});
+ std::vector<int64> broadcast_dims(xla::ShapeUtil::Rank(input_shape) - 1);
+ std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
+ std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
+ // Compute a mask that has 1s for elements equal to the maximum.
+ xla::XlaOp partial_mask = xla::ConvertElementType(
+ xla::Eq(input, input_max, broadcast_dims), output_type);
+
+ // In order to make identity elements for a bitwise And, we:
+ // Left shift the 1 to the leftmost bit, yielding 0x10...0
+ // Arithmetic right shift the 1 back to the rightmost bit, yielding
+ // 0xFF...F
+ int32 bits_in_type =
+ xla::ShapeUtil::ByteSizeOfPrimitiveType(output_type) * 8 - 1;
+ xla::XlaOp shift_amount =
+ xla::ConstantR0WithType(builder, output_type, bits_in_type);
+ xla::XlaOp full_mask = xla::ShiftRightArithmetic(
+ xla::ShiftLeft(partial_mask, shift_amount), shift_amount);
+
+ // And with the vector [0, 1, 2, ...] to convert each 0xFF...F into its
+ // index.
+
+ const int64 axis_size = xla::ShapeUtil::GetDimension(input_shape, axis);
+ xla::XlaOp iota = xla::Iota(builder, output_type, axis_size);
+ xla::XlaOp product =
+ xla::And(full_mask, iota, /*broadcast_dimensions=*/{axis});
+
+ // If there are multiple maximum elements, choose the one with the highest
+ // index.
+ return xla::Reduce(product, xla::MinValue(builder, output_type),
+ xla::CreateScalarMaxComputation(output_type, builder),
+ /*dimensions_to_reduce=*/{axis});
+ });
}
} // namespace
-xla::XlaOp XlaHelpers::MinValue(xla::XlaBuilder* b, DataType data_type) {
- xla::PrimitiveType type;
- TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
- return xla::ConstantLiteral(b, xla::Literal::MinValue(type));
-}
-
-xla::XlaOp XlaHelpers::MinFiniteValue(xla::XlaBuilder* b, DataType data_type) {
- xla::PrimitiveType type;
- TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
- switch (type) {
- case xla::F16:
- return xla::ConstantR0<Eigen::half>(
- b, Eigen::NumTraits<Eigen::half>::lowest());
- case xla::BF16:
- return xla::ConstantR0<bfloat16>(b, bfloat16::lowest());
- case xla::F32:
- return xla::ConstantR0<float>(b, -std::numeric_limits<float>::max());
- case xla::F64:
- return xla::ConstantR0<double>(b, -std::numeric_limits<double>::max());
- default:
- return xla::ConstantLiteral(b, xla::Literal::MinValue(type));
- }
-}
-
-xla::XlaOp XlaHelpers::MaxValue(xla::XlaBuilder* b, DataType data_type) {
- xla::PrimitiveType type;
- TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
- return xla::ConstantLiteral(b, xla::Literal::MaxValue(type));
-}
-
-xla::XlaOp XlaHelpers::MaxFiniteValue(xla::XlaBuilder* b, DataType data_type) {
- xla::PrimitiveType type;
- TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
- switch (type) {
- case xla::F16:
- return xla::ConstantR0<Eigen::half>(
- b, Eigen::NumTraits<Eigen::half>::highest());
- case xla::BF16:
- return xla::ConstantR0<bfloat16>(b, bfloat16::highest());
- case xla::F32:
- return xla::ConstantR0<float>(b, std::numeric_limits<float>::max());
- case xla::F64:
- return xla::ConstantR0<double>(b, std::numeric_limits<double>::max());
- default:
- return xla::ConstantLiteral(b, xla::Literal::MaxValue(type));
- }
-}
-
xla::XlaOp XlaHelpers::Zero(xla::XlaBuilder* b, DataType data_type) {
xla::PrimitiveType type;
TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
@@ -151,24 +103,6 @@ xla::XlaOp XlaHelpers::One(xla::XlaBuilder* b, DataType data_type) {
return xla::ConstantLiteral(b, xla::Literal::One(type));
}
-xla::XlaOp XlaHelpers::Epsilon(xla::XlaBuilder* b, DataType data_type) {
- switch (data_type) {
- case DT_HALF:
- return xla::ConstantR0<Eigen::half>(
- b,
- static_cast<Eigen::half>(Eigen::NumTraits<Eigen::half>::epsilon()));
- case DT_BFLOAT16:
- return xla::ConstantR0<bfloat16>(b, bfloat16::epsilon());
- case DT_FLOAT:
- return xla::ConstantR0<float>(b, std::numeric_limits<float>::epsilon());
- case DT_DOUBLE:
- return xla::ConstantR0<double>(b, std::numeric_limits<double>::epsilon());
- default:
- LOG(FATAL) << "Unsupported type in XlaHelpers::Epsilon: "
- << DataTypeString(data_type);
- }
-}
-
xla::XlaOp XlaHelpers::IntegerLiteral(xla::XlaBuilder* b, DataType data_type,
int64 value) {
xla::PrimitiveType type;
@@ -214,20 +148,14 @@ static Tensor MakeLinspaceTensor(const TensorShape& shape, int64 depth) {
return linspace;
}
-Status XlaHelpers::ArgMax(xla::XlaBuilder* builder, XlaOpKernelContext* ctx,
- const xla::XlaOp& input,
- const TensorShape& input_shape, DataType input_type,
- DataType output_type, int axis, xla::XlaOp* argmax) {
- return ArgMinMax(builder, ctx, input, input_shape, input_type, output_type,
- axis, /*is_min=*/false, argmax);
+xla::XlaOp XlaHelpers::ArgMax(xla::XlaOp input, xla::PrimitiveType output_type,
+ int axis) {
+ return ArgMinMax(input, output_type, axis, /*is_min=*/false);
}
-Status XlaHelpers::ArgMin(xla::XlaBuilder* builder, XlaOpKernelContext* ctx,
- const xla::XlaOp& input,
- const TensorShape& input_shape, DataType input_type,
- DataType output_type, int axis, xla::XlaOp* argmin) {
- return ArgMinMax(builder, ctx, input, input_shape, input_type, output_type,
- axis, /*is_min=*/true, argmin);
+xla::XlaOp XlaHelpers::ArgMin(xla::XlaOp input, xla::PrimitiveType output_type,
+ int axis) {
+ return ArgMinMax(input, output_type, axis, /*is_min=*/true);
}
Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis,
diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h
index 495bd2b8b6..d6ca4ab934 100644
--- a/tensorflow/compiler/tf2xla/xla_helpers.h
+++ b/tensorflow/compiler/tf2xla/xla_helpers.h
@@ -28,22 +28,6 @@ namespace tensorflow {
// Helper methods for building XLA computations.
class XlaHelpers {
public:
- // Returns a handle representing the minimum value of a scalar
- // element of data_type. -inf for floating-point types.
- static xla::XlaOp MinValue(xla::XlaBuilder* b, DataType data_type);
-
- // Returns a handle representing the minimum finite value of a scalar
- // element of data_type.
- static xla::XlaOp MinFiniteValue(xla::XlaBuilder* b, DataType data_type);
-
- // Returns a handle representing the maximum value of a scalar
- // element of data_type. inf for floating point types.
- static xla::XlaOp MaxValue(xla::XlaBuilder* b, DataType data_type);
-
- // Returns a handle representing the maximum finite value of a scalar
- // element of data_type.
- static xla::XlaOp MaxFiniteValue(xla::XlaBuilder* b, DataType data_type);
-
// Returns a handle representing the zero value of a scalar
// element of data_type.
static xla::XlaOp Zero(xla::XlaBuilder* b, DataType data_type);
@@ -52,10 +36,6 @@ class XlaHelpers {
// element of data_type.
static xla::XlaOp One(xla::XlaBuilder* b, DataType data_type);
- // Returns the machine epsilon for floating-point type `data_type`, i.e.,
- // the difference between 1.0 and the next representable value.
- static xla::XlaOp Epsilon(xla::XlaBuilder* b, DataType data_type);
-
// Returns a handle representing the given value of an integer scalar
// element of data_type.
// Note that unlike One and Zero, does not work on boolean types.
@@ -73,21 +53,15 @@ class XlaHelpers {
gtl::ArraySlice<int64> shape,
xla::Literal* output);
- // Sets `argmax` to the argmax of `input` along `axis`. `input_shape` and
- // `input_dtype` are the shape and dtype of `input` respectively, and
- // `output_type` is the dtype to use for `argmax`.
- static Status ArgMax(xla::XlaBuilder* builder, XlaOpKernelContext* ctx,
- const xla::XlaOp& input, const TensorShape& input_shape,
- DataType input_type, DataType output_type, int axis,
- xla::XlaOp* argmax);
-
- // Sets `argmin` to the argmin of `input` along `axis`. `input_shape` and
- // `input_dtype` are the shape and dtype of `input` respectively, and
- // `output_type` is the dtype to use for `argmin`.
- static Status ArgMin(xla::XlaBuilder* builder, XlaOpKernelContext* ctx,
- const xla::XlaOp& input, const TensorShape& input_shape,
- DataType input_type, DataType output_type, int axis,
- xla::XlaOp* argmin);
+ // Returns the argmax of `input` along `axis`. `output_type` is the type to
+ // use for the output.
+ static xla::XlaOp ArgMax(xla::XlaOp input, xla::PrimitiveType output_type,
+ int axis);
+
+ // Returns the argmin of `input` along `axis`. `output_type` is the type to
+ // use for the output.
+ static xla::XlaOp ArgMin(xla::XlaOp input, xla::PrimitiveType output_type,
+ int axis);
// Converts `indices` into a one-hot representation. `depth` is the size
// of the new axis to add. `axis` is the position at which to add the new
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index 0eabfb3a52..359cb4c467 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
namespace tensorflow {
@@ -353,8 +354,9 @@ Status XlaOpKernelContext::ReadVariableInput(int index, DataType type,
}
XlaContext& xla_context = XlaContext::Get(context_);
- TensorShape representation_shape =
- xla_context.RepresentationShape(variable->shape(), variable->type());
+ TF_ASSIGN_OR_RETURN(
+ TensorShape representation_shape,
+ xla_context.RepresentationShape(variable->shape(), variable->type()));
if (representation_shape == variable->shape()) {
*value = variable->value();
} else {
@@ -474,8 +476,8 @@ Status XlaOpKernelContext::AssignVariable(int input_index, DataType type,
TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape));
XlaContext& xla_context = XlaContext::Get(context_);
- TensorShape representation_shape =
- xla_context.RepresentationShape(shape, type);
+ TF_ASSIGN_OR_RETURN(TensorShape representation_shape,
+ xla_context.RepresentationShape(shape, type));
if (shape != representation_shape) {
handle = xla::Reshape(handle, representation_shape.dim_sizes());
}
diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD
index 273fa17371..a6b9b47253 100644
--- a/tensorflow/compiler/xla/client/lib/BUILD
+++ b/tensorflow/compiler/xla/client/lib/BUILD
@@ -24,6 +24,7 @@ cc_library(
srcs = ["arithmetic.cc"],
hdrs = ["arithmetic.h"],
deps = [
+ ":constants",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:types",
@@ -35,6 +36,62 @@ cc_library(
)
cc_library(
+ name = "constants",
+ srcs = ["constants.cc"],
+ hdrs = ["constants.h"],
+ deps = [
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ ],
+)
+
+xla_test(
+ name = "constants_test",
+ srcs = ["constants_test.cc"],
+ tags = ["enable_for_xla_interpreter"],
+ deps = [
+ ":constants",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ ],
+)
+
+cc_library(
+ name = "math",
+ srcs = ["math.cc"],
+ hdrs = ["math.h"],
+ deps = [
+ ":constants",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ ],
+)
+
+xla_test(
+ name = "math_test",
+ srcs = ["math_test.cc"],
+ tags = ["enable_for_xla_interpreter"],
+ deps = [
+ ":math",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ ],
+)
+
+cc_library(
name = "numeric",
srcs = ["numeric.cc"],
hdrs = ["numeric.h"],
diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc
index 8c314fa61b..978fc40f34 100644
--- a/tensorflow/compiler/xla/client/lib/arithmetic.cc
+++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <string>
+#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -120,134 +121,4 @@ XlaOp Any(XlaOp predicates) {
});
}
-namespace {
-XlaOp FloatLiteral(XlaBuilder* b, PrimitiveType data_type, float value) {
- return ConvertElementType(ConstantR0(b, value), data_type);
-}
-
-// Polynomials for computing erf/erfc. Originally from cephes.
-// Note we use float for compatibility across devices, at the cost of some
-// precision for 64 bit computations.
-//
-// Coefficients are in descending order.
-std::array<float, 9> kErfcPCoefficient = {
- 2.46196981473530512524E-10, 5.64189564831068821977E-1,
- 7.46321056442269912687E0, 4.86371970985681366614E1,
- 1.96520832956077098242E2, 5.26445194995477358631E2,
- 9.34528527171957607540E2, 1.02755188689515710272E3,
- 5.57535335369399327526E2};
-std::array<float, 9> kErfcQCoefficient = {
- 1.00000000000000000000E0, 1.32281951154744992508E1,
- 8.67072140885989742329E1, 3.54937778887819891062E2,
- 9.75708501743205489753E2, 1.82390916687909736289E3,
- 2.24633760818710981792E3, 1.65666309194161350182E3,
- 5.57535340817727675546E2};
-std::array<float, 6> kErfcRCoefficient = {
- 5.64189583547755073984E-1, 1.27536670759978104416E0,
- 5.01905042251180477414E0, 6.16021097993053585195E0,
- 7.40974269950448939160E0, 2.97886665372100240670E0};
-std::array<float, 7> kErfcSCoefficient = {
- 1.00000000000000000000E0, 2.26052863220117276590E0,
- 9.39603524938001434673E0, 1.20489539808096656605E1,
- 1.70814450747565897222E1, 9.60896809063285878198E0,
- 3.36907645100081516050E0};
-std::array<float, 5> kErfTCoefficient = {
- 9.60497373987051638749E0, 9.00260197203842689217E1,
- 2.23200534594684319226E3, 7.00332514112805075473E3,
- 5.55923013010394962768E4};
-std::array<float, 6> kErfUCoefficient = {
- 1.00000000000000000000E0, 3.35617141647503099647E1,
- 5.21357949780152679795E2, 4.59432382970980127987E3,
- 2.26290000613890934246E4, 4.92673942608635921086E4};
-} // namespace
-
-// Evaluate the polynomial given coefficients and `x`.
-// N.B. Coefficients should be supplied in decreasing order.
-XlaOp EvaluatePolynomial(XlaOp x,
- tensorflow::gtl::ArraySlice<float> coefficients,
- PrimitiveType data_type) {
- XlaBuilder* b = x.builder();
- XlaOp poly = FloatLiteral(b, data_type, 0.0);
- for (float c : coefficients) {
- poly = Add(Mul(poly, x), FloatLiteral(b, data_type, c));
- }
- return poly;
-}
-
-// Compute an approximation of the error function complement (1 - erf(x)).
-XlaOp Erfc(XlaOp x, PrimitiveType data_type) {
- XlaBuilder* b = x.builder();
- XlaOp zero = FloatLiteral(b, data_type, 0.0);
- XlaOp two = FloatLiteral(b, data_type, 2.0);
- XlaOp eight = FloatLiteral(b, data_type, 8.0);
-
- XlaOp abs_x = Abs(x);
- XlaOp z = Exp(Mul(Neg(x), x));
-
- XlaOp pp = EvaluatePolynomial(abs_x, kErfcPCoefficient, data_type);
- XlaOp pq = EvaluatePolynomial(abs_x, kErfcQCoefficient, data_type);
- XlaOp pr = EvaluatePolynomial(abs_x, kErfcRCoefficient, data_type);
- XlaOp ps = EvaluatePolynomial(abs_x, kErfcSCoefficient, data_type);
-
- XlaOp y = Select(Lt(abs_x, eight), Div(Mul(z, pp), pq), Div(Mul(z, pr), ps));
-
- return Select(Lt(x, zero), Sub(two, y), y);
-}
-
-// Compute a polynomial approximation of the error function.
-XlaOp Erf(XlaOp x, PrimitiveType data_type) {
- XlaOp z = Mul(x, x);
- XlaOp pt = EvaluatePolynomial(z, kErfTCoefficient, data_type);
- XlaOp pu = EvaluatePolynomial(z, kErfUCoefficient, data_type);
- return Div(Mul(x, pt), pu);
-}
-
-// Approximation for the inverse error function from
-// Giles, M., "Approximating the erfinv function".
-// The approximation has the form:
-// w = -log((1 - x) * (1 + x))
-// if ( w < 5 ) {
-// w = w - 2.5
-// p = sum_{i=1}^n lq[i]*w^i
-// } else {
-// w = sqrt(w) - 3
-// p = sum_{i=1}^n gq[i]*w^i
-// }
-// return p*x
-XlaOp ErfInv(XlaOp x) {
- XlaBuilder* b = x.builder();
- return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
- TF_ASSIGN_OR_RETURN(Shape shape, b->GetShape(x));
- constexpr int kDegree = 9;
- constexpr std::array<float, 9> w_less_than_5_constants = {
- 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f,
- -4.39150654e-06f, 0.00021858087f, -0.00125372503f,
- -0.00417768164f, 0.246640727f, 1.50140941f};
- constexpr std::array<float, 9> w_greater_than_5_constants = {
- -0.000200214257f, 0.000100950558f, 0.00134934322f,
- -0.00367342844f, 0.00573950773f, -0.0076224613f,
- 0.00943887047f, 1.00167406f, 2.83297682f};
-
- auto one = ConstantR0<float>(b, 1.0);
- auto w = Neg(Log(Mul(Sub(one, x), Add(one, x))));
-
- auto lt = Lt(w, ConstantR0<float>(b, 5.0));
- auto coefficient = [&](int i) {
- return Select(
- lt,
- Broadcast(ConstantR0<float>(b, w_less_than_5_constants[i]),
- AsInt64Slice(shape.dimensions())),
- Broadcast(ConstantR0<float>(b, w_greater_than_5_constants[i]),
- AsInt64Slice(shape.dimensions())));
- };
- w = Select(lt, Sub(w, ConstantR0<float>(b, 2.5f)),
- Sub(SqrtF32(w), ConstantR0<float>(b, 3.0f)));
- auto p = coefficient(0);
- for (int i = 1; i < kDegree; ++i) {
- p = Add(coefficient(i), Mul(p, w));
- }
- return Mul(p, x);
- });
-}
-
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h
index d0e04bbb5e..d0b916e8c8 100644
--- a/tensorflow/compiler/xla/client/lib/arithmetic.h
+++ b/tensorflow/compiler/xla/client/lib/arithmetic.h
@@ -55,21 +55,6 @@ XlaComputation CreateScalarOrComputation(XlaBuilder* builder);
// Note: if predicates is zero-sized, Any() vacuously returns false.
XlaOp Any(XlaOp predicates);
-// Evaluate the polynomial given coefficients and `x`.
-// N.B. Coefficients should be supplied in decreasing order.
-XlaOp EvaluatePolynomial(XlaOp x,
- tensorflow::gtl::ArraySlice<float> coefficients,
- PrimitiveType data_type);
-
-// Compute an approximation of the error function complement (1 - erf(x)).
-XlaOp Erfc(XlaOp x, PrimitiveType data_type);
-
-// Compute an approximation of the error function.
-XlaOp Erf(XlaOp x, PrimitiveType data_type);
-
-// Compute an approximation of the inverse of the error function.
-XlaOp ErfInv(XlaOp x);
-
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_ARITHMETIC_H_
diff --git a/tensorflow/compiler/xla/client/lib/constants.cc b/tensorflow/compiler/xla/client/lib/constants.cc
new file mode 100644
index 0000000000..1686389a23
--- /dev/null
+++ b/tensorflow/compiler/xla/client/lib/constants.cc
@@ -0,0 +1,103 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/util.h"
+
+namespace xla {
+
+XlaOp Zero(XlaBuilder* builder, PrimitiveType type) {
+ return ConstantLiteral(builder, Literal::Zero(type));
+}
+
+XlaOp Zeros(XlaBuilder* builder, const Shape& shape) {
+ return Broadcast(Zero(builder, shape.element_type()),
+ AsInt64Slice(shape.dimensions()));
+}
+
+XlaOp ZerosLike(XlaOp prototype) {
+ XlaBuilder* builder = prototype.builder();
+ return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(prototype));
+ return Zeros(builder, shape);
+ });
+}
+
+XlaOp One(XlaBuilder* builder, PrimitiveType type) {
+ return ConstantLiteral(builder, Literal::One(type));
+}
+
+XlaOp Epsilon(XlaBuilder* builder, PrimitiveType type) {
+ switch (type) {
+ case F16:
+ return ConstantR0<Eigen::half>(
+ builder,
+ static_cast<Eigen::half>(Eigen::NumTraits<Eigen::half>::epsilon()));
+ case BF16:
+ return ConstantR0<bfloat16>(builder, bfloat16::epsilon());
+ case F32:
+ return ConstantR0<float>(builder, std::numeric_limits<float>::epsilon());
+ case F64:
+ return ConstantR0<double>(builder,
+ std::numeric_limits<double>::epsilon());
+ default:
+ return builder->ReportError(InvalidArgument(
+ "Invalid type for Epsilon (%s).", PrimitiveType_Name(type).c_str()));
+ }
+}
+
+XlaOp MinValue(XlaBuilder* builder, PrimitiveType type) {
+ return ConstantLiteral(builder, Literal::MinValue(type));
+}
+
+XlaOp MinFiniteValue(XlaBuilder* builder, PrimitiveType type) {
+ switch (type) {
+ case F16:
+ return ConstantR0<Eigen::half>(builder,
+ Eigen::NumTraits<Eigen::half>::lowest());
+ case BF16:
+ return ConstantR0<bfloat16>(builder, bfloat16::lowest());
+ case F32:
+ return ConstantR0<float>(builder, -std::numeric_limits<float>::max());
+ case F64:
+ return ConstantR0<double>(builder, -std::numeric_limits<double>::max());
+ default:
+ return MinValue(builder, type);
+ }
+}
+
+XlaOp MaxValue(XlaBuilder* builder, PrimitiveType type) {
+ return ConstantLiteral(builder, Literal::MaxValue(type));
+}
+
+XlaOp MaxFiniteValue(XlaBuilder* builder, PrimitiveType type) {
+ switch (type) {
+ case F16:
+ return ConstantR0<Eigen::half>(builder,
+ Eigen::NumTraits<Eigen::half>::highest());
+ case BF16:
+ return ConstantR0<bfloat16>(builder, bfloat16::highest());
+ case F32:
+ return ConstantR0<float>(builder, std::numeric_limits<float>::max());
+ case F64:
+ return ConstantR0<double>(builder, std::numeric_limits<double>::max());
+ default:
+ return MaxValue(builder, type);
+ }
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/client/lib/constants.h b/tensorflow/compiler/xla/client/lib/constants.h
new file mode 100644
index 0000000000..b47f5243f0
--- /dev/null
+++ b/tensorflow/compiler/xla/client/lib/constants.h
@@ -0,0 +1,124 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONSTANTS_H_
+#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONSTANTS_H_
+
+#include <type_traits>
+
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/primitive_util.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+namespace xla {
+
+// Returns scalar 'value' as a scalar of 'type'. Unlike ConstantR0, 'type' is
+// determined at C++ run-time, rather than C++ compile-time.
+// If 'value' is floating point but 'type' is not, or if 'value' is complex but
+// 'type' is not, an error will be returned. This is to catch accidental
+// truncation; in such cases, use an explicit cast.
+template <typename T>
+XlaOp ConstantR0WithType(XlaBuilder* builder, PrimitiveType type, T value) {
+ if (std::is_floating_point<T>::value &&
+ !(primitive_util::IsFloatingPointType(type) ||
+ primitive_util::IsComplexType(type))) {
+ return builder->ReportError(InvalidArgument(
+ "Invalid cast from floating point type to %s in ConstantR0WithType.",
+ PrimitiveType_Name(type).c_str()));
+ }
+ if (std::is_same<T, complex64>::value &&
+ !primitive_util::IsComplexType(type)) {
+ return builder->ReportError(InvalidArgument(
+ "Invalid cast from complex type to %s in ConstantR0WithType.",
+ PrimitiveType_Name(type).c_str()));
+ }
+ switch (type) {
+ case F16:
+ return ConstantR0<half>(builder, static_cast<half>(value));
+ case BF16:
+ return ConstantR0<bfloat16>(builder, static_cast<bfloat16>(value));
+ case F32:
+ return ConstantR0<float>(builder, static_cast<float>(value));
+ case F64:
+ return ConstantR0<double>(builder, static_cast<double>(value));
+ case C64:
+ return ConstantR0<complex64>(builder, static_cast<complex64>(value));
+ case U8:
+ return ConstantR0<uint8>(builder, static_cast<uint8>(value));
+ case U32:
+ return ConstantR0<uint32>(builder, static_cast<uint32>(value));
+ case U64:
+ return ConstantR0<uint64>(builder, static_cast<uint64>(value));
+ case S8:
+ return ConstantR0<int8>(builder, static_cast<int8>(value));
+ case S32:
+ return ConstantR0<int32>(builder, static_cast<int32>(value));
+ case S64:
+ return ConstantR0<int64>(builder, static_cast<int64>(value));
+ default:
+ return builder->ReportError(
+ InvalidArgument("Invalid type for ConstantR0WithType (%s).",
+ PrimitiveType_Name(type).c_str()));
+ }
+}
+
+// Returns a scalar containing 'value' cast to the same run-time type as
+// 'prototype'.
+// If 'value' is floating point but 'prototype' is not, or if 'value' is complex
+// 'prototype' is not, an error will be returned.
+template <typename T>
+XlaOp ScalarLike(XlaOp prototype, T value) {
+ XlaBuilder* builder = prototype.builder();
+ return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(prototype));
+ return ConstantR0WithType(builder, shape.element_type(), value);
+ });
+}
+
+// Returns a scalar with value '0' of 'type'.
+XlaOp Zero(XlaBuilder* builder, PrimitiveType type);
+
+// Returns a zero-filled tensor with shape `shape`.
+XlaOp Zeros(XlaBuilder* builder, const Shape& shape);
+
+// Returns a zero-filled tensor with the same shape as `prototype`.
+XlaOp ZerosLike(XlaOp prototype);
+
+// Returns a scalar with value '1' of 'type'.
+XlaOp One(XlaBuilder* builder, PrimitiveType type);
+
+// Returns the machine epsilon for floating-point type `type`, i.e.,
+// the difference between 1.0 and the next representable value.
+XlaOp Epsilon(XlaBuilder* builder, PrimitiveType type);
+
+// Returns the minimum representable finite or infinite value for 'type'.
+// Returns '-inf' for floating-point types.
+XlaOp MinValue(XlaBuilder* builder, PrimitiveType type);
+
+// Returns the minimum representable finite value for 'type'. For a floating
+// point type, this is equal to -MaxFiniteValue().
+XlaOp MinFiniteValue(XlaBuilder* builder, PrimitiveType type);
+
+// Returns the maximum representable finite or infinite value for 'type'.
+// Returns 'inf' for floating-point types.
+XlaOp MaxValue(XlaBuilder* builder, PrimitiveType type);
+
+// Returns the maximum representable finite value for 'type'.
+XlaOp MaxFiniteValue(XlaBuilder* builder, PrimitiveType type);
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONSTANTS_H_
diff --git a/tensorflow/compiler/xla/client/lib/constants_test.cc b/tensorflow/compiler/xla/client/lib/constants_test.cc
new file mode 100644
index 0000000000..f1e3439862
--- /dev/null
+++ b/tensorflow/compiler/xla/client/lib/constants_test.cc
@@ -0,0 +1,159 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+namespace xla {
+namespace {
+
+using ConstantsTest = ClientLibraryTestBase;
+
+using ::testing::HasSubstr;
+
+XLA_TEST_F(ConstantsTest, ConstantR0WithTypeS32) {
+ XlaBuilder builder(TestName());
+ ConstantR0WithType(&builder, xla::S32, 4);
+ ComputeAndCompareR0<int32>(&builder, 4, {});
+}
+
+XLA_TEST_F(ConstantsTest, ConstantR0WithTypeS32DoesNotAcceptFloats) {
+ XlaBuilder builder(TestName());
+ ConstantR0WithType(&builder, xla::S32, 4.5);
+ auto statusor = builder.Build();
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(), HasSubstr("Invalid cast"));
+}
+
+XLA_TEST_F(ConstantsTest, ConstantR0WithTypeF32) {
+ XlaBuilder builder(TestName());
+ ConstantR0WithType(&builder, xla::F32, -7);
+ ComputeAndCompareR0<float>(&builder, -7, {});
+ ConstantR0WithType(&builder, xla::F32, 0.5);
+ ComputeAndCompareR0<float>(&builder, 0.5, {});
+}
+
+XLA_TEST_F(ConstantsTest, ScalarLikeS32) {
+ XlaBuilder builder(TestName());
+ ScalarLike(ConstantR0<int32>(&builder, 42), -3);
+ ComputeAndCompareR0<int32>(&builder, -3, {});
+}
+
+XLA_TEST_F(ConstantsTest, ScalarLikeF32) {
+ XlaBuilder builder(TestName());
+ ScalarLike(ConstantR0<float>(&builder, 42.75), -3.2);
+ ComputeAndCompareR0<float>(&builder, -3.2, {});
+}
+
+XLA_TEST_F(ConstantsTest, ZeroS32) {
+ XlaBuilder builder(TestName());
+ Zero(&builder, S32);
+ ComputeAndCompareR0<int32>(&builder, 0, {});
+}
+
+XLA_TEST_F(ConstantsTest, ZeroF32) {
+ XlaBuilder builder(TestName());
+ Zero(&builder, F32);
+ ComputeAndCompareR0<float>(&builder, 0.0, {});
+}
+
+XLA_TEST_F(ConstantsTest, ZerosS32) {
+ XlaBuilder builder(TestName());
+ Zeros(&builder, ShapeUtil::MakeShape(S32, {2, 2}));
+ ComputeAndCompareR2<int32>(&builder, {{0, 0}, {0, 0}}, {});
+}
+
+XLA_TEST_F(ConstantsTest, ZerosLikeF32) {
+ XlaBuilder builder(TestName());
+ ZerosLike(ConstantR1<float>(&builder, {1., 2., 3.}));
+ ComputeAndCompareR1<float>(&builder, {0., 0., 0.}, {});
+}
+
+XLA_TEST_F(ConstantsTest, OneS32) {
+ XlaBuilder builder(TestName());
+ One(&builder, S32);
+ ComputeAndCompareR0<int32>(&builder, 1, {});
+}
+
+XLA_TEST_F(ConstantsTest, OneF32) {
+ XlaBuilder builder(TestName());
+ One(&builder, F32);
+ ComputeAndCompareR0<float>(&builder, 1., {});
+}
+
+XLA_TEST_F(ConstantsTest, EpsilonF32) {
+ XlaBuilder builder(TestName());
+ Epsilon(&builder, F32);
+ ComputeAndCompareR0<float>(&builder, std::numeric_limits<float>::epsilon(),
+ {});
+}
+
+XLA_TEST_F(ConstantsTest, MinFiniteValueS32) {
+ XlaBuilder builder(TestName());
+ MinFiniteValue(&builder, S32);
+ ComputeAndCompareR0<int32>(&builder, std::numeric_limits<int32>::min(), {});
+}
+
+XLA_TEST_F(ConstantsTest, MaxFiniteValueS32) {
+ XlaBuilder builder(TestName());
+ MaxFiniteValue(&builder, S32);
+ ComputeAndCompareR0<int32>(&builder, std::numeric_limits<int32>::max(), {});
+}
+
+XLA_TEST_F(ConstantsTest, MinFiniteValueF32) {
+ XlaBuilder builder(TestName());
+ MinFiniteValue(&builder, F32);
+ ComputeAndCompareR0<float>(&builder, -std::numeric_limits<float>::max(), {});
+}
+
+XLA_TEST_F(ConstantsTest, MaxFiniteValueF32) {
+ XlaBuilder builder(TestName());
+ MaxFiniteValue(&builder, F32);
+ ComputeAndCompareR0<float>(&builder, std::numeric_limits<float>::max(), {});
+}
+
+XLA_TEST_F(ConstantsTest, MinValueS32) {
+ XlaBuilder builder(TestName());
+ MinValue(&builder, S32);
+ ComputeAndCompareR0<int32>(&builder, std::numeric_limits<int32>::min(), {});
+}
+
+XLA_TEST_F(ConstantsTest, MaxValueS32) {
+ XlaBuilder builder(TestName());
+ MaxValue(&builder, S32);
+ ComputeAndCompareR0<int32>(&builder, std::numeric_limits<int32>::max(), {});
+}
+
+XLA_TEST_F(ConstantsTest, MinValueF32) {
+ XlaBuilder builder(TestName());
+ MinValue(&builder, F32);
+ ComputeAndCompareR0<float>(&builder, -std::numeric_limits<float>::infinity(),
+ {});
+}
+
+XLA_TEST_F(ConstantsTest, MaxValueF32) {
+ XlaBuilder builder(TestName());
+ MaxValue(&builder, F32);
+ ComputeAndCompareR0<float>(&builder, std::numeric_limits<float>::infinity(),
+ {});
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc
new file mode 100644
index 0000000000..5587559040
--- /dev/null
+++ b/tensorflow/compiler/xla/client/lib/math.cc
@@ -0,0 +1,152 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/client/lib/math.h"
+
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+
+namespace xla {
+
+XlaOp Sqrt(XlaOp operand) { return Pow(operand, ScalarLike(operand, 0.5)); }
+
+XlaOp Rsqrt(XlaOp operand) { return Pow(operand, ScalarLike(operand, -0.5)); }
+
+XlaOp Square(XlaOp operand) { return Pow(operand, ScalarLike(operand, 2.0)); }
+
+XlaOp Reciprocal(XlaOp operand) {
+ return Pow(operand, ScalarLike(operand, -1.0));
+}
+
+namespace {
+
+// Polynomials for computing erf/erfc. Originally from cephes.
+// Note we use float for compatibility across devices, at the cost of some
+// precision for 64 bit computations.
+//
+// Coefficients are in descending order.
+std::array<float, 9> kErfcPCoefficient = {
+ 2.46196981473530512524E-10, 5.64189564831068821977E-1,
+ 7.46321056442269912687E0, 4.86371970985681366614E1,
+ 1.96520832956077098242E2, 5.26445194995477358631E2,
+ 9.34528527171957607540E2, 1.02755188689515710272E3,
+ 5.57535335369399327526E2};
+std::array<float, 9> kErfcQCoefficient = {
+ 1.00000000000000000000E0, 1.32281951154744992508E1,
+ 8.67072140885989742329E1, 3.54937778887819891062E2,
+ 9.75708501743205489753E2, 1.82390916687909736289E3,
+ 2.24633760818710981792E3, 1.65666309194161350182E3,
+ 5.57535340817727675546E2};
+std::array<float, 6> kErfcRCoefficient = {
+ 5.64189583547755073984E-1, 1.27536670759978104416E0,
+ 5.01905042251180477414E0, 6.16021097993053585195E0,
+ 7.40974269950448939160E0, 2.97886665372100240670E0};
+std::array<float, 7> kErfcSCoefficient = {
+ 1.00000000000000000000E0, 2.26052863220117276590E0,
+ 9.39603524938001434673E0, 1.20489539808096656605E1,
+ 1.70814450747565897222E1, 9.60896809063285878198E0,
+ 3.36907645100081516050E0};
+std::array<float, 5> kErfTCoefficient = {
+ 9.60497373987051638749E0, 9.00260197203842689217E1,
+ 2.23200534594684319226E3, 7.00332514112805075473E3,
+ 5.55923013010394962768E4};
+std::array<float, 6> kErfUCoefficient = {
+ 1.00000000000000000000E0, 3.35617141647503099647E1,
+ 5.21357949780152679795E2, 4.59432382970980127987E3,
+ 2.26290000613890934246E4, 4.92673942608635921086E4};
+} // namespace
+
+// Evaluate the polynomial given coefficients and `x`.
+// N.B. Coefficients should be supplied in decreasing order.
+XlaOp EvaluatePolynomial(XlaOp x,
+ tensorflow::gtl::ArraySlice<float> coefficients) {
+ XlaOp poly = ScalarLike(x, 0.0);
+ for (float c : coefficients) {
+ poly = poly * x + ScalarLike(x, c);
+ }
+ return poly;
+}
+
+// Compute an approximation of the error function complement (1 - erf(x)).
+XlaOp Erfc(XlaOp x) {
+ XlaOp abs_x = Abs(x);
+ XlaOp z = Exp(-x * x);
+
+ XlaOp pp = EvaluatePolynomial(abs_x, kErfcPCoefficient);
+ XlaOp pq = EvaluatePolynomial(abs_x, kErfcQCoefficient);
+ XlaOp pr = EvaluatePolynomial(abs_x, kErfcRCoefficient);
+ XlaOp ps = EvaluatePolynomial(abs_x, kErfcSCoefficient);
+
+ XlaOp y = Select(Lt(abs_x, ScalarLike(x, 8.0)), z * pp / pq, z * pr / ps);
+
+ return Select(Lt(x, ScalarLike(x, 0.0)), ScalarLike(x, 2.0) - y, y);
+}
+
+// Compute a polynomial approximation of the error function.
+XlaOp Erf(XlaOp x) {
+ XlaOp z = x * x;
+ XlaOp pt = EvaluatePolynomial(z, kErfTCoefficient);
+ XlaOp pu = EvaluatePolynomial(z, kErfUCoefficient);
+ return x * pt / pu;
+}
+
+// Approximation for the inverse error function from
+// Giles, M., "Approximating the erfinv function".
+// The approximation has the form:
+// w = -log((1 - x) * (1 + x))
+// if ( w < 5 ) {
+// w = w - 2.5
+// p = sum_{i=1}^n lq[i]*w^i
+// } else {
+// w = sqrt(w) - 3
+// p = sum_{i=1}^n gq[i]*w^i
+// }
+// return p*x
+XlaOp ErfInv(XlaOp x) {
+ XlaBuilder* b = x.builder();
+ return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ TF_ASSIGN_OR_RETURN(Shape shape, b->GetShape(x));
+ constexpr int kDegree = 9;
+ constexpr std::array<float, 9> w_less_than_5_constants = {
+ 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f,
+ -4.39150654e-06f, 0.00021858087f, -0.00125372503f,
+ -0.00417768164f, 0.246640727f, 1.50140941f};
+ constexpr std::array<float, 9> w_greater_than_5_constants = {
+ -0.000200214257f, 0.000100950558f, 0.00134934322f,
+ -0.00367342844f, 0.00573950773f, -0.0076224613f,
+ 0.00943887047f, 1.00167406f, 2.83297682f};
+
+ auto one = ScalarLike(x, 1.0);
+ auto w = -Log((one - x) * (one + x));
+
+ auto lt = Lt(w, ScalarLike(x, 5.0));
+ auto coefficient = [&](int i) {
+ return Select(lt,
+ Broadcast(ScalarLike(x, w_less_than_5_constants[i]),
+ AsInt64Slice(shape.dimensions())),
+ Broadcast(ScalarLike(x, w_greater_than_5_constants[i]),
+ AsInt64Slice(shape.dimensions())));
+ };
+ w = Select(lt, w - ScalarLike(x, 2.5), Sqrt(w) - ScalarLike(x, 3.0));
+ auto p = coefficient(0);
+ for (int i = 1; i < kDegree; ++i) {
+ p = coefficient(i) + p * w;
+ }
+ return p * x;
+ });
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/client/lib/math.h b/tensorflow/compiler/xla/client/lib/math.h
new file mode 100644
index 0000000000..e7c8b50273
--- /dev/null
+++ b/tensorflow/compiler/xla/client/lib/math.h
@@ -0,0 +1,51 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_
+#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_
+
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+
+namespace xla {
+
+// Computes the square root of 'operand'.
+XlaOp Sqrt(XlaOp operand);
+
+// Computes the reciprocal of the square root of 'operand'.
+XlaOp Rsqrt(XlaOp operand);
+
+// Computes the square of 'operand'.
+XlaOp Square(XlaOp operand);
+
+// Computes the reciprocal of 'operand'.
+XlaOp Reciprocal(XlaOp operand);
+
+// Evaluates a polynomial given coefficients and `x`.
+// N.B. Coefficients should be supplied in decreasing order.
+XlaOp EvaluatePolynomial(XlaOp x,
+ tensorflow::gtl::ArraySlice<float> coefficients);
+
+// Computes an approximation of the error function complement (1 - erf(x)).
+XlaOp Erfc(XlaOp x);
+
+// Computes an approximation of the error function.
+XlaOp Erf(XlaOp x);
+
+// Computes an approximation of the inverse of the error function.
+XlaOp ErfInv(XlaOp x);
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_
diff --git a/tensorflow/compiler/xla/client/lib/math_test.cc b/tensorflow/compiler/xla/client/lib/math_test.cc
new file mode 100644
index 0000000000..1df4e6ea42
--- /dev/null
+++ b/tensorflow/compiler/xla/client/lib/math_test.cc
@@ -0,0 +1,85 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/client/lib/math.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+namespace xla {
+namespace {
+
+class MathTest : public ClientLibraryTestBase {
+ public:
+ ErrorSpec error_spec_{0.0001};
+};
+
+XLA_TEST_F(MathTest, SqrtF32) {
+ XlaBuilder builder(TestName());
+ Literal zero_literal = Literal::Zero(PrimitiveType::F32);
+
+ std::unique_ptr<GlobalData> zero_data =
+ client_->TransferToServer(zero_literal).ConsumeValueOrDie();
+
+ XlaOp zero = Parameter(&builder, 0, zero_literal.shape(), "zero");
+ Sqrt(zero);
+
+ ComputeAndCompareR0<float>(&builder, 0.0f, {zero_data.get()}, error_spec_);
+}
+
+XLA_TEST_F(MathTest, SquareTenValues) {
+ XlaBuilder builder(TestName());
+ auto x = ConstantR1<float>(
+ &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
+ Square(x);
+
+ std::vector<float> expected = {4.41, 6.76, 6.76, 16., 4.41,
+ 5.29, 25., 0.81, 5.76, 2.56};
+ ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
+}
+
+XLA_TEST_F(MathTest, ReciprocalTenValues) {
+ XlaBuilder builder(TestName());
+ auto x = ConstantR1<float>(
+ &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
+ Reciprocal(x);
+
+ std::vector<float> expected = {
+ 0.47619048, -0.38461538, 0.38461538, -0.25, 0.47619048,
+ 0.43478261, -0.2, -1.11111111, -0.41666667, 0.625};
+ ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
+}
+
+XLA_TEST_F(MathTest, SqrtZeroes) {
+ XlaBuilder builder(TestName());
+ auto x = ConstantR1<float>(&builder, {0.0, -0.0});
+ Sqrt(x);
+
+ ComputeAndCompareR1<float>(&builder, {0, 0}, {}, error_spec_);
+}
+
+XLA_TEST_F(MathTest, SqrtSixValues) {
+ XlaBuilder builder(TestName());
+ auto x = ConstantR1<float>(&builder, {16.0, 1.0, 1024.0, 0.16, 0.2, 12345});
+ Sqrt(x);
+
+ std::vector<float> expected = {4, 1, 32, 0.4, 0.4472, 111.1080};
+ ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
+}
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
index 4f683a4115..95342af6a7 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
@@ -532,6 +532,14 @@ XlaOp XlaBuilder::Broadcast(
});
}
+XlaOp XlaBuilder::BroadcastInDim(
+ const XlaOp& operand, const Shape& shape,
+ const tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return InDimBroadcast(shape, operand, broadcast_dimensions);
+ });
+}
+
StatusOr<XlaOp> XlaBuilder::Reshape(const Shape& shape, const XlaOp& operand) {
TF_RETURN_IF_ERROR(first_error_);
@@ -1369,11 +1377,6 @@ XlaOp XlaBuilder::Sort(XlaOp keys, tensorflow::gtl::optional<XlaOp> values) {
});
}
-XlaOp XlaBuilder::SqrtF32(const XlaOp& operand) {
- return BinaryOp(HloOpcode::kPower, operand, ConstantR0<float>(0.5),
- /*broadcast_dimensions=*/{});
-}
-
XlaOp XlaBuilder::Pow(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kPower, lhs, rhs, broadcast_dimensions);
@@ -1404,16 +1407,6 @@ XlaOp XlaBuilder::BitcastConvertType(const XlaOp& operand,
});
}
-XlaOp XlaBuilder::SquareF32(const XlaOp& operand) {
- return BinaryOp(HloOpcode::kPower, operand, ConstantR0<float>(2.0),
- /*broadcast_dimensions=*/{});
-}
-
-XlaOp XlaBuilder::ReciprocalF32(const XlaOp& operand) {
- return BinaryOp(HloOpcode::kPower, operand, ConstantR0<float>(-1.0),
- /*broadcast_dimensions=*/{});
-}
-
XlaOp XlaBuilder::Neg(const XlaOp& operand) {
return UnaryOp(HloOpcode::kNegate, operand);
}
@@ -2139,6 +2132,13 @@ XlaOp Broadcast(const XlaOp& operand,
return operand.builder()->Broadcast(operand, broadcast_sizes);
}
+XlaOp BroadcastInDim(
+ const XlaOp& operand, const Shape& shape,
+ const tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return operand.builder()->BroadcastInDim(operand, shape,
+ broadcast_dimensions);
+}
+
XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value,
const PaddingConfig& padding_config) {
return operand.builder()->Pad(operand, padding_value, padding_config);
@@ -2497,14 +2497,6 @@ XlaOp Real(const XlaOp& operand) { return operand.builder()->Real(operand); }
XlaOp Imag(const XlaOp& operand) { return operand.builder()->Imag(operand); }
-XlaOp SqrtF32(const XlaOp& operand) {
- return operand.builder()->SqrtF32(operand);
-}
-
-XlaOp SquareF32(const XlaOp& operand) {
- return operand.builder()->SquareF32(operand);
-}
-
XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
return lhs.builder()->Pow(lhs, rhs, broadcast_dimensions);
@@ -2522,10 +2514,6 @@ XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type) {
return operand.builder()->BitcastConvertType(operand, new_element_type);
}
-XlaOp ReciprocalF32(const XlaOp& operand) {
- return operand.builder()->ReciprocalF32(operand);
-}
-
XlaOp Neg(const XlaOp& operand) { return operand.builder()->Neg(operand); }
XlaOp Transpose(const XlaOp& operand,
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h
index ac6ad87349..274aba8a31 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h
@@ -317,6 +317,27 @@ class XlaBuilder {
XlaOp Broadcast(const XlaOp& operand,
tensorflow::gtl::ArraySlice<int64> broadcast_sizes);
+ // Performs in-dimension-style broadcast.
+ //
+ // Operand specifies the input to be broadcast. "shape" is expected output
+ // shape. "broadcast_dimensions" are the dimensions to be broadcasting into.
+ // Dimension numbers in broadcast_dimensions map to individual dimensions
+ // of the operand, and specify what dimension of the output shape they
+ // should be broadcast.
+ // e.g.
+ // Say operand = [1, 2], i.e., a 1D tensor with 2 elements.
+ // and dimension of shape is [2,2].
+ // Specifying {1} as brodcast_dimension will generate output
+ // [1 , 2]
+ // [1 , 2]
+ // On the other hand, specifying {0} as broadcast_dimension
+ // will generate output
+ // [1 , 1]
+ // [2 , 2]
+ XlaOp BroadcastInDim(
+ const XlaOp& operand, const Shape& shape,
+ const tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+
// Enqueues a pad operation onto the computation that pads the given value on
// the edges as well as between the elements of the input. padding_config
// specifies the padding amount for each dimension.
@@ -730,16 +751,6 @@ class XlaBuilder {
// Enqueues an imaginary-part instruction onto the computation.
XlaOp Imag(const XlaOp& operand);
- // Enqueues a float32 sqrt instruction onto the computation.
- // (float32 is specified as there is an implicit float32 0.5f constant
- // exponent).
- XlaOp SqrtF32(const XlaOp& operand);
-
- // Enqueues a float32 square instruction onto the computation.
- // (float32 is specified as there is an implicit float32 2.0f constant
- // exponent).
- XlaOp SquareF32(const XlaOp& operand);
-
// Enqueues a lhs^rhs computation onto the computation.
XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
@@ -762,14 +773,6 @@ class XlaBuilder {
XlaOp BitcastConvertType(const XlaOp& operand,
PrimitiveType new_element_type);
- // Enqueues a float32 reciprocal instruction onto the computation.
- // (float32 is specified as there is an implicit float32 -1.0f constant
- // exponent).
- //
- // TODO(b/34468990) axe F32 suffix, can be determined by reflecting on the
- // shape of the operand.
- XlaOp ReciprocalF32(const XlaOp& operand);
-
// Enqueues a negate instruction onto the computation.
XlaOp Neg(const XlaOp& operand);
@@ -1033,6 +1036,10 @@ class XlaBuilder {
friend XlaOp Broadcast(const XlaOp& operand,
tensorflow::gtl::ArraySlice<int64> broadcast_sizes);
+ friend XlaOp BroadcastInDim(
+ const XlaOp& operand, const Shape& shape,
+ const tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+
friend XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value,
const PaddingConfig& padding_config);
@@ -1210,8 +1217,6 @@ class XlaBuilder {
friend XlaOp Tanh(const XlaOp& operand);
friend XlaOp Real(const XlaOp& operand);
friend XlaOp Imag(const XlaOp& operand);
- friend XlaOp SqrtF32(const XlaOp& operand);
- friend XlaOp SquareF32(const XlaOp& operand);
friend XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
friend XlaOp IsFinite(const XlaOp& operand);
@@ -1219,7 +1224,6 @@ class XlaBuilder {
PrimitiveType new_element_type);
friend XlaOp BitcastConvertType(const XlaOp& operand,
PrimitiveType new_element_type);
- friend XlaOp ReciprocalF32(const XlaOp& operand);
friend XlaOp Neg(const XlaOp& operand);
friend XlaOp Transpose(const XlaOp& operand,
tensorflow::gtl::ArraySlice<int64> permutation);
@@ -1376,6 +1380,27 @@ XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value);
XlaOp Broadcast(const XlaOp& operand,
tensorflow::gtl::ArraySlice<int64> broadcast_sizes);
+// Performs in-dimension-style broadcast.
+//
+// Operand specifies the input to be broadcast. "shape" is expected output
+// shape. "broadcast_dimensions" are the dimensions to be broadcasting into.
+// Dimension numbers in broadcast_dimensions map to individual dimensions
+// of the operand, and specify what dimension of the output shape they
+// should be broadcast.
+// e.g.
+// Say operand = [1, 2], i.e., a 1D tensor with 2 elements.
+// and dimension of shape is [2,2].
+// Specifying {1} as brodcast_dimension will generate output
+// [1 , 2]
+// [1 , 2]
+// On the other hand, specifying {0} as broadcast_dimension
+// will generate output
+// [1 , 1]
+// [2 , 2]
+XlaOp BroadcastInDim(
+ const XlaOp& operand, const Shape& shape,
+ const tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+
// Enqueues a pad operation onto the computation that pads the given value on
// the edges as well as between the elements of the input. padding_config
// specifies the padding amount for each dimension.
@@ -1787,16 +1812,6 @@ XlaOp Real(const XlaOp& operand);
// Enqueues an imaginary-part instruction onto the computation.
XlaOp Imag(const XlaOp& operand);
-// Enqueues a float32 sqrt instruction onto the computation.
-// (float32 is specified as there is an implicit float32 0.5f constant
-// exponent).
-XlaOp SqrtF32(const XlaOp& operand);
-
-// Enqueues a float32 square instruction onto the computation.
-// (float32 is specified as there is an implicit float32 2.0f constant
-// exponent).
-XlaOp SquareF32(const XlaOp& operand);
-
// Enqueues a lhs^rhs computation onto the computation.
XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
@@ -1817,14 +1832,6 @@ XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type);
// identical.
XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type);
-// Enqueues a float32 reciprocal instruction onto the computation.
-// (float32 is specified as there is an implicit float32 -1.0f constant
-// exponent).
-//
-// TODO(b/34468990) axe F32 suffix, can be determined by reflecting on the
-// shape of the operand.
-XlaOp ReciprocalF32(const XlaOp& operand);
-
// Enqueues a negate instruction onto the computation.
XlaOp Neg(const XlaOp& operand);
diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD
index 83834c1ff6..22cc4e2436 100644
--- a/tensorflow/compiler/xla/python/BUILD
+++ b/tensorflow/compiler/xla/python/BUILD
@@ -52,9 +52,9 @@ cc_library(
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:executable_build_options",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/lib:math",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
- "//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/core:framework_lite",
"//tensorflow/core:lib",
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc
index b5ba4e2d42..be55d50b23 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.cc
+++ b/tensorflow/compiler/xla/python/local_computation_builder.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/python/local_computation_builder.h"
+#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/compiler/xla/ptr_util.h"
@@ -626,11 +627,11 @@ _FORWARD_UNOP(Sign)
_FORWARD_UNOP(Cos)
_FORWARD_UNOP(Sin)
_FORWARD_UNOP(Tanh)
-_FORWARD_UNOP(SqrtF32)
-_FORWARD_UNOP(SquareF32)
+_FORWARD_UNOP(Sqrt)
+_FORWARD_UNOP(Square)
_FORWARD_BINOP(Pow)
_FORWARD_UNOP(IsFinite)
-_FORWARD_UNOP(ReciprocalF32)
+_FORWARD_UNOP(Reciprocal)
_FORWARD_UNOP(Neg)
_FORWARD_UNOP(Sort)
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h
index e920f8aecd..690ff277e8 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.h
+++ b/tensorflow/compiler/xla/python/local_computation_builder.h
@@ -346,11 +346,11 @@ class LocalComputationBuilder {
_FORWARD_UNOP(Cos)
_FORWARD_UNOP(Sin)
_FORWARD_UNOP(Tanh)
- _FORWARD_UNOP(SqrtF32)
- _FORWARD_UNOP(SquareF32)
+ _FORWARD_UNOP(Sqrt)
+ _FORWARD_UNOP(Square)
_FORWARD_BINOP(Pow)
_FORWARD_UNOP(IsFinite)
- _FORWARD_UNOP(ReciprocalF32)
+ _FORWARD_UNOP(Reciprocal)
_FORWARD_UNOP(Neg)
_FORWARD_UNOP(Sort)
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i
index 76e9e637cd..c44e69e615 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.i
+++ b/tensorflow/compiler/xla/python/local_computation_builder.i
@@ -1002,11 +1002,11 @@ tensorflow::ImportNumpy();
%unignore xla::swig::LocalComputationBuilder::Cos;
%unignore xla::swig::LocalComputationBuilder::Sin;
%unignore xla::swig::LocalComputationBuilder::Tanh;
-%unignore xla::swig::LocalComputationBuilder::SqrtF32;
-%unignore xla::swig::LocalComputationBuilder::SquareF32;
+%unignore xla::swig::LocalComputationBuilder::Sqrt;
+%unignore xla::swig::LocalComputationBuilder::Square;
%unignore xla::swig::LocalComputationBuilder::Pow;
%unignore xla::swig::LocalComputationBuilder::IsFinite;
-%unignore xla::swig::LocalComputationBuilder::ReciprocalF32;
+%unignore xla::swig::LocalComputationBuilder::Reciprocal;
%unignore xla::swig::LocalComputationBuilder::Neg;
%unignore xla::swig::LocalComputationBuilder::Sort;
%unignore xla::swig::DestructureLocalShapedBufferTuple;
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py
index abb97d0c6f..27aee634ba 100644
--- a/tensorflow/compiler/xla/python/xla_client.py
+++ b/tensorflow/compiler/xla/python/xla_client.py
@@ -99,10 +99,10 @@ _UNARY_OPS = [
'Cos',
'Sin',
'Tanh',
- 'SqrtF32',
- 'SquareF32',
+ 'Sqrt',
+ 'Square',
'IsFinite',
- 'ReciprocalF32',
+ 'Reciprocal',
'Neg',
'Sort',
]
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 88f994786a..d90b0fb57d 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -246,6 +246,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_execution_profile",
"//tensorflow/compiler/xla/service:pool",
"//tensorflow/core:lib",
+ "//tensorflow/core:ptr_util",
"//tensorflow/core:stream_executor_no_cuda",
],
)
diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc
index 77a48965e0..5e4fe1dd39 100644
--- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h"
#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -43,7 +44,9 @@ Status ConditionalThunk::Initialize(const GpuExecutable& executable,
}
Status ConditionalThunk::ExecuteOnStream(
- const BufferAllocations& buffer_allocations, se::Stream* stream) {
+ const BufferAllocations& buffer_allocations, se::Stream* stream,
+ HloExecutionProfiler* profiler) {
+ auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
// Copy the predicate value from device.
bool predicate;
se::DeviceMemoryBase predicate_address =
@@ -59,10 +62,15 @@ Status ConditionalThunk::ExecuteOnStream(
// Execute the true or the false computation depending on the value of the
// predicate.
if (predicate) {
- TF_RETURN_IF_ERROR(true_thunk_.ExecuteOnStream(buffer_allocations, stream));
+ profiler->StartHloComputation();
+ TF_RETURN_IF_ERROR(
+ true_thunk_.ExecuteOnStream(buffer_allocations, stream, profiler));
+ profiler->FinishHloComputation(hlo_instruction()->true_computation());
} else {
+ profiler->StartHloComputation();
TF_RETURN_IF_ERROR(
- false_thunk_.ExecuteOnStream(buffer_allocations, stream));
+ false_thunk_.ExecuteOnStream(buffer_allocations, stream, profiler));
+ profiler->FinishHloComputation(hlo_instruction()->false_computation());
}
return Status::OK();
diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.h b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h
index ee03865d17..aef24342c9 100644
--- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONDITIONAL_THUNK_H_
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -50,7 +51,8 @@ class ConditionalThunk : public Thunk {
Status Initialize(const GpuExecutable& executable,
se::StreamExecutor* executor) override;
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) override;
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) override;
private:
BufferAllocation::Slice predicate_buffer_index_;
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
index f088112412..7833a4077e 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <string>
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/strings/strcat.h"
@@ -55,7 +56,8 @@ ConvolutionThunk::ConvolutionThunk(
tensor_ops_enabled_(tensor_ops_enabled) {}
Status ConvolutionThunk::ExecuteOnStream(
- const BufferAllocations& buffer_allocations, se::Stream* stream) {
+ const BufferAllocations& buffer_allocations, se::Stream* stream,
+ HloExecutionProfiler* profiler) {
se::DeviceMemoryBase input_data =
buffer_allocations.GetDeviceAddress(input_buffer_);
se::DeviceMemoryBase filter_data =
@@ -68,6 +70,7 @@ Status ConvolutionThunk::ExecuteOnStream(
se::dnn::AlgorithmConfig algorithm_config(
se::dnn::AlgorithmDesc(algorithm_, tensor_ops_enabled_));
+ auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
TF_RETURN_IF_ERROR(RunCudnnConvolution(
convolution_kind_, input_shape_, filter_shape_, output_shape_, input_data,
filter_data, output_data, scratch, window_, dim_nums_, algorithm_config,
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
index 6d845025b1..d76ca6698d 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/types.h"
@@ -66,7 +67,8 @@ class ConvolutionThunk : public Thunk {
// Does the convolution for the thunk on "stream".
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) override;
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) override;
private:
class ScratchAllocator;
diff --git a/tensorflow/compiler/xla/service/gpu/copy_thunk.cc b/tensorflow/compiler/xla/service/gpu/copy_thunk.cc
index ee38c0318a..92e03f94c1 100644
--- a/tensorflow/compiler/xla/service/gpu/copy_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/copy_thunk.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/copy_thunk.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace xla {
@@ -30,9 +31,11 @@ HostToDeviceCopyThunk::HostToDeviceCopyThunk(
mem_size_(mem_size) {}
Status HostToDeviceCopyThunk::ExecuteOnStream(
- const BufferAllocations& buffer_allocations, se::Stream* stream) {
+ const BufferAllocations& buffer_allocations, se::Stream* stream,
+ HloExecutionProfiler* profiler) {
se::DeviceMemoryBase destination_data =
buffer_allocations.GetDeviceAddress(destination_buffer_);
+ auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
stream->ThenMemcpy(&destination_data, source_address_, mem_size_);
return Status::OK();
}
@@ -47,11 +50,13 @@ DeviceToDeviceCopyThunk::DeviceToDeviceCopyThunk(
mem_size_(mem_size) {}
Status DeviceToDeviceCopyThunk::ExecuteOnStream(
- const BufferAllocations& buffer_allocations, se::Stream* stream) {
+ const BufferAllocations& buffer_allocations, se::Stream* stream,
+ HloExecutionProfiler* profiler) {
se::DeviceMemoryBase destination_data =
buffer_allocations.GetDeviceAddress(destination_buffer_);
se::DeviceMemoryBase source_data =
buffer_allocations.GetDeviceAddress(source_buffer_);
+ auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
stream->ThenMemcpy(&destination_data, source_data, mem_size_);
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/gpu/copy_thunk.h b/tensorflow/compiler/xla/service/gpu/copy_thunk.h
index 8b128386f6..91564b520a 100644
--- a/tensorflow/compiler/xla/service/gpu/copy_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/copy_thunk.h
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -40,7 +41,8 @@ class HostToDeviceCopyThunk : public Thunk {
HostToDeviceCopyThunk& operator=(const HostToDeviceCopyThunk&) = delete;
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) override;
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) override;
private:
const void* source_address_;
@@ -63,7 +65,8 @@ class DeviceToDeviceCopyThunk : public Thunk {
DeviceToDeviceCopyThunk& operator=(const DeviceToDeviceCopyThunk&) = delete;
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) override;
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) override;
private:
const BufferAllocation::Slice source_buffer_;
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc
index 68099fd638..7b172812c3 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <string>
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
@@ -99,13 +100,15 @@ CudnnBatchNormForwardInferenceThunk::CudnnBatchNormForwardInferenceThunk(
}
Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream(
- const BufferAllocations& buffer_allocations, se::Stream* stream) {
+ const BufferAllocations& buffer_allocations, se::Stream* stream,
+ HloExecutionProfiler* profiler) {
dnn::BatchDescriptor operand_desc;
dnn::BatchDescriptor scale_offset_desc;
std::tie(operand_desc, scale_offset_desc) =
MakeDescriptors(hlo_instruction()->shape(), feature_index_);
se::DeviceMemory<float> output(buffer_allocations.GetDeviceAddress(output_));
+ auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
stream->ThenBatchNormalizationForward(
se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(operand_)),
se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(scale_)),
@@ -123,6 +126,7 @@ Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream(
/*is_training=*/false, //
/*var_to_inv_var=*/nullptr, //
/*inv_var_to_var=*/nullptr);
+
if (!stream->ok()) {
return InternalError("BatchNormalizationForward call failed.");
}
@@ -158,7 +162,8 @@ CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk(
}
Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream(
- const BufferAllocations& buffer_allocations, se::Stream* stream) {
+ const BufferAllocations& buffer_allocations, se::Stream* stream,
+ HloExecutionProfiler* profiler) {
dnn::BatchDescriptor operand_desc;
dnn::BatchDescriptor scale_offset_desc;
// The BatchNormTraining HLO outputs a tuple of three elements: output data,
@@ -175,6 +180,7 @@ Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream(
buffer_allocations.GetDeviceAddress(output_inv_stddev_));
se::DeviceMemory<float> null_device_ptr(nullptr);
+ auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
stream->ThenBatchNormalizationForward(
se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(operand_)),
se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(scale_)),
@@ -240,7 +246,8 @@ CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk(
}
Status CudnnBatchNormBackwardThunk::ExecuteOnStream(
- const BufferAllocations& buffer_allocations, se::Stream* stream) {
+ const BufferAllocations& buffer_allocations, se::Stream* stream,
+ HloExecutionProfiler* profiler) {
dnn::BatchDescriptor operand_desc;
dnn::BatchDescriptor scale_offset_desc;
@@ -257,6 +264,7 @@ Status CudnnBatchNormBackwardThunk::ExecuteOnStream(
se::DeviceMemory<float> output_grad_offset(
buffer_allocations.GetDeviceAddress(output_grad_offset_));
+ auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
stream->ThenBatchNormalizationBackward(
se::DeviceMemory<float>(
buffer_allocations.GetDeviceAddress(grad_output_)),
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h
index 874f85a863..d2143b3952 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/types.h"
@@ -60,7 +61,8 @@ class CudnnBatchNormForwardInferenceThunk : public Thunk {
const CudnnBatchNormForwardInferenceThunk&) = delete;
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) override;
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) override;
private:
BufferAllocation::Slice operand_;
@@ -90,7 +92,8 @@ class CudnnBatchNormForwardTrainingThunk : public Thunk {
const CudnnBatchNormForwardTrainingThunk&) = delete;
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) override;
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) override;
private:
BufferAllocation::Slice operand_;
@@ -123,7 +126,8 @@ class CudnnBatchNormBackwardThunk : public Thunk {
delete;
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) override;
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) override;
private:
BufferAllocation::Slice operand_;
diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc
index e14ee6918b..0cdddf8bcf 100644
--- a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <string>
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/strings/strcat.h"
@@ -107,7 +108,8 @@ FftThunk::FftThunk(FftType fft_type,
output_shape_(output_shape) {}
Status FftThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) {
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) {
VLOG(3) << "FFT type: " << FftTypeToString(fft_type_);
VLOG(3) << "Input shape: " << ShapeUtil::HumanStringWithLayout(input_shape_);
VLOG(3) << "Output shape: "
@@ -116,6 +118,7 @@ Status FftThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
FftScratchAllocator scratch_allocator(buffer_allocations.device_ordinal(),
buffer_allocations.memory_allocator());
+ auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
if (fft_plan_ == nullptr) {
const int64 fft_rank = fft_length_.size();
CHECK_LE(fft_rank, 3);
diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.h b/tensorflow/compiler/xla/service/gpu/fft_thunk.h
index b0a22564f3..8c53be5077 100644
--- a/tensorflow/compiler/xla/service/gpu/fft_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.h
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/types.h"
@@ -72,7 +73,8 @@ class FftThunk : public Thunk {
// Does the FFT for the thunk on "stream".
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) override;
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) override;
private:
const se::fft::Type fft_type_;
diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.cc b/tensorflow/compiler/xla/service/gpu/for_thunk.cc
index b36539e0cb..4fdc55909a 100644
--- a/tensorflow/compiler/xla/service/gpu/for_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/for_thunk.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/for_thunk.h"
#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -37,11 +38,15 @@ Status ForThunk::Initialize(const GpuExecutable& executable,
}
Status ForThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) {
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) {
+ auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
for (int64 i = 0; i < loop_limit_; ++i) {
+ profiler->StartHloComputation();
// Invoke loop body thunk sequence.
- TF_RETURN_IF_ERROR(
- body_thunk_sequence_->ExecuteOnStream(buffer_allocations, stream));
+ TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(buffer_allocations,
+ stream, profiler));
+ profiler->FinishHloComputation(hlo_instruction()->while_body());
}
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.h b/tensorflow/compiler/xla/service/gpu/for_thunk.h
index 41ddfe0ceb..c2d39071b2 100644
--- a/tensorflow/compiler/xla/service/gpu/for_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/for_thunk.h
@@ -19,6 +19,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -39,7 +40,8 @@ class ForThunk : public Thunk {
Status Initialize(const GpuExecutable& executable,
se::StreamExecutor* executor) override;
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) override;
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) override;
private:
const int64 loop_limit_;
diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
index 79fca43d02..dbc7754e25 100644
--- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
@@ -252,7 +252,8 @@ GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer,
alpha_(alpha) {}
Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) {
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) {
VLOG(2) << "Executing a GemmThunk";
se::DeviceMemoryBase lhs_data =
@@ -352,6 +353,7 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
alpha_, stream);
};
+ auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
bool launch_ok;
if (LayoutUtil::Minor(output_shape_.layout(), 0) == 0) {
launch_ok = launch(
diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h
index 7a4830d64e..939c7f85e3 100644
--- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -48,7 +49,8 @@ class GemmThunk : public Thunk {
// Does the gemm operation for the thunk on "stream", which must be non-null.
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) override;
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) override;
// Returns true if we'll perform autotuning if run on the given stream. If
// so, we want the GPU to be quiescent during autotuning, so as not to
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
index f20a828bc1..0cad2958c7 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
@@ -136,18 +136,17 @@ Status GpuExecutable::ExecuteThunks(
TF_RETURN_IF_ERROR(main_stream->BlockHostUntilDone());
}
- profiler.StartOperation();
VLOG(2) << "Executing the thunk for "
<< thunk->hlo_instruction()->ToString() << " on stream "
<< stream_no;
- TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(buffer_allocations, stream));
+ TF_RETURN_IF_ERROR(
+ thunk->ExecuteOnStream(buffer_allocations, stream, &profiler));
if (thunk_schedule_->Depended(thunk)) {
auto finish_event = MakeUnique<se::Event>(main_stream->parent());
finish_event->Init();
stream->ThenRecordEvent(finish_event.get());
thunk_to_finish_event[thunk] = std::move(finish_event);
}
- profiler.FinishOperation(thunk->hlo_instruction());
}
main_stream->ThenWaitFor(&sub_streams);
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
index 8bf62dde8b..09ef62c87f 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
@@ -51,7 +51,7 @@ HeuristicLayoutAssignment(const HloInstruction* instr,
// H <=> Y
// W <=> X
//
- // Therefore kOutputInputYX means NHWC; kBatchDepthYX means NCHW.
+ // Therefore kOutputInputYX and kBatchDepthYX mean NCHW.
// As of today, our empirical evidence is that cudnn 7.0 is faster on V100 x
// fp16 with the mostly-NHWC layout. The heuristic may change as cudnn version
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc
index daddd3738e..3e96beb575 100644
--- a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include <memory>
+#include <stack>
#include <vector>
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -24,9 +25,30 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/pool.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace xla {
namespace gpu {
+namespace {
+void InitAndStartTimer(std::stack<std::unique_ptr<se::Timer>>* timers,
+ se::Stream* stream) {
+ timers->push(MakeUnique<se::Timer>(stream->parent()));
+ stream->InitTimer(timers->top().get()).ThenStartTimer(timers->top().get());
+}
+
+uint64 GetCyclesTaken(
+ std::stack<std::unique_ptr<se::Timer>>* timers,
+ const std::vector<Pool<se::Stream>::SmartPtr>& sub_streams,
+ se::Stream* stream, double clock_rate_ghz) {
+ CHECK_GT(timers->size(), 0);
+ stream->ThenWaitFor(&sub_streams);
+ stream->ThenStopTimer(timers->top().get());
+ stream->BlockHostUntilDone().IgnoreError();
+ double nanoseconds = timers->top()->Nanoseconds();
+ timers->pop();
+ return static_cast<uint64>(nanoseconds * clock_rate_ghz);
+}
+} // namespace
HloExecutionProfiler::HloExecutionProfiler(
bool do_profile, HloExecutionProfile* profile, se::Stream* stream,
@@ -39,11 +61,7 @@ HloExecutionProfiler::HloExecutionProfiler(
computation_(computation) {
if (do_profile_) {
clock_rate_ghz_ = stream->parent()->GetDeviceDescription().clock_rate_ghz();
- execution_timer_.reset(new se::Timer(stream->parent()));
- per_op_timer_.reset(new se::Timer(stream->parent()));
- stream->InitTimer(execution_timer_.get())
- .ThenStartTimer(execution_timer_.get());
- stream->InitTimer(per_op_timer_.get());
+ InitAndStartTimer(&timers_, stream);
}
}
@@ -51,32 +69,47 @@ void HloExecutionProfiler::FinishExecution() {
CHECK(!finished_execution_) << "Call FinishExecution only once!";
finished_execution_ = true;
if (do_profile_) {
- stream_->ThenWaitFor(&sub_streams_);
- stream_->ThenStopTimer(execution_timer_.get());
- stream_->BlockHostUntilDone().IgnoreError();
profile_->set_total_cycles_executed(
*computation_,
- static_cast<uint64>(execution_timer_->Nanoseconds() * clock_rate_ghz_));
+ GetCyclesTaken(&timers_, sub_streams_, stream_, clock_rate_ghz_));
+ }
+}
+
+void HloExecutionProfiler::StartHloComputation() {
+ if (do_profile_) {
+ InitAndStartTimer(&timers_, stream_);
}
}
-void HloExecutionProfiler::StartOperation() {
+void HloExecutionProfiler::FinishHloComputation(
+ const HloComputation* computation) {
if (do_profile_) {
- stream_->ThenStartTimer(per_op_timer_.get());
+ profile_->set_total_cycles_executed(
+ *computation,
+ GetCyclesTaken(&timers_, sub_streams_, stream_, clock_rate_ghz_));
+ }
+}
+
+void HloExecutionProfiler::StartHloInstruction() {
+ if (do_profile_) {
+ InitAndStartTimer(&timers_, stream_);
}
}
-void HloExecutionProfiler::FinishOperation(
+void HloExecutionProfiler::FinishHloInstruction(
const HloInstruction* hlo_instruction) {
if (do_profile_) {
- stream_->ThenWaitFor(&sub_streams_);
- stream_->ThenStopTimer(per_op_timer_.get());
- stream_->BlockHostUntilDone().IgnoreError();
profile_->SetCyclesTakenBy(
hlo_instruction,
- static_cast<uint64>(per_op_timer_->Nanoseconds() * clock_rate_ghz_));
+ GetCyclesTaken(&timers_, sub_streams_, stream_, clock_rate_ghz_));
}
}
+std::unique_ptr<ScopedInstructionProfiler>
+HloExecutionProfiler::MakeScopedInstructionProfiler(
+ const HloInstruction* hlo_instruction) {
+ return MakeUnique<ScopedInstructionProfiler>(this, hlo_instruction);
+}
+
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h
index c9b882ff80..e5c655edc6 100644
--- a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h
+++ b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_EXECUTION_PROFILER_H_
#include <memory>
+#include <stack>
#include <vector>
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -28,6 +29,8 @@ limitations under the License.
namespace xla {
namespace gpu {
+class ScopedInstructionProfiler;
+
// A helper class for profiling HLO in the course of GPU program execution.
// All of the profiling is guarded internally, to avoid the caller needing to
// have lots of conditionals sprinkled around.
@@ -43,12 +46,25 @@ class HloExecutionProfiler {
// execution timer.
void FinishExecution();
- // If profiling is enabled, starts the per-operation timer.
- void StartOperation();
+ // If profiling is enabled, starts a timer for a (sub)computation.
+ void StartHloComputation();
+
+ // If profiling is enabled stops the timer for a (sub)computation and records
+ // the time that the computation took to execute in the profile.
+ void FinishHloComputation(const HloComputation* computation);
+
+ // If profiling is enabled, starts a per-operation timer.
+ void StartHloInstruction();
// If profiling is enabled, stops the per-operation timer and records the time
// that the hlo_instruction took to execute in the profile.
- void FinishOperation(const HloInstruction* hlo_instruction);
+ void FinishHloInstruction(const HloInstruction* hlo_instruction);
+
+ // Returns a ScopedInstructionProfiler and triggers a call to
+ // StartHloInstruction(). Once the returned ScopedInstructionProfiler goes
+ // out of scope, it triggers a call to FinishHloInstruction().
+ std::unique_ptr<ScopedInstructionProfiler> MakeScopedInstructionProfiler(
+ const HloInstruction* hlo_instruction);
private:
const bool do_profile_;
@@ -57,11 +73,33 @@ class HloExecutionProfiler {
se::Stream* stream_;
const std::vector<Pool<se::Stream>::SmartPtr>& sub_streams_;
const HloComputation* computation_;
- std::unique_ptr<se::Timer> execution_timer_;
- std::unique_ptr<se::Timer> per_op_timer_;
+ std::stack<std::unique_ptr<se::Timer>> timers_;
bool finished_execution_ = false;
};
+// This class can be used within the ExecuteOnStream() implementations of
+// Thunks. It ensures that we always have a pair of matching
+// StartHloInstruction() and FinishHloInstruction() calls to the profiler.
+class ScopedInstructionProfiler {
+ public:
+ ScopedInstructionProfiler(HloExecutionProfiler* profiler,
+ const HloInstruction* hlo_instruction)
+ : profiler_(profiler), hlo_instruction_(hlo_instruction) {
+ if (hlo_instruction != nullptr) {
+ profiler->StartHloInstruction();
+ }
+ }
+ ~ScopedInstructionProfiler() {
+ if (hlo_instruction_ != nullptr) {
+ profiler_->FinishHloInstruction(hlo_instruction_);
+ }
+ }
+
+ private:
+ HloExecutionProfiler* profiler_;
+ const HloInstruction* hlo_instruction_;
+};
+
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc
index 2b63d8727c..62915febb1 100644
--- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc
@@ -13,8 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/service/gpu/infeed_manager.h"
#include "tensorflow/compiler/xla/service/gpu/infeed_thunk.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
+#include "tensorflow/compiler/xla/service/gpu/infeed_manager.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -27,9 +28,11 @@ InfeedThunk::InfeedThunk(
: Thunk(Kind::kInfeed, hlo_instruction), infeed_slices_(infeed_slices) {}
Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) {
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) {
VLOG(2) << "Infeeding to GPU ";
+ auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
// First copy the infeed data which is element 0 of the infeed instruction's
// two-tuple output (the other element is a token).
se::DeviceMemoryBase data_address =
diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h
index cb9a6232f3..59487e245b 100644
--- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -40,7 +41,8 @@ class InfeedThunk : public Thunk {
InfeedThunk& operator=(const InfeedThunk&) = delete;
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) override;
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) override;
private:
const ShapeTree<BufferAllocation::Slice> infeed_slices_;
diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc
index f56c1ce69f..e76823ad10 100644
--- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/stringpiece.h"
@@ -75,7 +76,8 @@ void KernelThunk::SetLaunchDimensions(const LaunchDimensions& launch_dims) {
}
Status KernelThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) {
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) {
// Load the kernel.
se::StreamExecutor* executor = stream->parent();
LaunchDimensions launch_dimensions;
@@ -100,6 +102,7 @@ Status KernelThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
VLOG(3) << " Arg: alloc #" << arg->index() << ": " << buf.opaque() << " ("
<< buf.size() << "B)";
}
+ auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
if (!stream->parent()->Launch(
stream, se::ThreadDim(launch_dimensions.threads_per_block()),
se::BlockDim(launch_dimensions.block_count()), *kernel,
diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h
index 7def27e189..d751de50ad 100644
--- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -62,7 +63,8 @@ class KernelThunk : public Thunk {
// Executes the kernel for the thunk on "stream", which must be non-null.
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) override;
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) override;
private:
// Buffers passed to the kernel as arguments.
diff --git a/tensorflow/compiler/xla/service/gpu/memset_thunk.cc b/tensorflow/compiler/xla/service/gpu/memset_thunk.cc
index d4100a898b..9fd6cf7157 100644
--- a/tensorflow/compiler/xla/service/gpu/memset_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/memset_thunk.cc
@@ -14,21 +14,27 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/memset_thunk.h"
+
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/stream_executor/stream_executor.h"
namespace xla {
namespace gpu {
Status MemzeroThunk::ExecuteOnStream(
- const BufferAllocations& buffer_allocations, se::Stream* stream) {
+ const BufferAllocations& buffer_allocations, se::Stream* stream,
+ HloExecutionProfiler* profiler) {
se::DeviceMemoryBase dest_data = buffer_allocations.GetDeviceAddress(dest_);
+ auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
stream->ThenMemZero(&dest_data, dest_data.size());
return Status::OK();
}
Status Memset32BitValueThunk::ExecuteOnStream(
- const BufferAllocations& buffer_allocations, se::Stream* stream) {
+ const BufferAllocations& buffer_allocations, se::Stream* stream,
+ HloExecutionProfiler* profiler) {
se::DeviceMemoryBase dest_data = buffer_allocations.GetDeviceAddress(dest_);
+ auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
stream->ThenMemset32(&dest_data, value_, dest_data.size());
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/gpu/memset_thunk.h b/tensorflow/compiler/xla/service/gpu/memset_thunk.h
index 51c332d287..d1fec0bd76 100644
--- a/tensorflow/compiler/xla/service/gpu/memset_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/memset_thunk.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MEMSET_THUNK_H_
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/status.h"
@@ -36,7 +37,8 @@ class MemzeroThunk : public Thunk {
: Thunk(Kind::kMemzero, hlo), dest_(dest) {}
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) override;
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) override;
private:
const BufferAllocation::Slice dest_;
@@ -52,7 +54,8 @@ class Memset32BitValueThunk : public Thunk {
: Thunk(Kind::kMemset32BitValue, hlo), value_(value), dest_(dest) {}
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) override;
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) override;
private:
uint32 value_;
diff --git a/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc b/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc
index 88cb10883e..dfdba7d7d9 100644
--- a/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/core/lib/core/errors.h"
namespace xla {
@@ -33,9 +34,17 @@ Status SequentialThunk::Initialize(const GpuExecutable& executable,
}
Status SequentialThunk::ExecuteOnStream(
- const BufferAllocations& buffer_allocations, se::Stream* stream) {
+ const BufferAllocations& buffer_allocations, se::Stream* stream,
+ HloExecutionProfiler* profiler) {
+ // TODO(b/71544591): We need to potentially measure the total time of the
+ // sequential thunk. This happens for a reduce op which consists of
+ // SequentialThunk with a thunk that initializes the output, and another thunk
+ // that does the actual reduce. Right now, in this case we would only measure
+ // the time of the last thunk, because both thunks would have the same
+ // HloInstruction.
for (const auto& thunk : thunks_) {
- TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(buffer_allocations, stream));
+ TF_RETURN_IF_ERROR(
+ thunk->ExecuteOnStream(buffer_allocations, stream, profiler));
}
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/gpu/sequential_thunk.h b/tensorflow/compiler/xla/service/gpu/sequential_thunk.h
index 135f79e413..3c4de1d1a6 100644
--- a/tensorflow/compiler/xla/service/gpu/sequential_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/sequential_thunk.h
@@ -19,6 +19,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -41,7 +42,8 @@ class SequentialThunk : public Thunk {
Status Initialize(const GpuExecutable& executable,
se::StreamExecutor* executor) override;
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) override;
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) override;
private:
// The list of sub-thunks.
diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h
index 931c0bffab..14d41033c2 100644
--- a/tensorflow/compiler/xla/service/gpu/thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/thunk.h
@@ -20,6 +20,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -94,11 +95,12 @@ class Thunk {
// Execute the kernel for the thunk on the given stream. This method must be
// called after Initialize and can be called multiple times over Thunk's
- // lifetime. Stream argument must be non-null.
+ // lifetime. 'stream' and 'profiler' must be non-null.
//
// Precondition: Initialize(stream->parent()) has been called.
virtual Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) = 0;
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) = 0;
private:
Kind kind_;
diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc
index 97cb04c38f..a10e40451c 100644
--- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc
@@ -15,13 +15,15 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/tuple_thunk.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/util.h"
namespace xla {
namespace gpu {
Status TupleThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) {
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) {
std::vector<void*> tuple_element_buffer_addresses;
for (BufferAllocation::Slice tuple_element_buffer : tuple_element_buffers_) {
tuple_element_buffer_addresses.push_back(
@@ -31,6 +33,7 @@ Status TupleThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
buffer_allocations.GetDeviceAddress(dest_buffer_));
auto host_size = tuple_element_buffer_addresses.size() * sizeof(void*);
+ auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
if (!stream
->ThenMemcpy(&dest_buffer_address,
tuple_element_buffer_addresses.data(), host_size)
diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h
index 951f809b51..2d5735d6c4 100644
--- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
@@ -46,7 +47,8 @@ class TupleThunk : public Thunk {
TupleThunk& operator=(const TupleThunk&) = delete;
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) override;
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) override;
private:
const std::vector<BufferAllocation::Slice> tuple_element_buffers_;
diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.cc b/tensorflow/compiler/xla/service/gpu/while_thunk.cc
index 30b9640c4c..5e13f989c2 100644
--- a/tensorflow/compiler/xla/service/gpu/while_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/while_thunk.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/while_thunk.h"
#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -43,14 +44,18 @@ Status WhileThunk::Initialize(const GpuExecutable& executable,
}
Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) {
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) {
se::DeviceMemoryBase condition_result_data =
buffer_allocations.GetDeviceAddress(condition_result_buffer_index_);
+ auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
while (true) {
// Invoke thunk sequence for while 'condition' computation.
- TF_RETURN_IF_ERROR(
- condition_thunk_sequence_->ExecuteOnStream(buffer_allocations, stream));
+ profiler->StartHloComputation();
+ TF_RETURN_IF_ERROR(condition_thunk_sequence_->ExecuteOnStream(
+ buffer_allocations, stream, profiler));
+ profiler->FinishHloComputation(hlo_instruction()->while_condition());
// Copy the result of condition computation and break the loop if 'false'.
bool condition_result;
@@ -66,9 +71,14 @@ Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
break;
}
- // Invoke thunk sequence for while 'body' computation.
- TF_RETURN_IF_ERROR(
- body_thunk_sequence_->ExecuteOnStream(buffer_allocations, stream));
+ // We measure the time of one execution of the while body computation. The
+ // while body may be executed more than once, the last measurement "wins".
+ profiler->StartHloComputation();
+ // Invoke thunk sequence for while 'body' computation, and pass on
+ // 'profiler' to measure the timing of the thunks in 'body_thunk_sequence_'.
+ TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(buffer_allocations,
+ stream, profiler));
+ profiler->FinishHloComputation(hlo_instruction()->while_body());
}
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.h b/tensorflow/compiler/xla/service/gpu/while_thunk.h
index 22176685a9..9270f95ee6 100644
--- a/tensorflow/compiler/xla/service/gpu/while_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/while_thunk.h
@@ -19,6 +19,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -48,7 +49,8 @@ class WhileThunk : public Thunk {
Status Initialize(const GpuExecutable& executable,
se::StreamExecutor* executor) override;
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) override;
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) override;
private:
const BufferAllocation::Slice condition_result_buffer_index_;
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index 2166c34358..56d24423c4 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -891,44 +891,51 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
/* static */ Status ShapeUtil::ValidateShapeSize(const Shape& shape) {
VLOG(3) << "Validating shape size: " << ShapeUtil::HumanString(shape);
- auto invalid_argument =
- InvalidArgument("Shape %s size may overflow int64.",
- ShapeUtil::HumanString(shape).c_str());
+
if (!IsArray(shape)) {
return Status::OK();
}
- int64 shape_size;
- if (LayoutUtil::IsSparseArray(shape)) {
- shape_size = LayoutUtil::MaxSparseElements(shape.layout());
- if (shape_size < 0) {
- return invalid_argument;
- }
- shape_size = MultiplyWithoutOverflow(shape_size, ShapeUtil::Rank(shape));
- if (shape_size < 0) {
- return invalid_argument;
+
+ int64 shape_size = [&shape]() {
+ int64 shape_size;
+ if (LayoutUtil::IsSparseArray(shape)) {
+ shape_size = LayoutUtil::MaxSparseElements(shape.layout());
+ if (shape_size < 0) {
+ return shape_size;
+ }
+ shape_size = MultiplyWithoutOverflow(shape_size, ShapeUtil::Rank(shape));
+ if (shape_size < 0) {
+ return shape_size;
+ }
+ shape_size = MultiplyWithoutOverflow(shape_size, sizeof(int64));
+ if (shape_size < 0) {
+ return shape_size;
+ }
}
- shape_size = MultiplyWithoutOverflow(shape_size, sizeof(int64));
- if (shape_size < 0) {
- return invalid_argument;
+
+ shape_size = 1;
+
+ // This is intentionally unconditional: even if the shape is sparse, we want
+ // to verify the densified version has a reasonable size.
+ if (shape.dimensions().empty()) {
+ return shape_size;
}
- }
- // This is intentionally unconditional: even if the shape is sparse, we want
- // to verify the densified version has a reasonable size.
- if (shape.dimensions().empty()) {
- return Status::OK();
- }
- shape_size = 1;
- for (int64 dim : shape.dimensions()) {
- shape_size = MultiplyWithoutOverflow(shape_size, dim);
- if (shape_size < 0) {
- return invalid_argument;
+ for (int64 dim : shape.dimensions()) {
+ shape_size = MultiplyWithoutOverflow(shape_size, dim);
+ if (shape_size < 0) {
+ return shape_size;
+ }
}
- }
- shape_size = MultiplyWithoutOverflow(
- shape_size, ByteSizeOfPrimitiveType(shape.element_type()));
+ shape_size = MultiplyWithoutOverflow(
+ shape_size, ByteSizeOfPrimitiveType(shape.element_type()));
+
+ return shape_size;
+ }();
+
if (shape_size < 0) {
- return invalid_argument;
+ return InvalidArgument("Shape %s size may overflow int64.",
+ ShapeUtil::HumanString(shape).c_str());
}
VLOG(3) << "Shape size is valid: " << shape_size;
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 20b2885e90..77d398e5e2 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -886,6 +886,7 @@ xla_test(
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/client/lib:math",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/service:hlo",
diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc
index d9d7ba1362..217673c8cb 100644
--- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc
+++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
@@ -118,7 +119,7 @@ XLA_TEST_P(BatchNormalizationTest, SubtractInZ) {
XLA_TEST_P(BatchNormalizationTest, SquareTesseractElementwise) {
XlaBuilder builder("square_tesseract_elementwise");
auto x = ConstantLiteral(&builder, input_literal_);
- SquareF32(x);
+ Square(x);
using tensorflow::MathUtil;
@@ -150,7 +151,7 @@ XLA_TEST_P(BatchNormalizationTest, SquareAndReduce) {
auto activation_deviations = Sub(input_activations, set_means,
/*broadcast_dimensions=*/{1});
XlaComputation add = CreateScalarAddComputation(F32, &builder);
- auto dev_squares = SquareF32(activation_deviations);
+ auto dev_squares = Square(activation_deviations);
Reduce(dev_squares, ConstantR0<float>(&builder, 0.0f), add, {0, 2, 3});
std::vector<float> expected = {18, 0.06};
@@ -160,7 +161,7 @@ XLA_TEST_P(BatchNormalizationTest, SquareAndReduce) {
XLA_TEST_P(BatchNormalizationTest, VarianceToStddev) {
XlaBuilder builder("variance_to_stddev");
auto variance = ConstantR1<float>(&builder, {6.f, .02f});
- SqrtF32(variance);
+ Sqrt(variance);
std::vector<float> expected = {2.44948974f, 0.14142136f};
ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
@@ -195,20 +196,20 @@ XLA_TEST_P(BatchNormalizationTest, SpecComparisonForward) {
auto epsilon2 = ConstantR1<float>(&builder, {kEpsilon, kEpsilon});
auto activation_deviations = Sub(input_activations, set_means,
/*broadcast_dimensions=*/{1});
- auto dev_squares = SquareF32(activation_deviations);
+ auto dev_squares = Square(activation_deviations);
auto sum_of_squares =
CheckShape(&builder,
Reduce(dev_squares, ConstantR0<float>(&builder, 0.0f), add,
/*dimensions_to_reduce=*/{0, 2, 3}),
TwoElementVectorF32);
auto variance = Div(sum_of_squares, count);
- auto standard_deviation = SqrtF32(variance);
+ auto standard_deviation = Sqrt(variance);
auto standard_deviation_above_epsilon =
CheckShape(&builder, Gt(standard_deviation, epsilon),
ShapeUtil::MakeShape(PRED, {2}));
auto gt_eps =
Select(standard_deviation_above_epsilon, standard_deviation, epsilon2);
- auto normalization_factors = ReciprocalF32(gt_eps);
+ auto normalization_factors = Reciprocal(gt_eps);
auto normalized_input_activations =
Mul(activation_deviations, normalization_factors,
/*broadcast_dimensions=*/{1});
diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
index 5fdd1018a4..91aba9a8de 100644
--- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
+++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
@@ -156,6 +156,86 @@ XLA_TEST_F(BroadcastSimpleTest, 1DTo2D) {
ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
}
+XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsUsual) {
+ XlaBuilder b(TestName());
+ BroadcastInDim(ConstantR1<float>(&b, {1, 2}),
+ ShapeUtil::MakeShape(F32, {2, 2}), {1});
+
+ Array2D<float> expected(2, 2);
+ expected(0, 0) = 1;
+ expected(0, 1) = 2;
+ expected(1, 0) = 1;
+ expected(1, 1) = 2;
+
+ ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsTranspose) {
+ XlaBuilder b(TestName());
+ BroadcastInDim(ConstantR1<float>(&b, {1, 2}),
+ ShapeUtil::MakeShape(F32, {2, 2}), {0});
+
+ Array2D<float> expected(2, 2);
+ expected(0, 0) = 1;
+ expected(0, 1) = 1;
+ expected(1, 0) = 2;
+ expected(1, 1) = 2;
+
+ ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(BroadcastSimpleTest, 2DTo3D_WithDims) {
+ XlaBuilder b(TestName());
+ BroadcastInDim(ConstantR2<float>(&b, {{1.0, 5.0}, {2.0, 6.0}}),
+ ShapeUtil::MakeShape(F32, {2, 2, 2}), {0, 1});
+
+ Array3D<float> expected(2, 2, 2);
+ expected(0, 0, 0) = 1.0;
+ expected(1, 0, 0) = 2.0;
+ expected(0, 0, 1) = 1.0;
+ expected(1, 0, 1) = 2.0;
+ expected(0, 1, 0) = 5.0;
+ expected(1, 1, 0) = 6.0;
+ expected(1, 1, 1) = 6.0;
+ expected(0, 1, 1) = 5.0;
+
+ ComputeAndCompareR3<float>(&b, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(BroadcastSimpleTest, 2DTo3D_WithDimsNotPossibleWithBroadCast) {
+ XlaBuilder b(TestName());
+ BroadcastInDim(ConstantR2<float>(&b, {{1.0, 5.0}, {2.0, 6.0}}),
+ ShapeUtil::MakeShape(F32, {2, 2, 2}), {0, 2});
+
+ Array3D<float> expected(2, 2, 2);
+ expected(0, 0, 0) = 1.0;
+ expected(1, 0, 0) = 2.0;
+ expected(0, 0, 1) = 5.0;
+ expected(1, 0, 1) = 6.0;
+ expected(0, 1, 0) = 1.0;
+ expected(1, 1, 0) = 2.0;
+ expected(1, 1, 1) = 6.0;
+ expected(0, 1, 1) = 5.0;
+
+ ComputeAndCompareR3<float>(&b, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsNotPossibleWithBroadCast) {
+ XlaBuilder b(TestName());
+ BroadcastInDim(ConstantR1<float>(&b, {1, 2}),
+ ShapeUtil::MakeShape(F32, {3, 2}), {1});
+
+ Array2D<float> expected(3, 2);
+ expected(0, 0) = 1;
+ expected(0, 1) = 2;
+ expected(1, 0) = 1;
+ expected(1, 1) = 2;
+ expected(2, 0) = 1;
+ expected(2, 1) = 2;
+
+ ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
+}
+
// Tests implicit broadcasting of PREDs.
XLA_TEST_F(BroadcastSimpleTest, BooleanAnd2DTo3D_Pred) {
XlaBuilder b(TestName());
diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc
index bc994315c3..3afd8c8fc8 100644
--- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc
+++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc
@@ -897,18 +897,6 @@ XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionS32) {
ComputeAndCompareR0<int32>(&b, 10, {});
}
-XLA_TEST_F(ScalarComputationsTest, SqrtF320) {
- XlaBuilder builder(TestName());
- Literal zero_literal = Literal::Zero(PrimitiveType::F32);
-
- std::unique_ptr<GlobalData> zero_data =
- client_->TransferToServer(zero_literal).ConsumeValueOrDie();
-
- XlaOp zero = Parameter(&builder, 0, zero_literal.shape(), "zero");
- SqrtF32(zero);
-
- ComputeAndCompareR0<float>(&builder, 0.0f, {zero_data.get()}, error_spec_);
-}
XLA_TEST_F(ScalarComputationsTest, RoundScalar) {
XlaBuilder builder(TestName());
diff --git a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc
index c11df7cdf5..79bae22dac 100644
--- a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc
+++ b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc
@@ -135,46 +135,6 @@ XLA_TEST_F(VecOpsSimpleTest, NegateUint32Values) {
ComputeAndCompareR1<uint32>(&builder, expected, {});
}
-XLA_TEST_F(VecOpsSimpleTest, SquareTenValues) {
- XlaBuilder builder(TestName());
- auto x = ConstantR1<float>(
- &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
- SquareF32(x);
-
- std::vector<float> expected = {4.41, 6.76, 6.76, 16., 4.41,
- 5.29, 25., 0.81, 5.76, 2.56};
- ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
-}
-
-XLA_TEST_F(VecOpsSimpleTest, ReciprocalTenValues) {
- XlaBuilder builder(TestName());
- auto x = ConstantR1<float>(
- &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
- ReciprocalF32(x);
-
- std::vector<float> expected = {
- 0.47619048, -0.38461538, 0.38461538, -0.25, 0.47619048,
- 0.43478261, -0.2, -1.11111111, -0.41666667, 0.625};
- ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
-}
-
-XLA_TEST_F(VecOpsSimpleTest, SqrtZeroes) {
- XlaBuilder builder(TestName());
- auto x = ConstantR1<float>(&builder, {0.0, -0.0});
- SqrtF32(x);
-
- ComputeAndCompareR1<float>(&builder, {0, 0}, {}, error_spec_);
-}
-
-XLA_TEST_F(VecOpsSimpleTest, SqrtSixValues) {
- XlaBuilder builder(TestName());
- auto x = ConstantR1<float>(&builder, {16.0, 1.0, 1024.0, 0.16, 0.2, 12345});
- SqrtF32(x);
-
- std::vector<float> expected = {4, 1, 32, 0.4, 0.4472, 111.1080};
- ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
-}
-
XLA_TEST_F(VecOpsSimpleTest, InvSqrtSevenValues) {
XlaBuilder builder(TestName());
auto x = ConstantR1<float>(&builder,
diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
index c0616809f9..7dba058d40 100644
--- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
+++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
@@ -240,9 +240,7 @@ XLA_TEST_F(HloProfileTest, ProfileSingleComputation) {
EXPECT_TRUE(HasTrops(tanh_profile));
}
-// TODO(b/71544591): The GPU backend does not record cycles spent in on Hlo
-// instructions "interior" to while nodes.
-XLA_TEST_F(HloProfileTest, DISABLED_ON_GPU(ProfileWhileComputation)) {
+XLA_TEST_F(HloProfileTest, ProfileWhileComputation) {
const int64 size = 256;
Shape matrix_shape = ShapeUtil::MakeShape(F32, {size, size});
Shape while_result_shape =
@@ -337,8 +335,11 @@ static std::pair<int, char**> AddXlaHloProfileFlag(int argc, char** argv) {
new_argv[argc] = strdup("--xla_hlo_profile");
// Fusion can change the Hlo instructions that show up in the final Hlo
- // executable, so block it here.
- new_argv[argc + 1] = strdup("--xla_disable_hlo_passes=fusion");
+ // executable, so block it here. Also block the WhileLoopInvariantCodeMotion
+ // pass, otherwise a while loop is transformed and we could not match the
+ // original name in the ProfileWhileComputation test.
+ new_argv[argc + 1] = strdup(
+ "--xla_disable_hlo_passes=fusion,while-loop-invariant-code-motion");
return {argc + 2, new_argv};
}
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index fa69efa3f6..717533a13f 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -134,6 +134,8 @@ py_library(
"//tensorflow/contrib/bigtable",
"//tensorflow/contrib/cloud:cloud_py",
"//tensorflow/contrib/ffmpeg:ffmpeg_ops_py",
+ "//tensorflow/contrib/cloud:cloud_py", # depends on bigtable
+ "//tensorflow/contrib/bigtable", # doesn't compile on Windows
"//tensorflow/contrib/lite/python:lite", # unix dependency, need to fix code
]),
)
diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py
index 8c1ce5c2a2..2fbaa31d5e 100644
--- a/tensorflow/contrib/checkpoint/__init__.py
+++ b/tensorflow/contrib/checkpoint/__init__.py
@@ -44,8 +44,8 @@ from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import Checkpointa
from tensorflow.python.training.checkpointable.base import CheckpointableBase
from tensorflow.python.training.checkpointable.data_structures import List
from tensorflow.python.training.checkpointable.data_structures import Mapping
+from tensorflow.python.training.checkpointable.data_structures import NoDependency
from tensorflow.python.training.checkpointable.tracking import Checkpointable
-from tensorflow.python.training.checkpointable.tracking import NoDependency
from tensorflow.python.training.checkpointable.util import capture_dependencies
from tensorflow.python.training.checkpointable.util import list_objects
from tensorflow.python.training.checkpointable.util import object_metadata
diff --git a/tensorflow/contrib/checkpoint/python/containers_test.py b/tensorflow/contrib/checkpoint/python/containers_test.py
index 64d056bd68..ac85c7be80 100644
--- a/tensorflow/contrib/checkpoint/python/containers_test.py
+++ b/tensorflow/contrib/checkpoint/python/containers_test.py
@@ -26,6 +26,7 @@ from tensorflow.python.keras import layers
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
+from tensorflow.python.training.checkpointable import data_structures
from tensorflow.python.training.checkpointable import tracking
from tensorflow.python.training.checkpointable import util
@@ -79,7 +80,7 @@ class UniqueNameTrackerTests(test.TestCase):
resource_variable_ops.ResourceVariable(4.), "y"))
slots.append(slotdeps.track(
resource_variable_ops.ResourceVariable(5.), "x"))
- self.slots = slots
+ self.slots = data_structures.NoDependency(slots)
manager = SlotManager()
self.evaluate([v.initializer for v in manager.slots])
diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
index 5fc7e51d81..2022c1f2bd 100644
--- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
@@ -616,7 +616,44 @@ class BucketBySequenceLength(test.TestCase):
batch_sizes = batch_sizes[:-1]
self.assertEqual(sum(batch_sizes_val), sum(batch_sizes))
self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
- self.assertEqual(sorted(boundaries), sorted(lengths_val))
+ self.assertEqual([boundary - 1 for boundary in sorted(boundaries)],
+ sorted(lengths_val))
+
+ def testPadToBoundaryNoExtraneousPadding(self):
+
+ boundaries = [3, 7, 11]
+ batch_sizes = [2, 2, 2, 2]
+ lengths = range(1, 11)
+
+ def element_gen():
+ for length in lengths:
+ yield ([1] * length,)
+
+ element_len = lambda element: array_ops.shape(element)[0]
+ dataset = dataset_ops.Dataset.from_generator(
+ element_gen, (dtypes.int64,), ([None],)).apply(
+ grouping.bucket_by_sequence_length(
+ element_len, boundaries, batch_sizes,
+ pad_to_bucket_boundary=True))
+ batch, = dataset.make_one_shot_iterator().get_next()
+
+ with self.test_session() as sess:
+ batches = []
+ for _ in range(5):
+ batches.append(sess.run(batch))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(batch)
+
+ self.assertAllEqual(batches[0], [[1, 0],
+ [1, 1]])
+ self.assertAllEqual(batches[1], [[1, 1, 1, 0, 0, 0],
+ [1, 1, 1, 1, 0, 0]])
+ self.assertAllEqual(batches[2], [[1, 1, 1, 1, 1, 0],
+ [1, 1, 1, 1, 1, 1]])
+ self.assertAllEqual(batches[3], [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
+ [1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
+ self.assertAllEqual(batches[4], [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
+ [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
def testTupleElements(self):
diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
index 20ed639750..40a8e46676 100644
--- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
@@ -237,9 +237,9 @@ class PrefetchingKernelsOpsTest(test.TestCase):
buffer_resource_handle, ignore_lookup_error=True)
with self.test_session() as sess:
- self.assertEqual(["a"], sess.run(prefetch_op))
- self.assertEqual(["b"], sess.run(prefetch_op))
- self.assertEqual(["c"], sess.run(prefetch_op))
+ self.assertEqual([b"a"], sess.run(prefetch_op))
+ self.assertEqual([b"b"], sess.run(prefetch_op))
+ self.assertEqual([b"c"], sess.run(prefetch_op))
with self.assertRaises(errors.OutOfRangeError):
sess.run(prefetch_op)
diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py
index ca9540bf13..5d9640a768 100644
--- a/tensorflow/contrib/data/python/ops/grouping.py
+++ b/tensorflow/contrib/data/python/ops/grouping.py
@@ -149,9 +149,9 @@ def bucket_by_sequence_length(element_length_func,
@{tf.data.Dataset.padded_batch}. Defaults to padding with 0.
pad_to_bucket_boundary: bool, if `False`, will pad dimensions with unknown
size to maximum length in batch. If `True`, will pad dimensions with
- unknown size to bucket boundary, and caller must ensure that the source
- `Dataset` does not contain any elements with length longer than
- `max(bucket_boundaries)`.
+ unknown size to bucket boundary minus 1 (i.e., the maximum length in each
+ bucket), and caller must ensure that the source `Dataset` does not contain
+ any elements with length longer than `max(bucket_boundaries)`.
Returns:
A `Dataset` transformation function, which can be passed to
@@ -203,7 +203,7 @@ def bucket_by_sequence_length(element_length_func,
none_filler = None
if pad_to_bucket_boundary:
err_msg = ("When pad_to_bucket_boundary=True, elements must have "
- "length <= max(bucket_boundaries).")
+ "length < max(bucket_boundaries).")
check = check_ops.assert_less(
bucket_id,
constant_op.constant(len(bucket_batch_sizes) - 1,
@@ -213,7 +213,7 @@ def bucket_by_sequence_length(element_length_func,
boundaries = constant_op.constant(bucket_boundaries,
dtype=dtypes.int64)
bucket_boundary = boundaries[bucket_id]
- none_filler = bucket_boundary
+ none_filler = bucket_boundary - 1
shapes = make_padded_shapes(
padded_shapes or grouped_dataset.output_shapes,
none_filler=none_filler)
diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops.py b/tensorflow/contrib/distribute/python/cross_tower_ops.py
index 0261ce43fa..06555c6760 100644
--- a/tensorflow/contrib/distribute/python/cross_tower_ops.py
+++ b/tensorflow/contrib/distribute/python/cross_tower_ops.py
@@ -28,6 +28,7 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import device_util
@@ -88,7 +89,7 @@ def _simple_broadcast(value, destinations):
def _simple_reduce(per_device_value, reduce_to_device, accumulation_fn,
- method_string):
+ aggregation):
# pylint: disable=g-missing-docstring
all_values = []
count = 0
@@ -112,11 +113,12 @@ def _simple_reduce(per_device_value, reduce_to_device, accumulation_fn,
with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
reduced = cross_tower_utils.aggregate_tensors_or_indexed_slices(
all_values, accumulation_fn)
- if method_string == "mean":
+ if aggregation == vs.VariableAggregation.MEAN:
reduced = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(
reduced, count)
- elif method_string != "sum":
- raise ValueError("`method_string` must be 'sum' or 'mean'")
+ elif aggregation != vs.VariableAggregation.SUM:
+ raise ValueError("`aggregation` must be `sum`(VariableAggregation.SUM) "
+ "or `mean`(VariableAggregation.MEAN).")
return reduced
@@ -126,14 +128,15 @@ class CrossTowerOps(object):
def __init__(self):
pass
- def reduce(self, method_string, per_device_value, destinations=None):
+ def reduce(self, aggregation, per_device_value, destinations=None):
"""Reduce `per_device_value` to `destinations`.
- It runs the reduction operation defined by `method_string` and put the
+ It runs the reduction operation defined by `aggregation` and put the
result on `destinations`.
Args:
- method_string: either 'sum' or 'mean' specifying the reduction method.
+ aggregation: Indicates how a variable will be aggregated. Accepted values
+ are @{tf.VariableAggregation.SUM}, @{tf.VariableAggregation.MEAN}.
per_device_value: a PerDevice object.
destinations: the reduction destinations.
@@ -147,16 +150,17 @@ class CrossTowerOps(object):
raise ValueError("`per_device_value` must be a `PerDevice` object.")
if destinations is not None:
validate_destinations(destinations)
- return self._reduce(method_string, per_device_value, destinations)
+ return self._reduce(aggregation, per_device_value, destinations)
- def batch_reduce(self, method_string, value_destination_pairs):
+ def batch_reduce(self, aggregation, value_destination_pairs):
"""Reduce PerDevice objects in a batch.
Reduce each first element in `value_destination_pairs` to each second
element which indicates the destinations.
Args:
- method_string: either 'sum' or 'mean' specifying the reduction method.
+ aggregation: Indicates how a variable will be aggregated. Accepted values
+ are @{tf.VariableAggregation.SUM}, @{tf.VariableAggregation.MEAN}.
value_destination_pairs: a list or a tuple of tuples of PerDevice objects
and destinations. If a destination is None, then the destinations
are set to match the devices of the input PerDevice object.
@@ -175,7 +179,7 @@ class CrossTowerOps(object):
if d is not None:
validate_destinations(d)
- return self._batch_reduce(method_string, value_destination_pairs)
+ return self._batch_reduce(aggregation, value_destination_pairs)
def broadcast(self, tensor, destinations):
"""Broadcast the `tensor` to destinations.
@@ -190,11 +194,11 @@ class CrossTowerOps(object):
validate_destinations(destinations)
return self._broadcast(tensor, destinations)
- def _reduce(self, method_string, per_device_value, destinations):
+ def _reduce(self, aggregation, per_device_value, destinations):
raise NotImplementedError(
"_reduce method must be implemented in descendants.")
- def _batch_reduce(self, method_string, value_destination_pairs):
+ def _batch_reduce(self, aggregation, value_destination_pairs):
raise NotImplementedError(
"_batch_reduce method must be implemented in descendants.")
@@ -220,16 +224,18 @@ class ReductionToOneDeviceCrossTowerOps(CrossTowerOps):
self.accumulation_fn = accumulation_fn
super(ReductionToOneDeviceCrossTowerOps, self).__init__()
- def _reduce(self, method_string, per_device_value, destinations):
+ def _reduce(self, aggregation, per_device_value, destinations):
devices = get_devices_from(destinations or per_device_value)
reduce_to_device = self.reduce_to_device or devices[0]
reduced = _simple_reduce(per_device_value, reduce_to_device,
- self.accumulation_fn, method_string)
+ self.accumulation_fn, aggregation)
return self.broadcast(reduced, devices)
- def _batch_reduce(self, method_string, value_destination_pairs):
- return [self._reduce(method_string, t, destinations=v)
- for t, v in value_destination_pairs]
+ def _batch_reduce(self, aggregation, value_destination_pairs):
+ return [
+ self._reduce(aggregation, t, destinations=v)
+ for t, v in value_destination_pairs
+ ]
def _group_value_by_device(per_device_values):
@@ -260,18 +266,19 @@ def _group_value_by_device(per_device_values):
return grouped
-def _ungroup_and_make_mirrored(grouped_reduced, destinations, method_string):
+def _ungroup_and_make_mirrored(grouped_reduced, destinations, aggregation):
"""Ungroup results from all-reduce and make Mirrored objects.
Each all-reduce result will be divided by the number of destinations before
- Mirrored objects are created if method_string is "mean".
+ Mirrored objects are created if aggregation is "mean".
Args:
grouped_reduced: a list of lists, each sublist has components for each
device, paired with a None. It is the result from
cross_tower_utils.aggregate_gradients_using*.
destinations: a list of device strings for returned Mirrored objects.
- method_string: "mean" or "sum".
+ aggregation: Indicates how a variable will be aggregated. Accepted values
+ are @{tf.VariableAggregation.SUM}, @{tf.VariableAggregation.MEAN}.
Returns:
a list of Mirrored objects.
@@ -279,7 +286,7 @@ def _ungroup_and_make_mirrored(grouped_reduced, destinations, method_string):
index = [{} for _ in range(len(grouped_reduced[0]))]
for d, per_device_reduced in enumerate(grouped_reduced):
for i, (v, _) in enumerate(per_device_reduced):
- if method_string == "mean":
+ if aggregation == vs.VariableAggregation.MEAN:
index[i][destinations[d]] = v / len(destinations)
else:
index[i][destinations[d]] = v
@@ -488,13 +495,13 @@ class AllReduceCrossTowerOps(CrossTowerOps):
self._agg_small_grads_max_group = agg_small_grads_max_group
super(AllReduceCrossTowerOps, self).__init__()
- def _reduce(self, method_string, per_device_value, destinations):
+ def _reduce(self, aggregation, per_device_value, destinations):
contains_indexed_slices = cross_tower_utils.contains_indexed_slices(
per_device_value)
if ((destinations is None or _devices_match(per_device_value, destinations))
and not context.executing_eagerly()
and not contains_indexed_slices):
- return self._batch_all_reduce(method_string, [per_device_value])[0]
+ return self._batch_all_reduce(aggregation, [per_device_value])[0]
else:
if contains_indexed_slices:
logging.log_first_n(
@@ -504,16 +511,16 @@ class AllReduceCrossTowerOps(CrossTowerOps):
devices = get_devices_from(destinations or per_device_value)
reduce_to_device = devices[0]
reduced = _simple_reduce(per_device_value, reduce_to_device,
- math_ops.add_n, method_string)
+ math_ops.add_n, aggregation)
return self.broadcast(reduced, devices)
- def _batch_reduce(self, method_string, value_destination_pairs):
+ def _batch_reduce(self, aggregation, value_destination_pairs):
all_devices_match = _all_devices_match(value_destination_pairs)
contains_indexed_slices = cross_tower_utils.contains_indexed_slices(
value_destination_pairs)
if (all_devices_match and not context.executing_eagerly()
and not contains_indexed_slices):
- return self._batch_all_reduce(method_string,
+ return self._batch_all_reduce(aggregation,
[v[0] for v in value_destination_pairs])
else:
if not all_devices_match:
@@ -521,11 +528,11 @@ class AllReduceCrossTowerOps(CrossTowerOps):
"destinations are different.")
return [
- self._reduce(method_string, t, destinations=v)
+ self._reduce(aggregation, t, destinations=v)
for t, v in value_destination_pairs
]
- def _batch_all_reduce(self, method_string, per_device_values):
+ def _batch_all_reduce(self, aggregation, per_device_values):
"""All reduce algorithm in a batch."""
logging.info(
"batch_all_reduce invoked for batches size = %d with "
@@ -556,7 +563,7 @@ class AllReduceCrossTowerOps(CrossTowerOps):
reduced = _unpack_tensors(reduced, tensor_packer)
return _ungroup_and_make_mirrored(reduced, per_device_values[0].devices,
- method_string)
+ aggregation)
AllReduceSpecTuple = collections.namedtuple("AllReduceSpecTuple",
@@ -635,7 +642,7 @@ class MultiWorkerAllReduce(AllReduceCrossTowerOps):
validate_and_complete_spec(spec) for spec in all_reduce_spec
]
- def _batch_all_reduce(self, method_string, per_device_values):
+ def _batch_all_reduce(self, aggregation, per_device_values):
"""All reduce algorithm in a batch."""
logging.info(
"distributed batch_all_reduce invoked for batches size = %d with "
@@ -682,7 +689,7 @@ class MultiWorkerAllReduce(AllReduceCrossTowerOps):
assert not remaining_grads
return _ungroup_and_make_mirrored(aggregated_grads, destinations,
- method_string)
+ aggregation)
_dgx1_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7],
diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
index c540ea0d23..6a780ff60f 100644
--- a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
+++ b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
@@ -32,6 +32,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.training import device_util
@@ -129,32 +130,45 @@ class CrossTowerOpsTestBase(test.TestCase, parameterized.TestCase):
# test reduce()
for destinations in all_destinations:
self._assert_values_equal(
- cross_tower_ops.reduce("mean", per_device, destinations=destinations),
+ cross_tower_ops.reduce(
+ vs.VariableAggregation.MEAN,
+ per_device,
+ destinations=destinations),
_fake_mirrored(mean, destinations or per_device))
self._assert_values_equal(
cross_tower_ops.reduce(
- "mean", per_device_2, destinations=destinations),
+ vs.VariableAggregation.MEAN,
+ per_device_2,
+ destinations=destinations),
_fake_mirrored(mean_2, destinations or per_device))
self._assert_values_equal(
- cross_tower_ops.reduce("sum", per_device, destinations=destinations),
+ cross_tower_ops.reduce(
+ vs.VariableAggregation.SUM, per_device,
+ destinations=destinations),
_fake_mirrored(mean * len(devices), destinations or per_device))
self._assert_values_equal(
cross_tower_ops.reduce(
- "sum", per_device_2, destinations=destinations),
+ vs.VariableAggregation.SUM,
+ per_device_2,
+ destinations=destinations),
_fake_mirrored(mean_2 * len(devices), destinations or per_device))
# test batch_reduce()
for d1, d2 in itertools.product(all_destinations, all_destinations):
self._assert_values_equal(
- cross_tower_ops.batch_reduce(
- "mean", [(per_device, d1), (per_device_2, d2)]),
- [_fake_mirrored(mean, d1 or per_device),
- _fake_mirrored(mean_2, d2 or per_device_2)])
+ cross_tower_ops.batch_reduce(vs.VariableAggregation.MEAN,
+ [(per_device, d1), (per_device_2, d2)]),
+ [
+ _fake_mirrored(mean, d1 or per_device),
+ _fake_mirrored(mean_2, d2 or per_device_2)
+ ])
self._assert_values_equal(
- cross_tower_ops.batch_reduce(
- "sum", [(per_device, d1), (per_device_2, d2)]),
- [_fake_mirrored(mean * len(devices), d1 or per_device),
- _fake_mirrored(mean_2 * len(devices), d2 or per_device_2)])
+ cross_tower_ops.batch_reduce(vs.VariableAggregation.SUM,
+ [(per_device, d1), (per_device_2, d2)]),
+ [
+ _fake_mirrored(mean * len(devices), d1 or per_device),
+ _fake_mirrored(mean_2 * len(devices), d2 or per_device_2)
+ ])
# test broadcast()
for destinations in all_destinations:
@@ -255,8 +269,8 @@ class SingleWorkerCrossTowerOpsTest(CrossTowerOpsTestBase):
t0 = _make_indexed_slices([[1., 2.]], [1], [5, 2], devices[0])
t1 = _make_indexed_slices([[3., 4.], [5., 6.]], [1, 3], [5, 2], devices[1])
per_device = value_lib.PerDevice({devices[0]: t0, devices[1]: t1})
- result = cross_tower_ops_lib._simple_reduce(per_device, devices[0],
- math_ops.add_n, "sum")
+ result = cross_tower_ops_lib._simple_reduce(
+ per_device, devices[0], math_ops.add_n, vs.VariableAggregation.SUM)
# Test that the result is semantically equal to both the concatenated
# IndexedSlices with and without duplicate indices.
@@ -267,21 +281,22 @@ class SingleWorkerCrossTowerOpsTest(CrossTowerOpsTestBase):
self._assert_indexed_slices_equal(total_with_dups, result)
self._assert_indexed_slices_equal(total_without_dups, result)
- @combinations.generate(combinations.combine(
- cross_tower_ops_instance=[
- combinations.NamedObject(
- "ReductionToOneDeviceCrossTowerOps",
- cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()),
- combinations.NamedObject(
- "AllReduceCrossTowerOps",
- cross_tower_ops_lib.AllReduceCrossTowerOps())
- ],
- method_string=["sum", "mean"],
- batch_reduce=[True, False],
- mode=["graph", "eager"],
- required_gpus=1))
- def testIndexedSlicesAllReduce(self, cross_tower_ops_instance,
- method_string, batch_reduce):
+ @combinations.generate(
+ combinations.combine(
+ cross_tower_ops_instance=[
+ combinations.NamedObject(
+ "ReductionToOneDeviceCrossTowerOps",
+ cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()),
+ combinations.NamedObject(
+ "AllReduceCrossTowerOps",
+ cross_tower_ops_lib.AllReduceCrossTowerOps())
+ ],
+ aggregation=[vs.VariableAggregation.SUM, vs.VariableAggregation.MEAN],
+ batch_reduce=[True, False],
+ mode=["graph", "eager"],
+ required_gpus=1))
+ def testIndexedSlicesAllReduce(self, cross_tower_ops_instance, aggregation,
+ batch_reduce):
devices = ["/cpu:0", "/gpu:0"]
dense_shape = [5, 2]
t0 = _make_indexed_slices([[1., 2.]], [1], dense_shape, devices[0])
@@ -290,20 +305,19 @@ class SingleWorkerCrossTowerOpsTest(CrossTowerOpsTestBase):
per_device = value_lib.PerDevice({devices[0]: t0, devices[1]: t1})
if batch_reduce:
- result = cross_tower_ops_instance.batch_reduce(method_string,
+ result = cross_tower_ops_instance.batch_reduce(aggregation,
[(per_device, devices)])
else:
- result = cross_tower_ops_instance.reduce(method_string, per_device,
- devices)
+ result = cross_tower_ops_instance.reduce(aggregation, per_device, devices)
total_indices_with_dups = [1, 1, 3]
total_indices_without_dups = [1, 3]
- if method_string == "sum":
+ if aggregation == vs.VariableAggregation.SUM:
total_values_with_dups = [[1., 2.], [3., 4.], [5., 6.]]
total_values_without_dups = [[4., 6.], [5., 6.]]
else:
- assert method_string == "mean"
+ assert aggregation == vs.VariableAggregation.MEAN
total_values_with_dups = [[0.5, 1.], [1.5, 2.], [2.5, 3.]]
total_values_without_dups = [[2., 3.], [2.5, 3.]]
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index d269bed1e5..14c02ab1ad 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -104,9 +104,32 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
colocate_with = kwargs.pop("colocate_with", None)
devices = self._get_devices_from(colocate_with)
- tower_local = kwargs.pop("tower_local_reduce_method", None)
- if tower_local is not None:
+ # Get synchronization value
+ synchronization = kwargs.get(
+ "synchronization", variable_scope.VariableSynchronization.ON_WRITE)
+ if synchronization == variable_scope.VariableSynchronization.NONE:
+ raise ValueError("`NONE` variable synchronization mode is not "
+ "supported with `Mirrored` distribution strategy. Please"
+ " change the `synchronization` for variable: " +
+ kwargs["name"])
+ elif synchronization == variable_scope.VariableSynchronization.ON_READ:
+ # Variables that are to be synced on read are tower local.
+ is_tower_local = True
kwargs["trainable"] = False
+ elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE or
+ synchronization == variable_scope.VariableSynchronization.AUTO):
+ # `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`.
+ is_tower_local = False
+ else:
+ raise ValueError("Invalid variable synchronization mode: " +
+ synchronization + " for variable: " + kwargs["name"])
+
+ # Get aggregation value
+ aggregation = kwargs.pop("aggregation",
+ variable_scope.VariableAggregation.NONE)
+ if aggregation not in [a for a in variable_scope.VariableAggregation]:
+ raise ValueError("Invalid variable aggregation mode: " + aggregation +
+ " for variable: " + kwargs["name"])
# Ignore user-specified caching device, not needed for mirrored variables.
kwargs.pop("caching_device", None)
@@ -139,11 +162,11 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
assert not isinstance(v, values.DistributedVariable)
index[d] = v
- if tower_local is None:
- result = values.MirroredVariable(index, index[devices[0]])
+ if is_tower_local:
+ result = values.TowerLocalVariable(index, index[devices[0]],
+ aggregation)
else:
- result = values.TowerLocalVariable(
- index, index[devices[0]], tower_local)
+ result = values.MirroredVariable(index, index[devices[0]], aggregation)
if not context.executing_eagerly():
g = ops.get_default_graph()
@@ -308,12 +331,12 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps())
return self._cross_tower_ops
- def _reduce(self, method_string, value, destinations):
+ def _reduce(self, aggregation, value, destinations):
assert not isinstance(value, values.Mirrored)
if not isinstance(value, values.PerDevice):
if value == 0:
return 0
- if method_string == "mean":
+ if aggregation == variable_scope.VariableAggregation.MEAN:
return self._broadcast(value, destinations)
cross_tower_ops_lib.validate_destinations(destinations)
@@ -331,13 +354,13 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
value_updates[d] = array_ops.identity(value)
return values.Mirrored(value_updates)
raise ValueError("A non PerDevice value cannot be reduced with the given "
- "method_string.")
+ "aggregation.")
return self._get_cross_tower_ops().reduce(
- method_string, value, destinations=destinations)
+ aggregation, value, destinations=destinations)
- def _batch_reduce(self, method_string, value_destination_pairs):
- return self._get_cross_tower_ops().batch_reduce(method_string,
+ def _batch_reduce(self, aggregation, value_destination_pairs):
+ return self._get_cross_tower_ops().batch_reduce(aggregation,
value_destination_pairs)
def _update(self, var, fn, *args, **kwargs):
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
index 8d474124b7..c02817f461 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
@@ -114,7 +114,10 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase):
dist = self._get_distribution_strategy()
with dist.scope():
result = dist.call_for_each_tower(run_fn, dist.worker_device_index)
- reduced = dist.reduce("sum", result, destinations="/device:CPU:0")
+ reduced = dist.reduce(
+ variable_scope.VariableAggregation.SUM,
+ result,
+ destinations="/device:CPU:0")
unwrapped = dist.unwrap(reduced)
self.assertEqual(1, len(unwrapped))
expected = sum(range(len(dist.worker_devices)))
@@ -132,8 +135,10 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase):
dist = mirrored_strategy.MirroredStrategy(devices)
with dist.scope():
- reduced = dist.reduce("sum", 1.0, destinations=["/device:CPU:0",
- "/device:GPU:0"])
+ reduced = dist.reduce(
+ variable_scope.VariableAggregation.SUM,
+ 1.0,
+ destinations=["/device:CPU:0", "/device:GPU:0"])
unwrapped = dist.unwrap(reduced)
self.assertEqual(2, len(unwrapped))
self.assertEqual(1.0, self.evaluate(unwrapped[0]))
@@ -284,18 +289,68 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
self.assertEquals("common/dense" + suffix + "/bias:0", bias.name)
@test_util.run_in_graph_and_eager_modes(config=config)
+ def testWithVariableAndVariableScope(self):
+ self._skip_eager_if_gpus_less_than(1)
+
+ def model_fn():
+ v0 = variable_scope.variable(1.0, name="var0", aggregation=None)
+ with variable_scope.variable_scope("common"):
+ v1 = variable_scope.variable(1.0, name="var1")
+ # This will pause the current thread, and execute the other thread.
+ distribute_lib.get_tower_context().merge_call(lambda _: _)
+ v2 = variable_scope.variable(
+ 1.0,
+ name="var2",
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ aggregation=variable_scope.VariableAggregation.SUM)
+ v3 = variable_scope.variable(
+ 1.0,
+ name="var3",
+ synchronization=variable_scope.VariableSynchronization.ON_WRITE,
+ aggregation=variable_scope.VariableAggregation.MEAN)
+
+ return v0, v1, v2, v3
+
+ devices = ["/device:CPU:0", "/device:GPU:0"]
+ dist = mirrored_strategy.MirroredStrategy(devices)
+ with dist.scope():
+ v = variable_scope.variable(1.0, name="var-main0")
+ self.assertEquals("var-main0:0", v.name)
+
+ result = dist.call_for_each_tower(model_fn, run_concurrently=False)
+ self.assertEquals(4, len(result))
+ v0, v1, v2, v3 = result
+ self.assertIsInstance(v0, values.MirroredVariable)
+ self.assertEquals("var0:0", v0.name)
+ self.assertIsInstance(v1, values.MirroredVariable)
+ self.assertEquals("common/var1:0", v1.name)
+ self.assertIsInstance(v2, values.TowerLocalVariable)
+ self.assertEquals("common/var2:0", v2.name)
+ self.assertEquals(variable_scope.VariableAggregation.SUM, v2.aggregation)
+ self.assertIsInstance(v3, values.MirroredVariable)
+ self.assertEquals("common/var3:0", v3.name)
+ self.assertEquals(variable_scope.VariableAggregation.MEAN, v3.aggregation)
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
def testWithGetVariableAndVariableScope(self):
self._skip_eager_if_gpus_less_than(1)
def model_fn():
- v0 = variable_scope.get_variable("var-thread0", [1])
+ v0 = variable_scope.get_variable("var0", [1])
with variable_scope.variable_scope("common"):
- v1 = variable_scope.get_variable("var-thread1", [1])
+ v1 = variable_scope.get_variable("var1", [1])
# This will pause the current thread, and execute the other thread.
distribute_lib.get_tower_context().merge_call(lambda _: _)
- v2 = variable_scope.get_variable("var-thread2", [1])
+ v2 = variable_scope.get_variable(
+ "var2", [1],
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ aggregation=variable_scope.VariableAggregation.SUM)
+ v3 = variable_scope.get_variable(
+ "var3", [1],
+ synchronization=variable_scope.VariableSynchronization.ON_WRITE,
+ aggregation=variable_scope.VariableAggregation.MEAN)
- return v0, v1, v2
+ return v0, v1, v2, v3
devices = ["/device:CPU:0", "/device:GPU:0"]
dist = mirrored_strategy.MirroredStrategy(devices)
@@ -305,14 +360,78 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
self.assertEquals("main/var-main0:0", v.name)
result = dist.call_for_each_tower(model_fn, run_concurrently=False)
- self.assertEquals(3, len(result))
- v0, v1, v2 = result
+ self.assertEquals(4, len(result))
+ v0, v1, v2, v3 = result
self.assertIsInstance(v0, values.MirroredVariable)
- self.assertEquals("main/var-thread0:0", v0.name)
+ self.assertEquals("main/var0:0", v0.name)
self.assertIsInstance(v1, values.MirroredVariable)
- self.assertEquals("main/common/var-thread1:0", v1.name)
- self.assertIsInstance(v2, values.MirroredVariable)
- self.assertEquals("main/common/var-thread2:0", v2.name)
+ self.assertEquals("main/common/var1:0", v1.name)
+ self.assertIsInstance(v2, values.TowerLocalVariable)
+ self.assertEquals("main/common/var2:0", v2.name)
+ self.assertEquals(variable_scope.VariableAggregation.SUM,
+ v2.aggregation)
+ self.assertIsInstance(v3, values.MirroredVariable)
+ self.assertEquals("main/common/var3:0", v3.name)
+ self.assertEquals(variable_scope.VariableAggregation.MEAN,
+ v3.aggregation)
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testInvalidSynchronizationWithGetVariable(self):
+ self._skip_eager_if_gpus_less_than(1)
+ devices = ["/device:CPU:0", "/device:GPU:0"]
+ dist = mirrored_strategy.MirroredStrategy(devices)
+ with dist.scope():
+ with self.assertRaisesRegexp(
+ ValueError, "`NONE` variable synchronization mode is not "
+ "supported with `Mirrored` distribution strategy. Please change "
+ "the `synchronization` for variable: v"):
+ variable_scope.get_variable(
+ "v", [1],
+ synchronization=variable_scope.VariableSynchronization.NONE)
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testInvalidSynchronizationWithVariable(self):
+ self._skip_eager_if_gpus_less_than(1)
+ devices = ["/device:CPU:0", "/device:GPU:0"]
+ dist = mirrored_strategy.MirroredStrategy(devices)
+ with dist.scope():
+ with self.assertRaisesRegexp(
+ ValueError, "`NONE` variable synchronization mode is not "
+ "supported with `Mirrored` distribution strategy. Please change "
+ "the `synchronization` for variable: v"):
+ variable_scope.variable(
+ 1.0,
+ name="v",
+ synchronization=variable_scope.VariableSynchronization.NONE)
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testInvalidAggregationWithGetVariable(self):
+ self._skip_eager_if_gpus_less_than(1)
+ devices = ["/device:CPU:0", "/device:GPU:0"]
+ dist = mirrored_strategy.MirroredStrategy(devices)
+ with dist.scope():
+ with self.assertRaisesRegexp(
+ ValueError, "Invalid variable aggregation mode: invalid for "
+ "variable: v"):
+ variable_scope.get_variable(
+ "v", [1],
+ synchronization=variable_scope.VariableSynchronization.ON_WRITE,
+ aggregation="invalid")
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testInvalidAggregationWithVariable(self):
+ self._skip_eager_if_gpus_less_than(1)
+ devices = ["/device:CPU:0", "/device:GPU:0"]
+ dist = mirrored_strategy.MirroredStrategy(devices)
+ with dist.scope():
+ with self.assertRaisesRegexp(
+ ValueError, "Invalid variable aggregation mode: invalid for "
+ "variable: v"):
+ variable_scope.variable(
+ 1.0,
+ name="v",
+ synchronization=variable_scope.VariableSynchronization.ON_WRITE,
+ aggregation="invalid")
@test_util.run_in_graph_and_eager_modes(config=config)
def testThreeDevices(self):
@@ -362,9 +481,11 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
def model_fn(device_id):
tower_context = distribute_lib.get_tower_context()
- with tower_context.tower_local_var_scope("sum"):
+ with tower_context.tower_local_var_scope(
+ variable_scope.VariableAggregation.SUM):
v_sum = variable_scope.variable(1.0)
- with tower_context.tower_local_var_scope("mean"):
+ with tower_context.tower_local_var_scope(
+ variable_scope.VariableAggregation.MEAN):
v_mean = variable_scope.variable(4.0)
self.assertTrue(isinstance(v_sum, values.TowerLocalVariable))
self.assertTrue(isinstance(v_mean, values.TowerLocalVariable))
@@ -569,7 +690,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
def model_fn():
tower_context = distribute_lib.get_tower_context()
- with tower_context.tower_local_var_scope("sum"):
+ with tower_context.tower_local_var_scope(
+ variable_scope.VariableAggregation.SUM):
v_sum = variable_scope.variable(1.0)
self.assertTrue(isinstance(v_sum, values.TowerLocalVariable))
return v_sum
@@ -642,7 +764,8 @@ class MirroredVariableUpdateTest(test.TestCase):
# aggregation type.
self._skip_eager_if_gpus_less_than(1)
def var_fn():
- v = variable_scope.variable(1.0, name="foo")
+ v = variable_scope.variable(
+ 1.0, name="foo", aggregation=variable_scope.VariableAggregation.SUM)
return v
dist = mirrored_strategy.MirroredStrategy(
@@ -650,9 +773,6 @@ class MirroredVariableUpdateTest(test.TestCase):
with dist.scope():
mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False)
- # TODO(anjalisridhar): Use API introduced in cr/201463945 to set the
- # aggregation method.
- mirrored_var._aggregation_method = "sum"
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.evaluate(variables.global_variables_initializer())
@@ -661,7 +781,7 @@ class MirroredVariableUpdateTest(test.TestCase):
with self.assertRaisesRegexp(
ValueError, "A non PerDevice value cannot be reduced with the given "
- "method_string."):
+ "aggregation."):
self.evaluate(dist.unwrap(dist.call_for_each_tower(model_fn)))
@test_util.run_in_graph_and_eager_modes(config=config)
@@ -685,16 +805,14 @@ class MirroredVariableUpdateTest(test.TestCase):
def testAssignMirroredVarTowerContext(self):
self._skip_eager_if_gpus_less_than(1)
def var_fn():
- return variable_scope.variable(1.0, name="foo")
+ return variable_scope.variable(
+ 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN)
dist = mirrored_strategy.MirroredStrategy(
["/device:GPU:0", "/device:CPU:0"])
with dist.scope():
mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False)
- # TODO(anjalisridhar): Use API introduced in cr/201463945 to set the
- # aggregation method.
- mirrored_var._aggregation_method = "mean"
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.evaluate(variables.global_variables_initializer())
self.assertEquals(1.0, self.evaluate(mirrored_var))
@@ -729,16 +847,14 @@ class MirroredVariableUpdateTest(test.TestCase):
def testAssignAddMirroredVarTowerContext(self):
self._skip_eager_if_gpus_less_than(1)
def var_fn():
- return variable_scope.variable(1.0, name="foo")
+ return variable_scope.variable(
+ 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN)
dist = mirrored_strategy.MirroredStrategy(
["/device:GPU:0", "/device:CPU:0"])
with dist.scope():
mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False)
- # TODO(anjalisridhar): Use API introduced in cr/201463945 to set the
- # aggregation method.
- mirrored_var._aggregation_method = "mean"
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.evaluate(variables.global_variables_initializer())
self.assertEquals(1.0, self.evaluate(mirrored_var))
@@ -773,16 +889,14 @@ class MirroredVariableUpdateTest(test.TestCase):
def testAssignSubMirroredVarTowerContext(self):
self._skip_eager_if_gpus_less_than(1)
def var_fn():
- return variable_scope.variable(5.0, name="foo")
+ return variable_scope.variable(
+ 5.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN)
dist = mirrored_strategy.MirroredStrategy(
["/device:GPU:0", "/device:CPU:0"])
with dist.scope():
mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False)
- # TODO(anjalisridhar): Use API introduced in cr/201463945 to set the
- # aggregation method.
- mirrored_var._aggregation_method = "mean"
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.evaluate(variables.global_variables_initializer())
self.assertEquals(5.0, self.evaluate(mirrored_var))
diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py
index a580dac96c..dbd3514aec 100644
--- a/tensorflow/contrib/distribute/python/one_device_strategy.py
+++ b/tensorflow/contrib/distribute/python/one_device_strategy.py
@@ -24,6 +24,7 @@ from tensorflow.contrib.distribute.python import values
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.training import distribute as distribute_lib
@@ -43,11 +44,6 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy):
self._default_device = device
def _create_variable(self, next_creator, *args, **kwargs):
- # No need to distinguish tower-local variables when not mirroring,
- # we just enforce that they are not trainable.
- if kwargs.pop("tower_local_reduce_method", None) is not None:
- kwargs["trainable"] = False
-
colocate_with = kwargs.pop("colocate_with", None)
if colocate_with is None:
with ops.device(self._device):
@@ -80,15 +76,15 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy):
with ops.device(self._device):
return values.MapOutput([fn(m, *args, **kwargs) for m in map_over])
- def _reduce(self, method_string, value, destinations):
+ def _reduce(self, aggregation, value, destinations):
if not isinstance(value, values.MapOutput):
return value
l = value.get()
assert l
with ops.device(self._device):
- if method_string == "sum":
+ if aggregation == vs.VariableAggregation.SUM:
return math_ops.add_n(l)
- elif method_string == "mean":
+ elif aggregation == vs.VariableAggregation.MEAN:
return math_ops.add_n(l) / len(l)
else:
assert False
diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py
index d2fe8b3b1e..baed0ebaae 100644
--- a/tensorflow/contrib/distribute/python/strategy_test_lib.py
+++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py
@@ -26,6 +26,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.layers import core
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import optimizer
@@ -110,7 +111,8 @@ class DistributionTestBase(test.TestCase):
before_list.append(fetched)
# control_dependencies irrelevant but harmless in eager execution
with ops.control_dependencies([fetched]):
- g = d.reduce("sum", g, destinations=v)
+ g = d.reduce(
+ variable_scope.VariableAggregation.SUM, g, destinations=v)
with ops.control_dependencies(d.unwrap(d.update(v, update, g))):
after_list.append(d.read_var(v))
return before_list, after_list
@@ -162,7 +164,8 @@ class DistributionTestBase(test.TestCase):
fetched = d.read_var(v)
before_list.append(fetched)
with ops.control_dependencies([fetched]):
- g = d.reduce("sum", g, destinations=v)
+ g = d.reduce(
+ variable_scope.VariableAggregation.SUM, g, destinations=v)
with ops.control_dependencies(d.unwrap(d.update(v, update, g))):
after_list.append(d.read_var(v))
return before_list, after_list
@@ -184,7 +187,7 @@ class DistributionTestBase(test.TestCase):
with d.scope():
map_in = [constant_op.constant(i) for i in range(10)]
map_out = d.map(map_in, lambda x, y: x * y, 2)
- observed = d.reduce("sum", map_out)
+ observed = d.reduce(variable_scope.VariableAggregation.SUM, map_out)
expected = 90 # 2 * (0 + 1 + ... + 9)
self.assertEqual(expected, observed.numpy())
diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py
index 1ae12ae98a..bc53898539 100644
--- a/tensorflow/contrib/distribute/python/tpu_strategy.py
+++ b/tensorflow/contrib/distribute/python/tpu_strategy.py
@@ -29,6 +29,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.util import nest
@@ -137,9 +138,9 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
def get_finalize_ops(self):
return [tpu.shutdown_system()]
- def _reduce(self, method_string, value, destinations):
+ def _reduce(self, aggregation, value, destinations):
del destinations # TPU is graph mode only. Rely on implicit Send/Recv.
- if method_string == 'mean':
+ if aggregation == vs.VariableAggregation.MEAN:
# TODO(jhseu): Revisit once we support model-parallelism.
value *= (1. / self._num_cores_per_host)
return tpu_ops.cross_replica_sum(value)
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index 95390041f4..b36ac563d2 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -34,6 +34,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.training import device_util
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import saver
@@ -290,13 +291,13 @@ class MirroredVariable(DistributedVariable, Mirrored,
checkpointable.CheckpointableBase):
"""Holds a map from device to variables whose values are kept in sync."""
- def __init__(self, index, primary_var, aggregation_method=None):
+ def __init__(self, index, primary_var, aggregation):
# Use a weakref to make it easy to map from the contained values
# to the container without introducing a reference cycle.
for v in six.itervalues(index):
v._mirrored_container = weakref.ref(self) # pylint: disable=protected-access
self._primary_var = primary_var
- self._aggregation_method = aggregation_method
+ self._aggregation = aggregation
super(MirroredVariable, self).__init__(index)
# The arguments to update() are automatically unwrapped so the update()
@@ -325,17 +326,16 @@ class MirroredVariable(DistributedVariable, Mirrored,
# handle the different use cases can be found in the _reduce method.
# We call the function on each of the mirrored variables with the reduced
# value.
- if not self._aggregation_method:
+ if self._aggregation == vs.VariableAggregation.NONE:
raise ValueError("You must specify an aggregation method to update a "
"MirroredVariable in Tower Context.")
def merge_fn(strategy, value):
- return strategy.update(self,
- f,
- strategy.reduce(
- method_string=self._aggregation_method,
- value=value,
- destinations=self))
+ return strategy.update(
+ self, f,
+ strategy.reduce(
+ aggregation=self._aggregation, value=value, destinations=self))
+
return distribute_lib.get_tower_context().merge_call(merge_fn, *args,
**kwargs)
@@ -348,6 +348,10 @@ class MirroredVariable(DistributedVariable, Mirrored,
def assign(self, *args, **kwargs):
return self._assign_func(f=state_ops.assign, *args, **kwargs)
+ @property
+ def aggregation(self):
+ return self._aggregation
+
def _get_cross_tower(self):
device = device_util.canonicalize(device_util.current())
if device in self._index:
@@ -411,7 +415,7 @@ class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject):
# To preserve the sum across save and restore, we have to divide the
# total across all devices when restoring a variable that was summed
# when saving.
- if self._tower_local_variable.reduce_method == "sum":
+ if self._tower_local_variable.aggregation == vs.VariableAggregation.SUM:
tensor *= 1. / len(self._tower_local_variable.devices)
return control_flow_ops.group([
_assign_on_device(d, v, tensor)
@@ -428,9 +432,9 @@ class TowerLocalVariable(DistributedVariable, PerDevice,
checkpointable.CheckpointableBase):
"""Holds a map from device to variables whose values are reduced on save."""
- def __init__(self, index, primary_var, reduce_method):
+ def __init__(self, index, primary_var, aggregation):
self._primary_var = primary_var
- self._reduce_method = reduce_method
+ self._aggregation = aggregation
super(TowerLocalVariable, self).__init__(index)
def assign_sub(self, *args, **kwargs):
@@ -446,14 +450,14 @@ class TowerLocalVariable(DistributedVariable, PerDevice,
return self.get().assign(*args, **kwargs)
@property
- def reduce_method(self):
- return self._reduce_method
+ def aggregation(self):
+ return self._aggregation
def _get_cross_tower(self):
all_components = tuple(self._index.values())
# TODO(josh11b): Use a strategy-specific method.
total = math_ops.add_n(all_components)
- if self._reduce_method == "mean":
+ if self._aggregation == vs.VariableAggregation.MEAN:
return total * (1./ len(all_components))
return total
@@ -929,4 +933,3 @@ class MultiStepContext(object):
assert o.dtype == i.dtype, (
"Dtype {} of left {} doesn't match dtype {} of right {}.".
format(o.dtype, o, i.dtype, i))
-
diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py
index c5b246e804..8e44f2fea1 100644
--- a/tensorflow/contrib/distribute/python/values_test.py
+++ b/tensorflow/contrib/distribute/python/values_test.py
@@ -158,7 +158,8 @@ def _make_mirrored():
v.append(variable_scope.get_variable(
name=n, initializer=init, use_resource=True))
index[d] = v[-1]
- mirrored = values.MirroredVariable(index, v[0])
+ mirrored = values.MirroredVariable(index, v[0],
+ variable_scope.VariableAggregation.SUM)
return v, devices, mirrored
@@ -277,7 +278,8 @@ class RegroupAndSelectDeviceTest(test.TestCase):
v = variable_scope.get_variable(
name="v", initializer=1., use_resource=True)
index = {d: v}
- mirrored = values.MirroredVariable(index, v)
+ mirrored = values.MirroredVariable(index, v,
+ variable_scope.VariableAggregation.SUM)
result = values.regroup(index)
self.assertIs(mirrored, result)
@@ -581,7 +583,8 @@ class MirroredVariableTest(test.TestCase):
v = variable_scope.get_variable(
name="v", initializer=[1.], use_resource=True)
index = {"/job:foo/device:CPU:0": v}
- mirrored = values.MirroredVariable(index, v)
+ mirrored = values.MirroredVariable(index, v,
+ variable_scope.VariableAggregation.MEAN)
self.assertEquals(v.name, mirrored.name)
self.assertEquals(v.dtype, mirrored.dtype)
@@ -716,7 +719,9 @@ class MirroredVariableTest(test.TestCase):
with ops.device("/device:GPU:0"):
v = variable_scope.get_variable(
name="v", initializer=1., use_resource=True)
- mirrored = values.MirroredVariable({"/device:GPU:0": v}, v)
+ mirrored = values.MirroredVariable({
+ "/device:GPU:0": v
+ }, v, variable_scope.VariableAggregation.MEAN)
sess.run(variables_lib.global_variables_initializer())
sess.run({"complicated": mirrored})
@@ -746,24 +751,27 @@ class TowerLocalVariableTest(test.TestCase):
if context.num_gpus() < 1 and context.executing_eagerly():
self.skipTest("A GPU is not available for this test in eager mode.")
- v, tower_local = _make_tower_local("sum")
+ v, tower_local = _make_tower_local(variable_scope.VariableAggregation.SUM)
self.assertEquals(v[0].name, tower_local.name)
self.assertEquals(v[0].dtype, tower_local.dtype)
self.assertEquals(v[0].shape, tower_local.shape)
- self.assertEquals("sum", tower_local.reduce_method)
+ self.assertEquals(variable_scope.VariableAggregation.SUM,
+ tower_local.aggregation)
@test_util.run_in_graph_and_eager_modes(config=config)
def testVariableOnAnotherDevice(self):
v = variable_scope.get_variable(
name="v", initializer=[1.], use_resource=True)
index = {"/job:foo/device:CPU:0": v}
- tower_local = values.TowerLocalVariable(index, v, "mean")
+ tower_local = values.TowerLocalVariable(
+ index, v, variable_scope.VariableAggregation.MEAN)
self.assertEquals(v.name, tower_local.name)
self.assertEquals(v.dtype, tower_local.dtype)
self.assertEquals(v.shape, tower_local.shape)
- self.assertEquals("mean", tower_local.reduce_method)
+ self.assertEquals(variable_scope.VariableAggregation.MEAN,
+ tower_local.aggregation)
def _assign_tower_local(self, devices, v, new):
for d, var, n in zip(devices, v, new):
@@ -789,7 +797,7 @@ class TowerLocalVariableTest(test.TestCase):
self.skipTest("A GPU is not available for this test in eager mode.")
with self.test_session() as sess:
- v, tower_local = _make_tower_local("sum")
+ v, tower_local = _make_tower_local(variable_scope.VariableAggregation.SUM)
# Overwrite the initial values.
self._assign_tower_local(_devices, v, [3., 4.])
@@ -812,7 +820,8 @@ class TowerLocalVariableTest(test.TestCase):
self.skipTest("A GPU is not available for this test in eager mode.")
with self.test_session() as sess:
- v, tower_local = _make_tower_local("mean")
+ v, tower_local = _make_tower_local(
+ variable_scope.VariableAggregation.MEAN)
# Overwrite the initial values.
self._assign_tower_local(_devices, v, [3., 4.])
@@ -831,7 +840,8 @@ class TowerLocalVariableTest(test.TestCase):
def _save_tower_local_mean(self):
"""Save variables with mirroring, returns save_path."""
with self.test_session(graph=ops.Graph()) as sess:
- v, tower_local = _make_tower_local("mean")
+ v, tower_local = _make_tower_local(
+ variable_scope.VariableAggregation.MEAN)
# Overwrite the initial values.
self._assign_tower_local(_devices, v, [3., 4.])
@@ -893,7 +903,8 @@ class TowerLocalVariableTest(test.TestCase):
def _restore_tower_local_mean(self, save_path):
"""Restore to variables with mirroring in a fresh graph."""
with self.test_session(graph=ops.Graph()) as sess:
- v, tower_local = _make_tower_local("mean")
+ v, tower_local = _make_tower_local(
+ variable_scope.VariableAggregation.MEAN)
# Overwrite the initial values.
self._assign_tower_local(_devices, v, [7., 8.])
@@ -907,7 +918,7 @@ class TowerLocalVariableTest(test.TestCase):
def _restore_tower_local_sum(self, save_path):
"""Restore to variables with mirroring in a fresh graph."""
with self.test_session(graph=ops.Graph()) as sess:
- v, tower_local = _make_tower_local("sum")
+ v, tower_local = _make_tower_local(variable_scope.VariableAggregation.SUM)
# Overwrite the initial values.
self._assign_tower_local(_devices, v, [7., 8.])
@@ -968,7 +979,7 @@ class TowerLocalVariableTest(test.TestCase):
def testTensorConversion(self):
with context.graph_mode():
- _, tower_local = _make_tower_local("sum")
+ _, tower_local = _make_tower_local(variable_scope.VariableAggregation.SUM)
converted = ops.internal_convert_to_tensor(tower_local, as_ref=False)
self.assertIsInstance(converted, ops.Tensor)
self.assertEqual(converted.dtype, tower_local.dtype)
diff --git a/tensorflow/contrib/eager/python/examples/revnet/BUILD b/tensorflow/contrib/eager/python/examples/revnet/BUILD
index 432bb546f8..16636620a5 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/BUILD
+++ b/tensorflow/contrib/eager/python/examples/revnet/BUILD
@@ -72,6 +72,7 @@ cuda_py_test(
size = "large",
srcs = ["revnet_test.py"],
additional_deps = [
+ ":blocks_test",
":config",
":revnet",
"//tensorflow:tensorflow_py",
@@ -87,7 +88,6 @@ py_library(
srcs = ["cifar_input.py"],
srcs_version = "PY2AND3",
deps = [
- ":revnet",
"//tensorflow:tensorflow_py",
],
)
diff --git a/tensorflow/contrib/eager/python/examples/revnet/blocks.py b/tensorflow/contrib/eager/python/examples/revnet/blocks.py
index 74c1825a49..306096e9f8 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/blocks.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/blocks.py
@@ -24,7 +24,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import six
import tensorflow as tf
from tensorflow.contrib.eager.python.examples.revnet import ops
@@ -44,7 +43,8 @@ class RevBlock(tf.keras.Model):
batch_norm_first=False,
data_format="channels_first",
bottleneck=False,
- fused=True):
+ fused=True,
+ dtype=tf.float32):
"""Initialize RevBlock.
Args:
@@ -56,6 +56,7 @@ class RevBlock(tf.keras.Model):
data_format: tensor data format, "NCHW"/"NHWC"
bottleneck: use bottleneck residual if True
fused: use fused batch normalization if True
+ dtype: float16, float32, or float64
"""
super(RevBlock, self).__init__()
self.blocks = tf.contrib.checkpoint.List()
@@ -69,7 +70,8 @@ class RevBlock(tf.keras.Model):
batch_norm_first=curr_batch_norm_first,
data_format=data_format,
bottleneck=bottleneck,
- fused=fused)
+ fused=fused,
+ dtype=dtype)
self.blocks.append(block)
if data_format == "channels_first":
@@ -95,19 +97,22 @@ class RevBlock(tf.keras.Model):
for i in reversed(range(len(self.blocks))):
block = self.blocks[i]
if i == 0:
- y_inv = x
+ # First block usually contains downsampling that can't be reversed
+ with tf.GradientTape() as tape:
+ x = tf.identity(x)
+ tape.watch(x)
+ y = block(x, training=training)
+
+ grads_combined = tape.gradient(
+ y, [x] + block.trainable_variables, output_gradients=dy)
+ dy = grads_combined[0]
+ grads_all += grads_combined[1:]
+ vars_all += block.trainable_variables
else:
- # Don't update running stats when reconstructing activations
- vars_and_vals = block.get_moving_stats()
- y_inv = block.backward(y, training=training)
- block.restore_moving_stats(vars_and_vals)
-
- # Update running stats when computing gradients during training
- dy, grads, vars_ = block.backward_grads_and_vars(
- y_inv, dy, training=training)
-
- grads_all += grads
- vars_all += vars_
+ y, dy, grads, vars_ = block.backward_grads_and_vars(
+ y, dy, training=training)
+ grads_all += grads
+ vars_all += vars_
return dy, grads_all, vars_all
@@ -125,6 +130,7 @@ class _Residual(tf.keras.Model):
data_format: tensor data format, "NCHW"/"NHWC",
bottleneck: use bottleneck residual if True
fused: use fused batch normalization if True
+ dtype: float16, float32, or float64
"""
def __init__(self,
@@ -134,7 +140,8 @@ class _Residual(tf.keras.Model):
batch_norm_first=True,
data_format="channels_first",
bottleneck=False,
- fused=True):
+ fused=True,
+ dtype=tf.float32):
super(_Residual, self).__init__()
self.filters = filters
@@ -156,21 +163,22 @@ class _Residual(tf.keras.Model):
input_shape=f_input_shape,
batch_norm_first=batch_norm_first,
data_format=data_format,
- fused=fused)
+ fused=fused,
+ dtype=dtype)
self.g = factory(
filters=filters // 2,
strides=(1, 1),
input_shape=g_input_shape,
batch_norm_first=batch_norm_first,
data_format=data_format,
- fused=fused)
+ fused=fused,
+ dtype=dtype)
def call(self, x, training=True, concat=True):
"""Apply residual block to inputs."""
x1, x2 = tf.split(x, num_or_size_splits=2, axis=self.axis)
f_x2 = self.f(x2, training=training)
- # TODO(lxuechen): Replace with simpler downsampling
x1_down = ops.downsample(
x1, self.filters // 2, self.strides, axis=self.axis)
x2_down = ops.downsample(
@@ -183,65 +191,40 @@ class _Residual(tf.keras.Model):
return tf.concat([y1, y2], axis=self.axis)
- def backward(self, y, training=True):
- """Reconstruct inputs from outputs; only valid when stride 1."""
-
- assert self.strides == (1, 1)
-
- y1, y2 = tf.split(y, num_or_size_splits=2, axis=self.axis)
- g_y1 = self.g(y1, training=training)
- x2 = y2 - g_y1
- f_x2 = self.f(x2, training=training)
- x1 = y1 - f_x2
-
- return tf.concat([x1, x2], axis=self.axis)
-
- def backward_grads_and_vars(self, x, dy, training=True):
+ def backward_grads_and_vars(self, y, dy, training=True):
"""Manually compute backward gradients given input and output grads."""
+ dy1, dy2 = tf.split(dy, num_or_size_splits=2, axis=self.axis)
with tf.GradientTape(persistent=True) as tape:
- x = tf.identity(x) # TODO(lxuechen): Remove after b/110264016 is fixed
- x1, x2 = tf.split(x, num_or_size_splits=2, axis=self.axis)
- tape.watch([x1, x2])
- # Stitch back x for `call` so tape records correct grads
- x = tf.concat([x1, x2], axis=self.axis)
- dy1, dy2 = tf.split(dy, num_or_size_splits=2, axis=self.axis)
- y1, y2 = self.call(x, training=training, concat=False)
- x2_down = ops.downsample(
- x2, self.filters // 2, self.strides, axis=self.axis)
+ y = tf.identity(y)
+ tape.watch(y)
+ y1, y2 = tf.split(y, num_or_size_splits=2, axis=self.axis)
+ z1 = y1
+ gz1 = self.g(z1, training=training)
+ x2 = y2 - gz1
+ fx2 = self.f(x2, training=training)
+ x1 = z1 - fx2
grads_combined = tape.gradient(
- y2, [y1] + self.g.trainable_variables, output_gradients=[dy2])
- dy2_y1, dg = grads_combined[0], grads_combined[1:]
- dy1_plus = dy2_y1 + dy1
+ gz1, [z1] + self.g.trainable_variables, output_gradients=dy2)
+ dz1 = dy1 + grads_combined[0]
+ dg = grads_combined[1:]
+ dx1 = dz1
grads_combined = tape.gradient(
- y1, [x1, x2] + self.f.trainable_variables, output_gradients=[dy1_plus])
- dx1, dx2, df = grads_combined[0], grads_combined[1], grads_combined[2:]
- dx2 += tape.gradient(x2_down, [x2], output_gradients=[dy2])[0]
+ fx2, [x2] + self.f.trainable_variables, output_gradients=dz1)
+ dx2 = dy2 + grads_combined[0]
+ df = grads_combined[1:]
del tape
grads = df + dg
vars_ = self.f.trainable_variables + self.g.trainable_variables
- return tf.concat([dx1, dx2], axis=self.axis), grads, vars_
+ x = tf.concat([x1, x2], axis=self.axis)
+ dx = tf.concat([dx1, dx2], axis=self.axis)
- def get_moving_stats(self):
- vars_and_vals = {}
-
- def _is_moving_var(v): # pylint: disable=invalid-name
- n = v.name
- return n.endswith("moving_mean:0") or n.endswith("moving_variance:0")
-
- for v in filter(_is_moving_var, self.f.variables + self.g.variables):
- vars_and_vals[v] = v.read_value()
-
- return vars_and_vals
-
- def restore_moving_stats(self, vars_and_vals):
- for var_, val in six.iteritems(vars_and_vals):
- var_.assign(val)
+ return x, dx, grads, vars_
def _BottleneckResidualInner(filters,
@@ -249,7 +232,8 @@ def _BottleneckResidualInner(filters,
input_shape,
batch_norm_first=True,
data_format="channels_first",
- fused=True):
+ fused=True,
+ dtype=tf.float32):
"""Single bottleneck residual inner function contained in _Resdual.
Corresponds to the `F`/`G` functions in the paper.
@@ -262,6 +246,7 @@ def _BottleneckResidualInner(filters,
batch_norm_first: whether to apply activation and batch norm before conv
data_format: tensor data format, "NCHW"/"NHWC"
fused: use fused batch normalization if True
+ dtype: float16, float32, or float64
Returns:
A keras model
@@ -272,7 +257,7 @@ def _BottleneckResidualInner(filters,
if batch_norm_first:
model.add(
tf.keras.layers.BatchNormalization(
- axis=axis, input_shape=input_shape, fused=fused))
+ axis=axis, input_shape=input_shape, fused=fused, dtype=dtype))
model.add(tf.keras.layers.Activation("relu"))
model.add(
tf.keras.layers.Conv2D(
@@ -282,9 +267,11 @@ def _BottleneckResidualInner(filters,
input_shape=input_shape,
data_format=data_format,
use_bias=False,
- padding="SAME"))
+ padding="SAME",
+ dtype=dtype))
- model.add(tf.keras.layers.BatchNormalization(axis=axis, fused=fused))
+ model.add(
+ tf.keras.layers.BatchNormalization(axis=axis, fused=fused, dtype=dtype))
model.add(tf.keras.layers.Activation("relu"))
model.add(
tf.keras.layers.Conv2D(
@@ -293,9 +280,11 @@ def _BottleneckResidualInner(filters,
strides=(1, 1),
data_format=data_format,
use_bias=False,
- padding="SAME"))
+ padding="SAME",
+ dtype=dtype))
- model.add(tf.keras.layers.BatchNormalization(axis=axis, fused=fused))
+ model.add(
+ tf.keras.layers.BatchNormalization(axis=axis, fused=fused, dtype=dtype))
model.add(tf.keras.layers.Activation("relu"))
model.add(
tf.keras.layers.Conv2D(
@@ -304,7 +293,8 @@ def _BottleneckResidualInner(filters,
strides=(1, 1),
data_format=data_format,
use_bias=False,
- padding="SAME"))
+ padding="SAME",
+ dtype=dtype))
return model
@@ -314,7 +304,8 @@ def _ResidualInner(filters,
input_shape,
batch_norm_first=True,
data_format="channels_first",
- fused=True):
+ fused=True,
+ dtype=tf.float32):
"""Single residual inner function contained in _ResdualBlock.
Corresponds to the `F`/`G` functions in the paper.
@@ -326,6 +317,7 @@ def _ResidualInner(filters,
batch_norm_first: whether to apply activation and batch norm before conv
data_format: tensor data format, "NCHW"/"NHWC"
fused: use fused batch normalization if True
+ dtype: float16, float32, or float64
Returns:
A keras model
@@ -336,7 +328,7 @@ def _ResidualInner(filters,
if batch_norm_first:
model.add(
tf.keras.layers.BatchNormalization(
- axis=axis, input_shape=input_shape, fused=fused))
+ axis=axis, input_shape=input_shape, fused=fused, dtype=dtype))
model.add(tf.keras.layers.Activation("relu"))
model.add(
tf.keras.layers.Conv2D(
@@ -346,9 +338,11 @@ def _ResidualInner(filters,
input_shape=input_shape,
data_format=data_format,
use_bias=False,
- padding="SAME"))
+ padding="SAME",
+ dtype=dtype))
- model.add(tf.keras.layers.BatchNormalization(axis=axis, fused=fused))
+ model.add(
+ tf.keras.layers.BatchNormalization(axis=axis, fused=fused, dtype=dtype))
model.add(tf.keras.layers.Activation("relu"))
model.add(
tf.keras.layers.Conv2D(
@@ -357,6 +351,7 @@ def _ResidualInner(filters,
strides=(1, 1),
data_format=data_format,
use_bias=False,
- padding="SAME"))
+ padding="SAME",
+ dtype=dtype))
return model
diff --git a/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py b/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py
index a28ca6e3e0..d74785c8fe 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py
@@ -22,6 +22,27 @@ import tensorflow as tf
from tensorflow.contrib.eager.python.examples.revnet import blocks
+def compute_degree(g1, g2, eps=1e-7):
+ """Compute the degree between two vectors using their usual inner product."""
+
+ def _dot(u, v):
+ return tf.reduce_sum(u * v)
+
+ g1_norm = tf.sqrt(_dot(g1, g1))
+ g2_norm = tf.sqrt(_dot(g2, g2))
+ if g1_norm.numpy() == 0 and g2_norm.numpy() == 0:
+ cosine = 1. - eps
+ else:
+ g1_norm = 1. if g1_norm.numpy() == 0 else g1_norm
+ g2_norm = 1. if g2_norm.numpy() == 0 else g2_norm
+ cosine = _dot(g1, g2) / g1_norm / g2_norm
+ # Restrict to arccos range
+ cosine = tf.minimum(tf.maximum(cosine, eps - 1.), 1. - eps)
+ degree = tf.acos(cosine) * 180. / 3.141592653589793
+
+ return degree
+
+
def _validate_block_call_channels_last(block_factory, test):
"""Generic testing function for `channels_last` data format.
@@ -33,30 +54,30 @@ def _validate_block_call_channels_last(block_factory, test):
test: tf.test.TestCase object
"""
with tf.device("/cpu:0"): # NHWC format
- input_shape = (224, 224, 32)
+ input_shape = (8, 8, 128)
data_shape = (16,) + input_shape
x = tf.random_normal(shape=data_shape)
# Stride 1
block = block_factory(
- filters=64,
+ filters=128,
strides=(1, 1),
input_shape=input_shape,
data_format="channels_last")
y_tr, y_ev = block(x, training=True), block(x, training=False)
test.assertEqual(y_tr.shape, y_ev.shape)
- test.assertEqual(y_ev.shape, (16, 224, 224, 64))
+ test.assertEqual(y_ev.shape, (16, 8, 8, 128))
test.assertNotAllClose(y_tr, y_ev)
# Stride of 2
block = block_factory(
- filters=64,
+ filters=128,
strides=(2, 2),
input_shape=input_shape,
data_format="channels_last")
y_tr, y_ev = block(x, training=True), block(x, training=False)
test.assertEqual(y_tr.shape, y_ev.shape)
- test.assertEqual(y_ev.shape, (16, 112, 112, 64))
+ test.assertEqual(y_ev.shape, (16, 4, 4, 128))
test.assertNotAllClose(y_tr, y_ev)
@@ -74,22 +95,22 @@ def _validate_block_call_channels_first(block_factory, test):
test.skipTest("GPU not available")
with tf.device("/gpu:0"): # Default NCHW format
- input_shape = (32, 224, 224)
+ input_shape = (128, 8, 8)
data_shape = (16,) + input_shape
x = tf.random_normal(shape=data_shape)
# Stride of 1
- block = block_factory(filters=64, strides=(1, 1), input_shape=input_shape)
+ block = block_factory(filters=128, strides=(1, 1), input_shape=input_shape)
y_tr, y_ev = block(x, training=True), block(x, training=False)
test.assertEqual(y_tr.shape, y_ev.shape)
- test.assertEqual(y_ev.shape, (16, 64, 224, 224))
+ test.assertEqual(y_ev.shape, (16, 128, 8, 8))
test.assertNotAllClose(y_tr, y_ev)
# Stride of 2
- block = block_factory(filters=64, strides=(2, 2), input_shape=input_shape)
+ block = block_factory(filters=128, strides=(2, 2), input_shape=input_shape)
y_tr, y_ev = block(x, training=True), block(x, training=False)
test.assertEqual(y_tr.shape, y_ev.shape)
- test.assertEqual(y_ev.shape, (16, 64, 112, 112))
+ test.assertEqual(y_ev.shape, (16, 128, 4, 4))
test.assertNotAllClose(y_tr, y_ev)
@@ -101,121 +122,116 @@ class RevBlockTest(tf.test.TestCase):
self.skipTest("GPU not available")
with tf.device("/gpu:0"): # Default NCHW format
- input_shape = (32, 224, 224)
+ input_shape = (128, 8, 8)
data_shape = (16,) + input_shape
x = tf.random_normal(shape=data_shape)
# Stride of 1
block = blocks.RevBlock(
- n_res=3, filters=64, strides=(1, 1), input_shape=input_shape)
+ n_res=3, filters=128, strides=(1, 1), input_shape=input_shape)
y_tr, y_ev = block(x, training=True), block(x, training=False)
self.assertEqual(y_tr.shape, y_ev.shape)
- self.assertEqual(y_ev.shape, (16, 64, 224, 224))
+ self.assertEqual(y_ev.shape, (16, 128, 8, 8))
self.assertNotAllClose(y_tr, y_ev)
# Stride of 2
block = blocks.RevBlock(
- n_res=3, filters=64, strides=(2, 2), input_shape=input_shape)
+ n_res=3, filters=128, strides=(2, 2), input_shape=input_shape)
y_tr, y_ev = block(x, training=True), block(x, training=False)
self.assertEqual(y_tr.shape, y_ev.shape)
- self.assertEqual(y_ev.shape, [16, 64, 112, 112])
+ self.assertEqual(y_ev.shape, [16, 128, 4, 4])
self.assertNotAllClose(y_tr, y_ev)
def test_call_channels_last(self):
"""Test `call` function with `channels_last` data format."""
with tf.device("/cpu:0"): # NHWC format
- input_shape = (224, 224, 32)
+ input_shape = (8, 8, 128)
data_shape = (16,) + input_shape
x = tf.random_normal(shape=data_shape)
# Stride 1
block = blocks.RevBlock(
n_res=3,
- filters=64,
+ filters=128,
strides=(1, 1),
input_shape=input_shape,
data_format="channels_last")
y_tr, y_ev = block(x, training=True), block(x, training=False)
self.assertEqual(y_tr.shape, y_ev.shape)
- self.assertEqual(y_ev.shape, (16, 224, 224, 64))
+ self.assertEqual(y_ev.shape, (16, 8, 8, 128))
self.assertNotAllClose(y_tr, y_ev)
# Stride of 2
block = blocks.RevBlock(
n_res=3,
- filters=64,
+ filters=128,
strides=(2, 2),
input_shape=input_shape,
data_format="channels_last")
y_tr, y_ev = block(x, training=True), block(x, training=False)
self.assertEqual(y_tr.shape, y_ev.shape)
- self.assertEqual(y_ev.shape, (16, 112, 112, 64))
+ self.assertEqual(y_ev.shape, (16, 4, 4, 128))
self.assertNotAllClose(y_tr, y_ev)
+ def _check_grad_angle(self, grads, grads_true, atol=1e0):
+ """Check the angle between two list of vectors are all close."""
+ for g1, g2 in zip(grads, grads_true):
+ degree = compute_degree(g1, g2)
+ self.assertLessEqual(degree, atol)
+
def test_backward_grads_and_vars_channels_first(self):
"""Test `backward` function with `channels_first` data format."""
if not tf.test.is_gpu_available():
self.skipTest("GPU not available")
with tf.device("/gpu:0"): # Default NCHW format
- input_shape = (32, 224, 224)
- data_shape = (16,) + input_shape
- x = tf.random_normal(shape=data_shape)
-
# Stride 1
- y = tf.random_normal(shape=data_shape)
- dy = tf.random_normal(shape=data_shape)
- block = blocks.RevBlock(
- n_res=3, filters=32, strides=(1, 1), input_shape=input_shape)
- dy, grads, vars_ = block.backward_grads_and_vars(x, y, dy)
- self.assertEqual(dy.shape, x.shape)
- self.assertTrue(isinstance(grads, list))
- self.assertTrue(isinstance(vars_, list))
-
- # Stride 2
- y = tf.random_normal(shape=(16, 32, 112, 112))
- dy = tf.random_normal(shape=(16, 32, 112, 112))
- block = blocks.RevBlock(
- n_res=3, filters=32, strides=(2, 2), input_shape=input_shape)
- dy, grads, vars_ = block.backward_grads_and_vars(x, y, dy)
- self.assertEqual(dy.shape, x.shape)
- self.assertTrue(isinstance(grads, list))
- self.assertTrue(isinstance(vars_, list))
-
- def test_backward_grads_and_vars_channels_last(self):
- """Test `backward` function with `channels_last` data format."""
- with tf.device("/cpu:0"): # NHWC format
- input_shape = (224, 224, 32)
+ input_shape = (128, 8, 8)
data_shape = (16,) + input_shape
- x = tf.random_normal(shape=data_shape)
-
- # Stride 1
- y = tf.random_normal(shape=data_shape)
- dy = tf.random_normal(shape=data_shape)
+ x = tf.random_normal(shape=data_shape, dtype=tf.float64)
+ dy = tf.random_normal(shape=data_shape, dtype=tf.float64)
block = blocks.RevBlock(
n_res=3,
- filters=32,
+ filters=128,
strides=(1, 1),
input_shape=input_shape,
- data_format="channels_last")
- dy, grads, vars_ = block.backward_grads_and_vars(x, y, dy)
- self.assertEqual(dy.shape, x.shape)
- self.assertTrue(isinstance(grads, list))
- self.assertTrue(isinstance(vars_, list))
+ fused=False,
+ dtype=tf.float64)
+ with tf.GradientTape() as tape:
+ tape.watch(x)
+ y = block(x, training=True)
+ # Compute grads from reconstruction
+ dx, dw, vars_ = block.backward_grads_and_vars(x, y, dy, training=True)
+ # Compute true grads
+ grads = tape.gradient(y, [x] + vars_, output_gradients=dy)
+ dx_true, dw_true = grads[0], grads[1:]
+ self.assertAllClose(dx_true, dx)
+ self.assertAllClose(dw_true, dw)
+ self._check_grad_angle(dx_true, dx)
+ self._check_grad_angle(dw_true, dw)
# Stride 2
- y = tf.random_normal(shape=(16, 112, 112, 32))
- dy = tf.random_normal(shape=(16, 112, 112, 32))
+ x = tf.random_normal(shape=data_shape, dtype=tf.float64)
+ dy = tf.random_normal(shape=(16, 128, 4, 4), dtype=tf.float64)
block = blocks.RevBlock(
n_res=3,
- filters=32,
+ filters=128,
strides=(2, 2),
input_shape=input_shape,
- data_format="channels_last")
- dy, grads, vars_ = block.backward_grads_and_vars(x, y, dy)
- self.assertEqual(dy.shape, x.shape)
- self.assertTrue(isinstance(grads, list))
- self.assertTrue(isinstance(vars_, list))
+ fused=False,
+ dtype=tf.float64)
+ with tf.GradientTape() as tape:
+ tape.watch(x)
+ y = block(x, training=True)
+ # Compute grads from reconstruction
+ dx, dw, vars_ = block.backward_grads_and_vars(x, y, dy, training=True)
+ # Compute true grads
+ grads = tape.gradient(y, [x] + vars_, output_gradients=dy)
+ dx_true, dw_true = grads[0], grads[1:]
+ self.assertAllClose(dx_true, dx)
+ self.assertAllClose(dw_true, dw)
+ self._check_grad_angle(dx_true, dx)
+ self._check_grad_angle(dw_true, dw)
class _ResidualTest(tf.test.TestCase):
@@ -229,112 +245,40 @@ class _ResidualTest(tf.test.TestCase):
_validate_block_call_channels_first(blocks._Residual, self)
_validate_block_call_channels_last(blocks._Residual, self)
- def test_backward_channels_first(self):
- """Test `backward` function with `channels_first` data format."""
- if not tf.test.is_gpu_available():
- self.skipTest("GPU not available")
-
- with tf.device("/gpu:0"): # Default NCHW format
- input_shape = (16, 224, 224)
- data_shape = (16,) + input_shape
- x = tf.random_normal(shape=data_shape)
- residual = blocks._Residual(
- filters=16, strides=(1, 1), input_shape=input_shape)
-
- y_tr, y_ev = residual(x, training=True), residual(x, training=False)
- x_ = residual.backward(y_ev, training=False)
- self.assertAllClose(x, x_, rtol=1e-1, atol=1e-1)
- x_ = residual.backward(y_tr, training=True) # This updates moving avg
- self.assertAllClose(x, x_, rtol=1e-1, atol=1e-1)
-
- def test_backward_channels_last(self):
- """Test `backward` function with `channels_last` data format."""
- with tf.device("/cpu:0"): # NHWC format
- input_shape = (224, 224, 16)
- data_shape = (16,) + input_shape
- x = tf.random_normal(shape=data_shape)
- residual = blocks._Residual(
- filters=16,
- strides=(1, 1),
- input_shape=input_shape,
- data_format="channels_last")
-
- y_tr, y_ev = residual(x, training=True), residual(x, training=False)
- x_ = residual.backward(y_ev, training=False)
- self.assertAllClose(x, x_, rtol=1e-1, atol=1e-1)
- x_ = residual.backward(y_tr, training=True) # This updates moving avg
- self.assertAllClose(x, x_, rtol=1e-1, atol=1e-1)
-
def test_backward_grads_and_vars_channels_first(self):
"""Test `backward_grads` function with `channels_first` data format."""
if not tf.test.is_gpu_available():
self.skipTest("GPU not available")
with tf.device("/gpu:0"): # Default NCHW format
- input_shape = (16, 224, 224)
+ input_shape = (128, 8, 8)
data_shape = (16,) + input_shape
- x = tf.random_normal(shape=data_shape)
- dy = tf.random_normal(shape=data_shape)
+ # Use double precision for testing
+ x_true = tf.random_normal(shape=data_shape, dtype=tf.float64)
+ dy = tf.random_normal(shape=data_shape, dtype=tf.float64)
residual = blocks._Residual(
- filters=16, strides=(1, 1), input_shape=input_shape)
-
- vars_and_vals = residual.get_moving_stats()
- dx_tr, grads_tr, vars_tr = residual.backward_grads_and_vars(
- x, dy=dy, training=True)
- dx_ev, grads_ev, vars_ev = residual.backward_grads_and_vars(
- x, dy=dy, training=False)
- self.assertNotAllClose(dx_tr, dx_ev)
- self.assertTrue(isinstance(grads_tr, list))
- self.assertTrue(isinstance(grads_ev, list))
- self.assertTrue(isinstance(vars_tr, list))
- self.assertTrue(isinstance(vars_ev, list))
- for grad_tr, var_tr, grad_ev, var_ev in zip(grads_tr, vars_tr, grads_ev,
- vars_ev):
- self.assertEqual(grad_tr.shape, grad_ev.shape)
- self.assertEqual(var_tr.shape, var_ev.shape)
- self.assertEqual(grad_tr.shape, var_tr.shape)
-
- # Compare against the true gradient computed by the tape
- residual.restore_moving_stats(vars_and_vals)
- with tf.GradientTape(persistent=True) as tape:
- tape.watch(x)
- y = residual(x, training=True)
- grads = tape.gradient(
- y, [x] + residual.trainable_variables, output_gradients=[dy])
- dx_tr_true, grads_tr_true = grads[0], grads[1:]
+ filters=128,
+ strides=(1, 1),
+ input_shape=input_shape,
+ fused=False,
+ dtype=tf.float64)
- del tape
+ with tf.GradientTape() as tape:
+ x_true = tf.identity(x_true)
+ tape.watch(x_true)
+ y = residual(x_true, training=True)
- self.assertAllClose(dx_tr, dx_tr_true, rtol=1e-1, atol=1e-1)
- self.assertAllClose(grads_tr, grads_tr_true, rtol=1e-1, atol=1e-1)
+ # Gradients computed due to reversibility
+ x, dx, dw, vars_ = residual.backward_grads_and_vars(
+ y, dy=dy, training=True)
- def test_backward_grads_and_vars_channels_last(self):
- """Test `backward_grads` function with `channels_last` data format."""
- with tf.device("/cpu:0"): # NHWC format
- input_shape = (224, 224, 16)
- data_shape = (16,) + input_shape
- x = tf.random_normal(shape=data_shape)
- dy = tf.random_normal(shape=data_shape)
- residual = blocks._Residual(
- filters=16,
- strides=(1, 1),
- input_shape=input_shape,
- data_format="channels_last")
+ # True gradients computed by the tape
+ grads = tape.gradient(y, [x_true] + vars_, output_gradients=dy)
+ dx_true, dw_true = grads[0], grads[1:]
- dx_tr, grads_tr, vars_tr = residual.backward_grads_and_vars(
- x, dy=dy, training=True)
- dx_ev, grads_ev, vars_ev = residual.backward_grads_and_vars(
- x, dy=dy, training=False)
- self.assertNotAllClose(dx_tr, dx_ev)
- self.assertTrue(isinstance(grads_tr, list))
- self.assertTrue(isinstance(grads_ev, list))
- self.assertTrue(isinstance(vars_tr, list))
- self.assertTrue(isinstance(vars_ev, list))
- for grad_tr, var_tr, grad_ev, var_ev in zip(grads_tr, vars_tr, grads_ev,
- vars_ev):
- self.assertEqual(grad_tr.shape, grad_ev.shape)
- self.assertEqual(var_tr.shape, var_ev.shape)
- self.assertEqual(grad_tr.shape, var_tr.shape)
+ self.assertAllClose(x_true, x)
+ self.assertAllClose(dx_true, dx)
+ self.assertAllClose(dw_true, dw)
class _ResidualInnerTest(tf.test.TestCase):
diff --git a/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py b/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py
index e1d8b3a055..b6d4c35bfd 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py
@@ -35,7 +35,7 @@ def get_ds_from_tfrecords(data_dir,
epochs=None,
shuffle=True,
data_format="channels_first",
- num_parallel_calls=8,
+ num_parallel_calls=12,
prefetch=0,
div255=True,
dtype=tf.float32):
diff --git a/tensorflow/contrib/eager/python/examples/revnet/cifar_tfrecords.py b/tensorflow/contrib/eager/python/examples/revnet/cifar_tfrecords.py
index f79428b2a9..377844ad8f 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/cifar_tfrecords.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/cifar_tfrecords.py
@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Read CIFAR-10 data from pickled numpy arrays and writes TFRecords.
+"""Read CIFAR data from pickled numpy arrays and writes TFRecords.
Generates tf.train.Example protos and writes them to TFRecord files from the
-python version of the CIFAR-10 dataset downloaded from
+python version of the CIFAR dataset downloaded from
https://www.cs.toronto.edu/~kriz/cifar.html.
"""
@@ -32,20 +32,22 @@ from six.moves import cPickle as pickle
from six.moves import urllib
import tensorflow as tf
-CIFAR_FILENAME = 'cifar-10-python.tar.gz'
-CIFAR_DOWNLOAD_URL = 'https://www.cs.toronto.edu/~kriz/' + CIFAR_FILENAME
-CIFAR_LOCAL_FOLDER = 'cifar-10-batches-py'
+BASE_URL = 'https://www.cs.toronto.edu/~kriz/'
+CIFAR_FILE_NAMES = ['cifar-10-python.tar.gz', 'cifar-100-python.tar.gz']
+CIFAR_DOWNLOAD_URLS = [BASE_URL + name for name in CIFAR_FILE_NAMES]
+CIFAR_LOCAL_FOLDERS = ['cifar-10', 'cifar-100']
+EXTRACT_FOLDERS = ['cifar-10-batches-py', 'cifar-100-python']
-def download_and_extract(data_dir):
- """Download CIFAR-10 if not already downloaded."""
- filepath = os.path.join(data_dir, CIFAR_FILENAME)
+def download_and_extract(data_dir, file_name, url):
+ """Download CIFAR if not already downloaded."""
+ filepath = os.path.join(data_dir, file_name)
if tf.gfile.Exists(filepath):
return filepath
if not tf.gfile.Exists(data_dir):
tf.gfile.MakeDirs(data_dir)
- urllib.request.urlretrieve(CIFAR_DOWNLOAD_URL, filepath)
+ urllib.request.urlretrieve(url, filepath)
tarfile.open(os.path.join(filepath), 'r:gz').extractall(data_dir)
return filepath
@@ -58,12 +60,22 @@ def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
-def _get_file_names():
+def _get_file_names(folder):
"""Returns the file names expected to exist in the input_dir."""
+ assert folder in ['cifar-10', 'cifar-100']
+
file_names = {}
- file_names['train'] = ['data_batch_%d' % i for i in range(1, 5)]
- file_names['validation'] = ['data_batch_5']
- file_names['test'] = ['test_batch']
+ if folder == 'cifar-10':
+ file_names['train'] = ['data_batch_%d' % i for i in range(1, 5)]
+ file_names['validation'] = ['data_batch_5']
+ file_names['train_all'] = ['data_batch_%d' % i for i in range(1, 6)]
+ file_names['test'] = ['test_batch']
+ else:
+ file_names['train_all'] = ['train']
+ file_names['test'] = ['test']
+ # Split in `convert_to_tfrecord` function
+ file_names['train'] = ['train']
+ file_names['validation'] = ['train']
return file_names
@@ -76,14 +88,28 @@ def read_pickle_from_file(filename):
return data_dict
-def convert_to_tfrecord(input_files, output_file):
+def convert_to_tfrecord(input_files, output_file, folder):
"""Converts files with pickled data to TFRecords."""
+ assert folder in ['cifar-10', 'cifar-100']
+
print('Generating %s' % output_file)
with tf.python_io.TFRecordWriter(output_file) as record_writer:
for input_file in input_files:
data_dict = read_pickle_from_file(input_file)
data = data_dict[b'data']
- labels = data_dict[b'labels']
+ try:
+ labels = data_dict[b'labels']
+ except KeyError:
+ labels = data_dict[b'fine_labels']
+
+ if folder == 'cifar-100' and input_file.endswith('train.tfrecords'):
+ data = data[:40000]
+ labels = labels[:40000]
+ elif folder == 'cifar-100' and input_file.endswith(
+ 'validation.tfrecords'):
+ data = data[40000:]
+ labels = labels[40000:]
+
num_entries_in_batch = len(labels)
for i in range(num_entries_in_batch):
@@ -97,19 +123,24 @@ def convert_to_tfrecord(input_files, output_file):
def main(_):
- print('Download from {} and extract.'.format(CIFAR_DOWNLOAD_URL))
- download_and_extract(FLAGS.data_dir)
- file_names = _get_file_names()
- input_dir = os.path.join(FLAGS.data_dir, CIFAR_LOCAL_FOLDER)
-
- for mode, files in file_names.items():
- input_files = [os.path.join(input_dir, f) for f in files]
- output_file = os.path.join(FLAGS.data_dir, mode + '.tfrecords')
- try:
- os.remove(output_file)
- except OSError:
- pass
- convert_to_tfrecord(input_files, output_file)
+ for file_name, url, folder, extract_folder in zip(
+ CIFAR_FILE_NAMES, CIFAR_DOWNLOAD_URLS, CIFAR_LOCAL_FOLDERS,
+ EXTRACT_FOLDERS):
+ print('Download from {} and extract.'.format(url))
+ data_dir = os.path.join(FLAGS.data_dir, folder)
+ download_and_extract(data_dir, file_name, url)
+ file_names = _get_file_names(folder)
+ input_dir = os.path.join(data_dir, extract_folder)
+
+ for mode, files in file_names.items():
+ input_files = [os.path.join(input_dir, f) for f in files]
+ output_file = os.path.join(data_dir, mode + '.tfrecords')
+ try:
+ os.remove(output_file)
+ except OSError:
+ pass
+ convert_to_tfrecord(input_files, output_file, folder)
+
print('Done!')
@@ -118,6 +149,6 @@ if __name__ == '__main__':
flags.DEFINE_string(
'data_dir',
default=None,
- help='Directory to download and extract CIFAR-10 to.')
+ help='Directory to download, extract and store TFRecords.')
tf.app.run(main)
diff --git a/tensorflow/contrib/eager/python/examples/revnet/config.py b/tensorflow/contrib/eager/python/examples/revnet/config.py
index 30b0edbf43..3d93fa955a 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/config.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/config.py
@@ -66,7 +66,7 @@ def get_hparams_cifar_38():
config.add_hparam("dtype", tf.float32)
config.add_hparam("eval_batch_size", 1000)
config.add_hparam("div255", True)
- # TODO(lxuechen): This is imprecise, when training with validation set,
+ # This is imprecise, when training with validation set,
# we only have 40k images in training data
config.add_hparam("iters_per_epoch", 50000 // config.batch_size)
config.add_hparam("epochs", config.max_train_iter // config.iters_per_epoch)
@@ -74,6 +74,26 @@ def get_hparams_cifar_38():
return config
+def get_hparams_cifar_110():
+ config = get_hparams_cifar_38()
+ config.filters = [32, 64, 128]
+ config.n_res = [9, 9, 9]
+
+ return config
+
+
+def get_hparams_cifar_164():
+ config = get_hparams_cifar_38()
+ config.filters = [32, 64, 128]
+ config.n_res = [9, 9, 9]
+ config.use_bottleneck = True
+ # Due to bottleneck residual blocks
+ filters = [f * 4 for f in config.filters]
+ config.filters = filters
+
+ return config
+
+
def get_hparams_imagenet_56():
"""RevNet-56 configurations for ImageNet."""
@@ -113,9 +133,8 @@ def get_hparams_imagenet_56():
# TODO(lxuechen): Update this according to ImageNet data
config.add_hparam("iters_per_epoch", 50000 // config.batch_size)
config.add_hparam("epochs", config.max_train_iter // config.iters_per_epoch)
-
- if config.bottleneck:
- filters = [f * 4 for f in config.filters]
- config.filters = filters
+ # Due to bottleneck residual blocks
+ filters = [f * 4 for f in config.filters]
+ config.filters = filters
return config
diff --git a/tensorflow/contrib/eager/python/examples/revnet/main.py b/tensorflow/contrib/eager/python/examples/revnet/main.py
index 1065592509..e2f43b03f9 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/main.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/main.py
@@ -23,7 +23,6 @@ import sys
from absl import flags
import tensorflow as tf
-from tqdm import tqdm
from tensorflow.contrib.eager.python.examples.revnet import cifar_input
from tensorflow.contrib.eager.python.examples.revnet import config as config_
from tensorflow.contrib.eager.python.examples.revnet import revnet
@@ -32,19 +31,110 @@ tfe = tf.contrib.eager
def main(_):
"""Eager execution workflow with RevNet trained on CIFAR-10."""
+ config = get_config()
+ ds_train, ds_train_one_shot, ds_validation, ds_test = get_datasets(config)
+ model = revnet.RevNet(config=config)
+ global_step = tf.train.get_or_create_global_step() # Ensure correct summary
+ global_step.assign(1)
+ learning_rate = tf.train.piecewise_constant(
+ global_step, config.lr_decay_steps, config.lr_list)
+ optimizer = tf.train.MomentumOptimizer(
+ learning_rate, momentum=config.momentum)
+ checkpointer = tf.train.Checkpoint(
+ optimizer=optimizer, model=model, optimizer_step=global_step)
+
+ if FLAGS.train_dir:
+ summary_writer = tf.contrib.summary.create_file_writer(FLAGS.train_dir)
+ if FLAGS.restore:
+ latest_path = tf.train.latest_checkpoint(FLAGS.train_dir)
+ checkpointer.restore(latest_path)
+ print("Restored latest checkpoint at path:\"{}\" "
+ "with global_step: {}".format(latest_path, global_step.numpy()))
+ sys.stdout.flush()
+
+ if FLAGS.manual_grad:
+ print("Using manual gradients.")
+ else:
+ print("Not using manual gradients.")
+ sys.stdout.flush()
+
+ for x, y in ds_train:
+ train_one_iter(model, x, y, optimizer, global_step=global_step)
+
+ if global_step.numpy() % config.log_every == 0:
+ it_train = ds_train_one_shot.make_one_shot_iterator()
+ it_test = ds_test.make_one_shot_iterator()
+ acc_train, loss_train = evaluate(model, it_train)
+ acc_test, loss_test = evaluate(model, it_test)
+
+ if FLAGS.validate:
+ it_validation = ds_validation.make_one_shot_iterator()
+ acc_validation, loss_validation = evaluate(model, it_validation)
+ print("Iter {}, "
+ "training set accuracy {:.4f}, loss {:.4f}; "
+ "validation set accuracy {:.4f}, loss {:4.f}"
+ "test accuracy {:.4f}, loss {:.4f}".format(
+ global_step.numpy(), acc_train, loss_train, acc_validation,
+ loss_validation, acc_test, loss_test))
+ else:
+ print("Iter {}, "
+ "training set accuracy {:.4f}, loss {:.4f}; "
+ "test accuracy {:.4f}, loss {:.4f}".format(
+ global_step.numpy(), acc_train, loss_train, acc_test,
+ loss_test))
+ sys.stdout.flush()
+
+ if FLAGS.train_dir:
+ with summary_writer.as_default():
+ with tf.contrib.summary.always_record_summaries():
+ tf.contrib.summary.scalar("Training accuracy", acc_train)
+ tf.contrib.summary.scalar("Test accuracy", acc_test)
+ tf.contrib.summary.scalar("Training loss", loss_train)
+ tf.contrib.summary.scalar("Test loss", loss_test)
+ if FLAGS.validate:
+ tf.contrib.summary.scalar("Validation accuracy", acc_validation)
+ tf.contrib.summary.scalar("Validation loss", loss_validation)
+
+ if global_step.numpy() % config.save_every == 0 and FLAGS.train_dir:
+ saved_path = checkpointer.save(
+ file_prefix=os.path.join(FLAGS.train_dir, "ckpt"))
+ print("Saved checkpoint at path: \"{}\" "
+ "with global_step: {}".format(saved_path, global_step.numpy()))
+ sys.stdout.flush()
+
+
+def get_config():
+ """Return configuration."""
+ print("Config: {}".format(FLAGS.config))
+ sys.stdout.flush()
+ config = {
+ "revnet-38": config_.get_hparams_cifar_38(),
+ "revnet-110": config_.get_hparams_cifar_110(),
+ "revnet-164": config_.get_hparams_cifar_164(),
+ }[FLAGS.config]
+
+ if FLAGS.dataset == "cifar-100":
+ config.n_classes = 100
+
+ return config
+
+
+def get_datasets(config):
+ """Return dataset."""
if FLAGS.data_dir is None:
raise ValueError("No supplied data directory")
-
if not os.path.exists(FLAGS.data_dir):
raise ValueError("Data directory {} does not exist".format(FLAGS.data_dir))
+ if FLAGS.dataset not in ["cifar-10", "cifar-100"]:
+ raise ValueError("Unknown dataset {}".format(FLAGS.dataset))
- tf.enable_eager_execution()
- config = config_.get_hparams_cifar_38()
-
+ print("Training on {} dataset.".format(FLAGS.dataset))
+ sys.stdout.flush()
+ data_dir = os.path.join(FLAGS.data_dir, FLAGS.dataset)
if FLAGS.validate:
# 40k Training set
ds_train = cifar_input.get_ds_from_tfrecords(
- data_dir=FLAGS.data_dir,
+ data_dir=data_dir,
split="train",
data_aug=True,
batch_size=config.batch_size,
@@ -55,7 +145,7 @@ def main(_):
prefetch=config.batch_size)
# 10k Training set
ds_validation = cifar_input.get_ds_from_tfrecords(
- data_dir=FLAGS.data_dir,
+ data_dir=data_dir,
split="validation",
data_aug=False,
batch_size=config.eval_batch_size,
@@ -67,7 +157,7 @@ def main(_):
else:
# 50k Training set
ds_train = cifar_input.get_ds_from_tfrecords(
- data_dir=FLAGS.data_dir,
+ data_dir=data_dir,
split="train_all",
data_aug=True,
batch_size=config.batch_size,
@@ -76,10 +166,11 @@ def main(_):
data_format=config.data_format,
dtype=config.dtype,
prefetch=config.batch_size)
+ ds_validation = None
# Always compute loss and accuracy on whole training and test set
ds_train_one_shot = cifar_input.get_ds_from_tfrecords(
- data_dir=FLAGS.data_dir,
+ data_dir=data_dir,
split="train_all",
data_aug=False,
batch_size=config.eval_batch_size,
@@ -90,7 +181,7 @@ def main(_):
prefetch=config.eval_batch_size)
ds_test = cifar_input.get_ds_from_tfrecords(
- data_dir=FLAGS.data_dir,
+ data_dir=data_dir,
split="test",
data_aug=False,
batch_size=config.eval_batch_size,
@@ -100,91 +191,19 @@ def main(_):
dtype=config.dtype,
prefetch=config.eval_batch_size)
- model = revnet.RevNet(config=config)
- global_step = tfe.Variable(1, trainable=False)
- learning_rate = tf.train.piecewise_constant(
- global_step, config.lr_decay_steps, config.lr_list)
- optimizer = tf.train.MomentumOptimizer(
- learning_rate, momentum=config.momentum)
- checkpointer = tf.train.Checkpoint(
- optimizer=optimizer, model=model, optimizer_step=global_step)
-
- if FLAGS.train_dir:
- summary_writer = tf.contrib.summary.create_file_writer(FLAGS.train_dir)
- if FLAGS.restore:
- latest_path = tf.train.latest_checkpoint(FLAGS.train_dir)
- checkpointer.restore(latest_path)
- print("Restored latest checkpoint at path:\"{}\" "
- "with global_step: {}".format(latest_path, global_step.numpy()))
- sys.stdout.flush()
+ return ds_train, ds_train_one_shot, ds_validation, ds_test
- warmup(model, config)
- for x, y in ds_train:
- loss = train_one_iter(model, x, y, optimizer, global_step=global_step)
-
- if global_step.numpy() % config.log_every == 0:
- it_train = ds_train_one_shot.make_one_shot_iterator()
- acc_train, loss_train = evaluate(model, it_train)
- it_test = ds_test.make_one_shot_iterator()
- acc_test, loss_test = evaluate(model, it_test)
- if FLAGS.validate:
- it_validation = ds_validation.make_one_shot_iterator()
- acc_validation, loss_validation = evaluate(model, it_validation)
- print("Iter {}, "
- "training set accuracy {:.4f}, loss {:.4f}; "
- "validation set accuracy {:.4f}, loss {:4.f}"
- "test accuracy {:.4f}, loss {:.4f}".format(
- global_step.numpy(), acc_train, loss_train, acc_validation,
- loss_validation, acc_test, loss_test))
- else:
- print("Iter {}, "
- "training set accuracy {:.4f}, loss {:.4f}; "
- "test accuracy {:.4f}, loss {:.4f}".format(
- global_step.numpy(), acc_train, loss_train, acc_test,
- loss_test))
- sys.stdout.flush()
-
- if FLAGS.train_dir:
- with summary_writer.as_default():
- with tf.contrib.summary.always_record_summaries():
- tf.contrib.summary.scalar("Training loss", loss)
- tf.contrib.summary.scalar("Test accuracy", acc_test)
- if FLAGS.validate:
- tf.contrib.summary.scalar("Validation accuracy", acc_validation)
-
- if global_step.numpy() % config.save_every == 0 and FLAGS.train_dir:
- saved_path = checkpointer.save(
- file_prefix=os.path.join(FLAGS.train_dir, "ckpt"))
- print("Saved checkpoint at path: \"{}\" "
- "with global_step: {}".format(saved_path, global_step.numpy()))
- sys.stdout.flush()
-
-
-def warmup(model, config, steps=1):
- mock_input = tf.random_normal((config.batch_size,) + config.input_shape)
- for _ in range(steps):
- model(mock_input, training=False)
-
-
-def train_one_iter(model,
- inputs,
- labels,
- optimizer,
- global_step=None,
- verbose=False):
+def train_one_iter(model, inputs, labels, optimizer, global_step=None):
"""Train for one iteration."""
if FLAGS.manual_grad:
- if verbose:
- print("Using manual gradients")
- grads, vars_, loss = model.compute_gradients(inputs, labels)
+ grads, vars_, loss = model.compute_gradients(inputs, labels, training=True)
optimizer.apply_gradients(zip(grads, vars_), global_step=global_step)
else: # For correctness validation
- if verbose:
- print("Not using manual gradients")
with tf.GradientTape() as tape:
logits, _ = model(inputs, training=True)
loss = model.compute_loss(logits=logits, labels=labels)
+ tf.logging.info("Logits are placed on device: {}".format(logits.device))
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(
zip(grads, model.trainable_variables), global_step=global_step)
@@ -196,7 +215,7 @@ def evaluate(model, iterator):
"""Compute accuracy with the given dataset iterator."""
mean_loss = tfe.metrics.Mean()
accuracy = tfe.metrics.Accuracy()
- for x, y in tqdm(iterator):
+ for x, y in iterator:
logits, _ = model(x, training=False)
loss = model.compute_loss(logits=logits, labels=y)
accuracy(
@@ -209,11 +228,11 @@ def evaluate(model, iterator):
if __name__ == "__main__":
flags.DEFINE_string(
+ "data_dir", default=None, help="Directory to load tfrecords")
+ flags.DEFINE_string(
"train_dir",
default=None,
help="[Optional] Directory to store the training information")
- flags.DEFINE_string(
- "data_dir", default=None, help="Directory to load tfrecords")
flags.DEFINE_boolean(
"restore",
default=False,
@@ -226,5 +245,12 @@ if __name__ == "__main__":
"manual_grad",
default=False,
help="[Optional] Use manual gradient graph to save memory")
+ flags.DEFINE_string(
+ "dataset",
+ default="cifar-10",
+ help="[Optional] The dataset used; either `cifar-10` or `cifar-100`")
+ flags.DEFINE_string(
+ "config", default="revnet-38", help="[Optional] Architecture of network.")
FLAGS = flags.FLAGS
+ tf.enable_eager_execution()
tf.app.run(main)
diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet.py b/tensorflow/contrib/eager/python/examples/revnet/revnet.py
index 0228bff6fa..af0d20fa72 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/revnet.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/revnet.py
@@ -48,7 +48,6 @@ class RevNet(tf.keras.Model):
self._init_block = self._construct_init_block()
self._block_list = self._construct_intermediate_blocks()
self._final_block = self._construct_final_block()
- self._moving_stats_vars = None
def _construct_init_block(self):
init_block = tf.keras.Sequential(
@@ -60,9 +59,12 @@ class RevNet(tf.keras.Model):
data_format=self.config.data_format,
use_bias=False,
padding="SAME",
- input_shape=self.config.input_shape),
+ input_shape=self.config.input_shape,
+ dtype=self.config.dtype),
tf.keras.layers.BatchNormalization(
- axis=self.axis, fused=self.config.fused),
+ axis=self.axis,
+ fused=self.config.fused,
+ dtype=self.config.dtype),
tf.keras.layers.Activation("relu"),
],
name="init")
@@ -72,7 +74,8 @@ class RevNet(tf.keras.Model):
pool_size=(3, 3),
strides=(2, 2),
padding="SAME",
- data_format=self.config.data_format))
+ data_format=self.config.data_format,
+ dtype=self.config.dtype))
return init_block
def _construct_final_block(self):
@@ -97,11 +100,13 @@ class RevNet(tf.keras.Model):
tf.keras.layers.BatchNormalization(
axis=self.axis,
input_shape=input_shape,
- fused=self.config.fused),
+ fused=self.config.fused,
+ dtype=self.config.dtype),
tf.keras.layers.Activation("relu"),
tf.keras.layers.GlobalAveragePooling2D(
- data_format=self.config.data_format),
- tf.keras.layers.Dense(self.config.n_classes)
+ data_format=self.config.data_format, dtype=self.config.dtype),
+ tf.keras.layers.Dense(
+ self.config.n_classes, dtype=self.config.dtype)
],
name="final")
return final_block
@@ -139,7 +144,8 @@ class RevNet(tf.keras.Model):
batch_norm_first=(i != 0), # Only skip on first block
data_format=self.config.data_format,
bottleneck=self.config.bottleneck,
- fused=self.config.fused)
+ fused=self.config.fused,
+ dtype=self.config.dtype)
block_list.append(rev_block)
# Precompute input shape for the next block
@@ -174,21 +180,30 @@ class RevNet(tf.keras.Model):
def compute_loss(self, logits, labels):
"""Compute cross entropy loss."""
- cross_ent = tf.nn.sparse_softmax_cross_entropy_with_logits(
- logits=logits, labels=labels)
+ if self.config.dtype == tf.float32 or self.config.dtype == tf.float16:
+ cross_ent = tf.nn.sparse_softmax_cross_entropy_with_logits(
+ logits=logits, labels=labels)
+ else:
+ # `sparse_softmax_cross_entropy_with_logits` does not have a GPU kernel
+ # for float64, int32 pairs
+ labels = tf.one_hot(
+ labels, depth=self.config.n_classes, axis=1, dtype=self.config.dtype)
+ cross_ent = tf.nn.softmax_cross_entropy_with_logits(
+ logits=logits, labels=labels)
return tf.reduce_mean(cross_ent)
- def compute_gradients(self, inputs, labels, training=True):
+ def compute_gradients(self, inputs, labels, training=True, l2_reg=True):
"""Manually computes gradients.
- This method also SILENTLY updates the running averages of batch
- normalization when `training` is set to True.
+ When eager execution is enabled, this method also SILENTLY updates the
+ running averages of batch normalization when `training` is set to True.
Args:
inputs: Image tensor, either NHWC or NCHW, conforming to `data_format`
labels: One-hot labels for classification
training: Use the mini-batch stats in batch norm if set to True
+ l2_reg: Apply l2 regularization
Returns:
list of tuples each being (grad, var) for optimizer to use
@@ -205,7 +220,7 @@ class RevNet(tf.keras.Model):
# Manually backprop through last block
x = saved_hidden[-1]
with tf.GradientTape() as tape:
- x = tf.identity(x) # TODO(lxuechen): Remove after b/110264016 is fixed
+ x = tf.identity(x)
tape.watch(x)
# Running stats updated below
logits = self._final_block(x, training=training)
@@ -232,16 +247,17 @@ class RevNet(tf.keras.Model):
assert not saved_hidden # Cleared after backprop
with tf.GradientTape() as tape:
- x = tf.identity(x) # TODO(lxuechen): Remove after b/110264016 is fixed
+ x = tf.identity(x)
# Running stats updated below
y = self._init_block(x, training=training)
grads_all += tape.gradient(
- y, self._init_block.trainable_variables, output_gradients=[dy])
+ y, self._init_block.trainable_variables, output_gradients=dy)
vars_all += self._init_block.trainable_variables
# Apply weight decay
- grads_all = self._apply_weight_decay(grads_all, vars_all)
+ if l2_reg:
+ grads_all = self._apply_weight_decay(grads_all, vars_all)
return grads_all, vars_all, loss
@@ -254,6 +270,14 @@ class RevNet(tf.keras.Model):
]
def get_moving_stats(self):
+ """Get moving averages of batch normalization.
+
+ This is needed to avoid updating the running average twice in one iteration.
+
+ Returns:
+ A dictionary mapping variables for batch normalization moving averages
+ to their current values.
+ """
vars_and_vals = {}
def _is_moving_var(v):
@@ -266,5 +290,12 @@ class RevNet(tf.keras.Model):
return vars_and_vals
def restore_moving_stats(self, vars_and_vals):
+ """Restore moving averages of batch normalization.
+
+ This is needed to avoid updating the running average twice in one iteration.
+
+ Args:
+ vars_and_vals: The dictionary mapping variables to their previous values.
+ """
for var_, val in six.iteritems(vars_and_vals):
var_.assign(val)
diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py
index a5f240436a..b2ac4b67c9 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py
@@ -22,6 +22,7 @@ import gc
import time
import tensorflow as tf
+from tensorflow.contrib.eager.python.examples.revnet import blocks_test
from tensorflow.contrib.eager.python.examples.revnet import config as config_
from tensorflow.contrib.eager.python.examples.revnet import revnet
from tensorflow.python.client import device_lib
@@ -40,16 +41,18 @@ class RevNetTest(tf.test.TestCase):
def setUp(self):
super(RevNetTest, self).setUp()
- tf.set_random_seed(1)
- config = config_.get_hparams_imagenet_56()
+ config = config_.get_hparams_cifar_38()
+ # Reconstruction could cause numerical error, use double precision for tests
+ config.dtype = tf.float64
+ config.fused = False # Fused batch norm does not support tf.float64
shape = (config.batch_size,) + config.input_shape
self.model = revnet.RevNet(config=config)
- self.x = tf.random_normal(shape=shape)
+ self.x = tf.random_normal(shape=shape, dtype=tf.float64)
self.t = tf.random_uniform(
shape=[config.batch_size],
minval=0,
maxval=config.n_classes,
- dtype=tf.int32)
+ dtype=tf.int64)
self.config = config
def tearDown(self):
@@ -65,21 +68,51 @@ class RevNetTest(tf.test.TestCase):
y, _ = self.model(self.x, training=False)
self.assertEqual(y.shape, [self.config.batch_size, self.config.n_classes])
+ def _check_grad_angle_combined(self, grads, grads_true):
+ """Verify that the reconstructed gradients has correct direction.
+
+ Due to numerical imprecision, the magnitude may be slightly different.
+ Yet according to the paper, the angle should be roughly the same.
+
+ Args:
+ grads: list of gradients from reconstruction
+ grads_true: list of true gradients
+ """
+
+ def _combine(gs):
+ return [tf.reshape(g, [-1]) for g in gs]
+
+ g1_all = tf.concat(_combine(grads), axis=0)
+ g2_all = tf.concat(_combine(grads_true), axis=0)
+
+ self.assertEqual(len(g1_all.shape), 1)
+ self.assertEqual(len(g2_all.shape), 1)
+
+ degree = blocks_test.compute_degree(g1_all, g2_all)
+ self.assertLessEqual(degree, 1e0)
+
def test_compute_gradients(self):
"""Test `compute_gradients` function."""
-
- grads, vars_, _ = self.model.compute_gradients(
- inputs=self.x, labels=self.t, training=True)
+ self.model(self.x, training=False) # Initialize model
+ grads, vars_, loss = self.model.compute_gradients(
+ inputs=self.x, labels=self.t, training=True, l2_reg=True)
self.assertTrue(isinstance(grads, list))
self.assertTrue(isinstance(vars_, list))
self.assertEqual(len(grads), len(vars_))
for grad, var in zip(grads, vars_):
- if grad is not None:
- self.assertEqual(grad.shape, var.shape)
+ self.assertEqual(grad.shape, var.shape)
+
+ # Compare against the true gradient computed by the tape
+ with tf.GradientTape() as tape:
+ logits, _ = self.model(self.x, training=True)
+ loss_true = self.model.compute_loss(logits=logits, labels=self.t)
+ grads_true = tape.gradient(loss_true, vars_)
+ self.assertAllClose(loss, loss_true)
+ self.assertAllClose(grads, grads_true, rtol=1e-4, atol=1e-4)
+ self._check_grad_angle_combined(grads, grads_true)
def test_call_defun(self):
"""Test `call` function with defun."""
-
y, _ = tfe.defun(self.model.call)(self.x, training=False)
self.assertEqual(y.shape, [self.config.batch_size, self.config.n_classes])
@@ -96,8 +129,8 @@ class RevNetTest(tf.test.TestCase):
def test_training_graph(self):
"""Test model training in graph mode."""
-
with tf.Graph().as_default():
+ config = config_.get_hparams_cifar_38()
x = tf.random_normal(
shape=(self.config.batch_size,) + self.config.input_shape)
t = tf.random_uniform(
@@ -106,12 +139,14 @@ class RevNetTest(tf.test.TestCase):
maxval=self.config.n_classes,
dtype=tf.int32)
global_step = tfe.Variable(0., trainable=False)
- model = revnet.RevNet(config=self.config)
- grads_all, vars_all, _ = model.compute_gradients(x, t, training=True)
- optimizer = tf.train.AdamOptimizer(learning_rate=1e-3)
+ model = revnet.RevNet(config=config)
+ model(x)
updates = model.get_updates_for(x)
- self.assertEqual(len(updates), 192)
- with tf.control_dependencies(model.get_updates_for(x)):
+
+ x_ = tf.identity(x)
+ grads_all, vars_all, _ = model.compute_gradients(x_, t, training=True)
+ optimizer = tf.train.AdamOptimizer(learning_rate=1e-3)
+ with tf.control_dependencies(updates):
train_op = optimizer.apply_gradients(
zip(grads_all, vars_all), global_step=global_step)
diff --git a/tensorflow/contrib/eager/python/examples/workshop/1_basic.ipynb b/tensorflow/contrib/eager/python/examples/workshop/1_basic.ipynb
new file mode 100644
index 0000000000..3e7abe952d
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/workshop/1_basic.ipynb
@@ -0,0 +1,282 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "name": "TFE Workshop: control flow",
+ "version": "0.3.2",
+ "provenance": [],
+ "include_colab_link": true
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ "[View in Colaboratory](https://colab.research.google.com/gist/alextp/664b2f8700485ff6801f4d26293bd567/tfe-workshop-control-flow.ipynb)"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "9BpQzh9BvJlj",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 37
+ },
+ "outputId": "0b336886-8204-4815-89fa-5291a49d5784"
+ },
+ "cell_type": "code",
+ "source": [
+ "import tensorflow as tf\n",
+ "import numpy as np\n",
+ "tf.enable_eager_execution()"
+ ],
+ "execution_count": 1,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "0roIB19GvOjI",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "# Eager execution basics\n",
+ "\n",
+ "When eager execution is enabled TensorFlow immediately executes operations, and Tensors are always available. "
+ ]
+ },
+ {
+ "metadata": {
+ "id": "jeO8F-V-vN24",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 68
+ },
+ "outputId": "aeb3bdec-50b7-440d-93d8-5a171f091081"
+ },
+ "cell_type": "code",
+ "source": [
+ "t = tf.constant([[1, 2], [3, 4]])\n",
+ "t"
+ ],
+ "execution_count": 2,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "<tf.Tensor: id=0, shape=(2, 2), dtype=int32, numpy=\n",
+ "array([[1, 2],\n",
+ " [3, 4]], dtype=int32)>"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 2
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "Y17RwSFxvlDL",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 68
+ },
+ "outputId": "cfcc10c7-707b-4997-99b3-a5f382c5166b"
+ },
+ "cell_type": "code",
+ "source": [
+ "tf.matmul(t, t)"
+ ],
+ "execution_count": 3,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "<tf.Tensor: id=2, shape=(2, 2), dtype=int32, numpy=\n",
+ "array([[ 7, 10],\n",
+ " [15, 22]], dtype=int32)>"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 3
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "Dab1bS3TvmRE",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 34
+ },
+ "outputId": "8a624f3d-a658-4359-c586-1c5f6bf4c8b7"
+ },
+ "cell_type": "code",
+ "source": [
+ "# It's also possible to have Python control flow which depends on the value of tensors.\n",
+ "if t[0, 0] > 0.5:\n",
+ " print(\"T is bigger\")\n",
+ "else:\n",
+ " print(\"T is smaller\")"
+ ],
+ "execution_count": 4,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "T is bigger\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "dPgptJcGwIon",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 34
+ },
+ "outputId": "c4f27f2b-0848-4475-dde5-2534dac65a5c"
+ },
+ "cell_type": "code",
+ "source": [
+ "# Tensors are also usable as numpy arrays\n",
+ "np.prod(t)"
+ ],
+ "execution_count": 6,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "24"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 6
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "p3DTfQXnwXzj",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "# Exercise\n",
+ "\n",
+ "The algorithm for bisecting line search is a pretty simple way to find a zero of a continuous scalar function in an interval [a,b] where f(a) and f(b) have different signs. Simply evaluate f((a+b)/2), and narrow the interval by replacing either a or b with (a+b)/2 such that the function when applied on the boundary of the interval still has different signs.\n",
+ "\n",
+ "Implement a python function `bisecting_line_search(f, a, b, epsilon)` which returns a value such that `tf.abs(f(value)) < epsilon`.\n",
+ "\n",
+ "One thing to keep in mind: python's `==` opertor is not overloaded on Tensors, so you need to use `tf.equal` to compare for equality."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "6eq0YuI6ykm5",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "cell_type": "code",
+ "source": [
+ "# Example test harness to get you going\n",
+ "\n",
+ "def test_f(x):\n",
+ " return x - 0.1234\n",
+ "def bisecting_line_search(f, a, b, epsilon):\n",
+ " # Return x such that f(x) <= epsilon.\n",
+ " pass\n",
+ "a = tf.constant(0.0)\n",
+ "b = tf.constant(1.0)\n",
+ "epsilon = tf.constant(0.001)\n",
+ "x = bisecting_line_search(test_f, a, b, epsilon)\n",
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "LcMmEfd_xvej",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 170
+ },
+ "outputId": "f402aa50-8ce3-4416-f755-8bbcd1af7809"
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title Double-click to see the solution\n",
+ "\n",
+ "def bisecting_line_search(f, a, b, epsilon):\n",
+ " f_a = f(a)\n",
+ " f_b = f(b)\n",
+ " probe = (a + b) / 2\n",
+ " f_probe = f(probe)\n",
+ " while tf.abs(f_probe) > epsilon:\n",
+ " if tf.equal(tf.sign(f_probe), tf.sign(f_a)):\n",
+ " a = probe\n",
+ " f_a = f_probe\n",
+ " else:\n",
+ " b = probe\n",
+ " f_b = f_probe\n",
+ " probe = (a + b) / 2\n",
+ " f_probe = f(probe)\n",
+ " print(\"new probe\", probe)\n",
+ " return probe\n",
+ "\n",
+ "bisecting_line_search(test_f, 0., 1., 0.001)"
+ ],
+ "execution_count": 8,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "('new probe', 0.25)\n",
+ "('new probe', 0.125)\n",
+ "('new probe', 0.0625)\n",
+ "('new probe', 0.09375)\n",
+ "('new probe', 0.109375)\n",
+ "('new probe', 0.1171875)\n",
+ "('new probe', 0.12109375)\n",
+ "('new probe', 0.123046875)\n"
+ ],
+ "name": "stdout"
+ },
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "0.123046875"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 8
+ }
+ ]
+ }
+ ]
+} \ No newline at end of file
diff --git a/tensorflow/contrib/eager/python/examples/workshop/2_models.ipynb b/tensorflow/contrib/eager/python/examples/workshop/2_models.ipynb
new file mode 100644
index 0000000000..4f1410e00b
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/workshop/2_models.ipynb
@@ -0,0 +1,1018 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "name": "TFE Workshop: Models.ipynb",
+ "version": "0.3.2",
+ "provenance": [],
+ "collapsed_sections": [],
+ "include_colab_link": true
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ "[View in Colaboratory](https://colab.research.google.com/gist/alextp/5cfcffd408bd5103f5ae747bc97ab0b5/tfe-workshop-models.ipynb)"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "BMxv1O6Q0SJL",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 17
+ },
+ "outputId": "8be9c556-ac7f-4142-e35e-19dc2b097121"
+ },
+ "cell_type": "code",
+ "source": [
+ "import tensorflow as tf\n",
+ "tf.enable_eager_execution()\n",
+ "tfe = tf.contrib.eager"
+ ],
+ "execution_count": 1,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "lE1vJhxp0WR9",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "# Variables\n",
+ "\n",
+ "TensorFlow variables are useful to store the state in your program. They are integrated with other parts of the API (taking gradients, checkpointing, graph functions)."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "C4ztQNgc0VpW",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 34
+ },
+ "outputId": "8b63ae1f-2670-49c0-a31b-8cf7fc4194a1"
+ },
+ "cell_type": "code",
+ "source": [
+ "# Creating variables\n",
+ "v = tfe.Variable(1.0)\n",
+ "v"
+ ],
+ "execution_count": 2,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 2
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "H0daItGg1IAp",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 34
+ },
+ "outputId": "e47d5aab-16a1-4e29-c27d-7fbc0b94b5d3"
+ },
+ "cell_type": "code",
+ "source": [
+ "v.assign_add(1.0)\n",
+ "v"
+ ],
+ "execution_count": 3,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 3
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "BJvBzcIG1hyK",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "# Layers: common sets of useful operations\n",
+ "\n",
+ "Most of the time when writing code for machine learning models you want to operate at a higher level of abstraction than individual operations and manipulation of individual variables.\n",
+ "\n",
+ "Many machine learning models are expressible as the composition and stacking of relatively simple layers, and TensorFlow provides both a set of many common layers as a well as easy ways for you to write your own application-specific layers either from scratch or as the composition of existing layers.\n",
+ "\n",
+ "TensorFlow includes the full [Keras](https://keras.io) API in the tf.keras package, and the Keras layers are very useful when building your own models.\n"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "iSQTS3QW1YQQ",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 17
+ },
+ "outputId": "c5d8aa10-dcad-44f7-f0eb-0faf5249fd7e"
+ },
+ "cell_type": "code",
+ "source": [
+ "# In the tf.keras.layers package, layers are objects. To construct a layer,\n",
+ "# simply construct the object. Most layers take as a first argument the number\n",
+ "# of output dimensions / channels.\n",
+ "layer = tf.keras.layers.Dense(100)\n",
+ "\n",
+ "# The number of input dimensions is often unnecessary, as it can be inferred\n",
+ "# the first time the layer is used, but it can be provided if you want to \n",
+ "# specify it manually, which is useful in some complex models.\n",
+ "layer = tf.keras.layers.Dense(10, input_shape=(None, 5))\n"
+ ],
+ "execution_count": 4,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "nRuUogoS1liV",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 68
+ },
+ "outputId": "c352ce79-d519-45e4-a12e-1eaba76871a2"
+ },
+ "cell_type": "code",
+ "source": [
+ "layer(tf.zeros([2, 2]))"
+ ],
+ "execution_count": 5,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "<tf.Tensor: id=43, shape=(2, 10), dtype=float32, numpy=\n",
+ "array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
+ " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)>"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 5
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "JH4Kf4ka1mht",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 136
+ },
+ "outputId": "c34e2378-f83d-42c5-d30a-ebe55620368a"
+ },
+ "cell_type": "code",
+ "source": [
+ "layer.variables"
+ ],
+ "execution_count": 6,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "[<tf.Variable 'dense/kernel:0' shape=(2, 10) dtype=float32, numpy=\n",
+ " array([[-0.42494273, -0.2067694 , 0.4519381 , 0.6842533 , 0.04131705,\n",
+ " 0.70547956, 0.4021917 , -0.5939298 , -0.5671462 , 0.5586321 ],\n",
+ " [ 0.3709975 , -0.64126074, -0.5386696 , -0.42212513, 0.6550072 ,\n",
+ " 0.70081085, 0.08859557, -0.30801034, -0.31450653, 0.02522504]],\n",
+ " dtype=float32)>,\n",
+ " <tf.Variable 'dense/bias:0' shape=(10,) dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)>]"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 6
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "DSI4NF0_1vn-",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "The full list of pre-existing layers can be seen in [the documentation](https://www.tensorflow.org/api_docs/python/tf/keras/layers). It includes Dense (a fully-connected layer),\n",
+ "Conv2D, LSTM, BatchNormalization, Dropout, and many others."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "hMgDBftJ12Bp",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "# Models: composing layers\n",
+ "\n",
+ "Many interesting layer-like things in machine learning models are implemented by composing existing layers. For example, each residual block in a resnet is a composition of convolutions, batch normalizations, and a shortcut.\n",
+ "\n",
+ "The main class used when creating a layer-like thing which contains other layers is tf.keras.Model. Implementing one is done by inheriting from tf.keras.Model.\n"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "K3gVY6gj1nbe",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 190
+ },
+ "outputId": "6e9be0c4-960e-46c2-cdd9-7e94ad09d46b"
+ },
+ "cell_type": "code",
+ "source": [
+ "class ResnetIdentityBlock(tf.keras.Model):\n",
+ " def __init__(self, kernel_size, filters):\n",
+ " super(ResnetIdentityBlock, self).__init__(name='')\n",
+ " filters1, filters2, filters3 = filters\n",
+ "\n",
+ " self.conv2a = tf.keras.layers.Conv2D(filters1, (1, 1))\n",
+ " self.bn2a = tf.keras.layers.BatchNormalization()\n",
+ "\n",
+ " self.conv2b = tf.keras.layers.Conv2D(filters2, kernel_size, padding='same')\n",
+ " self.bn2b = tf.keras.layers.BatchNormalization()\n",
+ "\n",
+ " self.conv2c = tf.keras.layers.Conv2D(filters3, (1, 1))\n",
+ " self.bn2c = tf.keras.layers.BatchNormalization()\n",
+ "\n",
+ " def call(self, input_tensor, training=False):\n",
+ " x = self.conv2a(input_tensor)\n",
+ " x = self.bn2a(x, training=training)\n",
+ " x = tf.nn.relu(x)\n",
+ "\n",
+ " x = self.conv2b(x)\n",
+ " x = self.bn2b(x, training=training)\n",
+ " x = tf.nn.relu(x)\n",
+ "\n",
+ " x = self.conv2c(x)\n",
+ " x = self.bn2c(x, training=training)\n",
+ "\n",
+ " x += input_tensor\n",
+ " return tf.nn.relu(x)\n",
+ " \n",
+ "block = ResnetIdentityBlock(1, [1, 2, 3])\n",
+ "print(block(tf.zeros([1, 2, 3, 3])))\n",
+ "print([x.name for x in block.variables])"
+ ],
+ "execution_count": 7,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "tf.Tensor(\n",
+ "[[[[0. 0. 0.]\n",
+ " [0. 0. 0.]\n",
+ " [0. 0. 0.]]\n",
+ "\n",
+ " [[0. 0. 0.]\n",
+ " [0. 0. 0.]\n",
+ " [0. 0. 0.]]]], shape=(1, 2, 3, 3), dtype=float32)\n",
+ "['resnet_identity_block/conv2d/kernel:0', 'resnet_identity_block/conv2d/bias:0', 'resnet_identity_block/batch_normalization/gamma:0', 'resnet_identity_block/batch_normalization/beta:0', 'resnet_identity_block/conv2d_1/kernel:0', 'resnet_identity_block/conv2d_1/bias:0', 'resnet_identity_block/batch_normalization_1/gamma:0', 'resnet_identity_block/batch_normalization_1/beta:0', 'resnet_identity_block/conv2d_2/kernel:0', 'resnet_identity_block/conv2d_2/bias:0', 'resnet_identity_block/batch_normalization_2/gamma:0', 'resnet_identity_block/batch_normalization_2/beta:0', 'resnet_identity_block/batch_normalization/moving_mean:0', 'resnet_identity_block/batch_normalization/moving_variance:0', 'resnet_identity_block/batch_normalization_1/moving_mean:0', 'resnet_identity_block/batch_normalization_1/moving_variance:0', 'resnet_identity_block/batch_normalization_2/moving_mean:0', 'resnet_identity_block/batch_normalization_2/moving_variance:0']\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "LPXhHUIc1-sO",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "Much of the time, however, models which compose many layers simply call one layer after the other. This can be done in very little code using tf.keras.Sequential"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "5pXgzNAU17xk",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 173
+ },
+ "outputId": "03b7eaf8-9b35-482b-bcf0-a99af6c2c6a4"
+ },
+ "cell_type": "code",
+ "source": [
+ " my_seq = tf.keras.Sequential([tf.keras.layers.Conv2D(1, (1, 1)),\n",
+ " tf.keras.layers.BatchNormalization(),\n",
+ " tf.keras.layers.Conv2D(2, 1, \n",
+ " padding='same'),\n",
+ " tf.keras.layers.BatchNormalization(),\n",
+ " tf.keras.layers.Conv2D(3, (1, 1)),\n",
+ " tf.keras.layers.BatchNormalization()])\n",
+ "my_seq(tf.zeros([1, 2, 3, 3]))\n"
+ ],
+ "execution_count": 8,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "<tf.Tensor: id=493, shape=(1, 2, 3, 3), dtype=float32, numpy=\n",
+ "array([[[[0., 0., 0.],\n",
+ " [0., 0., 0.],\n",
+ " [0., 0., 0.]],\n",
+ "\n",
+ " [[0., 0., 0.],\n",
+ " [0., 0., 0.],\n",
+ " [0., 0., 0.]]]], dtype=float32)>"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 8
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "MZrns6p22GEQ",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Exercise!\n",
+ "\n",
+ "Make a simple convolutional neural network model, useful for things such as MNIST which don't need too many parameters. A sequence of two or three convolutions with small output channels (say, 32 and 64) plus one or two fully connected layers is probably enough.\n",
+ "\n",
+ "The input shape should be [batch_size, 28, 28, 1]."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "8CAUa3KNN916",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 17
+ },
+ "outputId": "97c0ff3c-c962-4c13-eee8-406101465761"
+ },
+ "cell_type": "code",
+ "source": [
+ "# TODO: Implement a convolutional model as described above, and assign it to\n",
+ "# model.\n",
+ "model = tf.keras.Sequential([\n",
+ " \n",
+ "])"
+ ],
+ "execution_count": 9,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "vLDDduR32E82",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 34
+ },
+ "outputId": "09bb1d43-b4c6-44b5-916e-0d2903d10cf4"
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title Click to see the answer\n",
+ "\n",
+ "max_pool = tf.keras.layers.MaxPooling2D(\n",
+ " (2, 2), (2, 2), padding='same')\n",
+ " # The model consists of a sequential chain of layers, so tf.keras.Sequential\n",
+ " # (a subclass of tf.keras.Model) makes for a compact description.\n",
+ "model = tf.keras.Sequential(\n",
+ " [\n",
+ " tf.keras.layers.Conv2D(\n",
+ " 32,\n",
+ " 5,\n",
+ " padding='same',\n",
+ " activation=tf.nn.relu),\n",
+ " max_pool,\n",
+ " tf.keras.layers.Conv2D(\n",
+ " 64,\n",
+ " 5,\n",
+ " padding='same',\n",
+ " activation=tf.nn.relu),\n",
+ " max_pool,\n",
+ " tf.keras.layers.Flatten(),\n",
+ " tf.keras.layers.Dense(1024, activation=tf.nn.relu),\n",
+ " tf.keras.layers.Dropout(0.4),\n",
+ " tf.keras.layers.Dense(10)\n",
+ " ])\n",
+ "\n",
+ "model(tf.zeros([1, 28, 28, 1]))"
+ ],
+ "execution_count": 10,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "<tf.Tensor: id=625, shape=(1, 10), dtype=float32, numpy=array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)>"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 10
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "H_CKVBroik4M",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "# Stop here for now"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "_yRwuE6MMmzC",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "# Training\n",
+ "\n",
+ "When eager execution is enabled, you can write Pythonic training loops. Simply\n",
+ "\n",
+ "1. load your data into a `tf.data.Dataset`, which lets you construct functional pipelines for processing, shuffling, and batching your data,\n",
+ "2. iterate over the dataset using a Python `for` loop, and\n",
+ "3. perform an optimization step in the body of your `for` loop.\n",
+ "\n",
+ "This workflow is exemplified in the following exercise."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "gj0-EkTc_Xt1",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "\n",
+ "\n",
+ "## Exercise!\n",
+ "\n",
+ "In this exercise, you'll train the convolutional model you implemented for the previous exericse on the MNIST dataset. "
+ ]
+ },
+ {
+ "metadata": {
+ "id": "WOGm9HHn_byR",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 17
+ },
+ "outputId": "bbccc7ad-33cd-446e-bcda-f358c7547e1b"
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title Utilities for downloading MNIST data (double-click to show code)\n",
+ "import gzip\n",
+ "import os\n",
+ "import tempfile\n",
+ "from six.moves import urllib\n",
+ "import shutil\n",
+ "\n",
+ "import numpy as np\n",
+ "\n",
+ "def read32(bytestream):\n",
+ " \"\"\"Read 4 bytes from bytestream as an unsigned 32-bit integer.\"\"\"\n",
+ " dt = np.dtype(np.uint32).newbyteorder('>')\n",
+ " return np.frombuffer(bytestream.read(4), dtype=dt)[0]\n",
+ "\n",
+ "\n",
+ "def check_image_file_header(filename):\n",
+ " \"\"\"Validate that filename corresponds to images for the MNIST dataset.\"\"\"\n",
+ " with tf.gfile.Open(filename, 'rb') as f:\n",
+ " magic = read32(f)\n",
+ " read32(f) # num_images, unused\n",
+ " rows = read32(f)\n",
+ " cols = read32(f)\n",
+ " if magic != 2051:\n",
+ " raise ValueError('Invalid magic number %d in MNIST file %s' % (magic,\n",
+ " f.name))\n",
+ " if rows != 28 or cols != 28:\n",
+ " raise ValueError(\n",
+ " 'Invalid MNIST file %s: Expected 28x28 images, found %dx%d' %\n",
+ " (f.name, rows, cols))\n",
+ "\n",
+ "\n",
+ "def check_labels_file_header(filename):\n",
+ " \"\"\"Validate that filename corresponds to labels for the MNIST dataset.\"\"\"\n",
+ " with tf.gfile.Open(filename, 'rb') as f:\n",
+ " magic = read32(f)\n",
+ " read32(f) # num_items, unused\n",
+ " if magic != 2049:\n",
+ " raise ValueError('Invalid magic number %d in MNIST file %s' % (magic,\n",
+ " f.name))\n",
+ " \n",
+ "def download(directory, filename):\n",
+ " \"\"\"Download (and unzip) a file from the MNIST dataset if not already done.\"\"\"\n",
+ " filepath = os.path.join(directory, filename)\n",
+ " if tf.gfile.Exists(filepath):\n",
+ " return filepath\n",
+ " if not tf.gfile.Exists(directory):\n",
+ " tf.gfile.MakeDirs(directory)\n",
+ " # CVDF mirror of http://yann.lecun.com/exdb/mnist/\n",
+ " url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz'\n",
+ " _, zipped_filepath = tempfile.mkstemp(suffix='.gz')\n",
+ " print('Downloading %s to %s' % (url, zipped_filepath))\n",
+ " urllib.request.urlretrieve(url, zipped_filepath)\n",
+ " with gzip.open(zipped_filepath, 'rb') as f_in, \\\n",
+ " tf.gfile.Open(filepath, 'wb') as f_out:\n",
+ " shutil.copyfileobj(f_in, f_out)\n",
+ " os.remove(zipped_filepath)\n",
+ " return filepath\n",
+ "\n",
+ "\n",
+ "def dataset(directory, images_file, labels_file):\n",
+ " \"\"\"Download and parse MNIST dataset.\"\"\"\n",
+ "\n",
+ " images_file = download(directory, images_file)\n",
+ " labels_file = download(directory, labels_file)\n",
+ "\n",
+ " check_image_file_header(images_file)\n",
+ " check_labels_file_header(labels_file)\n",
+ "\n",
+ " def decode_image(image):\n",
+ " # Normalize from [0, 255] to [0.0, 1.0]\n",
+ " image = tf.decode_raw(image, tf.uint8)\n",
+ " image = tf.cast(image, tf.float32)\n",
+ " image = tf.reshape(image, [28, 28, 1])\n",
+ " return image / 255.0\n",
+ "\n",
+ " def decode_label(label):\n",
+ " label = tf.decode_raw(label, tf.uint8) # tf.string -> [tf.uint8]\n",
+ " label = tf.reshape(label, []) # label is a scalar\n",
+ " return tf.to_int32(label)\n",
+ "\n",
+ " images = tf.data.FixedLengthRecordDataset(\n",
+ " images_file, 28 * 28, header_bytes=16).map(decode_image)\n",
+ " labels = tf.data.FixedLengthRecordDataset(\n",
+ " labels_file, 1, header_bytes=8).map(decode_label)\n",
+ " return tf.data.Dataset.zip((images, labels))\n",
+ "\n",
+ "\n",
+ "def get_training_data(directory):\n",
+ " \"\"\"tf.data.Dataset object for MNIST training data.\"\"\"\n",
+ " return dataset(directory, 'train-images-idx3-ubyte',\n",
+ " 'train-labels-idx1-ubyte').take(1024)\n",
+ "\n",
+ "def get_test_data(directory):\n",
+ " \"\"\"tf.data.Dataset object for MNIST test data.\"\"\"\n",
+ " return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte')"
+ ],
+ "execution_count": 11,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "4ejmJ2dv_f0R",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 85
+ },
+ "outputId": "274c0381-e505-4e69-f910-3def6f8572a7"
+ },
+ "cell_type": "code",
+ "source": [
+ "# Don't forget to run the cell above!\n",
+ "training_data = get_training_data(\"/tmp/mnist/train\")\n",
+ "test_data = get_test_data(\"/tmp/mnist/test\")"
+ ],
+ "execution_count": 12,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "Downloading https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz to /tmp/tmp4ull1xwa.gz\n",
+ "Downloading https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz to /tmp/tmp1eikhj1v.gz\n",
+ "Downloading https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz to /tmp/tmpcp8xah9c.gz\n",
+ "Downloading https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz to /tmp/tmpqww_1e74.gz\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "TANpFS6GKLMC",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "Fill in the implementation of `train_one_epoch` below and run the cell to train your model. "
+ ]
+ },
+ {
+ "metadata": {
+ "id": "btKL0Ss9_rmC",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 102
+ },
+ "outputId": "56858516-86fc-424a-f00d-6f088f98bf9b"
+ },
+ "cell_type": "code",
+ "source": [
+ "EPOCHS = 5\n",
+ "optimizer = tf.train.MomentumOptimizer(learning_rate=0.01, momentum=0.5)\n",
+ "\n",
+ "def loss_fn(logits, labels):\n",
+ " return tf.reduce_mean(\n",
+ " tf.nn.sparse_softmax_cross_entropy_with_logits(\n",
+ " logits=tf.squeeze(logits), labels=labels))\n",
+ "\n",
+ "def train_one_epoch(model, training_data, optimizer):\n",
+ " # TODO: Implement an optimization step and return the average loss.\n",
+ " #\n",
+ " # Hint: Use `tf.GradientTape` to compute the gradient of the loss, and use\n",
+ " # `optimizer.apply_gradients` to update the model's variables, which are\n",
+ " # accessible as `model.variables`\n",
+ " average_loss = tfe.metrics.Mean('loss')\n",
+ " for images, labels in training_data.shuffle(buffer_size=10000).batch(64):\n",
+ " pass\n",
+ " return average_loss.result()\n",
+ "\n",
+ "for epoch in range(EPOCHS):\n",
+ " loss = train_one_epoch(model, training_data, optimizer)\n",
+ " print(\"Average loss after epoch %d: %.4f\" % (epoch, loss))"
+ ],
+ "execution_count": 14,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "Average loss after epoch 0: 2.2847\n",
+ "Average loss after epoch 1: 2.2305\n",
+ "Average loss after epoch 2: 2.1334\n",
+ "Average loss after epoch 3: 1.9115\n",
+ "Average loss after epoch 4: 1.4285\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "yAOFupJN_htg",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 102
+ },
+ "outputId": "67e711e4-76c9-4e3f-bb49-a14955dba03a"
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title Double-click to see a solution.\n",
+ "EPOCHS = 5\n",
+ "optimizer = tf.train.MomentumOptimizer(learning_rate=0.01, momentum=0.5)\n",
+ "\n",
+ "def _loss_fn(logits, labels):\n",
+ " return tf.reduce_mean(\n",
+ " tf.nn.sparse_softmax_cross_entropy_with_logits(\n",
+ " logits=tf.squeeze(logits), labels=labels))\n",
+ "\n",
+ "def _train_one_epoch(model, training_data):\n",
+ " average_loss = tfe.metrics.Mean(\"loss\")\n",
+ " for images, labels in training_data.shuffle(buffer_size=10000).batch(64):\n",
+ " with tf.GradientTape() as tape:\n",
+ " logits = model(images, training=True)\n",
+ " loss = _loss_fn(logits, labels)\n",
+ " average_loss(loss)\n",
+ " gradients = tape.gradient(loss, model.variables)\n",
+ " optimizer.apply_gradients(zip(gradients, model.variables))\n",
+ " return average_loss.result()\n",
+ " \n",
+ "for epoch in range(EPOCHS):\n",
+ " loss = _train_one_epoch(model, training_data)\n",
+ " print(\"Average loss after epoch %d: %.4f\" % (epoch, loss))"
+ ],
+ "execution_count": 15,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "Average loss after epoch 0: 1.0563\n",
+ "Average loss after epoch 1: 0.8013\n",
+ "Average loss after epoch 2: 0.6306\n",
+ "Average loss after epoch 3: 0.5543\n",
+ "Average loss after epoch 4: 0.5037\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "uDy1DrYA_2Jz",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "Run the below cell to qualitatively evaluate your model. Note how eager execution interoperates seamlessly with `matplotlib`."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "vR7rMtpu_3nB",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 1752
+ },
+ "outputId": "b212aefa-f4b3-425c-f34d-2491429fa521"
+ },
+ "cell_type": "code",
+ "source": [
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "sampled_data = test_data.batch(1).shuffle(buffer_size=10000).take(5)\n",
+ "for image, label in sampled_data:\n",
+ " plt.figure()\n",
+ " plt.imshow(tf.reshape(image, (28, 28)))\n",
+ " plt.show()\n",
+ " logits = model(image, training=False)\n",
+ " prediction = tf.argmax(logits, axis=1, output_type=tf.int64)\n",
+ " print(\"Prediction: %d\" % prediction)"
+ ],
+ "execution_count": 16,
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUsAAAFKCAYAAACU6307AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAEwpJREFUeJzt3X1Ilff/x/HXmScxV2GZOmLVohXK\nKmLQjbUsy+pbI7rbaEm1IFhRSU1aE+kO3LqxCGrBMlsNkq0zZIM2Cu1mUTg1itXQbVnBQqKZNtcN\n2d3J3x9ffpLrNN/ndM65jn6fj7/m5cfrvI9XPHedc7zOcTU3NzcLAPCvXnJ6AABoD4glABgQSwAw\nIJYAYEAsAcCAWAKAAbEEAANiCQAG7kB/cOPGjbpw4YJcLpdyc3M1ZMiQYM4FABEloFieOXNGV69e\nlcfj0ZUrV5SbmyuPxxPs2QAgYgT0MLy8vFwZGRmSpP79++vWrVu6e/duUAcDgEgSUCwbGhrUvXv3\nlq979Oih+vr6oA0FAJEmKC/w8F4cADq6gGKZmJiohoaGlq9v3LihhISEoA0FAJEmoFiOHj1aJSUl\nkqTq6molJiaqS5cuQR0MACJJQK+Gv/nmm3rjjTf03nvvyeVyaf369cGeCwAiios3/wWAtnEFDwAY\nEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkA\nBsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgC\ngAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMHA7\nPQAQiAcPHpjX3rlzx+f2nj17qqGhodW2kydPmvb566+/mm//xx9/NK+13r4kjRgx4pltFRUVGjly\nZKttP/30k3mfL73E+dPz8JsBAIOAziwrKyu1YsUKDRgwQJI0cOBArV27NqiDAUAkCfhh+PDhw7Vz\n585gzgIAEYuH4QBgEHAsL1++rCVLlmju3LkqKysL5kwAEHFczc3Nzf7+UF1dnc6dO6cpU6aotrZW\nCxYsUGlpqaKjo0MxIwA4LqDnLJOSkjR16lRJUp8+fdSzZ0/V1dWpd+/eQR0OeB7+dIg/HQq3gH4z\nhw4d0hdffCFJqq+v182bN5WUlBTUwQAgkgR0Zjl+/HitWrVKx48f16NHj7RhwwYeggPo0AKKZZcu\nXbR79+5gzwIAESugF3gAf1RVVZnXfvfdd6Z1hw8fNu/zzJkzPrd7vV5FRUWZ99Me+LpPDx8+NP98\nR/t9BBPP5gKAAbEEAANiCQAGxBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAAM+3RGtPO/qV5fL\n1ep7BQUF5n1mZWWZ1z558sS8NhRcLpdpnT9vZebPJYT9+vUzry0pKfG5/Y8//mj1NW+7Fhz8FgHA\ngFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgCt40MrBgwd9bp87d26r7y1btsy8z1de\necW89q233jKte//99837/Dfff/99q68TExNNP/fqq6+ab8Of+x8MvXv3Duvt/a/gzBIADIglABgQ\nSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABi4mp/3CVXoMB49emRe+/rrr/vcfvXqVfXt\n27fl68zMTPM+P/74Y/PauLg481ognDizBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGx\nBAADYgkABny6YztVX19vXjthwgTz2oEDB5q+l5eXZ96n223/Z/b48WPTuuvXr5v3efz4cZ/bFy5c\nqC+//NK8n0CNHTvWvLZfv34hnAQvwnRmWVNTo4yMDBUVFUn67z/U+fPnKzMzUytWrNDDhw9DOiQA\nOK3NWN67d095eXlKTU1t2bZz505lZmbqq6++Ut++fVVcXBzSIQHAaW3GMjo6WoWFha0+fL6ysrLl\noV16errKy8tDNyEARIA2n0xyu93PPOfU1NSk6OhoSVJ8fLxfz58BQHv0wi/w8HaYzkhISDCv/eWX\nX4Jym0ePHg3Kfv6N9cWg3r17m/e5cOHCgL4HPC2gWMbGxur+/fuKiYlRXV1dq4foCI9QvRqelJTk\nc/vRo0c1ceLElq+PHDli3ievhvNqeEcQ0N9Zjho1SiUlJZKk0tJSjRkzJqhDAUCkafN/+VVVVdqy\nZYuuXbsmt9utkpISbdu2TTk5OfJ4POrVq5dmzJgRjlkBwDFtxnLQoEE6cODAM9v3798fkoEAIBLx\ngWXt1A8//GBeO3v2bPPa572Ik5aWplOnTrV8ff78efM+J02aZF5rnfX333837/N5vF6voqKiAvrZ\nd99917x20KBB5rWrVq0yr42JiTGvxYvj2nAAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAY\nEEsAMCCWAGDA5Y7tlD+X23377bcvfHv/vDTQn7cS8+ft1NLS0kzr/Ln/o0aN8rk9OTn5mcsmO3Xq\nZNrn7du3zbc/YsQI89q9e/ea1y5YsMC8Fi+OM0sAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyI\nJQAYEEsAMCCWAGDQ5kfhIjItXrzYvHb06NHmtRcvXnzu9z744IOW//bnUruhQ4ea11ovN3S7g/NP\nNzk5OaCfe/qTLtvi9XrNa/351E4udwwvziwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBg\nQCwBwIAreNqpjIyMkKz9N59//nlQ9tMRPHjwwOkREGacWQKAAbEEAANiCQAGxBIADIglABgQSwAw\nIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADEyxrKmpUUZGhoqKiiRJOTk5mjZtmubPn6/58+fr5MmT\noZwRABzX5rsO3bt3T3l5eUpNTW21PTs7W+np6SEbDAAiSZtnltHR0SosLFRiYmI45gGAiNTmmaXb\n7Zbb/eyyoqIi7d+/X/Hx8Vq7dq169OgRkgGBSDRx4kTzWq/XG8JJEC4Bvfnv9OnTFRcXp5SUFO3Z\ns0e7du3SunXrgj0bELGOHj1qXvuf//zHvHb27Nnmtd988415LV5cQK+Gp6amKiUlRZI0fvx41dTU\nBHUoAIg0AcUyKytLtbW1kqTKykoNGDAgqEMBQKRp82F4VVWVtmzZomvXrsntdqukpETz5s3TypUr\n1blzZ8XGxmrTpk3hmBUAHNNmLAcNGqQDBw48s33y5MkhGQgAIhGf7ggEgAsx/vdwuSMAGBBLADAg\nlgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADDgckcgAKdPnw7JfqdNmxaS/eLFcWYJAAbE\nEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAZcwQM85dSpU6Z1P//8s3mfL7/8snntuHHj\nzGsRXpxZAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAAy53RIf3999/+9we\nFxf3zPcyMjJM+/R6vebbP3jwoHlt7969zWsRXpxZAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBg\nQCwBwIBYAoABsQQAAy53DIMnT56Y1+bm5prWbdiwwbzPmJgY89r24u7du+a1b7/9ts/tZWVlz3zP\nehnjO++8Y7792bNnm9cicplimZ+fr3Pnzunx48davHixBg8erNWrV8vr9SohIUFbt25VdHR0qGcF\nAMe0GcuKigpdunRJHo9HjY2NmjlzplJTU5WZmakpU6Zo+/btKi4uVmZmZjjmBQBHtPmc5bBhw7Rj\nxw5JUrdu3dTU1KTKykpNmDBBkpSenq7y8vLQTgkADmszllFRUYqNjZUkFRcXKy0tTU1NTS0Pu+Pj\n41VfXx/aKQHAYeYXeI4dO6bi4mLt27dPkyZNatne3NwcksE6kpdesv/RwebNm0M4ScfRpUsX89qy\nsrKAvgc8zRTL06dPa/fu3dq7d6+6du2q2NhY3b9/XzExMaqrq1NiYmKo52zXeDU8+Px5NXzy5Mk+\nt5eVlWn06NGttlVUVJj26c+r4V9//bV5rT//Y0V4tXlk7ty5o/z8fBUUFCguLk6SNGrUKJWUlEiS\nSktLNWbMmNBOCQAOa/PM8vDhw2psbNTKlStbtm3evFlr1qyRx+NRr169NGPGjJAOCQBOazOWc+bM\n0Zw5c57Zvn///pAMBACRyNXMKzQh58+HW1n/uP/TTz817zM7Ozvotx8qv/32m2nd0qVLzfs8deqU\nz+1er1dRUVHm/TyturravDY5OTmg20Bk4dlkADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUA\nGBBLADAglgBgwOWOYeDP5Y4JCQmmdbdu3TLvc+LEiea148aN87k9Jycn4PfavH//vnntJ598Ylrn\nzz/bbt26+dze2Nio7t27t9p28eJF0z6tx0mSXC6XeS0iF2eWAGBALAHAgFgCgAGxBAADYgkABsQS\nAAyIJQAYEEsAMCCWAGBALAHAgMsdI0xxcbFp3bJly8z7bGhoCHScFi/ySYj++Oflh88zefJk8z4/\n+ugjn9uHDh2q8+fPP7MN8IUzSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABgQSwAw4Aqe\ndqqmpsa8Njs727z2yJEjPre/yBU8q1evNq8dPHiwaV1mZmZAswCB4swSAAyIJQAYEEsAMCCWAGBA\nLAHAgFgCgAGxBAADYgkABsQSAAyIJQAYcLkjABi4LYvy8/N17tw5PX78WIsXL9aJEydUXV2tuLg4\nSdKiRYs0bty4UM4JAI5qM5YVFRW6dOmSPB6PGhsbNXPmTI0cOVLZ2dlKT08Px4wA4Lg2Yzls2DAN\nGTJEktStWzc1NTXJ6/WGfDAAiCR+PWfp8Xh09uxZRUVFqb6+Xo8ePVJ8fLzWrl2rHj16hHJOAHCU\nOZbHjh1TQUGB9u3bp6qqKsXFxSklJUV79uzRn3/+qXXr1oV6VgBwjOlPh06fPq3du3ersLBQXbt2\nVWpqqlJSUiRJ48eP9+uNaAGgPWozlnfu3FF+fr4KCgpaXv3OyspSbW2tJKmyslIDBgwI7ZQA4LA2\nX+A5fPiwGhsbtXLlypZts2bN0sqVK9W5c2fFxsZq06ZNIR0SAJzGH6UDgAGXOwKAAbEEAANiCQAG\nxBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABgQSwAwIJYAYEAsAcCAWAKA\nAbEEAANiCQAGxBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABi4nbjRjRs3\n6sKFC3K5XMrNzdWQIUOcGCOoKisrtWLFCg0YMECSNHDgQK1du9bhqQJXU1OjpUuXauHChZo3b56u\nX7+u1atXy+v1KiEhQVu3blV0dLTTY/rln/cpJydH1dXViouLkyQtWrRI48aNc3ZIP+Xn5+vcuXN6\n/PixFi9erMGDB7f74yQ9e79OnDjh+LEKeyzPnDmjq1evyuPx6MqVK8rNzZXH4wn3GCExfPhw7dy5\n0+kxXti9e/eUl5en1NTUlm07d+5UZmampkyZou3bt6u4uFiZmZkOTukfX/dJkrKzs5Wenu7QVC+m\noqJCly5dksfjUWNjo2bOnKnU1NR2fZwk3/dr5MiRjh+rsD8MLy8vV0ZGhiSpf//+unXrlu7evRvu\nMfAvoqOjVVhYqMTExJZtlZWVmjBhgiQpPT1d5eXlTo0XEF/3qb0bNmyYduzYIUnq1q2bmpqa2v1x\nknzfL6/X6/BUDsSyoaFB3bt3b/m6R48eqq+vD/cYIXH58mUtWbJEc+fOVVlZmdPjBMztdismJqbV\ntqamppaHc/Hx8e3umPm6T5JUVFSkBQsW6MMPP9Rff/3lwGSBi4qKUmxsrCSpuLhYaWlp7f44Sb7v\nV1RUlOPHypHnLJ/W3Nzs9AhB8dprr2n58uWaMmWKamtrtWDBApWWlrbL54va0lGO2fTp0xUXF6eU\nlBTt2bNHu3bt0rp165wey2/Hjh1TcXGx9u3bp0mTJrVsb+/H6en7VVVV5fixCvuZZWJiohoaGlq+\nvnHjhhISEsI9RtAlJSVp6tSpcrlc6tOnj3r27Km6ujqnxwqa2NhY3b9/X5JUV1fXIR7OpqamKiUl\nRZI0fvx41dTUODyR/06fPq3du3ersLBQXbt27TDH6Z/3KxKOVdhjOXr0aJWUlEiSqqurlZiYqC5d\nuoR7jKA7dOiQvvjiC0lSfX29bt68qaSkJIenCp5Ro0a1HLfS0lKNGTPG4YleXFZWlmprayX99znZ\n//9Lhvbizp07ys/PV0FBQcurxB3hOPm6X5FwrFzNDpyrb9u2TWfPnpXL5dL69euVnJwc7hGC7u7d\nu1q1apVu376tR48eafny5Ro7dqzTYwWkqqpKW7Zs0bVr1+R2u5WUlKRt27YpJydHDx48UK9evbRp\n0yZ16tTJ6VHNfN2nefPmac+ePercubNiY2O1adMmxcfHOz2qmcfj0WeffaZ+/fq1bNu8ebPWrFnT\nbo+T5Pt+zZo1S0VFRY4eK0diCQDtDVfwAIABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwOD/\nAKCzFeFbFn4BAAAAAElFTkSuQmCC\n",
+ "text/plain": [
+ "<matplotlib.figure.Figure at 0x7fd61cfd1e80>"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ }
+ },
+ {
+ "output_type": "stream",
+ "text": [
+ "Prediction: 5\n"
+ ],
+ "name": "stdout"
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUsAAAFKCAYAAACU6307AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAEQ1JREFUeJzt3W9Ilff/x/HXSSd2VmKaRwiqjTBy\nq9gfap2iliaFQfRvsCXW1rpRRJGTCJG0MSHLIpbF8M9qN3L7cjZvNQiOVAQt7LQcBLqB1Y0QaXYs\naUa2mZ3fjS9ff7Vcvj2ec65jez7ueZ1P57wPlzy7Li8vjysUCoUEAHihcU4PAABjAbEEAANiCQAG\nxBIADIglABgQSwAwIJYAYEAsAcAgMdx/uH//fl27dk0ul0ulpaWaO3duJOcCgLgSViyvXLmiW7du\nyefz6ebNmyotLZXP54v0bAAQN8I6DW9ublZeXp4kacaMGbp//74ePHgQ0cEAIJ6EFcvu7m5NmjRp\n8Ou0tDQFg8GIDQUA8SYiF3j4WxwAXnZhxdLj8ai7u3vw6zt37igjIyNiQwFAvAkrlosWLZLf75ck\ntbW1yePxaMKECREdDADiSVhXw9955x29+eab+uijj+RyubRv375IzwUAccXFH/8FgOFxBw8AGBBL\nADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbE\nEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoAB\nsQQAA2IJAAbEEgAMEp0eAIgnP/30k2nd+vXrzc+Zl5dnXvvtt9+a1yK2OLIEAANiCQAGxBIADIgl\nABgQSwAwIJYAYEAsAcCAWAKAAbEEAAPu4AGecuzYMdO6YDBofk6XyxXuOIgjHFkCgEFYR5aBQEC7\ndu1SVlaWJGnmzJkqKyuL6GAAEE/CPg2fP3++qqurIzkLAMQtTsMBwCDsWN64cUPbtm3Thg0bdOnS\npUjOBABxxxUKhUIj/UddXV1qaWlRfn6+Ojo6tGnTJjU1NSkpKSkaMwKA48L6mWVmZqZWrlwpSZo2\nbZomT56srq4uTZ06NaLDAbH24Ycfmtb98MMP5ucsKCgwr21oaDCvRWyFdRp++vRpnThxQtJ/f9/s\n7t27yszMjOhgABBPwjqyzM3N1e7du3Xu3Dn19/fr888/5xQcwEstrFhOmDBBNTU1kZ4FAOIWtzsC\nT7lw4ULEn3PVqlURf07EHr9nCQAGxBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIA\nDLjdES89v98/5PYVK1Y899hIPrXRqre3N+LPidjjyBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEE\nAANiCQAGxBIADLiDB2NSKBQyr21oaBhy+4oVK/7xsUh6++23o/4aiD6OLAHAgFgCgAGxBAADYgkA\nBsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgIErNJL7xoA40dnZaV47derUIbc/efJE48aFd7zw7rvv\nmtf+/PPPYb0G4gtHlgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBPd8SY\nVFlZ6ejrb9682dHXR+yZjizb29uVl5c3+LGht2/f1saNG1VQUKBdu3bpr7/+iuqQAOC0YWP58OFD\nVVRUyOv1Dm6rrq5WQUGBvvvuO02fPl2NjY1RHRIAnDZsLJOSklRfXy+PxzO4LRAIaNmyZZKknJwc\nNTc3R29CAIgDw/7MMjExUYmJzy7r6+tTUlKSJCk9PV3BYDA60wFAnBj1BR7+HCaccPz48YisffLk\nSSTGwb9AWLF0u9169OiRkpOT1dXV9cwpOhALO3bsMK/96quvhtw+mj/+O5JYb9++PazXQHwJ6ztl\n4cKF8vv9kqSmpiYtXrw4okMBQLwZ9siytbVVBw8eVGdnpxITE+X3+3X48GGVlJTI5/NpypQpWrNm\nTSxmBQDHDBvL2bNn69SpU89t/+abb6IyEADEI+7gQVyxXnCJ1oeAWX/+XlhYGJXXR/zi3nAAMCCW\nAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGDA7Y6IKxUVFaZ10brd8dVXXzWt6+3t\nNT9nSkpKuOMgjnBkCQAGxBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADLjdEXHl\nyy+/dPT1BwYGTOv8fr/5OT/99NNwx0Ec4cgSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkA\nBsQSAAy4gwdR99tvv5nXjuSDwKzcbrf5sV9++cX0nGlpaaOaCWMPR5YAYEAsAcCAWAKAAbEEAANi\nCQAGxBIADIglABgQSwAwIJYAYEAsAcCA2x0RFusHe0kj+xCyJ0+ehDPOC507d878GLcx4p9wZAkA\nBqZYtre3Ky8vTw0NDZKkkpISrVq1Shs3btTGjRt14cKFaM4IAI4b9jT84cOHqqiokNfrfWZ7cXGx\ncnJyojYYAMSTYY8sk5KSVF9fL4/HE4t5ACAuuUKhUMiy8NixY5o0aZIKCwtVUlKiYDCo/v5+paen\nq6ysjB+MA3iphXU1fPXq1UpNTVV2drbq6up0/PhxlZeXR3o2xLGRXA3fvn27eW19fX0447xQc3Pz\nkNvfe+89BQKB57YBQwnrarjX61V2drYkKTc3V+3t7REdCgDiTVix3Llzpzo6OiRJgUBAWVlZER0K\nAOLNsKfhra2tOnjwoDo7O5WYmCi/36/CwkIVFRVp/PjxcrvdqqysjMWsAOCYYWM5e/ZsnTp16rnt\nK1asiMpAABCPzFfDgafdu3fPvHby5MkRf/0PPvjAvPY///nPkNsTEhKeu1CVkJAwqrnw8uJ2RwAw\nIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABgQSwAwIJYAYMCnO+IZ//TpiuPGjXvmsc2bN0fl\n9V0ul2ndF198YX7OF93CyO2NsOLIEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAM\nuIMHz/jf58H/3fTp05957Mcff4zK6xcWFprWzZo1KyqvD/wTjiwBwIBYAoABsQQAA2IJAAbEEgAM\niCUAGBBLADAglgBgQCwBwIBYAoABtzviGRcuXBhy+8cff/zMY6FQKCqvX15eHpXnBUaLI0sAMCCW\nAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGDgCkXrvjXEjV9//dW8ds6cOUNuHxgY\nUEJCwuDXI/m2Wb9+vXmtz+czrRs3jv/nEVume8OrqqrU0tKix48fa+vWrZozZ4727NmjgYEBZWRk\n6NChQ0pKSor2rADgmGFjefnyZV2/fl0+n089PT1au3atvF6vCgoKlJ+fryNHjqixsVEFBQWxmBcA\nHDHsucy8efN09OhRSVJKSor6+voUCAS0bNkySVJOTo6am5ujOyUAOGzYWCYkJMjtdkuSGhsbtWTJ\nEvX19Q2edqenpysYDEZ3SgBwmPnvWZ49e1aNjY06efKkli9fPrid60Px74033jCvHRgYCOsx4GVn\niuXFixdVU1Ojr7/+WhMnTpTb7dajR4+UnJysrq4ueTyeaM+JUeBqODB6w37H9fb2qqqqSrW1tUpN\nTZUkLVy4UH6/X5LU1NSkxYsXR3dKAHDYsEeWZ86cUU9Pj4qKiga3HThwQHv37pXP59OUKVO0Zs2a\nqA4JAE7jl9L/BTgNB0aPDyz7F7AGSHpxBJ9+LCUlxfycJ06cMK8lgohXfGcCgAGxBAADYgkABsQS\nAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADbnf8F7hx44Z5rfV2x+TkZPNzjuTWSCBecWQJAAbE\nEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMuN3xX6C4uNi89vvvv//HxxIT///b\n5a233hrVTMBYw5ElABgQSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABi4Qi/6hCoAgCSO\nLAHAhFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAAD06c7VlVVqaWlRY8f\nP9bWrVt1/vx5tbW1KTU1VZK0ZcsWLV26NJpzAoCjho3l5cuXdf36dfl8PvX09Gjt2rVasGCBiouL\nlZOTE4sZAcBxw8Zy3rx5mjt3riQpJSVFfX19GhgYiPpgABBPRvQn2nw+n65evaqEhAQFg0H19/cr\nPT1dZWVlSktLi+acAOAocyzPnj2r2tpanTx5Uq2trUpNTVV2drbq6ur0+++/q7y8PNqzAoBjTFfD\nL168qJqaGtXX12vixInyer3Kzs6WJOXm5qq9vT2qQwKA04aNZW9vr6qqqlRbWzt49Xvnzp3q6OiQ\nJAUCAWVlZUV3SgBw2LAXeM6cOaOenh4VFRUNblu3bp2Kioo0fvx4ud1uVVZWRnVIAHAan8EDAAbc\nwQMABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHA\ngFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsA\nMCCWAGCQ6MSL7t+/X9euXZPL5VJpaanmzp3rxBgRFQgEtGvXLmVlZUmSZs6cqbKyMoenCl97e7u2\nb9+uTz75RIWFhbp9+7b27NmjgYEBZWRk6NChQ0pKSnJ6zBH5+3sqKSlRW1ubUlNTJUlbtmzR0qVL\nnR1yhKqqqtTS0qLHjx9r69atmjNnzpjfT9Lz7+v8+fOO76uYx/LKlSu6deuWfD6fbt68qdLSUvl8\nvliPERXz589XdXW102OM2sOHD1VRUSGv1zu4rbq6WgUFBcrPz9eRI0fU2NiogoICB6ccmaHekyQV\nFxcrJyfHoalG5/Lly7p+/bp8Pp96enq0du1aeb3eMb2fpKHf14IFCxzfVzE/DW9ublZeXp4kacaM\nGbp//74ePHgQ6zHwAklJSaqvr5fH4xncFggEtGzZMklSTk6OmpubnRovLEO9p7Fu3rx5Onr0qCQp\nJSVFfX19Y34/SUO/r4GBAYenciCW3d3dmjRp0uDXaWlpCgaDsR4jKm7cuKFt27Zpw4YNunTpktPj\nhC0xMVHJycnPbOvr6xs8nUtPTx9z+2yo9yRJDQ0N2rRpkz777DPdu3fPgcnCl5CQILfbLUlqbGzU\nkiVLxvx+koZ+XwkJCY7vK0d+Zvm0UCjk9AgR8dprr2nHjh3Kz89XR0eHNm3apKampjH586LhvCz7\nbPXq1UpNTVV2drbq6up0/PhxlZeXOz3WiJ09e1aNjY06efKkli9fPrh9rO+np99Xa2ur4/sq5keW\nHo9H3d3dg1/fuXNHGRkZsR4j4jIzM7Vy5Uq5XC5NmzZNkydPVldXl9NjRYzb7dajR48kSV1dXS/F\n6azX61V2drYkKTc3V+3t7Q5PNHIXL15UTU2N6uvrNXHixJdmP/39fcXDvop5LBctWiS/3y9Jamtr\nk8fj0YQJE2I9RsSdPn1aJ06ckCQFg0HdvXtXmZmZDk8VOQsXLhzcb01NTVq8eLHDE43ezp071dHR\nIem/P5P9328yjBW9vb2qqqpSbW3t4FXil2E/DfW+4mFfuUIOHKsfPnxYV69elcvl0r59+zRr1qxY\njxBxDx480O7du/XHH3+ov79fO3bs0Pvvv+/0WGFpbW3VwYMH1dnZqcTERGVmZurw4cMqKSnRn3/+\nqSlTpqiyslKvvPKK06OaDfWeCgsLVVdXp/Hjx8vtdquyslLp6elOj2rm8/l07Ngxvf7664PbDhw4\noL17947Z/SQN/b7WrVunhoYGR/eVI7EEgLGGO3gAwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAg\nlgBg8H/nb4OLnfGqVAAAAABJRU5ErkJggg==\n",
+ "text/plain": [
+ "<matplotlib.figure.Figure at 0x7fd61bade5c0>"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ }
+ },
+ {
+ "output_type": "stream",
+ "text": [
+ "Prediction: 1\n"
+ ],
+ "name": "stdout"
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUsAAAFKCAYAAACU6307AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAE1ZJREFUeJzt3X1olfX/x/HXccc1DyrLuY1GaRGL\nRqZSaE7zZmqKgnhDsVwqkYGRE29QW8tp4M102solNJ03fzSqgyPoBmFDIlg1Jw0xNsrZDbKGranD\nG5x3x33/+NF+rp153js751znrOfjv13n43Xex4NPrrPL61yujo6ODgEA7muA0wMAQCwglgBgQCwB\nwIBYAoABsQQAA2IJAAbEEgAMiCUAGLiD/YM7duzQ6dOn5XK5lJ+fr9GjR4dyLgCIKkHF8uTJkzp3\n7py8Xq9+++035efny+v1hno2AIgaQX0Mr6mp0cyZMyVJjz/+uC5fvqxr166FdDAAiCZBxfLChQt6\n8MEHO38eNmyYWltbQzYUAESbkJzg4bs4APR3QcUyJSVFFy5c6Pz577//VnJycsiGAoBoE1QsJ02a\npMrKSklSQ0ODUlJSNHjw4JAOBgDRJKiz4c8884yeeuopvfzyy3K5XNqyZUuo5wKAqOLiy38BIDCu\n4AEAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBg\nQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUA\nGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJ\nAAbEEgAMiCUAGLiD+UO1tbVavXq10tPTJUlPPPGECgoKQjoYAESToGIpSePHj1dJSUkoZwGAqMXH\ncAAwCDqWv/76q9544w0tXrxY33//fShnAoCo4+ro6Ojo7R9qaWlRXV2d5syZo6amJi1btkxVVVWK\nj48Px4wA4LigjixTU1M1d+5cuVwujRgxQsOHD1dLS0uoZwOAqBFULL/88ksdOnRIktTa2qqLFy8q\nNTU1pIMBQDQJ6mP4tWvXtH79el25ckW3b99Wbm6upk6dGo75ACAqBBVLAPivCfr/WQL90alTp0zr\nSktLzfssKysLdpz78nec09HRIZfL1WVbbm6ueZ+9+b/T/36e/o7/ZwkABsQSAAyIJQAYEEsAMCCW\nAGBALAHAgFgCgAGxBAADYgkABsQSAAy4Nhz93tmzZ/1uT09P7/bY4sWLTfu0XhYZaT6fT3FxcUH/\n+Vu3bpnX9uV5YhFHlgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgwA3LEHa9uUjs\nzJkzpnXz588377Opqcnv9uvXr2vMmDFdtt28edO8Xyu32/7PrKCgwLw2Pj7e7/bCwsIuPz/77LPm\nfQ4YwPFTT/ibAQADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABtywDEG5ffu2\nee1bb71lXrt3795gxgmKv5t7PfTQQ6Y/u3r1avPzLF++3Lz2yJEj5rW5ubndtj3wwAPdLtl84IEH\nzPtEzziyBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABtzdEV3cvXvX7/YB\nAwZ0eSwvL8+8z0hewujPokWLzI999NFHpn16PB7z8y9evNi89uuvvzavbW5u7ratuLhYb7/9drdt\n6DvTkWVjY6Nmzpyp8vJySdL58+e1dOlS5eTkaPXq1bp161ZYhwQApwWM5fXr17V161ZlZmZ2bisp\nKVFOTo4++eQTjRw5UhUVFWEdEgCcFjCW8fHxKisrU0pKSue22tpazZgxQ5KUlZWlmpqa8E0IAFEg\n4O8s3W633O6uy9rb2xUfHy9JSkpKUmtra3imA4Ao0ecTPHwdZv8yYEDPHzbufey9994z77M3ayPt\n6NGjYX+OL774IuzPcS9O6IRHULH0eDy6ceOGEhIS1NLS0uUjOmKb9Wz4hg0bzPv84IMP+jxXX/R0\nNvzo0aN66aWXumyLpbPh/r6AuLi4WOvWreu2DX0X1P+znDhxoiorKyVJVVVVmjx5ckiHAoBoE/DI\nsr6+Xrt27VJzc7PcbrcqKyu1Z88e5eXlyev1Ki0tTQsWLIjErADgmICxHDVqlD7++ONu23tzrxAA\niHVcwfMf8Ndff5nXzpo1y+/2n376SWPHju38uaGhoc9z+TN06FDTutLSUvM+X3zxxR4f++yzz7r8\nfL8TXPf69NNPzc/fm99D9kZaWlqvtqNvuDYcAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkA\nBsQSAAyIJQAYuDr4QsqYdPXqVfPaUaNGmdf++eeffrf7fD7FxcWZ93Ovf75V3+LQoUOmdY888khQ\nswRivUVKdnZ2WJ7/ny/Vtjh16lS3bU8++aR++eWXbtvQdxxZAoABsQQAA2IJAAbEEgAMiCUAGBBL\nADAglgBgQCwBwIBYAoABsQQAA+7uGKPKy8vNa3u6hLEvlixZYl67Z88e89rk5GTTupaWFvM+X3/9\ndb/bv/rqK82bN6/LtsrKSvN+w6E3d43s6TJGLm8MD44sAcCAWAKAAbEEAANiCQAGxBIADIglABgQ\nSwAwIJYAYEAsAcCAG5ZFmbt375rWvfDCC+Z9fvvtt+a1Pd0wq729XYMGDer8ubGx0bzPtLQ089qf\nf/7ZtG7Dhg3mfVZVVfnd3pebsIXLjRs3zGsHDhwYxknwbxxZAoABsQQAA2IJAAbEEgAMiCUAGBBL\nADAglgBgQCwBwIBYAoABsQQAA25YFmWsV5/25hLG3vD5fKbHiouLzfv8448/zGu/+uor89pYsWDB\nAvPaaLv8Ev+PI0sAMDDFsrGxUTNnzuy8/WpeXp7mzZunpUuXaunSpWE7ygGAaBHwY/j169e1detW\nZWZmdtm+bt06ZWVlhW0wAIgmAY8s4+PjVVZWppSUlEjMAwBRKeCRpdvtltvdfVl5ebmOHDmipKQk\nFRQUaNiwYWEZ8L/G+gv++52ICZdbt25F/DnDzYm/R8SmoM6Gz58/X4mJicrIyNCBAwe0b98+bd68\nOdSz/SdZ//H29CW9fdVTrG/dutXlOVeuXGneZ7SeDY/Ul//25mz40aNHzWsHDOD8bCQF9bedmZmp\njIwMSdL06dN79a3ZABCLgorlqlWr1NTUJEmqra1Venp6SIcCgGgT8GN4fX29du3apebmZrndblVW\nVmrJkiVas2aNBg0aJI/Ho8LCwkjMCgCOCRjLUaNG6eOPP+62ffbs2WEZCACiEZc7ogvr5Y4lJSWR\nGKdf6M0JHk7aRC/eGQAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABgQSwAwIJYAYMDljlHG\nernbsWPHzPvszeV24fiC3958MfT69etN6/Lz84MdJyS2bdtmXvvKK6+EcRJECkeWAGBALAHAgFgC\ngAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGDAFTxRxuVymdb15u6ap06dMq+9dOlSj49VV1eb93Ov\nsWPHmtfW1dUF9RyhMmbMGNO6lStXmvfJTcj6B95FADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAM\niCUAGBBLADAglgBg4Oro6Ohwegj0b21tbea1kyZNMq07c+ZMsON08vl8iouL67Lthx9+MP3Z5557\nrs/Pj9jCkSUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADDg7o4Iu5MnT5rX\nhuIyxn/Ly8szPzZ+/PiQPz/6B1Msi4qKVFdXpzt37mjFihV6+umntXHjRvl8PiUnJ2v37t2Kj48P\n96wA4JiAsTxx4oTOnj0rr9ertrY2LVy4UJmZmcrJydGcOXNUXFysiooK5eTkRGJeAHBEwN9Zjhs3\nTnv37pUkDR06VO3t7aqtrdWMGTMkSVlZWaqpqQnvlADgsICxjIuLk8fjkSRVVFRoypQpam9v7/zY\nnZSUpNbW1vBOCQAOM5/gOX78uCoqKnT48GHNmjWrcztfh4lAZs+ebV7r8/nCOEl327dvj+jzIXaZ\nYlldXa3S0lIdPHhQQ4YMkcfj0Y0bN5SQkKCWlhalpKSEe07EsMrKSvPauXPnhvz5ezobvn37dr3z\nzjtdtm3bts20T5fL1ee5EFsCfgy/evWqioqKtH//fiUmJkqSJk6c2PkPoKqqSpMnTw7vlADgsIBH\nlseOHVNbW5vWrFnTuW3nzp3atGmTvF6v0tLStGDBgrAOCQBOCxjL7OxsZWdnd9t+5MiRsAwEANGI\nG5YhKL25CVlGRoZ5bTj+Z8Xvv//ud/vIkSN17ty5btsAf7g2HAAMiCUAGBBLADAglgBgQCwBwIBY\nAoABsQQAA2IJAAbEEgAMiCUAGHDDMgSlrKzMvDYclzDm5uaa16alpQX1GHAvjiwBwIBYAoABsQQA\nA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABlzuiizt37vjd7na7uzz2+eefh+X5V61aZVr3\n/vvvm/fpcrl6fGzgwIHm/eC/jSNLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADBw\ndXR0dDg9BKLHd99953f7888/3+WxqVOnmvf58MMPm9eeOXPGtC4hIcG8TyAUOLIEAANiCQAGxBIA\nDIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAG3LAMXQwZMiSox+5ny5Yt5rVcxohoZYplUVGR\n6urqdOfOHa1YsULffPONGhoalJiYKElavny5pk2bFs45AcBRAWN54sQJnT17Vl6vV21tbVq4cKEm\nTJigdevWKSsrKxIzAoDjAsZy3LhxGj16tCRp6NCham9vl8/nC/tgABBNAp7giYuLk8fjkSRVVFRo\nypQpiouLU3l5uZYtW6a1a9fq0qVLYR8UAJxk/j7L48ePa//+/Tp8+LDq6+uVmJiojIwMHThwQH/9\n9Zc2b94c7lkBwDGmEzzV1dUqLS3VwYMHNWTIEGVmZnY+Nn36dL377rvhmg8Rdvr0ab/bx4wZ0+Wx\nZ555xrzPsrIy89rXXnvNvBaIpIAfw69evaqioiLt37+/8+z3qlWr1NTUJEmqra1Venp6eKcEAIcF\nPLI8duyY2tratGbNms5tixYt0po1azRo0CB5PB4VFhaGdUgAcFrAWGZnZys7O7vb9oULF4ZlIACI\nRlzuCAAG3N0RAAw4sgQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAM\niCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQA\nA2IJAAbEEgAM3E486Y4dO3T69Gm5XC7l5+dr9OjRTowRUrW1tVq9erXS09MlSU888YQKCgocnip4\njY2NevPNN/Xqq69qyZIlOn/+vDZu3Cifz6fk5GTt3r1b8fHxTo/ZK/9+TXl5eWpoaFBiYqIkafny\n5Zo2bZqzQ/ZSUVGR6urqdOfOHa1YsUJPP/10zL9PUvfX9c033zj+XkU8lidPntS5c+fk9Xr122+/\nKT8/X16vN9JjhMX48eNVUlLi9Bh9dv36dW3dulWZmZmd20pKSpSTk6M5c+aouLhYFRUVysnJcXDK\n3vH3miRp3bp1ysrKcmiqvjlx4oTOnj0rr9ertrY2LVy4UJmZmTH9Pkn+X9eECRMcf68i/jG8pqZG\nM2fOlCQ9/vjjunz5sq5duxbpMXAf8fHxKisrU0pKSue22tpazZgxQ5KUlZWlmpoap8YLir/XFOvG\njRunvXv3SpKGDh2q9vb2mH+fJP+vy+fzOTyVA7G8cOGCHnzwwc6fhw0bptbW1kiPERa//vqr3njj\nDS1evFjff/+90+MEze12KyEhocu29vb2zo9zSUlJMfee+XtNklReXq5ly5Zp7dq1unTpkgOTBS8u\nLk4ej0eSVFFRoSlTpsT8+yT5f11xcXGOv1eO/M7yXh0dHU6PEBKPPvqocnNzNWfOHDU1NWnZsmWq\nqqqKyd8XBdJf3rP58+crMTFRGRkZOnDggPbt26fNmzc7PVavHT9+XBUVFTp8+LBmzZrVuT3W36d7\nX1d9fb3j71XEjyxTUlJ04cKFzp///vtvJScnR3qMkEtNTdXcuXPlcrk0YsQIDR8+XC0tLU6PFTIe\nj0c3btyQJLW0tPSLj7OZmZnKyMiQJE2fPl2NjY0OT9R71dXVKi0tVVlZmYYMGdJv3qd/v65oeK8i\nHstJkyapsrJSktTQ0KCUlBQNHjw40mOE3JdffqlDhw5JklpbW3Xx4kWlpqY6PFXoTJw4sfN9q6qq\n0uTJkx2eqO9WrVqlpqYmSf/3O9l//idDrLh69aqKioq0f//+zrPE/eF98ve6ouG9cnU4cKy+Z88e\n/fjjj3K5XNqyZYuefPLJSI8QcteuXdP69et15coV3b59W7m5uZo6darTYwWlvr5eu3btUnNzs9xu\nt1JTU7Vnzx7l5eXp5s2bSktLU2FhoQYOHOj0qGb+XtOSJUt04MABDRo0SB6PR4WFhUpKSnJ6VDOv\n16sPP/xQjz32WOe2nTt3atOmTTH7Pkn+X9eiRYtUXl7u6HvlSCwBINZwBQ8AGBBLADAglgBgQCwB\nwIBYAoABsQQAA2IJAAbEEgAM/gepgR0uaefKmwAAAABJRU5ErkJggg==\n",
+ "text/plain": [
+ "<matplotlib.figure.Figure at 0x7fd6199ef278>"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ }
+ },
+ {
+ "output_type": "stream",
+ "text": [
+ "Prediction: 4\n"
+ ],
+ "name": "stdout"
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUsAAAFKCAYAAACU6307AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAEelJREFUeJzt3W9MlfX/x/HXEWJyhg5BIG1ZfR0u\nKr3hhopOE2Q23FxiN0xCdNmGa5pG6hhTtNn8g85NtI0/aS1Z29moG96wILM2dYDKDRu0hrpyzCkC\nkUocDeH8brQfk8R4czyH64DPx624+Hid99nFnl2H61wHl8/n8wkA8J/GOD0AAIwExBIADIglABgQ\nSwAwIJYAYEAsAcCAWAKAAbEEAINwf//h7t27denSJblcLhUUFGjGjBmBnAsAQopfsTx//ryuXbsm\nj8ejq1evqqCgQB6PJ9CzAUDI8OtleE1NjdLT0yVJU6dO1e3bt9XZ2RnQwQAglPgVy7a2Nk2YMKHv\n65iYGLW2tgZsKAAINQG5wMNncQAY7fyKZXx8vNra2vq+vnXrluLi4gI2FACEGr9iOW/ePFVVVUmS\nGhsbFR8fr6ioqIAOBgChxK+r4TNnztSrr76qt99+Wy6XSzt27Aj0XAAQUlx8+C8ADI47eADAgFgC\ngAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCW\nAGBALAHAgFgCgAGxBAADYgkABsQSAAz8+lO4AJz3yy+/PLLtlVdeeWT777//bt7ne++9Z147f/58\n0zqPx2PeZyjjzBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABi4fD6fz+kh\ngNHsr7/+Mq+tr683r33rrbce2dba2qq4uLh+29rb2837XL16tXntp59+alrndrvN+wxlnFkCgAGx\nBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAF/sAzww/37981rMzMzzWtPnTplXvu4O2O8\nXm+/rysrK837XLJkiXnt2LFjzWtHA84sAcDArzPLuro6bdy4UYmJiZKkadOmafv27QEdDABCid8v\nw2fNmqXi4uJAzgIAIYuX4QBg4Hcsr1y5onXr1mnlypU6d+5cIGcCgJDj1+dZtrS0qL6+XhkZGWpu\nblZOTo6qq6sVERERjBkBwHF+/c4yISGh7y0GU6ZM0cSJE9XS0qLnn38+oMMBoWoobx1aunSpee2T\nvnWos7NTUVFR/bZ9+eWX5n3y1qHH8+tl+IkTJ3T06FFJ/3wyc3t7uxISEgI6GACEEr/OLNPS0rR5\n82b98MMP6u7u1s6dO3kJDmBU8yuWUVFRKikpCfQsABCyuN0ReIj1vcNbtmwx77O7u9u8dii/9//x\nxx8H3P7zzz/3+/p///ufeZ94PN5nCQAGxBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAG\nxBIADPz6PEvAaT09Pea1x48fH3D7mjVr9MUXX/Tblpuba9pnb2+v+fE/+eQT89qcnBzz2kmTJpnX\n4slxZgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABtzBgxHpcXflDGT16tUDbu/t\n7dWYMf6dL+zcudO8trCw0K/HQGjhzBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIA\nDIglABhwuyNCSnFxsWndRx99ZN7n4/642UC3O77zzjumff77D539l7CwMPNahC7OLAHAgFgCgAGx\nBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAG3OyLovF6vee2kSZNM6+7cuePvOH0Gut2x\npqbG9G9nz579xI+PkcV0ZtnU1KT09HRVVFRIkm7cuKFVq1YpKytLGzdu1N9//x3UIQHAaYPGsqur\nS7t27VJKSkrftuLiYmVlZemrr77SCy+8oMrKyqAOCQBOGzSWERERKi8vV3x8fN+2uro6LVq0SJKU\nmppqfukCACNV+KALwsMVHt5/mdfrVUREhCQpNjZWra2twZkOAELEoLEcDNeHMJjIyEjz2j///DOI\nkzyqt7d3WB8PI5dfsXS73bp3757Gjh2rlpaWfi/RgX/jajhGA7/eZzl37lxVVVVJkqqrqzV//vyA\nDgUAoWbQM8uGhgbt27dP169fV3h4uKqqqnTgwAHl5+fL4/Fo8uTJWrZs2XDMCgCO4U3pCDpehmM0\neOILPHg6ffvtt+a1hw4dMq8NRASfRElJiWkdsXz6cG84ABgQSwAwIJYAYEAsAcCAWAKAAbEEAANi\nCQAGxBIADIglABgQSwAw4HZH+MV6W6D0zydTWU2ZMsW07v79++Z9trS0mNcCj8OZJQAYEEsAMCCW\nAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMOB2R/Rz4cKFAbcnJyf3+15tbW1QHv/77783\nrRvKX4FMTk72dxygD2eWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGDAHTzoZ8GC\nBQNu93q9/b43lD8YNhTWP1jm9XqD8vjA43BmCQAGxBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEE\nAANiCQAGxBIADLjd8Slw5MgR89r/uo3R31scZ8yYYV7rcrn8eoxAuXnzpmldV1eXeZ9ut9vfcRBC\nOLMEAANTLJuampSenq6KigpJUn5+vpYuXapVq1Zp1apV+umnn4I5IwA4btCX4V1dXdq1a5dSUlL6\nbc/Ly1NqamrQBgOAUDLomWVERITKy8sVHx8/HPMAQEhy+Xw+n2Xh4cOHNWHCBGVnZys/P1+tra3q\n7u5WbGystm/frpiYmGDPCgCO8etq+Jtvvqno6GglJSWprKxMR44cUWFhYaBnQ4AM5Wr4Bx98MOD2\n3t5ejRnj3/XAoVwNP3/+vGndUK5GP+5/5AM9pzfeeMO0z6+//tr8+FwNHx38+ulPSUlRUlKSJCkt\nLU1NTU0BHQoAQo1fsdywYYOam5slSXV1dUpMTAzoUAAQagZ9Gd7Q0KB9+/bp+vXrCg8PV1VVlbKz\ns7Vp0yZFRkbK7XZrz549wzErADhm0Fi+9tprOn78+CPbrb/bAYDRgNsdnwLt7e2OPv6WLVvMayMi\nIkzrhnKBZyiqqqpM63799VfzPmfOnOnvOAgh3O4IAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoAB\nsQQAA2IJAAbEEgAMuN0RfomNjTWvTU5ODvjjnz17NuD7lNT30YODee6554Ly+AhdnFkCgAGxBAAD\nYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAF38MAv48ePN6999tlnA/74FRUVAd+nJM2aNcu0\nLiEhISiPj9DFmSUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADDgdkf45bff\nfjOv/eabb8xrs7OzTet6e3vN+/T5fH59D3gYZ5YAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABgQ\nSwAwIJYAYEAsAcCA2x0RdO+++25Q1lq5XC6/vgc8zBTLoqIi1dfX68GDB8rNzdX06dO1detW9fT0\nKC4uTvv371dERESwZwUAxwway9raWl2+fFkej0cdHR3KzMxUSkqKsrKylJGRoYMHD6qyslJZWVnD\nMS8AOGLQ31kmJyfr0KFDkqTx48fL6/Wqrq5OixYtkiSlpqaqpqYmuFMCgMMGjWVYWJjcbrckqbKy\nUgsWLJDX6+172R0bG6vW1tbgTgkADjNf4Dl16pQqKyt17NgxLV68uG87nwcY+nbs2BGQtUP5DMmR\nYjQ+JwSHKZZnzpxRSUmJPvvsM40bN05ut1v37t3T2LFj1dLSovj4+GDPiSfw8ccfP/Ha3t5ejRkz\nut5pNtBzWr16tenffv7558EYCSFs0J/+u3fvqqioSKWlpYqOjpYkzZ07V1VVVZKk6upqzZ8/P7hT\nAoDDBj2zPHnypDo6OrRp06a+bXv37tW2bdvk8Xg0efJkLVu2LKhDAoDTBo3lihUrtGLFike28zIE\nwNOEO3ieAnl5eea1Fy5ceOz3lixZ0vffZ8+eNe/zzp075rVAqBpdv7EHgCAhlgBgQCwBwIBYAoAB\nsQQAA2IJAAbEEgAMiCUAGBBLADAglgBg4PLxgZTww3fffWde+/Btkk543I+4z+d75A+W1dbWmvY5\ne/bsJ54LIwtnlgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIDbHQHAgDNL\nADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbE\nEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAg3DLoqKiItXX1+vBgwfKzc3V6dOn1djYqOjoaEnS\n2rVrtXDhwmDOCQCOGjSWtbW1unz5sjwejzo6OpSZmak5c+YoLy9PqampwzEjADhu0FgmJydrxowZ\nkqTx48fL6/Wqp6cn6IMBQChx+Xw+n3Wxx+PRxYsXFRYWptbWVnV3dys2Nlbbt29XTExMMOcEAEeZ\nY3nq1CmVlpbq2LFjamhoUHR0tJKSklRWVqabN2+qsLAw2LMCgGNMV8PPnDmjkpISlZeXa9y4cUpJ\nSVFSUpIkKS0tTU1NTUEdEgCcNmgs7969q6KiIpWWlvZd/d6wYYOam5slSXV1dUpMTAzulADgsEEv\n8Jw8eVIdHR3atGlT37bly5dr06ZNioyMlNvt1p49e4I6JAA4bUgXeADgacUdPABgQCwBwIBYAoAB\nsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBg\nQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbhTjzo7t27\ndenSJblcLhUUFGjGjBlOjBFQdXV12rhxoxITEyVJ06ZN0/bt2x2eyn9NTU16//33tWbNGmVnZ+vG\njRvaunWrenp6FBcXp/379ysiIsLpMYfk388pPz9fjY2Nio6OliStXbtWCxcudHbIISoqKlJ9fb0e\nPHig3NxcTZ8+fcQfJ+nR53X69GnHj9Wwx/L8+fO6du2aPB6Prl69qoKCAnk8nuEeIyhmzZql4uJi\np8d4Yl1dXdq1a5dSUlL6thUXFysrK0sZGRk6ePCgKisrlZWV5eCUQzPQc5KkvLw8paamOjTVk6mt\nrdXly5fl8XjU0dGhzMxMpaSkjOjjJA38vObMmeP4sRr2l+E1NTVKT0+XJE2dOlW3b99WZ2fncI+B\n/xAREaHy8nLFx8f3baurq9OiRYskSampqaqpqXFqPL8M9JxGuuTkZB06dEiSNH78eHm93hF/nKSB\nn1dPT4/DUzkQy7a2Nk2YMKHv65iYGLW2tg73GEFx5coVrVu3TitXrtS5c+ecHsdv4eHhGjt2bL9t\nXq+37+VcbGzsiDtmAz0nSaqoqFBOTo4+/PBD/fHHHw5M5r+wsDC53W5JUmVlpRYsWDDij5M08PMK\nCwtz/Fg58jvLh/l8PqdHCIgXX3xR69evV0ZGhpqbm5WTk6Pq6uoR+fuiwYyWY/bmm28qOjpaSUlJ\nKisr05EjR1RYWOj0WEN26tQpVVZW6tixY1q8eHHf9pF+nB5+Xg0NDY4fq2E/s4yPj1dbW1vf17du\n3VJcXNxwjxFwCQkJWrJkiVwul6ZMmaKJEyeqpaXF6bECxu126969e5KklpaWUfFyNiUlRUlJSZKk\ntLQ0NTU1OTzR0J05c0YlJSUqLy/XuHHjRs1x+vfzCoVjNeyxnDdvnqqqqiRJjY2Nio+PV1RU1HCP\nEXAnTpzQ0aNHJUmtra1qb29XQkKCw1MFzty5c/uOW3V1tebPn+/wRE9uw4YNam5ulvTP72T//50M\nI8Xdu3dVVFSk0tLSvqvEo+E4DfS8QuFYuXwOnKsfOHBAFy9elMvl0o4dO/Tyyy8P9wgB19nZqc2b\nN+vOnTvq7u7W+vXr9frrrzs9ll8aGhq0b98+Xb9+XeHh4UpISNCBAweUn5+v+/fva/LkydqzZ4+e\neeYZp0c1G+g5ZWdnq6ysTJGRkXK73dqzZ49iY2OdHtXM4/Ho8OHDeumll/q27d27V9u2bRuxx0ka\n+HktX75cFRUVjh4rR2IJACMNd/AAgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHA4P8ALqDX\nN3rmU3AAAAAASUVORK5CYII=\n",
+ "text/plain": [
+ "<matplotlib.figure.Figure at 0x7fd62944c6d8>"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ }
+ },
+ {
+ "output_type": "stream",
+ "text": [
+ "Prediction: 1\n"
+ ],
+ "name": "stdout"
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUsAAAFKCAYAAACU6307AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAEqVJREFUeJzt3W9Ilff/x/HX+eWkpMQ0dQRrZdgm\nq24Miiz6Y0nrFKPVjZqiMgiW/SMX0ZxlDYJMiyALZrnqRlKc4a1u5B9cjIWZUbDA7ljWQqJMm1iR\nbSbne2P8/H7NY77P8Ryvoz0f97y8us777BpPrnMuP+e4vF6vVwCAd/o/pwcAgNGAWAKAAbEEAANi\nCQAGxBIADIglABgQSwAwIJYAYBAR6D88dOiQbt++LZfLpYKCAs2dOzeYcwFAWAkoljdu3NDDhw/l\n8XjU0tKigoICeTyeYM8GAGEjoJfhDQ0NSk9PlyTNnDlTXV1devnyZVAHA4BwElAsOzo6NHny5L6f\nY2Nj1d7eHrShACDcBOUGD5/FAWCsCyiWCQkJ6ujo6Pv56dOnio+PD9pQABBuAorlokWLVFNTI0m6\nc+eOEhISNHHixKAOBgDhJKC74Z9//rk+++wzff3113K5XDpw4ECw5wKAsOLiw38BYGis4AEAA2IJ\nAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBY\nAoABsQQAA2IJAAbEEgAMiCUAGBBLADAI6KtwgVC5ePGiab+9e/eaj/ngwQOf271er1wul/k4gWpp\naTHvm5SUFMJJMBxcWQKAAbEEAANiCQAGxBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAAOWOyIg\n9+/fD8lxMzMzTfutWrXKfMzBljv6MmPGjKAf88mTJ+Z9We4YvriyBAADYgkABsQSAAyIJQAYEEsA\nMCCWAGBALAHAgFgCgAGxBAADVvAgIOnp6eZ9/VntYrV06VLzvh6PZ9DfdXV19fs5OjradMwtW7aY\nH3/27NnmfRG+uLIEAIOAriwbGxu1c+dOJScnS5JmzZqlwsLCoA4GAOEk4Jfh8+fPV2lpaTBnAYCw\nxctwADAIOJb37t1Tbm6uMjIyVF9fH8yZACDsuLxer9fff9TW1qZbt27J7XartbVVOTk5qq2tVWRk\nZChmBADHBfSeZWJiolavXi1JmjZtmqZMmaK2tjZ99NFHQR0O4cufD6kNxZ8OFRUVmffdunWrz+3R\n0dF6/vz5gG0W/vzpUHFxsXlf6+Nj5AX0MvzSpUs6c+aMJKm9vV3Pnj1TYmJiUAcDgHAS0JXl8uXL\ntXv3bv3666/q6enRjz/+yEtwAGNaQLGcOHGiysrKgj0LAIStgG7wYHR5+325d9m4caPP7VVVVXK7\n3X0/V1dXD3suX6zvRebn54fk8YHB8HeWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCW\nAGBALAHAgOWO7wF/Pk5ssDX/Xq9XLpcroMf35+PUWMaIcMWVJQAYEEsAMCCWAGBALAHAgFgCgAGx\nBAADYgkABsQSAAyIJQAYsIJnlLp27Zp530WLFg378d5ewXPhwgXzv83IyBj24wNO48oSAAyIJQAY\nEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYRDg9APp7/vy5ab9gLGH0JTc31/Q7ljDi\nfcOVJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMODbHcOM2+027VddXW0+\n5qpVq8z7ejwen9ujo6P7LcWMjo42HxMYC0xXls3NzUpPT1dFRYUk6fHjx8rOzlZmZqZ27typf/75\nJ6RDAoDThozlq1evdPDgQaWmpvZtKy0tVWZmpi5cuKCPP/5YlZWVIR0SAJw2ZCwjIyNVXl6uhISE\nvm2NjY1asWKFJCktLU0NDQ2hmxAAwsCQH9EWERGhiIj+u3V3dysyMlKSFBcXp/b29tBMBwBhYtif\nZ8n9oeCqqqpyeoRBcVMH77OAYhkVFaXXr19r/Pjxamtr6/cSHcPD3XAgPAX0d5YLFy5UTU2NJKm2\ntlaLFy8O6lAAEG6GvLJsampScXGxHj16pIiICNXU1Ojo0aPKz8+Xx+PR1KlT9dVXX43ErADgmCFj\nOXv2bJ0/f37A9nPnzoVkIAAIR6zgGQH379837ztz5sygP35LS4t536SkpKA/PjAWsDYcAAyIJQAY\nEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYDPvzLDG0I0eOBP2Yubm55n1ZwggMH1eW\nAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgOWOI6Cmpibox8zOzg76Mceq\nwb5dMykpacDvrEtT//zzT/PjT58+3byvP/+vfPLJJwO2VVVVye1299uWk5NjPuaaNWvM+0ZHR5v3\nHQu4sgQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA5fX6/U6PcRY588Xhj148MC0\nX0tLS0ge32kXL1407bd3717zMQf7b+r1euVyuczHGQ2G+5xWrVpl3tfj8Zj2GysrfbiyBAADYgkA\nBsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABix3HAFbtmwx71tWVmbabzSdtlAs9wyG\n4SwN9GdZYHV1dUCPEYiRXMJpXXI7mpbbvgtXlgBgYIplc3Oz0tPTVVFRIUnKz8/Xl19+qezsbGVn\nZ+u3334L5YwA4Lghvzf81atXOnjwoFJTU/tt37Vrl9LS0kI2GACEkyGvLCMjI1VeXq6EhISRmAcA\nwpL5Bs+JEyc0efJkZWVlKT8/X+3t7erp6VFcXJwKCwsVGxsb6lkBwDFDvgz3Ze3atYqJiVFKSopO\nnz6tkydPav/+/cGebczgbjh3w0cKd8NDJ6C74ampqUpJSZEkLV++XM3NzUEdCgDCTUCx3LFjh1pb\nWyVJjY2NSk5ODupQABBuhnwZ3tTUpOLiYj169EgRERGqqalRVlaW8vLyNGHCBEVFRamoqGgkZgUA\nxwwZy9mzZ+v8+fMDtn/xxRchGQgAwlFAN3gAt9tt3jcUN238eTWzYcOGQX/39k2KKVOmBDzTYEL1\n7YbPnz/3ub2rq6vfz99//735mNYbjJK0bds2035VVVXmY4YzljsCgAGxBAADYgkABsQSAAyIJQAY\nEEsAMCCWAGBALAHAgFgCgAGxBAADljuOUteuXTPvu3DhwmEfd+HChf1+F6rPaKyvrzft589zepfR\n/FmLgy2jfHv7Tz/9ZD6mP8sd3zdcWQKAAbEEAANiCQAGxBIADIglABgQSwAwIJYAYEAsAcCAWAKA\nASt4RkBxcbF535qaGtN+WVlZ5mP+8ccf5n19fZOn9O+KmcF+NxR/vlwsWCtz8F/+rPbyR2FhYUiO\nG664sgQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAYur9frdXoI/Jd1adqi\nRYtCPEl/Xq9XLpcroH/b1dVl3newL+HCQBcvXhywLSMjY8D2zMxM8zEvXLhg3nfNmjWm/cbKOeXK\nEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGLDccZTy5xv7grE0cjjLHf35\ndseHDx+a9svOzjYf88MPP/S5PSkpSffv3++37ZdffjEdc8mSJebH98fBgwfN+1ZXVw/YNpzzJEn1\n9fXmfd+3b+I0fRVuSUmJbt26pTdv3mjz5s2aM2eO9uzZo97eXsXHx+vIkSOKjIwM9awA4JghY3n9\n+nXdvXtXHo9HnZ2dWrdunVJTU5WZmSm3261jx46psrLSr8X6ADDaDPme5bx583T8+HFJ/356SHd3\ntxobG7VixQpJUlpamhoaGkI7JQA4bMhYjhs3TlFRUZKkyspKLVmyRN3d3X0vu+Pi4tTe3h7aKQHA\nYab3LCWprq5OlZWVOnv2rFauXNm3nftDzvDnzfVgnaOxeK6TkpL6/Zyfn+/QJP+qqqoa9jHG4nkK\nB6ZYXr16VWVlZfr55581adIkRUVF6fXr1xo/frza2tqUkJAQ6jnxFu6Gczecu+Eja8iX4S9evFBJ\nSYlOnTqlmJgYSf/+R6qpqZEk1dbWavHixaGdEgAcNuSV5eXLl9XZ2am8vLy+bYcPH9a+ffvk8Xg0\ndepUffXVVyEdEgCcNmQsN27cqI0bNw7Yfu7cuZAMBADhiBU874G335d7l23btvncXlVVJbfb3fez\nr/fLRpvhvr/ntBkzZgzYdv/+/QE3rerq6szHnDJlinnfsfJFZFasDQcAA2IJAAbEEgAMiCUAGBBL\nADAglgBgQCwBwIBYAoABsQQAA2IJAAYsd0RA/PmIuPPnz5v3tX702u+//24+5g8//OBzu6/ljr6W\nEPry7bffmh9/w4YN5n398fayRoQWV5YAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABgQSwAwIJYA\nYEAsAcCA5Y4AYMCVJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBA\nLAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgEGEZaeSkhLdunVLb9680ebN\nm3XlyhXduXNHMTExkqRNmzZp2bJloZwTABw1ZCyvX7+uu3fvyuPxqLOzU+vWrdOCBQu0a9cupaWl\njcSMAOC4IWM5b948zZ07V5IUHR2t7u5u9fb2hnwwAAgnLq/X67Xu7PF4dPPmTY0bN07t7e3q6elR\nXFycCgsLFRsbG8o5AcBR5ljW1dXp1KlTOnv2rJqamhQTE6OUlBSdPn1aT5480f79+0M9KwA4xnQ3\n/OrVqyorK1N5ebkmTZqk1NRUpaSkSJKWL1+u5ubmkA4JAE4bMpYvXrxQSUmJTp061Xf3e8eOHWpt\nbZUkNTY2Kjk5ObRTAoDDhrzBc/nyZXV2diovL69v2/r165WXl6cJEyYoKipKRUVFIR0SAJzm1w0e\nAHhfsYIHAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHA\ngFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsA\nMCCWAGBALAHAIMKJBz106JBu374tl8ulgoICzZ0714kxgqqxsVE7d+5UcnKyJGnWrFkqLCx0eKrA\nNTc3a+vWrfrmm2+UlZWlx48fa8+ePert7VV8fLyOHDmiyMhIp8f0y9vPKT8/X3fu3FFMTIwkadOm\nTVq2bJmzQ/qppKREt27d0ps3b7R582bNmTNn1J8naeDzunLliuPnasRjeePGDT18+FAej0ctLS0q\nKCiQx+MZ6TFCYv78+SotLXV6jGF79eqVDh48qNTU1L5tpaWlyszMlNvt1rFjx1RZWanMzEwHp/SP\nr+ckSbt27VJaWppDUw3P9evXdffuXXk8HnV2dmrdunVKTU0d1edJ8v28FixY4Pi5GvGX4Q0NDUpP\nT5ckzZw5U11dXXr58uVIj4F3iIyMVHl5uRISEvq2NTY2asWKFZKktLQ0NTQ0ODVeQHw9p9Fu3rx5\nOn78uCQpOjpa3d3do/48Sb6fV29vr8NTORDLjo4OTZ48ue/n2NhYtbe3j/QYIXHv3j3l5uYqIyND\n9fX1To8TsIiICI0fP77ftu7u7r6Xc3FxcaPunPl6TpJUUVGhnJwcfffdd/rrr78cmCxw48aNU1RU\nlCSpsrJSS5YsGfXnSfL9vMaNG+f4uXLkPcv/5fV6nR4hKKZPn67t27fL7XartbVVOTk5qq2tHZXv\nFw1lrJyztWvXKiYmRikpKTp9+rROnjyp/fv3Oz2W3+rq6lRZWamzZ89q5cqVfdtH+3n63+fV1NTk\n+Lka8SvLhIQEdXR09P389OlTxcfHj/QYQZeYmKjVq1fL5XJp2rRpmjJlitra2pweK2iioqL0+vVr\nSVJbW9uYeDmbmpqqlJQUSdLy5cvV3Nzs8ET+u3r1qsrKylReXq5JkyaNmfP09vMKh3M14rFctGiR\nampqJEl37txRQkKCJk6cONJjBN2lS5d05swZSVJ7e7uePXumxMREh6cKnoULF/adt9raWi1evNjh\niYZvx44dam1tlfTve7L//5cMo8WLFy9UUlKiU6dO9d0lHgvnydfzCodz5fI6cK1+9OhR3bx5Uy6X\nSwcOHNCnn3460iME3cuXL7V79249f/5cPT092r59u5YuXer0WAFpampScXGxHj16pIiICCUmJuro\n0aPKz8/X33//ralTp6qoqEgffPCB06Oa+XpOWVlZOn36tCZMmKCoqCgVFRUpLi7O6VHNPB6PTpw4\noRkzZvRtO3z4sPbt2zdqz5Pk+3mtX79eFRUVjp4rR2IJAKMNK3gAwIBYAoABsQQAA2IJAAbEEgAM\niCUAGBBLADAglgBg8B9OkjtgR8VvdgAAAABJRU5ErkJggg==\n",
+ "text/plain": [
+ "<matplotlib.figure.Figure at 0x7fd619a40b00>"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ }
+ },
+ {
+ "output_type": "stream",
+ "text": [
+ "Prediction: 6\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "4SJizeJtNaAs",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "# Profiling\n",
+ "\n",
+ "If you want to drill down into the performance characteristics of your code, you can use native Python profilers like [`cProfile`](https://docs.python.org/3/library/profile.html). In the next exercise, you'll do just that."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "_2v0QnG8__PJ",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Exercise!\n",
+ "\n",
+ "This exercise does not require coding. If you have not completed the training exercise, replace `train_one_epoch` below with `_train_one_epoch`.\n",
+ "\n",
+ "Run the below cell and inspect the printed profiles. What parts of the code appear to be hotspots or\n",
+ "bottlenecks? How does sorting the profile by total time compare to sorting it\n",
+ "by cumulative time?\n",
+ "\n"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "IFypaYbG_9fB",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 714
+ },
+ "outputId": "d9c3596b-a165-4edd-fc6b-53ccd0d01d19"
+ },
+ "cell_type": "code",
+ "source": [
+ "import cProfile\n",
+ "import pstats\n",
+ "\n",
+ "cProfile.run(\"train_one_epoch(model, training_data, optimizer)\", \"training_profile\")\n",
+ "\n",
+ "stats = pstats.Stats(\"training_profile\").strip_dirs().sort_stats(\"tottime\")\n",
+ "stats.print_stats(10)\n",
+ "\n",
+ "stats.sort_stats(\"cumtime\").print_stats(10)"
+ ],
+ "execution_count": 17,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "Thu Jun 7 12:25:04 2018 training_profile\n",
+ "\n",
+ " 92209 function calls (91817 primitive calls) in 3.446 seconds\n",
+ "\n",
+ " Ordered by: internal time\n",
+ " List reduced from 672 to 10 due to restriction <10>\n",
+ "\n",
+ " ncalls tottime percall cumtime percall filename:lineno(function)\n",
+ " 1080 2.552 0.002 2.552 0.002 {built-in method _pywrap_tensorflow_internal.TFE_Py_FastPathExecute}\n",
+ " 83 0.753 0.009 0.753 0.009 {built-in method _pywrap_tensorflow_internal.TFE_Py_Execute}\n",
+ " 16 0.006 0.000 1.019 0.064 network.py:736(_run_internal_graph)\n",
+ " 16 0.005 0.000 2.253 0.141 {built-in method _pywrap_tensorflow_internal.TFE_Py_TapeGradient}\n",
+ " 2321 0.004 0.000 0.007 0.000 abc.py:178(__instancecheck__)\n",
+ " 288 0.004 0.000 0.009 0.000 inspect.py:2092(_signature_from_function)\n",
+ " 878 0.004 0.000 0.005 0.000 ops.py:5936(__enter__)\n",
+ " 288 0.004 0.000 0.016 0.000 inspect.py:1079(getfullargspec)\n",
+ " 11006 0.003 0.000 0.005 0.000 {built-in method builtins.isinstance}\n",
+ " 768 0.003 0.000 0.008 0.000 {built-in method _pywrap_tensorflow_internal.Flatten}\n",
+ "\n",
+ "\n",
+ "Thu Jun 7 12:25:04 2018 training_profile\n",
+ "\n",
+ " 92209 function calls (91817 primitive calls) in 3.446 seconds\n",
+ "\n",
+ " Ordered by: cumulative time\n",
+ " List reduced from 672 to 10 due to restriction <10>\n",
+ "\n",
+ " ncalls tottime percall cumtime percall filename:lineno(function)\n",
+ " 1 0.000 0.000 3.446 3.446 {built-in method builtins.exec}\n",
+ " 1 0.000 0.000 3.446 3.446 <string>:1(<module>)\n",
+ " 1 0.001 0.001 3.446 3.446 <ipython-input-14-bcffed60b545>:9(train_one_epoch)\n",
+ " 1080 2.552 0.002 2.552 0.002 {built-in method _pywrap_tensorflow_internal.TFE_Py_FastPathExecute}\n",
+ " 16 0.000 0.000 2.255 0.141 backprop.py:739(gradient)\n",
+ " 16 0.000 0.000 2.253 0.141 imperative_grad.py:31(imperative_grad)\n",
+ " 16 0.005 0.000 2.253 0.141 {built-in method _pywrap_tensorflow_internal.TFE_Py_TapeGradient}\n",
+ " 400 0.002 0.000 2.246 0.006 backprop.py:145(grad_fn)\n",
+ " 400 0.002 0.000 2.239 0.006 backprop.py:95(_magic_gradient_function)\n",
+ " 32 0.001 0.000 1.601 0.050 nn_grad.py:497(_Conv2DGrad)\n",
+ "\n",
+ "\n"
+ ],
+ "name": "stdout"
+ },
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "<pstats.Stats at 0x7fd61f841710>"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 17
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "8ixpnyCNNTI4",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "cell_type": "code",
+ "source": [
+ ""
+ ],
+ "execution_count": 0,
+ "outputs": []
+ }
+ ]
+} \ No newline at end of file
diff --git a/tensorflow/contrib/eager/python/examples/workshop/3_inspecting.ipynb b/tensorflow/contrib/eager/python/examples/workshop/3_inspecting.ipynb
new file mode 100644
index 0000000000..64d19ec5c9
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/workshop/3_inspecting.ipynb
@@ -0,0 +1,443 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "name": "Debugging \"graph-first\" models with eager execution",
+ "version": "0.3.2",
+ "provenance": [],
+ "include_colab_link": true
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ "[View in Colaboratory](https://colab.research.google.com/gist/alextp/9568ab40f6ed6f9a3ba4736f6aef6127/debugging-graph-first-models-with-eager-execution.ipynb)"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "mm-t0GuIu1Dt",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "This colab uses eager execution and the Python debugger to modify the execution of a translation model. This combination lets you quickly explore counterfactuals when researching and designing modifications to a model.\n",
+ "\n",
+ "The model, Transformer from [Tensor2Tensor](https://github.com/tensorflow/tensor2tensor), was originally written with graph building in mind. Executing it eagerly can still be helpful!"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "gxb1DvIDg4sv",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title License (double click to show)\n",
+ "# Copyright 2018 The TensorFlow Authors.\n",
+ "\n",
+ "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
+ "# you may not use this file except in compliance with the License.\n",
+ "# You may obtain a copy of the License at\n",
+ "\n",
+ "# https://www.apache.org/licenses/LICENSE-2.0\n",
+ "\n",
+ "# Unless required by applicable law or agreed to in writing, software\n",
+ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
+ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
+ "# See the License for the specific language governing permissions and\n",
+ "# limitations under the License."
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "Gx3HA9N1ui64",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 37
+ },
+ "outputId": "f6986f34-f3e1-44e1-c902-2eb33081acad"
+ },
+ "cell_type": "code",
+ "source": [
+ "import tensorflow as tf\n",
+ "import pdb\n",
+ "tfe = tf.contrib.eager\n",
+ "\n",
+ "tf.enable_eager_execution()"
+ ],
+ "execution_count": 1,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "3LkOm2ct-Lmc",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 37
+ },
+ "outputId": "2edc74d9-6bc0-4e78-ab4e-83bf96099ef4"
+ },
+ "cell_type": "code",
+ "source": [
+ "!pip install -q -U tensor2tensor\n",
+ "from tensor2tensor.models import transformer"
+ ],
+ "execution_count": 2,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "1Z3oMsqV0zB6",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 170
+ },
+ "outputId": "0a8186ee-c688-457f-c9f6-9a6c1477a93b"
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title Create a tensor2tensor translation model, fetch a checkpoint (double click to show)\n",
+ "from tensor2tensor import problems\n",
+ "from tensor2tensor.utils import trainer_lib\n",
+ "from tensor2tensor.utils import registry\n",
+ "\n",
+ "import numpy as np\n",
+ "import os\n",
+ "\n",
+ "# Setup some directories\n",
+ "data_dir = os.path.expanduser(\"~/t2t/data\")\n",
+ "tmp_dir = os.path.expanduser(\"~/t2t/tmp\")\n",
+ "train_dir = os.path.expanduser(\"~/t2t/train\")\n",
+ "checkpoint_dir = os.path.expanduser(\"~/t2t/checkpoints\")\n",
+ "tf.gfile.MakeDirs(data_dir)\n",
+ "tf.gfile.MakeDirs(tmp_dir)\n",
+ "tf.gfile.MakeDirs(train_dir)\n",
+ "tf.gfile.MakeDirs(checkpoint_dir)\n",
+ "gs_data_dir = \"gs://tensor2tensor-data\"\n",
+ "gs_ckpt_dir = \"gs://tensor2tensor-checkpoints/\"\n",
+ "\n",
+ "# Fetch the problem\n",
+ "ende_problem = problems.problem(\"translate_ende_wmt32k\")\n",
+ "\n",
+ "# Copy the vocab file locally so we can encode inputs and decode model outputs\n",
+ "# All vocabs are stored on GCS\n",
+ "vocab_name = \"vocab.ende.32768\"\n",
+ "vocab_file = os.path.join(gs_data_dir, vocab_name)\n",
+ "!gsutil cp {vocab_file} {data_dir}\n",
+ "\n",
+ "# Get the encoders from the problem\n",
+ "encoders = ende_problem.feature_encoders(data_dir)\n",
+ "\n",
+ "# Setup helper functions for encoding and decoding\n",
+ "def encode(input_str, output_str=None):\n",
+ " \"\"\"Input str to features dict, ready for inference\"\"\"\n",
+ " inputs = encoders[\"inputs\"].encode(input_str) + [1] # add EOS id\n",
+ " batch_inputs = tf.reshape(inputs, [1, -1, 1]) # Make it 3D.\n",
+ " return {\"inputs\": batch_inputs}\n",
+ "\n",
+ "def decode(integers):\n",
+ " \"\"\"List of ints to str\"\"\"\n",
+ " integers = list(np.squeeze(integers))\n",
+ " if 1 in integers:\n",
+ " integers = integers[:integers.index(1)]\n",
+ " return encoders[\"inputs\"].decode(np.squeeze(integers))\n",
+ "\n",
+ "# Copy the pretrained checkpoint locally\n",
+ "ckpt_name = \"transformer_ende_test\"\n",
+ "gs_ckpt = os.path.join(gs_ckpt_dir, ckpt_name)\n",
+ "!gsutil -q cp -R {gs_ckpt} {checkpoint_dir}\n",
+ "checkpoint_path = tf.train.latest_checkpoint(\n",
+ " os.path.join(checkpoint_dir, ckpt_name))\n",
+ "\n",
+ "# Create hparams and the model\n",
+ "model_name = \"transformer\"\n",
+ "hparams_set = \"transformer_base\"\n",
+ "\n",
+ "hparams = trainer_lib.create_hparams(hparams_set, data_dir=data_dir, problem_name=\"translate_ende_wmt32k\")\n",
+ "\n",
+ "# NOTE: Only create the model once when restoring from a checkpoint; it's a\n",
+ "# Layer and so subsequent instantiations will have different variable scopes\n",
+ "# that will not match the checkpoint.\n",
+ "translate_model = registry.model(model_name)(hparams, tf.estimator.ModeKeys.EVAL)"
+ ],
+ "execution_count": 3,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "Copying gs://tensor2tensor-data/vocab.ende.32768...\n",
+ "/ [1 files][316.4 KiB/316.4 KiB] \n",
+ "Operation completed over 1 objects/316.4 KiB. \n",
+ "INFO:tensorflow:Setting T2TModel mode to 'eval'\n",
+ "INFO:tensorflow:Setting hparams.layer_prepostprocess_dropout to 0.0\n",
+ "INFO:tensorflow:Setting hparams.symbol_dropout to 0.0\n",
+ "INFO:tensorflow:Setting hparams.attention_dropout to 0.0\n",
+ "INFO:tensorflow:Setting hparams.dropout to 0.0\n",
+ "INFO:tensorflow:Setting hparams.relu_dropout to 0.0\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "4IblPXLGjuCl",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "We've created a Transformer model and fetched an existing training checkpoint. It hasn't created variables yet, and we want to load them from the checkpoint before they're used (restore-on-create) so the first run of the model outputs the correct value. The `tfe.restore_variables_on_create` API looks up variables by name on creation and restores their values."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "o3MWxcAqJoqG",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 51
+ },
+ "outputId": "fbc1b1bf-ffbe-4621-b3cb-5eb855fec3a8"
+ },
+ "cell_type": "code",
+ "source": [
+ "with tfe.restore_variables_on_create(checkpoint_path):\n",
+ " model_output = translate_model.infer(encode(\"Eager execution\"))\n",
+ "print(decode(model_output[\"outputs\"]))"
+ ],
+ "execution_count": 4,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "INFO:tensorflow:Greedy Decoding\n",
+ "Hinrichtung\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "xk5HV9Hhu9zO",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "Using global variable names can get somewhat fragile, so for new code we recommend the object-based `tf.keras.Model.save_weights` or `tf.train.Checkpoint`. However, these require some small code changes to work with existing graph building code.\n",
+ "\n",
+ "The Transformer model translates \"Eager execution\" in English to \"Hinrichtung\" in German, which refers to capital punishment rather than getting things done. Transformer first encodes the English, then decodes to German. We'll add a debugging hook at the start of the decode phase (once the encodings have been finalized) and see if we can correct the translation."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "GUGwbYvXZ9-7",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "cell_type": "code",
+ "source": [
+ "previous_fast_decode = transformer.fast_decode\n",
+ "def debug_fn(*args, **kwargs):\n",
+ " pdb.set_trace()\n",
+ " return previous_fast_decode(*args, **kwargs) # \"step\" in pdb to step in\n",
+ "transformer.fast_decode = debug_fn # Add our debugging hook to Transformer"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "f61HlvECxJn0",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "Now that we've \"monkey patched\" the model, we'll drop into a debugger just before decoding starts. In most cases it'd be simpler to add the `pdb.set_trace()` call to the code directly, but in this case we're working with prepackaged library code.\n",
+ "\n",
+ "First, let's find an encoding which represents the correct sense of \"execution\". Then we'll patch part of that encoding into the encoding of \"Eager execution\" to fix the translation. Feel free to poke around with the debugger (e.g. print a Tensor's value), but your main task is to save the encodings by assigning them to an attribute of the function:\n",
+ "\n",
+ "```\n",
+ "(running the next cell drops you into a pdb shell)\n",
+ "step\n",
+ "fast_decode.previous_encoding = encoder_output\n",
+ "continue\n",
+ "\n",
+ "```\n",
+ "\n",
+ "You can type `next` (or `n`) a few times before `continue` to watch the decoding ops run."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "dX4CPOGSpZrb",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 179
+ },
+ "outputId": "6de38c31-836f-40ef-b701-e42908172619"
+ },
+ "cell_type": "code",
+ "source": [
+ "model_output = translate_model.infer(encode(\"Immediate running\"))\n",
+ "print(decode(model_output[\"outputs\"]))"
+ ],
+ "execution_count": 7,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "> <ipython-input-6-ee9b4225ba2a>(4)debug_fn()\n",
+ "-> return previous_fast_decode(*args, **kwargs) # \"step\" in pdb to step in\n",
+ "(Pdb) step\n",
+ "--Call--\n",
+ "> /usr/local/lib/python2.7/dist-packages/tensor2tensor/models/transformer.py(427)fast_decode()\n",
+ "-> def fast_decode(encoder_output,\n",
+ "(Pdb) fast_decode.previous_encoding = encoder_output\n",
+ "(Pdb) continue\n",
+ "Sofortige Durchführung\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "-ZEZciV4FpLo",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "Now we have an encoding saved which gets the correct sense for \"execution\"."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "QeC_oDVqHD_v",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 179
+ },
+ "outputId": "253c9af1-003e-46bd-8bf5-db968cf6a8cf"
+ },
+ "cell_type": "code",
+ "source": [
+ "# Assumes you followed the pdb instructions above!\n",
+ "transformer.fast_decode.previous_encoding"
+ ],
+ "execution_count": 8,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "<tf.Tensor: id=9528, shape=(1, 4, 512), dtype=float32, numpy=\n",
+ "array([[[-0.15239455, 0.12273102, -0.11209048, ..., -0.12478986,\n",
+ " 0.37216735, -0.40987235],\n",
+ " [-0.2686283 , 0.51448774, 0.03650613, ..., 0.08731575,\n",
+ " 0.51110077, -0.6646815 ],\n",
+ " [-0.24441548, 0.36622533, 0.11685672, ..., 0.21941349,\n",
+ " -0.03304008, -0.579611 ],\n",
+ " [-0.03339856, -0.01185844, 0.00579634, ..., 0.00294734,\n",
+ " 0.00136655, -0.01362935]]], dtype=float32)>"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 8
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "bC9JjeDcHEav",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "Let's replace part of the encoding for \"Eager execution\" with the encoding of \"Immediate running\".\n",
+ "\n",
+ "Again we'll drop into a pdb shell. This time we'll run some TensorFlow operations to patch the encodings while the model is running.\n",
+ "\n",
+ "```\n",
+ "(running the next cell again drops you into a pdb shell)\n",
+ "step\n",
+ "encoder_output = tf.concat([fast_decode.previous_encoding[:, :3], encoder_output[:, 3:]], axis=1)\n",
+ "continue\n",
+ "```"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "t2as_Kn1h65G",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 179
+ },
+ "outputId": "5b4e546e-3bb4-4761-c545-467b631e3ffe"
+ },
+ "cell_type": "code",
+ "source": [
+ "model_output = translate_model.infer(encode(\"Eager execution\"))\n",
+ "print(decode(model_output[\"outputs\"]))"
+ ],
+ "execution_count": 9,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "> <ipython-input-6-ee9b4225ba2a>(4)debug_fn()\n",
+ "-> return previous_fast_decode(*args, **kwargs) # \"step\" in pdb to step in\n",
+ "(Pdb) step\n",
+ "--Call--\n",
+ "> /usr/local/lib/python2.7/dist-packages/tensor2tensor/models/transformer.py(427)fast_decode()\n",
+ "-> def fast_decode(encoder_output,\n",
+ "(Pdb) encoder_output = tf.concat([fast_decode.previous_encoding[:, :3], encoder_output[:, 3:]], axis=1)\n",
+ "(Pdb) continue\n",
+ "sofortige Ausführung\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "rK6tYZ23I2cm",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "We get a different decoding, with the correct sense of \"execution\". Likely we're keeping just the encoding of \"tion\" from \"Eager execution\", so no great breakthrough in translation modeling.\n",
+ "\n",
+ "Similarly it's possible to modify attention vectors, or change words during decoding to help debug a beam search."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "Nb-4ipYNRWxA",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "This colab was adapted from the [Tensor2Tensor colab](https://colab.research.google.com/github/tensorflow/tensor2tensor/blob/master/tensor2tensor/notebooks/hello_t2t.ipynb). Credit to Ankur Taly for its concept."
+ ]
+ }
+ ]
+} \ No newline at end of file
diff --git a/tensorflow/contrib/lite/Makefile b/tensorflow/contrib/lite/Makefile
index 2b6997146e..a616138d33 100644
--- a/tensorflow/contrib/lite/Makefile
+++ b/tensorflow/contrib/lite/Makefile
@@ -17,7 +17,29 @@ else
endif
endif
-ARCH := $(shell if [[ $(shell uname -m) =~ i[345678]86 ]]; then echo x86_32; else echo $(shell uname -m); fi)
+HOST_ARCH := $(shell if [[ $(shell uname -m) =~ i[345678]86 ]]; then echo x86_32; else echo $(shell uname -m); fi)
+
+# Self-hosting
+TARGET_ARCH := ${HOST_ARCH}
+
+# Cross compiling
+ifeq ($(CROSS),rpi)
+ TARGET_ARCH := armv7l
+ TARGET_TOOLCHAIN_PREFIX := arm-linux-gnueabihf-
+endif
+
+ifeq ($(CROSS),riscv)
+ TARGET_ARCH := riscv
+ TARGET_TOOLCHAIN_PREFIX := riscv32-unknown-elf-
+endif
+ifeq ($(CROSS),stm32f7)
+ TARGET_ARCH := armf7
+ TARGET_TOOLCHAIN_PREFIX := arm-none-eabi-
+endif
+ifeq ($(CROSS),stm32f1)
+ TARGET_ARCH := armm1
+ TARGET_TOOLCHAIN_PREFIX := arm-none-eabi-
+endif
# Where compiled objects are stored.
OBJDIR := $(MAKEFILE_DIR)/gen/obj/
@@ -25,11 +47,46 @@ BINDIR := $(MAKEFILE_DIR)/gen/bin/
LIBDIR := $(MAKEFILE_DIR)/gen/lib/
GENDIR := $(MAKEFILE_DIR)/gen/obj/
+LIBS :=
+ifeq ($(TARGET_ARCH),x86_64)
+ CXXFLAGS += -fPIC -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK -pthread # -msse4.2
+endif
+
+ifeq ($(TARGET_ARCH),armv7l)
+ CXXFLAGS += -mfpu=neon -pthread -fPIC
+ LIBS += -ldl
+endif
+
+ifeq ($(TARGET_ARCH),riscv)
+# CXXFLAGS += -march=gap8
+ CXXFLAGS += -DTFLITE_MCU
+ LIBS += -ldl
+ BUILD_TYPE := micro
+endif
+
+ifeq ($(TARGET_ARCH),armf7)
+ CXXFLAGS += -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK -DTFLITE_MCU
+ CXXFLAGS += -fno-rtti -fmessage-length=0 -fno-exceptions -fno-builtin -ffunction-sections -fdata-sections
+ CXXFLAGS += -funsigned-char -MMD
+ CXXFLAGS += -mcpu=cortex-m7 -mthumb -mfpu=fpv5-sp-d16 -mfloat-abi=softfp
+ CXXFLAGS += '-std=gnu++11' '-fno-rtti' '-Wvla' '-c' '-Wall' '-Wextra' '-Wno-unused-parameter' '-Wno-missing-field-initializers' '-fmessage-length=0' '-fno-exceptions' '-fno-builtin' '-ffunction-sections' '-fdata-sections' '-funsigned-char' '-MMD' '-fno-delete-null-pointer-checks' '-fomit-frame-pointer' '-Os'
+ LIBS += -ldl
+ BUILD_TYPE := micro
+endif
+ifeq ($(TARGET_ARCH),armm1)
+ CXXFLAGS += -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK -mcpu=cortex-m1 -mthumb -DTFLITE_MCU
+ CXXFLAGS += -fno-rtti -fmessage-length=0 -fno-exceptions -fno-builtin -ffunction-sections -fdata-sections
+ CXXFLAGS += -funsigned-char -MMD
+ LIBS += -ldl
+endif
+
# Settings for the host compiler.
-CXX := $(CC_PREFIX)gcc
-CXXFLAGS := --std=c++11 -O3 -DNDEBUG
-CC := $(CC_PREFIX)gcc
-CCFLAGS := -O3 -DNDEBUG
+CXX := $(CC_PREFIX) ${TARGET_TOOLCHAIN_PREFIX}g++
+CXXFLAGS += --std=c++11 -O3 -DNDEBUG
+CCFLAGS := ${CXXFLAGS}
+CC := $(CC_PREFIX) ${TARGET_TOOLCHAIN_PREFIX}gcc
+AR := $(CC_PREFIX) ${TARGET_TOOLCHAIN_PREFIX}ar
+CFLAGS :=
LDOPTS :=
LDOPTS += -L/usr/local/lib
ARFLAGS := -r
@@ -48,7 +105,7 @@ INCLUDES := \
# override local versions in the source tree.
INCLUDES += -I/usr/local/include
-LIBS := \
+LIBS += \
-lstdc++ \
-lpthread \
-lm \
@@ -92,18 +149,21 @@ PROFILE_SUMMARIZER_SRCS := \
CORE_CC_ALL_SRCS := \
$(wildcard tensorflow/contrib/lite/*.cc) \
+$(wildcard tensorflow/contrib/lite/*.c)
+ifneq ($(BUILD_TYPE),micro)
+CORE_CC_ALL_SRCS += \
$(wildcard tensorflow/contrib/lite/kernels/*.cc) \
$(wildcard tensorflow/contrib/lite/kernels/internal/*.cc) \
$(wildcard tensorflow/contrib/lite/kernels/internal/optimized/*.cc) \
$(wildcard tensorflow/contrib/lite/kernels/internal/reference/*.cc) \
$(PROFILER_SRCS) \
-$(wildcard tensorflow/contrib/lite/*.c) \
$(wildcard tensorflow/contrib/lite/kernels/*.c) \
$(wildcard tensorflow/contrib/lite/kernels/internal/*.c) \
$(wildcard tensorflow/contrib/lite/kernels/internal/optimized/*.c) \
$(wildcard tensorflow/contrib/lite/kernels/internal/reference/*.c) \
$(wildcard tensorflow/contrib/lite/downloads/farmhash/src/farmhash.cc) \
$(wildcard tensorflow/contrib/lite/downloads/fft2d/fftsg.c)
+endif
# Remove any duplicates.
CORE_CC_ALL_SRCS := $(sort $(CORE_CC_ALL_SRCS))
CORE_CC_EXCLUDE_SRCS := \
@@ -113,6 +173,11 @@ $(wildcard tensorflow/contrib/lite/*/*/*test.cc) \
$(wildcard tensorflow/contrib/lite/*/*/*/*test.cc) \
$(wildcard tensorflow/contrib/lite/kernels/test_util.cc) \
$(MINIMAL_SRCS)
+ifeq ($(BUILD_TYPE),micro)
+CORE_CC_EXCLUDE_SRCS += \
+tensorflow/contrib/lite/model.cc \
+tensorflow/contrib/lite/nnapi_delegate.cc
+endif
# Filter out all the excluded files.
TF_LITE_CC_SRCS := $(filter-out $(CORE_CC_EXCLUDE_SRCS), $(CORE_CC_ALL_SRCS))
# File names of the intermediate files target compilation generates.
@@ -120,7 +185,6 @@ TF_LITE_CC_OBJS := $(addprefix $(OBJDIR), \
$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(TF_LITE_CC_SRCS))))
LIB_OBJS := $(TF_LITE_CC_OBJS)
-
# Benchmark sources
BENCHMARK_SRCS_DIR := tensorflow/contrib/lite/tools/benchmark
BENCHMARK_ALL_SRCS := $(TFLITE_CC_SRCS) \
@@ -146,6 +210,9 @@ $(OBJDIR)%.o: %.c
# The target that's compiled if there's no command-line arguments.
all: $(LIB_PATH) $(MINIMAL_PATH) $(BENCHMARK_BINARY)
+# The target that's compiled for micro-controllers
+micro: $(LIB_PATH)
+
# Gathers together all the objects we've compiled into a single '.a' archive.
$(LIB_PATH): $(LIB_OBJS)
@mkdir -p $(dir $@)
diff --git a/tensorflow/contrib/lite/allocation.cc b/tensorflow/contrib/lite/allocation.cc
index a4772731ec..c42622ff02 100644
--- a/tensorflow/contrib/lite/allocation.cc
+++ b/tensorflow/contrib/lite/allocation.cc
@@ -14,7 +14,9 @@ limitations under the License.
==============================================================================*/
#include <fcntl.h>
+#ifndef TFLITE_MCU
#include <sys/mman.h>
+#endif
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
@@ -27,10 +29,13 @@ limitations under the License.
#include "tensorflow/contrib/lite/allocation.h"
#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/error_reporter.h"
+#ifndef TFLITE_MCU
#include "tensorflow/contrib/lite/nnapi_delegate.h"
+#endif
namespace tflite {
+#ifndef TFLITE_MCU
MMAPAllocation::MMAPAllocation(const char* filename,
ErrorReporter* error_reporter)
: Allocation(error_reporter), mmapped_buffer_(MAP_FAILED) {
@@ -111,6 +116,7 @@ MemoryAllocation::MemoryAllocation(const void* ptr, size_t num_bytes,
buffer_ = ptr;
buffer_size_bytes_ = num_bytes;
}
+#endif
MemoryAllocation::~MemoryAllocation() {}
diff --git a/tensorflow/contrib/lite/arena_planner.cc b/tensorflow/contrib/lite/arena_planner.cc
index 22be64d6ff..4257e754ad 100644
--- a/tensorflow/contrib/lite/arena_planner.cc
+++ b/tensorflow/contrib/lite/arena_planner.cc
@@ -35,12 +35,13 @@ struct AllocationInfo {
};
ArenaPlanner::ArenaPlanner(TfLiteContext* context,
- std::unique_ptr<GraphInfo> graph_info)
+ std::unique_ptr<GraphInfo> graph_info,
+ bool preserve_inputs)
: context_(context),
graph_info_(std::move(graph_info)),
arena_(kDefaultArenaAlignment),
- persistent_arena_(kDefaultArenaAlignment) {}
-
+ persistent_arena_(kDefaultArenaAlignment),
+ preserve_inputs_(preserve_inputs) {}
ArenaPlanner::~ArenaPlanner() {}
int64_t ArenaPlanner::BasePointer(TfLiteAllocationType type) {
@@ -112,9 +113,13 @@ TfLiteStatus ArenaPlanner::PlanAllocations() {
refcounts[tensor_index]++;
}
- // Queue all graph inputs for allocation.
+ // Queue all graph inputs for allocation. If preserve_inputs_ is true, make
+ // sure they never be overwritten.
for (int tensor_index : graph_info_->inputs()) {
if (tensor_index != kOptionalTensor) {
+ if (preserve_inputs_) {
+ refcounts[tensor_index]++;
+ }
TF_LITE_ENSURE_STATUS(allocate(0, tensor_index));
}
}
diff --git a/tensorflow/contrib/lite/arena_planner.h b/tensorflow/contrib/lite/arena_planner.h
index e9d0fbc5a9..1d84950e91 100644
--- a/tensorflow/contrib/lite/arena_planner.h
+++ b/tensorflow/contrib/lite/arena_planner.h
@@ -43,8 +43,11 @@ struct AllocationInfo;
class ArenaPlanner : public MemoryPlanner {
public:
// Ownership of 'context' is not taken and it must remain util the
- // ArenaPlanner is destroyed.
- ArenaPlanner(TfLiteContext* context, std::unique_ptr<GraphInfo> graph_info);
+ // ArenaPlanner is destroyed. If 'preserve_inputs' is true the inputs to the
+ // graph will not share memory with any other tensor, effectively preserving
+ // them until the end of inference.
+ ArenaPlanner(TfLiteContext* context, std::unique_ptr<GraphInfo> graph_info,
+ bool preserve_inputs);
~ArenaPlanner() override;
ArenaPlanner(const ArenaPlanner&) = delete;
ArenaPlanner& operator=(const ArenaPlanner&) = delete;
@@ -100,6 +103,8 @@ class ArenaPlanner : public MemoryPlanner {
// Raw memory buffer that is allocated for persistent tensors that are
// declared as kTfLiteArenaRwPersistent.
SimpleMemoryArena persistent_arena_;
+
+ bool preserve_inputs_;
};
} // namespace tflite
diff --git a/tensorflow/contrib/lite/arena_planner_test.cc b/tensorflow/contrib/lite/arena_planner_test.cc
index f0fd35216f..f5bd1932f9 100644
--- a/tensorflow/contrib/lite/arena_planner_test.cc
+++ b/tensorflow/contrib/lite/arena_planner_test.cc
@@ -151,11 +151,12 @@ void ReportError(TfLiteContext* context, const char* format, ...) {
class ArenaPlannerTest : public ::testing::Test {
protected:
- void SetGraph(TestGraph* graph) {
+ void SetGraph(TestGraph* graph, bool preserve_inputs = false) {
graph_ = graph;
context_.ReportError = ReportError;
planner_.reset(new ArenaPlanner(
- &context_, std::unique_ptr<GraphInfo>(new TestGraphInfo(graph))));
+ &context_, std::unique_ptr<GraphInfo>(new TestGraphInfo(graph)),
+ preserve_inputs));
CHECK(planner_->ResetAllocations() == kTfLiteOk);
CHECK(planner_->PlanAllocations() == kTfLiteOk);
}
@@ -243,6 +244,30 @@ TEST_F(ArenaPlannerTest, SimpleGraph) {
EXPECT_EQ(GetOffset(3), 0);
}
+TEST_F(ArenaPlannerTest, SimpleGraphInputsPreserved) {
+ TestGraph graph({0, 1},
+ {
+ /* in, out, tmp */
+ {{0, 1}, {2}, {}}, // First op
+ {{2, 0}, {4, 5}, {}}, // Second op
+ {{4, 5}, {3}, {}} // Third op
+ },
+ {3});
+ SetGraph(&graph, /*preserve_inputs=*/true);
+ Execute(0, 10);
+
+ // Alloc(+) and dealloc(-) order: +0 +1 +2 +4 +5 -2 +3 -4 -5
+ EXPECT_EQ(GetOffset(0), 0);
+ EXPECT_EQ(GetOffset(1), GetOffsetAfter(0));
+ EXPECT_EQ(GetOffset(2), GetOffsetAfter(1));
+ EXPECT_EQ(GetOffset(4), GetOffsetAfter(2));
+ EXPECT_EQ(GetOffset(5), GetOffsetAfter(4));
+ // Because we are keeping the inputs alive until the end (due to
+ // preserve_inputs=true), the output tensor will not be able to use that
+ // space. It will end up using the same are as tensor #2.
+ EXPECT_EQ(GetOffset(3), GetOffsetAfter(1));
+}
+
TEST_F(ArenaPlannerTest, SimpleGraphWithTemporary) {
TestGraph graph({0, 1},
{
diff --git a/tensorflow/contrib/lite/examples/android/BUILD b/tensorflow/contrib/lite/examples/android/BUILD
index dd2cd17324..4d2437e7d3 100644
--- a/tensorflow/contrib/lite/examples/android/BUILD
+++ b/tensorflow/contrib/lite/examples/android/BUILD
@@ -37,6 +37,7 @@ android_binary(
"@tflite_conv_actions_frozen//:conv_actions_frozen.tflite",
"//tensorflow/contrib/lite/examples/android/app/src/main/assets:conv_actions_labels.txt",
"@tflite_mobilenet_ssd//:mobilenet_ssd.tflite",
+ "@tflite_mobilenet_ssd_quant//:detect.tflite",
"//tensorflow/contrib/lite/examples/android/app/src/main/assets:box_priors.txt",
"//tensorflow/contrib/lite/examples/android/app/src/main/assets:coco_labels_list.txt",
],
diff --git a/tensorflow/contrib/lite/examples/android/app/build.gradle b/tensorflow/contrib/lite/examples/android/app/build.gradle
index 8e0a98ed63..1ffb9dd377 100644
--- a/tensorflow/contrib/lite/examples/android/app/build.gradle
+++ b/tensorflow/contrib/lite/examples/android/app/build.gradle
@@ -9,7 +9,7 @@ android {
targetSdkVersion 26
versionCode 1
versionName "1.0"
- testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner"
+ testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
// Remove this block.
jackOptions {
@@ -51,7 +51,7 @@ apply from: "download-models.gradle"
dependencies {
compile fileTree(dir: 'libs', include: ['*.jar'])
- androidTestCompile('com.android.support.test.espresso:espresso-core:2.2.2', {
+ androidTestCompile('com.androidx.test.espresso:espresso-core:2.2.2', {
exclude group: 'com.android.support', module: 'support-annotations'
})
compile 'org.tensorflow:tensorflow-lite:0.0.0-nightly'
diff --git a/tensorflow/contrib/lite/examples/android/app/download-models.gradle b/tensorflow/contrib/lite/examples/android/app/download-models.gradle
index 8e65dc076f..c100e37c16 100644
--- a/tensorflow/contrib/lite/examples/android/app/download-models.gradle
+++ b/tensorflow/contrib/lite/examples/android/app/download-models.gradle
@@ -12,8 +12,9 @@
def models = ['conv_actions_tflite.zip',
'mobilenet_ssd_tflite_v1.zip',
- 'mobilenet_v1_224_android_quant_2017_11_08.zip']
-// LINT.ThenChange(//tensorflow/examples/android/BUILD)
+ 'mobilenet_v1_224_android_quant_2017_11_08.zip',
+ 'coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip']
+// LINT.ThenChange(//tensorflow/contrib/lite/examples/android/BUILD)
// Root URL for model archives
def MODEL_URL = 'https://storage.googleapis.com/download.tensorflow.org/models/tflite'
diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/DetectorActivity.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/DetectorActivity.java
index de997e454a..87160f6b3f 100644
--- a/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/DetectorActivity.java
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/DetectorActivity.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+ * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -50,9 +50,10 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable
// Configuration values for the prepackaged SSD model.
private static final int TF_OD_API_INPUT_SIZE = 300;
- private static final String TF_OD_API_MODEL_FILE = "mobilenet_ssd.tflite";
+ private static final boolean TF_OD_API_IS_QUANTIZED = true;
+ private static final String TF_OD_API_MODEL_FILE = "detect.tflite";
private static final String TF_OD_API_LABELS_FILE = "file:///android_asset/coco_labels_list.txt";
-
+
// Which detection model to use: by default uses Tensorflow Object Detection API frozen
// checkpoints.
private enum DetectorMode {
@@ -107,7 +108,11 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable
try {
detector =
TFLiteObjectDetectionAPIModel.create(
- getAssets(), TF_OD_API_MODEL_FILE, TF_OD_API_LABELS_FILE, TF_OD_API_INPUT_SIZE);
+ getAssets(),
+ TF_OD_API_MODEL_FILE,
+ TF_OD_API_LABELS_FILE,
+ TF_OD_API_INPUT_SIZE,
+ TF_OD_API_IS_QUANTIZED);
cropSize = TF_OD_API_INPUT_SIZE;
} catch (final IOException e) {
LOGGER.e("Exception initializing classifier!", e);
diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java
index 580206943b..9eb21de9d0 100644
--- a/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java
@@ -30,12 +30,9 @@ import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.util.ArrayList;
-import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
-import java.util.PriorityQueue;
-import java.util.StringTokenizer;
import java.util.Vector;
import org.tensorflow.demo.env.Logger;
import org.tensorflow.lite.Interpreter;
@@ -48,40 +45,35 @@ public class TFLiteObjectDetectionAPIModel implements Classifier {
private static final Logger LOGGER = new Logger();
// Only return this many results.
- private static final int NUM_RESULTS = 1917;
- private static final int NUM_CLASSES = 91;
-
- private static final float Y_SCALE = 10.0f;
- private static final float X_SCALE = 10.0f;
- private static final float H_SCALE = 5.0f;
- private static final float W_SCALE = 5.0f;
-
+ private static final int NUM_DETECTIONS = 10;
+ private boolean isModelQuantized;
// Float model
private static final float IMAGE_MEAN = 128.0f;
private static final float IMAGE_STD = 128.0f;
-
- //Number of threads in the java app
+ // Number of threads in the java app
private static final int NUM_THREADS = 4;
-
-
// Config values.
private int inputSize;
-
- private final float[][] boxPriors = new float[4][NUM_RESULTS];
-
// Pre-allocated buffers.
private Vector<String> labels = new Vector<String>();
private int[] intValues;
+ // outputLocations: array of shape [Batchsize, NUM_DETECTIONS,4]
+ // contains the location of detected boxes
private float[][][] outputLocations;
- private float[][][] outputClasses;
-
- private ByteBuffer imgData = null;
+ // outputClasses: array of shape [Batchsize, NUM_DETECTIONS]
+ // contains the classes of detected boxes
+ private float[][] outputClasses;
+ // outputScores: array of shape [Batchsize, NUM_DETECTIONS]
+ // contains the scores of detected boxes
+ private float[][] outputScores;
+ // numDetections: array of shape [Batchsize]
+ // contains the number of detected boxes
+ private float[] numDetections;
+
+ private ByteBuffer imgData;
private Interpreter tfLite;
- private float expit(final float x) {
- return (float) (1. / (1. + Math.exp(-x)));
- }
/** Memory-map the model file in Assets. */
private static MappedByteBuffer loadModelFile(AssetManager assets, String modelFilename)
@@ -94,77 +86,24 @@ public class TFLiteObjectDetectionAPIModel implements Classifier {
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}
- private void loadCoderOptions(
- final AssetManager assetManager, final String locationFilename, final float[][] boxPriors)
- throws IOException {
- // Try to be intelligent about opening from assets or sdcard depending on prefix.
- final String assetPrefix = "file:///android_asset/";
- InputStream is;
- if (locationFilename.startsWith(assetPrefix)) {
- is = assetManager.open(locationFilename.split(assetPrefix, -1)[1]);
- } else {
- is = new FileInputStream(locationFilename);
- }
-
- final BufferedReader reader = new BufferedReader(new InputStreamReader(is));
-
- for (int lineNum = 0; lineNum < 4; ++lineNum) {
- String line = reader.readLine();
- final StringTokenizer st = new StringTokenizer(line, ", ");
- int priorIndex = 0;
- while (st.hasMoreTokens()) {
- final String token = st.nextToken();
- try {
- final float number = Float.parseFloat(token);
- boxPriors[lineNum][priorIndex++] = number;
- } catch (final NumberFormatException e) {
- // Silently ignore.
- }
- }
- if (priorIndex != NUM_RESULTS) {
- throw new RuntimeException(
- "BoxPrior length mismatch: " + priorIndex + " vs " + NUM_RESULTS);
- }
- }
-
- LOGGER.i("Loaded box priors!");
- }
-
- void decodeCenterSizeBoxes(float[][][] predictions) {
- for (int i = 0; i < NUM_RESULTS; ++i) {
- float ycenter = predictions[0][i][0] / Y_SCALE * boxPriors[2][i] + boxPriors[0][i];
- float xcenter = predictions[0][i][1] / X_SCALE * boxPriors[3][i] + boxPriors[1][i];
- float h = (float) Math.exp(predictions[0][i][2] / H_SCALE) * boxPriors[2][i];
- float w = (float) Math.exp(predictions[0][i][3] / W_SCALE) * boxPriors[3][i];
-
- float ymin = ycenter - h / 2.f;
- float xmin = xcenter - w / 2.f;
- float ymax = ycenter + h / 2.f;
- float xmax = xcenter + w / 2.f;
-
- predictions[0][i][0] = ymin;
- predictions[0][i][1] = xmin;
- predictions[0][i][2] = ymax;
- predictions[0][i][3] = xmax;
- }
- }
-
/**
* Initializes a native TensorFlow session for classifying images.
*
* @param assetManager The asset manager to be used to load assets.
* @param modelFilename The filepath of the model GraphDef protocol buffer.
* @param labelFilename The filepath of label file for classes.
+ * @param inputSize The size of image input
+ * @param isQuantized Boolean representing model is quantized or not
*/
public static Classifier create(
final AssetManager assetManager,
final String modelFilename,
final String labelFilename,
- final int inputSize) throws IOException {
+ final int inputSize,
+ final boolean isQuantized)
+ throws IOException {
final TFLiteObjectDetectionAPIModel d = new TFLiteObjectDetectionAPIModel();
- d.loadCoderOptions(assetManager, "file:///android_asset/box_priors.txt", d.boxPriors);
-
InputStream labelsInput = null;
String actualFilename = labelFilename.split("file:///android_asset/")[1];
labelsInput = assetManager.open(actualFilename);
@@ -185,15 +124,23 @@ public class TFLiteObjectDetectionAPIModel implements Classifier {
throw new RuntimeException(e);
}
+ d.isModelQuantized = isQuantized;
// Pre-allocate buffers.
- int numBytesPerChannel = 4; // Floating point
+ int numBytesPerChannel;
+ if (isQuantized) {
+ numBytesPerChannel = 1; // Quantized
+ } else {
+ numBytesPerChannel = 4; // Floating point
+ }
d.imgData = ByteBuffer.allocateDirect(1 * d.inputSize * d.inputSize * 3 * numBytesPerChannel);
d.imgData.order(ByteOrder.nativeOrder());
d.intValues = new int[d.inputSize * d.inputSize];
d.tfLite.setNumThreads(NUM_THREADS);
- d.outputLocations = new float[1][NUM_RESULTS][4];
- d.outputClasses = new float[1][NUM_RESULTS][NUM_CLASSES];
+ d.outputLocations = new float[1][NUM_DETECTIONS][4];
+ d.outputClasses = new float[1][NUM_DETECTIONS];
+ d.outputScores = new float[1][NUM_DETECTIONS];
+ d.numDetections = new float[1];
return d;
}
@@ -209,26 +156,37 @@ public class TFLiteObjectDetectionAPIModel implements Classifier {
// on the provided parameters.
bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
+ imgData.rewind();
for (int i = 0; i < inputSize; ++i) {
for (int j = 0; j < inputSize; ++j) {
int pixelValue = intValues[i * inputSize + j];
- // Float model
- imgData.putFloat((((pixelValue >> 16) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
- imgData.putFloat((((pixelValue >> 8) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
- imgData.putFloat(((pixelValue & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
+ if (isModelQuantized) {
+ // Quantized model
+ imgData.put((byte) ((pixelValue >> 16) & 0xFF));
+ imgData.put((byte) ((pixelValue >> 8) & 0xFF));
+ imgData.put((byte) (pixelValue & 0xFF));
+ } else { // Float model
+ imgData.putFloat((((pixelValue >> 16) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
+ imgData.putFloat((((pixelValue >> 8) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
+ imgData.putFloat(((pixelValue & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
+ }
}
}
Trace.endSection(); // preprocessBitmap
// Copy the input data into TensorFlow.
Trace.beginSection("feed");
- outputLocations = new float[1][NUM_RESULTS][4];
- outputClasses = new float[1][NUM_RESULTS][NUM_CLASSES];
+ outputLocations = new float[1][NUM_DETECTIONS][4];
+ outputClasses = new float[1][NUM_DETECTIONS];
+ outputScores = new float[1][NUM_DETECTIONS];
+ numDetections = new float[1];
Object[] inputArray = {imgData};
Map<Integer, Object> outputMap = new HashMap<>();
outputMap.put(0, outputLocations);
outputMap.put(1, outputClasses);
+ outputMap.put(2, outputScores);
+ outputMap.put(3, numDetections);
Trace.endSection();
// Run the inference call.
@@ -236,56 +194,26 @@ public class TFLiteObjectDetectionAPIModel implements Classifier {
tfLite.runForMultipleInputsOutputs(inputArray, outputMap);
Trace.endSection();
- decodeCenterSizeBoxes(outputLocations);
-
- // Find the best detections.
- final PriorityQueue<Recognition> pq =
- new PriorityQueue<Recognition>(
- 1,
- new Comparator<Recognition>() {
- @Override
- public int compare(final Recognition lhs, final Recognition rhs) {
- // Intentionally reversed to put high confidence at the head of the queue.
- return Float.compare(rhs.getConfidence(), lhs.getConfidence());
- }
- });
-
- // Scale them back to the input size.
- for (int i = 0; i < NUM_RESULTS; ++i) {
- float topClassScore = -1000f;
- int topClassScoreIndex = -1;
-
- // Skip the first catch-all class.
- for (int j = 1; j < NUM_CLASSES; ++j) {
- float score = expit(outputClasses[0][i][j]);
-
- if (score > topClassScore) {
- topClassScoreIndex = j;
- topClassScore = score;
- }
- }
-
- if (topClassScore > 0.001f) {
- final RectF detection =
- new RectF(
- outputLocations[0][i][1] * inputSize,
- outputLocations[0][i][0] * inputSize,
- outputLocations[0][i][3] * inputSize,
- outputLocations[0][i][2] * inputSize);
-
- pq.add(
- new Recognition(
- "" + i,
- labels.get(topClassScoreIndex),
- outputClasses[0][i][topClassScoreIndex],
- detection));
- }
- }
-
- final ArrayList<Recognition> recognitions = new ArrayList<Recognition>();
- for (int i = 0; i < Math.min(pq.size(), 10); ++i) {
- Recognition recog = pq.poll();
- recognitions.add(recog);
+ // Show the best detections.
+ // after scaling them back to the input size.
+ final ArrayList<Recognition> recognitions = new ArrayList<>(NUM_DETECTIONS);
+ for (int i = 0; i < NUM_DETECTIONS; ++i) {
+ final RectF detection =
+ new RectF(
+ outputLocations[0][i][1] * inputSize,
+ outputLocations[0][i][0] * inputSize,
+ outputLocations[0][i][3] * inputSize,
+ outputLocations[0][i][2] * inputSize);
+ // SSD Mobilenet V1 Model assumes class 0 is background class
+ // in label file and class labels start from 1 to number_of_classes+1,
+ // while outputClasses correspond to class index from 0 to number_of_classes
+ int labelOffset = 1;
+ recognitions.add(
+ new Recognition(
+ "" + i,
+ labels.get((int) outputClasses[0][i] + labelOffset),
+ outputScores[0][i],
+ detection));
}
Trace.endSection(); // "recognizeImage"
return recognitions;
diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc
index 62a0b1ff08..dcb4ef593e 100644
--- a/tensorflow/contrib/lite/interpreter.cc
+++ b/tensorflow/contrib/lite/interpreter.cc
@@ -24,15 +24,22 @@ limitations under the License.
#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/error_reporter.h"
#include "tensorflow/contrib/lite/graph_info.h"
+#ifndef TFLITE_MCU
#include "tensorflow/contrib/lite/kernels/eigen_support.h"
#include "tensorflow/contrib/lite/kernels/gemm_support.h"
+#endif
#include "tensorflow/contrib/lite/memory_planner.h"
+#ifndef TFLITE_MCU
#include "tensorflow/contrib/lite/nnapi_delegate.h"
+#endif
#include "tensorflow/contrib/lite/profiling/profiler.h"
#include "tensorflow/contrib/lite/schema/schema_generated.h"
#include "tensorflow/contrib/lite/util.h"
namespace tflite {
+#ifdef TFLITE_MCU
+class NNAPIDelegate {};
+#endif
namespace {
@@ -531,7 +538,8 @@ TfLiteStatus Interpreter::PrepareOpsStartingAt(
TfLiteStatus Interpreter::PrepareOpsAndTensors() {
if (!memory_planner_) {
memory_planner_.reset(new ArenaPlanner(
- &context_, std::unique_ptr<GraphInfo>(new InterpreterInfo(this))));
+ &context_, std::unique_ptr<GraphInfo>(new InterpreterInfo(this)),
+ /*preserve_inputs=*/true));
memory_planner_->PlanAllocations();
}
@@ -557,6 +565,7 @@ TfLiteStatus Interpreter::Invoke() {
}
TfLiteStatus status = kTfLiteOk;
+#ifndef TFLITE_MCU
if (nnapi_delegate_) {
if (next_execution_plan_index_to_prepare_ == execution_plan_.size()) {
TF_LITE_ENSURE_OK(&context_, nnapi_delegate_->Invoke(this));
@@ -570,6 +579,7 @@ TfLiteStatus Interpreter::Invoke() {
return kTfLiteError;
}
}
+#endif
// Invocations are always done in node order.
// Note that calling Invoke repeatedly will cause the original memory plan to
@@ -826,6 +836,7 @@ TfLiteStatus Interpreter::ResizeTensorImpl(TfLiteTensor* tensor,
}
void Interpreter::UseNNAPI(bool enable) {
+#ifndef TFLITE_MCU
// TODO(aselle): This is a workaround for finding if NNAPI exists.
// We also need to make sure getLibraryHandle() is renamed to be NNAPI
// prefixed.
@@ -835,6 +846,7 @@ void Interpreter::UseNNAPI(bool enable) {
} else if (!nnapi_delegate_) {
nnapi_delegate_.reset(new NNAPIDelegate);
}
+#endif
}
void Interpreter::SetNumThreads(int num_threads) {
@@ -842,8 +854,10 @@ void Interpreter::SetNumThreads(int num_threads) {
// TODO(ahentz): find a way to avoid this. It causes gemmlowp and eigen to
// be required in order to compile the framework.
+#ifndef TFLITE_MCU
gemm_support::SetNumThreads(&context_, num_threads);
eigen_support::SetNumThreads(&context_, num_threads);
+#endif
}
TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate,
diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc
index 21cdf87d1e..6f13b43ebf 100644
--- a/tensorflow/contrib/lite/interpreter_test.cc
+++ b/tensorflow/contrib/lite/interpreter_test.cc
@@ -231,32 +231,16 @@ TEST(BasicInterpreter, CheckArenaAllocation) {
ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
- ASSERT_EQ(interpreter.tensor(0)->data.raw, interpreter.tensor(4)->data.raw);
- ASSERT_EQ(interpreter.tensor(1)->data.raw, interpreter.tensor(7)->data.raw);
- ASSERT_EQ(interpreter.tensor(8)->data.raw, nullptr);
-
- ASSERT_LT(interpreter.tensor(4)->data.raw, interpreter.tensor(1)->data.raw);
- ASSERT_LT(interpreter.tensor(6)->data.raw, interpreter.tensor(1)->data.raw);
ASSERT_LT(interpreter.tensor(0)->data.raw, interpreter.tensor(1)->data.raw);
-
- ASSERT_LT(interpreter.tensor(0)->data.raw, interpreter.tensor(3)->data.raw);
- ASSERT_LT(interpreter.tensor(1)->data.raw, interpreter.tensor(3)->data.raw);
+ ASSERT_LT(interpreter.tensor(1)->data.raw, interpreter.tensor(2)->data.raw);
ASSERT_LT(interpreter.tensor(2)->data.raw, interpreter.tensor(3)->data.raw);
- ASSERT_LT(interpreter.tensor(4)->data.raw, interpreter.tensor(3)->data.raw);
- ASSERT_LT(interpreter.tensor(6)->data.raw, interpreter.tensor(3)->data.raw);
- ASSERT_LT(interpreter.tensor(7)->data.raw, interpreter.tensor(3)->data.raw);
- ASSERT_LT(interpreter.tensor(8)->data.raw, interpreter.tensor(3)->data.raw);
- ASSERT_LT(interpreter.tensor(9)->data.raw, interpreter.tensor(3)->data.raw);
-
- ASSERT_LT(interpreter.tensor(0)->data.raw, interpreter.tensor(5)->data.raw);
- ASSERT_LT(interpreter.tensor(1)->data.raw, interpreter.tensor(5)->data.raw);
- ASSERT_LT(interpreter.tensor(2)->data.raw, interpreter.tensor(5)->data.raw);
- ASSERT_LT(interpreter.tensor(3)->data.raw, interpreter.tensor(5)->data.raw);
+ ASSERT_LT(interpreter.tensor(3)->data.raw, interpreter.tensor(4)->data.raw);
ASSERT_LT(interpreter.tensor(4)->data.raw, interpreter.tensor(5)->data.raw);
- ASSERT_LT(interpreter.tensor(6)->data.raw, interpreter.tensor(5)->data.raw);
- ASSERT_LT(interpreter.tensor(7)->data.raw, interpreter.tensor(5)->data.raw);
- ASSERT_LT(interpreter.tensor(8)->data.raw, interpreter.tensor(5)->data.raw);
- ASSERT_LT(interpreter.tensor(9)->data.raw, interpreter.tensor(5)->data.raw);
+ ASSERT_LT(interpreter.tensor(5)->data.raw, interpreter.tensor(7)->data.raw);
+ ASSERT_EQ(interpreter.tensor(6)->data.raw, interpreter.tensor(2)->data.raw);
+ // #7 is the one with the largest pointer.
+ ASSERT_EQ(interpreter.tensor(8)->data.raw, nullptr);
+ ASSERT_EQ(interpreter.tensor(9)->data.raw, interpreter.tensor(5)->data.raw);
}
TEST(BasicInterpreter, BufferAccess) {
diff --git a/tensorflow/contrib/lite/java/demo/app/build.gradle b/tensorflow/contrib/lite/java/demo/app/build.gradle
index 192162cfce..288a5f73c5 100644
--- a/tensorflow/contrib/lite/java/demo/app/build.gradle
+++ b/tensorflow/contrib/lite/java/demo/app/build.gradle
@@ -10,7 +10,7 @@ android {
targetSdkVersion 26
versionCode 1
versionName "1.0"
- testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner"
+ testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
// Remove this block.
jackOptions {
@@ -44,7 +44,7 @@ repositories {
dependencies {
compile fileTree(dir: 'libs', include: ['*.jar'])
- androidTestCompile('com.android.support.test.espresso:espresso-core:2.2.2', {
+ androidTestCompile('com.androidx.test.espresso:espresso-core:2.2.2', {
exclude group: 'com.android.support', module: 'support-annotations'
})
compile 'com.android.support:appcompat-v7:25.2.0'
diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle b/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle
index c5d19bad89..3f32d62e5c 100644
--- a/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle
+++ b/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle
@@ -9,7 +9,7 @@ android {
targetSdkVersion 26
versionCode 1
versionName "1.0"
- testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner"
+ testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
// Remove this block.
jackOptions {
@@ -43,7 +43,7 @@ repositories {
dependencies {
compile fileTree(dir: 'libs', include: ['*.jar'])
- androidTestCompile('com.android.support.test.espresso:espresso-core:2.2.2', {
+ androidTestCompile('com.androidx.test.espresso:espresso-core:2.2.2', {
exclude group: 'com.android.support', module: 'support-annotations'
})
compile 'com.android.support:appcompat-v7:25.2.0'
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD
index 61d5af3478..27b8a16e15 100644
--- a/tensorflow/contrib/lite/kernels/BUILD
+++ b/tensorflow/contrib/lite/kernels/BUILD
@@ -964,6 +964,7 @@ tf_cc_test(
":builtin_ops",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_absl//absl/memory",
"@com_google_googletest//:gtest",
],
)
diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup.cc b/tensorflow/contrib/lite/kernels/embedding_lookup.cc
index 9410bead5e..0ba170a4da 100644
--- a/tensorflow/contrib/lite/kernels/embedding_lookup.cc
+++ b/tensorflow/contrib/lite/kernels/embedding_lookup.cc
@@ -94,7 +94,7 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
const TfLiteTensor* lookup, const TfLiteTensor* value,
TfLiteTensor* output) {
const int row_size = SizeOfDimension(value, 0);
- const double scaling_factor = 1.0 / value->params.scale;
+ const double scaling_factor = value->params.scale;
// col_size after we flatten tensor into 2D.
int col_size = 1;
diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
index 36c25388e8..a0e382edb6 100644
--- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
@@ -416,7 +416,7 @@ void LstmStep(
if (!use_cifg) {
if (use_peephole && !is_cell_state_all_zeros) {
VectorMultiply(cell_to_input_weights_ptr, n_cell,
- 1. / cell_to_input_weights_scale, recovered_cell_weights);
+ cell_to_input_weights_scale, recovered_cell_weights);
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
input_gate_scratch);
@@ -428,7 +428,7 @@ void LstmStep(
// For each batch and cell: update forget gate.
if (use_peephole && !is_cell_state_all_zeros) {
VectorMultiply(cell_to_forget_weights_ptr, n_cell,
- 1. / cell_to_forget_weights_scale, recovered_cell_weights);
+ cell_to_forget_weights_scale, recovered_cell_weights);
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
forget_gate_scratch);
@@ -460,7 +460,7 @@ void LstmStep(
// For each batch and cell: update the output gate.
if (use_peephole && !is_cell_state_all_zeros) {
VectorMultiply(cell_to_output_weights_ptr, n_cell,
- 1. / cell_to_output_weights_scale, recovered_cell_weights);
+ cell_to_output_weights_scale, recovered_cell_weights);
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
output_gate_scratch);
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h
index 4cfaa0f36d..0ce64f8c70 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h
@@ -3242,6 +3242,7 @@ inline void DepthwiseConv3x3Filter(
int32 output_shift, int32 output_activation_min,
int32 output_activation_max, uint8* output_data,
const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label(__PRETTY_FUNCTION__);
DepthwiseConvParams params;
params.input_depth = ArraySize(input_dims, 0);
params.input_width = ArraySize(input_dims, 1);
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
index 38ad32c734..5ba7e2af9b 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
@@ -162,7 +162,7 @@ void NeonMatrixBatchVectorMultiplyAccumulate(
int batch, row, col;
for (batch = 0; batch < n_batch; ++batch) {
- const float batch_scaling_factor_inv = 1.0 / scaling_factors[batch];
+ const float batch_scaling_factor = scaling_factors[batch];
// Copy the vector data to an aligned vector.
memcpy(aligned_vec, vectors + batch * m_cols, sizeof(int8) * m_cols);
// Compute dot-product for every column.
@@ -232,7 +232,7 @@ void NeonMatrixBatchVectorMultiplyAccumulate(
int32 neon_sum =
vgetq_lane_s64(pairwiseAdded, 0) + vgetq_lane_s64(pairwiseAdded, 1);
- *result += ((neon_sum + postable_sum) * batch_scaling_factor_inv);
+ *result += ((neon_sum + postable_sum) * batch_scaling_factor);
} // for row
} // for batch
@@ -418,13 +418,14 @@ void NeonSymmetricQuantizeFloats(const float* values, const int size,
*scaling_factor = 1;
return;
}
- *scaling_factor = kScale / range;
+ *scaling_factor = range / kScale;
+ const float scaling_factor_inv = 1.0f / *scaling_factor;
const int postamble_start =
size - (size & (2 * kFloatWeightsPerNeonLane - 1));
// Vectorized constants.
- const float32x4_t q_factor_f32x4 = vmovq_n_f32(*scaling_factor);
+ const float32x4_t q_factor_f32x4 = vmovq_n_f32(scaling_factor_inv);
const float32x4_t point5_f32x4 = vmovq_n_f32(0.5);
const float32x4_t zero_f32x4 = vmovq_n_f32(0.0);
const int32x4_t scale_i32x4 = vmovq_n_s32(kScale);
@@ -476,7 +477,7 @@ void NeonSymmetricQuantizeFloats(const float* values, const int size,
for (int i = postamble_start; i < size; ++i) {
const int32 quantized_value =
- static_cast<int32>(TfLiteRound(*scaling_factor * values[i]));
+ static_cast<int32>(TfLiteRound(scaling_factor_inv * values[i]));
quantized_values[i] = std::min(kScale, std::max(-kScale, quantized_value));
}
}
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
index f8c6f341f7..ccf112c990 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
@@ -51,10 +51,11 @@ void PortableSymmetricQuantizeFloats(const float* values, const int size,
*scaling_factor = 1;
return;
}
- *scaling_factor = kScale / range;
+ *scaling_factor = range / kScale;
+ const float scaling_factor_inv = 1.0f / *scaling_factor;
for (int i = 0; i < size; ++i) {
const int32_t quantized_value =
- static_cast<int32_t>(TfLiteRound(*scaling_factor * values[i]));
+ static_cast<int32_t>(TfLiteRound(values[i] * scaling_factor_inv));
// Clamp: just in case some odd numeric offset.
quantized_values[i] = std::min(kScale, std::max(-kScale, quantized_value));
}
@@ -85,7 +86,7 @@ void PortableMatrixBatchVectorMultiplyAccumulate(
float* __restrict__ result, int result_stride) {
int batch, row, col;
for (batch = 0; batch < n_batch; ++batch, vectors += m_cols) {
- const float batch_scaling_factor_inv = 1.0 / scaling_factors[batch];
+ const float batch_scaling_factor = scaling_factors[batch];
// Get the address of the first row.
const int8_t* row_ptr = matrix;
for (row = 0; row < m_rows; ++row, result += result_stride) {
@@ -98,7 +99,7 @@ void PortableMatrixBatchVectorMultiplyAccumulate(
for (col = 0; col < m_cols; ++col, ++row_ptr) {
dotprod += (*row_ptr) * (vectors[col]);
} // for col
- *result += (dotprod * batch_scaling_factor_inv);
+ *result += (dotprod * batch_scaling_factor);
} // for row
} // for batch
}
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc
index 14ee528394..aa0d49ae4d 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc
@@ -63,7 +63,8 @@ TEST(uKernels, SymmetricQuantizeFloatsTest) {
EXPECT_EQ(min, -640);
EXPECT_EQ(max, 1000);
- EXPECT_NEAR(scaling_factor, 0.127, 1e-6); // EQ won't work due to fpoint.
+ // EQ won't work due to fpoint.
+ EXPECT_NEAR(scaling_factor, 1000 / 127.0, 1e-6);
EXPECT_THAT(output,
testing::ElementsAreArray({-81, -81, -80, 1, 0, -1, -1, 0, 127}));
}
@@ -95,7 +96,7 @@ TEST(uKernels, SymmetricQuantizeFloatsAllAlmostZeroTest) {
EXPECT_NEAR(min, -9e-05, 1e-6);
EXPECT_NEAR(max, 0.0002, 1e-6);
- EXPECT_EQ(scaling_factor, 635000);
+ EXPECT_NEAR(scaling_factor, 1.57e-6, 1e-6);
EXPECT_THAT(output,
testing::ElementsAreArray({-6, 19, -4, -57, 1, 25, 6, 127, 0}));
}
diff --git a/tensorflow/contrib/lite/kernels/svdf.cc b/tensorflow/contrib/lite/kernels/svdf.cc
index 43ac3a2ce8..22eebdd4ce 100644
--- a/tensorflow/contrib/lite/kernels/svdf.cc
+++ b/tensorflow/contrib/lite/kernels/svdf.cc
@@ -382,11 +382,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// the Eval function.
// TODO(alanchiao): refactor logic out into dequantize function.
if (!op_data->float_weights_time_initialized) {
- const float inv_scale = 1.0 / weights_time->params.scale;
+ const float dequantization_scale = weights_time->params.scale;
const int8_t* weights_time_ptr =
reinterpret_cast<int8_t*>(weights_time->data.uint8);
for (int i = 0; i < NumElements(float_weights_time); ++i) {
- float_weights_time->data.f[i] = weights_time_ptr[i] * inv_scale;
+ float_weights_time->data.f[i] =
+ weights_time_ptr[i] * dequantization_scale;
}
op_data->float_weights_time_initialized = true;
}
diff --git a/tensorflow/contrib/lite/kernels/svdf_test.cc b/tensorflow/contrib/lite/kernels/svdf_test.cc
index 06df509d32..5af3ff8500 100644
--- a/tensorflow/contrib/lite/kernels/svdf_test.cc
+++ b/tensorflow/contrib/lite/kernels/svdf_test.cc
@@ -342,7 +342,7 @@ TEST_F(SVDFOpTest, BlackBoxTestHybridRank1) {
svdf.ResetState();
VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input),
&svdf,
- /*tolerance=*/0.00294435);
+ /*tolerance=*/0.002945);
}
TEST_F(SVDFOpTest, BlackBoxTestHybridRank2) {
diff --git a/tensorflow/contrib/lite/kernels/transpose_conv.cc b/tensorflow/contrib/lite/kernels/transpose_conv.cc
index 8b9deeed20..7182374a6f 100644
--- a/tensorflow/contrib/lite/kernels/transpose_conv.cc
+++ b/tensorflow/contrib/lite/kernels/transpose_conv.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/builtin_op_data.h"
#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/eigen_support.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
@@ -38,9 +39,35 @@ constexpr int kWeightsTensor = 1;
constexpr int kDataInputTensor = 2;
constexpr int kOutputTensor = 0;
-TfLiteStatus ResizeOutputShape(TfLiteContext* context,
- const TfLiteTensor* output_shape,
- TfLiteTensor* output) {
+const int kTensorNotAllocated = -1;
+
+struct OpData {
+ // IDs are the arbitrary identifiers used by TF Lite to identify and access
+ // memory buffers.
+ int im2col_id = kTensorNotAllocated;
+
+ // im2col is the only temporary currently tracked, therefore always index 0.
+ // If more temporaries are added, they should be properly tracked.
+ int32_t im2col_index = 0;
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ // This is a builtin op, so we don't use the contents in 'buffer', if any.
+ // Instead, we allocate a new object to use as scratch space for im2col, and
+ // to carry information from Prepare() to Eval().
+ auto* data = new OpData;
+ eigen_support::IncrementUsageCounter(context);
+ return data;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ eigen_support::DecrementUsageCounter(context);
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
+TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
+ const TfLiteTensor* output_shape,
+ TfLiteTensor* output) {
// Currently only support int32 for output shape.
if (output_shape->type != kTfLiteInt32) {
context->ReportError(context, "Output shape is %d, not int32.",
@@ -56,15 +83,60 @@ TfLiteStatus ResizeOutputShape(TfLiteContext* context,
return context->ResizeTensor(context, output, output_shape_array);
}
+// Allocate temporary im2col tensor.
+static TfLiteStatus AllocateIm2colTensor(TfLiteContext* context,
+ TfLiteNode* node) {
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+ if (data->im2col_id == kTensorNotAllocated) {
+ context->AddTensors(context, 1, &data->im2col_id);
+ }
+
+ TfLiteIntArrayFree(node->temporaries);
+ node->temporaries = TfLiteIntArrayCreate(1);
+ node->temporaries->data[data->im2col_index] = data->im2col_id;
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus ResizeIm2ColTensor(TfLiteContext* context,
+ const TfLiteTensor* output_shape,
+ const TfLiteTensor* weights,
+ const TfLiteTensor* input,
+ TfLiteTensor* im2col) {
+ if (output_shape->type != kTfLiteInt32) {
+ context->ReportError(context, "im2col shape is %d, not int32.",
+ output_shape->type);
+ return kTfLiteError;
+ }
+ TF_LITE_ENSURE_EQ(context, NumElements(output_shape), 4);
+ TfLiteIntArray* im2col_shape_array = TfLiteIntArrayCreate(4);
+ im2col_shape_array->data[0] = output_shape->data.i32[0];
+ im2col_shape_array->data[1] = output_shape->data.i32[1];
+ im2col_shape_array->data[2] = output_shape->data.i32[2];
+ const int input_depth = SizeOfDimension(input, 3);
+ const int filter_width = SizeOfDimension(weights, 1);
+ const int filter_height = SizeOfDimension(weights, 2);
+ im2col_shape_array->data[3] = input_depth * filter_height * filter_width;
+
+ im2col->type = input->type;
+ im2col->allocation_type = kTfLiteArenaRw;
+ return context->ResizeTensor(context, im2col, im2col_shape_array);
+}
+
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+ TF_LITE_ENSURE_STATUS(AllocateIm2colTensor(context, node));
+
const TfLiteTensor* output_shape =
GetInput(context, node, kOutputShapeTensor);
const TfLiteTensor* weights = GetInput(context, node, kWeightsTensor);
const TfLiteTensor* input = GetInput(context, node, kDataInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ OpData* user_data = reinterpret_cast<OpData*>(node->user_data);
+ TfLiteTensor* im2col =
+ &context->tensors[node->temporaries->data[user_data->im2col_index]];
TF_LITE_ENSURE_EQ(context, NumDimensions(output_shape), 1);
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
@@ -81,11 +153,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, SizeOfDimension(input, 3),
SizeOfDimension(weights, 3));
- if (!IsConstantTensor(output_shape)) {
+ if (IsConstantTensor(output_shape)) {
+ TF_LITE_ENSURE_STATUS(ResizeOutputTensor(context, output_shape, output));
+ TF_LITE_ENSURE_STATUS(
+ ResizeIm2ColTensor(context, output_shape, weights, input, im2col));
+ } else {
+ // Defer resizing until Eval().
SetTensorToDynamic(output);
- return kTfLiteOk;
}
- return ResizeOutputShape(context, output_shape, output);
+ return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
@@ -94,13 +170,19 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* weights = GetInput(context, node, kWeightsTensor);
const TfLiteTensor* input = GetInput(context, node, kDataInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
-
+ OpData* user_data = reinterpret_cast<OpData*>(node->user_data);
+ TfLiteTensor* im2col =
+ &context->tensors[node->temporaries->data[user_data->im2col_index]];
const auto* params =
reinterpret_cast<TfLiteTransposeConvParams*>(node->builtin_data);
if (IsDynamicTensor(output)) {
TF_LITE_ENSURE_OK(context,
- ResizeOutputShape(context, output_shape, output));
+ ResizeOutputTensor(context, output_shape, output));
+ }
+ if (IsDynamicTensor(im2col)) {
+ TF_LITE_ENSURE_OK(context, ResizeIm2ColTensor(context, output_shape,
+ weights, input, im2col));
}
// Get height and width of the output image.
@@ -119,17 +201,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// Currently only support float32.
switch (input->type) {
case kTfLiteFloat32:
- reference_ops::TransposeConv(
+ optimized_ops::TransposeConv(
GetTensorData<float>(input), GetTensorDims(input),
GetTensorData<float>(weights), GetTensorDims(weights), stride_width,
stride_height, padding_size.width, padding_size.height,
GetTensorData<float>(output), GetTensorDims(output),
- // Last two args specify im2col which reference_ops ignores.
- // (Note this does not lead to a performance regression, as the
- // previous optimized version was just a copy of the reference code.)
- // TODO(b/110208176): Allocate im2col tensors and switch to
- // optimized_ops.
- GetTensorData<float>(output), GetTensorDims(output));
+ GetTensorData<float>(im2col), GetTensorDims(im2col));
break;
default:
context->ReportError(context, "Type %d, not currently supported.",
@@ -142,8 +219,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace transpose_conv
TfLiteRegistration* Register_TRANSPOSE_CONV() {
- static TfLiteRegistration r = {nullptr, nullptr, transpose_conv::Prepare,
- transpose_conv::Eval};
+ static TfLiteRegistration r = {transpose_conv::Init, transpose_conv::Free,
+ transpose_conv::Prepare, transpose_conv::Eval};
return &r;
}
diff --git a/tensorflow/contrib/lite/kernels/transpose_conv_test.cc b/tensorflow/contrib/lite/kernels/transpose_conv_test.cc
index 55df897180..c741df19de 100644
--- a/tensorflow/contrib/lite/kernels/transpose_conv_test.cc
+++ b/tensorflow/contrib/lite/kernels/transpose_conv_test.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include <cstdarg>
#include <gtest/gtest.h>
+#include "absl/memory/memory.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
@@ -24,9 +25,49 @@ namespace {
using ::testing::ElementsAreArray;
+class ConstTransposeConvOpModel : public SingleOpModel {
+ // Just to be extra confusing, transpose_conv has an _input_ named
+ // "output_shape". This input sets the shape of the output tensor of the op.
+ // In this version of the test class, "output_shape" is a constant that must
+ // be specified in the constructor.
+ public:
+ ConstTransposeConvOpModel(TfLiteRegistration* registration,
+ std::initializer_list<int> input_shape,
+ std::initializer_list<int> filter_shape,
+ std::initializer_list<int> output_shape_data,
+ Padding padding, int stride_w, int stride_h) {
+ output_shape_ = AddConstInput(TensorType_INT32, output_shape_data,
+ {static_cast<int>(output_shape_data.size())});
+ filter_ = AddInput(TensorType_FLOAT32);
+ input_ = AddInput(TensorType_FLOAT32);
+ output_ = AddOutput(TensorType_FLOAT32);
+ SetBuiltinOp(
+ BuiltinOperator_TRANSPOSE_CONV, BuiltinOptions_TransposeConvOptions,
+ CreateTransposeConvOptions(builder_, padding, stride_w, stride_h)
+ .Union());
+ resolver_ = absl::make_unique<SingleOpResolver>(
+ BuiltinOperator_TRANSPOSE_CONV, registration);
+ BuildInterpreter({{4}, filter_shape, input_shape});
+ }
+
+ int output_shape() { return output_shape_; }
+ int filter() { return filter_; }
+ int input() { return input_; }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int output_shape_;
+ int filter_;
+ int input_;
+ int output_;
+};
+
class TransposeConvOpModel : public SingleOpModel {
public:
- TransposeConvOpModel(std::initializer_list<int> input_shape,
+ TransposeConvOpModel(TfLiteRegistration* registration,
+ std::initializer_list<int> input_shape,
std::initializer_list<int> filter_shape, Padding padding,
int stride_w, int stride_h) {
output_shape_ = AddInput(TensorType_INT32);
@@ -37,6 +78,8 @@ class TransposeConvOpModel : public SingleOpModel {
BuiltinOperator_TRANSPOSE_CONV, BuiltinOptions_TransposeConvOptions,
CreateTransposeConvOptions(builder_, padding, stride_w, stride_h)
.Union());
+ resolver_ = absl::make_unique<SingleOpResolver>(
+ BuiltinOperator_TRANSPOSE_CONV, registration);
BuildInterpreter({{4}, filter_shape, input_shape});
}
@@ -54,6 +97,15 @@ class TransposeConvOpModel : public SingleOpModel {
int output_;
};
+const auto kKernelMap = new std::map<string, TfLiteRegistration*>({});
+
+class TransposeConvOpTest : public SingleOpTest {
+ protected:
+ const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
+ return *kKernelMap;
+ }
+};
+
// Test case:
// output = tf.nn.conv2d_backprop_input(
// tf.constant([ 1, 4, 4, 1 ]),
@@ -61,8 +113,9 @@ class TransposeConvOpModel : public SingleOpModel {
// tf.constant(np.arange(1, 17), shape=[ 1, 4, 4, 1 ], dtype=tf.float32),
// [1, 1, 1, 1 ],
// "SAME")
-TEST(TransposeConvOpModelTest, SimpleTest) {
- TransposeConvOpModel m({1, 4, 4, 1}, {1, 3, 3, 1}, Padding_SAME, 1, 1);
+TEST_P(TransposeConvOpTest, SimpleTest) {
+ TransposeConvOpModel m(GetRegistration(), {1, 4, 4, 1}, {1, 3, 3, 1},
+ Padding_SAME, 1, 1);
m.PopulateTensor<int>(m.output_shape(), {1, 4, 4, 1});
m.PopulateTensor<float>(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9});
m.PopulateTensor<float>(
@@ -75,6 +128,21 @@ TEST(TransposeConvOpModelTest, SimpleTest) {
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
}
+// Test case: Same as above, but with a const "output_shape"
+TEST_P(TransposeConvOpTest, ConstSimpleTest) {
+ ConstTransposeConvOpModel m(GetRegistration(), {1, 4, 4, 1}, {1, 4, 4, 1},
+ {1, 3, 3, 1}, Padding_SAME, 1, 1);
+ m.PopulateTensor<float>(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9});
+ m.PopulateTensor<float>(
+ m.input(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray({29, 62, 83, 75, 99, 192, 237, 198, 207, 372,
+ 417, 330, 263, 446, 485, 365}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
+}
+
// Test case:
// filter = tf.constant(np.arange(1, 19),
// shape=[ 3, 3, 1, 2 ],
@@ -87,8 +155,9 @@ TEST(TransposeConvOpModelTest, SimpleTest) {
// "SAME")
// And filter value is derived by:
// filter = tf.reshape(tf.transpose(filter, perm=[3, 0, 1, 2]), shape=[18, 1])
-TEST(TransposeConvOpModelTest, TwoFiltersTest) {
- TransposeConvOpModel m({1, 4, 4, 2}, {1, 3, 3, 2}, Padding_SAME, 1, 1);
+TEST_P(TransposeConvOpTest, TwoFiltersTest) {
+ TransposeConvOpModel m(GetRegistration(), {1, 4, 4, 2}, {1, 3, 3, 2},
+ Padding_SAME, 1, 1);
m.PopulateTensor<int>(m.output_shape(), {1, 4, 4, 1});
m.PopulateTensor<float>(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18});
@@ -116,8 +185,9 @@ TEST(TransposeConvOpModelTest, TwoFiltersTest) {
// "VALID")
// And filter value is derived by:
// filter = tf.reshape(tf.transpose(filter, perm=[3, 0, 1, 2]), shape=[1, 18])
-TEST(TransposeConvOpModelTest, PaddingValidTest) {
- TransposeConvOpModel m({1, 4, 4, 2}, {1, 3, 3, 2}, Padding_VALID, 1, 1);
+TEST_P(TransposeConvOpTest, PaddingValidTest) {
+ TransposeConvOpModel m(GetRegistration(), {1, 4, 4, 2}, {1, 3, 3, 2},
+ Padding_VALID, 1, 1);
m.PopulateTensor<int>(m.output_shape(), {1, 6, 6, 1});
m.PopulateTensor<float>(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18});
@@ -146,8 +216,9 @@ TEST(TransposeConvOpModelTest, PaddingValidTest) {
// tf.constant(np.arange(1, 5), shape=[ 1, 2, 2, 1 ], dtype=tf.float32),
// [1, 2, 2, 1 ],
// "VALID")
-TEST(TransposeConvOpModelTest, StrideValidTest) {
- TransposeConvOpModel m({1, 2, 2, 1}, {1, 3, 3, 1}, Padding_VALID, 2, 2);
+TEST_P(TransposeConvOpTest, StrideValidTest) {
+ TransposeConvOpModel m(GetRegistration(), {1, 2, 2, 1}, {1, 3, 3, 1},
+ Padding_VALID, 2, 2);
m.PopulateTensor<int>(m.output_shape(), {1, 5, 5, 1});
m.PopulateTensor<float>(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9});
m.PopulateTensor<float>(m.input(), {1, 2, 3, 4});
@@ -170,8 +241,9 @@ TEST(TransposeConvOpModelTest, StrideValidTest) {
// tf.constant(np.arange(1, 5), shape=[ 1, 2, 2, 1 ], dtype=tf.float32),
// [1, 2, 2, 1 ],
// "VALID")
-TEST(TransposeConvOpModelTest, MultiChannelTest) {
- TransposeConvOpModel m({1, 2, 2, 1}, {2, 3, 3, 1}, Padding_VALID, 2, 2);
+TEST_P(TransposeConvOpTest, MultiChannelTest) {
+ TransposeConvOpModel m(GetRegistration(), {1, 2, 2, 1}, {2, 3, 3, 1},
+ Padding_VALID, 2, 2);
m.PopulateTensor<int>(m.output_shape(), {1, 5, 5, 2});
m.PopulateTensor<float>(m.filter(), {1, 3, 5, 7, 9, 11, 13, 15, 17, 2, 4, 6,
8, 10, 12, 14, 16, 18});
@@ -187,6 +259,24 @@ TEST(TransposeConvOpModelTest, MultiChannelTest) {
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 5, 5, 2}));
}
+// Test case: Same as above, but with a const "output_shape"
+TEST_P(TransposeConvOpTest, ConstMultiChannelTest) {
+ ConstTransposeConvOpModel m(GetRegistration(), {1, 2, 2, 1}, {2, 3, 3, 1},
+ {1, 5, 5, 2}, Padding_VALID, 2, 2);
+ m.PopulateTensor<float>(m.filter(), {1, 3, 5, 7, 9, 11, 13, 15, 17, 2, 4, 6,
+ 8, 10, 12, 14, 16, 18});
+ m.PopulateTensor<float>(m.input(), {1, 2, 3, 4});
+ m.Invoke();
+
+ EXPECT_THAT(
+ m.GetOutput(),
+ ElementsAreArray({1, 2, 3, 4, 7, 10, 6, 8, 10, 12, 7, 8, 9,
+ 10, 25, 28, 18, 20, 22, 24, 16, 20, 24, 28, 62, 72,
+ 42, 48, 54, 60, 21, 24, 27, 30, 61, 68, 36, 40, 44,
+ 48, 39, 42, 45, 48, 103, 110, 60, 64, 68, 72}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 5, 5, 2}));
+}
+
// Test case:
// filter = tf.constant(np.random.randint(1, 10, size=9),
// shape=[ 3, 3, 1, 1 ],
@@ -199,8 +289,9 @@ TEST(TransposeConvOpModelTest, MultiChannelTest) {
// "SAME")
// And filter value is derived by:
// filter = tf.reshape(tf.transpose(filter, perm=[3, 0, 1, 2]), shape=[-1])
-TEST(TransposeConvOpModelTest, AccuracyTest) {
- TransposeConvOpModel m({1, 1, 2, 1}, {1, 3, 3, 1}, Padding_SAME, 3, 3);
+TEST_P(TransposeConvOpTest, AccuracyTest) {
+ TransposeConvOpModel m(GetRegistration(), {1, 1, 2, 1}, {1, 3, 3, 1},
+ Padding_SAME, 3, 3);
m.PopulateTensor<int>(m.output_shape(), {1, 3, 4, 1});
m.PopulateTensor<float>(m.filter(), {9, 5, 6, 9, 8, 5, 3, 1, 4});
m.PopulateTensor<float>(m.input(), {323, 521});
@@ -212,6 +303,10 @@ TEST(TransposeConvOpModelTest, AccuracyTest) {
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 4, 1}));
}
+INSTANTIATE_TEST_CASE_P(
+ TransposeConvOpTest, TransposeConvOpTest,
+ ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap)));
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
index 1c28123a24..32daf2bb02 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
@@ -70,9 +70,21 @@ constexpr int kOutputStateTensor = 0;
constexpr int kCellStateTensor = 1;
constexpr int kOutputTensor = 2;
+// Temporary tensors
+enum TemporaryTensor {
+ kScratchBuffer = 0,
+ kInputQuantized = 1,
+ kOutputStateQuantized = 2,
+ kCellStateQuantized = 3,
+ kScalingFactors = 4,
+ kProductScalingFactors = 5,
+ kRecoveredCellWeights = 6,
+ kNumTemporaryTensors = 7
+};
+
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* scratch_tensor_index = new int;
- context->AddTensors(context, 1, scratch_tensor_index);
+ context->AddTensors(context, kNumTemporaryTensors, scratch_tensor_index);
return scratch_tensor_index;
}
@@ -84,7 +96,7 @@ void Free(TfLiteContext* context, void* buffer) {
TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TfLiteNode* node, int n_input,
int n_output, int n_cell) {
- auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
+ const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
// Making sure clipping parameters have valid values.
// == 0 means no clipping
@@ -242,6 +254,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Inferring batch size, number of outputs and sequence length and
// number of cells from the input tensors.
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE(context, input->dims->size > 1);
const int max_time = input->dims->data[0];
const int n_batch = input->dims->data[1];
@@ -288,86 +301,156 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, cell_state, cell_size));
- // Create a scratch buffer tensor.
+ // Mark state tensors as persistent tensors.
+ output_state->allocation_type = kTfLiteArenaRwPersistent;
+ cell_state->allocation_type = kTfLiteArenaRwPersistent;
+
+ // The weights are of consistent type, so it suffices to check one.
+ // TODO(mirkov): create a utility/macro for this check, so all Ops can use it.
+ const bool is_hybrid_op = (input_to_output_weights->type == kTfLiteUInt8 &&
+ input->type == kTfLiteFloat32);
+
TfLiteIntArrayFree(node->temporaries);
- node->temporaries = TfLiteIntArrayCreate(1);
+ if (is_hybrid_op) {
+ node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors);
+ } else {
+ node->temporaries = TfLiteIntArrayCreate(1);
+ }
node->temporaries->data[0] = *scratch_tensor_index;
- TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
+
+ // Create a scratch buffer tensor.
+ TfLiteTensor* scratch_buffer = GetTemporary(context, node, kScratchBuffer);
scratch_buffer->type = input->type;
scratch_buffer->allocation_type = kTfLiteArenaRw;
- // Mark state tensors as persistent tensors.
- output_state->allocation_type = kTfLiteArenaRwPersistent;
- cell_state->allocation_type = kTfLiteArenaRwPersistent;
-
const TfLiteTensor* input_to_input_weights =
GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
const bool use_cifg = (input_to_input_weights == nullptr);
+ TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
+ scratch_buffer_size->data[0] = n_batch;
if (use_cifg) {
- TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
- scratch_buffer_size->data[0] = n_batch;
// Reserving space for Cell, Forget, Output gates
scratch_buffer_size->data[1] = n_cell * 3;
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
- scratch_buffer_size));
} else {
- TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
- scratch_buffer_size->data[0] = n_batch;
// Reserving space for Input, Cell, Forget, Output gates
scratch_buffer_size->data[1] = n_cell * 4;
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
- scratch_buffer_size));
+ }
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
+ scratch_buffer_size));
+
+ if (is_hybrid_op) {
+ // Allocate temporary tensors to store quantized values of input,
+ // output_state and cell_state tensors.
+ node->temporaries->data[kInputQuantized] =
+ *scratch_tensor_index + kInputQuantized;
+ TfLiteTensor* input_quantized =
+ GetTemporary(context, node, kInputQuantized);
+ input_quantized->type = kTfLiteUInt8;
+ input_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
+ TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
+ input_quantized_size));
+ }
+ node->temporaries->data[kOutputStateQuantized] =
+ *scratch_tensor_index + kOutputStateQuantized;
+ TfLiteTensor* output_state_quantized =
+ GetTemporary(context, node, kOutputStateQuantized);
+ output_state_quantized->type = kTfLiteUInt8;
+ output_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(output_state_quantized->dims,
+ output_state->dims)) {
+ TfLiteIntArray* output_state_quantized_size =
+ TfLiteIntArrayCopy(output_state->dims);
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, output_state_quantized,
+ output_state_quantized_size));
+ }
+ node->temporaries->data[kCellStateQuantized] =
+ *scratch_tensor_index + kCellStateQuantized;
+ TfLiteTensor* cell_state_quantized =
+ GetTemporary(context, node, kCellStateQuantized);
+ cell_state_quantized->type = kTfLiteUInt8;
+ cell_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(cell_state_quantized->dims, cell_state->dims)) {
+ TfLiteIntArray* cell_state_quantized_size =
+ TfLiteIntArrayCopy(cell_state->dims);
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, cell_state_quantized,
+ cell_state_quantized_size));
+ }
+
+ // Allocate temporary tensors to store scaling factors and product scaling
+ // factors. The latter is a convenience storage which allows to quantize
+ // a vector once (which produces the scaling factors) and multiply it with
+ // different matrices (which requires multiplying the scaling factors with
+ // the scaling factor of the matrix).
+ node->temporaries->data[kScalingFactors] =
+ *scratch_tensor_index + kScalingFactors;
+ TfLiteTensor* scaling_factors =
+ GetTemporary(context, node, kScalingFactors);
+ scaling_factors->type = kTfLiteFloat32;
+ scaling_factors->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
+ scaling_factors_size->data[0] = n_batch;
+ if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) {
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
+ scaling_factors_size));
+ }
+ node->temporaries->data[kProductScalingFactors] =
+ *scratch_tensor_index + kProductScalingFactors;
+ TfLiteTensor* prod_scaling_factors =
+ GetTemporary(context, node, kProductScalingFactors);
+ prod_scaling_factors->type = kTfLiteFloat32;
+ prod_scaling_factors->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1);
+ prod_scaling_factors_size->data[0] = n_batch;
+ if (!TfLiteIntArrayEqual(prod_scaling_factors->dims,
+ prod_scaling_factors_size)) {
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, prod_scaling_factors,
+ prod_scaling_factors_size));
+ }
+
+ // Allocate a temporary tensor to store the recovered cell weights. Since
+ // this is used for diagonal matrices, only need to store n_cell values.
+ node->temporaries->data[kRecoveredCellWeights] =
+ *scratch_tensor_index + kRecoveredCellWeights;
+ TfLiteTensor* recovered_cell_weights =
+ GetTemporary(context, node, kRecoveredCellWeights);
+ recovered_cell_weights->type = kTfLiteFloat32;
+ recovered_cell_weights->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* recovered_cell_weights_size = TfLiteIntArrayCreate(1);
+ recovered_cell_weights_size->data[0] = n_cell;
+ if (!TfLiteIntArrayEqual(recovered_cell_weights->dims,
+ recovered_cell_weights_size)) {
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, recovered_cell_weights,
+ recovered_cell_weights_size));
+ }
}
return kTfLiteOk;
}
// The LSTM Op engine.
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
- const TfLiteTensor* input = GetInput(context, node, kInputTensor);
-
- const TfLiteTensor* input_to_input_weights =
- GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
- const TfLiteTensor* input_to_forget_weights =
- GetInput(context, node, kInputToForgetWeightsTensor);
- const TfLiteTensor* input_to_cell_weights =
- GetInput(context, node, kInputToCellWeightsTensor);
- const TfLiteTensor* input_to_output_weights =
- GetInput(context, node, kInputToOutputWeightsTensor);
-
- const TfLiteTensor* recurrent_to_input_weights =
- GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
- const TfLiteTensor* recurrent_to_forget_weights =
- GetInput(context, node, kRecurrentToForgetWeightsTensor);
- const TfLiteTensor* recurrent_to_cell_weights =
- GetInput(context, node, kRecurrentToCellWeightsTensor);
- const TfLiteTensor* recurrent_to_output_weights =
- GetInput(context, node, kRecurrentToOutputWeightsTensor);
-
- const TfLiteTensor* cell_to_input_weights =
- GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
- const TfLiteTensor* cell_to_forget_weights =
- GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
- const TfLiteTensor* cell_to_output_weights =
- GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
-
- const TfLiteTensor* input_gate_bias =
- GetOptionalInputTensor(context, node, kInputGateBiasTensor);
- const TfLiteTensor* forget_gate_bias =
- GetInput(context, node, kForgetGateBiasTensor);
- const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
- const TfLiteTensor* output_gate_bias =
- GetInput(context, node, kOutputGateBiasTensor);
-
- const TfLiteTensor* projection_weights =
- GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
- const TfLiteTensor* projection_bias =
- GetOptionalInputTensor(context, node, kProjectionBiasTensor);
-
- TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor);
- TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor);
- TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
-
+TfLiteStatus EvalFloat(
+ const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
+ const TfLiteTensor* input_to_forget_weights,
+ const TfLiteTensor* input_to_cell_weights,
+ const TfLiteTensor* input_to_output_weights,
+ const TfLiteTensor* recurrent_to_input_weights,
+ const TfLiteTensor* recurrent_to_forget_weights,
+ const TfLiteTensor* recurrent_to_cell_weights,
+ const TfLiteTensor* recurrent_to_output_weights,
+ const TfLiteTensor* cell_to_input_weights,
+ const TfLiteTensor* cell_to_forget_weights,
+ const TfLiteTensor* cell_to_output_weights,
+ const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
+ const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
+ const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
+ const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer,
+ TfLiteTensor* output_state, TfLiteTensor* cell_state,
+ TfLiteTensor* output) {
const int max_time = input->dims->data[0];
const int n_batch = input->dims->data[1];
const int n_input = input->dims->data[2];
@@ -380,8 +463,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const bool use_cifg = (input_to_input_weights == nullptr);
const bool use_peephole = (cell_to_output_weights != nullptr);
- // Index the scratch buffers pointers to the global scratch buffer.
- TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
float* input_gate_scratch = nullptr;
float* cell_scratch = nullptr;
float* forget_gate_scratch = nullptr;
@@ -432,6 +513,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
float* output_state_ptr = output_state->data.f;
float* cell_state_ptr = cell_state->data.f;
+ // Feed the sequence into the LSTM step-by-step.
for (int t = 0; t < max_time; t++) {
const float* input_ptr_batch = input->data.f + t * n_batch * n_input;
float* output_ptr_batch = output->data.f + t * n_batch * n_output;
@@ -452,6 +534,262 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
+TfLiteStatus EvalHybrid(
+ const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
+ const TfLiteTensor* input_to_forget_weights,
+ const TfLiteTensor* input_to_cell_weights,
+ const TfLiteTensor* input_to_output_weights,
+ const TfLiteTensor* recurrent_to_input_weights,
+ const TfLiteTensor* recurrent_to_forget_weights,
+ const TfLiteTensor* recurrent_to_cell_weights,
+ const TfLiteTensor* recurrent_to_output_weights,
+ const TfLiteTensor* cell_to_input_weights,
+ const TfLiteTensor* cell_to_forget_weights,
+ const TfLiteTensor* cell_to_output_weights,
+ const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
+ const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
+ const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
+ const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer,
+ TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors,
+ TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized,
+ TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized,
+ TfLiteTensor* output_state, TfLiteTensor* cell_state,
+ TfLiteTensor* output) {
+ const int max_time = input->dims->data[0];
+ const int n_batch = input->dims->data[1];
+ const int n_input = input->dims->data[2];
+ // n_cell and n_output will be the same size when there is no projection.
+ const int n_cell = input_to_output_weights->dims->data[0];
+ const int n_output = recurrent_to_output_weights->dims->data[1];
+
+ // Since we have already checked that weights are all there or none, we can
+ // check the existence of only one to get the condition.
+ const bool use_cifg = (input_to_input_weights == nullptr);
+ const bool use_peephole = (cell_to_output_weights != nullptr);
+
+ float* input_gate_scratch = nullptr;
+ float* cell_scratch = nullptr;
+ float* forget_gate_scratch = nullptr;
+ float* output_gate_scratch = nullptr;
+ if (use_cifg) {
+ cell_scratch = scratch_buffer->data.f;
+ forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ } else {
+ input_gate_scratch = scratch_buffer->data.f;
+ cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
+ }
+
+ // Check optional tensors, the respective pointers can be null.
+ int8_t* input_to_input_weights_ptr = nullptr;
+ float input_to_input_weights_scale = 1.0f;
+ int8_t* recurrent_to_input_weights_ptr = nullptr;
+ float recurrent_to_input_weights_scale = 1.0f;
+ float* input_gate_bias_ptr = nullptr;
+ if (!use_cifg) {
+ input_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_input_weights->data.uint8);
+ recurrent_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_input_weights->data.uint8);
+ input_gate_bias_ptr = input_gate_bias->data.f;
+ input_to_input_weights_scale = input_to_input_weights->params.scale;
+ recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale;
+ }
+
+ int8_t* cell_to_input_weights_ptr = nullptr;
+ int8_t* cell_to_forget_weights_ptr = nullptr;
+ int8_t* cell_to_output_weights_ptr = nullptr;
+ float cell_to_input_weights_scale = 1.0f;
+ float cell_to_forget_weights_scale = 1.0f;
+ float cell_to_output_weights_scale = 1.0f;
+ if (use_peephole) {
+ if (!use_cifg) {
+ cell_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(cell_to_input_weights->data.uint8);
+ cell_to_input_weights_scale = cell_to_input_weights->params.scale;
+ }
+ cell_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(cell_to_forget_weights->data.uint8);
+ cell_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(cell_to_output_weights->data.uint8);
+ cell_to_forget_weights_scale = cell_to_forget_weights->params.scale;
+ cell_to_output_weights_scale = cell_to_output_weights->params.scale;
+ }
+
+ const int8_t* projection_weights_ptr =
+ (projection_weights == nullptr)
+ ? nullptr
+ : reinterpret_cast<int8_t*>(projection_weights->data.uint8);
+ float projection_weights_scale =
+ (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale;
+ const float* projection_bias_ptr =
+ (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
+
+ // Required tensors, pointers are non-null.
+ const int8_t* input_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_forget_weights->data.uint8);
+ const float input_to_forget_weights_scale =
+ input_to_forget_weights->params.scale;
+ const int8_t* input_to_cell_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_cell_weights->data.uint8);
+ const float input_to_cell_weights_scale = input_to_cell_weights->params.scale;
+ const int8_t* input_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_output_weights->data.uint8);
+ const float input_to_output_weights_scale =
+ input_to_output_weights->params.scale;
+ const int8_t* recurrent_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_forget_weights->data.uint8);
+ const float recurrent_to_forget_weights_scale =
+ recurrent_to_forget_weights->params.scale;
+ const int8_t* recurrent_to_cell_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_cell_weights->data.uint8);
+ const float recurrent_to_cell_weights_scale =
+ recurrent_to_cell_weights->params.scale;
+ const int8_t* recurrent_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_output_weights->data.uint8);
+ const float recurrent_to_output_weights_scale =
+ recurrent_to_output_weights->params.scale;
+ const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
+ const float* cell_bias_ptr = cell_bias->data.f;
+ const float* output_gate_bias_ptr = output_gate_bias->data.f;
+
+ float* output_state_ptr = output_state->data.f;
+ float* cell_state_ptr = cell_state->data.f;
+
+ // Temporary storage for quantized values and scaling factors.
+ int8_t* quantized_input_ptr =
+ reinterpret_cast<int8_t*>(input_quantized->data.uint8);
+ int8_t* quantized_output_state_ptr =
+ reinterpret_cast<int8_t*>(output_state_quantized->data.uint8);
+ int8_t* quantized_cell_state_ptr =
+ reinterpret_cast<int8_t*>(cell_state_quantized->data.uint8);
+ float* scaling_factors_ptr = scaling_factors->data.f;
+ float* prod_scaling_factors_ptr = prod_scaling_factors->data.f;
+ float* recovered_cell_weights_ptr = recovered_cell_weights->data.f;
+
+ // Feed the sequence into the LSTM step-by-step.
+ for (int t = 0; t < max_time; t++) {
+ const float* input_ptr_batch = input->data.f + t * n_batch * n_input;
+ float* output_ptr_batch = output->data.f + t * n_batch * n_output;
+
+ kernel_utils::LstmStep(
+ input_ptr_batch, input_to_input_weights_ptr,
+ input_to_input_weights_scale, input_to_forget_weights_ptr,
+ input_to_forget_weights_scale, input_to_cell_weights_ptr,
+ input_to_cell_weights_scale, input_to_output_weights_ptr,
+ input_to_output_weights_scale, recurrent_to_input_weights_ptr,
+ recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr,
+ recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr,
+ recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr,
+ recurrent_to_output_weights_scale, cell_to_input_weights_ptr,
+ cell_to_input_weights_scale, cell_to_forget_weights_ptr,
+ cell_to_forget_weights_scale, cell_to_output_weights_ptr,
+ cell_to_output_weights_scale, input_gate_bias_ptr, forget_gate_bias_ptr,
+ cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr,
+ projection_weights_scale, projection_bias_ptr, params, n_batch, n_cell,
+ n_input, n_output, input_gate_scratch, forget_gate_scratch,
+ cell_scratch, output_gate_scratch, scaling_factors_ptr,
+ prod_scaling_factors_ptr, recovered_cell_weights_ptr,
+ quantized_input_ptr, quantized_output_state_ptr,
+ quantized_cell_state_ptr, output_state_ptr, cell_state_ptr,
+ output_ptr_batch);
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+
+ const TfLiteTensor* input_to_input_weights =
+ GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
+ const TfLiteTensor* input_to_forget_weights =
+ GetInput(context, node, kInputToForgetWeightsTensor);
+ const TfLiteTensor* input_to_cell_weights =
+ GetInput(context, node, kInputToCellWeightsTensor);
+ const TfLiteTensor* input_to_output_weights =
+ GetInput(context, node, kInputToOutputWeightsTensor);
+
+ const TfLiteTensor* recurrent_to_input_weights =
+ GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
+ const TfLiteTensor* recurrent_to_forget_weights =
+ GetInput(context, node, kRecurrentToForgetWeightsTensor);
+ const TfLiteTensor* recurrent_to_cell_weights =
+ GetInput(context, node, kRecurrentToCellWeightsTensor);
+ const TfLiteTensor* recurrent_to_output_weights =
+ GetInput(context, node, kRecurrentToOutputWeightsTensor);
+
+ const TfLiteTensor* cell_to_input_weights =
+ GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
+ const TfLiteTensor* cell_to_forget_weights =
+ GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
+ const TfLiteTensor* cell_to_output_weights =
+ GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
+
+ const TfLiteTensor* input_gate_bias =
+ GetOptionalInputTensor(context, node, kInputGateBiasTensor);
+ const TfLiteTensor* forget_gate_bias =
+ GetInput(context, node, kForgetGateBiasTensor);
+ const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
+ const TfLiteTensor* output_gate_bias =
+ GetInput(context, node, kOutputGateBiasTensor);
+
+ const TfLiteTensor* projection_weights =
+ GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
+ const TfLiteTensor* projection_bias =
+ GetOptionalInputTensor(context, node, kProjectionBiasTensor);
+
+ // Index the scratch buffers pointers to the global scratch buffer.
+ TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
+
+ TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor);
+ TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ switch (input_to_output_weights->type) {
+ case kTfLiteFloat32: {
+ return EvalFloat(input, input_to_input_weights, input_to_forget_weights,
+ input_to_cell_weights, input_to_output_weights,
+ recurrent_to_input_weights, recurrent_to_forget_weights,
+ recurrent_to_cell_weights, recurrent_to_output_weights,
+ cell_to_input_weights, cell_to_forget_weights,
+ cell_to_output_weights, input_gate_bias,
+ forget_gate_bias, cell_bias, output_gate_bias,
+ projection_weights, projection_bias, params,
+ scratch_buffer, output_state, cell_state, output);
+ }
+ case kTfLiteUInt8: {
+ TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
+ TfLiteTensor* output_state_quantized =
+ GetTemporary(context, node, /*index=*/2);
+ TfLiteTensor* cell_state_quantized =
+ GetTemporary(context, node, /*index=*/3);
+ TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4);
+ TfLiteTensor* prod_scaling_factors =
+ GetTemporary(context, node, /*index=*/5);
+ TfLiteTensor* recovered_cell_weights =
+ GetTemporary(context, node, /*index=*/6);
+ return EvalHybrid(
+ input, input_to_input_weights, input_to_forget_weights,
+ input_to_cell_weights, input_to_output_weights,
+ recurrent_to_input_weights, recurrent_to_forget_weights,
+ recurrent_to_cell_weights, recurrent_to_output_weights,
+ cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
+ input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias,
+ projection_weights, projection_bias, params, scratch_buffer,
+ scaling_factors, prod_scaling_factors, recovered_cell_weights,
+ input_quantized, output_state_quantized, cell_state_quantized,
+ output_state, cell_state, output);
+ }
+ default:
+ context->ReportError(context, "Type %d is not currently supported.",
+ input_to_output_weights->type);
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
} // namespace unidirectional_sequence_lstm
TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_LSTM() {
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc
index 5881ced7c7..de38bdef6f 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc
@@ -14,7 +14,6 @@ limitations under the License.
==============================================================================*/
// Unit test for TFLite Sequential LSTM op.
-#include <iomanip>
#include <memory>
#include <vector>
@@ -37,7 +36,8 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
bool use_peephole, bool use_projection_weights,
bool use_projection_bias, float cell_clip,
float proj_clip,
- const std::vector<std::vector<int>>& input_shapes)
+ const std::vector<std::vector<int>>& input_shapes,
+ const TensorType& weights_type = TensorType_FLOAT32)
: n_batch_(n_batch),
n_input_(n_input),
n_cell_(n_cell),
@@ -48,31 +48,31 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
if (use_cifg) {
input_to_input_weights_ = AddNullInput();
} else {
- input_to_input_weights_ = AddInput(TensorType_FLOAT32);
+ input_to_input_weights_ = AddInput(weights_type);
}
- input_to_forget_weights_ = AddInput(TensorType_FLOAT32);
- input_to_cell_weights_ = AddInput(TensorType_FLOAT32);
- input_to_output_weights_ = AddInput(TensorType_FLOAT32);
+ input_to_forget_weights_ = AddInput(weights_type);
+ input_to_cell_weights_ = AddInput(weights_type);
+ input_to_output_weights_ = AddInput(weights_type);
if (use_cifg) {
recurrent_to_input_weights_ = AddNullInput();
} else {
- recurrent_to_input_weights_ = AddInput(TensorType_FLOAT32);
+ recurrent_to_input_weights_ = AddInput(weights_type);
}
- recurrent_to_forget_weights_ = AddInput(TensorType_FLOAT32);
- recurrent_to_cell_weights_ = AddInput(TensorType_FLOAT32);
- recurrent_to_output_weights_ = AddInput(TensorType_FLOAT32);
+ recurrent_to_forget_weights_ = AddInput(weights_type);
+ recurrent_to_cell_weights_ = AddInput(weights_type);
+ recurrent_to_output_weights_ = AddInput(weights_type);
if (use_peephole) {
if (use_cifg) {
cell_to_input_weights_ = AddNullInput();
} else {
- cell_to_input_weights_ = AddInput(TensorType_FLOAT32);
+ cell_to_input_weights_ = AddInput(weights_type);
}
- cell_to_forget_weights_ = AddInput(TensorType_FLOAT32);
- cell_to_output_weights_ = AddInput(TensorType_FLOAT32);
+ cell_to_forget_weights_ = AddInput(weights_type);
+ cell_to_output_weights_ = AddInput(weights_type);
} else {
cell_to_input_weights_ = AddNullInput();
cell_to_forget_weights_ = AddNullInput();
@@ -89,7 +89,7 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
output_gate_bias_ = AddInput(TensorType_FLOAT32);
if (use_projection_weights) {
- projection_weights_ = AddInput(TensorType_FLOAT32);
+ projection_weights_ = AddInput(weights_type);
if (use_projection_bias) {
projection_bias_ = AddInput(TensorType_FLOAT32);
} else {
@@ -196,8 +196,9 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
zero_buffer.get() + zero_buffer_size);
}
- void SetInput(int offset, float* begin, float* end) {
- PopulateTensor(input_, offset, begin, end);
+ void SetInput(int offset, const float* begin, const float* end) {
+ PopulateTensor(input_, offset, const_cast<float*>(begin),
+ const_cast<float*>(end));
}
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
@@ -208,7 +209,7 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
int num_batches() { return n_batch_; }
int sequence_length() { return sequence_length_; }
- private:
+ protected:
int input_;
int input_to_input_weights_;
int input_to_forget_weights_;
@@ -243,7 +244,183 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
int sequence_length_;
};
-TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
+// The hybrid model has quantized weights.
+class HybridUnidirectionalLSTMOpModel : public UnidirectionalLSTMOpModel {
+ public:
+ HybridUnidirectionalLSTMOpModel(
+ int n_batch, int n_input, int n_cell, int n_output, int sequence_length,
+ bool use_cifg, bool use_peephole, bool use_projection_weights,
+ bool use_projection_bias, float cell_clip, float proj_clip,
+ const std::vector<std::vector<int>>& input_shapes)
+ : UnidirectionalLSTMOpModel(
+ n_batch, n_input, n_cell, n_output, sequence_length, use_cifg,
+ use_peephole, use_projection_weights, use_projection_bias,
+ cell_clip, proj_clip, input_shapes, TensorType_UINT8) {}
+
+ void SetInputToInputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(input_to_input_weights_, f);
+ }
+
+ void SetInputToForgetWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(input_to_forget_weights_, f);
+ }
+
+ void SetInputToCellWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(input_to_cell_weights_, f);
+ }
+
+ void SetInputToOutputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(input_to_output_weights_, f);
+ }
+
+ void SetRecurrentToInputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(recurrent_to_input_weights_, f);
+ }
+
+ void SetRecurrentToForgetWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(recurrent_to_forget_weights_, f);
+ }
+
+ void SetRecurrentToCellWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(recurrent_to_cell_weights_, f);
+ }
+
+ void SetRecurrentToOutputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(recurrent_to_output_weights_, f);
+ }
+
+ void SetCellToInputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(cell_to_input_weights_, f);
+ }
+
+ void SetCellToForgetWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(cell_to_forget_weights_, f);
+ }
+
+ void SetCellToOutputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(cell_to_output_weights_, f);
+ }
+
+ void SetProjectionWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(projection_weights_, f);
+ }
+};
+
+class BaseLstmTest : public ::testing::Test {
+ protected:
+ // Weights of the LSTM model. Some are optional.
+ std::initializer_list<float> input_to_input_weights_;
+ std::initializer_list<float> input_to_cell_weights_;
+ std::initializer_list<float> input_to_forget_weights_;
+ std::initializer_list<float> input_to_output_weights_;
+ std::initializer_list<float> input_gate_bias_;
+ std::initializer_list<float> cell_gate_bias_;
+ std::initializer_list<float> forget_gate_bias_;
+ std::initializer_list<float> output_gate_bias_;
+ std::initializer_list<float> recurrent_to_input_weights_;
+ std::initializer_list<float> recurrent_to_cell_weights_;
+ std::initializer_list<float> recurrent_to_forget_weights_;
+ std::initializer_list<float> recurrent_to_output_weights_;
+ std::initializer_list<float> cell_to_input_weights_;
+ std::initializer_list<float> cell_to_forget_weights_;
+ std::initializer_list<float> cell_to_output_weights_;
+ std::initializer_list<float> projection_weights_;
+
+ // LSTM input is stored as num_batch x num_inputs vector.
+ std::vector<std::vector<float>> lstm_input_;
+ // LSTM output is stored as num_batch x num_outputs vector.
+ std::vector<std::vector<float>> lstm_golden_output_;
+
+ // Compares output up to tolerance to the result of the lstm given the input.
+ void VerifyGoldens(const std::vector<std::vector<float>>& input,
+ const std::vector<std::vector<float>>& output,
+ UnidirectionalLSTMOpModel* lstm, float tolerance = 1e-5) {
+ const int num_batches = input.size();
+ EXPECT_GT(num_batches, 0);
+ const int num_inputs = lstm->num_inputs();
+ EXPECT_GT(num_inputs, 0);
+ const int input_sequence_size = input[0].size() / num_inputs;
+ EXPECT_GT(input_sequence_size, 0);
+ // Feed the whole sequence as input.
+ for (int i = 0; i < input_sequence_size; ++i) {
+ for (int b = 0; b < num_batches; ++b) {
+ const float* batch_start = input[b].data() + i * num_inputs;
+ const float* batch_end = batch_start + num_inputs;
+
+ lstm->SetInput(((i * num_batches) + b) * lstm->num_inputs(),
+ batch_start, batch_end);
+ }
+ }
+
+ lstm->Invoke();
+
+ const int num_outputs = lstm->num_outputs();
+ EXPECT_GT(num_outputs, 0);
+ std::vector<float> expected;
+ for (int i = 0; i < input_sequence_size; ++i) {
+ for (int b = 0; b < num_batches; ++b) {
+ const float* golden_start_batch = output[b].data() + i * num_outputs;
+ const float* golden_end_batch = golden_start_batch + num_outputs;
+
+ expected.insert(expected.end(), golden_start_batch, golden_end_batch);
+ }
+ }
+
+ EXPECT_THAT(lstm->GetOutput(),
+ ElementsAreArray(ArrayFloatNear(expected, tolerance)));
+ }
+};
+
+class NoCifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest {
+ void SetUp() override {
+ input_to_input_weights_ = {-0.45018822, -0.02338299, -0.0870589,
+ -0.34550029, 0.04266912, -0.15680569,
+ -0.34856534, 0.43890524};
+ input_to_cell_weights_ = {-0.50013041, 0.1370284, 0.11810488, 0.2013163,
+ -0.20583314, 0.44344562, 0.22077113, -0.29909778};
+ input_to_forget_weights_ = {0.09701663, 0.20334584, -0.50592935,
+ -0.31343272, -0.40032279, 0.44781327,
+ 0.01387155, -0.35593212};
+ input_to_output_weights_ = {-0.25065863, -0.28290087, 0.04613829,
+ 0.40525138, 0.44272184, 0.03897077,
+ -0.1556896, 0.19487578};
+ input_gate_bias_ = {0., 0., 0., 0.};
+ cell_gate_bias_ = {0., 0., 0., 0.};
+ forget_gate_bias_ = {1., 1., 1., 1.};
+ output_gate_bias_ = {0., 0., 0., 0.};
+
+ recurrent_to_input_weights_ = {
+ -0.0063535, -0.2042388, 0.31454784, -0.35746509,
+ 0.28902304, 0.08183324, -0.16555229, 0.02286911,
+ -0.13566875, 0.03034258, 0.48091322, -0.12528998,
+ 0.24077177, -0.51332325, -0.33502164, 0.10629296};
+
+ recurrent_to_cell_weights_ = {
+ -0.3407414, 0.24443203, -0.2078532, 0.26320225,
+ 0.05695659, -0.00123841, -0.4744786, -0.35869038,
+ -0.06418842, -0.13502428, -0.501764, 0.22830659,
+ -0.46367589, 0.26016325, -0.03894562, -0.16368064};
+
+ recurrent_to_forget_weights_ = {
+ -0.48684245, -0.06655136, 0.42224967, 0.2112639,
+ 0.27654213, 0.20864892, -0.07646349, 0.45877004,
+ 0.00141793, -0.14609534, 0.36447752, 0.09196436,
+ 0.28053468, 0.01560611, -0.20127171, -0.01140004};
+
+ recurrent_to_output_weights_ = {
+ 0.43385774, -0.17194885, 0.2718237, 0.09215671,
+ 0.24107647, -0.39835793, 0.18212086, 0.01301402,
+ 0.48572797, -0.50656658, 0.20047462, -0.20607421,
+ -0.51818722, -0.15390486, 0.0468148, 0.39922136};
+
+ lstm_input_ = {{2., 3., 3., 4., 1., 1.}};
+ lstm_golden_output_ = {{-0.02973187, 0.1229473, 0.20885126, -0.15358765,
+ -0.03716109, 0.12507336, 0.41193449, -0.20860538,
+ -0.15053082, 0.09120187, 0.24278517, -0.12222792}};
+ }
+};
+
+TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
const int n_batch = 1;
const int n_input = 2;
// n_cell and n_output have the same size when there is no projection.
@@ -252,9 +429,11 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
const int sequence_length = 3;
UnidirectionalLSTMOpModel lstm(
- n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
- /*use_peephole=*/false, /*use_projection_weights=*/false,
- /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ n_batch, n_input, n_cell, n_output, sequence_length,
+ /*use_cifg=*/false, /*use_peephole=*/false,
+ /*use_projection_weights=*/false,
+ /*use_projection_bias=*/false,
+ /*cell_clip=*/0.0, /*proj_clip=*/0.0,
{
{sequence_length, n_batch, n_input}, // input tensor
@@ -281,77 +460,138 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
{0}, // projection_bias tensor
});
- lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
- -0.34550029, 0.04266912, -0.15680569,
- -0.34856534, 0.43890524});
+ lstm.SetInputToInputWeights(input_to_input_weights_);
+ lstm.SetInputToCellWeights(input_to_cell_weights_);
+ lstm.SetInputToForgetWeights(input_to_forget_weights_);
+ lstm.SetInputToOutputWeights(input_to_output_weights_);
- lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163,
- -0.20583314, 0.44344562, 0.22077113,
- -0.29909778});
+ lstm.SetInputGateBias(input_gate_bias_);
+ lstm.SetCellBias(cell_gate_bias_);
+ lstm.SetForgetGateBias(forget_gate_bias_);
+ lstm.SetOutputGateBias(output_gate_bias_);
- lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935,
- -0.31343272, -0.40032279, 0.44781327,
- 0.01387155, -0.35593212});
+ lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
+ lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
+ lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
+ lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
+
+ // Resetting cell_state and output_state
+ lstm.ResetCellState();
+ lstm.ResetOutputState();
+
+ VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
+}
- lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829,
- 0.40525138, 0.44272184, 0.03897077, -0.1556896,
- 0.19487578});
+TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
+ const int n_batch = 1;
+ const int n_input = 2;
+ // n_cell and n_output have the same size when there is no projection.
+ const int n_cell = 4;
+ const int n_output = 4;
+ const int sequence_length = 3;
- lstm.SetInputGateBias({0., 0., 0., 0.});
+ HybridUnidirectionalLSTMOpModel lstm(
+ n_batch, n_input, n_cell, n_output, sequence_length,
+ /*use_cifg=*/false, /*use_peephole=*/false,
+ /*use_projection_weights=*/false,
+ /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ {
+ {sequence_length, n_batch, n_input}, // input tensor
- lstm.SetCellBias({0., 0., 0., 0.});
+ {n_cell, n_input}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
- lstm.SetForgetGateBias({1., 1., 1., 1.});
+ {n_cell, n_output}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
- lstm.SetOutputGateBias({0., 0., 0., 0.});
+ {0}, // cell_to_input_weight tensor
+ {0}, // cell_to_forget_weight tensor
+ {0}, // cell_to_output_weight tensor
- lstm.SetRecurrentToInputWeights(
- {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324,
- -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322,
- -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296});
+ {n_cell}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
- lstm.SetRecurrentToCellWeights(
- {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841,
- -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659,
- -0.46367589, 0.26016325, -0.03894562, -0.16368064});
+ {0, 0}, // projection_weight tensor
+ {0}, // projection_bias tensor
+ });
- lstm.SetRecurrentToForgetWeights(
- {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892,
- -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436,
- 0.28053468, 0.01560611, -0.20127171, -0.01140004});
+ lstm.SetInputToInputWeights(input_to_input_weights_);
+ lstm.SetInputToCellWeights(input_to_cell_weights_);
+ lstm.SetInputToForgetWeights(input_to_forget_weights_);
+ lstm.SetInputToOutputWeights(input_to_output_weights_);
- lstm.SetRecurrentToOutputWeights(
- {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793,
- 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421,
- -0.51818722, -0.15390486, 0.0468148, 0.39922136});
+ lstm.SetInputGateBias(input_gate_bias_);
+ lstm.SetCellBias(cell_gate_bias_);
+ lstm.SetForgetGateBias(forget_gate_bias_);
+ lstm.SetOutputGateBias(output_gate_bias_);
- // Input should have n_input * sequence_length many values.
- static float lstm_input[] = {2., 3., 3., 4., 1., 1.};
- static float lstm_golden_output[] = {-0.02973187, 0.1229473, 0.20885126,
- -0.15358765, -0.03716109, 0.12507336,
- 0.41193449, -0.20860538, -0.15053082,
- 0.09120187, 0.24278517, -0.12222792};
+ lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
+ lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
+ lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
+ lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
// Resetting cell_state and output_state
lstm.ResetCellState();
lstm.ResetOutputState();
- float* batch0_start = lstm_input;
- float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
+ VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm,
+ /*tolerance=*/0.0157651);
+}
- lstm.SetInput(0, batch0_start, batch0_end);
+class CifgPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest {
+ void SetUp() override {
+ input_to_cell_weights_ = {-0.49770179, -0.27711356, -0.09624726,
+ 0.05100781, 0.04717243, 0.48944736,
+ -0.38535351, -0.17212132};
- lstm.Invoke();
+ input_to_forget_weights_ = {-0.55291498, -0.42866567, 0.13056988,
+ -0.3633365, -0.22755712, 0.28253698,
+ 0.24407166, 0.33826375};
- float* golden_start = lstm_golden_output;
- float* golden_end =
- golden_start + lstm.num_outputs() * lstm.sequence_length();
- std::vector<float> expected;
- expected.insert(expected.end(), golden_start, golden_end);
- EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
-}
+ input_to_output_weights_ = {0.10725588, -0.02335852, -0.55932593,
+ -0.09426838, -0.44257352, 0.54939759,
+ 0.01533556, 0.42751634};
+ cell_gate_bias_ = {0., 0., 0., 0.};
+ forget_gate_bias_ = {1., 1., 1., 1.};
+ output_gate_bias_ = {0., 0., 0., 0.};
+
+ recurrent_to_cell_weights_ = {
+ 0.54066205, -0.32668582, -0.43562764, -0.56094903,
+ 0.42957711, 0.01841056, -0.32764608, -0.33027974,
+ -0.10826075, 0.20675004, 0.19069612, -0.03026325,
+ -0.54532051, 0.33003211, 0.44901288, 0.21193194};
+
+ recurrent_to_forget_weights_ = {
+ -0.13832897, -0.0515101, -0.2359007, -0.16661474,
+ -0.14340827, 0.36986142, 0.23414481, 0.55899,
+ 0.10798943, -0.41174671, 0.17751795, -0.34484994,
+ -0.35874045, -0.11352962, 0.27268326, 0.54058349};
+
+ recurrent_to_output_weights_ = {
+ 0.41613156, 0.42610586, -0.16495961, -0.5663873,
+ 0.30579174, -0.05115908, -0.33941799, 0.23364776,
+ 0.11178309, 0.09481031, -0.26424935, 0.46261835,
+ 0.50248802, 0.26114327, -0.43736315, 0.33149987};
+
+ cell_to_forget_weights_ = {0.47485286, -0.51955009, -0.24458408,
+ 0.31544167};
+ cell_to_output_weights_ = {-0.17135078, 0.82760304, 0.85573703,
+ -0.77109635};
+
+ lstm_input_ = {{2., 3., 3., 4., 1., 1.}};
+ lstm_golden_output_ = {{-0.36444446, -0.00352185, 0.12886585, -0.05163646,
+ -0.42312205, -0.01218222, 0.24201041, -0.08124574,
+ -0.358325, -0.04621704, 0.21641694, -0.06471302}};
+ }
+};
-TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
+TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
const int n_batch = 1;
const int n_input = 2;
// n_cell and n_output have the same size when there is no projection.
@@ -360,9 +600,11 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
const int sequence_length = 3;
UnidirectionalLSTMOpModel lstm(
- n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/true,
- /*use_peephole=*/true, /*use_projection_weights=*/false,
- /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ n_batch, n_input, n_cell, n_output, sequence_length,
+ /*use_cifg=*/true, /*use_peephole=*/true,
+ /*use_projection_weights=*/false,
+ /*use_projection_bias=*/false,
+ /*cell_clip=*/0.0, /*proj_clip=*/0.0,
{
{sequence_length, n_batch, n_input}, // input tensor
@@ -389,71 +631,690 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
{0}, // projection_bias tensor
});
- lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781,
- 0.04717243, 0.48944736, -0.38535351,
- -0.17212132});
+ lstm.SetInputToCellWeights(input_to_cell_weights_);
+ lstm.SetInputToForgetWeights(input_to_forget_weights_);
+ lstm.SetInputToOutputWeights(input_to_output_weights_);
- lstm.SetInputToForgetWeights({-0.55291498, -0.42866567, 0.13056988,
- -0.3633365, -0.22755712, 0.28253698, 0.24407166,
- 0.33826375});
+ lstm.SetCellBias(cell_gate_bias_);
+ lstm.SetForgetGateBias(forget_gate_bias_);
+ lstm.SetOutputGateBias(output_gate_bias_);
- lstm.SetInputToOutputWeights({0.10725588, -0.02335852, -0.55932593,
- -0.09426838, -0.44257352, 0.54939759,
- 0.01533556, 0.42751634});
+ lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
+ lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
+ lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
+
+ lstm.SetCellToForgetWeights(cell_to_forget_weights_);
+ lstm.SetCellToOutputWeights(cell_to_output_weights_);
+
+ // Resetting cell_state and output_state
+ lstm.ResetCellState();
+ lstm.ResetOutputState();
+
+ VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
+}
+
+TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
+ const int n_batch = 1;
+ const int n_input = 2;
+ // n_cell and n_output have the same size when there is no projection.
+ const int n_cell = 4;
+ const int n_output = 4;
+ const int sequence_length = 3;
+
+ HybridUnidirectionalLSTMOpModel lstm(
+ n_batch, n_input, n_cell, n_output, sequence_length,
+ /*use_cifg=*/true, /*use_peephole=*/true,
+ /*use_projection_weights=*/false,
+ /*use_projection_bias=*/false,
+ /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ {
+ {sequence_length, n_batch, n_input}, // input tensor
+
+ {0, 0}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
- lstm.SetCellBias({0., 0., 0., 0.});
+ {0, 0}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
- lstm.SetForgetGateBias({1., 1., 1., 1.});
+ {0}, // cell_to_input_weight tensor
+ {n_cell}, // cell_to_forget_weight tensor
+ {n_cell}, // cell_to_output_weight tensor
- lstm.SetOutputGateBias({0., 0., 0., 0.});
+ {0}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
- lstm.SetRecurrentToCellWeights(
- {0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711,
- 0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004,
- 0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288,
- 0.21193194});
+ {0, 0}, // projection_weight tensor
+ {0}, // projection_bias tensor
+ });
- lstm.SetRecurrentToForgetWeights(
- {-0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827,
- 0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795,
- -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349});
+ lstm.SetInputToCellWeights(input_to_cell_weights_);
+ lstm.SetInputToForgetWeights(input_to_forget_weights_);
+ lstm.SetInputToOutputWeights(input_to_output_weights_);
- lstm.SetRecurrentToOutputWeights(
- {0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908,
- -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835,
- 0.50248802, 0.26114327, -0.43736315, 0.33149987});
+ lstm.SetCellBias(cell_gate_bias_);
+ lstm.SetForgetGateBias(forget_gate_bias_);
+ lstm.SetOutputGateBias(output_gate_bias_);
- lstm.SetCellToForgetWeights(
- {0.47485286, -0.51955009, -0.24458408, 0.31544167});
- lstm.SetCellToOutputWeights(
- {-0.17135078, 0.82760304, 0.85573703, -0.77109635});
+ lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
+ lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
+ lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
- static float lstm_input[] = {2., 3., 3., 4., 1., 1.};
- static float lstm_golden_output[] = {-0.36444446, -0.00352185, 0.12886585,
- -0.05163646, -0.42312205, -0.01218222,
- 0.24201041, -0.08124574, -0.358325,
- -0.04621704, 0.21641694, -0.06471302};
+ lstm.SetCellToForgetWeights(cell_to_forget_weights_);
+ lstm.SetCellToOutputWeights(cell_to_output_weights_);
// Resetting cell_state and output_state
lstm.ResetCellState();
lstm.ResetOutputState();
- float* batch0_start = lstm_input;
- float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
-
- lstm.SetInput(0, batch0_start, batch0_end);
-
- lstm.Invoke();
-
- float* golden_start = lstm_golden_output;
- float* golden_end =
- golden_start + lstm.num_outputs() * lstm.sequence_length();
- std::vector<float> expected;
- expected.insert(expected.end(), golden_start, golden_end);
- EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
+ VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
}
-TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) {
+class NoCifgPeepholeProjectionClippingLstmTest : public BaseLstmTest {
+ void SetUp() override {
+ input_to_input_weights_ = {
+ 0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463,
+ 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048,
+ -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385,
+ -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282,
+ -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627,
+ -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226,
+ -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059,
+ 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698,
+ 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206,
+ 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585,
+ -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063,
+ 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603,
+ -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682,
+ -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988,
+ -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764,
+ 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476,
+ -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012,
+ -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604,
+ -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654,
+ -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677};
+
+ input_to_forget_weights_ = {
+ -0.0018401089, -0.004852237, 0.03698424, 0.014181704,
+ 0.028273236, -0.016726194, -0.05249759, -0.10204261,
+ 0.00861066, -0.040979505, -0.009899187, 0.01923892,
+ -0.028177269, -0.08535103, -0.14585495, 0.10662567,
+ -0.01909731, -0.017883534, -0.0047269356, -0.045103323,
+ 0.0030784295, 0.076784775, 0.07463696, 0.094531395,
+ 0.0814421, -0.12257899, -0.033945758, -0.031303465,
+ 0.045630626, 0.06843887, -0.13492945, -0.012480007,
+ -0.0811829, -0.07224499, -0.09628791, 0.045100946,
+ 0.0012300825, 0.013964662, 0.099372394, 0.02543059,
+ 0.06958324, 0.034257296, 0.0482646, 0.06267997,
+ 0.052625068, 0.12784666, 0.07077897, 0.025725935,
+ 0.04165009, 0.07241905, 0.018668644, -0.037377294,
+ -0.06277783, -0.08833636, -0.040120605, -0.011405586,
+ -0.007808335, -0.010301386, -0.005102167, 0.027717464,
+ 0.05483423, 0.11449111, 0.11289652, 0.10939839,
+ 0.13396506, -0.08402166, -0.01901462, -0.044678304,
+ -0.07720565, 0.014350063, -0.11757958, -0.0652038,
+ -0.08185733, -0.076754324, -0.092614375, 0.10405491,
+ 0.052960336, 0.035755895, 0.035839386, -0.012540553,
+ 0.036881298, 0.02913376, 0.03420159, 0.05448447,
+ -0.054523353, 0.02582715, 0.02327355, -0.011857179,
+ -0.0011980024, -0.034641717, -0.026125094, -0.17582615,
+ -0.15923657, -0.27486774, -0.0006143371, 0.0001771948,
+ -8.470171e-05, 0.02651807, 0.045790765, 0.06956496};
+
+ input_to_cell_weights_ = {
+ -0.04580283, -0.09549462, -0.032418985, -0.06454633,
+ -0.043528453, 0.043018587, -0.049152344, -0.12418144,
+ -0.078985475, -0.07596889, 0.019484362, -0.11434962,
+ -0.0074034138, -0.06314844, -0.092981495, 0.0062155537,
+ -0.025034338, -0.0028890965, 0.048929527, 0.06235075,
+ 0.10665918, -0.032036792, -0.08505916, -0.10843358,
+ -0.13002433, -0.036816437, -0.02130134, -0.016518239,
+ 0.0047691227, -0.0025825808, 0.066017866, 0.029991534,
+ -0.10652836, -0.1037554, -0.13056071, -0.03266643,
+ -0.033702414, -0.006473424, -0.04611692, 0.014419339,
+ -0.025174323, 0.0396852, 0.081777506, 0.06157468,
+ 0.10210095, -0.009658194, 0.046511717, 0.03603906,
+ 0.0069369148, 0.015960095, -0.06507666, 0.09551598,
+ 0.053568836, 0.06408714, 0.12835667, -0.008714329,
+ -0.20211966, -0.12093674, 0.029450472, 0.2849013,
+ -0.029227901, 0.1164364, -0.08560263, 0.09941786,
+ -0.036999565, -0.028842626, -0.0033637602, -0.017012902,
+ -0.09720865, -0.11193351, -0.029155117, -0.017936034,
+ -0.009768936, -0.04223324, -0.036159635, 0.06505112,
+ -0.021742892, -0.023377212, -0.07221364, -0.06430552,
+ 0.05453865, 0.091149814, 0.06387331, 0.007518393,
+ 0.055960953, 0.069779344, 0.046411168, 0.10509911,
+ 0.07463894, 0.0075130584, 0.012850982, 0.04555431,
+ 0.056955688, 0.06555285, 0.050801456, -0.009862683,
+ 0.00826772, -0.026555609, -0.0073611983, -0.0014897042};
+
+ input_to_output_weights_ = {
+ -0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918,
+ -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534,
+ 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722,
+ -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761,
+ -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394,
+ 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154,
+ -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135,
+ -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564,
+ -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047,
+ -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304,
+ 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946,
+ 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646,
+ 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813,
+ -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403,
+ 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415,
+ 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495,
+ -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158,
+ 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295,
+ -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739,
+ -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956};
+
+ input_gate_bias_ = {0.02234832, 0.14757581, 0.18176508, 0.10380666,
+ 0.053110216, -0.06928846, -0.13942584, -0.11816189,
+ 0.19483899, 0.03652339, -0.10250295, 0.036714908,
+ -0.18426876, 0.036065217, 0.21810818, 0.02383196,
+ -0.043370757, 0.08690144, -0.04444982, 0.00030581196};
+
+ forget_gate_bias_ = {0.035185695, -0.042891346, -0.03032477, 0.23027696,
+ 0.11098921, 0.15378423, 0.09263801, 0.09790885,
+ 0.09508917, 0.061199076, 0.07665568, -0.015443159,
+ -0.03499149, 0.046190713, 0.08895977, 0.10899629,
+ 0.40694186, 0.06030037, 0.012413437, -0.06108739};
+
+ cell_gate_bias_ = {-0.024379363, 0.0055531194, 0.23377132, 0.033463873,
+ -0.1483596, -0.10639995, -0.091433935, 0.058573797,
+ -0.06809782, -0.07889636, -0.043246906, -0.09829136,
+ -0.4279842, 0.034901652, 0.18797937, 0.0075234566,
+ 0.016178843, 0.1749513, 0.13975595, 0.92058027};
+
+ output_gate_bias_ = {0.046159424, -0.0012809046, 0.03563469, 0.12648113,
+ 0.027195795, 0.35373217, -0.018957434, 0.008907322,
+ -0.0762701, 0.12018895, 0.04216877, 0.0022856654,
+ 0.040952638, 0.3147856, 0.08225149, -0.057416286,
+ -0.14995944, -0.008040261, 0.13208859, 0.029760877};
+
+ recurrent_to_input_weights_ = {
+ -0.001374326, -0.078856036, 0.10672688, 0.029162422,
+ -0.11585556, 0.02557986, -0.13446963, -0.035785314,
+ -0.01244275, 0.025961924, -0.02337298, -0.044228926,
+ -0.055839065, -0.046598054, -0.010546039, -0.06900766,
+ 0.027239809, 0.022582639, -0.013296484, -0.05459212,
+ 0.08981, -0.045407712, 0.08682226, -0.06867011,
+ -0.14390695, -0.02916037, 0.000996957, 0.091420636,
+ 0.14283475, -0.07390571, -0.06402044, 0.062524505,
+ -0.093129106, 0.04860203, -0.08364217, -0.08119002,
+ 0.009352075, 0.22920375, 0.0016303885, 0.11583097,
+ -0.13732095, 0.012405723, -0.07551853, 0.06343048,
+ 0.12162708, -0.031923793, -0.014335606, 0.01790974,
+ -0.10650317, -0.0724401, 0.08554849, -0.05727212,
+ 0.06556731, -0.042729504, -0.043227166, 0.011683251,
+ -0.013082158, -0.029302018, -0.010899579, -0.062036745,
+ -0.022509435, -0.00964907, -0.01567329, 0.04260106,
+ -0.07787477, -0.11576462, 0.017356863, 0.048673786,
+ -0.017577527, -0.05527947, -0.082487635, -0.040137455,
+ -0.10820036, -0.04666372, 0.022746278, -0.07851417,
+ 0.01068115, 0.032956902, 0.022433773, 0.0026891115,
+ 0.08944216, -0.0685835, 0.010513544, 0.07228705,
+ 0.02032331, -0.059686817, -0.0005566496, -0.086984694,
+ 0.040414046, -0.1380399, 0.094208956, -0.05722982,
+ 0.012092817, -0.04989123, -0.086576, -0.003399834,
+ -0.04696032, -0.045747425, 0.10091314, 0.048676282,
+ -0.029037097, 0.031399418, -0.0040285117, 0.047237843,
+ 0.09504992, 0.041799378, -0.049185462, -0.031518843,
+ -0.10516937, 0.026374253, 0.10058866, -0.0033195973,
+ -0.041975245, 0.0073591834, 0.0033782164, -0.004325073,
+ -0.10167381, 0.042500053, -0.01447153, 0.06464186,
+ -0.017142897, 0.03312627, 0.009205989, 0.024138335,
+ -0.011337001, 0.035530265, -0.010912711, 0.0706555,
+ -0.005894094, 0.051841937, -0.1401738, -0.02351249,
+ 0.0365468, 0.07590991, 0.08838724, 0.021681072,
+ -0.10086113, 0.019608743, -0.06195883, 0.077335775,
+ 0.023646897, -0.095322326, 0.02233014, 0.09756986,
+ -0.048691444, -0.009579111, 0.07595467, 0.11480546,
+ -0.09801813, 0.019894179, 0.08502348, 0.004032281,
+ 0.037211012, 0.068537936, -0.048005626, -0.091520436,
+ -0.028379958, -0.01556313, 0.06554592, -0.045599163,
+ -0.01672207, -0.020169014, -0.011877351, -0.20212261,
+ 0.010889619, 0.0047078193, 0.038385306, 0.08540671,
+ -0.017140968, -0.0035865551, 0.016678626, 0.005633034,
+ 0.015963363, 0.00871737, 0.060130805, 0.028611384,
+ 0.10109069, -0.015060172, -0.07894427, 0.06401885,
+ 0.011584063, -0.024466386, 0.0047652307, -0.09041358,
+ 0.030737216, -0.0046374933, 0.14215417, -0.11823516,
+ 0.019899689, 0.006106124, -0.027092824, 0.0786356,
+ 0.05052217, -0.058925, -0.011402121, -0.024987547,
+ -0.0013661642, -0.06832946, -0.015667673, -0.1083353,
+ -0.00096863037, -0.06988685, -0.053350925, -0.027275559,
+ -0.033664223, -0.07978348, -0.025200296, -0.017207067,
+ -0.058403496, -0.055697463, 0.005798788, 0.12965427,
+ -0.062582195, 0.0013350133, -0.10482091, 0.0379771,
+ 0.072521195, -0.0029455067, -0.13797039, -0.03628521,
+ 0.013806405, -0.017858358, -0.01008298, -0.07700066,
+ -0.017081132, 0.019358726, 0.0027079724, 0.004635139,
+ 0.062634714, -0.02338735, -0.039547626, -0.02050681,
+ 0.03385117, -0.083611414, 0.002862572, -0.09421313,
+ 0.058618143, -0.08598433, 0.00972939, 0.023867095,
+ -0.053934585, -0.023203006, 0.07452513, -0.048767887,
+ -0.07314807, -0.056307215, -0.10433547, -0.06440842,
+ 0.04328182, 0.04389765, -0.020006588, -0.09076438,
+ -0.11652589, -0.021705797, 0.03345259, -0.010329105,
+ -0.025767034, 0.013057034, -0.07316461, -0.10145612,
+ 0.06358255, 0.18531723, 0.07759293, 0.12006465,
+ 0.1305557, 0.058638252, -0.03393652, 0.09622831,
+ -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845,
+ -0.005644518, 0.06857898, -0.12598175, -0.035084512,
+ 0.03156317, -0.12794146, -0.031963028, 0.04692781,
+ 0.030070418, 0.0071660685, -0.095516115, -0.004643372,
+ 0.040170413, -0.062104587, -0.0037324072, 0.0554317,
+ 0.08184801, -0.019164372, 0.06791302, 0.034257166,
+ -0.10307039, 0.021943003, 0.046745934, 0.0790918,
+ -0.0265588, -0.007824208, 0.042546265, -0.00977924,
+ -0.0002440307, -0.017384544, -0.017990116, 0.12252321,
+ -0.014512694, -0.08251313, 0.08861942, 0.13589665,
+ 0.026351685, 0.012641483, 0.07466548, 0.044301085,
+ -0.045414884, -0.051112458, 0.03444247, -0.08502782,
+ -0.04106223, -0.028126027, 0.028473156, 0.10467447};
+
+ recurrent_to_cell_weights_ = {
+ -0.037322544, 0.018592842, 0.0056175636, -0.06253426,
+ 0.055647098, -0.05713207, -0.05626563, 0.005559383,
+ 0.03375411, -0.025757805, -0.088049285, 0.06017052,
+ -0.06570978, 0.007384076, 0.035123326, -0.07920549,
+ 0.053676967, 0.044480428, -0.07663568, 0.0071805613,
+ 0.08089997, 0.05143358, 0.038261272, 0.03339287,
+ -0.027673481, 0.044746667, 0.028349208, 0.020090483,
+ -0.019443132, -0.030755889, -0.0040000007, 0.04465846,
+ -0.021585021, 0.0031670958, 0.0053199246, -0.056117613,
+ -0.10893326, 0.076739706, -0.08509834, -0.027997585,
+ 0.037871376, 0.01449768, -0.09002357, -0.06111149,
+ -0.046195522, 0.0422062, -0.005683705, -0.1253618,
+ -0.012925729, -0.04890792, 0.06985068, 0.037654128,
+ 0.03398274, -0.004781977, 0.007032333, -0.031787455,
+ 0.010868644, -0.031489216, 0.09525667, 0.013939797,
+ 0.0058680447, 0.0167067, 0.02668468, -0.04797466,
+ -0.048885044, -0.12722108, 0.035304096, 0.06554885,
+ 0.00972396, -0.039238118, -0.05159735, -0.11329045,
+ 0.1613692, -0.03750952, 0.06529313, -0.071974665,
+ -0.11769596, 0.015524369, -0.0013754242, -0.12446318,
+ 0.02786344, -0.014179351, 0.005264273, 0.14376344,
+ 0.015983658, 0.03406988, -0.06939408, 0.040699873,
+ 0.02111075, 0.09669095, 0.041345075, -0.08316494,
+ -0.07684199, -0.045768797, 0.032298047, -0.041805092,
+ 0.0119405, 0.0061010392, 0.12652606, 0.0064572375,
+ -0.024950314, 0.11574242, 0.04508852, -0.04335324,
+ 0.06760663, -0.027437469, 0.07216407, 0.06977076,
+ -0.05438599, 0.034033038, -0.028602652, 0.05346137,
+ 0.043184172, -0.037189785, 0.10420091, 0.00882477,
+ -0.054019816, -0.074273005, -0.030617684, -0.0028467078,
+ 0.024302477, -0.0038869337, 0.005332455, 0.0013399826,
+ 0.04361412, -0.007001822, 0.09631092, -0.06702025,
+ -0.042049985, -0.035070654, -0.04103342, -0.10273396,
+ 0.0544271, 0.037184782, -0.13150354, -0.0058036847,
+ -0.008264958, 0.042035464, 0.05891794, 0.029673764,
+ 0.0063542654, 0.044788733, 0.054816857, 0.062257513,
+ -0.00093483756, 0.048938446, -0.004952862, -0.007730018,
+ -0.04043371, -0.017094059, 0.07229206, -0.023670016,
+ -0.052195564, -0.025616996, -0.01520939, 0.045104615,
+ -0.007376126, 0.003533447, 0.006570588, 0.056037236,
+ 0.12436656, 0.051817212, 0.028532185, -0.08686856,
+ 0.11868599, 0.07663395, -0.07323171, 0.03463402,
+ -0.050708205, -0.04458982, -0.11590894, 0.021273347,
+ 0.1251325, -0.15313013, -0.12224372, 0.17228661,
+ 0.023029093, 0.086124025, 0.006445803, -0.03496501,
+ 0.028332196, 0.04449512, -0.042436164, -0.026587414,
+ -0.006041347, -0.09292539, -0.05678812, 0.03897832,
+ 0.09465633, 0.008115513, -0.02171956, 0.08304309,
+ 0.071401566, 0.019622514, 0.032163795, -0.004167056,
+ 0.02295182, 0.030739572, 0.056506045, 0.004612461,
+ 0.06524936, 0.059999723, 0.046395954, -0.0045512207,
+ -0.1335546, -0.030136576, 0.11584653, -0.014678886,
+ 0.0020118146, -0.09688814, -0.0790206, 0.039770417,
+ -0.0329582, 0.07922767, 0.029322514, 0.026405897,
+ 0.04207835, -0.07073373, 0.063781224, 0.0859677,
+ -0.10925287, -0.07011058, 0.048005477, 0.03438226,
+ -0.09606514, -0.006669445, -0.043381985, 0.04240257,
+ -0.06955775, -0.06769346, 0.043903265, -0.026784198,
+ -0.017840602, 0.024307009, -0.040079936, -0.019946516,
+ 0.045318738, -0.12233574, 0.026170589, 0.0074471775,
+ 0.15978073, 0.10185836, 0.10298046, -0.015476589,
+ -0.039390966, -0.072174534, 0.0739445, -0.1211869,
+ -0.0347889, -0.07943156, 0.014809798, -0.12412325,
+ -0.0030663363, 0.039695457, 0.0647603, -0.08291318,
+ -0.018529687, -0.004423833, 0.0037507233, 0.084633216,
+ -0.01514876, -0.056505352, -0.012800942, -0.06994386,
+ 0.012962922, -0.031234352, 0.07029052, 0.016418684,
+ 0.03618972, 0.055686004, -0.08663945, -0.017404709,
+ -0.054761406, 0.029065743, 0.052404847, 0.020238016,
+ 0.0048197987, -0.0214882, 0.07078733, 0.013016777,
+ 0.06262858, 0.009184685, 0.020785125, -0.043904778,
+ -0.0270329, -0.03299152, -0.060088247, -0.015162964,
+ -0.001828936, 0.12642565, -0.056757294, 0.013586685,
+ 0.09232601, -0.035886683, 0.06000002, 0.05229691,
+ -0.052580316, -0.082029596, -0.010794592, 0.012947712,
+ -0.036429964, -0.085508935, -0.13127148, -0.017744139,
+ 0.031502828, 0.036232427, -0.031581745, 0.023051167,
+ -0.05325106, -0.03421577, 0.028793324, -0.034633752,
+ -0.009881397, -0.043551125, -0.018609839, 0.0019097115,
+ -0.008799762, 0.056595087, 0.0022273948, 0.055752404};
+
+ recurrent_to_forget_weights_ = {
+ -0.057784554, -0.026057621, -0.068447545, -0.022581743,
+ 0.14811787, 0.10826372, 0.09471067, 0.03987225,
+ -0.0039523416, 0.00030638507, 0.053185795, 0.10572994,
+ 0.08414449, -0.022036452, -0.00066928595, -0.09203576,
+ 0.032950465, -0.10985798, -0.023809856, 0.0021431844,
+ -0.02196096, -0.00326074, 0.00058621005, -0.074678116,
+ -0.06193199, 0.055729095, 0.03736828, 0.020123724,
+ 0.061878487, -0.04729229, 0.034919553, -0.07585433,
+ -0.04421272, -0.044019096, 0.085488975, 0.04058006,
+ -0.06890133, -0.030951202, -0.024628663, -0.07672815,
+ 0.034293607, 0.08556707, -0.05293577, -0.033561368,
+ -0.04899627, 0.0241671, 0.015736353, -0.095442444,
+ -0.029564252, 0.016493602, -0.035026584, 0.022337519,
+ -0.026871363, 0.004780428, 0.0077918363, -0.03601621,
+ 0.016435321, -0.03263031, -0.09543275, -0.047392778,
+ 0.013454138, 0.028934088, 0.01685226, -0.086110644,
+ -0.046250615, -0.01847454, 0.047608484, 0.07339695,
+ 0.034546845, -0.04881143, 0.009128804, -0.08802852,
+ 0.03761666, 0.008096139, -0.014454086, 0.014361001,
+ -0.023502491, -0.0011840804, -0.07607001, 0.001856849,
+ -0.06509276, -0.006021153, -0.08570962, -0.1451793,
+ 0.060212336, 0.055259194, 0.06974018, 0.049454916,
+ -0.027794661, -0.08077226, -0.016179763, 0.1169753,
+ 0.17213494, -0.0056326236, -0.053934924, -0.0124349,
+ -0.11520337, 0.05409887, 0.088759385, 0.0019655675,
+ 0.0042065294, 0.03881498, 0.019844765, 0.041858196,
+ -0.05695512, 0.047233116, 0.038937137, -0.06542224,
+ 0.014429736, -0.09719407, 0.13908425, -0.05379757,
+ 0.012321099, 0.082840554, -0.029899208, 0.044217527,
+ 0.059855383, 0.07711018, -0.045319796, 0.0948846,
+ -0.011724666, -0.0033288454, -0.033542685, -0.04764985,
+ -0.13873616, 0.040668588, 0.034832682, -0.015319203,
+ -0.018715994, 0.046002675, 0.0599172, -0.043107376,
+ 0.0294216, -0.002314414, -0.022424703, 0.0030315618,
+ 0.0014641669, 0.0029166266, -0.11878115, 0.013738511,
+ 0.12375372, -0.0006038222, 0.029104086, 0.087442465,
+ 0.052958444, 0.07558703, 0.04817258, 0.044462286,
+ -0.015213451, -0.08783778, -0.0561384, -0.003008196,
+ 0.047060397, -0.002058388, 0.03429439, -0.018839769,
+ 0.024734668, 0.024614193, -0.042046934, 0.09597743,
+ -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786,
+ -0.02558259, -0.022822596, -0.023273505, -0.02464396,
+ -0.10991725, -0.006240552, 0.0074488563, 0.024044557,
+ 0.04383914, -0.046476185, 0.028658995, 0.060410924,
+ 0.050786525, 0.009452605, -0.0073054377, -0.024810238,
+ 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517,
+ 0.015898481, 0.021362653, -0.030262267, 0.016587038,
+ -0.011442813, 0.041154444, -0.007631438, -0.03423484,
+ -0.010977775, 0.036152758, 0.0066366293, 0.11915515,
+ 0.02318443, -0.041350313, 0.021485701, -0.10906167,
+ -0.028218046, -0.00954771, 0.020531068, -0.11995105,
+ -0.03672871, 0.024019798, 0.014255957, -0.05221243,
+ -0.00661567, -0.04630967, 0.033188973, 0.10107534,
+ -0.014027541, 0.030796422, -0.10270911, -0.035999842,
+ 0.15443139, 0.07684145, 0.036571592, -0.035900835,
+ -0.0034699554, 0.06209149, 0.015920248, -0.031122351,
+ -0.03858649, 0.01849943, 0.13872518, 0.01503974,
+ 0.069941424, -0.06948533, -0.0088794185, 0.061282158,
+ -0.047401894, 0.03100163, -0.041533746, -0.10430945,
+ 0.044574402, -0.01425562, -0.024290353, 0.034563623,
+ 0.05866852, 0.023947537, -0.09445152, 0.035450947,
+ 0.02247216, -0.0042998926, 0.061146557, -0.10250651,
+ 0.020881841, -0.06747029, 0.10062043, -0.0023941975,
+ 0.03532124, -0.016341697, 0.09685456, -0.016764693,
+ 0.051808182, 0.05875331, -0.04536488, 0.001626336,
+ -0.028892258, -0.01048663, -0.009793449, -0.017093895,
+ 0.010987891, 0.02357273, -0.00010856845, 0.0099760275,
+ -0.001845119, -0.03551521, 0.0018358806, 0.05763657,
+ -0.01769146, 0.040995963, 0.02235177, -0.060430344,
+ 0.11475477, -0.023854522, 0.10071741, 0.0686208,
+ -0.014250481, 0.034261297, 0.047418304, 0.08562733,
+ -0.030519066, 0.0060542435, 0.014653856, -0.038836084,
+ 0.04096551, 0.032249358, -0.08355519, -0.026823482,
+ 0.056386515, -0.010401743, -0.028396193, 0.08507674,
+ 0.014410365, 0.020995233, 0.17040324, 0.11511526,
+ 0.02459721, 0.0066619175, 0.025853224, -0.023133837,
+ -0.081302024, 0.017264642, -0.009585969, 0.09491168,
+ -0.051313367, 0.054532815, -0.014298593, 0.10657464,
+ 0.007076659, 0.10964551, 0.0409152, 0.008275321,
+ -0.07283536, 0.07937492, 0.04192024, -0.1075027};
+
+ recurrent_to_output_weights_ = {
+ 0.025825322, -0.05813119, 0.09495884, -0.045984812,
+ -0.01255415, -0.0026479573, -0.08196161, -0.054914974,
+ -0.0046604523, -0.029587349, -0.044576716, -0.07480124,
+ -0.082868785, 0.023254942, 0.027502948, -0.0039728214,
+ -0.08683098, -0.08116779, -0.014675607, -0.037924774,
+ -0.023314456, -0.007401714, -0.09255757, 0.029460307,
+ -0.08829125, -0.005139627, -0.08989442, -0.0555066,
+ 0.13596267, -0.025062224, -0.048351806, -0.03850004,
+ 0.07266485, -0.022414139, 0.05940088, 0.075114764,
+ 0.09597592, -0.010211725, -0.0049794707, -0.011523867,
+ -0.025980417, 0.072999895, 0.11091378, -0.081685916,
+ 0.014416728, 0.043229222, 0.034178585, -0.07530371,
+ 0.035837382, -0.085607, -0.007721233, -0.03287832,
+ -0.043848954, -0.06404588, -0.06632928, -0.073643476,
+ 0.008214239, -0.045984086, 0.039764922, 0.03474462,
+ 0.060612556, -0.080590084, 0.049127717, 0.04151091,
+ -0.030063879, 0.008801774, -0.023021035, -0.019558564,
+ 0.05158114, -0.010947698, -0.011825728, 0.0075720972,
+ 0.0699727, -0.0039981045, 0.069350146, 0.08799282,
+ 0.016156472, 0.035502106, 0.11695009, 0.006217345,
+ 0.13392477, -0.037875112, 0.025745004, 0.08940699,
+ -0.00924166, 0.0046702605, -0.036598757, -0.08811812,
+ 0.10522024, -0.032441203, 0.008176899, -0.04454919,
+ 0.07058152, 0.0067963637, 0.039206743, 0.03259838,
+ 0.03725492, -0.09515802, 0.013326398, -0.052055415,
+ -0.025676316, 0.03198509, -0.015951829, -0.058556724,
+ 0.036879618, 0.043357447, 0.028362012, -0.05908629,
+ 0.0059240665, -0.04995891, -0.019187413, 0.0276265,
+ -0.01628143, 0.0025863599, 0.08800015, 0.035250366,
+ -0.022165963, -0.07328642, -0.009415526, -0.07455109,
+ 0.11690406, 0.0363299, 0.07411125, 0.042103454,
+ -0.009660886, 0.019076364, 0.018299393, -0.046004917,
+ 0.08891175, 0.0431396, -0.026327137, -0.051502608,
+ 0.08979574, -0.051670972, 0.04940282, -0.07491107,
+ -0.021240504, 0.022596184, -0.034280192, 0.060163025,
+ -0.058211457, -0.051837247, -0.01349775, -0.04639988,
+ -0.035936575, -0.011681591, 0.064818054, 0.0073146066,
+ -0.021745546, -0.043124277, -0.06471268, -0.07053354,
+ -0.029321948, -0.05330136, 0.016933719, -0.053782392,
+ 0.13747959, -0.1361751, -0.11569455, 0.0033329215,
+ 0.05693899, -0.053219706, 0.063698, 0.07977434,
+ -0.07924483, 0.06936997, 0.0034815092, -0.007305279,
+ -0.037325785, -0.07251102, -0.033633437, -0.08677009,
+ 0.091591336, -0.14165086, 0.021752775, 0.019683983,
+ 0.0011612234, -0.058154266, 0.049996935, 0.0288841,
+ -0.0024567875, -0.14345716, 0.010955264, -0.10234828,
+ 0.1183656, -0.0010731248, -0.023590032, -0.072285876,
+ -0.0724771, -0.026382286, -0.0014920527, 0.042667855,
+ 0.0018776858, 0.02986552, 0.009814309, 0.0733756,
+ 0.12289186, 0.018043943, -0.0458958, 0.049412545,
+ 0.033632483, 0.05495232, 0.036686596, -0.013781798,
+ -0.010036754, 0.02576849, -0.08307328, 0.010112348,
+ 0.042521734, -0.05869831, -0.071689695, 0.03876447,
+ -0.13275425, -0.0352966, -0.023077697, 0.10285965,
+ 0.084736146, 0.15568255, -0.00040734606, 0.027835453,
+ -0.10292561, -0.032401145, 0.10053256, -0.026142767,
+ -0.08271222, -0.0030240538, -0.016368777, 0.1070414,
+ 0.042672627, 0.013456989, -0.0437609, -0.022309763,
+ 0.11576483, 0.04108048, 0.061026827, -0.0190714,
+ -0.0869359, 0.037901703, 0.0610107, 0.07202949,
+ 0.01675338, 0.086139716, -0.08795751, -0.014898893,
+ -0.023771819, -0.01965048, 0.007955471, -0.043740474,
+ 0.03346837, -0.10549954, 0.090567775, 0.042013682,
+ -0.03176985, 0.12569028, -0.02421228, -0.029526481,
+ 0.023851605, 0.031539805, 0.05292009, -0.02344001,
+ -0.07811758, -0.08834428, 0.10094801, 0.16594367,
+ -0.06861939, -0.021256343, -0.041093912, -0.06669611,
+ 0.035498552, 0.021757556, -0.09302526, -0.015403468,
+ -0.06614931, -0.051798206, -0.013874718, 0.03630673,
+ 0.010412845, -0.08077351, 0.046185967, 0.0035662893,
+ 0.03541868, -0.094149634, -0.034814864, 0.003128424,
+ -0.020674974, -0.03944324, -0.008110165, -0.11113267,
+ 0.08484226, 0.043586485, 0.040582247, 0.0968012,
+ -0.065249965, -0.028036479, 0.0050708856, 0.0017462453,
+ 0.0326779, 0.041296225, 0.09164146, -0.047743853,
+ -0.015952192, -0.034451712, 0.084197424, -0.05347844,
+ -0.11768019, 0.085926116, -0.08251791, -0.045081906,
+ 0.0948852, 0.068401024, 0.024856757, 0.06978981,
+ -0.057309967, -0.012775832, -0.0032452994, 0.01977615,
+ -0.041040014, -0.024264973, 0.063464895, 0.05431621,
+ };
+
+ cell_to_input_weights_ = {
+ 0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458,
+ -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174,
+ -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047,
+ 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175};
+
+ cell_to_forget_weights_ = {
+ -0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276,
+ -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766,
+ -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774,
+ 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355};
+
+ cell_to_output_weights_ = {
+ 0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764,
+ -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544,
+ -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817,
+ 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733};
+
+ projection_weights_ = {
+ -0.009802181, 0.09401916, 0.0717386, -0.13895074,
+ 0.09641832, 0.060420845, 0.08539281, 0.054285463,
+ 0.061395317, 0.034448683, -0.042991187, 0.019801661,
+ -0.16840284, -0.015726732, -0.23041931, -0.024478018,
+ -0.10959692, -0.013875541, 0.18600968, -0.061274476,
+ 0.0138165, -0.08160894, -0.07661644, 0.032372914,
+ 0.16169067, 0.22465782, -0.03993472, -0.004017731,
+ 0.08633481, -0.28869787, 0.08682067, 0.17240396,
+ 0.014975425, 0.056431185, 0.031037588, 0.16702051,
+ 0.0077946745, 0.15140012, 0.29405436, 0.120285,
+ -0.188994, -0.027265169, 0.043389652, -0.022061434,
+ 0.014777949, -0.20203483, 0.094781205, 0.19100232,
+ 0.13987629, -0.036132768, -0.06426278, -0.05108664,
+ 0.13221376, 0.009441198, -0.16715929, 0.15859416,
+ -0.040437475, 0.050779544, -0.022187516, 0.012166504,
+ 0.027685808, -0.07675938, -0.0055694645, -0.09444123,
+ 0.0046453946, 0.050794356, 0.10770313, -0.20790008,
+ -0.07149004, -0.11425117, 0.008225835, -0.035802525,
+ 0.14374903, 0.15262283, 0.048710253, 0.1847461,
+ -0.007487823, 0.11000021, -0.09542012, 0.22619456,
+ -0.029149994, 0.08527916, 0.009043713, 0.0042746216,
+ 0.016261552, 0.022461696, 0.12689082, -0.043589946,
+ -0.12035478, -0.08361797, -0.050666027, -0.1248618,
+ -0.1275799, -0.071875185, 0.07377272, 0.09944291,
+ -0.18897448, -0.1593054, -0.06526116, -0.040107165,
+ -0.004618631, -0.067624845, -0.007576253, 0.10727444,
+ 0.041546922, -0.20424393, 0.06907816, 0.050412357,
+ 0.00724631, 0.039827548, 0.12449835, 0.10747581,
+ 0.13708383, 0.09134148, -0.12617786, -0.06428341,
+ 0.09956831, 0.1208086, -0.14676677, -0.0727722,
+ 0.1126304, 0.010139365, 0.015571211, -0.038128063,
+ 0.022913318, -0.042050496, 0.16842307, -0.060597885,
+ 0.10531834, -0.06411776, -0.07451711, -0.03410368,
+ -0.13393489, 0.06534304, 0.003620307, 0.04490757,
+ 0.05970546, 0.05197996, 0.02839995, 0.10434969,
+ -0.013699693, -0.028353551, -0.07260381, 0.047201227,
+ -0.024575593, -0.036445823, 0.07155557, 0.009672501,
+ -0.02328883, 0.009533515, -0.03606021, -0.07421458,
+ -0.028082801, -0.2678904, -0.13221288, 0.18419984,
+ -0.13012612, -0.014588381, -0.035059117, -0.04824723,
+ 0.07830115, -0.056184657, 0.03277091, 0.025466874,
+ 0.14494097, -0.12522776, -0.098633975, -0.10766018,
+ -0.08317623, 0.08594209, 0.07749552, 0.039474737,
+ 0.1776665, -0.07409566, -0.0477268, 0.29323658,
+ 0.10801441, 0.1154011, 0.013952499, 0.10739139,
+ 0.10708251, -0.051456142, 0.0074137426, -0.10430189,
+ 0.10034707, 0.045594677, 0.0635285, -0.0715442,
+ -0.089667566, -0.10811871, 0.00026344223, 0.08298446,
+ -0.009525053, 0.006585689, -0.24567553, -0.09450807,
+ 0.09648481, 0.026996298, -0.06419476, -0.04752702,
+ -0.11063944, -0.23441927, -0.17608605, -0.052156363,
+ 0.067035615, 0.19271925, -0.0032889997, -0.043264326,
+ 0.09663576, -0.057112187, -0.10100678, 0.0628376,
+ 0.04447668, 0.017961001, -0.10094388, -0.10190601,
+ 0.18335468, 0.10494553, -0.052095775, -0.0026118709,
+ 0.10539724, -0.04383912, -0.042349473, 0.08438151,
+ -0.1947263, 0.02251204, 0.11216432, -0.10307853,
+ 0.17351969, -0.039091777, 0.08066188, -0.00561982,
+ 0.12633002, 0.11335965, -0.0088127935, -0.019777594,
+ 0.06864014, -0.059751723, 0.016233567, -0.06894641,
+ -0.28651384, -0.004228674, 0.019708522, -0.16305895,
+ -0.07468996, -0.0855457, 0.099339016, -0.07580735,
+ -0.13775392, 0.08434318, 0.08330512, -0.12131499,
+ 0.031935584, 0.09180414, -0.08876437, -0.08049874,
+ 0.008753825, 0.03498998, 0.030215185, 0.03907079,
+ 0.089751154, 0.029194152, -0.03337423, -0.019092513,
+ 0.04331237, 0.04299654, -0.036394123, -0.12915532,
+ 0.09793732, 0.07512415, -0.11319543, -0.032502122,
+ 0.15661901, 0.07671967, -0.005491124, -0.19379048,
+ -0.218606, 0.21448623, 0.017840758, 0.1416943,
+ -0.07051762, 0.19488361, 0.02664691, -0.18104725,
+ -0.09334311, 0.15026465, -0.15493552, -0.057762887,
+ -0.11604192, -0.262013, -0.01391798, 0.012185008,
+ 0.11156489, -0.07483202, 0.06693364, -0.26151478,
+ 0.046425626, 0.036540434, -0.16435726, 0.17338543,
+ -0.21401681, -0.11385144, -0.08283257, -0.069031075,
+ 0.030635102, 0.010969227, 0.11109743, 0.010919218,
+ 0.027526086, 0.13519906, 0.01891392, -0.046839405,
+ -0.040167913, 0.017953383, -0.09700955, 0.0061885654,
+ -0.07000971, 0.026893595, -0.038844477, 0.14543656};
+
+ lstm_input_ = {
+ {// Batch0: 4 (input_sequence_size) * 5 (n_input)
+ 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, // step 0
+ 0.596268, 0.998386, 0.568695, 0.864524, 0.571277, // step 1
+ 0.073204, 0.296072, 0.743333, 0.069199, 0.045348, // step 2
+ 0.867394, 0.291279, 0.013714, 0.482521, 0.626339}, // step 3
+
+ {// Batch1: 4 (input_sequence_size) * 5 (n_input)
+ 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, // step 0
+ 0.642421, 0.524260, 0.134799, 0.003639, 0.162482, // step 1
+ 0.640394, 0.930399, 0.050782, 0.432485, 0.988078, // step 2
+ 0.082922, 0.563329, 0.865614, 0.333232, 0.259916} // step 3
+ };
+
+ lstm_golden_output_ = {
+ {// Batch0: 4 (input_sequence_size) * 16 (n_output)
+ -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576,
+ -0.0211779, 0.0283512, -0.0114597, 0.00907307, -0.0244004,
+ -0.0152191, -0.0259063, 0.00914318, 0.00415118, 0.017147,
+ 0.0134203, -0.0166936, 0.0381209, 0.000889694, 0.0143363,
+ -0.0328911, -0.0234288, 0.0333051, -0.012229, 0.0110322,
+ -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308,
+ 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794,
+ 0.0276012, -0.0263374, -0.0371449, 0.0446149, -0.0205474,
+ 0.0103729, -0.0576349, -0.0150052, -0.0292043, 0.0376827,
+ 0.0136115, 0.0243435, 0.0354492, -0.0189322, 0.0464512,
+ -0.00251373, 0.0225745, -0.0308346, -0.0317124, 0.0460407,
+ -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193,
+ 0.0286833, 0.00824207, 0.0264887, 0.0305169},
+ {// Batch1: 4 (input_sequence_size) * 16 (n_output)
+ -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926,
+ -0.0186926, 0.0193662, -0.0115437, 0.00422612, -0.0345232,
+ 0.00223253, -0.00957321, 0.0210624, 0.013331, 0.0150954,
+ 0.02168, -0.0141913, 0.0322082, 0.00227024, 0.0260507,
+ -0.0188721, -0.0296489, 0.0399134, -0.0160509, 0.0116039,
+ -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233,
+ 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378,
+ 0.0167673, -0.0375007, -0.0238314, 0.038784, -0.0174034,
+ 0.0131743, -0.0506589, -0.0048447, -0.0240239, 0.0325789,
+ 0.00790065, 0.0220157, 0.0333314, -0.0264787, 0.0387855,
+ -0.000764675, 0.0217599, -0.037537, -0.0335206, 0.0431679,
+ -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181,
+ 0.0412031, 0.0118723, 0.0239643, 0.0394009}};
+ }
+};
+
+TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) {
const int n_batch = 2;
const int n_input = 5;
const int n_cell = 20;
@@ -461,8 +1322,9 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) {
const int sequence_length = 4;
UnidirectionalLSTMOpModel lstm(
- n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
- /*use_peephole=*/true, /*use_projection_weights=*/true,
+ n_batch, n_input, n_cell, n_output, sequence_length,
+ /*use_cifg=*/false, /*use_peephole=*/true,
+ /*use_projection_weights=*/true,
/*use_projection_bias=*/false,
/*cell_clip=*/0.0, /*proj_clip=*/0.0,
{
@@ -491,588 +1353,99 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) {
{0}, // projection_bias tensor
});
- lstm.SetInputToInputWeights(
- {0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463,
- 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048,
- -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385,
- -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282,
- -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627,
- -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226,
- -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059,
- 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698,
- 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206,
- 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585,
- -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063,
- 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603,
- -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682,
- -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988,
- -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764,
- 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476,
- -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012,
- -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604,
- -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654,
- -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677});
-
- lstm.SetInputToForgetWeights(
- {-0.0018401089, -0.004852237, 0.03698424, 0.014181704, 0.028273236,
- -0.016726194, -0.05249759, -0.10204261, 0.00861066, -0.040979505,
- -0.009899187, 0.01923892, -0.028177269, -0.08535103, -0.14585495,
- 0.10662567, -0.01909731, -0.017883534, -0.0047269356, -0.045103323,
- 0.0030784295, 0.076784775, 0.07463696, 0.094531395, 0.0814421,
- -0.12257899, -0.033945758, -0.031303465, 0.045630626, 0.06843887,
- -0.13492945, -0.012480007, -0.0811829, -0.07224499, -0.09628791,
- 0.045100946, 0.0012300825, 0.013964662, 0.099372394, 0.02543059,
- 0.06958324, 0.034257296, 0.0482646, 0.06267997, 0.052625068,
- 0.12784666, 0.07077897, 0.025725935, 0.04165009, 0.07241905,
- 0.018668644, -0.037377294, -0.06277783, -0.08833636, -0.040120605,
- -0.011405586, -0.007808335, -0.010301386, -0.005102167, 0.027717464,
- 0.05483423, 0.11449111, 0.11289652, 0.10939839, 0.13396506,
- -0.08402166, -0.01901462, -0.044678304, -0.07720565, 0.014350063,
- -0.11757958, -0.0652038, -0.08185733, -0.076754324, -0.092614375,
- 0.10405491, 0.052960336, 0.035755895, 0.035839386, -0.012540553,
- 0.036881298, 0.02913376, 0.03420159, 0.05448447, -0.054523353,
- 0.02582715, 0.02327355, -0.011857179, -0.0011980024, -0.034641717,
- -0.026125094, -0.17582615, -0.15923657, -0.27486774, -0.0006143371,
- 0.0001771948, -8.470171e-05, 0.02651807, 0.045790765, 0.06956496});
-
- lstm.SetInputToCellWeights(
- {-0.04580283, -0.09549462, -0.032418985, -0.06454633,
- -0.043528453, 0.043018587, -0.049152344, -0.12418144,
- -0.078985475, -0.07596889, 0.019484362, -0.11434962,
- -0.0074034138, -0.06314844, -0.092981495, 0.0062155537,
- -0.025034338, -0.0028890965, 0.048929527, 0.06235075,
- 0.10665918, -0.032036792, -0.08505916, -0.10843358,
- -0.13002433, -0.036816437, -0.02130134, -0.016518239,
- 0.0047691227, -0.0025825808, 0.066017866, 0.029991534,
- -0.10652836, -0.1037554, -0.13056071, -0.03266643,
- -0.033702414, -0.006473424, -0.04611692, 0.014419339,
- -0.025174323, 0.0396852, 0.081777506, 0.06157468,
- 0.10210095, -0.009658194, 0.046511717, 0.03603906,
- 0.0069369148, 0.015960095, -0.06507666, 0.09551598,
- 0.053568836, 0.06408714, 0.12835667, -0.008714329,
- -0.20211966, -0.12093674, 0.029450472, 0.2849013,
- -0.029227901, 0.1164364, -0.08560263, 0.09941786,
- -0.036999565, -0.028842626, -0.0033637602, -0.017012902,
- -0.09720865, -0.11193351, -0.029155117, -0.017936034,
- -0.009768936, -0.04223324, -0.036159635, 0.06505112,
- -0.021742892, -0.023377212, -0.07221364, -0.06430552,
- 0.05453865, 0.091149814, 0.06387331, 0.007518393,
- 0.055960953, 0.069779344, 0.046411168, 0.10509911,
- 0.07463894, 0.0075130584, 0.012850982, 0.04555431,
- 0.056955688, 0.06555285, 0.050801456, -0.009862683,
- 0.00826772, -0.026555609, -0.0073611983, -0.0014897042});
-
- lstm.SetInputToOutputWeights(
- {-0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918,
- -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534,
- 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722,
- -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761,
- -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394,
- 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154,
- -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135,
- -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564,
- -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047,
- -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304,
- 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946,
- 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646,
- 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813,
- -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403,
- 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415,
- 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495,
- -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158,
- 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295,
- -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739,
- -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956});
-
- lstm.SetInputGateBias(
- {0.02234832, 0.14757581, 0.18176508, 0.10380666, 0.053110216,
- -0.06928846, -0.13942584, -0.11816189, 0.19483899, 0.03652339,
- -0.10250295, 0.036714908, -0.18426876, 0.036065217, 0.21810818,
- 0.02383196, -0.043370757, 0.08690144, -0.04444982, 0.00030581196});
-
- lstm.SetForgetGateBias({0.035185695, -0.042891346, -0.03032477, 0.23027696,
- 0.11098921, 0.15378423, 0.09263801, 0.09790885,
- 0.09508917, 0.061199076, 0.07665568, -0.015443159,
- -0.03499149, 0.046190713, 0.08895977, 0.10899629,
- 0.40694186, 0.06030037, 0.012413437, -0.06108739});
-
- lstm.SetCellBias({-0.024379363, 0.0055531194, 0.23377132, 0.033463873,
- -0.1483596, -0.10639995, -0.091433935, 0.058573797,
- -0.06809782, -0.07889636, -0.043246906, -0.09829136,
- -0.4279842, 0.034901652, 0.18797937, 0.0075234566,
- 0.016178843, 0.1749513, 0.13975595, 0.92058027});
-
- lstm.SetOutputGateBias(
- {0.046159424, -0.0012809046, 0.03563469, 0.12648113, 0.027195795,
- 0.35373217, -0.018957434, 0.008907322, -0.0762701, 0.12018895,
- 0.04216877, 0.0022856654, 0.040952638, 0.3147856, 0.08225149,
- -0.057416286, -0.14995944, -0.008040261, 0.13208859, 0.029760877});
-
- lstm.SetRecurrentToInputWeights(
- {-0.001374326, -0.078856036, 0.10672688, 0.029162422,
- -0.11585556, 0.02557986, -0.13446963, -0.035785314,
- -0.01244275, 0.025961924, -0.02337298, -0.044228926,
- -0.055839065, -0.046598054, -0.010546039, -0.06900766,
- 0.027239809, 0.022582639, -0.013296484, -0.05459212,
- 0.08981, -0.045407712, 0.08682226, -0.06867011,
- -0.14390695, -0.02916037, 0.000996957, 0.091420636,
- 0.14283475, -0.07390571, -0.06402044, 0.062524505,
- -0.093129106, 0.04860203, -0.08364217, -0.08119002,
- 0.009352075, 0.22920375, 0.0016303885, 0.11583097,
- -0.13732095, 0.012405723, -0.07551853, 0.06343048,
- 0.12162708, -0.031923793, -0.014335606, 0.01790974,
- -0.10650317, -0.0724401, 0.08554849, -0.05727212,
- 0.06556731, -0.042729504, -0.043227166, 0.011683251,
- -0.013082158, -0.029302018, -0.010899579, -0.062036745,
- -0.022509435, -0.00964907, -0.01567329, 0.04260106,
- -0.07787477, -0.11576462, 0.017356863, 0.048673786,
- -0.017577527, -0.05527947, -0.082487635, -0.040137455,
- -0.10820036, -0.04666372, 0.022746278, -0.07851417,
- 0.01068115, 0.032956902, 0.022433773, 0.0026891115,
- 0.08944216, -0.0685835, 0.010513544, 0.07228705,
- 0.02032331, -0.059686817, -0.0005566496, -0.086984694,
- 0.040414046, -0.1380399, 0.094208956, -0.05722982,
- 0.012092817, -0.04989123, -0.086576, -0.003399834,
- -0.04696032, -0.045747425, 0.10091314, 0.048676282,
- -0.029037097, 0.031399418, -0.0040285117, 0.047237843,
- 0.09504992, 0.041799378, -0.049185462, -0.031518843,
- -0.10516937, 0.026374253, 0.10058866, -0.0033195973,
- -0.041975245, 0.0073591834, 0.0033782164, -0.004325073,
- -0.10167381, 0.042500053, -0.01447153, 0.06464186,
- -0.017142897, 0.03312627, 0.009205989, 0.024138335,
- -0.011337001, 0.035530265, -0.010912711, 0.0706555,
- -0.005894094, 0.051841937, -0.1401738, -0.02351249,
- 0.0365468, 0.07590991, 0.08838724, 0.021681072,
- -0.10086113, 0.019608743, -0.06195883, 0.077335775,
- 0.023646897, -0.095322326, 0.02233014, 0.09756986,
- -0.048691444, -0.009579111, 0.07595467, 0.11480546,
- -0.09801813, 0.019894179, 0.08502348, 0.004032281,
- 0.037211012, 0.068537936, -0.048005626, -0.091520436,
- -0.028379958, -0.01556313, 0.06554592, -0.045599163,
- -0.01672207, -0.020169014, -0.011877351, -0.20212261,
- 0.010889619, 0.0047078193, 0.038385306, 0.08540671,
- -0.017140968, -0.0035865551, 0.016678626, 0.005633034,
- 0.015963363, 0.00871737, 0.060130805, 0.028611384,
- 0.10109069, -0.015060172, -0.07894427, 0.06401885,
- 0.011584063, -0.024466386, 0.0047652307, -0.09041358,
- 0.030737216, -0.0046374933, 0.14215417, -0.11823516,
- 0.019899689, 0.006106124, -0.027092824, 0.0786356,
- 0.05052217, -0.058925, -0.011402121, -0.024987547,
- -0.0013661642, -0.06832946, -0.015667673, -0.1083353,
- -0.00096863037, -0.06988685, -0.053350925, -0.027275559,
- -0.033664223, -0.07978348, -0.025200296, -0.017207067,
- -0.058403496, -0.055697463, 0.005798788, 0.12965427,
- -0.062582195, 0.0013350133, -0.10482091, 0.0379771,
- 0.072521195, -0.0029455067, -0.13797039, -0.03628521,
- 0.013806405, -0.017858358, -0.01008298, -0.07700066,
- -0.017081132, 0.019358726, 0.0027079724, 0.004635139,
- 0.062634714, -0.02338735, -0.039547626, -0.02050681,
- 0.03385117, -0.083611414, 0.002862572, -0.09421313,
- 0.058618143, -0.08598433, 0.00972939, 0.023867095,
- -0.053934585, -0.023203006, 0.07452513, -0.048767887,
- -0.07314807, -0.056307215, -0.10433547, -0.06440842,
- 0.04328182, 0.04389765, -0.020006588, -0.09076438,
- -0.11652589, -0.021705797, 0.03345259, -0.010329105,
- -0.025767034, 0.013057034, -0.07316461, -0.10145612,
- 0.06358255, 0.18531723, 0.07759293, 0.12006465,
- 0.1305557, 0.058638252, -0.03393652, 0.09622831,
- -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845,
- -0.005644518, 0.06857898, -0.12598175, -0.035084512,
- 0.03156317, -0.12794146, -0.031963028, 0.04692781,
- 0.030070418, 0.0071660685, -0.095516115, -0.004643372,
- 0.040170413, -0.062104587, -0.0037324072, 0.0554317,
- 0.08184801, -0.019164372, 0.06791302, 0.034257166,
- -0.10307039, 0.021943003, 0.046745934, 0.0790918,
- -0.0265588, -0.007824208, 0.042546265, -0.00977924,
- -0.0002440307, -0.017384544, -0.017990116, 0.12252321,
- -0.014512694, -0.08251313, 0.08861942, 0.13589665,
- 0.026351685, 0.012641483, 0.07466548, 0.044301085,
- -0.045414884, -0.051112458, 0.03444247, -0.08502782,
- -0.04106223, -0.028126027, 0.028473156, 0.10467447});
-
- lstm.SetRecurrentToForgetWeights(
- {-0.057784554, -0.026057621, -0.068447545, -0.022581743,
- 0.14811787, 0.10826372, 0.09471067, 0.03987225,
- -0.0039523416, 0.00030638507, 0.053185795, 0.10572994,
- 0.08414449, -0.022036452, -0.00066928595, -0.09203576,
- 0.032950465, -0.10985798, -0.023809856, 0.0021431844,
- -0.02196096, -0.00326074, 0.00058621005, -0.074678116,
- -0.06193199, 0.055729095, 0.03736828, 0.020123724,
- 0.061878487, -0.04729229, 0.034919553, -0.07585433,
- -0.04421272, -0.044019096, 0.085488975, 0.04058006,
- -0.06890133, -0.030951202, -0.024628663, -0.07672815,
- 0.034293607, 0.08556707, -0.05293577, -0.033561368,
- -0.04899627, 0.0241671, 0.015736353, -0.095442444,
- -0.029564252, 0.016493602, -0.035026584, 0.022337519,
- -0.026871363, 0.004780428, 0.0077918363, -0.03601621,
- 0.016435321, -0.03263031, -0.09543275, -0.047392778,
- 0.013454138, 0.028934088, 0.01685226, -0.086110644,
- -0.046250615, -0.01847454, 0.047608484, 0.07339695,
- 0.034546845, -0.04881143, 0.009128804, -0.08802852,
- 0.03761666, 0.008096139, -0.014454086, 0.014361001,
- -0.023502491, -0.0011840804, -0.07607001, 0.001856849,
- -0.06509276, -0.006021153, -0.08570962, -0.1451793,
- 0.060212336, 0.055259194, 0.06974018, 0.049454916,
- -0.027794661, -0.08077226, -0.016179763, 0.1169753,
- 0.17213494, -0.0056326236, -0.053934924, -0.0124349,
- -0.11520337, 0.05409887, 0.088759385, 0.0019655675,
- 0.0042065294, 0.03881498, 0.019844765, 0.041858196,
- -0.05695512, 0.047233116, 0.038937137, -0.06542224,
- 0.014429736, -0.09719407, 0.13908425, -0.05379757,
- 0.012321099, 0.082840554, -0.029899208, 0.044217527,
- 0.059855383, 0.07711018, -0.045319796, 0.0948846,
- -0.011724666, -0.0033288454, -0.033542685, -0.04764985,
- -0.13873616, 0.040668588, 0.034832682, -0.015319203,
- -0.018715994, 0.046002675, 0.0599172, -0.043107376,
- 0.0294216, -0.002314414, -0.022424703, 0.0030315618,
- 0.0014641669, 0.0029166266, -0.11878115, 0.013738511,
- 0.12375372, -0.0006038222, 0.029104086, 0.087442465,
- 0.052958444, 0.07558703, 0.04817258, 0.044462286,
- -0.015213451, -0.08783778, -0.0561384, -0.003008196,
- 0.047060397, -0.002058388, 0.03429439, -0.018839769,
- 0.024734668, 0.024614193, -0.042046934, 0.09597743,
- -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786,
- -0.02558259, -0.022822596, -0.023273505, -0.02464396,
- -0.10991725, -0.006240552, 0.0074488563, 0.024044557,
- 0.04383914, -0.046476185, 0.028658995, 0.060410924,
- 0.050786525, 0.009452605, -0.0073054377, -0.024810238,
- 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517,
- 0.015898481, 0.021362653, -0.030262267, 0.016587038,
- -0.011442813, 0.041154444, -0.007631438, -0.03423484,
- -0.010977775, 0.036152758, 0.0066366293, 0.11915515,
- 0.02318443, -0.041350313, 0.021485701, -0.10906167,
- -0.028218046, -0.00954771, 0.020531068, -0.11995105,
- -0.03672871, 0.024019798, 0.014255957, -0.05221243,
- -0.00661567, -0.04630967, 0.033188973, 0.10107534,
- -0.014027541, 0.030796422, -0.10270911, -0.035999842,
- 0.15443139, 0.07684145, 0.036571592, -0.035900835,
- -0.0034699554, 0.06209149, 0.015920248, -0.031122351,
- -0.03858649, 0.01849943, 0.13872518, 0.01503974,
- 0.069941424, -0.06948533, -0.0088794185, 0.061282158,
- -0.047401894, 0.03100163, -0.041533746, -0.10430945,
- 0.044574402, -0.01425562, -0.024290353, 0.034563623,
- 0.05866852, 0.023947537, -0.09445152, 0.035450947,
- 0.02247216, -0.0042998926, 0.061146557, -0.10250651,
- 0.020881841, -0.06747029, 0.10062043, -0.0023941975,
- 0.03532124, -0.016341697, 0.09685456, -0.016764693,
- 0.051808182, 0.05875331, -0.04536488, 0.001626336,
- -0.028892258, -0.01048663, -0.009793449, -0.017093895,
- 0.010987891, 0.02357273, -0.00010856845, 0.0099760275,
- -0.001845119, -0.03551521, 0.0018358806, 0.05763657,
- -0.01769146, 0.040995963, 0.02235177, -0.060430344,
- 0.11475477, -0.023854522, 0.10071741, 0.0686208,
- -0.014250481, 0.034261297, 0.047418304, 0.08562733,
- -0.030519066, 0.0060542435, 0.014653856, -0.038836084,
- 0.04096551, 0.032249358, -0.08355519, -0.026823482,
- 0.056386515, -0.010401743, -0.028396193, 0.08507674,
- 0.014410365, 0.020995233, 0.17040324, 0.11511526,
- 0.02459721, 0.0066619175, 0.025853224, -0.023133837,
- -0.081302024, 0.017264642, -0.009585969, 0.09491168,
- -0.051313367, 0.054532815, -0.014298593, 0.10657464,
- 0.007076659, 0.10964551, 0.0409152, 0.008275321,
- -0.07283536, 0.07937492, 0.04192024, -0.1075027});
-
- lstm.SetRecurrentToCellWeights(
- {-0.037322544, 0.018592842, 0.0056175636, -0.06253426,
- 0.055647098, -0.05713207, -0.05626563, 0.005559383,
- 0.03375411, -0.025757805, -0.088049285, 0.06017052,
- -0.06570978, 0.007384076, 0.035123326, -0.07920549,
- 0.053676967, 0.044480428, -0.07663568, 0.0071805613,
- 0.08089997, 0.05143358, 0.038261272, 0.03339287,
- -0.027673481, 0.044746667, 0.028349208, 0.020090483,
- -0.019443132, -0.030755889, -0.0040000007, 0.04465846,
- -0.021585021, 0.0031670958, 0.0053199246, -0.056117613,
- -0.10893326, 0.076739706, -0.08509834, -0.027997585,
- 0.037871376, 0.01449768, -0.09002357, -0.06111149,
- -0.046195522, 0.0422062, -0.005683705, -0.1253618,
- -0.012925729, -0.04890792, 0.06985068, 0.037654128,
- 0.03398274, -0.004781977, 0.007032333, -0.031787455,
- 0.010868644, -0.031489216, 0.09525667, 0.013939797,
- 0.0058680447, 0.0167067, 0.02668468, -0.04797466,
- -0.048885044, -0.12722108, 0.035304096, 0.06554885,
- 0.00972396, -0.039238118, -0.05159735, -0.11329045,
- 0.1613692, -0.03750952, 0.06529313, -0.071974665,
- -0.11769596, 0.015524369, -0.0013754242, -0.12446318,
- 0.02786344, -0.014179351, 0.005264273, 0.14376344,
- 0.015983658, 0.03406988, -0.06939408, 0.040699873,
- 0.02111075, 0.09669095, 0.041345075, -0.08316494,
- -0.07684199, -0.045768797, 0.032298047, -0.041805092,
- 0.0119405, 0.0061010392, 0.12652606, 0.0064572375,
- -0.024950314, 0.11574242, 0.04508852, -0.04335324,
- 0.06760663, -0.027437469, 0.07216407, 0.06977076,
- -0.05438599, 0.034033038, -0.028602652, 0.05346137,
- 0.043184172, -0.037189785, 0.10420091, 0.00882477,
- -0.054019816, -0.074273005, -0.030617684, -0.0028467078,
- 0.024302477, -0.0038869337, 0.005332455, 0.0013399826,
- 0.04361412, -0.007001822, 0.09631092, -0.06702025,
- -0.042049985, -0.035070654, -0.04103342, -0.10273396,
- 0.0544271, 0.037184782, -0.13150354, -0.0058036847,
- -0.008264958, 0.042035464, 0.05891794, 0.029673764,
- 0.0063542654, 0.044788733, 0.054816857, 0.062257513,
- -0.00093483756, 0.048938446, -0.004952862, -0.007730018,
- -0.04043371, -0.017094059, 0.07229206, -0.023670016,
- -0.052195564, -0.025616996, -0.01520939, 0.045104615,
- -0.007376126, 0.003533447, 0.006570588, 0.056037236,
- 0.12436656, 0.051817212, 0.028532185, -0.08686856,
- 0.11868599, 0.07663395, -0.07323171, 0.03463402,
- -0.050708205, -0.04458982, -0.11590894, 0.021273347,
- 0.1251325, -0.15313013, -0.12224372, 0.17228661,
- 0.023029093, 0.086124025, 0.006445803, -0.03496501,
- 0.028332196, 0.04449512, -0.042436164, -0.026587414,
- -0.006041347, -0.09292539, -0.05678812, 0.03897832,
- 0.09465633, 0.008115513, -0.02171956, 0.08304309,
- 0.071401566, 0.019622514, 0.032163795, -0.004167056,
- 0.02295182, 0.030739572, 0.056506045, 0.004612461,
- 0.06524936, 0.059999723, 0.046395954, -0.0045512207,
- -0.1335546, -0.030136576, 0.11584653, -0.014678886,
- 0.0020118146, -0.09688814, -0.0790206, 0.039770417,
- -0.0329582, 0.07922767, 0.029322514, 0.026405897,
- 0.04207835, -0.07073373, 0.063781224, 0.0859677,
- -0.10925287, -0.07011058, 0.048005477, 0.03438226,
- -0.09606514, -0.006669445, -0.043381985, 0.04240257,
- -0.06955775, -0.06769346, 0.043903265, -0.026784198,
- -0.017840602, 0.024307009, -0.040079936, -0.019946516,
- 0.045318738, -0.12233574, 0.026170589, 0.0074471775,
- 0.15978073, 0.10185836, 0.10298046, -0.015476589,
- -0.039390966, -0.072174534, 0.0739445, -0.1211869,
- -0.0347889, -0.07943156, 0.014809798, -0.12412325,
- -0.0030663363, 0.039695457, 0.0647603, -0.08291318,
- -0.018529687, -0.004423833, 0.0037507233, 0.084633216,
- -0.01514876, -0.056505352, -0.012800942, -0.06994386,
- 0.012962922, -0.031234352, 0.07029052, 0.016418684,
- 0.03618972, 0.055686004, -0.08663945, -0.017404709,
- -0.054761406, 0.029065743, 0.052404847, 0.020238016,
- 0.0048197987, -0.0214882, 0.07078733, 0.013016777,
- 0.06262858, 0.009184685, 0.020785125, -0.043904778,
- -0.0270329, -0.03299152, -0.060088247, -0.015162964,
- -0.001828936, 0.12642565, -0.056757294, 0.013586685,
- 0.09232601, -0.035886683, 0.06000002, 0.05229691,
- -0.052580316, -0.082029596, -0.010794592, 0.012947712,
- -0.036429964, -0.085508935, -0.13127148, -0.017744139,
- 0.031502828, 0.036232427, -0.031581745, 0.023051167,
- -0.05325106, -0.03421577, 0.028793324, -0.034633752,
- -0.009881397, -0.043551125, -0.018609839, 0.0019097115,
- -0.008799762, 0.056595087, 0.0022273948, 0.055752404});
-
- lstm.SetRecurrentToOutputWeights({
- 0.025825322, -0.05813119, 0.09495884, -0.045984812, -0.01255415,
- -0.0026479573, -0.08196161, -0.054914974, -0.0046604523, -0.029587349,
- -0.044576716, -0.07480124, -0.082868785, 0.023254942, 0.027502948,
- -0.0039728214, -0.08683098, -0.08116779, -0.014675607, -0.037924774,
- -0.023314456, -0.007401714, -0.09255757, 0.029460307, -0.08829125,
- -0.005139627, -0.08989442, -0.0555066, 0.13596267, -0.025062224,
- -0.048351806, -0.03850004, 0.07266485, -0.022414139, 0.05940088,
- 0.075114764, 0.09597592, -0.010211725, -0.0049794707, -0.011523867,
- -0.025980417, 0.072999895, 0.11091378, -0.081685916, 0.014416728,
- 0.043229222, 0.034178585, -0.07530371, 0.035837382, -0.085607,
- -0.007721233, -0.03287832, -0.043848954, -0.06404588, -0.06632928,
- -0.073643476, 0.008214239, -0.045984086, 0.039764922, 0.03474462,
- 0.060612556, -0.080590084, 0.049127717, 0.04151091, -0.030063879,
- 0.008801774, -0.023021035, -0.019558564, 0.05158114, -0.010947698,
- -0.011825728, 0.0075720972, 0.0699727, -0.0039981045, 0.069350146,
- 0.08799282, 0.016156472, 0.035502106, 0.11695009, 0.006217345,
- 0.13392477, -0.037875112, 0.025745004, 0.08940699, -0.00924166,
- 0.0046702605, -0.036598757, -0.08811812, 0.10522024, -0.032441203,
- 0.008176899, -0.04454919, 0.07058152, 0.0067963637, 0.039206743,
- 0.03259838, 0.03725492, -0.09515802, 0.013326398, -0.052055415,
- -0.025676316, 0.03198509, -0.015951829, -0.058556724, 0.036879618,
- 0.043357447, 0.028362012, -0.05908629, 0.0059240665, -0.04995891,
- -0.019187413, 0.0276265, -0.01628143, 0.0025863599, 0.08800015,
- 0.035250366, -0.022165963, -0.07328642, -0.009415526, -0.07455109,
- 0.11690406, 0.0363299, 0.07411125, 0.042103454, -0.009660886,
- 0.019076364, 0.018299393, -0.046004917, 0.08891175, 0.0431396,
- -0.026327137, -0.051502608, 0.08979574, -0.051670972, 0.04940282,
- -0.07491107, -0.021240504, 0.022596184, -0.034280192, 0.060163025,
- -0.058211457, -0.051837247, -0.01349775, -0.04639988, -0.035936575,
- -0.011681591, 0.064818054, 0.0073146066, -0.021745546, -0.043124277,
- -0.06471268, -0.07053354, -0.029321948, -0.05330136, 0.016933719,
- -0.053782392, 0.13747959, -0.1361751, -0.11569455, 0.0033329215,
- 0.05693899, -0.053219706, 0.063698, 0.07977434, -0.07924483,
- 0.06936997, 0.0034815092, -0.007305279, -0.037325785, -0.07251102,
- -0.033633437, -0.08677009, 0.091591336, -0.14165086, 0.021752775,
- 0.019683983, 0.0011612234, -0.058154266, 0.049996935, 0.0288841,
- -0.0024567875, -0.14345716, 0.010955264, -0.10234828, 0.1183656,
- -0.0010731248, -0.023590032, -0.072285876, -0.0724771, -0.026382286,
- -0.0014920527, 0.042667855, 0.0018776858, 0.02986552, 0.009814309,
- 0.0733756, 0.12289186, 0.018043943, -0.0458958, 0.049412545,
- 0.033632483, 0.05495232, 0.036686596, -0.013781798, -0.010036754,
- 0.02576849, -0.08307328, 0.010112348, 0.042521734, -0.05869831,
- -0.071689695, 0.03876447, -0.13275425, -0.0352966, -0.023077697,
- 0.10285965, 0.084736146, 0.15568255, -0.00040734606, 0.027835453,
- -0.10292561, -0.032401145, 0.10053256, -0.026142767, -0.08271222,
- -0.0030240538, -0.016368777, 0.1070414, 0.042672627, 0.013456989,
- -0.0437609, -0.022309763, 0.11576483, 0.04108048, 0.061026827,
- -0.0190714, -0.0869359, 0.037901703, 0.0610107, 0.07202949,
- 0.01675338, 0.086139716, -0.08795751, -0.014898893, -0.023771819,
- -0.01965048, 0.007955471, -0.043740474, 0.03346837, -0.10549954,
- 0.090567775, 0.042013682, -0.03176985, 0.12569028, -0.02421228,
- -0.029526481, 0.023851605, 0.031539805, 0.05292009, -0.02344001,
- -0.07811758, -0.08834428, 0.10094801, 0.16594367, -0.06861939,
- -0.021256343, -0.041093912, -0.06669611, 0.035498552, 0.021757556,
- -0.09302526, -0.015403468, -0.06614931, -0.051798206, -0.013874718,
- 0.03630673, 0.010412845, -0.08077351, 0.046185967, 0.0035662893,
- 0.03541868, -0.094149634, -0.034814864, 0.003128424, -0.020674974,
- -0.03944324, -0.008110165, -0.11113267, 0.08484226, 0.043586485,
- 0.040582247, 0.0968012, -0.065249965, -0.028036479, 0.0050708856,
- 0.0017462453, 0.0326779, 0.041296225, 0.09164146, -0.047743853,
- -0.015952192, -0.034451712, 0.084197424, -0.05347844, -0.11768019,
- 0.085926116, -0.08251791, -0.045081906, 0.0948852, 0.068401024,
- 0.024856757, 0.06978981, -0.057309967, -0.012775832, -0.0032452994,
- 0.01977615, -0.041040014, -0.024264973, 0.063464895, 0.05431621,
- });
-
- lstm.SetCellToInputWeights(
- {0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458,
- -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174,
- -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047,
- 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175});
-
- lstm.SetCellToForgetWeights(
- {-0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276,
- -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766,
- -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774,
- 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355});
-
- lstm.SetCellToOutputWeights(
- {0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764,
- -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544,
- -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817,
- 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733});
-
- lstm.SetProjectionWeights(
- {-0.009802181, 0.09401916, 0.0717386, -0.13895074, 0.09641832,
- 0.060420845, 0.08539281, 0.054285463, 0.061395317, 0.034448683,
- -0.042991187, 0.019801661, -0.16840284, -0.015726732, -0.23041931,
- -0.024478018, -0.10959692, -0.013875541, 0.18600968, -0.061274476,
- 0.0138165, -0.08160894, -0.07661644, 0.032372914, 0.16169067,
- 0.22465782, -0.03993472, -0.004017731, 0.08633481, -0.28869787,
- 0.08682067, 0.17240396, 0.014975425, 0.056431185, 0.031037588,
- 0.16702051, 0.0077946745, 0.15140012, 0.29405436, 0.120285,
- -0.188994, -0.027265169, 0.043389652, -0.022061434, 0.014777949,
- -0.20203483, 0.094781205, 0.19100232, 0.13987629, -0.036132768,
- -0.06426278, -0.05108664, 0.13221376, 0.009441198, -0.16715929,
- 0.15859416, -0.040437475, 0.050779544, -0.022187516, 0.012166504,
- 0.027685808, -0.07675938, -0.0055694645, -0.09444123, 0.0046453946,
- 0.050794356, 0.10770313, -0.20790008, -0.07149004, -0.11425117,
- 0.008225835, -0.035802525, 0.14374903, 0.15262283, 0.048710253,
- 0.1847461, -0.007487823, 0.11000021, -0.09542012, 0.22619456,
- -0.029149994, 0.08527916, 0.009043713, 0.0042746216, 0.016261552,
- 0.022461696, 0.12689082, -0.043589946, -0.12035478, -0.08361797,
- -0.050666027, -0.1248618, -0.1275799, -0.071875185, 0.07377272,
- 0.09944291, -0.18897448, -0.1593054, -0.06526116, -0.040107165,
- -0.004618631, -0.067624845, -0.007576253, 0.10727444, 0.041546922,
- -0.20424393, 0.06907816, 0.050412357, 0.00724631, 0.039827548,
- 0.12449835, 0.10747581, 0.13708383, 0.09134148, -0.12617786,
- -0.06428341, 0.09956831, 0.1208086, -0.14676677, -0.0727722,
- 0.1126304, 0.010139365, 0.015571211, -0.038128063, 0.022913318,
- -0.042050496, 0.16842307, -0.060597885, 0.10531834, -0.06411776,
- -0.07451711, -0.03410368, -0.13393489, 0.06534304, 0.003620307,
- 0.04490757, 0.05970546, 0.05197996, 0.02839995, 0.10434969,
- -0.013699693, -0.028353551, -0.07260381, 0.047201227, -0.024575593,
- -0.036445823, 0.07155557, 0.009672501, -0.02328883, 0.009533515,
- -0.03606021, -0.07421458, -0.028082801, -0.2678904, -0.13221288,
- 0.18419984, -0.13012612, -0.014588381, -0.035059117, -0.04824723,
- 0.07830115, -0.056184657, 0.03277091, 0.025466874, 0.14494097,
- -0.12522776, -0.098633975, -0.10766018, -0.08317623, 0.08594209,
- 0.07749552, 0.039474737, 0.1776665, -0.07409566, -0.0477268,
- 0.29323658, 0.10801441, 0.1154011, 0.013952499, 0.10739139,
- 0.10708251, -0.051456142, 0.0074137426, -0.10430189, 0.10034707,
- 0.045594677, 0.0635285, -0.0715442, -0.089667566, -0.10811871,
- 0.00026344223, 0.08298446, -0.009525053, 0.006585689, -0.24567553,
- -0.09450807, 0.09648481, 0.026996298, -0.06419476, -0.04752702,
- -0.11063944, -0.23441927, -0.17608605, -0.052156363, 0.067035615,
- 0.19271925, -0.0032889997, -0.043264326, 0.09663576, -0.057112187,
- -0.10100678, 0.0628376, 0.04447668, 0.017961001, -0.10094388,
- -0.10190601, 0.18335468, 0.10494553, -0.052095775, -0.0026118709,
- 0.10539724, -0.04383912, -0.042349473, 0.08438151, -0.1947263,
- 0.02251204, 0.11216432, -0.10307853, 0.17351969, -0.039091777,
- 0.08066188, -0.00561982, 0.12633002, 0.11335965, -0.0088127935,
- -0.019777594, 0.06864014, -0.059751723, 0.016233567, -0.06894641,
- -0.28651384, -0.004228674, 0.019708522, -0.16305895, -0.07468996,
- -0.0855457, 0.099339016, -0.07580735, -0.13775392, 0.08434318,
- 0.08330512, -0.12131499, 0.031935584, 0.09180414, -0.08876437,
- -0.08049874, 0.008753825, 0.03498998, 0.030215185, 0.03907079,
- 0.089751154, 0.029194152, -0.03337423, -0.019092513, 0.04331237,
- 0.04299654, -0.036394123, -0.12915532, 0.09793732, 0.07512415,
- -0.11319543, -0.032502122, 0.15661901, 0.07671967, -0.005491124,
- -0.19379048, -0.218606, 0.21448623, 0.017840758, 0.1416943,
- -0.07051762, 0.19488361, 0.02664691, -0.18104725, -0.09334311,
- 0.15026465, -0.15493552, -0.057762887, -0.11604192, -0.262013,
- -0.01391798, 0.012185008, 0.11156489, -0.07483202, 0.06693364,
- -0.26151478, 0.046425626, 0.036540434, -0.16435726, 0.17338543,
- -0.21401681, -0.11385144, -0.08283257, -0.069031075, 0.030635102,
- 0.010969227, 0.11109743, 0.010919218, 0.027526086, 0.13519906,
- 0.01891392, -0.046839405, -0.040167913, 0.017953383, -0.09700955,
- 0.0061885654, -0.07000971, 0.026893595, -0.038844477, 0.14543656});
-
- static float lstm_input[][20] = {
- {// Batch0: 4 (input_sequence_size) * 5 (n_input)
- 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, 0.596268, 0.998386,
- 0.568695, 0.864524, 0.571277, 0.073204, 0.296072, 0.743333, 0.069199,
- 0.045348, 0.867394, 0.291279, 0.013714, 0.482521, 0.626339},
-
- {// Batch1: 4 (input_sequence_size) * 5 (n_input)
- 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, 0.642421, 0.524260,
- 0.134799, 0.003639, 0.162482, 0.640394, 0.930399, 0.050782, 0.432485,
- 0.988078, 0.082922, 0.563329, 0.865614, 0.333232, 0.259916}};
-
- static float lstm_golden_output[][64] = {
- {// Batch0: 4 (input_sequence_size) * 16 (n_output)
- -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576,
- -0.0211779, 0.0283512, -0.0114597, 0.00907307, -0.0244004,
- -0.0152191, -0.0259063, 0.00914318, 0.00415118, 0.017147,
- 0.0134203, -0.0166936, 0.0381209, 0.000889694, 0.0143363,
- -0.0328911, -0.0234288, 0.0333051, -0.012229, 0.0110322,
- -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308,
- 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794,
- 0.0276012, -0.0263374, -0.0371449, 0.0446149, -0.0205474,
- 0.0103729, -0.0576349, -0.0150052, -0.0292043, 0.0376827,
- 0.0136115, 0.0243435, 0.0354492, -0.0189322, 0.0464512,
- -0.00251373, 0.0225745, -0.0308346, -0.0317124, 0.0460407,
- -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193,
- 0.0286833, 0.00824207, 0.0264887, 0.0305169},
- {// Batch1: 4 (input_sequence_size) * 16 (n_output)
- -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926,
- -0.0186926, 0.0193662, -0.0115437, 0.00422612, -0.0345232,
- 0.00223253, -0.00957321, 0.0210624, 0.013331, 0.0150954,
- 0.02168, -0.0141913, 0.0322082, 0.00227024, 0.0260507,
- -0.0188721, -0.0296489, 0.0399134, -0.0160509, 0.0116039,
- -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233,
- 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378,
- 0.0167673, -0.0375007, -0.0238314, 0.038784, -0.0174034,
- 0.0131743, -0.0506589, -0.0048447, -0.0240239, 0.0325789,
- 0.00790065, 0.0220157, 0.0333314, -0.0264787, 0.0387855,
- -0.000764675, 0.0217599, -0.037537, -0.0335206, 0.0431679,
- -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181,
- 0.0412031, 0.0118723, 0.0239643, 0.0394009}};
+ lstm.SetInputToInputWeights(input_to_input_weights_);
+ lstm.SetInputToCellWeights(input_to_cell_weights_);
+ lstm.SetInputToForgetWeights(input_to_forget_weights_);
+ lstm.SetInputToOutputWeights(input_to_output_weights_);
+
+ lstm.SetInputGateBias(input_gate_bias_);
+ lstm.SetCellBias(cell_gate_bias_);
+ lstm.SetForgetGateBias(forget_gate_bias_);
+ lstm.SetOutputGateBias(output_gate_bias_);
+
+ lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
+ lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
+ lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
+ lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
+
+ lstm.SetCellToInputWeights(cell_to_input_weights_);
+ lstm.SetCellToForgetWeights(cell_to_forget_weights_);
+ lstm.SetCellToOutputWeights(cell_to_output_weights_);
+
+ lstm.SetProjectionWeights(projection_weights_);
// Resetting cell_state and output_state
lstm.ResetCellState();
lstm.ResetOutputState();
- for (int i = 0; i < lstm.sequence_length(); i++) {
- float* batch0_start = lstm_input[0] + i * lstm.num_inputs();
- float* batch0_end = batch0_start + lstm.num_inputs();
+ VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
+}
- lstm.SetInput(2 * i * lstm.num_inputs(), batch0_start, batch0_end);
+TEST_F(NoCifgPeepholeProjectionClippingLstmTest, HybridLstmBlackBoxTest) {
+ const int n_batch = 2;
+ const int n_input = 5;
+ const int n_cell = 20;
+ const int n_output = 16;
+ const int sequence_length = 4;
- float* batch1_start = lstm_input[1] + i * lstm.num_inputs();
- float* batch1_end = batch1_start + lstm.num_inputs();
- lstm.SetInput((2 * i + 1) * lstm.num_inputs(), batch1_start, batch1_end);
- }
+ HybridUnidirectionalLSTMOpModel lstm(
+ n_batch, n_input, n_cell, n_output, sequence_length,
+ /*use_cifg=*/false, /*use_peephole=*/true,
+ /*use_projection_weights=*/true,
+ /*use_projection_bias=*/false,
+ /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ {
+ {sequence_length, n_batch, n_input}, // input tensor
- lstm.Invoke();
+ {n_cell, n_input}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
- std::vector<float> expected;
- for (int i = 0; i < lstm.sequence_length(); i++) {
- float* golden_start_batch0 = lstm_golden_output[0] + i * lstm.num_outputs();
- float* golden_end_batch0 = golden_start_batch0 + lstm.num_outputs();
- float* golden_start_batch1 = lstm_golden_output[1] + i * lstm.num_outputs();
- float* golden_end_batch1 = golden_start_batch1 + lstm.num_outputs();
- expected.insert(expected.end(), golden_start_batch0, golden_end_batch0);
- expected.insert(expected.end(), golden_start_batch1, golden_end_batch1);
- }
- EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
+ {n_cell, n_output}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {n_cell}, // cell_to_input_weight tensor
+ {n_cell}, // cell_to_forget_weight tensor
+ {n_cell}, // cell_to_output_weight tensor
+
+ {n_cell}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {n_output, n_cell}, // projection_weight tensor
+ {0}, // projection_bias tensor
+ });
+
+ lstm.SetInputToInputWeights(input_to_input_weights_);
+ lstm.SetInputToCellWeights(input_to_cell_weights_);
+ lstm.SetInputToForgetWeights(input_to_forget_weights_);
+ lstm.SetInputToOutputWeights(input_to_output_weights_);
+
+ lstm.SetInputGateBias(input_gate_bias_);
+ lstm.SetCellBias(cell_gate_bias_);
+ lstm.SetForgetGateBias(forget_gate_bias_);
+ lstm.SetOutputGateBias(output_gate_bias_);
+
+ lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
+ lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
+ lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
+ lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
+
+ lstm.SetCellToInputWeights(cell_to_input_weights_);
+ lstm.SetCellToForgetWeights(cell_to_forget_weights_);
+ lstm.SetCellToOutputWeights(cell_to_output_weights_);
+
+ lstm.SetProjectionWeights(projection_weights_);
+
+ // Resetting cell_state and output_state
+ lstm.ResetCellState();
+ lstm.ResetOutputState();
+
+ VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
}
} // namespace
diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc
index 7627d89c09..44cef80ac3 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/nnapi_delegate.cc
@@ -29,27 +29,46 @@ limitations under the License.
namespace tflite {
-// TODO(aselle): FATAL leaves resources hanging.
-void FATAL(const char* format, ...) {
+void logError(const char* format, ...) {
+ // TODO(mikie): use android logging, stderr is not captured for Java
+ // applications
va_list args;
va_start(args, format);
vfprintf(stderr, format, args);
va_end(args);
+ fprintf(stderr, "\n");
fflush(stderr);
- exit(1);
}
+#define FATAL(...) \
+ logError(__VA_ARGS__); \
+ exit(1);
+
// TODO(aselle): Change the error model to use status codes.
-#define CHECK_TFLITE_SUCCESS(x) \
- if (x != kTfLiteOk) { \
- FATAL("Aborting since tflite returned failure."); \
+#define CHECK_TFLITE_SUCCESS(x) \
+ if (x != kTfLiteOk) { \
+ FATAL("Aborting since tflite returned failure nnapi_delegate.cc:%d.", \
+ __LINE__); \
+ }
+
+#define CHECK_NN(x) \
+ if (x != ANEURALNETWORKS_NO_ERROR) { \
+ FATAL("Aborting since NNAPI returned failure nnapi_delegate.cc:%d", \
+ __LINE__); \
}
-#define CHECK_NN(x) \
- if (x != ANEURALNETWORKS_NO_ERROR) { \
- FATAL("Aborting since tflite returned failure."); \
+#define RETURN_ERROR_IF_NN_FAILED(x) \
+ if (x != ANEURALNETWORKS_NO_ERROR) { \
+ logError( \
+ "Returning error since NNAPI returned failure nnapi_delegate.cc:%d.", \
+ __LINE__); \
+ return kTfLiteError; \
}
+// Tracking of NNAPI operand ids
+static const int64_t kOperandIdNotSet = -1;
+static const int64_t kOperandNotNeeded = -2;
+
namespace {
int32_t GetAndroidSdkVersion() {
@@ -104,21 +123,16 @@ NNAPIDelegate::~NNAPIDelegate() {
}
// Adds the tensors of the interpreter to the NN API model.
-// Returns the number of operands added.
-uint32_t addTensorOperands(tflite::Interpreter* interpreter,
- ANeuralNetworksModel* nn_model,
- const std::vector<uint32_t>& skip_list) {
+TfLiteStatus addTensorOperands(tflite::Interpreter* interpreter,
+ ANeuralNetworksModel* nn_model,
+ uint32_t* no_of_operands_added,
+ std::vector<int64_t>* nnapi_ids) {
uint32_t next_id = 0;
for (size_t i = 0; i < interpreter->tensors_size(); i++) {
- // skip temporaries tensors.
- bool shouldSkip = false;
- for (auto skip_idx : skip_list) {
- if (i == skip_idx) {
- shouldSkip = true;
- break;
- }
- }
- if (shouldSkip) continue;
+ // Skip temporaries and RNN back-edges.
+ if ((*nnapi_ids)[i] == kOperandNotNeeded) continue;
+
+ (*nnapi_ids)[i] = int64_t(next_id);
int32_t nn_type = 0;
// NNAPI requires 32-bit float scale to be zero, tflite doesn't care
@@ -144,7 +158,18 @@ uint32_t addTensorOperands(tflite::Interpreter* interpreter,
zeroPoint = tensor->params.zero_point;
break;
default:
- FATAL("Unsupported type.");
+ logError("Unsupported tensor type %d", tensor->type);
+ return kTfLiteError;
+ }
+ if (tensor->dims->size == 0) {
+ logError("NNAPI doesn't support tensors with rank 0 (index %d name %s)",
+ i, tensor->name);
+ return kTfLiteError;
+ }
+ if (tensor->dims->size > 4) {
+ logError("NNAPI doesn't support tensors with rank > 4 (index %d name %s)",
+ i, tensor->name);
+ return kTfLiteError;
}
// TODO(aselle): Note, many of these are intermediate results. Do I need
// to ever specify these sizes. I am currently below doing setValue
@@ -154,36 +179,53 @@ uint32_t addTensorOperands(tflite::Interpreter* interpreter,
ANeuralNetworksOperandType operand_type{
nn_type, static_cast<uint32_t>(tensor->dims->size),
reinterpret_cast<uint32_t*>(tensor->dims->data), scale, zeroPoint};
- CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type));
+ RETURN_ERROR_IF_NN_FAILED(
+ ANeuralNetworksModel_addOperand(nn_model, &operand_type));
// TODO(aselle): Based on Michael's suggestion, limiting this to read
// only memory
if (tensor->allocation_type == kTfLiteMmapRo) {
if (const NNAPIAllocation* alloc = dynamic_cast<const NNAPIAllocation*>(
static_cast<const Allocation*>(tensor->allocation))) {
- CHECK_NN(ANeuralNetworksModel_setOperandValueFromMemory(
- nn_model, next_id, alloc->memory(), alloc->offset(tensor->data.raw),
- tensor->bytes));
+ RETURN_ERROR_IF_NN_FAILED(
+ ANeuralNetworksModel_setOperandValueFromMemory(
+ nn_model, next_id, alloc->memory(),
+ alloc->offset(tensor->data.raw), tensor->bytes));
} else {
- CHECK_NN(ANeuralNetworksModel_setOperandValue(
+ RETURN_ERROR_IF_NN_FAILED(ANeuralNetworksModel_setOperandValue(
nn_model, next_id, tensor->data.raw, tensor->bytes));
}
} else if (tensor->bytes == 0) {
// These size 0 tensors are optional tensors reserved.
- CHECK_NN(
+ RETURN_ERROR_IF_NN_FAILED(
ANeuralNetworksModel_setOperandValue(nn_model, next_id, nullptr, 0));
}
++next_id;
}
- return next_id;
+ *no_of_operands_added = next_id;
+ return kTfLiteOk;
+}
+
+void MapAndAddTensorIds(const int* from_ids_buf, size_t from_ids_count,
+ std::vector<uint32_t>* into,
+ const std::vector<int64_t>& map) {
+ for (size_t i = 0; i < from_ids_count; i++) {
+ int from_id = from_ids_buf[i];
+ if (from_id == kOptionalTensor) {
+ into->push_back(from_id);
+ } else {
+ into->push_back(map[from_id]);
+ }
+ }
}
// Adds the operations and their parameters to the NN API model.
// 'next-id' is the operand ID of the next operand of the model.
-void AddOpsAndParams(tflite::Interpreter* interpreter,
- ANeuralNetworksModel* nn_model, uint32_t next_id,
- std::vector<int>* model_state_inputs,
- std::vector<int>* model_state_outputs) {
+TfLiteStatus AddOpsAndParams(
+ tflite::Interpreter* interpreter, ANeuralNetworksModel* nn_model,
+ uint32_t next_id, std::vector<int>* model_state_inputs,
+ std::vector<int>* model_state_outputs,
+ const std::vector<int64_t>& tensor_id_to_nnapi_id) {
for (size_t i = 0; i < interpreter->nodes_size(); i++) {
const auto* node_and_registration = interpreter->node_and_registration(i);
const TfLiteNode& node = node_and_registration->first;
@@ -192,10 +234,11 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
static_cast<tflite::BuiltinOperator>(registration.builtin_code);
// Add the parameters.
- std::vector<uint32_t> augmented_inputs(
- node.inputs->data, node.inputs->data + node.inputs->size);
- std::vector<uint32_t> augmented_outputs(
- node.outputs->data, node.outputs->data + node.outputs->size);
+ std::vector<uint32_t> augmented_inputs, augmented_outputs;
+ MapAndAddTensorIds(node.inputs->data, node.inputs->size, &augmented_inputs,
+ tensor_id_to_nnapi_id);
+ MapAndAddTensorIds(node.outputs->data, node.outputs->size,
+ &augmented_outputs, tensor_id_to_nnapi_id);
auto add_scalar_int32 = [&nn_model, &augmented_inputs,
&next_id](int value) {
@@ -244,42 +287,54 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
model_state_outputs->push_back(tensor_id);
next_id++;
};
+ auto check_and_add_activation = [&add_scalar_int32](int activation) {
+ if (activation > kTfLiteActRelu6) {
+ FATAL("NNAPI only supports RELU, RELU1 and RELU6 activations");
+ }
+ add_scalar_int32(activation);
+ };
auto add_add_params = [&add_scalar_int32](void* data) {
auto* builtin = reinterpret_cast<TfLiteAddParams*>(data);
+ if (builtin->activation > kTfLiteActRelu6) {
+ FATAL("NNAPI only supports RELU, RELU1 and RELU6 activations");
+ }
add_scalar_int32(builtin->activation);
};
- auto add_pooling_params = [&add_scalar_int32](void* data) {
+ auto add_pooling_params = [&add_scalar_int32,
+ &check_and_add_activation](void* data) {
auto builtin = reinterpret_cast<TfLitePoolParams*>(data);
add_scalar_int32(builtin->padding);
add_scalar_int32(builtin->stride_width);
add_scalar_int32(builtin->stride_height);
add_scalar_int32(builtin->filter_width);
add_scalar_int32(builtin->filter_height);
- add_scalar_int32(builtin->activation);
+ check_and_add_activation(builtin->activation);
};
- auto add_convolution_params = [&add_scalar_int32](void* data) {
+ auto add_convolution_params = [&add_scalar_int32,
+ &check_and_add_activation](void* data) {
auto builtin = reinterpret_cast<TfLiteConvParams*>(data);
add_scalar_int32(builtin->padding);
add_scalar_int32(builtin->stride_width);
add_scalar_int32(builtin->stride_height);
- add_scalar_int32(builtin->activation);
+ check_and_add_activation(builtin->activation);
};
- auto add_depthwise_conv_params = [&add_scalar_int32](void* data) {
+ auto add_depthwise_conv_params = [&add_scalar_int32,
+ &check_and_add_activation](void* data) {
auto builtin = reinterpret_cast<TfLiteDepthwiseConvParams*>(data);
add_scalar_int32(builtin->padding);
add_scalar_int32(builtin->stride_width);
add_scalar_int32(builtin->stride_height);
add_scalar_int32(builtin->depth_multiplier);
- add_scalar_int32(builtin->activation);
+ check_and_add_activation(builtin->activation);
};
- auto add_fully_connected_params = [&add_scalar_int32](void* data) {
+ auto add_fully_connected_params = [&check_and_add_activation](void* data) {
auto builtin = reinterpret_cast<TfLiteFullyConnectedParams*>(data);
- add_scalar_int32(builtin->activation);
+ check_and_add_activation(builtin->activation);
};
auto add_concatenation_params = [&add_scalar_int32](void* data) {
@@ -311,6 +366,7 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
// LSTM in NNAPI requires scratch tensor as an output operand.
auto add_lstm_scratch_tensor_float32 = [interpreter, &node, &nn_model,
&next_id, &augmented_outputs]() {
+ if (node.temporaries->size == 0) return;
int scratch_buffer_index = node.temporaries->data[0];
const TfLiteTensor* tensor = interpreter->tensor(scratch_buffer_index);
ANeuralNetworksOperandType operand_type{
@@ -385,7 +441,14 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
add_pooling_params(node.builtin_data);
nn_op_type = ANEURALNETWORKS_L2_POOL_2D;
break;
- case tflite::BuiltinOperator_CONV_2D:
+ case tflite::BuiltinOperator_CONV_2D: {
+ auto builtin = reinterpret_cast<TfLiteConvParams*>(node.builtin_data);
+ if (builtin->dilation_width_factor != 1 ||
+ builtin->dilation_height_factor != 1 || node.inputs->size != 3) {
+ logError("NNAPI does not support dilated Conv2D.");
+ return kTfLiteError;
+ }
+ }
add_convolution_params(node.builtin_data);
nn_op_type = ANEURALNETWORKS_CONV_2D;
break;
@@ -429,6 +492,10 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
nn_op_type = ANEURALNETWORKS_SPACE_TO_DEPTH;
break;
case tflite::BuiltinOperator_LSTM: {
+ if (node.inputs->size + /* no of params */ 3 != 21) {
+ logError("NNAPI only supports 21-input LSTMs");
+ return kTfLiteError;
+ }
duplicate_state_tensor_float32(
node.outputs->data[/*kOutputStateTensor*/ 0]);
duplicate_state_tensor_float32(
@@ -528,12 +595,12 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
case tflite::BuiltinOperator_RSQRT:
case tflite::BuiltinOperator_SHAPE:
case tflite::BuiltinOperator_POW:
- FATAL("Op code %d is currently not delegated to NNAPI", builtin);
- nn_op_type = -1; // set to invalid
+ logError("Op code %d is currently not delegated to NNAPI", builtin);
+ return kTfLiteError;
break;
case tflite::BuiltinOperator_CUSTOM:
- FATAL("Custom operations are not supported when using NNAPI.");
- nn_op_type = -1; // set to invalid
+ logError("Custom operations are not supported when using NNAPI.");
+ return kTfLiteError;
break;
}
@@ -542,47 +609,70 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
}
// Add the operation.
- CHECK_NN(ANeuralNetworksModel_addOperation(
+ RETURN_ERROR_IF_NN_FAILED(ANeuralNetworksModel_addOperation(
nn_model, nn_op_type, static_cast<uint32_t>(augmented_inputs.size()),
augmented_inputs.data(),
static_cast<uint32_t>(augmented_outputs.size()),
reinterpret_cast<uint32_t*>(augmented_outputs.data())));
}
+ return kTfLiteOk;
}
TfLiteStatus NNAPIDelegate::BuildGraph(Interpreter* interpreter) {
- // TODO(aselle): This is not correct. need to handle resize invalidation.
- if (nn_model_ && nn_compiled_model_) return kTfLiteOk;
+ if (nn_model_ && nn_compiled_model_) return model_status_;
+ // TODO(aselle): This is not correct. need to handle resize invalidation.
if (!nn_model_) {
CHECK_NN(ANeuralNetworksModel_create(&nn_model_));
- // Find all the temporary tensors and put them in a skip_list.
- std::vector<uint32_t> skip_list;
+ // Find which tensors should be added to NNAPI. TFLite has temporaries
+ // and RNN back-edges which are are not valid for NNAPI. We look through all
+ // inputs and outputs and mark the mapping in tensor_id_to_nnapi_id with
+ // kOperandIdNotSet. addTensorOperands will replace those with the
+ // corresponding NNAPI operand ids and skip kOperandNotNeeded entries.
+ std::vector<int64_t> tensor_id_to_nnapi_id(interpreter->tensors_size(),
+ kOperandNotNeeded);
+ auto set_ids_to_not_set = [&tensor_id_to_nnapi_id](const int* buf,
+ size_t count) {
+ for (int j = 0; j < count; j++) {
+ auto tensor_id = buf[j];
+ if (tensor_id != kOptionalTensor) {
+ tensor_id_to_nnapi_id[tensor_id] = kOperandIdNotSet;
+ }
+ }
+ };
for (size_t i = 0; i < interpreter->nodes_size(); i++) {
const auto* node_and_registration = interpreter->node_and_registration(i);
const TfLiteNode& node = node_and_registration->first;
- if (node.temporaries != nullptr) {
- for (int j = 0; j < node.temporaries->size; j++) {
- skip_list.push_back(static_cast<uint32_t>(node.temporaries->data[j]));
- }
- }
+ set_ids_to_not_set(node.inputs->data, node.inputs->size);
+ set_ids_to_not_set(node.outputs->data, node.outputs->size);
}
-
- uint32_t next_id = addTensorOperands(interpreter, nn_model_, skip_list);
- AddOpsAndParams(interpreter, nn_model_, next_id, &model_states_inputs_,
- &model_states_outputs_);
-
- std::vector<int> augmented_inputs = interpreter->inputs();
- std::vector<int> augmented_outputs = interpreter->outputs();
-
- // All state tensors input/output need to be treated as model input/output.
+ set_ids_to_not_set(interpreter->inputs().data(),
+ interpreter->inputs().size());
+ set_ids_to_not_set(interpreter->outputs().data(),
+ interpreter->outputs().size());
+
+ uint32_t next_id = 0;
+ RETURN_ERROR_IF_NN_FAILED(addTensorOperands(
+ interpreter, nn_model_, &next_id, &tensor_id_to_nnapi_id));
+ RETURN_ERROR_IF_NN_FAILED(
+ AddOpsAndParams(interpreter, nn_model_, next_id, &model_states_inputs_,
+ &model_states_outputs_, tensor_id_to_nnapi_id));
+
+ std::vector<uint32_t> augmented_inputs;
+ MapAndAddTensorIds(interpreter->inputs().data(),
+ interpreter->inputs().size(), &augmented_inputs,
+ tensor_id_to_nnapi_id);
augmented_inputs.insert(augmented_inputs.end(),
model_states_inputs_.begin(),
model_states_inputs_.end());
- augmented_outputs.insert(augmented_outputs.end(),
- model_states_outputs_.begin(),
- model_states_outputs_.end());
+ std::vector<uint32_t> augmented_outputs;
+ MapAndAddTensorIds(interpreter->outputs().data(),
+ interpreter->outputs().size(), &augmented_outputs,
+ tensor_id_to_nnapi_id);
+ MapAndAddTensorIds(model_states_outputs_.data(),
+ model_states_outputs_.size(), &augmented_outputs,
+ tensor_id_to_nnapi_id);
CHECK_NN(ANeuralNetworksModel_identifyInputsAndOutputs(
nn_model_, static_cast<uint32_t>(augmented_inputs.size()),
@@ -600,7 +690,13 @@ TfLiteStatus NNAPIDelegate::BuildGraph(Interpreter* interpreter) {
TfLiteStatus NNAPIDelegate::Invoke(Interpreter* interpreter) {
if (!nn_model_) {
- TF_LITE_ENSURE_STATUS(BuildGraph(interpreter));
+ model_status_ = BuildGraph(interpreter);
+ if (model_status_ != kTfLiteOk) {
+ logError("Failed to build graph for NNAPI");
+ }
+ }
+ if (model_status_ != kTfLiteOk) {
+ return model_status_;
}
ANeuralNetworksExecution* execution = nullptr;
diff --git a/tensorflow/contrib/lite/nnapi_delegate.h b/tensorflow/contrib/lite/nnapi_delegate.h
index 94dea4f9b2..8dc7d38a30 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.h
+++ b/tensorflow/contrib/lite/nnapi_delegate.h
@@ -59,14 +59,16 @@ class NNAPIDelegate {
ANeuralNetworksModel* nn_model_ = nullptr;
// The NN API compilation handle
ANeuralNetworksCompilation* nn_compiled_model_ = nullptr;
+ // Model status
+ TfLiteStatus model_status_ = kTfLiteOk;
// List of state tensors for LSTM, RNN, SVDF.
// NN API does not allow ops to maintain states across multiple
// invocations. We need to manually create state input tensors from
// corresponding state output tensors of TFLite operations, and map them
// correctly.
- std::vector<int> model_states_inputs_;
- std::vector<int> model_states_outputs_;
+ std::vector<int> model_states_inputs_; // holds NNAPI operand ids
+ std::vector<int> model_states_outputs_; // holds TFLite tensor ids
};
} // namespace tflite
diff --git a/tensorflow/contrib/lite/schema/BUILD b/tensorflow/contrib/lite/schema/BUILD
index 9717a4a1a4..f095151cae 100644
--- a/tensorflow/contrib/lite/schema/BUILD
+++ b/tensorflow/contrib/lite/schema/BUILD
@@ -65,6 +65,7 @@ cc_test(
],
tags = [
"tflite_not_portable_android",
+ "tflite_not_portable_ios",
],
deps = [
"//tensorflow/core:lib_platform",
diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD
index b823c97f38..789bc695f8 100644
--- a/tensorflow/contrib/lite/testing/BUILD
+++ b/tensorflow/contrib/lite/testing/BUILD
@@ -172,6 +172,7 @@ cc_test(
data = ["//tensorflow/contrib/lite:testdata/multi_add.bin"],
tags = [
"tflite_not_portable_android",
+ "tflite_not_portable_ios",
],
deps = [
":tflite_driver",
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 1360f1a273..50237ed792 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -94,8 +94,8 @@ KNOWN_BUGS = {
r"sigmoid.*input_shape=\[\]": "67645668",
# Concat doesn't work with a single input tensor
r"concat.*num_tensors=1": "67378344",
- # Transposition in MatMul is not supported.
- r"fully_connected.*transpose_.=True": "67586970",
+ # Transposition in MatMul is not fully supported.
+ "fully_connected.*transpose_a=True": "67586970",
# Softmax graphs are too complex.
r"softmax.*dim=0": "67749831",
# BatchToSpaceND only supports 4D tensors.
@@ -1325,6 +1325,12 @@ def make_fully_connected_tests(zip_path):
"transpose_a": [False],
"transpose_b": [False],
"constant_filter": [True, False],
+ }, {
+ "shape1": [[40, 37]],
+ "shape2": [[40, 37]],
+ "transpose_a": [False],
+ "transpose_b": [True],
+ "constant_filter": [True, False],
}]
def build_graph(parameters):
diff --git a/tensorflow/contrib/lite/testing/generate_testspec.cc b/tensorflow/contrib/lite/testing/generate_testspec.cc
index c0c861ff6d..c1092e4d25 100644
--- a/tensorflow/contrib/lite/testing/generate_testspec.cc
+++ b/tensorflow/contrib/lite/testing/generate_testspec.cc
@@ -25,7 +25,7 @@ namespace testing {
template <typename T>
void GenerateCsv(const std::vector<int>& shape, float min, float max,
string* out) {
- auto random_float = [](int min, int max) {
+ auto random_float = [](float min, float max) {
static unsigned int seed;
return min + (max - min) * static_cast<float>(rand_r(&seed)) / RAND_MAX;
};
@@ -37,16 +37,10 @@ void GenerateCsv(const std::vector<int>& shape, float min, float max,
*out = Join(data.data(), data.size(), ",");
}
-bool GenerateTestSpecFromTensorflowModel(
- std::iostream& stream, const string& tensorflow_model_path,
- const string& tflite_model_path, const std::vector<string>& input_layer,
+std::vector<string> GenerateInputValues(
+ const std::vector<string>& input_layer,
const std::vector<string>& input_layer_type,
- const std::vector<string>& input_layer_shape,
- const std::vector<string>& output_layer) {
- CHECK_EQ(input_layer.size(), input_layer_type.size());
- CHECK_EQ(input_layer.size(), input_layer_shape.size());
-
- // Generate inputs.
+ const std::vector<string>& input_layer_shape) {
std::vector<string> input_values;
input_values.resize(input_layer.size());
for (int i = 0; i < input_layer.size(); i++) {
@@ -73,9 +67,22 @@ bool GenerateTestSpecFromTensorflowModel(
default:
fprintf(stderr, "Unsupported type %d (%s) when generating testspec.\n",
type, input_layer_type[i].c_str());
- return false;
+ input_values.clear();
+ return input_values;
}
}
+ return input_values;
+}
+
+bool GenerateTestSpecFromTensorflowModel(
+ std::iostream& stream, const string& tensorflow_model_path,
+ const string& tflite_model_path, int num_invocations,
+ const std::vector<string>& input_layer,
+ const std::vector<string>& input_layer_type,
+ const std::vector<string>& input_layer_shape,
+ const std::vector<string>& output_layer) {
+ CHECK_EQ(input_layer.size(), input_layer_type.size());
+ CHECK_EQ(input_layer.size(), input_layer_shape.size());
// Invoke tensorflow model.
TfDriver runner(input_layer, input_layer_type, input_layer_shape,
@@ -91,39 +98,51 @@ bool GenerateTestSpecFromTensorflowModel(
return false;
}
- for (int i = 0; i < input_values.size(); i++) {
- runner.SetInput(i, input_values[i]);
- if (!runner.IsValid()) {
- cerr << runner.GetErrorMessage() << endl;
- return false;
- }
- }
-
- runner.Invoke();
- if (!runner.IsValid()) {
- cerr << runner.GetErrorMessage() << endl;
- return false;
- }
-
- // Write test spec.
+ // Write first part of test spec, defining model and input shapes.
stream << "load_model: " << tflite_model_path << "\n";
stream << "reshape {\n";
for (const auto& shape : input_layer_shape) {
stream << " input: \"" << shape << "\"\n";
}
stream << "}\n";
- stream << "invoke {\n";
- for (const auto& value : input_values) {
- stream << " input: \"" << value << "\"\n";
- }
- for (int i = 0; i < output_layer.size(); i++) {
- stream << " output: \"" << runner.ReadOutput(i) << "\"\n";
+
+ // Generate inputs.
+ for (int i = 0; i < num_invocations; ++i) {
+ // Note that the input values are random, so each invocation will have a
+ // different set.
+ std::vector<string> input_values =
+ GenerateInputValues(input_layer, input_layer_type, input_layer_shape);
+ if (input_values.empty()) return false;
+
+ // Run TensorFlow.
+ for (int j = 0; j < input_values.size(); j++) {
+ runner.SetInput(j, input_values[j]);
+ if (!runner.IsValid()) {
+ cerr << runner.GetErrorMessage() << endl;
+ return false;
+ }
+ }
+
+ runner.Invoke();
if (!runner.IsValid()) {
cerr << runner.GetErrorMessage() << endl;
return false;
}
+
+ // Write second part of test spec, with inputs and outputs.
+ stream << "invoke {\n";
+ for (const auto& value : input_values) {
+ stream << " input: \"" << value << "\"\n";
+ }
+ for (int j = 0; j < output_layer.size(); j++) {
+ stream << " output: \"" << runner.ReadOutput(j) << "\"\n";
+ if (!runner.IsValid()) {
+ cerr << runner.GetErrorMessage() << endl;
+ return false;
+ }
+ }
+ stream << "}\n";
}
- stream << "}\n";
return true;
}
diff --git a/tensorflow/contrib/lite/testing/generate_testspec.h b/tensorflow/contrib/lite/testing/generate_testspec.h
index 6e31a853c3..bfaf5e7ec8 100644
--- a/tensorflow/contrib/lite/testing/generate_testspec.h
+++ b/tensorflow/contrib/lite/testing/generate_testspec.h
@@ -30,13 +30,15 @@ namespace testing {
// stream: mutable iostream that contains the contents of test spec.
// tensorflow_model_path: path to TensorFlow model.
// tflite_model_path: path to tflite_model_path that the test spec runs
+// num_invocations: how many pairs of inputs and outputs will be generated.
// against. input_layer: names of input tensors. Example: input1
// input_layer_type: datatypes of input tensors. Example: float
// input_layer_shape: shapes of input tensors, separated by comma. example:
// 1,3,4 output_layer: names of output tensors. Example: output
bool GenerateTestSpecFromTensorflowModel(
std::iostream& stream, const string& tensorflow_model_path,
- const string& tflite_model_path, const std::vector<string>& input_layer,
+ const string& tflite_model_path, int num_invocations,
+ const std::vector<string>& input_layer,
const std::vector<string>& input_layer_type,
const std::vector<string>& input_layer_shape,
const std::vector<string>& output_layer);
diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
index a86cd5c6cc..c4e20312d8 100644
--- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
+++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
@@ -42,6 +42,7 @@ string* FLAGS_unzip_binary_path = new string("/usr/bin/unzip");
string* FLAGS_unzip_binary_path = new string("/system/bin/unzip");
#endif
bool FLAGS_use_nnapi = false;
+bool FLAGS_ignore_unsupported_nnapi = false;
} // namespace
// TensorFlow system environment for file system called.
@@ -225,16 +226,21 @@ TEST_P(OpsTest, RunZipTests) {
}
bool result = tflite::testing::ParseAndRunTests(&tflite_stream, &test_driver);
+ string message = test_driver.GetErrorMessage();
if (bug_number.empty()) {
- EXPECT_TRUE(result) << test_driver.GetErrorMessage();
+ if (FLAGS_use_nnapi && FLAGS_ignore_unsupported_nnapi && !result) {
+ EXPECT_EQ(message, string("Failed to invoke interpreter")) << message;
+ } else {
+ EXPECT_TRUE(result) << message;
+ }
} else {
if (FLAGS_ignore_known_bugs) {
EXPECT_FALSE(result) << "Test was expected to fail but is now passing; "
"you can mark http://b/"
<< bug_number << " as fixed! Yay!";
} else {
- EXPECT_TRUE(result) << test_driver.GetErrorMessage()
- << ": Possibly due to http://b/" << bug_number;
+ EXPECT_TRUE(result) << message << ": Possibly due to http://b/"
+ << bug_number;
}
}
}
@@ -277,8 +283,11 @@ int main(int argc, char** argv) {
tflite::testing::FLAGS_unzip_binary_path,
"Required: Location of a suitable unzip binary."),
tensorflow::Flag("use_nnapi", &tflite::testing::FLAGS_use_nnapi,
- "Whether to enable the NNAPI delegate")};
-
+ "Whether to enable the NNAPI delegate"),
+ tensorflow::Flag("ignore_unsupported_nnapi",
+ &tflite::testing::FLAGS_ignore_unsupported_nnapi,
+ "Don't fail tests just because delegation to NNAPI "
+ "is not possible")};
bool success = tensorflow::Flags::Parse(&argc, argv, flags);
if (!success || (argc == 2 && !strcmp(argv[1], "--helpfull"))) {
fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str());
diff --git a/tensorflow/contrib/lite/testing/tflite_diff_example_test.cc b/tensorflow/contrib/lite/testing/tflite_diff_example_test.cc
index 5afa0f800c..f2c49fe389 100644
--- a/tensorflow/contrib/lite/testing/tflite_diff_example_test.cc
+++ b/tensorflow/contrib/lite/testing/tflite_diff_example_test.cc
@@ -20,12 +20,29 @@ int main(int argc, char** argv) {
::tflite::testing::DiffOptions options =
::tflite::testing::ParseTfliteDiffFlags(&argc, argv);
if (options.tensorflow_model.empty()) return 1;
+
int failure_count = 0;
- for (int i = 0; i < 100; i++) {
- if (!tflite::testing::RunDiffTest(options)) {
+ for (int i = 0; i < options.num_runs_per_pass; i++) {
+ if (!tflite::testing::RunDiffTest(options, /*num_invocations=*/1)) {
++failure_count;
}
}
- fprintf(stderr, "Num errors: %d\n", failure_count);
+ int failures_in_first_pass = failure_count;
+
+ if (failure_count == 0) {
+ // Let's try again with num_invocations > 1 to make sure we can do multiple
+ // invocations without resetting the interpreter.
+ for (int i = 0; i < options.num_runs_per_pass; i++) {
+ if (!tflite::testing::RunDiffTest(options, /*num_invocations=*/2)) {
+ ++failure_count;
+ }
+ }
+ }
+
+ fprintf(stderr, "Num errors in single-inference pass: %d\n",
+ failures_in_first_pass);
+ fprintf(stderr, "Num errors in multi-inference pass : %d\n",
+ failure_count - failures_in_first_pass);
+
return failure_count != 0 ? 1 : 0;
}
diff --git a/tensorflow/contrib/lite/testing/tflite_diff_flags.h b/tensorflow/contrib/lite/testing/tflite_diff_flags.h
index 706108ed73..7a57e8d3fb 100644
--- a/tensorflow/contrib/lite/testing/tflite_diff_flags.h
+++ b/tensorflow/contrib/lite/testing/tflite_diff_flags.h
@@ -30,6 +30,7 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) {
string input_layer_type;
string input_layer_shape;
string output_layer;
+ int32_t num_runs_per_pass = 100;
} values;
std::vector<tensorflow::Flag> flags = {
@@ -49,6 +50,8 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) {
tensorflow::Flag("output_layer", &values.output_layer,
"Names of output tensors, separated by comma. Example "
"output_1,output_2"),
+ tensorflow::Flag("num_runs_per_pass", &values.num_runs_per_pass,
+ "Number of full runs in each pass."),
};
bool no_inputs = *argc == 1;
@@ -63,7 +66,8 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) {
Split<string>(values.input_layer, ","),
Split<string>(values.input_layer_type, ","),
Split<string>(values.input_layer_shape, ":"),
- Split<string>(values.output_layer, ",")};
+ Split<string>(values.output_layer, ","),
+ values.num_runs_per_pass};
}
} // namespace testing
diff --git a/tensorflow/contrib/lite/testing/tflite_diff_util.cc b/tensorflow/contrib/lite/testing/tflite_diff_util.cc
index f601d3752d..19f34c0a51 100644
--- a/tensorflow/contrib/lite/testing/tflite_diff_util.cc
+++ b/tensorflow/contrib/lite/testing/tflite_diff_util.cc
@@ -25,13 +25,14 @@ limitations under the License.
namespace tflite {
namespace testing {
-bool RunDiffTest(const DiffOptions& options) {
+bool RunDiffTest(const DiffOptions& options, int num_invocations) {
std::stringstream tflite_stream;
if (!GenerateTestSpecFromTensorflowModel(
tflite_stream, options.tensorflow_model, options.tflite_model,
- options.input_layer, options.input_layer_type,
- options.input_layer_shape, options.output_layer))
+ num_invocations, options.input_layer, options.input_layer_type,
+ options.input_layer_shape, options.output_layer)) {
return false;
+ }
TfLiteDriver tflite_driver(/*use_nnapi=*/true);
tflite_driver.LoadModel(options.tflite_model);
return tflite::testing::ParseAndRunTests(&tflite_stream, &tflite_driver);
diff --git a/tensorflow/contrib/lite/testing/tflite_diff_util.h b/tensorflow/contrib/lite/testing/tflite_diff_util.h
index 326fa6c3e2..4ab2f230fd 100644
--- a/tensorflow/contrib/lite/testing/tflite_diff_util.h
+++ b/tensorflow/contrib/lite/testing/tflite_diff_util.h
@@ -40,10 +40,14 @@ struct DiffOptions {
// Names of output tensors.
// Example output_1,output_2
std::vector<string> output_layer;
+ // Number of full runs (from building interpreter to checking outputs) in
+ // each of the passes. The first pass has a single inference, while the
+ // second pass does multiple inferences back to back.
+ int num_runs_per_pass;
};
// Run a single TensorFLow Lite diff test with a given options.
-bool RunDiffTest(const DiffOptions& options);
+bool RunDiffTest(const DiffOptions& options, int num_invocations);
} // namespace testing
} // namespace tflite
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
index 39f55208e4..2f1bb8f0ad 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
@@ -228,6 +228,14 @@ bool HardcodeMinMaxForOutput(Model* model, Operator* op, double min,
return true;
}
+bool MinMaxApproximatelyEqual(const MinMax& minmax1, const MinMax& minmax2) {
+ const double magnitude =
+ std::min(minmax1.max - minmax1.min, minmax2.max - minmax2.min);
+ const double tolerated = 1e-6 * magnitude;
+ return std::abs(minmax1.min - minmax2.min) < tolerated &&
+ std::abs(minmax1.max - minmax2.max) < tolerated;
+}
+
// Propagates MinMax from any of the listed arrays, to all others.
// If multiple of these arrays have MinMax, then these are required
// to agree with each other.
@@ -250,7 +258,7 @@ bool PropagateMinMaxAmongArrays(Model* model,
for (const string& array_name : array_names) {
auto& array = model->GetArray(array_name);
if (array.minmax) {
- CHECK(*array.minmax == *reference_minmax)
+ CHECK(MinMaxApproximatelyEqual(*array.minmax, *reference_minmax))
<< "Both the following arrays have minmax, and they disagree: "
<< reference_array_name << " (" << reference_minmax->min << ","
<< reference_minmax->max << ") and " << array_name << " ("
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc
index 3ca7f53512..c0b014b45e 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc
@@ -35,6 +35,26 @@ std::vector<std::unique_ptr<Operator>>::iterator FindOperator(
return it;
}
+bool ValidateSourceOp(const Model& model, const string& array_name,
+ OperatorType op_type, Operator** source_op) {
+ if (op_type == OperatorType::kNone) {
+ CHECK(!source_op);
+ } else {
+ CHECK(source_op);
+ *source_op = GetOpWithOutput(model, array_name);
+ if (*source_op == nullptr) {
+ return false;
+ }
+
+ // Check that first operator, if connected, is of correct type
+ if ((*source_op)->type != op_type) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
// Returns true if the given operator has exactly 1 input, and is connected to
// the given op_type.
// We use kNone to indicate an input unattached to an operator output. Usually
@@ -47,24 +67,10 @@ bool MatchOperatorInputs(const Operator& op, const Model& model,
}
// Check if first input is disconnected/connected to an operator
- Operator* x = GetOpWithOutput(model, op.inputs[0]);
- if ((op_type == OperatorType::kNone) && (x != nullptr)) {
- return false;
- }
- if ((op_type != OperatorType::kNone) && (x == nullptr)) {
+ if (!ValidateSourceOp(model, op.inputs[0], op_type, connected_op)) {
return false;
}
- // Check that first operator, if connected, is of correct type
- if ((x != nullptr) && (x->type != op_type)) {
- return false;
- }
-
- // Successfully matched. Optionally return matching input operators.
- if (connected_op) {
- *connected_op = x;
- }
-
return true;
}
@@ -81,40 +87,15 @@ bool MatchOperatorInputs(const Operator& op, const Model& model,
}
// Check if first input is disconnected/connected to an operator
- Operator* x = GetOpWithOutput(model, op.inputs[0]);
- if ((a_op_type == OperatorType::kNone) && (x != nullptr)) {
- return false;
- }
- if ((a_op_type != OperatorType::kNone) && (x == nullptr)) {
- return false;
- }
-
- // Check that first operator, if connected, is of correct type
- if ((x != nullptr) && (x->type != a_op_type)) {
+ if (!ValidateSourceOp(model, op.inputs[0], a_op_type, a_op)) {
return false;
}
// Check if second input is disconnected/connected to an operator
- Operator* y = GetOpWithOutput(model, op.inputs[1]);
- if ((b_op_type == OperatorType::kNone) && (y != nullptr)) {
- return false;
- }
- if ((b_op_type != OperatorType::kNone) && (y == nullptr)) {
+ if (!ValidateSourceOp(model, op.inputs[1], b_op_type, b_op)) {
return false;
}
- // Check that second operator, if connected, is of correct type
- if ((y != nullptr) && (y->type != b_op_type)) {
- return false;
- }
-
- // Successfully matched. Optionally return matching input operators.
- if (a_op != nullptr) {
- *a_op = x;
- }
- if (b_op != nullptr) {
- *b_op = y;
- }
return true;
}
@@ -132,57 +113,20 @@ bool MatchOperatorInputs(const Operator& op, const Model& model,
}
// Check if first input is disconnected/connected to an operator
- Operator* x = GetOpWithOutput(model, op.inputs[0]);
- if ((a_op_type == OperatorType::kNone) && (x != nullptr)) {
- return false;
- }
- if ((a_op_type != OperatorType::kNone) && (x == nullptr)) {
- return false;
- }
-
- // Check that first operator, if connected, is of correct type
- if ((x != nullptr) && (x->type != a_op_type)) {
+ if (!ValidateSourceOp(model, op.inputs[0], a_op_type, a_op)) {
return false;
}
// Check if second input is disconnected/connected to an operator
- Operator* y = GetOpWithOutput(model, op.inputs[1]);
- if ((b_op_type == OperatorType::kNone) && (y != nullptr)) {
- return false;
- }
- if ((b_op_type != OperatorType::kNone) && (y == nullptr)) {
- return false;
- }
-
- // Check that second operator, if connected, is of correct type
- if ((y != nullptr) && (y->type != b_op_type)) {
+ if (!ValidateSourceOp(model, op.inputs[1], b_op_type, b_op)) {
return false;
}
// Check if third input is disconnected/connected to an operator
- Operator* z = GetOpWithOutput(model, op.inputs[2]);
- if ((c_op_type == OperatorType::kNone) && (z != nullptr)) {
- return false;
- }
- if ((c_op_type != OperatorType::kNone) && (z == nullptr)) {
+ if (!ValidateSourceOp(model, op.inputs[2], c_op_type, c_op)) {
return false;
}
- // Check that third operator, if connected, is of correct type
- if ((z != nullptr) && (z->type != c_op_type)) {
- return false;
- }
-
- // Successfully matched. Optionally return matching input operators.
- if (a_op != nullptr) {
- *a_op = x;
- }
- if (b_op != nullptr) {
- *b_op = y;
- }
- if (c_op != nullptr) {
- *c_op = z;
- }
return true;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
index d496f5ae5e..fcf30bd347 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
@@ -32,21 +32,34 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) {
const auto* matmul_op =
static_cast<const TensorFlowMatMulOperator*>(matmul_it->get());
+ // Handling transposition of the first input here isn't very simple because
+ // we need to know the actual shape in order to produce a proper
+ // TransposeOperator. However, the second input is supposed to be 2D, so we
+ // can actually handle transposition of that matrix, which happens to be more
+ // common anyway.
+ CHECK(!matmul_op->transpose_a);
+
// Reorder the axes on the second input. TensorFlow uses row-major ordering
// on both inputs, however this is inefficient for the FullyConnected
// operator. We'll transpose the second input to be in column-major order now
// and let constant propagation optimize things (if possible).
- auto* transpose_op = new TransposeOperator;
- transpose_op->inputs = {
- matmul_op->inputs[1],
- CreateInt32Array(
- model,
- AvailableArrayName(*model, matmul_op->inputs[1] + "/transpose/perm"),
- {1, 0})};
- transpose_op->outputs = {
- AvailableArrayName(*model, matmul_op->inputs[1] + "/transpose")};
- model->GetOrCreateArray(transpose_op->outputs[0]);
- model->operators.emplace(matmul_it, transpose_op);
+ string input_lhs = matmul_op->inputs[0];
+ string input_rhs = matmul_op->inputs[1];
+ if (!matmul_op->transpose_b) {
+ auto* transpose_op = new TransposeOperator;
+ transpose_op->inputs = {
+ matmul_op->inputs[1],
+ CreateInt32Array(model,
+ AvailableArrayName(
+ *model, matmul_op->inputs[1] + "/transpose/perm"),
+ {1, 0})};
+ transpose_op->outputs = {
+ AvailableArrayName(*model, matmul_op->inputs[1] + "/transpose")};
+ model->GetOrCreateArray(transpose_op->outputs[0]);
+ model->operators.emplace(matmul_it, transpose_op);
+
+ input_rhs = transpose_op->outputs[0];
+ }
// Refresh iterator.
matmul_it = model->operators.begin();
@@ -57,9 +70,6 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) {
}
DCHECK_EQ(matmul_it->get(), matmul_op);
- string input_lhs = matmul_op->inputs[0];
- string input_rhs = transpose_op->outputs[0];
-
// Construct the new FullyConnectedOperator.
auto* fc_op = new FullyConnectedOperator;
fc_op->outputs = matmul_op->outputs;
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 55e39d963f..5c32a39035 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -984,18 +984,19 @@ tensorflow::Status ConvertMatMulOperator(
Model* model) {
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
- // Transpose flags should be easy to support, but we don't have a
- // GraphDef with them to test on at the moment.
- CHECK_EQ(HasAttr(node, "transpose_a") && GetBoolAttr(node, "transpose_a"),
- false);
- CHECK_EQ(HasAttr(node, "transpose_b") && GetBoolAttr(node, "transpose_b"),
- false);
CHECK(!HasAttr(node, "adjoint_a") ||
(GetBoolAttr(node, "adjoint_a") == false));
CHECK(!HasAttr(node, "adjoint_b") ||
(GetBoolAttr(node, "adjoint_b") == false));
auto* matmul = new TensorFlowMatMulOperator;
+ if (HasAttr(node, "transpose_a")) {
+ matmul->transpose_a = GetBoolAttr(node, "transpose_a");
+ }
+ if (HasAttr(node, "transpose_b")) {
+ matmul->transpose_b = GetBoolAttr(node, "transpose_b");
+ }
+
matmul->inputs = {node.input(0), node.input(1)};
matmul->outputs = {node.name()};
model->operators.emplace_back(matmul);
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index abe0bf3c54..3a1d243f87 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -837,6 +837,8 @@ struct BatchMatMulOperator : Operator {
// TensorFlow equivalent: MatMul
struct TensorFlowMatMulOperator : Operator {
TensorFlowMatMulOperator() : Operator(OperatorType::kMatMul) {}
+ bool transpose_a = false;
+ bool transpose_b = false;
};
// Padding operator. Pads a tensor with zeros.
diff --git a/tensorflow/contrib/lite/tools/BUILD b/tensorflow/contrib/lite/tools/BUILD
index 5913847329..a3df37358f 100644
--- a/tensorflow/contrib/lite/tools/BUILD
+++ b/tensorflow/contrib/lite/tools/BUILD
@@ -53,6 +53,7 @@ cc_test(
],
tags = [
"tflite_not_portable_android",
+ "tflite_not_portable_ios",
],
deps = [
":gen_op_registration",
diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
index c6f3bd6ee1..8c11d8bcfd 100644
--- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py
+++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
@@ -766,7 +766,8 @@ class OptimizerV2(optimizer_v1.Optimizer):
# *after* loss() is evaluated, so we know what loss reduction it uses.
if scale_loss_by_num_towers is None:
scale_loss_by_num_towers = (
- distribute_lib.get_loss_reduction() == "mean")
+ distribute_lib.get_loss_reduction() ==
+ variable_scope.VariableAggregation.MEAN)
if scale_loss_by_num_towers:
num_towers = distribute_lib.get_distribution_strategy().num_towers
if num_towers > 1:
@@ -784,7 +785,8 @@ class OptimizerV2(optimizer_v1.Optimizer):
# Scale loss for number of towers (non-callable-loss case).
if scale_loss_by_num_towers is None:
scale_loss_by_num_towers = (
- distribute_lib.get_loss_reduction() == "mean")
+ distribute_lib.get_loss_reduction() ==
+ variable_scope.VariableAggregation.MEAN)
if scale_loss_by_num_towers:
num_towers = distribute_lib.get_distribution_strategy().num_towers
if num_towers > 1:
@@ -896,7 +898,8 @@ class OptimizerV2(optimizer_v1.Optimizer):
def _distributed_apply(self, distribution, grads_and_vars, global_step, name):
"""`apply_gradients` for use with a `DistributionStrategy`."""
- reduced_grads = distribution.batch_reduce("sum", grads_and_vars)
+ reduced_grads = distribution.batch_reduce(
+ variable_scope.VariableAggregation.SUM, grads_and_vars)
var_list = [v for _, v in grads_and_vars]
grads_and_vars = zip(reduced_grads, var_list)
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 0e6bc03c0b..1c6111a748 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -234,7 +234,6 @@ tf_proto_library(
srcs = [],
cc_api_version = 2,
default_header = True,
- j2objc_api_version = 1,
java_api_version = 2,
js_api_version = 2,
protodeps = [
@@ -1263,6 +1262,7 @@ cc_library(
"//tensorflow/core/kernels:fake_quant_ops",
"//tensorflow/core/kernels:function_ops",
"//tensorflow/core/kernels:functional_ops",
+ "//tensorflow/core/kernels:grappler",
"//tensorflow/core/kernels:histogram_op",
"//tensorflow/core/kernels:image",
"//tensorflow/core/kernels:io",
@@ -2252,7 +2252,6 @@ tf_proto_library(
srcs = ERROR_CODES_PROTO_SRCS,
cc_api_version = 2,
default_header = True,
- j2objc_api_version = 1,
java_api_version = 2,
js_api_version = 2,
provide_cc_alias = True,
@@ -2274,7 +2273,6 @@ tf_proto_library(
srcs = COMMON_PROTO_SRCS + ADDITIONAL_CORE_PROTO_SRCS,
cc_api_version = 2,
default_header = True,
- j2objc_api_version = 1,
java_api_version = 2,
js_api_version = 2,
protodeps = [
diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc
index 39bda9119c..7a2b477845 100644
--- a/tensorflow/core/common_runtime/eager/execute.cc
+++ b/tensorflow/core/common_runtime/eager/execute.cc
@@ -128,7 +128,7 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op, int i,
// We are only here if the policy is warn or silent copies, so we should
// trigger a copy.
auto pre_time = Env::Default()->NowMicros();
- TensorHandle* result_handle;
+ TensorHandle* result_handle = nullptr;
Status status = EagerCopyToDevice(
*handle, ctx, expected_device->name().c_str(), &result_handle);
if (run_metadata != nullptr) {
diff --git a/tensorflow/core/framework/graph_to_functiondef.cc b/tensorflow/core/framework/graph_to_functiondef.cc
index 4ffa503379..b2bc414c49 100644
--- a/tensorflow/core/framework/graph_to_functiondef.cc
+++ b/tensorflow/core/framework/graph_to_functiondef.cc
@@ -153,7 +153,7 @@ Status GraphToFunctionDef(const Graph& graph, const string& name,
const string normalized = node_names.Normalize(node->name());
argdef->set_name(normalized);
Edge const* edge;
- TF_CHECK_OK(node->input_edge(0, &edge));
+ TF_RETURN_IF_ERROR(node->input_edge(0, &edge));
return_values[normalized] =
strings::StrCat(edge->src()->name(), ":", edge->src_output());
continue;
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 72ca3c3fa2..28072c2df3 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -263,6 +263,27 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage<string> {
ctx().nodes_to_preserve->end();
}
+ // TODO(ezhulenev): move to GraphOptimizerStage?
+ bool IsDrivenByControlDependency(const NodeDef& node) const {
+ return std::any_of(node.input().begin(), node.input().end(),
+ IsControlInput);
+ }
+
+ // TODO(ezhulenev): move to GraphOptimizerStage?
+ bool DrivesControlDependency(const NodeDef& node) const {
+ int position;
+ for (const NodeDef* output : ctx().node_map->GetOutputs(node.name())) {
+ for (int i = 0; i < output->input_size(); ++i) {
+ auto input = output->input(i);
+ string name = ParseNodeName(input, &position);
+ if (name == node.name() && /*control input*/ position < 0) {
+ return true;
+ }
+ }
+ }
+ return false;
+ }
+
private:
// Extended context required for ArithmeticOptimizer.
const ArithmeticOptimizerContext ctx_ext_;
@@ -393,27 +414,6 @@ class ArithmeticNodesGroupOptimizerStage : public ArithmeticOptimizerStage {
is_broadcastable);
}
- // TODO(ezhulenev): move to GraphOptimizerStage?
- bool IsDrivenByControlDependency(const NodeDef& node) const {
- return std::any_of(node.input().begin(), node.input().end(),
- IsControlInput);
- }
-
- // TODO(ezhulenev): move to GraphOptimizerStage?
- bool DrivesControlDependency(const NodeDef& node) const {
- int position;
- for (const NodeDef* output : ctx().node_map->GetOutputs(node.name())) {
- for (int i = 0; i < output->input_size(); ++i) {
- auto input = output->input(i);
- string name = ParseNodeName(input, &position);
- if (name == node.name() && /*control input*/ position < 0) {
- return true;
- }
- }
- }
- return false;
- }
-
string ShapeSignature(const TensorShapeProto& shape) const {
string signature = strings::StrCat("rank:", shape.dim_size(), ":dim");
for (int i = 0; i < shape.dim_size(); ++i)
@@ -2719,6 +2719,165 @@ class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage {
}
};
+// Replace a chain of type&shape preserving unary ops with a
+// '_UnaryOpsComposition' node.
+// TODO(ezhulenev): It should be a part of remapper optimizer because it doesn't
+// have to do much with arithmetic (together with FoldMultiplyIntoConv stage?).
+class UnaryOpsComposition : public ArithmeticOptimizerStage {
+ public:
+ explicit UnaryOpsComposition(const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext& ctx_ext)
+ : ArithmeticOptimizerStage("UnaryOpsComposition", ctx, ctx_ext) {
+ // WARN: This should be consistent with unary_ops_composition.cc.
+ // clang-format off
+ supported_ops_ = {// Ops defined via Eigen scalar ops.
+ {"Abs", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Acos", {DT_FLOAT, DT_DOUBLE}},
+ {"Acosh", {DT_FLOAT, DT_DOUBLE}},
+ {"Asin", {DT_FLOAT, DT_DOUBLE}},
+ {"Asinh", {DT_FLOAT, DT_DOUBLE}},
+ {"Atan", {DT_FLOAT, DT_DOUBLE}},
+ {"Atanh", {DT_FLOAT, DT_DOUBLE}},
+ {"Ceil", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Cos", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Cosh", {DT_FLOAT, DT_DOUBLE}},
+ {"Expm1", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Exp", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Floor", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Inv", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Log", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Log1p", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Neg", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Reciprocal", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Rint", {DT_FLOAT, DT_DOUBLE}},
+ {"Round", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Rsqrt", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Sigmoid", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Sin", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Sinh", {DT_FLOAT, DT_DOUBLE}},
+ {"Sqrt", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Square", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Tan", {DT_FLOAT, DT_DOUBLE}},
+ {"Tanh", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ // Additional ops that are not part of the Eigen.
+ {"Elu", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Relu", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Relu6", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Selu", {DT_FLOAT, DT_HALF, DT_DOUBLE}}};
+ // clang-format on
+ }
+ ~UnaryOpsComposition() override = default;
+
+ bool IsSupported(const NodeDef* node) const override {
+ return CanOptimize(*node);
+ }
+
+ Status TrySimplify(NodeDef* root, string* simplified_node_name) override {
+ DataType dtype = root->attr().at("T").type();
+
+ // Keep a trace of all supported input nodes that can be fused together.
+ std::vector<string> op_nodes = {root->name()};
+ std::vector<string> op_names = {root->op()};
+
+ // Check if we should follow input(0) while building an op composition.
+ const auto predicate_fn = [&](const NodeDef& input) {
+ if (input.name() == root->name()) return true;
+
+ bool follow_input_node =
+ dtype == GetDataTypeFromAttr(input, "T") &&
+ NumNonControlDataOutputs(input, *ctx().node_map) == 1 &&
+ CanOptimize(input);
+
+ if (follow_input_node) {
+ op_nodes.push_back(input.name());
+ op_names.push_back(input.op());
+ }
+
+ return follow_input_node;
+ };
+
+ NodeDef* last_op = GetTailOfChain(
+ *root, *ctx().node_map, /*follow_control_input*/ false, predicate_fn);
+
+ // We were not able to find a chain that can be replaced.
+ if (op_names.size() == 1) return Status::OK();
+
+ // Do not add fused nodes to any other chain.
+ std::for_each(op_nodes.begin(), op_nodes.end(),
+ [this](const string& name) { AddToFusedNodes(name); });
+
+ // Reverse the trace to get correct composition computation order.
+ std::reverse(op_names.begin(), op_names.end());
+
+ VLOG(2) << "Fuse unary ops: root=" << root->name() << " op_names=["
+ << str_util::Join(op_names, ", ") << "]";
+
+ NodeDef* composition_node = ctx().optimized_graph->add_node();
+ composition_node->set_name(
+ strings::StrCat(root->name(), "/unary_ops_composition"));
+ composition_node->set_op("_UnaryOpsComposition");
+ composition_node->add_input(last_op->input(0));
+ composition_node->set_device(root->device());
+
+ auto attr = composition_node->mutable_attr();
+ SetAttrValue(dtype, &(*attr)["T"]);
+ SetAttrValue(op_names, &(*attr)["op_names"]);
+
+ ctx().node_map->AddNode(composition_node->name(), composition_node);
+ ctx().node_map->AddOutput(NodeName(last_op->input(0)),
+ composition_node->name());
+
+ *simplified_node_name = composition_node->name();
+
+ return Status::OK();
+ }
+
+ private:
+ bool CanOptimize(const NodeDef& node) const {
+ DataType dtype = GetDataTypeFromAttr(node, "T");
+ if (!IsSupported(node.op(), dtype)) {
+ return false;
+ }
+ if (IsInPreserveSet(node)) {
+ return false;
+ }
+ if (!NodeIsOnCpu(node)) {
+ return false;
+ }
+ if (NodeIsAlreadyFused(node)) {
+ return false;
+ }
+ return !(IsDrivenByControlDependency(node) ||
+ DrivesControlDependency(node));
+ }
+
+ // UnaryOpsComposition is defined only for CPU.
+ bool NodeIsOnCpu(const NodeDef& node) const {
+ using str_util::StartsWith;
+
+ string task;
+ string device;
+
+ return DeviceNameUtils::SplitDeviceName(node.device(), &task, &device) &&
+ StartsWith(device, DEVICE_CPU);
+ }
+
+ bool NodeIsAlreadyFused(const NodeDef& node) const {
+ return fused_nodes_.count(node.name()) > 0;
+ }
+
+ void AddToFusedNodes(const string& name) { fused_nodes_.insert(name); }
+
+ // Check if an op is supported by the _UnaryOpsComposition for the given type.
+ bool IsSupported(const string& op_name, DataType dtype) const {
+ const auto it = supported_ops_.find(op_name);
+ return it != supported_ops_.end() && it->second.count(dtype) > 0;
+ }
+
+ std::unordered_map<string, std::set<DataType>> supported_ops_;
+ std::unordered_set<string> fused_nodes_;
+};
+
} // namespace
class UniqueNodes {
@@ -3001,6 +3160,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
pipeline.AddStage<OptimizeMaxOrMinOfMonotonicStage>(ctx, ctx_ext);
if (options_.convert_expm1)
pipeline.AddStage<ConvertExpm1Stage>(ctx, ctx_ext);
+ if (options_.unary_ops_composition)
+ pipeline.AddStage<UnaryOpsComposition>(ctx, ctx_ext);
VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: "
<< str_util::Join(pipeline.StageNames(), ", ");
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
index 45a5f65b81..551c3652bf 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
@@ -78,6 +78,7 @@ class ArithmeticOptimizer : public GraphOptimizer {
bool convert_pow = true;
bool convert_log1p = true;
bool convert_expm1 = true;
+ bool unary_ops_composition = true;
// Choose which arithmetic optimizer stages will be enabled for a given
// optimization level by default.
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index 3f6c04a5b5..54fdc01adb 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -141,6 +141,9 @@ class ArithmeticOptimizerTest : public GrapplerTest {
options.dedup_computations = false;
options.combine_add_to_addn = false;
options.convert_sqrt_div_to_rsqrt_mul = false;
+ options.convert_pow = false;
+ options.convert_log1p = false;
+ options.optimize_max_or_min_of_monotonic = false;
options.fold_conjugate_into_transpose = false;
options.fold_multiply_into_conv = false;
options.fold_transpose_into_matmul = false;
@@ -158,6 +161,7 @@ class ArithmeticOptimizerTest : public GrapplerTest {
options.reorder_cast_and_transpose = false;
options.replace_mul_with_square = false;
options.simplify_aggregation = false;
+ options.unary_ops_composition = false;
optimizer->options_ = options;
}
@@ -279,6 +283,11 @@ class ArithmeticOptimizerTest : public GrapplerTest {
DisableAllStages(optimizer);
optimizer->options_.convert_expm1 = true;
}
+
+ void EnableOnlyUnaryOpsComposition(ArithmeticOptimizer* optimizer) {
+ DisableAllStages(optimizer);
+ optimizer->options_.unary_ops_composition = true;
+ }
};
TEST_F(ArithmeticOptimizerTest, NoOp) {
@@ -3201,5 +3210,62 @@ TEST_F(ArithmeticOptimizerTest, OptimizeMaxOrMinOfMonotonicElementWise) {
EXPECT_EQ(2, required_node_count);
}
+TEST_F(ArithmeticOptimizerTest, UnaryOpsComposition) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+
+ auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
+ Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
+ Output log = ops::Log(s.WithOpName("log"), sqrt);
+ Output relu = ops::Relu(s.WithOpName("relu"), log);
+ Output final_out = ops::Identity(s.WithOpName("final_out"), relu);
+
+ GrapplerItem item;
+ item.fetch = {"final_out"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ // Place all nodes on CPU.
+ for (int i = 0; i < item.graph.node_size(); ++i) {
+ item.graph.mutable_node(i)->set_device("/device:CPU:0");
+ }
+
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(1, tensors_expected.size());
+
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyUnaryOpsComposition(&optimizer);
+ OptimizeAndPrune(&optimizer, &item, &output);
+
+ EXPECT_EQ(3, output.node_size());
+
+ // Check that Sqrt/Log/Relu were replaced with a single op.
+ int required_node_count = 0;
+ for (int i = 0; i < output.node_size(); ++i) {
+ const NodeDef& node = output.node(i);
+ if (node.name() == "final_out") {
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("relu/unary_ops_composition", node.input(0));
+ ++required_node_count;
+ } else if (node.name() == "relu/unary_ops_composition") {
+ EXPECT_EQ("_UnaryOpsComposition", node.op());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("x", node.input(0));
+
+ auto op_names = node.attr().at("op_names").list().s();
+ EXPECT_EQ(3, op_names.size());
+ EXPECT_EQ("Sqrt", op_names[0]);
+ EXPECT_EQ("Log", op_names[1]);
+ EXPECT_EQ("Relu", op_names[2]);
+ ++required_node_count;
+ }
+ }
+ EXPECT_EQ(2, required_node_count);
+
+ auto tensors = EvaluateNodes(output, item.fetch);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
+}
+
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 3e66d6412a..07360d594b 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -3385,6 +3385,14 @@ cc_library(
],
)
+# Kernels for the nodes intented to be added to the graph by the Grappler optimizers.
+cc_library(
+ name = "grappler",
+ deps = [
+ ":unary_ops_composition",
+ ],
+)
+
NN_DEPS = [
":bounds_check",
":conv_2d",
@@ -3921,6 +3929,7 @@ tf_cc_test(
cc_library(
name = "sparse",
deps = [
+ ":deserialize_sparse_string_op",
":deserialize_sparse_variant_op",
":serialize_sparse_op",
":sparse_add_grad_op",
@@ -4076,10 +4085,18 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "deserialize_sparse_string_op",
+ prefix = "deserialize_sparse_string_op",
+ deps = SPARSE_DEPS + [
+ ":reshape_util",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+tf_kernel_library(
name = "deserialize_sparse_variant_op",
prefix = "deserialize_sparse_variant_op",
deps = SPARSE_DEPS + [
- ":reshape_util",
"//tensorflow/core:protos_all_cc",
],
)
diff --git a/tensorflow/core/kernels/deserialize_sparse_string_op.cc b/tensorflow/core/kernels/deserialize_sparse_string_op.cc
new file mode 100644
index 0000000000..6fb07c11e9
--- /dev/null
+++ b/tensorflow/core/kernels/deserialize_sparse_string_op.cc
@@ -0,0 +1,293 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+
+#include <algorithm>
+#include <numeric>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/tensor_util.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/variant.h"
+#include "tensorflow/core/framework/variant_encode_decode.h"
+#include "tensorflow/core/kernels/reshape_util.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/lib/gtl/optional.h"
+#include "tensorflow/core/util/sparse/sparse_tensor.h"
+
+namespace tensorflow {
+
+namespace {
+
+using sparse::SparseTensor;
+
+class DeserializeSparseOp : public OpKernel {
+ public:
+ explicit DeserializeSparseOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& serialized_sparse = context->input(0);
+ const int ndims = serialized_sparse.shape().dims();
+
+ OP_REQUIRES(
+ context, ndims > 0,
+ errors::InvalidArgument("Serialized sparse should have non-zero rank ",
+ serialized_sparse.shape().DebugString()));
+
+ OP_REQUIRES(context, serialized_sparse.shape().dim_size(ndims - 1) == 3,
+ errors::InvalidArgument(
+ "Serialized sparse should have 3 as the last dimension ",
+ serialized_sparse.shape().DebugString()));
+
+ int num_sparse_tensors = 1;
+ for (int i = 0; i < ndims - 1; ++i) {
+ num_sparse_tensors *= serialized_sparse.shape().dim_size(i);
+ }
+
+ OP_REQUIRES(
+ context, num_sparse_tensors > 0,
+ errors::InvalidArgument(
+ "Serialized sparse should have at least 1 serialized tensor, "
+ "but has a zero dimension ",
+ serialized_sparse.shape().DebugString()));
+
+ if (num_sparse_tensors == 1 && ndims == 1) {
+ // Special case with a single sparse tensor. We can avoid data
+ // motion in the Concat and Reshape.
+ const auto& serialized_sparse_t = serialized_sparse.vec<string>();
+
+ Tensor output_indices;
+ Tensor output_values;
+ Tensor output_shape;
+ OP_REQUIRES_OK(context,
+ this->GetAndValidateSparseTensor(
+ serialized_sparse_t(0), serialized_sparse_t(1),
+ serialized_sparse_t(2), dtype_, 0 /* index */,
+ &output_indices, &output_values, &output_shape));
+ context->set_output(0, output_indices);
+ context->set_output(1, output_values);
+ context->set_output(2, output_shape);
+ return;
+ }
+
+ std::vector<Tensor> indices;
+ std::vector<Tensor> values;
+ TensorShape shape;
+ indices.reserve(num_sparse_tensors);
+ values.reserve(num_sparse_tensors);
+
+ const auto& serialized_sparse_t =
+ serialized_sparse.flat_inner_dims<string, 2>();
+ for (int i = 0; i < num_sparse_tensors; ++i) {
+ Tensor output_indices;
+ Tensor output_values;
+ Tensor output_shape;
+ OP_REQUIRES_OK(context,
+ this->GetAndValidateSparseTensor(
+ serialized_sparse_t(i, 0), serialized_sparse_t(i, 1),
+ serialized_sparse_t(i, 2), dtype_, i, &output_indices,
+ &output_values, &output_shape));
+ int64 num_entries = output_indices.dim_size(0);
+ int rank = output_indices.dim_size(1);
+
+ // Now we expand each SparseTensors' indices and shape by
+ // prefixing a dimension
+ Tensor expanded_indices(DT_INT64, TensorShape({num_entries, 1 + rank}));
+ const auto& output_indices_t = output_indices.matrix<int64>();
+ auto expanded_indices_t = expanded_indices.matrix<int64>();
+ expanded_indices_t.chip<1>(0).setZero();
+ if (rank > 0) {
+ Eigen::DSizes<Eigen::DenseIndex, 2> indices_start(0, 1);
+ Eigen::DSizes<Eigen::DenseIndex, 2> indices_sizes(num_entries, rank);
+ expanded_indices_t.slice(indices_start, indices_sizes) =
+ output_indices_t;
+ }
+ Tensor expanded_shape(DT_INT64, TensorShape({1 + rank}));
+ const auto& output_shape_t = output_shape.vec<int64>();
+ auto expanded_shape_t = expanded_shape.vec<int64>();
+ expanded_shape_t(0) = 1;
+ std::copy_n(&output_shape_t(0), rank, &expanded_shape_t(1));
+
+ TensorShape expanded_tensor_shape(expanded_shape.vec<int64>());
+
+ indices.push_back(expanded_indices);
+ values.push_back(output_values);
+ if (i == 0) {
+ shape = expanded_tensor_shape;
+ } else {
+ OP_REQUIRES(
+ context, shape.dims() == expanded_tensor_shape.dims(),
+ errors::InvalidArgument(
+ "Inconsistent shape across SparseTensors: rank prior to "
+ "SparseTensor[",
+ i, "] was: ", shape.dims() - 1, " but rank of SparseTensor[", i,
+ "] is: ", expanded_tensor_shape.dims() - 1));
+ for (int j = 1; j < shape.dims(); ++j) {
+ // NOTE(mrry): For compatibility with the implementations of
+ // DeserializeManySparse, and many ops that generate
+ // SparseTensors to batch that do not have a fixed
+ // dense_shape (e.g. `tf.parse_single_example()`), we
+ // compute the maximum in each dimension to find the
+ // smallest dense_shape that bounds all of the input
+ // SparseTensors.
+ shape.set_dim(j, std::max(shape.dim_size(j),
+ expanded_tensor_shape.dim_size(j)));
+ }
+ }
+ }
+
+ // Dimension 0 is the primary dimension.
+ int rank = shape.dims();
+ gtl::InlinedVector<int64, 8> std_order(rank);
+ std::iota(std_order.begin(), std_order.end(), 0);
+
+ std::vector<SparseTensor> tensors;
+ tensors.reserve(num_sparse_tensors);
+ for (int i = 0; i < num_sparse_tensors; ++i) {
+ tensors.emplace_back(indices[i], values[i], shape, std_order);
+ }
+
+ gtl::optional<SparseTensor> maybe_output;
+#define HANDLE_TYPE(T) \
+ case DataTypeToEnum<T>::value: { \
+ maybe_output = SparseTensor::Concat<T>(tensors); \
+ break; \
+ }
+
+ switch (dtype_) {
+ TF_CALL_ALL_TYPES(HANDLE_TYPE);
+ TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
+#undef HANDLE_TYPE
+ default:
+ OP_REQUIRES(context, false,
+ errors::Unimplemented(
+ "DeserializeSparse Unhandled data type: ", dtype_));
+ }
+ DCHECK(maybe_output);
+ SparseTensor& output = maybe_output.value();
+
+ // Compute the input shape for the reshape operation.
+ Tensor input_shape(DT_INT64, TensorShape({output.dims()}));
+ std::copy_n(output.shape().data(), output.dims(),
+ input_shape.vec<int64>().data());
+
+ // Compute the target shape for the reshape operation.
+ Tensor target_shape(DT_INT64, TensorShape({ndims + output.dims() - 2}));
+ for (int i = 0; i < ndims - 1; ++i) {
+ target_shape.vec<int64>()(i) = serialized_sparse.shape().dim_size(i);
+ }
+ for (int i = 0; i < output.dims() - 1; ++i) {
+ target_shape.vec<int64>()(i + ndims - 1) = output.shape().data()[i + 1];
+ }
+
+ Tensor output_indices;
+ Tensor output_shape;
+ Reshape(context, output.indices(), input_shape, target_shape,
+ 0 /* output indices index */, 2 /* output shape index */);
+ context->set_output(1, output.values());
+ }
+
+ private:
+ Status Deserialize(const string& serialized, Tensor* result) {
+ TensorProto proto;
+ if (!ParseProtoUnlimited(&proto, serialized)) {
+ return errors::InvalidArgument("Could not parse serialized proto");
+ }
+ Tensor tensor;
+ if (!tensor.FromProto(proto)) {
+ return errors::InvalidArgument("Could not construct tensor from proto");
+ }
+ *result = tensor;
+ return Status::OK();
+ }
+
+ Status GetAndValidateSparseTensor(
+ const string& serialized_indices, const string& serialized_values,
+ const string& serialized_shape, DataType values_dtype, int index,
+ Tensor* output_indices, Tensor* output_values, Tensor* output_shape) {
+ // Deserialize and validate the indices.
+ TF_RETURN_IF_ERROR(this->Deserialize(serialized_indices, output_indices));
+ if (!TensorShapeUtils::IsMatrix(output_indices->shape())) {
+ return errors::InvalidArgument(
+ "Expected serialized_sparse[", index,
+ ", 0] to represent an index matrix but received shape ",
+ output_indices->shape().DebugString());
+ }
+ int64 num_entries = output_indices->dim_size(0);
+ int rank = output_indices->dim_size(1);
+
+ // Deserialize and validate the values.
+ TF_RETURN_IF_ERROR(this->Deserialize(serialized_values, output_values));
+ if (!TensorShapeUtils::IsVector(output_values->shape())) {
+ return errors::InvalidArgument(
+ "Expected serialized_sparse[", index,
+ ", 1] to represent a values vector but received shape ",
+ output_values->shape().DebugString());
+ }
+ if (values_dtype != output_values->dtype()) {
+ return errors::InvalidArgument(
+ "Requested SparseTensor of type ", DataTypeString(values_dtype),
+ " but SparseTensor[", index,
+ "].values.dtype() == ", DataTypeString(output_values->dtype()));
+ }
+ if (num_entries != output_values->dim_size(0)) {
+ return errors::InvalidArgument(
+ "Expected row counts of SparseTensor[", index,
+ "].indices and SparseTensor[", index,
+ "].values to match but they do not: ", num_entries, " vs. ",
+ output_values->dim_size(0));
+ }
+
+ // Deserialize and validate the shape.
+ TF_RETURN_IF_ERROR(this->Deserialize(serialized_shape, output_shape));
+ if (!TensorShapeUtils::IsVector(output_shape->shape())) {
+ return errors::InvalidArgument(
+ "Expected serialized_sparse[", index,
+ ", 1] to be a shape vector but its shape is ",
+ output_shape->shape().DebugString());
+ }
+ if (rank != output_shape->dim_size(0)) {
+ return errors::InvalidArgument("Expected column counts of SparseTensor[",
+ index,
+ "].indices to match size of SparseTensor[",
+ index, "].shape but they do not: ", rank,
+ " vs. ", output_shape->dim_size(0));
+ }
+ return Status::OK();
+ }
+
+ DataType dtype_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("DeserializeSparse")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<string>("Tserialized"),
+ DeserializeSparseOp)
+
+REGISTER_KERNEL_BUILDER(Name("DeserializeManySparse").Device(DEVICE_CPU),
+ DeserializeSparseOp)
+
+} // namespace
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/quantize_and_dequantize_op.h b/tensorflow/core/kernels/quantize_and_dequantize_op.h
index 906d507c8a..782263e4e9 100644
--- a/tensorflow/core/kernels/quantize_and_dequantize_op.h
+++ b/tensorflow/core/kernels/quantize_and_dequantize_op.h
@@ -47,9 +47,13 @@ struct QuantizeAndDequantizeOneScaleImpl {
if (!range_given) {
input_min.device(d) = input.minimum();
input_max.device(d) = input.maximum();
+ d.memcpyDeviceToHost(&min_range, input_min.data(), sizeof(T));
+ d.memcpyDeviceToHost(&max_range, input_max.data(), sizeof(T));
+ } else {
+ // Copy the range values from their respective tensors on the host.
+ min_range = input_min_tensor->scalar<T>()();
+ max_range = input_max_tensor->scalar<T>()();
}
- d.memcpyDeviceToHost(&min_range, input_min.data(), sizeof(T));
- d.memcpyDeviceToHost(&max_range, input_max.data(), sizeof(T));
// Calculate the range for the simulated integer quantization:
// e.g. [-128,127] for signed = true, num_bits = 8,
diff --git a/tensorflow/core/kernels/serialize_sparse_op.cc b/tensorflow/core/kernels/serialize_sparse_op.cc
index 4fea57e6b7..852cef29c7 100644
--- a/tensorflow/core/kernels/serialize_sparse_op.cc
+++ b/tensorflow/core/kernels/serialize_sparse_op.cc
@@ -36,6 +36,8 @@ limitations under the License.
namespace tensorflow {
+namespace {
+
using sparse::SparseTensor;
template <typename T>
@@ -306,257 +308,6 @@ Status SerializeManySparseOpBase<Variant>::Serialize(const Tensor& input,
TF_CALL_ALL_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS
-template <typename T>
-class DeserializeSparseOp : public OpKernel {
- public:
- explicit DeserializeSparseOp(OpKernelConstruction* context)
- : OpKernel(context) {
- OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_));
- }
-
- void Compute(OpKernelContext* context) override {
- const Tensor& serialized_sparse = context->input(0);
- const int ndims = serialized_sparse.shape().dims();
-
- OP_REQUIRES(
- context, ndims > 0,
- errors::InvalidArgument("Serialized sparse should have non-zero rank ",
- serialized_sparse.shape().DebugString()));
-
- OP_REQUIRES(context, serialized_sparse.shape().dim_size(ndims - 1) == 3,
- errors::InvalidArgument(
- "Serialized sparse should have 3 as the last dimension ",
- serialized_sparse.shape().DebugString()));
-
- int num_sparse_tensors = 1;
- for (int i = 0; i < ndims - 1; ++i) {
- num_sparse_tensors *= serialized_sparse.shape().dim_size(i);
- }
-
- OP_REQUIRES(
- context, num_sparse_tensors > 0,
- errors::InvalidArgument(
- "Serialized sparse should have at least 1 serialized tensor, "
- "but has a zero dimension ",
- serialized_sparse.shape().DebugString()));
-
- if (num_sparse_tensors == 1 && ndims == 1) {
- // Special case with a single sparse tensor. We can avoid data
- // motion in the Concat and Reshape.
- const auto& serialized_sparse_t = serialized_sparse.vec<T>();
-
- Tensor output_indices;
- Tensor output_values;
- Tensor output_shape;
- OP_REQUIRES_OK(context,
- this->GetAndValidateSparseTensor(
- serialized_sparse_t(0), serialized_sparse_t(1),
- serialized_sparse_t(2), dtype_, 0 /* index */,
- &output_indices, &output_values, &output_shape));
- context->set_output(0, output_indices);
- context->set_output(1, output_values);
- context->set_output(2, output_shape);
- return;
- }
-
- std::vector<Tensor> indices;
- std::vector<Tensor> values;
- TensorShape shape;
- indices.reserve(num_sparse_tensors);
- values.reserve(num_sparse_tensors);
-
- const auto& serialized_sparse_t = serialized_sparse.flat_inner_dims<T, 2>();
- for (int i = 0; i < num_sparse_tensors; ++i) {
- Tensor output_indices;
- Tensor output_values;
- Tensor output_shape;
- OP_REQUIRES_OK(context,
- this->GetAndValidateSparseTensor(
- serialized_sparse_t(i, 0), serialized_sparse_t(i, 1),
- serialized_sparse_t(i, 2), dtype_, i, &output_indices,
- &output_values, &output_shape));
- int64 num_entries = output_indices.dim_size(0);
- int rank = output_indices.dim_size(1);
-
- // Now we expand each SparseTensors' indices and shape by
- // prefixing a dimension
- Tensor expanded_indices(DT_INT64, TensorShape({num_entries, 1 + rank}));
- const auto& output_indices_t = output_indices.matrix<int64>();
- auto expanded_indices_t = expanded_indices.matrix<int64>();
- expanded_indices_t.chip<1>(0).setZero();
- if (rank > 0) {
- Eigen::DSizes<Eigen::DenseIndex, 2> indices_start(0, 1);
- Eigen::DSizes<Eigen::DenseIndex, 2> indices_sizes(num_entries, rank);
- expanded_indices_t.slice(indices_start, indices_sizes) =
- output_indices_t;
- }
- Tensor expanded_shape(DT_INT64, TensorShape({1 + rank}));
- const auto& output_shape_t = output_shape.vec<int64>();
- auto expanded_shape_t = expanded_shape.vec<int64>();
- expanded_shape_t(0) = 1;
- std::copy_n(&output_shape_t(0), rank, &expanded_shape_t(1));
-
- TensorShape expanded_tensor_shape(expanded_shape.vec<int64>());
-
- indices.push_back(expanded_indices);
- values.push_back(output_values);
- if (i == 0) {
- shape = expanded_tensor_shape;
- } else {
- OP_REQUIRES(
- context, shape.dims() == expanded_tensor_shape.dims(),
- errors::InvalidArgument(
- "Inconsistent shape across SparseTensors: rank prior to "
- "SparseTensor[",
- i, "] was: ", shape.dims() - 1, " but rank of SparseTensor[", i,
- "] is: ", expanded_tensor_shape.dims() - 1));
- for (int j = 1; j < shape.dims(); ++j) {
- // NOTE(mrry): For compatibility with the implementations of
- // DeserializeManySparse, and many ops that generate
- // SparseTensors to batch that do not have a fixed
- // dense_shape (e.g. `tf.parse_single_example()`), we
- // compute the maximum in each dimension to find the
- // smallest dense_shape that bounds all of the input
- // SparseTensors.
- shape.set_dim(j, std::max(shape.dim_size(j),
- expanded_tensor_shape.dim_size(j)));
- }
- }
- }
-
- // Dimension 0 is the primary dimension.
- int rank = shape.dims();
- gtl::InlinedVector<int64, 8> std_order(rank);
- std::iota(std_order.begin(), std_order.end(), 0);
-
- std::vector<SparseTensor> tensors;
- tensors.reserve(num_sparse_tensors);
- for (int i = 0; i < num_sparse_tensors; ++i) {
- tensors.emplace_back(indices[i], values[i], shape, std_order);
- }
-
- gtl::optional<SparseTensor> maybe_output;
-#define HANDLE_TYPE(T) \
- case DataTypeToEnum<T>::value: { \
- maybe_output = SparseTensor::Concat<T>(tensors); \
- break; \
- }
-
- switch (dtype_) {
- TF_CALL_ALL_TYPES(HANDLE_TYPE);
- TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
-#undef HANDLE_TYPE
- default:
- OP_REQUIRES(context, false,
- errors::Unimplemented(
- "DeserializeSparse Unhandled data type: ", dtype_));
- }
- DCHECK(maybe_output);
- SparseTensor& output = maybe_output.value();
-
- // Compute the input shape for the reshape operation.
- Tensor input_shape(DT_INT64, TensorShape({output.dims()}));
- std::copy_n(output.shape().data(), output.dims(),
- input_shape.vec<int64>().data());
-
- // Compute the target shape for the reshape operation.
- Tensor target_shape(DT_INT64, TensorShape({ndims + output.dims() - 2}));
- for (int i = 0; i < ndims - 1; ++i) {
- target_shape.vec<int64>()(i) = serialized_sparse.shape().dim_size(i);
- }
- for (int i = 0; i < output.dims() - 1; ++i) {
- target_shape.vec<int64>()(i + ndims - 1) = output.shape().data()[i + 1];
- }
-
- Tensor output_indices;
- Tensor output_shape;
- Reshape(context, output.indices(), input_shape, target_shape,
- 0 /* output indices index */, 2 /* output shape index */);
- context->set_output(1, output.values());
- }
-
- protected:
- Status Deserialize(const T& serialized, Tensor* result);
-
- Status GetAndValidateSparseTensor(
- const T& serialized_indices, const T& serialized_values,
- const T& serialized_shape, DataType values_dtype, int index,
- Tensor* output_indices, Tensor* output_values, Tensor* output_shape) {
- // Deserialize and validate the indices.
- TF_RETURN_IF_ERROR(this->Deserialize(serialized_indices, output_indices));
- if (!TensorShapeUtils::IsMatrix(output_indices->shape())) {
- return errors::InvalidArgument(
- "Expected serialized_sparse[", index,
- ", 0] to represent an index matrix but received shape ",
- output_indices->shape().DebugString());
- }
- int64 num_entries = output_indices->dim_size(0);
- int rank = output_indices->dim_size(1);
-
- // Deserialize and validate the values.
- TF_RETURN_IF_ERROR(this->Deserialize(serialized_values, output_values));
- if (!TensorShapeUtils::IsVector(output_values->shape())) {
- return errors::InvalidArgument(
- "Expected serialized_sparse[", index,
- ", 1] to represent a values vector but received shape ",
- output_values->shape().DebugString());
- }
- if (values_dtype != output_values->dtype()) {
- return errors::InvalidArgument(
- "Requested SparseTensor of type ", DataTypeString(values_dtype),
- " but SparseTensor[", index,
- "].values.dtype() == ", DataTypeString(output_values->dtype()));
- }
- if (num_entries != output_values->dim_size(0)) {
- return errors::InvalidArgument(
- "Expected row counts of SparseTensor[", index,
- "].indices and SparseTensor[", index,
- "].values to match but they do not: ", num_entries, " vs. ",
- output_values->dim_size(0));
- }
-
- // Deserialize and validate the shape.
- TF_RETURN_IF_ERROR(this->Deserialize(serialized_shape, output_shape));
- if (!TensorShapeUtils::IsVector(output_shape->shape())) {
- return errors::InvalidArgument(
- "Expected serialized_sparse[", index,
- ", 1] to be a shape vector but its shape is ",
- output_shape->shape().DebugString());
- }
- if (rank != output_shape->dim_size(0)) {
- return errors::InvalidArgument("Expected column counts of SparseTensor[",
- index,
- "].indices to match size of SparseTensor[",
- index, "].shape but they do not: ", rank,
- " vs. ", output_shape->dim_size(0));
- }
- return Status::OK();
- }
-
- DataType dtype_;
-};
-
-template <>
-Status DeserializeSparseOp<string>::Deserialize(const string& serialized,
- Tensor* result) {
- TensorProto proto;
- if (!ParseProtoUnlimited(&proto, serialized)) {
- return errors::InvalidArgument("Could not parse serialized proto");
- }
- Tensor tensor;
- if (!tensor.FromProto(proto)) {
- return errors::InvalidArgument("Could not construct tensor from proto");
- }
- *result = tensor;
- return Status::OK();
-}
-
-REGISTER_KERNEL_BUILDER(Name("DeserializeSparse")
- .Device(DEVICE_CPU)
- .TypeConstraint<string>("Tserialized"),
- DeserializeSparseOp<string>)
-
-REGISTER_KERNEL_BUILDER(Name("DeserializeManySparse").Device(DEVICE_CPU),
- DeserializeSparseOp<string>)
+} // namespace
} // namespace tensorflow
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index f3b788f931..e037925961 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -361,7 +361,7 @@ class _ListFetchMapper(_FetchMapper):
for m, vi in zip(self._mappers, self._value_indices):
results.append(m.build_results([values[j] for j in vi]))
# Return a value of the original type of the fetches.
- if self._fetch_type == list:
+ if issubclass(self._fetch_type, list):
return results
elif self._fetch_type == tuple:
return tuple(results)
diff --git a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
index c3d42b49af..89de55dd4f 100644
--- a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
@@ -278,7 +278,7 @@ class PaddedBatchDatasetTest(test.TestCase, parameterized.TestCase):
result = sess.run(get_next)
padded_len = padded_shapes[0]
if padded_len is None or padded_len == -1:
- padded_len = np.max(result)
+ padded_len = np.max(result) if result.size > 0 else 0
self.assertEqual((batch_size, padded_len), result.shape)
for j in range(batch_size):
seq_len = seq_lens[(i * batch_size) + j]
@@ -288,7 +288,7 @@ class PaddedBatchDatasetTest(test.TestCase, parameterized.TestCase):
if not drop_remainder and len(seq_lens) % batch_size > 0:
result = sess.run(get_next)
- padded_len = np.max(result)
+ padded_len = np.max(result) if result.size > 0 else 0
self.assertEqual((len(seq_lens) % batch_size, padded_len),
result.shape)
for j in range(len(seq_lens) % batch_size):
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index 7cb6627615..89265d9575 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -888,7 +888,83 @@ class Dataset(object):
drop_remainder)
def map(self, map_func, num_parallel_calls=None):
- """Maps `map_func` across this dataset.
+ """Maps `map_func` across the elements of this dataset.
+
+ This transformation applies `map_func` to each element of this dataset, and
+ returns a new dataset containing the transformed elements, in the same
+ order as they appeared in the input.
+
+ For example:
+
+ ```python
+ # NOTE: The following examples use `{ ... }` to represent the
+ # contents of a dataset.
+ a = { 1, 2, 3, 4, 5 }
+
+ a.map(lambda x: x + 1) = { 2, 3, 4, 5, 6 }
+ ```
+
+ The input signature of `map_func` is determined by the structure of each
+ element in this dataset. For example:
+
+ ```python
+ # Each element is a `tf.Tensor` object.
+ a = { 1, 2, 3, 4, 5 }
+ # `map_func` takes a single argument of type `tf.Tensor` with the same
+ # shape and dtype.
+ result = a.map(lambda x: ...)
+
+ # Each element is a tuple containing two `tf.Tensor` objects.
+ b = { (1, "foo"), (2, "bar"), (3, "baz") }
+ # `map_func` takes two arguments of type `tf.Tensor`.
+ result = b.map(lambda x_int, y_str: ...)
+
+ # Each element is a dictionary mapping strings to `tf.Tensor` objects.
+ c = { {"a": 1, "b": "foo"}, {"a": 2, "b": "bar"}, {"a": 3, "b": "baz"} }
+ # `map_func` takes a single argument of type `dict` with the same keys as
+ # the elements.
+ result = c.map(lambda d: ...)
+ ```
+
+ The value or values returned by `map_func` determine the structure of each
+ element in the returned dataset.
+
+ ```python
+ # `map_func` returns a scalar `tf.Tensor` of type `tf.float32`.
+ def f(...):
+ return tf.constant(37.0)
+ result = dataset.map(f)
+ result.output_classes == tf.Tensor
+ result.output_types == tf.float32
+ result.output_shapes == [] # scalar
+
+ # `map_func` returns two `tf.Tensor` objects.
+ def g(...):
+ return tf.constant(37.0), tf.constant(["Foo", "Bar", "Baz"])
+ result = dataset.map(g)
+ result.output_classes == (tf.Tensor, tf.Tensor)
+ result.output_types == (tf.float32, tf.string)
+ result.output_shapes == ([], [3])
+
+ # Python primitives, lists, and NumPy arrays are implicitly converted to
+ # `tf.Tensor`.
+ def h(...):
+ return 37.0, ["Foo", "Bar", "Baz"], np.array([1.0, 2.0] dtype=np.float64)
+ result = dataset.map(h)
+ result.output_classes == (tf.Tensor, tf.Tensor, tf.Tensor)
+ result.output_types == (tf.float32, tf.string, tf.float64)
+ result.output_shapes == ([], [3], [2])
+
+ # `map_func` can return nested structures.
+ def i(...):
+ return {"a": 37.0, "b": [42, 16]}, "foo"
+ result.output_classes == ({"a": tf.Tensor, "b": tf.Tensor}, tf.Tensor)
+ result.output_types == ({"a": tf.float32, "b": tf.int32}, tf.string)
+ result.output_shapes == ({"a": [], "b": [2]}, [])
+ ```
+
+ In addition to `tf.Tensor` objects, `map_func` can accept as arguments and
+ return `tf.SparseTensor` objects.
Args:
map_func: A function mapping a nested structure of tensors (having
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 7edcb0931d..08470f65b0 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -801,6 +801,10 @@ class _PolymorphicFunction(object):
graph_function, inputs = self._maybe_define_function(*args, **kwds)
return graph_function(*inputs)
+ def call_python_function(self, *args, **kwargs):
+ """Directly calls the wrapped python function."""
+ return self._python_function(*args, **kwargs)
+
@property
def variables(self):
"""Returns a list of variables used in any of the defined functions."""
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index cf32f6e7fb..e1801b7ec6 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -829,6 +829,25 @@ class FunctionTest(test.TestCase):
out = foo.two(t)
self.assertEqual(float(out), 1.0)
+ def testPythonCallWithSideEffects(self):
+ state = []
+
+ @function.defun
+ def side_effecting_function():
+ state.append(0)
+
+ side_effecting_function()
+ self.assertAllEqual(state, [0])
+
+ # The second invocation should call the graph function, which shouldn't
+ # trigger the list append.
+ side_effecting_function()
+ self.assertAllEqual(state, [0])
+
+ # Whereas calling the python function directly should create a side-effect.
+ side_effecting_function.call_python_function()
+ self.assertAllEqual(state, [0, 0])
+
@test_util.with_c_shapes
class AutomaticControlDependenciesTest(test.TestCase):
diff --git a/tensorflow/python/eager/graph_callable.py b/tensorflow/python/eager/graph_callable.py
index 760a148552..848adf4fd3 100644
--- a/tensorflow/python/eager/graph_callable.py
+++ b/tensorflow/python/eager/graph_callable.py
@@ -110,13 +110,25 @@ class _VariableCapturingScope(object):
"""
# TODO(apassos) ignoring the regularizer and partitioner here; figure out
# how to deal with these.
- def _custom_getter(getter=None, name=None, shape=None, dtype=dtypes.float32, # pylint: disable=missing-docstring
- initializer=None, regularizer=None, reuse=None,
- trainable=True, collections=None, caching_device=None, # pylint: disable=redefined-outer-name
- partitioner=None, validate_shape=True,
- use_resource=None):
+ def _custom_getter( # pylint: disable=missing-docstring
+ getter=None,
+ name=None,
+ shape=None,
+ dtype=dtypes.float32,
+ initializer=None,
+ regularizer=None,
+ reuse=None,
+ trainable=True,
+ collections=None,
+ caching_device=None, # pylint: disable=redefined-outer-name
+ partitioner=None,
+ validate_shape=True,
+ use_resource=None,
+ aggregation=variable_scope.VariableAggregation.NONE,
+ synchronization=variable_scope.VariableSynchronization.AUTO):
del getter, regularizer, partitioner, validate_shape, use_resource, dtype
- del collections, initializer, trainable, reuse, caching_device, shape,
+ del collections, initializer, trainable, reuse, caching_device, shape
+ del aggregation, synchronization
assert name in self.variables
v = self.variables[name]
return v.variable
@@ -136,13 +148,24 @@ class _VariableCapturingScope(object):
"""
# TODO(apassos) ignoring the regularizer and partitioner here; figure out
# how to deal with these.
- def _custom_getter(getter=None, name=None, shape=None, dtype=dtypes.float32, # pylint: disable=missing-docstring
- initializer=None, regularizer=None, reuse=None,
- trainable=True, collections=None, caching_device=None, # pylint: disable=redefined-outer-name
- partitioner=None, validate_shape=True,
- use_resource=None):
+ def _custom_getter( # pylint: disable=missing-docstring
+ getter=None,
+ name=None,
+ shape=None,
+ dtype=dtypes.float32,
+ initializer=None,
+ regularizer=None,
+ reuse=None,
+ trainable=True,
+ collections=None,
+ caching_device=None, # pylint: disable=redefined-outer-name
+ partitioner=None,
+ validate_shape=True,
+ use_resource=None,
+ aggregation=variable_scope.VariableAggregation.NONE,
+ synchronization=variable_scope.VariableSynchronization.AUTO):
del getter, regularizer, collections, caching_device, partitioner
- del use_resource, validate_shape
+ del use_resource, validate_shape, aggregation, synchronization
if name in self.tf_variables:
if reuse:
return self.tf_variables[name].initialized_value()
diff --git a/tensorflow/python/estimator/keras.py b/tensorflow/python/estimator/keras.py
index 5769f5739c..cb37f99704 100644
--- a/tensorflow/python/estimator/keras.py
+++ b/tensorflow/python/estimator/keras.py
@@ -45,6 +45,8 @@ from tensorflow.python.saved_model import signature_constants
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import training_util
+from tensorflow.python.training.checkpointable import base as checkpointable
+from tensorflow.python.training.checkpointable import data_structures
_DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
@@ -241,8 +243,17 @@ def _in_place_subclassed_model_state_restoration(model):
# Restore layers and build attributes
if (hasattr(model, '_original_attributes_cache') and
model._original_attributes_cache is not None):
- model._layers = []
+ # Models have sticky attribute assignment, so we want to be careful to add
+ # back the previous attributes and track Layers by their original names
+ # without adding dependencies on "utility" attributes which Models exempt
+ # when they're constructed.
+ model._layers = data_structures.NoDependency([])
for name, value in model._original_attributes_cache.items():
+ if not isinstance(value, checkpointable.CheckpointableBase):
+ # If this value is not already checkpointable, it's probably that way
+ # for a reason; we don't want to start tracking data structures that the
+ # original Model didn't.
+ value = data_structures.NoDependency(value)
setattr(model, name, value)
model._original_attributes_cache = None
else:
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index 11f99c030f..824513dce0 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -2795,10 +2795,15 @@ class Function(object):
if not isinstance(self.fetches, list):
self.fetches = [self.fetches]
# The main use case of `fetches` being passed to a model is the ability
- # to run custom updates (since the outputs of fetches are never returned).
+ # to run custom updates
# This requires us to wrap fetches in `identity` ops.
self.fetches = [array_ops.identity(x) for x in self.fetches]
self.session_kwargs = session_kwargs
+ # This mapping keeps track of the function that should receive the
+ # output from a fetch in `fetches`: { fetch: function(fetch_output) }
+ # A Callback can use this to register a function with access to the
+ # output values for a fetch it added.
+ self.fetch_callbacks = dict()
if session_kwargs:
raise ValueError('Some keys in session_kwargs are not supported at this '
@@ -2808,6 +2813,7 @@ class Function(object):
self._feed_arrays = None
self._feed_symbols = None
self._symbol_vals = None
+ self._fetches = None
self._session = None
def _make_callable(self, feed_arrays, feed_symbols, symbol_vals, session):
@@ -2853,8 +2859,14 @@ class Function(object):
self._feed_arrays = feed_arrays
self._feed_symbols = feed_symbols
self._symbol_vals = symbol_vals
+ self._fetches = list(self.fetches)
self._session = session
+ def _call_fetch_callbacks(self, fetches_output):
+ for fetch, output in zip(self._fetches, fetches_output):
+ if fetch in self.fetch_callbacks:
+ self.fetch_callbacks[fetch](output)
+
def __call__(self, inputs):
if not isinstance(inputs, (list, tuple)):
raise TypeError('`inputs` should be a list or tuple.')
@@ -2891,14 +2903,14 @@ class Function(object):
np.asarray(self.feed_dict[key], dtype=key.dtype.base_dtype.name))
# Refresh callable if anything has changed.
- if (self._callable_fn is None or
- feed_arrays != self._feed_arrays or
+ if (self._callable_fn is None or feed_arrays != self._feed_arrays or
symbol_vals != self._symbol_vals or
- feed_symbols != self._feed_symbols or
+ feed_symbols != self._feed_symbols or self.fetches != self._fetches or
session != self._session):
self._make_callable(feed_arrays, feed_symbols, symbol_vals, session)
fetched = self._callable_fn(*array_vals)
+ self._call_fetch_callbacks(fetched[-len(self._fetches):])
return fetched[:len(self.outputs)]
diff --git a/tensorflow/python/keras/backend_test.py b/tensorflow/python/keras/backend_test.py
index 0ddffa61a4..36478ea089 100644
--- a/tensorflow/python/keras/backend_test.py
+++ b/tensorflow/python/keras/backend_test.py
@@ -276,6 +276,36 @@ class BackendUtilsTest(test.TestCase):
self.assertEqual(
keras.backend.get_session().run(fetches=[x, y]), [30., 40.])
+ def test_function_fetch_callbacks(self):
+
+ class CallbackStub(object):
+
+ def __init__(self):
+ self.times_called = 0
+ self.callback_result = 0
+
+ def _fetch_callback(self, result):
+ self.times_called += 1
+ self.callback_result = result
+
+ with self.test_session():
+ callback = CallbackStub()
+ x_placeholder = keras.backend.placeholder(shape=())
+ y_placeholder = keras.backend.placeholder(shape=())
+
+ callback_op = x_placeholder * y_placeholder
+
+ f = keras.backend.function(
+ inputs=[x_placeholder, y_placeholder],
+ outputs=[x_placeholder + y_placeholder])
+ f.fetches.append(callback_op)
+ f.fetch_callbacks[callback_op] = callback._fetch_callback
+
+ _ = f([10., 20.])
+
+ self.assertEqual(callback.times_called, 1)
+ self.assertEqual(callback.callback_result, 200)
+
class BackendVariableTest(test.TestCase):
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py
index 00a9c479fb..3ae06d7ab8 100644
--- a/tensorflow/python/keras/callbacks.py
+++ b/tensorflow/python/keras/callbacks.py
@@ -24,6 +24,7 @@ from collections import Iterable
from collections import OrderedDict
import csv
import json
+import math
import os
import time
@@ -723,8 +724,13 @@ class TensorBoard(Callback):
self.write_grads = write_grads
self.write_images = write_images
self.batch_size = batch_size
+ self._current_batch = 0
+ # abstracted writer class to be able to stub for testing
+ self._writer_class = tf_summary.FileWriter
def set_model(self, model):
+ """Sets Keras model and creates summary ops."""
+
self.model = model
self.sess = K.get_session()
if self.histogram_freq and self.merged is None:
@@ -775,54 +781,41 @@ class TensorBoard(Callback):
self.merged = tf_summary.merge_all()
if self.write_graph:
- self.writer = tf_summary.FileWriter(self.log_dir, self.sess.graph)
+ self.writer = self._writer_class(self.log_dir, self.sess.graph)
else:
- self.writer = tf_summary.FileWriter(self.log_dir)
+ self.writer = self._writer_class(self.log_dir)
- def on_epoch_end(self, epoch, logs=None):
- logs = logs or {}
+ def _fetch_callback(self, summary):
+ self.writer.add_summary(
+ summary, self._epoch + self._current_batch / self._batches_per_epoch)
+ self._current_batch += 1
+
+ def on_epoch_begin(self, epoch, logs=None):
+ """Add histogram op to Model test_function callbacks, reset batch count."""
if not self.validation_data and self.histogram_freq:
raise ValueError('If printing histograms, validation_data must be '
'provided, and cannot be a generator.')
- if self.validation_data and self.histogram_freq:
- if epoch % self.histogram_freq == 0:
-
- val_data = self.validation_data
- tensors = (
- self.model.inputs + self.model.targets + self.model.sample_weights)
-
- if self.model.uses_learning_phase:
- tensors += [K.learning_phase()]
-
- assert len(val_data) == len(tensors)
- val_size = val_data[0].shape[0]
- i = 0
- while i < val_size:
- step = min(self.batch_size, val_size - i)
- batch_val = []
- batch_val.append(val_data[0][i:i + step]
- if val_data[0] is not None else None)
- batch_val.append(val_data[1][i:i + step]
- if val_data[1] is not None else None)
- batch_val.append(val_data[2][i:i + step]
- if val_data[2] is not None else None)
- if self.model.uses_learning_phase:
- # do not slice the learning phase
- batch_val = [x[i:i + step] if x is not None else None
- for x in val_data[:-1]]
- batch_val.append(val_data[-1])
- else:
- batch_val = [x[i:i + step] if x is not None else None
- for x in val_data]
- feed_dict = {}
- for key, val in zip(tensors, batch_val):
- if val is not None:
- feed_dict[key] = val
- result = self.sess.run([self.merged], feed_dict=feed_dict)
- summary_str = result[0]
- self.writer.add_summary(summary_str, epoch)
- i += self.batch_size
+ if self.histogram_freq and epoch % self.histogram_freq == 0:
+ self._epoch = epoch
+ self._current_batch = 0
+ self._batches_per_epoch = math.ceil(
+ self.validation_data[0].shape[0] / self.batch_size)
+ if self.merged not in self.model.test_function.fetches:
+ self.model.test_function.fetches.append(self.merged)
+ self.model.test_function.fetch_callbacks[
+ self.merged] = self._fetch_callback
+
+ def on_epoch_end(self, epoch, logs=None):
+ """Checks if summary ops should run next epoch, logs scalar summaries."""
+
+ logs = logs or {}
+
+ if self.histogram_freq and self.histogram_freq > 1:
+ if self.merged in self.model.test_function.fetches:
+ self.model.test_function.fetches.remove(self.merged)
+ if self.merged in self.model.test_function.fetch_callbacks:
+ self.model.test_function.fetch_callbacks.pop(self.merged)
for name, value in logs.items():
if name in ['batch', 'size']:
diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py
index 92d66c95f6..d56f2f5bfc 100644
--- a/tensorflow/python/keras/callbacks_test.py
+++ b/tensorflow/python/keras/callbacks_test.py
@@ -27,6 +27,7 @@ import unittest
import numpy as np
+from tensorflow.core.framework import summary_pb2
from tensorflow.python import keras
from tensorflow.python.keras import testing_utils
from tensorflow.python.platform import test
@@ -901,6 +902,80 @@ class KerasCallbacksTest(test.TestCase):
callbacks=callbacks_factory(histogram_freq=1))
assert os.path.isdir(filepath)
+ def test_Tensorboard_histogram_summaries_in_test_function(self):
+
+ class FileWriterStub(object):
+
+ def __init__(self, logdir, graph=None):
+ self.logdir = logdir
+ self.graph = graph
+ self.steps_seen = []
+
+ def add_summary(self, summary, global_step):
+ summary_obj = summary_pb2.Summary()
+
+ # ensure a valid Summary proto is being sent
+ if isinstance(summary, bytes):
+ summary_obj.ParseFromString(summary)
+ else:
+ assert isinstance(summary, summary_pb2.Summary)
+ summary_obj = summary
+
+ # keep track of steps seen for the merged_summary op,
+ # which contains the histogram summaries
+ if len(summary_obj.value) > 1:
+ self.steps_seen.append(global_step)
+
+ def flush(self):
+ pass
+
+ def close(self):
+ pass
+
+ np.random.seed(1337)
+ tmpdir = self.get_temp_dir()
+ self.addCleanup(shutil.rmtree, tmpdir)
+ (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
+ train_samples=TRAIN_SAMPLES,
+ test_samples=TEST_SAMPLES,
+ input_shape=(INPUT_DIM,),
+ num_classes=NUM_CLASSES)
+ y_test = keras.utils.to_categorical(y_test)
+ y_train = keras.utils.to_categorical(y_train)
+
+ with self.test_session():
+ model = keras.models.Sequential()
+ model.add(
+ keras.layers.Dense(
+ NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu'))
+ # non_trainable_weights: moving_variance, moving_mean
+ model.add(keras.layers.BatchNormalization())
+ model.add(keras.layers.Dense(NUM_CLASSES, activation='softmax'))
+ model.compile(
+ loss='categorical_crossentropy',
+ optimizer='sgd',
+ metrics=['accuracy'])
+ tsb = keras.callbacks.TensorBoard(
+ log_dir=tmpdir,
+ histogram_freq=1,
+ write_images=True,
+ write_grads=True,
+ batch_size=5)
+ tsb._writer_class = FileWriterStub
+ cbks = [tsb]
+
+ # fit with validation data
+ model.fit(
+ x_train,
+ y_train,
+ batch_size=BATCH_SIZE,
+ validation_data=(x_test, y_test),
+ callbacks=cbks,
+ epochs=3,
+ verbose=0)
+
+ self.assertAllEqual(tsb.writer.steps_seen, [0, 0.5, 1, 1.5, 2, 2.5])
+
@unittest.skipIf(
os.name == 'nt',
'use_multiprocessing=True does not work on windows properly.')
diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py
index 4814275fd5..361778570b 100644
--- a/tensorflow/python/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/engine/base_layer.py
@@ -116,6 +116,7 @@ class Layer(checkpointable.CheckpointableBase):
constraints on inputs that can be accepted by the layer.
"""
+ @checkpointable.no_automatic_dependency_tracking
def __init__(self, trainable=True, name=None, dtype=None, **kwargs):
# These properties should be set by the user via keyword arguments.
# note that 'dtype', 'input_shape' and 'batch_input_shape'
@@ -217,7 +218,7 @@ class Layer(checkpointable.CheckpointableBase):
@activity_regularizer.setter
def activity_regularizer(self, regularizer):
"""Optional regularizer function for the output of this layer."""
- self._activity_regularizer = regularizer
+ self._activity_regularizer = self._no_dependency(regularizer)
@property
def trainable_weights(self):
@@ -658,7 +659,8 @@ class Layer(checkpointable.CheckpointableBase):
self._compute_previous_mask):
previous_mask = collect_previous_mask(inputs)
if not hasattr(self, '_call_fn_args'):
- self._call_fn_args = function_utils.fn_args(self.call)
+ self._call_fn_args = self._no_dependency(
+ function_utils.fn_args(self.call))
if ('mask' in self._call_fn_args and 'mask' not in kwargs and
not generic_utils.is_all_none(previous_mask)):
# The previous layer generated a mask, and mask was not explicitly pass
diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py
index aa84eaa8ab..a4d96de74f 100644
--- a/tensorflow/python/keras/engine/network.py
+++ b/tensorflow/python/keras/engine/network.py
@@ -81,6 +81,20 @@ class Network(base_layer.Layer):
# Subclassed network
self._init_subclassed_network(**kwargs)
+ # Several Network methods have "no_automatic_dependency_tracking"
+ # annotations. Since Network does automatic dependency tracking on attribute
+ # assignment, including for common data structures such as lists, by default
+ # we'd have quite a few empty dependencies which users don't care about (or
+ # would need some way to ignore dependencies automatically, which is confusing
+ # when applied to user code). Some attributes, such as _layers, would cause
+ # structural issues (_layers being the place where Layers assigned to tracked
+ # attributes are stored).
+ #
+ # Aside from these aesthetic and structural issues, useless dependencies on
+ # empty lists shouldn't cause issues; adding or removing them will not break
+ # checkpoints, but may cause "all Python objects matched" assertions to fail
+ # (in which case less strict assertions may be substituted if necessary).
+ @checkpointable.no_automatic_dependency_tracking
def _base_init(self, name=None):
# The following are implemented as property functions:
# self.trainable_weights
@@ -135,6 +149,7 @@ class Network(base_layer.Layer):
# restore operations when graph building.
self._in_progress_restore_finalizer = None
+ @checkpointable.no_automatic_dependency_tracking
def _init_graph_network(self, inputs, outputs, name=None):
self._call_convention = base_layer.CallConvention.EXPLICIT_INPUTS_ARGUMENT
# Normalize and set self.inputs, self.outputs.
@@ -293,6 +308,7 @@ class Network(base_layer.Layer):
for layer in self._output_layers:
self.output_names.append(layer.name)
+ @checkpointable.no_automatic_dependency_tracking
def _init_subclassed_network(self, name=None):
self._base_init(name=name)
self._is_graph_network = False
@@ -362,10 +378,31 @@ class Network(base_layer.Layer):
self._track_checkpointable(
layer, name='layer-%d' % layer_index, overwrite=True)
+ def _no_dependency(self, value):
+ """Override to allow `Layer` to disable dependency tracking.
+
+ `CheckpointableBase` defines this method, whose semantics are "if a subclass
+ does dependency tracking, this method exempts `value`." Layer uses
+ `_no_dependency` to exempt some of its attribute assignments (conditional on
+ attribute assignment causing tracking in the subclass).
+
+ Args:
+ value: An object which will be assigned to an object attribute, whose
+ value should not be tracked.
+
+ Returns:
+ A wrapped object which, when assigned to an attribute, will not be
+ tracked (`value` will be stored in the attribute).
+ """
+ return data_structures.NoDependency(value)
+
def __setattr__(self, name, value):
- no_dependency = isinstance(value, checkpointable.NoDependency)
- if no_dependency:
- value = value.value
+ if not getattr(self, '_setattr_tracking', True):
+ super(Network, self).__setattr__(name, value)
+ return
+ no_dependency = isinstance(value, data_structures.NoDependency)
+ value = data_structures.sticky_attribute_assignment(
+ checkpointable=self, value=value, name=name)
if isinstance(value, (
base_layer.Layer,
Network,
@@ -377,7 +414,9 @@ class Network(base_layer.Layer):
'forgot to call `super(YourClass, self).__init__()`.'
' Always start with this line.')
if not is_graph_network:
- if value not in self._layers:
+ # We need to check object identity to avoid de-duplicating empty
+ # container types which compare equal.
+ if not any((layer is value for layer in self._layers)):
self._layers.append(value)
if hasattr(value, '_use_resource_variables'):
# In subclassed models, legacy layers (tf.layers) must always use
@@ -385,12 +424,6 @@ class Network(base_layer.Layer):
value._use_resource_variables = True
if (not no_dependency
and isinstance(value, checkpointable.CheckpointableBase)):
- # Layer (and therefore Network/Model) inherit from CheckpointableBase
- # rather than Checkpointable, which means there is no Checkpointable
- # __setattr__ override (it would be a performance issue for functional
- # layers). Therefore Model tracks Checkpointable objects itself.
- self._track_checkpointable(
- checkpointable=value, name=name, overwrite=True)
if ( # For subclassed models only, users may add extra weights/variables
# simply by assigning them to attributes.
not self._is_graph_network
@@ -493,7 +526,8 @@ class Network(base_layer.Layer):
@property
def layers(self):
- return self._layers
+ return checkpointable_layer_utils.filter_empty_layer_containers(
+ self._layers)
def get_layer(self, name=None, index=None):
"""Retrieves a layer based on either its name (unique) or index.
diff --git a/tensorflow/python/keras/engine/sequential.py b/tensorflow/python/keras/engine/sequential.py
index cd76f08a32..371504a503 100644
--- a/tensorflow/python/keras/engine/sequential.py
+++ b/tensorflow/python/keras/engine/sequential.py
@@ -29,6 +29,7 @@ from tensorflow.python.keras.engine.input_layer import InputLayer
from tensorflow.python.keras.engine.training import Model
from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.util.tf_export import tf_export
@@ -108,6 +109,7 @@ class Sequential(Model):
return self._layers[1:]
return self._layers
+ @checkpointable.no_automatic_dependency_tracking
def add(self, layer):
"""Adds a layer instance on top of the layer stack.
@@ -191,6 +193,7 @@ class Sequential(Model):
else:
self._layers.append(layer)
+ @checkpointable.no_automatic_dependency_tracking
def pop(self):
"""Removes the last layer in the model.
@@ -210,6 +213,7 @@ class Sequential(Model):
self.outputs = [self.layers[-1].output]
self.build()
+ @checkpointable.no_automatic_dependency_tracking
def build(self, input_shape=None):
if input_shape and not self.inputs:
batch_shape = tuple(input_shape)
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index fce6cbdb7a..8e632651fa 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -42,6 +42,7 @@ from tensorflow.python.keras.utils.generic_utils import slice_arrays
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import optimizer as tf_optimizer_module
+from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.util.tf_export import tf_export
@@ -115,6 +116,7 @@ class Model(Network):
# Create a cache for dataset - uninitialized iterators
self._dataset_iterator_cache = weakref.WeakKeyDictionary()
+ @checkpointable.no_automatic_dependency_tracking
def compile(self,
optimizer,
loss=None,
@@ -178,6 +180,11 @@ class Model(Network):
raise ValueError('Only TF native optimizers are supported in Eager mode.')
self.optimizer = optimizers.get(optimizer)
+ # We've disabled automatic dependency tracking for this method, but do want
+ # to add a checkpoint dependency on the optimizer if it's checkpointable.
+ if isinstance(self.optimizer, checkpointable.CheckpointableBase):
+ self._track_checkpointable(
+ self.optimizer, name='optimizer', overwrite=True)
self.loss = loss
self.metrics = metrics or []
self.loss_weights = loss_weights
@@ -941,6 +948,7 @@ class Model(Network):
str(x[0].shape[0]) + ' samples')
return x, y, sample_weights
+ @checkpointable.no_automatic_dependency_tracking
def _set_inputs(self, inputs, training=None):
"""Set model's input and output specs based on the input data received.
@@ -989,6 +997,7 @@ class Model(Network):
else:
self._symbolic_set_inputs(inputs, training=training)
+ @checkpointable.no_automatic_dependency_tracking
def _eager_set_inputs(self, inputs):
"""Set model's input and output specs based on the input data received.
@@ -1041,6 +1050,7 @@ class Model(Network):
'output_%d' % (i + 1) for i in range(len(dummy_output_values))]
self.built = True
+ @checkpointable.no_automatic_dependency_tracking
def _symbolic_set_inputs(self, inputs, outputs=None, training=None):
"""Set model's inputs and output specs based.
diff --git a/tensorflow/python/keras/engine/training_arrays.py b/tensorflow/python/keras/engine/training_arrays.py
index 281ad9bd50..e82f5c0332 100644
--- a/tensorflow/python/keras/engine/training_arrays.py
+++ b/tensorflow/python/keras/engine/training_arrays.py
@@ -124,6 +124,12 @@ def fit_loop(model,
callback_metrics = copy.copy(out_labels) + [
'val_' + n for n in out_labels
]
+ if callbacks is not None and any(
+ [isinstance(callback, cbks.TensorBoard) for callback in callbacks]):
+ # need to create the test_function before start of the first epoch
+ # because TensorBoard callback on_epoch_begin adds summary to the
+ # list of fetches of the test_function
+ model._make_test_function()
else:
callback_metrics = copy.copy(out_labels)
diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py
index d4c213eedd..8b894ca6b1 100644
--- a/tensorflow/python/keras/layers/normalization.py
+++ b/tensorflow/python/keras/layers/normalization.py
@@ -34,6 +34,7 @@ from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.util.tf_export import tf_export
@@ -182,7 +183,8 @@ class BatchNormalization(Layer):
def _add_tower_local_variable(self, *args, **kwargs):
tower_context = distribute_lib.get_tower_context()
- with tower_context.tower_local_var_scope('mean'):
+ with tower_context.tower_local_var_scope(
+ variable_scope.VariableAggregation.MEAN):
return self.add_weight(*args, **kwargs)
def build(self, input_shape):
diff --git a/tensorflow/python/keras/model_subclassing_test.py b/tensorflow/python/keras/model_subclassing_test.py
index b7e16a41dd..3ac4852eff 100644
--- a/tensorflow/python/keras/model_subclassing_test.py
+++ b/tensorflow/python/keras/model_subclassing_test.py
@@ -31,7 +31,7 @@ from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
-from tensorflow.python.training.checkpointable import base as checkpointable
+from tensorflow.python.training.checkpointable import data_structures
from tensorflow.python.training.rmsprop import RMSPropOptimizer
try:
@@ -679,8 +679,8 @@ class ModelSubclassingTest(test.TestCase):
def __init__(self):
super(Foo, self).__init__()
self.isdep = keras.layers.Dense(1)
- self.notdep = checkpointable.NoDependency(keras.layers.Dense(2))
- self.notdep_var = checkpointable.NoDependency(
+ self.notdep = data_structures.NoDependency(keras.layers.Dense(2))
+ self.notdep_var = data_structures.NoDependency(
resource_variable_ops.ResourceVariable(1., name='notdep_var'))
m = Foo()
diff --git a/tensorflow/python/keras/optimizers.py b/tensorflow/python/keras/optimizers.py
index b02cafcf61..0b440185ca 100644
--- a/tensorflow/python/keras/optimizers.py
+++ b/tensorflow/python/keras/optimizers.py
@@ -31,7 +31,7 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import optimizer as tf_optimizer_module
from tensorflow.python.training import training_util
-from tensorflow.python.training.checkpointable import tracking as checkpointable
+from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.util.tf_export import tf_export
@@ -688,12 +688,13 @@ class Nadam(Optimizer):
return dict(list(base_config.items()) + list(config.items()))
-class TFOptimizer(Optimizer, checkpointable.Checkpointable):
+class TFOptimizer(Optimizer, checkpointable.CheckpointableBase):
"""Wrapper class for native TensorFlow optimizers.
"""
def __init__(self, optimizer): # pylint: disable=super-init-not-called
self.optimizer = optimizer
+ self._track_checkpointable(optimizer, name='optimizer')
with K.name_scope(self.__class__.__name__):
self.iterations = K.variable(0, dtype='int64', name='iterations')
diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py
index 1e59a8c9bf..054c6f9dd7 100644
--- a/tensorflow/python/kernel_tests/variable_scope_test.py
+++ b/tensorflow/python/kernel_tests/variable_scope_test.py
@@ -1253,6 +1253,31 @@ class VariableScopeWithCustomGetterTest(test.TestCase):
self.assertEqual(v3, v4)
self.assertEqual(3, called[0]) # skipped one in the first new_scope
+ def testSynchronizationAndAggregationWithCustomGetter(self):
+ called = [0]
+ synchronization = variable_scope.VariableSynchronization.AUTO
+ aggregation = variable_scope.VariableAggregation.NONE
+
+ def custom_getter(getter, *args, **kwargs):
+ called[0] += 1
+
+ # Verify synchronization and aggregation kwargs are as expected.
+ self.assertEqual(kwargs["synchronization"], synchronization)
+ self.assertEqual(kwargs["aggregation"], aggregation)
+ return getter(*args, **kwargs)
+
+ with variable_scope.variable_scope("scope", custom_getter=custom_getter):
+ variable_scope.get_variable("v", [1])
+ self.assertEqual(1, called[0])
+
+ with variable_scope.variable_scope("scope", custom_getter=custom_getter):
+ synchronization = variable_scope.VariableSynchronization.ON_READ
+ aggregation = variable_scope.VariableAggregation.MEAN
+ variable_scope.get_variable(
+ "v1", [1], synchronization=synchronization, aggregation=aggregation)
+
+ self.assertEqual(2, called[0])
+
def testCustomGetterWithReuse(self):
# Custom getter can choose to behave differently on reused variables.
def custom_getter(getter, *args, **kwargs):
@@ -1355,6 +1380,23 @@ class VariableScopeWithCustomGetterTest(test.TestCase):
self.assertAllEqual(variable_names, ["forced_name"])
+ called = [False]
+
+ def creater_c(next_creator, **kwargs):
+ called[0] = True
+ self.assertEqual(kwargs["synchronization"],
+ variable_scope.VariableSynchronization.ON_WRITE)
+ self.assertEqual(kwargs["aggregation"],
+ variable_scope.VariableAggregation.MEAN)
+ return next_creator(**kwargs)
+
+ with variable_scope.variable_creator_scope(creater_c):
+ variable_scope.get_variable(
+ "v", [],
+ synchronization=variable_scope.VariableSynchronization.ON_WRITE,
+ aggregation=variable_scope.VariableAggregation.MEAN)
+ self.assertTrue(called[0])
+
class PartitionInfoTest(test.TestCase):
diff --git a/tensorflow/python/ops/distributions/distribution.py b/tensorflow/python/ops/distributions/distribution.py
index 41dcd40188..c03ef967e6 100644
--- a/tensorflow/python/ops/distributions/distribution.py
+++ b/tensorflow/python/ops/distributions/distribution.py
@@ -212,7 +212,7 @@ class ReparameterizationType(object):
reparameterized, and straight-through gradients are either partially
unsupported or are not supported at all. In this case, for purposes of
e.g. RL or variational inference, it is generally safest to wrap the
- sample results in a `stop_gradients` call and instead use policy
+ sample results in a `stop_gradients` call and use policy
gradients / surrogate loss instead.
"""
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py
index 889a00190e..713a8ab2cc 100644
--- a/tensorflow/python/ops/gradients_impl.py
+++ b/tensorflow/python/ops/gradients_impl.py
@@ -31,6 +31,7 @@ from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
@@ -113,12 +114,14 @@ ops.register_tensor_conversion_function(ops.IndexedSlices,
_IndexedSlicesToTensor)
-def _MarkReachedOps(from_ops, reached_ops):
+def _MarkReachedOps(from_ops, reached_ops, func_graphs):
"""Mark all ops reached from "from_ops".
Args:
from_ops: list of Operations.
reached_ops: set of Operations.
+ func_graphs: list of function._FuncGraphs. This method will traverse through
+ these functions if they capture from_ops or any reachable ops.
"""
queue = collections.deque()
queue.extend(from_ops)
@@ -128,10 +131,11 @@ def _MarkReachedOps(from_ops, reached_ops):
reached_ops.add(op)
for output in op.outputs:
if _IsBackpropagatable(output):
- queue.extend(output.consumers())
+ queue.extend(_Consumers(output, func_graphs))
-def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops):
+def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs,
+ xs):
"""Initialize the pending count for ops between two lists of Operations.
'pending_count[op]' indicates the number of backprop inputs
@@ -141,6 +145,11 @@ def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops):
to_ops: list of Operations.
from_ops: list of Operations.
colocate_gradients_with_ops: Python bool. See docstring of gradients().
+ func_graphs: list of function._FuncGraphs. This method will traverse through
+ these functions if they capture from_ops or any reachable ops. This is
+ useful if to_ops occur in a function and from_ops are in an outer function
+ or graph.
+ xs: list of Tensors.
Returns:
A tuple containing: (1) the subset of to_ops reachable from from_ops by a
@@ -151,7 +160,7 @@ def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops):
"""
# Mark reachable ops from from_ops.
reached_ops = set()
- _MarkReachedOps(from_ops, reached_ops)
+ _MarkReachedOps(from_ops, reached_ops, func_graphs)
# X in reached_ops iff X is reachable from from_ops by a path of zero or more
# backpropagatable tensors.
@@ -170,7 +179,7 @@ def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops):
between_op_list.append(op)
# Clear the boolean so we won't add the inputs again.
reached_ops.remove(op)
- for inp in op.inputs:
+ for inp in _Inputs(op, xs):
queue.append(inp.op)
# X in between_ops iff X is on a path of zero or more backpropagatable tensors
# between from_ops and to_ops
@@ -182,7 +191,7 @@ def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops):
# Initialize pending count for between ops.
pending_count = collections.defaultdict(int)
for op in between_op_list:
- for x in op.inputs:
+ for x in _Inputs(op, xs):
if x.op in between_ops:
pending_count[x.op] += 1
@@ -303,7 +312,7 @@ def _VerifyGeneratedGradients(grads, op):
"inputs %d" % (len(grads), op.node_def, len(op.inputs)))
-def _StopOps(from_ops, stop_gradient_ops, pending_count):
+def _StopOps(from_ops, stop_gradient_ops, pending_count, xs):
"""The set of ops that terminate the gradient computation.
This computes the frontier of the forward graph *before* which backprop
@@ -319,6 +328,7 @@ def _StopOps(from_ops, stop_gradient_ops, pending_count):
from_ops: list of Operations.
stop_gradient_ops: list of Operations never to backprop through.
pending_count: mapping from operation to number of backprop inputs.
+ xs: list of Tensors.
Returns:
The set of operations.
@@ -326,7 +336,7 @@ def _StopOps(from_ops, stop_gradient_ops, pending_count):
stop_ops = set()
for op in from_ops:
is_stop_op = True
- for inp in op.inputs:
+ for inp in _Inputs(op, xs):
if pending_count[inp.op] > 0:
is_stop_op = False
break
@@ -346,10 +356,10 @@ def _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops): # pyli
yield
-def _SymGrad(op, out_grads):
+def _SymGrad(op, out_grads, xs):
"""Backprop through a function call node op given its outputs' gradients."""
- f_in = [x for x in op.inputs] + out_grads
- f_types = [x.dtype for x in op.inputs]
+ f_in = [x for x in _Inputs(op, xs)] + out_grads
+ f_types = [x.dtype for x in _Inputs(op, xs)]
f = attr_value_pb2.NameAttrList()
f.name = op.type
for k in op.node_def.attr:
@@ -399,7 +409,7 @@ def _MaybeCompile(scope, op, func, grad_fn):
return grad_fn()
-def _RaiseNoGradWrtInitialLoopValError(op, from_ops):
+def _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs):
"""Raises an error if we backprop through a loop var."""
# Find the nearest 'to_op' reachable from 'op' to provide a more helpful error
# message.
@@ -413,7 +423,7 @@ def _RaiseNoGradWrtInitialLoopValError(op, from_ops):
if curr_op in from_ops:
target_op = curr_op
break
- queue.extend(t.op for t in curr_op.inputs)
+ queue.extend(t.op for t in _Inputs(curr_op, xs))
assert target_op
raise ValueError(
"Cannot compute gradient inside while loop with respect to op '%s'. "
@@ -423,6 +433,68 @@ def _RaiseNoGradWrtInitialLoopValError(op, from_ops):
% target_op.name)
+def _MaybeCaptured(t):
+ """If t is a captured value placeholder, returns the original captured value.
+
+ Args:
+ t: Tensor
+
+ Returns:
+ A tensor, potentially from a different Graph/function._FuncGraph.
+ """
+ # pylint: disable=protected-access
+ if isinstance(t.op.graph, function._FuncGraph) and t.op.type == "Placeholder":
+ for input_t, placeholder_t in t.op.graph._captured.items():
+ if t == placeholder_t:
+ return _MaybeCaptured(input_t)
+ # pylint: enable=protected-access
+ return t
+
+
+# TODO(skyewm): plumbing xs through everywhere is ugly, consider making
+# _GradientsHelper a class with xs as a member variable.
+def _Inputs(op, xs):
+ """Returns the inputs of op, crossing closure boundaries where necessary.
+
+ Args:
+ op: Operation
+ xs: list of Tensors we are differentiating w.r.t.
+
+ Returns:
+ A list of tensors. The tensors may be from multiple
+ Graph/function._FuncGraphs if op is in a function._FuncGraph and has
+ captured inputs.
+ """
+ if isinstance(op.graph, function._FuncGraph): # pylint: disable=protected-access
+ # If we're differentiating w.r.t. `t`, do not attempt to traverse through it
+ # to a captured value. The algorithm needs to "see" `t` in this case, even
+ # if it's a function input for a captured value, whereas usually we'd like
+ # to traverse through these closures as if the captured value was the direct
+ # input to op.
+ return [t if (t in xs) else _MaybeCaptured(t) for t in op.inputs]
+ else:
+ return op.inputs
+
+
+def _Consumers(t, func_graphs):
+ """Returns the consumers of t, crossing closure boundaries where necessary.
+
+ Args:
+ t: Tensor
+ func_graphs: a list of function._FuncGraphs that may have captured t.
+
+ Returns:
+ A list of tensors. The tensors will be from the current graph and/or
+ func_graphs.
+ """
+ consumers = t.consumers()
+ for func in func_graphs:
+ for input_t, placeholder in func._captured.items(): # pylint: disable=protected-access
+ if input_t == t:
+ consumers.extend(_Consumers(placeholder, func_graphs))
+ return consumers
+
+
@tf_export("gradients")
def gradients(ys,
xs,
@@ -532,6 +604,14 @@ def _GradientsHelper(ys,
if src_graph is None:
src_graph = ops.get_default_graph()
+ # If src_graph is a _FuncGraph (i.e. a function body), gather it and all
+ # ancestor graphs. This is necessary for correctly handling captured values.
+ func_graphs = []
+ curr_graph = src_graph
+ while isinstance(curr_graph, function._FuncGraph): # pylint: disable=protected-access
+ func_graphs.append(curr_graph)
+ curr_graph = curr_graph._outer_graph # pylint: disable=protected-access
+
ys = _AsList(ys)
xs = _AsList(xs)
stop_gradients = [] if stop_gradients is None else _AsList(stop_gradients)
@@ -566,12 +646,13 @@ def _GradientsHelper(ys,
# Initialize the pending count for ops in the connected subgraph from ys
# to the xs.
if len(ys) > 1:
- ys = [array_ops.identity(y) if y.consumers() else y for y in ys]
+ ys = [array_ops.identity(y) if _Consumers(y, func_graphs) else y
+ for y in ys]
to_ops = [t.op for t in ys]
from_ops = [t.op for t in xs]
stop_gradient_ops = [t.op for t in stop_gradients]
reachable_to_ops, pending_count, loop_state = _PendingCount(
- to_ops, from_ops, colocate_gradients_with_ops)
+ to_ops, from_ops, colocate_gradients_with_ops, func_graphs, xs)
# Iterate over the collected ops.
#
@@ -605,7 +686,7 @@ def _GradientsHelper(ys,
_SetGrad(grads, y, loop_state.ZerosLikeForExit(y))
queue.append(y.op)
- stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count)
+ stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs)
while queue:
# generate gradient subgraph for op.
op = queue.popleft()
@@ -654,7 +735,7 @@ def _GradientsHelper(ys,
op._control_flow_context.IsWhileContext() and
op._control_flow_context ==
ops.get_default_graph()._get_control_flow_context()):
- _RaiseNoGradWrtInitialLoopValError(op, from_ops)
+ _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs)
# pylint: enable=protected-access
if (grad_fn or is_func_call) and has_out_grads:
@@ -686,7 +767,7 @@ def _GradientsHelper(ys,
# For function call ops, we add a 'SymbolicGradient'
# node to the graph to compute gradients.
in_grads = _MaybeCompile(grad_scope, op, func_call,
- lambda: _SymGrad(op, out_grads))
+ lambda: _SymGrad(op, out_grads, xs))
in_grads = _AsList(in_grads)
_VerifyGeneratedGradients(in_grads, op)
if gate_gradients and len([x for x in in_grads
@@ -701,8 +782,8 @@ def _GradientsHelper(ys,
else:
# If no grad_fn is defined or none of out_grads is available,
# just propagate a list of None backwards.
- in_grads = [None] * len(op.inputs)
- for i, (t_in, in_grad) in enumerate(zip(op.inputs, in_grads)):
+ in_grads = [None] * len(_Inputs(op, xs))
+ for i, (t_in, in_grad) in enumerate(zip(_Inputs(op, xs), in_grads)):
if in_grad is not None:
if (isinstance(in_grad, ops.Tensor) and
t_in.dtype != dtypes.resource):
@@ -720,7 +801,8 @@ def _GradientsHelper(ys,
loop_state.ExitGradWhileContext(op, before=False)
# Update pending count for the inputs of op and enqueue ready ops.
- _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state)
+ _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state,
+ xs)
if loop_state:
loop_state.PostProcessing()
@@ -739,9 +821,10 @@ def _HasAnyNotNoneGrads(grads, op):
return False
-def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state):
+def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state,
+ xs):
"""Update pending count for the inputs of op and enqueue ready ops."""
- for x in op.inputs:
+ for x in _Inputs(op, xs):
pending_count[x.op] -= 1
ready = (pending_count[x.op] == 0)
if loop_state and not ready:
diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py
index d70cd088c9..d02fcf4ee2 100644
--- a/tensorflow/python/ops/gradients_test.py
+++ b/tensorflow/python/ops/gradients_test.py
@@ -437,6 +437,96 @@ class FunctionGradientsTest(test_util.TensorFlowTestCase):
grad_func=grad_func, python_grad_func=self._PythonGradient)
f.add_to_graph(ops.Graph())
+ def testGradientWrtCaptured(self):
+ with ops.Graph().as_default():
+ x = constant_op.constant(1.0, name="x")
+
+ @function.Defun()
+ def Foo():
+ y = math_ops.multiply(x, 2.0, name="y")
+ g = gradients_impl.gradients(y, x)
+ return g[0]
+
+ f = Foo()
+ with self.test_session() as sess:
+ self.assertEqual(sess.run(f), 2.0)
+
+ def testGradientOfCaptured(self):
+ with ops.Graph().as_default():
+ x = constant_op.constant(1.0, name="x")
+ y = math_ops.multiply(x, 2.0, name="y")
+
+ @function.Defun()
+ def Foo():
+ g = gradients_impl.gradients(y, x)
+ return g[0]
+
+ f = Foo()
+ with self.test_session() as sess:
+ self.assertEqual(sess.run(f), 2.0)
+
+ def testCapturedResourceVariable(self):
+ with ops.Graph().as_default():
+ var = resource_variable_ops.ResourceVariable(1.0, name="var")
+
+ @function.Defun()
+ def Foo():
+ y = math_ops.multiply(var, 2.0, name="y")
+ g = gradients_impl.gradients(y, var)
+ return g[0]
+
+ f = Foo()
+ with self.test_session() as sess:
+ sess.run(variables.global_variables_initializer())
+ self.assertEqual(sess.run(f), 2.0)
+
+ def testCapturedNested(self):
+ with ops.Graph().as_default():
+ x1 = constant_op.constant(1.0, name="x1")
+ x2 = constant_op.constant(2.0, name="x2")
+ x3 = math_ops.multiply(x1, x2, name="x3")
+
+ @function.Defun()
+ def Outer():
+ outer1 = array_ops.identity(x1, name="outer1")
+
+ @function.Defun()
+ def Inner():
+ inner1 = array_ops.identity(outer1, name="inner1")
+ inner2 = array_ops.identity(x2, name="inner2")
+ inner3 = array_ops.identity(x3, name="inner3")
+ return gradients_impl.gradients([inner1, inner2, inner3, x1],
+ [x1, x2])
+
+ return Inner()
+
+ x1_grad, x2_grad = Outer()
+ with self.test_session() as sess:
+ # 1.0 + None + 2.0 + 1.0 = 4.0
+ self.assertEqual(sess.run(x1_grad), 4.0)
+ # None + 1.0 + 1.0 + None = 2.0
+ self.assertEqual(sess.run(x2_grad), 2.0)
+
+ def testCapturedFromFunction(self):
+ with ops.Graph().as_default():
+ x = constant_op.constant(1.0, name="x")
+
+ @function.Defun()
+ def Outer():
+ y = math_ops.multiply(x, 2.0, name="y")
+
+ @function.Defun()
+ def Inner():
+ z = math_ops.multiply(y, 3.0, name="z")
+ g = gradients_impl.gradients(z, y)
+ return g[0]
+
+ return Inner()
+
+ z_grad = Outer()
+ with self.test_session() as sess:
+ self.assertEqual(sess.run(z_grad), 3.0)
+
class StopGradientTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py
index 5eab12c41d..bfd225b0d8 100644
--- a/tensorflow/python/ops/metrics_impl.py
+++ b/tensorflow/python/ops/metrics_impl.py
@@ -73,7 +73,8 @@ def metric_variable(shape, dtype, validate_shape=True, name=None):
A (non-trainable) variable initialized to zero, or if inside a
`DistributionStrategy` scope a tower-local variable container.
"""
- with distribute_lib.get_tower_context().tower_local_var_scope('sum'):
+ with distribute_lib.get_tower_context().tower_local_var_scope(
+ variable_scope.VariableAggregation.SUM):
# Note that "tower local" implies trainable=False.
return variable_scope.variable(
lambda: array_ops.zeros(shape, dtype),
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index 47414c28af..f862b62fad 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -1,4 +1,4 @@
- # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -44,9 +44,11 @@ from tensorflow.python.util import function_utils
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export
-__all__ = ["AUTO_REUSE", "VariableScope", "get_variable_scope",
- "get_variable", "get_local_variable", "variable_scope",
- "variable_op_scope", "no_regularizer"]
+__all__ = [
+ "AUTO_REUSE", "VariableScope", "get_variable_scope", "get_variable",
+ "get_local_variable", "variable_scope", "variable_op_scope",
+ "no_regularizer", "VariableSynchronization", "VariableAggregation"
+]
class _PartitionInfo(object):
@@ -188,6 +190,38 @@ class _ReuseMode(enum.Enum):
# REUSE_FALSE = 2
# REUSE_TRUE = 3
+
+@tf_export("VariableSynchronization")
+class VariableSynchronization(enum.Enum):
+ """Indicates when a distributed variable will be synced."""
+
+ # Indicates that the synchronization will be determined by the current
+ # `DistributionStrategy` (eg. With `MirroredStrategy` this would be
+ # `ON_WRITE`).
+ AUTO = 0
+
+ # Indicates that there will only be one copy of the variable, so there is no
+ # need to sync.
+ NONE = 1
+
+ # Indicates that the variable will be aggregated across devices
+ # every time it is updated.
+ ON_WRITE = 2
+
+ # Indicates that the variable will be aggregated across devices
+ # when it is read (eg. when checkpointing or when evaluating an op that uses
+ # the variable).
+ ON_READ = 3
+
+
+@tf_export("VariableAggregation")
+class VariableAggregation(enum.Enum):
+ """Indicates how a distributed variable will be aggregated."""
+ NONE = 0
+ SUM = 1
+ MEAN = 2
+
+
AUTO_REUSE = _ReuseMode.AUTO_REUSE
tf_export("AUTO_REUSE").export_constant(__name__, "AUTO_REUSE")
AUTO_REUSE.__doc__ = """
@@ -214,11 +248,23 @@ class _VariableStore(object):
self._partitioned_vars = {} # A dict of the stored PartitionedVariables.
self._store_eager_variables = False
- def get_variable(self, name, shape=None, dtype=dtypes.float32,
- initializer=None, regularizer=None, reuse=None,
- trainable=True, collections=None, caching_device=None,
- partitioner=None, validate_shape=True, use_resource=None,
- custom_getter=None, constraint=None):
+ def get_variable(self,
+ name,
+ shape=None,
+ dtype=dtypes.float32,
+ initializer=None,
+ regularizer=None,
+ reuse=None,
+ trainable=True,
+ collections=None,
+ caching_device=None,
+ partitioner=None,
+ validate_shape=True,
+ use_resource=None,
+ custom_getter=None,
+ constraint=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
"""Gets an existing variable with these parameters or create a new one.
If a variable with the given name is already stored, we return the stored
@@ -291,6 +337,14 @@ class _VariableStore(object):
variable and return the Tensor for the projected value
(which must have the same shape). Constraints are not safe to
use when doing asynchronous distributed training.
+ synchronization: Indicates when a distributed a variable will be
+ aggregated. Accepted values are constants defined in the class
+ @{VariableSynchronization}. By default the synchronization is set to
+ `AUTO` and the current `DistributionStrategy` chooses
+ when to synchronize.
+ aggregation: Indicates how a distributed variable will be aggregated.
+ Accepted values are constants defined in the class
+ @{tf.VariableAggregation}.
Returns:
The created or existing `Variable` (or `PartitionedVariable`, if a
@@ -343,11 +397,22 @@ class _VariableStore(object):
# it to custom_getter.
# Note: the parameters of _true_getter, and their documentation, match
# *exactly* item-for-item with the docstring of this method.
- def _true_getter(name, shape=None, dtype=dtypes.float32, # pylint: disable=missing-docstring
- initializer=None, regularizer=None, reuse=None,
- trainable=True, collections=None, caching_device=None,
- partitioner=None, validate_shape=True, use_resource=None,
- constraint=None):
+ def _true_getter( # pylint: disable=missing-docstring
+ name,
+ shape=None,
+ dtype=dtypes.float32,
+ initializer=None,
+ regularizer=None,
+ reuse=None,
+ trainable=True,
+ collections=None,
+ caching_device=None,
+ partitioner=None,
+ validate_shape=True,
+ use_resource=None,
+ constraint=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
is_scalar = (shape is not None
and isinstance(shape, collections_lib.Sequence)
and not shape)
@@ -397,11 +462,20 @@ class _VariableStore(object):
"name was already created with partitioning?" % name)
return self._get_single_variable(
- name=name, shape=shape, dtype=dtype,
- initializer=initializer, regularizer=regularizer, reuse=reuse,
- trainable=trainable, collections=collections,
- caching_device=caching_device, validate_shape=validate_shape,
- use_resource=use_resource, constraint=constraint)
+ name=name,
+ shape=shape,
+ dtype=dtype,
+ initializer=initializer,
+ regularizer=regularizer,
+ reuse=reuse,
+ trainable=trainable,
+ collections=collections,
+ caching_device=caching_device,
+ validate_shape=validate_shape,
+ use_resource=use_resource,
+ constraint=constraint,
+ synchronization=synchronization,
+ aggregation=aggregation)
if custom_getter is not None:
# Handle backwards compatibility with getter arguments that were added
@@ -420,6 +494,8 @@ class _VariableStore(object):
"partitioner": partitioner,
"validate_shape": validate_shape,
"use_resource": use_resource,
+ "synchronization": synchronization,
+ "aggregation": aggregation,
}
# `fn_args` can handle functions, `functools.partial`, `lambda`.
if "constraint" in function_utils.fn_args(custom_getter):
@@ -427,12 +503,21 @@ class _VariableStore(object):
return custom_getter(**custom_getter_kwargs)
else:
return _true_getter(
- name, shape=shape, dtype=dtype,
- initializer=initializer, regularizer=regularizer,
- reuse=reuse, trainable=trainable, collections=collections,
- caching_device=caching_device, partitioner=partitioner,
- validate_shape=validate_shape, use_resource=use_resource,
- constraint=constraint)
+ name,
+ shape=shape,
+ dtype=dtype,
+ initializer=initializer,
+ regularizer=regularizer,
+ reuse=reuse,
+ trainable=trainable,
+ collections=collections,
+ caching_device=caching_device,
+ partitioner=partitioner,
+ validate_shape=validate_shape,
+ use_resource=use_resource,
+ constraint=constraint,
+ synchronization=synchronization,
+ aggregation=aggregation)
def _get_partitioned_variable(
self, name, partitioner, shape=None, dtype=dtypes.float32,
@@ -693,7 +778,9 @@ class _VariableStore(object):
caching_device=None,
validate_shape=True,
use_resource=None,
- constraint=None):
+ constraint=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
"""Get or create a single Variable (e.g. a shard or entire variable).
See the documentation of get_variable above (ignore partitioning components)
@@ -713,6 +800,8 @@ class _VariableStore(object):
validate_shape: see get_variable.
use_resource: see get_variable.
constraint: see get_variable.
+ synchronization: see get_variable.
+ aggregation: see get_variable.
Returns:
A Variable. See documentation of get_variable above.
@@ -793,7 +882,9 @@ class _VariableStore(object):
dtype=variable_dtype,
validate_shape=validate_shape,
constraint=constraint,
- use_resource=use_resource)
+ use_resource=use_resource,
+ synchronization=synchronization,
+ aggregation=aggregation)
if context.executing_eagerly() and self._store_eager_variables:
if collections:
ops.add_to_collections(collections, v)
@@ -1052,7 +1143,9 @@ class VariableScope(object):
validate_shape=True,
use_resource=None,
custom_getter=None,
- constraint=None):
+ constraint=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
"""Gets an existing variable with this name or create a new one."""
if regularizer is None:
regularizer = self._regularizer
@@ -1090,12 +1183,22 @@ class VariableScope(object):
if dtype is None:
dtype = self._dtype
return var_store.get_variable(
- full_name, shape=shape, dtype=dtype, initializer=initializer,
- regularizer=regularizer, reuse=reuse, trainable=trainable,
- collections=collections, caching_device=caching_device,
- partitioner=partitioner, validate_shape=validate_shape,
- use_resource=use_resource, custom_getter=custom_getter,
- constraint=constraint)
+ full_name,
+ shape=shape,
+ dtype=dtype,
+ initializer=initializer,
+ regularizer=regularizer,
+ reuse=reuse,
+ trainable=trainable,
+ collections=collections,
+ caching_device=caching_device,
+ partitioner=partitioner,
+ validate_shape=validate_shape,
+ use_resource=use_resource,
+ custom_getter=custom_getter,
+ constraint=constraint,
+ synchronization=synchronization,
+ aggregation=aggregation)
def _get_partitioned_variable(self,
var_store,
@@ -1326,14 +1429,28 @@ def get_variable(name,
validate_shape=True,
use_resource=None,
custom_getter=None,
- constraint=None):
+ constraint=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
return get_variable_scope().get_variable(
- _get_default_variable_store(), name, shape=shape, dtype=dtype,
- initializer=initializer, regularizer=regularizer, trainable=trainable,
- collections=collections, caching_device=caching_device,
- partitioner=partitioner, validate_shape=validate_shape,
- use_resource=use_resource, custom_getter=custom_getter,
- constraint=constraint)
+ _get_default_variable_store(),
+ name,
+ shape=shape,
+ dtype=dtype,
+ initializer=initializer,
+ regularizer=regularizer,
+ trainable=trainable,
+ collections=collections,
+ caching_device=caching_device,
+ partitioner=partitioner,
+ validate_shape=validate_shape,
+ use_resource=use_resource,
+ custom_getter=custom_getter,
+ constraint=constraint,
+ synchronization=synchronization,
+ aggregation=aggregation)
+
+
get_variable_or_local_docstring = (
"""%s
@@ -1430,29 +1547,44 @@ get_variable.__doc__ = get_variable_or_local_docstring % (
# The argument list for get_local_variable must match arguments to get_variable.
# So, if you are updating the arguments, also update arguments to get_variable.
@tf_export("get_local_variable")
-def get_local_variable(name,
- shape=None,
- dtype=None,
- initializer=None,
- regularizer=None,
- trainable=False, # pylint: disable=unused-argument
- collections=None,
- caching_device=None,
- partitioner=None,
- validate_shape=True,
- use_resource=None,
- custom_getter=None,
- constraint=None):
+def get_local_variable( # pylint: disable=missing-docstring
+ name,
+ shape=None,
+ dtype=None,
+ initializer=None,
+ regularizer=None,
+ trainable=False, # pylint: disable=unused-argument
+ collections=None,
+ caching_device=None,
+ partitioner=None,
+ validate_shape=True,
+ use_resource=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE,
+ custom_getter=None,
+ constraint=None):
if collections:
collections += [ops.GraphKeys.LOCAL_VARIABLES]
else:
collections = [ops.GraphKeys.LOCAL_VARIABLES]
return get_variable(
- name, shape=shape, dtype=dtype, initializer=initializer,
- regularizer=regularizer, trainable=False, collections=collections,
- caching_device=caching_device, partitioner=partitioner,
- validate_shape=validate_shape, use_resource=use_resource,
- custom_getter=custom_getter, constraint=constraint)
+ name,
+ shape=shape,
+ dtype=dtype,
+ initializer=initializer,
+ regularizer=regularizer,
+ trainable=False,
+ collections=collections,
+ caching_device=caching_device,
+ partitioner=partitioner,
+ validate_shape=validate_shape,
+ use_resource=use_resource,
+ synchronization=synchronization,
+ aggregation=aggregation,
+ custom_getter=custom_getter,
+ constraint=constraint)
+
+
get_local_variable.__doc__ = get_variable_or_local_docstring % (
"Gets an existing *local* variable or creates a new one.",
"Behavior is the same as in `get_variable`, except that variables are\n"
@@ -2214,6 +2346,12 @@ def default_variable_creator(next_creator=None, **kwargs):
dtype = kwargs.get("dtype", None)
constraint = kwargs.get("constraint", None)
use_resource = kwargs.get("use_resource", None)
+
+ # Enforce `ON_READ` variables to be not trainable.
+ synchronization = kwargs.pop("synchronization", VariableSynchronization.AUTO)
+ if synchronization == VariableSynchronization.ON_READ:
+ trainable = False
+
if use_resource is None:
use_resource = get_variable_scope().use_resource
if use_resource or (use_resource is None and context.executing_eagerly()):
@@ -2248,18 +2386,28 @@ def variable(initial_value=None,
name=None,
dtype=None,
constraint=None,
- use_resource=None):
+ use_resource=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)
for getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access
previous_getter = _make_getter(getter, previous_getter)
- return previous_getter(initial_value=initial_value,
- trainable=trainable,
- collections=collections,
- validate_shape=validate_shape,
- caching_device=caching_device,
- name=name, dtype=dtype,
- constraint=constraint,
- use_resource=use_resource)
+
+ # Reset `aggregation` that is explicitly set as `None` to the enum None value.
+ if aggregation is None:
+ aggregation = VariableAggregation.NONE
+ return previous_getter(
+ initial_value=initial_value,
+ trainable=trainable,
+ collections=collections,
+ validate_shape=validate_shape,
+ caching_device=caching_device,
+ name=name,
+ dtype=dtype,
+ constraint=constraint,
+ use_resource=use_resource,
+ synchronization=synchronization,
+ aggregation=aggregation)
@tf_contextlib.contextmanager
@@ -2311,6 +2459,14 @@ def variable_creator_scope(variable_creator):
constraint: A constraint function to be applied to the variable after
updates by some algorithms.
use_resource: if True, a ResourceVariable is always created.
+ synchronization: Indicates when a distributed a variable will be
+ aggregated. Accepted values are constants defined in the class
+ @{VariableSynchronization}. By default the synchronization is set to
+ `AUTO` and the current `DistributionStrategy` chooses
+ when to synchronize.
+ aggregation: Indicates how a distributed variable will be aggregated.
+ Accepted values are constants defined in the class
+ @{tf.VariableAggregation}.
This set may grow over time, so it's important the signature of creators is as
mentioned above.
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index d3172838a4..9a09cdaa52 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -1723,6 +1723,8 @@ def report_uninitialized_variables(var_list=None,
var_list.append(op.outputs[0])
with ops.name_scope(name):
# Run all operations on CPU
+ if var_list:
+ init_vars = [state_ops.is_variable_initialized(v) for v in var_list]
with ops.device("/cpu:0"):
if not var_list:
# Return an empty tensor so we only need to check for returned tensor
@@ -1730,9 +1732,7 @@ def report_uninitialized_variables(var_list=None,
return array_ops.constant([], dtype=dtypes.string)
else:
# Get a 1-D boolean tensor listing whether each variable is initialized.
- variables_mask = math_ops.logical_not(
- array_ops.stack(
- [state_ops.is_variable_initialized(v) for v in var_list]))
+ variables_mask = math_ops.logical_not(array_ops.stack(init_vars))
# Get a 1-D string tensor containing all the variable names.
variable_names_tensor = array_ops.constant(
[s.op.name for s in var_list])
diff --git a/tensorflow/python/training/checkpointable/BUILD b/tensorflow/python/training/checkpointable/BUILD
index 54f359489e..35007653a0 100644
--- a/tensorflow/python/training/checkpointable/BUILD
+++ b/tensorflow/python/training/checkpointable/BUILD
@@ -47,6 +47,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":base",
+ ":data_structures",
],
)
diff --git a/tensorflow/python/training/checkpointable/base.py b/tensorflow/python/training/checkpointable/base.py
index 99c8098eca..e9c8c21905 100644
--- a/tensorflow/python/training/checkpointable/base.py
+++ b/tensorflow/python/training/checkpointable/base.py
@@ -33,6 +33,7 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import saveable_object
from tensorflow.python.util import nest
from tensorflow.python.util import serialization
+from tensorflow.python.util import tf_decorator
# Key where the object graph proto is saved in a TensorBundle
@@ -340,6 +341,34 @@ _SlotVariableRestoration = collections.namedtuple(
])
+def no_automatic_dependency_tracking(method):
+ """Disables automatic dependency tracking on attribute assignment.
+
+ Use to decorate any method of a Checkpointable object. Attribute assignment in
+ that method will not add dependencies (also respected in Model). Harmless if
+ used in a class which does not do automatic dependency tracking (which means
+ it's safe to use in base classes which may have subclasses which also inherit
+ from Checkpointable).
+
+ Args:
+ method: The method to decorate.
+ Returns:
+ A decorated method which sets and un-sets automatic dependency tracking for
+ the object the method is called on (not thread safe).
+ """
+
+ def _method_wrapper(self, *args, **kwargs):
+ previous_value = getattr(self, "_setattr_tracking", True)
+ self._setattr_tracking = False # pylint: disable=protected-access
+ try:
+ method(self, *args, **kwargs)
+ finally:
+ self._setattr_tracking = previous_value # pylint: disable=protected-access
+
+ return tf_decorator.make_decorator(
+ target=method, decorator_func=_method_wrapper)
+
+
class CheckpointableBase(object):
"""Base class for `Checkpointable` objects without automatic dependencies.
@@ -349,6 +378,11 @@ class CheckpointableBase(object):
checks.
"""
+ # CheckpointableBase does not do automatic dependency tracking, but uses the
+ # no_automatic_dependency_tracking decorator so it can avoid adding
+ # dependencies if a subclass is Checkpointable / inherits from Model (both of
+ # which have __setattr__ overrides).
+ @no_automatic_dependency_tracking
def _maybe_initialize_checkpointable(self):
"""Initialize dependency management.
@@ -386,6 +420,10 @@ class CheckpointableBase(object):
# building.
self._name_based_restores = set()
+ def _no_dependency(self, value):
+ """If automatic dependency tracking is enabled, ignores `value`."""
+ return value
+
def _name_based_attribute_restore(self, checkpoint):
"""Restore the object's attributes from a name-based checkpoint."""
self._name_based_restores.add(checkpoint)
@@ -733,28 +771,3 @@ class CheckpointableBase(object):
return {OBJECT_CONFIG_JSON_KEY: functools.partial(
PythonStringStateSaveable,
state_callback=_state_callback)}
-
-
-class NoDependency(object):
- """Allows attribute assignment to `Checkpointable` objects with no dependency.
-
- Example usage:
- ```python
- obj = Checkpointable()
- obj.has_dependency = tf.Variable(0., name="dep")
- obj.no_dependency = NoDependency(tf.Variable(1., name="nodep"))
- assert obj.no_dependency.name == "nodep:0"
- ```
-
- `obj` in this example has a dependency on the variable "dep", and both
- attributes contain un-wrapped `Variable` objects.
-
- `NoDependency` also works with `tf.keras.Model`, but only for checkpoint
- dependencies: wrapping a `Layer` in `NoDependency` will assign the (unwrapped)
- `Layer` to the attribute without a checkpoint dependency, but the `Model` will
- still track the `Layer` (so it will appear in `Model.layers`, and its
- variables will appear in `Model.variables`).
- """
-
- def __init__(self, value):
- self.value = value
diff --git a/tensorflow/python/training/checkpointable/data_structures.py b/tensorflow/python/training/checkpointable/data_structures.py
index c46585b417..019d43f09c 100644
--- a/tensorflow/python/training/checkpointable/data_structures.py
+++ b/tensorflow/python/training/checkpointable/data_structures.py
@@ -22,49 +22,126 @@ import collections
import six
from tensorflow.python.ops import variables
-from tensorflow.python.training.checkpointable import base as checkpointable_lib
+from tensorflow.python.training.checkpointable import base
from tensorflow.python.training.checkpointable import layer_utils
-# TODO(allenl): We could track regular Python data structures which get assigned
-# to Checkpointable objects. Making this work with restore-on-create would be
-# tricky; we'd need to re-create nested structures with our own wrapped objects
-# on assignment to an attribute, and track the user's original structure to make
-# sure they don't modify it except through the wrappers (since we could save the
-# user's updated structure, but would have no way to support restore-on-create
-# for those modifications).
-# TODO(allenl): A dictionary data structure would be good too.
-class CheckpointableDataStructure(checkpointable_lib.CheckpointableBase):
+class NoDependency(object):
+ """Allows attribute assignment to `Checkpointable` objects with no dependency.
+
+ Example usage:
+ ```python
+ obj = Checkpointable()
+ obj.has_dependency = tf.Variable(0., name="dep")
+ obj.no_dependency = NoDependency(tf.Variable(1., name="nodep"))
+ assert obj.no_dependency.name == "nodep:0"
+ ```
+
+ `obj` in this example has a dependency on the variable "dep", and both
+ attributes contain un-wrapped `Variable` objects.
+
+ `NoDependency` also works with `tf.keras.Model`, but only for checkpoint
+ dependencies: wrapping a `Layer` in `NoDependency` will assign the (unwrapped)
+ `Layer` to the attribute without a checkpoint dependency, but the `Model` will
+ still track the `Layer` (so it will appear in `Model.layers`, and its
+ variables will appear in `Model.variables`).
+ """
+
+ def __init__(self, value):
+ self.value = value
+
+
+def _wrap_or_unwrap(value):
+ """Wraps basic data structures, unwraps NoDependency objects."""
+ if isinstance(value, NoDependency):
+ return value.value
+ if isinstance(value, base.CheckpointableBase):
+ return value # Skip conversion for already checkpointable objects.
+ elif isinstance(value, list):
+ return _ListWrapper(value)
+ else:
+ return value
+ # TODO(allenl): Handle other common data structures. Tuples will require
+ # special casing (tuple subclasses are not weak referenceable, so replacement
+ # with a wrapper that subclasses tuple on attribute assignment works poorly,
+ # and replacement with a wrapper that isn't a tuple is also problematic),
+ # probably a tree traversal where the leaves are non-tuples(/namedtuples) to
+ # come up with names. Dictionaries should look like lists.
+
+
+def sticky_attribute_assignment(checkpointable, name, value):
+ """Adds dependencies, generally called from __setattr__.
+
+ This behavior is shared between Checkpointable and Model.
+
+ Respects NoDependency indicators, but otherwise makes checkpointable objects
+ out of common data structures and tracks objects by their attribute names.
+
+ Args:
+ checkpointable: The object to add dependencies to (generally the one having
+ an attribute assigned).
+ name: The attribute name being assigned.
+ value: The value being assigned. Not necessarily a checkpointable object.
+
+ Returns:
+ The value which should be stored in the attribute (unwrapped from a
+ NoDependency object if necessary).
+ """
+ if isinstance(value, NoDependency):
+ add_dependency = False
+ else:
+ add_dependency = True
+ value = _wrap_or_unwrap(value)
+ if not add_dependency:
+ return value
+ if isinstance(value, base.CheckpointableBase):
+ checkpointable._track_checkpointable( # pylint: disable=protected-access
+ value, name=name,
+ # Allow the user to switch the Checkpointable which is tracked by this
+ # name, since assigning a new variable to an attribute has
+ # historically been fine (e.g. Adam did this).
+ overwrite=True)
+ return value
+
+
+class CheckpointableDataStructure(base.CheckpointableBase):
"""Base class for data structures which contain checkpointable objects."""
def __init__(self):
+ # An append-only ordered set
self._layers = []
+
self.trainable = True
self._extra_variables = []
def _track_value(self, value, name):
"""Add a dependency on `value`."""
- if isinstance(value, checkpointable_lib.CheckpointableBase):
- self._track_checkpointable(value, name=name)
- if isinstance(value, variables.Variable):
- self._extra_variables.append(value)
- else:
+ value = sticky_attribute_assignment(
+ checkpointable=self, value=value, name=name)
+ if isinstance(value, variables.Variable):
+ self._extra_variables.append(value)
+ if not isinstance(value, base.CheckpointableBase):
raise ValueError(
("Only checkpointable objects (such as Layers or Optimizers) may be "
"stored in a List object. Got %s, which does not inherit from "
"CheckpointableBase.") % (value,))
if (isinstance(value, CheckpointableDataStructure)
or layer_utils.is_layer(value)):
- if value not in self._layers:
+ # Check for object-identity rather than with __eq__ to avoid
+ # de-duplicating empty container types. Automatically generated list
+ # wrappers keep things like "[] == []" true, which means "[] in [[]]" is
+ # also true. This becomes not true once one of the lists is mutated.
+ if not any((layer is value for layer in self._layers)):
self._layers.append(value)
if hasattr(value, "_use_resource_variables"):
# In subclassed models, legacy layers (tf.layers) must always use
# resource variables.
value._use_resource_variables = True # pylint: disable=protected-access
+ return value
@property
def layers(self):
- return self._layers
+ return layer_utils.filter_empty_layer_containers(self._layers)
@property
def trainable_weights(self):
@@ -164,24 +241,28 @@ class List(CheckpointableDataStructure, collections.Sequence):
def __init__(self, *args, **kwargs):
"""Construct a new sequence. Arguments are passed to `list()`."""
super(List, self).__init__()
- self._storage = list(*args, **kwargs)
+ self._storage = self._make_storage(*args, **kwargs)
for index, element in enumerate(self._storage):
- self._track_value(element, name=self._name_element(index))
+ self._storage[index] = self._track_value(
+ element, name=self._name_element(index))
+
+ def _make_storage(self, *args, **kwargs):
+ """Determines the backing storage (overridden in subclasses)."""
+ return list(*args, **kwargs)
def _name_element(self, index):
return "%d" % (index,)
def append(self, value):
"""Add a new checkpointable value."""
- self._track_value(value, self._name_element(len(self._storage)))
+ value = self._track_value(value, self._name_element(len(self._storage)))
self._storage.append(value)
def extend(self, values):
"""Add a sequence of checkpointable values."""
- for index_offset, value in enumerate(values):
- self._track_value(
- value, name=self._name_element(len(self._storage) + index_offset))
- self._storage.extend(values)
+ for value in values:
+ self._storage.append(self._track_value(
+ value, name=self._name_element(len(self._storage))))
def __iadd__(self, values):
self.extend(values)
@@ -189,9 +270,12 @@ class List(CheckpointableDataStructure, collections.Sequence):
def __add__(self, other):
if isinstance(other, List):
- return List(self._storage + other._storage) # pylint: disable=protected-access
+ return self.__class__(self._storage + other._storage) # pylint: disable=protected-access
else:
- return List(self._storage + other)
+ return self.__class__(self._storage + other)
+
+ def __radd__(self, other):
+ return self + other
def __getitem__(self, key):
return self._storage[key]
@@ -203,6 +287,144 @@ class List(CheckpointableDataStructure, collections.Sequence):
return "List(%s)" % (repr(self._storage),)
+class _ListWrapper(List, collections.MutableSequence,
+ # Shadowed, but there for isinstance checks.
+ list):
+ """Wraps the built-in `list` to support restore-on-create for variables.
+
+ Unlike `List`, this sequence type is mutable in the same ways built-in lists
+ are. Instead of throwing an error immediately like `List`, it records
+ problematic mutations (e.g. assigning a new element to a position already
+ occupied, meaning both elements get the same names at different times) and
+ refuses to save.
+
+ On assignment to an attribute of a Model or Checkpointable object, Python
+ lists are replaced with _ListWrapper. Wrapping a list in a
+ `tf.contrib.checkpoint.NoDependency` object prevents this.
+ """
+
+ def __init__(self, wrapped_list):
+ """Construct a new list wrapper.
+
+ Args:
+ wrapped_list: The initial value of the data structure. A shallow copy may
+ be maintained for error checking. `wrapped_list` itself should not be
+ modified directly after constructing the `_ListWrapper`, and if changes
+ are detected the `_ListWrapper` will throw an exception on save.
+ """
+ # Monotonic flags which indicate this object would not be restored properly,
+ # and therefore should throw an error on save to avoid giving the impression
+ # that restoring it will work.
+ self._non_append_mutation = False
+ self._external_modification = False
+ super(_ListWrapper, self).__init__(wrapped_list)
+ self._last_wrapped_list_snapshot = list(self._storage)
+
+ def _make_storage(self, wrapped_list):
+ """Use the user's original list for storage."""
+ return wrapped_list
+
+ def _check_external_modification(self):
+ """Checks for any changes to the wrapped list not through the wrapper."""
+ if self._external_modification or self._non_append_mutation:
+ return
+ if self._storage != self._last_wrapped_list_snapshot:
+ self._external_modification = True
+ self._last_wrapped_list_snapshot = None
+
+ def _update_snapshot(self):
+ """Acknowledges tracked changes to the wrapped list."""
+ if self._external_modification or self._non_append_mutation:
+ return
+ self._last_wrapped_list_snapshot = list(self._storage)
+
+ @property
+ def _checkpoint_dependencies(self):
+ self._check_external_modification()
+ if self._non_append_mutation:
+ raise ValueError(
+ ("Unable to save the object %s (a list wrapper constructed to track "
+ "checkpointable TensorFlow objects). A list element was replaced "
+ "(__setitem__), deleted, or inserted. In order to support "
+ "restoration on object creation, tracking is exclusively for "
+ "append-only data structures.\n\nIf you don't need this list "
+ "checkpointed, wrap it in a tf.contrib.checkpoint.NoDependency "
+ "object; it will be automatically un-wrapped and subsequently "
+ "ignored." % (self,)))
+ if self._external_modification:
+ raise ValueError(
+ ("Unable to save the object %s (a list wrapper constructed to track "
+ "checkpointable TensorFlow objects). The wrapped list was modified "
+ "outside the wrapper (its final value was %s, its value when a "
+ "checkpoint dependency was added was %s), which breaks restoration "
+ "on object creation.\n\nIf you don't need this list checkpointed, "
+ "wrap it in a tf.contrib.checkpoint.NoDependency object; it will be "
+ "automatically un-wrapped and subsequently ignored." % (
+ self, self._storage, self._last_wrapped_list_snapshot)))
+ return super(_ListWrapper, self)._checkpoint_dependencies
+
+ def __delitem__(self, key):
+ self._non_append_mutation = True
+ del self._storage[key]
+
+ def __setitem__(self, key, value):
+ self._non_append_mutation = True
+ self._storage[key] = value
+
+ def append(self, value):
+ """Add a new checkpointable value."""
+ self._check_external_modification()
+ super(_ListWrapper, self).append(value)
+ self._update_snapshot()
+
+ def extend(self, values):
+ """Add a sequence of checkpointable values."""
+ self._check_external_modification()
+ super(_ListWrapper, self).extend(values)
+ self._update_snapshot()
+
+ def __eq__(self, other):
+ return self._storage == getattr(other, "_storage", other)
+
+ def __ne__(self, other):
+ return self._storage != getattr(other, "_storage", other)
+
+ def __lt__(self, other):
+ return self._storage < getattr(other, "_storage", other)
+
+ def __le__(self, other):
+ return self._storage <= getattr(other, "_storage", other)
+
+ def __gt__(self, other):
+ return self._storage > getattr(other, "_storage", other)
+
+ def __ge__(self, other):
+ return self._storage >= getattr(other, "_storage", other)
+
+ def __hash__(self):
+ # List wrappers need to compare like regular lists, and so like regular
+ # lists they don't belong in hash tables.
+ raise TypeError("unhashable type: 'ListWrapper'")
+
+ def insert(self, index, obj):
+ self._non_append_mutation = True
+ self._storage.insert(index, obj)
+
+ def _track_value(self, value, name):
+ """Allows storage of non-checkpointable objects."""
+ try:
+ value = super(_ListWrapper, self)._track_value(value=value, name=name)
+ except ValueError:
+ # Even if this value isn't checkpointable, we need to make sure
+ # NoDependency objects get unwrapped.
+ value = sticky_attribute_assignment(
+ checkpointable=self, value=value, name=name)
+ return value
+
+ def __repr__(self):
+ return "ListWrapper(%s)" % (repr(self._storage),)
+
+
class Mapping(CheckpointableDataStructure, collections.Mapping):
"""An append-only checkpointable mapping data structure with string keys.
@@ -217,8 +439,10 @@ class Mapping(CheckpointableDataStructure, collections.Mapping):
"""Construct a new sequence. Arguments are passed to `dict()`."""
super(Mapping, self).__init__()
self._storage = dict(*args, **kwargs)
- for key, value in self._storage.items():
- self._track_value(value, name=self._name_element(key))
+ self._storage.update(
+ {key: self._track_value(
+ value, name=self._name_element(key))
+ for key, value in self._storage.items()})
def _name_element(self, key):
if not isinstance(key, six.string_types):
@@ -228,13 +452,14 @@ class Mapping(CheckpointableDataStructure, collections.Mapping):
return str(key)
def __setitem__(self, key, value):
+ name = self._name_element(key)
+ value = self._track_value(value, name=name)
current_value = self._storage.setdefault(key, value)
if current_value is not value:
raise ValueError(
("Mappings are an append-only data structure. Tried to overwrite the "
"key '%s' with value %s, but it already contains %s")
% (key, value, current_value))
- self._track_value(value, name=self._name_element(key))
def update(self, *args, **kwargs):
for key, value in dict(*args, **kwargs).items():
diff --git a/tensorflow/python/training/checkpointable/data_structures_test.py b/tensorflow/python/training/checkpointable/data_structures_test.py
index ce5852dd6e..ec8c9da809 100644
--- a/tensorflow/python/training/checkpointable/data_structures_test.py
+++ b/tensorflow/python/training/checkpointable/data_structures_test.py
@@ -31,6 +31,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.training.checkpointable import data_structures
+from tensorflow.python.training.checkpointable import tracking
class HasList(training.Model):
@@ -113,6 +114,19 @@ class ListTests(test.TestCase):
model(model_input)
self.assertEqual(2, len(model.losses))
+ def testModelContainersCompareEqual(self):
+ class HasEqualContainers(training.Model):
+
+ def __init__(self):
+ super(HasEqualContainers, self).__init__()
+ self.l1 = []
+ self.l2 = []
+
+ model = HasEqualContainers()
+ model.l1.append(HasEqualContainers())
+ model.l2.append(HasEqualContainers())
+ self.assertEqual([model.l1, model.l2], model.layers)
+
def testNotCheckpointable(self):
class NotCheckpointable(object):
pass
@@ -158,11 +172,62 @@ class ListTests(test.TestCase):
self.assertEqual([v], l.trainable_weights)
self.assertEqual([v2], l.non_trainable_weights)
+ def testListWrapperBasic(self):
+ # _ListWrapper, unlike List, compares like the built-in list type (since it
+ # is used to automatically replace lists).
+ a = tracking.Checkpointable()
+ b = tracking.Checkpointable()
+ self.assertEqual([a, a],
+ [a, a])
+ self.assertEqual(data_structures._ListWrapper([a, a]),
+ data_structures._ListWrapper([a, a]))
+ self.assertEqual([a, a],
+ data_structures._ListWrapper([a, a]))
+ self.assertEqual(data_structures._ListWrapper([a, a]),
+ [a, a])
+ self.assertNotEqual([a, a],
+ [b, a])
+ self.assertNotEqual(data_structures._ListWrapper([a, a]),
+ data_structures._ListWrapper([b, a]))
+ self.assertNotEqual([a, a],
+ data_structures._ListWrapper([b, a]))
+ self.assertLess([a], [a, b])
+ self.assertLess(data_structures._ListWrapper([a]),
+ data_structures._ListWrapper([a, b]))
+ self.assertLessEqual([a], [a, b])
+ self.assertLessEqual(data_structures._ListWrapper([a]),
+ data_structures._ListWrapper([a, b]))
+ self.assertGreater([a, b], [a])
+ self.assertGreater(data_structures._ListWrapper([a, b]),
+ data_structures._ListWrapper([a]))
+ self.assertGreaterEqual([a, b], [a])
+ self.assertGreaterEqual(data_structures._ListWrapper([a, b]),
+ data_structures._ListWrapper([a]))
+ self.assertEqual([a], data_structures._ListWrapper([a]))
+ self.assertEqual([a], list(data_structures.List([a])))
+ self.assertEqual([a, a], data_structures._ListWrapper([a]) + [a])
+ self.assertEqual([a, a], [a] + data_structures._ListWrapper([a]))
+ self.assertIsInstance(data_structures._ListWrapper([a]), list)
+
+ def testWrapperChangesList(self):
+ l = []
+ l_wrapper = data_structures._ListWrapper(l)
+ l_wrapper.append(1)
+ self.assertEqual([1], l)
+
+ def testListChangesWrapper(self):
+ l = []
+ l_wrapper = data_structures._ListWrapper(l)
+ l.append(1)
+ self.assertEqual([1], l_wrapper)
+
def testHashing(self):
has_sequences = set([data_structures.List(),
data_structures.List()])
self.assertEqual(2, len(has_sequences))
self.assertNotIn(data_structures.List(), has_sequences)
+ with self.assertRaises(TypeError):
+ has_sequences.add(data_structures._ListWrapper([]))
class HasMapping(training.Model):
diff --git a/tensorflow/python/training/checkpointable/layer_utils.py b/tensorflow/python/training/checkpointable/layer_utils.py
index fdcf963d32..978fcb2252 100644
--- a/tensorflow/python/training/checkpointable/layer_utils.py
+++ b/tensorflow/python/training/checkpointable/layer_utils.py
@@ -30,6 +30,14 @@ def is_layer(obj):
and hasattr(obj, "variables"))
+def filter_empty_layer_containers(layer_list):
+ """Filter out empty Layer-like containers."""
+ return [layer for layer in layer_list
+ # Filter out only empty Checkpointable data structures. Empty Networks
+ # will still show up in Model.layers.
+ if is_layer(layer) or getattr(layer, "layers", True)]
+
+
def gather_trainable_weights(trainable, sub_layers, extra_variables):
"""Lists the trainable weights for an object with sub-layers.
diff --git a/tensorflow/python/training/checkpointable/tracking.py b/tensorflow/python/training/checkpointable/tracking.py
index 00e14ac982..bd0bed9d46 100644
--- a/tensorflow/python/training/checkpointable/tracking.py
+++ b/tensorflow/python/training/checkpointable/tracking.py
@@ -18,31 +18,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.training.checkpointable import base
-
-
-class NoDependency(object):
- """Allows attribute assignment to `Checkpointable` objects with no dependency.
-
- Example usage:
- ```python
- obj = Checkpointable()
- obj.has_dependency = tf.Variable(0., name="dep")
- obj.no_dependency = NoDependency(tf.Variable(1., name="nodep"))
- assert obj.no_dependency.name == "nodep:0"
- ```
-
- `obj` in this example has a dependency on the variable "dep", and both
- attributes contain un-wrapped `Variable` objects.
-
- `NoDependency` also works with `tf.keras.Model`, but only for checkpoint
- dependencies: wrapping a `Layer` in `NoDependency` will assign the (unwrapped)
- `Layer` to the attribute without a checkpoint dependency, but the `Model` will
- still track the `Layer` (so it will appear in `Model.layers`, and its
- variables will appear in `Model.variables`).
- """
-
- def __init__(self, value):
- self.value = value
+from tensorflow.python.training.checkpointable import data_structures
class NotCheckpointable(object):
@@ -86,18 +62,11 @@ class Checkpointable(base.CheckpointableBase):
def __setattr__(self, name, value):
"""Support self.foo = checkpointable syntax."""
- # Perform the attribute assignment, and potentially call other __setattr__
- # overrides such as that for tf.keras.Model.
- no_dependency = isinstance(value, NoDependency)
- if no_dependency:
- value = value.value
+ if getattr(self, "_setattr_tracking", True):
+ value = data_structures.sticky_attribute_assignment(
+ checkpointable=self, value=value, name=name)
super(Checkpointable, self).__setattr__(name, value)
- if not no_dependency and isinstance(value, base.CheckpointableBase):
- self._track_checkpointable(
- value, name=name,
- # Allow the user to switch the Checkpointable which is tracked by this
- # name, since assigning a new variable to an attribute has
- # historically been fine (e.g. Adam did this).
- # TODO(allenl): Should this be a warning once Checkpointable save/load
- # is usable?
- overwrite=True)
+
+ def _no_dependency(self, value):
+ """Override to allow CheckpointableBase to disable dependency tracking."""
+ return data_structures.NoDependency(value)
diff --git a/tensorflow/python/training/checkpointable/tracking_test.py b/tensorflow/python/training/checkpointable/tracking_test.py
index baf6f57efb..f0178b074d 100644
--- a/tensorflow/python/training/checkpointable/tracking_test.py
+++ b/tensorflow/python/training/checkpointable/tracking_test.py
@@ -16,8 +16,19 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os
+
+import numpy
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.keras.engine import training
+from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
+from tensorflow.python.training.checkpointable import base
+from tensorflow.python.training.checkpointable import data_structures
from tensorflow.python.training.checkpointable import tracking
+from tensorflow.python.training.checkpointable import util
+from tensorflow.python.util import nest
class InterfaceTests(test.TestCase):
@@ -27,7 +38,7 @@ class InterfaceTests(test.TestCase):
root.leaf = tracking.Checkpointable()
root.leaf = root.leaf
duplicate_name_dep = tracking.Checkpointable()
- with self.assertRaises(ValueError):
+ with self.assertRaisesRegexp(ValueError, "already declared"):
root._track_checkpointable(duplicate_name_dep, name="leaf")
# No error; we're overriding __setattr__, so we can't really stop people
# from doing this while maintaining backward compatibility.
@@ -39,11 +50,119 @@ class InterfaceTests(test.TestCase):
hasdep = tracking.Checkpointable()
root.hasdep = hasdep
nodep = tracking.Checkpointable()
- root.nodep = tracking.NoDependency(nodep)
+ root.nodep = data_structures.NoDependency(nodep)
self.assertEqual(1, len(root._checkpoint_dependencies))
self.assertIs(root._checkpoint_dependencies[0].ref, root.hasdep)
self.assertIs(root.hasdep, hasdep)
self.assertIs(root.nodep, nodep)
+ class NoDependencyModel(training.Model):
+
+ @base.no_automatic_dependency_tracking
+ def __init__(self):
+ super(NoDependencyModel, self).__init__()
+ self.a = []
+ self.b = tracking.Checkpointable()
+
+ nodeps = NoDependencyModel()
+ self.assertEqual([nodeps], util.list_objects(nodeps))
+
+ def testListBasic(self):
+ a = tracking.Checkpointable()
+ b = tracking.Checkpointable()
+ a.l = [b]
+ c = tracking.Checkpointable()
+ a.l.append(c)
+ a_deps = util.list_objects(a)
+ self.assertIn(b, a_deps)
+ self.assertIn(c, a_deps)
+ direct_a_dep, = a._checkpoint_dependencies
+ self.assertEqual("l", direct_a_dep.name)
+ self.assertIn(b, direct_a_dep.ref)
+ self.assertIn(c, direct_a_dep.ref)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testMutationDirtiesList(self):
+ a = tracking.Checkpointable()
+ b = tracking.Checkpointable()
+ a.l = [b]
+ c = tracking.Checkpointable()
+ a.l.insert(0, c)
+ checkpoint = util.Checkpoint(a=a)
+ with self.assertRaisesRegexp(ValueError, "A list element was replaced"):
+ checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
+
+ @test_util.run_in_graph_and_eager_modes
+ def testOutOfBandEditDirtiesList(self):
+ a = tracking.Checkpointable()
+ b = tracking.Checkpointable()
+ held_reference = [b]
+ a.l = held_reference
+ c = tracking.Checkpointable()
+ held_reference.append(c)
+ checkpoint = util.Checkpoint(a=a)
+ with self.assertRaisesRegexp(ValueError, "The wrapped list was modified"):
+ checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
+
+ @test_util.run_in_graph_and_eager_modes
+ def testNestedLists(self):
+ a = tracking.Checkpointable()
+ a.l = []
+ b = tracking.Checkpointable()
+ a.l.append([b])
+ c = tracking.Checkpointable()
+ a.l[0].append(c)
+ a_deps = util.list_objects(a)
+ self.assertIn(b, a_deps)
+ self.assertIn(c, a_deps)
+ a.l[0].append(1)
+ d = tracking.Checkpointable()
+ a.l[0].append(d)
+ a_deps = util.list_objects(a)
+ self.assertIn(d, a_deps)
+ self.assertIn(b, a_deps)
+ self.assertIn(c, a_deps)
+ self.assertNotIn(1, a_deps)
+ e = tracking.Checkpointable()
+ f = tracking.Checkpointable()
+ a.l1 = [[], [e]]
+ a.l1[0].append(f)
+ a_deps = util.list_objects(a)
+ self.assertIn(e, a_deps)
+ self.assertIn(f, a_deps)
+ checkpoint = util.Checkpoint(a=a)
+ checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
+ a.l[0].append(data_structures.NoDependency([]))
+ a.l[0][-1].append(5)
+ checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
+ # Dirtying the inner list means the root object is unsaveable.
+ a.l[0][1] = 2
+ with self.assertRaisesRegexp(ValueError, "A list element was replaced"):
+ checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
+
+ @test_util.run_in_graph_and_eager_modes
+ def testNoDepList(self):
+ a = training.Model()
+ a.l1 = data_structures.NoDependency([])
+ a.l1.insert(1, 0)
+ self.assertTrue(isinstance(a.l1, list))
+ checkpoint = util.Checkpoint(a=a)
+ checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
+ a.l2 = []
+ a.l2.insert(1, 0)
+ with self.assertRaisesRegexp(ValueError, "A list element was replaced"):
+ checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
+
+ @test_util.run_in_graph_and_eager_modes
+ def testAssertions(self):
+ a = tracking.Checkpointable()
+ a.l = [numpy.zeros([2, 2])]
+ self.assertAllEqual([numpy.zeros([2, 2])], a.l)
+ self.assertAllClose([numpy.zeros([2, 2])], a.l)
+ nest.map_structure(self.assertAllClose, a.l, [numpy.zeros([2, 2])])
+ a.tensors = [array_ops.ones([2, 2]), array_ops.zeros([3, 3])]
+ self.assertAllClose([numpy.ones([2, 2]), numpy.zeros([3, 3])],
+ self.evaluate(a.tensors))
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/training/checkpointable/util.py b/tensorflow/python/training/checkpointable/util.py
index e0f61137b1..6ae5765b13 100644
--- a/tensorflow/python/training/checkpointable/util.py
+++ b/tensorflow/python/training/checkpointable/util.py
@@ -40,6 +40,7 @@ from tensorflow.python.training import optimizer as optimizer_lib
from tensorflow.python.training import saveable_object as saveable_object_lib
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training.checkpointable import base
+from tensorflow.python.training.checkpointable import data_structures
from tensorflow.python.training.checkpointable import tracking
from tensorflow.python.util import deprecation
from tensorflow.python.util import tf_contextlib
@@ -93,7 +94,7 @@ class _CheckpointRestoreCoordinator(object):
# use them (for example because of inconsistent references when
# loading). Used to make status assertions fail when loading checkpoints
# that don't quite match.
- self.all_python_objects = weakref.WeakSet()
+ self.all_python_objects = _ObjectIdentityWeakSet()
self.save_path = save_path
self.dtype_map = dtype_map
# When graph building, contains a list of ops to run to restore objects from
@@ -272,11 +273,129 @@ def object_metadata(save_path):
return object_graph_proto
+class _ObjectIdentityWrapper(object):
+ """Wraps an object, mapping __eq__ on wrapper to "is" on wrapped.
+
+ Since __eq__ is based on object identity, it's safe to also define __hash__
+ based on object ids. This lets us add unhashable types like checkpointable
+ _ListWrapper objects to object-identity collections.
+ """
+
+ def __init__(self, wrapped):
+ self._wrapped = wrapped
+
+ @property
+ def unwrapped(self):
+ return self._wrapped
+
+ def __eq__(self, other):
+ if isinstance(other, _ObjectIdentityWrapper):
+ return self._wrapped is other._wrapped # pylint: disable=protected-access
+ return self._wrapped is other
+
+ def __hash__(self):
+ # Wrapper id() is also fine for weakrefs. In fact, we rely on
+ # id(weakref.ref(a)) == id(weakref.ref(a)) and weakref.ref(a) is
+ # weakref.ref(a) in _WeakObjectIdentityWrapper.
+ return id(self._wrapped)
+
+
+class _WeakObjectIdentityWrapper(_ObjectIdentityWrapper):
+
+ def __init__(self, wrapped):
+ super(_WeakObjectIdentityWrapper, self).__init__(weakref.ref(wrapped))
+
+ @property
+ def unwrapped(self):
+ return self._wrapped()
+
+
+class _ObjectIdentityDictionary(collections.MutableMapping):
+ """A mutable mapping data structure which compares using "is".
+
+ This is necessary because we have checkpointable objects (_ListWrapper) which
+ have behavior identical to built-in Python lists (including being unhashable
+ and comparing based on the equality of their contents by default).
+ """
+
+ def __init__(self):
+ self._storage = {}
+
+ def _wrap_key(self, key):
+ return _ObjectIdentityWrapper(key)
+
+ def __getitem__(self, key):
+ return self._storage[self._wrap_key(key)]
+
+ def __setitem__(self, key, value):
+ self._storage[self._wrap_key(key)] = value
+
+ def __delitem__(self, key):
+ del self._storage[self._wrap_key(key)]
+
+ def __len__(self):
+ return len(self._storage)
+
+ def __iter__(self):
+ for key in self._storage:
+ yield key.unwrapped
+
+
+class _ObjectIdentityWeakKeyDictionary(_ObjectIdentityDictionary):
+ """Like weakref.WeakKeyDictionary, but compares objects with "is"."""
+
+ def _wrap_key(self, key):
+ return _WeakObjectIdentityWrapper(key)
+
+ def __len__(self):
+ # Iterate, discarding old weak refs
+ return len(list(self._storage))
+
+ def __iter__(self):
+ keys = self._storage.keys()
+ for key in keys:
+ unwrapped = key.unwrapped
+ if unwrapped is None:
+ del self[key]
+ else:
+ yield unwrapped
+
+
+class _ObjectIdentityWeakSet(collections.MutableSet):
+ """Like weakref.WeakSet, but compares objects with "is"."""
+
+ def __init__(self):
+ self._storage = set()
+
+ def __contains__(self, key):
+ return _WeakObjectIdentityWrapper(key) in self._storage
+
+ def discard(self, key):
+ self._storage.discard(_WeakObjectIdentityWrapper(key))
+
+ def add(self, key):
+ self._storage.add(_WeakObjectIdentityWrapper(key))
+
+ def __len__(self):
+ # Iterate, discarding old weak refs
+ return len(list(self))
+
+ def __iter__(self):
+ keys = list(self._storage)
+ for key in keys:
+ unwrapped = key.unwrapped
+ if unwrapped is None:
+ self.discard(key)
+ else:
+ yield unwrapped
+
+
def _breadth_first_checkpointable_traversal(root_checkpointable):
"""Find shortest paths to all variables owned by dependencies of root."""
bfs_sorted = []
to_visit = collections.deque([root_checkpointable])
- path_to_root = {root_checkpointable: ()}
+ path_to_root = _ObjectIdentityDictionary()
+ path_to_root[root_checkpointable] = ()
while to_visit:
current_checkpointable = to_visit.popleft()
if isinstance(current_checkpointable, tracking.NotCheckpointable):
@@ -337,7 +456,7 @@ def _slot_variable_naming_for_optimizer(optimizer_path):
def _serialize_slot_variables(checkpointable_objects, node_ids, object_names):
"""Gather and name slot variables."""
non_slot_objects = list(checkpointable_objects)
- slot_variables = {}
+ slot_variables = _ObjectIdentityDictionary()
for checkpointable in non_slot_objects:
if isinstance(checkpointable, optimizer_lib.Optimizer):
naming_scheme = _slot_variable_naming_for_optimizer(
@@ -500,11 +619,12 @@ def _serialize_object_graph(root_checkpointable, saveables_cache):
"""
checkpointable_objects, path_to_root = (
_breadth_first_checkpointable_traversal(root_checkpointable))
- object_names = {
- obj: _object_prefix_from_path(path)
- for obj, path in path_to_root.items()}
- node_ids = {node: node_id for node_id, node
- in enumerate(checkpointable_objects)}
+ object_names = _ObjectIdentityDictionary()
+ for obj, path in path_to_root.items():
+ object_names[obj] = _object_prefix_from_path(path)
+ node_ids = _ObjectIdentityDictionary()
+ for node_id, node in enumerate(checkpointable_objects):
+ node_ids[node] = node_id
slot_variables = _serialize_slot_variables(
checkpointable_objects=checkpointable_objects,
node_ids=node_ids,
@@ -535,11 +655,12 @@ def list_objects(root_checkpointable):
# to run.
checkpointable_objects, path_to_root = (
_breadth_first_checkpointable_traversal(root_checkpointable))
- object_names = {
- obj: _object_prefix_from_path(path)
- for obj, path in path_to_root.items()}
- node_ids = {node: node_id for node_id, node
- in enumerate(checkpointable_objects)}
+ object_names = _ObjectIdentityDictionary()
+ for obj, path in path_to_root.items():
+ object_names[obj] = _object_prefix_from_path(path)
+ node_ids = _ObjectIdentityDictionary()
+ for node_id, node in enumerate(checkpointable_objects):
+ node_ids[node] = node_id
_serialize_slot_variables(
checkpointable_objects=checkpointable_objects,
node_ids=node_ids,
@@ -988,7 +1109,7 @@ class CheckpointableSaver(object):
else:
# Maps Checkpointable objects -> attribute names -> SaveableObjects, to
# avoid re-creating SaveableObjects when graph building.
- self._saveable_object_cache = weakref.WeakKeyDictionary()
+ self._saveable_object_cache = _ObjectIdentityWeakKeyDictionary()
@property
def _root_checkpointable(self):
@@ -1310,7 +1431,7 @@ class Checkpoint(tracking.Checkpointable):
with ops.device("/cpu:0"):
# add_variable creates a dependency named "save_counter"; NoDependency
# prevents creating a second dependency named "_save_counter".
- self._save_counter = tracking.NoDependency(
+ self._save_counter = data_structures.NoDependency(
add_variable(self, name="save_counter", initializer=0,
dtype=dtypes.int64))
diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py
index 6a326b65bb..562ad3bb02 100644
--- a/tensorflow/python/training/distribute.py
+++ b/tensorflow/python/training/distribute.py
@@ -221,11 +221,12 @@ def has_distribution_strategy():
def get_loss_reduction():
- """Reduce `method_string` corresponding to the last loss reduction."""
+ """Reduce `aggregation` corresponding to the last loss reduction."""
loss_reduction = ops.get_default_graph()._last_loss_reduction # pylint: disable=protected-access
+ print(loss_reduction)
if loss_reduction == losses_impl.Reduction.SUM:
- return "sum"
- return "mean"
+ return variable_scope.VariableAggregation.SUM
+ return variable_scope.VariableAggregation.MEAN
# ------------------------------------------------------------------------------
@@ -539,8 +540,8 @@ class DistributionStrategy(object):
1. Wrap your input dataset in `d.distribute_dataset()` and create an iterator.
2. Define each tower `d.call_for_each_tower()` up to the point of
getting a list of gradient, variable pairs.
- 3. Call `d.reduce("sum", t, v)` or `d.batch_reduce()` to sum the
- gradients (with locality T) into values with locality V(`v`).
+ 3. Call `d.reduce(VariableAggregation.SUM, t, v)` or `d.batch_reduce()` to sum
+ the gradients (with locality T) into values with locality V(`v`).
4. Call `d.update(v)` for each variable to update its value.
Steps 3 and 4 are done automatically by class `Optimizer` if you call
@@ -614,7 +615,7 @@ class DistributionStrategy(object):
# Note: should support "colocate_with" argument.
raise NotImplementedError("must be implemented in descendants")
- def tower_local_var_scope(self, reduce_method):
+ def tower_local_var_scope(self, aggregation):
"""Inside this scope, new variables will not be mirrored.
There will still be one component variable per tower, but there is
@@ -636,16 +637,21 @@ class DistributionStrategy(object):
random numbers.
Args:
- reduce_method: String used as a `method_string` to `reduce()`
- to get the value to save when checkpointing.
+ aggregation: Indicates how a variable will be aggregated. Accepted values
+ are @{tf.VariableAggregation.SUM}, @{tf.VariableAggregation.MEAN}.
Returns:
A context manager.
"""
+ # TODO(psv): Remove this after adding support for synchronization and
+ # aggregation parameters in get_variable() and mirrored strategy.
def create_tower_local_variable(next_creator, *args, **kwargs):
_require_distribution_strategy_scope(self)
kwargs["use_resource"] = True
- kwargs["tower_local_reduce_method"] = reduce_method
+
+ # Set synchronization to be ON_READ for tower local variables.
+ kwargs["synchronization"] = variable_scope.VariableSynchronization.ON_READ
+ kwargs["aggregation"] = aggregation
return next_creator(*args, **kwargs)
_require_distribution_strategy_scope(self)
@@ -816,12 +822,12 @@ class DistributionStrategy(object):
def _call_for_each_tower(self, fn, *args, **kwargs):
raise NotImplementedError("must be implemented in descendants")
- def reduce(self, method_string, value, destinations=None):
+ def reduce(self, aggregation, value, destinations=None):
"""Combine (via e.g. sum or mean) values across towers.
Args:
- method_string: A string indicating how to combine values, either
- "sum" or "mean".
+ aggregation: Indicates how a variable will be aggregated. Accepted values
+ are @{tf.VariableAggregation.SUM}, @{tf.VariableAggregation.MEAN}.
value: A per-device value with one value per tower.
destinations: An optional mirrored variable, a device string,
list of device strings. The return value will be copied to all
@@ -836,18 +842,21 @@ class DistributionStrategy(object):
# TODO(josh11b): Return an unwrapped value if colocate_with is a
# single device.
_require_cross_tower_context(self)
- assert method_string in ("sum", "mean")
- return self._reduce(method_string, value, destinations)
+ assert aggregation in [
+ variable_scope.VariableAggregation.SUM,
+ variable_scope.VariableAggregation.MEAN
+ ]
+ return self._reduce(aggregation, value, destinations)
- def _reduce(self, method_string, value, destinations):
+ def _reduce(self, aggregation, value, destinations):
raise NotImplementedError("must be implemented in descendants")
- def batch_reduce(self, method_string, value_destination_pairs):
+ def batch_reduce(self, aggregation, value_destination_pairs):
"""Combine multiple `reduce` calls into one for faster execution.
Args:
- method_string: A string indicating how to combine values, either
- "sum" or "mean".
+ aggregation: Indicates how a variable will be aggregated. Accepted values
+ are @{tf.VariableAggregation.SUM}, @{tf.VariableAggregation.MEAN}.
value_destination_pairs: A sequence of (value, destinations)
pairs. See `reduce()` for a description.
@@ -856,12 +865,17 @@ class DistributionStrategy(object):
"""
# TODO(josh11b): More docstring
_require_cross_tower_context(self)
- assert method_string in ("sum", "mean")
- return self._batch_reduce(method_string, value_destination_pairs)
-
- def _batch_reduce(self, method_string, value_destination_pairs):
- return [self.reduce(method_string, t, destinations=v)
- for t, v in value_destination_pairs]
+ assert aggregation in [
+ variable_scope.VariableAggregation.SUM,
+ variable_scope.VariableAggregation.MEAN
+ ]
+ return self._batch_reduce(aggregation, value_destination_pairs)
+
+ def _batch_reduce(self, aggregation, value_destination_pairs):
+ return [
+ self.reduce(aggregation, t, destinations=v)
+ for t, v in value_destination_pairs
+ ]
def update(self, var, fn, *args, **kwargs):
"""Run `fn` to update `var` using inputs mirrored to the same devices.
@@ -1090,9 +1104,9 @@ class TowerContext(object):
finally:
_pop_per_thread_mode()
- def tower_local_var_scope(self, reduce_method):
+ def tower_local_var_scope(self, aggregation):
"""Alias for distribution_strategy.tower_local_var_scope()."""
- return self._distribution_strategy.tower_local_var_scope(reduce_method)
+ return self._distribution_strategy.tower_local_var_scope(aggregation)
@property
def is_single_tower(self):
@@ -1140,13 +1154,12 @@ class _DefaultDistributionStrategy(DistributionStrategy):
def creator(next_creator, *args, **kwargs):
_require_distribution_strategy_scope(self)
- kwargs.pop("tower_local_reduce_method", None)
return next_creator(*args, **kwargs)
return _CurrentDistributionContext(
self, variable_scope.variable_creator_scope(creator))
- def tower_local_var_scope(self, reduce_method):
+ def tower_local_var_scope(self, aggregation):
"""Does not set to resource variables."""
def create_tower_local_variable(next_creator, *args, **kwargs):
_require_distribution_strategy_scope(self)
@@ -1176,9 +1189,9 @@ class _DefaultDistributionStrategy(DistributionStrategy):
with TowerContext(self, tower_id=0):
return fn(*args, **kwargs)
- def _reduce(self, method_string, value, destinations):
+ def _reduce(self, aggregation, value, destinations):
# TODO(josh11b): Use destinations?
- del method_string, destinations
+ del aggregation, destinations
return value
def _update(self, var, fn, *args, **kwargs):
diff --git a/tensorflow/python/training/distribute_test.py b/tensorflow/python/training/distribute_test.py
index 0a4f19c31f..694145ede7 100644
--- a/tensorflow/python/training/distribute_test.py
+++ b/tensorflow/python/training/distribute_test.py
@@ -29,6 +29,14 @@ class _TestTowerContext(distribute.TowerContext):
return kwargs["test_arg"]
+def _get_test_variable(name, synchronization, aggregation):
+ return {
+ "name": name,
+ "synchronization": synchronization,
+ "aggregation": aggregation
+ }
+
+
class _TestStrategy(distribute.DistributionStrategy):
def _call_for_each_tower(self, fn, *args, **kwargs):
@@ -36,7 +44,8 @@ class _TestStrategy(distribute.DistributionStrategy):
return fn(*args, **kwargs)
def _create_variable(self, next_creator, *args, **kwargs):
- return kwargs["name"]
+ return _get_test_variable(kwargs["name"], kwargs["synchronization"],
+ kwargs["aggregation"])
def _assert_in_default_state(t):
@@ -61,7 +70,11 @@ class TestStrategyTest(test.TestCase):
self.assertTrue(distribute.has_distribution_strategy())
self.assertIs(dist, distribute.get_distribution_strategy())
self.assertEqual("foo", tower_context.merge_call(None, test_arg="foo"))
- self.assertEqual("bar", variable_scope.variable(1.0, name="bar"))
+ expected_value = _get_test_variable(
+ "bar", variable_scope.VariableSynchronization.AUTO,
+ variable_scope.VariableAggregation.NONE)
+ self.assertDictEqual(expected_value,
+ variable_scope.variable(1.0, name="bar"))
with self.assertRaises(RuntimeError):
dist.call_for_each_tower(run_fn)
@@ -77,7 +90,27 @@ class TestStrategyTest(test.TestCase):
self.assertIs(dist, distribute.get_cross_tower_context())
self.assertTrue(distribute.has_distribution_strategy())
self.assertIs(dist, distribute.get_distribution_strategy())
- self.assertEqual("baz", variable_scope.variable(1.0, name="baz"))
+ expected_value = _get_test_variable(
+ "baz", variable_scope.VariableSynchronization.AUTO,
+ variable_scope.VariableAggregation.NONE)
+ self.assertDictEqual(expected_value,
+ variable_scope.variable(1.0, name="baz"))
+ _assert_in_default_state(self)
+
+ def testSettingSynchronizationAndAggregation(self):
+ _assert_in_default_state(self)
+ dist = _TestStrategy()
+ with dist.scope():
+ expected_value = _get_test_variable(
+ "baz", variable_scope.VariableSynchronization.ON_WRITE,
+ variable_scope.VariableAggregation.MEAN)
+ self.assertDictEqual(
+ expected_value,
+ variable_scope.variable(
+ 1.0,
+ name="baz",
+ synchronization=variable_scope.VariableSynchronization.ON_WRITE,
+ aggregation=variable_scope.VariableAggregation.MEAN))
_assert_in_default_state(self)
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index fe9ffde11c..784c9ddd1b 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -461,7 +461,8 @@ class Optimizer(
# Have to be careful to call distribute_lib.get_loss_reduction()
# *after* loss() is evaluated, so we know what loss reduction it uses.
# TODO(josh11b): Test that we handle weight decay in a reasonable way.
- if distribute_lib.get_loss_reduction() == "mean":
+ if distribute_lib.get_loss_reduction(
+ ) == variable_scope.VariableAggregation.MEAN:
num_towers = distribute_lib.get_distribution_strategy().num_towers
if num_towers > 1:
loss_value *= (1. / num_towers)
@@ -478,7 +479,8 @@ class Optimizer(
"be a function when eager execution is enabled.")
# Scale loss if using a "mean" loss reduction and multiple towers.
- if distribute_lib.get_loss_reduction() == "mean":
+ if distribute_lib.get_loss_reduction(
+ ) == variable_scope.VariableAggregation.MEAN:
num_towers = distribute_lib.get_distribution_strategy().num_towers
if num_towers > 1:
loss *= (1. / num_towers)
@@ -649,7 +651,8 @@ class Optimizer(
towers. If `global_step` was not None, that operation also
increments `global_step`.
"""
- reduced_grads = distribution.batch_reduce("sum", grads_and_vars)
+ reduced_grads = distribution.batch_reduce(
+ variable_scope.VariableAggregation.SUM, grads_and_vars)
var_list = [v for _, v in grads_and_vars]
grads_and_vars = zip(reduced_grads, var_list)
# Note that this is called in a cross-tower context.
diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py
index 1104768ae8..d63f59a8c8 100644
--- a/tensorflow/python/util/nest.py
+++ b/tensorflow/python/util/nest.py
@@ -167,11 +167,14 @@ def assert_same_structure(nest1, nest2, check_types=True):
Args:
nest1: an arbitrarily nested structure.
nest2: an arbitrarily nested structure.
- check_types: if `True` (default) types of sequences are checked as
- well, including the keys of dictionaries. If set to `False`, for example
- a list and a tuple of objects will look the same if they have the same
+ check_types: if `True` (default) types of sequences are checked as well,
+ including the keys of dictionaries. If set to `False`, for example a
+ list and a tuple of objects will look the same if they have the same
size. Note that namedtuples with identical name and fields are always
- considered to have the same shallow structure.
+ considered to have the same shallow structure. Two types will also be
+ considered the same if they are both list subtypes (which allows "list"
+ and "_ListWrapper" from checkpointable dependency tracking to compare
+ equal).
Raises:
ValueError: If the two structures do not have the same number of elements or
diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc
index c79d8a8445..366f8a0deb 100644
--- a/tensorflow/python/util/util.cc
+++ b/tensorflow/python/util/util.cc
@@ -394,7 +394,11 @@ bool AssertSameStructureHelper(PyObject* o1, PyObject* o2, bool check_types,
type2->tp_name);
return true;
}
- } else if (type1 != type2) {
+ } else if (type1 != type2
+ /* If both sequences are list types, don't complain. This allows
+ one to be a list subclass (e.g. _ListWrapper used for automatic
+ dependency tracking.) */
+ && !(PyList_Check(o1) && PyList_Check(o2))) {
*is_type_error = true;
*error_msg = tensorflow::strings::StrCat(
"The two namedtuples don't have the same sequence type. "
diff --git a/tensorflow/tools/api/golden/tensorflow.-variable-aggregation.pbtxt b/tensorflow/tools/api/golden/tensorflow.-variable-aggregation.pbtxt
new file mode 100644
index 0000000000..36b534af36
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.-variable-aggregation.pbtxt
@@ -0,0 +1,16 @@
+path: "tensorflow.VariableAggregation"
+tf_class {
+ is_instance: "<enum \'VariableAggregation\'>"
+ member {
+ name: "MEAN"
+ mtype: "<enum \'VariableAggregation\'>"
+ }
+ member {
+ name: "NONE"
+ mtype: "<enum \'VariableAggregation\'>"
+ }
+ member {
+ name: "SUM"
+ mtype: "<enum \'VariableAggregation\'>"
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.-variable-scope.pbtxt b/tensorflow/tools/api/golden/tensorflow.-variable-scope.pbtxt
index 8e539069da..ec1f72453f 100644
--- a/tensorflow/tools/api/golden/tensorflow.-variable-scope.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-variable-scope.pbtxt
@@ -56,7 +56,7 @@ tf_class {
}
member_method {
name: "get_variable"
- argspec: "args=[\'self\', \'var_store\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'reuse\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'custom_getter\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'var_store\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'reuse\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'custom_getter\', \'constraint\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "global_variables"
diff --git a/tensorflow/tools/api/golden/tensorflow.-variable-synchronization.pbtxt b/tensorflow/tools/api/golden/tensorflow.-variable-synchronization.pbtxt
new file mode 100644
index 0000000000..7589bb2888
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.-variable-synchronization.pbtxt
@@ -0,0 +1,20 @@
+path: "tensorflow.VariableSynchronization"
+tf_class {
+ is_instance: "<enum \'VariableSynchronization\'>"
+ member {
+ name: "AUTO"
+ mtype: "<enum \'VariableSynchronization\'>"
+ }
+ member {
+ name: "NONE"
+ mtype: "<enum \'VariableSynchronization\'>"
+ }
+ member {
+ name: "ON_READ"
+ mtype: "<enum \'VariableSynchronization\'>"
+ }
+ member {
+ name: "ON_WRITE"
+ mtype: "<enum \'VariableSynchronization\'>"
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt
index adab5399b2..9ec20f0955 100644
--- a/tensorflow/tools/api/golden/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.pbtxt
@@ -261,10 +261,18 @@ tf_module {
mtype: "<type \'type\'>"
}
member {
+ name: "VariableAggregation"
+ mtype: "<class \'enum.EnumMeta\'>"
+ }
+ member {
name: "VariableScope"
mtype: "<type \'type\'>"
}
member {
+ name: "VariableSynchronization"
+ mtype: "<class \'enum.EnumMeta\'>"
+ }
+ member {
name: "WholeFileReader"
mtype: "<type \'type\'>"
}
@@ -1150,7 +1158,7 @@ tf_module {
}
member_method {
name: "get_local_variable"
- argspec: "args=[\'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'custom_getter\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'synchronization\', \'aggregation\', \'custom_getter\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\', \'None\', \'True\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\', \'None\'], "
}
member_method {
name: "get_seed"
@@ -1166,7 +1174,7 @@ tf_module {
}
member_method {
name: "get_variable"
- argspec: "args=[\'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'custom_getter\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'custom_getter\', \'constraint\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "get_variable_scope"
diff --git a/tensorflow/tools/ci_build/ci_parameterized_build.sh b/tensorflow/tools/ci_build/ci_parameterized_build.sh
index d49d4b0c49..08e2c3edd2 100755
--- a/tensorflow/tools/ci_build/ci_parameterized_build.sh
+++ b/tensorflow/tools/ci_build/ci_parameterized_build.sh
@@ -131,7 +131,7 @@ BAZEL_CMD="bazel test"
BAZEL_BUILD_ONLY_CMD="bazel build"
BAZEL_CLEAN_CMD="bazel clean"
-DEFAULT_BAZEL_CONFIGS="--config=gcp --config=hdfs"
+DEFAULT_BAZEL_CONFIGS=""
PIP_CMD="${CI_BUILD_DIR}/builds/pip.sh"
PIP_TEST_TUTORIALS_FLAG="--test_tutorials"
diff --git a/tensorflow/tools/ci_build/ci_sanity.sh b/tensorflow/tools/ci_build/ci_sanity.sh
index f0a437c183..db37edf809 100755
--- a/tensorflow/tools/ci_build/ci_sanity.sh
+++ b/tensorflow/tools/ci_build/ci_sanity.sh
@@ -543,7 +543,7 @@ SANITY_STEPS=("do_pylint PYTHON2" "do_pylint PYTHON3" "do_check_futures_test" "d
SANITY_STEPS_DESC=("Python 2 pylint" "Python 3 pylint" "Check that python files have certain __future__ imports" "buildifier check" "bazel nobuild" "pip: license check for external dependencies" "C library: license check for external dependencies" "Java Native Library: license check for external dependencies" "Pip Smoke Test: Checking py_test dependencies exist in pip package" "Check load py_test: Check that BUILD files with py_test target properly load py_test" "Code Link Check: Check there are no broken links" "Test entries in /tensorflow/contrib/cmake/python_{modules|protos|protos_cc}.txt for validity and consistency" "Check file names for cases")
INCREMENTAL_FLAG=""
-DEFAULT_BAZEL_CONFIGS="--config=hdfs --config=gcp"
+DEFAULT_BAZEL_CONFIGS=""
# Parse command-line arguments
BAZEL_FLAGS=${DEFAULT_BAZEL_CONFIGS}
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 2fe0b6f072..178e7f08c8 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -754,6 +754,14 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
],
build_file = str(Label("//third_party:tflite_mobilenet.BUILD")),
)
+ tf_http_archive(
+ name = "tflite_mobilenet_ssd_quant",
+ sha256 = "a809cd290b4d6a2e8a9d5dad076e0bd695b8091974e0eed1052b480b2f21b6dc",
+ urls = ["https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_0.75_quant_2018_06_29.zip",
+ "https://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_0.75_quant_2018_06_29.zip",
+ ],
+ build_file = str(Label("//third_party:tflite_mobilenet.BUILD")),
+ )
tf_http_archive(
name = "tflite_conv_actions_frozen",