aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--CODEOWNERS2
-rw-r--r--tensorflow/BUILD1
-rw-r--r--tensorflow/c/BUILD1
-rw-r--r--tensorflow/c/c_api_experimental.h4
-rwxr-xr-xtensorflow/c/eager/c_api.cc15
-rw-r--r--tensorflow/c/eager/c_api_internal.h11
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_cond.cc787
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_cond.h166
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_cond_test.cc118
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.cc4
-rw-r--r--tensorflow/compiler/xla/reference_util.cc14
-rw-r--r--tensorflow/compiler/xla/service/BUILD2
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc21
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc50
-rw-r--r--tensorflow/compiler/xla/service/batch_dot_simplification.cc4
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_normalization_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment_test.cc11
-rw-r--r--tensorflow/compiler/xla/service/convolution_feature_group_converter.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc6
-rw-r--r--tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc13
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/dot_decomposer.cc6
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc112
-rw-r--r--tensorflow/compiler/xla/service/graphviz_example.cc7
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator_test.cc31
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation_test.cc15
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils.cc25
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils.h11
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_map.cc41
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_map.h10
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_metadata.h3
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.h3
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_test.cc112
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h43
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc41
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc57
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h7
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction_test.cc35
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc18
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h11
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc41
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc21
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding_metadata.cc7
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding_metadata.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc4
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.cc27
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.h10
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment_test.cc76
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc14
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.h6
-rw-r--r--tensorflow/compiler/xla/service/shape_inference_test.cc16
-rw-r--r--tensorflow/compiler/xla/service/transpose_folding.cc7
-rw-r--r--tensorflow/compiler/xla/service/transpose_folding_test.cc31
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc5
-rw-r--r--tensorflow/compiler/xla/tests/multioutput_fusion_test.cc12
-rw-r--r--tensorflow/compiler/xla/tests/reduce_window_test.cc86
-rw-r--r--tensorflow/contrib/BUILD9
-rw-r--r--tensorflow/contrib/autograph/converters/builtin_functions.py41
-rw-r--r--tensorflow/contrib/autograph/converters/builtin_functions_test.py9
-rw-r--r--tensorflow/contrib/autograph/impl/api.py4
-rw-r--r--tensorflow/contrib/autograph/operators/BUILD11
-rw-r--r--tensorflow/contrib/autograph/operators/__init__.py5
-rw-r--r--tensorflow/contrib/autograph/operators/control_flow.py6
-rw-r--r--tensorflow/contrib/autograph/operators/py_builtins.py225
-rw-r--r--tensorflow/contrib/autograph/operators/py_builtins_test.py131
-rw-r--r--tensorflow/contrib/autograph/utils/BUILD23
-rw-r--r--tensorflow/contrib/autograph/utils/__init__.py3
-rw-r--r--tensorflow/contrib/autograph/utils/builtins.py143
-rw-r--r--tensorflow/contrib/autograph/utils/builtins_test.py145
-rw-r--r--tensorflow/contrib/autograph/utils/tensors.py41
-rw-r--r--tensorflow/contrib/autograph/utils/tensors_test.py57
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py13
-rw-r--r--tensorflow/contrib/data/python/ops/interleave_ops.py17
-rw-r--r--tensorflow/contrib/data/python/ops/readers.py6
-rw-r--r--tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py2
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py4
-rw-r--r--tensorflow/contrib/distribute/python/parameter_server_strategy.py2
-rw-r--r--tensorflow/contrib/distribute/python/tpu_strategy.py13
-rw-r--r--tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb4
-rw-r--r--tensorflow/contrib/eager/python/metrics_test.py45
-rw-r--r--tensorflow/contrib/factorization/python/ops/wals.py70
-rw-r--r--tensorflow/contrib/factorization/python/ops/wals_test.py10
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py10
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py2
-rw-r--r--tensorflow/contrib/layers/python/layers/layers_test.py2
-rw-r--r--tensorflow/contrib/lite/RELEASE.md8
-rw-r--r--tensorflow/contrib/lite/build_def.bzl1
-rw-r--r--tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h18
-rw-r--r--tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc13
-rw-r--r--tensorflow/contrib/lite/g3doc/README.md4
-rw-r--r--tensorflow/contrib/lite/g3doc/api_docs/python/index.md10
-rw-r--r--tensorflow/contrib/lite/g3doc/apis.md43
-rw-r--r--tensorflow/contrib/lite/kernels/BUILD53
-rw-r--r--tensorflow/contrib/lite/kernels/activations.cc36
-rw-r--r--tensorflow/contrib/lite/kernels/activations_test.cc70
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc699
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h12
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h13
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h8
-rw-r--r--tensorflow/contrib/lite/kernels/internal/quantization_util.cc210
-rw-r--r--tensorflow/contrib/lite/kernels/internal/quantization_util.h38
-rw-r--r--tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc133
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc36
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h22
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h477
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_utils.h10
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc90
-rw-r--r--tensorflow/contrib/lite/kernels/internal/types.h7
-rw-r--r--tensorflow/contrib/lite/kernels/layer_norm_lstm.cc1316
-rw-r--r--tensorflow/contrib/lite/kernels/layer_norm_lstm_test.cc664
-rw-r--r--tensorflow/contrib/lite/kernels/pad.cc34
-rw-r--r--tensorflow/contrib/lite/kernels/pad_test.cc13
-rw-r--r--tensorflow/contrib/lite/kernels/register.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/relu1.cc59
-rw-r--r--tensorflow/contrib/lite/kernels/relu1_test.cc79
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py25
-rw-r--r--tensorflow/contrib/lite/testing/generated_examples_zip_test.cc6
-rw-r--r--tensorflow/contrib/lite/toco/args.h4
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc29
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc10
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.h5
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export.cc52
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export.h51
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export_test.cc9
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc39
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.h8
-rw-r--r--tensorflow/contrib/lite/toco/toco_cmdline_flags.cc18
-rw-r--r--tensorflow/contrib/lite/toco/toco_flags.proto15
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.cc24
-rw-r--r--tensorflow/contrib/lite/tools/optimize/quantize_weights.cc74
-rw-r--r--tensorflow/contrib/lite/tools/optimize/quantize_weights.h17
-rw-r--r--tensorflow/contrib/lite/tools/optimize/quantize_weights_test.cc30
-rw-r--r--tensorflow/contrib/opt/python/training/elastic_average_optimizer.py14
-rw-r--r--tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py63
-rw-r--r--tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py17
-rw-r--r--tensorflow/contrib/opt/python/training/model_average_optimizer.py8
-rw-r--r--tensorflow/contrib/opt/python/training/model_average_optimizer_test.py40
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py73
-rw-r--r--tensorflow/contrib/saved_model/cc/saved_model/BUILD2
-rw-r--r--tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.cc81
-rw-r--r--tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h3
-rw-r--r--tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils_test.cc123
-rw-r--r--tensorflow/contrib/tensor_forest/BUILD5
-rw-r--r--tensorflow/contrib/tpu/profiler/tf_op_stats.proto5
-rw-r--r--tensorflow/contrib/tpu/proto/optimization_parameters.proto4
-rw-r--r--tensorflow/contrib/tpu/python/tpu/keras_support.py132
-rw-r--r--tensorflow/core/BUILD2
-rw-r--r--tensorflow/core/common_runtime/bfc_allocator.cc2
-rw-r--r--tensorflow/core/common_runtime/eager/context.cc24
-rw-r--r--tensorflow/core/common_runtime/eager/context.h19
-rw-r--r--tensorflow/core/common_runtime/graph_runner.cc4
-rw-r--r--tensorflow/core/common_runtime/placer.cc52
-rw-r--r--tensorflow/core/common_runtime/placer.h2
-rw-r--r--tensorflow/core/common_runtime/placer_test.cc50
-rw-r--r--tensorflow/core/common_runtime/pool_allocator.cc1
-rw-r--r--tensorflow/core/common_runtime/session_state.cc2
-rw-r--r--tensorflow/core/common_runtime/step_stats_collector.cc6
-rw-r--r--tensorflow/core/graph/testlib.cc27
-rw-r--r--tensorflow/core/graph/testlib.h9
-rw-r--r--tensorflow/core/grappler/op_types.cc37
-rw-r--r--tensorflow/core/grappler/op_types.h2
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc10
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc42
-rw-r--r--tensorflow/core/grappler/utils/functions.cc32
-rw-r--r--tensorflow/core/grappler/utils/functions.h13
-rw-r--r--tensorflow/core/kernels/data/BUILD37
-rw-r--r--tensorflow/core/kernels/data/captured_function.cc20
-rw-r--r--tensorflow/core/kernels/data/captured_function.h13
-rw-r--r--tensorflow/core/kernels/data/map_dataset_op.cc6
-rw-r--r--tensorflow/core/kernels/data/single_threaded_executor.cc378
-rw-r--r--tensorflow/core/kernels/data/single_threaded_executor.h60
-rw-r--r--tensorflow/core/kernels/data/single_threaded_executor_test.cc330
-rw-r--r--tensorflow/core/kernels/debug_ops.h4
-rw-r--r--tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h304
-rw-r--r--tensorflow/core/kernels/eigen_benchmark.h74
-rw-r--r--tensorflow/core/kernels/eigen_benchmark_cpu_test.cc15
-rw-r--r--tensorflow/core/kernels/gpu_utils.h3
-rw-r--r--tensorflow/core/kernels/list_kernels.h21
-rw-r--r--tensorflow/core/kernels/merge_v2_checkpoints_op_test.cc4
-rw-r--r--tensorflow/core/kernels/remote_fused_graph_execute_utils.cc26
-rw-r--r--tensorflow/core/kernels/save_restore_tensor.cc9
-rw-r--r--tensorflow/core/kernels/save_restore_v2_ops.cc4
-rw-r--r--tensorflow/core/kernels/string_strip_op.cc2
-rw-r--r--tensorflow/core/kernels/tensor_array_ops.cc2
-rw-r--r--tensorflow/core/kernels/whole_file_read_ops.cc2
-rw-r--r--tensorflow/core/lib/core/errors.h4
-rw-r--r--tensorflow/core/lib/gtl/inlined_vector.h665
-rw-r--r--tensorflow/core/lib/gtl/inlined_vector_test.cc898
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt43
-rw-r--r--tensorflow/core/ops/dataset_ops.cc1
-rw-r--r--tensorflow/core/ops/ops.pbtxt7
-rw-r--r--tensorflow/core/platform/cloud/curl_http_request.cc4
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system.cc14
-rw-r--r--tensorflow/core/platform/cloud/oauth_client.cc4
-rw-r--r--tensorflow/core/platform/cloud/oauth_client_test.cc6
-rw-r--r--tensorflow/core/platform/default/build_config.bzl1
-rw-r--r--tensorflow/core/protobuf/config.proto9
-rw-r--r--tensorflow/python/client/session.py6
-rw-r--r--tensorflow/python/compat/compat.py2
-rw-r--r--tensorflow/python/data/kernel_tests/iterator_ops_test.py87
-rw-r--r--tensorflow/python/data/kernel_tests/map_dataset_op_test.py107
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py4
-rw-r--r--tensorflow/python/eager/backprop_test.py8
-rw-r--r--tensorflow/python/estimator/canned/dnn.py14
-rw-r--r--tensorflow/python/framework/error_interpolation.py14
-rw-r--r--tensorflow/python/framework/error_interpolation_test.py22
-rw-r--r--tensorflow/python/framework/tensor_util.py2
-rw-r--r--tensorflow/python/framework/test_util.py19
-rw-r--r--tensorflow/python/keras/backend.py18
-rw-r--r--tensorflow/python/keras/engine/distributed_training_utils.py22
-rw-r--r--tensorflow/python/keras/engine/network.py4
-rw-r--r--tensorflow/python/keras/engine/training.py87
-rw-r--r--tensorflow/python/keras/engine/training_distributed.py71
-rw-r--r--tensorflow/python/kernel_tests/check_ops_test.py80
-rw-r--r--tensorflow/python/kernel_tests/distributions/bernoulli_test.py196
-rw-r--r--tensorflow/python/kernel_tests/distributions/beta_test.py462
-rw-r--r--tensorflow/python/kernel_tests/distributions/bijector_test.py13
-rw-r--r--tensorflow/python/kernel_tests/distributions/dirichlet_test.py262
-rw-r--r--tensorflow/python/kernel_tests/distributions/exponential_test.py187
-rw-r--r--tensorflow/python/kernel_tests/distributions/gamma_test.py529
-rw-r--r--tensorflow/python/kernel_tests/distributions/laplace_test.py439
-rw-r--r--tensorflow/python/kernel_tests/distributions/normal_test.py613
-rw-r--r--tensorflow/python/kernel_tests/distributions/special_math_test.py35
-rw-r--r--tensorflow/python/kernel_tests/distributions/student_t_test.py505
-rw-r--r--tensorflow/python/kernel_tests/distributions/uniform_test.py354
-rw-r--r--tensorflow/python/kernel_tests/distributions/util_test.py230
-rw-r--r--tensorflow/python/kernel_tests/functional_ops_test.py414
-rw-r--r--tensorflow/python/kernel_tests/list_ops_test.py53
-rw-r--r--tensorflow/python/kernel_tests/py_func_test.py87
-rw-r--r--tensorflow/python/kernel_tests/resource_variable_ops_test.py39
-rw-r--r--tensorflow/python/kernel_tests/rnn_test.py18
-rw-r--r--tensorflow/python/ops/cond_v2_impl.py51
-rw-r--r--tensorflow/python/ops/distributions/distribution.py18
-rw-r--r--tensorflow/stream_executor/blas.h1
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt10
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt10
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.-config-proto.-experimental.pbtxt10
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.-config-proto.pbtxt10
-rw-r--r--tensorflow/tools/ci_build/Dockerfile.gpu1
-rwxr-xr-xtensorflow/tools/ci_build/install/install_deb_packages.sh6
-rwxr-xr-xtensorflow/tools/ci_build/linux/libtensorflow_docker.sh1
-rw-r--r--tensorflow/tools/docs/parser.py26
-rw-r--r--tensorflow/tools/docs/parser_test.py46
-rw-r--r--tensorflow/tools/docs/pretty_docs.py2
-rw-r--r--tensorflow/tools/graph_transforms/freeze_requantization_ranges.cc2
-rw-r--r--tensorflow/tools/graph_transforms/sparsify_gather_test.cc4
-rw-r--r--tensorflow/tools/graph_transforms/transform_graph.cc15
-rw-r--r--tensorflow/tools/graph_transforms/transform_utils.cc2
-rwxr-xr-xtensorflow/workspace.bzl8
252 files changed, 10421 insertions, 6541 deletions
diff --git a/CODEOWNERS b/CODEOWNERS
index 1725a5c471..78f80c8d71 100644
--- a/CODEOWNERS
+++ b/CODEOWNERS
@@ -60,3 +60,5 @@
/tensorflow/contrib/tpu/ @frankchn @saeta @jhseu @sourabhbajaj
/tensorflow/contrib/training/ @joel-shor @ebrevdo
/tensorflow/contrib/util/ @sherrym
+
+/third_party/systemlibs/ @perfinion
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index b5e0a4e98b..661cba5ff0 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -433,6 +433,7 @@ package_group(
"-//third_party/tensorflow/python/estimator",
"//learning/meta_rank/...",
"//tensorflow/...",
+ "//tensorflow_estimator/...",
"//tensorflow_fold/llgtm/...",
"//third_party/py/tensor2tensor/...",
],
diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD
index 2c3a877edf..109b3b37aa 100644
--- a/tensorflow/c/BUILD
+++ b/tensorflow/c/BUILD
@@ -117,6 +117,7 @@ tf_cuda_library(
deps = [
":c_api",
":c_api_internal",
+ "//tensorflow/c/eager:c_api",
"//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags",
"//tensorflow/contrib/tpu:all_ops",
"//tensorflow/core:core_cpu",
diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h
index 6617c5a572..09d482d6df 100644
--- a/tensorflow/c/c_api_experimental.h
+++ b/tensorflow/c/c_api_experimental.h
@@ -20,6 +20,7 @@ limitations under the License.
#include <stdint.h>
#include "tensorflow/c/c_api.h"
+#include "tensorflow/c/eager/c_api.h"
// --------------------------------------------------------------------------
// Experimental C API for TensorFlow.
@@ -131,6 +132,9 @@ TF_CAPI_EXPORT extern void TF_EnqueueNamedTensor(TF_Session* session,
TF_Tensor* tensor,
TF_Status* status);
+TF_CAPI_EXPORT extern TFE_Context* TFE_NewContextFromSession(
+ const TFE_ContextOptions* opts, TF_Session* sess, TF_Status* status);
+
#ifdef __cplusplus
} /* end extern "C" */
#endif
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 1ccae3f138..77e3878a94 100755
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -273,7 +273,20 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
new tensorflow::IntraProcessRendezvous(device_mgr.get());
return new TFE_Context(opts->session_options.options, opts->policy,
- opts->async, std::move(device_mgr), r);
+ opts->async, device_mgr.release(),
+ /*device_mgr_owned*/ true, r);
+}
+
+TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
+ TF_Session* sess, TF_Status* status) {
+ const tensorflow::DeviceMgr* device_mgr = nullptr;
+ status->status = sess->session->LocalDeviceManager(&device_mgr);
+ if (!status->status.ok()) return nullptr;
+ tensorflow::Rendezvous* r =
+ new tensorflow::IntraProcessRendezvous(device_mgr);
+ return new TFE_Context(opts->session_options.options, opts->policy,
+ opts->async, device_mgr, /*device_mgr_owned*/ false,
+ r);
}
void TFE_DeleteContext(TFE_Context* ctx) { delete ctx; }
diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h
index a5c0681e2e..104d52430c 100644
--- a/tensorflow/c/eager/c_api_internal.h
+++ b/tensorflow/c/eager/c_api_internal.h
@@ -62,15 +62,14 @@ struct TFE_ContextOptions {
};
struct TFE_Context {
- explicit TFE_Context(const tensorflow::SessionOptions& opts,
- TFE_ContextDevicePlacementPolicy default_policy,
- bool async,
- std::unique_ptr<tensorflow::DeviceMgr> device_mgr,
- tensorflow::Rendezvous* rendezvous)
+ TFE_Context(const tensorflow::SessionOptions& opts,
+ TFE_ContextDevicePlacementPolicy default_policy, bool async,
+ const tensorflow::DeviceMgr* device_mgr, bool device_mgr_owned,
+ tensorflow::Rendezvous* rendezvous)
: context(opts,
static_cast<tensorflow::ContextDevicePlacementPolicy>(
default_policy),
- async, std::move(device_mgr), rendezvous) {}
+ async, device_mgr, device_mgr_owned, rendezvous) {}
tensorflow::EagerContext context;
};
diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc
index b5667ca0d3..e2affee51f 100644
--- a/tensorflow/compiler/tf2xla/functionalize_cond.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc
@@ -40,26 +40,11 @@ using xla::StatusOr;
namespace tensorflow {
namespace functionalize_cond {
-string DebugString(const CondStateMap::CondNode& node) {
- return node.ToString();
-}
-
// TODO(jpienaar): Move to OutputTensor.
string DebugString(const OutputTensor& tensor) {
return strings::StrCat(tensor.node->name(), ":", tensor.index);
}
-string DebugString(CondStateMap::CondId cond_state) {
- if (cond_state == nullptr || cond_state->empty()) return "[]";
- return strings::StrCat(
- "[",
- absl::StrJoin(*cond_state, ", ",
- [](string* output, const CondStateMap::CondNode& node) {
- strings::StrAppend(output, node.ToString());
- }),
- "]");
-}
-
string Branch_Name(BranchType b) {
switch (b) {
case BranchType::kElseBranch:
@@ -73,6 +58,24 @@ string Branch_Name(BranchType b) {
}
}
+string DebugString(StateMap::CondId cond_state) {
+ if (cond_state == nullptr || cond_state->empty()) return "{}";
+ using value_type = StateMap::CondState::value_type;
+ return strings::StrCat(
+ "{",
+ absl::StrJoin(*cond_state, ", ",
+ [](string* output, const value_type& pred_branch) {
+ const OutputTensor& pred = pred_branch.first;
+ const BranchType& branch = pred_branch.second;
+ if (branch == BranchType::kNeither)
+ strings::StrAppend(output, "d");
+ else
+ strings::StrAppend(output, "s(", DebugString(pred), ",",
+ Branch_Name(branch), ")");
+ }),
+ "}");
+}
+
// Returns the predicate of a switch.
Status GetSwitchPredicate(const Node& switch_node, OutputTensor* pred) {
const Edge* pred_edge;
@@ -86,64 +89,65 @@ Status GetSwitchPredicate(const Node& switch_node, OutputTensor* pred) {
return Status::OK();
}
-CondStateMap::CondNode::CondNode(Type type, Node* switch_node,
- BranchType branch)
- : type(type), branch(branch) {
- if (type == Type::kSwitch) {
- TF_CHECK_OK(GetSwitchPredicate(*switch_node, &predicate));
- }
-}
-
-string CondStateMap::CondNode::ToString() const {
- switch (type) {
- case Type::kSwitch:
- return strings::StrCat("s(", DebugString(predicate), ",",
- Branch_Name(branch), ")");
- case Type::kMerge:
- return "m";
- case Type::kDead:
- return "d";
- }
+Status GetSwitchValue(const Node& switch_node, OutputTensor* val) {
+ const Edge* val_edge;
+ TF_RETURN_IF_ERROR(switch_node.input_edge(0, &val_edge));
+ *val = OutputTensor(val_edge->src(), val_edge->src_output());
+ return Status::OK();
}
-bool CondStateMap::CondNode::operator==(const CondNode& other) const {
- if (type != Type::kSwitch) return type == other.type;
- return type == other.type && predicate == other.predicate &&
- branch == other.branch;
+bool StateMap::OutputTensorLess::operator()(const OutputTensor& lhs,
+ const OutputTensor& rhs) const {
+ return (lhs.node->id() < rhs.node->id()) ||
+ (lhs.node->id() == rhs.node->id() && lhs.index < rhs.index);
}
-bool CondStateMap::CondNode::operator!=(const CondNode& other) const {
- return !(*this == other);
-}
+struct CondStateLess {
+ bool operator()(const StateMap::CondState::value_type& lhs,
+ const StateMap::CondState::value_type& rhs) const {
+ if (StateMap::OutputTensorLess().operator()(lhs.first, rhs.first))
+ return true;
+ if (lhs.first.node->id() == rhs.first.node->id() &&
+ lhs.first.index == rhs.first.index)
+ return lhs.second < rhs.second;
+ return false;
+ }
+};
-CondStateMap::CondStateMap(Graph* graph) {
+StateMap::StateMap(Graph* graph) {
node_to_condid_map_.resize(graph->num_node_ids());
+ node_to_ancestorid_map_.resize(graph->num_node_ids());
// Initialize the dead state (empty state is designated with a nullptr).
- dead_id_ = GetUniqueId({CondNode(CondStateMap::CondNode::Type::kDead)});
+ dead_id_ = GetCondId(
+ {std::make_pair(OutputTensor(nullptr, -1), BranchType::kNeither)});
}
-bool CondStateMap::IsDead(CondStateMap::CondId id) const {
- return id == dead_id_;
-}
+bool StateMap::IsDead(StateMap::CondId id) const { return id == dead_id_; }
-bool CondStateMap::IsEmpty(CondStateMap::CondId id) const {
- return id == nullptr;
-}
+bool StateMap::IsEmpty(StateMap::CondId id) const { return id == nullptr; }
-size_t CondStateMap::CondHash::operator()(
- const CondStateMap::CondNode& item) const {
- return Hash64Combine(Hash64Combine(OutputTensor::Hash()(item.predicate),
- hash<BranchType>()(item.branch)),
- hash<CondStateMap::CondNode::Type>()(item.type));
+size_t StateMap::Hash::operator()(const StateMap::CondState& map) const {
+ if (map.empty()) return 0;
+ // Compute hash of the front element.
+ auto it = map.begin();
+ size_t h = Hash64Combine(OutputTensor::Hash()(it->first),
+ hash<BranchType>()(it->second));
+ for (++it; it != map.end(); ++it) {
+ // Combine the has with the different elements in the map.
+ h = Hash64Combine(h, Hash64Combine(OutputTensor::Hash()(it->first),
+ hash<BranchType>()(it->second)));
+ }
+ return h;
}
-size_t CondStateMap::CondHash::operator()(
- const CondStateMap::CondState& vec) const {
- if (vec.empty()) return 0;
- size_t h = (*this)(vec.front());
- auto it = vec.begin();
- for (++it; it != vec.end(); ++it) {
- h = Hash64Combine(h, (*this)(*it));
+size_t StateMap::Hash::operator()(const StateMap::AncestorState& map) const {
+ if (map.empty()) return 0;
+ // Compute hash of the front element.
+ auto it = map.begin();
+ size_t h = hash<Node*>()(*it);
+ for (++it; it != map.end(); ++it) {
+ // Combine the has with the different elements in the map.
+ h = Hash64Combine(h, hash<Node*>()(*it));
}
return h;
}
@@ -176,49 +180,71 @@ string DebugString(const CondArgNodes& nodes) {
"]");
}
-CondStateMap::CondId CondStateMap::LookupId(const Node* node) const {
+StateMap::CondId StateMap::LookupCondId(const Node* node) const {
if (node->id() < node_to_condid_map_.size())
return node_to_condid_map_[node->id()];
- return added_node_mapping_.at(node->id());
+ return added_node_condid_mapping_.at(node->id());
}
-CondStateMap::CondId CondStateMap::GetUniqueId(
- const CondStateMap::CondState& state) {
+StateMap::CondId StateMap::GetCondId(const StateMap::CondState& state) {
if (state.empty()) return nullptr;
return &*condstate_set_.insert(state).first;
}
-const CondStateMap::CondState& CondStateMap::LookupState(
- const Node* node) const {
- return *LookupId(node);
-}
-
-void CondStateMap::ResetId(const Node* node, CondStateMap::CondId id) {
+void StateMap::ResetCondId(const Node* node, StateMap::CondId id) {
if (node->id() < node_to_condid_map_.size())
node_to_condid_map_[node->id()] = id;
else
- added_node_mapping_[node->id()] = id;
+ added_node_condid_mapping_[node->id()] = id;
+}
+
+StateMap::AncestorId StateMap::LookupAncestorId(const Node* node) const {
+ if (node->id() < node_to_ancestorid_map_.size())
+ return node_to_ancestorid_map_[node->id()];
+ return added_node_ancestorid_mapping_.at(node->id());
}
-void CondStateMap::MarkDead(const Node* node) { ResetId(node, dead_id_); }
+StateMap::AncestorId StateMap::GetAncestorId(
+ const StateMap::AncestorState& state) {
+ if (state.empty()) return nullptr;
+ return &*ancestorstate_set_.insert(state).first;
+}
-string CondStateMap::CondStateToString(const Node* node) const {
- return CondStateToString(LookupId(node));
+void StateMap::ResetAncestorId(const Node* node, StateMap::AncestorId id) {
+ if (node->id() < node_to_ancestorid_map_.size())
+ node_to_ancestorid_map_[node->id()] = id;
+ else
+ added_node_ancestorid_mapping_[node->id()] = id;
}
-string CondStateMap::CondStateToString(CondStateMap::CondId id) const {
+const StateMap::CondState& StateMap::LookupState(const Node* node) const {
+ return *LookupCondId(node);
+}
+
+void StateMap::MarkDead(const Node* node) { ResetCondId(node, dead_id_); }
+
+string StateMap::CondStateToString(const Node* node) const {
+ return CondStateToString(LookupCondId(node));
+}
+
+string StateMap::CondStateToString(StateMap::CondId id) const {
return DebugString(id);
}
+string StateMap::AncestorStateToString(const Node* node) const {
+ if (auto id = LookupAncestorId(node)) return NodesToString(*id);
+ return "{}";
+}
+
FunctionalizeCond::FunctionalizeCond(Graph* graph,
FunctionLibraryDefinition* library)
- : cond_state_map_(graph), library_(library), graph_(graph) {}
+ : state_map_(graph), library_(library), graph_(graph) {}
// Class representing the merge/switch nodes that will become a conditional.
class Conditional {
public:
Conditional(OutputTensor predicate, FunctionalizeCond* parent,
- CondStateMap* cond_state_map);
+ StateMap* cond_state_map);
// Adds merge node that is part of this conditional.
Status AddMerge(Node* m);
@@ -247,6 +273,10 @@ class Conditional {
// Adds switch node that is part of this conditional.
Status AddSwitch(Node* s);
+ // Adds a switch node along the edge and rewire the edge to go via the switch.
+ Status AddSwitchNodeAlongEdge(const Edge* edge, BranchType branch,
+ Graph* graph);
+
// Internal name of conditional. The name is based on the first merge node
// added.
string name() const;
@@ -255,7 +285,7 @@ class Conditional {
FunctionalizeCond* parent_;
// Mapping between nodes and their cond state.
- CondStateMap* cond_state_map_;
+ StateMap* state_map_;
// The predicate of the conditional.
OutputTensor predicate_;
@@ -292,8 +322,8 @@ class Conditional {
};
Conditional::Conditional(OutputTensor predicate, FunctionalizeCond* parent,
- CondStateMap* cond_state_map)
- : parent_(parent), cond_state_map_(cond_state_map), predicate_(predicate) {}
+ StateMap* cond_state_map)
+ : parent_(parent), state_map_(cond_state_map), predicate_(predicate) {}
Status Conditional::AddMerge(Node* m) {
merges_.insert(m);
@@ -397,6 +427,35 @@ Status Conditional::BuildArgumentNodes() {
return Status::OK();
}
+Status Conditional::AddSwitchNodeAlongEdge(const Edge* edge, BranchType branch,
+ Graph* graph) {
+ // Previously we had edge:
+ // src:src_output ---- edge ----> dst:dst_input
+ // post this we have (in graph)
+ // src:src_output --> switch<pred> --- new_edge --> dst:dst_input
+
+ // TODO(jpienaar): One could keep a map caching the extra switch nodes added
+ // to avoid adding another switch to feed a value for which a switch was
+ // already added.
+ Node* switch_node;
+ Node* src = edge->src();
+ int src_output = edge->src_output();
+ TF_RETURN_IF_ERROR(
+ NodeBuilder(graph->NewName(strings::StrCat(src->name(), "_added_switch")),
+ "Switch")
+ .Input(src, src_output)
+ .Input(const_cast<Node*>(predicate_.node), predicate_.index)
+ .Finalize(graph, &switch_node));
+ state_map_->ResetCondId(switch_node, state_map_->LookupCondId(src));
+ state_map_->ResetAncestorId(switch_node, state_map_->LookupAncestorId(src));
+
+ Node* dst = edge->dst();
+ int dst_input = edge->dst_input();
+ graph->RemoveEdge(edge);
+ graph->AddEdge(switch_node, static_cast<int>(branch), dst, dst_input);
+ return AddSwitch(switch_node);
+}
+
Status Conditional::ExtractBodies(Graph* graph) {
VLOG(2) << "Extracting bodies for " << name();
for (auto b : {BranchType::kElseBranch, BranchType::kThenBranch}) {
@@ -405,16 +464,16 @@ Status Conditional::ExtractBodies(Graph* graph) {
}
auto find_branch = [&](const Edge* e) {
- const auto& id = cond_state_map_->LookupId(e->src());
+ const auto& id = state_map_->LookupCondId(e->src());
return IsSwitch(e->src()) ? BranchType(e->src_output())
- : cond_state_map_->FindBranchOf(id, predicate_);
+ : state_map_->FindBranchOf(id, predicate_);
};
std::array<std::vector<Node*>, 2> stacks;
VLOG(5) << "Merges: " << NodesToString(merges_);
for (Node* m : merges_) {
VLOG(5) << "For merge: " << m->DebugString() << " "
- << cond_state_map_->CondStateToString(m);
+ << state_map_->CondStateToString(m);
for (auto e : m->in_edges()) {
if (e->IsControlEdge()) continue;
BranchType branch = find_branch(e);
@@ -422,7 +481,8 @@ Status Conditional::ExtractBodies(Graph* graph) {
branch == BranchType::kElseBranch)
<< "Error: " << e->src()->name()
<< " is not on either then or else branch (" << Branch_Name(branch)
- << ").";
+ << ") for predicate " << DebugString(predicate_) << " ["
+ << DebugString(state_map_->LookupCondId(e->src())) << "].";
Node* src = e->src();
if (IsSwitch(src)) {
// Switch node outputs and dependencies are handled separately.
@@ -456,8 +516,8 @@ Status Conditional::ExtractBodies(Graph* graph) {
if (IsMerge(dst)) continue;
Node* src = e->src();
- auto dst_id = cond_state_map_->LookupId(dst);
- auto src_id = cond_state_map_->LookupId(src);
+ auto dst_id = state_map_->LookupCondId(dst);
+ auto src_id = state_map_->LookupCondId(src);
if (dst_id != src_id) {
if (e->IsControlEdge()) {
external_control_outputs_.push_back(e->src());
@@ -480,8 +540,11 @@ Status Conditional::ExtractBodies(Graph* graph) {
}
}
- // Copying incomming edges to dst node.
- for (const Edge* e : n->in_edges()) {
+ // Copying incomming edges to dst node. Iterate over a copy of the edges
+ // as they could be mutated during iteration.
+ std::vector<const Edge*> in_edges(n->in_edges().begin(),
+ n->in_edges().end());
+ for (const Edge* e : in_edges) {
Node* src = e->src();
// Skip src/dst node.
if (!src->IsOp()) continue;
@@ -494,8 +557,8 @@ Status Conditional::ExtractBodies(Graph* graph) {
}
// Verify input is from the same context.
- auto src_id = cond_state_map_->LookupId(src);
- auto dst_id = cond_state_map_->LookupId(dst);
+ auto src_id = state_map_->LookupCondId(src);
+ auto dst_id = state_map_->LookupCondId(dst);
if (IsMerge(dst) || src_id == dst_id) {
// TODO(jpienaar): The merge case can be more strict.
if (node_map.at(src->id()) == nullptr) {
@@ -506,18 +569,25 @@ Status Conditional::ExtractBodies(Graph* graph) {
external_control_inputs_.push_back(src);
} else {
// This shouldn't happen, this means we have an external data input
- // not entering via a switch node. Work around this for constant
- // nodes as some constant nodes are inserted without the required
- // control context dominance.
+ // not entering via a switch node. Work around this by for
+ // * constant nodes copy them;
+ // * non-constant nodes, insert a switch along the edge;
if (IsConstant(src)) {
node_map.at(src->id()) = output->CopyNode(src);
} else {
- return errors::InvalidArgument(
- "Graph contains node ", FormatNodeForError(*src),
- " that feeds into node ", FormatNodeForError(*dst),
- " but these nodes are in different control contexts (",
- DebugString(src_id), " vs ", DebugString(dst_id),
- " (detected during in edge testing)");
+ StateMap::CondState state = *dst_id;
+ state.erase(predicate_);
+ if (state_map_->GetCondId(state) == src_id) {
+ TF_RETURN_IF_ERROR(AddSwitchNodeAlongEdge(e, branch, graph));
+ continue;
+ } else {
+ return errors::InvalidArgument(
+ "Graph contains node ", FormatNodeForError(*src),
+ " that feeds into node ", FormatNodeForError(*dst),
+ " but these nodes are in different control contexts (",
+ DebugString(src_id), " vs ", DebugString(dst_id),
+ " (detected during in edge testing)");
+ }
}
}
@@ -639,7 +709,8 @@ Status Conditional::BuildIfNode(Graph* graph,
VLOG(3) << "Build If node";
NodeDef if_def;
TF_RETURN_IF_ERROR(builder.Finalize(&if_def));
- TF_ASSIGN_OR_RETURN(if_node_, parent_->AddIfNode(if_def, *merges_.begin()));
+ TF_ASSIGN_OR_RETURN(if_node_,
+ parent_->AddIfNode(if_def, *merges_.begin(), predicate_));
return Status::OK();
}
@@ -699,7 +770,8 @@ Status Conditional::AddOutputEdges(Graph* graph) {
Status Conditional::BuildAndReplace(Graph* graph,
FunctionLibraryDefinition* library) {
- VLOG(1) << "Build If and replace merge nodes " << name();
+ VLOG(1) << "Build If and replace merge nodes "
+ << NodesToString(this->merges_);
if (replaced_) return Status::OK();
TF_RETURN_IF_ERROR(ExtractBodies(graph));
@@ -719,7 +791,7 @@ Status Conditional::BuildAndReplace(Graph* graph,
TF_RETURN_IF_ERROR(AddInputEdges(graph));
TF_RETURN_IF_ERROR(AddOutputEdges(graph));
TF_RETURN_IF_ERROR(parent_->PropagateUpdatedState(if_node_));
- for (Node* m : merges_) cond_state_map_->MarkDead(m);
+ for (Node* m : merges_) state_map_->MarkDead(m);
// Check that the if_node doesn't feed into itself.
TF_RETURN_WITH_CONTEXT_IF_ERROR(
@@ -735,55 +807,41 @@ string Conditional::name() const {
return strings::StrCat((*merges_.begin())->name(), "_if");
}
-bool CondStateMap::ScopeIn(CondStateMap::CondId id,
- CondStateMap::CondId* scope) {
- if (id == nullptr) {
- *scope = nullptr;
- return true;
- }
- CondState state;
- for (const CondNode& node : *id) {
- if (node.type == CondNode::Type::kSwitch) {
- state.push_back(node);
- }
- if (node.type == CondNode::Type::kMerge) {
- if (state.empty()) {
- return false;
- }
- DCHECK(state.back().type == CondNode::Type::kSwitch &&
- state.back().branch == BranchType::kBoth);
- state.pop_back();
- }
- }
- *scope = GetUniqueId(state);
- return true;
-}
-
Status FunctionalizeCond::AddIdentityNode(const Node* replacee, Node* if_node,
int port) {
Node* id;
TF_RETURN_IF_ERROR(NodeBuilder(replacee->name(), "Identity")
.Input(if_node, port)
.Finalize(graph_, &id));
- cond_state_map_.ResetId(id, cond_state_map_.LookupId(if_node));
+ state_map_.ResetCondId(id, state_map_.LookupCondId(if_node));
+ state_map_.ResetAncestorId(id, state_map_.LookupAncestorId(if_node));
return Status::OK();
}
StatusOr<Node*> FunctionalizeCond::AddIfNode(const NodeDef& def,
- const Node* replacee) {
+ const Node* replacee,
+ const OutputTensor& predicate) {
Status status;
Node* ret = graph_->AddNode(def, &status);
TF_RETURN_IF_ERROR(status);
- CondStateMap::CondState state = cond_state_map_.LookupState(replacee);
- state.pop_back();
VLOG(1) << "Adding If for " << replacee->name();
- cond_state_map_.ResetId(ret, cond_state_map_.GetUniqueId(state));
+ StateMap::CondId id = state_map_.LookupCondId(replacee);
+ if (id) {
+ StateMap::CondState state = *id;
+ state.erase(predicate);
+ state_map_.ResetCondId(ret, state_map_.GetCondId(state));
+ } else {
+ state_map_.ResetCondId(ret, nullptr);
+ }
+
+ state_map_.ResetAncestorId(ret, state_map_.LookupAncestorId(replacee));
+
return ret;
}
Status FunctionalizeCond::PropagateUpdatedState(const Node* replacee) {
VLOG(2) << "Propagating update state for " << replacee->name() << " "
- << cond_state_map_.CondStateToString(replacee);
+ << state_map_.CondStateToString(replacee);
// Redo topological sort as the order could have changed.
// TODO(jpienaar): The original topological order could also be updated
// dynamically if needed.
@@ -801,10 +859,10 @@ Status FunctionalizeCond::PropagateUpdatedState(const Node* replacee) {
if (changed.find(*it) != changed.end()) {
// Update the node state.
Node* n = *it;
- CondStateMap::CondId old_state = cond_state_map_.LookupId(n);
- cond_state_map_.ResetId(n, nullptr);
+ StateMap::CondId old_state = state_map_.LookupCondId(n);
+ state_map_.ResetCondId(n, nullptr);
TF_RETURN_IF_ERROR(DetermineCondState(n));
- if (cond_state_map_.LookupId(n) != old_state) {
+ if (state_map_.LookupCondId(n) != old_state) {
for (auto out : n->out_nodes())
if (out->IsOp()) changed.insert(out);
}
@@ -825,127 +883,44 @@ BranchType MeetBranch(const BranchType& lhs, const BranchType& rhs) {
return BranchType::kNeither;
}
-CondStateMap::ContainsResult CondStateMap::LhsHoldsWhereverRhsHolds(
- CondStateMap::CondId lhs, CondStateMap::CondId rhs) {
- CondId lhs_scope;
- CondId rhs_scope;
- bool could_determine_scope = ScopeIn(lhs, &lhs_scope);
- could_determine_scope = could_determine_scope && ScopeIn(rhs, &rhs_scope);
- if (!could_determine_scope) return kIncomparable;
-
- // Returns whether a contains b.
- auto contains = [&](CondId a, CondId b) {
- // Handle empty states.
- if (a == nullptr && b != nullptr) return true;
- if (a == nullptr && b == nullptr) return true;
- if (a != nullptr && b == nullptr) return false;
-
- if (a->size() > b->size()) return false;
- auto a_it = a->begin();
- auto b_it = b->begin();
- while (a_it != a->end()) {
- if (*a_it != *b_it) {
- if (!(a_it->predicate == b_it->predicate)) return false;
- BranchType mb = MeetBranch(a_it->branch, b_it->branch);
- if (mb != b_it->branch) return false;
- }
- ++a_it;
- ++b_it;
- }
- return true;
- };
-
- bool lhs_contains_rhs = contains(lhs_scope, rhs_scope);
- bool rhs_contains_lhs = contains(rhs_scope, lhs_scope);
- if (lhs_contains_rhs && rhs_contains_lhs) return kEqual;
- if (lhs_contains_rhs) return kLhsContainsRhs;
- if (rhs_contains_lhs) return kRhsContainsLhs;
- return kIncomparable;
-}
-
-BranchType CondStateMap::FindBranchOf(CondId id, OutputTensor predicate) const {
+BranchType StateMap::FindBranchOf(CondId id, OutputTensor predicate) const {
if (IsEmpty(id)) return BranchType::kNeither;
- absl::optional<BranchType> b;
const CondState& nodes = *id;
- for (auto it = nodes.rbegin(); it != nodes.rend(); ++it) {
- if (it->type == CondStateMap::CondNode::Type::kSwitch &&
- it->predicate == predicate) {
- if (b.has_value()) {
- b = MeetBranch(*b, it->branch);
- } else {
- b = it->branch;
- }
- if (*b == BranchType::kNeither) {
- LOG(FATAL) << "Inconsistent state for node: " << DebugString(id);
- }
- }
- }
- return b.has_value() ? *b : BranchType::kNeither;
+ auto it = nodes.find(predicate);
+ if (it == nodes.end()) return BranchType::kNeither;
+ return it->second;
}
-StatusOr<CondStateMap::CondId> FunctionalizeCond::JoinCondStatesNonMerge(
- CondStateMap::CondId src, CondStateMap::CondId dst) {
- VLOG(4) << "Joining src=" << DebugString(src) << " [" << src
+StatusOr<StateMap::CondId> FunctionalizeCond::JoinCondStatesNonMerge(
+ StateMap::CondId src, StateMap::CondId dst) {
+ VLOG(5) << "Joining src=" << DebugString(src) << " [" << src
<< "] and dst=" << DebugString(dst) << " [" << dst << "]";
- if (cond_state_map_.IsEmpty(dst) || cond_state_map_.IsDead(src)) return src;
- if (cond_state_map_.IsDead(dst)) return dst;
+ if (state_map_.IsEmpty(dst) || state_map_.IsDead(src)) return src;
+ if (state_map_.IsDead(dst) || state_map_.IsEmpty(src)) return dst;
// Nothing to do if the CondState is the same.
if (src == dst) return src;
- CondStateMap::CondId src_scope;
- CondStateMap::CondId dst_scope;
- if (!cond_state_map_.ScopeIn(src, &src_scope))
- return errors::Unimplemented(
- "Predicates that must hold for node to execute are invalid! ",
- DebugString(src));
- if (!cond_state_map_.ScopeIn(dst, &dst_scope))
- return errors::Unimplemented(
- "Predicates that must hold for node to execute are invalid! ",
- DebugString(dst));
-
- auto result = cond_state_map_.LhsHoldsWhereverRhsHolds(src_scope, dst_scope);
- switch (result) {
- case CondStateMap::kIncomparable:
- return errors::InvalidArgument(
- "Graph contains node with inputs predicated on incompatible "
- "predicates: ",
- DebugString(src), " and ", DebugString(dst));
- case CondStateMap::kEqual:
- // If both respect the same predicates, propagate the longer constraint.
- if ((src != nullptr && dst == nullptr) ||
- (src != nullptr && dst != nullptr && src->size() > dst->size()))
- return src;
- else
- return dst;
- case CondStateMap::kLhsContainsRhs:
- // src contains dst, so dst is already more restrictive.
- return dst;
- case CondStateMap::kRhsContainsLhs:
- // dst contains src, so src is more restrictive.
- return src;
- }
-}
-
-StatusOr<CondStateMap::CondState::const_iterator>
-FindThenElseSwitchForPredicate(const OutputTensor& pred,
- CondStateMap::CondId id) {
- for (auto it = id->begin(); it != id->end(); ++it) {
- // Along every path one there can be only one instance of a then or else
- // switch for a given predicate, so return once found.
- if (it->type == CondStateMap::CondNode::Type::kSwitch &&
- it->predicate == pred &&
- (it->branch == BranchType::kThenBranch ||
- it->branch == BranchType::kElseBranch))
- return it;
+ StateMap::CondState both = *src;
+ for (const auto& kv : *dst) {
+ auto it = both.find(kv.first);
+ if (it == both.end()) {
+ both.insert(kv);
+ } else {
+ if (it->second != kv.second) {
+ return errors::InvalidArgument(
+ "Graph contains node with inputs predicated on incompatible "
+ "predicates: ",
+ DebugString(src), " and ", DebugString(dst));
+ }
+ }
}
- return errors::Internal("Unable to find then/else branch with predicate ",
- DebugString(pred), " for ", DebugString(id));
+ return state_map_.GetCondId(both);
}
-StatusOr<CondStateMap::CondId> FunctionalizeCond::JoinCondStatesMerge(
- CondStateMap::CondId src, CondStateMap::CondId dst) {
+StatusOr<StateMap::CondId> FunctionalizeCond::JoinCondStatesMerge(
+ Node* merge, StateMap::CondId src, StateMap::CondId dst) {
// Determine the flow state when joining two states for a merge
// node. Combining the two states for a merge node is effectively performing a
// disjunction of the states along the different input edges. For a merge that
@@ -956,91 +931,56 @@ StatusOr<CondStateMap::CondId> FunctionalizeCond::JoinCondStatesMerge(
// followed by s(p, both).
VLOG(4) << "Joining (for merge) " << DebugString(src) << " and "
<< DebugString(dst);
- if (cond_state_map_.IsEmpty(dst)) return src;
-
- if (cond_state_map_.IsDead(src)) return src;
- if (cond_state_map_.IsDead(dst)) return dst;
-
- CondStateMap::CondId src_scope;
- CondStateMap::CondId dst_scope;
- if (!cond_state_map_.ScopeIn(src, &src_scope))
- return errors::Unimplemented(
- "Predicates that must hold for node to execute are invalid! ",
- DebugString(src));
- if (!cond_state_map_.ScopeIn(dst, &dst_scope))
- return errors::Unimplemented(
- "Predicates that must hold for node to execute are invalid! ",
- DebugString(dst));
-
- TF_RET_CHECK(src_scope != nullptr && dst_scope != nullptr)
- << "Illegal merge inputs from outer scope: src=" << DebugString(src)
- << " dst=" << DebugString(dst);
- auto src_it = src_scope->begin();
- auto dst_it = dst_scope->begin();
-
- // Find branch divergent condition.
- OutputTensor pred;
- while (src_it != src_scope->end() && dst_it != dst_scope->end()) {
- if (*src_it != *dst_it) {
- VLOG(5) << "Diverges with: " << DebugString(*src_it) << " and "
- << DebugString(*dst_it);
- if (!(src_it->predicate == dst_it->predicate)) {
- return errors::InvalidArgument(
- "Unable to find common predicate which holds for one input "
- "but not the other of the merge node.");
- }
- pred = src_it->predicate;
- break;
- }
- ++src_it;
- ++dst_it;
- }
-
- if (pred.node == nullptr)
- return errors::InvalidArgument("Unable to determine predicate for merge.");
-
- TF_ASSIGN_OR_RETURN(auto div_src_it,
- FindThenElseSwitchForPredicate(pred, src));
- TF_ASSIGN_OR_RETURN(auto div_dst_it,
- FindThenElseSwitchForPredicate(pred, dst));
- TF_RET_CHECK(*div_src_it != *div_dst_it);
-
- CondStateMap::CondState result;
- // Populate result with the longest/most restrictive path up to the divergent
- // node. For example, if the one input is `[switch(pred:0, then)]` and the
- // other is `[switch(pred:0, both), merge, switch(pred:0, else)]` (as created
- // in gradient of cond test), then the resultant state here should be
- // `[switch(pred:0, both), merge, switch(pred:0, both)]`.
- if (std::distance(src->begin(), div_src_it) >
- std::distance(dst->begin(), div_dst_it)) {
- result.assign(src->begin(), std::next(div_src_it));
+ if (state_map_.IsEmpty(dst)) return src;
+
+ if (state_map_.IsDead(src)) return src;
+ if (state_map_.IsDead(dst)) return dst;
+
+ std::vector<StateMap::CondState::value_type> diff;
+ StateMap::CondState merged;
+ std::set_symmetric_difference(src->begin(), src->end(), dst->begin(),
+ dst->end(), std::back_inserter(diff),
+ CondStateLess());
+ std::set_intersection(src->begin(), src->end(), dst->begin(), dst->end(),
+ std::inserter(merged, merged.begin()), CondStateLess());
+
+ // Update mapping from merge node to predicate.
+ if (diff.size() == 2) {
+ auto pred = diff[0].first;
+ bool different_branches = (diff[0].second != diff[1].second) &&
+ (diff[0].second == BranchType::kThenBranch ||
+ diff[0].second == BranchType::kElseBranch) &&
+ (diff[1].second == BranchType::kThenBranch ||
+ diff[1].second == BranchType::kElseBranch);
+ if (!(pred == diff[1].first) || !different_branches)
+ return errors::InvalidArgument(
+ "Unable to determine predicate for merge node");
+ merge_to_predicate_[merge] = pred;
} else {
- result.assign(dst->begin(), std::next(div_dst_it));
+ return errors::InvalidArgument(
+ "Merge of two inputs that differ on more than one predicate ",
+ DebugString(src), " and ", DebugString(dst));
}
- result.back().branch = BranchType::kBoth;
- return cond_state_map_.GetUniqueId(result);
+
+ return state_map_.GetCondId(merged);
}
-CondStateMap::CondId FunctionalizeCond::StateAlongEdge(const Edge* e) {
+StateMap::CondId FunctionalizeCond::StateAlongEdge(const Edge* e) {
Node* src = e->src();
- CondStateMap::CondId id = cond_state_map_.LookupId(e->src());
- if (IsMerge(src)) {
- CondStateMap::CondState state;
- if (id != nullptr) state = *id;
- state.emplace_back(CondStateMap::CondNode::Type::kMerge);
- return cond_state_map_.GetUniqueId(state);
- }
+ StateMap::CondId id = state_map_.LookupCondId(e->src());
+
+ // Dead nodes only propagate dead state.
+ if (state_map_.IsDead(id)) return id;
+
if (IsSwitch(src)) {
- CondStateMap::CondState state;
+ StateMap::CondState state;
if (id != nullptr) state = *id;
- if (e->IsControlEdge()) {
- state.emplace_back(CondStateMap::CondNode::Type::kSwitch, src,
- BranchType::kBoth);
- } else {
- state.emplace_back(CondStateMap::CondNode::Type::kSwitch, src,
- BranchType(e->src_output()));
+ OutputTensor predicate;
+ TF_CHECK_OK(GetSwitchPredicate(*src, &predicate));
+ if (!e->IsControlEdge()) {
+ state[predicate] = BranchType(e->src_output());
}
- return cond_state_map_.GetUniqueId(state);
+ return state_map_.GetCondId(state);
}
return id;
}
@@ -1049,22 +989,21 @@ Status FunctionalizeCond::DetermineCondStateMerge(Node* dst) {
// Only Merge nodes with two inputs are supported, but if this is a redundant
// merge, then the dead edge may already have been removed (if due to a
// switch) and so the input count would be incorrect.
- if (cond_state_map_.IsDead(cond_state_map_.LookupId(dst)))
- return Status::OK();
+ if (state_map_.IsDead(state_map_.LookupCondId(dst))) return Status::OK();
int data_inputs = 0;
for (auto e : dst->in_edges()) {
Node* src = e->src();
VLOG(5) << "Processing forward flow for merge: " << e->DebugString() << " "
- << cond_state_map_.CondStateToString(src);
+ << state_map_.CondStateToString(src);
if (!src->IsOp()) continue;
if (!e->IsControlEdge()) ++data_inputs;
- CondStateMap::CondId prop = StateAlongEdge(e);
- auto id_or = JoinCondStatesMerge(prop, cond_state_map_.LookupId(dst));
+ StateMap::CondId prop = StateAlongEdge(e);
+ auto id_or = JoinCondStatesMerge(dst, prop, state_map_.LookupCondId(dst));
TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ",
FormatNodeForError(*dst));
- cond_state_map_.ResetId(dst, id_or.ValueOrDie());
+ state_map_.ResetCondId(dst, id_or.ValueOrDie());
}
// Incomplete Merge nodes are not supported.
@@ -1076,27 +1015,20 @@ Status FunctionalizeCond::DetermineCondStateMerge(Node* dst) {
return Status::OK();
}
-Status FunctionalizeCond::DetermineCondState(Node* dst) {
- // The logic for the merge and non-merge case differ: for non-merge it is
- // the most restrictive CondState, while for merge nodes the
- // resultant state is less restrictive than either.
- if (IsMerge(dst)) {
- TF_RETURN_IF_ERROR(DetermineCondStateMerge(dst));
- } else {
- // Handle non-merge join.
- for (auto e : dst->in_edges()) {
- VLOG(5) << "Processing forward flow for: " << e->DebugString() << " "
- << cond_state_map_.CondStateToString(dst);
- Node* src = e->src();
- if (!src->IsOp()) continue;
-
- // Joining the state between the current and propagated state.
- CondStateMap::CondId prop = StateAlongEdge(e);
- auto id_or = JoinCondStatesNonMerge(prop, cond_state_map_.LookupId(dst));
- TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ",
- FormatNodeForError(*dst));
- cond_state_map_.ResetId(dst, id_or.ValueOrDie());
- }
+Status FunctionalizeCond::DetermineCondStateNonMerge(Node* dst) {
+ // Handle non-merge join.
+ for (auto e : dst->in_edges()) {
+ VLOG(4) << "Processing forward flow for: " << e->DebugString() << " "
+ << state_map_.CondStateToString(dst);
+ Node* src = e->src();
+ if (!src->IsOp()) continue;
+
+ // Joining the state between the current and propagated state.
+ StateMap::CondId prop = StateAlongEdge(e);
+ auto id_or = JoinCondStatesNonMerge(prop, state_map_.LookupCondId(dst));
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ",
+ FormatNodeForError(*dst));
+ state_map_.ResetCondId(dst, id_or.ValueOrDie());
}
return Status::OK();
}
@@ -1104,8 +1036,7 @@ Status FunctionalizeCond::DetermineCondState(Node* dst) {
Status FunctionalizeCond::RemoveRedundantMerge(Node* node) {
// Handle redundant merge nodes. A merge node is considered redundant if
// one input edge is dead while the other has a value.
- if (!cond_state_map_.IsDead(cond_state_map_.LookupId(node)))
- return Status::OK();
+ if (!state_map_.IsDead(state_map_.LookupCondId(node))) return Status::OK();
const Edge* non_dead_edge = nullptr;
for (auto e : node->in_edges()) {
@@ -1113,8 +1044,8 @@ Status FunctionalizeCond::RemoveRedundantMerge(Node* node) {
Node* src = e->src();
// Handle merge with dead state.
- const auto& src_id = cond_state_map_.LookupId(src);
- if (!cond_state_map_.IsDead(src_id)) {
+ const auto& src_id = state_map_.LookupCondId(src);
+ if (!state_map_.IsDead(src_id)) {
non_dead_edge = e;
break;
}
@@ -1124,7 +1055,7 @@ Status FunctionalizeCond::RemoveRedundantMerge(Node* node) {
return errors::InvalidArgument("Merge node ", FormatNodeForError(*node),
" has no non-dead inputs.");
}
- cond_state_map_.MarkDead(node);
+ state_map_.MarkDead(node);
delete_nodes_.push_back(node->id());
VLOG(5) << "removing redundant merge: " << node->name();
while (!node->out_edges().empty()) {
@@ -1149,16 +1080,33 @@ Status FunctionalizeCond::RemoveRedundantSwitch(Node* node) {
// along one. The checking of predicate is based on the exact predicate
// (rather than boolean equivalence) and aimed at redundant switches as
// currently generated by gradient code.
+ StateMap::CondId dst_id = state_map_.LookupCondId(node);
+ if (state_map_.IsDead(dst_id)) return Status::OK();
+
+ BranchType b;
OutputTensor pred;
TF_RETURN_IF_ERROR(GetSwitchPredicate(*node, &pred));
- auto dst_id = cond_state_map_.LookupId(node);
- BranchType b = cond_state_map_.FindBranchOf(dst_id, pred);
+
// Determine if we are already on a branch where the switch predicate is
- // true/false.
- if (b != BranchType::kThenBranch && b != BranchType::kElseBranch)
- return Status::OK();
+ // true/false. Consider both the data and predicate to determine if the
+ // node is redundant (skipping over identity node).
+ b = state_map_.FindBranchOf(dst_id, pred);
+ if (b != BranchType::kThenBranch && b != BranchType::kElseBranch) {
+ OutputTensor val;
+ const Edge* e;
+ TF_RETURN_IF_ERROR(node->input_edge(0, &e));
+ val = OutputTensor(e->src(), e->src_output());
+ while (IsIdentity(val.node)) {
+ TF_RETURN_IF_ERROR(val.node->input_edge(0, &e));
+ val = OutputTensor(e->src(), e->src_output());
+ }
+ b = state_map_.FindBranchOf(dst_id, val);
+ if (b != BranchType::kThenBranch && b != BranchType::kElseBranch)
+ return Status::OK();
+ }
- VLOG(5) << "Redundant switch " << node->name();
+ VLOG(5) << "Redundant switch " << node->name() << " " << Branch_Name(b) << " "
+ << DebugString(dst_id);
const Edge* value_edge;
TF_RETURN_IF_ERROR(node->input_edge(0, &value_edge));
Node* val_node = value_edge->src();
@@ -1171,19 +1119,19 @@ Status FunctionalizeCond::RemoveRedundantSwitch(Node* node) {
graph_->RemoveEdge(e);
if (switch_branch == Graph::kControlSlot) {
if (IsMerge(dst_node)) {
- auto id_or =
- JoinCondStatesMerge(dst_id, cond_state_map_.LookupId(dst_node));
+ auto id_or = JoinCondStatesMerge(dst_node, dst_id,
+ state_map_.LookupCondId(dst_node));
TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ",
FormatNodeForError(*dst_node));
- cond_state_map_.ResetId(dst_node, id_or.ValueOrDie());
+ state_map_.ResetCondId(dst_node, id_or.ValueOrDie());
} else {
auto id_or =
- JoinCondStatesNonMerge(dst_id, cond_state_map_.LookupId(dst_node));
+ JoinCondStatesNonMerge(dst_id, state_map_.LookupCondId(dst_node));
TF_RETURN_IF_ERROR(id_or.status());
- cond_state_map_.ResetId(dst_node, id_or.ValueOrDie());
+ state_map_.ResetCondId(dst_node, id_or.ValueOrDie());
}
} else if (BranchType(switch_branch) != b) {
- cond_state_map_.MarkDead(dst_node);
+ state_map_.MarkDead(dst_node);
delete_nodes_.push_back(dst_node->id());
continue;
}
@@ -1195,20 +1143,47 @@ Status FunctionalizeCond::RemoveRedundantSwitch(Node* node) {
return Status::OK();
}
-Status FunctionalizeCond::DetermineCondStates(
- std::vector<Node*> rev_topo_order) {
+Status FunctionalizeCond::DetermineStates(std::vector<Node*> rev_topo_order) {
// The state that is propagated along the given edge.
for (auto it = rev_topo_order.rbegin(); it != rev_topo_order.rend(); ++it) {
Node* dst = *it;
TF_RETURN_IF_ERROR(DetermineCondState(dst));
+ TF_RETURN_IF_ERROR(DetermineAncestorState(dst));
if (IsSwitch(dst)) TF_RETURN_IF_ERROR(RemoveRedundantSwitch(dst));
if (IsMerge(dst)) TF_RETURN_IF_ERROR(RemoveRedundantMerge(dst));
- VLOG(5) << dst->name() << " :: " << cond_state_map_.CondStateToString(dst);
+ VLOG(5) << dst->name() << " :: " << state_map_.CondStateToString(dst)
+ << " @ " << state_map_.AncestorStateToString(dst);
+ if (VLOG_IS_ON(10)) DumpGraphWithCondState("cond_it");
}
return Status::OK();
}
+Status FunctionalizeCond::DetermineAncestorState(Node* dst) {
+ StateMap::AncestorId id = nullptr;
+ StateMap::AncestorState state;
+
+ auto insert = [&](StateMap::AncestorId id, Node* src) {
+ auto other_id = state_map_.LookupAncestorId(src);
+ if (other_id != id && other_id != nullptr) {
+ state.insert(other_id->begin(), other_id->end());
+ }
+ if (IsSwitch(src) || IsMerge(src)) {
+ state.insert(src);
+ }
+ return state_map_.GetAncestorId(state);
+ };
+
+ // Compute the union of all the switch/merge nodes that affects the input of
+ // dst.
+ for (auto e : dst->in_edges()) {
+ Node* src = e->src();
+ id = insert(id, src);
+ }
+ state_map_.ResetAncestorId(dst, id);
+ return Status::OK();
+}
+
void FunctionalizeCond::DeleteReachableNodes() {
// Delete all nodes that have been extracted or are reachable from
// deleted/dead nodes. The input and outgoing edges should have already been
@@ -1239,16 +1214,8 @@ void FunctionalizeCond::SortMergeNodes(std::vector<Node*>* merge_order) {
inner_to_outer_merge_order.reserve(merge_order->size());
for (auto it = merge_order->rbegin(); it != merge_order->rend(); ++it) {
Node* merge = *it;
- CondStateMap::CondId id = cond_state_map_.LookupId(merge);
- int depth = 0;
- for (auto cond_node_it = id->begin(); cond_node_it != id->end();
- ++cond_node_it) {
- if (cond_node_it->type == CondStateMap::CondNode::Type::kSwitch &&
- (cond_node_it->branch == BranchType::kThenBranch ||
- cond_node_it->branch == BranchType::kElseBranch)) {
- ++depth;
- }
- }
+ StateMap::CondId id = state_map_.LookupCondId(merge);
+ int depth = id != nullptr ? id->size() : 0;
inner_to_outer_merge_order.emplace_back(depth, merge);
}
std::stable_sort(
@@ -1271,10 +1238,10 @@ Status FunctionalizeCond::FunctionalizeInternal() {
// determine deeper equivalence). We shall refer to this structure as the
// CondState;
// 3. Sort the merge nodes by nesting depth;
- // 4. Extract merge nodes together that have the same CondState and whose
- // input nodes have the same state from the innermost to the outermost into
- // IfOps; Note: In the above only nodes paths that converge to a merge node
- // will be considered for removal.
+ // 4. Extract merge nodes together that have the same CondState and
+ // AncestorState from the innermost to the outermost into IfOps;
+ // Note: In the above only nodes that feed into a merge node will be
+ // considered for functionalization.
// Perform a DFS over the graph and
// * Determine the reverse topological order of the nodes (there should be no
@@ -1306,40 +1273,40 @@ Status FunctionalizeCond::FunctionalizeInternal() {
return Status::OK();
}
- TF_RETURN_IF_ERROR(DetermineCondStates(std::move(rev_topo_order)));
-
+ TF_RETURN_IF_ERROR(DetermineStates(std::move(rev_topo_order)));
if (VLOG_IS_ON(4)) DumpGraphWithCondState("cond_id");
// Sort the merge nodes from innermost outwards.
SortMergeNodes(&merge_order);
- // Extract from innermost out.
- for (auto it = merge_order.begin(); it != merge_order.end(); ++it) {
- Node* merge = *it;
- auto id = cond_state_map_.LookupId(merge);
- if (cond_state_map_.IsDead(id)) continue;
-
- // Construct a Conditional with the predicate of the merge (which is the
- // last entry of the CondState for the merge) and this as parent.
- DCHECK(id->back().predicate.node != nullptr);
- Conditional cond(id->back().predicate, this, &cond_state_map_);
- TF_RETURN_IF_ERROR(cond.AddMerge(merge));
-
- // Find all merge nodes with the same CondId. This is done repeatedly as
- // the CondId can change due replaced conditionals. E.g., the one branch
- // could previously have had a conditional nested in it, and so would have
- // had CondState with sub-state [switch(p,b),m] (where p is some predicate),
- // post removing the nested conditional that sub-state would no longer be
- // path of the propagated state along that path.
- auto end = merge_order.end();
- for (auto merge_candidate_it = std::next(it); merge_candidate_it != end;
- ++merge_candidate_it) {
- auto merge_candidate_it_id =
- cond_state_map_.LookupId(*merge_candidate_it);
- if (merge_candidate_it_id != id) continue;
- TF_RETURN_IF_ERROR(cond.AddMerge(*merge_candidate_it));
+ // Cluster merge nodes by CondId and AncestorId in order of nesting.
+ using ClusterPair = std::pair<StateMap::CondId, StateMap::AncestorId>;
+ std::deque<std::vector<Node*>> merge_clusters;
+ std::map<ClusterPair, int> merge_cluster_index;
+ for (Node* merge : merge_order) {
+ auto cond_id = state_map_.LookupCondId(merge);
+ if (state_map_.IsDead(cond_id)) continue;
+
+ ClusterPair key =
+ std::make_pair(cond_id, state_map_.LookupAncestorId(merge));
+ auto idx = merge_cluster_index.find(key);
+ if (idx == merge_cluster_index.end()) {
+ merge_cluster_index[key] = merge_clusters.size();
+ merge_clusters.push_back({merge});
+ } else {
+ merge_clusters[idx->second].emplace_back(merge);
}
+ }
+ // Extract the conditionals from inner most to outer most. Extracting from
+ // innermost to outermost enables the extraction pass to stop once it
+ // encounters a Switch node instead of having to keep track of Switch/Merge
+ // nodes seen.
+ for (const auto& cluster : merge_clusters) {
+ // Construct a Conditional with the predicate of the merge.
+ Conditional cond(merge_to_predicate_.at(cluster.front()), this,
+ &state_map_);
+ for (Node* merge : cluster) TF_RETURN_IF_ERROR(cond.AddMerge(merge));
TF_RETURN_IF_ERROR(cond.BuildAndReplace(graph_, library_));
if (VLOG_IS_ON(4)) DumpGraphWithCondState("after_extract");
@@ -1359,7 +1326,9 @@ void FunctionalizeCond::DumpGraphWithCondState(const string& name) {
for (Node* n : graph_->nodes()) {
n->ClearAttr(kCondGroupDebugAttr);
- n->AddAttr(kCondGroupDebugAttr, cond_state_map_.CondStateToString(n));
+ n->AddAttr(kCondGroupDebugAttr,
+ strings::StrCat(state_map_.CondStateToString(n), "_",
+ state_map_.AncestorStateToString(n)));
}
LOG(INFO) << "FunctionalizeControlFlow (" << name << "): "
<< dump_graph::DumpGraphToFile(
diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.h b/tensorflow/compiler/tf2xla/functionalize_cond.h
index 86436011c6..28301150ea 100644
--- a/tensorflow/compiler/tf2xla/functionalize_cond.h
+++ b/tensorflow/compiler/tf2xla/functionalize_cond.h
@@ -43,105 +43,88 @@ enum class BranchType {
kNeither = 3,
};
-// CondStateMap is responsible for mapping from each graph Node to a CondState,
-// where each CondState is the array of CondNodes (corresponding to switch,
-// merge or dead states) as described below. For efficiency, this class interns
-// the CondState, so that CondState equality comparisons are simply pointer
+// StateMap is responsible for mapping from each graph Node to
+// * a CondState, where each CondState is a map from predicate to branch (i,e.,
+// what predicates have to hold or not hold).
+// * a AncestorState, where each AncestorState is a set of switch/merge nodes
+// that are an ancestor of the node in the graph;
+// For efficiency, this class interns the CondState (AncestorState), so that
+// CondState (AncestorState) equality comparisons are simply pointer
// comparisons.
-class CondStateMap {
+class StateMap {
public:
- explicit CondStateMap(Graph* graph);
-
- // Represents an entry in the CondState. An entry can either be the
- // switch (along with predicate), merge, or dead:
- // * switch node indicates a node that is executed along a branch with the
- // given predicate - a branch can be then, else or both;
- // * merge node indicates that the node is executed as output of a merge;
- // * dead indicates that this node can never be executed;
- struct CondNode {
- enum class Type { kSwitch = 1, kMerge = 2, kDead = 3 };
-
- CondNode(Type type, Node* switch_node = nullptr,
- BranchType branch = BranchType::kNeither);
-
- string ToString() const;
- bool operator==(const CondNode& other) const;
- bool operator!=(const CondNode& other) const;
-
- // Type of node.
- Type type;
-
- // Predicate and branch, only used when type is kSwitch.
- OutputTensor predicate;
- BranchType branch;
+ explicit StateMap(Graph* graph);
+
+ // Compare two OutputTensors by (node id, index).
+ struct OutputTensorLess {
+ bool operator()(const OutputTensor& lhs, const OutputTensor& rhs) const;
};
- // A node in the graph is executed when multiple conditions hold. The order
- // represents the nesting of the predicates that hold and is used when
- // extracting the nested conditionals.
- using CondState = std::vector<CondNode>;
+ // A node in the graph is executed when multiple conditions hold. Keep track
+ // of the predicates that must hold for a node to execute.
+ using CondState = std::map<OutputTensor, BranchType, OutputTensorLess>;
// Every unique ID is mapped to a CondState.
using CondId = const CondState*;
+ // Keep track of which switch/merge node's feed into a node's values.
+ using AncestorState = std::set<Node*>;
+
+ // Every unique ID is mapped to a AncestorState.
+ using AncestorId = const AncestorState*;
+
// Returns the CondId for a given node.
- CondId LookupId(const Node* node) const;
+ CondId LookupCondId(const Node* node) const;
// Returns the unique CondId for CondState.
- CondId GetUniqueId(const CondState& state);
+ CondId GetCondId(const CondState& state);
+
+ // Resets the CondId for a given node.
+ void ResetCondId(const Node* node, CondId id);
+
+ // Returns the AncestorId for a given node.
+ AncestorId LookupAncestorId(const Node* node) const;
+
+ // Returns the unique AncestorId for CondState.
+ AncestorId GetAncestorId(const AncestorState& state);
+
+ // Resets the AncestorId for a given node.
+ void ResetAncestorId(const Node* node, AncestorId id);
// Returns the CondState for a Node.
// REQUIRES: node has a non-empty CondState.
const CondState& LookupState(const Node* node) const;
- // Resets the CondId for a given node.
- void ResetId(const Node* node, CondId id);
-
// Marks `node` as dead.
void MarkDead(const Node* node);
// Determine branch execution of CondState.
BranchType FindBranchOf(CondId id, OutputTensor predicate) const;
- // Enum to represent whether one cond flow state contains another.
- enum ContainsResult {
- kIncomparable,
- kEqual,
- kLhsContainsRhs,
- kRhsContainsLhs
- };
-
- // Returns whether the lhs CondState holds wherever rhs CondState hols. I.e.,
- // [(p,t)] contains [(p,t), (r,t)].
- ContainsResult LhsHoldsWhereverRhsHolds(CondId lhs, CondId rhs);
-
// Returns textual representation of node's CondState.
string CondStateToString(const Node* node) const;
string CondStateToString(CondId id) const;
+ // Returns textual representation of node's AncestorState.
+ string AncestorStateToString(const Node* node) const;
+
// Returns whether the cond state is the dead state.
bool IsDead(CondId id) const;
// Returns whether the cond state is the empty state.
bool IsEmpty(CondId id) const;
- // Computes the predicates that have to hold for a node to execute and returns
- // whether it was possible to determine the predicates that must hold. `scope`
- // is populated with these predicates. Scope differs from state in that it
- // does not include merge and both nodes.
- bool ScopeIn(CondId id, CondId* scope);
-
private:
- // Hash for CondNode and CondState.
- struct CondHash {
- size_t operator()(const CondNode& item) const;
- size_t operator()(const CondState& vec) const;
+ // Hash for CondState and AncestorState.
+ struct Hash {
+ size_t operator()(const CondState& map) const;
+ size_t operator()(const AncestorState& map) const;
};
// Set to keep track of unique CondStates.
// Pointers to the entries in the unordered set are used as identifiers:
// unordered_set guarantees that the pointers remain the same.
- std::unordered_set<CondState, CondHash> condstate_set_;
+ std::unordered_set<CondState, Hash> condstate_set_;
// Mapping from Node id to CondId.
std::vector<CondId> node_to_condid_map_;
@@ -150,7 +133,12 @@ class CondStateMap {
// from Node id in the original graph to the CondId, but there will be nodes
// added to the original graph (such as If nodes) whose CondState needs to be
// tracked too.
- std::unordered_map<int, CondId> added_node_mapping_;
+ std::unordered_map<int, CondId> added_node_condid_mapping_;
+
+ // AncestorId variants of the CondId members.
+ std::unordered_set<AncestorState, Hash> ancestorstate_set_;
+ std::vector<AncestorId> node_to_ancestorid_map_;
+ std::unordered_map<int, AncestorId> added_node_ancestorid_mapping_;
// Identifier of the dead flow state. The empty flow state is represented with
// a nullptr.
@@ -173,7 +161,8 @@ class FunctionalizeCond {
// Add a If node to the graph defined by def that will, amongst other, replace
// replacee in the graph.
- xla::StatusOr<Node*> AddIfNode(const NodeDef& def, const Node* replacee);
+ xla::StatusOr<Node*> AddIfNode(const NodeDef& def, const Node* replacee,
+ const OutputTensor& predicate);
// Propagates the state of a newly inserted node.
Status PropagateUpdatedState(const Node* replacee);
@@ -185,35 +174,42 @@ class FunctionalizeCond {
FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library);
// Performs the actual cond functionalization. Iterate over groups of merge
- // nodes (linked by common predicate & CondIds of the incomming edges),
- // from innermost to outermost, and extract into If nodes.
+ // nodes (linked by common predicates & ancestor IDs), from innermost to
+ // outermost, and extract into If nodes.
Status FunctionalizeInternal();
// Returns the forward flow state propagated along edge `e`.
- // This may modify cond_state_map_.
- CondStateMap::CondId StateAlongEdge(const Edge* e);
+ // This may modify state_map_.
+ StateMap::CondId StateAlongEdge(const Edge* e);
- // Determines the CondState of all the nodes in the given vector where
- // the input is expected in reverse topological order.
- // This populates the cond_state_map_.
- Status DetermineCondStates(std::vector<Node*> rev_topo_order);
+ // Determines the CondState and AncestorState of all the nodes in the given
+ // vector where the input is expected in reverse topological order.
+ // This populates the state_map_.
+ Status DetermineStates(std::vector<Node*> rev_topo_order);
// Determine the CondState for a given node using the incomming edges
// to the node. Note: it is expected that this node's CondState is only
// determined once its input's CondState is.
- Status DetermineCondState(Node* dst);
+ Status DetermineCondState(Node* dst) {
+ if (IsMerge(dst)) return DetermineCondStateMerge(dst);
+ return DetermineCondStateNonMerge(dst);
+ }
// Helper functions for DetermineCondState.
+ Status DetermineCondStateNonMerge(Node* dst);
Status DetermineCondStateMerge(Node* dst);
- // Helper functions for DetermineCondStates. Determines the dst node's
- // CondState by joining the src and dst's CondState where either
- // the dst node is a merge or not.
- // These may modify cond_state_map_.
- xla::StatusOr<CondStateMap::CondId> JoinCondStatesMerge(
- CondStateMap::CondId src, CondStateMap::CondId dst);
- xla::StatusOr<CondStateMap::CondId> JoinCondStatesNonMerge(
- CondStateMap::CondId src, CondStateMap::CondId dst);
+ // Determines the dst node's CondState by joining the src and dst's CondState
+ // where either the dst node is a merge or not.
+ // These may modify state_map_.
+ xla::StatusOr<StateMap::CondId> JoinCondStatesMerge(Node* merge,
+ StateMap::CondId src,
+ StateMap::CondId dst);
+ xla::StatusOr<StateMap::CondId> JoinCondStatesNonMerge(StateMap::CondId src,
+ StateMap::CondId dst);
+
+ // Determines which switch/merge nodes are ancestors of this node.
+ Status DetermineAncestorState(Node* dst);
// Checks if a merge node is redundant and if so removes it from the graph.
Status RemoveRedundantMerge(Node* node);
@@ -228,9 +224,13 @@ class FunctionalizeCond {
// Deletes all nodes in/consumers of `delete_nodes_`.
void DeleteReachableNodes();
- // Member used to unique the CondState to a unique CondId and keep track of
- // CondState/CondId per Node.
- CondStateMap cond_state_map_;
+ // Member used to unique the CondState to a unique CondId (AncestorState to a
+ // unique AncestorId) and keep track of CondState/CondId
+ // (AncestorState/AncestorId) per Node.
+ StateMap state_map_;
+
+ // Mapping from merge nodes to predicate.
+ std::unordered_map<Node*, OutputTensor> merge_to_predicate_;
// Nodes to be deleted.
std::deque<int> delete_nodes_;
diff --git a/tensorflow/compiler/tf2xla/functionalize_cond_test.cc b/tensorflow/compiler/tf2xla/functionalize_cond_test.cc
index a27f889392..b0aabd63bb 100644
--- a/tensorflow/compiler/tf2xla/functionalize_cond_test.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_cond_test.cc
@@ -37,28 +37,23 @@ class FunctionalizeCondTest : public ::testing::Test {
flib_def_.get()));
}
- CondStateMap::CondId GetUniqueId(
- const CondStateMap::CondStateMap::CondState& state) {
- return fc_->cond_state_map_.GetUniqueId(state);
+ StateMap::CondId GetUniqueId(const StateMap::StateMap::CondState& state) {
+ return fc_->state_map_.GetCondId(state);
}
- xla::StatusOr<CondStateMap::CondId> JoinCondStatesNonMerge(
- CondStateMap::CondId src, CondStateMap::CondId dst) {
- return fc_->JoinCondStatesNonMerge(src, dst);
- }
-
- xla::StatusOr<CondStateMap::CondId> JoinCondStatesMerge(
- CondStateMap::CondId src, CondStateMap::CondId dst) {
- return fc_->JoinCondStatesMerge(src, dst);
+ string GetString(const StateMap::StateMap::CondId id) {
+ return fc_->state_map_.CondStateToString(id);
}
- bool ScopeIn(CondStateMap::CondId ff, CondStateMap::CondId* scope) {
- return fc_->cond_state_map_.ScopeIn(ff, scope);
+ xla::StatusOr<StateMap::CondId> JoinCondStatesNonMerge(StateMap::CondId src,
+ StateMap::CondId dst) {
+ return fc_->JoinCondStatesNonMerge(src, dst);
}
- CondStateMap::ContainsResult LhsHoldsWhereverRhsHolds(
- CondStateMap::CondId lhs, CondStateMap::CondId rhs) {
- return fc_->cond_state_map_.LhsHoldsWhereverRhsHolds(lhs, rhs);
+ xla::StatusOr<StateMap::CondId> JoinCondStatesMerge(Node* n,
+ StateMap::CondId src,
+ StateMap::CondId dst) {
+ return fc_->JoinCondStatesMerge(n, src, dst);
}
FunctionDefLibrary fdef_lib_;
@@ -69,50 +64,6 @@ class FunctionalizeCondTest : public ::testing::Test {
namespace {
-TEST_F(FunctionalizeCondTest, ScopeIn) {
- Tensor pred_tensor(DT_BOOL, TensorShape());
- pred_tensor.flat<bool>().setZero();
- Node* pred = test::graph::Constant(graph_.get(), pred_tensor, "pred");
- Tensor val_tensor(DT_INT32, TensorShape());
- val_tensor.flat<int>().setZero();
- Node* val = test::graph::Constant(graph_.get(), val_tensor, "val");
- Node* s = test::graph::Switch(graph_.get(), val, pred);
-
- {
- CondStateMap::CondStateMap::CondState ss;
- ss.emplace_back(CondStateMap::CondNode(
- CondStateMap::CondNode::Type::kSwitch, s, BranchType::kThenBranch));
- CondStateMap::CondId id = GetUniqueId(ss);
- CondStateMap::CondId scope;
- ASSERT_TRUE(ScopeIn(id, &scope));
- ASSERT_TRUE(id == scope);
- }
-
- CondStateMap::CondState empty;
- {
- CondStateMap::CondState ss;
- ss.emplace_back(CondStateMap::CondNode(
- CondStateMap::CondNode::Type::kSwitch, s, BranchType::kBoth));
- ss.emplace_back(
- CondStateMap::CondNode(CondStateMap::CondNode::Type::kMerge));
- CondStateMap::CondId id = GetUniqueId(ss);
- CondStateMap::CondId scope_1;
- ASSERT_TRUE(ScopeIn(id, &scope_1));
- ASSERT_TRUE(scope_1 == GetUniqueId(empty));
- ASSERT_TRUE(id != scope_1);
-
- ss.clear();
- ss.emplace_back(CondStateMap::CondNode(
- CondStateMap::CondNode::Type::kSwitch, s, BranchType::kBoth));
- id = GetUniqueId(ss);
- CondStateMap::CondId scope_2;
- ASSERT_TRUE(ScopeIn(id, &scope_2));
-
- ASSERT_TRUE(LhsHoldsWhereverRhsHolds(scope_1, scope_2) ==
- CondStateMap::ContainsResult::kLhsContainsRhs);
- }
-}
-
TEST_F(FunctionalizeCondTest, JoinCondStates) {
Tensor pred_tensor(DT_BOOL, TensorShape());
pred_tensor.flat<bool>().setZero();
@@ -120,22 +71,18 @@ TEST_F(FunctionalizeCondTest, JoinCondStates) {
Tensor val_tensor(DT_INT32, TensorShape());
val_tensor.flat<int>().setZero();
Node* val = test::graph::Constant(graph_.get(), val_tensor, "val");
- Node* s = test::graph::Switch(graph_.get(), val, pred);
+ Node* m = test::graph::Merge(graph_.get(), val, val);
- CondStateMap::CondId empty = GetUniqueId({});
-
- CondStateMap::CondId then_branch;
+ StateMap::CondId then_branch;
{
- CondStateMap::CondState ss;
- ss.emplace_back(CondStateMap::CondNode(
- CondStateMap::CondNode::Type::kSwitch, s, BranchType::kThenBranch));
+ StateMap::CondState ss;
+ ss.insert(std::make_pair(OutputTensor(pred, 0), BranchType::kThenBranch));
then_branch = GetUniqueId(ss);
}
- CondStateMap::CondId else_branch;
+ StateMap::CondId else_branch;
{
- CondStateMap::CondState ss;
- ss.emplace_back(CondStateMap::CondNode(
- CondStateMap::CondNode::Type::kSwitch, s, BranchType::kElseBranch));
+ StateMap::CondState ss;
+ ss.insert(std::make_pair(OutputTensor(pred, 0), BranchType::kElseBranch));
else_branch = GetUniqueId(ss);
}
@@ -144,39 +91,14 @@ TEST_F(FunctionalizeCondTest, JoinCondStates) {
EXPECT_TRUE(errors::IsInvalidArgument(status));
// Merge between then and else branch.
- auto joined_or = JoinCondStatesMerge(then_branch, else_branch);
+ auto joined_or = JoinCondStatesMerge(m, then_branch, else_branch);
TF_EXPECT_OK(joined_or.status());
- CondStateMap::CondId joined = joined_or.ValueOrDie();
+ StateMap::CondId joined = joined_or.ValueOrDie();
// Merge between then branch and both branch.
auto t = JoinCondStatesNonMerge(then_branch, joined);
// Note: this is OK in terms of constraint predication, but
TF_EXPECT_OK(t.status());
-
- // Post merge the propagated forward flow state has an additional merge.
- CondStateMap::CondId post_merge;
- {
- CondStateMap::CondState ss;
- ss = *joined;
- ss.emplace_back(
- CondStateMap::CondNode(CondStateMap::CondNode::Type::kMerge));
- post_merge = GetUniqueId(ss);
- }
-
- t = JoinCondStatesNonMerge(post_merge, joined);
- TF_EXPECT_OK(t.status());
- EXPECT_TRUE(joined == t.ValueOrDie());
-
- // No predicate that results in two paths predicated on different conditions
- // merge.
- t = JoinCondStatesMerge(post_merge, joined);
- EXPECT_FALSE(t.ok());
-
- // Post the merge we are effectively in the root scope and merging should
- // result in the more restrictive post merge state.
- t = JoinCondStatesNonMerge(post_merge, empty);
- TF_EXPECT_OK(t.status());
- EXPECT_TRUE(post_merge == t.ValueOrDie());
}
} // namespace
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index e639028ccd..7f2125f74c 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -990,8 +990,8 @@ XlaOp XlaBuilder::ConvGeneralDilated(
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
ShapeInference::InferConvolveShape(
- lhs_shape, rhs_shape, instr.window(),
- dimension_numbers, feature_group_count));
+ lhs_shape, rhs_shape, feature_group_count,
+ instr.window(), dimension_numbers));
*instr.mutable_convolution_dimension_numbers() = dimension_numbers;
instr.set_feature_group_count(feature_group_count);
diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc
index a4854f593f..8a05d1b0d7 100644
--- a/tensorflow/compiler/xla/reference_util.cc
+++ b/tensorflow/compiler/xla/reference_util.cc
@@ -564,18 +564,22 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
dim2.set_base_dilation(lhs_dilation.second);
*window.add_dimensions() = dim2;
- const Shape& shape =
- ShapeInference::InferConvolveShape(lhs_literal->shape(),
- rhs_literal->shape(), window, dnums)
- .ConsumeValueOrDie();
+ const Shape& shape = ShapeInference::InferConvolveShape(
+ lhs_literal->shape(), rhs_literal->shape(),
+ /*feature_group_count=*/1, window, dnums)
+ .ConsumeValueOrDie();
HloInstruction* lhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
HloInstruction* rhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ /*new_size=*/2, PrecisionConfigProto::DEFAULT);
b.AddInstruction(HloInstruction::CreateConvolve(
- shape, lhs_instruction, rhs_instruction, window, dnums));
+ shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
+ window, dnums, precision_config));
HloModuleConfig config;
HloModule module("ReferenceUtil", config);
auto computation = module.AddEntryComputation(b.Build());
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 26b48cf419..f6cfac6537 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -3289,6 +3289,8 @@ tf_cc_test(
size = "small",
srcs = ["hlo_parser_test.cc"],
deps = [
+ ":hlo",
+ ":hlo_casting_utils",
":hlo_matchers",
":hlo_parser",
"//tensorflow/compiler/xla:window_util",
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 7c078f07d7..3d18fe3be2 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -950,9 +950,9 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper(
new_dot_rhs = rhs_slice;
}
- auto* new_dot = computation_->AddInstruction(HloInstruction::CreateDot(
- dot.shape(), new_dot_lhs, new_dot_rhs, new_dot_dnums));
- new_dot->set_precision_config(dot.precision_config());
+ auto* new_dot = computation_->AddInstruction(
+ HloInstruction::CreateDot(dot.shape(), new_dot_lhs, new_dot_rhs,
+ new_dot_dnums, dot.precision_config()));
if (add_result) {
add_result = computation_->AddInstruction(HloInstruction::CreateBinary(
@@ -1053,9 +1053,9 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfGather(
const int n =
right_operand->shape().dimensions(1 - rhs_contracting_dimension);
auto memoized_shape = ShapeUtil::MakeShape(F32, {m, n});
- auto* memoized_inst = computation_->AddInstruction(HloInstruction::CreateDot(
- memoized_shape, left_operand, right_operand, dnums));
- memoized_inst->set_precision_config(dot->precision_config());
+ auto* memoized_inst = computation_->AddInstruction(
+ HloInstruction::CreateDot(memoized_shape, left_operand, right_operand,
+ dnums, dot->precision_config()));
// Get pair {start, 0} or {0, start}.
HloInstruction* original_start_indices =
lhs_is_dynamic_slice ? lhs->mutable_operand(1) : rhs->mutable_operand(1);
@@ -1151,9 +1151,8 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
dot_dimension_numbers.add_rhs_contracting_dimensions(0);
auto new_dot = computation_->AddInstruction(HloInstruction::CreateDot(
ShapeUtil::PermuteDimensions({1, 0}, dot->shape()),
- rhs->mutable_operand(0), lhs->mutable_operand(0),
- dot_dimension_numbers));
- new_dot->set_precision_config(dot->precision_config());
+ rhs->mutable_operand(0), lhs->mutable_operand(0), dot_dimension_numbers,
+ dot->precision_config()));
return ReplaceWithNewInstruction(
dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, {1, 0}));
}
@@ -2477,8 +2476,8 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
dot_dimension_numbers.add_lhs_contracting_dimensions(1);
dot_dimension_numbers.add_rhs_contracting_dimensions(0);
auto dot = computation_->AddInstruction(HloInstruction::CreateDot(
- dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers));
- dot->set_precision_config(convolution->precision_config());
+ dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers,
+ convolution->precision_config()));
return ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot));
}
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 43a891e4fa..019840b476 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -1013,6 +1013,13 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) {
1);
}
+PrecisionConfigProto DefaultPrecisionConfig(int operands) {
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ operands, PrecisionConfigProto::DEFAULT);
+ return precision_config;
+}
+
TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) {
auto builder = HloComputation::Builder(TestName());
HloInstruction* lhs = builder.AddInstruction(HloInstruction::CreateParameter(
@@ -1044,7 +1051,8 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) {
dim->set_window_reversal(false);
// Create add computation.
builder.AddInstruction(HloInstruction::CreateConvolve(
- ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, window, dnums));
+ ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, /*feature_group_count=*/1,
+ window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(builder.Build());
HloPassFix<AlgebraicSimplifier> simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
@@ -2260,9 +2268,11 @@ TEST_P(ConvInputPaddingTest, DoTest) {
.ValueOrDie();
builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeInference::InferConvolveShape(lhs_pad->shape(), filter->shape(),
- window, dnums)
+ /*feature_group_count=*/1, window,
+ dnums)
.ValueOrDie(),
- lhs_pad, filter, window, dnums));
+ lhs_pad, filter, /*feature_group_count=*/1, window, dnums,
+ DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
@@ -2368,9 +2378,11 @@ TEST_P(ConvFilterPaddingTest, DoIt) {
.ValueOrDie();
auto* orig_conv = builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeInference::InferConvolveShape(input->shape(), rhs_pad->shape(),
- window, dnums)
+ /*feature_group_count=*/1, window,
+ dnums)
.ValueOrDie(),
- input, rhs_pad, window, dnums));
+ input, rhs_pad, /*feature_group_count=*/1, window, dnums,
+ DefaultPrecisionConfig(2)));
// Add a PrecisionConfig and check that AlgebraicSimplifier keeps it in place
// after the transformation.
@@ -2522,8 +2534,9 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
HloInstruction* filter =
b.AddInstruction(HloInstruction::CreateParameter(1, f_shape, "filter"));
- b.AddInstruction(HloInstruction::CreateConvolve(out_shape, input, filter,
- window, dnums));
+ b.AddInstruction(HloInstruction::CreateConvolve(
+ out_shape, input, filter,
+ /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
// TODO(b/80488902): verify this module.
auto module = HloTestBase::CreateNewModule();
@@ -2901,7 +2914,8 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- builder.AddInstruction(HloInstruction::CreateDot(r1f32, x, y, dot_dnums));
+ builder.AddInstruction(HloInstruction::CreateDot(r1f32, x, y, dot_dnums,
+ DefaultPrecisionConfig(2)));
std::unique_ptr<HloComputation> dot_computation(builder.Build());
HloComputation::Builder call_builder(TestName() + ".Call");
@@ -3253,8 +3267,8 @@ TEST_P(DotStrengthReductionTest, DotStrengthReduction) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- builder.AddInstruction(
- HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums));
+ builder.AddInstruction(HloInstruction::CreateDot(
+ dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
auto computation = module().AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
@@ -3329,8 +3343,8 @@ TEST_P(DotOfConcatSimplificationTest, ConstantLHS) {
dot_dnums.add_rhs_contracting_dimensions(0);
Shape dot_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.n});
- builder.AddInstruction(
- HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums));
+ builder.AddInstruction(HloInstruction::CreateDot(
+ dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
auto computation = module().AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
@@ -3393,8 +3407,8 @@ TEST_P(DotOfConcatSimplificationTest, ConstantRHS) {
dot_dnums.add_rhs_contracting_dimensions(0);
Shape dot_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.n});
- builder.AddInstruction(
- HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums));
+ builder.AddInstruction(HloInstruction::CreateDot(
+ dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
auto computation = module().AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
@@ -3511,8 +3525,8 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) {
int64 dot_row_size = 1;
int64 dot_col_size = spec.n;
Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size});
- builder.AddInstruction(
- HloInstruction::CreateDot(dot_shape, ds, rhs, dot_dnums));
+ builder.AddInstruction(HloInstruction::CreateDot(
+ dot_shape, ds, rhs, dot_dnums, DefaultPrecisionConfig(2)));
auto computation = module().AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
@@ -3581,8 +3595,8 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) {
int64 dot_row_size = spec.m;
int64 dot_col_size = 1;
Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size});
- builder.AddInstruction(
- HloInstruction::CreateDot(dot_shape, lhs, ds, dot_dnums));
+ builder.AddInstruction(HloInstruction::CreateDot(
+ dot_shape, lhs, ds, dot_dnums, DefaultPrecisionConfig(2)));
auto computation = module().AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.cc b/tensorflow/compiler/xla/service/batch_dot_simplification.cc
index a16b85a0a5..eda026ac56 100644
--- a/tensorflow/compiler/xla/service/batch_dot_simplification.cc
+++ b/tensorflow/compiler/xla/service/batch_dot_simplification.cc
@@ -63,8 +63,8 @@ BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot(
new_dim_numbers.rhs_contracting_dimensions(0) - degenerate_dims.size());
TF_ASSIGN_OR_RETURN(HloInstruction * new_dot,
- MakeDotHlo(new_lhs, new_rhs, new_dim_numbers));
- new_dot->set_precision_config(batch_dot->precision_config());
+ MakeDotHlo(new_lhs, new_rhs, new_dim_numbers,
+ batch_dot->precision_config()));
TF_ASSIGN_OR_RETURN(HloInstruction * new_dot_reshaped,
MakeReshapeHlo(batch_dot->shape(), new_dot));
diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
index b08705d4c2..d480d72297 100644
--- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
@@ -308,8 +308,11 @@ TEST_F(BFloat16NormalizationTest, DoNotAddUnsupportedMixedPrecision) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ 2, PrecisionConfigProto::DEFAULT);
HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateDot(bf16_shape, a, b, dot_dnums));
+ HloInstruction::CreateDot(bf16_shape, a, b, dot_dnums, precision_config));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
index 8bd1533972..7398f105a0 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
@@ -1490,10 +1490,13 @@ TEST_F(BufferAssignmentTest, OneTempAllocation) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- auto dot_ab = builder.AddInstruction(
- HloInstruction::CreateDot(shape_2x4, param_a, param_b, dot_dnums));
- auto dot_bc = builder.AddInstruction(
- HloInstruction::CreateDot(shape_3x4, param_b, param_c, dot_dnums));
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ 2, PrecisionConfigProto::DEFAULT);
+ auto dot_ab = builder.AddInstruction(HloInstruction::CreateDot(
+ shape_2x4, param_a, param_b, dot_dnums, precision_config));
+ auto dot_bc = builder.AddInstruction(HloInstruction::CreateDot(
+ shape_3x4, param_b, param_c, dot_dnums, precision_config));
builder.AddInstruction(
HloInstruction::CreateConcatenate(shape_5x4, {dot_ab, dot_bc}, 0));
diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc
index 9c81a86bbb..0826380f65 100644
--- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc
+++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc
@@ -223,8 +223,8 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
filter_mask, expanded_filter, zero_filter));
auto new_convolution = HloInstruction::CreateConvolve(
convolution->shape(), convolution->mutable_operand(0), new_filter,
- convolution->window(), dim_numbers, /*feature_group_count=*/1);
- new_convolution->set_precision_config(convolution->precision_config());
+ /*feature_group_count=*/1, convolution->window(), dim_numbers,
+ convolution->precision_config());
TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction(
convolution, std::move(new_convolution)));
return Status::OK();
diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc
index 098ce17a56..2d9978404c 100644
--- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc
+++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc
@@ -130,9 +130,9 @@ StatusOr<bool> ConvCanonicalization::Run(HloModule* module) {
// change the dimension mapping but not the dimension sizes. For
// example, input height and width are the same as before the reshapes.
HloInstruction* new_conv = module->entry_computation()->AddInstruction(
- HloInstruction::CreateConvolve(new_conv_shape, new_input, new_kernel,
- hlo->window(), new_dnums));
- new_conv->set_precision_config(hlo->precision_config());
+ HloInstruction::CreateConvolve(
+ new_conv_shape, new_input, new_kernel, hlo->feature_group_count(),
+ hlo->window(), new_dnums, hlo->precision_config()));
// Reshape the output back to the shape of the original convolution.
TF_RETURN_IF_ERROR(module->entry_computation()->ReplaceWithNewInstruction(
diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc
index 547d4c696d..616c453750 100644
--- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc
@@ -56,6 +56,13 @@ class ConvCanonicalizationTest : public HloTestBase {
static constexpr int kOutputFeatureCount = 64;
};
+PrecisionConfigProto DefaultPrecisionConfig(int operands) {
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ operands, PrecisionConfigProto::DEFAULT);
+ return precision_config;
+}
+
TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) {
auto builder = HloComputation::Builder(TestName());
// The input dimensions are in CNHW order.
@@ -84,7 +91,8 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) {
builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeUtil::MakeShape(
F32, {kOutputFeatureCount, kBatchSize, output_size, output_size}),
- input, kernel, conv_window_, dnums));
+ input, kernel, /*feature_group_count=*/1, conv_window_, dnums,
+ DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
HloComputation* entry_computation =
@@ -146,7 +154,8 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) {
builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeUtil::MakeShape(
F32, {kBatchSize, output_size, output_size, kOutputFeatureCount}),
- input, kernel, conv_window_, dnums));
+ input, kernel, /*feature_group_count=*/1, conv_window_, dnums,
+ DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
index 284929ca07..6bd0a2dd90 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
@@ -38,7 +38,11 @@ std::unique_ptr<HloInstruction> MakeDot(const Shape& shape, HloInstruction* lhs,
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums);
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ 2, PrecisionConfigProto::DEFAULT);
+ return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums,
+ precision_config);
}
TEST_F(InstructionFusionTest, DotOperationFusion_Basic_0) {
diff --git a/tensorflow/compiler/xla/service/dot_decomposer.cc b/tensorflow/compiler/xla/service/dot_decomposer.cc
index 09cb10d6ee..b2ba261790 100644
--- a/tensorflow/compiler/xla/service/dot_decomposer.cc
+++ b/tensorflow/compiler/xla/service/dot_decomposer.cc
@@ -134,9 +134,9 @@ Status DecomposeBatchDot(HloInstruction* dot) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- auto dot_r2 = computation->AddInstruction(HloInstruction::CreateDot(
- dot_shape_r2, lhs_slice_r2, rhs_slice_r2, dot_dnums));
- dot_r2->set_precision_config(dot->precision_config());
+ auto dot_r2 = computation->AddInstruction(
+ HloInstruction::CreateDot(dot_shape_r2, lhs_slice_r2, rhs_slice_r2,
+ dot_dnums, dot->precision_config()));
// Reshape Dot to R3 so we can concat along batch dimension.
auto dot_r3 = computation->AddInstruction(
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc
index 46c23db465..9b46bfc098 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc
@@ -95,6 +95,13 @@ class CudnnConvolutionRewriterTest : public HloVerifiedTestBase {
ConvolutionDimensionNumbers tf_default_dnums_for_backward_input_;
};
+PrecisionConfigProto DefaultPrecisionConfig(int operands) {
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ operands, PrecisionConfigProto::DEFAULT);
+ return precision_config;
+}
+
TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolve) {
HloComputation::Builder builder(TestName());
HloInstruction* activations =
@@ -107,12 +114,12 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolve) {
conv_window.mutable_dimensions(1)->set_size(2);
conv_window.mutable_dimensions(1)->set_window_dilation(2);
builder.AddInstruction(HloInstruction::CreateConvolve(
- ShapeInference::InferConvolveShape(activations->shape(),
- gradients->shape(), conv_window,
- tf_default_dnums_for_backward_filter_)
+ ShapeInference::InferConvolveShape(
+ activations->shape(), gradients->shape(), /*feature_group_count=*/1,
+ conv_window, tf_default_dnums_for_backward_filter_)
.ConsumeValueOrDie(),
- activations, gradients, conv_window,
- tf_default_dnums_for_backward_filter_));
+ activations, gradients, /*feature_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
HloComputation* entry_computation =
@@ -135,12 +142,12 @@ TEST_F(CudnnConvolutionRewriterTest,
Window conv_window = default_conv_window_;
conv_window.mutable_dimensions(1)->set_size(3);
builder.AddInstruction(HloInstruction::CreateConvolve(
- ShapeInference::InferConvolveShape(activations->shape(),
- gradients->shape(), conv_window,
- tf_default_dnums_for_backward_filter_)
+ ShapeInference::InferConvolveShape(
+ activations->shape(), gradients->shape(), /*feature_group_count=*/1,
+ conv_window, tf_default_dnums_for_backward_filter_)
.ConsumeValueOrDie(),
- activations, gradients, conv_window,
- tf_default_dnums_for_backward_filter_));
+ activations, gradients, /*feature_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
HloComputation* entry_computation =
@@ -170,7 +177,8 @@ TEST_F(CudnnConvolutionRewriterTest,
}
builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeUtil::MakeShape(F32, {32, 3, 3, 32}), activations, gradients,
- conv_window, tf_default_dnums_for_backward_filter_));
+ /*feature_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
HloComputation* entry_computation =
@@ -200,7 +208,8 @@ TEST_F(CudnnConvolutionRewriterTest,
}
builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeUtil::MakeShape(F32, {320, 3, 3, 192}), activations, gradients,
- conv_window, tf_default_dnums_for_backward_filter_));
+ /*feature_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
HloComputation* entry_computation =
@@ -228,7 +237,8 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolveWithUnevenPadding) {
}
builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeUtil::MakeShape(F32, {32, 2, 2, 32}), activations, gradients,
- conv_window, tf_default_dnums_for_backward_filter_));
+ /*feature_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
HloComputation* entry_computation =
@@ -272,13 +282,14 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveEvenPadding) {
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeUtil::MakeShape(F32, {4, 3, 16, 16}), /*lhs=*/output,
- /*rhs=*/reverse_kernel, conv_window, conv_dnums));
+ /*rhs=*/reverse_kernel, /*feature_group_count=*/1, conv_window,
+ conv_dnums, DefaultPrecisionConfig(2)));
// Verify the convolution's shape is consistent with ShapeInference.
CHECK(ShapeUtil::Compatible(
- conv->shape(),
- ShapeInference::InferConvolveShape(
- output->shape(), reverse_kernel->shape(), conv_window, conv_dnums)
- .ValueOrDie()));
+ conv->shape(), ShapeInference::InferConvolveShape(
+ output->shape(), reverse_kernel->shape(),
+ /*feature_group_count=*/1, conv_window, conv_dnums)
+ .ValueOrDie()));
auto module = CreateNewModule();
HloComputation* entry_computation =
@@ -319,11 +330,11 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolve1x1Filter) {
builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeInference::InferConvolveShape(output->shape(), kernel->shape(),
- conv_window,
+ /*feature_group_count=*/1, conv_window,
tf_default_dnums_for_backward_input_)
.ConsumeValueOrDie(),
- /*lhs=*/output, /*rhs=*/kernel, conv_window,
- tf_default_dnums_for_backward_input_));
+ /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
HloComputation* entry_computation =
@@ -350,12 +361,13 @@ TEST_F(CudnnConvolutionRewriterTest,
1, ShapeUtil::MakeShape(F32, {1, 1, 1, 1}), "kernel"));
builder.AddInstruction(HloInstruction::CreateConvolve(
- ShapeInference::InferConvolveShape(output->shape(), kernel->shape(),
- default_conv_window_,
- tf_default_dnums_for_backward_input_)
+ ShapeInference::InferConvolveShape(
+ output->shape(), kernel->shape(), /*feature_group_count=*/1,
+ default_conv_window_, tf_default_dnums_for_backward_input_)
.ConsumeValueOrDie(),
- /*lhs=*/output, /*rhs=*/kernel, default_conv_window_,
- tf_default_dnums_for_backward_input_));
+ /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1,
+ default_conv_window_, tf_default_dnums_for_backward_input_,
+ DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
HloComputation* entry_computation =
@@ -402,13 +414,15 @@ TEST_F(CudnnConvolutionRewriterTest,
}
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel,
- conv_window, tf_default_dnums_for_backward_input_));
+ /*feature_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
// Verify the convolution's shape is consistent with ShapeInference.
CHECK(ShapeUtil::Compatible(
- conv->shape(), ShapeInference::InferConvolveShape(
- output->shape(), reverse_kernel->shape(), conv_window,
- tf_default_dnums_for_backward_input_)
- .ValueOrDie()));
+ conv->shape(),
+ ShapeInference::InferConvolveShape(
+ output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1,
+ conv_window, tf_default_dnums_for_backward_input_)
+ .ValueOrDie()));
auto module = CreateNewModule();
HloComputation* entry_computation =
@@ -449,13 +463,15 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveLowPaddingTooLarge) {
}
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel,
- conv_window, tf_default_dnums_for_backward_input_));
+ /*feature_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
// Verify the convolution's shape is consistent with ShapeInference.
CHECK(ShapeUtil::Compatible(
- conv->shape(), ShapeInference::InferConvolveShape(
- output->shape(), reverse_kernel->shape(), conv_window,
- tf_default_dnums_for_backward_input_)
- .ValueOrDie()));
+ conv->shape(),
+ ShapeInference::InferConvolveShape(
+ output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1,
+ conv_window, tf_default_dnums_for_backward_input_)
+ .ValueOrDie()));
auto module = CreateNewModule();
HloComputation* entry_computation =
@@ -502,13 +518,15 @@ TEST_F(CudnnConvolutionRewriterTest,
forward_conv_col_dim->set_base_dilation(2);
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeUtil::MakeShape(F32, {1, 1, 14, 1}), output, reverse_kernel,
- conv_window, tf_default_dnums_for_backward_input_));
+ /*feature_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
// Verify the convolution's shape is consistent with ShapeInference.
CHECK(ShapeUtil::Compatible(
- conv->shape(), ShapeInference::InferConvolveShape(
- output->shape(), reverse_kernel->shape(), conv_window,
- tf_default_dnums_for_backward_input_)
- .ValueOrDie()));
+ conv->shape(),
+ ShapeInference::InferConvolveShape(
+ output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1,
+ conv_window, tf_default_dnums_for_backward_input_)
+ .ValueOrDie()));
auto module = CreateNewModule();
const HloComputation* entry_computation =
@@ -554,13 +572,15 @@ TEST_F(CudnnConvolutionRewriterTest,
forward_conv_col_dim->set_padding_high(2);
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeUtil::MakeShape(F32, {1, 1, 4, 1}), output, reverse_kernel,
- conv_window, tf_default_dnums_for_backward_input_));
+ /*feature_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
// Verify the convolution's shape is consistent with ShapeInference.
CHECK(ShapeUtil::Compatible(
- conv->shape(), ShapeInference::InferConvolveShape(
- output->shape(), reverse_kernel->shape(), conv_window,
- tf_default_dnums_for_backward_input_)
- .ValueOrDie()));
+ conv->shape(),
+ ShapeInference::InferConvolveShape(
+ output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1,
+ conv_window, tf_default_dnums_for_backward_input_)
+ .ValueOrDie()));
auto module = CreateNewModule();
HloComputation* entry_computation =
diff --git a/tensorflow/compiler/xla/service/graphviz_example.cc b/tensorflow/compiler/xla/service/graphviz_example.cc
index a2be89511b..0a49d85c6d 100644
--- a/tensorflow/compiler/xla/service/graphviz_example.cc
+++ b/tensorflow/compiler/xla/service/graphviz_example.cc
@@ -112,8 +112,11 @@ std::unique_ptr<HloModule> MakeBigGraph() {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- auto dot = builder.AddInstruction(
- HloInstruction::CreateDot(vshape, clamp, param_v0, dot_dnums));
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ /*new_size=*/2, PrecisionConfigProto::DEFAULT);
+ auto dot = builder.AddInstruction(HloInstruction::CreateDot(
+ vshape, clamp, param_v0, dot_dnums, precision_config));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({dot, param_s, clamp}));
auto scalar = builder.AddInstruction(
diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc
index 5f85f14565..576c5ff7a4 100644
--- a/tensorflow/compiler/xla/service/heap_simulator_test.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc
@@ -353,6 +353,13 @@ TEST_F(HeapSimulatorTest, BufferReusedOnce) {
(neg_buffer == output_buffer_1));
}
+PrecisionConfigProto DefaultPrecisionConfig(int operands) {
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ operands, PrecisionConfigProto::DEFAULT);
+ return precision_config;
+}
+
TEST_F(HeapSimulatorTest, MultiplyDot) {
auto builder = HloComputation::Builder(TestName());
auto paramA = builder.AddInstruction(
@@ -366,8 +373,8 @@ TEST_F(HeapSimulatorTest, MultiplyDot) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- auto dot = builder.AddInstruction(
- HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums));
+ auto dot = builder.AddInstruction(HloInstruction::CreateDot(
+ f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
// The buffer for dot is the output, and it cannot be shared with the buffer
// for mul, since dot isn't elementwise.
@@ -402,8 +409,8 @@ TEST_F(HeapSimulatorTest, MultiplyDotAdd) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- auto dot = builder.AddInstruction(
- HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums));
+ auto dot = builder.AddInstruction(HloInstruction::CreateDot(
+ f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, dot, paramA));
@@ -440,10 +447,10 @@ TEST_F(HeapSimulatorTest, MultiplyDotDot) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- auto dot0 = builder.AddInstruction(
- HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums));
- auto dot1 = builder.AddInstruction(
- HloInstruction::CreateDot(f32vec4_, dot0, paramY, dot_dnums));
+ auto dot0 = builder.AddInstruction(HloInstruction::CreateDot(
+ f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
+ auto dot1 = builder.AddInstruction(HloInstruction::CreateDot(
+ f32vec4_, dot0, paramY, dot_dnums, DefaultPrecisionConfig(2)));
// The buffer for dot1 is the output. No buffers can be shared. The buffer
// for mul is freed before the end, since it's no longer used after dot0
@@ -481,10 +488,10 @@ TEST_F(HeapSimulatorTest, MultiplyDotDotTuple) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- auto dot0 = builder.AddInstruction(
- HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums));
- auto dot1 = builder.AddInstruction(
- HloInstruction::CreateDot(f32vec4_, dot0, paramY, dot_dnums));
+ auto dot0 = builder.AddInstruction(HloInstruction::CreateDot(
+ f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
+ auto dot1 = builder.AddInstruction(HloInstruction::CreateDot(
+ f32vec4_, dot0, paramY, dot_dnums, DefaultPrecisionConfig(2)));
auto tuple =
builder.AddInstruction(HloInstruction::CreateTuple({dot0, dot1}));
diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc
index f7ed1b0316..a2c1ce34c6 100644
--- a/tensorflow/compiler/xla/service/hlo_computation_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc
@@ -601,8 +601,11 @@ TEST_F(HloComputationTest, Stringification) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ 2, PrecisionConfigProto::DEFAULT);
builder.AddInstruction(
- HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+ HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config));
auto module = CreateNewModule();
auto* computation = module->AddEntryComputation(builder.Build());
@@ -633,8 +636,11 @@ TEST_F(HloComputationTest, StringificationIndent) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ 2, PrecisionConfigProto::DEFAULT);
builder.AddInstruction(
- HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+ HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config));
auto module = CreateNewModule();
auto* computation = module->AddEntryComputation(builder.Build());
@@ -666,8 +672,11 @@ TEST_F(HloComputationTest, StringificationCanonical) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ 2, PrecisionConfigProto::DEFAULT);
builder.AddInstruction(
- HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+ HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config));
auto module = CreateNewModule();
auto* computation = module->AddEntryComputation(builder.Build());
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
index 19ffb465c0..a6ae0337a5 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
@@ -61,15 +61,18 @@ StatusOr<HloInstruction*> MakeSliceHlo(HloInstruction* operand,
}
StatusOr<HloInstruction*> MakeConvolveHlo(
- HloInstruction* lhs, HloInstruction* rhs, const Window& window,
- const ConvolutionDimensionNumbers& dimension_numbers) {
+ HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count,
+ const Window& window, const ConvolutionDimensionNumbers& dimension_numbers,
+ const PrecisionConfigProto& precision_config) {
HloComputation* computation = lhs->parent();
CHECK_EQ(computation, rhs->parent());
- TF_ASSIGN_OR_RETURN(Shape convolve_shape, ShapeInference::InferConvolveShape(
- lhs->shape(), rhs->shape(),
- window, dimension_numbers));
+ TF_ASSIGN_OR_RETURN(Shape convolve_shape,
+ ShapeInference::InferConvolveShape(
+ lhs->shape(), rhs->shape(), feature_group_count,
+ window, dimension_numbers));
return computation->AddInstruction(HloInstruction::CreateConvolve(
- convolve_shape, lhs, rhs, window, dimension_numbers));
+ convolve_shape, lhs, rhs, feature_group_count, window, dimension_numbers,
+ precision_config));
}
StatusOr<HloInstruction*> MakeTransposeHlo(HloInstruction* operand,
@@ -164,15 +167,17 @@ StatusOr<HloInstruction*> MakeConcatHlo(
HloInstruction::CreateConcatenate(concat_shape, operands, dimension));
}
-StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs,
- const DotDimensionNumbers& dim_numbers) {
+StatusOr<HloInstruction*> MakeDotHlo(
+ HloInstruction* lhs, HloInstruction* rhs,
+ const DotDimensionNumbers& dim_numbers,
+ const PrecisionConfigProto& precision_config) {
HloComputation* computation = lhs->parent();
CHECK_EQ(computation, rhs->parent());
TF_ASSIGN_OR_RETURN(
Shape dot_shape,
ShapeInference::InferDotOpShape(lhs->shape(), rhs->shape(), dim_numbers));
- return computation->AddInstruction(
- HloInstruction::CreateDot(dot_shape, lhs, rhs, dim_numbers));
+ return computation->AddInstruction(HloInstruction::CreateDot(
+ dot_shape, lhs, rhs, dim_numbers, precision_config));
}
StatusOr<HloInstruction*> MakeMapHlo(absl::Span<HloInstruction* const> operands,
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h
index a1c4b374d1..1c82956907 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.h
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h
@@ -48,8 +48,9 @@ StatusOr<HloInstruction*> MakeSliceHlo(HloInstruction* operand,
// Creates a convolution HLO instruction and adds it to the computation
// containing `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation).
StatusOr<HloInstruction*> MakeConvolveHlo(
- HloInstruction* lhs, HloInstruction* rhs, const Window& window,
- const ConvolutionDimensionNumbers& dimension_numbers);
+ HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count,
+ const Window& window, const ConvolutionDimensionNumbers& dimension_numbers,
+ const PrecisionConfigProto& precision_config);
// Creates a transpose HLO instruction and adds it to the computation containing
// `operand`.
@@ -97,8 +98,10 @@ StatusOr<HloInstruction*> MakeConcatHlo(
// Creates a Dot HLO instruction and adds it to the computation containing `lhs`
// and `rhs` (both must be in the same computation).
-StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs,
- const DotDimensionNumbers& dim_numbers);
+StatusOr<HloInstruction*> MakeDotHlo(
+ HloInstruction* lhs, HloInstruction* rhs,
+ const DotDimensionNumbers& dim_numbers,
+ const PrecisionConfigProto& precision_config);
// Creates a Map HLO instruction and adds it to the computation containing the
// operands. All operands must be in the same computation.
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
index d1a96c10f8..62eea2b06c 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
@@ -2334,8 +2334,11 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ 2, PrecisionConfigProto::DEFAULT);
auto dot = builder.AddInstruction(
- HloInstruction::CreateDot(data_shape, a, b, dot_dnums));
+ HloInstruction::CreateDot(data_shape, a, b, dot_dnums, precision_config));
auto one = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.cc b/tensorflow/compiler/xla/service/hlo_domain_map.cc
index 8b2846e0c2..113fd18eae 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_map.cc
+++ b/tensorflow/compiler/xla/service/hlo_domain_map.cc
@@ -51,6 +51,10 @@ int64 HloDomainMap::GetDomainId(HloInstruction* instruction) const {
return FindOrDefault(instruction_to_domain_, instruction, -1);
}
+int64 HloDomainMap::GetDomainMetadataId(HloInstruction* instruction) const {
+ return FindOrDie(domain_metadata_id_, instruction);
+}
+
Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) {
TF_RET_CHECK(instruction->opcode() == HloOpcode::kDomain);
// We only check operands, so we are sure to not process the empty domain from
@@ -93,6 +97,43 @@ Status HloDomainMap::Populate(HloComputation* computation) {
CreateDomain(instruction, instructions_post_order));
TF_RETURN_IF_ERROR(InsertDomain(std::move(domain)));
}
+ TF_RETURN_IF_ERROR(PopulateDomainMetadataMap());
+ return Status::OK();
+}
+
+Status HloDomainMap::PopulateDomainMetadataMap() {
+ auto hash = [](const DomainMetadata* m) { return m->Hash(); };
+ auto equal = [](const DomainMetadata* a, const DomainMetadata* b) {
+ return a->Matches(*b);
+ };
+ tensorflow::gtl::FlatMap<const DomainMetadata*, int64, decltype(hash),
+ decltype(equal)>
+ domain_metadata(1024, hash, equal);
+
+ for (auto& domain : instruction_domains_) {
+ int64 domain_metadata_id = -1;
+ if (!domain->enter_domains.empty()) {
+ const HloInstruction* domain_instruction = *domain->enter_domains.begin();
+ domain_metadata_id =
+ domain_metadata
+ .insert({&domain_instruction->user_side_metadata(),
+ domain_metadata.size() + 1})
+ .first->second;
+ } else if (!domain->exit_domains.empty()) {
+ const HloInstruction* domain_instruction = *domain->exit_domains.begin();
+ domain_metadata_id =
+ domain_metadata
+ .insert({&domain_instruction->operand_side_metadata(),
+ domain_metadata.size() + 1})
+ .first->second;
+ } else {
+ domain_metadata_id = 0;
+ }
+ TF_RET_CHECK(domain_metadata_id >= 0);
+ for (HloInstruction* instruction : domain->instructions) {
+ domain_metadata_id_[instruction] = domain_metadata_id;
+ }
+ }
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.h b/tensorflow/compiler/xla/service/hlo_domain_map.h
index 633109249a..56b557d7ce 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_map.h
+++ b/tensorflow/compiler/xla/service/hlo_domain_map.h
@@ -69,6 +69,11 @@ class HloDomainMap {
// instruction is not found within any domain.
int64 GetDomainId(HloInstruction* instruction) const;
+ // Returns the unique id of the domain metadata for the domain the given
+ // instruction belongs to. The given instruction must not be a kDomain
+ // instruction since each domain instruction is associated with 2 domains.
+ int64 GetDomainMetadataId(HloInstruction* instruction) const;
+
private:
// Map used for representing instruction ordering, i.e.
// order_map[a] < order_map[b] means a must be ordered before b.
@@ -109,9 +114,14 @@ class HloDomainMap {
const tensorflow::gtl::FlatSet<HloInstruction*>& instruction_set,
const InstructionOrderMap& instructions_order);
+ // Populates domain_metadata_id_ that maps each HloInstruction to the unique
+ // ID of its associated domain metatadata.
+ Status PopulateDomainMetadataMap();
+
string domain_kind_;
std::vector<std::unique_ptr<DomainMetadata::Domain>> instruction_domains_;
tensorflow::gtl::FlatMap<HloInstruction*, int64> instruction_to_domain_;
+ tensorflow::gtl::FlatMap<HloInstruction*, int64> domain_metadata_id_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_domain_metadata.h b/tensorflow/compiler/xla/service/hlo_domain_metadata.h
index 6c142ee474..302807f816 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_metadata.h
+++ b/tensorflow/compiler/xla/service/hlo_domain_metadata.h
@@ -72,6 +72,9 @@ class DomainMetadata {
// two matches.
virtual bool Matches(const DomainMetadata& other) const = 0;
+ // Returns the hash value of the metadata.
+ virtual size_t Hash() const = 0;
+
// Returns a string representation of the metadata.
virtual string ToString() const = 0;
};
diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc
index 974ab94467..43e74d2f6f 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc
@@ -99,6 +99,8 @@ class OpNameMetadata : public DomainMetadata {
static absl::string_view KindName() { return "opname"; }
+ size_t Hash() const override { return std::hash<string>()(opname_); }
+
private:
string opname_;
};
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index 441dcad000..ffb3451164 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -53,7 +53,6 @@ namespace xla {
namespace {
-
template <typename OperandT>
StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
LiteralSlice lhs_literal,
@@ -345,7 +344,8 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseUnaryOp(
}
StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateDotOp(
- const DotDimensionNumbers& dim_numbers, const Literal& lhs,
+ const DotDimensionNumbers& dim_numbers,
+ const PrecisionConfigProto& precision_config, const Literal& lhs,
const Literal& rhs) {
std::unique_ptr<HloInstruction> lhs_instr =
HloInstruction::CreateConstant(lhs.CloneToUnique());
@@ -358,7 +358,7 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateDotOp(
std::unique_ptr<HloInstruction> cloned_instruction =
HloInstruction::CreateDot(dot_shape, lhs_instr.get(), rhs_instr.get(),
- dim_numbers);
+ dim_numbers, precision_config);
return Evaluate(cloned_instruction.get());
}
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h
index c2d49e56ac..e13af8e999 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.h
@@ -115,7 +115,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
HloOpcode opcode, const Literal& operand);
StatusOr<std::unique_ptr<Literal>> EvaluateDotOp(
- const DotDimensionNumbers& dim_numbers, const Literal& lhs,
+ const DotDimensionNumbers& dim_numbers,
+ const PrecisionConfigProto& precision_config, const Literal& lhs,
const Literal& rhs);
protected:
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
index 7e490d7f32..f586f253da 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
@@ -622,6 +622,13 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) {
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
+PrecisionConfigProto DefaultPrecisionConfig(int operands) {
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ operands, PrecisionConfigProto::DEFAULT);
+ return precision_config;
+}
+
TEST_P(HloEvaluatorTest, DotRank2AndRank1) {
HloComputation::Builder b(TestName());
@@ -649,7 +656,8 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) {
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction,
- rhs_instruction, dot_dnums));
+ rhs_instruction, dot_dnums,
+ DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
std::unique_ptr<Literal> result = Evaluate();
@@ -694,7 +702,8 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) {
dot_dnums.add_lhs_contracting_dimensions(0);
dot_dnums.add_rhs_contracting_dimensions(0);
b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction,
- rhs_instruction, dot_dnums));
+ rhs_instruction, dot_dnums,
+ DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
std::unique_ptr<Literal> result = Evaluate();
@@ -737,7 +746,8 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) {
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction,
- rhs_instruction, dot_dnums));
+ rhs_instruction, dot_dnums,
+ DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
std::unique_ptr<Literal> result = Evaluate();
@@ -788,9 +798,10 @@ TEST_P(HloEvaluatorTest, SimpleConv1D) {
dnums.set_kernel_input_feature_dimension(1);
dnums.add_kernel_spatial_dimensions(2);
- const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 3});
+ Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 3});
b.AddInstruction(HloInstruction::CreateConvolve(
- shape, lhs_instruction, rhs_instruction, window, dnums));
+ shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
+ window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
std::unique_ptr<Literal> result = Evaluate();
@@ -842,9 +853,10 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) {
ConvolutionDimensionNumbers dnums =
XlaBuilder::CreateDefaultConvDimensionNumbers(2);
- const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
+ Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
b.AddInstruction(HloInstruction::CreateConvolve(
- shape, lhs_instruction, rhs_instruction, window, dnums));
+ shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
+ window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
std::unique_ptr<Literal> result = Evaluate();
@@ -925,9 +937,10 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) {
dnums.add_kernel_spatial_dimensions(3);
dnums.add_kernel_spatial_dimensions(1);
- const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
+ Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
b.AddInstruction(HloInstruction::CreateConvolve(
- shape, lhs_instruction, rhs_instruction, window, dnums));
+ shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
+ window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
std::unique_ptr<Literal> result = Evaluate();
@@ -1002,9 +1015,10 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) {
dnums.add_kernel_spatial_dimensions(3);
dnums.add_kernel_spatial_dimensions(1);
- const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
+ Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
b.AddInstruction(HloInstruction::CreateConvolve(
- shape, lhs_instruction, rhs_instruction, window, dnums));
+ shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
+ window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
std::unique_ptr<Literal> result = Evaluate();
@@ -1061,9 +1075,10 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) {
ConvolutionDimensionNumbers dnums =
XlaBuilder::CreateDefaultConvDimensionNumbers(2);
- const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 7, 7});
+ Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 7, 7});
b.AddInstruction(HloInstruction::CreateConvolve(
- shape, lhs_instruction, rhs_instruction, window, dnums));
+ shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
+ window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
std::unique_ptr<Literal> result = Evaluate();
@@ -1124,9 +1139,10 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) {
ConvolutionDimensionNumbers dnums =
XlaBuilder::CreateDefaultConvDimensionNumbers(2);
- const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 8, 8});
+ Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 8, 8});
b.AddInstruction(HloInstruction::CreateConvolve(
- shape, lhs_instruction, rhs_instruction, window, dnums));
+ shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
+ window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
std::unique_ptr<Literal> result = Evaluate();
@@ -1195,9 +1211,10 @@ TEST_P(HloEvaluatorTest,
ConvolutionDimensionNumbers dnums =
XlaBuilder::CreateDefaultConvDimensionNumbers(2);
- const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 9, 3});
+ Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 9, 3});
b.AddInstruction(HloInstruction::CreateConvolve(
- shape, lhs_instruction, rhs_instruction, window, dnums));
+ shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
+ window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
std::unique_ptr<Literal> result = Evaluate();
@@ -1219,6 +1236,67 @@ TEST_P(HloEvaluatorTest,
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
+TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) {
+ HloComputation::Builder b(TestName());
+ std::vector<int64> input_dims = {1, 2, 2, 4};
+ std::vector<int64> filter_dims = {2, 2, 2, 8};
+ Shape input_shape = ShapeUtil::MakeShapeWithType<float>(input_dims);
+ Shape filter_shape = ShapeUtil::MakeShapeWithType<float>(filter_dims);
+ // Tensorflow dimension numbers for 2D convolution.
+ ConvolutionDimensionNumbers dnums;
+ dnums.set_input_batch_dimension(0);
+ dnums.set_output_batch_dimension(0);
+ dnums.add_input_spatial_dimensions(1);
+ dnums.add_output_spatial_dimensions(1);
+ dnums.add_input_spatial_dimensions(2);
+ dnums.add_output_spatial_dimensions(2);
+ dnums.set_input_feature_dimension(3);
+ dnums.set_output_feature_dimension(3);
+ dnums.add_kernel_spatial_dimensions(0);
+ dnums.add_kernel_spatial_dimensions(1);
+ dnums.set_kernel_input_feature_dimension(2);
+ dnums.set_kernel_output_feature_dimension(3);
+
+ Window window;
+ WindowDimension dim;
+ dim.set_size(2);
+ dim.set_stride(1);
+ dim.set_padding_low(0);
+ dim.set_padding_high(0);
+ dim.set_window_dilation(1);
+ dim.set_base_dilation(1);
+ *window.add_dimensions() = dim;
+ *window.add_dimensions() = dim;
+
+ std::vector<float> input_elems(ShapeUtil::ElementsIn(input_shape));
+ std::iota(input_elems.begin(), input_elems.end(), -7);
+ auto input_r1 = LiteralUtil::CreateR1<float>(input_elems);
+ auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
+ HloInstruction* lhs_instruction =
+ b.AddInstruction(HloInstruction::CreateConstant(std::move(input_r4)));
+
+ std::vector<float> filter_elems(ShapeUtil::ElementsIn(filter_shape));
+ std::iota(filter_elems.begin(), filter_elems.end(), -31);
+ auto filter_r1 = LiteralUtil::CreateR1<float>(filter_elems);
+ auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
+ HloInstruction* rhs_instruction =
+ b.AddInstruction(HloInstruction::CreateConstant(std::move(filter_r4)));
+
+ Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 8});
+ b.AddInstruction(HloInstruction::CreateConvolve(
+ shape, lhs_instruction, rhs_instruction,
+ /*feature_group_count=*/2, window, dnums, DefaultPrecisionConfig(2)));
+ module().AddEntryComputation(b.Build());
+
+ std::unique_ptr<Literal> result = Evaluate();
+
+ Array4D<float> expected_array(1, 1, 1, 8);
+ expected_array.FillWithYX(
+ Array2D<float>({{668, 664, 660, 656, 668, 680, 692, 704}}));
+ auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+}
+
class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {};
// Tests that Reduce doesn't lose precision when adding many numbers (because
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index cb27e13e99..6a09bb08f4 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -1021,9 +1021,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
CHECK_EQ(num_spatial_dims + 2, lhs_rank);
CHECK_EQ(num_spatial_dims + 2, rhs_rank);
- TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
- ShapeInference::InferConvolveShape(lhs_shape, rhs_shape,
- window, dnums));
+ TF_ASSIGN_OR_RETURN(
+ auto inferred_return_shape,
+ ShapeInference::InferConvolveShape(
+ lhs_shape, rhs_shape, conv->feature_group_count(), window, dnums));
CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
<< "return shape set to: " << ShapeUtil::HumanString(result_shape)
<< " but is inferred to be: "
@@ -1046,9 +1047,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
auto lhs_literal_data = lhs_literal.data<ReturnT>();
auto rhs_literal_data = rhs_literal.data<ReturnT>();
+ int64 feature_group_count = conv->feature_group_count();
+
auto func = [&window_shape, &dnums, &lhs_shape, &rhs_shape, &window,
&lhs_dim_multipliers, &rhs_dim_multipliers, lhs_literal_data,
- rhs_literal_data](absl::Span<const int64> out_index) {
+ rhs_literal_data,
+ feature_group_count](absl::Span<const int64> out_index) {
// Dimension number applicable for input (lhs).
const int64 input_batch_dim = dnums.input_batch_dimension();
const int64 input_z_dim = dnums.input_feature_dimension();
@@ -1060,6 +1064,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
const int64 output_z_dim = dnums.output_feature_dimension();
const int64 z_size = ShapeUtil::GetDimension(lhs_shape, input_z_dim);
+ const int64 output_z_size =
+ ShapeUtil::GetDimension(rhs_shape, kernel_output_z_dim);
ElementwiseT result_val = static_cast<ElementwiseT>(0);
DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size(),
@@ -1068,6 +1074,33 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// Convolve input feature with kernel.
do {
for (int64 iz = 0; iz < z_size; ++iz) {
+ int64 rhs_iz = iz;
+ // Handle grouped convolutions.
+ if (feature_group_count > 1) {
+ // The size of a feature group.
+ int64 feature_group_size = z_size / feature_group_count;
+ rhs_iz = iz % feature_group_size;
+
+ // The output feature dimension is a concatenation of convolution
+ // results from the different groups.
+ int64 output_feature_group_size =
+ output_z_size / feature_group_count;
+
+ // Calculate the group index to which the current input feature
+ // index belongs.
+ int64 input_group_index = iz / feature_group_size;
+
+ // Calculate the group index to which the current output index
+ // belongs.
+ int64 output_group_index =
+ out_index[output_z_dim] / output_feature_group_size;
+ if (input_group_index != output_group_index) {
+ // If the current output index does not belong to the current
+ // feature group, skip it.
+ continue;
+ }
+ }
+
int64 lhs_linear_index = 0;
lhs_linear_index += out_index[output_batch_dim] *
lhs_dim_multipliers[input_batch_dim];
@@ -1076,7 +1109,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
int64 rhs_linear_index = 0;
rhs_linear_index += out_index[output_z_dim] *
rhs_dim_multipliers[kernel_output_z_dim];
- rhs_linear_index += iz * rhs_dim_multipliers[kernel_input_z_dim];
+ rhs_linear_index += rhs_iz * rhs_dim_multipliers[kernel_input_z_dim];
// Find corresponding spatial dimension index for input (lhs).
for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) {
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index 3041d94fa9..0345a2a5f8 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -120,12 +120,19 @@ class NodeFilter {
std::function<NodeFilterResult(const HloInstruction* instr)> filter_;
};
+// We arbitrarily set this as the boundary between "large" and "small"
+// instructions.
+bool IsSmall(const HloInstruction* instr) {
+ return ShapeUtil::ElementsInRecursive(instr->shape()) < 4096;
+}
+
// Node color schemes, used by NodeColorAttributes.
enum ColorScheme {
kBlue,
kBrown,
kDarkBlue,
kDarkGreen,
+ kDarkOrange,
kDarkRed,
kGray,
kGreen,
@@ -158,6 +165,10 @@ NodeColors NodeColorsForScheme(ColorScheme color) {
return NodeColors{"filled", "#1565c0", "#003c8f", "white"};
case kDarkGreen:
return NodeColors{"filled", "#2e7d32", "#005005", "white"};
+ case kDarkOrange:
+ // This is more of a "medium" orange, made to look close to kOrange;
+ // there's probably room for a darker weight if desired.
+ return NodeColors{"filled", "#ffb74d", "#c88719", "black"};
case kDarkRed:
return NodeColors{"filled", "#b71c1c", "#7f0000", "white"};
case kGray:
@@ -893,7 +904,10 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
sharding_colors_.emplace(instr->sharding(), color);
return color;
}
- const auto kParameterColor = kOrange;
+
+ // Choose different weights of orange for small vs large parameters. This
+ // distinction is often important, especially in fusion nodes.
+ auto parameter_color = IsSmall(instr) ? kOrange : kDarkOrange;
// Special case: If this instruction has a parameter merged into it, paint it
// the same color as a parameter. Unless the merged-in parameter is a
@@ -905,7 +919,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
ShouldMergeIntoUsers(operand) &&
TryGetFusionParameterConstant(operand) == nullptr;
})) {
- return kParameterColor;
+ return parameter_color;
}
// Pick different colors or shapes for instructions which are particularly
@@ -1015,7 +1029,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
case HloOpcode::kReducePrecision:
return kRed;
case HloOpcode::kParameter:
- return kParameterColor;
+ return parameter_color;
case HloOpcode::kBatchNormGrad:
case HloOpcode::kBatchNormInference:
case HloOpcode::kBatchNormTraining:
@@ -1160,20 +1174,6 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
return StrJoin(lines, "<br/>");
}
-// Gets the total number of array elements in the given shape. For tuples, this
-// is the sum of all the sizes of all of the array elements recursively in the
-// tuple.
-static int64 TotalElementsInShape(const Shape& shape) {
- int64 elems = 0;
- ShapeUtil::ForEachSubshape(
- shape, [&](const Shape& subshape, const ShapeIndex& /*index*/) {
- if (ShapeUtil::IsArray(subshape)) {
- elems += ShapeUtil::ElementsIn(subshape);
- }
- });
- return elems;
-}
-
void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) {
auto add_edge = [&](const HloInstruction* from, const HloInstruction* to,
int64 operand_num, bool control_edge = false) {
@@ -1196,14 +1196,11 @@ void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) {
}
// We print "small" arrays using a hollow arrowhead and "large" arrays using
- // a filled arrowhead. For now, we use an arbitrary cutoff for what "big"
- // means.
- bool is_big_array = TotalElementsInShape(from->shape()) >= 4096;
-
+ // a filled arrowhead.
constexpr char kEdgeFmt[] =
R"(%s -> %s [arrowhead=%s tooltip="%s -> %s" %s];)";
edges_.push_back(StrFormat(kEdgeFmt, InstructionId(from), InstructionId(to),
- (is_big_array ? "normal" : "empty"),
+ (IsSmall(from) ? "empty" : "normal"),
from->name(), to->name(), edge_label));
};
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 6d13f85cbb..f25761ac70 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -341,17 +341,21 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
source_target_pairs);
break;
}
- case HloOpcode::kConvolution:
+ case HloOpcode::kConvolution: {
TF_RET_CHECK(proto.operand_ids_size() == 2)
<< "Convolution instruction should have 2 operands but sees "
<< proto.operand_ids_size();
TF_RET_CHECK(proto.has_window());
TF_RET_CHECK(proto.has_convolution_dimension_numbers());
+ PrecisionConfigProto precision_config = proto.precision_config();
+ precision_config.mutable_operand_precision()->Resize(
+ proto.operand_ids_size(), PrecisionConfigProto::DEFAULT);
instruction = CreateConvolve(
- proto.shape(), operands(0), operands(1), proto.window(),
- proto.convolution_dimension_numbers(),
- std::max(static_cast<int64>(proto.feature_group_count()), 1LL));
+ proto.shape(), operands(0), operands(1),
+ std::max<int64>(proto.feature_group_count(), 1), proto.window(),
+ proto.convolution_dimension_numbers(), precision_config);
break;
+ }
case HloOpcode::kReduceWindow:
TF_RET_CHECK(proto.operand_ids_size() == 2)
<< "ReduceWindow instruction should have 2 operands but sees "
@@ -468,6 +472,20 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
computation_map.at(computation_id));
}
}
+ if (instruction->opcode() == HloOpcode::kDot) {
+ instruction->precision_config_ = proto.precision_config();
+ instruction->precision_config_.mutable_operand_precision()->Resize(
+ instruction->operand_count(), PrecisionConfigProto::DEFAULT);
+ TF_RET_CHECK(proto.has_dot_dimension_numbers());
+ instruction->dot_dimension_numbers_ =
+ absl::make_unique<DotDimensionNumbers>(
+ proto.dot_dimension_numbers());
+ } else {
+ TF_RET_CHECK(!proto.has_precision_config())
+ << instruction->opcode() << proto.DebugString();
+ TF_RET_CHECK(!proto.has_dot_dimension_numbers())
+ << instruction->opcode();
+ }
break;
}
}
@@ -476,12 +494,6 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
instruction->SetAndSanitizeName(proto.name());
instruction->metadata_ = proto.metadata();
instruction->backend_config_ = proto.backend_config();
- instruction->precision_config_ = proto.precision_config();
-
- if (proto.has_dot_dimension_numbers()) {
- instruction->dot_dimension_numbers_ =
- absl::make_unique<DotDimensionNumbers>(proto.dot_dimension_numbers());
- }
if (proto.has_sharding()) {
TF_ASSIGN_OR_RETURN(const auto& sharding,
@@ -643,10 +655,12 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvolve(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
- const Window& window, const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count) {
+ int64 feature_group_count, const Window& window,
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ const PrecisionConfigProto& precision_config) {
return absl::make_unique<HloConvolutionInstruction>(
- shape, lhs, rhs, window, dimension_numbers, feature_group_count);
+ shape, lhs, rhs, feature_group_count, window, dimension_numbers,
+ precision_config);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFft(
@@ -658,13 +672,15 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDot(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
- const DotDimensionNumbers& dimension_numbers) {
+ const DotDimensionNumbers& dimension_numbers,
+ const PrecisionConfigProto& precision_config) {
auto instruction =
absl::WrapUnique(new HloInstruction(HloOpcode::kDot, shape));
instruction->AppendOperand(lhs);
instruction->AppendOperand(rhs);
instruction->dot_dimension_numbers_ =
absl::make_unique<DotDimensionNumbers>(dimension_numbers);
+ instruction->set_precision_config(precision_config);
return instruction;
}
@@ -1057,7 +1073,6 @@ void HloInstruction::SetupDerivedInstruction(
derived_instruction->clear_sharding();
}
derived_instruction->set_metadata(metadata_);
- derived_instruction->set_precision_config(precision_config_);
}
bool HloInstruction::HasSideEffectNoRecurse() const {
@@ -1278,7 +1293,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
case HloOpcode::kDot:
CHECK_EQ(new_operands.size(), 2);
clone = CreateDot(shape, new_operands[0], new_operands[1],
- *dot_dimension_numbers_);
+ *dot_dimension_numbers_, precision_config());
break;
case HloOpcode::kReshape:
CHECK_EQ(new_operands.size(), 1);
@@ -2167,7 +2182,9 @@ HloInstructionProto HloInstruction::ToProto() const {
*proto.mutable_metadata() = metadata_;
proto.set_backend_config(backend_config_);
- *proto.mutable_precision_config() = precision_config_;
+ if (opcode() == HloOpcode::kConvolution || opcode() == HloOpcode::kDot) {
+ *proto.mutable_precision_config() = precision_config_;
+ }
if (opcode() != HloOpcode::kFusion) {
for (const HloComputation* computation : called_computations_) {
proto.add_called_computation_ids(computation->unique_id());
@@ -2948,7 +2965,11 @@ StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) {
}
string HloInstruction::PrecisionConfigToString() const {
- if (precision_config_.operand_precision().empty()) {
+ if (absl::c_all_of(
+ precision_config_.operand_precision(), [](int32 precision) {
+ return static_cast<PrecisionConfigProto::Precision>(precision) ==
+ PrecisionConfigProto::DEFAULT;
+ })) {
return "";
}
return StrCat(
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index cca134e8b4..55d592ff94 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -405,9 +405,9 @@ class HloInstruction {
// and window describes how the filter is applied to lhs.
static std::unique_ptr<HloInstruction> CreateConvolve(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
- const Window& window,
+ int64 feature_group_count, const Window& window,
const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count = 1);
+ const PrecisionConfigProto& precision_config);
// Creates an FFT op, of the type indicated by fft_type.
static std::unique_ptr<HloInstruction> CreateFft(
@@ -418,7 +418,8 @@ class HloInstruction {
// dimensions specified in 'dimension_numbers'.
static std::unique_ptr<HloInstruction> CreateDot(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
- const DotDimensionNumbers& dimension_numbers);
+ const DotDimensionNumbers& dimension_numbers,
+ const PrecisionConfigProto& precision_config);
// Creates a dot op with operands 'lhs' and 'rhs' that contracts dimension 1
// of the LHS with dimension 0 of the RHS with no batch dimensions. Both LHS
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
index 76b0e940a6..b4e302e832 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
@@ -1122,6 +1122,13 @@ TEST_F(HloInstructionTest, PartiallyElementwiseWithReuse) {
}
}
+PrecisionConfigProto DefaultPrecisionConfig(int operands) {
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ operands, PrecisionConfigProto::DEFAULT);
+ return precision_config;
+}
+
TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) {
// Fused expression:
//
@@ -1147,8 +1154,8 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+ HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
+ sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
auto* computation = module->AddEntryComputation(builder.Build());
@@ -1188,8 +1195,8 @@ TEST_F(HloInstructionTest, NoRedundantFusionOperandsAfterReplacingUse) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateDot(s, x, reshape, dot_dnums));
+ HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
+ s, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
auto* computation = module->AddEntryComputation(builder.Build());
@@ -1239,8 +1246,8 @@ TEST_F(HloInstructionTest, NestedFusionEquality) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- auto dot = builder.AddInstruction(
- HloInstruction::CreateDot(data_shape, a, b_t, dot_dnums));
+ auto dot = builder.AddInstruction(HloInstruction::CreateDot(
+ data_shape, a, b_t, dot_dnums, DefaultPrecisionConfig(2)));
auto one = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto add_operand = builder.AddInstruction(
@@ -1320,8 +1327,8 @@ TEST_F(HloInstructionTest, Stringification) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+ HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
+ sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
auto options = HloPrintOptions().set_print_metadata(false);
@@ -1485,8 +1492,8 @@ TEST_F(HloInstructionTest, CanonnicalStringificationFusion) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+ HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
+ sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
auto options = HloPrintOptions().Canonical();
@@ -1527,8 +1534,8 @@ TEST_F(HloInstructionTest, CanonnicalStringificationWhile) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+ HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
+ sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
auto* computation = module->AddEntryComputation(builder.Build());
@@ -1583,8 +1590,8 @@ TEST_F(HloInstructionTest, CanonnicalStringificationConditional) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+ HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
+ sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
auto* computation = module->AddEntryComputation(builder.Build());
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index e46afa764f..e3683aaec9 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -1628,12 +1628,13 @@ std::unique_ptr<HloInstruction> HloOutfeedInstruction::CloneWithNewOperandsImpl(
HloConvolutionInstruction::HloConvolutionInstruction(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
- const Window& window, const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count)
+ int64 feature_group_count, const Window& window,
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ const PrecisionConfigProto& precision_config)
: HloInstruction(HloOpcode::kConvolution, shape),
+ feature_group_count_(feature_group_count),
window_(window),
- convolution_dimension_numbers_(dimension_numbers),
- feature_group_count_(feature_group_count) {
+ convolution_dimension_numbers_(dimension_numbers) {
if (window_util::HasBaseDilation(window)) {
SetAndSanitizeName(StrCat(name(), "-base-dilated"));
}
@@ -1642,6 +1643,7 @@ HloConvolutionInstruction::HloConvolutionInstruction(
}
AppendOperand(lhs);
AppendOperand(rhs);
+ set_precision_config(precision_config);
}
string HloConvolutionInstruction::ToCategory() const {
@@ -1672,7 +1674,9 @@ std::vector<string> HloConvolutionInstruction::ExtraAttributesToStringImpl(
}
extra.push_back(StrCat("dim_labels=", ConvolutionDimensionNumbersToString(
convolution_dimension_numbers_)));
- extra.push_back(StrCat("feature_group_count=", feature_group_count_));
+ if (feature_group_count_ != 1) {
+ extra.push_back(StrCat("feature_group_count=", feature_group_count_));
+ }
return extra;
}
@@ -1697,8 +1701,8 @@ HloConvolutionInstruction::CloneWithNewOperandsImpl(
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
return absl::make_unique<HloConvolutionInstruction>(
- shape, new_operands[0], new_operands[1], window(),
- convolution_dimension_numbers_, feature_group_count_);
+ shape, new_operands[0], new_operands[1], feature_group_count_, window(),
+ convolution_dimension_numbers_, precision_config());
}
HloReduceWindowInstruction::HloReduceWindowInstruction(
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index 3230383579..1c85aa4681 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -942,9 +942,9 @@ class HloConvolutionInstruction : public HloInstruction {
public:
explicit HloConvolutionInstruction(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
- const Window& window,
+ int64 feature_group_count, const Window& window,
const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count);
+ const PrecisionConfigProto& precision_config);
const Window& window() const override { return window_; }
void set_window(const Window& window) override { window_ = window; }
const ConvolutionDimensionNumbers& convolution_dimension_numbers() const {
@@ -972,12 +972,13 @@ class HloConvolutionInstruction : public HloInstruction {
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
- Window window_;
- // Describes the dimension numbers used for a convolution.
- ConvolutionDimensionNumbers convolution_dimension_numbers_;
// The number of feature groups. Must be a divisor of the input feature
// dimension and output feature dimension.
int64 feature_group_count_;
+ // Describes the window used for a convolution.
+ Window window_;
+ // Describes the dimension numbers used for a convolution.
+ ConvolutionDimensionNumbers convolution_dimension_numbers_;
};
class HloReduceWindowInstruction : public HloInstruction {
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index ea8e6a239a..62f01c4adb 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -530,10 +530,6 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
attrs["backend_config"] = {/*required=*/false, AttrTy::kString,
&backend_config};
- optional<std::vector<PrecisionConfigProto::Precision>> operand_precision;
- attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
- &operand_precision};
-
HloInstruction* instruction;
switch (opcode) {
case HloOpcode::kParameter: {
@@ -913,6 +909,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
AttrTy::kConvolutionDimensionNumbers, &dnums};
attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64,
&feature_group_count};
+ optional<std::vector<PrecisionConfigProto::Precision>> operand_precision;
+ attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
+ &operand_precision};
if (!ParseOperands(&operands, /*expected_size=*/2) ||
!ParseAttributes(attrs)) {
return false;
@@ -923,9 +922,17 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
if (!feature_group_count) {
feature_group_count = 1;
}
+ PrecisionConfigProto precision_config;
+ if (operand_precision) {
+ *precision_config.mutable_operand_precision() = {
+ operand_precision->begin(), operand_precision->end()};
+ } else {
+ precision_config.mutable_operand_precision()->Resize(
+ operands.size(), PrecisionConfigProto::DEFAULT);
+ }
instruction = builder->AddInstruction(HloInstruction::CreateConvolve(
- shape, /*lhs=*/operands[0], /*rhs=*/operands[1], *window, *dnums,
- feature_group_count.value()));
+ shape, /*lhs=*/operands[0], /*rhs=*/operands[1],
+ feature_group_count.value(), *window, *dnums, precision_config));
break;
}
case HloOpcode::kFft: {
@@ -1272,6 +1279,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
optional<std::vector<tensorflow::int64>> rhs_batch_dims;
attrs["rhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List,
&rhs_batch_dims};
+ optional<std::vector<PrecisionConfigProto::Precision>> operand_precision;
+ attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
+ &operand_precision};
if (!ParseOperands(&operands, /*expected_size=*/2) ||
!ParseAttributes(attrs)) {
@@ -1296,8 +1306,17 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
rhs_batch_dims->end()};
}
- instruction = builder->AddInstruction(
- HloInstruction::CreateDot(shape, operands[0], operands[1], dnum));
+ PrecisionConfigProto precision_config;
+ if (operand_precision) {
+ *precision_config.mutable_operand_precision() = {
+ operand_precision->begin(), operand_precision->end()};
+ } else {
+ precision_config.mutable_operand_precision()->Resize(
+ operands.size(), PrecisionConfigProto::DEFAULT);
+ }
+
+ instruction = builder->AddInstruction(HloInstruction::CreateDot(
+ shape, operands[0], operands[1], dnum, precision_config));
break;
}
case HloOpcode::kGather: {
@@ -1414,12 +1433,6 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
if (backend_config) {
instruction->set_raw_backend_config_string(std::move(*backend_config));
}
- if (operand_precision) {
- PrecisionConfigProto precision_config;
- *precision_config.mutable_operand_precision() = {operand_precision->begin(),
- operand_precision->end()};
- instruction->set_precision_config(precision_config);
- }
return AddInstruction(name, instruction, name_loc);
} // NOLINT(readability/fn_size)
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index 759789437c..0dfc0a4d1c 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -19,6 +19,8 @@ limitations under the License.
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -382,7 +384,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2
%input = f32[1,2,1]{2,1,0} parameter(0)
%copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
%filter = f32[1,1,1]{2,1,0} parameter(1)
- ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, feature_group_count=1, operand_precision={high,default}
+ ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, operand_precision={high,default}
}
)"
@@ -395,7 +397,7 @@ R"(HloModule ConvolveR2_module
ENTRY %ConvolveR2.v3 (input: f32[1,2], filter: f32[1,1]) -> f32[1,2] {
%input = f32[1,2]{1,0} parameter(0)
%filter = f32[1,1]{1,0} parameter(1)
- ROOT %convolution = f32[1,2]{0,1} convolution(f32[1,2]{1,0} %input, f32[1,1]{1,0} %filter), dim_labels=bf_io->bf, feature_group_count=1
+ ROOT %convolution = f32[1,2]{0,1} convolution(f32[1,2]{1,0} %input, f32[1,1]{1,0} %filter), dim_labels=bf_io->bf
}
)"
@@ -408,7 +410,7 @@ R"(HloModule ConvolveBackward_module
ENTRY %ConvolveBackward (input: f32[128,7,7,512], filter: f32[3,3,512,512]) -> f32[128,14,14,512] {
%input = f32[128,7,7,512]{0,3,2,1} parameter(0)
%filter = f32[3,3,512,512]{3,2,1,0} parameter(1)
- ROOT %convolution-base-dilated = f32[128,14,14,512]{0,3,2,1} convolution(f32[128,7,7,512]{0,3,2,1} %input, f32[3,3,512,512]{3,2,1,0} %filter), window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1}, dim_labels=b01f_01oi->b01f, feature_group_count=1
+ ROOT %convolution-base-dilated = f32[128,14,14,512]{0,3,2,1} convolution(f32[128,7,7,512]{0,3,2,1} %input, f32[3,3,512,512]{3,2,1,0} %filter), window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1}, dim_labels=b01f_01oi->b01f
}
)"
@@ -1775,5 +1777,18 @@ TEST(HloParserSingleOpTest, SingleOpNoShapesProducesError) {
::testing::HasSubstr("Operand broadcast had no shape in HLO text"));
}
+TEST(HloParserSingleOpTest, ConvolutionTrivialFeatureGroupCount) {
+ const string text =
+ R"(%convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f)";
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloOpToModule(text));
+ const HloComputation* computation = module->entry_computation();
+ ASSERT_NE(computation, nullptr);
+ EXPECT_THAT(computation->root_instruction(),
+ op::Convolution(op::Parameter(0), op::Parameter(1)));
+ auto* convolution =
+ Cast<HloConvolutionInstruction>(computation->root_instruction());
+ EXPECT_EQ(convolution->feature_group_count(), 1);
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
index 34cba6136f..e3f4a9852a 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
@@ -422,6 +422,13 @@ bool ShardingMetadata::Matches(const DomainMetadata& other) const {
: false;
}
+size_t ShardingMetadata::Hash() const {
+ if (sharding_ != nullptr) {
+ return sharding_->Hash();
+ }
+ return static_cast<size_t>(0x297814aaad196e6dULL);
+}
+
string ShardingMetadata::ToString() const {
return sharding_ != nullptr ? sharding_->ToString() : "{}";
}
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h
index cba5db927a..e3ae82a070 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h
+++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h
@@ -36,6 +36,8 @@ class ShardingMetadata : public DomainMetadata {
bool Matches(const DomainMetadata& other) const override;
+ size_t Hash() const override;
+
string ToString() const override;
const HloSharding* sharding() const { return sharding_.get(); }
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 95516dec74..069586a738 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -86,8 +86,8 @@ Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) {
const Shape expected,
ShapeInference::InferConvolveShape(
convolution->operand(0)->shape(), convolution->operand(1)->shape(),
- convolution->window(), convolution->convolution_dimension_numbers(),
- convolution->feature_group_count()));
+ convolution->feature_group_count(), convolution->window(),
+ convolution->convolution_dimension_numbers()));
return CheckShape(convolution, expected);
}
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
index a4de02a890..4a71ee909b 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
@@ -165,6 +165,7 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayFor(
TF_ASSIGN_OR_RETURN(
computed_array,
ComputeArrayForDot(instr->shape(), instr->dot_dimension_numbers(),
+ instr->precision_config(),
FindOrDie(cache_, instr->operand(0)),
FindOrDie(cache_, instr->operand(1))));
} else {
@@ -1030,6 +1031,7 @@ bool CanFoldDotIntoIndexedArray(
StatusOr<Analysis::Array*>
IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs(
const Shape& shape, const DotDimensionNumbers& dim_numbers,
+ const PrecisionConfigProto& precision_config,
ScalarIndexedConstantArray* lhs, ConstantArray* rhs) {
VLOG(3) << "ComputeArrayForDotWithIndexedLhs(" << ToString(lhs) << " "
<< ToString(rhs);
@@ -1045,9 +1047,10 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs(
new_dim_numbers.set_lhs_contracting_dimensions(
0, lhs->source_dim() == (lhs_rank - 1) ? (lhs_rank - 2) : (lhs_rank - 1));
- TF_ASSIGN_OR_RETURN(Literal * literal_for_new_source,
- TakeOwnership(HloEvaluator{}.EvaluateDotOp(
- new_dim_numbers, lhs->literal(), *rhs->literal())));
+ TF_ASSIGN_OR_RETURN(
+ Literal * literal_for_new_source,
+ TakeOwnership(HloEvaluator{}.EvaluateDotOp(
+ new_dim_numbers, precision_config, lhs->literal(), *rhs->literal())));
// The new source dimension is wherever the non-batch non-contracting LHS
// dimension "went".
@@ -1063,7 +1066,8 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs(
StatusOr<Analysis::Array*>
IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs(
const Shape& shape, const DotDimensionNumbers& dim_numbers,
- ConstantArray* lhs, ScalarIndexedConstantArray* rhs) {
+ const PrecisionConfigProto& precision_config, ConstantArray* lhs,
+ ScalarIndexedConstantArray* rhs) {
VLOG(3) << "ComputeArrayForDotWithIndexedRhs(" << ToString(lhs) << " "
<< ToString(rhs);
if (!CanFoldDotIntoIndexedArray(
@@ -1079,9 +1083,10 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs(
new_dim_numbers.set_rhs_contracting_dimensions(
0, rhs->source_dim() == (rhs_rank - 1) ? (rhs_rank - 2) : (rhs_rank - 1));
- TF_ASSIGN_OR_RETURN(Literal * literal_for_new_source,
- TakeOwnership(HloEvaluator{}.EvaluateDotOp(
- new_dim_numbers, *lhs->literal(), rhs->literal())));
+ TF_ASSIGN_OR_RETURN(
+ Literal * literal_for_new_source,
+ TakeOwnership(HloEvaluator{}.EvaluateDotOp(
+ new_dim_numbers, precision_config, *lhs->literal(), rhs->literal())));
// The new source dimension is wherever the non-batch non-contracting RHS
// dimension "went".
@@ -1095,8 +1100,8 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs(
}
StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForDot(
- const Shape& shape, const DotDimensionNumbers& dim_numbers, Array* lhs,
- Array* rhs) {
+ const Shape& shape, const DotDimensionNumbers& dim_numbers,
+ const PrecisionConfigProto& precision_config, Array* lhs, Array* rhs) {
// Intuitively, if
//
// - The LHS of a dot product is a gathered sequence of rows from a constant
@@ -1119,6 +1124,7 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForDot(
dynamic_cast<ScalarIndexedConstantArray*>(lhs)) {
if (auto* rhs_constant = dynamic_cast<ConstantArray*>(rhs)) {
return ComputeArrayForDotWithIndexedLhs(shape, dim_numbers,
+ precision_config,
lhs_indexed_array, rhs_constant);
}
}
@@ -1126,7 +1132,8 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForDot(
if (auto* rhs_indexed_array =
dynamic_cast<ScalarIndexedConstantArray*>(rhs)) {
if (auto* lhs_constant = dynamic_cast<ConstantArray*>(lhs)) {
- return ComputeArrayForDotWithIndexedRhs(shape, dim_numbers, lhs_constant,
+ return ComputeArrayForDotWithIndexedRhs(shape, dim_numbers,
+ precision_config, lhs_constant,
rhs_indexed_array);
}
}
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h
index dcfb725535..f21e784a4d 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.h
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h
@@ -267,15 +267,17 @@ class IndexedArrayAnalysis {
StatusOr<Array*> ComputeArrayForDotWithIndexedLhs(
const Shape& shape, const DotDimensionNumbers& dim_numbers,
+ const PrecisionConfigProto& precision_config,
ScalarIndexedConstantArray* lhs, ConstantArray* rhs);
StatusOr<Array*> ComputeArrayForDotWithIndexedRhs(
const Shape& shape, const DotDimensionNumbers& dim_numbers,
- ConstantArray* lhs, ScalarIndexedConstantArray* rhs);
+ const PrecisionConfigProto& precision_config, ConstantArray* lhs,
+ ScalarIndexedConstantArray* rhs);
- StatusOr<Array*> ComputeArrayForDot(const Shape& shape,
- const DotDimensionNumbers& dim_numbers,
- Array* lhs, Array* rhs);
+ StatusOr<Array*> ComputeArrayForDot(
+ const Shape& shape, const DotDimensionNumbers& dim_numbers,
+ const PrecisionConfigProto& precision_config, Array* lhs, Array* rhs);
// This tries to fold a ScalarIndexedArray which has another
// ScalarIndexedArray as a source into a ScalarIndexedArray that instead has a
diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc
index 021fe630ff..69c7e42601 100644
--- a/tensorflow/compiler/xla/service/layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc
@@ -874,18 +874,18 @@ TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) {
)";
auto module = ParseHloString(module_str).ValueOrDie();
- module =
+ auto compiled_module =
backend()
.compiler()
->RunHloPasses(std::move(module), backend().default_stream_executor(),
/*device_allocator=*/nullptr)
.ConsumeValueOrDie();
-
- auto copy = FindInstruction(module.get(), "copy.1");
- auto slice = FindInstruction(module.get(), "slice0");
- EXPECT_EQ(slice->operand(0), copy);
- EXPECT_TRUE(
- LayoutUtil::Equal(slice->shape().layout(), copy->shape().layout()));
+ HloInstruction* root =
+ compiled_module->entry_computation()->root_instruction();
+ Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {1, 0});
+ EXPECT_THAT(root, op::Add(op::Parameter(),
+ op::Slice(AllOf(op::Copy(op::Parameter(1)),
+ op::ShapeWithLayout(shape_copy)))));
}
TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) {
@@ -902,18 +902,20 @@ TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) {
)";
auto module = ParseHloString(module_str).ValueOrDie();
- module =
+ auto compiled_module =
backend()
.compiler()
->RunHloPasses(std::move(module), backend().default_stream_executor(),
/*device_allocator=*/nullptr)
.ConsumeValueOrDie();
-
- auto copy = FindInstruction(module.get(), "copy.1");
- auto dslice = FindInstruction(module.get(), "dslice0");
- EXPECT_EQ(dslice->operand(0), copy);
- EXPECT_TRUE(
- LayoutUtil::Equal(dslice->shape().layout(), copy->shape().layout()));
+ HloInstruction* root =
+ compiled_module->entry_computation()->root_instruction();
+ Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {1, 0});
+ EXPECT_THAT(root,
+ op::Add(op::Parameter(),
+ op::DynamicSlice(AllOf(op::Copy(op::Parameter(1)),
+ op::ShapeWithLayout(shape_copy)),
+ op::Parameter(2))));
}
TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) {
@@ -931,18 +933,20 @@ TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) {
)";
auto module = ParseHloString(module_str).ValueOrDie();
- module =
+ auto compiled_module =
backend()
.compiler()
->RunHloPasses(std::move(module), backend().default_stream_executor(),
/*device_allocator=*/nullptr)
.ConsumeValueOrDie();
-
- auto copy = FindInstruction(module.get(), "copy.1");
- auto concat = FindInstruction(module.get(), "concat0");
- EXPECT_EQ(concat->operand(0), copy);
- EXPECT_TRUE(
- LayoutUtil::Equal(concat->shape().layout(), copy->shape().layout()));
+ HloInstruction* root =
+ compiled_module->entry_computation()->root_instruction();
+ Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {3, 5}, {1, 0});
+ EXPECT_THAT(root,
+ op::Add(op::Parameter(),
+ op::Concatenate(AllOf(op::Copy(op::Parameter(1)),
+ op::ShapeWithLayout(shape_copy)),
+ op::Parameter(2))));
}
TEST_F(LayoutAssignmentTest,
@@ -960,15 +964,39 @@ TEST_F(LayoutAssignmentTest,
)";
auto module = ParseHloString(module_str).ValueOrDie();
- module =
+ auto compiled_module =
backend()
.compiler()
->RunHloPasses(std::move(module), backend().default_stream_executor(),
/*device_allocator=*/nullptr)
.ConsumeValueOrDie();
+ HloInstruction* root =
+ compiled_module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Convolution(op::Parameter(0), op::Parameter(1)));
+}
+
+TEST_F(LayoutAssignmentTest, PropagatingLayoutFromResultToOperand) {
+ const char* module_str = R"(
+ HloModule PropagatingLayoutFromResultToOperand
+
+ ENTRY PropagatingLayoutFromResultToOperand {
+ par0 = f32[4,5]{1,0} parameter(0)
+ ROOT slice0 = f32[3,4]{0,1} slice(par0), slice={[1:4],[1:5]}
+ }
+ )";
- auto copy = FindInstruction(module.get(), "copy.1");
- EXPECT_EQ(copy, nullptr);
+ auto module = ParseHloString(module_str).ValueOrDie();
+ auto compiled_module =
+ backend()
+ .compiler()
+ ->RunHloPasses(std::move(module), backend().default_stream_executor(),
+ /*device_allocator=*/nullptr)
+ .ConsumeValueOrDie();
+ HloInstruction* root =
+ compiled_module->entry_computation()->root_instruction();
+ Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {0, 1});
+ EXPECT_THAT(root, op::Slice(AllOf(op::Copy(op::Parameter(0)),
+ op::ShapeWithLayout(shape_copy))));
}
} // namespace
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 2611749862..74bdf2a2e3 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -1552,8 +1552,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/* static */ StatusOr<Shape> ShapeInference::InferConvolveShape(
- const Shape& lhs, const Shape& rhs, const Window& window,
- const ConvolutionDimensionNumbers& dnums, int64 feature_group_count) {
+ const Shape& lhs, const Shape& rhs, int64 feature_group_count,
+ const Window& window, const ConvolutionDimensionNumbers& dnums) {
TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution"));
TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution"));
@@ -1672,6 +1672,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs),
dnums.DebugString());
}
+ if (kernel_output_features % feature_group_count > 0) {
+ return InvalidArgument(
+ "Expected output feature dimension (value %d) to be divisible by "
+ "feature_group_count (value %d); "
+ "got <conv>(%s, %s)\n"
+ "Dimension numbers: {%s}.",
+ kernel_output_features, feature_group_count,
+ ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs),
+ dnums.DebugString());
+ }
std::vector<int64> window_dims(num_spatial_dims);
for (int i = 0; i < num_spatial_dims; ++i) {
window_dims[i] = window.dimensions(i).size();
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h
index a28345acef..96a0ee165d 100644
--- a/tensorflow/compiler/xla/service/shape_inference.h
+++ b/tensorflow/compiler/xla/service/shape_inference.h
@@ -108,9 +108,9 @@ class ShapeInference {
// Infers the shape produced by applying the given convolutional
// filter (rhs) to lhs in the way specified by the fields on window.
static StatusOr<Shape> InferConvolveShape(
- const Shape& lhs, const Shape& rhs, const Window& window,
- const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count = 1);
+ const Shape& lhs, const Shape& rhs, int64 feature_group_count,
+ const Window& window,
+ const ConvolutionDimensionNumbers& dimension_numbers);
// Infers the shape produced by the given FFT type on the given operand.
static StatusOr<Shape> InferFftShape(const Shape& in, FftType fft_type,
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index cc92e58ef8..864ed43118 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -419,8 +419,8 @@ TEST_F(ShapeInferenceTest, Convolve) {
dim1->set_padding_high(0);
dim1->set_window_dilation(1);
dim1->set_base_dilation(1);
- auto inferred_status =
- ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums);
+ auto inferred_status = ShapeInference::InferConvolveShape(
+ lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums);
ASSERT_IS_OK(inferred_status.status());
Shape inferred_shape = inferred_status.ValueOrDie();
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 2, 3}),
@@ -464,8 +464,8 @@ TEST_F(ShapeInferenceTest, ConvolveWithWindowDilation) {
dim1->set_padding_high(1);
dim1->set_window_dilation(2);
dim1->set_base_dilation(1);
- auto inferred_status =
- ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums);
+ auto inferred_status = ShapeInference::InferConvolveShape(
+ lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums);
ASSERT_IS_OK(inferred_status.status());
Shape inferred_shape = inferred_status.ValueOrDie();
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 31, 5}),
@@ -509,8 +509,8 @@ TEST_F(ShapeInferenceTest, ConvolveWithBaseDilation) {
dim1->set_padding_high(1);
dim1->set_window_dilation(1);
dim1->set_base_dilation(2);
- auto inferred_status =
- ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums);
+ auto inferred_status = ShapeInference::InferConvolveShape(
+ lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums);
ASSERT_IS_OK(inferred_status.status());
Shape inferred_shape = inferred_status.ValueOrDie();
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 4, 9}),
@@ -547,8 +547,8 @@ TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) {
dim1->set_stride(2);
dim1->set_padding_low(1);
dim1->set_padding_high(1);
- auto inferred_status =
- ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums);
+ auto inferred_status = ShapeInference::InferConvolveShape(
+ lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums);
ASSERT_FALSE(inferred_status.ok());
ASSERT_THAT(inferred_status.status().error_message(),
HasSubstr("each dimension exactly once"));
diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc
index 530f40e4b2..7c1f4b5cc6 100644
--- a/tensorflow/compiler/xla/service/transpose_folding.cc
+++ b/tensorflow/compiler/xla/service/transpose_folding.cc
@@ -108,8 +108,7 @@ Status FoldTransposeIntoDot(InstructionOperandsPair pair) {
}
std::unique_ptr<HloInstruction> new_dot = HloInstruction::CreateDot(
- dot->shape(), new_lhs, new_rhs, new_dim_numbers);
- new_dot->set_precision_config(dot->precision_config());
+ dot->shape(), new_lhs, new_rhs, new_dim_numbers, dot->precision_config());
return dot->parent()->ReplaceWithNewInstruction(dot, std::move(new_dot));
}
@@ -178,8 +177,8 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) {
}
auto new_conv = HloInstruction::CreateConvolve(
- convolution.shape(), new_lhs, new_rhs, convolution.window(), new_dnums);
- new_conv->set_precision_config(convolution.precision_config());
+ convolution.shape(), new_lhs, new_rhs, convolution.feature_group_count(),
+ convolution.window(), new_dnums, convolution.precision_config());
TF_CHECK_OK(convolution.parent()->ReplaceWithNewInstruction(
&convolution, std::move(new_conv)));
diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc
index 58f767e913..e486a00e53 100644
--- a/tensorflow/compiler/xla/service/transpose_folding_test.cc
+++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc
@@ -215,6 +215,13 @@ ENTRY entry_computation {
/*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/1));
}
+PrecisionConfigProto DefaultPrecisionConfig(int operands) {
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ operands, PrecisionConfigProto::DEFAULT);
+ return precision_config;
+}
+
// Test that a two dimension swap of the kernel gets folded into convolution.
TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) {
auto builder = HloComputation::Builder("entry_computation");
@@ -240,10 +247,12 @@ TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) {
transpose_y->shape().dimensions(dnums.kernel_spatial_dimensions(i)));
}
StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
- x->shape(), transpose_y->shape(), window, dnums);
+ x->shape(), transpose_y->shape(), /*feature_group_count=*/1, window,
+ dnums);
EXPECT_IS_OK(conv_shape);
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
- conv_shape.ValueOrDie(), x, transpose_y, window, dnums));
+ conv_shape.ValueOrDie(), x, transpose_y,
+ /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
auto module = CreateNewModule("test_module");
HloComputation* entry_computation =
@@ -293,10 +302,12 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) {
transpose_y->shape().dimensions(dnums.kernel_spatial_dimensions(i)));
}
StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
- x->shape(), transpose_y->shape(), window, dnums);
+ x->shape(), transpose_y->shape(), /*feature_group_count=*/1, window,
+ dnums);
EXPECT_IS_OK(conv_shape);
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
- conv_shape.ValueOrDie(), x, transpose_y, window, dnums));
+ conv_shape.ValueOrDie(), x, transpose_y,
+ /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
auto module = CreateNewModule("test_module");
HloComputation* entry_computation =
@@ -351,10 +362,12 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) {
dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i)));
}
StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
- transpose_x->shape(), y->shape(), window, dnums);
+ transpose_x->shape(), y->shape(), /*feature_group_count=*/1, window,
+ dnums);
EXPECT_IS_OK(conv_shape);
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
- conv_shape.ValueOrDie(), transpose_x, y, window, dnums));
+ conv_shape.ValueOrDie(), transpose_x, y,
+ /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
auto module = CreateNewModule("test_module");
HloComputation* entry_computation =
@@ -415,10 +428,12 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) {
dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i)));
}
StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
- transpose_x->shape(), y->shape(), window, dnums);
+ transpose_x->shape(), y->shape(), /*feature_group_count=*/1, window,
+ dnums);
EXPECT_IS_OK(conv_shape);
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
- conv_shape.ValueOrDie(), transpose_x, y, window, dnums));
+ conv_shape.ValueOrDie(), transpose_x, y,
+ /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
auto module = CreateNewModule("test_module");
HloComputation* entry_computation =
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
index a32d1f9026..e3328203a6 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
@@ -1064,8 +1064,11 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ /*new_size=*/2, PrecisionConfigProto::DEFAULT);
auto dot = builder.AddInstruction(
- HloInstruction::CreateDot(data_shape, a, b, dot_dnums));
+ HloInstruction::CreateDot(data_shape, a, b, dot_dnums, precision_config));
auto one = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
index 05f90ba9fb..53b5e933b6 100644
--- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
+++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
@@ -47,6 +47,12 @@ limitations under the License.
namespace xla {
namespace {
+PrecisionConfigProto DefaultPrecisionConfig(int operands) {
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ operands, PrecisionConfigProto::DEFAULT);
+ return precision_config;
+}
class MultiOutputFusionTest : public HloTestBase {
protected:
@@ -90,8 +96,8 @@ class MultiOutputFusionTest : public HloTestBase {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateDot(elem_shape2, sub, add2, dot_dnums));
+ HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
+ elem_shape2, sub, add2, dot_dnums, DefaultPrecisionConfig(2)));
auto computation = hlo_module->AddEntryComputation(builder.Build(dot));
if (manual_fusion) {
@@ -154,7 +160,7 @@ class MultiOutputFusionTest : public HloTestBase {
dot_dnums.add_rhs_contracting_dimensions(0);
HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
ShapeUtil::MakeShapeWithDescendingLayout(F32, {1}), sub, reshape,
- dot_dnums));
+ dot_dnums, DefaultPrecisionConfig(2)));
auto computation = hlo_module->AddEntryComputation(builder.Build(dot));
if (manual_fusion) {
diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc
index 997880a018..a1001296a1 100644
--- a/tensorflow/compiler/xla/tests/reduce_window_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc
@@ -613,7 +613,7 @@ class R4ReduceWindowTest : public ReduceWindowTestBase,
Array4D<float> input(param.base_bounds[0], param.base_bounds[1],
param.base_bounds[2], param.base_bounds[3]);
- input.FillIota(1);
+ input.FillRandom(0.1f, 0.1f);
std::unique_ptr<Literal> input_literal =
LiteralUtil::CreateR4FromArray4DWithLayout(
input, LayoutUtil::MakeLayout(param.layout));
@@ -629,7 +629,14 @@ class R4ReduceWindowTest : public ReduceWindowTestBase,
auto init_value =
CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b);
CHECK(param.reducer == kAdd || param.reducer == kMax);
- auto computation = param.reducer == kAdd
+ auto reducer = param.reducer;
+ if (use_bfloat16() && Product(param.window_bounds) > 128) {
+ // To avoid numerical issues, force the reducer to be kMax for large bf16
+ // windows.
+ reducer = kMax;
+ }
+
+ auto computation = reducer == kAdd
? CreateScalarAddComputation(FloatType(), &b)
: CreateScalarMaxComputation(FloatType(), &b);
ReduceWindowWithGeneralPadding(
@@ -640,8 +647,8 @@ class R4ReduceWindowTest : public ReduceWindowTestBase,
/*window_strides=*/param.strides,
/*padding=*/padding);
- CHECK(param.reducer == kAdd || param.reducer == kMax);
- auto reduce_func = param.reducer == kAdd
+ CHECK(reducer == kAdd || reducer == kMax);
+ auto reduce_func = reducer == kAdd
? +[](float a, float b) { return a + b; }
: +[](float a, float b) { return std::max(a, b); };
std::unique_ptr<Array4D<float>> expected =
@@ -809,6 +816,22 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = {
/*pad_high=*/{1, 0, 0, 0},
/*layout=*/{3, 2, 1, 0},
/*reducer=*/kAdd},
+
+ R4ReduceWindowTestData{/*base_bounds=*/{8, 256, 256, 3},
+ /*window_bounds=*/{1, 64, 64, 1},
+ /*strides=*/{1, 64, 64, 1},
+ /*pad_low=*/{0, 0, 0, 0},
+ /*pad_high=*/{0, 0, 0, 0},
+ /*layout=*/{3, 0, 2, 1},
+ /*reducer=*/kAdd},
+
+ R4ReduceWindowTestData{/*base_bounds=*/{112, 112, 8, 64},
+ /*window_bounds=*/{112, 112, 1, 8},
+ /*strides=*/{112, 112, 1, 8},
+ /*pad_low=*/{0, 0, 0, 0},
+ /*pad_high=*/{0, 0, 0, 0},
+ /*layout=*/{3, 2, 1, 0},
+ /*reducer=*/kAdd},
};
INSTANTIATE_TEST_CASE_P(
@@ -930,6 +953,27 @@ struct R3ReduceWindowTestData {
{/*base_bounds=*/{6, 21, 3}, /*window_bounds=*/{2, 3, 2},
/*strides=*/{1, 2, 2}, /*layout=*/{1, 0, 2},
/*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
+ {/*base_bounds=*/{95, 202, 251}, /*window_bounds=*/{95, 202, 251},
+ /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
+ /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
+ {/*base_bounds=*/{999, 57, 3}, /*window_bounds=*/{999, 57, 3},
+ /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
+ /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
+ {/*base_bounds=*/{178, 302, 64}, /*window_bounds=*/{178, 302, 64},
+ /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
+ /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
+ {/*base_bounds=*/{63, 261, 257}, /*window_bounds=*/{63, 261, 257},
+ /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
+ /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
+ {/*base_bounds=*/{10003, 10, 5}, /*window_bounds=*/{9999, 7, 3},
+ /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
+ /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
+ {/*base_bounds=*/{9999, 1, 1}, /*window_bounds=*/{9999, 1, 1},
+ /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
+ /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
+ {/*base_bounds=*/{10003, 10, 5}, /*window_bounds=*/{9999, 7, 3},
+ /*strides=*/{2, 2, 2}, /*layout=*/{2, 1, 0},
+ /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
};
string R3ReduceWindowTestDataToString(
@@ -956,35 +1000,42 @@ class R3ReduceWindowTest : public ReduceWindowTestBase,
R3ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
};
-TEST_P(R3ReduceWindowTest, Add) {
+TEST_P(R3ReduceWindowTest, DoIt) {
XlaBuilder b(TestName());
const auto& param = ::testing::get<0>(GetParam());
- CHECK(param.reducer == kAdd);
const float kInitValue = 0.0f;
Array3D<float> input(param.base_bounds[0], param.base_bounds[1],
- param.base_bounds[2], 1.0f);
+ param.base_bounds[2]);
+ input.FillRandom(0.1f, 0.1f);
std::unique_ptr<Literal> input_literal =
LiteralUtil::CreateR3FromArray3DWithLayout(
input, LayoutUtil::MakeLayout(param.layout));
+ auto reducer = param.reducer;
+ if (use_bfloat16()) {
+ input_literal = LiteralUtil::ConvertF32ToBF16(*input_literal);
+ if (Product(param.window_bounds) > 128) {
+ // To avoid numerical issues, force the reducer to be kMax for large bf16
+ // windows.
+ reducer = kMax;
+ }
+ }
- XlaOp parameter;
- auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0",
- &b, &parameter);
+ XlaOp parameter = Parameter(&b, 0, input_literal->shape(), "input");
auto init_value =
CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b);
+
+ auto computation = reducer == kAdd
+ ? CreateScalarAddComputation(FloatType(), &b)
+ : CreateScalarMaxComputation(FloatType(), &b);
+
ReduceWindow(/*operand=*/parameter,
/*init_value=*/init_value,
- /*computation=*/CreateScalarAddComputation(FloatType(), &b),
+ /*computation=*/computation,
/*window_dimensions=*/param.window_bounds,
/*window_strides=*/param.strides, /*padding=*/param.padding);
- auto expected = ReferenceUtil::ReduceWindow3DAdd(
- /*operand=*/input, /*init=*/kInitValue, /*window=*/param.window_bounds,
- /*stride=*/param.strides, /*padding=*/param.padding);
-
- ComputeAndCompareLiteral(&b, *LiteralUtil::CreateFromArray(*expected),
- {input_arg.get()}, DefaultErrorSpec());
+ ComputeAndCompare(&b, {std::move(*input_literal)}, DefaultErrorSpec());
}
INSTANTIATE_TEST_CASE_P(
@@ -1093,7 +1144,6 @@ class R2ReduceWindowTest : public ReduceWindowTestBase,
void DoIt() {
XlaBuilder b(TestName());
const auto& param = ::testing::get<0>(GetParam());
- CHECK(param.reducer == kAdd);
const float kInitValue = 0.0f;
Array2D<float> input(param.base_bounds[0], param.base_bounds[1], 1.0f);
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index 66983801bf..798f499870 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -20,13 +20,7 @@ py_library(
),
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
- deps = if_not_windows([
- # TODO(aaroey): tensorrt dependency has to appear before tflite so the
- # build can resolve its flatbuffers symbols within the tensorrt library.
- # This is an issue with the tensorrt static library and will be fixed by
- # the next tensorrt release, so fix the order here after that.
- "//tensorflow/contrib/tensorrt:init_py", # doesn't compile on windows
- ]) + [
+ deps = [
"//tensorflow/contrib/all_reduce",
"//tensorflow/contrib/batching:batch_py",
"//tensorflow/contrib/bayesflow:bayesflow_py",
@@ -135,6 +129,7 @@ py_library(
]) + if_not_windows([
"//tensorflow/contrib/bigtable", # depends on bigtable
"//tensorflow/contrib/cloud:cloud_py", # doesn't compile on Windows
+ "//tensorflow/contrib/tensorrt:init_py", # doesn't compile on windows
"//tensorflow/contrib/ffmpeg:ffmpeg_ops_py",
]),
)
diff --git a/tensorflow/contrib/autograph/converters/builtin_functions.py b/tensorflow/contrib/autograph/converters/builtin_functions.py
index b26c52294c..29dce13999 100644
--- a/tensorflow/contrib/autograph/converters/builtin_functions.py
+++ b/tensorflow/contrib/autograph/converters/builtin_functions.py
@@ -21,6 +21,8 @@ from __future__ import print_function
import gast
from tensorflow.contrib.autograph.core import converter
+from tensorflow.contrib.autograph.operators import py_builtins
+from tensorflow.contrib.autograph.pyct import anno
from tensorflow.contrib.autograph.pyct import templates
@@ -31,41 +33,32 @@ class BuiltinFunctionTransformer(converter.Base):
TF equivalent, like `len`.
"""
- def _convert_builtin(self, node):
+ def _convert_builtin(self, f, args, as_expression):
template = """
- ag__.utils.dynamic_builtin(func, args)
+ ag__.func(args)
"""
- return templates.replace(template, func=node.func, args=node.args)[0].value
-
- def _convert_print(self, node):
- template = """
- ag__.utils.dynamic_print(args)
- """
- return templates.replace(template, args=node.args)[0].value
+ if as_expression:
+ return templates.replace_as_expression(
+ template, func=py_builtins.overload_of(f).__name__, args=args)
+ else:
+ return templates.replace(
+ template, func=py_builtins.overload_of(f).__name__, args=args)
def visit_Call(self, node):
- self.generic_visit(node)
- # TODO(mdan): This won't work if the function was hidden.
- # TODO(mdan): Rely on the live_val and use inspect_utils.is_builtin instead.
- if (isinstance(node.func, gast.Name) and
- node.func.id in ('len', 'range', 'xrange', 'float', 'int')):
- return self._convert_builtin(node)
- # Print needs to be handled separately because it can be read as statement.
- if isinstance(node.func, gast.Name) and node.func.id == 'print':
- return self._convert_print(node)
+ node = self.generic_visit(node)
+ if anno.hasanno(node.func, 'live_val'):
+ live_val = anno.getanno(node.func, 'live_val')
+ if live_val in py_builtins.SUPPORTED_BUILTINS:
+ node = self._convert_builtin(live_val, node.args, as_expression=True)
return node
def visit_Print(self, node):
- self.generic_visit(node)
+ node = self.generic_visit(node)
args = node.values
# Following is the case when calling print(a, b)
if len(args) == 1 and isinstance(args[0], gast.Tuple):
args = args[0].elts
- template = """
- fname(args)
- """
- function_call = templates.replace(template, fname='print', args=args)[0]
- return self.visit(function_call)
+ return self._convert_builtin(print, args, as_expression=False)
def transform(node, ctx):
diff --git a/tensorflow/contrib/autograph/converters/builtin_functions_test.py b/tensorflow/contrib/autograph/converters/builtin_functions_test.py
index d0a0cbbeb6..3e3a04f38b 100644
--- a/tensorflow/contrib/autograph/converters/builtin_functions_test.py
+++ b/tensorflow/contrib/autograph/converters/builtin_functions_test.py
@@ -23,6 +23,7 @@ import six
from tensorflow.contrib.autograph.converters import builtin_functions
from tensorflow.contrib.autograph.core import converter_testing
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
@@ -34,11 +35,11 @@ class BuiltinFunctionsTest(converter_testing.TestCase):
def test_fn(a):
return len(a)
- with self.converted(test_fn, builtin_functions, {'len': len},
- array_ops.shape) as result:
+ with self.converted(test_fn, builtin_functions, {'len': len}) as result:
with self.cached_session() as sess:
- ops = result.test_fn(constant_op.constant([0, 0, 0]))
- self.assertEqual(sess.run(ops), 3)
+ p = array_ops.placeholder(dtype=dtypes.int32, shape=None)
+ ops = result.test_fn(p)
+ self.assertEqual(sess.run(ops, {p: [0, 0, 0]}), 3)
def test_print(self):
diff --git a/tensorflow/contrib/autograph/impl/api.py b/tensorflow/contrib/autograph/impl/api.py
index 276a387180..8b38d5d080 100644
--- a/tensorflow/contrib/autograph/impl/api.py
+++ b/tensorflow/contrib/autograph/impl/api.py
@@ -29,9 +29,9 @@ import six
from tensorflow.contrib.autograph.core import config
from tensorflow.contrib.autograph.core import converter
from tensorflow.contrib.autograph.impl import conversion
+from tensorflow.contrib.autograph.operators import py_builtins
from tensorflow.contrib.autograph.pyct import compiler
from tensorflow.contrib.autograph.pyct import inspect_utils
-from tensorflow.contrib.autograph.utils import builtins
from tensorflow.contrib.autograph.utils import py_func
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import tf_decorator
@@ -150,7 +150,7 @@ def converted_call(f, recursive, verbose, force_conversion, arg_types, *args,
unknown_arg_value = object() # Sentinel for arguments of unknown value
if inspect_utils.isbuiltin(f):
- return builtins.dynamic_builtin(f, *args, **kwargs)
+ return py_builtins.overload_of(f)(*args, **kwargs)
if tf_inspect.isfunction(f) or tf_inspect.ismethod(f):
# Regular functions
diff --git a/tensorflow/contrib/autograph/operators/BUILD b/tensorflow/contrib/autograph/operators/BUILD
index 332d5dab19..29759bad79 100644
--- a/tensorflow/contrib/autograph/operators/BUILD
+++ b/tensorflow/contrib/autograph/operators/BUILD
@@ -22,6 +22,7 @@ py_library(
"__init__.py",
"control_flow.py",
"data_structures.py",
+ "py_builtins.py",
"slices.py",
],
srcs_version = "PY2AND3",
@@ -62,6 +63,16 @@ py_test(
)
py_test(
+ name = "py_builtins_test",
+ srcs = ["py_builtins_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":operators",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+py_test(
name = "slices_test",
srcs = ["slices_test.py"],
srcs_version = "PY2AND3",
diff --git a/tensorflow/contrib/autograph/operators/__init__.py b/tensorflow/contrib/autograph/operators/__init__.py
index 392cb60bcc..c4fbc260a2 100644
--- a/tensorflow/contrib/autograph/operators/__init__.py
+++ b/tensorflow/contrib/autograph/operators/__init__.py
@@ -45,6 +45,11 @@ from tensorflow.contrib.autograph.operators.data_structures import list_stack
from tensorflow.contrib.autograph.operators.data_structures import ListPopOpts
from tensorflow.contrib.autograph.operators.data_structures import ListStackOpts
from tensorflow.contrib.autograph.operators.data_structures import new_list
+from tensorflow.contrib.autograph.operators.py_builtins import float_
+from tensorflow.contrib.autograph.operators.py_builtins import int_
+from tensorflow.contrib.autograph.operators.py_builtins import len_
+from tensorflow.contrib.autograph.operators.py_builtins import print_
+from tensorflow.contrib.autograph.operators.py_builtins import range_
from tensorflow.contrib.autograph.operators.slices import get_item
from tensorflow.contrib.autograph.operators.slices import GetItemOpts
from tensorflow.contrib.autograph.operators.slices import set_item
diff --git a/tensorflow/contrib/autograph/operators/control_flow.py b/tensorflow/contrib/autograph/operators/control_flow.py
index 9909e52164..9a66a6bb60 100644
--- a/tensorflow/contrib/autograph/operators/control_flow.py
+++ b/tensorflow/contrib/autograph/operators/control_flow.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.utils import builtins
+from tensorflow.contrib.autograph.operators import py_builtins
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
@@ -82,8 +82,8 @@ def _py_for_stmt(iter_, extra_test, body, init_state):
def _known_len_for_stmt(iter_, extra_test, body, init_state):
- """Overload of for_stmt that iterates over objects that define a length."""
- n = builtins.dynamic_len(iter_)
+ """Overload of for_stmt that iterates over objects that admit a length."""
+ n = py_builtins.len_(iter_)
def while_body(iterate_index, *state):
iterate = iter_[iterate_index]
diff --git a/tensorflow/contrib/autograph/operators/py_builtins.py b/tensorflow/contrib/autograph/operators/py_builtins.py
new file mode 100644
index 0000000000..c5730934e7
--- /dev/null
+++ b/tensorflow/contrib/autograph/operators/py_builtins.py
@@ -0,0 +1,225 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Operators corresponding to Python builtin functions.
+
+List of built-in functions: https://docs.python.org/3/library/functions.html
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import six
+
+from tensorflow.contrib.autograph.utils import py_func
+from tensorflow.contrib.autograph.utils import tensors
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_parsing_ops
+from tensorflow.python.ops import gen_string_ops
+from tensorflow.python.ops import list_ops
+from tensorflow.python.ops import math_ops
+
+
+UNDEFINED = object()
+
+
+def overload_of(f):
+ if f in SUPPORTED_BUILTINS:
+ return BUILTIN_FUINCTIONS_MAP[f.__name__]
+ return f
+
+
+def abs_(x):
+ if tensor_util.is_tensor(x):
+ return _tf_abs(x)
+ return _py_abs(x)
+
+
+def _tf_abs(x):
+ return math_ops.abs(x)
+
+
+def _py_abs(x):
+ return abs(x)
+
+
+def float_(x=0):
+ if tensor_util.is_tensor(x):
+ return _tf_float(x)
+ return _py_float(x)
+
+
+def _tf_float(x):
+ # TODO(mdan): We shouldn't assume float32.
+ if x.dtype == dtypes.string:
+ return gen_parsing_ops.string_to_number(x, out_type=dtypes.float32)
+ return math_ops.cast(x, dtype=dtypes.float32)
+
+
+def _py_float(x):
+ return float(x)
+
+
+def int_(x=0, base=UNDEFINED):
+ if tensor_util.is_tensor(x):
+ return _tf_int(x, base)
+ return _py_int(x, base)
+
+
+def _tf_int(x, base):
+ if base not in (10, UNDEFINED):
+ raise NotImplementedError('base {} not supported for int'.format(base))
+
+ # TODO(mdan): We shouldn't assume int32.
+ if x.dtype == dtypes.string:
+ return gen_parsing_ops.string_to_number(x, out_type=dtypes.int32)
+ return math_ops.cast(x, dtype=dtypes.int32)
+
+
+def _py_int(x, base):
+ if base is UNDEFINED:
+ return int(x)
+ return int(x, base)
+
+
+def len_(s):
+ if tensors.is_tensor_array(s):
+ return _tf_tensor_array_len(s)
+ elif tensors.is_tensor_list(s):
+ return _tf_tensor_list_len(s)
+ elif tensor_util.is_tensor(s):
+ return _tf_tensor_len(s)
+ return _py_len(s)
+
+
+def _tf_tensor_array_len(s):
+ return s.size()
+
+
+def _tf_tensor_list_len(s):
+ return list_ops.tensor_list_length(s)
+
+
+def _tf_tensor_len(s):
+ """Overload of len_ for Tensor arguments."""
+ # Statically shaped tensors: length is known ahead of time.
+ if s.shape.ndims and s.shape[0].value is not None:
+ return s.shape[0].value
+
+ # Static shape of unknown dimensions: use dynamic shape but statically
+ # chech that it's a scalar.
+ shape = array_ops.shape(s)
+
+ assert shape.shape, 'shape tensor of zero size? {}'.format(shape)
+
+ if shape.shape[0] == 0:
+ raise ValueError(
+ 'len requires a non-scalar tensor, got one of shape {}'.format(shape))
+
+ if shape.shape[0].value is not None:
+ return array_ops.shape(s)[0]
+
+ # Fully dynamic shape: use ops.
+ rank = array_ops.rank(s)
+
+ def raise_zero_rank_error():
+ msg = gen_string_ops.string_join(
+ ['len requires non-zero rank, got ',
+ gen_string_ops.as_string(rank)])
+ with ops.control_dependencies([control_flow_ops.Assert(False, [msg])]):
+ return constant_op.constant(0, dtype=dtypes.int32)
+
+ return control_flow_ops.cond(rank > 0, lambda: array_ops.shape(s)[0],
+ raise_zero_rank_error)
+
+
+def _py_len(s):
+ return len(s)
+
+
+def print_(*objects, **kwargs):
+ # Note: Python 2.6 doesn't support explicit keywords after starargs.
+ unknown_kwargs = tuple(
+ set(kwargs.keys()) - set(('sep', 'end', 'file', 'flush')))
+ if unknown_kwargs:
+ raise ValueError('invalid keyword arguments: {}'.format(unknown_kwargs))
+
+ # TODO(mdan): use logging_ops.Print when py_func is not supported.
+ return _tf_py_func_print(objects, kwargs)
+
+
+def _tf_py_func_print(objects, kwargs):
+ """Overload of print_ as a py_func implementation."""
+ override_kwargs = {k: v for k, v in kwargs.items() if v is not UNDEFINED}
+ if 'flush' not in override_kwargs:
+ # Defaulting to flushing the console in graph mode, which helps reduce
+ # garbled output in IPython.
+ override_kwargs['flush'] = True
+
+ def print_wrapper(*vals):
+ if six.PY3:
+ # TensorFlow doesn't seem to generate Unicode when passing strings to
+ # py_func. This causes the print to add a "b'" wrapper to the output,
+ # which is probably never what you want.
+ vals = tuple(
+ v.decode('utf-8') if isinstance(v, bytes) else v for v in vals)
+ six.print_(*vals, **override_kwargs)
+
+ return py_func.wrap_py_func(
+ print_wrapper, None, objects, use_dummy_return=True)
+
+
+def range_(start_or_stop, stop=UNDEFINED, step=UNDEFINED):
+ if any(tensor_util.is_tensor(s) for s in (start_or_stop, stop, step)):
+ return _tf_range(start_or_stop, stop, step)
+ return _py_range(start_or_stop, stop, step)
+
+
+def _tf_range(start_or_stop, stop, step):
+ # TODO(mdan): We should optimize this when a full tensor is not required.
+ if step is not UNDEFINED:
+ return math_ops.range(start_or_stop, stop, step)
+ if stop is not UNDEFINED:
+ return math_ops.range(start_or_stop, stop)
+ return math_ops.range(start_or_stop)
+
+
+def _py_range(start_or_stop, stop, step):
+ if step is not UNDEFINED:
+ return range(start_or_stop, stop, step)
+ if stop is not UNDEFINED:
+ return range(start_or_stop, stop)
+ return range(start_or_stop)
+
+
+SUPPORTED_BUILTINS = set((abs, float, int, len, print, range))
+
+if six.PY2:
+ SUPPORTED_BUILTINS.add(xrange)
+
+BUILTIN_FUINCTIONS_MAP = {
+ 'abs': abs_,
+ 'float': float_,
+ 'int': int_,
+ 'len': len_,
+ 'print': print_,
+ 'range': range_,
+ 'xrange': range_,
+}
diff --git a/tensorflow/contrib/autograph/operators/py_builtins_test.py b/tensorflow/contrib/autograph/operators/py_builtins_test.py
new file mode 100644
index 0000000000..4073c51785
--- /dev/null
+++ b/tensorflow/contrib/autograph/operators/py_builtins_test.py
@@ -0,0 +1,131 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for py_builtins module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import sys
+
+import six
+
+from tensorflow.contrib.autograph.operators import data_structures
+from tensorflow.contrib.autograph.operators import py_builtins
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors_impl
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import tensor_array_ops
+from tensorflow.python.platform import test
+
+
+class PyBuiltinsTest(test.TestCase):
+
+ def test_abs(self):
+ self.assertEqual(py_builtins.abs_(-1), 1)
+ with self.test_session() as sess:
+ t = py_builtins.abs_(constant_op.constant(-1))
+ self.assertEqual(sess.run(t), 1)
+ t = py_builtins.abs_(constant_op.constant([-1, 2, -3]))
+ self.assertAllEqual(sess.run(t), [1, 2, 3])
+
+ def test_float(self):
+ self.assertEqual(py_builtins.float_(10), 10.0)
+ self.assertEqual(py_builtins.float_('10.0'), 10.0)
+ with self.test_session() as sess:
+ t = py_builtins.float_(constant_op.constant(1, dtype=dtypes.int64))
+ self.assertEqual(sess.run(t), 1.0)
+ st = py_builtins.float_(constant_op.constant('1.0'))
+ self.assertEqual(sess.run(st), 1.0)
+
+ def test_int(self):
+ self.assertEqual(py_builtins.int_(10.0), 10)
+ self.assertEqual(py_builtins.int_('11', 2), 3)
+ with self.test_session() as sess:
+ t = py_builtins.int_(constant_op.constant(1, dtype=dtypes.float64))
+ self.assertEqual(sess.run(t), 1)
+ st = py_builtins.int_(constant_op.constant('1'))
+ self.assertEqual(sess.run(st), 1)
+ st = py_builtins.int_(constant_op.constant('1'), 10)
+ self.assertEqual(sess.run(st), 1)
+
+ def test_int_unsupported_base(self):
+ t = constant_op.constant(1, dtype=dtypes.float64)
+ with self.assertRaises(NotImplementedError):
+ py_builtins.int_(t, 2)
+
+ def test_len(self):
+ self.assertEqual(py_builtins.len_([1, 2, 3]), 3)
+ with self.test_session() as sess:
+ t = py_builtins.len_(constant_op.constant([[1], [2], [3]]))
+ self.assertEqual(t, 3)
+ ta = py_builtins.len_(tensor_array_ops.TensorArray(dtypes.int32, size=5))
+ self.assertEqual(sess.run(ta), 5)
+ tl = py_builtins.len_(data_structures.tf_tensor_list_new([3, 4, 5]))
+ self.assertEqual(sess.run(tl), 3)
+
+ def test_len_scalar(self):
+ with self.assertRaises(ValueError):
+ py_builtins.len_(constant_op.constant(1))
+
+ def test_len_dynamic_shape(self):
+ with self.test_session() as sess:
+ p = array_ops.placeholder(dtype=dtypes.int32, shape=None)
+ t = py_builtins.len_(p)
+ self.assertEqual(sess.run(t, {p: [1, 2, 3]}), 3)
+
+ with self.assertRaises(errors_impl.InvalidArgumentError):
+ t = py_builtins.len_(p)
+ sess.run(t, {p: 1})
+
+ def test_print_tensors(self):
+ try:
+ out_capturer = six.StringIO()
+ sys.stdout = out_capturer
+ with self.test_session() as sess:
+ sess.run(py_builtins.print_(constant_op.constant('test message'), 1))
+ self.assertEqual(out_capturer.getvalue(), 'test message 1\n')
+ finally:
+ sys.stdout = sys.__stdout__
+
+ def test_print_complex(self):
+ try:
+ out_capturer = six.StringIO()
+ sys.stdout = out_capturer
+ with self.test_session() as sess:
+ sess.run(
+ py_builtins.print_(constant_op.constant('test message'), [1, 2]))
+ self.assertEqual(out_capturer.getvalue(), 'test message [1, 2]\n')
+ finally:
+ sys.stdout = sys.__stdout__
+
+ def test_range(self):
+ self.assertListEqual(list(py_builtins.range_(3)), [0, 1, 2])
+ self.assertListEqual(list(py_builtins.range_(1, 3)), [1, 2])
+ self.assertListEqual(list(py_builtins.range_(2, 0, -1)), [2, 1])
+
+ def test_range_tensor(self):
+ with self.test_session() as sess:
+ r = py_builtins.range_(constant_op.constant(3))
+ self.assertAllEqual(sess.run(r), [0, 1, 2])
+ r = py_builtins.range_(1, constant_op.constant(3))
+ self.assertAllEqual(sess.run(r), [1, 2])
+ r = py_builtins.range_(2, 0, constant_op.constant(-1))
+ self.assertAllEqual(sess.run(r), [2, 1])
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/autograph/utils/BUILD b/tensorflow/contrib/autograph/utils/BUILD
index d2b399f19b..4504a5c7a3 100644
--- a/tensorflow/contrib/autograph/utils/BUILD
+++ b/tensorflow/contrib/autograph/utils/BUILD
@@ -20,12 +20,12 @@ py_library(
name = "utils",
srcs = [
"__init__.py",
- "builtins.py",
"context_managers.py",
"misc.py",
"multiple_dispatch.py",
"py_func.py",
"tensor_list.py",
+ "tensors.py",
"testing.py",
"type_check.py",
],
@@ -42,17 +42,6 @@ py_library(
)
py_test(
- name = "builtins_test",
- srcs = ["builtins_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_windows"],
- deps = [
- ":utils",
- "//tensorflow/python:client_testlib",
- ],
-)
-
-py_test(
name = "context_managers_test",
srcs = ["context_managers_test.py"],
srcs_version = "PY2AND3",
@@ -113,3 +102,13 @@ py_test(
"//tensorflow/python:list_ops",
],
)
+
+py_test(
+ name = "tensors_test",
+ srcs = ["tensors_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":utils",
+ "//tensorflow/python:client_testlib",
+ ],
+)
diff --git a/tensorflow/contrib/autograph/utils/__init__.py b/tensorflow/contrib/autograph/utils/__init__.py
index 57b5f74741..38e0a0a8f0 100644
--- a/tensorflow/contrib/autograph/utils/__init__.py
+++ b/tensorflow/contrib/autograph/utils/__init__.py
@@ -18,9 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.utils.builtins import dynamic_builtin
-from tensorflow.contrib.autograph.utils.builtins import dynamic_print
-from tensorflow.contrib.autograph.utils.builtins import dynamic_range
from tensorflow.contrib.autograph.utils.context_managers import control_dependency_on_returns
from tensorflow.contrib.autograph.utils.misc import alias_tensors
from tensorflow.contrib.autograph.utils.multiple_dispatch import dynamic_is
diff --git a/tensorflow/contrib/autograph/utils/builtins.py b/tensorflow/contrib/autograph/utils/builtins.py
deleted file mode 100644
index 4dd440ef19..0000000000
--- a/tensorflow/contrib/autograph/utils/builtins.py
+++ /dev/null
@@ -1,143 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Builtin conversion utilities."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import sys
-
-import six
-
-from tensorflow.contrib.autograph.utils import py_func
-from tensorflow.contrib.autograph.utils import type_check
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import tensor_util
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import list_ops
-from tensorflow.python.ops import logging_ops
-from tensorflow.python.ops import math_ops
-
-
-def dynamic_builtin(f, *args, **kwargs):
- """Converts a builtin function call inline."""
- if f is len:
- return dynamic_len(*args, **kwargs)
- if six.PY2 and f is xrange:
- return dynamic_range(*args, **kwargs)
- if f is range:
- return dynamic_range(*args, **kwargs)
- if f is int:
- return dynamic_int(*args, **kwargs)
- if f is float:
- return dynamic_float(*args, **kwargs)
- if f is abs:
- return dynamic_abs(*args, **kwargs)
-
- raise NotImplementedError(
- 'The "%s" builtin is not yet supported.' % f.__name__)
-
-
-def dynamic_len(list_or_tensor):
- """Implementation of len using dynamic dispatch."""
- if _is_tensor_list(list_or_tensor):
- return list_ops.tensor_list_length(list_or_tensor)
- elif tensor_util.is_tensor(list_or_tensor):
- shape = list_or_tensor.shape
- if not shape.ndims:
- raise ValueError(
- 'len requires non-zero rank for tensor "%s"' % list_or_tensor)
- return array_ops.shape(list_or_tensor)[0]
- return len(list_or_tensor)
-
-
-def _is_tensor_list(list_or_tensor):
- return (tensor_util.is_tensor(list_or_tensor)
- and list_or_tensor.dtype == dtypes.variant)
-
-
-def dynamic_int(num_or_tensor, **kwargs):
- """Implementation of int() using dynamic dispatch."""
- if tensor_util.is_tensor(num_or_tensor):
- return math_ops.cast(num_or_tensor, dtype=dtypes.int32, **kwargs)
- return int(num_or_tensor)
-
-
-def dynamic_float(num_or_tensor, **kwargs):
- """Implementation of float() using dynamic dispatch."""
- if tensor_util.is_tensor(num_or_tensor):
- return math_ops.cast(num_or_tensor, dtype=dtypes.float32, **kwargs)
- return float(num_or_tensor)
-
-
-def dynamic_abs(num_or_tensor, **kwargs):
- if tensor_util.is_tensor(num_or_tensor):
- return math_ops.abs(num_or_tensor, **kwargs)
- else:
- return abs(num_or_tensor, **kwargs)
-
-
-def dynamic_range(start_or_stop, stop=None, step=None):
- """Implementation of range using dynamic dispatch."""
- if type_check.is_tensor(start_or_stop, stop, step):
- if step is not None:
- return math_ops.range(start_or_stop, stop, step)
- if stop is not None:
- return math_ops.range(start_or_stop, stop)
- return math_ops.range(start_or_stop)
-
- if step is not None:
- return range(start_or_stop, stop, step)
- elif stop is not None:
- return range(start_or_stop, stop)
- return range(start_or_stop)
-
-
-def is_tf_print_compatible(value):
- # TODO(mdan): Enable once we can reliably test this.
- # This is currently disabled because we can't capture the output of
- # op kernels from Python.
- del value
- return False
-
-
-def dynamic_print(*values):
- """Implementation of print using dynamic dispatch.
-
- The function attempts to use tf.Print if all the values are compatible.
- Otherwise, it will fall back to py_func.
-
- Args:
- *values: values to print
- Returns:
- A dummy value indicating the print completed. If tf.
- """
-
- if all(map(is_tf_print_compatible, values)):
- return logging_ops.Print(1, values)
-
- def print_wrapper(*vals):
- if six.PY3:
- # TensorFlow doesn't seem to generate Unicode when passing strings to
- # py_func. This causes the print to add a "b'" wrapper to the output,
- # which is probably never what you want.
- vals = tuple(v.decode() if isinstance(v, bytes) else v for v in vals)
- print(*vals)
- # The flush helps avoid garbled output in IPython.
- sys.stdout.flush()
-
- return py_func.wrap_py_func(
- print_wrapper, None, values, use_dummy_return=True)
diff --git a/tensorflow/contrib/autograph/utils/builtins_test.py b/tensorflow/contrib/autograph/utils/builtins_test.py
deleted file mode 100644
index b1cd5253bc..0000000000
--- a/tensorflow/contrib/autograph/utils/builtins_test.py
+++ /dev/null
@@ -1,145 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for builtins module."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import sys
-
-import six
-
-from tensorflow.contrib.autograph.utils import builtins
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.platform import test
-
-
-class BuiltinsTest(test.TestCase):
-
- def test_dynamic_len_tf_scalar(self):
- a = constant_op.constant(1)
-
- with self.assertRaisesRegexp(ValueError,
- 'len requires non-zero rank for tensor.*'):
- with self.test_session() as sess:
- sess.run(builtins.dynamic_builtin(len, a))
-
- def test_dynamic_len_tf_array(self):
- a = constant_op.constant([1, 2, 3])
-
- with self.test_session() as sess:
- self.assertEqual(3, sess.run(builtins.dynamic_builtin(len, a)))
-
- def test_dynamic_abs_tf_scalar(self):
- a = constant_op.constant(-1)
-
- with self.test_session() as sess:
- self.assertEqual(1, sess.run(builtins.dynamic_builtin(abs, a)))
-
- def test_dynamic_abs_tf_array(self):
- a = constant_op.constant([-1, 2, -3])
-
- with self.test_session() as sess:
- self.assertListEqual([1, 2, 3],
- list(sess.run(builtins.dynamic_builtin(abs, a))))
-
- def test_dynamic_abs_py_scalar(self):
- a = -1
- self.assertEqual(1, builtins.dynamic_builtin(abs, a))
-
- def test_dynamic_len_tf_matrix(self):
- a = constant_op.constant([[1, 2], [3, 4]])
-
- with self.test_session() as sess:
- self.assertEqual(2, sess.run(builtins.dynamic_builtin(len, a)))
-
- def test_dynamic_len_py_list(self):
- a = [3] * 5
-
- self.assertEqual(5, builtins.dynamic_builtin(len, a))
-
- def test_dynamic_range_all_python(self):
- self.assertListEqual(list(builtins.dynamic_builtin(range, 3)), [0, 1, 2])
- self.assertListEqual(list(builtins.dynamic_builtin(range, 1, 3)), [1, 2])
- self.assertListEqual(
- list(builtins.dynamic_builtin(range, 2, 0, -1)), [2, 1])
-
- def test_dynamic_range_tf(self):
- with self.test_session() as sess:
- self.assertAllEqual(
- sess.run(builtins.dynamic_builtin(range, constant_op.constant(3))),
- [0, 1, 2])
- self.assertAllEqual(
- sess.run(builtins.dynamic_builtin(range, 1, constant_op.constant(3))),
- [1, 2])
- self.assertAllEqual(
- sess.run(
- builtins.dynamic_builtin(range, 2, 0, constant_op.constant(-1))),
- [2, 1])
-
- def test_dynamic_range_detection(self):
- def range(x): # pylint:disable=redefined-builtin
- return x
-
- # Functions that just have the names of builtins are rejected.
- with self.assertRaises(NotImplementedError):
- self.assertEqual(builtins.dynamic_builtin(range, 1), 1)
- if six.PY2:
- self.assertListEqual(
- list(builtins.dynamic_builtin(xrange, 3)), [0, 1, 2])
- self.assertListEqual(
- list(builtins.dynamic_builtin(six.moves.range, 3)), [0, 1, 2])
- self.assertListEqual(
- list(builtins.dynamic_builtin(six.moves.xrange, 3)), [0, 1, 2])
-
- def test_casts(self):
- i = constant_op.constant(2, dtype=dtypes.int32)
- f = constant_op.constant(1.0, dtype=dtypes.float32)
-
- self.assertEqual(builtins.dynamic_builtin(int, i).dtype, dtypes.int32)
- self.assertEqual(builtins.dynamic_builtin(int, f).dtype, dtypes.int32)
- self.assertEqual(builtins.dynamic_builtin(float, i).dtype, dtypes.float32)
- self.assertEqual(builtins.dynamic_builtin(float, f).dtype, dtypes.float32)
-
- self.assertEqual(builtins.dynamic_builtin(int, True), 1)
- self.assertEqual(builtins.dynamic_builtin(int, False), 0)
- self.assertEqual(builtins.dynamic_builtin(float, True), 1.0)
- self.assertEqual(builtins.dynamic_builtin(float, False), 0.0)
-
- def test_dynamic_print_tf(self):
- try:
- out_capturer = six.StringIO()
- sys.stdout = out_capturer
- with self.test_session() as sess:
- sess.run(builtins.dynamic_print('test message', 1))
- self.assertEqual(out_capturer.getvalue(), 'test message 1\n')
- finally:
- sys.stdout = sys.__stdout__
-
- def test_dynamic_print_complex(self):
- try:
- out_capturer = six.StringIO()
- sys.stdout = out_capturer
- with self.test_session() as sess:
- sess.run(builtins.dynamic_print('test message', [1, 2]))
- self.assertEqual(out_capturer.getvalue(), 'test message [1, 2]\n')
- finally:
- sys.stdout = sys.__stdout__
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/autograph/utils/tensors.py b/tensorflow/contrib/autograph/utils/tensors.py
new file mode 100644
index 0000000000..fa5db81a71
--- /dev/null
+++ b/tensorflow/contrib/autograph/utils/tensors.py
@@ -0,0 +1,41 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""This module defines tensor utilities not found in TensorFlow.
+
+The reason these utilities are not defined in TensorFlow is because they may
+not be not fully robust, although they work in the vast majority of cases. So
+we define them here in order for their behavior to be consistently verified.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import tensor_array_ops
+
+
+def is_tensor_array(t):
+ return isinstance(t, tensor_array_ops.TensorArray)
+
+
+def is_tensor_list(t):
+ # TODO(mdan): This is just a heuristic.
+ # With TF lacking support for templated types, this is unfortunately the
+ # closest we can get right now. A dedicated op ought to be possible to
+ # construct.
+ return (tensor_util.is_tensor(t) and t.dtype == dtypes.variant and
+ not t.shape.ndims)
diff --git a/tensorflow/contrib/autograph/utils/tensors_test.py b/tensorflow/contrib/autograph/utils/tensors_test.py
new file mode 100644
index 0000000000..e855e0b6cb
--- /dev/null
+++ b/tensorflow/contrib/autograph/utils/tensors_test.py
@@ -0,0 +1,57 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensors module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.autograph.utils import tensors
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import list_ops
+from tensorflow.python.ops import tensor_array_ops
+from tensorflow.python.platform import test
+
+
+class TensorsTest(test.TestCase):
+
+ def _simple_tensor_array(self):
+ return tensor_array_ops.TensorArray(dtypes.int32, size=3)
+
+ def _simple_tensor_list(self):
+ return list_ops.empty_tensor_list(
+ element_shape=constant_op.constant([1]), element_dtype=dtypes.int32)
+
+ def _simple_list_of_tensors(self):
+ return [constant_op.constant(1), constant_op.constant(2)]
+
+ def test_is_tensor_array(self):
+ self.assertTrue(tensors.is_tensor_array(self._simple_tensor_array()))
+ self.assertFalse(tensors.is_tensor_array(self._simple_tensor_list()))
+ self.assertFalse(tensors.is_tensor_array(constant_op.constant(1)))
+ self.assertFalse(tensors.is_tensor_array(self._simple_list_of_tensors()))
+ self.assertFalse(tensors.is_tensor_array(None))
+
+ def test_is_tensor_list(self):
+ self.assertFalse(tensors.is_tensor_list(self._simple_tensor_array()))
+ self.assertTrue(tensors.is_tensor_list(self._simple_tensor_list()))
+ self.assertFalse(tensors.is_tensor_list(constant_op.constant(1)))
+ self.assertFalse(tensors.is_tensor_list(self._simple_list_of_tensors()))
+ self.assertFalse(tensors.is_tensor_list(None))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py
index e6407174b1..35d727482b 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py
@@ -141,11 +141,18 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler):
# The bias is computed on gradients and hessians (and not
# filtered_gradients) which have exactly one value per example, so we
# don't double count a gradient in multivalent columns.
+ # Since unsorted_segment_sum can be numerically unstable, use 64bit
+ # operation.
+ gradients64 = math_ops.cast(gradients, dtypes.float64)
+ hessians64 = math_ops.cast(hessians, dtypes.float64)
per_partition_gradients = math_ops.unsorted_segment_sum(
- gradients, mapped_partitions, array_ops.size(unique_partitions))
+ gradients64, mapped_partitions, array_ops.size(unique_partitions))
per_partition_hessians = math_ops.unsorted_segment_sum(
- hessians, mapped_partitions, array_ops.size(unique_partitions))
-
+ hessians64, mapped_partitions, array_ops.size(unique_partitions))
+ per_partition_gradients = math_ops.cast(per_partition_gradients,
+ dtypes.float32)
+ per_partition_hessians = math_ops.cast(per_partition_hessians,
+ dtypes.float32)
# Prepend a bias feature per partition that accumulates the stats for all
# examples in that partition.
# Bias is added to the stats even if there are no examples with values in
diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py
index 38c0a09c33..92d4251a86 100644
--- a/tensorflow/contrib/data/python/ops/interleave_ops.py
+++ b/tensorflow/contrib/data/python/ops/interleave_ops.py
@@ -220,6 +220,7 @@ def sample_from_datasets(datasets, weights=None, seed=None):
if weights is None:
# Select inputs with uniform probability.
logits = [[1.0] * num_datasets]
+
else:
# Use the given `weights` as the probability of choosing the respective
# input.
@@ -245,8 +246,11 @@ def sample_from_datasets(datasets, weights=None, seed=None):
return array_ops.squeeze(
stateless.stateless_multinomial(logits, 1, seed=seed), axis=[0, 1])
- selector_input = random_ops.RandomDataset(seed).batch(2).map(
- select_dataset_constant_logits)
+ selector_input = dataset_ops.MapDataset(
+ random_ops.RandomDataset(seed).batch(2),
+ select_dataset_constant_logits,
+ use_inter_op_parallelism=False)
+
else:
# Use each element of the given `weights` dataset as the probability of
# choosing the respective input.
@@ -259,9 +263,12 @@ def sample_from_datasets(datasets, weights=None, seed=None):
return array_ops.squeeze(
stateless.stateless_multinomial(logits, 1, seed=seed), axis=[0, 1])
- selector_input = dataset_ops.Dataset.zip(
- (logits_ds, random_ops.RandomDataset(seed).batch(2)
- )).map(select_dataset_varying_logits)
+ logits_and_seeds = dataset_ops.Dataset.zip(
+ (logits_ds, random_ops.RandomDataset(seed).batch(2)))
+ selector_input = dataset_ops.MapDataset(
+ logits_and_seeds,
+ select_dataset_varying_logits,
+ use_inter_op_parallelism=False)
return _DirectedInterleaveDataset(selector_input, datasets)
diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py
index 7f09ba71dc..4c466781f7 100644
--- a/tensorflow/contrib/data/python/ops/readers.py
+++ b/tensorflow/contrib/data/python/ops/readers.py
@@ -499,7 +499,8 @@ def make_csv_dataset(
# indefinitely, and all batches will be full-sized.
dataset = dataset.batch(batch_size=batch_size,
drop_remainder=num_epochs is None)
- dataset = dataset.map(map_fn)
+ dataset = dataset_ops.MapDataset(
+ dataset, map_fn, use_inter_op_parallelism=False)
dataset = dataset.prefetch(prefetch_buffer_size)
return dataset
@@ -778,7 +779,8 @@ def make_batched_features_dataset(file_pattern,
# Extract values if the `Example` tensors are stored as key-value tuples.
if dataset.output_types == (dtypes.string, dtypes.string):
- dataset = dataset.map(lambda _, v: v)
+ dataset = dataset_ops.MapDataset(
+ dataset, lambda _, v: v, use_inter_op_parallelism=False)
# Apply dataset repeat and shuffle transformations.
dataset = _maybe_shuffle_and_repeat(
diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
index 4fa8aa06cc..77079d0df9 100644
--- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
+++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
@@ -229,6 +229,8 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy):
if not session_config or not self._cluster_spec:
return
+ session_config.isolate_session_state = True
+
assert self._task_type
assert self._task_id is not None
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index d1235b7afb..0c6805d682 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -572,6 +572,10 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
task_type=None,
task_id=None):
del task_type, task_id
+
+ if session_config:
+ session_config.isolate_session_state = True
+
if cluster_spec:
self._initialize_multi_worker(self._num_gpus, cluster_spec)
diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py
index 88d7768b14..1125d027f6 100644
--- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py
+++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py
@@ -412,6 +412,8 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
if not session_config or not self._cluster_spec:
return
+ session_config.isolate_session_state = False
+
assert self._cluster_spec
assert self._task_type
assert self._task_id is not None
diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py
index 32d7444e42..4fb70ec685 100644
--- a/tensorflow/contrib/distribute/python/tpu_strategy.py
+++ b/tensorflow/contrib/distribute/python/tpu_strategy.py
@@ -311,3 +311,16 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
if self._tpu_cluster_resolver.get_master() in ('', 'local'):
return '/replica:0/task:0/device:CPU:0'
return '/job:tpu_worker/task:%d/device:CPU:0' % (host_id,)
+
+ def configure(self,
+ session_config=None,
+ cluster_spec=None,
+ task_type=None,
+ task_id=None):
+ del cluster_spec, task_type, task_id
+ if session_config:
+ session_config.isolate_session_state = True
+ cluster_spec = self._tpu_cluster_resolver.cluster_spec()
+ if cluster_spec:
+ session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
+
diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb
index 315d7a4893..529c99b37c 100644
--- a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb
+++ b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb
@@ -66,7 +66,7 @@
"\n",
"[Image Source](https://commons.wikimedia.org/wiki/Surfing#/media/File:Surfing_in_Hawaii.jpg), License: Public Domain\n",
"\n",
- "Our goal is generate a caption, such as \"a surfer riding on a wave\". Here, we'll use an attention based model. This enables us to see which parts of the image the model focuses on as it generates a caption.\n",
+ "Our goal is to generate a caption, such as \"a surfer riding on a wave\". Here, we'll use an attention-based model. This enables us to see which parts of the image the model focuses on as it generates a caption.\n",
"\n",
"![Prediction](https://tensorflow.org/images/imcap_prediction.png)\n",
"\n",
@@ -128,7 +128,7 @@
"source": [
"## Download and prepare the MS-COCO dataset\n",
"\n",
- "We will use the [MS-COCO dataset](http://cocodataset.org/#home) to train our model. This dataset contains >82,000 images, each of which has been annotated with at least 5 different captions. The code code below will download and extract the dataset automatically. \n",
+ "We will use the [MS-COCO dataset](http://cocodataset.org/#home) to train our model. This dataset contains >82,000 images, each of which has been annotated with at least 5 different captions. The code below will download and extract the dataset automatically. \n",
"\n",
"**Caution: large download ahead**. We'll use the training set, it's a 13GB file."
]
diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py
index aa99616810..dcc7b71d79 100644
--- a/tensorflow/contrib/eager/python/metrics_test.py
+++ b/tensorflow/contrib/eager/python/metrics_test.py
@@ -25,11 +25,14 @@ from tensorflow.contrib.eager.python import metrics
from tensorflow.contrib.summary import summary_test_util
from tensorflow.python.eager import context
from tensorflow.python.eager import test
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import summary_ops_v2 as summary_ops
from tensorflow.python.training import training_util
from tensorflow.python.training.checkpointable import util as checkpointable_utils
@@ -244,6 +247,48 @@ class MetricsTest(test.TestCase):
value = m.value()
self.assertEqual(self.evaluate(value), 2.5)
+ @test_util.run_in_graph_and_eager_modes
+ def testGraphAndEagerTensorGlobalVariables(self):
+ m = metrics.Mean(use_global_variables=True)
+ inputs = ops.convert_to_tensor([1.0, 2.0])
+ accumulate = m(inputs)
+ result = m.result()
+ self.evaluate(m.init_variables())
+ self.evaluate(accumulate)
+ self.assertEqual(self.evaluate(result), 1.5)
+ # Second init resets all the variables.
+ self.evaluate(m.init_variables())
+ inputs = ops.convert_to_tensor([2.0, 3.0])
+ self.evaluate(m(inputs))
+ value = m.value()
+ self.assertEqual(self.evaluate(value), 2.5)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testGraphAndEagerTensorWhileLoopDoubleCall(self):
+ m = metrics.Mean()
+ init_value = constant_op.constant(1)
+ cond = lambda i: math_ops.less(i, 3)
+ def body(x):
+ with ops.control_dependencies([m(x)]):
+ return math_ops.add(x, 1)
+ accumulate = control_flow_ops.while_loop(cond, body, [init_value])
+
+ result = m.result()
+ self.evaluate(m.init_variables())
+ self.evaluate(accumulate)
+ self.assertEqual(self.evaluate(result), 1.5)
+ # Second init resets all the variables.
+ self.evaluate(m.init_variables())
+ inputs = ops.convert_to_tensor([2.0, 3.0])
+ self.evaluate(m(inputs))
+ if ops.context.executing_eagerly():
+ self.evaluate(control_flow_ops.while_loop(cond, body, [init_value]))
+ else:
+ # Reuse the loop operators in graph mode
+ self.evaluate(accumulate)
+ value = m.value()
+ self.assertEqual(self.evaluate(value), 2.0)
+
def testTwoMeansGraph(self):
# Verify two metrics with the same name in the same graph raises a
# ValueError.
diff --git a/tensorflow/contrib/factorization/python/ops/wals.py b/tensorflow/contrib/factorization/python/ops/wals.py
index ca46c39baa..b82bf1188f 100644
--- a/tensorflow/contrib/factorization/python/ops/wals.py
+++ b/tensorflow/contrib/factorization/python/ops/wals.py
@@ -377,64 +377,68 @@ class WALSMatrixFactorization(estimator.Estimator):
WALS (Weighted Alternating Least Squares) is an algorithm for weighted matrix
factorization. It computes a low-rank approximation of a given sparse (n x m)
- matrix A, by a product of two matrices, U * V^T, where U is a (n x k) matrix
- and V is a (m x k) matrix. Here k is the rank of the approximation, also
- called the embedding dimension. We refer to U as the row factors, and V as the
- column factors.
+ matrix `A`, by a product of two matrices, `U * V^T`, where `U` is a (n x k)
+ matrix and `V` is a (m x k) matrix. Here k is the rank of the approximation,
+ also called the embedding dimension. We refer to `U` as the row factors, and
+ `V` as the column factors.
See tensorflow/contrib/factorization/g3doc/wals.md for the precise problem
formulation.
- The training proceeds in sweeps: during a row_sweep, we fix V and solve for U.
- During a column sweep, we fix U and solve for V. Each one of these problems is
- an unconstrained quadratic minimization problem and can be solved exactly (it
- can also be solved in mini-batches, since the solution decouples nicely).
+ The training proceeds in sweeps: during a row_sweep, we fix `V` and solve for
+ `U`. During a column sweep, we fix `U` and solve for `V`. Each one of these
+ problems is an unconstrained quadratic minimization problem and can be solved
+ exactly (it can also be solved in mini-batches, since the solution decouples
+ across rows of each matrix).
The alternating between sweeps is achieved by using a hook during training,
which is responsible for keeping track of the sweeps and running preparation
ops at the beginning of each sweep. It also updates the global_step variable,
which keeps track of the number of batches processed since the beginning of
training.
The current implementation assumes that the training is run on a single
- machine, and will fail if config.num_worker_replicas is not equal to one.
- Training is done by calling self.fit(input_fn=input_fn), where input_fn
+ machine, and will fail if `config.num_worker_replicas` is not equal to one.
+ Training is done by calling `self.fit(input_fn=input_fn)`, where `input_fn`
provides two tensors: one for rows of the input matrix, and one for rows of
the transposed input matrix (i.e. columns of the original matrix). Note that
during a row sweep, only row batches are processed (ignoring column batches)
and vice-versa.
Also note that every row (respectively every column) of the input matrix
must be processed at least once for the sweep to be considered complete. In
- particular, training will not make progress if input_fn does not generate some
- rows.
-
- For prediction, given a new set of input rows A' (e.g. new rows of the A
- matrix), we compute a corresponding set of row factors U', such that U' * V^T
- is a good approximation of A'. We call this operation a row projection. A
- similar operation is defined for columns.
- Projection is done by calling self.get_projections(input_fn=input_fn), where
- input_fn satisfies the constraints given below.
-
- The input functions must satisfy the following constraints: Calling input_fn
- must return a tuple (features, labels) where labels is None, and features is
- a dict containing the following keys:
+ particular, training will not make progress if some rows are not generated by
+ the `input_fn`.
+
+ For prediction, given a new set of input rows `A'`, we compute a corresponding
+ set of row factors `U'`, such that `U' * V^T` is a good approximation of `A'`.
+ We call this operation a row projection. A similar operation is defined for
+ columns. Projection is done by calling
+ `self.get_projections(input_fn=input_fn)`, where `input_fn` satisfies the
+ constraints given below.
+
+ The input functions must satisfy the following constraints: Calling `input_fn`
+ must return a tuple `(features, labels)` where `labels` is None, and
+ `features` is a dict containing the following keys:
+
TRAIN:
- - WALSMatrixFactorization.INPUT_ROWS: float32 SparseTensor (matrix).
+ * `WALSMatrixFactorization.INPUT_ROWS`: float32 SparseTensor (matrix).
Rows of the input matrix to process (or to project).
- - WALSMatrixFactorization.INPUT_COLS: float32 SparseTensor (matrix).
+ * `WALSMatrixFactorization.INPUT_COLS`: float32 SparseTensor (matrix).
Columns of the input matrix to process (or to project), transposed.
+
INFER:
- - WALSMatrixFactorization.INPUT_ROWS: float32 SparseTensor (matrix).
+ * `WALSMatrixFactorization.INPUT_ROWS`: float32 SparseTensor (matrix).
Rows to project.
- - WALSMatrixFactorization.INPUT_COLS: float32 SparseTensor (matrix).
+ * `WALSMatrixFactorization.INPUT_COLS`: float32 SparseTensor (matrix).
Columns to project.
- - WALSMatrixFactorization.PROJECT_ROW: Boolean Tensor. Whether to project
+ * `WALSMatrixFactorization.PROJECT_ROW`: Boolean Tensor. Whether to project
the rows or columns.
- - WALSMatrixFactorization.PROJECTION_WEIGHTS (Optional): float32 Tensor
+ * `WALSMatrixFactorization.PROJECTION_WEIGHTS` (Optional): float32 Tensor
(vector). The weights to use in the projection.
+
EVAL:
- - WALSMatrixFactorization.INPUT_ROWS: float32 SparseTensor (matrix).
+ * `WALSMatrixFactorization.INPUT_ROWS`: float32 SparseTensor (matrix).
Rows to project.
- - WALSMatrixFactorization.INPUT_COLS: float32 SparseTensor (matrix).
+ * `WALSMatrixFactorization.INPUT_COLS`: float32 SparseTensor (matrix).
Columns to project.
- - WALSMatrixFactorization.PROJECT_ROW: Boolean Tensor. Whether to project
+ * `WALSMatrixFactorization.PROJECT_ROW`: Boolean Tensor. Whether to project
the rows or columns.
"""
# Keys to be used in model_fn
@@ -469,7 +473,7 @@ class WALSMatrixFactorization(estimator.Estimator):
max_sweeps=None,
model_dir=None,
config=None):
- """Creates a model for matrix factorization using the WALS method.
+ r"""Creates a model for matrix factorization using the WALS method.
Args:
num_rows: Total number of rows for input matrix.
diff --git a/tensorflow/contrib/factorization/python/ops/wals_test.py b/tensorflow/contrib/factorization/python/ops/wals_test.py
index 36b483c6d7..31820a18b4 100644
--- a/tensorflow/contrib/factorization/python/ops/wals_test.py
+++ b/tensorflow/contrib/factorization/python/ops/wals_test.py
@@ -125,11 +125,13 @@ class WALSMatrixFactorizationTest(test.TestCase):
nz_row_ids = np.arange(np.shape(np_matrix)[0])
nz_col_ids = np.arange(np.shape(np_matrix)[1])
- def extract_features(row_batch, col_batch, shape):
+ def extract_features(row_batch, col_batch, num_rows, num_cols):
row_ids = row_batch[0]
col_ids = col_batch[0]
- rows = self.remap_sparse_tensor_rows(row_batch[1], row_ids, shape)
- cols = self.remap_sparse_tensor_rows(col_batch[1], col_ids, shape)
+ rows = self.remap_sparse_tensor_rows(
+ row_batch[1], row_ids, shape=[num_rows, num_cols])
+ cols = self.remap_sparse_tensor_rows(
+ col_batch[1], col_ids, shape=[num_cols, num_rows])
features = {
wals_lib.WALSMatrixFactorization.INPUT_ROWS: rows,
wals_lib.WALSMatrixFactorization.INPUT_COLS: cols,
@@ -154,7 +156,7 @@ class WALSMatrixFactorizationTest(test.TestCase):
capacity=10,
enqueue_many=True)
- features = extract_features(row_batch, col_batch, sp_mat.dense_shape)
+ features = extract_features(row_batch, col_batch, num_rows, num_cols)
if mode == model_fn.ModeKeys.INFER or mode == model_fn.ModeKeys.EVAL:
self.assertTrue(
diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py
index ab9886580d..7243f150ce 100644
--- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py
+++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py
@@ -184,7 +184,7 @@ class GANEstimator(estimator.Estimator):
return _get_estimator_spec(
mode, gan_model, generator_loss_fn, discriminator_loss_fn,
get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer,
- get_hooks_fn)
+ get_hooks_fn, use_loss_summaries)
super(GANEstimator, self).__init__(
model_fn=_model_fn, model_dir=model_dir, config=config)
@@ -211,15 +211,17 @@ def _get_gan_model(
def _get_estimator_spec(
mode, gan_model, generator_loss_fn, discriminator_loss_fn,
get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer,
- get_hooks_fn=None):
+ get_hooks_fn=None, use_loss_summaries=True):
"""Get the EstimatorSpec for the current mode."""
if mode == model_fn_lib.ModeKeys.PREDICT:
estimator_spec = model_fn_lib.EstimatorSpec(
mode=mode, predictions=gan_model.generated_data)
else:
gan_loss = tfgan_tuples.GANLoss(
- generator_loss=generator_loss_fn(gan_model),
- discriminator_loss=discriminator_loss_fn(gan_model))
+ generator_loss=generator_loss_fn(
+ gan_model, add_summaries=use_loss_summaries),
+ discriminator_loss=discriminator_loss_fn(
+ gan_model, add_summaries=use_loss_summaries))
if mode == model_fn_lib.ModeKeys.EVAL:
estimator_spec = _get_eval_estimator_spec(
gan_model, gan_loss, get_eval_metric_ops_fn)
diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py
index 9ac9c6ca9c..83f8dd641f 100644
--- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py
+++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py
@@ -116,7 +116,7 @@ def get_dummy_gan_model():
discriminator_fn=None)
-def dummy_loss_fn(gan_model):
+def dummy_loss_fn(gan_model, add_summaries=True):
return math_ops.reduce_sum(gan_model.discriminator_real_outputs -
gan_model.discriminator_gen_outputs)
diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py
index eee90864b4..52c9c4f3be 100644
--- a/tensorflow/contrib/layers/python/layers/layers_test.py
+++ b/tensorflow/contrib/layers/python/layers/layers_test.py
@@ -1288,7 +1288,7 @@ class ConvolutionInPlaneTest(test.TestCase):
result = sess.run(vert_gradients)
expected = np.zeros((1, 9, 10, 1))
- self.assertAllEqual(result, expected)
+ self.assertAllClose(result, expected, rtol=1e-5, atol=1e-5)
def testVertConvWithVaryingImage(self):
image = np.asmatrix(('1.0 2.0 3.0;' '1.1 2.0 4.0;' '-4.3 0.0 8.9'))
diff --git a/tensorflow/contrib/lite/RELEASE.md b/tensorflow/contrib/lite/RELEASE.md
deleted file mode 100644
index 8fd63d5cee..0000000000
--- a/tensorflow/contrib/lite/RELEASE.md
+++ /dev/null
@@ -1,8 +0,0 @@
-# Release 0.1.7
-
-* TensorFlow Lite 0.1.7 is based on tag `tflite-v0.1.7` (git commit
- fa1db5eb0da85b5baccc2a46d534fdeb3bb473d0).
-* To reproduce the iOS library, it's required to cherry pick git commit
- f1f1d5172fe5bfeaeb2cf657ffc43ba744187bee to fix a dependency issue.
-* The code is based on TensorFlow 1.8.0 release candidate and it's very close
- to TensorFlow 1.8.0 release.
diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl
index fc199f0a0e..0246e7fa30 100644
--- a/tensorflow/contrib/lite/build_def.bzl
+++ b/tensorflow/contrib/lite/build_def.bzl
@@ -57,6 +57,7 @@ def tflite_linkopts_unstripped():
"-Wl,--as-needed", # Don't link unused libs.
],
"//tensorflow:darwin": [],
+ "//tensorflow:ios": [],
"//tensorflow/contrib/lite:mips": [],
"//tensorflow/contrib/lite:mips64": [],
"//conditions:default": [
diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h
index c658e43092..7c5099235a 100644
--- a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h
+++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h
@@ -257,6 +257,16 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
} else {
max_coeff = raw_input.maxCoeff();
}
+
+ // Get normalization term of softmax: log(sum(exp(logit[j]-max_coeff))).
+ float logsumexp = 0.0;
+ for (int j = 0; j < raw_input.size(); ++j) {
+ logsumexp += Eigen::numext::exp(raw_input(j) - max_coeff);
+ }
+ logsumexp = Eigen::numext::log(logsumexp);
+ // Final normalization offset to get correct log probabilities.
+ float norm_offset = max_coeff + logsumexp;
+
const float label_selection_input_min =
(label_selection_margin_ >= 0) ? (max_coeff - label_selection_margin_)
: -std::numeric_limits<float>::infinity();
@@ -288,10 +298,10 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
beam_scorer_->GetStateExpansionScore(b->state, previous));
}
// Plabel(l=abc @ t=6) *= P(c @ 6)
- b->newp.label += raw_input(b->label) - max_coeff;
+ b->newp.label += raw_input(b->label) - norm_offset;
}
// Pblank(l=abc @ t=6) = P(l=abc @ t=5) * P(- @ 6)
- b->newp.blank = b->oldp.total + raw_input(blank_index_) - max_coeff;
+ b->newp.blank = b->oldp.total + raw_input(blank_index_) - norm_offset;
// P(l=abc @ t=6) = Plabel(l=abc @ t=6) + Pblank(l=abc @ t=6)
b->newp.total = LogSumExp(b->newp.blank, b->newp.label);
@@ -326,6 +336,8 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
const float logit = top_k ? top_k_logits[ind] : raw_input(ind);
// Perform label selection: if input for this label looks very
// unpromising, never evaluate it with a scorer.
+ // We may compare logits instead of log probabilities,
+ // since the difference is the same in both cases.
if (logit < label_selection_input_min) {
continue;
}
@@ -339,7 +351,7 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
// Plabel(l=abcd @ t=6) = P(l=abc @ t=5) * P(d @ 6)
beam_scorer_->ExpandState(b->state, b->label, &c.state, c.label);
float previous = (c.label == b->label) ? b->oldp.blank : b->oldp.total;
- c.newp.label = logit - max_coeff +
+ c.newp.label = logit - norm_offset +
beam_scorer_->GetStateExpansionScore(c.state, previous);
// P(l=abcd @ t=6) = Plabel(l=abcd @ t=6)
c.newp.total = c.newp.label;
diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc
index 32458305c4..aa42b495bd 100644
--- a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc
+++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc
@@ -117,7 +117,7 @@ TEST(CTCBeamSearchTest, SimpleTest) {
EXPECT_THAT(decoded_outputs[2], ElementsAre(1, 1));
// Check log probabilities output.
EXPECT_THAT(m.GetLogProbabilitiesOutput(),
- ElementsAreArray(ArrayFloatNear({0.32134813})));
+ ElementsAreArray(ArrayFloatNear({-0.357094})));
}
TEST(CTCBeamSearchTest, MultiBatchTest) {
@@ -148,9 +148,8 @@ TEST(CTCBeamSearchTest, MultiBatchTest) {
EXPECT_THAT(decoded_outputs[1], ElementsAre(1, 0, 0, 0));
EXPECT_THAT(decoded_outputs[2], ElementsAre(3, 2));
// Check log probabilities output.
- EXPECT_THAT(
- m.GetLogProbabilitiesOutput(),
- ElementsAreArray(ArrayFloatNear({0.46403232, 0.49500442, 0.40443572})));
+ EXPECT_THAT(m.GetLogProbabilitiesOutput(),
+ ElementsAreArray(ArrayFloatNear({-1.88343, -1.41188, -1.20958})));
}
TEST(CTCBeamSearchTest, MultiPathsTest) {
@@ -188,8 +187,8 @@ TEST(CTCBeamSearchTest, MultiPathsTest) {
EXPECT_THAT(decoded_outputs[5], ElementsAre(2, 2));
// Check log probabilities output.
EXPECT_THAT(m.GetLogProbabilitiesOutput(),
- ElementsAreArray(ArrayFloatNear(
- {0.91318405, 0.9060272, 1.0780245, 0.64358956})));
+ ElementsAreArray(
+ ArrayFloatNear({-2.65148, -2.65864, -2.17914, -2.61357})));
}
TEST(CTCBeamSearchTest, NonEqualSequencesTest) {
@@ -223,7 +222,7 @@ TEST(CTCBeamSearchTest, NonEqualSequencesTest) {
EXPECT_THAT(decoded_outputs[2], ElementsAre(3, 1));
// Check log probabilities output.
EXPECT_THAT(m.GetLogProbabilitiesOutput(),
- ElementsAreArray(ArrayFloatNear({0., 1.0347567, 0.7833005})));
+ ElementsAreArray(ArrayFloatNear({-0.97322, -1.16334, -2.15553})));
}
} // namespace
diff --git a/tensorflow/contrib/lite/g3doc/README.md b/tensorflow/contrib/lite/g3doc/README.md
deleted file mode 100644
index e3db478481..0000000000
--- a/tensorflow/contrib/lite/g3doc/README.md
+++ /dev/null
@@ -1,4 +0,0 @@
-This is a *work-in-progress* TF Lite subsite for:
-https://www.tensorflow.org/mobile
-
-DO NOT PUBLISH
diff --git a/tensorflow/contrib/lite/g3doc/api_docs/python/index.md b/tensorflow/contrib/lite/g3doc/api_docs/python/index.md
deleted file mode 100644
index 70031a3c3d..0000000000
--- a/tensorflow/contrib/lite/g3doc/api_docs/python/index.md
+++ /dev/null
@@ -1,10 +0,0 @@
-Project: /mobile/_project.yaml
-Book: /mobile/_book.yaml
-page_type: reference
-<style> table img { max-width: 100%; } </style>
-<script src="/_static/js/managed/mathjax/MathJax.js?config=TeX-AMS-MML_SVG"></script>
-
-<!-- DO NOT EDIT! Automatically generated file. -->
-# All symbols in TensorFlow Lite
-
-TEMP PAGE
diff --git a/tensorflow/contrib/lite/g3doc/apis.md b/tensorflow/contrib/lite/g3doc/apis.md
index f255017ad9..69616c7b8a 100644
--- a/tensorflow/contrib/lite/g3doc/apis.md
+++ b/tensorflow/contrib/lite/g3doc/apis.md
@@ -37,7 +37,7 @@ float* output = interpreter->typed_output_tensor<float>(0);
```
### Data Alignment
-TensorFlow Lite data is usually aligned to 32-bit boundaries. It is recommended
+TensorFlow Lite data is usually aligned to 16-byte boundaries. It is recommended
that all data provided to TensorFlow Lite be aligned that way.
### Error Reporting
@@ -112,7 +112,7 @@ below. It should be noted that:
* Tensors are represented by integers, in order to avoid string comparisons
(and any fixed dependency on string libraries).
- * An interpreter must not be accessed from concurrent threads
+ * An interpreter must not be accessed from concurrent threads.
* Memory allocation for input and output tensors must be triggered
by calling AllocateTensors() right after resizing tensors.
@@ -169,7 +169,7 @@ former provides error reporting facilities and access to global objects,
including all the tensors. The latter allows implementations to access their
inputs and outputs.
-When the interpreter loads a model, it calls init() once for each node in the
+When the interpreter loads a model, it calls `init()` once for each node in the
graph. A given `init()` will be called more than once if the op is used
multiple times in the graph. For custom ops a configuration buffer will be
provided, containing a flexbuffer that maps parameter names to their values.
@@ -210,8 +210,9 @@ namespace custom {
Note that registration is not automatic and an explicit call to
`Register_MY_CUSTOM_OP` should be made somewhere. While the standard
-`:builtin_ops` takes care of the registration of builtins, custom ops will have
-to be collected in separated custom libraries.
+`BuiltinOpResolver` (available from the `:builtin_ops` target) takes care of the
+registration of builtins, custom ops will have to be collected in separate
+custom libraries.
### Customizing the kernel library
@@ -232,7 +233,7 @@ class OpResolver {
};
```
-The regular usage will require the developer to use the `BuiltinOpResolver` and
+Regular usage will require the developer to use the `BuiltinOpResolver` and
write:
```c++
@@ -308,18 +309,25 @@ an `IllegalArgumentException` will be thrown.
#### Inputs
-Each input should be an array, a multi-dimensional array, or a `ByteBuffer` of
-the supported primitive types.
+Each input should be an array or multi-dimensional array of the supported
+primitive types, or a raw `ByteBuffer` of the appropriate size. If the input is
+an array or multi-dimensional array, the associated input tensor will be
+implicitly resized to the array's dimensions at inference time. If the input is
+a ByteBuffer, the caller should first manually resize the associated input
+tensor (via `Interpreter.resizeInput()`) before running inference.
-The use of `ByteBuffer` is preferred since it allows the `Interpreter` to avoid
-unnecessary copies. Each `ByteBuffer` needs to be a direct byte buffer, and its
-order must be `ByteOrder.nativeOrder()`. After it is used for a model inference,
-it must remain unchanged until the model inference is finished.
+When using 'ByteBuffer', prefer using direct byte buffers, as this allows the
+`Interpreter` to avoid unnecessary copies. If the `ByteBuffer` is a direct byte
+buffer, its order must be `ByteOrder.nativeOrder()`. After it is used for a
+model inference, it must remain unchanged until the model inference is finished.
#### Outputs
-Each output should be an array, or a multi-dimensional array of the supported
-primitive types.
+Each output should be an array or multi-dimensional array of the supported
+primitive types, or a ByteBuffer of the appropriate size. Note that some models
+have dynamic outputs, where the shape of output tensors can vary depending on
+the input. There's no straightforward way of handling this with the existing
+Java inference API, but planned extensions will make this possible.
#### Running Model Inference
@@ -339,9 +347,10 @@ interpreter.runForMultipleInputsOutputs(inputs, map_of_indices_to_outputs);
where each entry in `inputs` corresponds to an input tensor and
`map_of_indices_to_outputs` maps indices of output tensors to the
corresponding output data. In both cases the tensor indices should correspond to
-the values given to the `TensorFlow Lite Optimized Converter` when the model was
-created. Be aware that the order of tensors in `input` must match the order
-given to the `TensorFlow Lite Optimized Converter`.
+the values given to the [TensorFlow Lite Optimized Converter](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md)
+when the model was created. Be aware that the order of tensors in `input` must
+match the order given to the `TensorFlow Lite Optimized Converter`.
+
The Java API also provides convenient functions for app developers to get the
index of any model input or output using a tensor name:
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD
index 8287115f5c..b7c5cbf207 100644
--- a/tensorflow/contrib/lite/kernels/BUILD
+++ b/tensorflow/contrib/lite/kernels/BUILD
@@ -6,7 +6,7 @@ licenses(["notice"]) # Apache 2.0
load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts")
load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite")
-load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_opts_nortti_if_android")
# Suppress warnings that are introduced by Eigen Tensor.
EXTRA_EIGEN_COPTS = select({
@@ -147,7 +147,7 @@ tf_cc_test(
)
cc_library(
- name = "builtin_ops",
+ name = "builtin_op_kernels",
srcs = [
"activations.cc",
"add.cc",
@@ -177,6 +177,7 @@ cc_library(
"gather.cc",
"hashtable_lookup.cc",
"l2norm.cc",
+ "layer_norm_lstm.cc",
"local_response_norm.cc",
"logical.cc",
"lsh_projection.cc",
@@ -191,7 +192,7 @@ cc_library(
"pooling.cc",
"pow.cc",
"reduce.cc",
- "register.cc",
+ "relu1.cc",
"reshape.cc",
"resize_bilinear.cc",
"select.cc",
@@ -216,9 +217,9 @@ cc_library(
],
hdrs = [
"padding.h",
- "register.h",
],
- copts = tflite_copts() + EXTRA_EIGEN_COPTS,
+ copts = tflite_copts() + tf_opts_nortti_if_android() + EXTRA_EIGEN_COPTS,
+ visibility = ["//visibility:private"],
deps = [
":activation_functor",
":eigen_support",
@@ -242,6 +243,17 @@ cc_library(
],
)
+cc_library(
+ name = "builtin_ops",
+ srcs = ["register.cc"],
+ hdrs = ["register.h"],
+ deps = [
+ ":builtin_op_kernels",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite:util",
+ ],
+)
+
tf_cc_test(
name = "audio_spectrogram_test",
size = "small",
@@ -294,6 +306,23 @@ tf_cc_test(
)
tf_cc_test(
+ name = "relu1_test",
+ size = "small",
+ srcs = ["relu1_test.cc"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ "@flatbuffers",
+ ],
+)
+
+tf_cc_test(
name = "activations_test",
size = "small",
srcs = ["activations_test.cc"],
@@ -904,6 +933,20 @@ tf_cc_test(
)
tf_cc_test(
+ name = "layer_norm_lstm_test",
+ size = "small",
+ srcs = ["layer_norm_lstm_test.cc"],
+ tags = ["tflite_not_portable_ios"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ "@flatbuffers",
+ ],
+)
+
+tf_cc_test(
name = "lstm_test",
size = "small",
srcs = ["lstm_test.cc"],
diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/contrib/lite/kernels/activations.cc
index 9c891fe904..5cdd9fc94f 100644
--- a/tensorflow/contrib/lite/kernels/activations.cc
+++ b/tensorflow/contrib/lite/kernels/activations.cc
@@ -200,7 +200,7 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, input->type, output->type);
const int num_dims = NumDimensions(input);
- TF_LITE_ENSURE(context, num_dims == 1 || num_dims == 2 || num_dims == 4);
+ TF_LITE_ENSURE(context, num_dims >= 1 && num_dims <= 4);
if (input->type == kTfLiteUInt8) {
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
@@ -453,6 +453,19 @@ void Softmax2DFloat(const TfLiteTensor* input, TfLiteTensor* output,
Softmax(input->data.f, input_size, batch_size, params->beta, output->data.f);
}
+// Takes a 3D tensor and perform softmax along the last dimension.
+void Softmax3DFloat(const TfLiteTensor* input, TfLiteTensor* output,
+ TfLiteSoftmaxParams* params) {
+ const int batch_size = input->dims->data[0];
+ const int intermediate_size = input->dims->data[1];
+ const int input_size = input->dims->data[2];
+ optimized_ops::Softmax(
+ GetTensorData<float>(input),
+ GetTensorShape({batch_size, intermediate_size, 1, input_size}),
+ params->beta, GetTensorData<float>(output),
+ GetTensorShape({batch_size, intermediate_size, 1, input_size}));
+}
+
void Softmax1DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
TfLiteSoftmaxParams* params, OpData* data) {
// TODO(ahentz): this is arguably a dirty trick. Since the implementation
@@ -480,6 +493,19 @@ void Softmax2DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
GetTensorShape({batch_size, 1, 1, input_size}));
}
+void Softmax3DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
+ TfLiteSoftmaxParams* params, OpData* data) {
+ const int batch_size = input->dims->data[0];
+ const int intermediate_size = input->dims->data[1];
+ const int input_size = input->dims->data[2];
+ optimized_ops::Softmax(
+ GetTensorData<uint8_t>(input),
+ GetTensorShape({batch_size, intermediate_size, 1, input_size}),
+ data->input_multiplier, data->input_left_shift, data->diff_min,
+ GetTensorData<uint8_t>(output),
+ GetTensorShape({batch_size, intermediate_size, 1, input_size}));
+}
+
// Takes a 4D tensor and perform softmax along the forth dimension.
void Softmax4DFloat(const TfLiteTensor* input, TfLiteTensor* output,
TfLiteSoftmaxParams* params) {
@@ -515,6 +541,10 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
Softmax2DFloat(input, output, params);
return kTfLiteOk;
}
+ if (NumDimensions(input) == 3) {
+ Softmax3DFloat(input, output, params);
+ return kTfLiteOk;
+ }
if (NumDimensions(input) == 4) {
Softmax4DFloat(input, output, params);
return kTfLiteOk;
@@ -533,6 +563,10 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
Softmax2DQuantized(input, output, params, data);
return kTfLiteOk;
}
+ if (NumDimensions(input) == 3) {
+ Softmax3DQuantized(input, output, params, data);
+ return kTfLiteOk;
+ }
if (NumDimensions(input) == 4) {
Softmax4DQuantized(input, output, params, data);
return kTfLiteOk;
diff --git a/tensorflow/contrib/lite/kernels/activations_test.cc b/tensorflow/contrib/lite/kernels/activations_test.cc
index e577e3a762..9fa47e190a 100644
--- a/tensorflow/contrib/lite/kernels/activations_test.cc
+++ b/tensorflow/contrib/lite/kernels/activations_test.cc
@@ -339,6 +339,76 @@ TEST(QuantizedActivationsOpTest, Softmax4D) {
kQuantizedTolerance)));
}
+TEST(FloatActivationsOpTest, Softmax3D) {
+ FloatActivationsOpModel m(0.1,
+ /*input=*/{TensorType_FLOAT32, {1, 2, 4}});
+ m.SetInput({
+ 0, -6, 2, 4, // depth = 0
+ 3, -2, 10, 1, // depth = 1
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
+ .23463, .12877, .28658, .35003, //
+ .22528, .13664, .45365, .18443, //
+ })));
+
+ // Same input, but a different shape.
+ FloatActivationsOpModel m2(0.1,
+ /*input=*/{TensorType_FLOAT32, {4, 1, 2}});
+ m2.SetInput({
+ 0, -6, //
+ 2, 4, //
+ 3, -2, //
+ 10, 1, //
+ });
+ m2.Invoke();
+ EXPECT_THAT(m2.GetOutput(), ElementsAreArray(ArrayFloatNear({
+ 0.645656, 0.354344, //
+ 0.450166, 0.549834, //
+ 0.622459, 0.377541, //
+ 0.710949, 0.28905, //
+ })));
+}
+
+TEST(QuantizedActivationsOpTest, Softmax3D) {
+ QuantizedActivationsOpModel m(
+ 0.1,
+ /*input=*/{TensorType_UINT8, {1, 2, 4}, -10, 10});
+ m.SetInput<uint8_t>({
+ 0, -6, 2, 4, // depth = 0
+ 3, -2, 10, 1, // depth = 1
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ .23463, .12877, .28658, .35003, //
+ .22528, .13664, .45365, .18443, //
+ },
+ kQuantizedTolerance)));
+
+ // Same input, but a different shape.
+ QuantizedActivationsOpModel m2(
+ 0.1,
+ /*input=*/{TensorType_UINT8, {4, 1, 2}, -10, 10});
+ m2.SetInput<uint8_t>({
+ 0, -6, //
+ 2, 4, //
+ 3, -2, //
+ 10, 1, //
+ });
+ m2.Invoke();
+ EXPECT_THAT(m2.GetDequantizedOutput<uint8_t>(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ 0.645656, 0.354344, //
+ 0.450166, 0.549834, //
+ 0.622459, 0.377541, //
+ 0.710949, 0.28905, //
+ },
+ kQuantizedTolerance)));
+}
+
TEST(FloatActivationsOpTest, Softmax1D) {
FloatActivationsOpModel m(0.1,
/*input=*/{TensorType_FLOAT32, {8}});
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
index af47b33922..cde4f55a16 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
@@ -108,9 +108,26 @@ constexpr int kBwInputCellStateTensor = 38;
constexpr int kFwOutputTensor = 0;
constexpr int kBwOutputTensor = 1;
+// Temporary tensors.
+enum TemporaryTensor {
+ // Scratch buffers for input, forget, etc. gates
+ kFwScratchBuffer = 0,
+ kBwScratchBuffer = 1,
+ // Quantized tensors needed for the hybrid kernel.
+ kInputQuantized = 2,
+ kFwActivationStateQuantized = 3,
+ kBwActivationStateQuantized = 4,
+ kFwCellStateQuantized = 5,
+ kBwCellStateQuantized = 6,
+ kScalingFactors = 7,
+ kProductScalingFactors = 8,
+ kRecoveredCellWeights = 9,
+ kNumTemporaryTensors = 10
+};
+
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* scratch_tensor_index = new int;
- context->AddTensors(context, /*tensors_to_add=*/2, scratch_tensor_index);
+ context->AddTensors(context, kNumTemporaryTensors, scratch_tensor_index);
return scratch_tensor_index;
}
@@ -131,7 +148,7 @@ TfLiteStatus CheckLstmTensorDimensions(
int input_gate_bias_tensor, int forget_gate_bias_tensor,
int cell_gate_bias_tensor, int output_gate_bias_tensor,
int projection_weights_tensor, int projection_bias_tensor) {
- auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
+ const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
// Making sure clipping parameters have valid values.
// == 0 means no clipping
@@ -324,7 +341,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Inferring batch size, number of outputs and sequence length and
// number of cells from the input tensors.
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
- TF_LITE_ENSURE(context, input->dims->size > 1);
+ TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
+ TF_LITE_ENSURE_EQ(context, input->dims->size, 3);
const int max_time = input->dims->data[0];
const int n_batch = input->dims->data[1];
const int n_input = input->dims->data[2];
@@ -370,11 +388,19 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, fw_output, fw_output_size));
- // Create a scratch buffer tensor.
+ // The weights are of consistent type, so it suffices to check one.
+ const bool is_hybrid_op = (fw_input_to_output_weights->type == kTfLiteUInt8);
+
TfLiteIntArrayFree(node->temporaries);
- node->temporaries = TfLiteIntArrayCreate(2);
- node->temporaries->data[0] = *scratch_tensor_index;
- TfLiteTensor* fw_scratch_buffer = GetTemporary(context, node, /*index=*/0);
+ if (is_hybrid_op) {
+ node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors);
+ } else {
+ node->temporaries = TfLiteIntArrayCreate(2); // the two scratch buffers.
+ }
+ // Create a scratch buffer tensor.
+ node->temporaries->data[kFwScratchBuffer] = *scratch_tensor_index;
+ TfLiteTensor* fw_scratch_buffer =
+ GetTemporary(context, node, kFwScratchBuffer);
fw_scratch_buffer->type = input->type;
fw_scratch_buffer->allocation_type = kTfLiteArenaRw;
@@ -435,8 +461,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumElements(bw_cell_state), n_batch * n_bw_cell);
// Create a scratch buffer tensor.
- node->temporaries->data[1] = *(scratch_tensor_index) + 1;
- TfLiteTensor* bw_scratch_buffer = GetTemporary(context, node, /*index=*/1);
+ node->temporaries->data[kBwScratchBuffer] =
+ *(scratch_tensor_index) + kBwScratchBuffer;
+ TfLiteTensor* bw_scratch_buffer =
+ GetTemporary(context, node, kBwScratchBuffer);
bw_scratch_buffer->type = input->type;
bw_scratch_buffer->allocation_type = kTfLiteArenaRw;
@@ -454,18 +482,441 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_scratch_buffer,
bw_scratch_buffer_size));
+ if (is_hybrid_op) {
+ // Allocate temporary tensors to store quantized values of input,
+ // output_state and cell_state tensors.
+ node->temporaries->data[kInputQuantized] =
+ *scratch_tensor_index + kInputQuantized;
+ TfLiteTensor* input_quantized =
+ GetTemporary(context, node, kInputQuantized);
+ input_quantized->type = kTfLiteUInt8;
+ input_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
+ TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
+ input_quantized_size));
+ }
+
+ node->temporaries->data[kFwActivationStateQuantized] =
+ *scratch_tensor_index + kFwActivationStateQuantized;
+ TfLiteTensor* fw_activation_state_quantized =
+ GetTemporary(context, node, kFwActivationStateQuantized);
+ fw_activation_state_quantized->type = kTfLiteUInt8;
+ fw_activation_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(fw_activation_state_quantized->dims,
+ fw_activation_state->dims)) {
+ TfLiteIntArray* fw_activation_state_quantized_size =
+ TfLiteIntArrayCopy(fw_activation_state->dims);
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(context, fw_activation_state_quantized,
+ fw_activation_state_quantized_size));
+ }
+ node->temporaries->data[kBwActivationStateQuantized] =
+ *scratch_tensor_index + kBwActivationStateQuantized;
+ TfLiteTensor* bw_activation_state_quantized =
+ GetTemporary(context, node, kBwActivationStateQuantized);
+ bw_activation_state_quantized->type = kTfLiteUInt8;
+ bw_activation_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(bw_activation_state_quantized->dims,
+ bw_activation_state->dims)) {
+ TfLiteIntArray* bw_activation_state_quantized_size =
+ TfLiteIntArrayCopy(bw_activation_state->dims);
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(context, bw_activation_state_quantized,
+ bw_activation_state_quantized_size));
+ }
+ node->temporaries->data[kFwCellStateQuantized] =
+ *scratch_tensor_index + kFwCellStateQuantized;
+ TfLiteTensor* fw_cell_state_quantized =
+ GetTemporary(context, node, kFwCellStateQuantized);
+ fw_cell_state_quantized->type = kTfLiteUInt8;
+ fw_cell_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(fw_cell_state_quantized->dims,
+ fw_cell_state->dims)) {
+ TfLiteIntArray* fw_cell_state_quantized_size =
+ TfLiteIntArrayCopy(fw_cell_state->dims);
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, fw_cell_state_quantized,
+ fw_cell_state_quantized_size));
+ }
+ node->temporaries->data[kBwCellStateQuantized] =
+ *scratch_tensor_index + kBwCellStateQuantized;
+ TfLiteTensor* bw_cell_state_quantized =
+ GetTemporary(context, node, kBwCellStateQuantized);
+ bw_cell_state_quantized->type = kTfLiteUInt8;
+ bw_cell_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(bw_cell_state_quantized->dims,
+ bw_cell_state->dims)) {
+ TfLiteIntArray* bw_cell_state_quantized_size =
+ TfLiteIntArrayCopy(bw_cell_state->dims);
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, bw_cell_state_quantized,
+ bw_cell_state_quantized_size));
+ }
+
+ // Allocate temporary tensors to store scaling factors and product scaling
+ // factors. The latter is a convenience storage which allows to quantize
+ // a vector once (which produces the scaling factors) and multiply it with
+ // different matrices (which requires multiplying the scaling factors with
+ // the scaling factor of the matrix).
+ node->temporaries->data[kScalingFactors] =
+ *scratch_tensor_index + kScalingFactors;
+ TfLiteTensor* scaling_factors =
+ GetTemporary(context, node, kScalingFactors);
+ scaling_factors->type = kTfLiteFloat32;
+ scaling_factors->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
+ scaling_factors_size->data[0] = n_batch;
+ if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) {
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
+ scaling_factors_size));
+ }
+ node->temporaries->data[kProductScalingFactors] =
+ *scratch_tensor_index + kProductScalingFactors;
+ TfLiteTensor* prod_scaling_factors =
+ GetTemporary(context, node, kProductScalingFactors);
+ prod_scaling_factors->type = kTfLiteFloat32;
+ prod_scaling_factors->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1);
+ prod_scaling_factors_size->data[0] = n_batch;
+ if (!TfLiteIntArrayEqual(prod_scaling_factors->dims,
+ prod_scaling_factors_size)) {
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, prod_scaling_factors,
+ prod_scaling_factors_size));
+ }
+
+ // Allocate a temporary tensor to store the recovered cell weights. Since
+ // this is used for diagonal matrices, only need to store n_cell values.
+ node->temporaries->data[kRecoveredCellWeights] =
+ *scratch_tensor_index + kRecoveredCellWeights;
+ TfLiteTensor* recovered_cell_weights =
+ GetTemporary(context, node, kRecoveredCellWeights);
+ recovered_cell_weights->type = kTfLiteFloat32;
+ recovered_cell_weights->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* recovered_cell_weights_size = TfLiteIntArrayCreate(1);
+ recovered_cell_weights_size->data[0] = n_fw_cell;
+ if (!TfLiteIntArrayEqual(recovered_cell_weights->dims,
+ recovered_cell_weights_size)) {
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, recovered_cell_weights,
+ recovered_cell_weights_size));
+ }
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus EvalFloat(
+ const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
+ const TfLiteTensor* input_to_forget_weights,
+ const TfLiteTensor* input_to_cell_weights,
+ const TfLiteTensor* input_to_output_weights,
+ const TfLiteTensor* recurrent_to_input_weights,
+ const TfLiteTensor* recurrent_to_forget_weights,
+ const TfLiteTensor* recurrent_to_cell_weights,
+ const TfLiteTensor* recurrent_to_output_weights,
+ const TfLiteTensor* cell_to_input_weights,
+ const TfLiteTensor* cell_to_forget_weights,
+ const TfLiteTensor* cell_to_output_weights,
+ const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
+ const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
+ const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
+ const TfLiteLSTMParams* params, bool forward_sequence,
+ TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state,
+ TfLiteTensor* cell_state, TfLiteTensor* output) {
+ const int max_time = input->dims->data[0];
+ const int n_batch = input->dims->data[1];
+ const int n_input = input->dims->data[2];
+
+ // n_cell and n_output will be the same size when there is no projection.
+ const int n_cell = input_to_output_weights->dims->data[0];
+ const int n_output = recurrent_to_output_weights->dims->data[1];
+
+ // Since we have already checked that weights are all there or none, we can
+ // check the existense of only one to the get the condition.
+ const bool use_cifg = (input_to_input_weights == nullptr);
+ const bool use_peephole = (cell_to_output_weights != nullptr);
+
+ // Index the scratch buffers pointers to the global scratch buffer.
+ float* input_gate_scratch = nullptr;
+ float* cell_scratch = nullptr;
+ float* forget_gate_scratch = nullptr;
+ float* output_gate_scratch = nullptr;
+ if (use_cifg) {
+ cell_scratch = scratch_buffer->data.f;
+ forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ } else {
+ input_gate_scratch = scratch_buffer->data.f;
+ cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
+ }
+
+ // Check optional tensors, the respective pointers can be null.
+ const float* input_to_input_weights_ptr =
+ (use_cifg) ? nullptr : input_to_input_weights->data.f;
+ const float* recurrent_to_input_weights_ptr =
+ (use_cifg) ? nullptr : recurrent_to_input_weights->data.f;
+ const float* input_gate_bias_ptr =
+ (use_cifg) ? nullptr : input_gate_bias->data.f;
+ const float* cell_to_input_weights_ptr =
+ (use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr;
+ const float* cell_to_forget_weights_ptr =
+ (use_peephole) ? cell_to_forget_weights->data.f : nullptr;
+ const float* cell_to_output_weights_ptr =
+ (use_peephole) ? cell_to_output_weights->data.f : nullptr;
+ const float* projection_weights_ptr =
+ (projection_weights == nullptr) ? nullptr : projection_weights->data.f;
+ const float* projection_bias_ptr =
+ (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
+
+ // Loop through the sequence.
+ if (forward_sequence) {
+ for (int t = 0; t < max_time; t++) {
+ const float* input_ptr = input->data.f + t * n_batch * n_input;
+ float* output_ptr_time = output->data.f + t * n_batch * n_output;
+
+ kernel_utils::LstmStep(
+ input_ptr, input_to_input_weights_ptr,
+ input_to_forget_weights->data.f, input_to_cell_weights->data.f,
+ input_to_output_weights->data.f, recurrent_to_input_weights_ptr,
+ recurrent_to_forget_weights->data.f,
+ recurrent_to_cell_weights->data.f,
+ recurrent_to_output_weights->data.f, cell_to_input_weights_ptr,
+ cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
+ input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f,
+ output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr,
+ params, n_batch, n_cell, n_input, n_output, activation_state->data.f,
+ cell_state->data.f, input_gate_scratch, forget_gate_scratch,
+ cell_scratch, output_gate_scratch, output_ptr_time);
+ }
+ } else {
+ // Loop through the sequence backwards.
+ for (int t = max_time - 1; t >= 0; t--) {
+ const float* input_ptr = input->data.f + t * n_batch * n_input;
+ float* output_ptr_time = output->data.f + t * n_batch * n_output;
+
+ kernel_utils::LstmStep(
+ input_ptr, input_to_input_weights_ptr,
+ input_to_forget_weights->data.f, input_to_cell_weights->data.f,
+ input_to_output_weights->data.f, recurrent_to_input_weights_ptr,
+ recurrent_to_forget_weights->data.f,
+ recurrent_to_cell_weights->data.f,
+ recurrent_to_output_weights->data.f, cell_to_input_weights_ptr,
+ cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
+ input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f,
+ output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr,
+ params, n_batch, n_cell, n_input, n_output, activation_state->data.f,
+ cell_state->data.f, input_gate_scratch, forget_gate_scratch,
+ cell_scratch, output_gate_scratch, output_ptr_time);
+ }
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus EvalHybrid(
+ const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
+ const TfLiteTensor* input_to_forget_weights,
+ const TfLiteTensor* input_to_cell_weights,
+ const TfLiteTensor* input_to_output_weights,
+ const TfLiteTensor* recurrent_to_input_weights,
+ const TfLiteTensor* recurrent_to_forget_weights,
+ const TfLiteTensor* recurrent_to_cell_weights,
+ const TfLiteTensor* recurrent_to_output_weights,
+ const TfLiteTensor* cell_to_input_weights,
+ const TfLiteTensor* cell_to_forget_weights,
+ const TfLiteTensor* cell_to_output_weights,
+ const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
+ const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
+ const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
+ const TfLiteLSTMParams* params, bool forward_sequence,
+ TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors,
+ TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights,
+ TfLiteTensor* input_quantized, TfLiteTensor* output_state_quantized,
+ TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state,
+ TfLiteTensor* cell_state, TfLiteTensor* output) {
+ const int max_time = input->dims->data[0];
+ const int n_batch = input->dims->data[1];
+ const int n_input = input->dims->data[2];
+ // n_cell and n_output will be the same size when there is no projection.
+ const int n_cell = input_to_output_weights->dims->data[0];
+ const int n_output = recurrent_to_output_weights->dims->data[1];
+
+ // Since we have already checked that weights are all there or none, we can
+ // check the existence of only one to get the condition.
+ const bool use_cifg = (input_to_input_weights == nullptr);
+ const bool use_peephole = (cell_to_output_weights != nullptr);
+
+ float* input_gate_scratch = nullptr;
+ float* cell_scratch = nullptr;
+ float* forget_gate_scratch = nullptr;
+ float* output_gate_scratch = nullptr;
+ if (use_cifg) {
+ cell_scratch = scratch_buffer->data.f;
+ forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ } else {
+ input_gate_scratch = scratch_buffer->data.f;
+ cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
+ }
+
+ // Check optional tensors, the respective pointers can be null.
+ int8_t* input_to_input_weights_ptr = nullptr;
+ float input_to_input_weights_scale = 1.0f;
+ int8_t* recurrent_to_input_weights_ptr = nullptr;
+ float recurrent_to_input_weights_scale = 1.0f;
+ float* input_gate_bias_ptr = nullptr;
+ if (!use_cifg) {
+ input_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_input_weights->data.uint8);
+ recurrent_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_input_weights->data.uint8);
+ input_gate_bias_ptr = input_gate_bias->data.f;
+ input_to_input_weights_scale = input_to_input_weights->params.scale;
+ recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale;
+ }
+
+ int8_t* cell_to_input_weights_ptr = nullptr;
+ int8_t* cell_to_forget_weights_ptr = nullptr;
+ int8_t* cell_to_output_weights_ptr = nullptr;
+ float cell_to_input_weights_scale = 1.0f;
+ float cell_to_forget_weights_scale = 1.0f;
+ float cell_to_output_weights_scale = 1.0f;
+ if (use_peephole) {
+ if (!use_cifg) {
+ cell_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(cell_to_input_weights->data.uint8);
+ cell_to_input_weights_scale = cell_to_input_weights->params.scale;
+ }
+ cell_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(cell_to_forget_weights->data.uint8);
+ cell_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(cell_to_output_weights->data.uint8);
+ cell_to_forget_weights_scale = cell_to_forget_weights->params.scale;
+ cell_to_output_weights_scale = cell_to_output_weights->params.scale;
+ }
+
+ const int8_t* projection_weights_ptr =
+ (projection_weights == nullptr)
+ ? nullptr
+ : reinterpret_cast<int8_t*>(projection_weights->data.uint8);
+ const float projection_weights_scale =
+ (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale;
+ const float* projection_bias_ptr =
+ (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
+
+ // Required tensors, pointers are non-null.
+ const int8_t* input_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_forget_weights->data.uint8);
+ const float input_to_forget_weights_scale =
+ input_to_forget_weights->params.scale;
+ const int8_t* input_to_cell_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_cell_weights->data.uint8);
+ const float input_to_cell_weights_scale = input_to_cell_weights->params.scale;
+ const int8_t* input_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_output_weights->data.uint8);
+ const float input_to_output_weights_scale =
+ input_to_output_weights->params.scale;
+ const int8_t* recurrent_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_forget_weights->data.uint8);
+ const float recurrent_to_forget_weights_scale =
+ recurrent_to_forget_weights->params.scale;
+ const int8_t* recurrent_to_cell_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_cell_weights->data.uint8);
+ const float recurrent_to_cell_weights_scale =
+ recurrent_to_cell_weights->params.scale;
+ const int8_t* recurrent_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_output_weights->data.uint8);
+ const float recurrent_to_output_weights_scale =
+ recurrent_to_output_weights->params.scale;
+ const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
+ const float* cell_bias_ptr = cell_bias->data.f;
+ const float* output_gate_bias_ptr = output_gate_bias->data.f;
+
+ float* output_state_ptr = output_state->data.f;
+ float* cell_state_ptr = cell_state->data.f;
+
+ // Temporary storage for quantized values and scaling factors.
+ int8_t* quantized_input_ptr =
+ reinterpret_cast<int8_t*>(input_quantized->data.uint8);
+ int8_t* quantized_output_state_ptr =
+ reinterpret_cast<int8_t*>(output_state_quantized->data.uint8);
+ int8_t* quantized_cell_state_ptr =
+ reinterpret_cast<int8_t*>(cell_state_quantized->data.uint8);
+ float* scaling_factors_ptr = scaling_factors->data.f;
+ float* prod_scaling_factors_ptr = prod_scaling_factors->data.f;
+ float* recovered_cell_weights_ptr = recovered_cell_weights->data.f;
+
+ if (forward_sequence) {
+ // Feed the sequence into the LSTM step-by-step.
+ for (int t = 0; t < max_time; t++) {
+ const float* input_ptr = input->data.f + t * n_batch * n_input;
+ float* output_ptr = output->data.f + t * n_batch * n_output;
+
+ kernel_utils::LstmStep(
+ input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale,
+ input_to_forget_weights_ptr, input_to_forget_weights_scale,
+ input_to_cell_weights_ptr, input_to_cell_weights_scale,
+ input_to_output_weights_ptr, input_to_output_weights_scale,
+ recurrent_to_input_weights_ptr, recurrent_to_input_weights_scale,
+ recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_scale,
+ recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_scale,
+ recurrent_to_output_weights_ptr, recurrent_to_output_weights_scale,
+ cell_to_input_weights_ptr, cell_to_input_weights_scale,
+ cell_to_forget_weights_ptr, cell_to_forget_weights_scale,
+ cell_to_output_weights_ptr, cell_to_output_weights_scale,
+ input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr,
+ output_gate_bias_ptr, projection_weights_ptr,
+ projection_weights_scale, projection_bias_ptr, params, n_batch,
+ n_cell, n_input, n_output, input_gate_scratch, forget_gate_scratch,
+ cell_scratch, output_gate_scratch, scaling_factors_ptr,
+ prod_scaling_factors_ptr, recovered_cell_weights_ptr,
+ quantized_input_ptr, quantized_output_state_ptr,
+ quantized_cell_state_ptr, output_state_ptr, cell_state_ptr,
+ output_ptr);
+ }
+ } else {
+ // Loop through the sequence backwards.
+ for (int t = max_time - 1; t >= 0; t--) {
+ const float* input_ptr = input->data.f + t * n_batch * n_input;
+ float* output_ptr = output->data.f + t * n_batch * n_output;
+
+ kernel_utils::LstmStep(
+ input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale,
+ input_to_forget_weights_ptr, input_to_forget_weights_scale,
+ input_to_cell_weights_ptr, input_to_cell_weights_scale,
+ input_to_output_weights_ptr, input_to_output_weights_scale,
+ recurrent_to_input_weights_ptr, recurrent_to_input_weights_scale,
+ recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_scale,
+ recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_scale,
+ recurrent_to_output_weights_ptr, recurrent_to_output_weights_scale,
+ cell_to_input_weights_ptr, cell_to_input_weights_scale,
+ cell_to_forget_weights_ptr, cell_to_forget_weights_scale,
+ cell_to_output_weights_ptr, cell_to_output_weights_scale,
+ input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr,
+ output_gate_bias_ptr, projection_weights_ptr,
+ projection_weights_scale, projection_bias_ptr, params, n_batch,
+ n_cell, n_input, n_output, input_gate_scratch, forget_gate_scratch,
+ cell_scratch, output_gate_scratch, scaling_factors_ptr,
+ prod_scaling_factors_ptr, recovered_cell_weights_ptr,
+ quantized_input_ptr, quantized_output_state_ptr,
+ quantized_cell_state_ptr, output_state_ptr, cell_state_ptr,
+ output_ptr);
+ }
+ }
+
return kTfLiteOk;
}
// The LSTM Op engine.
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
+ const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
// Input tensor.
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
- const int max_time = input->dims->data[0];
- const int n_batch = input->dims->data[1];
- const int n_input = input->dims->data[2];
// Tensors for the forward cell.
const TfLiteTensor* fw_input_to_input_weights =
@@ -559,149 +1010,91 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetVariableInput(context, node, kBwInputCellStateTensor);
TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
- // n_cell and n_output will be the same size when there is no projection.
- const int n_fw_cell = fw_input_to_output_weights->dims->data[0];
- const int n_fw_output = fw_recurrent_to_output_weights->dims->data[1];
-
- // Since we have already checked that weights are all there or none, we can
- // check the existense of only one to the get the condition.
- const bool fw_use_cifg = (fw_input_to_input_weights == nullptr);
- const bool fw_use_peephole = (fw_cell_to_output_weights != nullptr);
-
- // Index the scratch buffers pointers to the global scratch buffer.
TfLiteTensor* fw_scratch_buffer =
- &context->tensors[node->temporaries->data[0]];
- float* fw_input_gate_scratch = nullptr;
- float* fw_cell_scratch = nullptr;
- float* fw_forget_gate_scratch = nullptr;
- float* fw_output_gate_scratch = nullptr;
- if (fw_use_cifg) {
- fw_cell_scratch = fw_scratch_buffer->data.f;
- fw_forget_gate_scratch = fw_scratch_buffer->data.f + n_fw_cell * n_batch;
- fw_output_gate_scratch =
- fw_scratch_buffer->data.f + 2 * n_fw_cell * n_batch;
- } else {
- fw_input_gate_scratch = fw_scratch_buffer->data.f;
- fw_cell_scratch = fw_scratch_buffer->data.f + n_fw_cell * n_batch;
- fw_forget_gate_scratch =
- fw_scratch_buffer->data.f + 2 * n_fw_cell * n_batch;
- fw_output_gate_scratch =
- fw_scratch_buffer->data.f + 3 * n_fw_cell * n_batch;
- }
-
- // Check optional tensors, the respective pointers can be null.
- const float* fw_input_to_input_weights_ptr =
- (fw_use_cifg) ? nullptr : fw_input_to_input_weights->data.f;
- const float* fw_recurrent_to_input_weights_ptr =
- (fw_use_cifg) ? nullptr : fw_recurrent_to_input_weights->data.f;
- const float* fw_input_gate_bias_ptr =
- (fw_use_cifg) ? nullptr : fw_input_gate_bias->data.f;
- const float* fw_cell_to_input_weights_ptr =
- (fw_use_peephole && !fw_use_cifg) ? fw_cell_to_input_weights->data.f
- : nullptr;
- const float* fw_cell_to_forget_weights_ptr =
- (fw_use_peephole) ? fw_cell_to_forget_weights->data.f : nullptr;
- const float* fw_cell_to_output_weights_ptr =
- (fw_use_peephole) ? fw_cell_to_output_weights->data.f : nullptr;
- const float* fw_projection_weights_ptr = (fw_projection_weights == nullptr)
- ? nullptr
- : fw_projection_weights->data.f;
- const float* fw_projection_bias_ptr =
- (fw_projection_bias == nullptr) ? nullptr : fw_projection_bias->data.f;
-
- // Loop through the sequence.
- for (int t = 0; t < max_time; t++) {
- const float* input_ptr_batch = input->data.f + t * n_batch * n_input;
- float* output_ptr_time = fw_output->data.f + t * n_batch * n_fw_output;
-
- kernel_utils::LstmStep(
- input_ptr_batch, fw_input_to_input_weights_ptr,
- fw_input_to_forget_weights->data.f, fw_input_to_cell_weights->data.f,
- fw_input_to_output_weights->data.f, fw_recurrent_to_input_weights_ptr,
- fw_recurrent_to_forget_weights->data.f,
- fw_recurrent_to_cell_weights->data.f,
- fw_recurrent_to_output_weights->data.f, fw_cell_to_input_weights_ptr,
- fw_cell_to_forget_weights_ptr, fw_cell_to_output_weights_ptr,
- fw_input_gate_bias_ptr, fw_forget_gate_bias->data.f,
- fw_cell_bias->data.f, fw_output_gate_bias->data.f,
- fw_projection_weights_ptr, fw_projection_bias_ptr, params, n_batch,
- n_fw_cell, n_input, n_fw_output, fw_activation_state->data.f,
- fw_cell_state->data.f, fw_input_gate_scratch, fw_forget_gate_scratch,
- fw_cell_scratch, fw_output_gate_scratch, output_ptr_time);
- }
-
- // n_cell and n_output will be the same size when there is no projection.
- const int n_bw_cell = bw_input_to_output_weights->dims->data[0];
- const int n_bw_output = bw_recurrent_to_output_weights->dims->data[1];
-
- // Since we have already checked that weights are all there or none, we can
- // check the existense of only one to the get the condition.
- const bool bw_use_cifg = (bw_input_to_input_weights == nullptr);
- const bool bw_use_peephole = (bw_cell_to_output_weights != nullptr);
-
- // Index the scratch buffers pointers to the global scratch buffer.
+ GetTemporary(context, node, kFwScratchBuffer);
TfLiteTensor* bw_scratch_buffer =
- &context->tensors[node->temporaries->data[1]];
- float* bw_input_gate_scratch = nullptr;
- float* bw_cell_scratch = nullptr;
- float* bw_forget_gate_scratch = nullptr;
- float* bw_output_gate_scratch = nullptr;
- if (bw_use_cifg) {
- bw_cell_scratch = bw_scratch_buffer->data.f;
- bw_forget_gate_scratch = bw_scratch_buffer->data.f + n_bw_cell * n_batch;
- bw_output_gate_scratch =
- bw_scratch_buffer->data.f + 2 * n_bw_cell * n_batch;
- } else {
- bw_input_gate_scratch = bw_scratch_buffer->data.f;
- bw_cell_scratch = bw_scratch_buffer->data.f + n_bw_cell * n_batch;
- bw_forget_gate_scratch =
- bw_scratch_buffer->data.f + 2 * n_bw_cell * n_batch;
- bw_output_gate_scratch =
- bw_scratch_buffer->data.f + 3 * n_bw_cell * n_batch;
+ GetTemporary(context, node, kBwScratchBuffer);
+
+ switch (fw_input_to_output_weights->type) {
+ case kTfLiteFloat32: {
+ TfLiteStatus fw_pass_status = EvalFloat(
+ input, fw_input_to_input_weights, fw_input_to_forget_weights,
+ fw_input_to_cell_weights, fw_input_to_output_weights,
+ fw_recurrent_to_input_weights, fw_recurrent_to_forget_weights,
+ fw_recurrent_to_cell_weights, fw_recurrent_to_output_weights,
+ fw_cell_to_input_weights, fw_cell_to_forget_weights,
+ fw_cell_to_output_weights, fw_input_gate_bias, fw_forget_gate_bias,
+ fw_cell_bias, fw_output_gate_bias, fw_projection_weights,
+ fw_projection_bias, params, /*forward_sequence=*/true,
+ fw_scratch_buffer, fw_activation_state, fw_cell_state, fw_output);
+ TF_LITE_ENSURE_OK(context, fw_pass_status);
+
+ TfLiteStatus bw_pass_status = EvalFloat(
+ input, bw_input_to_input_weights, bw_input_to_forget_weights,
+ bw_input_to_cell_weights, bw_input_to_output_weights,
+ bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights,
+ bw_recurrent_to_cell_weights, bw_recurrent_to_output_weights,
+ bw_cell_to_input_weights, bw_cell_to_forget_weights,
+ bw_cell_to_output_weights, bw_input_gate_bias, bw_forget_gate_bias,
+ bw_cell_bias, bw_output_gate_bias, bw_projection_weights,
+ bw_projection_bias, params, /*forward_sequence=*/false,
+ bw_scratch_buffer, bw_activation_state, bw_cell_state, bw_output);
+ TF_LITE_ENSURE_OK(context, bw_pass_status);
+ return kTfLiteOk;
+ }
+ case kTfLiteUInt8: {
+ TfLiteTensor* input_quantized =
+ GetTemporary(context, node, kInputQuantized);
+ TfLiteTensor* fw_activation_state_quantized =
+ GetTemporary(context, node, kFwActivationStateQuantized);
+ TfLiteTensor* bw_activation_state_quantized =
+ GetTemporary(context, node, kBwActivationStateQuantized);
+ TfLiteTensor* fw_cell_state_quantized =
+ GetTemporary(context, node, kFwCellStateQuantized);
+ TfLiteTensor* bw_cell_state_quantized =
+ GetTemporary(context, node, kBwCellStateQuantized);
+ TfLiteTensor* scaling_factors =
+ GetTemporary(context, node, kScalingFactors);
+ TfLiteTensor* prod_scaling_factors =
+ GetTemporary(context, node, kProductScalingFactors);
+ TfLiteTensor* recovered_cell_weights =
+ GetTemporary(context, node, kRecoveredCellWeights);
+ TfLiteStatus fw_pass_status = EvalHybrid(
+ input, fw_input_to_input_weights, fw_input_to_forget_weights,
+ fw_input_to_cell_weights, fw_input_to_output_weights,
+ fw_recurrent_to_input_weights, fw_recurrent_to_forget_weights,
+ fw_recurrent_to_cell_weights, fw_recurrent_to_output_weights,
+ fw_cell_to_input_weights, fw_cell_to_forget_weights,
+ fw_cell_to_output_weights, fw_input_gate_bias, fw_forget_gate_bias,
+ fw_cell_bias, fw_output_gate_bias, fw_projection_weights,
+ fw_projection_bias, params, /*forward_sequence=*/true,
+ fw_scratch_buffer, scaling_factors, prod_scaling_factors,
+ recovered_cell_weights, input_quantized,
+ fw_activation_state_quantized, fw_cell_state_quantized,
+ fw_activation_state, fw_cell_state, fw_output);
+ TF_LITE_ENSURE_OK(context, fw_pass_status);
+
+ TfLiteStatus bw_pass_status = EvalHybrid(
+ input, bw_input_to_input_weights, bw_input_to_forget_weights,
+ bw_input_to_cell_weights, bw_input_to_output_weights,
+ bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights,
+ bw_recurrent_to_cell_weights, bw_recurrent_to_output_weights,
+ bw_cell_to_input_weights, bw_cell_to_forget_weights,
+ bw_cell_to_output_weights, bw_input_gate_bias, bw_forget_gate_bias,
+ bw_cell_bias, bw_output_gate_bias, bw_projection_weights,
+ bw_projection_bias, params, /*forward_sequence=*/false,
+ bw_scratch_buffer, scaling_factors, prod_scaling_factors,
+ recovered_cell_weights, input_quantized,
+ bw_activation_state_quantized, bw_cell_state_quantized,
+ bw_activation_state, bw_cell_state, bw_output);
+ TF_LITE_ENSURE_OK(context, bw_pass_status);
+ return kTfLiteOk;
+ }
+ default:
+ context->ReportError(context, "Type %d is not currently supported.",
+ fw_input_to_output_weights->type);
+ return kTfLiteError;
}
-
- // Check optional tensors, the respective pointers can be null.
- const float* bw_input_to_input_weights_ptr =
- (bw_use_cifg) ? nullptr : bw_input_to_input_weights->data.f;
- const float* bw_recurrent_to_input_weights_ptr =
- (bw_use_cifg) ? nullptr : bw_recurrent_to_input_weights->data.f;
- const float* bw_input_gate_bias_ptr =
- (bw_use_cifg) ? nullptr : bw_input_gate_bias->data.f;
- const float* bw_cell_to_input_weights_ptr =
- (bw_use_peephole && !bw_use_cifg) ? bw_cell_to_input_weights->data.f
- : nullptr;
- const float* bw_cell_to_forget_weights_ptr =
- (bw_use_peephole) ? bw_cell_to_forget_weights->data.f : nullptr;
- const float* bw_cell_to_output_weights_ptr =
- (bw_use_peephole) ? bw_cell_to_output_weights->data.f : nullptr;
- const float* bw_projection_weights_ptr = (bw_projection_weights == nullptr)
- ? nullptr
- : bw_projection_weights->data.f;
- const float* bw_projection_bias_ptr =
- (bw_projection_bias == nullptr) ? nullptr : bw_projection_bias->data.f;
-
- // Loop through the sequence backwards.
- for (int t = max_time - 1; t >= 0; t--) {
- const float* input_ptr_batch = input->data.f + t * n_batch * n_input;
- float* output_ptr_time = bw_output->data.f + t * n_batch * n_bw_output;
-
- kernel_utils::LstmStep(
- input_ptr_batch, bw_input_to_input_weights_ptr,
- bw_input_to_forget_weights->data.f, bw_input_to_cell_weights->data.f,
- bw_input_to_output_weights->data.f, bw_recurrent_to_input_weights_ptr,
- bw_recurrent_to_forget_weights->data.f,
- bw_recurrent_to_cell_weights->data.f,
- bw_recurrent_to_output_weights->data.f, bw_cell_to_input_weights_ptr,
- bw_cell_to_forget_weights_ptr, bw_cell_to_output_weights_ptr,
- bw_input_gate_bias_ptr, bw_forget_gate_bias->data.f,
- bw_cell_bias->data.f, bw_output_gate_bias->data.f,
- bw_projection_weights_ptr, bw_projection_bias_ptr, params, n_batch,
- n_bw_cell, n_input, n_bw_output, bw_activation_state->data.f,
- bw_cell_state->data.f, bw_input_gate_scratch, bw_forget_gate_scratch,
- bw_cell_scratch, bw_output_gate_scratch, output_ptr_time);
- }
-
- // Backward step.
return kTfLiteOk;
}
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h
index e671624fe7..5ca1b4b76f 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h
@@ -79,6 +79,11 @@ void BatchVectorBatchVectorDotProduct(const float* vector1,
n_batch, result, result_stride);
}
+void VectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
+ float* batch_vector) {
+ PortableVectorBatchVectorAdd(vector, v_size, n_batch, batch_vector);
+}
+
void VectorBatchVectorAssign(const float* vector, int v_size, int n_batch,
float* batch_vector) {
PortableVectorBatchVectorAssign(vector, v_size, n_batch, batch_vector);
@@ -138,6 +143,13 @@ void ReductionSumVector(const float* input_vector, float* output_vector,
reduction_size);
}
+void MeanStddevNormalization(const float* input_vector, float* output_vector,
+ int v_size, int n_batch,
+ float normalization_epsilon) {
+ PortableMeanStddevNormalization(input_vector, output_vector, v_size, n_batch,
+ normalization_epsilon);
+}
+
} // namespace tensor_utils
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
index 70adffda3b..2c8e8f90e3 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
@@ -43,6 +43,14 @@ namespace optimized_ops {
// Unoptimized reference ops:
using reference_ops::ArgMax;
using reference_ops::ArgMinMax;
+using reference_ops::Broadcast4DSlowGreater;
+using reference_ops::Broadcast4DSlowGreaterEqual;
+using reference_ops::Broadcast4DSlowGreaterEqualWithScaling;
+using reference_ops::Broadcast4DSlowGreaterWithScaling;
+using reference_ops::Broadcast4DSlowLess;
+using reference_ops::Broadcast4DSlowLessEqual;
+using reference_ops::Broadcast4DSlowLessEqualWithScaling;
+using reference_ops::Broadcast4DSlowLessWithScaling;
using reference_ops::BroadcastAdd4DSlow;
using reference_ops::BroadcastGreater;
using reference_ops::BroadcastGreaterEqual;
@@ -58,8 +66,12 @@ using reference_ops::FakeQuant;
using reference_ops::Gather;
using reference_ops::Greater;
using reference_ops::GreaterEqual;
+using reference_ops::GreaterEqualWithScaling;
+using reference_ops::GreaterWithScaling;
using reference_ops::Less;
using reference_ops::LessEqual;
+using reference_ops::LessEqualWithScaling;
+using reference_ops::LessWithScaling;
using reference_ops::Mean;
using reference_ops::RankOneSelect;
using reference_ops::Relu1;
@@ -67,6 +79,7 @@ using reference_ops::Relu6;
using reference_ops::ReluX;
using reference_ops::Select;
using reference_ops::SpaceToBatchND;
+using reference_ops::Split;
using reference_ops::StridedSlice;
using reference_ops::Transpose;
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h
index 8664ebc4f6..7e53dc2fa2 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h
@@ -117,6 +117,10 @@ void PortableClipVector(const float* vector, int v_size, float abs_limit,
void NeonClipVector(const float* vector, int v_size, float abs_limit,
float* result);
+// Add another vector for each batch in the batch vector.
+void PortableVectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
+ float* batch_vector);
+
// Batch vector initialization with another vector.
void PortableVectorBatchVectorAssign(const float* vector, int v_size,
int n_batch, float* batch_vector);
@@ -172,6 +176,10 @@ void PortableReductionSumVector(const float* input_vector, float* output_vector,
void NeonReductionSumVector(const float* input_vector, float* output_vector,
int output_size, int reduction_size);
+void PortableMeanStddevNormalization(const float* input_vector,
+ float* output_vector, int v_size,
+ int n_batch, float normalization_epsilon);
+
} // namespace tensor_utils
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc
index f882f9910e..544ef16ce1 100644
--- a/tensorflow/contrib/lite/kernels/internal/quantization_util.cc
+++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc
@@ -23,6 +23,32 @@ limitations under the License.
namespace tflite {
+namespace {
+// These constants are used to manipulate the binary representation of doubles.
+// Double-precision binary64 floating point format is:
+// Bit | 63 | 62-52 | 51-0 |
+// | Sign | Exponent | Fraction |
+// To avoid 64-bit integers as much as possible, I break this into high and
+// low 32-bit chunks. High is:
+// Bit | 31 | 30-20 | 19-0 |
+// | Sign | Exponent | High Fraction |
+// Low is:
+// Bit | 31-0 |
+// | Low Fraction |
+// We then access the components through logical bit-wise operations to
+// extract the parts needed, with the positions and masks derived from the
+// layout shown above.
+constexpr uint64_t kSignMask = 0x8000000000000000LL;
+constexpr uint64_t kExponentMask = 0x7ff0000000000000LL;
+constexpr int32_t kExponentShift = 52;
+constexpr int32_t kExponentBias = 1023;
+constexpr uint32_t kExponentIsBadNum = 0x7ff;
+constexpr uint64_t kFractionMask = 0x000fffffffc00000LL;
+constexpr uint32_t kFractionShift = 22;
+constexpr uint32_t kFractionRoundingMask = 0x003fffff;
+constexpr uint32_t kFractionRoundingThreshold = 0x00200000;
+} // namespace
+
void QuantizeMultiplier(double double_multiplier, int32_t* quantized_multiplier,
int* shift) {
if (double_multiplier == 0.) {
@@ -30,8 +56,16 @@ void QuantizeMultiplier(double double_multiplier, int32_t* quantized_multiplier,
*shift = 0;
return;
}
+#ifdef TFLITE_EMULATE_FLOAT
+ // If we're trying to avoid the use of floating-point instructions (for
+ // example on microcontrollers) then use an alternative implementation
+ // that only requires integer and bitwise operations. To enable this, you
+ // need to set the define during the build process for your platform.
+ int64_t q_fixed = IntegerFrExp(double_multiplier, shift);
+#else // TFLITE_EMULATE_FLOAT
const double q = std::frexp(double_multiplier, shift);
auto q_fixed = static_cast<int64_t>(TfLiteRound(q * (1ll << 31)));
+#endif // TFLITE_EMULATE_FLOAT
TFLITE_CHECK(q_fixed <= (1ll << 31));
if (q_fixed == (1ll << 31)) {
q_fixed /= 2;
@@ -60,6 +94,163 @@ void QuantizeMultiplierSmallerThanOneExp(double double_multiplier,
*left_shift = shift;
}
+int64_t IntegerFrExp(double input, int* shift) {
+ // Make sure our assumptions about the double layout hold.
+ TFLITE_CHECK_EQ(8, sizeof(double));
+
+ // We want to access the bits of the input double value directly, which is
+ // tricky to do safely, so use a union to handle the casting.
+ union {
+ double double_value;
+ uint64_t double_as_uint;
+ } cast_union;
+ cast_union.double_value = input;
+ const uint64_t u = cast_union.double_as_uint;
+
+ // If the bitfield is all zeros apart from the sign bit, this is a normalized
+ // zero value, so return standard values for this special case.
+ if ((u & ~kSignMask) == 0) {
+ *shift = 0;
+ return 0;
+ }
+
+ // Deal with NaNs and Infs, which are always indicated with a fixed pattern in
+ // the exponent, and distinguished by whether the fractions are zero or
+ // non-zero.
+ const uint32_t exponent_part = ((u & kExponentMask) >> kExponentShift);
+ if (exponent_part == kExponentIsBadNum) {
+ *shift = std::numeric_limits<int>::max();
+ if (u & kFractionMask) {
+ // NaN, so just return zero (with the exponent set to INT_MAX).
+ return 0;
+ } else {
+ // Infinity, so return +/- INT_MAX.
+ if (u & kSignMask) {
+ return std::numeric_limits<int64_t>::min();
+ } else {
+ return std::numeric_limits<int64_t>::max();
+ }
+ }
+ }
+
+ // The shift is fairly easy to extract from the high bits of the double value,
+ // just by masking it out and applying a bias. The std::frexp() implementation
+ // always returns values between 0.5 and 1.0 though, whereas the exponent
+ // assumes 1.0 to 2.0 is the standard range, so I add on one to match that
+ // interface.
+ *shift = (exponent_part - kExponentBias) + 1;
+
+ // There's an implicit high bit in the double format definition, so make sure
+ // we include that at the top, and then reconstruct the rest of the fractional
+ // value from the remaining fragments.
+ int64_t fraction = 0x40000000 + ((u & kFractionMask) >> kFractionShift);
+
+ // We're cutting off some bits at the bottom, so to exactly match the standard
+ // frexp implementation here we'll apply rounding by adding one to the least
+ // significant bit of the result if the discarded portion is over half of the
+ // maximum.
+ if ((u & kFractionRoundingMask) > kFractionRoundingThreshold) {
+ fraction += 1;
+ }
+ // Negate the fraction if the sign bit was set.
+ if (u & kSignMask) {
+ fraction *= -1;
+ }
+
+ return fraction;
+}
+
+double DoubleFromFractionAndShift(int64_t fraction, int shift) {
+ union {
+ double double_value;
+ uint64_t double_as_uint;
+ } result;
+
+ // Detect NaNs and infinities.
+ if (shift == std::numeric_limits<int>::max()) {
+ if (fraction == 0) {
+ return NAN;
+ } else if (fraction > 0) {
+ return INFINITY;
+ } else {
+ return -INFINITY;
+ }
+ }
+
+ // Return a normalized zero for a zero fraction.
+ if (fraction == 0) {
+ result.double_as_uint = 0;
+ return result.double_value;
+ }
+
+ bool is_negative = (fraction < 0);
+ int64_t encoded_fraction = is_negative ? -fraction : fraction;
+ int64_t encoded_shift = (shift - 1);
+ while (encoded_fraction < 0x40000000) {
+ encoded_fraction *= 2;
+ encoded_shift -= 1;
+ }
+ while (encoded_fraction > 0x80000000) {
+ encoded_fraction /= 2;
+ encoded_shift += 1;
+ }
+ encoded_fraction -= 0x40000000;
+ if (encoded_shift < -1022) {
+ encoded_shift = -1023;
+ } else if (encoded_shift > 1022) {
+ encoded_shift = 1023;
+ }
+ encoded_shift += kExponentBias;
+ uint64_t encoded_sign = is_negative ? kSignMask : 0;
+ result.double_as_uint = encoded_sign | (encoded_shift << kExponentShift) |
+ (encoded_fraction << kFractionShift);
+ return result.double_value;
+}
+
+double IntegerDoubleMultiply(double a, double b) {
+ int a_shift;
+ const int64_t a_fraction = IntegerFrExp(a, &a_shift);
+ int b_shift;
+ const int64_t b_fraction = IntegerFrExp(b, &b_shift);
+ // Detect NaNs and infinities.
+ if (a_shift == std::numeric_limits<int>::max() ||
+ (b_shift == std::numeric_limits<int>::max())) {
+ return NAN;
+ }
+ const int result_shift = a_shift + b_shift + 1;
+ const int64_t result_fraction = (a_fraction * b_fraction) >> 32;
+ return DoubleFromFractionAndShift(result_fraction, result_shift);
+}
+
+int IntegerDoubleCompare(double a, double b) {
+ int a_shift;
+ const int64_t a_fraction = IntegerFrExp(a, &a_shift);
+ int b_shift;
+ const int64_t b_fraction = IntegerFrExp(b, &b_shift);
+
+ // Detect NaNs and infinities.
+ if (a_shift == std::numeric_limits<int>::max() ||
+ (b_shift == std::numeric_limits<int>::max())) {
+ return 1;
+ }
+
+ if ((a_fraction == 0) && (b_fraction < 0)) {
+ return 1;
+ } else if ((a_fraction < 0) && (b_fraction == 0)) {
+ return -1;
+ } else if (a_shift < b_shift) {
+ return -1;
+ } else if (a_shift > b_shift) {
+ return 1;
+ } else if (a_fraction < b_fraction) {
+ return -1;
+ } else if (a_fraction > b_fraction) {
+ return 1;
+ } else {
+ return 0;
+ }
+}
+
void PreprocessSoftmaxScaling(double beta, double input_scale,
int input_integer_bits,
int32_t* quantized_multiplier, int* left_shift) {
@@ -72,8 +263,20 @@ void PreprocessSoftmaxScaling(double beta, double input_scale,
// result is double equivalent of Q0.31 (actually with more precision). Thus
// this generates a Q(input_integer_bits).(31-input_integer_bits)
// representation.
+#ifdef TFLITE_EMULATE_FLOAT
+ const double input_beta = IntegerDoubleMultiply(beta, input_scale);
+ int shift;
+ int64_t fraction = IntegerFrExp(input_beta, &shift);
+ shift += (31 - input_integer_bits);
+ double input_beta_real_multiplier =
+ DoubleFromFractionAndShift(fraction, shift);
+ if (IntegerDoubleCompare(input_beta_real_multiplier, (1ll << 31) - 1.0) > 0) {
+ input_beta_real_multiplier = (1ll << 31) - 1.0;
+ }
+#else // TFLITE_EMULATE_FLOAT
const double input_beta_real_multiplier = std::min(
beta * input_scale * (1 << (31 - input_integer_bits)), (1ll << 31) - 1.0);
+#endif // TFLITE_EMULATE_FLOAT
QuantizeMultiplierGreaterThanOne(input_beta_real_multiplier,
quantized_multiplier, left_shift);
@@ -97,6 +300,12 @@ void PreprocessLogSoftmaxScalingExp(double beta, double input_scale,
}
int CalculateInputRadius(int input_integer_bits, int input_left_shift) {
+#ifdef TFLITE_EMULATE_FLOAT
+ int64_t result = (1 << input_integer_bits) - 1;
+ result <<= (31 - input_integer_bits);
+ result >>= input_left_shift;
+ return result;
+#else // TFLITE_EMULATE_FLOAT
const double max_input_rescaled = 1.0 * ((1 << input_integer_bits) - 1) *
(1ll << (31 - input_integer_bits)) /
(1ll << input_left_shift);
@@ -104,6 +313,7 @@ int CalculateInputRadius(int input_integer_bits, int input_left_shift) {
// After scaling the difference, the result would be at the maximum. Thus we
// must ensure that our value has lower magnitude.
return static_cast<int>(std::floor(max_input_rescaled));
+#endif // TFLITE_EMULATE_FLOAT
}
void NudgeQuantizationRange(const float min, const float max,
diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.h b/tensorflow/contrib/lite/kernels/internal/quantization_util.h
index 9ee4a47fbb..d74a1bac97 100644
--- a/tensorflow/contrib/lite/kernels/internal/quantization_util.h
+++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.h
@@ -195,6 +195,44 @@ void QuantizeMultiplierGreaterThanOne(double double_multiplier,
void QuantizeMultiplier(double double_multiplier, int32_t* quantized_multiplier,
int* shift);
+// Splits a double input value into a returned fraction, and a shift value from
+// the exponent, using only bitwise and integer operations to support
+// microcontrollers and other environments without floating-point support.
+//
+// This is designed to be a replacement for how std::frexp() is used within the
+// QuantizeMultiplier() function, and so has a different signature than the
+// standard version, returning a 64-bit integer rather than a double. This
+// result has a maximum value of 1<<31, with the fraction expressed as a
+// proportion of that maximum.
+//
+// std::frexp() returns NaNs and infinities unmodified, but since we're
+// returning integers that can't represent those values, instead we return
+// a shift of std::numeric_limits<int>::max() for all bad numbers, with an int64
+// result of 0 for NaNs, std:numeric_limits<int64_t>::max() for +INFINITY, and
+// std::numeric_limits<int64_t>::min() for -INFINITY. Denormalized inputs will
+// result in return values that end up truncating some bits at the end,
+// reflecting the loss of precision inherent in denormalization.
+int64_t IntegerFrExp(double input, int* shift);
+
+// Converts an integer fraction in the format produced by IntegerFrExp (where
+// 0x40000000 is 1.0) and an exponent shift (between -1022 and +1022) into an
+// IEEE binary64 double format result. The implementation uses only integer and
+// bitwise operators, so no floating point hardware support or emulation is
+// needed. This is here so quantized operations can run non-time-critical
+// preparation calculations on microcontrollers and other platforms without
+// float support.
+double DoubleFromFractionAndShift(int64_t fraction, int shift);
+
+// Performs a multiplication of two numbers in double format, using only integer
+// and bitwise instructions. This is aimed at supporting housekeeping functions
+// for quantized operations on microcontrollers without floating-point hardware.
+double IntegerDoubleMultiply(double a, double b);
+
+// Returns -1 if a is less than b, 0 if a and b are equal, and +1 if a is
+// greater than b. It is implemented using only integer and logical instructions
+// so that it can be easily run on microcontrollers for quantized operations.
+int IntegerDoubleCompare(double a, double b);
+
// This first creates a multiplier in a double equivalent of
// Q(input_integer_bits).(31-input_integer_bits) representation, with extra
// precision in the double's fractional bits. It then splits the result into
diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc
index 00fc3e91dc..14281f25c6 100644
--- a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc
@@ -191,6 +191,139 @@ TEST(QuantizationUtilTest, ChooseQuantizationParamsZeroPointOnMaxBoundary) {
EXPECT_EQ(qp.zero_point, 255);
}
+TEST(QuantizationUtilTest, IntegerFrExp) {
+ int shift;
+ int64_t result = IntegerFrExp(0.0, &shift);
+ EXPECT_EQ(0, result);
+ EXPECT_EQ(0, shift);
+
+ result = IntegerFrExp(1.0, &shift);
+ EXPECT_NEAR(0x40000000, result, 1);
+ EXPECT_EQ(1, shift);
+
+ result = IntegerFrExp(0.25, &shift);
+ EXPECT_NEAR(0x40000000, result, 1);
+ EXPECT_EQ(-1, shift);
+
+ result = IntegerFrExp(-1.0, &shift);
+ EXPECT_NEAR(-(1 << 30), result, 1);
+ EXPECT_EQ(1, shift);
+
+ result = IntegerFrExp(123.45, &shift);
+ EXPECT_NEAR(2071147315, result, 1);
+ EXPECT_EQ(7, shift);
+
+ result = IntegerFrExp(NAN, &shift);
+ EXPECT_NEAR(0, result, 1);
+ EXPECT_EQ(0x7fffffff, shift);
+
+ result = IntegerFrExp(INFINITY, &shift);
+ EXPECT_NEAR(std::numeric_limits<int64_t>::max(), result, 1);
+ EXPECT_EQ(0x7fffffff, shift);
+
+ result = IntegerFrExp(-INFINITY, &shift);
+ EXPECT_NEAR(std::numeric_limits<int64_t>::min(), result, 1);
+ EXPECT_EQ(0x7fffffff, shift);
+}
+
+TEST(QuantizationUtilTest, IntegerFrExpVersusDouble) {
+ int shift;
+ int32_t result = IntegerFrExp(0.0, &shift);
+ EXPECT_EQ(result, 0);
+ EXPECT_EQ(shift, 0);
+
+ int double_shift;
+ double double_result = std::frexp(0.0, &double_shift);
+ EXPECT_EQ(double_result, 0);
+ EXPECT_EQ(double_shift, 0);
+
+ result = IntegerFrExp(1.0, &shift);
+ EXPECT_NEAR(result, 0x40000000, 1);
+ EXPECT_EQ(shift, 1);
+ double_result = std::frexp(1.0, &double_shift);
+ EXPECT_NEAR(double_result, 0.5, 1e-5);
+ EXPECT_EQ(double_shift, 1);
+
+ result = IntegerFrExp(0.25, &shift);
+ EXPECT_NEAR(result, 0x40000000, 1);
+ EXPECT_EQ(shift, -1);
+ double_result = std::frexp(0.25, &double_shift);
+ EXPECT_NEAR(double_result, 0.5, 1e-5);
+ EXPECT_EQ(double_shift, -1);
+
+ result = IntegerFrExp(-1.0, &shift);
+ EXPECT_NEAR(result, -(1 << 30), 1);
+ EXPECT_EQ(shift, 1);
+ double_result = std::frexp(-1.0, &double_shift);
+ EXPECT_NEAR(double_result, -0.5, 1e-5);
+ EXPECT_EQ(double_shift, 1);
+
+ result = IntegerFrExp(123.45, &shift);
+ EXPECT_NEAR(result, (0.964453 * (1L << 31)), 1000);
+ EXPECT_EQ(shift, 7);
+ double_result = std::frexp(123.45, &double_shift);
+ EXPECT_NEAR(double_result, 0.964453, 1e-5);
+ EXPECT_EQ(double_shift, 7);
+}
+
+TEST(QuantizationUtilTest, DoubleFromFractionAndShift) {
+ double result = DoubleFromFractionAndShift(0, 0);
+ EXPECT_EQ(0, result);
+
+ result = DoubleFromFractionAndShift(0x40000000, 1);
+ EXPECT_NEAR(1.0, result, 1e-5);
+
+ result = DoubleFromFractionAndShift(0x40000000, 2);
+ EXPECT_NEAR(2.0, result, 1e-5);
+
+ int shift;
+ int64_t fraction = IntegerFrExp(3.0, &shift);
+ result = DoubleFromFractionAndShift(fraction, shift);
+ EXPECT_NEAR(3.0, result, 1e-5);
+
+ fraction = IntegerFrExp(123.45, &shift);
+ result = DoubleFromFractionAndShift(fraction, shift);
+ EXPECT_NEAR(123.45, result, 1e-5);
+
+ fraction = IntegerFrExp(-23.232323, &shift);
+ result = DoubleFromFractionAndShift(fraction, shift);
+ EXPECT_NEAR(-23.232323, result, 1e-5);
+
+ fraction = IntegerFrExp(NAN, &shift);
+ result = DoubleFromFractionAndShift(fraction, shift);
+ EXPECT_TRUE(std::isnan(result));
+
+ fraction = IntegerFrExp(INFINITY, &shift);
+ result = DoubleFromFractionAndShift(fraction, shift);
+ EXPECT_FALSE(std::isfinite(result));
+}
+
+TEST(QuantizationUtilTest, IntegerDoubleMultiply) {
+ EXPECT_NEAR(1.0, IntegerDoubleMultiply(1.0, 1.0), 1e-5);
+ EXPECT_NEAR(2.0, IntegerDoubleMultiply(1.0, 2.0), 1e-5);
+ EXPECT_NEAR(2.0, IntegerDoubleMultiply(2.0, 1.0), 1e-5);
+ EXPECT_NEAR(4.0, IntegerDoubleMultiply(2.0, 2.0), 1e-5);
+ EXPECT_NEAR(0.5, IntegerDoubleMultiply(1.0, 0.5), 1e-5);
+ EXPECT_NEAR(0.25, IntegerDoubleMultiply(0.5, 0.5), 1e-5);
+ EXPECT_NEAR(-1.0, IntegerDoubleMultiply(1.0, -1.0), 1e-5);
+ EXPECT_NEAR(-1.0, IntegerDoubleMultiply(-1.0, 1.0), 1e-5);
+ EXPECT_NEAR(1.0, IntegerDoubleMultiply(-1.0, -1.0), 1e-5);
+ EXPECT_NEAR(15000000.0, IntegerDoubleMultiply(3000.0, 5000.0), 1e-5);
+ EXPECT_TRUE(std::isnan(IntegerDoubleMultiply(NAN, 5000.0)));
+ EXPECT_TRUE(std::isnan(IntegerDoubleMultiply(3000.0, NAN)));
+}
+
+TEST(QuantizationUtilTest, IntegerDoubleCompare) {
+ EXPECT_EQ(-1, IntegerDoubleCompare(0.0, 1.0));
+ EXPECT_EQ(1, IntegerDoubleCompare(1.0, 0.0));
+ EXPECT_EQ(0, IntegerDoubleCompare(1.0, 1.0));
+ EXPECT_EQ(0, IntegerDoubleCompare(0.0, 0.0));
+ EXPECT_EQ(-1, IntegerDoubleCompare(-10.0, 10.0));
+ EXPECT_EQ(1, IntegerDoubleCompare(123.45, 10.0));
+ EXPECT_EQ(1, IntegerDoubleCompare(NAN, INFINITY));
+ EXPECT_EQ(1, IntegerDoubleCompare(INFINITY, NAN));
+}
+
#ifdef GTEST_HAS_DEATH_TEST
TEST(QuantizationUtilTest, ChooseQuantizationParamsInvalidRange) {
EXPECT_DEATH(ChooseQuantizationParams<uint8>(10.0, -30.0), "");
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
index e79e75a898..2a30910c3f 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
@@ -173,6 +173,16 @@ void PortableVectorBatchVectorCwiseProductAccumulate(const float* vector,
}
}
+void PortableVectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
+ float* batch_vector) {
+ for (int b = 0; b < n_batch; b++) {
+ for (int i = 0; i < v_size; ++i) {
+ batch_vector[i] += vector[i];
+ }
+ batch_vector += v_size;
+ }
+}
+
void PortableVectorBatchVectorAssign(const float* vector, int v_size,
int n_batch, float* batch_vector) {
for (int b = 0; b < n_batch; b++) {
@@ -243,5 +253,31 @@ void PortableReductionSumVector(const float* input_vector, float* output_vector,
}
}
+void PortableMeanStddevNormalization(const float* input_vector,
+ float* output_vector, int v_size,
+ int n_batch, float normalization_epsilon) {
+ for (int batch = 0; batch < n_batch; ++batch) {
+ float sum = 0.0f;
+ float sum_sq = 0.0f;
+ for (int i = 0; i < v_size; ++i) {
+ sum += input_vector[i];
+ sum_sq += input_vector[i] * input_vector[i];
+ }
+ const float mean = sum / v_size;
+ float stddev_inv = 0.0f;
+ const float variance = sum_sq / v_size - mean * mean;
+ if (variance == 0) {
+ stddev_inv = 1.0f / sqrt(normalization_epsilon);
+ } else {
+ stddev_inv = 1.0f / sqrt(variance);
+ }
+ for (int i = 0; i < v_size; ++i) {
+ output_vector[i] = (input_vector[i] - mean) * stddev_inv;
+ }
+ input_vector += v_size;
+ output_vector += v_size;
+ }
+}
+
} // namespace tensor_utils
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h
index 3829be0c5e..f5b3a84f07 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h
@@ -87,6 +87,10 @@ void PortableVectorBatchVectorCwiseProductAccumulate(const float* vector,
void PortableVectorBatchVectorAssign(const float* vector, int v_size,
int n_batch, float* batch_vector);
+// Add another vector for each batch in the batch vector.
+void PortableVectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
+ float* batch_vector);
+
// Apply sigmoid to elements of a vector.
void PortableApplySigmoidToVector(const float* vector, int v_size,
float* result);
@@ -125,6 +129,12 @@ void PortableVectorShiftLeft(float* vector, int v_size, float shift_value);
void PortableReductionSumVector(const float* input_vector, float* output_vector,
int output_size, int reduction_size);
+// Layer norm for each batch.
+// normalization_epsilon is added to avoid divergence.
+void PortableMeanStddevNormalization(const float* input_vector,
+ float* output_vector, int v_size,
+ int n_batch, float normalization_epsilon);
+
float Clip(float f, float abs_limit) { return PortableClip(f, abs_limit); }
bool IsZeroVector(const float* vector, int v_size) {
@@ -193,6 +203,11 @@ void BatchVectorBatchVectorDotProduct(const float* vector1,
result, result_stride);
}
+void VectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
+ float* batch_vector) {
+ PortableVectorBatchVectorAdd(vector, v_size, n_batch, batch_vector);
+}
+
void VectorBatchVectorAssign(const float* vector, int v_size, int n_batch,
float* batch_vector) {
PortableVectorBatchVectorAssign(vector, v_size, n_batch, batch_vector);
@@ -240,6 +255,13 @@ void ReductionSumVector(const float* input_vector, float* output_vector,
reduction_size);
}
+void MeanStddevNormalization(const float* input_vector, float* output_vector,
+ int v_size, int n_batch,
+ float normalization_epsilon) {
+ PortableMeanStddevNormalization(input_vector, output_vector, v_size, n_batch,
+ normalization_epsilon);
+}
+
} // namespace tensor_utils
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index 62f7ade7d5..00f9616cc2 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -2524,32 +2524,69 @@ void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
}
template <typename Scalar>
+void Split(const SplitParams& params, const RuntimeShape& input_shape,
+ const Scalar* input_data, const RuntimeShape* const* output_shapes,
+ Scalar* const* output_data) {
+ const int concat_dimensions = input_shape.DimensionsCount();
+ int axis = params.axis < 0 ? params.axis + concat_dimensions : params.axis;
+ int outputs_count = params.num_split;
+ TFLITE_DCHECK_LT(axis, concat_dimensions);
+
+ int64_t concat_size = 0;
+ for (int i = 0; i < outputs_count; i++) {
+ TFLITE_DCHECK_EQ(output_shapes[i]->DimensionsCount(), concat_dimensions);
+ for (int j = 0; j < concat_dimensions; j++) {
+ if (j != axis) {
+ MatchingDim(*output_shapes[i], j, input_shape, j);
+ }
+ }
+ concat_size += output_shapes[i]->Dims(axis);
+ }
+ TFLITE_DCHECK_EQ(concat_size, input_shape.Dims(axis));
+ int64_t outer_size = 1;
+ for (int i = 0; i < axis; ++i) {
+ outer_size *= input_shape.Dims(i);
+ }
+ // For all output arrays,
+ // FlatSize() = outer_size * Dims(axis) * base_inner_size;
+ int64_t base_inner_size = 1;
+ for (int i = axis + 1; i < concat_dimensions; ++i) {
+ base_inner_size *= input_shape.Dims(i);
+ }
+
+ const Scalar* input_ptr = input_data;
+ for (int k = 0; k < outer_size; k++) {
+ for (int i = 0; i < outputs_count; ++i) {
+ const int copy_size = output_shapes[i]->Dims(axis) * base_inner_size;
+ memcpy(output_data[i] + k * copy_size, input_ptr,
+ copy_size * sizeof(Scalar));
+ input_ptr += copy_size;
+ }
+ }
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy Dims<4>.
+template <typename Scalar>
void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims,
int axis, int outputs_count, Scalar* const* output_data,
const Dims<4>* const* output_dims) {
- const int batches = ArraySize(*output_dims[0], 3);
- const int height = ArraySize(*output_dims[0], 2);
- const int width = ArraySize(*output_dims[0], 1);
- const int depth = ArraySize(*output_dims[0], 0);
-
- const int slice_size = ArraySize(*output_dims[0], axis);
-
+ std::vector<RuntimeShape> output_shapes(outputs_count);
+ std::vector<const RuntimeShape*> output_shapes_indirect(outputs_count);
for (int i = 0; i < outputs_count; ++i) {
- int offset = i * slice_size * input_dims.strides[axis];
- for (int b = 0; b < batches; ++b) {
- for (int y = 0; y < height; ++y) {
- for (int x = 0; x < width; ++x) {
- for (int c = 0; c < depth; ++c) {
- auto out = Offset(*output_dims[i], c, x, y, b);
- auto in = Offset(input_dims, c, x, y, b);
- output_data[i][out] = input_data[offset + in];
- }
- }
- }
- }
+ ShapeFromDims(*output_dims[i], &output_shapes[i]);
+ output_shapes_indirect[i] = &output_shapes[i];
}
+ tflite::SplitParams op_params;
+ op_params.axis = 3 - axis;
+ op_params.num_split = outputs_count;
+
+ Split(op_params, DimsToShape(input_dims), input_data,
+ output_shapes_indirect.data(), output_data);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy Dims<4>.
template <FusedActivationFunctionType Ac, typename Scalar>
void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims,
int outputs_count, Scalar* const* output_data,
@@ -2560,9 +2597,8 @@ void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims,
/* height = */ MatchingArraySize(*output_dims[i], 2, input_dims, 2);
/* width = */ MatchingArraySize(*output_dims[i], 1, input_dims, 1);
}
- // for now we dont have a model with a TensorFlowSplit
- // with fused activation function.
- TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
+ // For now we don't have a model with a Split with fused activation.
+ TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone);
TensorFlowSplit(input_data, input_dims, /*axis=*/0, outputs_count,
output_data, output_dims);
@@ -3416,23 +3452,55 @@ inline void Floor(const RuntimeShape& input_shape, const float* input_data,
}
template <typename T>
-inline void Gather(const T* input_data, const Dims<4>& input_dims,
- int input_rank, const int32* coords_data,
- const Dims<4>& coords_dims, T* output_data,
- const Dims<4>& output_dims) {
- TFLITE_DCHECK(coords_dims.sizes[0] == output_dims.sizes[input_rank - 1]);
- int stride = input_dims.strides[input_rank - 1];
+inline void Gather(const tflite::GatherParams& op_params,
+ const RuntimeShape& input_shape, const T* input_data,
+ const RuntimeShape& coords_shape, const int32* coords_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ // TODO(b/80418076): Enable these checks when moving legacy ops to
+ // legacy_reference_ops.
+ //
+ // TFLITE_DCHECK_EQ(coords_shape.DimensionsCount(), 1);
+ const int input_rank = op_params.input_rank;
+ const int gather_dimensions = output_shape.DimensionsCount();
+ TFLITE_DCHECK_LE(input_shape.DimensionsCount(), gather_dimensions);
+ const int axis = gather_dimensions - input_rank;
+ TFLITE_DCHECK_LT(axis, gather_dimensions);
+ TFLITE_DCHECK_GE(axis, 0);
+ const int coords_count = coords_shape.FlatSize();
+ TFLITE_DCHECK_EQ(coords_count, output_shape.Dims(axis));
+
+ int64_t stride = 1;
+ for (int i = axis + 1; i < gather_dimensions; ++i) {
+ stride *= input_shape.Dims(i);
+ }
T* out = output_data;
- for (int i = 0; i < coords_dims.sizes[0]; i++) {
+ for (int i = 0; i < coords_count; ++i) {
TFLITE_DCHECK_GE(coords_data[i], 0);
- TFLITE_DCHECK_LT(coords_data[i], input_dims.sizes[input_rank - 1]);
+ TFLITE_DCHECK_LT(coords_data[i], input_shape.Dims(axis));
const T* in = input_data + coords_data[i] * stride;
memcpy(out, in, sizeof(T) * stride);
out += stride;
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy Dims<4> version.
+// When moving legacy ops to legacy_reference_ops, replace content with looser
+// implementation.
+template <typename T>
+inline void Gather(const T* input_data, const Dims<4>& input_dims,
+ int input_rank, const int32* coords_data,
+ const Dims<4>& coords_dims, T* output_data,
+ const Dims<4>& output_dims) {
+ tflite::GatherParams op_params;
+ op_params.input_rank = input_rank;
+
+ Gather(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(coords_dims), coords_data, DimsToShape(output_dims),
+ output_data);
+}
+
template <typename T>
inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
const RuntimeShape& unextended_input_shape,
@@ -4301,9 +4369,10 @@ template <typename T>
using ComparisonFn = bool (*)(T, T);
template <typename T, ComparisonFn<T> F>
-inline void Comparison(const RuntimeShape& input1_shape, const T* input1_data,
- const RuntimeShape& input2_shape, const T* input2_data,
- const RuntimeShape& output_shape, bool* output_data) {
+inline void ComparisonImpl(
+ const ComparisonParams& op_params, const RuntimeShape& input1_shape,
+ const T* input1_data, const RuntimeShape& input2_shape,
+ const T* input2_data, const RuntimeShape& output_shape, bool* output_data) {
const int64_t flatsize =
MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int64_t i = 0; i < flatsize; ++i) {
@@ -4311,25 +4380,45 @@ inline void Comparison(const RuntimeShape& input1_shape, const T* input1_data,
}
}
+template <ComparisonFn<float> F>
+inline void Comparison(const ComparisonParams& op_params,
+ const RuntimeShape& input1_shape,
+ const float* input1_data,
+ const RuntimeShape& input2_shape,
+ const float* input2_data,
+ const RuntimeShape& output_shape, bool* output_data) {
+ ComparisonImpl<float, F>(op_params, input1_shape, input1_data, input2_shape,
+ input2_data, output_shape, output_data);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
template <typename T, ComparisonFn<T> F>
inline void Comparison(const T* input1_data, const Dims<4>& input1_dims,
const T* input2_data, const Dims<4>& input2_dims,
bool* output_data, const Dims<4>& output_dims) {
- Comparison<T, F>(DimsToShape(input1_dims), input1_data,
- DimsToShape(input2_dims), input2_data,
- DimsToShape(output_dims), output_data);
+ ComparisonParams op_params;
+ // No parameters needed.
+ ComparisonImpl<T, F>(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
}
template <typename T, ComparisonFn<int32> F>
-inline void Comparison(int left_shift, const T* input1_data,
- const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift,
- const T* input2_data, const Dims<4>& input2_dims,
- int32 input2_offset, int32 input2_multiplier,
- int input2_shift, bool* output_data,
- const Dims<4>& output_dims) {
+inline void ComparisonWithScaling(
+ const ComparisonParams& op_params, const RuntimeShape& input1_shape,
+ const T* input1_data, const RuntimeShape& input2_shape,
+ const T* input2_data, const RuntimeShape& output_shape, bool* output_data) {
+ int left_shift = op_params.left_shift;
+ int32 input1_offset = op_params.input1_offset;
+ int32 input1_multiplier = op_params.input1_multiplier;
+ int input1_shift = op_params.input1_shift;
+ int32 input2_offset = op_params.input2_offset;
+ int32 input2_multiplier = op_params.input2_multiplier;
+ int input2_shift = op_params.input2_shift;
+
const int64_t flatsize =
- MatchingFlatSize(input1_dims, input2_dims, output_dims);
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int64_t i = 0; i < flatsize; ++i) {
const int32 input1_val = input1_offset + input1_data[i];
const int32 input2_val = input2_offset + input2_data[i];
@@ -4337,68 +4426,140 @@ inline void Comparison(int left_shift, const T* input1_data,
const int32 shifted_input2_val = input2_val * (1 << left_shift);
const int32 scaled_input1_val =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
- shifted_input1_val, input1_multiplier,
- kReverseShift * input1_shift);
+ shifted_input1_val, input1_multiplier, input1_shift);
const int32 scaled_input2_val =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
- shifted_input2_val, input2_multiplier,
- kReverseShift * input2_shift);
+ shifted_input2_val, input2_multiplier, input2_shift);
output_data[i] = F(scaled_input1_val, scaled_input2_val);
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+template <typename T, ComparisonFn<int32> F>
+inline void Comparison(int left_shift, const T* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift,
+ const T* input2_data, const Dims<4>& input2_dims,
+ int32 input2_offset, int32 input2_multiplier,
+ int input2_shift, bool* output_data,
+ const Dims<4>& output_dims) {
+ tflite::ComparisonParams op_params;
+ op_params.left_shift = left_shift;
+ op_params.input1_offset = input1_offset;
+ op_params.input1_multiplier = input1_multiplier;
+ op_params.input1_shift = kReverseShift * input1_shift;
+ op_params.input2_offset = input2_offset;
+ op_params.input2_multiplier = input2_multiplier;
+ op_params.input2_shift = kReverseShift * input2_shift;
+
+ ComparisonWithScaling<T, F>(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
template <typename T, ComparisonFn<T> F>
-inline void BroadcastComparison(const T* input1_data,
- const Dims<4>& input1_dims,
- const T* input2_data,
- const Dims<4>& input2_dims, bool* output_data,
- const Dims<4>& output_dims) {
+inline void BroadcastComparison4DSlowImpl(
+ const ComparisonParams& op_params,
+ const RuntimeShape& unextended_input1_shape, const T* input1_data,
+ const RuntimeShape& unextended_input2_shape, const T* input2_data,
+ const RuntimeShape& unextended_output_shape, bool* output_data) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastComparison4DSlow");
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
- F(input1_data[SubscriptToIndex(desc1, c, x, y, b)],
- input2_data[SubscriptToIndex(desc2, c, x, y, b)]);
+ NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+ unextended_input2_shape, &desc1, &desc2);
+
+ for (int b = 0; b < output_shape.Dims(0); ++b) {
+ for (int y = 0; y < output_shape.Dims(1); ++y) {
+ for (int x = 0; x < output_shape.Dims(2); ++x) {
+ for (int c = 0; c < output_shape.Dims(3); ++c) {
+ output_data[Offset(output_shape, b, y, x, c)] =
+ F(input1_data[SubscriptToIndex(desc1, b, y, x, c)],
+ input2_data[SubscriptToIndex(desc2, b, y, x, c)]);
}
}
}
}
}
+template <ComparisonFn<float> F>
+inline void BroadcastComparison4DSlow(const ComparisonParams& op_params,
+ const RuntimeShape& input1_shape,
+ const float* input1_data,
+ const RuntimeShape& input2_shape,
+ const float* input2_data,
+ const RuntimeShape& output_shape,
+ bool* output_data) {
+ BroadcastComparison4DSlowImpl<float, F>(op_params, input1_shape, input1_data,
+ input2_shape, input2_data,
+ output_shape, output_data);
+}
-template <typename T, ComparisonFn<int32> F>
-inline void BroadcastComparison(int left_shift, const T* input1_data,
- const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift,
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+template <typename T, ComparisonFn<T> F>
+inline void BroadcastComparison(const T* input1_data,
+ const Dims<4>& input1_dims,
const T* input2_data,
- const Dims<4>& input2_dims, int32 input2_offset,
- int32 input2_multiplier, int input2_shift,
- bool* output_data, const Dims<4>& output_dims) {
+ const Dims<4>& input2_dims, bool* output_data,
+ const Dims<4>& output_dims) {
+ ComparisonParams op_params;
+ // No parameters needed.
+ BroadcastComparison4DSlowImpl<T, F>(op_params, DimsToShape(input1_dims),
+ input1_data, DimsToShape(input2_dims),
+ input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <typename T, ComparisonFn<int32> F>
+inline void BroadcastComparison4DSlowWithScaling(
+ const ComparisonParams& op_params,
+ const RuntimeShape& unextended_input1_shape, const T* input1_data,
+ const RuntimeShape& unextended_input2_shape, const T* input2_data,
+ const RuntimeShape& unextended_output_shape, bool* output_data) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastComparison4DSlowWithScaling");
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
+ NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+ unextended_input2_shape, &desc1, &desc2);
+
+ int left_shift = op_params.left_shift;
+ int32 input1_offset = op_params.input1_offset;
+ int32 input1_multiplier = op_params.input1_multiplier;
+ int input1_shift = op_params.input1_shift;
+ int32 input2_offset = op_params.input2_offset;
+ int32 input2_multiplier = op_params.input2_multiplier;
+ int input2_shift = op_params.input2_shift;
+
+ for (int b = 0; b < output_shape.Dims(0); ++b) {
+ for (int y = 0; y < output_shape.Dims(1); ++y) {
+ for (int x = 0; x < output_shape.Dims(2); ++x) {
+ for (int c = 0; c < output_shape.Dims(3); ++c) {
const int32 input1_val =
- input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)];
+ input1_offset + input1_data[SubscriptToIndex(desc1, b, y, x, c)];
const int32 input2_val =
- input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)];
+ input2_offset + input2_data[SubscriptToIndex(desc2, b, y, x, c)];
const int32 shifted_input1_val = input1_val * (1 << left_shift);
const int32 shifted_input2_val = input2_val * (1 << left_shift);
const int32 scaled_input1_val =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
- shifted_input1_val, input1_multiplier,
- kReverseShift * input1_shift);
+ shifted_input1_val, input1_multiplier, input1_shift);
const int32 scaled_input2_val =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
- shifted_input2_val, input2_multiplier,
- kReverseShift * input2_shift);
- output_data[Offset(output_dims, c, x, y, b)] =
+ shifted_input2_val, input2_multiplier, input2_shift);
+ output_data[Offset(output_shape, b, y, x, c)] =
F(scaled_input1_val, scaled_input2_val);
}
}
@@ -4406,51 +4567,117 @@ inline void BroadcastComparison(int left_shift, const T* input1_data,
}
}
-#define TFLITE_COMPARISON_OP(name) \
- template <typename T> \
- inline void name(const T* input1_data, const Dims<4>& input1_dims, \
- const T* input2_data, const Dims<4>& input2_dims, \
- bool* output_data, const Dims<4>& output_dims) { \
- gemmlowp::ScopedProfilingLabel label(#name); \
- Comparison<T, name##Fn>(input1_data, input1_dims, input2_data, \
- input2_dims, output_data, output_dims); \
- } \
- template <typename T> \
- inline void name( \
- int left_shift, const T* input1_data, const Dims<4>& input1_dims, \
- int32 input1_offset, int32 input1_multiplier, int input1_shift, \
- const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset, \
- int32 input2_multiplier, int input2_shift, bool* output_data, \
- const Dims<4>& output_dims) { \
- gemmlowp::ScopedProfilingLabel label(#name "/8bit"); \
- Comparison<T, name##Fn>(left_shift, input1_data, input1_dims, \
- input1_offset, input1_multiplier, input1_shift, \
- input2_data, input2_dims, input2_offset, \
- input2_multiplier, input2_shift, output_data, \
- output_dims); \
- } \
- template <typename T> \
- inline void Broadcast##name( \
- const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, \
- const Dims<4>& input2_dims, bool* output_data, \
- const Dims<4>& output_dims) { \
- gemmlowp::ScopedProfilingLabel label("Broadcast" #name); \
- BroadcastComparison<T, name##Fn>(input1_data, input1_dims, input2_data, \
- input2_dims, output_data, output_dims); \
- } \
- template <typename T> \
- inline void Broadcast##name( \
- int left_shift, const T* input1_data, const Dims<4>& input1_dims, \
- int32 input1_offset, int32 input1_multiplier, int input1_shift, \
- const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset, \
- int32 input2_multiplier, int input2_shift, bool* output_data, \
- const Dims<4>& output_dims) { \
- gemmlowp::ScopedProfilingLabel label("Broadcast" #name "/8bit"); \
- BroadcastComparison<T, name##Fn>(left_shift, input1_data, input1_dims, \
- input1_offset, input1_multiplier, \
- input1_shift, input2_data, input2_dims, \
- input2_offset, input2_multiplier, \
- input2_shift, output_data, output_dims); \
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+template <typename T, ComparisonFn<int32> F>
+inline void BroadcastComparison(int left_shift, const T* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift,
+ const T* input2_data,
+ const Dims<4>& input2_dims, int32 input2_offset,
+ int32 input2_multiplier, int input2_shift,
+ bool* output_data, const Dims<4>& output_dims) {
+ ComparisonParams op_params;
+
+ op_params.left_shift = left_shift;
+ op_params.input1_offset = input1_offset;
+ op_params.input1_multiplier = input1_multiplier;
+ op_params.input1_shift = kReverseShift * input1_shift;
+ op_params.input2_offset = input2_offset;
+ op_params.input2_multiplier = input2_multiplier;
+ op_params.input2_shift = kReverseShift * input2_shift;
+
+ BroadcastComparison4DSlowWithScaling<T, F>(
+ op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+#define TFLITE_COMPARISON_OP(name) \
+ template <typename T> \
+ inline void name(const T* input1_data, const Dims<4>& input1_dims, \
+ const T* input2_data, const Dims<4>& input2_dims, \
+ bool* output_data, const Dims<4>& output_dims) { \
+ gemmlowp::ScopedProfilingLabel label(#name); \
+ Comparison<T, name##Fn>(input1_data, input1_dims, input2_data, \
+ input2_dims, output_data, output_dims); \
+ } \
+ template <typename T> \
+ inline void name( \
+ int left_shift, const T* input1_data, const Dims<4>& input1_dims, \
+ int32 input1_offset, int32 input1_multiplier, int input1_shift, \
+ const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset, \
+ int32 input2_multiplier, int input2_shift, bool* output_data, \
+ const Dims<4>& output_dims) { \
+ gemmlowp::ScopedProfilingLabel label(#name "/8bit"); \
+ Comparison<T, name##Fn>(left_shift, input1_data, input1_dims, \
+ input1_offset, input1_multiplier, input1_shift, \
+ input2_data, input2_dims, input2_offset, \
+ input2_multiplier, input2_shift, output_data, \
+ output_dims); \
+ } \
+ template <typename T> \
+ inline void Broadcast##name( \
+ const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, \
+ const Dims<4>& input2_dims, bool* output_data, \
+ const Dims<4>& output_dims) { \
+ gemmlowp::ScopedProfilingLabel label("Broadcast" #name); \
+ BroadcastComparison<T, name##Fn>(input1_data, input1_dims, input2_data, \
+ input2_dims, output_data, output_dims); \
+ } \
+ template <typename T> \
+ inline void Broadcast##name( \
+ int left_shift, const T* input1_data, const Dims<4>& input1_dims, \
+ int32 input1_offset, int32 input1_multiplier, int input1_shift, \
+ const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset, \
+ int32 input2_multiplier, int input2_shift, bool* output_data, \
+ const Dims<4>& output_dims) { \
+ gemmlowp::ScopedProfilingLabel label("Broadcast" #name "/8bit"); \
+ BroadcastComparison<T, name##Fn>(left_shift, input1_data, input1_dims, \
+ input1_offset, input1_multiplier, \
+ input1_shift, input2_data, input2_dims, \
+ input2_offset, input2_multiplier, \
+ input2_shift, output_data, output_dims); \
+ } \
+ inline void name(const ComparisonParams& op_params, \
+ const RuntimeShape& input1_shape, const float* input1_data, \
+ const RuntimeShape& input2_shape, const float* input2_data, \
+ const RuntimeShape& output_shape, bool* output_data) { \
+ gemmlowp::ScopedProfilingLabel label(#name); \
+ Comparison<name##Fn>(op_params, input1_shape, input1_data, input2_shape, \
+ input2_data, output_shape, output_data); \
+ } \
+ template <typename T> \
+ inline void name##WithScaling( \
+ const ComparisonParams& op_params, const RuntimeShape& input1_shape, \
+ const T* input1_data, const RuntimeShape& input2_shape, \
+ const T* input2_data, const RuntimeShape& output_shape, \
+ bool* output_data) { \
+ gemmlowp::ScopedProfilingLabel label(#name "/8bit"); \
+ ComparisonWithScaling<T, name##Fn>(op_params, input1_shape, input1_data, \
+ input2_shape, input2_data, \
+ output_shape, output_data); \
+ } \
+ inline void Broadcast4DSlow##name( \
+ const ComparisonParams& op_params, const RuntimeShape& input1_shape, \
+ const float* input1_data, const RuntimeShape& input2_shape, \
+ const float* input2_data, const RuntimeShape& output_shape, \
+ bool* output_data) { \
+ gemmlowp::ScopedProfilingLabel label("Broadcast" #name); \
+ BroadcastComparison4DSlow<name##Fn>(op_params, input1_shape, input1_data, \
+ input2_shape, input2_data, \
+ output_shape, output_data); \
+ } \
+ template <typename T> \
+ inline void Broadcast4DSlow##name##WithScaling( \
+ const ComparisonParams& op_params, const RuntimeShape& input1_shape, \
+ const T* input1_data, const RuntimeShape& input2_shape, \
+ const T* input2_data, const RuntimeShape& output_shape, \
+ bool* output_data) { \
+ gemmlowp::ScopedProfilingLabel label("Broadcast" #name "/8bit"); \
+ BroadcastComparison4DSlowWithScaling<T, name##Fn>( \
+ op_params, input1_shape, input1_data, input2_shape, input2_data, \
+ output_shape, output_data); \
}
TFLITE_COMPARISON_OP(Equal);
TFLITE_COMPARISON_OP(NotEqual);
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h
index 748356d1bd..1439bf8c37 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h
@@ -113,6 +113,10 @@ void VectorBatchVectorCwiseProductAccumulate(const float* vector, int v_size,
const float* batch_vector,
int n_batch, float* result);
+// Add another vector for each batch in the batch vector.
+void VectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
+ float* batch_vector);
+
// Batch vector initialization with another vector.
void VectorBatchVectorAssign(const float* vector, int v_size, int n_batch,
float* batch_vector);
@@ -152,6 +156,12 @@ void VectorShiftLeft(float* vector, int v_size, float shift_value);
// added to get one element of output.
void ReductionSumVector(const float* input_vector, float* output_vector,
int output_size, int reduction_size);
+
+// Layer norm for each batch.
+// normalization_epsilon is added to avoid divergence.
+void MeanStddevNormalization(const float* input_vector, float* output_vector,
+ int v_size, int n_batch,
+ float normalization_epsilon);
} // namespace tensor_utils
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc
index 240fb64ca3..dad924fc28 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc
@@ -496,6 +496,16 @@ TEST(uKernels, VectorVectorCwiseProductAccumulateTest) {
{1.0, 1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4, 1.45})));
}
+TEST(uKernels, VectorBatchVectorAddTest) {
+ constexpr int kVectorSize = 3;
+ constexpr int kBatchSize = 2;
+ static float input[kVectorSize] = {0.0, -0.5, 1.0};
+ std::vector<float> output = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
+ VectorBatchVectorAdd(input, kVectorSize, kBatchSize, output.data());
+ EXPECT_THAT(output,
+ testing::ElementsAreArray({1.0, 1.5, 4.0, 4.0, 4.5, 7.0}));
+}
+
TEST(uKernels, VectorBatchVectorAssignTest) {
constexpr int kVectorSize = 5;
constexpr int kBatchSize = 3;
@@ -712,5 +722,85 @@ TEST(uKernels, ReductionSumVectorTest) {
EXPECT_THAT(result2, ElementsAreArray(ArrayFloatNear({1.0, 3.5})));
}
+TEST(uKernels, MeanStddevNormalizationNoneZeroInput) {
+ constexpr int kVectorSize = 4;
+ constexpr int kBatchSize = 2;
+ constexpr float kNormalizationEpsilon = 1e-8;
+
+ // None-zero input.
+ static float input[kVectorSize * kBatchSize] = {
+ 0.1, 0.2, 0.3, 0.4, // batch 0
+ 0.9, 1.0, 1.1, 1.2, // batch 1
+ };
+ std::vector<float> output(kVectorSize * kBatchSize);
+ MeanStddevNormalization(input, output.data(), kVectorSize, kBatchSize,
+ kNormalizationEpsilon);
+ const std::vector<float> expected_output = {
+ -1.34164071, -0.447213531, 0.44721365, 1.34164071, // batch 0
+ -1.34163153, -0.447210163, 0.447211236, 1.3416326, // batch 1
+ };
+ EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
+}
+
+TEST(uKernels, MeanStddevNormalizationAllZeroInput) {
+ constexpr int kVectorSize = 4;
+ constexpr int kBatchSize = 2;
+ constexpr float kNormalizationEpsilon = 1e-8;
+
+ // Zero input.
+ static float input[kVectorSize * kBatchSize] = {
+ 0.0, 0.0, 0.0, 0.0, // batch 0
+ 0.0, 0.0, 0.0, 0.0, // batch 1
+ };
+ std::vector<float> output(kVectorSize * kBatchSize);
+ MeanStddevNormalization(input, output.data(), kVectorSize, kBatchSize,
+ kNormalizationEpsilon);
+ const std::vector<float> expected_output = {
+ 0.0, 0.0, 0.0, 0.0, // batch 0
+ 0.0, 0.0, 0.0, 0.0, // batch 1
+ };
+ EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
+}
+
+TEST(uKernels, MeanStddevNormalizationMixed) {
+ constexpr int kVectorSize = 4;
+ constexpr int kBatchSize = 2;
+ constexpr float kNormalizationEpsilon = 1e-8;
+
+ // Mix of zero and non-zero input.
+ static float input[kVectorSize * kBatchSize] = {
+ 0.0, 0.0, 0.0, 0.0, // batch 0
+ 0.1, 0.2, 0.3, 0.4, // batch 1
+ };
+ std::vector<float> output(kVectorSize * kBatchSize);
+ MeanStddevNormalization(input, output.data(), kVectorSize, kBatchSize,
+ kNormalizationEpsilon);
+ const std::vector<float> expected_output = {
+ 0.0, 0.0, 0.0, 0.0, // batch 0
+ -1.34164071, -0.447213531, 0.44721365, 1.34164071, // batch 1
+ };
+ EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
+}
+
+TEST(uKernels, MeanStddevNormalizationSmallValue) {
+ constexpr int kVectorSize = 4;
+ constexpr int kBatchSize = 2;
+ constexpr float kNormalizationEpsilon = 1e-8;
+
+ // Mix of zero and non-zero input.
+ static float input[kVectorSize * kBatchSize] = {
+ 3e-5, -7e-6, -9e-5, 1e-6, // batch 0
+ 4e-5, 9e-6, 2e-4, 0.0, // batch 1
+ };
+ std::vector<float> output(kVectorSize * kBatchSize);
+ MeanStddevNormalization(input, output.data(), kVectorSize, kBatchSize,
+ kNormalizationEpsilon);
+ const std::vector<float> expected_output = {
+ 1.04231524, 0.212946132, -1.64753067, 0.392269224, // batch 0
+ -0.275023013, -0.658201098, 1.70267045, -0.769446373, // batch 1
+ };
+ EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
+}
+
} // namespace tensor_utils
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h
index 3b296f024f..9f6e74a267 100644
--- a/tensorflow/contrib/lite/kernels/internal/types.h
+++ b/tensorflow/contrib/lite/kernels/internal/types.h
@@ -720,12 +720,12 @@ struct ConcatenationParams {
struct ComparisonParams {
// uint8 inference params.
int left_shift;
- int32 input0_offset;
- int32 input0_multiplier;
- int input0_shift;
int32 input1_offset;
int32 input1_multiplier;
int input1_shift;
+ int32 input2_offset;
+ int32 input2_multiplier;
+ int input2_shift;
// Shape dependent / common to inference types.
bool is_broadcast;
};
@@ -889,6 +889,7 @@ struct SplitParams {
// Graphs that split into, say, 2000 nodes are encountered. The indices in
// OperatorEdges are of type uint16.
uint16 num_split;
+ int16 axis;
};
struct SqueezeParams {
diff --git a/tensorflow/contrib/lite/kernels/layer_norm_lstm.cc b/tensorflow/contrib/lite/kernels/layer_norm_lstm.cc
new file mode 100644
index 0000000000..1bbea67b93
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/layer_norm_lstm.cc
@@ -0,0 +1,1316 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Layer Normalization LSTM op that applies normalization by mean and standard
+// deviation to the activation of the LSTM layers. Please see
+// https://arxiv.org/abs/1607.06450 for details.
+#include "flatbuffers/flexbuffers.h" // flatbuffers
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace layer_norm_lstm {
+
+// Struct to hold Layer Norm LSTM option data.
+struct OpData {
+ TfLiteFusedActivation activation;
+ float cell_clip;
+ float proj_clip;
+ int scratch_tensor_index;
+};
+
+// Input Tensors of size {n_batch, n_input}
+constexpr int kInputTensor = 0;
+
+// Input weight tensors of size: {n_cell, n_input}
+constexpr int kInputToInputWeightsTensor = 1; // Optional
+constexpr int kInputToForgetWeightsTensor = 2;
+constexpr int kInputToCellWeightsTensor = 3;
+constexpr int kInputToOutputWeightsTensor = 4;
+
+// Recurrent weight tensors of size {n_cell, n_output}
+constexpr int kRecurrentToInputWeightsTensor = 5; // Optional
+constexpr int kRecurrentToForgetWeightsTensor = 6;
+constexpr int kRecurrentToCellWeightsTensor = 7;
+constexpr int kRecurrentToOutputWeightsTensor = 8;
+
+// Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
+constexpr int kCellToInputWeightsTensor = 9; // Optional
+constexpr int kCellToForgetWeightsTensor = 10; // Optional
+constexpr int kCellToOutputWeightsTensor = 11; // Optional
+
+// Layer norm weights tensors of size {n_cell}, representing a diagonal matrix.
+constexpr int kInputLayerNormWeightsTensor = 12;
+constexpr int kForgetLayerNormWeightsTensor = 13;
+constexpr int kCellLayerNormWeightsTensor = 14;
+constexpr int kOutputLayerNormWeightsTensor = 15;
+
+// Gates bias tensors of size {n_cell}
+constexpr int kInputGateBiasTensor = 16; // Optional
+constexpr int kForgetGateBiasTensor = 17;
+constexpr int kCellGateBiasTensor = 18;
+constexpr int kOutputGateBiasTensor = 19;
+
+// Projection weight tensor of size {n_output, n_cell}
+constexpr int kProjectionWeightsTensor = 20; // Optional
+// Projection bias tensor of size {n_output}
+constexpr int kProjectionBiasTensor = 21; // Optional
+
+// State tensors.
+constexpr int kInputActivationStateTensor = 22;
+constexpr int kInputCellStateTensor = 23;
+
+// Output tensor.
+constexpr int kOutputTensor = 0;
+
+// Total number of scratch tensors for hybrid Op.
+constexpr int kTensorsToAdd = 7;
+
+// Small float to avoid divergence during calculation of deviation.
+const float kLayerNormEpsilon = 1e-8;
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* data = new OpData;
+
+ // Turn custom option data into flexbuffer map format.
+ const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
+ const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
+
+ // Get activation function, cell_clip and proj_clip from the flexbuffer.
+ // TODO(b/113824099): make activation more generic.
+ assert(m["fused_activation_function"].ToString() == "TANH");
+ data->activation = kTfLiteActTanh;
+ data->cell_clip = m["cell_clip"].AsFloat();
+ data->proj_clip = m["proj_clip"].AsFloat();
+
+ // Populate scratch_tensor_index.
+ context->AddTensors(context, /*tensors_to_add=*/kTensorsToAdd,
+ &data->scratch_tensor_index);
+ return data;
+}
+
+// Check that input tensor dimensions matches with each other.
+TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
+ TfLiteNode* node, int n_input,
+ int n_output, int n_cell) {
+ const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
+
+ // Making sure clipping parameters have valid values.
+ // == 0 means no clipping
+ // > 0 means clipping
+ TF_LITE_ENSURE(context, op_data->cell_clip >= 0);
+ TF_LITE_ENSURE(context, op_data->proj_clip >= 0);
+
+ const TfLiteTensor* input_to_input_weights =
+ GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
+ if (input_to_input_weights != nullptr) {
+ TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell);
+ TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input);
+ }
+
+ const TfLiteTensor* input_to_forget_weights =
+ GetInput(context, node, kInputToForgetWeightsTensor);
+ TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell);
+ TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input);
+
+ const TfLiteTensor* input_to_cell_weights =
+ GetInput(context, node, kInputToCellWeightsTensor);
+ TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell);
+ TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input);
+
+ const TfLiteTensor* recurrent_to_input_weights =
+ GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
+ if (recurrent_to_input_weights != nullptr) {
+ TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0],
+ n_cell);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1],
+ n_output);
+ }
+
+ const TfLiteTensor* recurrent_to_forget_weights =
+ GetInput(context, node, kRecurrentToForgetWeightsTensor);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0],
+ n_cell);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1],
+ n_output);
+
+ const TfLiteTensor* recurrent_to_cell_weights =
+ GetInput(context, node, kRecurrentToCellWeightsTensor);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1],
+ n_output);
+
+ // We make sure the input-gate's parameters are either both present (regular
+ // LSTM) or not at all (CIFG-LSTM).
+ const bool cifg_weights_all_or_none =
+ ((input_to_input_weights != nullptr) &&
+ (recurrent_to_input_weights != nullptr)) ||
+ ((input_to_input_weights == nullptr) &&
+ (recurrent_to_input_weights == nullptr));
+ TF_LITE_ENSURE(context, cifg_weights_all_or_none == true);
+
+ const TfLiteTensor* cell_to_input_weights =
+ GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
+ if (cell_to_input_weights) {
+ TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell);
+ }
+
+ const TfLiteTensor* cell_to_forget_weights =
+ GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
+ if (cell_to_forget_weights) {
+ TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell);
+ }
+
+ const TfLiteTensor* cell_to_output_weights =
+ GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
+ if (cell_to_output_weights) {
+ TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell);
+ }
+
+ // Making sure the peephole weights are there all or none.
+ const bool use_cifg = (input_to_input_weights == nullptr);
+ const bool peephole_weights_all_or_none =
+ ((cell_to_input_weights != nullptr || use_cifg) &&
+ (cell_to_forget_weights != nullptr) &&
+ (cell_to_output_weights != nullptr)) ||
+ ((cell_to_input_weights == nullptr) &&
+ (cell_to_forget_weights == nullptr) &&
+ (cell_to_output_weights == nullptr));
+ TF_LITE_ENSURE(context, peephole_weights_all_or_none == true);
+
+ // Making sure layer norm weights are not null and have the right dimension.
+ const TfLiteTensor* input_layer_norm_weights =
+ GetInput(context, node, kInputLayerNormWeightsTensor);
+ TF_LITE_ENSURE(context, input_layer_norm_weights != nullptr);
+ TF_LITE_ENSURE_EQ(context, input_layer_norm_weights->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, input_layer_norm_weights->dims->data[0], n_cell);
+
+ const TfLiteTensor* forget_layer_norm_weights =
+ GetInput(context, node, kForgetLayerNormWeightsTensor);
+ TF_LITE_ENSURE(context, forget_layer_norm_weights != nullptr);
+ TF_LITE_ENSURE_EQ(context, forget_layer_norm_weights->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, forget_layer_norm_weights->dims->data[0], n_cell);
+
+ const TfLiteTensor* cell_layer_norm_weights =
+ GetInput(context, node, kCellLayerNormWeightsTensor);
+ TF_LITE_ENSURE(context, cell_layer_norm_weights != nullptr);
+ TF_LITE_ENSURE_EQ(context, cell_layer_norm_weights->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, cell_layer_norm_weights->dims->data[0], n_cell);
+
+ const TfLiteTensor* output_layer_norm_weights =
+ GetInput(context, node, kOutputLayerNormWeightsTensor);
+ TF_LITE_ENSURE(context, output_layer_norm_weights != nullptr);
+ TF_LITE_ENSURE_EQ(context, output_layer_norm_weights->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, output_layer_norm_weights->dims->data[0], n_cell);
+
+ // Make sure the input gate bias is present only when not a CIFG-LSTM.
+ const TfLiteTensor* input_gate_bias =
+ GetOptionalInputTensor(context, node, kInputGateBiasTensor);
+ if (use_cifg) {
+ TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr);
+ } else {
+ TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell);
+ }
+
+ const TfLiteTensor* forget_gate_bias =
+ GetInput(context, node, kForgetGateBiasTensor);
+ TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
+
+ const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
+ TF_LITE_ENSURE_EQ(context, cell_bias->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, cell_bias->dims->data[0], n_cell);
+
+ const TfLiteTensor* output_gate_bias =
+ GetInput(context, node, kOutputGateBiasTensor);
+ TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell);
+
+ const TfLiteTensor* projection_weights =
+ GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
+ if (projection_weights != nullptr) {
+ TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output);
+ TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell);
+ }
+
+ const TfLiteTensor* projection_bias =
+ GetOptionalInputTensor(context, node, kProjectionBiasTensor);
+ if (projection_bias != nullptr) {
+ TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output);
+ }
+
+ // Making sure the projection tensors are consistent:
+ // 1) If projection weight is not present, then projection bias should not be
+ // present.
+ // 2) If projection weight is present, then projection bias is optional.
+ const bool projection_tensors_consistent =
+ ((projection_weights != nullptr) || (projection_bias == nullptr));
+ TF_LITE_ENSURE(context, projection_tensors_consistent == true);
+
+ return kTfLiteOk;
+}
+
+// Resize the output, state tensors based on the sizes of the input tensors.
+// Allocate a temporary scratch tensor. Also check that the sizes of the input
+// tensors match each other.
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 24);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
+
+ // Inferring batch size, number of outputs and number of cells from the
+ // input tensors.
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
+ TF_LITE_ENSURE(context, input->dims->size > 1);
+ const int n_batch = input->dims->data[0];
+ const int n_input = input->dims->data[1];
+
+ const TfLiteTensor* input_to_output_weights =
+ GetInput(context, node, kInputToOutputWeightsTensor);
+ const int n_cell = input_to_output_weights->dims->data[0];
+ TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input);
+
+ const TfLiteTensor* recurrent_to_output_weights =
+ GetInput(context, node, kRecurrentToOutputWeightsTensor);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0],
+ n_cell);
+ const int n_output = recurrent_to_output_weights->dims->data[1];
+
+ // Check that input tensor dimensions matches with each other.
+ TF_LITE_ENSURE_OK(context, CheckInputTensorDimensions(context, node, n_input,
+ n_output, n_cell));
+
+ // Get the pointer to output, activation_state and cell_state tensors.
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ const TfLiteTensor* activation_state =
+ GetInput(context, node, kInputActivationStateTensor);
+ const TfLiteTensor* cell_state =
+ GetInput(context, node, kInputCellStateTensor);
+
+ // Check the shape of input state tensors.
+ // These tensor may be 1D or 2D. It's fine as long as the total size is
+ // correct.
+ TF_LITE_ENSURE_EQ(context, NumElements(activation_state), n_batch * n_output);
+ TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell);
+ // Resize the output tensors.
+ TfLiteIntArray* output_size = TfLiteIntArrayCreate(2);
+ output_size->data[0] = n_batch;
+ output_size->data[1] = n_output;
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, output, output_size));
+
+ // The weights are of consistent type, so it suffices to check one.
+ const bool is_hybrid_op = (input_to_output_weights->type == kTfLiteUInt8 &&
+ input->type == kTfLiteFloat32);
+
+ TfLiteIntArrayFree(node->temporaries);
+ if (is_hybrid_op) {
+ node->temporaries = TfLiteIntArrayCreate(7);
+ } else {
+ node->temporaries = TfLiteIntArrayCreate(1);
+ }
+ node->temporaries->data[0] = op_data->scratch_tensor_index;
+
+ // Create a scratch buffer tensor.
+ TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
+ scratch_buffer->type = input->type;
+ scratch_buffer->allocation_type = kTfLiteArenaRw;
+
+ const TfLiteTensor* input_to_input_weights =
+ GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
+ const bool use_cifg = (input_to_input_weights == nullptr);
+ TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
+ scratch_buffer_size->data[0] = n_batch;
+ if (use_cifg) {
+ // Reserving space for Cell, Forget, Output gates
+ scratch_buffer_size->data[1] = n_cell * 3;
+ } else {
+ // Reserving space for Input, Cell, Forget, Output gates
+ scratch_buffer_size->data[1] = n_cell * 4;
+ }
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
+ scratch_buffer_size));
+
+ if (is_hybrid_op) {
+ // Allocate temporary tensors to store quantized values of input,
+ // activation_state and cell_state tensors.
+ node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
+ TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
+ input_quantized->type = kTfLiteUInt8;
+ input_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
+ TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
+ input_quantized_size));
+ }
+ node->temporaries->data[2] = op_data->scratch_tensor_index + 2;
+ TfLiteTensor* activation_state_quantized =
+ GetTemporary(context, node, /*index=*/2);
+ activation_state_quantized->type = kTfLiteUInt8;
+ activation_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(activation_state_quantized->dims,
+ activation_state->dims)) {
+ TfLiteIntArray* activation_state_quantized_size =
+ TfLiteIntArrayCopy(activation_state->dims);
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(context, activation_state_quantized,
+ activation_state_quantized_size));
+ }
+ node->temporaries->data[3] = op_data->scratch_tensor_index + 3;
+ TfLiteTensor* cell_state_quantized =
+ GetTemporary(context, node, /*index=*/3);
+ cell_state_quantized->type = kTfLiteUInt8;
+ cell_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(cell_state_quantized->dims, cell_state->dims)) {
+ TfLiteIntArray* cell_state_quantized_size =
+ TfLiteIntArrayCopy(cell_state->dims);
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, cell_state_quantized,
+ cell_state_quantized_size));
+ }
+
+ // Allocate temporary tensors to store scaling factors and product scaling
+ // factors. The latter is a convenience storage which allows to quantize
+ // a vector once (which produces the scaling factors) and multiply it with
+ // different matrices (which requires multiplying the scaling factors with
+ // the scaling factor of the matrix).
+ node->temporaries->data[4] = op_data->scratch_tensor_index + 4;
+ TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4);
+ scaling_factors->type = kTfLiteFloat32;
+ scaling_factors->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
+ scaling_factors_size->data[0] = n_batch;
+ if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) {
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
+ scaling_factors_size));
+ }
+ node->temporaries->data[5] = op_data->scratch_tensor_index + 5;
+ TfLiteTensor* prod_scaling_factors =
+ GetTemporary(context, node, /*index=*/5);
+ prod_scaling_factors->type = kTfLiteFloat32;
+ prod_scaling_factors->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1);
+ prod_scaling_factors_size->data[0] = n_batch;
+ if (!TfLiteIntArrayEqual(prod_scaling_factors->dims,
+ prod_scaling_factors_size)) {
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, prod_scaling_factors,
+ prod_scaling_factors_size));
+ }
+
+ // Allocate a temporary tensor to store the recovered weights. Since
+ // this is used for diagonal matrices, only need to store n_cell values.
+ node->temporaries->data[6] = op_data->scratch_tensor_index + 6;
+ TfLiteTensor* recovered_weights = GetTemporary(context, node, /*index=*/6);
+ recovered_weights->type = kTfLiteFloat32;
+ recovered_weights->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* recovered_weights_size = TfLiteIntArrayCreate(1);
+ recovered_weights_size->data[0] = n_cell;
+ if (!TfLiteIntArrayEqual(recovered_weights->dims, recovered_weights_size)) {
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, recovered_weights,
+ recovered_weights_size));
+ }
+ }
+ return kTfLiteOk;
+}
+
+void LayerNormLstmStep(
+ const float* input_ptr_batch, const float* input_to_input_weights_ptr,
+ const float* input_to_forget_weights_ptr,
+ const float* input_to_cell_weights_ptr,
+ const float* input_to_output_weights_ptr,
+ const float* recurrent_to_input_weights_ptr,
+ const float* recurrent_to_forget_weights_ptr,
+ const float* recurrent_to_cell_weights_ptr,
+ const float* recurrent_to_output_weights_ptr,
+ const float* cell_to_input_weights_ptr,
+ const float* cell_to_forget_weights_ptr,
+ const float* cell_to_output_weights_ptr,
+ const float* input_layer_norm_weight_ptr,
+ const float* forget_layer_norm_weight_ptr,
+ const float* cell_layer_norm_weight_ptr,
+ const float* output_layer_norm_weight_ptr, const float* input_gate_bias_ptr,
+ const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
+ const float* output_gate_bias_ptr, const float* projection_weights_ptr,
+ const float* projection_bias_ptr, float cell_clip, float proj_clip,
+ const TfLiteFusedActivation& activation, int n_batch, int n_cell,
+ int n_input, int n_output, float* output_state_ptr, float* cell_state_ptr,
+ float* input_gate_scratch, float* forget_gate_scratch, float* cell_scratch,
+ float* output_gate_scratch, float* output_ptr_batch) {
+ // Since we have already checked that weights are all there or none, we can
+ // check the existense of only one to the get the condition.
+ const bool use_cifg = (input_to_input_weights_ptr == nullptr);
+ const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
+
+ // Initialize scratch buffers with 0.
+ if (!use_cifg) {
+ tensor_utils::ZeroVector(input_gate_scratch, n_cell * n_batch);
+ }
+ tensor_utils::ZeroVector(forget_gate_scratch, n_cell * n_batch);
+ tensor_utils::ZeroVector(cell_scratch, n_cell * n_batch);
+ tensor_utils::ZeroVector(output_gate_scratch, n_cell * n_batch);
+
+ // For each batch and cell: compute input_weight * input.
+ if (!use_cifg) {
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_input_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
+ input_gate_scratch, /*result_stride=*/1);
+ }
+
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_forget_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
+ forget_gate_scratch, /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_cell_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
+ cell_scratch, /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_output_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
+ output_gate_scratch, /*result_stride=*/1);
+
+ // For each batch and cell: compute recurrent_weight * output_state.
+ if (!use_cifg) {
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr,
+ n_batch, input_gate_scratch, /*result_stride=*/1);
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr,
+ n_batch, forget_gate_scratch,
+ /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr,
+ n_batch, cell_scratch, /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr,
+ n_batch, output_gate_scratch,
+ /*result_stride=*/1);
+
+ // For each batch and cell: update input gate.
+ if (!use_cifg) {
+ if (use_peephole) {
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch,
+ input_gate_scratch);
+ }
+ tensor_utils::MeanStddevNormalization(input_gate_scratch,
+ input_gate_scratch, n_cell, n_batch,
+ kLayerNormEpsilon);
+ tensor_utils::VectorBatchVectorCwiseProduct(input_layer_norm_weight_ptr,
+ n_cell, input_gate_scratch,
+ n_batch, input_gate_scratch);
+ tensor_utils::VectorBatchVectorAdd(input_gate_bias_ptr, n_cell, n_batch,
+ input_gate_scratch);
+ tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
+ input_gate_scratch);
+ }
+
+ // For each batch and cell: update forget gate.
+ if (use_peephole) {
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch,
+ forget_gate_scratch);
+ }
+ tensor_utils::MeanStddevNormalization(forget_gate_scratch,
+ forget_gate_scratch, n_cell, n_batch,
+ kLayerNormEpsilon);
+ tensor_utils::VectorBatchVectorCwiseProduct(forget_layer_norm_weight_ptr,
+ n_cell, forget_gate_scratch,
+ n_batch, forget_gate_scratch);
+ tensor_utils::VectorBatchVectorAdd(forget_gate_bias_ptr, n_cell, n_batch,
+ forget_gate_scratch);
+ tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
+ forget_gate_scratch);
+
+ // For each batch and cell: update the cell.
+ tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
+ n_batch, kLayerNormEpsilon);
+ tensor_utils::VectorBatchVectorCwiseProduct(
+ cell_layer_norm_weight_ptr, n_cell, cell_scratch, n_batch, cell_scratch);
+ tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch,
+ cell_scratch);
+ tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
+ n_batch * n_cell, cell_state_ptr);
+ tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
+ activation, cell_scratch);
+ if (use_cifg) {
+ tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
+ forget_gate_scratch);
+ tensor_utils::VectorVectorCwiseProductAccumulate(
+ cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
+ } else {
+ tensor_utils::VectorVectorCwiseProductAccumulate(
+ cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
+ }
+ if (cell_clip > 0.0) {
+ tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell, cell_clip,
+ cell_state_ptr);
+ }
+
+ // For each batch and cell: update the output gate.
+ if (use_peephole) {
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch,
+ output_gate_scratch);
+ }
+ tensor_utils::MeanStddevNormalization(output_gate_scratch,
+ output_gate_scratch, n_cell, n_batch,
+ kLayerNormEpsilon);
+ tensor_utils::VectorBatchVectorCwiseProduct(output_layer_norm_weight_ptr,
+ n_cell, output_gate_scratch,
+ n_batch, output_gate_scratch);
+ tensor_utils::VectorBatchVectorAdd(output_gate_bias_ptr, n_cell, n_batch,
+ output_gate_scratch);
+ tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
+ output_gate_scratch);
+ tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
+ activation, cell_scratch);
+ tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
+ n_batch * n_cell, output_gate_scratch);
+
+ // For each batch: update the projection and output_state.
+ const bool use_projection_weight = (projection_weights_ptr != nullptr);
+ const bool use_projection_bias = (projection_bias_ptr != nullptr);
+ if (use_projection_weight) {
+ if (use_projection_bias) {
+ tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
+ n_batch, output_ptr_batch);
+ } else {
+ tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output);
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ projection_weights_ptr, n_output, n_cell, output_gate_scratch, n_batch,
+ output_ptr_batch, /*result_stride=*/1);
+ if (proj_clip > 0.0) {
+ tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output, proj_clip,
+ output_ptr_batch);
+ }
+ } else {
+ tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
+ output_ptr_batch);
+ }
+ tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output,
+ output_state_ptr);
+}
+
+void LayerNormLstmStep(
+ const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr,
+ float input_to_input_weights_scale,
+ const int8_t* input_to_forget_weights_ptr,
+ float input_to_forget_weights_scale,
+ const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale,
+ const int8_t* input_to_output_weights_ptr,
+ float input_to_output_weights_scale,
+ const int8_t* recurrent_to_input_weights_ptr,
+ float recurrent_to_input_weights_scale,
+ const int8_t* recurrent_to_forget_weights_ptr,
+ float recurrent_to_forget_weights_scale,
+ const int8_t* recurrent_to_cell_weights_ptr,
+ float recurrent_to_cell_weights_scale,
+ const int8_t* recurrent_to_output_weights_ptr,
+ float recurrent_to_output_weights_scale,
+ const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
+ const int8_t* cell_to_forget_weights_ptr,
+ float cell_to_forget_weights_scale,
+ const int8_t* cell_to_output_weights_ptr,
+ float cell_to_output_weights_scale,
+ const float* input_layer_norm_weight_ptr,
+ const float* forget_layer_norm_weight_ptr,
+ const float* cell_layer_norm_weight_ptr,
+ const float* output_layer_norm_weight_ptr, const float* input_gate_bias_ptr,
+ const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
+ const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr,
+ float projection_weights_scale, const float* projection_bias_ptr,
+ float cell_clip, float proj_clip, const TfLiteFusedActivation& activation,
+ int n_batch, int n_cell, int n_input, int n_output,
+ float* input_gate_scratch, float* forget_gate_scratch, float* cell_scratch,
+ float* output_gate_scratch, float* scaling_factors,
+ float* product_scaling_factors, float* recovered_weights,
+ int8_t* quantized_input_ptr_batch, int8_t* quantized_output_state_ptr,
+ int8_t* quantized_cell_state_ptr, float* output_state_ptr,
+ float* cell_state_ptr, float* output_ptr_batch) {
+ // Since we have already checked that weights are all there or none, we can
+ // check the existense of only one to the get the condition.
+ const bool use_cifg = (input_to_input_weights_ptr == nullptr);
+ const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
+
+ // Initialize scratch buffers with 0.
+ if (!use_cifg) {
+ tensor_utils::ZeroVector(input_gate_scratch, n_cell * n_batch);
+ }
+ tensor_utils::ZeroVector(forget_gate_scratch, n_cell * n_batch);
+ tensor_utils::ZeroVector(cell_scratch, n_cell * n_batch);
+ tensor_utils::ZeroVector(output_gate_scratch, n_cell * n_batch);
+
+ if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) {
+ // Save quantization and matmul computation for all zero input.
+ float unused_min, unused_max;
+ for (int b = 0; b < n_batch; ++b) {
+ const int offset = b * n_input;
+ tensor_utils::SymmetricQuantizeFloats(
+ input_ptr_batch + offset, n_input, quantized_input_ptr_batch + offset,
+ &unused_min, &unused_max, &scaling_factors[b]);
+ }
+ // For each batch and cell: compute input_weight * input.
+ if (!use_cifg) {
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * input_to_input_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_input_weights_ptr, n_cell, n_input,
+ quantized_input_ptr_batch, product_scaling_factors, n_batch,
+ input_gate_scratch, /*result_stride=*/1);
+ }
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * input_to_forget_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
+ product_scaling_factors, n_batch, forget_gate_scratch,
+ /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * input_to_cell_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
+ product_scaling_factors, n_batch, cell_scratch, /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * input_to_output_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
+ product_scaling_factors, n_batch, output_gate_scratch,
+ /*result_stride=*/1);
+ }
+
+ if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) {
+ // Save quantization and matmul computation for all zero input.
+ float unused_min, unused_max;
+ for (int b = 0; b < n_batch; ++b) {
+ const int offset = b * n_output;
+ tensor_utils::SymmetricQuantizeFloats(output_state_ptr + offset, n_output,
+ quantized_output_state_ptr + offset,
+ &unused_min, &unused_max,
+ &scaling_factors[b]);
+ }
+ // For each batch and cell: compute recurrent_weight * output_state.
+ if (!use_cifg) {
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * recurrent_to_input_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_input_weights_ptr, n_cell, n_output,
+ quantized_output_state_ptr, product_scaling_factors, n_batch,
+ input_gate_scratch, /*result_stride=*/1);
+ }
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * recurrent_to_forget_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_forget_weights_ptr, n_cell, n_output,
+ quantized_output_state_ptr, product_scaling_factors, n_batch,
+ forget_gate_scratch, /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * recurrent_to_cell_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_cell_weights_ptr, n_cell, n_output,
+ quantized_output_state_ptr, product_scaling_factors, n_batch,
+ cell_scratch, /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * recurrent_to_output_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_output_weights_ptr, n_cell, n_output,
+ quantized_output_state_ptr, product_scaling_factors, n_batch,
+ output_gate_scratch, /*result_stride=*/1);
+ }
+
+ // Save quantization and matmul computation for all zero input.
+ bool is_cell_state_all_zeros =
+ tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
+
+ // For each batch and cell: update input gate.
+ if (!use_cifg) {
+ if (use_peephole && !is_cell_state_all_zeros) {
+ tensor_utils::VectorScalarMultiply(cell_to_input_weights_ptr, n_cell,
+ cell_to_input_weights_scale,
+ recovered_weights);
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ recovered_weights, n_cell, cell_state_ptr, n_batch,
+ input_gate_scratch);
+ }
+ tensor_utils::MeanStddevNormalization(input_gate_scratch,
+ input_gate_scratch, n_cell, n_batch,
+ kLayerNormEpsilon);
+ tensor_utils::VectorBatchVectorCwiseProduct(input_layer_norm_weight_ptr,
+ n_cell, input_gate_scratch,
+ n_batch, input_gate_scratch);
+ tensor_utils::VectorBatchVectorAdd(input_gate_bias_ptr, n_cell, n_batch,
+ input_gate_scratch);
+ tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
+ input_gate_scratch);
+ }
+
+ // For each batch and cell: update forget gate.
+ if (use_peephole && !is_cell_state_all_zeros) {
+ tensor_utils::VectorScalarMultiply(cell_to_forget_weights_ptr, n_cell,
+ cell_to_forget_weights_scale,
+ recovered_weights);
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ recovered_weights, n_cell, cell_state_ptr, n_batch,
+ forget_gate_scratch);
+ }
+ tensor_utils::MeanStddevNormalization(forget_gate_scratch,
+ forget_gate_scratch, n_cell, n_batch,
+ kLayerNormEpsilon);
+ tensor_utils::VectorBatchVectorCwiseProduct(forget_layer_norm_weight_ptr,
+ n_cell, forget_gate_scratch,
+ n_batch, forget_gate_scratch);
+ tensor_utils::VectorBatchVectorAdd(forget_gate_bias_ptr, n_cell, n_batch,
+ forget_gate_scratch);
+ tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
+ forget_gate_scratch);
+
+ // For each batch and cell: update the cell.
+ tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
+ n_batch, kLayerNormEpsilon);
+ tensor_utils::VectorBatchVectorCwiseProduct(
+ cell_layer_norm_weight_ptr, n_cell, cell_scratch, n_batch, cell_scratch);
+ tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch,
+ cell_scratch);
+ tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
+ n_batch * n_cell, cell_state_ptr);
+ tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
+ activation, cell_scratch);
+ if (use_cifg) {
+ tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
+ forget_gate_scratch);
+ tensor_utils::VectorVectorCwiseProductAccumulate(
+ cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
+ } else {
+ tensor_utils::VectorVectorCwiseProductAccumulate(
+ cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
+ }
+ if (cell_clip > 0.0) {
+ tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell, cell_clip,
+ cell_state_ptr);
+ }
+
+ is_cell_state_all_zeros =
+ tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
+ // For each batch and cell: update the output gate.
+ if (use_peephole && !is_cell_state_all_zeros) {
+ tensor_utils::VectorScalarMultiply(cell_to_output_weights_ptr, n_cell,
+ cell_to_output_weights_scale,
+ recovered_weights);
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ recovered_weights, n_cell, cell_state_ptr, n_batch,
+ output_gate_scratch);
+ }
+ tensor_utils::MeanStddevNormalization(output_gate_scratch,
+ output_gate_scratch, n_cell, n_batch,
+ kLayerNormEpsilon);
+ tensor_utils::VectorBatchVectorCwiseProduct(output_layer_norm_weight_ptr,
+ n_cell, output_gate_scratch,
+ n_batch, output_gate_scratch);
+ tensor_utils::VectorBatchVectorAdd(output_gate_bias_ptr, n_cell, n_batch,
+ output_gate_scratch);
+ tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
+ output_gate_scratch);
+ tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
+ activation, cell_scratch);
+ tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
+ n_batch * n_cell, output_gate_scratch);
+
+ // For each batch: update the projection and output_state.
+ const bool use_projection_weight = (projection_weights_ptr != nullptr);
+ const bool use_projection_bias = (projection_bias_ptr != nullptr);
+ if (use_projection_weight) {
+ if (use_projection_bias) {
+ tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
+ n_batch, output_ptr_batch);
+ } else {
+ tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output);
+ }
+ if (!tensor_utils::IsZeroVector(output_gate_scratch, n_batch * n_cell)) {
+ // Save quantization and matmul computation for all zero input.
+ float unused_min, unused_max;
+ for (int b = 0; b < n_batch; ++b) {
+ const int offset = b * n_cell;
+ tensor_utils::SymmetricQuantizeFloats(
+ output_gate_scratch + offset, n_cell,
+ quantized_cell_state_ptr + offset, &unused_min, &unused_max,
+ &scaling_factors[b]);
+ }
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * projection_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ projection_weights_ptr, n_output, n_cell, quantized_cell_state_ptr,
+ product_scaling_factors, n_batch, output_ptr_batch,
+ /*result_stride=*/1);
+ }
+ if (proj_clip > 0.0) {
+ tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output, proj_clip,
+ output_ptr_batch);
+ }
+ } else {
+ tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
+ output_ptr_batch);
+ }
+ tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output,
+ output_state_ptr);
+}
+
+// The LayerNormLSTM Op engine.
+TfLiteStatus EvalFloat(
+ const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
+ const TfLiteTensor* input_to_forget_weights,
+ const TfLiteTensor* input_to_cell_weights,
+ const TfLiteTensor* input_to_output_weights,
+ const TfLiteTensor* recurrent_to_input_weights,
+ const TfLiteTensor* recurrent_to_forget_weights,
+ const TfLiteTensor* recurrent_to_cell_weights,
+ const TfLiteTensor* recurrent_to_output_weights,
+ const TfLiteTensor* cell_to_input_weights,
+ const TfLiteTensor* cell_to_forget_weights,
+ const TfLiteTensor* cell_to_output_weights,
+ const TfLiteTensor* input_layer_norm_weights,
+ const TfLiteTensor* forget_layer_norm_weights,
+ const TfLiteTensor* cell_layer_norm_weights,
+ const TfLiteTensor* output_layer_norm_weights,
+ const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
+ const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
+ const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
+ float cell_clip, float proj_clip, const TfLiteFusedActivation& activation,
+ TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state,
+ TfLiteTensor* cell_state, TfLiteTensor* output) {
+ const int n_batch = input->dims->data[0];
+ const int n_input = input->dims->data[1];
+ // n_cell and n_output will be the same size when there is no projection.
+ const int n_cell = input_to_output_weights->dims->data[0];
+ const int n_output = recurrent_to_output_weights->dims->data[1];
+
+ // Since we have already checked that weights are all there or none, we can
+ // check the existence of only one to get the condition.
+ const bool use_cifg = (input_to_input_weights == nullptr);
+ const bool use_peephole = (cell_to_output_weights != nullptr);
+
+ float* input_gate_scratch = nullptr;
+ float* cell_scratch = nullptr;
+ float* forget_gate_scratch = nullptr;
+ float* output_gate_scratch = nullptr;
+ if (use_cifg) {
+ cell_scratch = scratch_buffer->data.f;
+ forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ } else {
+ input_gate_scratch = scratch_buffer->data.f;
+ cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
+ }
+
+ // Check optional tensors, the respective pointers can be null.
+ const float* input_to_input_weights_ptr =
+ (use_cifg) ? nullptr : input_to_input_weights->data.f;
+ const float* recurrent_to_input_weights_ptr =
+ (use_cifg) ? nullptr : recurrent_to_input_weights->data.f;
+ const float* input_gate_bias_ptr =
+ (use_cifg) ? nullptr : input_gate_bias->data.f;
+ const float* cell_to_input_weights_ptr =
+ (use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr;
+ const float* cell_to_forget_weights_ptr =
+ (use_peephole) ? cell_to_forget_weights->data.f : nullptr;
+ const float* cell_to_output_weights_ptr =
+ (use_peephole) ? cell_to_output_weights->data.f : nullptr;
+ const float* projection_weights_ptr =
+ (projection_weights == nullptr) ? nullptr : projection_weights->data.f;
+ const float* projection_bias_ptr =
+ (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
+
+ // Required tensors, pointers are non-null.
+ const float* input_ptr_batch = input->data.f;
+ const float* input_to_forget_weights_ptr = input_to_forget_weights->data.f;
+ const float* input_to_cell_weights_ptr = input_to_cell_weights->data.f;
+ const float* input_to_output_weights_ptr = input_to_output_weights->data.f;
+ const float* recurrent_to_forget_weights_ptr =
+ recurrent_to_forget_weights->data.f;
+ const float* recurrent_to_cell_weights_ptr =
+ recurrent_to_cell_weights->data.f;
+ const float* recurrent_to_output_weights_ptr =
+ recurrent_to_output_weights->data.f;
+ const float* input_layer_norm_weight_ptr = input_layer_norm_weights->data.f;
+ const float* forget_layer_norm_weight_ptr = forget_layer_norm_weights->data.f;
+ const float* cell_layer_norm_weight_ptr = cell_layer_norm_weights->data.f;
+ const float* output_layer_norm_weight_ptr = output_layer_norm_weights->data.f;
+ const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
+ const float* cell_bias_ptr = cell_bias->data.f;
+ const float* output_gate_bias_ptr = output_gate_bias->data.f;
+
+ float* activation_state_ptr = activation_state->data.f;
+ float* cell_state_ptr = cell_state->data.f;
+ float* output_ptr_batch = output->data.f;
+
+ LayerNormLstmStep(
+ input_ptr_batch, input_to_input_weights_ptr, input_to_forget_weights_ptr,
+ input_to_cell_weights_ptr, input_to_output_weights_ptr,
+ recurrent_to_input_weights_ptr, recurrent_to_forget_weights_ptr,
+ recurrent_to_cell_weights_ptr, recurrent_to_output_weights_ptr,
+ cell_to_input_weights_ptr, cell_to_forget_weights_ptr,
+ cell_to_output_weights_ptr, input_layer_norm_weight_ptr,
+ forget_layer_norm_weight_ptr, cell_layer_norm_weight_ptr,
+ output_layer_norm_weight_ptr, input_gate_bias_ptr, forget_gate_bias_ptr,
+ cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr,
+ projection_bias_ptr, cell_clip, proj_clip, activation, n_batch, n_cell,
+ n_input, n_output, activation_state_ptr, cell_state_ptr,
+ input_gate_scratch, forget_gate_scratch, cell_scratch,
+ output_gate_scratch, output_ptr_batch);
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus EvalHybrid(
+ const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
+ const TfLiteTensor* input_to_forget_weights,
+ const TfLiteTensor* input_to_cell_weights,
+ const TfLiteTensor* input_to_output_weights,
+ const TfLiteTensor* recurrent_to_input_weights,
+ const TfLiteTensor* recurrent_to_forget_weights,
+ const TfLiteTensor* recurrent_to_cell_weights,
+ const TfLiteTensor* recurrent_to_output_weights,
+ const TfLiteTensor* cell_to_input_weights,
+ const TfLiteTensor* cell_to_forget_weights,
+ const TfLiteTensor* cell_to_output_weights,
+ const TfLiteTensor* input_layer_norm_weights,
+ const TfLiteTensor* forget_layer_norm_weights,
+ const TfLiteTensor* cell_layer_norm_weights,
+ const TfLiteTensor* output_layer_norm_weights,
+ const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
+ const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
+ const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
+ float cell_clip, float proj_clip, const TfLiteFusedActivation& activation,
+ TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors,
+ TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_weights,
+ TfLiteTensor* input_quantized, TfLiteTensor* activation_state_quantized,
+ TfLiteTensor* cell_state_quantized, TfLiteTensor* activation_state,
+ TfLiteTensor* cell_state, TfLiteTensor* output) {
+ const int n_batch = input->dims->data[0];
+ const int n_input = input->dims->data[1];
+ // n_cell and n_output will be the same size when there is no projection.
+ const int n_cell = input_to_output_weights->dims->data[0];
+ const int n_output = recurrent_to_output_weights->dims->data[1];
+
+ // Since we have already checked that weights are all there or none, we can
+ // check the existence of only one to get the condition.
+ const bool use_cifg = (input_to_input_weights == nullptr);
+ const bool use_peephole = (cell_to_output_weights != nullptr);
+
+ float* input_gate_scratch = nullptr;
+ float* cell_scratch = nullptr;
+ float* forget_gate_scratch = nullptr;
+ float* output_gate_scratch = nullptr;
+ if (use_cifg) {
+ cell_scratch = scratch_buffer->data.f;
+ forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ } else {
+ input_gate_scratch = scratch_buffer->data.f;
+ cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
+ }
+
+ // Check optional tensors, the respective pointers can be null.
+ int8_t* input_to_input_weights_ptr = nullptr;
+ float input_to_input_weights_scale = 1.0f;
+ int8_t* recurrent_to_input_weights_ptr = nullptr;
+ float recurrent_to_input_weights_scale = 1.0f;
+ float* input_gate_bias_ptr = nullptr;
+ if (!use_cifg) {
+ input_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_input_weights->data.uint8);
+ recurrent_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_input_weights->data.uint8);
+ input_gate_bias_ptr = input_gate_bias->data.f;
+ input_to_input_weights_scale = input_to_input_weights->params.scale;
+ recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale;
+ }
+
+ int8_t* cell_to_input_weights_ptr = nullptr;
+ int8_t* cell_to_forget_weights_ptr = nullptr;
+ int8_t* cell_to_output_weights_ptr = nullptr;
+ float cell_to_input_weights_scale = 1.0f;
+ float cell_to_forget_weights_scale = 1.0f;
+ float cell_to_output_weights_scale = 1.0f;
+ if (use_peephole) {
+ if (!use_cifg) {
+ cell_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(cell_to_input_weights->data.uint8);
+ cell_to_input_weights_scale = cell_to_input_weights->params.scale;
+ }
+ cell_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(cell_to_forget_weights->data.uint8);
+ cell_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(cell_to_output_weights->data.uint8);
+ cell_to_forget_weights_scale = cell_to_forget_weights->params.scale;
+ cell_to_output_weights_scale = cell_to_output_weights->params.scale;
+ }
+
+ const int8_t* projection_weights_ptr =
+ (projection_weights == nullptr)
+ ? nullptr
+ : reinterpret_cast<int8_t*>(projection_weights->data.uint8);
+ const float projection_weights_scale =
+ (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale;
+ const float* projection_bias_ptr =
+ (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
+
+ // Required tensors, pointers are non-null.
+ const float* input_ptr_batch = input->data.f;
+ const int8_t* input_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_forget_weights->data.uint8);
+ const float input_to_forget_weights_scale =
+ input_to_forget_weights->params.scale;
+ const int8_t* input_to_cell_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_cell_weights->data.uint8);
+ const float input_to_cell_weights_scale = input_to_cell_weights->params.scale;
+ const int8_t* input_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_output_weights->data.uint8);
+ const float input_to_output_weights_scale =
+ input_to_output_weights->params.scale;
+ const int8_t* recurrent_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_forget_weights->data.uint8);
+ const float recurrent_to_forget_weights_scale =
+ recurrent_to_forget_weights->params.scale;
+ const int8_t* recurrent_to_cell_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_cell_weights->data.uint8);
+ const float recurrent_to_cell_weights_scale =
+ recurrent_to_cell_weights->params.scale;
+ const int8_t* recurrent_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_output_weights->data.uint8);
+ const float recurrent_to_output_weights_scale =
+ recurrent_to_output_weights->params.scale;
+ const float* input_layer_norm_weight_ptr = input_layer_norm_weights->data.f;
+ const float* forget_layer_norm_weight_ptr = forget_layer_norm_weights->data.f;
+ const float* cell_layer_norm_weight_ptr = cell_layer_norm_weights->data.f;
+ const float* output_layer_norm_weight_ptr = output_layer_norm_weights->data.f;
+ const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
+ const float* cell_bias_ptr = cell_bias->data.f;
+ const float* output_gate_bias_ptr = output_gate_bias->data.f;
+
+ float* activation_state_ptr = activation_state->data.f;
+ float* cell_state_ptr = cell_state->data.f;
+ float* output_ptr_batch = output->data.f;
+
+ // Temporary storage for quantized values and scaling factors.
+ int8_t* quantized_input_ptr =
+ reinterpret_cast<int8_t*>(input_quantized->data.uint8);
+ int8_t* quantized_activation_state_ptr =
+ reinterpret_cast<int8_t*>(activation_state_quantized->data.uint8);
+ int8_t* quantized_cell_state_ptr =
+ reinterpret_cast<int8_t*>(cell_state_quantized->data.uint8);
+ float* scaling_factors_ptr = scaling_factors->data.f;
+ float* prod_scaling_factors_ptr = prod_scaling_factors->data.f;
+ float* recovered_weights_ptr = recovered_weights->data.f;
+
+ LayerNormLstmStep(
+ input_ptr_batch, input_to_input_weights_ptr, input_to_input_weights_scale,
+ input_to_forget_weights_ptr, input_to_forget_weights_scale,
+ input_to_cell_weights_ptr, input_to_cell_weights_scale,
+ input_to_output_weights_ptr, input_to_output_weights_scale,
+ recurrent_to_input_weights_ptr, recurrent_to_input_weights_scale,
+ recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_scale,
+ recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_scale,
+ recurrent_to_output_weights_ptr, recurrent_to_output_weights_scale,
+ cell_to_input_weights_ptr, cell_to_input_weights_scale,
+ cell_to_forget_weights_ptr, cell_to_forget_weights_scale,
+ cell_to_output_weights_ptr, cell_to_output_weights_scale,
+ input_layer_norm_weight_ptr, forget_layer_norm_weight_ptr,
+ cell_layer_norm_weight_ptr, output_layer_norm_weight_ptr,
+ input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr,
+ output_gate_bias_ptr, projection_weights_ptr, projection_weights_scale,
+ projection_bias_ptr, cell_clip, proj_clip, activation, n_batch, n_cell,
+ n_input, n_output, input_gate_scratch, forget_gate_scratch, cell_scratch,
+ output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr,
+ recovered_weights_ptr, quantized_input_ptr,
+ quantized_activation_state_ptr, quantized_cell_state_ptr,
+ activation_state_ptr, cell_state_ptr, output_ptr_batch);
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
+
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+
+ const TfLiteTensor* input_to_input_weights =
+ GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
+ const TfLiteTensor* input_to_forget_weights =
+ GetInput(context, node, kInputToForgetWeightsTensor);
+ const TfLiteTensor* input_to_cell_weights =
+ GetInput(context, node, kInputToCellWeightsTensor);
+ const TfLiteTensor* input_to_output_weights =
+ GetInput(context, node, kInputToOutputWeightsTensor);
+
+ const TfLiteTensor* recurrent_to_input_weights =
+ GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
+ const TfLiteTensor* recurrent_to_forget_weights =
+ GetInput(context, node, kRecurrentToForgetWeightsTensor);
+ const TfLiteTensor* recurrent_to_cell_weights =
+ GetInput(context, node, kRecurrentToCellWeightsTensor);
+ const TfLiteTensor* recurrent_to_output_weights =
+ GetInput(context, node, kRecurrentToOutputWeightsTensor);
+
+ const TfLiteTensor* cell_to_input_weights =
+ GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
+ const TfLiteTensor* cell_to_forget_weights =
+ GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
+ const TfLiteTensor* cell_to_output_weights =
+ GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
+
+ const TfLiteTensor* input_layer_norm_weights =
+ GetInput(context, node, kInputLayerNormWeightsTensor);
+ const TfLiteTensor* forget_layer_norm_weights =
+ GetInput(context, node, kForgetLayerNormWeightsTensor);
+ const TfLiteTensor* cell_layer_norm_weights =
+ GetInput(context, node, kCellLayerNormWeightsTensor);
+ const TfLiteTensor* output_layer_norm_weights =
+ GetInput(context, node, kOutputLayerNormWeightsTensor);
+
+ const TfLiteTensor* input_gate_bias =
+ GetOptionalInputTensor(context, node, kInputGateBiasTensor);
+ const TfLiteTensor* forget_gate_bias =
+ GetInput(context, node, kForgetGateBiasTensor);
+ const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
+ const TfLiteTensor* output_gate_bias =
+ GetInput(context, node, kOutputGateBiasTensor);
+
+ const TfLiteTensor* projection_weights =
+ GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
+ const TfLiteTensor* projection_bias =
+ GetOptionalInputTensor(context, node, kProjectionBiasTensor);
+
+ // Index the scratch buffers pointers to the global scratch buffer.
+ TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
+
+ TfLiteTensor* activation_state =
+ &context->tensors[node->inputs->data[kInputActivationStateTensor]];
+ TfLiteTensor* cell_state =
+ &context->tensors[node->inputs->data[kInputCellStateTensor]];
+
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ switch (input_to_output_weights->type) {
+ case kTfLiteFloat32: {
+ return EvalFloat(input, input_to_input_weights, input_to_forget_weights,
+ input_to_cell_weights, input_to_output_weights,
+ recurrent_to_input_weights, recurrent_to_forget_weights,
+ recurrent_to_cell_weights, recurrent_to_output_weights,
+ cell_to_input_weights, cell_to_forget_weights,
+ cell_to_output_weights, input_layer_norm_weights,
+ forget_layer_norm_weights, cell_layer_norm_weights,
+ output_layer_norm_weights, input_gate_bias,
+ forget_gate_bias, cell_bias, output_gate_bias,
+ projection_weights, projection_bias, op_data->cell_clip,
+ op_data->proj_clip, op_data->activation, scratch_buffer,
+ activation_state, cell_state, output);
+ }
+ case kTfLiteUInt8: {
+ TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
+ TfLiteTensor* activation_state_quantized =
+ GetTemporary(context, node, /*index=*/2);
+ TfLiteTensor* cell_state_quantized =
+ GetTemporary(context, node, /*index=*/3);
+ TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4);
+ TfLiteTensor* prod_scaling_factors =
+ GetTemporary(context, node, /*index=*/5);
+ TfLiteTensor* recovered_weights =
+ GetTemporary(context, node, /*index=*/6);
+ return EvalHybrid(
+ input, input_to_input_weights, input_to_forget_weights,
+ input_to_cell_weights, input_to_output_weights,
+ recurrent_to_input_weights, recurrent_to_forget_weights,
+ recurrent_to_cell_weights, recurrent_to_output_weights,
+ cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
+ input_layer_norm_weights, forget_layer_norm_weights,
+ cell_layer_norm_weights, output_layer_norm_weights, input_gate_bias,
+ forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
+ projection_bias, op_data->cell_clip, op_data->proj_clip,
+ op_data->activation, scratch_buffer, scaling_factors,
+ prod_scaling_factors, recovered_weights, input_quantized,
+ activation_state_quantized, cell_state_quantized, activation_state,
+ cell_state, output);
+ }
+ default:
+ context->ReportError(context, "Type %d is not currently supported.",
+ input_to_output_weights->type);
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
+} // namespace layer_norm_lstm
+
+TfLiteRegistration* Register_LAYER_NORM_LSTM() {
+ static TfLiteRegistration r = {layer_norm_lstm::Init, layer_norm_lstm::Free,
+ layer_norm_lstm::Prepare,
+ layer_norm_lstm::Eval};
+ return &r;
+}
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/layer_norm_lstm_test.cc b/tensorflow/contrib/lite/kernels/layer_norm_lstm_test.cc
new file mode 100644
index 0000000000..abc229f85a
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/layer_norm_lstm_test.cc
@@ -0,0 +1,664 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Unit test for TFLite Layer Norm LSTM op.
+
+#include <memory>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "flatbuffers/flexbuffers.h" // flatbuffers
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+TfLiteRegistration* Register_LAYER_NORM_LSTM();
+
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class LayerNormLSTMOpModel : public SingleOpModel {
+ public:
+ LayerNormLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output,
+ bool use_cifg, bool use_peephole,
+ bool use_projection_weights, bool use_projection_bias,
+ float cell_clip, float proj_clip,
+ const std::vector<std::vector<int>>& input_shapes,
+ const TensorType& weight_type = TensorType_FLOAT32)
+ : n_batch_(n_batch),
+ n_input_(n_input),
+ n_cell_(n_cell),
+ n_output_(n_output) {
+ input_ = AddInput(TensorType_FLOAT32);
+
+ if (use_cifg) {
+ input_to_input_weights_ = AddNullInput();
+ } else {
+ input_to_input_weights_ = AddInput(weight_type);
+ }
+
+ input_to_forget_weights_ = AddInput(weight_type);
+ input_to_cell_weights_ = AddInput(weight_type);
+ input_to_output_weights_ = AddInput(weight_type);
+
+ if (use_cifg) {
+ recurrent_to_input_weights_ = AddNullInput();
+ } else {
+ recurrent_to_input_weights_ = AddInput(weight_type);
+ }
+
+ recurrent_to_forget_weights_ = AddInput(weight_type);
+ recurrent_to_cell_weights_ = AddInput(weight_type);
+ recurrent_to_output_weights_ = AddInput(weight_type);
+
+ if (use_peephole) {
+ if (use_cifg) {
+ cell_to_input_weights_ = AddNullInput();
+ } else {
+ cell_to_input_weights_ = AddInput(weight_type);
+ }
+ cell_to_forget_weights_ = AddInput(weight_type);
+ cell_to_output_weights_ = AddInput(weight_type);
+ } else {
+ cell_to_input_weights_ = AddNullInput();
+ cell_to_forget_weights_ = AddNullInput();
+ cell_to_output_weights_ = AddNullInput();
+ }
+
+ input_layer_norm_weights_ = AddInput(TensorType_FLOAT32);
+ forget_layer_norm_weights_ = AddInput(TensorType_FLOAT32);
+ cell_layer_norm_weights_ = AddInput(TensorType_FLOAT32);
+ output_layer_norm_weights_ = AddInput(TensorType_FLOAT32);
+
+ if (use_cifg) {
+ input_gate_bias_ = AddNullInput();
+ } else {
+ input_gate_bias_ = AddInput(TensorType_FLOAT32);
+ }
+ forget_gate_bias_ = AddInput(TensorType_FLOAT32);
+ cell_bias_ = AddInput(TensorType_FLOAT32);
+ output_gate_bias_ = AddInput(TensorType_FLOAT32);
+
+ if (use_projection_weights) {
+ projection_weights_ = AddInput(weight_type);
+ if (use_projection_bias) {
+ projection_bias_ = AddInput(TensorType_FLOAT32);
+ } else {
+ projection_bias_ = AddNullInput();
+ }
+ } else {
+ projection_weights_ = AddNullInput();
+ projection_bias_ = AddNullInput();
+ }
+
+ // Adding the 2 state tensors.
+ output_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_output_ * n_batch_}}, true);
+ cell_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, true);
+
+ output_ = AddOutput(TensorType_FLOAT32);
+
+ // Set up and pass in custom options using flexbuffer.
+ flexbuffers::Builder fbb;
+ fbb.Map([&]() {
+ fbb.Int("cell_clip", cell_clip);
+ fbb.Int("proj_clip", proj_clip);
+ fbb.String("fused_activation_function", "TANH");
+ });
+ fbb.Finish();
+ SetCustomOp("LAYER_NORM_LSTM", fbb.GetBuffer(), Register_LAYER_NORM_LSTM);
+ BuildInterpreter(input_shapes);
+ }
+
+ void SetInputToInputWeights(std::initializer_list<float> f) {
+ PopulateTensor(input_to_input_weights_, f);
+ }
+
+ void SetInputToForgetWeights(std::initializer_list<float> f) {
+ PopulateTensor(input_to_forget_weights_, f);
+ }
+
+ void SetInputToCellWeights(std::initializer_list<float> f) {
+ PopulateTensor(input_to_cell_weights_, f);
+ }
+
+ void SetInputToOutputWeights(std::initializer_list<float> f) {
+ PopulateTensor(input_to_output_weights_, f);
+ }
+
+ void SetRecurrentToInputWeights(std::initializer_list<float> f) {
+ PopulateTensor(recurrent_to_input_weights_, f);
+ }
+
+ void SetRecurrentToForgetWeights(std::initializer_list<float> f) {
+ PopulateTensor(recurrent_to_forget_weights_, f);
+ }
+
+ void SetRecurrentToCellWeights(std::initializer_list<float> f) {
+ PopulateTensor(recurrent_to_cell_weights_, f);
+ }
+
+ void SetRecurrentToOutputWeights(std::initializer_list<float> f) {
+ PopulateTensor(recurrent_to_output_weights_, f);
+ }
+
+ void SetCellToInputWeights(std::initializer_list<float> f) {
+ PopulateTensor(cell_to_input_weights_, f);
+ }
+
+ void SetCellToForgetWeights(std::initializer_list<float> f) {
+ PopulateTensor(cell_to_forget_weights_, f);
+ }
+
+ void SetCellToOutputWeights(std::initializer_list<float> f) {
+ PopulateTensor(cell_to_output_weights_, f);
+ }
+
+ void SetInputLayerNormWeights(std::initializer_list<float> f) {
+ PopulateTensor(input_layer_norm_weights_, f);
+ }
+
+ void SetForgetLayerNormWeights(std::initializer_list<float> f) {
+ PopulateTensor(forget_layer_norm_weights_, f);
+ }
+
+ void SetCellLayerNormWeights(std::initializer_list<float> f) {
+ PopulateTensor(cell_layer_norm_weights_, f);
+ }
+
+ void SetOutputLayerNormWeights(std::initializer_list<float> f) {
+ PopulateTensor(output_layer_norm_weights_, f);
+ }
+
+ void SetInputGateBias(std::initializer_list<float> f) {
+ PopulateTensor(input_gate_bias_, f);
+ }
+
+ void SetForgetGateBias(std::initializer_list<float> f) {
+ PopulateTensor(forget_gate_bias_, f);
+ }
+
+ void SetCellBias(std::initializer_list<float> f) {
+ PopulateTensor(cell_bias_, f);
+ }
+
+ void SetOutputGateBias(std::initializer_list<float> f) {
+ PopulateTensor(output_gate_bias_, f);
+ }
+
+ void SetProjectionWeights(std::initializer_list<float> f) {
+ PopulateTensor(projection_weights_, f);
+ }
+
+ void SetProjectionBias(std::initializer_list<float> f) {
+ PopulateTensor(projection_bias_, f);
+ }
+
+ void SetInput(int offset, const float* begin, const float* end) {
+ PopulateTensor(input_, offset, const_cast<float*>(begin),
+ const_cast<float*>(end));
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+ int num_inputs() { return n_input_; }
+ int num_outputs() { return n_output_; }
+ int num_cells() { return n_cell_; }
+ int num_batches() { return n_batch_; }
+
+ protected:
+ int input_;
+ int input_to_input_weights_;
+ int input_to_forget_weights_;
+ int input_to_cell_weights_;
+ int input_to_output_weights_;
+
+ int recurrent_to_input_weights_;
+ int recurrent_to_forget_weights_;
+ int recurrent_to_cell_weights_;
+ int recurrent_to_output_weights_;
+
+ int cell_to_input_weights_;
+ int cell_to_forget_weights_;
+ int cell_to_output_weights_;
+
+ int input_layer_norm_weights_;
+ int forget_layer_norm_weights_;
+ int cell_layer_norm_weights_;
+ int output_layer_norm_weights_;
+
+ int input_gate_bias_;
+ int forget_gate_bias_;
+ int cell_bias_;
+ int output_gate_bias_;
+
+ int projection_weights_;
+ int projection_bias_;
+
+ int output_state_;
+ int cell_state_;
+
+ int output_;
+
+ int n_batch_;
+ int n_input_;
+ int n_cell_;
+ int n_output_;
+};
+
+class HybridLayerNormLSTMOpModel : public LayerNormLSTMOpModel {
+ public:
+ HybridLayerNormLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output,
+ bool use_cifg, bool use_peephole,
+ bool use_projection_weights,
+ bool use_projection_bias, float cell_clip,
+ float proj_clip,
+ const std::vector<std::vector<int>>& input_shapes)
+ : LayerNormLSTMOpModel(n_batch, n_input, n_cell, n_output, use_cifg,
+ use_peephole, use_projection_weights,
+ use_projection_bias, cell_clip, proj_clip,
+ input_shapes, TensorType_UINT8) {}
+
+ void SetInputToInputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(input_to_input_weights_, f);
+ }
+
+ void SetInputToForgetWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(input_to_forget_weights_, f);
+ }
+
+ void SetInputToCellWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(input_to_cell_weights_, f);
+ }
+
+ void SetInputToOutputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(input_to_output_weights_, f);
+ }
+
+ void SetRecurrentToInputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(recurrent_to_input_weights_, f);
+ }
+
+ void SetRecurrentToForgetWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(recurrent_to_forget_weights_, f);
+ }
+
+ void SetRecurrentToCellWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(recurrent_to_cell_weights_, f);
+ }
+
+ void SetRecurrentToOutputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(recurrent_to_output_weights_, f);
+ }
+
+ void SetCellToInputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(cell_to_input_weights_, f);
+ }
+
+ void SetCellToForgetWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(cell_to_forget_weights_, f);
+ }
+
+ void SetCellToOutputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(cell_to_output_weights_, f);
+ }
+
+ void SetInputLayerNormWeights(std::initializer_list<float> f) {
+ PopulateTensor(input_layer_norm_weights_, f);
+ }
+
+ void SetForgetLayerNormWeights(std::initializer_list<float> f) {
+ PopulateTensor(forget_layer_norm_weights_, f);
+ }
+
+ void SetCellLayerNormWeights(std::initializer_list<float> f) {
+ PopulateTensor(cell_layer_norm_weights_, f);
+ }
+
+ void SetOutputLayerNormWeights(std::initializer_list<float> f) {
+ PopulateTensor(output_layer_norm_weights_, f);
+ }
+
+ void SetProjectionWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(projection_weights_, f);
+ }
+};
+
+class BaseLayerNormLstmTest : public ::testing::Test {
+ protected:
+ // Weights of the Layer Norm LSTM model. Some are optional.
+ std::initializer_list<float> input_to_input_weights_;
+ std::initializer_list<float> input_to_cell_weights_;
+ std::initializer_list<float> input_to_forget_weights_;
+ std::initializer_list<float> input_to_output_weights_;
+ std::initializer_list<float> input_gate_bias_;
+ std::initializer_list<float> cell_gate_bias_;
+ std::initializer_list<float> forget_gate_bias_;
+ std::initializer_list<float> output_gate_bias_;
+ std::initializer_list<float> recurrent_to_input_weights_;
+ std::initializer_list<float> recurrent_to_cell_weights_;
+ std::initializer_list<float> recurrent_to_forget_weights_;
+ std::initializer_list<float> recurrent_to_output_weights_;
+ std::initializer_list<float> cell_to_input_weights_;
+ std::initializer_list<float> cell_to_forget_weights_;
+ std::initializer_list<float> cell_to_output_weights_;
+ std::initializer_list<float> input_layer_norm_weights_;
+ std::initializer_list<float> forget_layer_norm_weights_;
+ std::initializer_list<float> cell_layer_norm_weights_;
+ std::initializer_list<float> output_layer_norm_weights_;
+ std::initializer_list<float> projection_weights_;
+
+ // Layer Norm LSTM input is stored as num_batch x num_inputs vector.
+ std::vector<std::vector<float>> layer_norm_lstm_input_;
+
+ // Compares output up to tolerance to the result of the layer_norm_lstm given
+ // the input.
+ void VerifyGoldens(const std::vector<std::vector<float>>& input,
+ const std::vector<std::vector<float>>& output,
+ LayerNormLSTMOpModel* layer_norm_lstm,
+ float tolerance = 1e-5) {
+ const int num_batches = input.size();
+ EXPECT_GT(num_batches, 0);
+ const int num_inputs = layer_norm_lstm->num_inputs();
+ EXPECT_GT(num_inputs, 0);
+ const int input_sequence_size = input[0].size() / num_inputs;
+ EXPECT_GT(input_sequence_size, 0);
+ for (int i = 0; i < input_sequence_size; ++i) {
+ for (int b = 0; b < num_batches; ++b) {
+ const float* batch_start = input[b].data() + i * num_inputs;
+ const float* batch_end = batch_start + num_inputs;
+
+ layer_norm_lstm->SetInput(b * layer_norm_lstm->num_inputs(),
+ batch_start, batch_end);
+ }
+
+ layer_norm_lstm->Invoke();
+
+ const int num_outputs = layer_norm_lstm->num_outputs();
+ std::vector<float> expected;
+ for (int b = 0; b < num_batches; ++b) {
+ const float* golden_start_batch = output[b].data() + i * num_outputs;
+ const float* golden_end_batch = golden_start_batch + num_outputs;
+ expected.insert(expected.end(), golden_start_batch, golden_end_batch);
+ }
+ EXPECT_THAT(layer_norm_lstm->GetOutput(),
+ ElementsAreArray(ArrayFloatNear(expected, tolerance)));
+ }
+ }
+};
+
+class NoCifgPeepholeProjectionNoClippingLayerNormLstmTest
+ : public BaseLayerNormLstmTest {
+ void SetUp() override {
+ input_to_input_weights_ = {0.5, 0.6, 0.7, -0.8, -0.9, 0.1, 0.2,
+ 0.3, -0.4, 0.5, -0.8, 0.7, -0.6, 0.5,
+ -0.4, -0.5, -0.4, -0.3, -0.2, -0.1};
+
+ input_to_forget_weights_ = {-0.6, -0.1, 0.3, 0.2, 0.9, -0.5, -0.2,
+ -0.4, 0.3, -0.8, -0.4, 0.3, -0.5, -0.4,
+ -0.6, 0.3, -0.4, -0.6, -0.5, -0.5};
+
+ input_to_cell_weights_ = {-0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2,
+ -0.3, -0.2, -0.6, 0.6, -0.1, -0.4, -0.3,
+ -0.7, 0.7, -0.9, -0.5, 0.8, 0.6};
+
+ input_to_output_weights_ = {-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3,
+ -0.3, -0.8, -0.2, 0.6, -0.2, 0.4, -0.7,
+ -0.3, -0.5, 0.1, 0.5, -0.6, -0.4};
+
+ input_gate_bias_ = {0.03, 0.15, 0.22, 0.38};
+
+ forget_gate_bias_ = {0.1, -0.3, -0.2, 0.1};
+
+ cell_gate_bias_ = {-0.05, 0.72, 0.25, 0.08};
+
+ output_gate_bias_ = {0.05, -0.01, 0.2, 0.1};
+
+ recurrent_to_input_weights_ = {-0.2, -0.3, 0.4, 0.1, -0.5, 0.9,
+ -0.2, -0.3, -0.7, 0.05, -0.2, -0.6};
+
+ recurrent_to_cell_weights_ = {-0.3, 0.2, 0.1, -0.3, 0.8, -0.08,
+ -0.2, 0.3, 0.8, -0.6, -0.1, 0.2};
+
+ recurrent_to_forget_weights_ = {-0.5, -0.3, -0.5, -0.2, 0.6, 0.4,
+ 0.9, 0.3, -0.1, 0.2, 0.5, 0.2};
+
+ recurrent_to_output_weights_ = {0.3, -0.1, 0.1, -0.2, -0.5, -0.7,
+ -0.2, -0.6, -0.1, -0.4, -0.7, -0.2};
+
+ cell_to_input_weights_ = {0.05, 0.1, 0.25, 0.15};
+
+ cell_to_forget_weights_ = {-0.02, -0.15, -0.25, -0.03};
+
+ cell_to_output_weights_ = {0.1, -0.1, -0.5, 0.05};
+
+ input_layer_norm_weights_ = {0.1, 0.2, 0.3, 0.5};
+ forget_layer_norm_weights_ = {0.2, 0.2, 0.4, 0.3};
+ cell_layer_norm_weights_ = {0.7, 0.2, 0.3, 0.8};
+ output_layer_norm_weights_ = {0.6, 0.2, 0.2, 0.5};
+
+ projection_weights_ = {-0.1, 0.2, 0.01, -0.2, 0.1, 0.5,
+ 0.3, 0.08, 0.07, 0.2, -0.4, 0.2};
+
+ layer_norm_lstm_input_ = {
+ {// Batch0: 3 (input_sequence_size) * 5 (n_input)
+ 0.7, 0.8, 0.1, 0.2, 0.3, // seq 0
+ 0.8, 0.1, 0.2, 0.4, 0.5, // seq 1
+ 0.2, 0.7, 0.7, 0.1, 0.7}, // seq 2
+
+ {// Batch1: 3 (input_sequence_size) * 5 (n_input)
+ 0.3, 0.2, 0.9, 0.8, 0.1, // seq 0
+ 0.1, 0.5, 0.2, 0.4, 0.2, // seq 1
+ 0.6, 0.9, 0.2, 0.5, 0.7}, // seq 2
+ };
+ }
+};
+
+TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
+ LayerNormLstmBlackBoxTest) {
+ const int n_batch = 2;
+ const int n_input = 5;
+ const int n_cell = 4;
+ const int n_output = 3;
+ const float ceil_clip = 0.0;
+ const float proj_clip = 0.0;
+
+ LayerNormLSTMOpModel layer_norm_lstm(
+ n_batch, n_input, n_cell, n_output,
+ /*use_cifg=*/false, /*use_peephole=*/true,
+ /*use_projection_weights=*/true,
+ /*use_projection_bias=*/false, ceil_clip, proj_clip,
+ {
+ {n_batch, n_input}, // input tensor
+
+ {n_cell, n_input}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {n_cell, n_output}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {n_cell}, // cell_to_input_weight tensor
+ {n_cell}, // cell_to_forget_weight tensor
+ {n_cell}, // cell_to_output_weight tensor
+
+ {n_cell}, // input_layer_norm_weight tensor
+ {n_cell}, // forget_layer_norm_weight tensor
+ {n_cell}, // cell_layer_norm_weight tensor
+ {n_cell}, // output_layer_norm_weight tensor
+
+ {n_cell}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {n_output, n_cell}, // projection_weight tensor
+ {0}, // projection_bias tensor
+ });
+
+ layer_norm_lstm.SetInputToInputWeights(input_to_input_weights_);
+ layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_);
+ layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_);
+ layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_);
+
+ layer_norm_lstm.SetInputGateBias(input_gate_bias_);
+ layer_norm_lstm.SetCellBias(cell_gate_bias_);
+ layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
+ layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
+
+ layer_norm_lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
+ layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
+ layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
+ layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
+
+ layer_norm_lstm.SetCellToInputWeights(cell_to_input_weights_);
+ layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
+ layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
+
+ layer_norm_lstm.SetInputLayerNormWeights(input_layer_norm_weights_);
+ layer_norm_lstm.SetForgetLayerNormWeights(forget_layer_norm_weights_);
+ layer_norm_lstm.SetCellLayerNormWeights(cell_layer_norm_weights_);
+ layer_norm_lstm.SetOutputLayerNormWeights(output_layer_norm_weights_);
+
+ layer_norm_lstm.SetProjectionWeights(projection_weights_);
+
+ // Verify the final output.
+ const std::vector<std::vector<float>> layer_norm_lstm_golden_output = {
+ {
+ // Batch0: 3 (input_sequence_size) * 3 (n_output)
+ 0.0244077, 0.128027, -0.00170918, // seq 0
+ 0.0137642, 0.140751, 0.0395835, // seq 1
+ -0.00459231, 0.155278, 0.0837377, // seq 2
+ },
+ {
+ // Batch1: 3 (input_sequence_size) * 3 (n_output)
+ -0.00692428, 0.0848741, 0.063445, // seq 0
+ -0.00403912, 0.139963, 0.072681, // seq 1
+ 0.00752706, 0.161903, 0.0561371, // seq 2
+ }};
+
+ VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output,
+ &layer_norm_lstm);
+}
+
+TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
+ HybridLayerNormLstmBlackBoxTest) {
+ const int n_batch = 2;
+ const int n_input = 5;
+ const int n_cell = 4;
+ const int n_output = 3;
+ const float ceil_clip = 0.0;
+ const float proj_clip = 0.0;
+
+ HybridLayerNormLSTMOpModel layer_norm_lstm(
+ n_batch, n_input, n_cell, n_output,
+ /*use_cifg=*/false, /*use_peephole=*/true,
+ /*use_projection_weights=*/true,
+ /*use_projection_bias=*/false, ceil_clip, proj_clip,
+ {
+ {n_batch, n_input}, // input tensor
+
+ {n_cell, n_input}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {n_cell, n_output}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {n_cell}, // cell_to_input_weight tensor
+ {n_cell}, // cell_to_forget_weight tensor
+ {n_cell}, // cell_to_output_weight tensor
+
+ {n_cell}, // input_layer_norm_weight tensor
+ {n_cell}, // forget_layer_norm_weight tensor
+ {n_cell}, // cell_layer_norm_weight tensor
+ {n_cell}, // output_layer_norm_weight tensor
+
+ {n_cell}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {n_output, n_cell}, // projection_weight tensor
+ {0}, // projection_bias tensor
+ });
+
+ layer_norm_lstm.SetInputToInputWeights(input_to_input_weights_);
+ layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_);
+ layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_);
+ layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_);
+
+ layer_norm_lstm.SetInputGateBias(input_gate_bias_);
+ layer_norm_lstm.SetCellBias(cell_gate_bias_);
+ layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
+ layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
+
+ layer_norm_lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
+ layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
+ layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
+ layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
+
+ layer_norm_lstm.SetCellToInputWeights(cell_to_input_weights_);
+ layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
+ layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
+
+ layer_norm_lstm.SetInputLayerNormWeights(input_layer_norm_weights_);
+ layer_norm_lstm.SetForgetLayerNormWeights(forget_layer_norm_weights_);
+ layer_norm_lstm.SetCellLayerNormWeights(cell_layer_norm_weights_);
+ layer_norm_lstm.SetOutputLayerNormWeights(output_layer_norm_weights_);
+
+ layer_norm_lstm.SetProjectionWeights(projection_weights_);
+
+ const std::vector<std::vector<float>> layer_norm_lstm_golden_output = {
+ {
+ // Batch0: 3 (input_sequence_size) * 3 (n_output)
+ 0.0244576, 0.127847, -0.00181765, // seq 0
+ 0.0137518, 0.140892, 0.0402234, // seq 1
+ -0.0048839, 0.155096, 0.0840309, // seq 2
+ },
+ {
+ // Batch1: 3 (input_sequence_size) * 3 (n_output)
+ -0.00728636, 0.0843957, 0.0634786, // seq 0
+ -0.00448382, 0.139278, 0.0737372, // seq 1
+ 0.00734616, 0.161793, 0.0560238, // seq 2
+ }};
+
+ VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output,
+ &layer_norm_lstm);
+}
+
+} // namespace
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/pad.cc b/tensorflow/contrib/lite/kernels/pad.cc
index 55bcf3b533..3bce05353d 100644
--- a/tensorflow/contrib/lite/kernels/pad.cc
+++ b/tensorflow/contrib/lite/kernels/pad.cc
@@ -92,8 +92,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
op_context.constant_values->type);
}
- // TODO(nupurgarg): Our current implementations rely on the inputs being 4D.
- TF_LITE_ENSURE_EQ(context, op_context.dims, 4);
+ // TODO(nupurgarg): Current implementations rely on the inputs being <= 4D.
+ TF_LITE_ENSURE(context, op_context.dims <= 4);
// Exit early if paddings is a non-const tensor. Set output tensor to
// dynamic so output size can be determined in Eval.
@@ -134,21 +134,21 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
after_padding.push_back(paddings_data[idx * 2 + 1]);
}
-#define TF_LITE_PAD(type, scalar, pad_value) \
- TF_LITE_ENSURE_EQ(context, before_padding.size(), 4); \
- TF_LITE_ENSURE_EQ(context, after_padding.size(), 4); \
- tflite::PadParams op_params; \
- op_params.left_padding_count = 4; \
- op_params.right_padding_count = 4; \
- for (int i = 0; i < 4; ++i) { \
- op_params.left_padding[i] = before_padding[3 - i]; \
- op_params.right_padding[i] = after_padding[3 - i]; \
- } \
- const scalar pad_value_copy = pad_value; \
- \
- type::Pad(op_params, GetTensorShape(op_context.input), \
- GetTensorData<scalar>(op_context.input), &pad_value_copy, \
- GetTensorShape(op_context.output), \
+#define TF_LITE_PAD(type, scalar, pad_value) \
+ TF_LITE_ENSURE(context, before_padding.size() <= 4); \
+ TF_LITE_ENSURE(context, after_padding.size() <= 4); \
+ tflite::PadParams op_params; \
+ op_params.left_padding_count = before_padding.size(); \
+ op_params.right_padding_count = after_padding.size(); \
+ for (int i = 0; i < op_context.dims; ++i) { \
+ op_params.left_padding[i] = before_padding[op_context.dims - 1 - i]; \
+ op_params.right_padding[i] = after_padding[op_context.dims - 1 - i]; \
+ } \
+ const scalar pad_value_copy = pad_value; \
+ \
+ type::Pad(op_params, GetTensorShape(op_context.input), \
+ GetTensorData<scalar>(op_context.input), &pad_value_copy, \
+ GetTensorShape(op_context.output), \
GetTensorData<scalar>(op_context.output))
switch (op_context.input->type) {
case kTfLiteFloat32: {
diff --git a/tensorflow/contrib/lite/kernels/pad_test.cc b/tensorflow/contrib/lite/kernels/pad_test.cc
index f8b9064fbb..f663899713 100644
--- a/tensorflow/contrib/lite/kernels/pad_test.cc
+++ b/tensorflow/contrib/lite/kernels/pad_test.cc
@@ -193,7 +193,7 @@ TEST(PadOpTest, TooManyDimensions) {
PadOpConstModel({TensorType_FLOAT32, {1, 2, 3, 4, 5, 6, 7, 8, 9}}, {9, 2},
{1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9},
{TensorType_FLOAT32}),
- "dims != 4");
+ "dims <= 4");
}
TEST(PadOpTest, UnequalDimensions) {
@@ -221,6 +221,15 @@ TEST(PadOpTest, SimpleConstTest) {
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
}
+TEST(PadOpTest, SimpleConst1DTest) {
+ PadOpConstModel m({TensorType_FLOAT32, {2}}, {1, 2}, {1, 2},
+ {TensorType_FLOAT32});
+ m.SetInput({2, 3});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 2, 3, 0, 0}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({5}));
+}
+
TEST(PadOpTest, SimpleDynamicTest) {
PadOpDynamicModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, {4, 2},
{TensorType_FLOAT32});
@@ -334,7 +343,7 @@ TEST(PadV2OpTest, TooManyDimensions) {
{TensorType_FLOAT32, {1, 2, 3, 4, 5, 6, 7, 8, 9}}, {9, 2},
{1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9}, 0.0,
{TensorType_FLOAT32}),
- "dims != 4");
+ "dims <= 4");
}
TEST(PadV2OpTest, UnequalDimensions) {
diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc
index 7b859dc332..c66959fdf4 100644
--- a/tensorflow/contrib/lite/kernels/register.cc
+++ b/tensorflow/contrib/lite/kernels/register.cc
@@ -22,8 +22,10 @@ namespace ops {
namespace custom {
TfLiteRegistration* Register_AUDIO_SPECTROGRAM();
+TfLiteRegistration* Register_LAYER_NORM_LSTM();
TfLiteRegistration* Register_MFCC();
TfLiteRegistration* Register_DETECTION_POSTPROCESS();
+TfLiteRegistration* Register_RELU_1();
} // namespace custom
@@ -247,6 +249,8 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddCustom("Mfcc", tflite::ops::custom::Register_MFCC());
AddCustom("AudioSpectrogram",
tflite::ops::custom::Register_AUDIO_SPECTROGRAM());
+ AddCustom("LayerNormLstm", tflite::ops::custom::Register_LAYER_NORM_LSTM());
+ AddCustom("Relu1", tflite::ops::custom::Register_RELU_1());
AddCustom("TFLite_Detection_PostProcess",
tflite::ops::custom::Register_DETECTION_POSTPROCESS());
}
diff --git a/tensorflow/contrib/lite/kernels/relu1.cc b/tensorflow/contrib/lite/kernels/relu1.cc
new file mode 100644
index 0000000000..abafee2d57
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/relu1.cc
@@ -0,0 +1,59 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace relu1 {
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+ const TfLiteTensor* input = GetInput(context, node, 0);
+ TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ output->type = input->type;
+ return context->ResizeTensor(context, output,
+ TfLiteIntArrayCopy(input->dims));
+}
+
+// This is derived from lite/kernels/activations.cc.
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* input = GetInput(context, node, 0);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ const int elements = NumElements(input);
+ const float* in = input->data.f;
+ const float* in_end = in + elements;
+ float* out = output->data.f;
+ for (; in < in_end; ++in, ++out) {
+ *out = std::min(std::max(0.f, *in), 1.f);
+ }
+ return kTfLiteOk;
+}
+
+} // namespace relu1
+
+TfLiteRegistration* Register_RELU_1() {
+ static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
+ relu1::Prepare, relu1::Eval};
+ return &r;
+}
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/relu1_test.cc b/tensorflow/contrib/lite/kernels/relu1_test.cc
new file mode 100644
index 0000000000..c1e0149c20
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/relu1_test.cc
@@ -0,0 +1,79 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "flatbuffers/flexbuffers.h" // flatbuffers
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+TfLiteRegistration* Register_RELU_1();
+
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class BaseActivationsOpModel : public SingleOpModel {
+ public:
+ explicit BaseActivationsOpModel(const TensorData& input) {
+ input_ = AddInput(input);
+ output_ = AddOutput({input.type, {}});
+ flexbuffers::Builder fbb;
+ fbb.Map([&]() {});
+ fbb.Finish();
+ SetCustomOp("RELU_1", fbb.GetBuffer(), Register_RELU_1);
+ BuildInterpreter({GetShape(input_)});
+ }
+
+ protected:
+ int input_;
+ int output_;
+};
+
+class FloatActivationsOpModel : public BaseActivationsOpModel {
+ public:
+ using BaseActivationsOpModel::BaseActivationsOpModel;
+
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor(input_, data);
+ }
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+};
+
+TEST(FloatActivationsOpTest, Relu1) {
+ FloatActivationsOpModel m(/*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
+ m.SetInput({
+ 0.0, -0.6, 0.2, -0.4, //
+ 0.3, -2.0, 1.1, -0.1, //
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({
+ 0.0, 0.0, 0.2, 0.0, //
+ 0.3, 0.0, 1.0, 0.0, //
+ }));
+}
+
+} // namespace
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 57134ccd15..32f02a4f6c 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -1679,6 +1679,7 @@ def make_pad_tests(zip_path):
# TODO(nupurgarg): Add test for tf.uint8.
test_parameters = [
+ # 4D:
{
"dtype": [tf.int32, tf.int64, tf.float32],
"input_shape": [[1, 1, 2, 1], [2, 1, 1, 1]],
@@ -1686,13 +1687,20 @@ def make_pad_tests(zip_path):
[0, 0], [2, 3]]],
"constant_paddings": [True, False],
},
- # Non-4D use case.
+ # 2D:
{
"dtype": [tf.int32, tf.int64, tf.float32],
- "input_shape": [[1, 2], [0, 1, 2]],
+ "input_shape": [[1, 2]],
"paddings": [[[0, 1], [2, 3]]],
"constant_paddings": [True, False],
},
+ # 1D:
+ {
+ "dtype": [tf.int32],
+ "input_shape": [[1]],
+ "paddings": [[[1, 2]]],
+ "constant_paddings": [False],
+ },
]
def build_graph(parameters):
@@ -1730,6 +1738,7 @@ def make_padv2_tests(zip_path):
# TODO(nupurgarg): Add test for tf.uint8.
test_parameters = [
+ # 4D:
{
"dtype": [tf.int32, tf.int64, tf.float32],
"input_shape": [[1, 1, 2, 1], [2, 1, 1, 1]],
@@ -1738,14 +1747,22 @@ def make_padv2_tests(zip_path):
"constant_paddings": [True, False],
"constant_values": [0, 2],
},
- # Non-4D use case.
+ # 2D:
{
"dtype": [tf.int32, tf.int64, tf.float32],
- "input_shape": [[1, 2], [0, 1, 2]],
+ "input_shape": [[1, 2]],
"paddings": [[[0, 1], [2, 3]]],
"constant_paddings": [True, False],
"constant_values": [0, 2],
},
+ # 1D:
+ {
+ "dtype": [tf.int32],
+ "input_shape": [[1]],
+ "paddings": [[[0, 1]]],
+ "constant_paddings": [False],
+ "constant_values": [0, 2],
+ },
]
def build_graph(parameters):
diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
index 37c7ae0e1c..349aa5a3b4 100644
--- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
+++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
@@ -58,12 +58,6 @@ tensorflow::Env* env = tensorflow::Env::Default();
// Key is a substring of the test name and value is a bug number.
// TODO(ahentz): make sure we clean this list up frequently.
std::map<string, string> kBrokenTests = {
- // Pad and PadV2 only supports 4D tensors.
- {R"(^\/pad.*,input_shape=\[.,.\],paddings=\[\[.,.\],\[.,.\]\])",
- "70527055"},
- {R"(^\/padv2.*,input_shape=\[.,.\],paddings=\[\[.,.\],\[.,.\]\])",
- "70527055"},
-
// L2Norm only supports tensors with 4D or fewer.
{R"(^\/l2norm_dim=.*,epsilon=.*,input_shape=\[.,.,.,.,.*\])", "67963684"},
diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h
index 84f71dc7a7..f14dbc258b 100644
--- a/tensorflow/contrib/lite/toco/args.h
+++ b/tensorflow/contrib/lite/toco/args.h
@@ -247,6 +247,10 @@ struct ParsedTocoFlags {
Arg<bool> allow_nudging_weights_to_use_fast_gemm_kernel = Arg<bool>(false);
Arg<int64> dedupe_array_min_size_bytes = Arg<int64>(64);
Arg<bool> split_tflite_lstm_inputs = Arg<bool>(true);
+ // WARNING: Experimental interface, subject to change
+ Arg<bool> allow_eager_ops = Arg<bool>(false);
+ // WARNING: Experimental interface, subject to change
+ Arg<bool> force_eager_ops = Arg<bool>(false);
};
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
index 502de88f7c..3114fa93e8 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
@@ -63,6 +63,25 @@ bool HardcodeMinMaxForL2Normalization(Model* model, Operator* op) {
return true;
}
+bool HardcodeInputMinMaxFromOutput(Model* model, Operator* op) {
+ auto& input = model->GetArray(op->inputs[0]);
+ if (input.minmax) {
+ const auto* minmax = input.minmax.get();
+ if (minmax) {
+ return false;
+ }
+ }
+ auto& output = model->GetArray(op->outputs[0]);
+ if (output.minmax) {
+ const auto* minmax = model->GetArray(op->outputs[0]).minmax.get();
+ if (minmax) {
+ input.GetOrCreateMinMax() = *minmax;
+ return true;
+ }
+ }
+ return false;
+}
+
bool HardcodeMinMaxForConcatenation(Model* model, Operator* op) {
// Do not early return if the output already has min/max:
// we may still need to adjust the inputs min/max.
@@ -366,6 +385,16 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) {
changed = HardcodeMinMaxForL2Normalization(model, op);
break;
+ case OperatorType::kRelu:
+ // For any normalization other than batch norm, the quantizations ranges
+ // before and after relu are expected to be known. Having a quantization
+ // op before relu would reduce the number of bits of precision for the
+ // activation in half. So we deduce the range before relu from that after
+ // the relu. This would eliminate the need for two fake quantization nodes
+ // and would not reduce the bits of precision available for activation.
+ changed = HardcodeInputMinMaxFromOutput(model, op);
+ break;
+
case OperatorType::kConcatenation:
changed = HardcodeMinMaxForConcatenation(model, op);
break;
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index cb6da21039..9bc23c4b3c 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -2061,8 +2061,14 @@ std::unique_ptr<Model> ImportTensorFlowGraphDef(
}
Model* model = new Model;
- const internal::ConverterMapType& converter_map =
- internal::GetTensorFlowNodeConverterMap();
+ internal::ConverterMapType converter_map;
+
+ // This is used for the TFLite "Full Eager Mode" conversion. All the ops are
+ // imported as `TensorFlowUnsupportedOperator`, and later all these ops are
+ // converted to TFLite Eager ops.
+ if (!tf_import_flags.import_all_ops_as_unsupported) {
+ converter_map = internal::GetTensorFlowNodeConverterMap();
+ }
for (auto node : inlined_graph.node()) {
StripZeroOutputIndexFromInputs(&node);
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.h b/tensorflow/contrib/lite/toco/import_tensorflow.h
index 2177872334..7db23f2d44 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.h
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.h
@@ -27,6 +27,11 @@ struct TensorFlowImportFlags {
// If true, control dependencies will be dropped immediately
// during the import of the TensorFlow GraphDef.
bool drop_control_dependency = false;
+
+ // Do not recognize any op and import all ops as
+ // `TensorFlowUnsupportedOperator`. This is used to populated with the
+ // `force_eager_ops` flag.
+ bool import_all_ops_as_unsupported = false;
};
std::unique_ptr<Model> ImportTensorFlowGraphDef(
diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc
index c79469f59b..fee10b1dff 100644
--- a/tensorflow/contrib/lite/toco/tflite/export.cc
+++ b/tensorflow/contrib/lite/toco/tflite/export.cc
@@ -49,12 +49,21 @@ namespace {
details::OperatorKey GetOperatorKey(
const ::toco::Operator& op,
- const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) {
+ const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
+ bool allow_eager_ops) {
string custom_code;
if (op.type == OperatorType::kUnsupported) {
const TensorFlowUnsupportedOperator& unsupported_op =
static_cast<const TensorFlowUnsupportedOperator&>(op);
- custom_code = unsupported_op.tensorflow_op;
+
+ // TODO(b/113715895): When `allow_eager_ops` is on, for now there's no way
+ // to populate a regular custom op. We need to find a way to fix this.
+ if (allow_eager_ops) {
+ custom_code = string(::tflite::kEagerCustomCodePrefix) +
+ unsupported_op.tensorflow_op;
+ } else {
+ custom_code = unsupported_op.tensorflow_op;
+ }
}
int version = 1;
if (ops_by_type.count(op.type) != 0) {
@@ -91,11 +100,12 @@ void LoadTensorsMap(const Model& model, TensorsMap* tensors_map) {
void LoadOperatorsMap(
const Model& model, OperatorsMap* operators_map,
- const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) {
+ const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
+ bool allow_eager_ops) {
// First find a list of unique operator types.
std::set<OperatorKey> keys;
for (const auto& op : model.operators) {
- keys.insert(GetOperatorKey(*op, ops_by_type));
+ keys.insert(GetOperatorKey(*op, ops_by_type, allow_eager_ops));
}
// Now assign indices to them and fill in the map.
int index = 0;
@@ -189,7 +199,7 @@ Offset<Vector<Offset<OperatorCode>>> ExportOperatorCodes(
const Model& model,
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
const details::OperatorsMap& operators_map, FlatBufferBuilder* builder,
- std::set<string>* error_summary) {
+ std::set<string>* error_summary, const ExportParams& params) {
// Map from operator name to TF Lite enum value, for all builtins.
std::map<string, BuiltinOperator> builtin_ops;
for (int i = BuiltinOperator_MIN; i <= BuiltinOperator_MAX; ++i) {
@@ -205,7 +215,8 @@ Offset<Vector<Offset<OperatorCode>>> ExportOperatorCodes(
std::map<int, Offset<OperatorCode>> ordered_opcodes;
for (const auto& op : model.operators) {
- const details::OperatorKey operator_key = GetOperatorKey(*op, ops_by_type);
+ const details::OperatorKey operator_key =
+ GetOperatorKey(*op, ops_by_type, params.allow_eager_ops);
int op_index = operators_map.at(operator_key);
int op_version = operator_key.version;
@@ -252,7 +263,7 @@ Offset<Vector<Offset<Operator>>> ExportOperators(
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
const details::OperatorsMap& operators_map,
const details::TensorsMap& tensors_map, FlatBufferBuilder* builder,
- std::set<int32_t>* variable_tensor_indices) {
+ std::set<int32_t>* variable_tensor_indices, const ExportParams& params) {
variable_tensor_indices->clear();
// The operators are in execution order, so we just follow tf.mini order.
@@ -269,7 +280,8 @@ Offset<Vector<Offset<Operator>>> ExportOperators(
outputs.push_back(tensors_map.at(output));
}
- int op_index = operators_map.at(GetOperatorKey(*op, ops_by_type));
+ int op_index = operators_map.at(
+ GetOperatorKey(*op, ops_by_type, params.allow_eager_ops));
auto tflite_op_it = ops_by_type.find(op->type);
BaseOperator* tflite_op = tflite_op_it == ops_by_type.end()
@@ -320,16 +332,15 @@ Offset<Vector<Offset<Buffer>>> ExportBuffers(
return builder->CreateVector(buffer_vector);
}
-void Export(const Model& model, bool allow_custom_ops, bool quantize_weights,
- string* output_file_contents) {
- const auto ops_by_type = BuildOperatorByTypeMap();
- Export(model, allow_custom_ops, quantize_weights, output_file_contents,
- ops_by_type);
+void Export(const Model& model, string* output_file_contents,
+ const ExportParams& params) {
+ const auto ops_by_type = BuildOperatorByTypeMap(params.allow_eager_ops);
+ Export(model, output_file_contents, params, ops_by_type);
}
void Export(
- const Model& model, bool allow_custom_ops, bool quantize_weights,
- string* output_file_contents,
+ const Model& model, string* output_file_contents,
+ const ExportParams& params,
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) {
flatbuffers::FlatBufferBuilder builder(/*initial_size=*/10240);
@@ -337,7 +348,8 @@ void Export(
details::LoadTensorsMap(model, &tensors_map);
details::OperatorsMap operators_map;
- details::LoadOperatorsMap(model, &operators_map, ops_by_type);
+ details::LoadOperatorsMap(model, &operators_map, ops_by_type,
+ params.allow_eager_ops);
std::vector<const Array*> buffers_to_write;
Array empty_array;
@@ -345,7 +357,7 @@ void Export(
std::set<string> error_summary;
auto op_codes = ExportOperatorCodes(model, ops_by_type, operators_map,
- &builder, &error_summary);
+ &builder, &error_summary, params);
for (const auto& op : model.operators) {
if (op->type == OperatorType::kFakeQuant) {
@@ -355,7 +367,7 @@ void Export(
"for --std_values and --mean_values.";
}
}
- if (!allow_custom_ops && !error_summary.empty()) {
+ if (!params.allow_custom_ops && !error_summary.empty()) {
// Remove ExpandDims and ReorderAxes from unimplemented list unless they
// compose the list. Both ops are removed during graph transformations.
// However, if an op is unimplemented earlier in the model, the graph
@@ -383,7 +395,7 @@ void Export(
std::set<int32_t> variable_tensor_indices;
auto ops = ExportOperators(model, ops_by_type, operators_map, tensors_map,
- &builder, &variable_tensor_indices);
+ &builder, &variable_tensor_indices, params);
auto tensors = ExportTensors(model, tensors_map, &builder, &buffers_to_write,
variable_tensor_indices);
@@ -402,7 +414,7 @@ void Export(
builder.CreateVector(subgraphs), description, buffers);
::tflite::FinishModelBuffer(builder, new_model_location);
- if (quantize_weights) {
+ if (params.quantize_weights) {
// Call the quantize_weights tool.
LOG(INFO) << "Quantizing TFLite model after conversion to flatbuffer. "
"dump_graphviz will only output the model before this "
diff --git a/tensorflow/contrib/lite/toco/tflite/export.h b/tensorflow/contrib/lite/toco/tflite/export.h
index 915d5dd3d6..b070a38768 100644
--- a/tensorflow/contrib/lite/toco/tflite/export.h
+++ b/tensorflow/contrib/lite/toco/tflite/export.h
@@ -23,22 +23,54 @@ namespace toco {
namespace tflite {
+// The parameters for exporting a TFLite model.
+struct ExportParams {
+ bool allow_custom_ops = false;
+ bool allow_eager_ops = false;
+ bool quantize_weights = false;
+};
+
// Transform the given tf.mini model into a TF Lite flatbuffer and deposit the
// result in the given string.
-void Export(const Model& model, bool allow_custom_ops, bool quantize_weights,
- string* output_file_contents);
+void Export(const Model& model, string* output_file_contents,
+ const ExportParams& params);
+
+// Export API with custom TFLite operator mapping.
+void Export(
+ const Model& model, string* output_file_contents,
+ const ExportParams& params,
+ const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type);
-// This if backward-compatibility.
+// This is for backward-compatibility.
// TODO(ycling): Remove the deprecated entry functions.
-inline void Export(const Model& model, string* output_file_contents) {
- Export(model, true, false, output_file_contents);
+inline void Export(const Model& model, bool allow_custom_ops,
+ bool quantize_weights, string* output_file_contents) {
+ ExportParams params;
+ params.allow_custom_ops = allow_custom_ops;
+ params.quantize_weights = quantize_weights;
+ Export(model, output_file_contents, params);
}
-// Export API with custom TFLite operator mapping.
-void Export(
+// This is for backward-compatibility.
+// TODO(ycling): Remove the deprecated entry functions.
+inline void Export(
const Model& model, bool allow_custom_ops, bool quantize_weights,
string* output_file_contents,
- const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type);
+ const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) {
+ ExportParams params;
+ params.allow_custom_ops = allow_custom_ops;
+ params.quantize_weights = quantize_weights;
+ Export(model, output_file_contents, params, ops_by_type);
+}
+
+// This is for backward-compatibility.
+// TODO(ycling): Remove the deprecated entry functions.
+inline void Export(const Model& model, string* output_file_contents) {
+ ExportParams params;
+ params.allow_custom_ops = true;
+ Export(model, output_file_contents, params);
+ Export(model, true, false, output_file_contents);
+}
namespace details {
@@ -88,7 +120,8 @@ using OperatorsMap = std::unordered_map<OperatorKey, int, OperatorKey::Hash>;
void LoadTensorsMap(const Model& model, TensorsMap* tensors_map);
void LoadOperatorsMap(
const Model& model, OperatorsMap* operators_map,
- const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type);
+ const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
+ bool allow_eager_ops);
} // namespace details
} // namespace tflite
diff --git a/tensorflow/contrib/lite/toco/tflite/export_test.cc b/tensorflow/contrib/lite/toco/tflite/export_test.cc
index 4994ea30de..8d4d197c46 100644
--- a/tensorflow/contrib/lite/toco/tflite/export_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/export_test.cc
@@ -105,7 +105,8 @@ TEST_F(ExportTest, LoadOperatorsMap) {
details::OperatorsMap operators;
const auto ops_by_type = BuildOperatorByTypeMap();
- details::LoadOperatorsMap(input_model_, &operators, ops_by_type);
+ // TODO(ycling): Add a test for allow_eager_ops.
+ details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
EXPECT_EQ(0, operators[details::OperatorKey(OperatorType::kAdd, "", 1)]);
EXPECT_EQ(1, operators[details::OperatorKey(OperatorType::kConv, "", 1)]);
EXPECT_EQ(2, operators[details::OperatorKey(OperatorType::kSub, "", 1)]);
@@ -253,7 +254,7 @@ TEST_F(VersionedOpExportTest, LoadOperatorsMapWithOpV1) {
details::OperatorsMap operators;
const auto ops_by_type = BuildFakeOperatorByTypeMap();
- details::LoadOperatorsMap(input_model_, &operators, ops_by_type);
+ details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
EXPECT_EQ(1, operators.size());
EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 1)));
@@ -264,7 +265,7 @@ TEST_F(VersionedOpExportTest, LoadOperatorsMapWithOpV2) {
details::OperatorsMap operators;
const auto ops_by_type = BuildFakeOperatorByTypeMap();
- details::LoadOperatorsMap(input_model_, &operators, ops_by_type);
+ details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
EXPECT_EQ(1, operators.size());
EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 2)));
@@ -276,7 +277,7 @@ TEST_F(VersionedOpExportTest, LoadOperatorsMapWithBothVersions) {
details::OperatorsMap operators;
const auto ops_by_type = BuildFakeOperatorByTypeMap();
- details::LoadOperatorsMap(input_model_, &operators, ops_by_type);
+ details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
EXPECT_EQ(2, operators.size());
EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 1)));
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index a314c8d53a..eb0f7c443a 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -1149,7 +1149,9 @@ class Unpack : public BuiltinOperator<UnpackOperator, ::tflite::UnpackOptions,
class TensorFlowUnsupported : public BaseOperator {
public:
- using BaseOperator::BaseOperator;
+ TensorFlowUnsupported(const string& name, OperatorType type,
+ bool allow_eager_ops)
+ : BaseOperator(name, type), allow_eager_ops_(allow_eager_ops) {}
Options Serialize(const Operator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
@@ -1165,6 +1167,9 @@ class TensorFlowUnsupported : public BaseOperator {
std::unique_ptr<Operator> Deserialize(
const BuiltinOptions* builtin_options,
const CustomOptions* custom_options) const override {
+ // Deserializing Eager ops doesn't work now.
+ // TODO(ycling): Revisit and decide if we should fix the flow for importing
+ // TFLite models with Eager ops.
auto op = absl::make_unique<TensorFlowUnsupportedOperator>();
if (custom_options) {
auto flexbuffer_map =
@@ -1185,6 +1190,16 @@ class TensorFlowUnsupported : public BaseOperator {
return std::unique_ptr<flexbuffers::Builder>();
}
+ if (allow_eager_ops_) {
+ fbb->Vector([&]() {
+ fbb->String(node_def.op());
+ fbb->String(op.tensorflow_node_def);
+ });
+ fbb->Finish();
+ LOG(INFO) << "Writing eager op: " << node_def.op();
+ return std::unique_ptr<flexbuffers::Builder>(fbb.release());
+ }
+
bool has_valid_attr = false;
size_t map_start = fbb->StartMap();
for (const auto& pair : node_def.attr()) {
@@ -1285,11 +1300,15 @@ class TensorFlowUnsupported : public BaseOperator {
// custom ops.
return 1;
}
+
+ private:
+ const bool allow_eager_ops_;
};
namespace {
// Build a vector containing all the known operators.
-std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
+std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
+ bool allow_eager_ops = false) {
std::vector<std::unique_ptr<BaseOperator>> ops;
using tensorflow::MakeUnique;
// Builtin Operators.
@@ -1400,8 +1419,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
MakeUnique<DepthToSpace>("DEPTH_TO_SPACE", OperatorType::kDepthToSpace));
ops.push_back(MakeUnique<CTCBeamSearchDecoder>(
"CTC_BEAM_SEARCH_DECODER", OperatorType::kCTCBeamSearchDecoder));
- ops.push_back(MakeUnique<TensorFlowUnsupported>("TENSORFLOW_UNSUPPORTED",
- OperatorType::kUnsupported));
+ ops.push_back(MakeUnique<TensorFlowUnsupported>(
+ "TENSORFLOW_UNSUPPORTED", OperatorType::kUnsupported, allow_eager_ops));
// There operators are supported by Toco, but not by TF Lite, and has no
// attributes.
@@ -1474,10 +1493,12 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
}
} // namespace
-std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap() {
+std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap(
+ bool allow_eager_ops) {
std::map<OperatorType, std::unique_ptr<BaseOperator>> result;
- std::vector<std::unique_ptr<BaseOperator>> ops = BuildOperatorList();
+ std::vector<std::unique_ptr<BaseOperator>> ops =
+ BuildOperatorList(allow_eager_ops);
for (auto& op : ops) {
result[op->type()] = std::move(op);
}
@@ -1485,10 +1506,12 @@ std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap() {
return result;
}
-std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap() {
+std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap(
+ bool allow_eager_ops) {
std::map<string, std::unique_ptr<BaseOperator>> result;
- std::vector<std::unique_ptr<BaseOperator>> ops = BuildOperatorList();
+ std::vector<std::unique_ptr<BaseOperator>> ops =
+ BuildOperatorList(allow_eager_ops);
for (auto& op : ops) {
result[op->name()] = std::move(op);
}
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.h b/tensorflow/contrib/lite/toco/tflite/operator.h
index d9ea23edf2..702fb28ea6 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.h
+++ b/tensorflow/contrib/lite/toco/tflite/operator.h
@@ -26,11 +26,15 @@ namespace tflite {
class BaseOperator;
// Return a map contained all know TF Lite Operators, keyed by their names.
-std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap();
+// TODO(ycling): The pattern to propagate parameters (e.g. allow_eager_ops)
+// is ugly here. Consider refactoring.
+std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap(
+ bool allow_eager_ops = false);
// Return a map contained all know TF Lite Operators, keyed by the type of
// their tf.mini counterparts.
-std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap();
+std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap(
+ bool allow_eager_ops = false);
// These are the flatbuffer types for custom and builtin options.
using CustomOptions = flatbuffers::Vector<uint8_t>;
diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
index f83a290195..b6aebc0470 100644
--- a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
+++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
@@ -165,7 +165,13 @@ bool ParseTocoFlagsFromCommandLineFlags(
parsed_flags.post_training_quantize.default_value(),
"Boolean indicating whether to quantize the weights of the "
"converted float model. Model size will be reduced and there will "
- "be latency improvements (at the cost of accuracy).")};
+ "be latency improvements (at the cost of accuracy)."),
+ // WARNING: Experimental interface, subject to change
+ Flag("allow_eager_ops", parsed_flags.allow_eager_ops.bind(),
+ parsed_flags.allow_eager_ops.default_value(), ""),
+ // WARNING: Experimental interface, subject to change
+ Flag("force_eager_ops", parsed_flags.force_eager_ops.bind(),
+ parsed_flags.force_eager_ops.default_value(), "")};
bool asked_for_help =
*argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
if (asked_for_help) {
@@ -260,6 +266,16 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags,
READ_TOCO_FLAG(split_tflite_lstm_inputs, FlagRequirement::kNone);
READ_TOCO_FLAG(quantize_weights, FlagRequirement::kNone);
READ_TOCO_FLAG(post_training_quantize, FlagRequirement::kNone);
+ READ_TOCO_FLAG(allow_eager_ops, FlagRequirement::kNone);
+ READ_TOCO_FLAG(force_eager_ops, FlagRequirement::kNone);
+
+ if (parsed_toco_flags.force_eager_ops.value() &&
+ !parsed_toco_flags.allow_eager_ops.value()) {
+ // TODO(ycling): Consider to enforce `allow_eager_ops` when
+ // `force_eager_ops` is true.
+ LOG(WARNING) << "--force_eager_ops should always be used with "
+ "--allow_eager_ops.";
+ }
// Deprecated flag handling.
if (parsed_toco_flags.input_type.specified()) {
diff --git a/tensorflow/contrib/lite/toco/toco_flags.proto b/tensorflow/contrib/lite/toco/toco_flags.proto
index c1dd621429..53d60fed05 100644
--- a/tensorflow/contrib/lite/toco/toco_flags.proto
+++ b/tensorflow/contrib/lite/toco/toco_flags.proto
@@ -37,7 +37,7 @@ enum FileFormat {
// of as properties of models, instead describing how models are to be
// processed in the context of the present tooling job.
//
-// Next ID to use: 27.
+// Next ID to use: 29.
message TocoFlags {
// Input file format
optional FileFormat input_format = 1;
@@ -189,4 +189,17 @@ message TocoFlags {
// model. Model size will be reduced and there will be latency improvements
// (at the cost of accuracy).
optional bool post_training_quantize = 26 [default = false];
+
+ // When enabled, unsupported ops will be converted to TFLite Eager ops.
+ // TODO(ycling): Consider to rename the following 2 flags and don't call it
+ // "Eager".
+ // `allow_eager_ops` should always be used with `allow_custom_ops`.
+ // WARNING: Experimental interface, subject to change
+ optional bool allow_eager_ops = 27 [default = false];
+
+ // When enabled, all TensorFlow ops will be converted to TFLite Eager
+ // ops directly. This will force `allow_eager_ops` to true.
+ // `force_eager_ops` should always be used with `allow_eager_ops`.
+ // WARNING: Experimental interface, subject to change
+ optional bool force_eager_ops = 28 [default = false];
}
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
index 7db7acb44d..a7c17156b1 100644
--- a/tensorflow/contrib/lite/toco/toco_tooling.cc
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -197,6 +197,10 @@ std::unique_ptr<Model> Import(const TocoFlags& toco_flags,
toco_flags.has_drop_control_dependency()
? toco_flags.drop_control_dependency()
: (toco_flags.output_format() != TENSORFLOW_GRAPHDEF);
+
+ tf_import_flags.import_all_ops_as_unsupported =
+ toco_flags.force_eager_ops();
+
model = ImportTensorFlowGraphDef(model_flags, tf_import_flags,
input_file_contents);
break;
@@ -397,11 +401,21 @@ void Export(const TocoFlags& toco_flags, const Model& model,
case TENSORFLOW_GRAPHDEF:
ExportTensorFlowGraphDef(model, output_file_contents);
break;
- case TFLITE:
- toco::tflite::Export(model, allow_custom_ops,
- toco_flags.post_training_quantize(),
- output_file_contents);
- break;
+ case TFLITE: {
+ toco::tflite::ExportParams params;
+
+ // Always allow custom ops when eager ops are allowed.
+ if (toco_flags.force_eager_ops() || toco_flags.allow_eager_ops()) {
+ params.allow_eager_ops = true;
+ params.allow_custom_ops = true;
+ } else if (allow_custom_ops) {
+ params.allow_custom_ops = true;
+ }
+
+ params.quantize_weights = toco_flags.post_training_quantize();
+
+ toco::tflite::Export(model, output_file_contents, params);
+ } break;
case GRAPHVIZ_DOT:
DumpGraphviz(model, output_file_contents);
break;
diff --git a/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc b/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc
index e0ed7c7946..e5bb3c990a 100644
--- a/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc
+++ b/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc
@@ -42,10 +42,9 @@ typedef struct {
bool eval_hybrid;
} TensorInfo;
-// The minimum number of elements a weights array must have to be quantized
-// by this transformation.
-// TODO(suharshs): Make this configurable.
-const int kWeightsMinSize = 1024;
+// The default minimum number of elements a weights array must have to be
+// quantized by this transformation.
+const int kWeightsMinNumElementsDefault = 1024;
// Nudge min and max so that floating point 0 falls exactly on a quantized
// value, returning the nudges scale and zero_point.
@@ -158,42 +157,45 @@ bool IsHybridEvaluationOp(const OperatorT* op, const BuiltinOperator& op_code) {
// Returns a vector of TensorInfos for each input tensor of op that should be
// quantized.
-std::vector<TensorInfo> GetQuantizableTensorsFromOperator(const ModelT* model,
- const OperatorT* op) {
+std::vector<TensorInfo> GetQuantizableTensorsFromOperator(
+ const ModelT* model, const OperatorT* op, uint64_t weights_min_num_elements,
+ bool use_hybrid_evaluation) {
SubGraphT* subgraph = model->subgraphs.at(0).get();
const BuiltinOperator op_code =
model->operator_codes[op->opcode_index]->builtin_code;
std::vector<TensorInfo> tensor_infos;
- bool eval_hybrid = IsHybridEvaluationOp(op, op_code);
+ bool eval_hybrid = use_hybrid_evaluation && IsHybridEvaluationOp(op, op_code);
bool skipped_tensor = false;
std::vector<int32_t> op_input_indices = GetWeightInputIndices(op_code);
for (const int32_t op_input_idx : op_input_indices) {
int32_t tensor_idx = op->inputs[op_input_idx];
+ TensorT* tensor = subgraph->tensors[tensor_idx].get();
// TODO(suharshs): Support shared weights, i.e. If two tensors share the
// same weight array, things may break. (i.e. SSD object detection)
- if (CountTensorConsumers(model, subgraph, tensor_idx) != 1) {
- LOG(INFO) << "Skipping quantization of tensor that is shared between "
- "multiple multiple operations.";
+ if (!eval_hybrid &&
+ CountTensorConsumers(model, subgraph, tensor_idx) != 1) {
+ LOG(INFO) << "Skipping quantization of tensor " << tensor->name
+ << " that is shared between multiple multiple operations.";
skipped_tensor = true;
continue;
}
- TensorT* tensor = subgraph->tensors[tensor_idx].get();
-
if (tensor->type != TensorType_FLOAT32) {
- LOG(INFO) << "Skipping quantization of tensor that is not type float.";
+ LOG(INFO) << "Skipping quantization of tensor " << tensor->name
+ << " that is not type float.";
skipped_tensor = true;
continue;
}
const uint64_t num_elements = NumElements(tensor);
- if (num_elements < kWeightsMinSize) {
- LOG(INFO) << "Skipping quantization of tensor because it has fewer than "
- << kWeightsMinSize << " elements (" << num_elements << ").";
+ if (num_elements < weights_min_num_elements) {
+ LOG(INFO) << "Skipping quantization of tensor " << tensor->name
+ << " because it has fewer than " << weights_min_num_elements
+ << " elements (" << num_elements << ").";
skipped_tensor = true;
continue;
}
@@ -331,11 +333,10 @@ void MakeTensor(const string& name, const std::vector<int32_t>& shape,
tensor->reset(tensor_raw);
}
-} // namespace
-
-TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
- const Model* input_model,
- bool use_hybrid_evaluation) {
+TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder,
+ const Model* input_model,
+ bool use_hybrid_evaluation,
+ uint64_t weights_min_num_elements) {
std::unique_ptr<ModelT> model;
model.reset(input_model->UnPack());
@@ -352,11 +353,11 @@ TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
for (int i = 0; i < subgraph->operators.size(); ++i) {
OperatorT* op = subgraph->operators[i].get();
- std::vector<TensorInfo> tensor_infos =
- GetQuantizableTensorsFromOperator(model.get(), op);
+ std::vector<TensorInfo> tensor_infos = GetQuantizableTensorsFromOperator(
+ model.get(), op, weights_min_num_elements, use_hybrid_evaluation);
for (const TensorInfo& tensor_info : tensor_infos) {
- if (use_hybrid_evaluation && tensor_info.eval_hybrid) {
+ if (tensor_info.eval_hybrid) {
// Quantize the tensor.
TF_LITE_ENSURE_STATUS(
SymmetricQuantizeTensor(model.get(), tensor_info.tensor));
@@ -399,9 +400,32 @@ TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
return kTfLiteOk;
}
+} // namespace
+
+namespace internal {
+TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
+ const Model* input_model,
+ bool use_hybrid_evaluation) {
+ // By default we require that only weights with more than
+ // kWeightsMinSizeDefault elements are quantized.
+ return QuantizeWeightsInternal(builder, input_model, use_hybrid_evaluation,
+ kWeightsMinNumElementsDefault);
+}
+} // namespace internal
+
+TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
+ const Model* input_model,
+ uint64_t weights_min_num_elements) {
+ return QuantizeWeightsInternal(builder, input_model, true,
+ weights_min_num_elements);
+}
+
TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
const Model* input_model) {
- return QuantizeWeights(builder, input_model, true);
+ // By default we require that only weights with more than
+ // kWeightsMinSizeDefault elements are quantized.
+ return QuantizeWeightsInternal(builder, input_model, true,
+ kWeightsMinNumElementsDefault);
}
} // namespace optimize
diff --git a/tensorflow/contrib/lite/tools/optimize/quantize_weights.h b/tensorflow/contrib/lite/tools/optimize/quantize_weights.h
index 3743c0ce53..706f10b87b 100644
--- a/tensorflow/contrib/lite/tools/optimize/quantize_weights.h
+++ b/tensorflow/contrib/lite/tools/optimize/quantize_weights.h
@@ -25,6 +25,8 @@ namespace tflite {
namespace optimize {
// Quantizes input_model and populates the provided builder with the new model.
+// By default only weights tensors weight more than 1024 elements will be
+// quantized.
//
// A tflite::Model can be obtained from the builder with:
// const uint8_t* buffer = builder->GetBufferPointer();
@@ -32,11 +34,22 @@ namespace optimize {
TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
const Model* input_model);
-// Same as above, but if use_hybrid_evaluation is false, will disable using
-// hybrid eval for operations that support it.
+// Same as above, but only weights with greater than or equal
+// weights_min_num_elements elements will be quantized.
+TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
+ const Model* input_model,
+ uint64_t weights_min_num_elements);
+
+namespace internal {
+// If use_hybrid_evaluation is false, will disable using hybrid eval for
+// operations that support it.
+//
+// We use this internal QuantizeWeights call to test models with hybrid
+// evaluation disabled.
TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
const Model* input_model,
bool use_hybrid_evaluation);
+} // namespace internal
} // namespace optimize
} // namespace tflite
diff --git a/tensorflow/contrib/lite/tools/optimize/quantize_weights_test.cc b/tensorflow/contrib/lite/tools/optimize/quantize_weights_test.cc
index efaf9929e9..387b3471c2 100644
--- a/tensorflow/contrib/lite/tools/optimize/quantize_weights_test.cc
+++ b/tensorflow/contrib/lite/tools/optimize/quantize_weights_test.cc
@@ -76,7 +76,8 @@ class QuantizeWeightsTest : public ::testing::Test {
void CheckWeights(const Model* input_model_packed,
const Model* output_model_packed,
- bool use_hybrid_evaluation) {
+ bool use_hybrid_evaluation,
+ uint64_t weights_min_num_elements = 1024) {
std::unique_ptr<ModelT> input_model;
input_model.reset(input_model_packed->UnPack());
@@ -113,8 +114,9 @@ class QuantizeWeightsTest : public ::testing::Test {
int tensor_size = GetElementsNum(tensor);
// If the tensor_size is less than 1024 we expect the tensor to remain
// unquantized.
- if (tensor_size < 1024) {
- ASSERT_TRUE(tensor->type == TensorType_FLOAT32) << tensor->name;
+ if (tensor_size < weights_min_num_elements) {
+ ASSERT_TRUE(tensor->type == TensorType_FLOAT32)
+ << tensor->name << " of type " << tensor->type;
const OperatorT* preceding_op = GetOpWithOutput(subgraph, tensor_idx);
// The weight tensor should not come from a dequantize op.
ASSERT_TRUE(preceding_op == nullptr);
@@ -183,7 +185,7 @@ TEST_F(QuantizeWeightsTest, SimpleTestWithoutHybrid) {
flatbuffers::FlatBufferBuilder builder;
// Disable hybrid evaluation.
- EXPECT_EQ(QuantizeWeights(&builder, input_model, false), kTfLiteOk);
+ EXPECT_EQ(internal::QuantizeWeights(&builder, input_model, false), kTfLiteOk);
const uint8_t* buffer = builder.GetBufferPointer();
const Model* output_model = GetModel(buffer);
@@ -191,6 +193,26 @@ TEST_F(QuantizeWeightsTest, SimpleTestWithoutHybrid) {
CheckWeights(input_model, output_model, false);
}
+TEST_F(QuantizeWeightsTest, SimpleTestWithWeightsMinNumElements) {
+ string model_path =
+ "third_party/tensorflow/contrib/lite/tools/optimize/testdata/"
+ "mobilenet_v1_0.25_128.tflite";
+ std::unique_ptr<FlatBufferModel> input_fb =
+ FlatBufferModel::BuildFromFile(model_path.data());
+ const Model* input_model = input_fb->GetModel();
+
+ flatbuffers::FlatBufferBuilder builder;
+ // Make weights_min_size sufficiently large such that no quantization should
+ // happen, i.e. the original model is the same size as the old one.
+ const uint64_t kWeightsMinNumElements = 1000000;
+ EXPECT_EQ(QuantizeWeights(&builder, input_model, kWeightsMinNumElements),
+ kTfLiteOk);
+
+ const uint8_t* buffer = builder.GetBufferPointer();
+ const Model* output_model = GetModel(buffer);
+ CheckWeights(input_model, output_model, true, kWeightsMinNumElements);
+}
+
// TODO(suharshs): Add tests that run the resulting model.
} // namespace
diff --git a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py
index bbafd59aae..6c203e5519 100644
--- a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py
+++ b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py
@@ -128,12 +128,14 @@ class ElasticAverageCustomGetter(object):
= list(global_center_variable)[i]
return local_var
else:
- return getter(
- name,
- trainable=trainable,
- collections=collections,
- *args,
- **kwargs)
+ kwargs['trainable'] = trainable
+ kwargs['collections'] = collections
+ if ops.GraphKeys.LOCAL_VARIABLES in collections:
+ with ops.device(self._worker_device):
+ return getter(name, *args, **kwargs)
+ else:
+ return getter(name, *args, **kwargs)
+
class ElasticAverageOptimizer(optimizer.Optimizer):
diff --git a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py
index 72117c1e81..f026f437dc 100644
--- a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py
+++ b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py
@@ -25,9 +25,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.training import adam
@@ -46,7 +48,12 @@ class LazyAdamOptimizer(adam.AdamOptimizer):
may lead to different empirical results.
"""
- def _apply_sparse(self, grad, var):
+ def _apply_sparse_shared(self,
+ grad,
+ var,
+ indices,
+ scatter_update,
+ scatter_sub):
beta1_power, beta2_power = self._get_beta_accumulators()
beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
@@ -58,23 +65,51 @@ class LazyAdamOptimizer(adam.AdamOptimizer):
# \\(m := beta1 * m + (1 - beta1) * g_t\\)
m = self.get_slot(var, "m")
- m_t = state_ops.scatter_update(m, grad.indices,
- beta1_t * array_ops.gather(m, grad.indices) +
- (1 - beta1_t) * grad.values,
- use_locking=self._use_locking)
+ m_t = scatter_update(m, indices,
+ beta1_t * array_ops.gather(m, indices) +
+ (1 - beta1_t) * grad)
# \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\)
v = self.get_slot(var, "v")
- v_t = state_ops.scatter_update(v, grad.indices,
- beta2_t * array_ops.gather(v, grad.indices) +
- (1 - beta2_t) * math_ops.square(grad.values),
- use_locking=self._use_locking)
+ v_t = scatter_update(v, indices,
+ beta2_t * array_ops.gather(v, indices) +
+ (1 - beta2_t) * math_ops.square(grad))
# \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\)
- m_t_slice = array_ops.gather(m_t, grad.indices)
- v_t_slice = array_ops.gather(v_t, grad.indices)
+ m_t_slice = array_ops.gather(m_t, indices)
+ v_t_slice = array_ops.gather(v_t, indices)
denominator_slice = math_ops.sqrt(v_t_slice) + epsilon_t
- var_update = state_ops.scatter_sub(var, grad.indices,
- lr * m_t_slice / denominator_slice,
- use_locking=self._use_locking)
+ var_update = scatter_sub(var, indices,
+ lr * m_t_slice / denominator_slice)
return control_flow_ops.group(var_update, m_t, v_t)
+
+ def _apply_sparse(self, grad, var):
+ return self._apply_sparse_shared(
+ grad.values, var, grad.indices,
+ self._scatter_update,
+ self._scatter_sub)
+
+ def _resource_apply_sparse(self, grad, var, indices):
+ return self._apply_sparse_shared(
+ grad, var, indices,
+ self._resource_scatter_update,
+ self._resource_scatter_sub)
+
+ # Utility functions for updating resource or non-resource variables.
+ def _scatter_update(self, x, i, v):
+ return state_ops.scatter_update(
+ x, i, v, use_locking=self._use_locking)
+
+ def _scatter_sub(self, x, i, v):
+ return state_ops.scatter_sub(
+ x, i, v, use_locking=self._use_locking)
+
+ def _resource_scatter_update(self, x, i, v):
+ update_op = resource_variable_ops.resource_scatter_update(x.handle, i, v)
+ with ops.control_dependencies([update_op]):
+ return x.value()
+
+ def _resource_scatter_sub(self, x, i, v):
+ sub_op = resource_variable_ops.resource_scatter_sub(x.handle, i, v)
+ with ops.control_dependencies([sub_op]):
+ return x.value()
diff --git a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py
index dc4c462ce4..d3e9e89502 100644
--- a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py
@@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -51,7 +52,7 @@ def adam_update_numpy(param,
class AdamOptimizerTest(test.TestCase):
- def testSparse(self):
+ def doTestSparse(self, use_resource=False):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
with self.cached_session():
# Initialize variables for numpy implementation.
@@ -61,8 +62,12 @@ class AdamOptimizerTest(test.TestCase):
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
- var0 = variables.Variable(var0_np)
- var1 = variables.Variable(var1_np)
+ if use_resource:
+ var0 = resource_variable_ops.ResourceVariable(var0_np)
+ var1 = resource_variable_ops.ResourceVariable(var1_np)
+ else:
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
grads0_np_indices = np.array([0, 1], dtype=np.int32)
grads0 = ops.IndexedSlices(
constant_op.constant(grads0_np),
@@ -94,6 +99,12 @@ class AdamOptimizerTest(test.TestCase):
self.assertAllCloseAccordingToType(var0_np, var0.eval())
self.assertAllCloseAccordingToType(var1_np, var1.eval())
+ def testSparse(self):
+ self.doTestSparse(use_resource=False)
+
+ def testResourceSparse(self):
+ self.doTestSparse(use_resource=True)
+
def testSparseDevicePlacement(self):
for index_dtype in [dtypes.int32, dtypes.int64]:
with self.test_session(force_gpu=test.is_gpu_available()):
diff --git a/tensorflow/contrib/opt/python/training/model_average_optimizer.py b/tensorflow/contrib/opt/python/training/model_average_optimizer.py
index b6b10e500b..746df77ba2 100644
--- a/tensorflow/contrib/opt/python/training/model_average_optimizer.py
+++ b/tensorflow/contrib/opt/python/training/model_average_optimizer.py
@@ -89,7 +89,13 @@ class ModelAverageCustomGetter(object):
self._local_2_global[local_var] = global_variable
return local_var
else:
- return getter(name, trainable, collections, *args, **kwargs)
+ kwargs['trainable'] = trainable
+ kwargs['collections'] = collections
+ if ops.GraphKeys.LOCAL_VARIABLES in collections:
+ with ops.device(self._worker_device):
+ return getter(name, *args, **kwargs)
+ else:
+ return getter(name, *args, **kwargs)
class ModelAverageOptimizer(optimizer.Optimizer):
diff --git a/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py b/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py
index 3acd940268..b1fc50a21f 100644
--- a/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py
@@ -80,28 +80,28 @@ def _get_workers(num_workers, steps, workers):
var_0 = variable_scope.get_variable(initializer=0.0, name="v0")
var_1 = variable_scope.get_variable(initializer=1.0, name="v1")
- with ops.device("/job:worker/task:" + str(worker_id)):
- if worker_id == 0:
- grads_0 = constant_op.constant(-1.0)
- grads_1 = constant_op.constant(-1.0)
- else:
- grads_0 = constant_op.constant(-2.0)
- grads_1 = constant_op.constant(-2.0)
- sgd_opt = gradient_descent.GradientDescentOptimizer(1.0)
- opt = model_average_optimizer.ModelAverageOptimizer(
- opt=sgd_opt,
- num_worker=num_workers,
- ma_custom_getter=ma_coustom,
- is_chief=is_chief,
- interval_steps=steps)
- train_op = [
- opt.apply_gradients([[grads_0, var_0], [grads_1, var_1]],
- global_step)
- ]
- easgd_hook = opt.make_session_run_hook()
+ with ops.device("/job:worker/task:" + str(worker_id)):
+ if worker_id == 0:
+ grads_0 = constant_op.constant(-1.0)
+ grads_1 = constant_op.constant(-1.0)
+ else:
+ grads_0 = constant_op.constant(-2.0)
+ grads_1 = constant_op.constant(-2.0)
+ sgd_opt = gradient_descent.GradientDescentOptimizer(1.0)
+ opt = model_average_optimizer.ModelAverageOptimizer(
+ opt=sgd_opt,
+ num_worker=num_workers,
+ ma_custom_getter=ma_coustom,
+ is_chief=is_chief,
+ interval_steps=steps)
+ train_op = [
+ opt.apply_gradients([[grads_0, var_0], [grads_1, var_1]],
+ global_step)
+ ]
+ ma_hook = opt.make_session_run_hook()
# Creates MonitoredSession
sess = training.MonitoredTrainingSession(
- workers[worker_id].target, hooks=[easgd_hook])
+ workers[worker_id].target, hooks=[ma_hook])
sessions.append(sess)
graphs.append(graph)
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
index 15ce9d1ce7..be0306cb07 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
@@ -48,7 +48,7 @@ Linear = core_rnn_cell._Linear # pylint: disable=invalid-name
class RNNCellTest(test.TestCase):
def testLinear(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(1.0)):
x = array_ops.zeros([1, 2])
@@ -69,7 +69,7 @@ class RNNCellTest(test.TestCase):
self.assertEqual(len(variables_lib.trainable_variables()), 2)
def testBasicRNNCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -89,7 +89,7 @@ class RNNCellTest(test.TestCase):
self.assertEqual(res[0].shape, (1, 2))
def testBasicRNNCellNotTrainable(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
def not_trainable_getter(getter, *args, **kwargs):
kwargs["trainable"] = False
@@ -116,7 +116,7 @@ class RNNCellTest(test.TestCase):
self.assertEqual(res[0].shape, (1, 2))
def testIndRNNCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -137,7 +137,7 @@ class RNNCellTest(test.TestCase):
self.assertEqual(res[0].shape, (1, 2))
def testGRUCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -165,7 +165,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[0], [[0.156736, 0.156736]])
def testIndyGRUCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -193,7 +193,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[0], [[0.155127, 0.157328]])
def testSRUCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -208,7 +208,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[0], [[0.509682, 0.509682]])
def testSRUCellWithDiffSize(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
@@ -288,7 +288,7 @@ class RNNCellTest(test.TestCase):
def testBasicLSTMCellDimension0Error(self):
"""Tests that dimension 0 in both(x and m) shape must be equal."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
num_units = 2
@@ -309,7 +309,7 @@ class RNNCellTest(test.TestCase):
def testBasicLSTMCellStateSizeError(self):
"""Tests that state_size must be num_units * 2."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
num_units = 2
@@ -329,7 +329,7 @@ class RNNCellTest(test.TestCase):
})
def testBasicLSTMCellStateTupleType(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -360,7 +360,7 @@ class RNNCellTest(test.TestCase):
self.assertTrue(isinstance(out_m1, rnn_cell_impl.LSTMStateTuple))
def testBasicLSTMCellWithStateTuple(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -459,7 +459,7 @@ class RNNCellTest(test.TestCase):
self.assertEqual(len(res), 2)
def testLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_units = 8
num_proj = 6
state_size = num_units + num_proj
@@ -494,7 +494,7 @@ class RNNCellTest(test.TestCase):
float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) > 1e-6)
def testLSTMCellVariables(self):
- with self.test_session():
+ with self.cached_session():
num_units = 8
num_proj = 6
state_size = num_units + num_proj
@@ -517,7 +517,7 @@ class RNNCellTest(test.TestCase):
"root/lstm_cell/projection/kernel")
def testLSTMCellLayerNorm(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_units = 2
num_proj = 3
batch_size = 1
@@ -562,22 +562,21 @@ class RNNCellTest(test.TestCase):
rnn_cell_impl.DropoutWrapper,
rnn_cell_impl.ResidualWrapper,
lambda cell: rnn_cell_impl.MultiRNNCell([cell])]:
- with self.test_session():
- cell = rnn_cell_impl.BasicRNNCell(1)
- wrapper = wrapper_type(cell)
- wrapper(array_ops.ones([1, 1]),
- state=wrapper.zero_state(batch_size=1, dtype=dtypes.float32))
- self.evaluate([v.initializer for v in cell.variables])
- checkpoint = checkpointable_utils.Checkpoint(wrapper=wrapper)
- prefix = os.path.join(self.get_temp_dir(), "ckpt")
- self.evaluate(cell._bias.assign([40.]))
- save_path = checkpoint.save(prefix)
- self.evaluate(cell._bias.assign([0.]))
- checkpoint.restore(save_path).assert_consumed().run_restore_ops()
- self.assertAllEqual([40.], self.evaluate(cell._bias))
+ cell = rnn_cell_impl.BasicRNNCell(1)
+ wrapper = wrapper_type(cell)
+ wrapper(array_ops.ones([1, 1]),
+ state=wrapper.zero_state(batch_size=1, dtype=dtypes.float32))
+ self.evaluate([v.initializer for v in cell.variables])
+ checkpoint = checkpointable_utils.Checkpoint(wrapper=wrapper)
+ prefix = os.path.join(self.get_temp_dir(), "ckpt")
+ self.evaluate(cell._bias.assign([40.]))
+ save_path = checkpoint.save(prefix)
+ self.evaluate(cell._bias.assign([0.]))
+ checkpoint.restore(save_path).assert_consumed().run_restore_ops()
+ self.assertAllEqual([40.], self.evaluate(cell._bias))
def testOutputProjectionWrapper(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
@@ -594,7 +593,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[0], [[0.231907, 0.231907]])
def testInputProjectionWrapper(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -612,7 +611,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[0], [[0.154605, 0.154605, 0.154605]])
def testResidualWrapper(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
@@ -638,7 +637,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[2], res[3])
def testResidualWrapperWithSlice(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 5])
@@ -716,7 +715,7 @@ class RNNCellTest(test.TestCase):
self.assertTrue([s for s in gpu_stats if "gru_cell" in s.node_name])
def testEmbeddingWrapper(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 1], dtype=dtypes.int32)
@@ -735,7 +734,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[0], [[0.17139, 0.17139]])
def testEmbeddingWrapperWithDynamicRnn(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope("root"):
inputs = ops.convert_to_tensor([[[0], [0]]], dtype=dtypes.int64)
input_lengths = ops.convert_to_tensor([2], dtype=dtypes.int64)
@@ -753,7 +752,7 @@ class RNNCellTest(test.TestCase):
sess.run(outputs)
def testMultiRNNCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -770,7 +769,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res, [[0.175991, 0.175991, 0.13248, 0.13248]])
def testMultiRNNCellWithStateTuple(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -809,7 +808,7 @@ class DropoutWrapperTest(test.TestCase):
time_steps=None,
parallel_iterations=None,
**kwargs):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
if batch_size is None and time_steps is None:
diff --git a/tensorflow/contrib/saved_model/cc/saved_model/BUILD b/tensorflow/contrib/saved_model/cc/saved_model/BUILD
index 3c616c555b..ea4d41d43b 100644
--- a/tensorflow/contrib/saved_model/cc/saved_model/BUILD
+++ b/tensorflow/contrib/saved_model/cc/saved_model/BUILD
@@ -30,6 +30,7 @@ cc_library(
hdrs = ["signature_def_utils.h"],
visibility = ["//visibility:public"],
deps = [
+ "//tensorflow/cc/saved_model:signature_constants",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_proto_parsing",
@@ -42,6 +43,7 @@ tf_cc_test(
srcs = ["signature_def_utils_test.cc"],
deps = [
":signature_def_utils",
+ "//tensorflow/cc/saved_model:signature_constants",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_proto_parsing",
diff --git a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.cc b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.cc
index a45908d272..e87e497e5f 100644
--- a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.cc
+++ b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.cc
@@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h"
+#include "tensorflow/cc/saved_model/signature_constants.h"
+#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/protobuf.h"
@@ -33,6 +35,79 @@ Status FindInProtobufMap(StringPiece description,
*value = &it->second;
return Status::OK();
}
+
+// Looks up the TensorInfo for the given key in the given map and verifies that
+// its datatype matches the given correct datatype.
+bool VerifyTensorInfoForKeyInMap(const protobuf::Map<string, TensorInfo>& map,
+ const string& key, DataType correct_dtype) {
+ const TensorInfo* tensor_info;
+ const Status& status = FindInProtobufMap("", map, key, &tensor_info);
+ if (!status.ok()) {
+ return false;
+ }
+ if (tensor_info->dtype() != correct_dtype) {
+ return false;
+ }
+ return true;
+}
+
+bool IsValidPredictSignature(const SignatureDef& signature_def) {
+ if (signature_def.method_name() != kPredictMethodName) {
+ return false;
+ }
+ if (signature_def.inputs().empty()) {
+ return false;
+ }
+ if (signature_def.outputs().empty()) {
+ return false;
+ }
+ return true;
+}
+
+bool IsValidRegressionSignature(const SignatureDef& signature_def) {
+ if (signature_def.method_name() != kRegressMethodName) {
+ return false;
+ }
+ if (!VerifyTensorInfoForKeyInMap(signature_def.inputs(), kRegressInputs,
+ DT_STRING)) {
+ return false;
+ }
+ if (!VerifyTensorInfoForKeyInMap(signature_def.outputs(), kRegressOutputs,
+ DT_FLOAT)) {
+ return false;
+ }
+ return true;
+}
+
+bool IsValidClassificationSignature(const SignatureDef& signature_def) {
+ if (signature_def.method_name() != kClassifyMethodName) {
+ return false;
+ }
+ if (!VerifyTensorInfoForKeyInMap(signature_def.inputs(), kClassifyInputs,
+ DT_STRING)) {
+ return false;
+ }
+ if (signature_def.outputs().empty()) {
+ return false;
+ }
+ for (auto const& output : signature_def.outputs()) {
+ const string& key = output.first;
+ const TensorInfo& tensor_info = output.second;
+ if (key == kClassifyOutputClasses) {
+ if (tensor_info.dtype() != DT_STRING) {
+ return false;
+ }
+ } else if (key == kClassifyOutputScores) {
+ if (tensor_info.dtype() != DT_FLOAT) {
+ return false;
+ }
+ } else {
+ return false;
+ }
+ }
+ return true;
+}
+
} // namespace
Status FindSignatureDefByKey(const MetaGraphDef& meta_graph_def,
@@ -74,4 +149,10 @@ Status FindOutputTensorNameByKey(const SignatureDef& signature_def,
return Status::OK();
}
+bool IsValidSignature(const SignatureDef& signature_def) {
+ return IsValidClassificationSignature(signature_def) ||
+ IsValidRegressionSignature(signature_def) ||
+ IsValidPredictSignature(signature_def);
+}
+
} // namespace tensorflow
diff --git a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h
index b732cdd41e..bb24faa989 100644
--- a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h
+++ b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h
@@ -64,6 +64,9 @@ Status FindInputTensorNameByKey(const SignatureDef& signature_def,
Status FindOutputTensorNameByKey(const SignatureDef& signature_def,
const string& tensor_info_key, string* name);
+// Determine whether a SignatureDef can be served by TensorFlow Serving.
+bool IsValidSignature(const SignatureDef& signature_def);
+
} // namespace tensorflow
#endif // TENSORFLOW_CONTRIB_SAVED_MODEL_CC_SAVED_MODEL_SIGNATURE_DEF_UTILS_H_
diff --git a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils_test.cc b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils_test.cc
index a063e95696..c743112ce0 100644
--- a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils_test.cc
+++ b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h"
+#include "tensorflow/cc/saved_model/signature_constants.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
@@ -22,7 +23,7 @@ limitations under the License.
namespace tensorflow {
-class SignatureDefUtilsTest : public ::testing::Test {
+class FindByKeyTest : public ::testing::Test {
protected:
MetaGraphDef MakeSampleMetaGraphDef() {
MetaGraphDef result;
@@ -32,13 +33,23 @@ class SignatureDefUtilsTest : public ::testing::Test {
return result;
}
+ void SetInputNameForKey(const string& key, const string& name,
+ SignatureDef* signature_def) {
+ (*signature_def->mutable_inputs())[key].set_name(name);
+ }
+
+ void SetOutputNameForKey(const string& key, const string& name,
+ SignatureDef* signature_def) {
+ (*signature_def->mutable_outputs())[key].set_name(name);
+ }
+
SignatureDef MakeSampleSignatureDef() {
SignatureDef result;
result.set_method_name(kMethodName);
- (*result.mutable_inputs())[kInput1Key].set_name(kInput1Name);
- (*result.mutable_inputs())[kInput2Key].set_name(kInput2Name);
- (*result.mutable_outputs())[kOutput1Key].set_name(kOutput1Name);
- (*result.mutable_outputs())[kOutput2Key].set_name(kOutput2Name);
+ SetInputNameForKey(kInput1Key, kInput1Name, &result);
+ SetInputNameForKey(kInput2Key, kInput2Name, &result);
+ SetOutputNameForKey(kOutput1Key, kOutput1Name, &result);
+ SetOutputNameForKey(kOutput2Key, kOutput2Name, &result);
return result;
}
@@ -54,7 +65,7 @@ class SignatureDefUtilsTest : public ::testing::Test {
const string kOutput2Name = "output_two";
};
-TEST_F(SignatureDefUtilsTest, FindSignatureDefByKey) {
+TEST_F(FindByKeyTest, FindSignatureDefByKey) {
const MetaGraphDef meta_graph_def = MakeSampleMetaGraphDef();
const SignatureDef* signature_def;
// Succeeds for an existing signature.
@@ -67,7 +78,7 @@ TEST_F(SignatureDefUtilsTest, FindSignatureDefByKey) {
.ok());
}
-TEST_F(SignatureDefUtilsTest, FindInputTensorNameByKey) {
+TEST_F(FindByKeyTest, FindInputTensorNameByKey) {
const SignatureDef signature_def = MakeSampleSignatureDef();
string name;
// Succeeds for an existing input.
@@ -78,7 +89,7 @@ TEST_F(SignatureDefUtilsTest, FindInputTensorNameByKey) {
FindInputTensorNameByKey(signature_def, "nonexistent", &name).ok());
}
-TEST_F(SignatureDefUtilsTest, FindOutputTensorNameByKey) {
+TEST_F(FindByKeyTest, FindOutputTensorNameByKey) {
const SignatureDef signature_def = MakeSampleSignatureDef();
string name;
// Succeeds for an existing output.
@@ -89,4 +100,100 @@ TEST_F(SignatureDefUtilsTest, FindOutputTensorNameByKey) {
FindOutputTensorNameByKey(signature_def, "nonexistent", &name).ok());
}
+class IsValidSignatureTest : public ::testing::Test {
+ protected:
+ void SetInputDataTypeForKey(const string& key, DataType dtype) {
+ (*signature_def_.mutable_inputs())[key].set_dtype(dtype);
+ }
+
+ void SetOutputDataTypeForKey(const string& key, DataType dtype) {
+ (*signature_def_.mutable_outputs())[key].set_dtype(dtype);
+ }
+
+ void EraseOutputKey(const string& key) {
+ (*signature_def_.mutable_outputs()).erase(key);
+ }
+
+ void ExpectInvalidSignature() {
+ EXPECT_FALSE(IsValidSignature(signature_def_));
+ }
+
+ void ExpectValidSignature() { EXPECT_TRUE(IsValidSignature(signature_def_)); }
+
+ SignatureDef signature_def_;
+};
+
+TEST_F(IsValidSignatureTest, IsValidPredictSignature) {
+ signature_def_.set_method_name("not_kPredictMethodName");
+ // Incorrect method name
+ ExpectInvalidSignature();
+
+ signature_def_.set_method_name(kPredictMethodName);
+ // No inputs
+ ExpectInvalidSignature();
+
+ SetInputDataTypeForKey(kPredictInputs, DT_STRING);
+ // No outputs
+ ExpectInvalidSignature();
+
+ SetOutputDataTypeForKey(kPredictOutputs, DT_STRING);
+ ExpectValidSignature();
+}
+
+TEST_F(IsValidSignatureTest, IsValidRegressionSignature) {
+ signature_def_.set_method_name("not_kRegressMethodName");
+ // Incorrect method name
+ ExpectInvalidSignature();
+
+ signature_def_.set_method_name(kRegressMethodName);
+ // No inputs
+ ExpectInvalidSignature();
+
+ SetInputDataTypeForKey(kRegressInputs, DT_STRING);
+ // No outputs
+ ExpectInvalidSignature();
+
+ SetOutputDataTypeForKey(kRegressOutputs, DT_STRING);
+ // Incorrect data type
+ ExpectInvalidSignature();
+
+ SetOutputDataTypeForKey(kRegressOutputs, DT_FLOAT);
+ ExpectValidSignature();
+}
+
+TEST_F(IsValidSignatureTest, IsValidClassificationSignature) {
+ signature_def_.set_method_name("not_kClassifyMethodName");
+ // Incorrect method name
+ ExpectInvalidSignature();
+
+ signature_def_.set_method_name(kClassifyMethodName);
+ // No inputs
+ ExpectInvalidSignature();
+
+ SetInputDataTypeForKey(kClassifyInputs, DT_STRING);
+ // No outputs
+ ExpectInvalidSignature();
+
+ SetOutputDataTypeForKey("invalidKey", DT_FLOAT);
+ // Invalid key
+ ExpectInvalidSignature();
+
+ EraseOutputKey("invalidKey");
+ SetOutputDataTypeForKey(kClassifyOutputClasses, DT_FLOAT);
+ // Invalid dtype for classes
+ ExpectInvalidSignature();
+
+ SetOutputDataTypeForKey(kClassifyOutputClasses, DT_STRING);
+ // Valid without scores
+ ExpectValidSignature();
+
+ SetOutputDataTypeForKey(kClassifyOutputScores, DT_STRING);
+ // Invalid dtype for scores
+ ExpectInvalidSignature();
+
+ SetOutputDataTypeForKey(kClassifyOutputScores, DT_FLOAT);
+ // Valid with both classes and scores
+ ExpectValidSignature();
+}
+
} // namespace tensorflow
diff --git a/tensorflow/contrib/tensor_forest/BUILD b/tensorflow/contrib/tensor_forest/BUILD
index 652f709fe2..00c855daa3 100644
--- a/tensorflow/contrib/tensor_forest/BUILD
+++ b/tensorflow/contrib/tensor_forest/BUILD
@@ -462,7 +462,10 @@ py_test(
size = "small",
srcs = ["python/kernel_tests/scatter_add_ndim_op_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_pip_gpu"],
+ tags = [
+ "no_gpu",
+ "no_pip_gpu",
+ ],
deps = [
":tensor_forest_ops_py",
"//tensorflow/python:framework_test_lib",
diff --git a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto
index 2b13343efa..f88dc51636 100644
--- a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto
+++ b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto
@@ -79,12 +79,15 @@ message StepInfoResult {
// The step duration in picoseconds.
optional uint64 duration_ps = 2;
// The infeed duration in picoseconds.
- // Can turn into a map if we want a variable number of ops.
optional uint64 infeed_duration_ps = 3;
+ // The outfeed duration in picoseconds.
+ optional uint64 host_outfeed_ps = 8;
// The start time of this step in picoseconds.
optional uint64 begin_ps = 4;
// The waiting time within this step in picoseconds.
optional uint64 wait_duration_ps = 5;
+ // The unit b outfeed duration in picoseconds.
+ optional uint64 unit_b_outfeed_ps = 9;
// The time spent on cross-replica-sum in picoseconds.
optional uint64 crs_duration_ps = 6;
// Percentage of unit b time spent on infeed.
diff --git a/tensorflow/contrib/tpu/proto/optimization_parameters.proto b/tensorflow/contrib/tpu/proto/optimization_parameters.proto
index bf807af68b..cbf6809257 100644
--- a/tensorflow/contrib/tpu/proto/optimization_parameters.proto
+++ b/tensorflow/contrib/tpu/proto/optimization_parameters.proto
@@ -18,8 +18,10 @@ message DynamicLearningRate {
message LearningRate {
oneof learning_rate {
float constant = 1;
- DynamicLearningRate dynamic = 2;
+ // DynamicLearningRate dynamic = 2; -- disabled while code is being
+ // rewritten.
}
+ reserved 2;
}
message AdagradParameters {
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py
index ff88508d03..dd7f8b678f 100644
--- a/tensorflow/contrib/tpu/python/tpu/keras_support.py
+++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py
@@ -170,11 +170,41 @@ class TPUDistributionStrategy(object):
worker_re = re.compile('/job:([^/]+)')
for device in metadata.devices:
if 'TPU:0' in device.name:
- self.worker_name = worker_re.search(device.name).group(1)
+ self._worker_name = worker_re.search(device.name).group(1)
break
+ def _make_assignment_for_model(self, cpu_model):
+ """Makes a `TPUAssignment` for the passed in `cpu_model`."""
+ num_cores = self._num_cores
+ if num_cores > 1 and cpu_model.stateful:
+ logging.warning(
+ 'Model replication does not currently support stateful models. '
+ 'Degrading to a single core.')
+ num_cores = 1
+
+ return TPUAssignment(
+ worker_name=self._worker_name, num_cores=num_cores)
+
+
+class TPUAssignment(object):
+ """This is object holding TPU resources assignment for the concrete model.
+
+ `TPUDistributionStrategy` is responsible to create the instance of
+ `TPUAssignment`, so, it can dynamically adjust the `num_cores` to use based on
+ model and input batch sizes.
+ """
+
+ def __init__(self, worker_name, num_cores):
+ self._worker_name = worker_name
+ self._num_cores = num_cores
+
+ @property
+ def worker_name(self):
+ return self._worker_name
+
@property
def num_towers(self):
+ # TODO(xiejw): Support automatically assign num_cores based on inputs.
return self._num_cores
@@ -495,8 +525,8 @@ class TPUNumpyInfeedManager(TPUInfeedManager):
infeed_dict[tensor] = value
return infeed_dict
- def __init__(self, distribution_strategy):
- self._strategy = distribution_strategy
+ def __init__(self, tpu_assignment):
+ self._tpu_assignment = tpu_assignment
def _split_tensors(self, inputs):
"""Split input data across shards.
@@ -509,16 +539,16 @@ class TPUNumpyInfeedManager(TPUInfeedManager):
Returns:
List of lists containing the input to feed to each TPU shard.
"""
- if self._strategy.num_towers == 1:
+ if self._tpu_assignment.num_towers == 1:
return [inputs]
batch_size = inputs[0].shape[0]
- assert batch_size % self._strategy.num_towers == 0, (
- 'batch_size must be divisible by strategy.num_towers (%s vs %s)' %
- (batch_size, self._strategy.num_towers))
- shard_size = batch_size // self._strategy.num_towers
+ assert batch_size % self._tpu_assignment.num_towers == 0, (
+ 'batch_size must be divisible by the number of TPU cores in use (%s '
+ 'vs %s)' % (batch_size, self._tpu_assignment.num_towers))
+ shard_size = batch_size // self._tpu_assignment.num_towers
input_list = []
- for index in range(self._strategy.num_towers):
+ for index in range(self._tpu_assignment.num_towers):
shard_inputs = [
x[index * shard_size:(index + 1) * shard_size] for x in inputs
]
@@ -533,8 +563,9 @@ class TPUNumpyInfeedManager(TPUInfeedManager):
infeed_op = []
shard_infeed_tensors = []
- for shard_id in range(self._strategy.num_towers):
- with ops.device('/job:%s/device:CPU:0' % self._strategy.worker_name):
+ for shard_id in range(self._tpu_assignment.num_towers):
+ with ops.device(
+ '/job:%s/device:CPU:0' % self._tpu_assignment.worker_name):
infeed_tensors = []
with ops.device('/device:TPU:%d' % shard_id):
for spec in input_specs:
@@ -573,30 +604,31 @@ class TPUDatasetInfeedManager(TPUInfeedManager):
# TODO(saeta): Verify tpu_model_op is as expected!
return {}
- def __init__(self, dataset, distribution_strategy, tpu_session):
+ # pylint: disable=redefined-outer-name
+ def __init__(self, dataset, tpu_assignment, tpu_session):
"""Constructs a TPUDatasetInfeedManager.
Must be called within a `KerasTPUModel.tpu_session` context!
Args:
dataset: A `tf.data.Dataset` to infeed.
- distribution_strategy: The `TPUDistributionStrategy` used to configure the
+ tpu_assignment: The `TPUAssignment` used to configure the
Keras TPU model.
tpu_session: The `tf.Session` object used for running the TPU model.
"""
self._verify_dataset_shape(dataset)
self._dataset = dataset
- self._strategy = distribution_strategy
+ self._tpu_assignment = tpu_assignment
dummy_x_shape = dataset.output_shapes[0].as_list()
- dummy_x_shape[0] *= distribution_strategy.num_towers
+ dummy_x_shape[0] *= tpu_assignment.num_towers
dummy_y_shape = dataset.output_shapes[1].as_list()
- dummy_y_shape[0] *= distribution_strategy.num_towers
+ dummy_y_shape[0] *= tpu_assignment.num_towers
self._iterator = dataset.make_initializable_iterator()
tpu_session.run(self._iterator.initializer)
self._get_next_ops = []
ctrl_deps = []
- for i in range(distribution_strategy.num_towers):
+ for i in range(tpu_assignment.num_towers):
with ops.control_dependencies(ctrl_deps): # Ensure deterministic
# TODO(saeta): Ensure correct placement!
get_next_op = self._iterator.get_next()
@@ -676,10 +708,11 @@ class TPUDatasetInfeedManager(TPUInfeedManager):
def build_infeed_from_input_specs(self, input_specs, execution_mode):
shard_infeed_tensors = self._get_next_ops
- assert len(shard_infeed_tensors) == self._strategy.num_towers
+ assert len(shard_infeed_tensors) == self._tpu_assignment.num_towers
infeed_ops = []
- for shard_id in range(self._strategy.num_towers):
- with ops.device('/job:%s/device:CPU:0' % self._strategy.worker_name):
+ for shard_id in range(self._tpu_assignment.num_towers):
+ with ops.device(
+ '/job:%s/device:CPU:0' % self._tpu_assignment.worker_name):
infeed_ops.append(
tpu_ops.infeed_enqueue_tuple(
shard_infeed_tensors[shard_id],
@@ -702,10 +735,10 @@ class TPUFunction(object):
instead of being injected as `feed_dict` items or fetches.
"""
- def __init__(self, model, execution_mode, strategy):
+ def __init__(self, model, execution_mode, tpu_assignment):
self.model = model
self.execution_mode = execution_mode
- self._strategy = strategy
+ self._tpu_assignment = tpu_assignment
self._compilation_cache = {}
self._cloned_model = None
@@ -757,7 +790,8 @@ class TPUFunction(object):
# Clone our CPU model, running within the TPU device context.
with TPURewriteContext(tpu_input_map):
with variable_scope.variable_scope('tpu_model_%s' % id(self.model)):
- with keras_tpu_variables.replicated_scope(self._strategy.num_towers):
+ with keras_tpu_variables.replicated_scope(
+ self._tpu_assignment.num_towers):
self._cloned_model = models.clone_model(self.model)
# Create a copy of the optimizer for this graph.
@@ -827,7 +861,7 @@ class TPUFunction(object):
# `execute op` replicates `_model_fn` `num_replicas` times, with each shard
# running on a different logical core.
compile_op, execute_op = tpu.split_compile_and_replicate(
- _model_fn, inputs=[[]] * self._strategy.num_towers)
+ _model_fn, inputs=[[]] * self._tpu_assignment.num_towers)
# Generate CPU side operations to enqueue features/labels and dequeue
# outputs from the model call.
@@ -835,8 +869,9 @@ class TPUFunction(object):
input_specs, self.execution_mode)
# Build output ops.
outfeed_op = []
- for shard_id in range(self._strategy.num_towers):
- with ops.device('/job:%s/device:CPU:0' % self._strategy.worker_name):
+ for shard_id in range(self._tpu_assignment.num_towers):
+ with ops.device(
+ '/job:%s/device:CPU:0' % self._tpu_assignment.worker_name):
outfeed_op.extend(
tpu_ops.outfeed_dequeue_tuple(
dtypes=[spec.dtype for spec in self._outfeed_spec],
@@ -886,7 +921,7 @@ class TPUFunction(object):
for x, mgr in self.model._numpy_to_infeed_manager_list:
if inputs[0] is x:
return mgr
- return TPUNumpyInfeedManager(self.model._strategy)
+ return TPUNumpyInfeedManager(self.model._tpu_assignment)
def _tpu_model_ops_for_input_specs(self, input_specs, infeed_manager):
"""Looks up the corresponding `TPUModelOp` for a given `input_specs`.
@@ -958,7 +993,7 @@ class TPUFunction(object):
outputs = [[]] * len(self._outfeed_spec)
outputs_per_replica = len(self._outfeed_spec)
- for i in range(self._strategy.num_towers):
+ for i in range(self._tpu_assignment.num_towers):
output_group = outfeed_outputs[i * outputs_per_replica:(i + 1) *
outputs_per_replica]
for j in range(outputs_per_replica):
@@ -967,7 +1002,7 @@ class TPUFunction(object):
return [np.concatenate(group) for group in outputs]
else:
return outfeed_outputs[:len(outfeed_outputs) //
- self._strategy.num_towers]
+ self._tpu_assignment.num_towers]
def __call__(self, inputs):
"""__call__ executes the function on the computational hardware.
@@ -1119,11 +1154,11 @@ class KerasTPUModel(models.Model):
self.predict_function = None
self.test_function = None
self.train_function = None
- self._strategy = strategy
- cluster_resolver = self._strategy._tpu_cluster_resolver
+ cluster_resolver = strategy._tpu_cluster_resolver
self._tpu_name_or_address = cluster_resolver.get_master()
self._cpu_model = cpu_model
+ self._tpu_assignment = strategy._make_assignment_for_model(cpu_model)
self._tpu_model = None
self._tpu_weights_initialized = False
@@ -1146,7 +1181,7 @@ class KerasTPUModel(models.Model):
return {
'cpu_model': self._cpu_model,
'tpu_name_or_address': self._tpu_name_or_address,
- 'strategy': self._strategy,
+ 'tpu_assignment': self._tpu_assignment,
}
def compile(self,
@@ -1207,7 +1242,7 @@ class KerasTPUModel(models.Model):
'/keras')
if callable(x):
with self.tpu_session() as sess,\
- ops.device('/job:%s/device:CPU:0' % self._strategy.worker_name):
+ ops.device('/job:%s/device:CPU:0' % self._tpu_assignment.worker_name):
dataset = x()
if steps_per_epoch is None:
raise ValueError('When using tf.data as input to a model, you '
@@ -1215,7 +1250,8 @@ class KerasTPUModel(models.Model):
if y is not None:
raise ValueError('When using tf.data as input to a model, y must be '
'None')
- infeed_manager = TPUDatasetInfeedManager(dataset, self._strategy, sess)
+ infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment,
+ sess)
# Use dummy numpy inputs for the rest of Keras' shape checking. We
# intercept them when building the model.
x = infeed_manager.dummy_x
@@ -1236,7 +1272,8 @@ class KerasTPUModel(models.Model):
if validation_steps is None:
raise ValueError('When using tf.data as validation for a model, you '
'should specify the validation_steps argument.')
- infeed_manager = TPUDatasetInfeedManager(dataset, self._strategy, sess)
+ infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment,
+ sess)
# Use dummy numpy inputs for the rest of Keras' shape checking. We
# intercept them when building the model.
val_x = infeed_manager.dummy_x
@@ -1313,7 +1350,8 @@ class KerasTPUModel(models.Model):
if y is not None:
raise ValueError('When using tf.data as input to a model, y must be '
'None')
- infeed_manager = TPUDatasetInfeedManager(dataset, self._strategy, sess)
+ infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment,
+ sess)
# Use dummy numpy inputs for the rest of Keras' shape checking. We
# intercept them when building the model.
x = infeed_manager.dummy_x
@@ -1740,20 +1778,24 @@ class KerasTPUModel(models.Model):
def _make_train_function(self):
if not self.train_function:
self.train_function = TPUFunction(
- self, model_fn_lib.ModeKeys.TRAIN, strategy=self._strategy)
+ self,
+ model_fn_lib.ModeKeys.TRAIN,
+ tpu_assignment=self._tpu_assignment)
return self.train_function
def _make_test_function(self):
if not self.test_function:
self.test_function = TPUFunction(
- self, model_fn_lib.ModeKeys.EVAL, strategy=self._strategy)
+ self, model_fn_lib.ModeKeys.EVAL, tpu_assignment=self._tpu_assignment)
return self.test_function
def _make_predict_function(self):
if not self.predict_function:
self.predict_function = TPUFunction(
- self, model_fn_lib.ModeKeys.PREDICT, strategy=self._strategy)
+ self,
+ model_fn_lib.ModeKeys.PREDICT,
+ tpu_assignment=self._tpu_assignment)
return self.predict_function
def _initialize_weights(self, cloned_model):
@@ -1825,6 +1867,7 @@ class KerasTPUModel(models.Model):
self._session.close()
+# pylint: disable=bad-continuation
def _validate_shapes(model):
"""Validate that all layers in `model` have constant shape."""
for layer in model.layers:
@@ -1852,10 +1895,13 @@ Layer: %(layer)s
Input shape: %(input_shape)s
Output shape: %(output_shape)s
""" % {
- 'layer': layer,
- 'input_shape': layer.input_shape,
- 'output_shape': layer.output_shape
- })
+ 'layer': layer,
+ 'input_shape': layer.input_shape,
+ 'output_shape': layer.output_shape
+ })
+
+
+# pylint: enable=bad-continuation
@experimental
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 5c314f359c..c06fea130f 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -695,6 +695,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":lib_internal",
+ "@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
],
@@ -3220,7 +3221,6 @@ tf_cc_tests(
"lib/gtl/edit_distance_test.cc",
"lib/gtl/flatmap_test.cc",
"lib/gtl/flatset_test.cc",
- "lib/gtl/inlined_vector_test.cc",
"lib/gtl/int_type_test.cc",
"lib/gtl/iterator_range_test.cc",
"lib/gtl/manual_constructor_test.cc",
diff --git a/tensorflow/core/common_runtime/bfc_allocator.cc b/tensorflow/core/common_runtime/bfc_allocator.cc
index 3bf0532491..84c6285bbe 100644
--- a/tensorflow/core/common_runtime/bfc_allocator.cc
+++ b/tensorflow/core/common_runtime/bfc_allocator.cc
@@ -596,7 +596,7 @@ string BFCAllocator::RenderOccupancy() {
region_offset += region.memory_size();
}
- return std::string(rendered, resolution);
+ return string(rendered, resolution);
}
void BFCAllocator::DumpMemoryLog(size_t num_bytes) {
diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc
index 39a3b49cd1..879a794368 100644
--- a/tensorflow/core/common_runtime/eager/context.cc
+++ b/tensorflow/core/common_runtime/eager/context.cc
@@ -36,22 +36,34 @@ bool ReadBoolFromEnvVar(StringPiece env_var_name, bool default_val) {
EagerContext::EagerContext(const SessionOptions& opts,
ContextDevicePlacementPolicy default_policy,
- bool async, std::unique_ptr<DeviceMgr> device_mgr,
+ bool async,
+ std::unique_ptr<const DeviceMgr> device_mgr,
Rendezvous* rendezvous)
+ : EagerContext(opts, default_policy, async, device_mgr.release(),
+ /*device_mgr_owned*/ true, rendezvous) {}
+
+EagerContext::EagerContext(const SessionOptions& opts,
+ ContextDevicePlacementPolicy default_policy,
+ bool async, const DeviceMgr* device_mgr,
+ bool device_mgr_owned, Rendezvous* rendezvous)
: policy_(default_policy),
- local_device_manager_(std::move(device_mgr)),
- local_unowned_device_manager_(nullptr),
- devices_(local_device_manager_->ListDevices()),
+ devices_(device_mgr->ListDevices()),
rendezvous_(rendezvous),
thread_pool_(NewThreadPoolFromSessionOptions(opts)),
pflr_(new ProcessFunctionLibraryRuntime(
- local_device_manager_.get(), opts.env, TF_GRAPH_DEF_VERSION,
- &func_lib_def_, {}, thread_pool_.get())),
+ device_mgr, opts.env, TF_GRAPH_DEF_VERSION, &func_lib_def_, {},
+ thread_pool_.get())),
log_device_placement_(opts.config.log_device_placement()),
num_active_steps_(0),
async_default_(async),
env_(opts.env),
use_send_tensor_rpc_(false) {
+ if (device_mgr_owned) {
+ local_device_manager_.reset(device_mgr);
+ local_unowned_device_manager_ = nullptr;
+ } else {
+ local_unowned_device_manager_ = device_mgr;
+ }
InitDeviceMapAndAsync();
if (opts.config.inter_op_parallelism_threads() > 0) {
runner_ = [this](std::function<void()> closure) {
diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h
index 3c95ac590d..eb6eb0d55a 100644
--- a/tensorflow/core/common_runtime/eager/context.h
+++ b/tensorflow/core/common_runtime/eager/context.h
@@ -65,10 +65,17 @@ enum ContextDevicePlacementPolicy {
class EagerContext {
public:
- explicit EagerContext(const SessionOptions& opts,
- ContextDevicePlacementPolicy default_policy, bool async,
- std::unique_ptr<DeviceMgr> device_mgr,
- Rendezvous* rendezvous);
+ // TODO: remove this constructor once we migrate all callers to the next one.
+ EagerContext(const SessionOptions& opts,
+ ContextDevicePlacementPolicy default_policy, bool async,
+ std::unique_ptr<const DeviceMgr> device_mgr,
+ Rendezvous* rendezvous);
+
+ EagerContext(const SessionOptions& opts,
+ ContextDevicePlacementPolicy default_policy, bool async,
+ const DeviceMgr* device_mgr, bool device_mgr_owned,
+ Rendezvous* rendezvous);
+
~EagerContext();
// Returns the function library runtime for the given device.
@@ -207,8 +214,8 @@ class EagerContext {
thread_local_policies_ GUARDED_BY(policy_map_mu_);
// Only one of the below is set.
- std::unique_ptr<DeviceMgr> local_device_manager_;
- DeviceMgr* local_unowned_device_manager_;
+ std::unique_ptr<const DeviceMgr> local_device_manager_;
+ const DeviceMgr* local_unowned_device_manager_;
std::unique_ptr<DeviceMgr> remote_device_manager_;
// Devices owned by device_manager
diff --git a/tensorflow/core/common_runtime/graph_runner.cc b/tensorflow/core/common_runtime/graph_runner.cc
index 0a1797fa19..f9aef3af70 100644
--- a/tensorflow/core/common_runtime/graph_runner.cc
+++ b/tensorflow/core/common_runtime/graph_runner.cc
@@ -56,7 +56,7 @@ class SimpleRendezvous : public Rendezvous {
}
mutex_lock l(mu_);
- string edge_name = std::string(parsed.edge_name);
+ string edge_name(parsed.edge_name);
if (table_.count(edge_name) > 0) {
return errors::Internal("Send of an already sent tensor");
}
@@ -69,7 +69,7 @@ class SimpleRendezvous : public Rendezvous {
Tensor tensor;
Status status = Status::OK();
{
- string key = std::string(parsed.edge_name);
+ string key(parsed.edge_name);
mutex_lock l(mu_);
if (table_.count(key) <= 0) {
status = errors::Internal("Did not find key ", key);
diff --git a/tensorflow/core/common_runtime/placer.cc b/tensorflow/core/common_runtime/placer.cc
index 7f3c25d81d..3b59995433 100644
--- a/tensorflow/core/common_runtime/placer.cc
+++ b/tensorflow/core/common_runtime/placer.cc
@@ -254,9 +254,11 @@ class ColocationGraph {
old_root_member.device_name,
allow_soft_placement_);
if (!s.ok()) {
- return errors::InvalidArgument("Cannot colocate nodes '", x.name(),
- "' and '", y.name(), ": ",
- s.error_message());
+ return errors::InvalidArgument(
+ "Cannot colocate nodes ",
+ errors::FormatColocationNodeForError(x.name()), " and ",
+ errors::FormatColocationNodeForError(y.name()), ": ",
+ s.error_message());
}
// Ensure that the common root has at least one supported device
@@ -267,8 +269,10 @@ class ColocationGraph {
old_root_member.supported_device_types);
if (new_root_member.supported_device_types.empty()) {
return errors::InvalidArgument(
- "Cannot colocate nodes '", x.name(), "' and '", y.name(),
- "' because no device type supports both of those nodes and the "
+ "Cannot colocate nodes ",
+ errors::FormatColocationNodeForError(x.name()), " and ",
+ errors::FormatColocationNodeForError(y.name()),
+ " because no device type supports both of those nodes and the "
"other nodes colocated with them.",
DebugInfo(x_root), DebugInfo(y_root));
}
@@ -376,8 +380,9 @@ class ColocationGraph {
// merged set device is different, so print both.
return errors::InvalidArgument(
"Could not satisfy explicit device specification '",
- node->requested_device(),
- "' because the node was colocated with a group of nodes that "
+ node->requested_device(), "' because the node ",
+ errors::FormatColocationNodeForError(node->name()),
+ " was colocated with a group of nodes that ",
"required incompatible device '",
DeviceNameUtils::ParsedNameToString(
members_[node_root].device_name),
@@ -809,10 +814,10 @@ Status Placer::Run() {
std::vector<Device*>* devices;
Status status = colocation_graph.GetDevicesForNode(node, &devices);
if (!status.ok()) {
- return AttachDef(errors::InvalidArgument(
- "Cannot assign a device for operation ",
- RichNodeName(node), ": ", status.error_message()),
- *node);
+ return AttachDef(
+ errors::InvalidArgument("Cannot assign a device for operation ",
+ node->name(), ": ", status.error_message()),
+ *node);
}
// Returns the first device in sorted devices list so we will always
@@ -856,10 +861,10 @@ Status Placer::Run() {
std::vector<Device*>* devices;
Status status = colocation_graph.GetDevicesForNode(node, &devices);
if (!status.ok()) {
- return AttachDef(errors::InvalidArgument(
- "Cannot assign a device for operation ",
- RichNodeName(node), ": ", status.error_message()),
- *node);
+ return AttachDef(
+ errors::InvalidArgument("Cannot assign a device for operation ",
+ node->name(), ": ", status.error_message()),
+ *node);
}
int assigned_device = -1;
@@ -925,21 +930,4 @@ void Placer::LogDeviceAssignment(const Node* node) const {
}
}
-bool Placer::ClientHandlesErrorFormatting() const {
- return options_ != nullptr &&
- options_->config.experimental().client_handles_error_formatting();
-}
-
-// Returns the node name in single quotes. If the client handles formatted
-// errors, appends a formatting tag which the client will reformat into, for
-// example, " (defined at filename:123)".
-// TODO(shikharagarwal): Remove this function once
-// client_handles_error_formatting flag is removed.
-string Placer::RichNodeName(const Node* node) const {
- if (ClientHandlesErrorFormatting()) {
- return errors::FormatNodeNameForError(node->name());
- }
- return strings::StrCat("'", node->name(), "'");
-}
-
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/placer.h b/tensorflow/core/common_runtime/placer.h
index cefcdd25db..f97ffe7372 100644
--- a/tensorflow/core/common_runtime/placer.h
+++ b/tensorflow/core/common_runtime/placer.h
@@ -87,8 +87,6 @@ class Placer {
// placement if the SessionOptions entry in 'options_' requests it.
void AssignAndLog(int assigned_device, Node* node) const;
void LogDeviceAssignment(const Node* node) const;
- bool ClientHandlesErrorFormatting() const;
- string RichNodeName(const Node* node) const;
Graph* const graph_; // Not owned.
const DeviceSet* const devices_; // Not owned.
diff --git a/tensorflow/core/common_runtime/placer_test.cc b/tensorflow/core/common_runtime/placer_test.cc
index 83d27e2730..9b8a95e3b6 100644
--- a/tensorflow/core/common_runtime/placer_test.cc
+++ b/tensorflow/core/common_runtime/placer_test.cc
@@ -800,11 +800,11 @@ TEST_F(PlacerTest, TestInvalidMultipleColocationGroups) {
}
Status s = Place(&g);
- EXPECT_TRUE(
- str_util::StrContains(s.error_message(),
- "Cannot colocate nodes 'foo' and 'in' because no "
- "device type supports both of those nodes and the "
- "other nodes colocated with them"));
+ EXPECT_TRUE(str_util::StrContains(
+ s.error_message(),
+ "Cannot colocate nodes {{colocation_node foo}} and "
+ "{{colocation_node in}} because no device type supports both of those "
+ "nodes and the other nodes colocated with them"));
}
TEST_F(PlacerTest, TestColocationGroupWithReferenceConnections) {
@@ -867,9 +867,9 @@ TEST_F(PlacerTest, TestColocationGroupWithUnsatisfiableReferenceConnections) {
Status s = Place(&g);
EXPECT_TRUE(str_util::StrContains(
s.error_message(),
- "Cannot colocate nodes 'var3' and 'assign3' because no "
- "device type supports both of those nodes and the other "
- "nodes colocated with them."));
+ "Cannot colocate nodes {{colocation_node var3}} and {{colocation_node "
+ "assign3}} because no device type supports both of those nodes and the "
+ "other nodes colocated with them."));
}
TEST_F(PlacerTest, TestColocationAndReferenceConnections) {
@@ -1154,35 +1154,12 @@ TEST_F(PlacerTest, TestNonexistentGpuNoAllowSoftPlacementFormatTag) {
}
SessionOptions options;
- options.config.mutable_experimental()->set_client_handles_error_formatting(
- true);
Status s = Place(&g, &options);
EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
LOG(WARNING) << s.error_message();
- EXPECT_TRUE(str_util::StrContains(
- s.error_message(), "Cannot assign a device for operation {{node in}}"));
-}
-
-// Test that the "Cannot assign a device" error message does not contain a
-// format tag when not it shouldn't
-TEST_F(PlacerTest, TestNonexistentGpuNoAllowSoftPlacementNoFormatTag) {
- Graph g(OpRegistry::Global());
- { // Scope for temporary variables used to construct g.
- GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
- ops::SourceOp("TestDevice",
- b.opts().WithName("in").WithDevice("/device:fakegpu:11"));
- TF_EXPECT_OK(BuildGraph(b, &g));
- }
-
- SessionOptions options;
- options.config.mutable_experimental()->set_client_handles_error_formatting(
- false);
- Status s = Place(&g, &options);
- EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
- EXPECT_TRUE(str_util::StrContains(
- s.error_message(), "Cannot assign a device for operation 'in'"));
- EXPECT_FALSE(str_util::StrContains(
- s.error_message(), "'in' (defined at ^^node:in:${file}:${line}^^)"));
+ EXPECT_TRUE(str_util::StrContains(s.error_message(),
+ "Cannot assign a device for operation in"));
+ EXPECT_TRUE(str_util::StrContains(s.error_message(), "{{node in}}"));
}
// Test that placement fails when a node requests an explicit device that is not
@@ -1288,8 +1265,9 @@ TEST_F(PlacerTest, TestUnsatisfiableConstraintWithReferenceConnections) {
Status s = Place(&g);
EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
- EXPECT_TRUE(str_util::StrContains(
- s.error_message(), "Cannot colocate nodes 'var' and 'assign'"));
+ EXPECT_TRUE(str_util::StrContains(s.error_message(),
+ "Cannot colocate nodes {{colocation_node "
+ "var}} and {{colocation_node assign}}"));
}
// Test that a generator node follows its consumers (where there are several
diff --git a/tensorflow/core/common_runtime/pool_allocator.cc b/tensorflow/core/common_runtime/pool_allocator.cc
index 10a24ed14c..fdad8de8d6 100644
--- a/tensorflow/core/common_runtime/pool_allocator.cc
+++ b/tensorflow/core/common_runtime/pool_allocator.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/core/common_runtime/session_state.cc b/tensorflow/core/common_runtime/session_state.cc
index 65ff356e73..5b1915755d 100644
--- a/tensorflow/core/common_runtime/session_state.cc
+++ b/tensorflow/core/common_runtime/session_state.cc
@@ -70,7 +70,7 @@ Status TensorStore::SaveTensors(const std::vector<string>& output_names,
// Save only the tensors in output_names in the session.
for (const string& name : output_names) {
TensorId id(ParseTensorName(name));
- const string& op_name = std::string(id.first);
+ const string op_name(id.first);
auto it = tensors_.find(op_name);
if (it != tensors_.end()) {
// Save the tensor to the session state.
diff --git a/tensorflow/core/common_runtime/step_stats_collector.cc b/tensorflow/core/common_runtime/step_stats_collector.cc
index 9c2510e6a9..836cb8ed14 100644
--- a/tensorflow/core/common_runtime/step_stats_collector.cc
+++ b/tensorflow/core/common_runtime/step_stats_collector.cc
@@ -176,7 +176,7 @@ static int ExtractGpuWithStreamAll(string device_name) {
} else {
// Convert the captured string into an integer. But first we need to put
// the digits back in order
- string ordered_capture = std::string(capture);
+ string ordered_capture(capture);
std::reverse(ordered_capture.begin(), ordered_capture.end());
int gpu_id;
CHECK(strings::safe_strto32(ordered_capture, &gpu_id));
@@ -205,7 +205,7 @@ static int ExtractGpuWithoutStream(string device_name) {
} else {
// Convert the captured string into an integer. But first we need to put
// the digits back in order
- string ordered_capture = std::string(capture);
+ string ordered_capture(capture);
std::reverse(ordered_capture.begin(), ordered_capture.end());
int gpu_id;
CHECK(strings::safe_strto32(ordered_capture, &gpu_id));
@@ -252,7 +252,7 @@ void StepStatsCollector::BuildCostModel(
for (auto& itr : per_device_stats) {
const StringPiece device_name = itr.first;
- const int gpu_id = ExtractGpuWithoutStream(std::string(device_name));
+ const int gpu_id = ExtractGpuWithoutStream(string(device_name));
if (gpu_id >= 0) {
// Reference the gpu hardware stats in addition to the regular stats
// for this gpu device if they're available.
diff --git a/tensorflow/core/graph/testlib.cc b/tensorflow/core/graph/testlib.cc
index ea7788f654..0a38aa1c91 100644
--- a/tensorflow/core/graph/testlib.cc
+++ b/tensorflow/core/graph/testlib.cc
@@ -485,6 +485,33 @@ Node* DiagPart(Graph* g, Node* in, DataType type) {
return ret;
}
+Node* CheckNumerics(Graph* g, Node* in, const string& message) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "CheckNumerics")
+ .Input(in)
+ .Attr("message", message)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Arg(Graph* g, int64 index, DataType type) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Arg")
+ .Attr("T", type)
+ .Attr("index", index)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Retval(Graph* g, int64 index, Node* in) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Retval")
+ .Input(in)
+ .Attr("index", index)
+ .Finalize(g, &ret));
+ return ret;
+}
+
void ToGraphDef(Graph* g, GraphDef* gdef) { g->ToGraphDef(gdef); }
} // end namespace graph
diff --git a/tensorflow/core/graph/testlib.h b/tensorflow/core/graph/testlib.h
index 8585b35a19..bd0284d43a 100644
--- a/tensorflow/core/graph/testlib.h
+++ b/tensorflow/core/graph/testlib.h
@@ -209,6 +209,15 @@ Node* Diag(Graph* g, Node* in, DataType type);
// Add a DiagPart node in "g".
Node* DiagPart(Graph* g, Node* in, DataType type);
+// Add a CheckNumerics node in "g".
+Node* CheckNumerics(Graph* g, Node* in, const string& message);
+
+// Add an _Arg node in "g".
+Node* Arg(Graph* g, int64 index, DataType type);
+
+// Add a _Retval node in "g".
+Node* Retval(Graph* g, int64 index, Node* in);
+
} // end namespace graph
} // end namespace test
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index 653b088b1d..e78239bd43 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -135,16 +135,37 @@ bool IsDequeueOp(const NodeDef& node) {
bool IsDiv(const NodeDef& node) { return node.op() == "Div"; }
-bool IsElementWiseMonotonic(const NodeDef& node) {
- static const std::unordered_set<string>* element_wise_monotonic_ops =
+// Returns true if node represents a unary elementwise function that is
+// monotonic. If *is_non_decreasing is true, the function is non-decreasing,
+// e.g. sqrt, exp. *is_non_decreasing is false, the function is non-increasing,
+// e.g. inv.
+bool IsElementWiseMonotonic(const NodeDef& node, bool* is_non_decreasing) {
+ static const std::unordered_set<string>* monotonic_non_decreasing_ops =
CHECK_NOTNULL((new std::unordered_set<string>{
- "Relu",
- "Relu6",
- "Sigmoid",
- "Sqrt",
- "Tanh",
+ "Asinh", "Atanh", "Ceil", "Elu", "Erf", "Exp", "Expm1",
+ "Floor", "Log", "Log1p", "Relu", "Relu", "Relu6", "Rint",
+ "Selu", "Sigmoid", "Sign", "Sinh", "Sqrt", "Tanh",
+ }));
+ static const std::unordered_set<string>* monotonic_non_increasing_ops =
+ CHECK_NOTNULL((new std::unordered_set<string>{
+ "Inv",
+ "Reciprocal",
+ "Erfc",
+ "Rsqrt",
+ "Neg",
}));
- return element_wise_monotonic_ops->count(node.op()) > 0;
+ if (monotonic_non_decreasing_ops->count(node.op()) > 0) {
+ if (is_non_decreasing) {
+ *is_non_decreasing = true;
+ }
+ return true;
+ } else if (monotonic_non_increasing_ops->count(node.op()) > 0) {
+ if (is_non_decreasing) {
+ *is_non_decreasing = false;
+ }
+ return true;
+ }
+ return false;
}
bool IsEluGrad(const NodeDef& node) { return node.op() == "EluGrad"; }
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index 94439265c9..25ab6b65ac 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -55,7 +55,7 @@ bool IsDepthwiseConv2dNativeBackpropFilter(const NodeDef& node);
bool IsDepthwiseConv2dNativeBackpropInput(const NodeDef& node);
bool IsDequeueOp(const NodeDef& node);
bool IsDiv(const NodeDef& node);
-bool IsElementWiseMonotonic(const NodeDef& node);
+bool IsElementWiseMonotonic(const NodeDef& node, bool* is_non_decreasing);
bool IsEluGrad(const NodeDef& node);
bool IsEnter(const NodeDef& node);
bool IsEqual(const NodeDef& node);
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 4fed88d536..65947ddce5 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -2706,8 +2706,9 @@ class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage {
// 0. inner_function is not in the preserve set,
// 1. inner_function's Op is element-wise monotonic
// 2. inner_function's output is not being consumed elsewhere.
+ bool is_non_decreasing = false;
if (!IsInPreserveSet(*inner_function) &&
- IsElementWiseMonotonic(*inner_function) &&
+ IsElementWiseMonotonic(*inner_function, &is_non_decreasing) &&
ctx().node_map->GetOutputs(inner_function->name()).size() == 1) {
// Swap the first inputs of the inner function Op & the reduction Op.
NodeDef* inner_input;
@@ -2719,7 +2720,12 @@ class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage {
UpdateConsumers(reduction_node, inner_function->name());
ctx().node_map->UpdateInput(inner_function->name(), inner_input->name(),
reduction_node->name());
-
+ if (!is_non_decreasing) {
+ // Flip Min<->Max if the function is non-increasing, e.g.
+ // Max(Neg(x)) = Neg(Min(x)).
+ const string opposite = IsMax(*reduction_node) ? "Min" : "Max";
+ reduction_node->set_op(opposite);
+ }
AddToOptimizationQueue(reduction_node);
AddToOptimizationQueue(inner_function);
AddToOptimizationQueue(inner_input);
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index bfccc0affd..39517edc06 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -3248,6 +3248,48 @@ TEST_F(ArithmeticOptimizerTest,
VerifyGraphsMatch(item.graph, output, __LINE__);
}
+TEST_F(ArithmeticOptimizerTest,
+ OptimizeMaxOrMinOfMonotonicElementWiseNonIncreasing) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
+ Output neg = ops::Neg(s.WithOpName("neg"), x);
+ Output reduce_max = ops::Max(s.WithOpName("reduce_max"), neg, {0});
+ Output final_out = ops::Identity(s.WithOpName("final_out"), reduce_max);
+
+ GrapplerItem item;
+ item.fetch = {"final_out"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(1, tensors_expected.size());
+
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
+ OptimizeAndPrune(&optimizer, &item, &output);
+ auto tensors = EvaluateNodes(output, item.fetch);
+ EXPECT_EQ(1, tensors.size());
+
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
+ EXPECT_EQ(item.graph.node_size(), output.node_size());
+ // Check if the inputs are switched
+ int required_node_count = 0;
+ for (int i = 0; i < output.node_size(); ++i) {
+ const NodeDef& node = output.node(i);
+ if (node.name() == "neg") {
+ EXPECT_EQ("Neg", node.op());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("reduce_max", node.input(0));
+ ++required_node_count;
+ } else if (node.name() == "reduce_max") {
+ EXPECT_EQ("Min", node.op());
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("x", node.input(0));
+ ++required_node_count;
+ }
+ }
+ EXPECT_EQ(2, required_node_count);
+}
+
TEST_F(ArithmeticOptimizerTest, UnaryOpsComposition) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
diff --git a/tensorflow/core/grappler/utils/functions.cc b/tensorflow/core/grappler/utils/functions.cc
index a2c363ea6e..a428aea7f5 100644
--- a/tensorflow/core/grappler/utils/functions.cc
+++ b/tensorflow/core/grappler/utils/functions.cc
@@ -304,21 +304,21 @@ Status GrapplerFunctionItemInstantiation::GetArgType(
}
GrapplerFunctionItem::GrapplerFunctionItem(
- const string& func_name, const string& description,
- const AttrValueMap& func_attr,
- const std::vector<InputArgExpansion>& input_arg_expansions,
- const std::vector<OutputArgExpansion>& output_arg_expansions,
- const std::vector<string>& keep_nodes, const int graph_def_version,
- bool is_stateful, GraphDef&& function_body)
- : description_(description),
- func_attr_(func_attr),
- input_arg_expansions_(input_arg_expansions),
- output_arg_expansions_(output_arg_expansions),
+ string func_name, string description, AttrValueMap func_attr,
+ std::vector<InputArgExpansion> input_arg_expansions,
+ std::vector<OutputArgExpansion> output_arg_expansions,
+ std::vector<string> keep_nodes, const int graph_def_version,
+ const bool is_stateful, GraphDef&& function_body)
+ : description_(std::move(description)),
+ func_attr_(std::move(func_attr)),
+ input_arg_expansions_(std::move(input_arg_expansions)),
+ output_arg_expansions_(std::move(output_arg_expansions)),
is_stateful_(is_stateful) {
- id = func_name;
- keep_ops = keep_nodes;
- // Swap the graph body.
- graph.Swap(&function_body);
+ // Move assign GrapplerItem members.
+ keep_ops = std::move(keep_nodes);
+ id = std::move(func_name);
+ graph = std::move(function_body);
+
graph.mutable_versions()->set_producer(graph_def_version);
// Fill the feed nodes with input placeholders.
for (const InputArgExpansion& input_arg : input_arg_expansions_) {
@@ -598,8 +598,8 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
*item = GrapplerFunctionItem(
/*func_name=*/signature.name(), /*description=*/signature.description(),
/*func_attr=*/AttrValueMap(func.attr().begin(), func.attr().end()),
- inputs, outputs, keep_nodes, graph_def_version, is_stateful,
- std::move(function_body));
+ std::move(inputs), std::move(outputs), std::move(keep_nodes),
+ graph_def_version, is_stateful, std::move(function_body));
return Status::OK();
}
diff --git a/tensorflow/core/grappler/utils/functions.h b/tensorflow/core/grappler/utils/functions.h
index 61588ceb83..733caf325f 100644
--- a/tensorflow/core/grappler/utils/functions.h
+++ b/tensorflow/core/grappler/utils/functions.h
@@ -136,13 +136,12 @@ class GrapplerFunctionItemInstantiation {
class GrapplerFunctionItem : public GrapplerItem {
public:
GrapplerFunctionItem() = default;
- GrapplerFunctionItem(
- const string& func_name, const string& description,
- const AttrValueMap& func_attr,
- const std::vector<InputArgExpansion>& input_arg_expansions,
- const std::vector<OutputArgExpansion>& output_arg_expansions,
- const std::vector<string>& keep_nodes, const int versions,
- bool is_stateful, GraphDef&& function_body);
+ GrapplerFunctionItem(string func_name, string description,
+ AttrValueMap func_attr,
+ std::vector<InputArgExpansion> input_arg_expansions,
+ std::vector<OutputArgExpansion> output_arg_expansions,
+ std::vector<string> keep_nodes, int graph_def_version,
+ bool is_stateful, GraphDef&& function_body);
const string& description() const;
diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD
index e7b3d0c92f..3a1ac73f64 100644
--- a/tensorflow/core/kernels/data/BUILD
+++ b/tensorflow/core/kernels/data/BUILD
@@ -51,6 +51,7 @@ cc_library(
hdrs = ["captured_function.h"],
deps = [
":dataset",
+ ":single_threaded_executor",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@@ -61,6 +62,42 @@ cc_library(
)
cc_library(
+ name = "single_threaded_executor",
+ srcs = ["single_threaded_executor.cc"],
+ hdrs = ["single_threaded_executor.h"],
+ deps = [
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:lib",
+ ],
+ alwayslink = 1,
+)
+
+tf_cc_test(
+ name = "single_threaded_executor_test",
+ srcs = ["single_threaded_executor_test.cc"],
+ deps = [
+ ":single_threaded_executor",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/kernels:array",
+ "//tensorflow/core/kernels:control_flow_ops",
+ "//tensorflow/core/kernels:function_ops",
+ "//tensorflow/core/kernels:math",
+ "//tensorflow/core/kernels:random_ops",
+ "//tensorflow/core/kernels:state",
+ ],
+)
+
+cc_library(
name = "window_dataset",
srcs = ["window_dataset.cc"],
hdrs = ["window_dataset.h"],
diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc
index abdf6ee4e8..186740c2ac 100644
--- a/tensorflow/core/kernels/data/captured_function.cc
+++ b/tensorflow/core/kernels/data/captured_function.cc
@@ -28,7 +28,16 @@ namespace tensorflow {
Status CapturedFunction::Create(
const NameAttrList& func, std::vector<Tensor> captured_inputs,
std::unique_ptr<CapturedFunction>* out_function) {
- out_function->reset(new CapturedFunction(func, std::move(captured_inputs)));
+ return Create(func, std::move(captured_inputs), true, out_function);
+}
+
+/* static */
+Status CapturedFunction::Create(
+ const NameAttrList& func, std::vector<Tensor> captured_inputs,
+ bool use_inter_op_parallelism,
+ std::unique_ptr<CapturedFunction>* out_function) {
+ out_function->reset(new CapturedFunction(func, std::move(captured_inputs),
+ use_inter_op_parallelism));
return Status::OK();
}
@@ -272,6 +281,9 @@ Status CapturedFunction::Instantiate(IteratorContext* ctx) {
inst_opts.overlay_lib = ctx->function_library().get();
inst_opts.state_handle = std::to_string(random::New64());
inst_opts.create_kernels_eagerly = true;
+ if (!use_inter_op_parallelism_) {
+ inst_opts.executor_type = "SINGLE_THREADED_EXECUTOR";
+ }
Status s = (lib_->Instantiate(func_.name(), AttrSlice(&func_.attr()),
inst_opts, &f_handle_));
TF_RETURN_IF_ERROR(s);
@@ -398,10 +410,12 @@ void CapturedFunction::RunAsync(IteratorContext* ctx,
}
CapturedFunction::CapturedFunction(const NameAttrList& func,
- std::vector<Tensor> captured_inputs)
+ std::vector<Tensor> captured_inputs,
+ bool use_inter_op_parallelism)
: func_(func),
lib_(nullptr),
f_handle_(kInvalidHandle),
- captured_inputs_(std::move(captured_inputs)) {}
+ captured_inputs_(std::move(captured_inputs)),
+ use_inter_op_parallelism_(use_inter_op_parallelism) {}
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/captured_function.h b/tensorflow/core/kernels/data/captured_function.h
index c95f2b1c01..ae6bdfc2a0 100644
--- a/tensorflow/core/kernels/data/captured_function.h
+++ b/tensorflow/core/kernels/data/captured_function.h
@@ -48,6 +48,15 @@ class CapturedFunction {
std::vector<Tensor> captured_inputs,
std::unique_ptr<CapturedFunction>* out_function);
+ // Creates a new instance from a list of named attributes and captured inputs.
+ //
+ // If `low_latency_hint` is true, the runtime may use an executor that is
+ // optimized for small functions.
+ static Status Create(const NameAttrList& func,
+ std::vector<Tensor> captured_inputs,
+ bool use_inter_op_parallelism,
+ std::unique_ptr<CapturedFunction>* out_function);
+
// Creates a new instance using a list of named attributes, fetching captured
// inputs from a context argument.
static Status Create(const NameAttrList& func, OpKernelContext* ctx,
@@ -114,7 +123,8 @@ class CapturedFunction {
private:
CapturedFunction(const NameAttrList& func,
- std::vector<Tensor> captured_inputs);
+ std::vector<Tensor> captured_inputs,
+ bool use_inter_op_parallelism);
Status GetHandle(IteratorContext* ctx,
FunctionLibraryRuntime::Handle* out_handle);
@@ -126,6 +136,7 @@ class CapturedFunction {
const std::vector<Tensor> captured_inputs_;
DataTypeSlice ret_types_;
std::function<void(std::function<void()>)> captured_runner_ = nullptr;
+ const bool use_inter_op_parallelism_;
TF_DISALLOW_COPY_AND_ASSIGN(CapturedFunction);
};
diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc
index 7f8182d917..6c45fcafcc 100644
--- a/tensorflow/core/kernels/data/map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_dataset_op.cc
@@ -34,6 +34,8 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("use_inter_op_parallelism",
+ &use_inter_op_parallelism_));
}
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
@@ -48,7 +50,8 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<CapturedFunction> captured_func;
OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- func_, std::move(other_arguments), &captured_func));
+ func_, std::move(other_arguments),
+ use_inter_op_parallelism_, &captured_func));
*output = new Dataset(ctx, input, func_, std::move(captured_func),
output_types_, output_shapes_);
@@ -187,6 +190,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
NameAttrList func_;
+ bool use_inter_op_parallelism_;
};
REGISTER_KERNEL_BUILDER(Name("MapDataset").Device(DEVICE_CPU), MapDatasetOp);
diff --git a/tensorflow/core/kernels/data/single_threaded_executor.cc b/tensorflow/core/kernels/data/single_threaded_executor.cc
new file mode 100644
index 0000000000..e785b8b4d5
--- /dev/null
+++ b/tensorflow/core/kernels/data/single_threaded_executor.cc
@@ -0,0 +1,378 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/data/single_threaded_executor.h"
+
+#include "tensorflow/core/common_runtime/executor.h"
+#include "tensorflow/core/common_runtime/executor_factory.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace {
+
+typedef gtl::InlinedVector<TensorValue, 4> TensorValueVec;
+typedef gtl::InlinedVector<DeviceContext*, 4> DeviceContextVec;
+typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec;
+
+class SingleThreadedExecutorImpl : public Executor {
+ public:
+ explicit SingleThreadedExecutorImpl(const LocalExecutorParams& params)
+ : params_(params) {}
+
+ ~SingleThreadedExecutorImpl() override {
+ for (const KernelState& kernel_state : kernels_) {
+ params_.delete_kernel(kernel_state.kernel);
+ }
+ }
+
+ Status Initialize(const Graph& graph) {
+ // Topologicially sort `graph` to get a sequence of OpKernels.
+ std::vector<Node*> ordered_nodes;
+ ordered_nodes.reserve(graph.num_nodes());
+ GetReversePostOrder(graph, &ordered_nodes);
+
+ if (ordered_nodes.size() != graph.num_nodes()) {
+ return errors::InvalidArgument("Graph had ", graph.num_nodes(),
+ " but reverse post-order had ",
+ ordered_nodes.size());
+ }
+
+ kernels_.resize(ordered_nodes.size());
+
+ std::unordered_map<Node*, size_t> node_to_index_map;
+
+ // Create the kernel and input-related structures for each node in `graph`.
+ for (size_t i = 0; i < ordered_nodes.size(); ++i) {
+ Node* n = ordered_nodes[i];
+ node_to_index_map[n] = i;
+
+ for (DataType dt : n->output_types()) {
+ if (IsRefType(dt)) {
+ return errors::Unimplemented(
+ "Single-threaded executor does not support reference-typed "
+ "edges.");
+ }
+ }
+
+ if (n->IsControlFlow()) {
+ return errors::Unimplemented(
+ "Single-threaded executor does not support control flow.");
+ }
+ if (n->IsSend() || n->IsHostSend() || n->IsRecv() || n->IsHostRecv()) {
+ return errors::Unimplemented(
+ "Single-threaded executor does not support partitioned graphs.");
+ }
+ if (n->IsCollective()) {
+ return errors::Unimplemented(
+ "Single-threaded executor does not support collective ops.");
+ }
+
+ KernelState& kernel_state = kernels_[i];
+ TF_RETURN_IF_ERROR(params_.create_kernel(n->def(), &kernel_state.kernel));
+ kernel_state.num_inputs = n->num_inputs();
+ kernel_state.num_outputs = n->num_outputs();
+
+ if (i == 0) {
+ kernel_state.input_start_index = 0;
+ } else {
+ const KernelState& previous_kernel_state = kernels_[i - 1];
+ kernel_state.input_start_index =
+ previous_kernel_state.input_start_index +
+ previous_kernel_state.num_inputs;
+ }
+ }
+
+ // Build the mapping from each node output to the input slot for the
+ // corresponding destination node.
+ for (size_t i = 0; i < ordered_nodes.size(); ++i) {
+ Node* n = ordered_nodes[i];
+ KernelState& kernel_state = kernels_[i];
+ kernel_state.output_locations.resize(kernel_state.num_outputs);
+ for (const Edge* e : n->out_edges()) {
+ if (!e->IsControlEdge()) {
+ kernel_state.output_locations[e->src_output()].push_back(
+ kernels_[node_to_index_map[e->dst()]].input_start_index +
+ e->dst_input());
+ }
+ }
+
+ // Compute allocator attributes for each node output, and corresponding
+ // node input.
+ kernel_state.output_alloc_attrs.resize(kernel_state.num_outputs);
+ AllocatorAttributes* attrs = kernel_state.output_alloc_attrs.data();
+
+ OpKernel* op_kernel = kernel_state.kernel;
+ for (int out = 0; out < n->num_outputs(); out++) {
+ DCHECK_LT(out, op_kernel->output_memory_types().size());
+ bool on_host = op_kernel->output_memory_types()[out] == HOST_MEMORY;
+ if (on_host) {
+ AllocatorAttributes h;
+ h.set_on_host(on_host);
+ attrs[out].Merge(h);
+ }
+ }
+ }
+
+ if (!kernels_.empty()) {
+ const KernelState& last_kernel_state = kernels_.back();
+ total_num_inputs_ =
+ last_kernel_state.input_start_index + last_kernel_state.num_inputs;
+ input_alloc_attrs_.resize(total_num_inputs_);
+ for (size_t i = 0; i < ordered_nodes.size(); ++i) {
+ for (size_t j = 0; j < kernels_[i].output_locations.size(); ++j) {
+ for (size_t output_location : kernels_[i].output_locations[j]) {
+ input_alloc_attrs_[output_location] =
+ kernels_[i].output_alloc_attrs[j];
+ }
+ }
+ }
+ } else {
+ total_num_inputs_ = 0;
+ }
+ return Status::OK();
+ }
+
+ // TODO(mrry): Consider specializing the implementation of Executor::Run()
+ // instead, to avoid unnecessary atomic operations in the callback when
+ // running synchronously.
+ void RunAsync(const Args& args, DoneCallback done) override {
+ // The inputs to each kernel are stored contiguously in `inputs`.
+ //
+ // We use `kernels_[i].input_start_index` and `kernels_[i].num_inputs` to
+ // determine the range of elements in this vector that correspond to
+ // the inputs of `kernels_[i]`.
+ //
+ // This vector has the following layout:
+ //
+ // * Kernel 0, input 0.
+ // * Kernel 0, input 1.
+ // * ...
+ // * Kernel 0, input `kernels_[0].num_inputs - 1`.
+ // * Kernel 1, input 0.
+ // * ...
+ // * Kernel 1, input `kernels_[1].num_inputs - 1`.
+ // * ...
+ // * Kernel `kernels_.size() - 1`, input 0.
+ // * ...
+ // * Kernel `kernels_.size() - 1`, input `kernels_.back().num_inputs - 1`.
+ //
+ // Note that kernels with zero inputs do not correspond to any elements in
+ // this vector.
+ //
+ // We use `ManualConstructor<Tensor>` to avoid the overhead of
+ // default-constructing an invalid `Tensor` for each slot at the beginning
+ // of execution:
+ // * Elements are initialized when the outputs of a kernel execution are
+ // propagated to the inputs of kernels that depend on them.
+ // * The elements corresponding to the inputs for kernel `i` are destroyed
+ // after kernel `i` executes.
+ // * In an error case (see below), we use the connectivity information in
+ // `KernelState::output_locations` to determine which locations have been
+ // initialized, and manually destroy them.
+ std::vector<ManualConstructor<Tensor>> inputs(total_num_inputs_);
+
+ // TODO(mrry): Can we avoid copying into these vectors? Consider modifying
+ // OpKernelContext to take the TensorValueVec as a pointer into `inputs`.
+ TensorValueVec node_inputs;
+ DeviceContextVec input_device_contexts;
+ AllocatorAttributeVec input_alloc_attrs;
+
+ // Prepare the parameters that will be the same for all kernels.
+ OpKernelContext::Params params;
+ params.step_id = args.step_id;
+ Device* device = params_.device;
+ params.device = device;
+ params.log_memory = false; // TODO(mrry): Too severe?
+ params.record_tensor_accesses = false; // TODO(mrry): Too severe?
+ params.rendezvous = args.rendezvous;
+ params.session_state = args.session_state;
+ params.tensor_store = args.tensor_store;
+ params.cancellation_manager = args.cancellation_manager;
+ // TODO(mrry): ArgOp is a relatively expensive OpKernel due to the Tensor
+ // allocations that it performs. Consider specializing its handling in the
+ // executor.
+ params.call_frame = args.call_frame;
+ params.function_library = params_.function_library;
+ params.resource_manager = device->resource_manager();
+ params.step_container = args.step_container;
+ params.slice_reader_cache = nullptr; // TODO(mrry): Too severe?
+ params.inputs = &node_inputs;
+ params.input_device_contexts = &input_device_contexts;
+ params.input_alloc_attrs = &input_alloc_attrs;
+
+ Args::Runner runner_copy = args.runner;
+ params.runner = &runner_copy;
+ params.stats_collector = args.stats_collector;
+
+ // NOTE(mrry): We are assuming that the graph is loopless and condless.
+ params.frame_iter = FrameAndIter(0, 0);
+ params.is_input_dead = false;
+
+ // TODO(mrry): Add non-default device context inference.
+ params.op_device_context = nullptr;
+ // TODO(mrry): Consider implementing forwarding.
+ params.forward_from_array = nullptr;
+
+ // Execute the kernels one-at-a-time in topological order.
+ for (size_t i = 0; i < kernels_.size(); ++i) {
+ const KernelState& kernel_state = kernels_[i];
+
+ // Prepare the per-kernel parameters.
+ const size_t input_start_index = kernel_state.input_start_index;
+ const size_t num_inputs = kernel_state.num_inputs;
+ const size_t num_outputs = kernel_state.num_outputs;
+
+ node_inputs.clear();
+ node_inputs.resize(num_inputs);
+ input_alloc_attrs.clear();
+ input_alloc_attrs.resize(num_inputs);
+ for (size_t j = 0; j < num_inputs; ++j) {
+ auto t = inputs[input_start_index + j].get();
+ node_inputs[j].tensor = t;
+ input_alloc_attrs[j] = input_alloc_attrs_[input_start_index + j];
+ }
+ params.op_kernel = kernel_state.kernel;
+ input_device_contexts.clear();
+ input_device_contexts.resize(num_inputs);
+ params.output_attr_array = kernel_state.output_alloc_attrs.data();
+ OpKernelContext ctx(&params, num_outputs);
+
+ // Actually execute the kernel.
+ device->Compute(kernel_state.kernel, &ctx);
+
+ if (!ctx.status().ok()) {
+ // On failure, we must manually free all intermediate tensors. We have
+ // already freed all the inputs for kernels up to (but not including)
+ // the `i`th kernel. We scan through the previously executed kernels and
+ // destroy any tensors that were destined to be the input for a kernel
+ // that has not yet executed.
+ for (size_t j = 0; j < i; ++j) {
+ const KernelState& executed_kernel_state = kernels_[j];
+ for (size_t k = 0; k < executed_kernel_state.num_outputs; ++k) {
+ for (size_t output_location :
+ executed_kernel_state.output_locations[k]) {
+ if (output_location >= input_start_index) {
+ // Only destroy an output location if it is an input to an
+ // operation that has not yet executed.
+ inputs[output_location].Destroy();
+ }
+ }
+ }
+ }
+ done(ctx.status());
+ return;
+ }
+
+ // Free the inputs to the current kernel.
+ for (size_t j = 0; j < num_inputs; ++j) {
+ inputs[input_start_index + j].Destroy();
+ }
+
+ // Forward the outputs of the kernel to the inputs of subsequent kernels.
+ for (size_t j = 0; j < num_outputs; ++j) {
+ TensorValue val = ctx.release_output(j);
+ // TODO(mrry): Consider flattening the `output_locations` vector
+ // to improve the cache-friendliness of this loop.
+ for (size_t output_location : kernel_state.output_locations[j]) {
+ // TODO(mrry): Validate that the types match the expected values or
+ // ensure that the necessary validation has already happened.
+ inputs[output_location].Init(*val.tensor);
+ }
+ delete val.tensor;
+ }
+ }
+ done(Status::OK());
+ }
+
+ private:
+ const LocalExecutorParams params_;
+
+ // All following members are read-only after Initialize().
+
+ // The sum of the number of inputs for each node in the graph. This determines
+ // the length of the flat `inputs` vector. See comment at the beginning of
+ // `RunAsync()` for details.
+ size_t total_num_inputs_;
+
+ // Represents cached graph structure state for each kernel.
+ struct KernelState {
+ // The kernel object. Not owned.
+ //
+ // This pointer is managed by `params_.create_kernel()` and
+ // `params_.delete_kernel()`.
+ OpKernel* kernel;
+
+ // These fields determine the range of elements in `inputs` that corresponds
+ // to the inputs of `kernel`.
+ size_t input_start_index;
+ size_t num_inputs;
+
+ size_t num_outputs;
+
+ // For the `j`th output of `kernel`, `output_locations[j]` contains the
+ // locations in the flat `inputs` vector to which that output must be
+ // copied. See comment at the beginning of `RunAsync()` for details.
+ std::vector<std::vector<size_t>>
+ output_locations; // Length = `num_outputs`.
+
+ // Memory space information for each output of `kernel`.
+ std::vector<AllocatorAttributes>
+ output_alloc_attrs; // Length = `num_outputs`.
+ };
+ std::vector<KernelState> kernels_;
+
+ // Memory space information for each input. This information is stored in the
+ // same order as the flat `inputs` vector. See comment at the beginning of
+ // `RunAsync()` for details.
+ std::vector<AllocatorAttributes>
+ input_alloc_attrs_; // Length = `total_num_inputs_`.
+};
+
+class SingleThreadedExecutorRegistrar {
+ public:
+ SingleThreadedExecutorRegistrar() {
+ ExecutorFactory::Register("SINGLE_THREADED_EXECUTOR", new Factory());
+ }
+
+ private:
+ class Factory : public ExecutorFactory {
+ Status NewExecutor(const LocalExecutorParams& params,
+ std::unique_ptr<const Graph> graph,
+ std::unique_ptr<Executor>* out_executor) override {
+ Executor* ret;
+ TF_RETURN_IF_ERROR(
+ NewSingleThreadedExecutor(params, std::move(graph), &ret));
+ out_executor->reset(ret);
+ return Status::OK();
+ }
+ };
+};
+static SingleThreadedExecutorRegistrar registrar;
+
+} // namespace
+
+Status NewSingleThreadedExecutor(const LocalExecutorParams& params,
+ std::unique_ptr<const Graph> graph,
+ Executor** executor) {
+ std::unique_ptr<SingleThreadedExecutorImpl> impl(
+ new SingleThreadedExecutorImpl(params));
+ TF_RETURN_IF_ERROR(impl->Initialize(*graph));
+ *executor = impl.release();
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/single_threaded_executor.h b/tensorflow/core/kernels/data/single_threaded_executor.h
new file mode 100644
index 0000000000..15836b24c9
--- /dev/null
+++ b/tensorflow/core/kernels/data/single_threaded_executor.h
@@ -0,0 +1,60 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_DATA_SINGLE_THREADED_EXECUTOR_H_
+#define TENSORFLOW_CORE_KERNELS_DATA_SINGLE_THREADED_EXECUTOR_H_
+
+#include "tensorflow/core/common_runtime/executor.h"
+
+namespace tensorflow {
+
+// Creates a new `Executor` for executing `graph` synchronously on the caller
+// thread.
+//
+// NOTE(mrry): The returned executor is optimized to impose low overhead on
+// graphs that perform a small amount of work (e.g. <15us of work per graph on
+// present architectures). It eschews concurrency, because issuing work to
+// multiple threads can dominate the cost of executing small ops synchronously,
+// and because contention in the executor data structures can reduce throughput
+// (in terms of ops executed per unit time).
+//
+// However, the current implementation has the following limitations:
+//
+// 1. Reference-typed tensors are not supported and will not be supported in
+// future.
+// 2. Graphs with control flow (containing "Switch" and "Merge" nodes) are not
+// currently supported. The current plan is to extend support to "functional"
+// control flow after the TensorFlow APIs transition to building graphs in
+// that form (e.g. `tf.cond_v2()`).
+// 3. Partitioned graphs (containing "_Recv" nodes) are not currently supported.
+// The present implementation executes kernels one at a time in topological
+// order, and cannot currently distinguish between disconnected subgraphs
+// that are logically connected by subgraphs on a different device.
+// 4. Memory logging is not currently supported.
+// 5. Allocation forwarding is not currently supported.
+// 6. Non-default device contexts are not currently supported. In effect, this
+// limits the executor to CPU devices.
+// 7. Ops that rely on `OpKernelContext::slice_reader_cache()` being non-null
+// are not currently supported.
+//
+// The single-threaded executor is primarily suitable for executing simple
+// TensorFlow functions, such as one might find in a `tf.data` pipeline.
+Status NewSingleThreadedExecutor(const LocalExecutorParams& params,
+ std::unique_ptr<const Graph> graph,
+ Executor** executor);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_DATA_SINGLE_THREADED_EXECUTOR_H_
diff --git a/tensorflow/core/kernels/data/single_threaded_executor_test.cc b/tensorflow/core/kernels/data/single_threaded_executor_test.cc
new file mode 100644
index 0000000000..f8b5769197
--- /dev/null
+++ b/tensorflow/core/kernels/data/single_threaded_executor_test.cc
@@ -0,0 +1,330 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/data/single_threaded_executor.h"
+
+#include <algorithm>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/executor.h"
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/common_runtime/process_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/rendezvous.h"
+#include "tensorflow/core/framework/versions.pb.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/platform/tracing.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+namespace {
+
+class ExecutorTest : public ::testing::Test {
+ protected:
+ ExecutorTest()
+ : device_(DeviceFactory::NewDevice("CPU", {},
+ "/job:localhost/replica:0/task:0")) {}
+
+ ~ExecutorTest() override {
+ // There should always be exactly one Ref left on the Rendezvous
+ // when the test completes.
+ CHECK(rendez_->Unref());
+ delete exec_;
+ delete device_;
+ }
+
+ // Resets executor_ with a new executor based on a graph 'gdef'.
+ void Create(std::unique_ptr<const Graph> graph) {
+ const int version = graph->versions().producer();
+ LocalExecutorParams params;
+ params.device = device_;
+ params.create_kernel = [this, version](const NodeDef& ndef,
+ OpKernel** kernel) {
+ return CreateNonCachedKernel(device_, nullptr, ndef, version, kernel);
+ };
+ params.delete_kernel = [](OpKernel* kernel) {
+ DeleteNonCachedKernel(kernel);
+ };
+ delete exec_;
+ TF_CHECK_OK(NewSingleThreadedExecutor(params, std::move(graph), &exec_));
+ runner_ = [](std::function<void()> fn) { fn(); };
+ rendez_ = NewLocalRendezvous();
+ }
+
+ Status Run(Rendezvous* rendez) {
+ Executor::Args args;
+ args.rendezvous = rendez;
+ args.runner = runner_;
+ return exec_->Run(args);
+ }
+
+ Status Run(CallFrameInterface* call_frame) {
+ Executor::Args args;
+ args.call_frame = call_frame;
+ args.runner = runner_;
+ return exec_->Run(args);
+ }
+
+ Device* device_ = nullptr;
+ Executor* exec_ = nullptr;
+ Executor::Args::Runner runner_;
+ Rendezvous* rendez_ = nullptr;
+};
+
+// A float val -> Tensor<float>
+Tensor V(const float val) {
+ Tensor tensor(DT_FLOAT, TensorShape({}));
+ tensor.scalar<float>()() = val;
+ return tensor;
+}
+
+// A int32 val -> Tensor<int32>
+Tensor VI(const int32 val) {
+ Tensor tensor(DT_INT32, TensorShape({}));
+ tensor.scalar<int32>()() = val;
+ return tensor;
+}
+
+// A bool val -> Tensor<bool>
+Tensor VB(const bool val) {
+ Tensor tensor(DT_BOOL, TensorShape({}));
+ tensor.scalar<bool>()() = val;
+ return tensor;
+}
+
+// A double val -> Tensor<double>
+Tensor VD(const double val) {
+ Tensor tensor(DT_DOUBLE, TensorShape({}));
+ tensor.scalar<double>()() = val;
+ return tensor;
+}
+
+// Tensor<float> -> a float val.
+float V(const Tensor& tensor) {
+ CHECK_EQ(tensor.dtype(), DT_FLOAT);
+ CHECK(TensorShapeUtils::IsScalar(tensor.shape()));
+ return tensor.scalar<float>()();
+}
+
+Rendezvous::ParsedKey Key(const string& sender, const uint64 incarnation,
+ const string& receiver, const string& name) {
+ Rendezvous::ParsedKey result;
+ TF_CHECK_OK(
+ Rendezvous::ParseKey(Rendezvous::CreateKey(sender, incarnation, receiver,
+ name, FrameAndIter(0, 0)),
+ &result));
+ return result;
+}
+
+TEST_F(ExecutorTest, SimpleAdd) {
+ // c = a + b
+ std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
+ auto in0 = test::graph::Arg(g.get(), 0, DT_FLOAT);
+ auto in1 = test::graph::Arg(g.get(), 0, DT_FLOAT);
+ auto tmp = test::graph::Add(g.get(), in0, in1);
+ test::graph::Retval(g.get(), 0, tmp);
+ FixupSourceAndSinkEdges(g.get());
+ Create(std::move(g));
+ FunctionCallFrame call_frame({DT_FLOAT, DT_FLOAT}, {DT_FLOAT});
+ TF_ASSERT_OK(call_frame.SetArgs({V(1.0), V(1.0)}));
+ TF_ASSERT_OK(Run(&call_frame));
+ std::vector<Tensor> retvals;
+ TF_ASSERT_OK(call_frame.ConsumeRetvals(&retvals, false));
+ EXPECT_EQ(2.0, V(retvals[0])); // out = 1.0 + 1.0 = 2.0
+}
+
+TEST_F(ExecutorTest, SelfAdd) {
+ // v0 <- a
+ // v1 = v0 + v0
+ // v2 = v1 + v1
+ // ... ...
+ // v10 = v9 + v9
+ //
+ // b <- v10
+ // All nodes are executed by one thread.
+ std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
+ auto v = test::graph::Arg(g.get(), 0, DT_FLOAT);
+ const int N = 10;
+ for (int i = 1; i <= N; ++i) {
+ v = test::graph::Add(g.get(), v, v);
+ }
+ // out <- v10
+ test::graph::Retval(g.get(), 0, v);
+ FixupSourceAndSinkEdges(g.get());
+ Create(std::move(g));
+ FunctionCallFrame call_frame({DT_FLOAT}, {DT_FLOAT});
+ // a = 1.0
+ TF_ASSERT_OK(call_frame.SetArgs({V(1.0)}));
+ TF_ASSERT_OK(Run(&call_frame));
+ std::vector<Tensor> retvals;
+ TF_ASSERT_OK(call_frame.ConsumeRetvals(&retvals, false));
+ EXPECT_EQ(1024.0, V(retvals[0])); // b=v10=2*v9=4*v8=...=1024*a=1024.0
+}
+
+// Builds a graph which adds N copies of one variable "in". I.e.,
+// a + a + a + ... + a
+// The returned graph is parenthesized ramdonly. I.e.,
+// a + ((a + a) + a)
+// (a + a) + (a + a)
+// ((a + a) + a) + a
+// are all possibly generated.
+void BuildTree(int N, Graph* g) {
+ CHECK_GT(N, 1);
+ // A single input node "in".
+ auto in = test::graph::Arg(g, 0, DT_FLOAT);
+ std::vector<Node*> nodes;
+ int i = 0;
+ // Duplicate "in" N times. Each copies is named as l0, l1, l2, ....
+ for (; i < N; ++i) {
+ nodes.push_back(test::graph::Identity(g, in, 0));
+ }
+ random::PhiloxRandom philox(0, 17);
+ random::SimplePhilox rnd(&philox);
+ while (nodes.size() > 1) {
+ // Randomly pick two from nodes and add them. The resulting node
+ // is named lik n10, n11, .... and is put back into "nodes".
+ int x = rnd.Uniform(nodes.size());
+ auto in0 = nodes[x];
+ nodes[x] = nodes.back();
+ nodes.resize(nodes.size() - 1);
+ x = rnd.Uniform(nodes.size());
+ auto in1 = nodes[x];
+ // node = in0 + in1.
+ nodes[x] = test::graph::Add(g, in0, in1);
+ }
+ // The final output node "out".
+ test::graph::Retval(g, 0, nodes.back());
+ FixupSourceAndSinkEdges(g);
+}
+
+TEST_F(ExecutorTest, RandomTree) {
+ std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
+ BuildTree(4096, g.get());
+ Create(std::move(g));
+ FunctionCallFrame call_frame({DT_FLOAT}, {DT_FLOAT});
+ TF_ASSERT_OK(call_frame.SetArgs({V(1.0)}));
+ TF_ASSERT_OK(Run(&call_frame));
+ std::vector<Tensor> retvals;
+ TF_ASSERT_OK(call_frame.ConsumeRetvals(&retvals, false));
+ EXPECT_EQ(4096.0, V(retvals[0]));
+}
+
+TEST_F(ExecutorTest, OpError) {
+ std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
+ auto zero = test::graph::Constant(g.get(), V(0.0));
+ auto inf = test::graph::Unary(g.get(), "Reciprocal", zero);
+ auto check = test::graph::CheckNumerics(g.get(), inf, "message");
+ auto two = test::graph::Constant(g.get(), V(2.0));
+ test::graph::Binary(g.get(), "Mul", check, two);
+ FixupSourceAndSinkEdges(g.get());
+ Create(std::move(g));
+ FunctionCallFrame call_frame({}, {});
+ // Fails due to invalid dtype.
+ EXPECT_TRUE(errors::IsInvalidArgument(Run(&call_frame)));
+}
+
+static void BM_executor(int iters, int width, int depth) {
+#ifdef PLATFORM_GOOGLE
+ BenchmarkUseRealTime();
+#endif // PLATFORM_GOOGLE
+ Graph* g = new Graph(OpRegistry::Global());
+ random::PhiloxRandom philox(1729, 17);
+ random::SimplePhilox rand(&philox);
+ uint64 cur = 0;
+ uint32 r = 1 + rand.Rand32() % width;
+ std::vector<Node*> ready_nodes;
+ for (int i = 0; i < r; ++i) {
+ ready_nodes.push_back(test::graph::NoOp(g, {}));
+ ++cur;
+ }
+ for (int i = 0; i < depth; ++i) {
+ std::random_shuffle(ready_nodes.begin(), ready_nodes.end());
+ r = 1 + rand.Rand32() % (ready_nodes.size());
+ std::vector<Node*> control_inputs;
+ for (int j = 0; j < r; ++j) {
+ control_inputs.push_back(ready_nodes.back());
+ ready_nodes.pop_back();
+ }
+ Node* n = test::graph::NoOp(g, control_inputs);
+ ++cur;
+ r = 1 + rand.Rand32() % width;
+ for (int j = 0; j < r; ++j) {
+ ready_nodes.push_back(test::graph::NoOp(g, {n}));
+ ++cur;
+ }
+ }
+ FixupSourceAndSinkEdges(g);
+#ifdef PLATFORM_GOOGLE
+ SetBenchmarkLabel(strings::StrCat("Nodes = ", cur));
+ SetBenchmarkItemsProcessed(cur * static_cast<int64>(iters));
+#endif // PLATFORM_GOOGLE
+ test::Benchmark("cpu", g, nullptr, nullptr, nullptr,
+ "SINGLE_THREADED_EXECUTOR")
+ .Run(iters);
+}
+
+// Tall skinny graphs
+BENCHMARK(BM_executor)->ArgPair(16, 1024);
+BENCHMARK(BM_executor)->ArgPair(32, 8192);
+
+// Short fat graphs
+BENCHMARK(BM_executor)->ArgPair(1024, 16);
+BENCHMARK(BM_executor)->ArgPair(8192, 32);
+
+// Tall fat graph
+BENCHMARK(BM_executor)->ArgPair(1024, 1024);
+
+// TODO(mrry): This benchmark currently crashes with a use-after free, because
+// test::Benchmark::RunWithArgs() assumes that the executor will take ownership
+// of the given graph, *and* keep its nodes (`x`, `y` and `z`) alive for the
+// duration of the benchmark. Since the single threaded executor does not retain
+// a copy of the graph, this fails.
+//
+// TODO(mrry): Add support for Arg/Retval "function call convention" in
+// `test::Benchmark::RunWithArgs()`.
+#if 0
+#define ALICE "/job:j/replica:0/task:0/cpu:0"
+#define BOB "/job:j/replica:0/task:0/gpu:0"
+
+static void BM_FeedInputFetchOutput(int iters) {
+ Graph* g = new Graph(OpRegistry::Global());
+ // z = x + y: x and y are provided as benchmark inputs. z is the
+ // output of the benchmark. Conceptually, the caller is ALICE, the
+ // benchmark is BOB.
+ Node* x = test::graph::Recv(g, "x", "float", ALICE, 1, BOB);
+ Node* y = test::graph::Recv(g, "y", "float", ALICE, 1, BOB);
+ Node* sum = test::graph::Add(g, x, y);
+ Node* z = test::graph::Send(g, sum, "z", BOB, 1, ALICE);
+ FixupSourceAndSinkEdges(g);
+ Tensor val(DT_FLOAT, TensorShape({}));
+ val.scalar<float>()() = 3.14;
+ SetBenchmarkItemsProcessed(static_cast<int64>(iters));
+ test::Benchmark("cpu", g, nullptr, nullptr, nullptr,
+ "SINGLE_THREADED_EXECUTOR")
+ .RunWithArgs({{x, val}, {y, val}}, {z}, iters);
+}
+BENCHMARK(BM_FeedInputFetchOutput);
+#endif
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/debug_ops.h b/tensorflow/core/kernels/debug_ops.h
index 33ed5522d0..d705e82b0d 100644
--- a/tensorflow/core/kernels/debug_ops.h
+++ b/tensorflow/core/kernels/debug_ops.h
@@ -255,7 +255,7 @@ class DebugNanCountOp : public BaseDebugOp {
TensorShape shape({1});
OP_REQUIRES_OK(context, context->allocate_output(0, shape, &output_tensor));
output_tensor->vec<int64>()(0) = nan_count;
- PublishTensor(*output_tensor);
+ OP_REQUIRES_OK(context, PublishTensor(*output_tensor));
}
};
@@ -380,7 +380,7 @@ class DebugNumericSummaryOp : public BaseDebugOp {
bool mute = mute_if_healthy_ && nan_count == 0 && negative_inf_count == 0 &&
positive_inf_count == 0;
if (!mute) {
- PublishTensor(*output_tensor);
+ OP_REQUIRES_OK(context, PublishTensor(*output_tensor));
}
}
diff --git a/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h b/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h
index e13e548f86..3ebeb7be2b 100644
--- a/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h
+++ b/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h
@@ -323,47 +323,34 @@ CuboidConvolutionBackwardInput(
template <typename OutputBackward, typename Input>
EIGEN_ALWAYS_INLINE static const typename internal::conditional<
internal::traits<OutputBackward>::Layout == ColMajor,
- const TensorShufflingOp<
- const array<typename internal::traits<OutputBackward>::Index, 5>,
- const TensorReverseOp<
- const array<bool, 5>,
+ TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index, 5>,
+ const TensorContractionOp<
+ const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
const TensorReshapingOp<
- const DSizes<typename internal::traits<OutputBackward>::Index,
- 5>,
- const TensorContractionOp<
- const array<
- IndexPair<typename internal::traits<Input>::Index>, 2>,
- const TensorReshapingOp<
- const DSizes<typename internal::traits<Input>::Index,
- 3>,
- const Input>,
- const TensorReshapingOp<
- const DSizes<
- typename internal::traits<OutputBackward>::Index,
- 4>,
- const TensorVolumePatchOp<
- Dynamic, Dynamic, Dynamic,
- const OutputBackward> > > > > >,
- const TensorShufflingOp<
- const array<typename internal::traits<OutputBackward>::Index, 5>,
- const TensorReverseOp<
- const array<bool, 5>,
+ const DSizes<typename internal::traits<Input>::Index, 2>,
+ const OutputBackward>,
+ const TensorShufflingOp<
+ const array<typename internal::traits<OutputBackward>::Index,
+ 2>,
+ const TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index, 2>,
+ const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic,
+ const Input> > > > >,
+ TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index, 5>,
+ const TensorContractionOp<
+ const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
+ const TensorShufflingOp<
+ const array<typename internal::traits<OutputBackward>::Index,
+ 2>,
+ const TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index, 2>,
+ const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic,
+ const Input> > >,
const TensorReshapingOp<
- const DSizes<typename internal::traits<OutputBackward>::Index,
- 5>,
- const TensorContractionOp<
- const array<
- IndexPair<typename internal::traits<Input>::Index>, 2>,
- const TensorReshapingOp<
- const DSizes<
- typename internal::traits<OutputBackward>::Index,
- 4>,
- const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic,
- const OutputBackward> >,
- const TensorReshapingOp<
- const DSizes<typename internal::traits<Input>::Index,
- 3>,
- const Input> > > > > >::type
+ const DSizes<typename internal::traits<Input>::Index, 2>,
+ const OutputBackward> > > >::type
CuboidConvolutionBackwardKernel(
const Input& input, const OutputBackward& output_backward,
typename internal::traits<Input>::Index kernelPlanes,
@@ -406,213 +393,114 @@ CuboidConvolutionBackwardKernel(
const TensorIndex outputCols =
isColMajor ? out.dimension(3) : out.dimension(NumDims - 4);
+ // Number of filters. This is the same as the output depth.
const TensorIndex kernelFilters =
isColMajor ? out.dimension(0) : out.dimension(NumDims - 1);
+ // Number of channels. This is the same as the input depth.
const TensorIndex kernelChannels =
isColMajor ? in.dimension(0) : in.dimension(NumDims - 1);
- TensorIndex forward_pad_z, forward_pad_y, forward_pad_x;
- const TensorIndex size_z =
- Eigen::divup(inputPlanes, static_cast<TensorIndex>(stridePlanes));
- const TensorIndex size_y =
- Eigen::divup(inputRows, static_cast<TensorIndex>(strideRows));
- const TensorIndex size_x =
- Eigen::divup(inputCols, static_cast<TensorIndex>(strideCols));
-
- // Infer padding type.
- if (size_z == outputPlanes && size_y == outputRows && size_x == outputCols) {
- // SAME padding.
- const TensorIndex dz = numext::maxi<TensorIndex>(
- 0, (size_z - 1) * stridePlanes + kernelPlanes - inputPlanes);
- const TensorIndex dy = numext::maxi<TensorIndex>(
- 0, (size_y - 1) * strideRows + kernelRows - inputRows);
- const TensorIndex dx = numext::maxi<TensorIndex>(
- 0, (size_x - 1) * strideCols + kernelCols - inputCols);
-
- forward_pad_z = dz / 2;
- forward_pad_y = dy / 2;
- forward_pad_x = dx / 2;
- } else {
- // VALID padding.
- forward_pad_z = 0;
- forward_pad_y = 0;
- forward_pad_x = 0;
- }
-
- const TensorIndex padding_ztop = kernelPlanes - 1 - forward_pad_z;
- const TensorIndex padding_top = kernelRows - 1 - forward_pad_y;
- const TensorIndex padding_left = kernelCols - 1 - forward_pad_x;
-
- const TensorIndex padding_zbottom = inputPlanes + kernelPlanes - 1 -
- (outputPlanes - 1) * stridePlanes - 1 -
- padding_ztop;
- const TensorIndex padding_bottom = inputRows + kernelRows - 1 -
- (outputRows - 1) * strideRows - 1 -
- padding_top;
- const TensorIndex padding_right = inputCols + kernelCols - 1 -
- (outputCols - 1) * strideCols - 1 -
- padding_left;
-
- eigen_assert(padding_ztop >= 0);
- eigen_assert(padding_zbottom >= 0);
- eigen_assert(padding_top >= 0);
- eigen_assert(padding_left >= 0);
- eigen_assert(padding_bottom >= 0);
- eigen_assert(padding_right >= 0);
-
- // The output_backward has dimensions out_depth X out_plaens X out_rows X
- // out_cols X OTHERS
- // When we extract the image patches from output_backward (with input as the
- // kernel), it will have dimensions
- // (out_depth) X (input_planes * input_rows * input_cols) X (kernel_planes *
- // kernel_rows * kernel_cols) X OTHERS
- DSizes<TensorIndex, 4> pre_contract_dims;
+ // TODO(ezhulenev): Add support for inflated strides. Without inflated strides
+ // effective kernel planes/rows/cols are always the same as the kernel itself
+ // (see eigen_spatial_convolutions for details).
+ const TensorIndex kernelPlanesEff = kernelPlanes;
+ const TensorIndex kernelRowsEff = kernelRows;
+ const TensorIndex kernelColsEff = kernelCols;
+
+ const TensorIndex padPlanes = numext::maxi<Index>(
+ 0, (outputPlanes - 1) * stridePlanes + kernelPlanesEff - inputPlanes);
+ const TensorIndex padRows = numext::maxi<Index>(
+ 0, (outputRows - 1) * strideRows + kernelRowsEff - inputRows);
+ const TensorIndex padCols = numext::maxi<Index>(
+ 0, (outputCols - 1) * strideCols + kernelColsEff - inputCols);
+
+ const TensorIndex padding_top_z = padPlanes / 2;
+ const TensorIndex padding_bottom_z = padPlanes - padding_top_z;
+ const TensorIndex padding_top = padRows / 2;
+ const TensorIndex padding_bottom = padRows - padding_top;
+ const TensorIndex padding_left = padCols / 2;
+ const TensorIndex padding_right = padCols - padding_left;
+
+ // Reshaped output_backward before contraction.
+ DSizes<TensorIndex, 2> output_dims;
if (isColMajor) {
- pre_contract_dims[0] = kernelFilters;
- pre_contract_dims[1] = inputRows * inputCols * inputPlanes;
- pre_contract_dims[2] = kernelRows * kernelCols * kernelPlanes;
- pre_contract_dims[3] = 1;
+ output_dims[0] = kernelFilters;
+ output_dims[1] = outputPlanes * outputRows * outputCols;
for (int i = 4; i < NumDims; ++i) {
- pre_contract_dims[3] *= out.dimension(i);
+ output_dims[1] *= out.dimension(i);
}
} else {
- pre_contract_dims[3] = kernelFilters;
- pre_contract_dims[2] = inputRows * inputCols * inputPlanes;
- pre_contract_dims[1] = kernelRows * kernelCols * kernelPlanes;
- pre_contract_dims[0] = 1;
+ output_dims[1] = kernelFilters;
+ output_dims[0] = outputCols * outputRows * outputPlanes;
for (int i = 0; i < NumDims - 4; ++i) {
- pre_contract_dims[0] *= out.dimension(i);
+ output_dims[0] *= out.dimension(i);
}
}
- // The input has dimensions in_depth X (input_planes * input_rows *
- // input_cols) X OTHERS
- DSizes<TensorIndex, 3> input_dims;
+ // Reshaped extract_volume_patches(in)
+ DSizes<TensorIndex, 2> pre_contract_dims;
if (isColMajor) {
- input_dims[0] = kernelChannels;
- input_dims[1] = inputRows * inputCols * inputPlanes;
- input_dims[2] = 1;
+ pre_contract_dims[0] =
+ kernelChannels * kernelPlanes * kernelRows * kernelCols;
+ pre_contract_dims[1] = outputPlanes * outputRows * outputCols;
for (int i = 4; i < NumDims; ++i) {
- input_dims[2] *= in.dimension(i);
+ pre_contract_dims[1] *= in.dimension(i);
}
- eigen_assert(input_dims[2] == pre_contract_dims[3]);
+ eigen_assert(output_dims[1] == pre_contract_dims[1]);
} else {
- input_dims[2] = kernelChannels;
- input_dims[1] = inputRows * inputCols * inputPlanes;
- input_dims[0] = 1;
+ pre_contract_dims[1] =
+ kernelCols * kernelRows * kernelPlanes * kernelChannels;
+ pre_contract_dims[0] = outputCols * outputRows * outputPlanes;
for (int i = 0; i < NumDims - 4; ++i) {
- input_dims[0] *= in.dimension(i);
+ pre_contract_dims[0] *= in.dimension(i);
}
- eigen_assert(input_dims[0] == pre_contract_dims[0]);
+ eigen_assert(output_dims[0] == pre_contract_dims[0]);
}
- // We will contract along dimensions (1, 2) in and (1, 3) in out, if
- // this is col-major.
- // For row-major, it's dimensions (0, 1) in and (0, 2) in out.
- array<IndexPair<TensorIndex>, 2> contract_dims;
- if (isColMajor) {
- // col-major: in.contract(output.patches)
- contract_dims[0] = IndexPair<TensorIndex>(1, 1);
- contract_dims[1] = IndexPair<TensorIndex>(2, 3);
- } else {
- // row-major: output.patches.contract(in)
- contract_dims[0] = IndexPair<TensorIndex>(0, 0);
- contract_dims[1] = IndexPair<TensorIndex>(2, 1);
- }
+ array<TensorIndex, 2> shuffle_dims;
+ shuffle_dims[0] = 1;
+ shuffle_dims[1] = 0;
- // After the contraction, the kernel will have dimension
- // in_depth X out_depth X kernel_patches X kernel_rows X kernel_cols
- // We will need to shuffle the first two dimensions and reverse the spatial
- // dimensions.
- // The end shape is:
- // out_depth X in_shape X kernel_planes X kernel_rows X kernel_cols
+ array<IndexPair<TensorIndex>, 1> contract_dims;
+ contract_dims[0] = IndexPair<TensorIndex>(1, 0);
- // This is the shape of the kernel *before* the shuffling.
DSizes<TensorIndex, 5> kernel_dims;
if (isColMajor) {
- kernel_dims[0] = kernelChannels;
- kernel_dims[1] = kernelFilters;
+ kernel_dims[0] = kernelFilters;
+ kernel_dims[1] = kernelChannels;
kernel_dims[2] = kernelPlanes;
kernel_dims[3] = kernelRows;
kernel_dims[4] = kernelCols;
} else {
- kernel_dims[0] = kernelCols;
- kernel_dims[1] = kernelRows;
+ kernel_dims[4] = kernelFilters;
+ kernel_dims[3] = kernelChannels;
kernel_dims[2] = kernelPlanes;
- kernel_dims[3] = kernelFilters;
- kernel_dims[4] = kernelChannels;
- }
-
- // Flip filters and channels.
- array<TensorIndex, 5> kernel_shuffle;
- if (isColMajor) {
- kernel_shuffle[0] = 1;
- kernel_shuffle[1] = 0;
- kernel_shuffle[2] = 2;
- kernel_shuffle[3] = 3;
- kernel_shuffle[4] = 4;
- } else {
- kernel_shuffle[0] = 0;
- kernel_shuffle[1] = 1;
- kernel_shuffle[2] = 2;
- kernel_shuffle[3] = 4;
- kernel_shuffle[4] = 3;
- }
-
- // Reverse the spatial dimensions.
- array<bool, 5> kernel_reverse;
- if (isColMajor) {
- kernel_reverse[0] = false;
- kernel_reverse[1] = false;
- kernel_reverse[2] = true;
- kernel_reverse[3] = true;
- kernel_reverse[4] = true;
- } else {
- kernel_reverse[0] = true;
- kernel_reverse[1] = true;
- kernel_reverse[2] = true;
- kernel_reverse[3] = false;
- kernel_reverse[4] = false;
+ kernel_dims[1] = kernelRows;
+ kernel_dims[0] = kernelCols;
}
- DSizes<TensorIndex, NumDims> strides;
- for (int i = 0; i < NumDims; i++) {
- strides[i] = 1;
- }
- if (isColMajor) {
- strides[1] = stridePlanes;
- strides[2] = strideRows;
- strides[3] = strideCols;
- } else {
- strides[NumDims - 2] = stridePlanes;
- strides[NumDims - 3] = strideRows;
- strides[NumDims - 4] = strideCols;
- }
return choose(
Cond<internal::traits<Input>::Layout == ColMajor>(),
- input.reshape(input_dims)
- .contract(output_backward
+ output_backward.reshape(output_dims)
+ .contract(input
.extract_volume_patches(
- inputPlanes, inputRows, inputCols, 1, 1, 1,
- stridePlanes, strideRows, strideCols,
-
- padding_ztop, padding_zbottom, padding_top,
- padding_bottom, padding_left, padding_right)
- .reshape(pre_contract_dims),
+ kernelPlanes, kernelRows, kernelCols, stridePlanes,
+ strideRows, strideCols, 1, 1, 1, padding_top_z,
+ padding_bottom_z, padding_top, padding_bottom,
+ padding_left, padding_right)
+ .reshape(pre_contract_dims)
+ .shuffle(shuffle_dims),
contract_dims)
- .reshape(kernel_dims)
- .reverse(kernel_reverse)
- .shuffle(kernel_shuffle),
- output_backward
- .extract_volume_patches(inputPlanes, inputRows, inputCols, 1, 1, 1,
- stridePlanes, strideRows, strideCols,
- padding_ztop, padding_zbottom, padding_top,
+ .reshape(kernel_dims),
+ input
+ .extract_volume_patches(kernelPlanes, kernelRows, kernelCols,
+ stridePlanes, strideRows, strideCols, 1, 1, 1,
+ padding_top_z, padding_bottom_z, padding_top,
padding_bottom, padding_left, padding_right)
.reshape(pre_contract_dims)
- .contract(input.reshape(input_dims), contract_dims)
- .reshape(kernel_dims)
- .reverse(kernel_reverse)
- .shuffle(kernel_shuffle));
+ .shuffle(shuffle_dims)
+ .contract(output_backward.reshape(output_dims), contract_dims)
+ .reshape(kernel_dims));
}
} // end namespace Eigen
diff --git a/tensorflow/core/kernels/eigen_benchmark.h b/tensorflow/core/kernels/eigen_benchmark.h
index 46ad38fb77..87e41b89b3 100644
--- a/tensorflow/core/kernels/eigen_benchmark.h
+++ b/tensorflow/core/kernels/eigen_benchmark.h
@@ -76,6 +76,9 @@ class SpatialConvolutionBenchmarksSuite {
void SpatialConvolutionBackwardInput(Dimensions input_dims,
Dimensions filter_dims) {
+ using OutputBackward = TTypes<float, 4>::ConstTensor;
+ using InputBackward = TTypes<float, 4>::Tensor;
+
Dimensions output_dims(input_dims[0], // batch
input_dims[1], // input_height
input_dims[2], // input_width
@@ -85,37 +88,37 @@ class SpatialConvolutionBenchmarksSuite {
Eigen::Index input_rows = input_dims[1];
Eigen::Index input_cols = input_dims[2];
- Scalar* input_data =
- static_cast<Scalar*>(device_.allocate(BufferSize(input_dims)));
Scalar* filter_data =
static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims)));
- Scalar* output_data =
+ Scalar* output_backward_data =
static_cast<Scalar*>(device_.allocate(BufferSize(output_dims)));
+ Scalar* input_backward_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(input_dims)));
- device_.memset(input_data, 123, BufferSize(input_dims));
device_.memset(filter_data, 123, BufferSize(filter_dims));
+ device_.memset(output_backward_data, 123, BufferSize(output_dims));
- Input input(input_data, input_dims);
Filter filter(filter_data, filter_dims);
- Output output(output_data, output_dims);
+ OutputBackward output_backward(output_backward_data, output_dims);
+ InputBackward input_backward(input_backward_data, input_dims);
::tensorflow::testing::StartTiming();
for (int i = 0; i < iters_; ++i) {
- output.device(device_) = Eigen::SpatialConvolutionBackwardInput(
- filter, input, input_rows, input_cols);
- tensorflow::testing::DoNotOptimize(output);
+ input_backward.device(device_) = Eigen::SpatialConvolutionBackwardInput(
+ filter, output_backward, input_rows, input_cols);
+ tensorflow::testing::DoNotOptimize(input_backward);
}
::tensorflow::testing::StopTiming();
- device_.deallocate(input_data);
device_.deallocate(filter_data);
- device_.deallocate(output_data);
+ device_.deallocate(output_backward_data);
+ device_.deallocate(input_backward_data);
}
void SpatialConvolutionBackwardKernel(Dimensions input_dims,
Dimensions filter_dims) {
using OutputBackward = TTypes<float, 4>::ConstTensor;
- using FilterGrad = TTypes<float, 4>::Tensor;
+ using FilterBackward = TTypes<float, 4>::Tensor;
Dimensions output_dims(input_dims[0], // batch
input_dims[1], // input_height
@@ -130,7 +133,7 @@ class SpatialConvolutionBenchmarksSuite {
static_cast<Scalar*>(device_.allocate(BufferSize(input_dims)));
Scalar* output_backward_data =
static_cast<Scalar*>(device_.allocate(BufferSize(output_dims)));
- Scalar* filter_data =
+ Scalar* filter_backward_data =
static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims)));
device_.memset(input_data, 123, BufferSize(input_dims));
@@ -138,19 +141,19 @@ class SpatialConvolutionBenchmarksSuite {
Input input(input_data, input_dims);
OutputBackward output_backward(output_backward_data, input_dims);
- FilterGrad filter_grad(filter_data, filter_dims);
+ FilterBackward filter_backward(filter_backward_data, filter_dims);
::tensorflow::testing::StartTiming();
for (int i = 0; i < iters_; ++i) {
- filter_grad.device(device_) = Eigen::SpatialConvolutionBackwardKernel(
+ filter_backward.device(device_) = Eigen::SpatialConvolutionBackwardKernel(
input, output_backward, filter_rows, filter_cols);
- tensorflow::testing::DoNotOptimize(filter_grad);
+ tensorflow::testing::DoNotOptimize(filter_backward);
}
::tensorflow::testing::StopTiming();
device_.deallocate(input_data);
device_.deallocate(output_backward_data);
- device_.deallocate(filter_data);
+ device_.deallocate(filter_backward_data);
}
private:
@@ -215,42 +218,45 @@ class CuboidConvolutionBenchmarksSuite {
input_dims[3], // input_planes
filter_dims[4]); // filter_count
+ using OutputBackward = TTypes<float, 5>::ConstTensor;
+ using InputBackward = TTypes<float, 5>::Tensor;
+
// Assuming that the convolution had SAME padding.
Eigen::Index input_rows = input_dims[1];
Eigen::Index input_cols = input_dims[2];
Eigen::Index input_planes = input_dims[3];
- Scalar* input_data =
- static_cast<Scalar*>(device_.allocate(BufferSize(input_dims)));
Scalar* filter_data =
static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims)));
- Scalar* output_data =
+ Scalar* output_backward_data =
static_cast<Scalar*>(device_.allocate(BufferSize(output_dims)));
+ Scalar* input_backward_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(input_dims)));
- device_.memset(input_data, 123, BufferSize(input_dims));
device_.memset(filter_data, 123, BufferSize(filter_dims));
+ device_.memset(output_backward_data, 123, BufferSize(output_dims));
- Input input(input_data, input_dims);
Filter filter(filter_data, filter_dims);
- Output output(output_data, output_dims);
+ OutputBackward output_backward(output_backward_data, output_dims);
+ InputBackward input_backward(input_backward_data, input_dims);
::tensorflow::testing::StartTiming();
for (int i = 0; i < iters_; ++i) {
- output.device(device_) = Eigen::CuboidConvolutionBackwardInput(
- filter, input, input_planes, input_rows, input_cols);
- tensorflow::testing::DoNotOptimize(output);
+ input_backward.device(device_) = Eigen::CuboidConvolutionBackwardInput(
+ filter, output_backward, input_planes, input_rows, input_cols);
+ tensorflow::testing::DoNotOptimize(input_backward);
}
::tensorflow::testing::StopTiming();
- device_.deallocate(input_data);
device_.deallocate(filter_data);
- device_.deallocate(output_data);
+ device_.deallocate(output_backward_data);
+ device_.deallocate(input_backward_data);
}
void CuboidConvolutionBackwardKernel(Dimensions input_dims,
Dimensions filter_dims) {
using OutputBackward = TTypes<float, 5>::ConstTensor;
- using FilterGrad = TTypes<float, 5>::Tensor;
+ using FilterBackward = TTypes<float, 5>::Tensor;
Dimensions output_dims(input_dims[0], // batch
input_dims[1], // input_height
@@ -267,7 +273,7 @@ class CuboidConvolutionBenchmarksSuite {
static_cast<Scalar*>(device_.allocate(BufferSize(input_dims)));
Scalar* output_backward_data =
static_cast<Scalar*>(device_.allocate(BufferSize(output_dims)));
- Scalar* filter_data =
+ Scalar* filter_backward_data =
static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims)));
device_.memset(input_data, 123, BufferSize(input_dims));
@@ -275,19 +281,19 @@ class CuboidConvolutionBenchmarksSuite {
Input input(input_data, input_dims);
OutputBackward output_backward(output_backward_data, output_dims);
- FilterGrad filter_grad(filter_data, filter_dims);
+ FilterBackward filter_backward(filter_backward_data, filter_dims);
::tensorflow::testing::StartTiming();
for (int i = 0; i < iters_; ++i) {
- filter_grad.device(device_) = Eigen::CuboidConvolutionBackwardKernel(
+ filter_backward.device(device_) = Eigen::CuboidConvolutionBackwardKernel(
input, output_backward, filter_planes, filter_rows, filter_cols);
- tensorflow::testing::DoNotOptimize(filter_grad);
+ tensorflow::testing::DoNotOptimize(filter_backward);
}
::tensorflow::testing::StopTiming();
device_.deallocate(input_data);
device_.deallocate(output_backward_data);
- device_.deallocate(filter_data);
+ device_.deallocate(filter_backward_data);
}
private:
diff --git a/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc b/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc
index 2a8308ef9a..7c2bbb8148 100644
--- a/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc
+++ b/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc
@@ -123,6 +123,7 @@ void SpatialConvolutionBackwardKernel(int iters, int num_threads,
#define BM_SpatialConvolution(NT, N, H, W, C, FC, FH, FW, LABEL) \
static void BM_SPATIAL_NAME(SpatialConvolution, NT, N, H, W, C, FC, FH, \
FW)(int iters) { \
+ ::tensorflow::testing::SetLabel(LABEL); \
SpatialConvolution(iters, NT, N, H, W, C, FC, FH, FW); \
} \
BENCHMARK(BM_SPATIAL_NAME(SpatialConvolution, NT, N, H, W, C, FC, FH, FW))
@@ -130,6 +131,7 @@ void SpatialConvolutionBackwardKernel(int iters, int num_threads,
#define BM_SpatialConvolutionBwdInput(NT, N, H, W, C, FC, FH, FW, LABEL) \
static void BM_SPATIAL_NAME(SpatialConvolutionBwdInput, NT, N, H, W, C, FC, \
FH, FW)(int iters) { \
+ ::tensorflow::testing::SetLabel(LABEL); \
SpatialConvolutionBackwardInput(iters, NT, N, H, W, C, FC, FH, FW); \
} \
BENCHMARK( \
@@ -138,6 +140,7 @@ void SpatialConvolutionBackwardKernel(int iters, int num_threads,
#define BM_SpatialConvolutionBwdKernel(NT, N, H, W, C, FC, FH, FW, LABEL) \
static void BM_SPATIAL_NAME(SpatialConvolutionBwdKernel, NT, N, H, W, C, FC, \
FH, FW)(int iters) { \
+ ::tensorflow::testing::SetLabel(LABEL); \
SpatialConvolutionBackwardKernel(iters, NT, N, H, W, C, FC, FH, FW); \
} \
BENCHMARK(BM_SPATIAL_NAME(SpatialConvolutionBwdKernel, NT, N, H, W, C, FC, \
@@ -348,6 +351,7 @@ void CuboidConvolutionBackwardKernel(int iters, int num_threads,
#define BM_CuboidConvolution(NT, N, H, W, P, C, FC, FH, FW, FP, LABEL) \
static void BM_CUBOID_NAME(CuboidConvolution, NT, N, H, W, P, C, FC, FH, FW, \
FP)(int iters) { \
+ ::tensorflow::testing::SetLabel(LABEL); \
CuboidConvolution(iters, NT, N, H, W, P, C, FC, FH, FW, FP); \
} \
BENCHMARK( \
@@ -356,6 +360,7 @@ void CuboidConvolutionBackwardKernel(int iters, int num_threads,
#define BM_CuboidConvolutionBwdInput(NT, N, H, W, P, C, FC, FH, FW, FP, LABEL) \
static void BM_CUBOID_NAME(CuboidConvolutionBwdInput, NT, N, H, W, P, C, FC, \
FH, FW, FP)(int iters) { \
+ ::tensorflow::testing::SetLabel(LABEL); \
CuboidConvolutionBackwardInput(iters, NT, N, H, W, P, C, FC, FH, FW, FP); \
} \
BENCHMARK(BM_CUBOID_NAME(CuboidConvolutionBwdInput, NT, N, H, W, P, C, FC, \
@@ -365,6 +370,7 @@ void CuboidConvolutionBackwardKernel(int iters, int num_threads,
LABEL) \
static void BM_CUBOID_NAME(CuboidConvolutionBwdKernel, NT, N, H, W, P, C, \
FC, FH, FW, FP)(int iters) { \
+ ::tensorflow::testing::SetLabel(LABEL); \
CuboidConvolutionBackwardKernel(iters, NT, N, H, W, P, C, FC, FH, FW, FP); \
} \
BENCHMARK(BM_CUBOID_NAME(CuboidConvolutionBwdKernel, NT, N, H, W, P, C, FC, \
@@ -395,8 +401,11 @@ void CuboidConvolutionBackwardKernel(int iters, int num_threads,
BM_CuboidConvolutions(8, // batch size
25, 25, 25, 4, // input: height, width, panes, depth
16, 5, 5, 5, // filter: count, height, width, panes
- "conv3d");
+ "conv3d_depth4");
+BM_CuboidConvolutions(8, 25, 25, 25, 8, 16, 5, 5, 5, "conv3d_depth8");
-BM_CuboidConvolutionsBwdInput(8, 25, 25, 25, 4, 16, 5, 5, 5, "conv3d");
+BM_CuboidConvolutionsBwdInput(8, 25, 25, 25, 4, 16, 5, 5, 5, "conv3d_depth4");
+BM_CuboidConvolutionsBwdInput(8, 25, 25, 25, 8, 16, 5, 5, 5, "conv3d_depth8");
-BM_CuboidConvolutionsBwdKernel(8, 25, 25, 25, 4, 16, 5, 5, 5, "conv3d");
+BM_CuboidConvolutionsBwdKernel(8, 25, 25, 25, 4, 16, 5, 5, 5, "conv3d_depth4");
+BM_CuboidConvolutionsBwdKernel(8, 25, 25, 25, 8, 16, 5, 5, 5, "conv3d_depth8");
diff --git a/tensorflow/core/kernels/gpu_utils.h b/tensorflow/core/kernels/gpu_utils.h
index c7dbefa0b4..86146f75f4 100644
--- a/tensorflow/core/kernels/gpu_utils.h
+++ b/tensorflow/core/kernels/gpu_utils.h
@@ -123,8 +123,7 @@ class AutoTuneMap {
string GetActionSummary(StringPiece action, const Parameters& params,
const Config& config) {
return strings::Printf("autotune_map %s %s: %s -> (%s)", name_.c_str(),
- std::string(action).c_str(),
- params.ToString().c_str(),
+ string(action).c_str(), params.ToString().c_str(),
config.ToString().c_str());
}
diff --git a/tensorflow/core/kernels/list_kernels.h b/tensorflow/core/kernels/list_kernels.h
index 066a1d603b..72581c9293 100644
--- a/tensorflow/core/kernels/list_kernels.h
+++ b/tensorflow/core/kernels/list_kernels.h
@@ -374,7 +374,12 @@ Status TensorListZerosLike(OpKernelContext* c, const TensorList& x,
y->tensors.reserve(x.tensors.size());
for (const Tensor& t : x.tensors) {
Tensor out_tensor;
- TF_RETURN_IF_ERROR(c->allocate_temp(t.dtype(), t.shape(), &out_tensor));
+ AllocatorAttributes attr;
+ if (t.dtype() == DT_VARIANT) {
+ attr.set_on_host(true);
+ }
+ TF_RETURN_IF_ERROR(
+ c->allocate_temp(t.dtype(), t.shape(), &out_tensor, attr));
switch (out_tensor.dtype()) {
#define DTYPE_CASE(dtype) \
case DataTypeToEnum<dtype>::value: \
@@ -385,6 +390,20 @@ Status TensorListZerosLike(OpKernelContext* c, const TensorList& x,
TF_CALL_POD_TYPES(DTYPE_CASE)
#undef DTYPE_CASE
+
+ case DataTypeToEnum<Variant>::value: {
+ const TensorList* inner_x = t.scalar<Variant>()().get<TensorList>();
+ if (inner_x == nullptr) {
+ return errors::InvalidArgument("Input handle is not a list. Saw: '",
+ t.scalar<Variant>()().DebugString(),
+ "'");
+ }
+ TensorList inner_y;
+ TF_RETURN_IF_ERROR(TensorListZerosLike<Device>(c, *inner_x, &inner_y));
+ out_tensor.scalar<Variant>()() = std::move(inner_y);
+ break;
+ }
+
default:
return errors::InvalidArgument(
"Trying to compute zeros_like for unsupported dtype ",
diff --git a/tensorflow/core/kernels/merge_v2_checkpoints_op_test.cc b/tensorflow/core/kernels/merge_v2_checkpoints_op_test.cc
index 10e468ce46..693ed8a8f0 100644
--- a/tensorflow/core/kernels/merge_v2_checkpoints_op_test.cc
+++ b/tensorflow/core/kernels/merge_v2_checkpoints_op_test.cc
@@ -114,9 +114,7 @@ class MergeV2CheckpointsOpTest : public OpsTestBase {
// Exercises "delete_old_dirs".
for (int i = 0; i < 2; ++i) {
int directory_found =
- Env::Default()
- ->IsDirectory(std::string(io::Dirname(prefixes[i])))
- .code();
+ Env::Default()->IsDirectory(string(io::Dirname(prefixes[i]))).code();
if (delete_old_dirs) {
EXPECT_EQ(error::NOT_FOUND, directory_found);
} else {
diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc
index 194a711d98..26f107f940 100644
--- a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc
+++ b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc
@@ -47,7 +47,7 @@ std::unordered_set<string> BuildNodeSetFromNodeNamesAndPorts(
std::unordered_set<string> retval;
for (const string& node_name_and_port : node_names_and_ports) {
const TensorId tid = ParseTensorName(node_name_and_port);
- retval.emplace(std::string(tid.first));
+ retval.emplace(tid.first);
}
return retval;
}
@@ -64,7 +64,7 @@ Node* FindMutableNodeByName(const string& name, Graph* graph) {
const NodeDef* FindNodeDefByName(const string& input,
const GraphDef& graph_def) {
const TensorId tid = ParseTensorName(input);
- const string name = std::string(tid.first);
+ const string name = string(tid.first);
for (const NodeDef& node_def : graph_def.node()) {
if (node_def.name() == name) {
return &node_def;
@@ -423,7 +423,7 @@ RemoteFusedGraphExecuteUtils::AddOutputTensorShapeTypeByTensorShapeMap(
std::vector<DataType> data_types;
std::vector<TensorShape> shapes;
const TensorId tid = ParseTensorName(name_and_port);
- const string node_name = std::string(tid.first);
+ const string node_name(tid.first);
const int port = tid.second;
const NodeDef* node_def = FindNodeDefByName(node_name, graph_def);
CHECK_NOTNULL(node_def);
@@ -522,8 +522,7 @@ RemoteFusedGraphExecuteUtils::GetTensorShapeType(
const TensorShapeMap& tensor_shape_map, const string& node_name) {
if (node_name.find(':') != string::npos) {
const TensorId tid = ParseTensorName(node_name);
- return GetTensorShapeType(tensor_shape_map, std::string(tid.first),
- tid.second);
+ return GetTensorShapeType(tensor_shape_map, string(tid.first), tid.second);
} else {
return GetTensorShapeType(tensor_shape_map, node_name, 0);
}
@@ -570,7 +569,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteGraphInputsAndOutputsFromProto(
const TensorId tid = ParseTensorName(name);
CHECK_EQ(tensor_shape_map->count(name), 0);
tensor_shape_map->emplace(
- std::string(tid.first),
+ string(tid.first),
std::make_pair(tid.second,
std::make_pair(tensor.dtype(), tensor.shape())));
}
@@ -692,7 +691,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode(
std::vector<NodeBuilder::NodeOut> node_out_list;
for (const string& input : inputs) {
const TensorId tid = ParseTensorName(input);
- Node* node = FindMutableNodeByName(std::string(tid.first), graph);
+ Node* node = FindMutableNodeByName(string(tid.first), graph);
CHECK_NOTNULL(node);
node_out_list.emplace_back(node, tid.second);
}
@@ -848,7 +847,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode(
for (const string& subgraph_input : std::get<1>(cluster)) {
const TensorId tid = ParseTensorName(subgraph_input);
- const string subgraph_input_name = std::string(tid.first);
+ const string subgraph_input_name(tid.first);
const int subgraph_input_port = tid.second;
const NodeDef* node_def = FindNodeDefByName(subgraph_input_name, graph_def);
CHECK_NOTNULL(node_def);
@@ -895,7 +894,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode(
std::deque<const Node*> queue;
for (const string& output : border_outputs) {
const TensorId tid = ParseTensorName(output);
- const string& output_node_name = std::string(tid.first);
+ const string output_node_name(tid.first);
for (const Node* node : graph.nodes()) {
if (output_node_name == node->name()) {
queue.push_back(node);
@@ -975,7 +974,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode(
for (int j = 0; j < border_outputs.size(); ++j) {
const string& output = border_outputs.at(j);
const TensorId tid = ParseTensorName(output);
- const string output_name = std::string(tid.first);
+ const string output_name(tid.first);
Node* src_node = edge->src();
if (src_node != nullptr && src_node->name() == output_name &&
edge->src_output() == tid.second) {
@@ -995,12 +994,11 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode(
// RemoteFusedGraphExecuteOpNode
for (const string& output : outputs) {
const TensorId output_tid = ParseTensorName(output);
- const string output_name = std::string(output_tid.first);
+ const string output_name(output_tid.first);
for (size_t i = 0; i < border_outputs.size(); ++i) {
const TensorId subgraph_output_tid =
ParseTensorName(border_outputs.at(i));
- const string& subgraph_output_name =
- std::string(subgraph_output_tid.first);
+ const string subgraph_output_name(subgraph_output_tid.first);
if (output_name == subgraph_output_name) {
LOG(INFO) << "As graph output and subgraph output are same, "
<< "the graph output node is replaced by identity node";
@@ -1435,7 +1433,7 @@ RemoteFusedGraphExecuteUtils::BuildNodeMapFromOpsDefinitions(
GraphDef* graph_def) {
const TensorId tid = ParseTensorName(input);
CHECK_EQ(0, tid.second);
- const string node_name = std::string(tid.first);
+ const string node_name(tid.first);
for (NodeDef& node : *graph_def->mutable_node()) {
if (node.name() != node_name) {
continue;
diff --git a/tensorflow/core/kernels/save_restore_tensor.cc b/tensorflow/core/kernels/save_restore_tensor.cc
index e335e38bdc..82546d581a 100644
--- a/tensorflow/core/kernels/save_restore_tensor.cc
+++ b/tensorflow/core/kernels/save_restore_tensor.cc
@@ -161,9 +161,12 @@ void RestoreTensor(OpKernelContext* context,
// If we cannot find a cached reader we will allocate our own.
std::unique_ptr<checkpoint::TensorSliceReader> allocated_reader;
- const checkpoint::TensorSliceReader* reader =
- context->slice_reader_cache()->GetReader(file_pattern, open_func,
- preferred_shard);
+ const checkpoint::TensorSliceReader* reader = nullptr;
+
+ if (context->slice_reader_cache()) {
+ reader = context->slice_reader_cache()->GetReader(file_pattern, open_func,
+ preferred_shard);
+ }
if (!reader) {
allocated_reader.reset(new checkpoint::TensorSliceReader(
file_pattern, open_func, preferred_shard));
diff --git a/tensorflow/core/kernels/save_restore_v2_ops.cc b/tensorflow/core/kernels/save_restore_v2_ops.cc
index ab4de6c815..180eb3ca34 100644
--- a/tensorflow/core/kernels/save_restore_v2_ops.cc
+++ b/tensorflow/core/kernels/save_restore_v2_ops.cc
@@ -220,9 +220,9 @@ class MergeV2Checkpoints : public OpKernel {
context, tensorflow::MergeBundles(env, input_prefixes, merged_prefix));
if (delete_old_dirs_) {
- const string& merged_dir = std::string(io::Dirname(merged_prefix));
+ const string merged_dir(io::Dirname(merged_prefix));
for (const string& input_prefix : input_prefixes) {
- const string& dirname = std::string(io::Dirname(input_prefix));
+ const string dirname(io::Dirname(input_prefix));
if (dirname == merged_dir) continue;
Status status = env->DeleteDir(dirname);
// For sharded save, only the first delete will go through and all
diff --git a/tensorflow/core/kernels/string_strip_op.cc b/tensorflow/core/kernels/string_strip_op.cc
index 2aeafa28c4..544dca96ba 100644
--- a/tensorflow/core/kernels/string_strip_op.cc
+++ b/tensorflow/core/kernels/string_strip_op.cc
@@ -43,7 +43,7 @@ class StringStripOp : public OpKernel {
for (int64 i = 0; i < input.size(); ++i) {
StringPiece entry(input(i));
str_util::RemoveWhitespaceContext(&entry);
- output(i) = std::string(entry);
+ output(i) = string(entry);
}
}
};
diff --git a/tensorflow/core/kernels/tensor_array_ops.cc b/tensorflow/core/kernels/tensor_array_ops.cc
index 632b65e9b6..2ec2651c04 100644
--- a/tensorflow/core/kernels/tensor_array_ops.cc
+++ b/tensorflow/core/kernels/tensor_array_ops.cc
@@ -297,7 +297,7 @@ class TensorArrayGradOp : public TensorArrayCreationOp {
resource.name());
}
tensor_array_name =
- std::string(StringPiece(resource.name()).substr(container.size()));
+ string(StringPiece(resource.name()).substr(container.size()));
}
auto output_handle = tensor_array_output_handle->flat<string>();
diff --git a/tensorflow/core/kernels/whole_file_read_ops.cc b/tensorflow/core/kernels/whole_file_read_ops.cc
index ed2bf3e8e2..1bf46b5e46 100644
--- a/tensorflow/core/kernels/whole_file_read_ops.cc
+++ b/tensorflow/core/kernels/whole_file_read_ops.cc
@@ -134,7 +134,7 @@ class WriteFileOp : public OpKernel {
"Contents tensor must be scalar, but had shape: ",
contents_input->shape().DebugString()));
const string& filename = filename_input->scalar<string>()();
- const string dir = std::string(io::Dirname(filename));
+ const string dir(io::Dirname(filename));
if (!context->env()->FileExists(dir).ok()) {
OP_REQUIRES_OK(context, context->env()->RecursivelyCreateDir(dir));
}
diff --git a/tensorflow/core/lib/core/errors.h b/tensorflow/core/lib/core/errors.h
index 982901a39c..d5cbe6c616 100644
--- a/tensorflow/core/lib/core/errors.h
+++ b/tensorflow/core/lib/core/errors.h
@@ -136,11 +136,9 @@ string FormatNodeNamesForError(const T& names) {
::tensorflow::strings::StrAppend(output, FormatNodeNameForError(s));
});
}
-// TODO(b/113350742): Consolidate the two different formats `{{key value}}` and
-// `^^key:value^^` in a follow-on CL.
// LINT.IfChange
inline string FormatColocationNodeForError(const string& name) {
- return strings::StrCat("^^colocation_node:", name, "^^");
+ return strings::StrCat("{{colocation_node ", name, "}}");
}
// LINT.ThenChange(//tensorflow/python/framework/error_interpolation.py)
template <typename T>
diff --git a/tensorflow/core/lib/gtl/inlined_vector.h b/tensorflow/core/lib/gtl/inlined_vector.h
index c18dc9ad1a..2d622dc229 100644
--- a/tensorflow/core/lib/gtl/inlined_vector.h
+++ b/tensorflow/core/lib/gtl/inlined_vector.h
@@ -13,674 +13,19 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-// An InlinedVector<T,N,A> is like a std::vector<T,A>, except that storage
-// for sequences of length <= N are provided inline without requiring
-// any heap allocation. Typically N is very small (e.g., 4) so that
-// sequences that are expected to be short do not require allocations.
-//
-// Only some of the std::vector<> operations are currently implemented.
-// Other operations may be added as needed to facilitate migrating
-// code that uses std::vector<> to InlinedVector<>.
-//
-// NOTE: If you want an inlined version to replace use of a
-// std::vector<bool>, consider using util::bitmap::InlinedBitVector<NBITS>
-// in util/bitmap/inlined_bitvector.h
-//
-// TODO(billydonahue): change size_t to size_type where appropriate.
-
#ifndef TENSORFLOW_CORE_LIB_GTL_INLINED_VECTOR_H_
#define TENSORFLOW_CORE_LIB_GTL_INLINED_VECTOR_H_
-#include <stddef.h>
-#include <stdlib.h>
-#include <string.h>
-#include <sys/types.h>
-#include <algorithm>
-#include <cstddef>
-#include <iterator>
-#include <memory>
-#include <type_traits>
-#include <vector>
-
-#include "tensorflow/core/lib/gtl/manual_constructor.h"
-#include "tensorflow/core/platform/byte_order.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/mem.h"
+#include "absl/container/inlined_vector.h"
+// TODO(kramerb): This is kept only because lots of targets transitively depend
+// on it. Remove all targets' dependencies.
+#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
-#include <initializer_list> // NOLINT(build/include_order)
-
namespace tensorflow {
namespace gtl {
-template <typename T, int N>
-class InlinedVector {
- public:
- typedef T value_type;
- typedef T* pointer;
- typedef const T* const_pointer;
- typedef T& reference;
- typedef const T& const_reference;
- typedef size_t size_type;
- typedef std::ptrdiff_t difference_type;
- typedef pointer iterator;
- typedef const_pointer const_iterator;
-
- // Create an empty vector
- InlinedVector();
-
- // Create a vector with n copies of value_type().
- explicit InlinedVector(size_t n);
-
- // Create a vector with n copies of elem
- InlinedVector(size_t n, const value_type& elem);
-
- // Create and initialize with the elements [range_start .. range_end).
- // The unused enable_if argument restricts this constructor so that it is
- // elided when value_type is an integral type. This prevents ambiguous
- // interpretation between a call to this constructor with two integral
- // arguments and a call to the preceding (n, elem) constructor.
- template <typename InputIterator>
- InlinedVector(
- InputIterator range_start, InputIterator range_end,
- typename std::enable_if<!std::is_integral<InputIterator>::value>::type* =
- NULL) {
- InitRep();
- AppendRange(range_start, range_end);
- }
-
- InlinedVector(std::initializer_list<value_type> init) {
- InitRep();
- AppendRange(init.begin(), init.end());
- }
-
- InlinedVector(const InlinedVector& v);
-
- ~InlinedVector() { clear(); }
-
- InlinedVector& operator=(const InlinedVector& v) {
- // Optimized to avoid reallocation.
- // Prefer reassignment to copy construction for elements.
- const size_t s = size();
- const size_t vs = v.size();
- if (s < vs) { // grow
- reserve(vs);
- if (s) std::copy(v.begin(), v.begin() + s, begin());
- std::copy(v.begin() + s, v.end(), std::back_inserter(*this));
- } else { // maybe shrink
- erase(begin() + vs, end());
- std::copy(v.begin(), v.end(), begin());
- }
- return *this;
- }
-
- size_t size() const { return size_internal(); }
-
- bool empty() const { return (size() == 0); }
-
- // Return number of elements that can be stored in vector
- // without requiring a reallocation of underlying memory
- size_t capacity() const {
- if (is_inline()) {
- return kFit;
- } else {
- return static_cast<size_t>(1) << u_.data[kSize - 2];
- }
- }
-
- // Return a pointer to the underlying array.
- // Only result[0,size()-1] are defined.
- pointer data() {
- if (is_inline()) {
- return reinterpret_cast<T*>(u_.data);
- } else {
- return outofline_pointer();
- }
- }
- const_pointer data() const {
- return const_cast<InlinedVector<T, N>*>(this)->data();
- }
-
- // Remove all elements
- void clear() {
- DiscardStorage();
- u_.data[kSize - 1] = 0;
- }
-
- // Return the ith element
- // REQUIRES: 0 <= i < size()
- const value_type& at(size_t i) const {
- DCHECK_LT(i, size());
- return data()[i];
- }
- const value_type& operator[](size_t i) const {
- DCHECK_LT(i, size());
- return data()[i];
- }
-
- // Return a non-const reference to the ith element
- // REQUIRES: 0 <= i < size()
- value_type& at(size_t i) {
- DCHECK_LT(i, size());
- return data()[i];
- }
- value_type& operator[](size_t i) {
- DCHECK_LT(i, size());
- return data()[i];
- }
-
- value_type& back() {
- DCHECK(!empty());
- return at(size() - 1);
- }
-
- const value_type& back() const {
- DCHECK(!empty());
- return at(size() - 1);
- }
-
- value_type& front() {
- DCHECK(!empty());
- return at(0);
- }
-
- const value_type& front() const {
- DCHECK(!empty());
- return at(0);
- }
-
- // Append a T constructed with args to the vector.
- // Increases size() by one.
- // Amortized complexity: O(1)
- // Worst-case complexity: O(size())
- template <typename... Args>
- void emplace_back(Args&&... args) {
- size_t s = size();
- DCHECK_LE(s, capacity());
- if (s < capacity()) {
- new (data() + s) T(std::forward<Args>(args)...);
- set_size_internal(s + 1);
- } else {
- EmplaceBackSlow(std::forward<Args>(args)...);
- }
- }
-
- // Append t to the vector.
- // Increases size() by one.
- // Amortized complexity: O(1)
- // Worst-case complexity: O(size())
- void push_back(const value_type& t) { emplace_back(t); }
- void push_back(value_type&& t) { emplace_back(std::move(t)); }
-
- inline void pop_back() {
- DCHECK(!empty());
- const size_t s = size();
- Destroy(data() + s - 1, 1);
- set_size_internal(s - 1);
- }
-
- // Resizes the vector to contain "n" elements.
- // If "n" is smaller than the initial size, extra elements are destroyed.
- // If "n" is larger than the initial size, enough copies of "elem"
- // are appended to increase the size to "n". If "elem" is omitted,
- // new elements are value-initialized.
- void resize(size_t n) { Resize<ValueInit>(n, nullptr); }
- void resize(size_t n, const value_type& elem) { Resize<Fill>(n, &elem); }
-
- iterator begin() { return data(); }
- const_iterator begin() const { return data(); }
-
- iterator end() { return data() + size(); }
- const_iterator end() const { return data() + size(); }
-
- iterator insert(iterator pos, const value_type& v);
-
- iterator erase(iterator pos) {
- DCHECK_LT(pos, end());
- DCHECK_GE(pos, begin());
- std::copy(pos + 1, end(), pos);
- pop_back();
- return pos;
- }
-
- iterator erase(iterator first, iterator last);
-
- // Enlarges the underlying representation so it can hold at least
- // "n" elements without reallocation.
- // Does not change size() or the actual contents of the vector.
- void reserve(size_t n) {
- if (n > capacity()) {
- // Make room for new elements
- Grow<Move>(n);
- }
- }
-
- // Swap the contents of *this with other.
- // REQUIRES: value_type is swappable and copyable.
- void swap(InlinedVector& other);
-
- private:
- // Representation can either be inlined or out-of-line.
- // In either case, at least sizeof(void*) + 8 bytes are available.
- //
- // Inlined:
- // Last byte holds the length.
- // First (length*sizeof(T)) bytes stores the elements.
- // Outlined:
- // Last byte holds kSentinel.
- // Second-last byte holds lg(capacity)
- // Preceding 6 bytes hold size.
- // First sizeof(T*) bytes hold pointer.
-
- // Compute rep size.
- static const size_t kSizeUnaligned = N * sizeof(T) + 1; // Room for tag
- static const size_t kSize = ((kSizeUnaligned + 15) / 16) * 16; // Align
-
- // See how many fit T we can fit inside kSize, but no more than 254
- // since 255 is used as sentinel tag for out-of-line allocation.
- static const unsigned int kSentinel = 255;
- static const size_t kFit1 = (kSize - 1) / sizeof(T);
- static const size_t kFit = (kFit1 >= kSentinel) ? (kSentinel - 1) : kFit1;
-
- union {
- unsigned char data[kSize];
- // Force data to be aligned enough for a pointer.
- T* unused_aligner;
- } u_;
-
- inline void InitRep() { u_.data[kSize - 1] = 0; }
- inline bool is_inline() const { return u_.data[kSize - 1] != kSentinel; }
-
- inline T* outofline_pointer() const {
- T* ptr;
- memcpy(&ptr, &u_.data[0], sizeof(ptr));
- return ptr;
- }
-
- inline void set_outofline_pointer(T* p) {
- memcpy(&u_.data[0], &p, sizeof(p));
- }
-
- inline uint64_t outofline_word() const {
- uint64_t word;
- memcpy(&word, &u_.data[kSize - 8], sizeof(word));
- return word;
- }
-
- inline void set_outofline_word(uint64_t w) {
- memcpy(&u_.data[kSize - 8], &w, sizeof(w));
- }
-
- inline size_t size_internal() const {
- uint8_t s = static_cast<uint8_t>(u_.data[kSize - 1]);
- if (s != kSentinel) {
- return static_cast<size_t>(s);
- } else {
- const uint64_t word = outofline_word();
- if (port::kLittleEndian) {
- // The sentinel and capacity bits are most-significant bits in word.
- return static_cast<size_t>(word & 0xffffffffffffull);
- } else {
- // The sentinel and capacity bits are least-significant bits in word.
- return static_cast<size_t>(word >> 16);
- }
- }
- }
-
- void set_size_internal(size_t n) {
- if (is_inline()) {
- DCHECK_LT(n, kSentinel);
- u_.data[kSize - 1] = static_cast<unsigned char>(n);
- } else {
- uint64_t word;
- if (port::kLittleEndian) {
- // The sentinel and capacity bits are most-significant bits in word.
- word = (static_cast<uint64_t>(n) |
- (static_cast<uint64_t>(u_.data[kSize - 2]) << 48) |
- (static_cast<uint64_t>(kSentinel) << 56));
- } else {
- // The sentinel and capacity bits are least-significant bits in word.
- word = ((static_cast<uint64_t>(n) << 16) |
- (static_cast<uint64_t>(u_.data[kSize - 2]) << 8) |
- (static_cast<uint64_t>(kSentinel)));
- }
- set_outofline_word(word);
- DCHECK_EQ(u_.data[kSize - 1], kSentinel) << n;
- }
- }
-
- void DiscardStorage() {
- T* base = data();
- size_t n = size();
- Destroy(base, n);
- if (!is_inline()) {
- port::Free(base);
- }
- }
-
- template <typename... Args>
- void EmplaceBackSlow(Args&&... args) {
- const size_t s = size();
- DCHECK_EQ(s, capacity());
- Grow<Move, Construct>(s + 1, std::forward<Args>(args)...);
- set_size_internal(s + 1);
- }
-
- // Movers for Grow
- // Does nothing.
- static void Nop(T* src, size_t n, T* dst) {}
-
- // Moves srcs[0,n-1] contents to dst[0,n-1].
- static void Move(T* src, size_t n, T* dst) {
- for (size_t i = 0; i < n; i++) {
- new (dst + i) T(std::move(*(src + i)));
- }
- }
-
- // Initializers for Resize.
- // Initializes dst[0,n-1] with empty constructor.
- static void ValueInit(const T*, size_t n, T* dst) {
- for (size_t i = 0; i < n; i++) {
- new (dst + i) T();
- }
- }
-
- // Initializes dst[0,n-1] with copies of *src.
- static void Fill(const T* src, size_t n, T* dst) {
- for (size_t i = 0; i < n; i++) {
- new (dst + i) T(*src);
- }
- }
-
- void Destroy(T* src, int n) {
- if (!std::is_trivially_destructible<T>::value) {
- for (int i = 0; i < n; i++) {
- (src + i)->~T();
- }
- }
- }
-
- // Initialization methods for Grow.
- // 1) Leave uninitialized memory.
- struct Uninitialized {
- void operator()(T*) const {}
- };
- // 2) Construct a T with args at not-yet-initialized memory pointed by dst.
- struct Construct {
- template <class... Args>
- void operator()(T* dst, Args&&... args) const {
- new (dst) T(std::forward<Args>(args)...);
- }
- };
-
- // Grow so that capacity >= n. Uses Mover to move existing elements
- // to new buffer, and possibly initialize the new element according
- // to InitType.
- // We pass the InitType and Mover as template arguments so that
- // this code compiles even if T does not support copying or default
- // construction.
- template <void(Mover)(T*, size_t, T*), class InitType = Uninitialized,
- class... Args>
- void Grow(size_t n, Args&&... args) {
- size_t s = size();
- DCHECK_LE(s, capacity());
-
- // Compute new capacity by repeatedly doubling current capacity
- size_t target = 1;
- size_t target_lg = 0;
- while (target < kFit || target < n) {
- // TODO(psrc): Check and avoid overflow?
- target_lg++;
- target <<= 1;
- }
-
- T* src = data();
- T* dst = static_cast<T*>(port::Malloc(target * sizeof(T)));
-
- // Need to copy elem before discarding src since it might alias src.
- InitType{}(dst + s, std::forward<Args>(args)...);
- Mover(src, s, dst);
- DiscardStorage();
-
- u_.data[kSize - 1] = kSentinel;
- u_.data[kSize - 2] = static_cast<unsigned char>(target_lg);
- set_size_internal(s);
- DCHECK_EQ(capacity(), target);
- set_outofline_pointer(dst);
- }
-
- // Resize to size n. Any new elements are initialized by passing
- // elem and the destination to Initializer. We pass the Initializer
- // as a template argument so that this code compiles even if T does
- // not support copying.
- template <void(Initializer)(const T*, size_t, T*)>
- void Resize(size_t n, const T* elem) {
- size_t s = size();
- if (n <= s) {
- Destroy(data() + n, s - n);
- set_size_internal(n);
- return;
- }
- reserve(n);
- DCHECK_GE(capacity(), n);
- set_size_internal(n);
- Initializer(elem, n - s, data() + s);
- }
-
- template <typename Iter>
- void AppendRange(Iter first, Iter last, std::input_iterator_tag);
-
- // Faster path for forward iterators.
- template <typename Iter>
- void AppendRange(Iter first, Iter last, std::forward_iterator_tag);
-
- template <typename Iter>
- void AppendRange(Iter first, Iter last);
-};
-
-// Provide linkage for constants.
-template <typename T, int N>
-const size_t InlinedVector<T, N>::kSizeUnaligned;
-template <typename T, int N>
-const size_t InlinedVector<T, N>::kSize;
-template <typename T, int N>
-const unsigned int InlinedVector<T, N>::kSentinel;
-template <typename T, int N>
-const size_t InlinedVector<T, N>::kFit1;
-template <typename T, int N>
-const size_t InlinedVector<T, N>::kFit;
-
-template <typename T, int N>
-inline void swap(InlinedVector<T, N>& a, InlinedVector<T, N>& b) {
- a.swap(b);
-}
-
-template <typename T, int N>
-inline bool operator==(const InlinedVector<T, N>& a,
- const InlinedVector<T, N>& b) {
- return a.size() == b.size() && std::equal(a.begin(), a.end(), b.begin());
-}
-
-template <typename T, int N>
-inline bool operator!=(const InlinedVector<T, N>& a,
- const InlinedVector<T, N>& b) {
- return !(a == b);
-}
-
-template <typename T, int N>
-inline bool operator<(const InlinedVector<T, N>& a,
- const InlinedVector<T, N>& b) {
- return std::lexicographical_compare(a.begin(), a.end(), b.begin(), b.end());
-}
-
-template <typename T, int N>
-inline bool operator>(const InlinedVector<T, N>& a,
- const InlinedVector<T, N>& b) {
- return b < a;
-}
-
-template <typename T, int N>
-inline bool operator<=(const InlinedVector<T, N>& a,
- const InlinedVector<T, N>& b) {
- return !(b < a);
-}
-
-template <typename T, int N>
-inline bool operator>=(const InlinedVector<T, N>& a,
- const InlinedVector<T, N>& b) {
- return !(a < b);
-}
-
-// ========================================
-// Implementation
-
-template <typename T, int N>
-inline InlinedVector<T, N>::InlinedVector() {
- InitRep();
-}
-
-template <typename T, int N>
-inline InlinedVector<T, N>::InlinedVector(size_t n) {
- InitRep();
- if (n > capacity()) {
- Grow<Nop>(n); // Must use Nop in case T is not copyable
- }
- set_size_internal(n);
- ValueInit(nullptr, n, data());
-}
-
-template <typename T, int N>
-inline InlinedVector<T, N>::InlinedVector(size_t n, const value_type& elem) {
- InitRep();
- if (n > capacity()) {
- Grow<Nop>(n); // Can use Nop since we know we have nothing to copy
- }
- set_size_internal(n);
- Fill(&elem, n, data());
-}
-
-template <typename T, int N>
-inline InlinedVector<T, N>::InlinedVector(const InlinedVector& v) {
- InitRep();
- *this = v;
-}
-
-template <typename T, int N>
-typename InlinedVector<T, N>::iterator InlinedVector<T, N>::insert(
- iterator pos, const value_type& v) {
- DCHECK_GE(pos, begin());
- DCHECK_LE(pos, end());
- if (pos == end()) {
- push_back(v);
- return end() - 1;
- }
- size_t s = size();
- size_t idx = std::distance(begin(), pos);
- if (s == capacity()) {
- Grow<Move>(s + 1);
- }
- CHECK_LT(s, capacity());
- pos = begin() + idx; // Reset 'pos' into a post-enlarge iterator.
- Fill(data() + s - 1, 1, data() + s); // data[s] = data[s-1]
- std::copy_backward(pos, data() + s - 1, data() + s);
- *pos = v;
-
- set_size_internal(s + 1);
- return pos;
-}
-
-template <typename T, int N>
-typename InlinedVector<T, N>::iterator InlinedVector<T, N>::erase(
- iterator first, iterator last) {
- DCHECK_LE(begin(), first);
- DCHECK_LE(first, last);
- DCHECK_LE(last, end());
-
- size_t s = size();
- ptrdiff_t erase_gap = std::distance(first, last);
- std::copy(last, data() + s, first);
- Destroy(data() + s - erase_gap, erase_gap);
- set_size_internal(s - erase_gap);
- return first;
-}
-
-template <typename T, int N>
-void InlinedVector<T, N>::swap(InlinedVector& other) {
- using std::swap; // Augment ADL with std::swap.
- if (&other == this) {
- return;
- }
-
- InlinedVector* a = this;
- InlinedVector* b = &other;
-
- const bool a_inline = a->is_inline();
- const bool b_inline = b->is_inline();
-
- if (!a_inline && !b_inline) {
- // Just swap the top-level representations.
- T* aptr = a->outofline_pointer();
- T* bptr = b->outofline_pointer();
- a->set_outofline_pointer(bptr);
- b->set_outofline_pointer(aptr);
-
- uint64_t aword = a->outofline_word();
- uint64_t bword = b->outofline_word();
- a->set_outofline_word(bword);
- b->set_outofline_word(aword);
- return;
- }
-
- // Make a the larger of the two to reduce number of cases.
- size_t a_size = a->size();
- size_t b_size = b->size();
- if (a->size() < b->size()) {
- swap(a, b);
- swap(a_size, b_size);
- }
- DCHECK_GE(a_size, b_size);
-
- if (b->capacity() < a_size) {
- b->Grow<Move>(a_size);
- }
-
- // One is inline and one is not.
- // 'a' is larger. Swap the elements up to the smaller array size.
- std::swap_ranges(a->data(), a->data() + b_size, b->data());
- std::uninitialized_copy(a->data() + b_size, a->data() + a_size,
- b->data() + b_size);
- Destroy(a->data() + b_size, a_size - b_size);
- a->set_size_internal(b_size);
- b->set_size_internal(a_size);
- DCHECK_EQ(b->size(), a_size);
- DCHECK_EQ(a->size(), b_size);
-}
-
-template <typename T, int N>
-template <typename Iter>
-inline void InlinedVector<T, N>::AppendRange(Iter first, Iter last,
- std::input_iterator_tag) {
- std::copy(first, last, std::back_inserter(*this));
-}
-
-template <typename T, int N>
-template <typename Iter>
-inline void InlinedVector<T, N>::AppendRange(Iter first, Iter last,
- std::forward_iterator_tag) {
- typedef typename std::iterator_traits<Iter>::difference_type Length;
- Length length = std::distance(first, last);
- size_t s = size();
- reserve(s + length);
- std::uninitialized_copy_n(first, length, data() + s);
- set_size_internal(s + length);
-}
-
-template <typename T, int N>
-template <typename Iter>
-inline void InlinedVector<T, N>::AppendRange(Iter first, Iter last) {
- typedef typename std::iterator_traits<Iter>::iterator_category IterTag;
- AppendRange(first, last, IterTag());
-}
+using absl::InlinedVector;
} // namespace gtl
} // namespace tensorflow
diff --git a/tensorflow/core/lib/gtl/inlined_vector_test.cc b/tensorflow/core/lib/gtl/inlined_vector_test.cc
deleted file mode 100644
index 2721885c4a..0000000000
--- a/tensorflow/core/lib/gtl/inlined_vector_test.cc
+++ /dev/null
@@ -1,898 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/lib/gtl/inlined_vector.h"
-
-#include <list>
-#include <memory>
-#include <string>
-#include <vector>
-
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/platform/test.h"
-#include "tensorflow/core/platform/test_benchmark.h"
-#include "tensorflow/core/platform/types.h"
-
-namespace tensorflow {
-
-typedef tensorflow::gtl::InlinedVector<int, 8> IntVec;
-
-// A type that counts number of live occurrences of the type
-static int64 instances = 0;
-class Instance {
- public:
- int value_;
- explicit Instance(int x) : value_(x) { instances++; }
- Instance(const Instance& x) : value_(x.value_) { instances++; }
- ~Instance() { instances--; }
-
- friend inline void swap(Instance& a, Instance& b) {
- using std::swap;
- swap(a.value_, b.value_);
- }
-
- friend std::ostream& operator<<(std::ostream& o, const Instance& v) {
- return o << "[value:" << v.value_ << "]";
- }
-};
-
-typedef tensorflow::gtl::InlinedVector<Instance, 8> InstanceVec;
-
-// A simple reference counted class to make sure that the proper elements are
-// destroyed in the erase(begin, end) test.
-class RefCounted {
- public:
- RefCounted(int value, int* count) : value_(value), count_(count) { Ref(); }
-
- RefCounted(const RefCounted& v) : value_(v.value_), count_(v.count_) {
- VLOG(5) << "[RefCounted: copy"
- << " from count @" << v.count_ << "]";
- Ref();
- }
-
- ~RefCounted() {
- Unref();
- count_ = nullptr;
- }
-
- friend void swap(RefCounted& a, RefCounted& b) {
- using std::swap;
- swap(a.value_, b.value_);
- swap(a.count_, b.count_);
- }
-
- RefCounted& operator=(RefCounted v) {
- using std::swap;
- swap(*this, v);
- return *this;
- }
-
- void Ref() const {
- CHECK(count_ != nullptr);
- ++(*count_);
- VLOG(5) << "[Ref: refcount " << *count_ << " on count @" << count_ << "]";
- }
-
- void Unref() const {
- --(*count_);
- CHECK_GE(*count_, 0);
- VLOG(5) << "[Unref: refcount " << *count_ << " on count @" << count_ << "]";
- }
-
- int count() const { return *count_; }
-
- friend std::ostream& operator<<(std::ostream& o, const RefCounted& v) {
- return o << "[value:" << v.value_ << ", count:" << *v.count_ << "]";
- }
-
- int value_;
- int* count_;
-};
-
-typedef tensorflow::gtl::InlinedVector<RefCounted, 8> RefCountedVec;
-
-// A class with a vtable pointer
-class Dynamic {
- public:
- virtual ~Dynamic() {}
-
- friend std::ostream& operator<<(std::ostream& o, const Dynamic& v) {
- return o << "[Dynamic]";
- }
-};
-
-typedef tensorflow::gtl::InlinedVector<Dynamic, 8> DynamicVec;
-
-// Append 0..len-1 to *v
-static void Fill(IntVec* v, int len, int offset = 0) {
- for (int i = 0; i < len; i++) {
- v->push_back(i + offset);
- }
-}
-
-static IntVec Fill(int len, int offset = 0) {
- IntVec v;
- Fill(&v, len, offset);
- return v;
-}
-
-TEST(IntVec, SimpleOps) {
- for (int len = 0; len < 20; len++) {
- IntVec v;
- const IntVec& cv = v; // const alias
-
- Fill(&v, len);
- EXPECT_EQ(len, v.size());
- EXPECT_LE(len, v.capacity());
-
- for (int i = 0; i < len; i++) {
- EXPECT_EQ(i, v[i]);
- }
- EXPECT_EQ(v.begin(), v.data());
- EXPECT_EQ(cv.begin(), cv.data());
-
- int counter = 0;
- for (IntVec::iterator iter = v.begin(); iter != v.end(); ++iter) {
- EXPECT_EQ(counter, *iter);
- counter++;
- }
- EXPECT_EQ(counter, len);
-
- counter = 0;
- for (IntVec::const_iterator iter = v.begin(); iter != v.end(); ++iter) {
- EXPECT_EQ(counter, *iter);
- counter++;
- }
- EXPECT_EQ(counter, len);
-
- if (len > 0) {
- EXPECT_EQ(0, v.front());
- EXPECT_EQ(len - 1, v.back());
- v.pop_back();
- EXPECT_EQ(len - 1, v.size());
- for (size_t i = 0; i < v.size(); ++i) {
- EXPECT_EQ(i, v[i]);
- }
- }
- }
-}
-
-TEST(IntVec, Erase) {
- for (int len = 1; len < 20; len++) {
- for (int i = 0; i < len; ++i) {
- IntVec v;
- Fill(&v, len);
- v.erase(v.begin() + i);
- EXPECT_EQ(len - 1, v.size());
- for (int j = 0; j < i; ++j) {
- EXPECT_EQ(j, v[j]);
- }
- for (int j = i; j < len - 1; ++j) {
- EXPECT_EQ(j + 1, v[j]);
- }
- }
- }
-}
-
-// At the end of this test loop, the elements between [erase_begin, erase_end)
-// should have reference counts == 0, and all others elements should have
-// reference counts == 1.
-TEST(RefCountedVec, EraseBeginEnd) {
- for (int len = 1; len < 20; ++len) {
- for (int erase_begin = 0; erase_begin < len; ++erase_begin) {
- for (int erase_end = erase_begin; erase_end <= len; ++erase_end) {
- std::vector<int> counts(len, 0);
- RefCountedVec v;
- for (int i = 0; i < len; ++i) {
- v.push_back(RefCounted(i, &counts[i]));
- }
-
- int erase_len = erase_end - erase_begin;
-
- v.erase(v.begin() + erase_begin, v.begin() + erase_end);
-
- EXPECT_EQ(len - erase_len, v.size());
-
- // Check the elements before the first element erased.
- for (int i = 0; i < erase_begin; ++i) {
- EXPECT_EQ(i, v[i].value_);
- }
-
- // Check the elements after the first element erased.
- for (size_t i = erase_begin; i < v.size(); ++i) {
- EXPECT_EQ(i + erase_len, v[i].value_);
- }
-
- // Check that the elements at the beginning are preserved.
- for (int i = 0; i < erase_begin; ++i) {
- EXPECT_EQ(1, counts[i]);
- }
-
- // Check that the erased elements are destroyed
- for (int i = erase_begin; i < erase_end; ++i) {
- EXPECT_EQ(0, counts[i]);
- }
-
- // Check that the elements at the end are preserved.
- for (int i = erase_end; i < len; ++i) {
- EXPECT_EQ(1, counts[i]);
- }
- }
- }
- }
-}
-
-struct NoDefaultCtor {
- explicit NoDefaultCtor(int) {}
-};
-struct NoCopy {
- NoCopy() {}
- NoCopy(const NoCopy&) = delete;
-};
-struct NoAssign {
- NoAssign() {}
- NoAssign& operator=(const NoAssign&) = delete;
-};
-struct MoveOnly {
- MoveOnly() {}
- MoveOnly(MoveOnly&&) = default;
- MoveOnly& operator=(MoveOnly&&) = default;
-};
-TEST(InlinedVectorTest, NoDefaultCtor) {
- tensorflow::gtl::InlinedVector<NoDefaultCtor, 1> v(10, NoDefaultCtor(2));
- (void)v;
-}
-TEST(InlinedVectorTest, NoCopy) {
- tensorflow::gtl::InlinedVector<NoCopy, 1> v(10);
- (void)v;
-}
-TEST(InlinedVectorTest, NoAssign) {
- tensorflow::gtl::InlinedVector<NoAssign, 1> v(10);
- (void)v;
-}
-TEST(InlinedVectorTest, MoveOnly) {
- gtl::InlinedVector<MoveOnly, 2> v;
- v.push_back(MoveOnly{});
- v.push_back(MoveOnly{});
- v.push_back(MoveOnly{});
-}
-
-TEST(IntVec, Insert) {
- for (int len = 0; len < 20; len++) {
- for (int pos = 0; pos <= len; pos++) {
- IntVec v;
- Fill(&v, len);
- v.insert(v.begin() + pos, 9999);
- EXPECT_EQ(v.size(), len + 1);
- for (int i = 0; i < pos; i++) {
- EXPECT_EQ(v[i], i);
- }
- EXPECT_EQ(v[pos], 9999);
- for (size_t i = pos + 1; i < v.size(); i++) {
- EXPECT_EQ(v[i], i - 1);
- }
- }
- }
-}
-
-TEST(RefCountedVec, InsertConstructorDestructor) {
- // Make sure the proper construction/destruction happen during insert
- // operations.
- for (int len = 0; len < 20; len++) {
- SCOPED_TRACE(len);
- for (int pos = 0; pos <= len; pos++) {
- SCOPED_TRACE(pos);
- std::vector<int> counts(len, 0);
- int inserted_count = 0;
- RefCountedVec v;
- for (int i = 0; i < len; ++i) {
- SCOPED_TRACE(i);
- v.push_back(RefCounted(i, &counts[i]));
- }
-
- for (auto elem : counts) {
- EXPECT_EQ(1, elem);
- }
-
- RefCounted insert_element(9999, &inserted_count);
- EXPECT_EQ(1, inserted_count);
- v.insert(v.begin() + pos, insert_element);
- EXPECT_EQ(2, inserted_count);
- // Check that the elements at the end are preserved.
- for (auto elem : counts) {
- EXPECT_EQ(1, elem);
- }
- EXPECT_EQ(2, inserted_count);
- }
- }
-}
-
-TEST(IntVec, Resize) {
- for (int len = 0; len < 20; len++) {
- IntVec v;
- Fill(&v, len);
-
- // Try resizing up and down by k elements
- static const int kResizeElem = 1000000;
- for (int k = 0; k < 10; k++) {
- // Enlarging resize
- v.resize(len + k, kResizeElem);
- EXPECT_EQ(len + k, v.size());
- EXPECT_LE(len + k, v.capacity());
- for (int i = 0; i < len + k; i++) {
- if (i < len) {
- EXPECT_EQ(i, v[i]);
- } else {
- EXPECT_EQ(kResizeElem, v[i]);
- }
- }
-
- // Shrinking resize
- v.resize(len, kResizeElem);
- EXPECT_EQ(len, v.size());
- EXPECT_LE(len, v.capacity());
- for (int i = 0; i < len; i++) {
- EXPECT_EQ(i, v[i]);
- }
- }
- }
-}
-
-TEST(IntVec, InitWithLength) {
- for (int len = 0; len < 20; len++) {
- IntVec v(len, 7);
- EXPECT_EQ(len, v.size());
- EXPECT_LE(len, v.capacity());
- for (int i = 0; i < len; i++) {
- EXPECT_EQ(7, v[i]);
- }
- }
-}
-
-TEST(IntVec, CopyConstructorAndAssignment) {
- for (int len = 0; len < 20; len++) {
- IntVec v;
- Fill(&v, len);
- EXPECT_EQ(len, v.size());
- EXPECT_LE(len, v.capacity());
-
- IntVec v2(v);
- EXPECT_EQ(v, v2);
-
- for (int start_len = 0; start_len < 20; start_len++) {
- IntVec v3;
- Fill(&v3, start_len, 99); // Add dummy elements that should go away
- v3 = v;
- EXPECT_EQ(v, v3);
- }
- }
-}
-
-TEST(OverheadTest, Storage) {
- // Check for size overhead.
- using tensorflow::gtl::InlinedVector;
- EXPECT_EQ(2 * sizeof(int*), sizeof(InlinedVector<int*, 1>));
- EXPECT_EQ(4 * sizeof(int*), sizeof(InlinedVector<int*, 2>));
- EXPECT_EQ(4 * sizeof(int*), sizeof(InlinedVector<int*, 3>));
- EXPECT_EQ(6 * sizeof(int*), sizeof(InlinedVector<int*, 4>));
-
- EXPECT_EQ(2 * sizeof(char*), sizeof(InlinedVector<char, 1>));
- EXPECT_EQ(2 * sizeof(char*), sizeof(InlinedVector<char, 2>));
- EXPECT_EQ(2 * sizeof(char*), sizeof(InlinedVector<char, 3>));
- EXPECT_EQ(2 * sizeof(char*),
- sizeof(InlinedVector<char, 2 * sizeof(char*) - 1>));
- EXPECT_EQ(4 * sizeof(char*), sizeof(InlinedVector<char, 2 * sizeof(char*)>));
-}
-
-TEST(IntVec, Clear) {
- for (int len = 0; len < 20; len++) {
- SCOPED_TRACE(len);
- IntVec v;
- Fill(&v, len);
- v.clear();
- EXPECT_EQ(0, v.size());
- EXPECT_EQ(v.begin(), v.end());
- }
-}
-
-TEST(IntVec, Reserve) {
- for (size_t len = 0; len < 20; len++) {
- IntVec v;
- Fill(&v, len);
-
- for (size_t newlen = 0; newlen < 100; newlen++) {
- const int* start_rep = v.data();
- v.reserve(newlen);
- const int* final_rep = v.data();
- if (newlen <= len) {
- EXPECT_EQ(start_rep, final_rep);
- }
- EXPECT_LE(newlen, v.capacity());
-
- // Filling up to newlen should not change rep
- while (v.size() < newlen) {
- v.push_back(0);
- }
- EXPECT_EQ(final_rep, v.data());
- }
- }
-}
-
-template <typename T>
-static std::vector<typename T::value_type> Vec(const T& src) {
- std::vector<typename T::value_type> result;
- for (const auto& elem : src) {
- result.push_back(elem);
- }
- return result;
-}
-
-TEST(IntVec, SelfRefPushBack) {
- std::vector<string> std_v;
- tensorflow::gtl::InlinedVector<string, 4> v;
- const string s = "A quite long string to ensure heap.";
- std_v.push_back(s);
- v.push_back(s);
- for (int i = 0; i < 20; ++i) {
- EXPECT_EQ(std_v, Vec(v));
-
- v.push_back(v.back());
- std_v.push_back(std_v.back());
- }
- EXPECT_EQ(std_v, Vec(v));
-}
-
-TEST(IntVec, SelfRefPushBackWithMove) {
- std::vector<string> std_v;
- gtl::InlinedVector<string, 4> v;
- const string s = "A quite long string to ensure heap.";
- std_v.push_back(s);
- v.push_back(s);
- for (int i = 0; i < 20; ++i) {
- EXPECT_EQ(v.back(), std_v.back());
-
- v.push_back(std::move(v.back()));
- std_v.push_back(std::move(std_v.back()));
- }
- EXPECT_EQ(v.back(), std_v.back());
-}
-
-TEST(IntVec, Swap) {
- for (int l1 = 0; l1 < 20; l1++) {
- SCOPED_TRACE(l1);
- for (int l2 = 0; l2 < 20; l2++) {
- SCOPED_TRACE(l2);
- IntVec a = Fill(l1, 0);
- IntVec b = Fill(l2, 100);
- {
- using std::swap;
- swap(a, b);
- }
- EXPECT_EQ(l1, b.size());
- EXPECT_EQ(l2, a.size());
- for (int i = 0; i < l1; i++) {
- SCOPED_TRACE(i);
- EXPECT_EQ(i, b[i]);
- }
- for (int i = 0; i < l2; i++) {
- SCOPED_TRACE(i);
- EXPECT_EQ(100 + i, a[i]);
- }
- }
- }
-}
-
-TEST(InstanceVec, Swap) {
- for (int l1 = 0; l1 < 20; l1++) {
- for (int l2 = 0; l2 < 20; l2++) {
- InstanceVec a, b;
- for (int i = 0; i < l1; i++) a.push_back(Instance(i));
- for (int i = 0; i < l2; i++) b.push_back(Instance(100 + i));
- EXPECT_EQ(l1 + l2, instances);
- {
- using std::swap;
- swap(a, b);
- }
- EXPECT_EQ(l1 + l2, instances);
- EXPECT_EQ(l1, b.size());
- EXPECT_EQ(l2, a.size());
- for (int i = 0; i < l1; i++) {
- EXPECT_EQ(i, b[i].value_);
- }
- for (int i = 0; i < l2; i++) {
- EXPECT_EQ(100 + i, a[i].value_);
- }
- }
- }
-}
-
-TEST(IntVec, EqualAndNotEqual) {
- IntVec a, b;
- EXPECT_TRUE(a == b);
- EXPECT_FALSE(a != b);
-
- a.push_back(3);
- EXPECT_FALSE(a == b);
- EXPECT_TRUE(a != b);
-
- b.push_back(3);
- EXPECT_TRUE(a == b);
- EXPECT_FALSE(a != b);
-
- b.push_back(7);
- EXPECT_FALSE(a == b);
- EXPECT_TRUE(a != b);
-
- a.push_back(6);
- EXPECT_FALSE(a == b);
- EXPECT_TRUE(a != b);
-
- a.clear();
- b.clear();
- for (int i = 0; i < 100; i++) {
- a.push_back(i);
- b.push_back(i);
- EXPECT_TRUE(a == b);
- EXPECT_FALSE(a != b);
-
- b[i] = b[i] + 1;
- EXPECT_FALSE(a == b);
- EXPECT_TRUE(a != b);
-
- b[i] = b[i] - 1; // Back to before
- EXPECT_TRUE(a == b);
- EXPECT_FALSE(a != b);
- }
-}
-
-TEST(IntVec, RelationalOps) {
- IntVec a, b;
- EXPECT_FALSE(a < b);
- EXPECT_FALSE(b < a);
- EXPECT_FALSE(a > b);
- EXPECT_FALSE(b > a);
- EXPECT_TRUE(a <= b);
- EXPECT_TRUE(b <= a);
- EXPECT_TRUE(a >= b);
- EXPECT_TRUE(b >= a);
- b.push_back(3);
- EXPECT_TRUE(a < b);
- EXPECT_FALSE(b < a);
- EXPECT_FALSE(a > b);
- EXPECT_TRUE(b > a);
- EXPECT_TRUE(a <= b);
- EXPECT_FALSE(b <= a);
- EXPECT_FALSE(a >= b);
- EXPECT_TRUE(b >= a);
-}
-
-TEST(InstanceVec, CountConstructorsDestructors) {
- const int start = instances;
- for (int len = 0; len < 20; len++) {
- InstanceVec v;
- for (int i = 0; i < len; i++) {
- v.push_back(Instance(i));
- }
- EXPECT_EQ(start + len, instances);
-
- { // Copy constructor should create 'len' more instances.
- InstanceVec v_copy(v);
- EXPECT_EQ(start + len + len, instances);
- }
- EXPECT_EQ(start + len, instances);
-
- // Enlarging resize() must construct some objects
- v.resize(len + 10, Instance(100));
- EXPECT_EQ(start + len + 10, instances);
-
- // Shrinking resize() must destroy some objects
- v.resize(len, Instance(100));
- EXPECT_EQ(start + len, instances);
-
- // reserve() must not increase the number of initialized objects
- v.reserve(len + 1000);
- EXPECT_EQ(start + len, instances);
-
- // pop_back() and erase() must destroy one object
- if (len > 0) {
- v.pop_back();
- EXPECT_EQ(start + len - 1, instances);
- if (!v.empty()) {
- v.erase(v.begin());
- EXPECT_EQ(start + len - 2, instances);
- }
- }
- }
- EXPECT_EQ(start, instances);
-}
-
-TEST(InstanceVec, CountConstructorsDestructorsOnAssignment) {
- const int start = instances;
- for (int len = 0; len < 20; len++) {
- for (int longorshort = 0; longorshort <= 1; ++longorshort) {
- InstanceVec longer, shorter;
- for (int i = 0; i < len; i++) {
- longer.push_back(Instance(i));
- shorter.push_back(Instance(i));
- }
- longer.push_back(Instance(len));
- EXPECT_EQ(start + len + len + 1, instances);
-
- if (longorshort) {
- shorter = longer;
- EXPECT_EQ(start + (len + 1) + (len + 1), instances);
- } else {
- longer = shorter;
- EXPECT_EQ(start + len + len, instances);
- }
- }
- }
- EXPECT_EQ(start, instances);
-}
-
-TEST(RangedConstructor, SimpleType) {
- std::vector<int> source_v = {4, 5, 6, 7};
- // First try to fit in inline backing
- tensorflow::gtl::InlinedVector<int, 4> v(source_v.begin(), source_v.end());
- tensorflow::gtl::InlinedVector<int, 4> empty4;
- EXPECT_EQ(4, v.size());
- EXPECT_EQ(empty4.capacity(), v.capacity()); // Must still be inline
- EXPECT_EQ(4, v[0]);
- EXPECT_EQ(5, v[1]);
- EXPECT_EQ(6, v[2]);
- EXPECT_EQ(7, v[3]);
-
- // Now, force a re-allocate
- tensorflow::gtl::InlinedVector<int, 2> realloc_v(source_v.begin(),
- source_v.end());
- tensorflow::gtl::InlinedVector<int, 2> empty2;
- EXPECT_EQ(4, realloc_v.size());
- EXPECT_LT(empty2.capacity(), realloc_v.capacity());
- EXPECT_EQ(4, realloc_v[0]);
- EXPECT_EQ(5, realloc_v[1]);
- EXPECT_EQ(6, realloc_v[2]);
- EXPECT_EQ(7, realloc_v[3]);
-}
-
-TEST(RangedConstructor, ComplexType) {
- // We also use a list here to pass a different flavor of iterator (e.g. not
- // random-access).
- std::list<Instance> source_v = {Instance(0)};
-
- // First try to fit in inline backing
- tensorflow::gtl::InlinedVector<Instance, 1> v(source_v.begin(),
- source_v.end());
- tensorflow::gtl::InlinedVector<Instance, 1> empty1;
- EXPECT_EQ(1, v.size());
- EXPECT_EQ(empty1.capacity(), v.capacity()); // Must still be inline
- EXPECT_EQ(0, v[0].value_);
-
- std::list<Instance> source_v2 = {Instance(0), Instance(1), Instance(2),
- Instance(3)};
- // Now, force a re-allocate
- tensorflow::gtl::InlinedVector<Instance, 1> realloc_v(source_v2.begin(),
- source_v2.end());
- EXPECT_EQ(4, realloc_v.size());
- EXPECT_LT(empty1.capacity(), realloc_v.capacity());
- EXPECT_EQ(0, realloc_v[0].value_);
- EXPECT_EQ(1, realloc_v[1].value_);
- EXPECT_EQ(2, realloc_v[2].value_);
- EXPECT_EQ(3, realloc_v[3].value_);
-}
-
-TEST(RangedConstructor, ElementsAreConstructed) {
- std::vector<string> source_v = {"cat", "dog"};
-
- // Force expansion and re-allocation of v. Ensures that when the vector is
- // expanded that new elements are constructed.
- tensorflow::gtl::InlinedVector<string, 1> v(source_v.begin(), source_v.end());
- EXPECT_EQ("cat", v[0]);
- EXPECT_EQ("dog", v[1]);
-}
-
-TEST(InitializerListConstructor, SimpleTypeWithInlineBacking) {
- auto vec = tensorflow::gtl::InlinedVector<int, 3>{4, 5, 6};
- EXPECT_EQ(3, vec.size());
- EXPECT_EQ(3, vec.capacity());
- EXPECT_EQ(4, vec[0]);
- EXPECT_EQ(5, vec[1]);
- EXPECT_EQ(6, vec[2]);
-}
-
-TEST(InitializerListConstructor, SimpleTypeWithReallocationRequired) {
- auto vec = tensorflow::gtl::InlinedVector<int, 2>{4, 5, 6};
- EXPECT_EQ(3, vec.size());
- EXPECT_LE(3, vec.capacity());
- EXPECT_EQ(4, vec[0]);
- EXPECT_EQ(5, vec[1]);
- EXPECT_EQ(6, vec[2]);
-}
-
-TEST(InitializerListConstructor, DisparateTypesInList) {
- EXPECT_EQ((std::vector<int>{-7, 8}),
- Vec(tensorflow::gtl::InlinedVector<int, 2>{-7, 8ULL}));
-
- EXPECT_EQ(
- (std::vector<string>{"foo", "bar"}),
- Vec(tensorflow::gtl::InlinedVector<string, 2>{"foo", string("bar")}));
-}
-
-TEST(InitializerListConstructor, ComplexTypeWithInlineBacking) {
- tensorflow::gtl::InlinedVector<Instance, 1> empty;
- auto vec = tensorflow::gtl::InlinedVector<Instance, 1>{Instance(0)};
- EXPECT_EQ(1, vec.size());
- EXPECT_EQ(empty.capacity(), vec.capacity());
- EXPECT_EQ(0, vec[0].value_);
-}
-
-TEST(InitializerListConstructor, ComplexTypeWithReallocationRequired) {
- auto vec =
- tensorflow::gtl::InlinedVector<Instance, 1>{Instance(0), Instance(1)};
- EXPECT_EQ(2, vec.size());
- EXPECT_LE(2, vec.capacity());
- EXPECT_EQ(0, vec[0].value_);
- EXPECT_EQ(1, vec[1].value_);
-}
-
-TEST(DynamicVec, DynamicVecCompiles) {
- DynamicVec v;
- (void)v;
-}
-
-static void BM_InlinedVectorFill(int iters, int len) {
- for (int i = 0; i < iters; i++) {
- IntVec v;
- for (int j = 0; j < len; j++) {
- v.push_back(j);
- }
- }
- testing::BytesProcessed((int64{iters} * len) * sizeof(int));
-}
-BENCHMARK(BM_InlinedVectorFill)->Range(0, 1024);
-
-static void BM_InlinedVectorFillRange(int iters, int len) {
- std::unique_ptr<int[]> ia(new int[len]);
- for (int j = 0; j < len; j++) {
- ia[j] = j;
- }
- for (int i = 0; i < iters; i++) {
- IntVec TF_ATTRIBUTE_UNUSED v(ia.get(), ia.get() + len);
- }
- testing::BytesProcessed((int64{iters} * len) * sizeof(int));
-}
-BENCHMARK(BM_InlinedVectorFillRange)->Range(0, 1024);
-
-static void BM_StdVectorFill(int iters, int len) {
- for (int i = 0; i < iters; i++) {
- std::vector<int> v;
- v.reserve(len);
- for (int j = 0; j < len; j++) {
- v.push_back(j);
- }
- }
- testing::BytesProcessed((int64{iters} * len) * sizeof(int));
-}
-BENCHMARK(BM_StdVectorFill)->Range(0, 1024);
-
-bool StringRepresentedInline(string s) {
- const char* chars = s.data();
- string s1 = std::move(s);
- return s1.data() != chars;
-}
-
-static void BM_InlinedVectorFillString(int iters, int len) {
- string strings[4] = {"a quite long string", "another long string",
- "012345678901234567", "to cause allocation"};
- for (int i = 0; i < iters; i++) {
- gtl::InlinedVector<string, 8> v;
- for (int j = 0; j < len; j++) {
- v.push_back(strings[j & 3]);
- }
- }
- testing::ItemsProcessed(int64{iters} * len);
-}
-BENCHMARK(BM_InlinedVectorFillString)->Range(0, 1024);
-
-static void BM_StdVectorFillString(int iters, int len) {
- string strings[4] = {"a quite long string", "another long string",
- "012345678901234567", "to cause allocation"};
- for (int i = 0; i < iters; i++) {
- std::vector<string> v;
- v.reserve(len);
- for (int j = 0; j < len; j++) {
- v.push_back(strings[j & 3]);
- }
- }
- testing::ItemsProcessed(int64{iters} * len);
- // The purpose of the benchmark is to verify that inlined vector is
- // efficient when moving is more efficient than copying. To do so, we
- // use strings that are larger than the small string optimization.
- CHECK(!StringRepresentedInline(strings[0]));
-}
-BENCHMARK(BM_StdVectorFillString)->Range(0, 1024);
-
-namespace {
-struct Buffer { // some arbitrary structure for benchmarking.
- char* base;
- int length;
- int capacity;
- void* user_data;
-};
-} // anonymous namespace
-
-static void BM_InlinedVectorTenAssignments(int iters, int len) {
- typedef tensorflow::gtl::InlinedVector<Buffer, 2> BufferVec;
-
- BufferVec src;
- src.resize(len);
-
- iters *= 10;
- BufferVec dst;
- for (int i = 0; i < iters; i++) {
- dst = src;
- }
-}
-BENCHMARK(BM_InlinedVectorTenAssignments)
- ->Arg(0)
- ->Arg(1)
- ->Arg(2)
- ->Arg(3)
- ->Arg(4)
- ->Arg(20);
-
-static void BM_CreateFromInitializerList(int iters) {
- for (; iters > 0; iters--) {
- tensorflow::gtl::InlinedVector<int, 4> x{1, 2, 3};
- (void)x[0];
- }
-}
-BENCHMARK(BM_CreateFromInitializerList);
-
-namespace {
-
-struct LargeSwappable {
- LargeSwappable() : d_(1024, 17) {}
- ~LargeSwappable() {}
- LargeSwappable(const LargeSwappable& o) : d_(o.d_) {}
-
- friend void swap(LargeSwappable& a, LargeSwappable& b) {
- using std::swap;
- swap(a.d_, b.d_);
- }
-
- LargeSwappable& operator=(LargeSwappable o) {
- using std::swap;
- swap(*this, o);
- return *this;
- }
-
- std::vector<int> d_;
-};
-
-} // namespace
-
-static void BM_LargeSwappableElements(int iters, int len) {
- typedef tensorflow::gtl::InlinedVector<LargeSwappable, 32> Vec;
- Vec a(len);
- Vec b;
- while (--iters >= 0) {
- using std::swap;
- swap(a, b);
- }
-}
-BENCHMARK(BM_LargeSwappableElements)->Range(0, 1024);
-
-} // namespace tensorflow
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index cb0cb46752..9836f784ab 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -29381,6 +29381,49 @@ op {
}
}
op {
+ name: "MapDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "other_arguments"
+ type_list_attr: "Targuments"
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "Targuments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "use_inter_op_parallelism"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+}
+op {
name: "MapDefun"
input_arg {
name: "arguments"
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index f03639e833..1a5ad8f421 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -198,6 +198,7 @@ REGISTER_OP("MapDataset")
.Attr("Targuments: list(type) >= 0")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
+ .Attr("use_inter_op_parallelism: bool = true")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ParallelMapDataset")
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 4419f93d0c..28b25fdeae 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -14542,6 +14542,13 @@ op {
has_minimum: true
minimum: 1
}
+ attr {
+ name: "use_inter_op_parallelism"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
}
op {
name: "MapDefun"
diff --git a/tensorflow/core/platform/cloud/curl_http_request.cc b/tensorflow/core/platform/cloud/curl_http_request.cc
index a1be4aacce..5e1eabee5b 100644
--- a/tensorflow/core/platform/cloud/curl_http_request.cc
+++ b/tensorflow/core/platform/cloud/curl_http_request.cc
@@ -394,9 +394,9 @@ size_t CurlHttpRequest::HeaderCallback(const void* ptr, size_t size,
.StopCapture()
.OneLiteral(": ")
.GetResult(&value, &name)) {
- string str_value = std::string(value);
+ string str_value(value);
str_util::StripTrailingWhitespace(&str_value);
- that->response_headers_[std::string(name)] = str_value;
+ that->response_headers_[string(name)] = str_value;
}
return size * nmemb;
}
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc
index 9d33787bd5..8f959c018e 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system.cc
@@ -179,13 +179,13 @@ Status ParseGcsPath(StringPiece fname, bool empty_object_ok, string* bucket,
return errors::InvalidArgument("GCS path doesn't start with 'gs://': ",
fname);
}
- *bucket = std::string(bucketp);
+ *bucket = string(bucketp);
if (bucket->empty() || *bucket == ".") {
return errors::InvalidArgument("GCS path doesn't contain a bucket name: ",
fname);
}
str_util::ConsumePrefix(&objectp, "/");
- *object = std::string(objectp);
+ *object = string(objectp);
if (!empty_object_ok && object->empty()) {
return errors::InvalidArgument("GCS path doesn't contain an object name: ",
fname);
@@ -224,7 +224,7 @@ std::set<string> AddAllSubpaths(const std::vector<string>& paths) {
for (const string& path : paths) {
StringPiece subpath = io::Dirname(path);
while (!subpath.empty()) {
- result.emplace(std::string(subpath));
+ result.emplace(string(subpath));
subpath = io::Dirname(subpath);
}
}
@@ -723,7 +723,7 @@ GcsFileSystem::GcsFileSystem() {
if (!header_name.empty() && !header_value.empty()) {
additional_header_.reset(new std::pair<const string, const string>(
- std::string(header_name), std::string(header_value)));
+ string(header_name), string(header_value)));
VLOG(1) << "GCS additional header ENABLED. "
<< "Name: " << additional_header_->first << ", "
@@ -1229,7 +1229,7 @@ Status GcsFileSystem::GetMatchingPaths(const string& pattern,
// Find the fixed prefix by looking for the first wildcard.
const string& fixed_prefix =
pattern.substr(0, pattern.find_first_of("*?[\\"));
- const string& dir = std::string(io::Dirname(fixed_prefix));
+ const string dir(io::Dirname(fixed_prefix));
if (dir.empty()) {
return errors::InvalidArgument(
"A GCS pattern doesn't have a bucket name: ", pattern);
@@ -1326,7 +1326,7 @@ Status GcsFileSystem::GetChildrenBounded(const string& dirname,
" doesn't match the prefix ", object_prefix));
}
if (!relative_path.empty() || include_self_directory_marker) {
- result->emplace_back(std::string(relative_path));
+ result->emplace_back(relative_path);
}
if (++retrieved_results >= max_results) {
return Status::OK();
@@ -1354,7 +1354,7 @@ Status GcsFileSystem::GetChildrenBounded(const string& dirname,
"Unexpected response: the returned folder name ", prefix_str,
" doesn't match the prefix ", object_prefix);
}
- result->emplace_back(std::string(relative_path));
+ result->emplace_back(relative_path);
if (++retrieved_results >= max_results) {
return Status::OK();
}
diff --git a/tensorflow/core/platform/cloud/oauth_client.cc b/tensorflow/core/platform/cloud/oauth_client.cc
index ee6ba7b041..9b85cae9b9 100644
--- a/tensorflow/core/platform/cloud/oauth_client.cc
+++ b/tensorflow/core/platform/cloud/oauth_client.cc
@@ -216,7 +216,7 @@ Status OAuthClient::GetTokenFromServiceAccountJson(
// Send the request to the Google OAuth 2.0 server to get the token.
std::unique_ptr<HttpRequest> request(http_request_factory_->Create());
std::vector<char> response_buffer;
- request->SetUri(std::string(oauth_server_uri));
+ request->SetUri(string(oauth_server_uri));
request->SetPostFromBuffer(request_body.c_str(), request_body.size());
request->SetResultBuffer(&response_buffer);
TF_RETURN_IF_ERROR(request->Send());
@@ -248,7 +248,7 @@ Status OAuthClient::GetTokenFromRefreshTokenJson(
std::unique_ptr<HttpRequest> request(http_request_factory_->Create());
std::vector<char> response_buffer;
- request->SetUri(std::string(oauth_server_uri));
+ request->SetUri(string(oauth_server_uri));
request->SetPostFromBuffer(request_body.c_str(), request_body.size());
request->SetResultBuffer(&response_buffer);
TF_RETURN_IF_ERROR(request->Send());
diff --git a/tensorflow/core/platform/cloud/oauth_client_test.cc b/tensorflow/core/platform/cloud/oauth_client_test.cc
index 4ffa72288b..1cd0641cd3 100644
--- a/tensorflow/core/platform/cloud/oauth_client_test.cc
+++ b/tensorflow/core/platform/cloud/oauth_client_test.cc
@@ -126,9 +126,9 @@ TEST(OAuthClientTest, GetTokenFromServiceAccountJson) {
EXPECT_EQ("urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Ajwt-bearer",
grant_type);
- int last_dot = std::string(assertion).find_last_of(".");
- string header_dot_claim = std::string(assertion.substr(0, last_dot));
- string signature_encoded = std::string(assertion.substr(last_dot + 1));
+ int last_dot = assertion.rfind('.');
+ string header_dot_claim(assertion.substr(0, last_dot));
+ string signature_encoded(assertion.substr(last_dot + 1));
// Check that 'signature' signs 'header_dot_claim'.
diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
index 07b2e3426b..bb841aeab7 100644
--- a/tensorflow/core/platform/default/build_config.bzl
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -625,6 +625,7 @@ def tf_additional_lib_deps():
"""Additional dependencies needed to build TF libraries."""
return [
"@com_google_absl//absl/base:base",
+ "@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/types:span",
"@com_google_absl//absl/types:optional",
] + if_static(
diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto
index da3a99565e..625d5649e6 100644
--- a/tensorflow/core/protobuf/config.proto
+++ b/tensorflow/core/protobuf/config.proto
@@ -390,9 +390,12 @@ message ConfigProto {
message Experimental {
// Task name for group resolution.
string collective_group_leader = 1;
- // Whether the client will format templated errors. For example, the string:
- // "The node was defined on ^^node:Foo:${file}:${line}^^".
- bool client_handles_error_formatting = 2;
+
+ // We removed the flag client_handles_error_formatting. Marking the tag
+ // number as reserved.
+ // TODO(shikharagarwal): Should we just remove this tag so that it can be
+ // used in future for other purpose?
+ reserved 2;
// Which executor to use, the default executor will be used
// if it is an empty string or "DEFAULT"
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index 1841dd998b..ae0ad27f15 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -1132,7 +1132,7 @@ class BaseSession(SessionInterface):
for details of the allowable fetch types.
feed_list: (Optional.) A list of `feed_dict` keys. See
`tf.Session.run` for details of the allowable feed key types.
- accept_options: (Optional.) Iff `True`, the returned `Callable` will be
+ accept_options: (Optional.) If `True`, the returned `Callable` will be
able to accept `tf.RunOptions` and `tf.RunMetadata` as optional
keyword arguments `options` and `run_metadata`, respectively, with
the same syntax and semantics as `tf.Session.run`, which is useful
@@ -1302,9 +1302,7 @@ class BaseSession(SessionInterface):
node_def = op.node_def
except KeyError:
pass
- if (self._config is not None and
- self._config.experimental.client_handles_error_formatting):
- message = error_interpolation.interpolate(message, self._graph)
+ message = error_interpolation.interpolate(message, self._graph)
raise type(e)(node_def, op, message)
def _extend_graph(self):
diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py
index 459f494b48..586f4c6936 100644
--- a/tensorflow/python/compat/compat.py
+++ b/tensorflow/python/compat/compat.py
@@ -26,7 +26,7 @@ import datetime
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export
-_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 9, 4)
+_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 9, 5)
@tf_export("compat.forward_compatible")
diff --git a/tensorflow/python/data/kernel_tests/iterator_ops_test.py b/tensorflow/python/data/kernel_tests/iterator_ops_test.py
index b0414ad655..671e5d4812 100644
--- a/tensorflow/python/data/kernel_tests/iterator_ops_test.py
+++ b/tensorflow/python/data/kernel_tests/iterator_ops_test.py
@@ -91,7 +91,7 @@ class IteratorTest(test.TestCase):
self.assertEqual([c.shape[1:] for c in components],
[t.shape for t in get_next])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for _ in range(14):
for i in range(7):
result = sess.run(get_next)
@@ -117,7 +117,7 @@ class IteratorTest(test.TestCase):
self.assertEqual([c.shape[1:] for c in components],
[t.shape for t in get_next])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for _ in range(14):
for i in range(7):
result = sess.run(get_next)
@@ -208,7 +208,7 @@ class IteratorTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
sess.run(next_element)
@@ -216,7 +216,7 @@ class IteratorTest(test.TestCase):
with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
sess.run(next_element)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
def consumer_thread():
with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
@@ -287,7 +287,7 @@ class IteratorTest(test.TestCase):
.make_initializable_iterator())
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesRegexp(errors.FailedPreconditionError,
"iterator has not been initialized"):
sess.run(get_next)
@@ -308,7 +308,7 @@ class IteratorTest(test.TestCase):
self.assertEqual(dataset_4.output_types, iterator.output_types)
self.assertEqual([None], iterator.output_shapes.as_list())
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# The iterator is initially uninitialized.
with self.assertRaises(errors.FailedPreconditionError):
sess.run(get_next)
@@ -380,7 +380,7 @@ class IteratorTest(test.TestCase):
self.assertEqual(dataset_4.output_types, feedable_iterator.output_types)
self.assertEqual([], feedable_iterator.output_shapes)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
iterator_3_handle = sess.run(iterator_3.string_handle())
iterator_4_handle = sess.run(iterator_4.string_handle())
@@ -436,7 +436,7 @@ class IteratorTest(test.TestCase):
self.assertEqual(dataset_4.output_types, feedable_iterator.output_types)
self.assertEqual([], feedable_iterator.output_shapes)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
iterator_3_handle = sess.run(iterator_3.string_handle())
iterator_4_handle = sess.run(iterator_4.string_handle())
@@ -524,7 +524,7 @@ class IteratorTest(test.TestCase):
feedable_int_any = iterator_ops.Iterator.from_string_handle(
handle_placeholder, dtypes.int32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
handle_int_scalar = sess.run(
dataset_int_scalar.make_one_shot_iterator().string_handle())
handle_float_vector = sess.run(
@@ -687,7 +687,7 @@ class IteratorTest(test.TestCase):
f=_remote_fn,
target=target_placeholder)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
elem = sess.run(
remote_op,
feed_dict={
@@ -803,16 +803,15 @@ class IteratorCheckpointingTest(test.TestCase):
get_next = iterator.get_next if context.executing_eagerly(
) else functools.partial(self.evaluate, iterator.get_next())
checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
- with self.test_session() as sess:
- self.assertAllEqual([1, 4], get_next())
- save_path = checkpoint.save(checkpoint_prefix)
- self.assertAllEqual([9, 16], get_next())
- self.assertAllEqual([25, 36], get_next())
- checkpoint.restore(save_path).run_restore_ops(sess)
- self.assertAllEqual([9, 16], get_next())
- self.assertAllEqual([25, 36], get_next())
- with self.assertRaises(errors.OutOfRangeError):
- get_next()
+ self.assertAllEqual([1, 4], get_next())
+ save_path = checkpoint.save(checkpoint_prefix)
+ self.assertAllEqual([9, 16], get_next())
+ self.assertAllEqual([25, 36], get_next())
+ checkpoint.restore(save_path).run_restore_ops()
+ self.assertAllEqual([9, 16], get_next())
+ self.assertAllEqual([25, 36], get_next())
+ with self.assertRaises(errors.OutOfRangeError):
+ get_next()
@test_util.run_in_graph_and_eager_modes
def testSaveRestoreMultipleIterator(self):
@@ -833,19 +832,18 @@ class IteratorCheckpointingTest(test.TestCase):
) else functools.partial(self.evaluate, iterator_3.get_next())
checkpoint = checkpointable_utils.Checkpoint(
iterator_1=iterator_1, iterator_2=iterator_2, iterator_3=iterator_3)
- with self.test_session() as sess:
- self.assertAllEqual([1, 4], get_next_1())
- self.assertAllEqual(0, get_next_3())
- self.assertAllEqual(1, get_next_3())
- self.assertAllEqual(2, get_next_3())
- save_path = checkpoint.save(checkpoint_prefix)
- self.assertAllEqual([1, 4], get_next_2())
- self.assertAllEqual([9, 16], get_next_2())
- self.assertAllEqual(3, get_next_3())
- checkpoint.restore(save_path).run_restore_ops(sess)
- self.assertAllEqual([9, 16], get_next_1())
- self.assertAllEqual([1, 4], get_next_2())
- self.assertAllEqual(3, get_next_3())
+ self.assertAllEqual([1, 4], get_next_1())
+ self.assertAllEqual(0, get_next_3())
+ self.assertAllEqual(1, get_next_3())
+ self.assertAllEqual(2, get_next_3())
+ save_path = checkpoint.save(checkpoint_prefix)
+ self.assertAllEqual([1, 4], get_next_2())
+ self.assertAllEqual([9, 16], get_next_2())
+ self.assertAllEqual(3, get_next_3())
+ checkpoint.restore(save_path).run_restore_ops()
+ self.assertAllEqual([9, 16], get_next_1())
+ self.assertAllEqual([1, 4], get_next_2())
+ self.assertAllEqual(3, get_next_3())
@test_util.run_in_graph_and_eager_modes
def testRestoreExhaustedIterator(self):
@@ -856,17 +854,16 @@ class IteratorCheckpointingTest(test.TestCase):
get_next = iterator.get_next if context.executing_eagerly(
) else functools.partial(self.evaluate, iterator.get_next())
checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
- with self.test_session() as sess:
- self.assertAllEqual(0, get_next())
- self.assertAllEqual(1, get_next())
- save_path = checkpoint.save(checkpoint_prefix)
- self.assertAllEqual(2, get_next())
- checkpoint.restore(save_path).run_restore_ops(sess)
- self.assertAllEqual(2, get_next())
- save_path = checkpoint.save(checkpoint_prefix)
- checkpoint.restore(save_path).run_restore_ops(sess)
- with self.assertRaises(errors.OutOfRangeError):
- get_next()
+ self.assertAllEqual(0, get_next())
+ self.assertAllEqual(1, get_next())
+ save_path = checkpoint.save(checkpoint_prefix)
+ self.assertAllEqual(2, get_next())
+ checkpoint.restore(save_path).run_restore_ops()
+ self.assertAllEqual(2, get_next())
+ save_path = checkpoint.save(checkpoint_prefix)
+ checkpoint.restore(save_path).run_restore_ops()
+ with self.assertRaises(errors.OutOfRangeError):
+ get_next()
def testRestoreInReconstructedIteratorInitializable(self):
checkpoint_directory = self.get_temp_dir()
@@ -876,7 +873,7 @@ class IteratorCheckpointingTest(test.TestCase):
get_next = iterator.get_next()
checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
for i in range(5):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
checkpoint.restore(checkpoint_management.latest_checkpoint(
checkpoint_directory)).initialize_or_restore(sess)
for j in range(2):
diff --git a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
index 52b4320bf1..df2c9b170a 100644
--- a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
@@ -711,57 +711,74 @@ class MapDatasetBenchmark(test.Benchmark):
def benchmarkChainOfMaps(self):
chain_lengths = [0, 1, 2, 5, 10, 20, 50]
for chain_length in chain_lengths:
- with ops.Graph().as_default():
- dataset = dataset_ops.Dataset.from_tensors(0).repeat(None)
- for _ in range(chain_length):
- dataset = dataset.map(lambda x: x)
- iterator = dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- with session.Session() as sess:
- for _ in range(5):
- sess.run(next_element.op)
- deltas = []
- for _ in range(100):
- start = time.time()
- for _ in range(100):
+ for use_inter_op_parallelism in [False, True]:
+ with ops.Graph().as_default():
+ dataset = dataset_ops.Dataset.from_tensors(0).repeat(None)
+ for _ in range(chain_length):
+ dataset = dataset_ops.MapDataset(
+ dataset,
+ lambda x: x,
+ use_inter_op_parallelism=use_inter_op_parallelism)
+ iterator = dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with session.Session() as sess:
+ for _ in range(5):
sess.run(next_element.op)
- end = time.time()
- deltas.append(end - start)
-
- median_wall_time = np.median(deltas) / 100
- print("Map dataset chain length: %d Median wall time: %f"
- % (chain_length, median_wall_time))
- self.report_benchmark(
- iters=1000, wall_time=median_wall_time,
- name="benchmark_map_dataset_chain_latency_%d" % chain_length)
+ deltas = []
+ for _ in range(100):
+ start = time.time()
+ for _ in range(100):
+ sess.run(next_element.op)
+ end = time.time()
+ deltas.append(end - start)
+
+ median_wall_time = np.median(deltas) / 100
+ print("Map dataset chain length%s: %d Median wall time: %f" %
+ (" (single threaded mode)" if not use_inter_op_parallelism
+ else "", chain_length, median_wall_time))
+ self.report_benchmark(
+ iters=1000,
+ wall_time=median_wall_time,
+ name="benchmark_map_dataset_chain_latency_%d%s" %
+ (chain_length, "_single_threaded"
+ if not use_inter_op_parallelism else ""))
def benchmarkMapFanOut(self):
fan_outs = [1, 2, 5, 10, 20, 50, 100]
for fan_out in fan_outs:
- with ops.Graph().as_default():
- dataset = dataset_ops.Dataset.from_tensors(
- tuple(0 for _ in range(fan_out))).repeat(None).map(lambda *xs: xs)
- iterator = dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- with session.Session() as sess:
- for _ in range(5):
- sess.run(next_element[0].op)
- deltas = []
- for _ in range(100):
- start = time.time()
- for _ in range(100):
+ for use_inter_op_parallelism in [False, True]:
+ with ops.Graph().as_default():
+ dataset = dataset_ops.Dataset.from_tensors(
+ tuple(0 for _ in range(fan_out))).repeat(None)
+ dataset = dataset_ops.MapDataset(
+ dataset,
+ lambda *xs: xs,
+ use_inter_op_parallelism=use_inter_op_parallelism)
+ iterator = dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with session.Session() as sess:
+ for _ in range(5):
sess.run(next_element[0].op)
- end = time.time()
- deltas.append(end - start)
-
- median_wall_time = np.median(deltas) / 100
- print("Map dataset fan out: %d Median wall time: %f"
- % (fan_out, median_wall_time))
- self.report_benchmark(
- iters=1000, wall_time=median_wall_time,
- name="benchmark_map_dataset_fan_out_%d" % fan_out)
+ deltas = []
+ for _ in range(100):
+ start = time.time()
+ for _ in range(100):
+ sess.run(next_element[0].op)
+ end = time.time()
+ deltas.append(end - start)
+
+ median_wall_time = np.median(deltas) / 100
+ print("Map dataset fan out%s: %d Median wall time: %f" %
+ (" (single threaded mode)" if not use_inter_op_parallelism
+ else "", fan_out, median_wall_time))
+ self.report_benchmark(
+ iters=1000,
+ wall_time=median_wall_time,
+ name="benchmark_map_dataset_fan_out_%d%s" %
+ (fan_out, "_single_threaded"
+ if not use_inter_op_parallelism else ""))
if __name__ == "__main__":
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index 8c37b1871b..6205ee392e 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -2207,10 +2207,11 @@ def _warn_if_collections(transformation_name):
class MapDataset(Dataset):
"""A `Dataset` that maps a function over elements in its input."""
- def __init__(self, input_dataset, map_func):
+ def __init__(self, input_dataset, map_func, use_inter_op_parallelism=True):
"""See `Dataset.map()` for details."""
super(MapDataset, self).__init__()
self._input_dataset = input_dataset
+ self._use_inter_op_parallelism = use_inter_op_parallelism
wrapped_func = StructuredFunctionWrapper(
map_func, "Dataset.map()", input_dataset)
@@ -2225,6 +2226,7 @@ class MapDataset(Dataset):
input_t,
self._map_func.captured_inputs,
f=self._map_func,
+ use_inter_op_parallelism=self._use_inter_op_parallelism,
**flat_structure(self))
@property
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index caf36b6a36..6673178ee7 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -64,7 +64,7 @@ class BackpropTest(test.TestCase):
grad = backprop.gradients_function(fn, [0])(var)[0]
grad = self.evaluate(ops.convert_to_tensor(grad))
- with context.graph_mode(), self.test_session():
+ with context.graph_mode():
tf_var = array_ops.constant(var_np, dtypes.float32)
tf_ind1 = array_ops.constant([0, 1])
tf_ind2 = array_ops.constant([2, 3])
@@ -79,7 +79,7 @@ class BackpropTest(test.TestCase):
tf_dense_grad = math_ops.unsorted_segment_sum(
tf_grad.values, tf_grad.indices, tf_grad.dense_shape[0])
- self.assertAllClose(grad, tf_dense_grad.eval())
+ self.assertAllClose(grad, self.evaluate(tf_dense_grad))
def testImplicitGradWithResourceVariable(self):
x = resource_variable_ops.ResourceVariable(
@@ -198,7 +198,7 @@ class BackpropTest(test.TestCase):
grad = backprop.implicit_grad(f)()[0][0]
opt = training.GradientDescentOptimizer(lrn_rate)
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
tf_x = array_ops.ones((batch_size), dtypes.int64)
# TODO(ashankar,apassos): Change to ResourceVariable.
tf_embedding = variables.Variable(
@@ -941,7 +941,7 @@ class BackpropTest(test.TestCase):
def testZerosCacheDoesntLeakAcrossGraphs(self):
with context.graph_mode():
def get_grad():
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
t = constant_op.constant(1, dtype=dtypes.float32, shape=(10, 4))
x = constant_op.constant(2, dtype=dtypes.float32, shape=(10, 4))
with backprop.GradientTape() as tape:
diff --git a/tensorflow/python/estimator/canned/dnn.py b/tensorflow/python/estimator/canned/dnn.py
index c08cf61220..1c0c4581c0 100644
--- a/tensorflow/python/estimator/canned/dnn.py
+++ b/tensorflow/python/estimator/canned/dnn.py
@@ -142,7 +142,7 @@ def _dnn_model_fn(features,
dropout=None,
input_layer_partitioner=None,
config=None,
- tpu_estimator_spec=False,
+ use_tpu=False,
batch_norm=False):
"""Deep Neural Net model_fn.
@@ -164,8 +164,8 @@ def _dnn_model_fn(features,
input_layer_partitioner: Partitioner for input layer. Defaults
to `min_max_variable_partitioner` with `min_slice_size` 64 << 20.
config: `RunConfig` object to configure the runtime settings.
- tpu_estimator_spec: Whether to return a `_TPUEstimatorSpec` or
- or `model_fn.EstimatorSpec` instance.
+ use_tpu: Whether to make a DNN model able to run on TPU. Will make function
+ return a `_TPUEstimatorSpec` instance and disable variable partitioning.
batch_norm: Whether to use batch normalization after each hidden layer.
Returns:
@@ -182,13 +182,15 @@ def _dnn_model_fn(features,
optimizer, learning_rate=_LEARNING_RATE)
num_ps_replicas = config.num_ps_replicas if config else 0
- partitioner = partitioned_variables.min_max_variable_partitioner(
- max_partitions=num_ps_replicas)
+ partitioner = (None if use_tpu else
+ partitioned_variables.min_max_variable_partitioner(
+ max_partitions=num_ps_replicas))
with variable_scope.variable_scope(
'dnn',
values=tuple(six.itervalues(features)),
partitioner=partitioner):
input_layer_partitioner = input_layer_partitioner or (
+ None if use_tpu else
partitioned_variables.min_max_variable_partitioner(
max_partitions=num_ps_replicas,
min_slice_size=64 << 20))
@@ -203,7 +205,7 @@ def _dnn_model_fn(features,
batch_norm=batch_norm)
logits = logit_fn(features=features, mode=mode)
- if tpu_estimator_spec:
+ if use_tpu:
return head._create_tpu_estimator_spec( # pylint: disable=protected-access
features=features,
mode=mode,
diff --git a/tensorflow/python/framework/error_interpolation.py b/tensorflow/python/framework/error_interpolation.py
index a69018d00d..46bda2e621 100644
--- a/tensorflow/python/framework/error_interpolation.py
+++ b/tensorflow/python/framework/error_interpolation.py
@@ -15,7 +15,7 @@
"""Function for interpolating formatted errors from the TensorFlow runtime.
Exposes the function `interpolate` to interpolate messages with tags of the form
-^^type:name:format^^.
+{{type name}}.
"""
from __future__ import absolute_import
@@ -32,7 +32,7 @@ import six
from tensorflow.python.util import tf_stack
_NAME_REGEX = r"[A-Za-z0-9.][A-Za-z0-9_.\-/]*?"
-_TAG_REGEX = r"\^\^({name}):({name})\^\^".format(name=_NAME_REGEX)
+_TAG_REGEX = r"{{{{({name}) ({name})}}}}".format(name=_NAME_REGEX)
_INTERPOLATION_REGEX = r"^(.*?)({tag})".format(tag=_TAG_REGEX)
_INTERPOLATION_PATTERN = re.compile(_INTERPOLATION_REGEX)
@@ -48,8 +48,8 @@ def _parse_message(message):
"""Parses the message.
Splits the message into separators and tags. Tags are named tuples
- representing the string ^^type:name^^ and they are separated by
- separators. For example, in "123^^node:Foo^^456^^node:Bar^^789", there are
+ representing the string {{type name}} and they are separated by
+ separators. For example, in "123{{node Foo}}456{{node Bar}}789", there are
two tags and three separators. The separators are the numeric characters.
Args:
@@ -58,7 +58,7 @@ def _parse_message(message):
Returns:
(list of separator strings, list of _ParseTags).
- For example, if message is "123^^node:Foo^^456" then this function
+ For example, if message is "123{{node Foo}}456" then this function
returns (["123", "456"], [_ParseTag("node", "Foo")])
"""
seps = []
@@ -276,7 +276,7 @@ def interpolate(error_message, graph):
message.
Returns:
- The string with tags of the form ^^type:name^^ interpolated.
+ The string with tags of the form {{type name}} interpolated.
"""
seps, tags = _parse_message(error_message)
subs = []
@@ -288,7 +288,7 @@ def interpolate(error_message, graph):
except KeyError:
op = None
- msg = "^^%s:%s^^" % (t.type, t.name)
+ msg = "{{%s %s}}" % (t.type, t.name)
if op is not None:
field_dict = compute_field_dict(op)
if t.type == "node":
diff --git a/tensorflow/python/framework/error_interpolation_test.py b/tensorflow/python/framework/error_interpolation_test.py
index a7c7bbf28b..d312b825d2 100644
--- a/tensorflow/python/framework/error_interpolation_test.py
+++ b/tensorflow/python/framework/error_interpolation_test.py
@@ -167,20 +167,20 @@ class InterpolateFilenamesAndLineNumbersTest(test.TestCase):
self.assertEqual(interpolated_string, normal_string)
def testOneTagWithAFakeNameResultsInPlaceholders(self):
- one_tag_string = "^^node:MinusOne^^"
+ one_tag_string = "{{node MinusOne}}"
interpolated_string = error_interpolation.interpolate(
one_tag_string, self.graph)
self.assertEqual(one_tag_string, interpolated_string)
def testTwoTagsNoSeps(self):
- two_tags_no_seps = "^^node:One^^^^node:Three^^"
+ two_tags_no_seps = "{{node One}}{{node Three}}"
interpolated_string = error_interpolation.interpolate(
two_tags_no_seps, self.graph)
self.assertRegexpMatches(interpolated_string,
"constant_op.py:[0-9]+.*constant_op.py:[0-9]+")
def testTwoTagsWithSeps(self):
- two_tags_with_seps = ";;;^^node:Two^^,,,^^node:Three^^;;;"
+ two_tags_with_seps = ";;;{{node Two}},,,{{node Three}};;;"
interpolated_string = error_interpolation.interpolate(
two_tags_with_seps, self.graph)
expected_regex = (
@@ -206,23 +206,23 @@ class InterpolateDeviceSummaryTest(test.TestCase):
self.graph = self.three.graph
def testNodeZeroHasNoDeviceSummaryInfo(self):
- message = "^^colocation_node:zero^^"
+ message = "{{colocation_node zero}}"
result = error_interpolation.interpolate(message, self.graph)
self.assertIn("No device assignments were active", result)
def testNodeOneHasExactlyOneInterpolatedDevice(self):
- message = "^^colocation_node:one^^"
+ message = "{{colocation_node one}}"
result = error_interpolation.interpolate(message, self.graph)
self.assertEqual(2, result.count("tf.device(/cpu)"))
def testNodeTwoHasTwoInterpolatedDevice(self):
- message = "^^colocation_node:two^^"
+ message = "{{colocation_node two}}"
result = error_interpolation.interpolate(message, self.graph)
self.assertEqual(2, result.count("tf.device(/cpu)"))
self.assertEqual(2, result.count("tf.device(/cpu:0)"))
def testNodeThreeHasFancyFunctionDisplayNameForInterpolatedDevice(self):
- message = "^^colocation_node:three^^"
+ message = "{{colocation_node three}}"
result = error_interpolation.interpolate(message, self.graph)
num_devices = result.count("tf.device")
self.assertEqual(2, num_devices)
@@ -256,12 +256,12 @@ class InterpolateColocationSummaryTest(test.TestCase):
self.graph = node_three.graph
def testNodeThreeHasColocationInterpolation(self):
- message = "^^colocation_node:Three_with_one^^"
+ message = "{{colocation_node Three_with_one}}"
result = error_interpolation.interpolate(message, self.graph)
self.assertIn("colocate_with(One)", result)
def testNodeFourHasColocationInterpolationForNodeThreeOnly(self):
- message = "^^colocation_node:Four_with_three^^"
+ message = "{{colocation_node Four_with_three}}"
result = error_interpolation.interpolate(message, self.graph)
self.assertIn("colocate_with(Three_with_one)", result)
self.assertNotIn(
@@ -269,13 +269,13 @@ class InterpolateColocationSummaryTest(test.TestCase):
"Node One should not appear in Four_with_three's summary:\n%s" % result)
def testNodeFiveHasColocationInterpolationForNodeOneAndTwo(self):
- message = "^^colocation_node:Five_with_one_with_two^^"
+ message = "{{colocation_node Five_with_one_with_two}}"
result = error_interpolation.interpolate(message, self.graph)
self.assertIn("colocate_with(One)", result)
self.assertIn("colocate_with(Two)", result)
def testColocationInterpolationForNodeLackingColocation(self):
- message = "^^colocation_node:One^^"
+ message = "{{colocation_node One}}"
result = error_interpolation.interpolate(message, self.graph)
self.assertIn("No node-device colocations", result)
self.assertNotIn("Two", result)
diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py
index b14290c203..26170b000d 100644
--- a/tensorflow/python/framework/tensor_util.py
+++ b/tensorflow/python/framework/tensor_util.py
@@ -367,7 +367,7 @@ def make_tensor_proto(values, dtype=None, shape=None, verify_shape=False):
A `TensorProto`. Depending on the type, it may contain data in the
"tensor_content" attribute, which is not directly useful to Python programs.
To access the values you should convert the proto back to a numpy ndarray
- with `tensor_util.MakeNdarray(proto)`.
+ with `tf.make_ndarray(proto)`.
If `values` is a `TensorProto`, it is immediately returned; `dtype` and
`shape` are ignored.
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index b5388ad0b2..3b63e49a84 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -535,15 +535,16 @@ def assert_no_new_tensors(f):
tensors_before = set(
id(obj) for obj in gc.get_objects() if _is_tensorflow_object(obj))
- if context.executing_eagerly():
- f(self, **kwargs)
- ops.reset_default_graph()
- else:
- # Run the test in a new graph so that collections get cleared when it's
- # done, but inherit the graph key so optimizers behave.
- outside_graph_key = ops.get_default_graph()._graph_key
- with ops.Graph().as_default():
- ops.get_default_graph()._graph_key = outside_graph_key
+ outside_executed_eagerly = context.executing_eagerly()
+ # Run the test in a new graph so that collections get cleared when it's
+ # done, but inherit the graph key so optimizers behave.
+ outside_graph_key = ops.get_default_graph()._graph_key
+ with ops.Graph().as_default():
+ ops.get_default_graph()._graph_key = outside_graph_key
+ if outside_executed_eagerly:
+ with context.eager_mode():
+ f(self, **kwargs)
+ else:
f(self, **kwargs)
# Make an effort to clear caches, which would otherwise look like leaked
# Tensors.
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index b52ab7f05c..7768caeaf0 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -443,13 +443,7 @@ def get_session():
session = default_session
else:
if _SESSION is None:
- if not os.environ.get('OMP_NUM_THREADS'):
- config = config_pb2.ConfigProto(allow_soft_placement=True)
- else:
- num_thread = int(os.environ.get('OMP_NUM_THREADS'))
- config = config_pb2.ConfigProto(
- intra_op_parallelism_threads=num_thread, allow_soft_placement=True)
- _SESSION = session_module.Session(config=config)
+ _SESSION = session_module.Session(config=get_default_session_config())
session = _SESSION
if not _MANUAL_VAR_INIT:
with session.graph.as_default():
@@ -468,6 +462,16 @@ def set_session(session):
_SESSION = session
+def get_default_session_config():
+ if not os.environ.get('OMP_NUM_THREADS'):
+ config = config_pb2.ConfigProto(allow_soft_placement=True)
+ else:
+ num_thread = int(os.environ.get('OMP_NUM_THREADS'))
+ config = config_pb2.ConfigProto(
+ intra_op_parallelism_threads=num_thread, allow_soft_placement=True)
+ return config
+
+
# DEVICE MANIPULATION
diff --git a/tensorflow/python/keras/engine/distributed_training_utils.py b/tensorflow/python/keras/engine/distributed_training_utils.py
index fcb073322c..c1c4970025 100644
--- a/tensorflow/python/keras/engine/distributed_training_utils.py
+++ b/tensorflow/python/keras/engine/distributed_training_utils.py
@@ -17,8 +17,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.client import session as session_module
from tensorflow.python.framework import tensor_util
-from tensorflow.python.keras import backend
+from tensorflow.python.keras import backend as K
from tensorflow.python.keras import callbacks
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import distribute as distribute_lib
@@ -46,7 +47,7 @@ def set_weights(distribution_strategy, dist_model, weights):
assign_ops.append(distribution_strategy.unwrap(sw.assign(w)))
weights = weights[num_param:]
- backend.get_session().run(assign_ops)
+ K.get_session().run(assign_ops)
def unwrap_values(distribution_strategy, grouped_inputs, grouped_outputs,
@@ -269,3 +270,20 @@ def validate_all_tensor_shapes(x, x_values):
if x_shape != x_values[i].get_shape().as_list():
raise ValueError('Input tensor shapes do not match for distributed tensor'
' inputs {}'.format(x))
+
+
+def configure_and_create_session(distribution_strategy):
+ """Configure session config and create a session with it."""
+ # TODO(priyag): Throw error if a session already exists.
+ session_config = K.get_default_session_config()
+ distribution_strategy.configure(session_config)
+
+ if distribution_strategy.__class__.__name__ == 'TPUStrategy':
+ # TODO(priyag): Remove this workaround when Distributed Coordinator is
+ # integrated with keras and we can create a session from there.
+ master = distribution_strategy._tpu_cluster_resolver.master() # pylint: disable=protected-access
+ session = session_module.Session(config=session_config, target=master)
+ else:
+ session = session_module.Session(config=session_config)
+
+ K.set_session(session)
diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py
index cd74e36e68..f8c23ed124 100644
--- a/tensorflow/python/keras/engine/network.py
+++ b/tensorflow/python/keras/engine/network.py
@@ -1355,7 +1355,9 @@ class Network(base_layer.Layer):
```
"""
if not self._is_graph_network:
- raise NotImplementedError
+ raise NotImplementedError(
+ 'Currently `save` requires model to be a graph network. Consider '
+ 'using `save_weights`, in order to save the weights of the model.')
from tensorflow.python.keras.models import save_model # pylint: disable=g-import-not-at-top
save_model(self, filepath, overwrite, include_optimizer)
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index 85d25411b4..966b446f22 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -405,20 +405,9 @@ class Model(Network):
# Set DistributionStrategy specific parameters.
self._distribution_strategy = distribute
if self._distribution_strategy is not None:
- self._grouped_model = self._compile_distributed_model(
+ self._grouped_model = None
+ distributed_training_utils.configure_and_create_session(
self._distribution_strategy)
- with self._distribution_strategy.scope():
- first_replicated_model = self._distribution_strategy.unwrap(
- self._grouped_model)[0]
- # If the specified metrics in `compile` are stateful, raise an error
- # since we currently don't support stateful metrics.
- if first_replicated_model.stateful_metric_names:
- raise NotImplementedError('Stateful metrics are not supported with '
- 'DistributionStrategy.')
-
- # We initialize the callback model with the first replicated model.
- self._replicated_model = DistributedCallbackModel(first_replicated_model)
- self._replicated_model.set_original_model(self)
if not self.built:
# Model is not compilable because it does not know its number of inputs
# and outputs, nor their shapes and names. We will compile after the first
@@ -636,6 +625,12 @@ class Model(Network):
skip_target_indices=skip_target_indices,
sample_weights=self.sample_weights)
+ # If using distribution strategy and stateful_metrics, raise an error
+ # since we currently don't support stateful metrics.
+ if self._distribution_strategy is not None and self.stateful_metric_names:
+ raise NotImplementedError('Stateful metrics are not supported with '
+ 'DistributionStrategy.')
+
# Prepare gradient updates and state updates.
self.total_loss = total_loss
@@ -652,19 +647,6 @@ class Model(Network):
trainable_weights = self.trainable_weights
self._collected_trainable_weights = trainable_weights
- def _compile_distributed_model(self, distribution_strategy):
- # TODO(anjalisridhar): Can we move the clone_and_build_model to outside the
- # model?
- def _clone_model_per_tower(model):
- new_model = training_distributed.clone_and_build_model(model)
- return new_model
-
- with distribution_strategy.scope():
- # Create a copy of this model on each of the devices.
- grouped_models = distribution_strategy.call_for_each_tower(
- _clone_model_per_tower, self)
- return grouped_models
-
def _check_trainable_weights_consistency(self):
"""Check trainable weights count consistency.
@@ -790,10 +772,7 @@ class Model(Network):
Fraction of the training data to be used as validation data.
Returns:
- A tuple of 3 lists: input arrays, target arrays, sample-weight arrays.
- If the model's input and targets are symbolic, these lists are empty
- (since the model takes no user-provided data, instead the data comes
- from the symbolic inputs/targets).
+ Iterator for reading the dataset `x`.
Raises:
ValueError: In case of invalid user-provided data.
@@ -828,30 +807,7 @@ class Model(Network):
training_utils.validate_iterator_input(x, y, sample_weight,
validation_split)
- # x an y may be PerDevice objects with an input and output tensor
- # corresponding to each device. For example, x could be
- # PerDevice:{device: get_next tensor,...}.
- next_element = iterator.get_next()
-
- if not isinstance(next_element, (list, tuple)) or len(next_element) != 2:
- raise ValueError('Please provide model inputs as a list or tuple of 2 '
- 'elements: input and target pair. '
- 'Received %s' % next_element)
- x, y = next_element
- # Validate that all the elements in x and y are of the same type and shape.
- # We can then pass the first element of x and y to `_standardize_weights`
- # below and be confident of the output. We need to reopen the scope since
- # we unwrap values when we validate x and y.
- with self._distribution_strategy.scope():
- x_values, y_values = distributed_training_utils.\
- validate_distributed_dataset_inputs(self._distribution_strategy, x, y)
-
- _, _, sample_weights = self._standardize_weights(x_values,
- y_values,
- sample_weight,
- class_weight,
- batch_size)
- return x, y, sample_weights
+ return iterator
def _standardize_user_data(self,
x,
@@ -916,7 +872,7 @@ class Model(Network):
RuntimeError: If the model was never compiled.
"""
if self._distribution_strategy:
- return self._distribution_standardize_user_data(
+ iterator = self._distribution_standardize_user_data(
x,
y,
sample_weight=sample_weight,
@@ -926,6 +882,7 @@ class Model(Network):
steps_name=steps_name,
steps=steps,
validation_split=validation_split)
+ return iterator, None, None
if isinstance(x, dataset_ops.Dataset):
if context.executing_eagerly():
@@ -982,6 +939,7 @@ class Model(Network):
def _standardize_weights(self, x, y, sample_weight=None, class_weight=None,
batch_size=None,):
+ # TODO(sourabhbajaj): Split input validation from weight standardization.
if sample_weight is not None and class_weight is not None:
logging.warning(
'Received both a `sample_weight` and `class_weight` argument. '
@@ -1566,12 +1524,11 @@ class Model(Network):
validation_steps=validation_steps)
elif self._distribution_strategy:
return training_distributed.fit_loop(
- self, x, y,
+ self, x,
epochs=epochs,
verbose=verbose,
callbacks=callbacks,
- val_inputs=val_x,
- val_targets=val_y,
+ val_iterator=val_x,
initial_epoch=initial_epoch,
steps_per_epoch=steps_per_epoch,
validation_steps=validation_steps)
@@ -1677,8 +1634,7 @@ class Model(Network):
elif self._distribution_strategy:
return training_distributed.test_loop(
self,
- inputs=x,
- targets=y,
+ iterator=x,
verbose=verbose,
steps=steps)
else:
@@ -2188,6 +2144,13 @@ class Model(Network):
return self.callback_model
return self
+ def _make_callback_model(self):
+ first_replicated_model = self._distribution_strategy.unwrap(
+ self._grouped_model)[0]
+ # We initialize the callback model with the first replicated model.
+ self._replicated_model = DistributedCallbackModel(first_replicated_model)
+ self._replicated_model.set_original_model(self)
+
class DistributedCallbackModel(Model):
"""Model that is used for callbacks with DistributionStrategy."""
@@ -2225,6 +2188,6 @@ class DistributedCallbackModel(Model):
# Whitelisted atttributes of the model that can be accessed by the user
# during a callback.
if item not in ['_setattr_tracking']:
- logging.warning('You are accessing attribute ' + item + 'of the'
- 'DistributedCallbackModel that may not have been set'
+ logging.warning('You are accessing attribute ' + item + 'of the '
+ 'DistributedCallbackModel that may not have been set '
'correctly.')
diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py
index 85f1d6299f..a7bb1f8177 100644
--- a/tensorflow/python/keras/engine/training_distributed.py
+++ b/tensorflow/python/keras/engine/training_distributed.py
@@ -30,13 +30,11 @@ from tensorflow.python.platform import tf_logging as logging
def fit_loop(
model,
- inputs,
- targets,
+ iterator,
epochs=100,
verbose=1,
callbacks=None,
- val_inputs=None,
- val_targets=None,
+ val_iterator=None,
initial_epoch=0,
steps_per_epoch=None,
validation_steps=None):
@@ -44,13 +42,11 @@ def fit_loop(
Arguments:
model: Keras Model instance.
- inputs: List of input arrays.
- targets: List of target arrays.
+ iterator: Iterator for input data.
epochs: Number of times to iterate over the data
verbose: Verbosity mode, 0, 1 or 2
callbacks: List of callbacks to be called during training
- val_inputs: List of input arrays.
- val_targets: List of target arrays.
+ val_iterator: Iterator for validation data.
initial_epoch: Epoch at which to start training
(useful for resuming a previous training run)
steps_per_epoch: Total number of steps (batches of samples)
@@ -67,6 +63,10 @@ def fit_loop(
ValueError: in case of invalid arguments.
"""
current_strategy = model._distribution_strategy
+
+ clone_model_on_towers(
+ model, current_strategy, make_callback_model=True)
+
def _per_device_train_function(model):
model._make_train_function()
return (model.train_function.inputs,
@@ -74,6 +74,7 @@ def fit_loop(
model.train_function.updates_op,
model.train_function.session_kwargs)
+ inputs, targets = _get_input_from_iterator(iterator, model)
with current_strategy.scope():
# Create train ops on each of the devices when we call
# `_per_device_train_function`.
@@ -169,8 +170,7 @@ def fit_loop(
if do_validation:
val_outs = test_loop(
model,
- val_inputs,
- val_targets,
+ val_iterator,
steps=validation_steps,
verbose=0)
if not isinstance(val_outs, list):
@@ -192,13 +192,12 @@ def fit_loop(
return model.history
-def test_loop(model, inputs, targets, verbose=0, steps=None):
+def test_loop(model, iterator, verbose=0, steps=None):
"""evaluate method to validate a model that uses DistributionStrategy.
Arguments:
model: Keras Model instance.
- inputs: List of input arrays.
- targets: List of target arrays.
+ iterator: Iterator for input data.
verbose: verbosity mode.
steps: Total number of steps (batches of samples)
before declaring predictions finished.
@@ -211,6 +210,9 @@ def test_loop(model, inputs, targets, verbose=0, steps=None):
the display labels for the scalar outputs.
"""
current_strategy = model._distribution_strategy
+
+ clone_model_on_towers(model, current_strategy)
+
def _per_device_test_function(model):
model._make_test_function()
return (model.test_function.inputs,
@@ -218,6 +220,7 @@ def test_loop(model, inputs, targets, verbose=0, steps=None):
model.test_function.updates_op,
model.test_function.session_kwargs)
+ inputs, targets = _get_input_from_iterator(iterator, model)
with current_strategy.scope():
(grouped_inputs, grouped_outputs, grouped_updates,
grouped_session_args) = current_strategy.call_for_each_tower(
@@ -284,12 +287,12 @@ def test_loop(model, inputs, targets, verbose=0, steps=None):
return outs
-def predict_loop(model, inputs, verbose=0, steps=None):
+def predict_loop(model, iterator, verbose=0, steps=None):
"""Abstract method to loop over some data in batches.
Arguments:
model: Keras Model instance.
- inputs: list of tensors to be fed to `f`.
+ iterator: Iterator for input data.
verbose: verbosity mode.
steps: Total number of steps (batches of samples)
before declaring `_predict_loop` finished.
@@ -301,6 +304,9 @@ def predict_loop(model, inputs, verbose=0, steps=None):
(if the model has multiple outputs).
"""
current_strategy = model._distribution_strategy
+
+ clone_model_on_towers(model, current_strategy)
+
def _per_device_predict_function(model):
model._make_predict_function()
return (model.predict_function.inputs,
@@ -308,6 +314,7 @@ def predict_loop(model, inputs, verbose=0, steps=None):
model.predict_function.updates_op,
model.predict_function.session_kwargs)
+ inputs, _ = _get_input_from_iterator(iterator, model)
with current_strategy.scope():
(grouped_inputs, grouped_outputs, grouped_updates,
grouped_session_args) = current_strategy.call_for_each_tower(
@@ -366,7 +373,7 @@ def predict_loop(model, inputs, verbose=0, steps=None):
]
-def clone_and_build_model(model):
+def _clone_and_build_model(model):
"""Clone and build the given keras_model."""
# We need to set the import here since we run into a circular dependency
# error.
@@ -390,6 +397,16 @@ def clone_and_build_model(model):
return cloned_model
+def clone_model_on_towers(model, strategy, make_callback_model=False):
+ """Create a cloned model on each tower, unless already created."""
+ if not model._grouped_model:
+ with strategy.scope():
+ model._grouped_model = strategy.call_for_each_tower(
+ _clone_and_build_model, model)
+ if make_callback_model:
+ model._make_callback_model()
+
+
def _aggregate_metrics_across_towers(num_devices, out_labels, outs):
"""Aggregate metrics values across all towers.
@@ -419,3 +436,25 @@ def _aggregate_metrics_across_towers(num_devices, out_labels, outs):
merged_output.append(m)
current_index += num_devices
return merged_output
+
+
+def _get_input_from_iterator(iterator, model):
+ """Get elements from the iterator and verify the input shape and type."""
+ next_element = iterator.get_next()
+ # TODO(anjalisridhar): Support predict input correctly as it will not contain
+ # targets, only inputs.
+ if not isinstance(next_element, (list, tuple)) or len(next_element) != 2:
+ raise ValueError('Please provide model inputs as a list or tuple of 2 '
+ 'elements: input and target pair. '
+ 'Received %s' % next_element)
+
+ x, y = next_element
+ # Validate that all the elements in x and y are of the same type and shape.
+ # We can then pass the first element of x and y to `_standardize_weights`
+ # below and be confident of the output.
+ x_values, y_values = distributed_training_utils.\
+ validate_distributed_dataset_inputs(model._distribution_strategy, x, y)
+ # TODO(sourabhbajaj): Add support for sample weights in distribution
+ # strategy.
+ model._standardize_weights(x_values, y_values)
+ return x, y
diff --git a/tensorflow/python/kernel_tests/check_ops_test.py b/tensorflow/python/kernel_tests/check_ops_test.py
index 05f998d0d2..680d0c97cc 100644
--- a/tensorflow/python/kernel_tests/check_ops_test.py
+++ b/tensorflow/python/kernel_tests/check_ops_test.py
@@ -116,7 +116,7 @@ class AssertEqualTest(test.TestCase):
check_ops.assert_equal(static_big, static_small, message="fail")
def test_raises_when_greater_dynamic(self):
- with self.test_session():
+ with self.cached_session():
small = array_ops.placeholder(dtypes.int32, name="small")
big = array_ops.placeholder(dtypes.int32, name="big")
with ops.control_dependencies(
@@ -194,7 +194,7 @@ First 2 elements of y:
check_ops.assert_equal(static_big, static_small, message="fail")
def test_raises_when_less_dynamic(self):
- with self.test_session():
+ with self.cached_session():
small = array_ops.placeholder(dtypes.int32, name="small")
big = array_ops.placeholder(dtypes.int32, name="big")
with ops.control_dependencies([check_ops.assert_equal(small, big)]):
@@ -271,30 +271,28 @@ class AssertNoneEqualTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_raises_when_not_equal_but_non_broadcastable_shapes(self):
- with self.test_session():
- small = constant_op.constant([1, 1, 1], name="small")
- big = constant_op.constant([10, 10], name="big")
- # The exception in eager and non-eager mode is different because
- # eager mode relies on shape check done as part of the C++ op, while
- # graph mode does shape checks when creating the `Operation` instance.
- with self.assertRaisesRegexp(
- (ValueError, errors.InvalidArgumentError),
- (r"Incompatible shapes: \[3\] vs. \[2\]|"
- r"Dimensions must be equal, but are 3 and 2")):
- with ops.control_dependencies(
- [check_ops.assert_none_equal(small, big)]):
- out = array_ops.identity(small)
- self.evaluate(out)
+ small = constant_op.constant([1, 1, 1], name="small")
+ big = constant_op.constant([10, 10], name="big")
+ # The exception in eager and non-eager mode is different because
+ # eager mode relies on shape check done as part of the C++ op, while
+ # graph mode does shape checks when creating the `Operation` instance.
+ with self.assertRaisesRegexp(
+ (ValueError, errors.InvalidArgumentError),
+ (r"Incompatible shapes: \[3\] vs. \[2\]|"
+ r"Dimensions must be equal, but are 3 and 2")):
+ with ops.control_dependencies(
+ [check_ops.assert_none_equal(small, big)]):
+ out = array_ops.identity(small)
+ self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_both_empty(self):
- with self.test_session():
- larry = constant_op.constant([])
- curly = constant_op.constant([])
- with ops.control_dependencies(
- [check_ops.assert_none_equal(larry, curly)]):
- out = array_ops.identity(larry)
- self.evaluate(out)
+ larry = constant_op.constant([])
+ curly = constant_op.constant([])
+ with ops.control_dependencies(
+ [check_ops.assert_none_equal(larry, curly)]):
+ out = array_ops.identity(larry)
+ self.evaluate(out)
def test_returns_none_with_eager(self):
with context.eager_mode():
@@ -905,7 +903,7 @@ class AssertRankTest(test.TestCase):
self.evaluate(array_ops.identity(tensor))
def test_rank_zero_tensor_raises_if_rank_too_small_dynamic_rank(self):
- with self.test_session():
+ with self.cached_session():
tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
desired_rank = 1
with ops.control_dependencies(
@@ -923,7 +921,7 @@ class AssertRankTest(test.TestCase):
self.evaluate(array_ops.identity(tensor))
def test_rank_zero_tensor_doesnt_raise_if_rank_just_right_dynamic_rank(self):
- with self.test_session():
+ with self.cached_session():
tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
desired_rank = 0
with ops.control_dependencies(
@@ -940,7 +938,7 @@ class AssertRankTest(test.TestCase):
self.evaluate(array_ops.identity(tensor))
def test_rank_one_tensor_raises_if_rank_too_large_dynamic_rank(self):
- with self.test_session():
+ with self.cached_session():
tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
desired_rank = 0
with ops.control_dependencies(
@@ -957,7 +955,7 @@ class AssertRankTest(test.TestCase):
self.evaluate(array_ops.identity(tensor))
def test_rank_one_tensor_doesnt_raise_if_rank_just_right_dynamic_rank(self):
- with self.test_session():
+ with self.cached_session():
tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
desired_rank = 1
with ops.control_dependencies(
@@ -974,7 +972,7 @@ class AssertRankTest(test.TestCase):
self.evaluate(array_ops.identity(tensor))
def test_rank_one_tensor_raises_if_rank_too_small_dynamic_rank(self):
- with self.test_session():
+ with self.cached_session():
tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
desired_rank = 2
with ops.control_dependencies(
@@ -989,7 +987,7 @@ class AssertRankTest(test.TestCase):
check_ops.assert_rank(tensor, np.array([], dtype=np.int32))
def test_raises_if_rank_is_not_scalar_dynamic(self):
- with self.test_session():
+ with self.cached_session():
tensor = constant_op.constant(
[1, 2], dtype=dtypes.float32, name="my_tensor")
rank_tensor = array_ops.placeholder(dtypes.int32, name="rank_tensor")
@@ -1006,7 +1004,7 @@ class AssertRankTest(test.TestCase):
check_ops.assert_rank(tensor, .5)
def test_raises_if_rank_is_not_integer_dynamic(self):
- with self.test_session():
+ with self.cached_session():
tensor = constant_op.constant(
[1, 2], dtype=dtypes.float32, name="my_tensor")
rank_tensor = array_ops.placeholder(dtypes.float32, name="rank_tensor")
@@ -1029,7 +1027,7 @@ class AssertRankInTest(test.TestCase):
self.evaluate(array_ops.identity(tensor_rank0))
def test_rank_zero_tensor_raises_if_rank_mismatch_dynamic_rank(self):
- with self.test_session():
+ with self.cached_session():
tensor_rank0 = array_ops.placeholder(dtypes.float32, name="my_tensor")
with ops.control_dependencies([
check_ops.assert_rank_in(tensor_rank0, (1, 2), message="fail")]):
@@ -1045,7 +1043,7 @@ class AssertRankInTest(test.TestCase):
self.evaluate(array_ops.identity(tensor_rank0))
def test_rank_zero_tensor_doesnt_raise_if_rank_matches_dynamic_rank(self):
- with self.test_session():
+ with self.cached_session():
tensor_rank0 = array_ops.placeholder(dtypes.float32, name="my_tensor")
for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)):
with ops.control_dependencies([
@@ -1061,7 +1059,7 @@ class AssertRankInTest(test.TestCase):
self.evaluate(array_ops.identity(tensor_rank1))
def test_rank_one_tensor_doesnt_raise_if_rank_matches_dynamic_rank(self):
- with self.test_session():
+ with self.cached_session():
tensor_rank1 = array_ops.placeholder(dtypes.float32, name="my_tensor")
for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)):
with ops.control_dependencies([
@@ -1079,7 +1077,7 @@ class AssertRankInTest(test.TestCase):
self.evaluate(array_ops.identity(tensor_rank1))
def test_rank_one_tensor_raises_if_rank_mismatches_dynamic_rank(self):
- with self.test_session():
+ with self.cached_session():
tensor_rank1 = array_ops.placeholder(dtypes.float32, name="my_tensor")
with ops.control_dependencies([
check_ops.assert_rank_in(tensor_rank1, (0, 2))]):
@@ -1098,7 +1096,7 @@ class AssertRankInTest(test.TestCase):
check_ops.assert_rank_in(tensor, desired_ranks)
def test_raises_if_rank_is_not_scalar_dynamic(self):
- with self.test_session():
+ with self.cached_session():
tensor = constant_op.constant(
(42, 43), dtype=dtypes.float32, name="my_tensor")
desired_ranks = (
@@ -1120,7 +1118,7 @@ class AssertRankInTest(test.TestCase):
check_ops.assert_rank_in(tensor, (1, .5,))
def test_raises_if_rank_is_not_integer_dynamic(self):
- with self.test_session():
+ with self.cached_session():
tensor = constant_op.constant(
(42, 43), dtype=dtypes.float32, name="my_tensor")
rank_tensor = array_ops.placeholder(dtypes.float32, name="rank_tensor")
@@ -1143,7 +1141,7 @@ class AssertRankAtLeastTest(test.TestCase):
self.evaluate(array_ops.identity(tensor))
def test_rank_zero_tensor_raises_if_rank_too_small_dynamic_rank(self):
- with self.test_session():
+ with self.cached_session():
tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
desired_rank = 1
with ops.control_dependencies(
@@ -1160,7 +1158,7 @@ class AssertRankAtLeastTest(test.TestCase):
self.evaluate(array_ops.identity(tensor))
def test_rank_zero_tensor_doesnt_raise_if_rank_just_right_dynamic_rank(self):
- with self.test_session():
+ with self.cached_session():
tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
desired_rank = 0
with ops.control_dependencies(
@@ -1176,7 +1174,7 @@ class AssertRankAtLeastTest(test.TestCase):
self.evaluate(array_ops.identity(tensor))
def test_rank_one_ten_doesnt_raise_if_rank_too_large_dynamic_rank(self):
- with self.test_session():
+ with self.cached_session():
tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
desired_rank = 0
with ops.control_dependencies(
@@ -1192,7 +1190,7 @@ class AssertRankAtLeastTest(test.TestCase):
self.evaluate(array_ops.identity(tensor))
def test_rank_one_tensor_doesnt_raise_if_rank_just_right_dynamic_rank(self):
- with self.test_session():
+ with self.cached_session():
tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
desired_rank = 1
with ops.control_dependencies(
@@ -1209,7 +1207,7 @@ class AssertRankAtLeastTest(test.TestCase):
self.evaluate(array_ops.identity(tensor))
def test_rank_one_tensor_raises_if_rank_too_small_dynamic_rank(self):
- with self.test_session():
+ with self.cached_session():
tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
desired_rank = 2
with ops.control_dependencies(
diff --git a/tensorflow/python/kernel_tests/distributions/bernoulli_test.py b/tensorflow/python/kernel_tests/distributions/bernoulli_test.py
index 9ad77a54cb..26d013bccb 100644
--- a/tensorflow/python/kernel_tests/distributions/bernoulli_test.py
+++ b/tensorflow/python/kernel_tests/distributions/bernoulli_test.py
@@ -62,59 +62,50 @@ class BernoulliTest(test.TestCase):
def testP(self):
p = [0.2, 0.4]
dist = bernoulli.Bernoulli(probs=p)
- with self.test_session():
- self.assertAllClose(p, self.evaluate(dist.probs))
+ self.assertAllClose(p, self.evaluate(dist.probs))
@test_util.run_in_graph_and_eager_modes
def testLogits(self):
logits = [-42., 42.]
dist = bernoulli.Bernoulli(logits=logits)
- with self.test_session():
- self.assertAllClose(logits, self.evaluate(dist.logits))
+ self.assertAllClose(logits, self.evaluate(dist.logits))
if not special:
return
- with self.test_session():
- self.assertAllClose(special.expit(logits), self.evaluate(dist.probs))
+ self.assertAllClose(special.expit(logits), self.evaluate(dist.probs))
p = [0.01, 0.99, 0.42]
dist = bernoulli.Bernoulli(probs=p)
- with self.test_session():
- self.assertAllClose(special.logit(p), self.evaluate(dist.logits))
+ self.assertAllClose(special.logit(p), self.evaluate(dist.logits))
@test_util.run_in_graph_and_eager_modes
def testInvalidP(self):
invalid_ps = [1.01, 2.]
for p in invalid_ps:
- with self.test_session():
- with self.assertRaisesOpError("probs has components greater than 1"):
- dist = bernoulli.Bernoulli(probs=p, validate_args=True)
- self.evaluate(dist.probs)
+ with self.assertRaisesOpError("probs has components greater than 1"):
+ dist = bernoulli.Bernoulli(probs=p, validate_args=True)
+ self.evaluate(dist.probs)
invalid_ps = [-0.01, -3.]
for p in invalid_ps:
- with self.test_session():
- with self.assertRaisesOpError("Condition x >= 0"):
- dist = bernoulli.Bernoulli(probs=p, validate_args=True)
- self.evaluate(dist.probs)
+ with self.assertRaisesOpError("Condition x >= 0"):
+ dist = bernoulli.Bernoulli(probs=p, validate_args=True)
+ self.evaluate(dist.probs)
valid_ps = [0.0, 0.5, 1.0]
for p in valid_ps:
- with self.test_session():
- dist = bernoulli.Bernoulli(probs=p)
- self.assertEqual(p, self.evaluate(dist.probs)) # Should not fail
+ dist = bernoulli.Bernoulli(probs=p)
+ self.assertEqual(p, self.evaluate(dist.probs)) # Should not fail
@test_util.run_in_graph_and_eager_modes
def testShapes(self):
- with self.test_session():
- for batch_shape in ([], [1], [2, 3, 4]):
- dist = make_bernoulli(batch_shape)
- self.assertAllEqual(batch_shape, dist.batch_shape.as_list())
- self.assertAllEqual(batch_shape,
- self.evaluate(dist.batch_shape_tensor()))
- self.assertAllEqual([], dist.event_shape.as_list())
- self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
+ for batch_shape in ([], [1], [2, 3, 4]):
+ dist = make_bernoulli(batch_shape)
+ self.assertAllEqual(batch_shape, dist.batch_shape.as_list())
+ self.assertAllEqual(batch_shape, self.evaluate(dist.batch_shape_tensor()))
+ self.assertAllEqual([], dist.event_shape.as_list())
+ self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
@test_util.run_in_graph_and_eager_modes
def testDtype(self):
@@ -137,31 +128,29 @@ class BernoulliTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def _testPmf(self, **kwargs):
dist = bernoulli.Bernoulli(**kwargs)
- with self.test_session():
- # pylint: disable=bad-continuation
- xs = [
- 0,
- [1],
- [1, 0],
- [[1, 0]],
- [[1, 0], [1, 1]],
- ]
- expected_pmfs = [
- [[0.8, 0.6], [0.7, 0.4]],
- [[0.2, 0.4], [0.3, 0.6]],
- [[0.2, 0.6], [0.3, 0.4]],
- [[0.2, 0.6], [0.3, 0.4]],
- [[0.2, 0.6], [0.3, 0.6]],
- ]
- # pylint: enable=bad-continuation
-
- for x, expected_pmf in zip(xs, expected_pmfs):
- self.assertAllClose(self.evaluate(dist.prob(x)), expected_pmf)
- self.assertAllClose(
- self.evaluate(dist.log_prob(x)), np.log(expected_pmf))
+ # pylint: disable=bad-continuation
+ xs = [
+ 0,
+ [1],
+ [1, 0],
+ [[1, 0]],
+ [[1, 0], [1, 1]],
+ ]
+ expected_pmfs = [
+ [[0.8, 0.6], [0.7, 0.4]],
+ [[0.2, 0.4], [0.3, 0.6]],
+ [[0.2, 0.6], [0.3, 0.4]],
+ [[0.2, 0.6], [0.3, 0.4]],
+ [[0.2, 0.6], [0.3, 0.6]],
+ ]
+ # pylint: enable=bad-continuation
+
+ for x, expected_pmf in zip(xs, expected_pmfs):
+ self.assertAllClose(self.evaluate(dist.prob(x)), expected_pmf)
+ self.assertAllClose(self.evaluate(dist.log_prob(x)), np.log(expected_pmf))
def testPmfCorrectBroadcastDynamicShape(self):
- with self.test_session():
+ with self.cached_session():
p = array_ops.placeholder(dtype=dtypes.float32)
dist = bernoulli.Bernoulli(probs=p)
event1 = [1, 0, 1]
@@ -178,12 +167,11 @@ class BernoulliTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testPmfInvalid(self):
p = [0.1, 0.2, 0.7]
- with self.test_session():
- dist = bernoulli.Bernoulli(probs=p, validate_args=True)
- with self.assertRaisesOpError("must be non-negative."):
- self.evaluate(dist.prob([1, 1, -1]))
- with self.assertRaisesOpError("Elements cannot exceed 1."):
- self.evaluate(dist.prob([2, 0, 1]))
+ dist = bernoulli.Bernoulli(probs=p, validate_args=True)
+ with self.assertRaisesOpError("must be non-negative."):
+ self.evaluate(dist.prob([1, 1, -1]))
+ with self.assertRaisesOpError("Elements cannot exceed 1."):
+ self.evaluate(dist.prob([2, 0, 1]))
@test_util.run_in_graph_and_eager_modes
def testPmfWithP(self):
@@ -194,7 +182,7 @@ class BernoulliTest(test.TestCase):
self._testPmf(logits=special.logit(p))
def testBroadcasting(self):
- with self.test_session():
+ with self.cached_session():
p = array_ops.placeholder(dtypes.float32)
dist = bernoulli.Bernoulli(probs=p)
self.assertAllClose(np.log(0.5), dist.log_prob(1).eval({p: 0.5}))
@@ -208,70 +196,63 @@ class BernoulliTest(test.TestCase):
}))
def testPmfShapes(self):
- with self.test_session():
+ with self.cached_session():
p = array_ops.placeholder(dtypes.float32, shape=[None, 1])
dist = bernoulli.Bernoulli(probs=p)
self.assertEqual(2, len(dist.log_prob(1).eval({p: [[0.5], [0.5]]}).shape))
- with self.test_session():
dist = bernoulli.Bernoulli(probs=0.5)
self.assertEqual(2, len(self.evaluate(dist.log_prob([[1], [1]])).shape))
- with self.test_session():
dist = bernoulli.Bernoulli(probs=0.5)
self.assertEqual((), dist.log_prob(1).get_shape())
self.assertEqual((1), dist.log_prob([1]).get_shape())
self.assertEqual((2, 1), dist.log_prob([[1], [1]]).get_shape())
- with self.test_session():
dist = bernoulli.Bernoulli(probs=[[0.5], [0.5]])
self.assertEqual((2, 1), dist.log_prob(1).get_shape())
@test_util.run_in_graph_and_eager_modes
def testBoundaryConditions(self):
- with self.test_session():
- dist = bernoulli.Bernoulli(probs=1.0)
- self.assertAllClose(np.nan, self.evaluate(dist.log_prob(0)))
- self.assertAllClose([np.nan], [self.evaluate(dist.log_prob(1))])
+ dist = bernoulli.Bernoulli(probs=1.0)
+ self.assertAllClose(np.nan, self.evaluate(dist.log_prob(0)))
+ self.assertAllClose([np.nan], [self.evaluate(dist.log_prob(1))])
@test_util.run_in_graph_and_eager_modes
def testEntropyNoBatch(self):
p = 0.2
dist = bernoulli.Bernoulli(probs=p)
- with self.test_session():
- self.assertAllClose(self.evaluate(dist.entropy()), entropy(p))
+ self.assertAllClose(self.evaluate(dist.entropy()), entropy(p))
@test_util.run_in_graph_and_eager_modes
def testEntropyWithBatch(self):
p = [[0.1, 0.7], [0.2, 0.6]]
dist = bernoulli.Bernoulli(probs=p, validate_args=False)
- with self.test_session():
- self.assertAllClose(
- self.evaluate(dist.entropy()),
- [[entropy(0.1), entropy(0.7)], [entropy(0.2),
- entropy(0.6)]])
+ self.assertAllClose(
+ self.evaluate(dist.entropy()),
+ [[entropy(0.1), entropy(0.7)], [entropy(0.2),
+ entropy(0.6)]])
@test_util.run_in_graph_and_eager_modes
def testSampleN(self):
- with self.test_session():
- p = [0.2, 0.6]
- dist = bernoulli.Bernoulli(probs=p)
- n = 100000
- samples = dist.sample(n)
- samples.set_shape([n, 2])
- self.assertEqual(samples.dtype, dtypes.int32)
- sample_values = self.evaluate(samples)
- self.assertTrue(np.all(sample_values >= 0))
- self.assertTrue(np.all(sample_values <= 1))
- # Note that the standard error for the sample mean is ~ sqrt(p * (1 - p) /
- # n). This means that the tolerance is very sensitive to the value of p
- # as well as n.
- self.assertAllClose(p, np.mean(sample_values, axis=0), atol=1e-2)
- self.assertEqual(set([0, 1]), set(sample_values.flatten()))
- # In this test we're just interested in verifying there isn't a crash
- # owing to mismatched types. b/30940152
- dist = bernoulli.Bernoulli(np.log([.2, .4]))
- self.assertAllEqual((1, 2), dist.sample(1, seed=42).get_shape().as_list())
+ p = [0.2, 0.6]
+ dist = bernoulli.Bernoulli(probs=p)
+ n = 100000
+ samples = dist.sample(n)
+ samples.set_shape([n, 2])
+ self.assertEqual(samples.dtype, dtypes.int32)
+ sample_values = self.evaluate(samples)
+ self.assertTrue(np.all(sample_values >= 0))
+ self.assertTrue(np.all(sample_values <= 1))
+ # Note that the standard error for the sample mean is ~ sqrt(p * (1 - p) /
+ # n). This means that the tolerance is very sensitive to the value of p
+ # as well as n.
+ self.assertAllClose(p, np.mean(sample_values, axis=0), atol=1e-2)
+ self.assertEqual(set([0, 1]), set(sample_values.flatten()))
+ # In this test we're just interested in verifying there isn't a crash
+ # owing to mismatched types. b/30940152
+ dist = bernoulli.Bernoulli(np.log([.2, .4]))
+ self.assertAllEqual((1, 2), dist.sample(1, seed=42).get_shape().as_list())
@test_util.run_in_graph_and_eager_modes
def testNotReparameterized(self):
@@ -284,7 +265,7 @@ class BernoulliTest(test.TestCase):
self.assertIsNone(grad_p)
def testSampleActsLikeSampleN(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
p = [0.2, 0.6]
dist = bernoulli.Bernoulli(probs=p)
n = 1000
@@ -299,27 +280,24 @@ class BernoulliTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testMean(self):
- with self.test_session():
- p = np.array([[0.2, 0.7], [0.5, 0.4]], dtype=np.float32)
- dist = bernoulli.Bernoulli(probs=p)
- self.assertAllEqual(self.evaluate(dist.mean()), p)
+ p = np.array([[0.2, 0.7], [0.5, 0.4]], dtype=np.float32)
+ dist = bernoulli.Bernoulli(probs=p)
+ self.assertAllEqual(self.evaluate(dist.mean()), p)
@test_util.run_in_graph_and_eager_modes
def testVarianceAndStd(self):
var = lambda p: p * (1. - p)
- with self.test_session():
- p = [[0.2, 0.7], [0.5, 0.4]]
- dist = bernoulli.Bernoulli(probs=p)
- self.assertAllClose(
- self.evaluate(dist.variance()),
- np.array(
- [[var(0.2), var(0.7)], [var(0.5), var(0.4)]], dtype=np.float32))
- self.assertAllClose(
- self.evaluate(dist.stddev()),
- np.array(
- [[np.sqrt(var(0.2)), np.sqrt(var(0.7))],
- [np.sqrt(var(0.5)), np.sqrt(var(0.4))]],
- dtype=np.float32))
+ p = [[0.2, 0.7], [0.5, 0.4]]
+ dist = bernoulli.Bernoulli(probs=p)
+ self.assertAllClose(
+ self.evaluate(dist.variance()),
+ np.array([[var(0.2), var(0.7)], [var(0.5), var(0.4)]],
+ dtype=np.float32))
+ self.assertAllClose(
+ self.evaluate(dist.stddev()),
+ np.array([[np.sqrt(var(0.2)), np.sqrt(var(0.7))],
+ [np.sqrt(var(0.5)), np.sqrt(var(0.4))]],
+ dtype=np.float32))
@test_util.run_in_graph_and_eager_modes
def testBernoulliBernoulliKL(self):
diff --git a/tensorflow/python/kernel_tests/distributions/beta_test.py b/tensorflow/python/kernel_tests/distributions/beta_test.py
index 36f3ffc333..d580a415dd 100644
--- a/tensorflow/python/kernel_tests/distributions/beta_test.py
+++ b/tensorflow/python/kernel_tests/distributions/beta_test.py
@@ -20,7 +20,6 @@ import importlib
import numpy as np
-from tensorflow.python.client import session
from tensorflow.python.eager import backprop
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import random_seed
@@ -51,237 +50,215 @@ stats = try_import("scipy.stats")
class BetaTest(test.TestCase):
def testSimpleShapes(self):
- with self.test_session():
- a = np.random.rand(3)
- b = np.random.rand(3)
- dist = beta_lib.Beta(a, b)
- self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
- self.assertAllEqual([3], self.evaluate(dist.batch_shape_tensor()))
- self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape)
- self.assertEqual(tensor_shape.TensorShape([3]), dist.batch_shape)
+ a = np.random.rand(3)
+ b = np.random.rand(3)
+ dist = beta_lib.Beta(a, b)
+ self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
+ self.assertAllEqual([3], self.evaluate(dist.batch_shape_tensor()))
+ self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape)
+ self.assertEqual(tensor_shape.TensorShape([3]), dist.batch_shape)
def testComplexShapes(self):
- with self.test_session():
- a = np.random.rand(3, 2, 2)
- b = np.random.rand(3, 2, 2)
- dist = beta_lib.Beta(a, b)
- self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
- self.assertAllEqual([3, 2, 2], self.evaluate(dist.batch_shape_tensor()))
- self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape)
- self.assertEqual(
- tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape)
+ a = np.random.rand(3, 2, 2)
+ b = np.random.rand(3, 2, 2)
+ dist = beta_lib.Beta(a, b)
+ self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
+ self.assertAllEqual([3, 2, 2], self.evaluate(dist.batch_shape_tensor()))
+ self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape)
+ self.assertEqual(tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape)
def testComplexShapesBroadcast(self):
- with self.test_session():
- a = np.random.rand(3, 2, 2)
- b = np.random.rand(2, 2)
- dist = beta_lib.Beta(a, b)
- self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
- self.assertAllEqual([3, 2, 2], self.evaluate(dist.batch_shape_tensor()))
- self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape)
- self.assertEqual(
- tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape)
+ a = np.random.rand(3, 2, 2)
+ b = np.random.rand(2, 2)
+ dist = beta_lib.Beta(a, b)
+ self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
+ self.assertAllEqual([3, 2, 2], self.evaluate(dist.batch_shape_tensor()))
+ self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape)
+ self.assertEqual(tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape)
def testAlphaProperty(self):
a = [[1., 2, 3]]
b = [[2., 4, 3]]
- with self.test_session():
- dist = beta_lib.Beta(a, b)
- self.assertEqual([1, 3], dist.concentration1.get_shape())
- self.assertAllClose(a, self.evaluate(dist.concentration1))
+ dist = beta_lib.Beta(a, b)
+ self.assertEqual([1, 3], dist.concentration1.get_shape())
+ self.assertAllClose(a, self.evaluate(dist.concentration1))
def testBetaProperty(self):
a = [[1., 2, 3]]
b = [[2., 4, 3]]
- with self.test_session():
- dist = beta_lib.Beta(a, b)
- self.assertEqual([1, 3], dist.concentration0.get_shape())
- self.assertAllClose(b, self.evaluate(dist.concentration0))
+ dist = beta_lib.Beta(a, b)
+ self.assertEqual([1, 3], dist.concentration0.get_shape())
+ self.assertAllClose(b, self.evaluate(dist.concentration0))
def testPdfXProper(self):
a = [[1., 2, 3]]
b = [[2., 4, 3]]
- with self.test_session():
- dist = beta_lib.Beta(a, b, validate_args=True)
- self.evaluate(dist.prob([.1, .3, .6]))
- self.evaluate(dist.prob([.2, .3, .5]))
- # Either condition can trigger.
- with self.assertRaisesOpError("sample must be positive"):
- self.evaluate(dist.prob([-1., 0.1, 0.5]))
- with self.assertRaisesOpError("sample must be positive"):
- self.evaluate(dist.prob([0., 0.1, 0.5]))
- with self.assertRaisesOpError("sample must be less than `1`"):
- self.evaluate(dist.prob([.1, .2, 1.2]))
- with self.assertRaisesOpError("sample must be less than `1`"):
- self.evaluate(dist.prob([.1, .2, 1.0]))
+ dist = beta_lib.Beta(a, b, validate_args=True)
+ self.evaluate(dist.prob([.1, .3, .6]))
+ self.evaluate(dist.prob([.2, .3, .5]))
+ # Either condition can trigger.
+ with self.assertRaisesOpError("sample must be positive"):
+ self.evaluate(dist.prob([-1., 0.1, 0.5]))
+ with self.assertRaisesOpError("sample must be positive"):
+ self.evaluate(dist.prob([0., 0.1, 0.5]))
+ with self.assertRaisesOpError("sample must be less than `1`"):
+ self.evaluate(dist.prob([.1, .2, 1.2]))
+ with self.assertRaisesOpError("sample must be less than `1`"):
+ self.evaluate(dist.prob([.1, .2, 1.0]))
def testPdfTwoBatches(self):
- with self.test_session():
- a = [1., 2]
- b = [1., 2]
- x = [.5, .5]
- dist = beta_lib.Beta(a, b)
- pdf = dist.prob(x)
- self.assertAllClose([1., 3. / 2], self.evaluate(pdf))
- self.assertEqual((2,), pdf.get_shape())
+ a = [1., 2]
+ b = [1., 2]
+ x = [.5, .5]
+ dist = beta_lib.Beta(a, b)
+ pdf = dist.prob(x)
+ self.assertAllClose([1., 3. / 2], self.evaluate(pdf))
+ self.assertEqual((2,), pdf.get_shape())
def testPdfTwoBatchesNontrivialX(self):
- with self.test_session():
- a = [1., 2]
- b = [1., 2]
- x = [.3, .7]
- dist = beta_lib.Beta(a, b)
- pdf = dist.prob(x)
- self.assertAllClose([1, 63. / 50], self.evaluate(pdf))
- self.assertEqual((2,), pdf.get_shape())
+ a = [1., 2]
+ b = [1., 2]
+ x = [.3, .7]
+ dist = beta_lib.Beta(a, b)
+ pdf = dist.prob(x)
+ self.assertAllClose([1, 63. / 50], self.evaluate(pdf))
+ self.assertEqual((2,), pdf.get_shape())
def testPdfUniformZeroBatch(self):
- with self.test_session():
- # This is equivalent to a uniform distribution
- a = 1.
- b = 1.
- x = np.array([.1, .2, .3, .5, .8], dtype=np.float32)
- dist = beta_lib.Beta(a, b)
- pdf = dist.prob(x)
- self.assertAllClose([1.] * 5, self.evaluate(pdf))
- self.assertEqual((5,), pdf.get_shape())
+ # This is equivalent to a uniform distribution
+ a = 1.
+ b = 1.
+ x = np.array([.1, .2, .3, .5, .8], dtype=np.float32)
+ dist = beta_lib.Beta(a, b)
+ pdf = dist.prob(x)
+ self.assertAllClose([1.] * 5, self.evaluate(pdf))
+ self.assertEqual((5,), pdf.get_shape())
def testPdfAlphaStretchedInBroadcastWhenSameRank(self):
- with self.test_session():
- a = [[1., 2]]
- b = [[1., 2]]
- x = [[.5, .5], [.3, .7]]
- dist = beta_lib.Beta(a, b)
- pdf = dist.prob(x)
- self.assertAllClose([[1., 3. / 2], [1., 63. / 50]], self.evaluate(pdf))
- self.assertEqual((2, 2), pdf.get_shape())
+ a = [[1., 2]]
+ b = [[1., 2]]
+ x = [[.5, .5], [.3, .7]]
+ dist = beta_lib.Beta(a, b)
+ pdf = dist.prob(x)
+ self.assertAllClose([[1., 3. / 2], [1., 63. / 50]], self.evaluate(pdf))
+ self.assertEqual((2, 2), pdf.get_shape())
def testPdfAlphaStretchedInBroadcastWhenLowerRank(self):
- with self.test_session():
- a = [1., 2]
- b = [1., 2]
- x = [[.5, .5], [.2, .8]]
- pdf = beta_lib.Beta(a, b).prob(x)
- self.assertAllClose([[1., 3. / 2], [1., 24. / 25]], self.evaluate(pdf))
- self.assertEqual((2, 2), pdf.get_shape())
+ a = [1., 2]
+ b = [1., 2]
+ x = [[.5, .5], [.2, .8]]
+ pdf = beta_lib.Beta(a, b).prob(x)
+ self.assertAllClose([[1., 3. / 2], [1., 24. / 25]], self.evaluate(pdf))
+ self.assertEqual((2, 2), pdf.get_shape())
def testPdfXStretchedInBroadcastWhenSameRank(self):
- with self.test_session():
- a = [[1., 2], [2., 3]]
- b = [[1., 2], [2., 3]]
- x = [[.5, .5]]
- pdf = beta_lib.Beta(a, b).prob(x)
- self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]], self.evaluate(pdf))
- self.assertEqual((2, 2), pdf.get_shape())
+ a = [[1., 2], [2., 3]]
+ b = [[1., 2], [2., 3]]
+ x = [[.5, .5]]
+ pdf = beta_lib.Beta(a, b).prob(x)
+ self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]], self.evaluate(pdf))
+ self.assertEqual((2, 2), pdf.get_shape())
def testPdfXStretchedInBroadcastWhenLowerRank(self):
- with self.test_session():
- a = [[1., 2], [2., 3]]
- b = [[1., 2], [2., 3]]
- x = [.5, .5]
- pdf = beta_lib.Beta(a, b).prob(x)
- self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]], self.evaluate(pdf))
- self.assertEqual((2, 2), pdf.get_shape())
+ a = [[1., 2], [2., 3]]
+ b = [[1., 2], [2., 3]]
+ x = [.5, .5]
+ pdf = beta_lib.Beta(a, b).prob(x)
+ self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]], self.evaluate(pdf))
+ self.assertEqual((2, 2), pdf.get_shape())
def testBetaMean(self):
- with session.Session():
- a = [1., 2, 3]
- b = [2., 4, 1.2]
- dist = beta_lib.Beta(a, b)
- self.assertEqual(dist.mean().get_shape(), (3,))
- if not stats:
- return
- expected_mean = stats.beta.mean(a, b)
- self.assertAllClose(expected_mean, self.evaluate(dist.mean()))
+ a = [1., 2, 3]
+ b = [2., 4, 1.2]
+ dist = beta_lib.Beta(a, b)
+ self.assertEqual(dist.mean().get_shape(), (3,))
+ if not stats:
+ return
+ expected_mean = stats.beta.mean(a, b)
+ self.assertAllClose(expected_mean, self.evaluate(dist.mean()))
def testBetaVariance(self):
- with session.Session():
- a = [1., 2, 3]
- b = [2., 4, 1.2]
- dist = beta_lib.Beta(a, b)
- self.assertEqual(dist.variance().get_shape(), (3,))
- if not stats:
- return
- expected_variance = stats.beta.var(a, b)
- self.assertAllClose(expected_variance, self.evaluate(dist.variance()))
+ a = [1., 2, 3]
+ b = [2., 4, 1.2]
+ dist = beta_lib.Beta(a, b)
+ self.assertEqual(dist.variance().get_shape(), (3,))
+ if not stats:
+ return
+ expected_variance = stats.beta.var(a, b)
+ self.assertAllClose(expected_variance, self.evaluate(dist.variance()))
def testBetaMode(self):
- with session.Session():
- a = np.array([1.1, 2, 3])
- b = np.array([2., 4, 1.2])
- expected_mode = (a - 1) / (a + b - 2)
- dist = beta_lib.Beta(a, b)
- self.assertEqual(dist.mode().get_shape(), (3,))
- self.assertAllClose(expected_mode, self.evaluate(dist.mode()))
+ a = np.array([1.1, 2, 3])
+ b = np.array([2., 4, 1.2])
+ expected_mode = (a - 1) / (a + b - 2)
+ dist = beta_lib.Beta(a, b)
+ self.assertEqual(dist.mode().get_shape(), (3,))
+ self.assertAllClose(expected_mode, self.evaluate(dist.mode()))
def testBetaModeInvalid(self):
- with session.Session():
- a = np.array([1., 2, 3])
- b = np.array([2., 4, 1.2])
- dist = beta_lib.Beta(a, b, allow_nan_stats=False)
- with self.assertRaisesOpError("Condition x < y.*"):
- self.evaluate(dist.mode())
-
- a = np.array([2., 2, 3])
- b = np.array([1., 4, 1.2])
- dist = beta_lib.Beta(a, b, allow_nan_stats=False)
- with self.assertRaisesOpError("Condition x < y.*"):
- self.evaluate(dist.mode())
+ a = np.array([1., 2, 3])
+ b = np.array([2., 4, 1.2])
+ dist = beta_lib.Beta(a, b, allow_nan_stats=False)
+ with self.assertRaisesOpError("Condition x < y.*"):
+ self.evaluate(dist.mode())
+
+ a = np.array([2., 2, 3])
+ b = np.array([1., 4, 1.2])
+ dist = beta_lib.Beta(a, b, allow_nan_stats=False)
+ with self.assertRaisesOpError("Condition x < y.*"):
+ self.evaluate(dist.mode())
def testBetaModeEnableAllowNanStats(self):
- with session.Session():
- a = np.array([1., 2, 3])
- b = np.array([2., 4, 1.2])
- dist = beta_lib.Beta(a, b, allow_nan_stats=True)
+ a = np.array([1., 2, 3])
+ b = np.array([2., 4, 1.2])
+ dist = beta_lib.Beta(a, b, allow_nan_stats=True)
- expected_mode = (a - 1) / (a + b - 2)
- expected_mode[0] = np.nan
- self.assertEqual((3,), dist.mode().get_shape())
- self.assertAllClose(expected_mode, self.evaluate(dist.mode()))
+ expected_mode = (a - 1) / (a + b - 2)
+ expected_mode[0] = np.nan
+ self.assertEqual((3,), dist.mode().get_shape())
+ self.assertAllClose(expected_mode, self.evaluate(dist.mode()))
- a = np.array([2., 2, 3])
- b = np.array([1., 4, 1.2])
- dist = beta_lib.Beta(a, b, allow_nan_stats=True)
+ a = np.array([2., 2, 3])
+ b = np.array([1., 4, 1.2])
+ dist = beta_lib.Beta(a, b, allow_nan_stats=True)
- expected_mode = (a - 1) / (a + b - 2)
- expected_mode[0] = np.nan
- self.assertEqual((3,), dist.mode().get_shape())
- self.assertAllClose(expected_mode, self.evaluate(dist.mode()))
+ expected_mode = (a - 1) / (a + b - 2)
+ expected_mode[0] = np.nan
+ self.assertEqual((3,), dist.mode().get_shape())
+ self.assertAllClose(expected_mode, self.evaluate(dist.mode()))
def testBetaEntropy(self):
- with session.Session():
- a = [1., 2, 3]
- b = [2., 4, 1.2]
- dist = beta_lib.Beta(a, b)
- self.assertEqual(dist.entropy().get_shape(), (3,))
- if not stats:
- return
- expected_entropy = stats.beta.entropy(a, b)
- self.assertAllClose(expected_entropy, self.evaluate(dist.entropy()))
+ a = [1., 2, 3]
+ b = [2., 4, 1.2]
+ dist = beta_lib.Beta(a, b)
+ self.assertEqual(dist.entropy().get_shape(), (3,))
+ if not stats:
+ return
+ expected_entropy = stats.beta.entropy(a, b)
+ self.assertAllClose(expected_entropy, self.evaluate(dist.entropy()))
def testBetaSample(self):
- with self.test_session():
- a = 1.
- b = 2.
- beta = beta_lib.Beta(a, b)
- n = constant_op.constant(100000)
- samples = beta.sample(n)
- sample_values = self.evaluate(samples)
- self.assertEqual(sample_values.shape, (100000,))
- self.assertFalse(np.any(sample_values < 0.0))
- if not stats:
- return
- self.assertLess(
- stats.kstest(
- # Beta is a univariate distribution.
- sample_values,
- stats.beta(a=1., b=2.).cdf)[0],
- 0.01)
- # The standard error of the sample mean is 1 / (sqrt(18 * n))
- self.assertAllClose(
- sample_values.mean(axis=0), stats.beta.mean(a, b), atol=1e-2)
- self.assertAllClose(
- np.cov(sample_values, rowvar=0), stats.beta.var(a, b), atol=1e-1)
+ a = 1.
+ b = 2.
+ beta = beta_lib.Beta(a, b)
+ n = constant_op.constant(100000)
+ samples = beta.sample(n)
+ sample_values = self.evaluate(samples)
+ self.assertEqual(sample_values.shape, (100000,))
+ self.assertFalse(np.any(sample_values < 0.0))
+ if not stats:
+ return
+ self.assertLess(
+ stats.kstest(
+ # Beta is a univariate distribution.
+ sample_values,
+ stats.beta(a=1., b=2.).cdf)[0],
+ 0.01)
+ # The standard error of the sample mean is 1 / (sqrt(18 * n))
+ self.assertAllClose(
+ sample_values.mean(axis=0), stats.beta.mean(a, b), atol=1e-2)
+ self.assertAllClose(
+ np.cov(sample_values, rowvar=0), stats.beta.var(a, b), atol=1e-1)
def testBetaFullyReparameterized(self):
a = constant_op.constant(1.0)
@@ -297,78 +274,71 @@ class BetaTest(test.TestCase):
# Test that sampling with the same seed twice gives the same results.
def testBetaSampleMultipleTimes(self):
- with self.test_session():
- a_val = 1.
- b_val = 2.
- n_val = 100
+ a_val = 1.
+ b_val = 2.
+ n_val = 100
- random_seed.set_random_seed(654321)
- beta1 = beta_lib.Beta(concentration1=a_val,
- concentration0=b_val,
- name="beta1")
- samples1 = self.evaluate(beta1.sample(n_val, seed=123456))
+ random_seed.set_random_seed(654321)
+ beta1 = beta_lib.Beta(
+ concentration1=a_val, concentration0=b_val, name="beta1")
+ samples1 = self.evaluate(beta1.sample(n_val, seed=123456))
- random_seed.set_random_seed(654321)
- beta2 = beta_lib.Beta(concentration1=a_val,
- concentration0=b_val,
- name="beta2")
- samples2 = self.evaluate(beta2.sample(n_val, seed=123456))
+ random_seed.set_random_seed(654321)
+ beta2 = beta_lib.Beta(
+ concentration1=a_val, concentration0=b_val, name="beta2")
+ samples2 = self.evaluate(beta2.sample(n_val, seed=123456))
- self.assertAllClose(samples1, samples2)
+ self.assertAllClose(samples1, samples2)
def testBetaSampleMultidimensional(self):
- with self.test_session():
- a = np.random.rand(3, 2, 2).astype(np.float32)
- b = np.random.rand(3, 2, 2).astype(np.float32)
- beta = beta_lib.Beta(a, b)
- n = constant_op.constant(100000)
- samples = beta.sample(n)
- sample_values = self.evaluate(samples)
- self.assertEqual(sample_values.shape, (100000, 3, 2, 2))
- self.assertFalse(np.any(sample_values < 0.0))
- if not stats:
- return
- self.assertAllClose(
- sample_values[:, 1, :].mean(axis=0),
- stats.beta.mean(a, b)[1, :],
- atol=1e-1)
+ a = np.random.rand(3, 2, 2).astype(np.float32)
+ b = np.random.rand(3, 2, 2).astype(np.float32)
+ beta = beta_lib.Beta(a, b)
+ n = constant_op.constant(100000)
+ samples = beta.sample(n)
+ sample_values = self.evaluate(samples)
+ self.assertEqual(sample_values.shape, (100000, 3, 2, 2))
+ self.assertFalse(np.any(sample_values < 0.0))
+ if not stats:
+ return
+ self.assertAllClose(
+ sample_values[:, 1, :].mean(axis=0),
+ stats.beta.mean(a, b)[1, :],
+ atol=1e-1)
def testBetaCdf(self):
- with self.test_session():
- shape = (30, 40, 50)
- for dt in (np.float32, np.float64):
- a = 10. * np.random.random(shape).astype(dt)
- b = 10. * np.random.random(shape).astype(dt)
- x = np.random.random(shape).astype(dt)
- actual = self.evaluate(beta_lib.Beta(a, b).cdf(x))
- self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x)
- self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x)
- if not stats:
- return
- self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0)
+ shape = (30, 40, 50)
+ for dt in (np.float32, np.float64):
+ a = 10. * np.random.random(shape).astype(dt)
+ b = 10. * np.random.random(shape).astype(dt)
+ x = np.random.random(shape).astype(dt)
+ actual = self.evaluate(beta_lib.Beta(a, b).cdf(x))
+ self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x)
+ self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x)
+ if not stats:
+ return
+ self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0)
def testBetaLogCdf(self):
- with self.test_session():
- shape = (30, 40, 50)
- for dt in (np.float32, np.float64):
- a = 10. * np.random.random(shape).astype(dt)
- b = 10. * np.random.random(shape).astype(dt)
- x = np.random.random(shape).astype(dt)
- actual = self.evaluate(math_ops.exp(beta_lib.Beta(a, b).log_cdf(x)))
- self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x)
- self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x)
- if not stats:
- return
- self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0)
+ shape = (30, 40, 50)
+ for dt in (np.float32, np.float64):
+ a = 10. * np.random.random(shape).astype(dt)
+ b = 10. * np.random.random(shape).astype(dt)
+ x = np.random.random(shape).astype(dt)
+ actual = self.evaluate(math_ops.exp(beta_lib.Beta(a, b).log_cdf(x)))
+ self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x)
+ self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x)
+ if not stats:
+ return
+ self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0)
def testBetaWithSoftplusConcentration(self):
- with self.test_session():
- a, b = -4.2, -9.1
- dist = beta_lib.BetaWithSoftplusConcentration(a, b)
- self.assertAllClose(
- self.evaluate(nn_ops.softplus(a)), self.evaluate(dist.concentration1))
- self.assertAllClose(
- self.evaluate(nn_ops.softplus(b)), self.evaluate(dist.concentration0))
+ a, b = -4.2, -9.1
+ dist = beta_lib.BetaWithSoftplusConcentration(a, b)
+ self.assertAllClose(
+ self.evaluate(nn_ops.softplus(a)), self.evaluate(dist.concentration1))
+ self.assertAllClose(
+ self.evaluate(nn_ops.softplus(b)), self.evaluate(dist.concentration0))
def testBetaBetaKL(self):
for shape in [(10,), (4, 5)]:
diff --git a/tensorflow/python/kernel_tests/distributions/bijector_test.py b/tensorflow/python/kernel_tests/distributions/bijector_test.py
index 8b11556330..e20f59f48a 100644
--- a/tensorflow/python/kernel_tests/distributions/bijector_test.py
+++ b/tensorflow/python/kernel_tests/distributions/bijector_test.py
@@ -36,11 +36,10 @@ class BaseBijectorTest(test.TestCase):
"""Tests properties of the Bijector base-class."""
def testIsAbstract(self):
- with self.test_session():
- with self.assertRaisesRegexp(TypeError,
- ("Can't instantiate abstract class Bijector "
- "with abstract methods __init__")):
- bijector.Bijector() # pylint: disable=abstract-class-instantiated
+ with self.assertRaisesRegexp(TypeError,
+ ("Can't instantiate abstract class Bijector "
+ "with abstract methods __init__")):
+ bijector.Bijector() # pylint: disable=abstract-class-instantiated
def testDefaults(self):
class _BareBonesBijector(bijector.Bijector):
@@ -136,7 +135,7 @@ class BijectorTestEventNdims(test.TestCase):
def testBijectorDynamicEventNdims(self):
bij = BrokenBijector(validate_args=True)
event_ndims = array_ops.placeholder(dtype=np.int32, shape=None)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("Expected scalar"):
bij.forward_log_det_jacobian(1., event_ndims=event_ndims).eval({
event_ndims: (1, 2)})
@@ -308,7 +307,7 @@ class BijectorReduceEventDimsTest(test.TestCase):
event_ndims = array_ops.placeholder(dtype=np.int32, shape=[])
bij = ExpOnlyJacobian(forward_min_event_ndims=1)
bij.inverse_log_det_jacobian(x, event_ndims=event_ndims)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
ildj = sess.run(bij.inverse_log_det_jacobian(x, event_ndims=event_ndims),
feed_dict={event_ndims: 1})
self.assertAllClose(-np.log(x_), ildj)
diff --git a/tensorflow/python/kernel_tests/distributions/dirichlet_test.py b/tensorflow/python/kernel_tests/distributions/dirichlet_test.py
index 67ed0447ed..cace5b3ba2 100644
--- a/tensorflow/python/kernel_tests/distributions/dirichlet_test.py
+++ b/tensorflow/python/kernel_tests/distributions/dirichlet_test.py
@@ -49,115 +49,102 @@ stats = try_import("scipy.stats")
class DirichletTest(test.TestCase):
def testSimpleShapes(self):
- with self.test_session():
- alpha = np.random.rand(3)
- dist = dirichlet_lib.Dirichlet(alpha)
- self.assertEqual(3, self.evaluate(dist.event_shape_tensor()))
- self.assertAllEqual([], self.evaluate(dist.batch_shape_tensor()))
- self.assertEqual(tensor_shape.TensorShape([3]), dist.event_shape)
- self.assertEqual(tensor_shape.TensorShape([]), dist.batch_shape)
+ alpha = np.random.rand(3)
+ dist = dirichlet_lib.Dirichlet(alpha)
+ self.assertEqual(3, self.evaluate(dist.event_shape_tensor()))
+ self.assertAllEqual([], self.evaluate(dist.batch_shape_tensor()))
+ self.assertEqual(tensor_shape.TensorShape([3]), dist.event_shape)
+ self.assertEqual(tensor_shape.TensorShape([]), dist.batch_shape)
def testComplexShapes(self):
- with self.test_session():
- alpha = np.random.rand(3, 2, 2)
- dist = dirichlet_lib.Dirichlet(alpha)
- self.assertEqual(2, self.evaluate(dist.event_shape_tensor()))
- self.assertAllEqual([3, 2], self.evaluate(dist.batch_shape_tensor()))
- self.assertEqual(tensor_shape.TensorShape([2]), dist.event_shape)
- self.assertEqual(tensor_shape.TensorShape([3, 2]), dist.batch_shape)
+ alpha = np.random.rand(3, 2, 2)
+ dist = dirichlet_lib.Dirichlet(alpha)
+ self.assertEqual(2, self.evaluate(dist.event_shape_tensor()))
+ self.assertAllEqual([3, 2], self.evaluate(dist.batch_shape_tensor()))
+ self.assertEqual(tensor_shape.TensorShape([2]), dist.event_shape)
+ self.assertEqual(tensor_shape.TensorShape([3, 2]), dist.batch_shape)
def testConcentrationProperty(self):
alpha = [[1., 2, 3]]
- with self.test_session():
- dist = dirichlet_lib.Dirichlet(alpha)
- self.assertEqual([1, 3], dist.concentration.get_shape())
- self.assertAllClose(alpha, self.evaluate(dist.concentration))
+ dist = dirichlet_lib.Dirichlet(alpha)
+ self.assertEqual([1, 3], dist.concentration.get_shape())
+ self.assertAllClose(alpha, self.evaluate(dist.concentration))
def testPdfXProper(self):
alpha = [[1., 2, 3]]
- with self.test_session():
- dist = dirichlet_lib.Dirichlet(alpha, validate_args=True)
- self.evaluate(dist.prob([.1, .3, .6]))
- self.evaluate(dist.prob([.2, .3, .5]))
- # Either condition can trigger.
- with self.assertRaisesOpError("samples must be positive"):
- self.evaluate(dist.prob([-1., 1.5, 0.5]))
- with self.assertRaisesOpError("samples must be positive"):
- self.evaluate(dist.prob([0., .1, .9]))
- with self.assertRaisesOpError(
- "sample last-dimension must sum to `1`"):
- self.evaluate(dist.prob([.1, .2, .8]))
+ dist = dirichlet_lib.Dirichlet(alpha, validate_args=True)
+ self.evaluate(dist.prob([.1, .3, .6]))
+ self.evaluate(dist.prob([.2, .3, .5]))
+ # Either condition can trigger.
+ with self.assertRaisesOpError("samples must be positive"):
+ self.evaluate(dist.prob([-1., 1.5, 0.5]))
+ with self.assertRaisesOpError("samples must be positive"):
+ self.evaluate(dist.prob([0., .1, .9]))
+ with self.assertRaisesOpError("sample last-dimension must sum to `1`"):
+ self.evaluate(dist.prob([.1, .2, .8]))
def testPdfZeroBatches(self):
- with self.test_session():
- alpha = [1., 2]
- x = [.5, .5]
- dist = dirichlet_lib.Dirichlet(alpha)
- pdf = dist.prob(x)
- self.assertAllClose(1., self.evaluate(pdf))
- self.assertEqual((), pdf.get_shape())
+ alpha = [1., 2]
+ x = [.5, .5]
+ dist = dirichlet_lib.Dirichlet(alpha)
+ pdf = dist.prob(x)
+ self.assertAllClose(1., self.evaluate(pdf))
+ self.assertEqual((), pdf.get_shape())
def testPdfZeroBatchesNontrivialX(self):
- with self.test_session():
- alpha = [1., 2]
- x = [.3, .7]
- dist = dirichlet_lib.Dirichlet(alpha)
- pdf = dist.prob(x)
- self.assertAllClose(7. / 5, self.evaluate(pdf))
- self.assertEqual((), pdf.get_shape())
+ alpha = [1., 2]
+ x = [.3, .7]
+ dist = dirichlet_lib.Dirichlet(alpha)
+ pdf = dist.prob(x)
+ self.assertAllClose(7. / 5, self.evaluate(pdf))
+ self.assertEqual((), pdf.get_shape())
def testPdfUniformZeroBatches(self):
- with self.test_session():
- # Corresponds to a uniform distribution
- alpha = [1., 1, 1]
- x = [[.2, .5, .3], [.3, .4, .3]]
- dist = dirichlet_lib.Dirichlet(alpha)
- pdf = dist.prob(x)
- self.assertAllClose([2., 2.], self.evaluate(pdf))
- self.assertEqual((2), pdf.get_shape())
+ # Corresponds to a uniform distribution
+ alpha = [1., 1, 1]
+ x = [[.2, .5, .3], [.3, .4, .3]]
+ dist = dirichlet_lib.Dirichlet(alpha)
+ pdf = dist.prob(x)
+ self.assertAllClose([2., 2.], self.evaluate(pdf))
+ self.assertEqual((2), pdf.get_shape())
def testPdfAlphaStretchedInBroadcastWhenSameRank(self):
- with self.test_session():
- alpha = [[1., 2]]
- x = [[.5, .5], [.3, .7]]
- dist = dirichlet_lib.Dirichlet(alpha)
- pdf = dist.prob(x)
- self.assertAllClose([1., 7. / 5], self.evaluate(pdf))
- self.assertEqual((2), pdf.get_shape())
+ alpha = [[1., 2]]
+ x = [[.5, .5], [.3, .7]]
+ dist = dirichlet_lib.Dirichlet(alpha)
+ pdf = dist.prob(x)
+ self.assertAllClose([1., 7. / 5], self.evaluate(pdf))
+ self.assertEqual((2), pdf.get_shape())
def testPdfAlphaStretchedInBroadcastWhenLowerRank(self):
- with self.test_session():
- alpha = [1., 2]
- x = [[.5, .5], [.2, .8]]
- pdf = dirichlet_lib.Dirichlet(alpha).prob(x)
- self.assertAllClose([1., 8. / 5], self.evaluate(pdf))
- self.assertEqual((2), pdf.get_shape())
+ alpha = [1., 2]
+ x = [[.5, .5], [.2, .8]]
+ pdf = dirichlet_lib.Dirichlet(alpha).prob(x)
+ self.assertAllClose([1., 8. / 5], self.evaluate(pdf))
+ self.assertEqual((2), pdf.get_shape())
def testPdfXStretchedInBroadcastWhenSameRank(self):
- with self.test_session():
- alpha = [[1., 2], [2., 3]]
- x = [[.5, .5]]
- pdf = dirichlet_lib.Dirichlet(alpha).prob(x)
- self.assertAllClose([1., 3. / 2], self.evaluate(pdf))
- self.assertEqual((2), pdf.get_shape())
+ alpha = [[1., 2], [2., 3]]
+ x = [[.5, .5]]
+ pdf = dirichlet_lib.Dirichlet(alpha).prob(x)
+ self.assertAllClose([1., 3. / 2], self.evaluate(pdf))
+ self.assertEqual((2), pdf.get_shape())
def testPdfXStretchedInBroadcastWhenLowerRank(self):
- with self.test_session():
- alpha = [[1., 2], [2., 3]]
- x = [.5, .5]
- pdf = dirichlet_lib.Dirichlet(alpha).prob(x)
- self.assertAllClose([1., 3. / 2], self.evaluate(pdf))
- self.assertEqual((2), pdf.get_shape())
+ alpha = [[1., 2], [2., 3]]
+ x = [.5, .5]
+ pdf = dirichlet_lib.Dirichlet(alpha).prob(x)
+ self.assertAllClose([1., 3. / 2], self.evaluate(pdf))
+ self.assertEqual((2), pdf.get_shape())
def testMean(self):
- with self.test_session():
- alpha = [1., 2, 3]
- dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
- self.assertEqual(dirichlet.mean().get_shape(), [3])
- if not stats:
- return
- expected_mean = stats.dirichlet.mean(alpha)
- self.assertAllClose(self.evaluate(dirichlet.mean()), expected_mean)
+ alpha = [1., 2, 3]
+ dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
+ self.assertEqual(dirichlet.mean().get_shape(), [3])
+ if not stats:
+ return
+ expected_mean = stats.dirichlet.mean(alpha)
+ self.assertAllClose(self.evaluate(dirichlet.mean()), expected_mean)
def testCovarianceFromSampling(self):
alpha = np.array([[1., 2, 3],
@@ -197,73 +184,66 @@ class DirichletTest(test.TestCase):
self.assertAllClose(sample_stddev_, analytic_stddev, atol=0.02, rtol=0.)
def testVariance(self):
- with self.test_session():
- alpha = [1., 2, 3]
- denominator = np.sum(alpha)**2 * (np.sum(alpha) + 1)
- dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
- self.assertEqual(dirichlet.covariance().get_shape(), (3, 3))
- if not stats:
- return
- expected_covariance = np.diag(stats.dirichlet.var(alpha))
- expected_covariance += [[0., -2, -3], [-2, 0, -6],
- [-3, -6, 0]] / denominator
- self.assertAllClose(
- self.evaluate(dirichlet.covariance()), expected_covariance)
+ alpha = [1., 2, 3]
+ denominator = np.sum(alpha)**2 * (np.sum(alpha) + 1)
+ dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
+ self.assertEqual(dirichlet.covariance().get_shape(), (3, 3))
+ if not stats:
+ return
+ expected_covariance = np.diag(stats.dirichlet.var(alpha))
+ expected_covariance += [[0., -2, -3], [-2, 0, -6], [-3, -6, 0]
+ ] / denominator
+ self.assertAllClose(
+ self.evaluate(dirichlet.covariance()), expected_covariance)
def testMode(self):
- with self.test_session():
- alpha = np.array([1.1, 2, 3])
- expected_mode = (alpha - 1) / (np.sum(alpha) - 3)
- dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
- self.assertEqual(dirichlet.mode().get_shape(), [3])
- self.assertAllClose(self.evaluate(dirichlet.mode()), expected_mode)
+ alpha = np.array([1.1, 2, 3])
+ expected_mode = (alpha - 1) / (np.sum(alpha) - 3)
+ dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
+ self.assertEqual(dirichlet.mode().get_shape(), [3])
+ self.assertAllClose(self.evaluate(dirichlet.mode()), expected_mode)
def testModeInvalid(self):
- with self.test_session():
- alpha = np.array([1., 2, 3])
- dirichlet = dirichlet_lib.Dirichlet(concentration=alpha,
- allow_nan_stats=False)
- with self.assertRaisesOpError("Condition x < y.*"):
- self.evaluate(dirichlet.mode())
+ alpha = np.array([1., 2, 3])
+ dirichlet = dirichlet_lib.Dirichlet(
+ concentration=alpha, allow_nan_stats=False)
+ with self.assertRaisesOpError("Condition x < y.*"):
+ self.evaluate(dirichlet.mode())
def testModeEnableAllowNanStats(self):
- with self.test_session():
- alpha = np.array([1., 2, 3])
- dirichlet = dirichlet_lib.Dirichlet(concentration=alpha,
- allow_nan_stats=True)
- expected_mode = np.zeros_like(alpha) + np.nan
+ alpha = np.array([1., 2, 3])
+ dirichlet = dirichlet_lib.Dirichlet(
+ concentration=alpha, allow_nan_stats=True)
+ expected_mode = np.zeros_like(alpha) + np.nan
- self.assertEqual(dirichlet.mode().get_shape(), [3])
- self.assertAllClose(self.evaluate(dirichlet.mode()), expected_mode)
+ self.assertEqual(dirichlet.mode().get_shape(), [3])
+ self.assertAllClose(self.evaluate(dirichlet.mode()), expected_mode)
def testEntropy(self):
- with self.test_session():
- alpha = [1., 2, 3]
- dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
- self.assertEqual(dirichlet.entropy().get_shape(), ())
- if not stats:
- return
- expected_entropy = stats.dirichlet.entropy(alpha)
- self.assertAllClose(self.evaluate(dirichlet.entropy()), expected_entropy)
+ alpha = [1., 2, 3]
+ dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
+ self.assertEqual(dirichlet.entropy().get_shape(), ())
+ if not stats:
+ return
+ expected_entropy = stats.dirichlet.entropy(alpha)
+ self.assertAllClose(self.evaluate(dirichlet.entropy()), expected_entropy)
def testSample(self):
- with self.test_session():
- alpha = [1., 2]
- dirichlet = dirichlet_lib.Dirichlet(alpha)
- n = constant_op.constant(100000)
- samples = dirichlet.sample(n)
- sample_values = self.evaluate(samples)
- self.assertEqual(sample_values.shape, (100000, 2))
- self.assertTrue(np.all(sample_values > 0.0))
- if not stats:
- return
- self.assertLess(
- stats.kstest(
- # Beta is a univariate distribution.
- sample_values[:, 0],
- stats.beta(
- a=1., b=2.).cdf)[0],
- 0.01)
+ alpha = [1., 2]
+ dirichlet = dirichlet_lib.Dirichlet(alpha)
+ n = constant_op.constant(100000)
+ samples = dirichlet.sample(n)
+ sample_values = self.evaluate(samples)
+ self.assertEqual(sample_values.shape, (100000, 2))
+ self.assertTrue(np.all(sample_values > 0.0))
+ if not stats:
+ return
+ self.assertLess(
+ stats.kstest(
+ # Beta is a univariate distribution.
+ sample_values[:, 0],
+ stats.beta(a=1., b=2.).cdf)[0],
+ 0.01)
def testDirichletFullyReparameterized(self):
alpha = constant_op.constant([1.0, 2.0, 3.0])
diff --git a/tensorflow/python/kernel_tests/distributions/exponential_test.py b/tensorflow/python/kernel_tests/distributions/exponential_test.py
index 850da3e969..27d1291912 100644
--- a/tensorflow/python/kernel_tests/distributions/exponential_test.py
+++ b/tensorflow/python/kernel_tests/distributions/exponential_test.py
@@ -22,7 +22,6 @@ import importlib
import numpy as np
-from tensorflow.python.client import session
from tensorflow.python.eager import backprop
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_util
@@ -48,121 +47,108 @@ stats = try_import("scipy.stats")
class ExponentialTest(test.TestCase):
def testExponentialLogPDF(self):
- with session.Session():
- batch_size = 6
- lam = constant_op.constant([2.0] * batch_size)
- lam_v = 2.0
- x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
- exponential = exponential_lib.Exponential(rate=lam)
+ batch_size = 6
+ lam = constant_op.constant([2.0] * batch_size)
+ lam_v = 2.0
+ x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
+ exponential = exponential_lib.Exponential(rate=lam)
- log_pdf = exponential.log_prob(x)
- self.assertEqual(log_pdf.get_shape(), (6,))
+ log_pdf = exponential.log_prob(x)
+ self.assertEqual(log_pdf.get_shape(), (6,))
- pdf = exponential.prob(x)
- self.assertEqual(pdf.get_shape(), (6,))
+ pdf = exponential.prob(x)
+ self.assertEqual(pdf.get_shape(), (6,))
- if not stats:
- return
- expected_log_pdf = stats.expon.logpdf(x, scale=1 / lam_v)
- self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf)
- self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf))
+ if not stats:
+ return
+ expected_log_pdf = stats.expon.logpdf(x, scale=1 / lam_v)
+ self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf)
+ self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf))
def testExponentialCDF(self):
- with session.Session():
- batch_size = 6
- lam = constant_op.constant([2.0] * batch_size)
- lam_v = 2.0
- x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
+ batch_size = 6
+ lam = constant_op.constant([2.0] * batch_size)
+ lam_v = 2.0
+ x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
- exponential = exponential_lib.Exponential(rate=lam)
+ exponential = exponential_lib.Exponential(rate=lam)
- cdf = exponential.cdf(x)
- self.assertEqual(cdf.get_shape(), (6,))
+ cdf = exponential.cdf(x)
+ self.assertEqual(cdf.get_shape(), (6,))
- if not stats:
- return
- expected_cdf = stats.expon.cdf(x, scale=1 / lam_v)
- self.assertAllClose(self.evaluate(cdf), expected_cdf)
+ if not stats:
+ return
+ expected_cdf = stats.expon.cdf(x, scale=1 / lam_v)
+ self.assertAllClose(self.evaluate(cdf), expected_cdf)
def testExponentialMean(self):
- with session.Session():
- lam_v = np.array([1.0, 4.0, 2.5])
- exponential = exponential_lib.Exponential(rate=lam_v)
- self.assertEqual(exponential.mean().get_shape(), (3,))
- if not stats:
- return
- expected_mean = stats.expon.mean(scale=1 / lam_v)
- self.assertAllClose(self.evaluate(exponential.mean()), expected_mean)
+ lam_v = np.array([1.0, 4.0, 2.5])
+ exponential = exponential_lib.Exponential(rate=lam_v)
+ self.assertEqual(exponential.mean().get_shape(), (3,))
+ if not stats:
+ return
+ expected_mean = stats.expon.mean(scale=1 / lam_v)
+ self.assertAllClose(self.evaluate(exponential.mean()), expected_mean)
def testExponentialVariance(self):
- with session.Session():
- lam_v = np.array([1.0, 4.0, 2.5])
- exponential = exponential_lib.Exponential(rate=lam_v)
- self.assertEqual(exponential.variance().get_shape(), (3,))
- if not stats:
- return
- expected_variance = stats.expon.var(scale=1 / lam_v)
- self.assertAllClose(
- self.evaluate(exponential.variance()), expected_variance)
+ lam_v = np.array([1.0, 4.0, 2.5])
+ exponential = exponential_lib.Exponential(rate=lam_v)
+ self.assertEqual(exponential.variance().get_shape(), (3,))
+ if not stats:
+ return
+ expected_variance = stats.expon.var(scale=1 / lam_v)
+ self.assertAllClose(
+ self.evaluate(exponential.variance()), expected_variance)
def testExponentialEntropy(self):
- with session.Session():
- lam_v = np.array([1.0, 4.0, 2.5])
- exponential = exponential_lib.Exponential(rate=lam_v)
- self.assertEqual(exponential.entropy().get_shape(), (3,))
- if not stats:
- return
- expected_entropy = stats.expon.entropy(scale=1 / lam_v)
- self.assertAllClose(
- self.evaluate(exponential.entropy()), expected_entropy)
+ lam_v = np.array([1.0, 4.0, 2.5])
+ exponential = exponential_lib.Exponential(rate=lam_v)
+ self.assertEqual(exponential.entropy().get_shape(), (3,))
+ if not stats:
+ return
+ expected_entropy = stats.expon.entropy(scale=1 / lam_v)
+ self.assertAllClose(self.evaluate(exponential.entropy()), expected_entropy)
def testExponentialSample(self):
- with self.test_session():
- lam = constant_op.constant([3.0, 4.0])
- lam_v = [3.0, 4.0]
- n = constant_op.constant(100000)
- exponential = exponential_lib.Exponential(rate=lam)
-
- samples = exponential.sample(n, seed=137)
- sample_values = self.evaluate(samples)
- self.assertEqual(sample_values.shape, (100000, 2))
- self.assertFalse(np.any(sample_values < 0.0))
- if not stats:
- return
- for i in range(2):
- self.assertLess(
- stats.kstest(
- sample_values[:, i], stats.expon(scale=1.0 / lam_v[i]).cdf)[0],
- 0.01)
+ lam = constant_op.constant([3.0, 4.0])
+ lam_v = [3.0, 4.0]
+ n = constant_op.constant(100000)
+ exponential = exponential_lib.Exponential(rate=lam)
+
+ samples = exponential.sample(n, seed=137)
+ sample_values = self.evaluate(samples)
+ self.assertEqual(sample_values.shape, (100000, 2))
+ self.assertFalse(np.any(sample_values < 0.0))
+ if not stats:
+ return
+ for i in range(2):
+ self.assertLess(
+ stats.kstest(sample_values[:, i],
+ stats.expon(scale=1.0 / lam_v[i]).cdf)[0], 0.01)
def testExponentialSampleMultiDimensional(self):
- with self.test_session():
- batch_size = 2
- lam_v = [3.0, 22.0]
- lam = constant_op.constant([lam_v] * batch_size)
+ batch_size = 2
+ lam_v = [3.0, 22.0]
+ lam = constant_op.constant([lam_v] * batch_size)
- exponential = exponential_lib.Exponential(rate=lam)
+ exponential = exponential_lib.Exponential(rate=lam)
+
+ n = 100000
+ samples = exponential.sample(n, seed=138)
+ self.assertEqual(samples.get_shape(), (n, batch_size, 2))
+
+ sample_values = self.evaluate(samples)
- n = 100000
- samples = exponential.sample(n, seed=138)
- self.assertEqual(samples.get_shape(), (n, batch_size, 2))
-
- sample_values = self.evaluate(samples)
-
- self.assertFalse(np.any(sample_values < 0.0))
- if not stats:
- return
- for i in range(2):
- self.assertLess(
- stats.kstest(
- sample_values[:, 0, i],
- stats.expon(scale=1.0 / lam_v[i]).cdf)[0],
- 0.01)
- self.assertLess(
- stats.kstest(
- sample_values[:, 1, i],
- stats.expon(scale=1.0 / lam_v[i]).cdf)[0],
- 0.01)
+ self.assertFalse(np.any(sample_values < 0.0))
+ if not stats:
+ return
+ for i in range(2):
+ self.assertLess(
+ stats.kstest(sample_values[:, 0, i],
+ stats.expon(scale=1.0 / lam_v[i]).cdf)[0], 0.01)
+ self.assertLess(
+ stats.kstest(sample_values[:, 1, i],
+ stats.expon(scale=1.0 / lam_v[i]).cdf)[0], 0.01)
def testFullyReparameterized(self):
lam = constant_op.constant([0.1, 1.0])
@@ -174,11 +160,10 @@ class ExponentialTest(test.TestCase):
self.assertIsNotNone(grad_lam)
def testExponentialWithSoftplusRate(self):
- with self.test_session():
- lam = [-2.2, -3.4]
- exponential = exponential_lib.ExponentialWithSoftplusRate(rate=lam)
- self.assertAllClose(
- self.evaluate(nn_ops.softplus(lam)), self.evaluate(exponential.rate))
+ lam = [-2.2, -3.4]
+ exponential = exponential_lib.ExponentialWithSoftplusRate(rate=lam)
+ self.assertAllClose(
+ self.evaluate(nn_ops.softplus(lam)), self.evaluate(exponential.rate))
if __name__ == "__main__":
diff --git a/tensorflow/python/kernel_tests/distributions/gamma_test.py b/tensorflow/python/kernel_tests/distributions/gamma_test.py
index 297e20264c..4eff40b029 100644
--- a/tensorflow/python/kernel_tests/distributions/gamma_test.py
+++ b/tensorflow/python/kernel_tests/distributions/gamma_test.py
@@ -50,221 +50,203 @@ stats = try_import("scipy.stats")
class GammaTest(test.TestCase):
def testGammaShape(self):
- with self.test_session():
- alpha = constant_op.constant([3.0] * 5)
- beta = constant_op.constant(11.0)
- gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
+ alpha = constant_op.constant([3.0] * 5)
+ beta = constant_op.constant(11.0)
+ gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
- self.assertEqual(self.evaluate(gamma.batch_shape_tensor()), (5,))
- self.assertEqual(gamma.batch_shape, tensor_shape.TensorShape([5]))
- self.assertAllEqual(self.evaluate(gamma.event_shape_tensor()), [])
- self.assertEqual(gamma.event_shape, tensor_shape.TensorShape([]))
+ self.assertEqual(self.evaluate(gamma.batch_shape_tensor()), (5,))
+ self.assertEqual(gamma.batch_shape, tensor_shape.TensorShape([5]))
+ self.assertAllEqual(self.evaluate(gamma.event_shape_tensor()), [])
+ self.assertEqual(gamma.event_shape, tensor_shape.TensorShape([]))
def testGammaLogPDF(self):
- with self.test_session():
- batch_size = 6
- alpha = constant_op.constant([2.0] * batch_size)
- beta = constant_op.constant([3.0] * batch_size)
- alpha_v = 2.0
- beta_v = 3.0
- x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
- gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
- log_pdf = gamma.log_prob(x)
- self.assertEqual(log_pdf.get_shape(), (6,))
- pdf = gamma.prob(x)
- self.assertEqual(pdf.get_shape(), (6,))
- if not stats:
- return
- expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
- self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf)
- self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf))
+ batch_size = 6
+ alpha = constant_op.constant([2.0] * batch_size)
+ beta = constant_op.constant([3.0] * batch_size)
+ alpha_v = 2.0
+ beta_v = 3.0
+ x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
+ gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
+ log_pdf = gamma.log_prob(x)
+ self.assertEqual(log_pdf.get_shape(), (6,))
+ pdf = gamma.prob(x)
+ self.assertEqual(pdf.get_shape(), (6,))
+ if not stats:
+ return
+ expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
+ self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf)
+ self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf))
def testGammaLogPDFMultidimensional(self):
- with self.test_session():
- batch_size = 6
- alpha = constant_op.constant([[2.0, 4.0]] * batch_size)
- beta = constant_op.constant([[3.0, 4.0]] * batch_size)
- alpha_v = np.array([2.0, 4.0])
- beta_v = np.array([3.0, 4.0])
- x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
- gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
- log_pdf = gamma.log_prob(x)
- log_pdf_values = self.evaluate(log_pdf)
- self.assertEqual(log_pdf.get_shape(), (6, 2))
- pdf = gamma.prob(x)
- pdf_values = self.evaluate(pdf)
- self.assertEqual(pdf.get_shape(), (6, 2))
- if not stats:
- return
- expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
- self.assertAllClose(log_pdf_values, expected_log_pdf)
- self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
+ batch_size = 6
+ alpha = constant_op.constant([[2.0, 4.0]] * batch_size)
+ beta = constant_op.constant([[3.0, 4.0]] * batch_size)
+ alpha_v = np.array([2.0, 4.0])
+ beta_v = np.array([3.0, 4.0])
+ x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
+ gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
+ log_pdf = gamma.log_prob(x)
+ log_pdf_values = self.evaluate(log_pdf)
+ self.assertEqual(log_pdf.get_shape(), (6, 2))
+ pdf = gamma.prob(x)
+ pdf_values = self.evaluate(pdf)
+ self.assertEqual(pdf.get_shape(), (6, 2))
+ if not stats:
+ return
+ expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
+ self.assertAllClose(log_pdf_values, expected_log_pdf)
+ self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
def testGammaLogPDFMultidimensionalBroadcasting(self):
- with self.test_session():
- batch_size = 6
- alpha = constant_op.constant([[2.0, 4.0]] * batch_size)
- beta = constant_op.constant(3.0)
- alpha_v = np.array([2.0, 4.0])
- beta_v = 3.0
- x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
- gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
- log_pdf = gamma.log_prob(x)
- log_pdf_values = self.evaluate(log_pdf)
- self.assertEqual(log_pdf.get_shape(), (6, 2))
- pdf = gamma.prob(x)
- pdf_values = self.evaluate(pdf)
- self.assertEqual(pdf.get_shape(), (6, 2))
-
- if not stats:
- return
- expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
- self.assertAllClose(log_pdf_values, expected_log_pdf)
- self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
+ batch_size = 6
+ alpha = constant_op.constant([[2.0, 4.0]] * batch_size)
+ beta = constant_op.constant(3.0)
+ alpha_v = np.array([2.0, 4.0])
+ beta_v = 3.0
+ x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
+ gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
+ log_pdf = gamma.log_prob(x)
+ log_pdf_values = self.evaluate(log_pdf)
+ self.assertEqual(log_pdf.get_shape(), (6, 2))
+ pdf = gamma.prob(x)
+ pdf_values = self.evaluate(pdf)
+ self.assertEqual(pdf.get_shape(), (6, 2))
- def testGammaCDF(self):
- with self.test_session():
- batch_size = 6
- alpha = constant_op.constant([2.0] * batch_size)
- beta = constant_op.constant([3.0] * batch_size)
- alpha_v = 2.0
- beta_v = 3.0
- x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
+ if not stats:
+ return
+ expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
+ self.assertAllClose(log_pdf_values, expected_log_pdf)
+ self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
- gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
- cdf = gamma.cdf(x)
- self.assertEqual(cdf.get_shape(), (6,))
- if not stats:
- return
- expected_cdf = stats.gamma.cdf(x, alpha_v, scale=1 / beta_v)
- self.assertAllClose(self.evaluate(cdf), expected_cdf)
+ def testGammaCDF(self):
+ batch_size = 6
+ alpha = constant_op.constant([2.0] * batch_size)
+ beta = constant_op.constant([3.0] * batch_size)
+ alpha_v = 2.0
+ beta_v = 3.0
+ x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
+
+ gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
+ cdf = gamma.cdf(x)
+ self.assertEqual(cdf.get_shape(), (6,))
+ if not stats:
+ return
+ expected_cdf = stats.gamma.cdf(x, alpha_v, scale=1 / beta_v)
+ self.assertAllClose(self.evaluate(cdf), expected_cdf)
def testGammaMean(self):
- with self.test_session():
- alpha_v = np.array([1.0, 3.0, 2.5])
- beta_v = np.array([1.0, 4.0, 5.0])
- gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
- self.assertEqual(gamma.mean().get_shape(), (3,))
- if not stats:
- return
- expected_means = stats.gamma.mean(alpha_v, scale=1 / beta_v)
- self.assertAllClose(self.evaluate(gamma.mean()), expected_means)
+ alpha_v = np.array([1.0, 3.0, 2.5])
+ beta_v = np.array([1.0, 4.0, 5.0])
+ gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
+ self.assertEqual(gamma.mean().get_shape(), (3,))
+ if not stats:
+ return
+ expected_means = stats.gamma.mean(alpha_v, scale=1 / beta_v)
+ self.assertAllClose(self.evaluate(gamma.mean()), expected_means)
def testGammaModeAllowNanStatsIsFalseWorksWhenAllBatchMembersAreDefined(self):
- with self.test_session():
- alpha_v = np.array([5.5, 3.0, 2.5])
- beta_v = np.array([1.0, 4.0, 5.0])
- gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
- expected_modes = (alpha_v - 1) / beta_v
- self.assertEqual(gamma.mode().get_shape(), (3,))
- self.assertAllClose(self.evaluate(gamma.mode()), expected_modes)
+ alpha_v = np.array([5.5, 3.0, 2.5])
+ beta_v = np.array([1.0, 4.0, 5.0])
+ gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
+ expected_modes = (alpha_v - 1) / beta_v
+ self.assertEqual(gamma.mode().get_shape(), (3,))
+ self.assertAllClose(self.evaluate(gamma.mode()), expected_modes)
def testGammaModeAllowNanStatsFalseRaisesForUndefinedBatchMembers(self):
- with self.test_session():
- # Mode will not be defined for the first entry.
- alpha_v = np.array([0.5, 3.0, 2.5])
- beta_v = np.array([1.0, 4.0, 5.0])
- gamma = gamma_lib.Gamma(concentration=alpha_v,
- rate=beta_v,
- allow_nan_stats=False)
- with self.assertRaisesOpError("x < y"):
- self.evaluate(gamma.mode())
+ # Mode will not be defined for the first entry.
+ alpha_v = np.array([0.5, 3.0, 2.5])
+ beta_v = np.array([1.0, 4.0, 5.0])
+ gamma = gamma_lib.Gamma(
+ concentration=alpha_v, rate=beta_v, allow_nan_stats=False)
+ with self.assertRaisesOpError("x < y"):
+ self.evaluate(gamma.mode())
def testGammaModeAllowNanStatsIsTrueReturnsNaNforUndefinedBatchMembers(self):
- with self.test_session():
- # Mode will not be defined for the first entry.
- alpha_v = np.array([0.5, 3.0, 2.5])
- beta_v = np.array([1.0, 4.0, 5.0])
- gamma = gamma_lib.Gamma(concentration=alpha_v,
- rate=beta_v,
- allow_nan_stats=True)
- expected_modes = (alpha_v - 1) / beta_v
- expected_modes[0] = np.nan
- self.assertEqual(gamma.mode().get_shape(), (3,))
- self.assertAllClose(self.evaluate(gamma.mode()), expected_modes)
+ # Mode will not be defined for the first entry.
+ alpha_v = np.array([0.5, 3.0, 2.5])
+ beta_v = np.array([1.0, 4.0, 5.0])
+ gamma = gamma_lib.Gamma(
+ concentration=alpha_v, rate=beta_v, allow_nan_stats=True)
+ expected_modes = (alpha_v - 1) / beta_v
+ expected_modes[0] = np.nan
+ self.assertEqual(gamma.mode().get_shape(), (3,))
+ self.assertAllClose(self.evaluate(gamma.mode()), expected_modes)
def testGammaVariance(self):
- with self.test_session():
- alpha_v = np.array([1.0, 3.0, 2.5])
- beta_v = np.array([1.0, 4.0, 5.0])
- gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
- self.assertEqual(gamma.variance().get_shape(), (3,))
- if not stats:
- return
- expected_variances = stats.gamma.var(alpha_v, scale=1 / beta_v)
- self.assertAllClose(self.evaluate(gamma.variance()), expected_variances)
+ alpha_v = np.array([1.0, 3.0, 2.5])
+ beta_v = np.array([1.0, 4.0, 5.0])
+ gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
+ self.assertEqual(gamma.variance().get_shape(), (3,))
+ if not stats:
+ return
+ expected_variances = stats.gamma.var(alpha_v, scale=1 / beta_v)
+ self.assertAllClose(self.evaluate(gamma.variance()), expected_variances)
def testGammaStd(self):
- with self.test_session():
- alpha_v = np.array([1.0, 3.0, 2.5])
- beta_v = np.array([1.0, 4.0, 5.0])
- gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
- self.assertEqual(gamma.stddev().get_shape(), (3,))
- if not stats:
- return
- expected_stddev = stats.gamma.std(alpha_v, scale=1. / beta_v)
- self.assertAllClose(self.evaluate(gamma.stddev()), expected_stddev)
+ alpha_v = np.array([1.0, 3.0, 2.5])
+ beta_v = np.array([1.0, 4.0, 5.0])
+ gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
+ self.assertEqual(gamma.stddev().get_shape(), (3,))
+ if not stats:
+ return
+ expected_stddev = stats.gamma.std(alpha_v, scale=1. / beta_v)
+ self.assertAllClose(self.evaluate(gamma.stddev()), expected_stddev)
def testGammaEntropy(self):
- with self.test_session():
- alpha_v = np.array([1.0, 3.0, 2.5])
- beta_v = np.array([1.0, 4.0, 5.0])
- gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
- self.assertEqual(gamma.entropy().get_shape(), (3,))
- if not stats:
- return
- expected_entropy = stats.gamma.entropy(alpha_v, scale=1 / beta_v)
- self.assertAllClose(self.evaluate(gamma.entropy()), expected_entropy)
+ alpha_v = np.array([1.0, 3.0, 2.5])
+ beta_v = np.array([1.0, 4.0, 5.0])
+ gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
+ self.assertEqual(gamma.entropy().get_shape(), (3,))
+ if not stats:
+ return
+ expected_entropy = stats.gamma.entropy(alpha_v, scale=1 / beta_v)
+ self.assertAllClose(self.evaluate(gamma.entropy()), expected_entropy)
def testGammaSampleSmallAlpha(self):
- with self.test_session():
- alpha_v = 0.05
- beta_v = 1.0
- alpha = constant_op.constant(alpha_v)
- beta = constant_op.constant(beta_v)
- n = 100000
- gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
- samples = gamma.sample(n, seed=137)
- sample_values = self.evaluate(samples)
- self.assertEqual(samples.get_shape(), (n,))
- self.assertEqual(sample_values.shape, (n,))
- self.assertTrue(self._kstest(alpha_v, beta_v, sample_values))
- if not stats:
- return
- self.assertAllClose(
- sample_values.mean(),
- stats.gamma.mean(
- alpha_v, scale=1 / beta_v),
- atol=.01)
- self.assertAllClose(
- sample_values.var(),
- stats.gamma.var(alpha_v, scale=1 / beta_v),
- atol=.15)
+ alpha_v = 0.05
+ beta_v = 1.0
+ alpha = constant_op.constant(alpha_v)
+ beta = constant_op.constant(beta_v)
+ n = 100000
+ gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
+ samples = gamma.sample(n, seed=137)
+ sample_values = self.evaluate(samples)
+ self.assertEqual(samples.get_shape(), (n,))
+ self.assertEqual(sample_values.shape, (n,))
+ self.assertTrue(self._kstest(alpha_v, beta_v, sample_values))
+ if not stats:
+ return
+ self.assertAllClose(
+ sample_values.mean(),
+ stats.gamma.mean(alpha_v, scale=1 / beta_v),
+ atol=.01)
+ self.assertAllClose(
+ sample_values.var(),
+ stats.gamma.var(alpha_v, scale=1 / beta_v),
+ atol=.15)
def testGammaSample(self):
- with self.test_session():
- alpha_v = 4.0
- beta_v = 3.0
- alpha = constant_op.constant(alpha_v)
- beta = constant_op.constant(beta_v)
- n = 100000
- gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
- samples = gamma.sample(n, seed=137)
- sample_values = self.evaluate(samples)
- self.assertEqual(samples.get_shape(), (n,))
- self.assertEqual(sample_values.shape, (n,))
- self.assertTrue(self._kstest(alpha_v, beta_v, sample_values))
- if not stats:
- return
- self.assertAllClose(
- sample_values.mean(),
- stats.gamma.mean(
- alpha_v, scale=1 / beta_v),
- atol=.01)
- self.assertAllClose(
- sample_values.var(),
- stats.gamma.var(alpha_v, scale=1 / beta_v),
- atol=.15)
+ alpha_v = 4.0
+ beta_v = 3.0
+ alpha = constant_op.constant(alpha_v)
+ beta = constant_op.constant(beta_v)
+ n = 100000
+ gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
+ samples = gamma.sample(n, seed=137)
+ sample_values = self.evaluate(samples)
+ self.assertEqual(samples.get_shape(), (n,))
+ self.assertEqual(sample_values.shape, (n,))
+ self.assertTrue(self._kstest(alpha_v, beta_v, sample_values))
+ if not stats:
+ return
+ self.assertAllClose(
+ sample_values.mean(),
+ stats.gamma.mean(alpha_v, scale=1 / beta_v),
+ atol=.01)
+ self.assertAllClose(
+ sample_values.var(),
+ stats.gamma.var(alpha_v, scale=1 / beta_v),
+ atol=.15)
def testGammaFullyReparameterized(self):
alpha = constant_op.constant(4.0)
@@ -279,37 +261,37 @@ class GammaTest(test.TestCase):
self.assertIsNotNone(grad_beta)
def testGammaSampleMultiDimensional(self):
- with self.test_session():
- alpha_v = np.array([np.arange(1, 101, dtype=np.float32)]) # 1 x 100
- beta_v = np.array([np.arange(1, 11, dtype=np.float32)]).T # 10 x 1
- gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
- n = 10000
- samples = gamma.sample(n, seed=137)
- sample_values = self.evaluate(samples)
- self.assertEqual(samples.get_shape(), (n, 10, 100))
- self.assertEqual(sample_values.shape, (n, 10, 100))
- zeros = np.zeros_like(alpha_v + beta_v) # 10 x 100
- alpha_bc = alpha_v + zeros
- beta_bc = beta_v + zeros
- if not stats:
- return
- self.assertAllClose(
- sample_values.mean(axis=0),
- stats.gamma.mean(
- alpha_bc, scale=1 / beta_bc),
- atol=0., rtol=.05)
- self.assertAllClose(
- sample_values.var(axis=0),
- stats.gamma.var(alpha_bc, scale=1 / beta_bc),
- atol=10.0, rtol=0.)
- fails = 0
- trials = 0
- for ai, a in enumerate(np.reshape(alpha_v, [-1])):
- for bi, b in enumerate(np.reshape(beta_v, [-1])):
- s = sample_values[:, bi, ai]
- trials += 1
- fails += 0 if self._kstest(a, b, s) else 1
- self.assertLess(fails, trials * 0.03)
+ alpha_v = np.array([np.arange(1, 101, dtype=np.float32)]) # 1 x 100
+ beta_v = np.array([np.arange(1, 11, dtype=np.float32)]).T # 10 x 1
+ gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
+ n = 10000
+ samples = gamma.sample(n, seed=137)
+ sample_values = self.evaluate(samples)
+ self.assertEqual(samples.get_shape(), (n, 10, 100))
+ self.assertEqual(sample_values.shape, (n, 10, 100))
+ zeros = np.zeros_like(alpha_v + beta_v) # 10 x 100
+ alpha_bc = alpha_v + zeros
+ beta_bc = beta_v + zeros
+ if not stats:
+ return
+ self.assertAllClose(
+ sample_values.mean(axis=0),
+ stats.gamma.mean(alpha_bc, scale=1 / beta_bc),
+ atol=0.,
+ rtol=.05)
+ self.assertAllClose(
+ sample_values.var(axis=0),
+ stats.gamma.var(alpha_bc, scale=1 / beta_bc),
+ atol=10.0,
+ rtol=0.)
+ fails = 0
+ trials = 0
+ for ai, a in enumerate(np.reshape(alpha_v, [-1])):
+ for bi, b in enumerate(np.reshape(beta_v, [-1])):
+ s = sample_values[:, bi, ai]
+ trials += 1
+ fails += 0 if self._kstest(a, b, s) else 1
+ self.assertLess(fails, trials * 0.03)
def _kstest(self, alpha, beta, samples):
# Uses the Kolmogorov-Smirnov test for goodness of fit.
@@ -320,30 +302,29 @@ class GammaTest(test.TestCase):
return ks < 0.02
def testGammaPdfOfSampleMultiDims(self):
- with self.test_session():
- gamma = gamma_lib.Gamma(concentration=[7., 11.], rate=[[5.], [6.]])
- num = 50000
- samples = gamma.sample(num, seed=137)
- pdfs = gamma.prob(samples)
- sample_vals, pdf_vals = self.evaluate([samples, pdfs])
- self.assertEqual(samples.get_shape(), (num, 2, 2))
- self.assertEqual(pdfs.get_shape(), (num, 2, 2))
- self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.02)
- self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.02)
- self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.02)
- self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.02)
- if not stats:
- return
- self.assertAllClose(
- stats.gamma.mean(
- [[7., 11.], [7., 11.]], scale=1 / np.array([[5., 5.], [6., 6.]])),
- sample_vals.mean(axis=0),
- atol=.1)
- self.assertAllClose(
- stats.gamma.var([[7., 11.], [7., 11.]],
- scale=1 / np.array([[5., 5.], [6., 6.]])),
- sample_vals.var(axis=0),
- atol=.1)
+ gamma = gamma_lib.Gamma(concentration=[7., 11.], rate=[[5.], [6.]])
+ num = 50000
+ samples = gamma.sample(num, seed=137)
+ pdfs = gamma.prob(samples)
+ sample_vals, pdf_vals = self.evaluate([samples, pdfs])
+ self.assertEqual(samples.get_shape(), (num, 2, 2))
+ self.assertEqual(pdfs.get_shape(), (num, 2, 2))
+ self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.02)
+ self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.02)
+ self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.02)
+ self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.02)
+ if not stats:
+ return
+ self.assertAllClose(
+ stats.gamma.mean([[7., 11.], [7., 11.]],
+ scale=1 / np.array([[5., 5.], [6., 6.]])),
+ sample_vals.mean(axis=0),
+ atol=.1)
+ self.assertAllClose(
+ stats.gamma.var([[7., 11.], [7., 11.]],
+ scale=1 / np.array([[5., 5.], [6., 6.]])),
+ sample_vals.var(axis=0),
+ atol=.1)
def _assertIntegral(self, sample_vals, pdf_vals, err=1e-3):
s_p = zip(sample_vals, pdf_vals)
@@ -356,32 +337,29 @@ class GammaTest(test.TestCase):
self.assertNear(1., total, err=err)
def testGammaNonPositiveInitializationParamsRaises(self):
- with self.test_session():
- alpha_v = constant_op.constant(0.0, name="alpha")
- beta_v = constant_op.constant(1.0, name="beta")
- with self.assertRaisesOpError("x > 0"):
- gamma = gamma_lib.Gamma(concentration=alpha_v,
- rate=beta_v,
- validate_args=True)
- self.evaluate(gamma.mean())
- alpha_v = constant_op.constant(1.0, name="alpha")
- beta_v = constant_op.constant(0.0, name="beta")
- with self.assertRaisesOpError("x > 0"):
- gamma = gamma_lib.Gamma(concentration=alpha_v,
- rate=beta_v,
- validate_args=True)
- self.evaluate(gamma.mean())
+ alpha_v = constant_op.constant(0.0, name="alpha")
+ beta_v = constant_op.constant(1.0, name="beta")
+ with self.assertRaisesOpError("x > 0"):
+ gamma = gamma_lib.Gamma(
+ concentration=alpha_v, rate=beta_v, validate_args=True)
+ self.evaluate(gamma.mean())
+ alpha_v = constant_op.constant(1.0, name="alpha")
+ beta_v = constant_op.constant(0.0, name="beta")
+ with self.assertRaisesOpError("x > 0"):
+ gamma = gamma_lib.Gamma(
+ concentration=alpha_v, rate=beta_v, validate_args=True)
+ self.evaluate(gamma.mean())
def testGammaWithSoftplusConcentrationRate(self):
- with self.test_session():
- alpha_v = constant_op.constant([0.0, -2.1], name="alpha")
- beta_v = constant_op.constant([1.0, -3.6], name="beta")
- gamma = gamma_lib.GammaWithSoftplusConcentrationRate(
- concentration=alpha_v, rate=beta_v)
- self.assertAllEqual(self.evaluate(nn_ops.softplus(alpha_v)),
- self.evaluate(gamma.concentration))
- self.assertAllEqual(self.evaluate(nn_ops.softplus(beta_v)),
- self.evaluate(gamma.rate))
+ alpha_v = constant_op.constant([0.0, -2.1], name="alpha")
+ beta_v = constant_op.constant([1.0, -3.6], name="beta")
+ gamma = gamma_lib.GammaWithSoftplusConcentrationRate(
+ concentration=alpha_v, rate=beta_v)
+ self.assertAllEqual(
+ self.evaluate(nn_ops.softplus(alpha_v)),
+ self.evaluate(gamma.concentration))
+ self.assertAllEqual(
+ self.evaluate(nn_ops.softplus(beta_v)), self.evaluate(gamma.rate))
def testGammaGammaKL(self):
alpha0 = np.array([3.])
@@ -391,15 +369,14 @@ class GammaTest(test.TestCase):
beta1 = np.array([0.5, 1., 1.5, 2., 2.5, 3.])
# Build graph.
- with self.test_session():
- g0 = gamma_lib.Gamma(concentration=alpha0, rate=beta0)
- g1 = gamma_lib.Gamma(concentration=alpha1, rate=beta1)
- x = g0.sample(int(1e4), seed=0)
- kl_sample = math_ops.reduce_mean(g0.log_prob(x) - g1.log_prob(x), 0)
- kl_actual = kullback_leibler.kl_divergence(g0, g1)
-
- # Execute graph.
- [kl_sample_, kl_actual_] = self.evaluate([kl_sample, kl_actual])
+ g0 = gamma_lib.Gamma(concentration=alpha0, rate=beta0)
+ g1 = gamma_lib.Gamma(concentration=alpha1, rate=beta1)
+ x = g0.sample(int(1e4), seed=0)
+ kl_sample = math_ops.reduce_mean(g0.log_prob(x) - g1.log_prob(x), 0)
+ kl_actual = kullback_leibler.kl_divergence(g0, g1)
+
+ # Execute graph.
+ [kl_sample_, kl_actual_] = self.evaluate([kl_sample, kl_actual])
self.assertEqual(beta0.shape, kl_actual.get_shape())
diff --git a/tensorflow/python/kernel_tests/distributions/laplace_test.py b/tensorflow/python/kernel_tests/distributions/laplace_test.py
index 24b243f647..630c2cb424 100644
--- a/tensorflow/python/kernel_tests/distributions/laplace_test.py
+++ b/tensorflow/python/kernel_tests/distributions/laplace_test.py
@@ -21,7 +21,6 @@ import importlib
import numpy as np
-from tensorflow.python.client import session
from tensorflow.python.eager import backprop
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import tensor_shape
@@ -49,212 +48,198 @@ stats = try_import("scipy.stats")
class LaplaceTest(test.TestCase):
def testLaplaceShape(self):
- with self.test_session():
- loc = constant_op.constant([3.0] * 5)
- scale = constant_op.constant(11.0)
- laplace = laplace_lib.Laplace(loc=loc, scale=scale)
+ loc = constant_op.constant([3.0] * 5)
+ scale = constant_op.constant(11.0)
+ laplace = laplace_lib.Laplace(loc=loc, scale=scale)
- self.assertEqual(self.evaluate(laplace.batch_shape_tensor()), (5,))
- self.assertEqual(laplace.batch_shape, tensor_shape.TensorShape([5]))
- self.assertAllEqual(self.evaluate(laplace.event_shape_tensor()), [])
- self.assertEqual(laplace.event_shape, tensor_shape.TensorShape([]))
+ self.assertEqual(self.evaluate(laplace.batch_shape_tensor()), (5,))
+ self.assertEqual(laplace.batch_shape, tensor_shape.TensorShape([5]))
+ self.assertAllEqual(self.evaluate(laplace.event_shape_tensor()), [])
+ self.assertEqual(laplace.event_shape, tensor_shape.TensorShape([]))
def testLaplaceLogPDF(self):
- with self.test_session():
- batch_size = 6
- loc = constant_op.constant([2.0] * batch_size)
- scale = constant_op.constant([3.0] * batch_size)
- loc_v = 2.0
- scale_v = 3.0
- x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
- laplace = laplace_lib.Laplace(loc=loc, scale=scale)
- log_pdf = laplace.log_prob(x)
- self.assertEqual(log_pdf.get_shape(), (6,))
- if not stats:
- return
- expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v)
- self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf)
+ batch_size = 6
+ loc = constant_op.constant([2.0] * batch_size)
+ scale = constant_op.constant([3.0] * batch_size)
+ loc_v = 2.0
+ scale_v = 3.0
+ x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
+ laplace = laplace_lib.Laplace(loc=loc, scale=scale)
+ log_pdf = laplace.log_prob(x)
+ self.assertEqual(log_pdf.get_shape(), (6,))
+ if not stats:
+ return
+ expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v)
+ self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf)
- pdf = laplace.prob(x)
- self.assertEqual(pdf.get_shape(), (6,))
- self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf))
+ pdf = laplace.prob(x)
+ self.assertEqual(pdf.get_shape(), (6,))
+ self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf))
def testLaplaceLogPDFMultidimensional(self):
- with self.test_session():
- batch_size = 6
- loc = constant_op.constant([[2.0, 4.0]] * batch_size)
- scale = constant_op.constant([[3.0, 4.0]] * batch_size)
- loc_v = np.array([2.0, 4.0])
- scale_v = np.array([3.0, 4.0])
- x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
- laplace = laplace_lib.Laplace(loc=loc, scale=scale)
- log_pdf = laplace.log_prob(x)
- log_pdf_values = self.evaluate(log_pdf)
- self.assertEqual(log_pdf.get_shape(), (6, 2))
-
- pdf = laplace.prob(x)
- pdf_values = self.evaluate(pdf)
- self.assertEqual(pdf.get_shape(), (6, 2))
- if not stats:
- return
- expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v)
- self.assertAllClose(log_pdf_values, expected_log_pdf)
- self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
+ batch_size = 6
+ loc = constant_op.constant([[2.0, 4.0]] * batch_size)
+ scale = constant_op.constant([[3.0, 4.0]] * batch_size)
+ loc_v = np.array([2.0, 4.0])
+ scale_v = np.array([3.0, 4.0])
+ x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
+ laplace = laplace_lib.Laplace(loc=loc, scale=scale)
+ log_pdf = laplace.log_prob(x)
+ log_pdf_values = self.evaluate(log_pdf)
+ self.assertEqual(log_pdf.get_shape(), (6, 2))
+
+ pdf = laplace.prob(x)
+ pdf_values = self.evaluate(pdf)
+ self.assertEqual(pdf.get_shape(), (6, 2))
+ if not stats:
+ return
+ expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v)
+ self.assertAllClose(log_pdf_values, expected_log_pdf)
+ self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
def testLaplaceLogPDFMultidimensionalBroadcasting(self):
- with self.test_session():
- batch_size = 6
- loc = constant_op.constant([[2.0, 4.0]] * batch_size)
- scale = constant_op.constant(3.0)
- loc_v = np.array([2.0, 4.0])
- scale_v = 3.0
- x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
- laplace = laplace_lib.Laplace(loc=loc, scale=scale)
- log_pdf = laplace.log_prob(x)
- log_pdf_values = self.evaluate(log_pdf)
- self.assertEqual(log_pdf.get_shape(), (6, 2))
-
- pdf = laplace.prob(x)
- pdf_values = self.evaluate(pdf)
- self.assertEqual(pdf.get_shape(), (6, 2))
- if not stats:
- return
- expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v)
- self.assertAllClose(log_pdf_values, expected_log_pdf)
- self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
+ batch_size = 6
+ loc = constant_op.constant([[2.0, 4.0]] * batch_size)
+ scale = constant_op.constant(3.0)
+ loc_v = np.array([2.0, 4.0])
+ scale_v = 3.0
+ x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
+ laplace = laplace_lib.Laplace(loc=loc, scale=scale)
+ log_pdf = laplace.log_prob(x)
+ log_pdf_values = self.evaluate(log_pdf)
+ self.assertEqual(log_pdf.get_shape(), (6, 2))
+
+ pdf = laplace.prob(x)
+ pdf_values = self.evaluate(pdf)
+ self.assertEqual(pdf.get_shape(), (6, 2))
+ if not stats:
+ return
+ expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v)
+ self.assertAllClose(log_pdf_values, expected_log_pdf)
+ self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
def testLaplaceCDF(self):
- with self.test_session():
- batch_size = 6
- loc = constant_op.constant([2.0] * batch_size)
- scale = constant_op.constant([3.0] * batch_size)
- loc_v = 2.0
- scale_v = 3.0
- x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
+ batch_size = 6
+ loc = constant_op.constant([2.0] * batch_size)
+ scale = constant_op.constant([3.0] * batch_size)
+ loc_v = 2.0
+ scale_v = 3.0
+ x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
- laplace = laplace_lib.Laplace(loc=loc, scale=scale)
+ laplace = laplace_lib.Laplace(loc=loc, scale=scale)
- cdf = laplace.cdf(x)
- self.assertEqual(cdf.get_shape(), (6,))
- if not stats:
- return
- expected_cdf = stats.laplace.cdf(x, loc_v, scale=scale_v)
- self.assertAllClose(self.evaluate(cdf), expected_cdf)
+ cdf = laplace.cdf(x)
+ self.assertEqual(cdf.get_shape(), (6,))
+ if not stats:
+ return
+ expected_cdf = stats.laplace.cdf(x, loc_v, scale=scale_v)
+ self.assertAllClose(self.evaluate(cdf), expected_cdf)
def testLaplaceLogCDF(self):
- with self.test_session():
- batch_size = 6
- loc = constant_op.constant([2.0] * batch_size)
- scale = constant_op.constant([3.0] * batch_size)
- loc_v = 2.0
- scale_v = 3.0
- x = np.array([-2.5, 2.5, -4.0, 0.1, 1.0, 2.0], dtype=np.float32)
+ batch_size = 6
+ loc = constant_op.constant([2.0] * batch_size)
+ scale = constant_op.constant([3.0] * batch_size)
+ loc_v = 2.0
+ scale_v = 3.0
+ x = np.array([-2.5, 2.5, -4.0, 0.1, 1.0, 2.0], dtype=np.float32)
- laplace = laplace_lib.Laplace(loc=loc, scale=scale)
+ laplace = laplace_lib.Laplace(loc=loc, scale=scale)
- cdf = laplace.log_cdf(x)
- self.assertEqual(cdf.get_shape(), (6,))
- if not stats:
- return
- expected_cdf = stats.laplace.logcdf(x, loc_v, scale=scale_v)
- self.assertAllClose(self.evaluate(cdf), expected_cdf)
+ cdf = laplace.log_cdf(x)
+ self.assertEqual(cdf.get_shape(), (6,))
+ if not stats:
+ return
+ expected_cdf = stats.laplace.logcdf(x, loc_v, scale=scale_v)
+ self.assertAllClose(self.evaluate(cdf), expected_cdf)
def testLaplaceLogSurvivalFunction(self):
- with self.test_session():
- batch_size = 6
- loc = constant_op.constant([2.0] * batch_size)
- scale = constant_op.constant([3.0] * batch_size)
- loc_v = 2.0
- scale_v = 3.0
- x = np.array([-2.5, 2.5, -4.0, 0.1, 1.0, 2.0], dtype=np.float32)
+ batch_size = 6
+ loc = constant_op.constant([2.0] * batch_size)
+ scale = constant_op.constant([3.0] * batch_size)
+ loc_v = 2.0
+ scale_v = 3.0
+ x = np.array([-2.5, 2.5, -4.0, 0.1, 1.0, 2.0], dtype=np.float32)
- laplace = laplace_lib.Laplace(loc=loc, scale=scale)
+ laplace = laplace_lib.Laplace(loc=loc, scale=scale)
- sf = laplace.log_survival_function(x)
- self.assertEqual(sf.get_shape(), (6,))
- if not stats:
- return
- expected_sf = stats.laplace.logsf(x, loc_v, scale=scale_v)
- self.assertAllClose(self.evaluate(sf), expected_sf)
+ sf = laplace.log_survival_function(x)
+ self.assertEqual(sf.get_shape(), (6,))
+ if not stats:
+ return
+ expected_sf = stats.laplace.logsf(x, loc_v, scale=scale_v)
+ self.assertAllClose(self.evaluate(sf), expected_sf)
def testLaplaceMean(self):
- with self.test_session():
- loc_v = np.array([1.0, 3.0, 2.5])
- scale_v = np.array([1.0, 4.0, 5.0])
- laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
- self.assertEqual(laplace.mean().get_shape(), (3,))
- if not stats:
- return
- expected_means = stats.laplace.mean(loc_v, scale=scale_v)
- self.assertAllClose(self.evaluate(laplace.mean()), expected_means)
+ loc_v = np.array([1.0, 3.0, 2.5])
+ scale_v = np.array([1.0, 4.0, 5.0])
+ laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
+ self.assertEqual(laplace.mean().get_shape(), (3,))
+ if not stats:
+ return
+ expected_means = stats.laplace.mean(loc_v, scale=scale_v)
+ self.assertAllClose(self.evaluate(laplace.mean()), expected_means)
def testLaplaceMode(self):
- with self.test_session():
- loc_v = np.array([0.5, 3.0, 2.5])
- scale_v = np.array([1.0, 4.0, 5.0])
- laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
- self.assertEqual(laplace.mode().get_shape(), (3,))
- self.assertAllClose(self.evaluate(laplace.mode()), loc_v)
+ loc_v = np.array([0.5, 3.0, 2.5])
+ scale_v = np.array([1.0, 4.0, 5.0])
+ laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
+ self.assertEqual(laplace.mode().get_shape(), (3,))
+ self.assertAllClose(self.evaluate(laplace.mode()), loc_v)
def testLaplaceVariance(self):
- with self.test_session():
- loc_v = np.array([1.0, 3.0, 2.5])
- scale_v = np.array([1.0, 4.0, 5.0])
- laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
- self.assertEqual(laplace.variance().get_shape(), (3,))
- if not stats:
- return
- expected_variances = stats.laplace.var(loc_v, scale=scale_v)
- self.assertAllClose(self.evaluate(laplace.variance()), expected_variances)
+ loc_v = np.array([1.0, 3.0, 2.5])
+ scale_v = np.array([1.0, 4.0, 5.0])
+ laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
+ self.assertEqual(laplace.variance().get_shape(), (3,))
+ if not stats:
+ return
+ expected_variances = stats.laplace.var(loc_v, scale=scale_v)
+ self.assertAllClose(self.evaluate(laplace.variance()), expected_variances)
def testLaplaceStd(self):
- with self.test_session():
- loc_v = np.array([1.0, 3.0, 2.5])
- scale_v = np.array([1.0, 4.0, 5.0])
- laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
- self.assertEqual(laplace.stddev().get_shape(), (3,))
- if not stats:
- return
- expected_stddev = stats.laplace.std(loc_v, scale=scale_v)
- self.assertAllClose(self.evaluate(laplace.stddev()), expected_stddev)
+ loc_v = np.array([1.0, 3.0, 2.5])
+ scale_v = np.array([1.0, 4.0, 5.0])
+ laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
+ self.assertEqual(laplace.stddev().get_shape(), (3,))
+ if not stats:
+ return
+ expected_stddev = stats.laplace.std(loc_v, scale=scale_v)
+ self.assertAllClose(self.evaluate(laplace.stddev()), expected_stddev)
def testLaplaceEntropy(self):
- with self.test_session():
- loc_v = np.array([1.0, 3.0, 2.5])
- scale_v = np.array([1.0, 4.0, 5.0])
- laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
- self.assertEqual(laplace.entropy().get_shape(), (3,))
- if not stats:
- return
- expected_entropy = stats.laplace.entropy(loc_v, scale=scale_v)
- self.assertAllClose(self.evaluate(laplace.entropy()), expected_entropy)
+ loc_v = np.array([1.0, 3.0, 2.5])
+ scale_v = np.array([1.0, 4.0, 5.0])
+ laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
+ self.assertEqual(laplace.entropy().get_shape(), (3,))
+ if not stats:
+ return
+ expected_entropy = stats.laplace.entropy(loc_v, scale=scale_v)
+ self.assertAllClose(self.evaluate(laplace.entropy()), expected_entropy)
def testLaplaceSample(self):
- with session.Session():
- loc_v = 4.0
- scale_v = 3.0
- loc = constant_op.constant(loc_v)
- scale = constant_op.constant(scale_v)
- n = 100000
- laplace = laplace_lib.Laplace(loc=loc, scale=scale)
- samples = laplace.sample(n, seed=137)
- sample_values = self.evaluate(samples)
- self.assertEqual(samples.get_shape(), (n,))
- self.assertEqual(sample_values.shape, (n,))
- if not stats:
- return
- self.assertAllClose(
- sample_values.mean(),
- stats.laplace.mean(
- loc_v, scale=scale_v),
- rtol=0.05,
- atol=0.)
- self.assertAllClose(
- sample_values.var(),
- stats.laplace.var(loc_v, scale=scale_v),
- rtol=0.05,
- atol=0.)
- self.assertTrue(self._kstest(loc_v, scale_v, sample_values))
+ loc_v = 4.0
+ scale_v = 3.0
+ loc = constant_op.constant(loc_v)
+ scale = constant_op.constant(scale_v)
+ n = 100000
+ laplace = laplace_lib.Laplace(loc=loc, scale=scale)
+ samples = laplace.sample(n, seed=137)
+ sample_values = self.evaluate(samples)
+ self.assertEqual(samples.get_shape(), (n,))
+ self.assertEqual(sample_values.shape, (n,))
+ if not stats:
+ return
+ self.assertAllClose(
+ sample_values.mean(),
+ stats.laplace.mean(loc_v, scale=scale_v),
+ rtol=0.05,
+ atol=0.)
+ self.assertAllClose(
+ sample_values.var(),
+ stats.laplace.var(loc_v, scale=scale_v),
+ rtol=0.05,
+ atol=0.)
+ self.assertTrue(self._kstest(loc_v, scale_v, sample_values))
def testLaplaceFullyReparameterized(self):
loc = constant_op.constant(4.0)
@@ -269,39 +254,37 @@ class LaplaceTest(test.TestCase):
self.assertIsNotNone(grad_scale)
def testLaplaceSampleMultiDimensional(self):
- with session.Session():
- loc_v = np.array([np.arange(1, 101, dtype=np.float32)]) # 1 x 100
- scale_v = np.array([np.arange(1, 11, dtype=np.float32)]).T # 10 x 1
- laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
- n = 10000
- samples = laplace.sample(n, seed=137)
- sample_values = self.evaluate(samples)
- self.assertEqual(samples.get_shape(), (n, 10, 100))
- self.assertEqual(sample_values.shape, (n, 10, 100))
- zeros = np.zeros_like(loc_v + scale_v) # 10 x 100
- loc_bc = loc_v + zeros
- scale_bc = scale_v + zeros
- if not stats:
- return
- self.assertAllClose(
- sample_values.mean(axis=0),
- stats.laplace.mean(
- loc_bc, scale=scale_bc),
- rtol=0.35,
- atol=0.)
- self.assertAllClose(
- sample_values.var(axis=0),
- stats.laplace.var(loc_bc, scale=scale_bc),
- rtol=0.10,
- atol=0.)
- fails = 0
- trials = 0
- for ai, a in enumerate(np.reshape(loc_v, [-1])):
- for bi, b in enumerate(np.reshape(scale_v, [-1])):
- s = sample_values[:, bi, ai]
- trials += 1
- fails += 0 if self._kstest(a, b, s) else 1
- self.assertLess(fails, trials * 0.03)
+ loc_v = np.array([np.arange(1, 101, dtype=np.float32)]) # 1 x 100
+ scale_v = np.array([np.arange(1, 11, dtype=np.float32)]).T # 10 x 1
+ laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
+ n = 10000
+ samples = laplace.sample(n, seed=137)
+ sample_values = self.evaluate(samples)
+ self.assertEqual(samples.get_shape(), (n, 10, 100))
+ self.assertEqual(sample_values.shape, (n, 10, 100))
+ zeros = np.zeros_like(loc_v + scale_v) # 10 x 100
+ loc_bc = loc_v + zeros
+ scale_bc = scale_v + zeros
+ if not stats:
+ return
+ self.assertAllClose(
+ sample_values.mean(axis=0),
+ stats.laplace.mean(loc_bc, scale=scale_bc),
+ rtol=0.35,
+ atol=0.)
+ self.assertAllClose(
+ sample_values.var(axis=0),
+ stats.laplace.var(loc_bc, scale=scale_bc),
+ rtol=0.10,
+ atol=0.)
+ fails = 0
+ trials = 0
+ for ai, a in enumerate(np.reshape(loc_v, [-1])):
+ for bi, b in enumerate(np.reshape(scale_v, [-1])):
+ s = sample_values[:, bi, ai]
+ trials += 1
+ fails += 0 if self._kstest(a, b, s) else 1
+ self.assertLess(fails, trials * 0.03)
def _kstest(self, loc, scale, samples):
# Uses the Kolmogorov-Smirnov test for goodness of fit.
@@ -349,30 +332,26 @@ class LaplaceTest(test.TestCase):
self.assertNear(1., total, err=err)
def testLaplaceNonPositiveInitializationParamsRaises(self):
- with self.test_session():
- loc_v = constant_op.constant(0.0, name="loc")
- scale_v = constant_op.constant(-1.0, name="scale")
- with self.assertRaisesOpError(
- "Condition x > 0 did not hold element-wise"):
- laplace = laplace_lib.Laplace(
- loc=loc_v, scale=scale_v, validate_args=True)
- self.evaluate(laplace.mean())
- loc_v = constant_op.constant(1.0, name="loc")
- scale_v = constant_op.constant(0.0, name="scale")
- with self.assertRaisesOpError(
- "Condition x > 0 did not hold element-wise"):
- laplace = laplace_lib.Laplace(
- loc=loc_v, scale=scale_v, validate_args=True)
- self.evaluate(laplace.mean())
+ loc_v = constant_op.constant(0.0, name="loc")
+ scale_v = constant_op.constant(-1.0, name="scale")
+ with self.assertRaisesOpError("Condition x > 0 did not hold element-wise"):
+ laplace = laplace_lib.Laplace(
+ loc=loc_v, scale=scale_v, validate_args=True)
+ self.evaluate(laplace.mean())
+ loc_v = constant_op.constant(1.0, name="loc")
+ scale_v = constant_op.constant(0.0, name="scale")
+ with self.assertRaisesOpError("Condition x > 0 did not hold element-wise"):
+ laplace = laplace_lib.Laplace(
+ loc=loc_v, scale=scale_v, validate_args=True)
+ self.evaluate(laplace.mean())
def testLaplaceWithSoftplusScale(self):
- with self.test_session():
- loc_v = constant_op.constant([0.0, 1.0], name="loc")
- scale_v = constant_op.constant([-1.0, 2.0], name="scale")
- laplace = laplace_lib.LaplaceWithSoftplusScale(loc=loc_v, scale=scale_v)
- self.assertAllClose(
- self.evaluate(nn_ops.softplus(scale_v)), self.evaluate(laplace.scale))
- self.assertAllClose(self.evaluate(loc_v), self.evaluate(laplace.loc))
+ loc_v = constant_op.constant([0.0, 1.0], name="loc")
+ scale_v = constant_op.constant([-1.0, 2.0], name="scale")
+ laplace = laplace_lib.LaplaceWithSoftplusScale(loc=loc_v, scale=scale_v)
+ self.assertAllClose(
+ self.evaluate(nn_ops.softplus(scale_v)), self.evaluate(laplace.scale))
+ self.assertAllClose(self.evaluate(loc_v), self.evaluate(laplace.loc))
if __name__ == "__main__":
diff --git a/tensorflow/python/kernel_tests/distributions/normal_test.py b/tensorflow/python/kernel_tests/distributions/normal_test.py
index 7ff48c0c10..de73a40b23 100644
--- a/tensorflow/python/kernel_tests/distributions/normal_test.py
+++ b/tensorflow/python/kernel_tests/distributions/normal_test.py
@@ -61,16 +61,15 @@ class NormalTest(test.TestCase):
self.assertAllEqual(all_true, is_finite)
def _testParamShapes(self, sample_shape, expected):
- with self.test_session():
- param_shapes = normal_lib.Normal.param_shapes(sample_shape)
- mu_shape, sigma_shape = param_shapes["loc"], param_shapes["scale"]
- self.assertAllEqual(expected, self.evaluate(mu_shape))
- self.assertAllEqual(expected, self.evaluate(sigma_shape))
- mu = array_ops.zeros(mu_shape)
- sigma = array_ops.ones(sigma_shape)
- self.assertAllEqual(
- expected,
- self.evaluate(array_ops.shape(normal_lib.Normal(mu, sigma).sample())))
+ param_shapes = normal_lib.Normal.param_shapes(sample_shape)
+ mu_shape, sigma_shape = param_shapes["loc"], param_shapes["scale"]
+ self.assertAllEqual(expected, self.evaluate(mu_shape))
+ self.assertAllEqual(expected, self.evaluate(sigma_shape))
+ mu = array_ops.zeros(mu_shape)
+ sigma = array_ops.ones(sigma_shape)
+ self.assertAllEqual(
+ expected,
+ self.evaluate(array_ops.shape(normal_lib.Normal(mu, sigma).sample())))
def _testParamStaticShapes(self, sample_shape, expected):
param_shapes = normal_lib.Normal.param_static_shapes(sample_shape)
@@ -91,156 +90,150 @@ class NormalTest(test.TestCase):
self._testParamStaticShapes(
tensor_shape.TensorShape(sample_shape), sample_shape)
- @test_util.run_in_graph_and_eager_modes
+ @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
def testNormalWithSoftplusScale(self):
- with self.test_session():
- mu = array_ops.zeros((10, 3))
- rho = array_ops.ones((10, 3)) * -2.
- normal = normal_lib.NormalWithSoftplusScale(loc=mu, scale=rho)
- self.assertAllEqual(self.evaluate(mu), self.evaluate(normal.loc))
- self.assertAllEqual(
- self.evaluate(nn_ops.softplus(rho)), self.evaluate(normal.scale))
+ mu = array_ops.zeros((10, 3))
+ rho = array_ops.ones((10, 3)) * -2.
+ normal = normal_lib.NormalWithSoftplusScale(loc=mu, scale=rho)
+ self.assertAllEqual(self.evaluate(mu), self.evaluate(normal.loc))
+ self.assertAllEqual(
+ self.evaluate(nn_ops.softplus(rho)), self.evaluate(normal.scale))
@test_util.run_in_graph_and_eager_modes
def testNormalLogPDF(self):
- with self.test_session():
- batch_size = 6
- mu = constant_op.constant([3.0] * batch_size)
- sigma = constant_op.constant([math.sqrt(10.0)] * batch_size)
- x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], dtype=np.float32)
- normal = normal_lib.Normal(loc=mu, scale=sigma)
-
- log_pdf = normal.log_prob(x)
- self.assertAllEqual(
- self.evaluate(normal.batch_shape_tensor()), log_pdf.get_shape())
- self.assertAllEqual(
- self.evaluate(normal.batch_shape_tensor()),
- self.evaluate(log_pdf).shape)
- self.assertAllEqual(normal.batch_shape, log_pdf.get_shape())
- self.assertAllEqual(normal.batch_shape, self.evaluate(log_pdf).shape)
+ batch_size = 6
+ mu = constant_op.constant([3.0] * batch_size)
+ sigma = constant_op.constant([math.sqrt(10.0)] * batch_size)
+ x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], dtype=np.float32)
+ normal = normal_lib.Normal(loc=mu, scale=sigma)
- pdf = normal.prob(x)
- self.assertAllEqual(
- self.evaluate(normal.batch_shape_tensor()), pdf.get_shape())
- self.assertAllEqual(
- self.evaluate(normal.batch_shape_tensor()),
- self.evaluate(pdf).shape)
- self.assertAllEqual(normal.batch_shape, pdf.get_shape())
- self.assertAllEqual(normal.batch_shape, self.evaluate(pdf).shape)
-
- if not stats:
- return
- expected_log_pdf = stats.norm(self.evaluate(mu),
- self.evaluate(sigma)).logpdf(x)
- self.assertAllClose(expected_log_pdf, self.evaluate(log_pdf))
- self.assertAllClose(np.exp(expected_log_pdf), self.evaluate(pdf))
+ log_pdf = normal.log_prob(x)
+ self.assertAllEqual(
+ self.evaluate(normal.batch_shape_tensor()), log_pdf.get_shape())
+ self.assertAllEqual(
+ self.evaluate(normal.batch_shape_tensor()),
+ self.evaluate(log_pdf).shape)
+ self.assertAllEqual(normal.batch_shape, log_pdf.get_shape())
+ self.assertAllEqual(normal.batch_shape, self.evaluate(log_pdf).shape)
+
+ pdf = normal.prob(x)
+ self.assertAllEqual(
+ self.evaluate(normal.batch_shape_tensor()), pdf.get_shape())
+ self.assertAllEqual(
+ self.evaluate(normal.batch_shape_tensor()),
+ self.evaluate(pdf).shape)
+ self.assertAllEqual(normal.batch_shape, pdf.get_shape())
+ self.assertAllEqual(normal.batch_shape, self.evaluate(pdf).shape)
+
+ if not stats:
+ return
+ expected_log_pdf = stats.norm(self.evaluate(mu),
+ self.evaluate(sigma)).logpdf(x)
+ self.assertAllClose(expected_log_pdf, self.evaluate(log_pdf))
+ self.assertAllClose(np.exp(expected_log_pdf), self.evaluate(pdf))
@test_util.run_in_graph_and_eager_modes
def testNormalLogPDFMultidimensional(self):
- with self.test_session():
- batch_size = 6
- mu = constant_op.constant([[3.0, -3.0]] * batch_size)
- sigma = constant_op.constant([[math.sqrt(10.0), math.sqrt(15.0)]] *
- batch_size)
- x = np.array([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=np.float32).T
- normal = normal_lib.Normal(loc=mu, scale=sigma)
-
- log_pdf = normal.log_prob(x)
- log_pdf_values = self.evaluate(log_pdf)
- self.assertEqual(log_pdf.get_shape(), (6, 2))
- self.assertAllEqual(
- self.evaluate(normal.batch_shape_tensor()), log_pdf.get_shape())
- self.assertAllEqual(
- self.evaluate(normal.batch_shape_tensor()),
- self.evaluate(log_pdf).shape)
- self.assertAllEqual(normal.batch_shape, log_pdf.get_shape())
- self.assertAllEqual(normal.batch_shape, self.evaluate(log_pdf).shape)
-
- pdf = normal.prob(x)
- pdf_values = self.evaluate(pdf)
- self.assertEqual(pdf.get_shape(), (6, 2))
- self.assertAllEqual(
- self.evaluate(normal.batch_shape_tensor()), pdf.get_shape())
- self.assertAllEqual(
- self.evaluate(normal.batch_shape_tensor()), pdf_values.shape)
- self.assertAllEqual(normal.batch_shape, pdf.get_shape())
- self.assertAllEqual(normal.batch_shape, pdf_values.shape)
+ batch_size = 6
+ mu = constant_op.constant([[3.0, -3.0]] * batch_size)
+ sigma = constant_op.constant(
+ [[math.sqrt(10.0), math.sqrt(15.0)]] * batch_size)
+ x = np.array([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=np.float32).T
+ normal = normal_lib.Normal(loc=mu, scale=sigma)
- if not stats:
- return
- expected_log_pdf = stats.norm(self.evaluate(mu),
- self.evaluate(sigma)).logpdf(x)
- self.assertAllClose(expected_log_pdf, log_pdf_values)
- self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
+ log_pdf = normal.log_prob(x)
+ log_pdf_values = self.evaluate(log_pdf)
+ self.assertEqual(log_pdf.get_shape(), (6, 2))
+ self.assertAllEqual(
+ self.evaluate(normal.batch_shape_tensor()), log_pdf.get_shape())
+ self.assertAllEqual(
+ self.evaluate(normal.batch_shape_tensor()),
+ self.evaluate(log_pdf).shape)
+ self.assertAllEqual(normal.batch_shape, log_pdf.get_shape())
+ self.assertAllEqual(normal.batch_shape, self.evaluate(log_pdf).shape)
+
+ pdf = normal.prob(x)
+ pdf_values = self.evaluate(pdf)
+ self.assertEqual(pdf.get_shape(), (6, 2))
+ self.assertAllEqual(
+ self.evaluate(normal.batch_shape_tensor()), pdf.get_shape())
+ self.assertAllEqual(
+ self.evaluate(normal.batch_shape_tensor()), pdf_values.shape)
+ self.assertAllEqual(normal.batch_shape, pdf.get_shape())
+ self.assertAllEqual(normal.batch_shape, pdf_values.shape)
+
+ if not stats:
+ return
+ expected_log_pdf = stats.norm(self.evaluate(mu),
+ self.evaluate(sigma)).logpdf(x)
+ self.assertAllClose(expected_log_pdf, log_pdf_values)
+ self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
@test_util.run_in_graph_and_eager_modes
def testNormalCDF(self):
- with self.test_session():
- batch_size = 50
- mu = self._rng.randn(batch_size)
- sigma = self._rng.rand(batch_size) + 1.0
- x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64)
+ batch_size = 50
+ mu = self._rng.randn(batch_size)
+ sigma = self._rng.rand(batch_size) + 1.0
+ x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64)
- normal = normal_lib.Normal(loc=mu, scale=sigma)
- cdf = normal.cdf(x)
- self.assertAllEqual(
- self.evaluate(normal.batch_shape_tensor()), cdf.get_shape())
- self.assertAllEqual(
- self.evaluate(normal.batch_shape_tensor()),
- self.evaluate(cdf).shape)
- self.assertAllEqual(normal.batch_shape, cdf.get_shape())
- self.assertAllEqual(normal.batch_shape, self.evaluate(cdf).shape)
- if not stats:
- return
- expected_cdf = stats.norm(mu, sigma).cdf(x)
- self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0)
+ normal = normal_lib.Normal(loc=mu, scale=sigma)
+ cdf = normal.cdf(x)
+ self.assertAllEqual(
+ self.evaluate(normal.batch_shape_tensor()), cdf.get_shape())
+ self.assertAllEqual(
+ self.evaluate(normal.batch_shape_tensor()),
+ self.evaluate(cdf).shape)
+ self.assertAllEqual(normal.batch_shape, cdf.get_shape())
+ self.assertAllEqual(normal.batch_shape, self.evaluate(cdf).shape)
+ if not stats:
+ return
+ expected_cdf = stats.norm(mu, sigma).cdf(x)
+ self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0)
@test_util.run_in_graph_and_eager_modes
def testNormalSurvivalFunction(self):
- with self.test_session():
- batch_size = 50
- mu = self._rng.randn(batch_size)
- sigma = self._rng.rand(batch_size) + 1.0
- x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64)
+ batch_size = 50
+ mu = self._rng.randn(batch_size)
+ sigma = self._rng.rand(batch_size) + 1.0
+ x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64)
- normal = normal_lib.Normal(loc=mu, scale=sigma)
+ normal = normal_lib.Normal(loc=mu, scale=sigma)
- sf = normal.survival_function(x)
- self.assertAllEqual(
- self.evaluate(normal.batch_shape_tensor()), sf.get_shape())
- self.assertAllEqual(
- self.evaluate(normal.batch_shape_tensor()),
- self.evaluate(sf).shape)
- self.assertAllEqual(normal.batch_shape, sf.get_shape())
- self.assertAllEqual(normal.batch_shape, self.evaluate(sf).shape)
- if not stats:
- return
- expected_sf = stats.norm(mu, sigma).sf(x)
- self.assertAllClose(expected_sf, self.evaluate(sf), atol=0)
+ sf = normal.survival_function(x)
+ self.assertAllEqual(
+ self.evaluate(normal.batch_shape_tensor()), sf.get_shape())
+ self.assertAllEqual(
+ self.evaluate(normal.batch_shape_tensor()),
+ self.evaluate(sf).shape)
+ self.assertAllEqual(normal.batch_shape, sf.get_shape())
+ self.assertAllEqual(normal.batch_shape, self.evaluate(sf).shape)
+ if not stats:
+ return
+ expected_sf = stats.norm(mu, sigma).sf(x)
+ self.assertAllClose(expected_sf, self.evaluate(sf), atol=0)
@test_util.run_in_graph_and_eager_modes
def testNormalLogCDF(self):
- with self.test_session():
- batch_size = 50
- mu = self._rng.randn(batch_size)
- sigma = self._rng.rand(batch_size) + 1.0
- x = np.linspace(-100.0, 10.0, batch_size).astype(np.float64)
+ batch_size = 50
+ mu = self._rng.randn(batch_size)
+ sigma = self._rng.rand(batch_size) + 1.0
+ x = np.linspace(-100.0, 10.0, batch_size).astype(np.float64)
- normal = normal_lib.Normal(loc=mu, scale=sigma)
+ normal = normal_lib.Normal(loc=mu, scale=sigma)
- cdf = normal.log_cdf(x)
- self.assertAllEqual(
- self.evaluate(normal.batch_shape_tensor()), cdf.get_shape())
- self.assertAllEqual(
- self.evaluate(normal.batch_shape_tensor()),
- self.evaluate(cdf).shape)
- self.assertAllEqual(normal.batch_shape, cdf.get_shape())
- self.assertAllEqual(normal.batch_shape, self.evaluate(cdf).shape)
+ cdf = normal.log_cdf(x)
+ self.assertAllEqual(
+ self.evaluate(normal.batch_shape_tensor()), cdf.get_shape())
+ self.assertAllEqual(
+ self.evaluate(normal.batch_shape_tensor()),
+ self.evaluate(cdf).shape)
+ self.assertAllEqual(normal.batch_shape, cdf.get_shape())
+ self.assertAllEqual(normal.batch_shape, self.evaluate(cdf).shape)
- if not stats:
- return
- expected_cdf = stats.norm(mu, sigma).logcdf(x)
- self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0, rtol=1e-3)
+ if not stats:
+ return
+ expected_cdf = stats.norm(mu, sigma).logcdf(x)
+ self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0, rtol=1e-3)
def testFiniteGradientAtDifficultPoints(self):
for dtype in [np.float32, np.float64]:
@@ -256,7 +249,7 @@ class NormalTest(test.TestCase):
]:
value = func(x)
grads = gradients_impl.gradients(value, [mu, sigma])
- with self.test_session(graph=g):
+ with self.session(graph=g):
variables.global_variables_initializer().run()
self.assertAllFinite(value)
self.assertAllFinite(grads[0])
@@ -264,112 +257,106 @@ class NormalTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testNormalLogSurvivalFunction(self):
- with self.test_session():
- batch_size = 50
- mu = self._rng.randn(batch_size)
- sigma = self._rng.rand(batch_size) + 1.0
- x = np.linspace(-10.0, 100.0, batch_size).astype(np.float64)
+ batch_size = 50
+ mu = self._rng.randn(batch_size)
+ sigma = self._rng.rand(batch_size) + 1.0
+ x = np.linspace(-10.0, 100.0, batch_size).astype(np.float64)
- normal = normal_lib.Normal(loc=mu, scale=sigma)
+ normal = normal_lib.Normal(loc=mu, scale=sigma)
- sf = normal.log_survival_function(x)
- self.assertAllEqual(
- self.evaluate(normal.batch_shape_tensor()), sf.get_shape())
- self.assertAllEqual(
- self.evaluate(normal.batch_shape_tensor()),
- self.evaluate(sf).shape)
- self.assertAllEqual(normal.batch_shape, sf.get_shape())
- self.assertAllEqual(normal.batch_shape, self.evaluate(sf).shape)
+ sf = normal.log_survival_function(x)
+ self.assertAllEqual(
+ self.evaluate(normal.batch_shape_tensor()), sf.get_shape())
+ self.assertAllEqual(
+ self.evaluate(normal.batch_shape_tensor()),
+ self.evaluate(sf).shape)
+ self.assertAllEqual(normal.batch_shape, sf.get_shape())
+ self.assertAllEqual(normal.batch_shape, self.evaluate(sf).shape)
- if not stats:
- return
- expected_sf = stats.norm(mu, sigma).logsf(x)
- self.assertAllClose(expected_sf, self.evaluate(sf), atol=0, rtol=1e-5)
+ if not stats:
+ return
+ expected_sf = stats.norm(mu, sigma).logsf(x)
+ self.assertAllClose(expected_sf, self.evaluate(sf), atol=0, rtol=1e-5)
@test_util.run_in_graph_and_eager_modes
def testNormalEntropyWithScalarInputs(self):
# Scipy.stats.norm cannot deal with the shapes in the other test.
- with self.test_session():
- mu_v = 2.34
- sigma_v = 4.56
- normal = normal_lib.Normal(loc=mu_v, scale=sigma_v)
-
- entropy = normal.entropy()
- self.assertAllEqual(
- self.evaluate(normal.batch_shape_tensor()), entropy.get_shape())
- self.assertAllEqual(
- self.evaluate(normal.batch_shape_tensor()),
- self.evaluate(entropy).shape)
- self.assertAllEqual(normal.batch_shape, entropy.get_shape())
- self.assertAllEqual(normal.batch_shape, self.evaluate(entropy).shape)
- # scipy.stats.norm cannot deal with these shapes.
- if not stats:
- return
- expected_entropy = stats.norm(mu_v, sigma_v).entropy()
- self.assertAllClose(expected_entropy, self.evaluate(entropy))
+ mu_v = 2.34
+ sigma_v = 4.56
+ normal = normal_lib.Normal(loc=mu_v, scale=sigma_v)
+
+ entropy = normal.entropy()
+ self.assertAllEqual(
+ self.evaluate(normal.batch_shape_tensor()), entropy.get_shape())
+ self.assertAllEqual(
+ self.evaluate(normal.batch_shape_tensor()),
+ self.evaluate(entropy).shape)
+ self.assertAllEqual(normal.batch_shape, entropy.get_shape())
+ self.assertAllEqual(normal.batch_shape, self.evaluate(entropy).shape)
+ # scipy.stats.norm cannot deal with these shapes.
+ if not stats:
+ return
+ expected_entropy = stats.norm(mu_v, sigma_v).entropy()
+ self.assertAllClose(expected_entropy, self.evaluate(entropy))
@test_util.run_in_graph_and_eager_modes
def testNormalEntropy(self):
- with self.test_session():
- mu_v = np.array([1.0, 1.0, 1.0])
- sigma_v = np.array([[1.0, 2.0, 3.0]]).T
- normal = normal_lib.Normal(loc=mu_v, scale=sigma_v)
-
- # scipy.stats.norm cannot deal with these shapes.
- sigma_broadcast = mu_v * sigma_v
- expected_entropy = 0.5 * np.log(2 * np.pi * np.exp(1) * sigma_broadcast**
- 2)
- entropy = normal.entropy()
- np.testing.assert_allclose(expected_entropy, self.evaluate(entropy))
- self.assertAllEqual(
- self.evaluate(normal.batch_shape_tensor()), entropy.get_shape())
- self.assertAllEqual(
- self.evaluate(normal.batch_shape_tensor()),
- self.evaluate(entropy).shape)
- self.assertAllEqual(normal.batch_shape, entropy.get_shape())
- self.assertAllEqual(normal.batch_shape, self.evaluate(entropy).shape)
-
- @test_util.run_in_graph_and_eager_modes
+ mu_v = np.array([1.0, 1.0, 1.0])
+ sigma_v = np.array([[1.0, 2.0, 3.0]]).T
+ normal = normal_lib.Normal(loc=mu_v, scale=sigma_v)
+
+ # scipy.stats.norm cannot deal with these shapes.
+ sigma_broadcast = mu_v * sigma_v
+ expected_entropy = 0.5 * np.log(2 * np.pi * np.exp(1) * sigma_broadcast**2)
+ entropy = normal.entropy()
+ np.testing.assert_allclose(expected_entropy, self.evaluate(entropy))
+ self.assertAllEqual(
+ self.evaluate(normal.batch_shape_tensor()), entropy.get_shape())
+ self.assertAllEqual(
+ self.evaluate(normal.batch_shape_tensor()),
+ self.evaluate(entropy).shape)
+ self.assertAllEqual(normal.batch_shape, entropy.get_shape())
+ self.assertAllEqual(normal.batch_shape, self.evaluate(entropy).shape)
+
+ @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
def testNormalMeanAndMode(self):
- with self.test_session():
- # Mu will be broadcast to [7, 7, 7].
- mu = [7.]
- sigma = [11., 12., 13.]
+ # Mu will be broadcast to [7, 7, 7].
+ mu = [7.]
+ sigma = [11., 12., 13.]
- normal = normal_lib.Normal(loc=mu, scale=sigma)
+ normal = normal_lib.Normal(loc=mu, scale=sigma)
- self.assertAllEqual((3,), normal.mean().get_shape())
- self.assertAllEqual([7., 7, 7], self.evaluate(normal.mean()))
+ self.assertAllEqual((3,), normal.mean().get_shape())
+ self.assertAllEqual([7., 7, 7], self.evaluate(normal.mean()))
- self.assertAllEqual((3,), normal.mode().get_shape())
- self.assertAllEqual([7., 7, 7], self.evaluate(normal.mode()))
+ self.assertAllEqual((3,), normal.mode().get_shape())
+ self.assertAllEqual([7., 7, 7], self.evaluate(normal.mode()))
@test_util.run_in_graph_and_eager_modes
def testNormalQuantile(self):
- with self.test_session():
- batch_size = 52
- mu = self._rng.randn(batch_size)
- sigma = self._rng.rand(batch_size) + 1.0
- p = np.linspace(0., 1.0, batch_size - 2).astype(np.float64)
- # Quantile performs piecewise rational approximation so adding some
- # special input values to make sure we hit all the pieces.
- p = np.hstack((p, np.exp(-33), 1. - np.exp(-33)))
+ batch_size = 52
+ mu = self._rng.randn(batch_size)
+ sigma = self._rng.rand(batch_size) + 1.0
+ p = np.linspace(0., 1.0, batch_size - 2).astype(np.float64)
+ # Quantile performs piecewise rational approximation so adding some
+ # special input values to make sure we hit all the pieces.
+ p = np.hstack((p, np.exp(-33), 1. - np.exp(-33)))
- normal = normal_lib.Normal(loc=mu, scale=sigma)
- x = normal.quantile(p)
+ normal = normal_lib.Normal(loc=mu, scale=sigma)
+ x = normal.quantile(p)
- self.assertAllEqual(
- self.evaluate(normal.batch_shape_tensor()), x.get_shape())
- self.assertAllEqual(
- self.evaluate(normal.batch_shape_tensor()),
- self.evaluate(x).shape)
- self.assertAllEqual(normal.batch_shape, x.get_shape())
- self.assertAllEqual(normal.batch_shape, self.evaluate(x).shape)
+ self.assertAllEqual(
+ self.evaluate(normal.batch_shape_tensor()), x.get_shape())
+ self.assertAllEqual(
+ self.evaluate(normal.batch_shape_tensor()),
+ self.evaluate(x).shape)
+ self.assertAllEqual(normal.batch_shape, x.get_shape())
+ self.assertAllEqual(normal.batch_shape, self.evaluate(x).shape)
- if not stats:
- return
- expected_x = stats.norm(mu, sigma).ppf(p)
- self.assertAllClose(expected_x, self.evaluate(x), atol=0.)
+ if not stats:
+ return
+ expected_x = stats.norm(mu, sigma).ppf(p)
+ self.assertAllClose(expected_x, self.evaluate(x), atol=0.)
def _baseQuantileFiniteGradientAtDifficultPoints(self, dtype):
g = ops.Graph()
@@ -385,7 +372,7 @@ class NormalTest(test.TestCase):
value = dist.quantile(p)
grads = gradients_impl.gradients(value, [mu, p])
- with self.test_session(graph=g):
+ with self.cached_session(graph=g):
variables.global_variables_initializer().run()
self.assertAllFinite(grads[0])
self.assertAllFinite(grads[1])
@@ -398,61 +385,58 @@ class NormalTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testNormalVariance(self):
- with self.test_session():
- # sigma will be broadcast to [7, 7, 7]
- mu = [1., 2., 3.]
- sigma = [7.]
+ # sigma will be broadcast to [7, 7, 7]
+ mu = [1., 2., 3.]
+ sigma = [7.]
- normal = normal_lib.Normal(loc=mu, scale=sigma)
+ normal = normal_lib.Normal(loc=mu, scale=sigma)
- self.assertAllEqual((3,), normal.variance().get_shape())
- self.assertAllEqual([49., 49, 49], self.evaluate(normal.variance()))
+ self.assertAllEqual((3,), normal.variance().get_shape())
+ self.assertAllEqual([49., 49, 49], self.evaluate(normal.variance()))
@test_util.run_in_graph_and_eager_modes
def testNormalStandardDeviation(self):
- with self.test_session():
- # sigma will be broadcast to [7, 7, 7]
- mu = [1., 2., 3.]
- sigma = [7.]
+ # sigma will be broadcast to [7, 7, 7]
+ mu = [1., 2., 3.]
+ sigma = [7.]
- normal = normal_lib.Normal(loc=mu, scale=sigma)
+ normal = normal_lib.Normal(loc=mu, scale=sigma)
- self.assertAllEqual((3,), normal.stddev().get_shape())
- self.assertAllEqual([7., 7, 7], self.evaluate(normal.stddev()))
+ self.assertAllEqual((3,), normal.stddev().get_shape())
+ self.assertAllEqual([7., 7, 7], self.evaluate(normal.stddev()))
@test_util.run_in_graph_and_eager_modes
def testNormalSample(self):
- with self.test_session():
- mu = constant_op.constant(3.0)
- sigma = constant_op.constant(math.sqrt(3.0))
- mu_v = 3.0
- sigma_v = np.sqrt(3.0)
- n = constant_op.constant(100000)
- normal = normal_lib.Normal(loc=mu, scale=sigma)
- samples = normal.sample(n)
- sample_values = self.evaluate(samples)
- # Note that the standard error for the sample mean is ~ sigma / sqrt(n).
- # The sample variance similarly is dependent on sigma and n.
- # Thus, the tolerances below are very sensitive to number of samples
- # as well as the variances chosen.
- self.assertEqual(sample_values.shape, (100000,))
- self.assertAllClose(sample_values.mean(), mu_v, atol=1e-1)
- self.assertAllClose(sample_values.std(), sigma_v, atol=1e-1)
-
- expected_samples_shape = tensor_shape.TensorShape(
- [self.evaluate(n)]).concatenate(
- tensor_shape.TensorShape(
- self.evaluate(normal.batch_shape_tensor())))
-
- self.assertAllEqual(expected_samples_shape, samples.get_shape())
- self.assertAllEqual(expected_samples_shape, sample_values.shape)
-
- expected_samples_shape = (
- tensor_shape.TensorShape([self.evaluate(n)]).concatenate(
- normal.batch_shape))
-
- self.assertAllEqual(expected_samples_shape, samples.get_shape())
- self.assertAllEqual(expected_samples_shape, sample_values.shape)
+ mu = constant_op.constant(3.0)
+ sigma = constant_op.constant(math.sqrt(3.0))
+ mu_v = 3.0
+ sigma_v = np.sqrt(3.0)
+ n = constant_op.constant(100000)
+ normal = normal_lib.Normal(loc=mu, scale=sigma)
+ samples = normal.sample(n)
+ sample_values = self.evaluate(samples)
+ # Note that the standard error for the sample mean is ~ sigma / sqrt(n).
+ # The sample variance similarly is dependent on sigma and n.
+ # Thus, the tolerances below are very sensitive to number of samples
+ # as well as the variances chosen.
+ self.assertEqual(sample_values.shape, (100000,))
+ self.assertAllClose(sample_values.mean(), mu_v, atol=1e-1)
+ self.assertAllClose(sample_values.std(), sigma_v, atol=1e-1)
+
+ expected_samples_shape = tensor_shape.TensorShape(
+ [self.evaluate(n)]).concatenate(
+ tensor_shape.TensorShape(
+ self.evaluate(normal.batch_shape_tensor())))
+
+ self.assertAllEqual(expected_samples_shape, samples.get_shape())
+ self.assertAllEqual(expected_samples_shape, sample_values.shape)
+
+ expected_samples_shape = (
+ tensor_shape.TensorShape([self.evaluate(n)]).concatenate(
+ normal.batch_shape))
+
+ self.assertAllEqual(expected_samples_shape, samples.get_shape())
+ self.assertAllEqual(expected_samples_shape, sample_values.shape)
def testNormalFullyReparameterized(self):
mu = constant_op.constant(4.0)
@@ -468,66 +452,63 @@ class NormalTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testNormalSampleMultiDimensional(self):
- with self.test_session():
- batch_size = 2
- mu = constant_op.constant([[3.0, -3.0]] * batch_size)
- sigma = constant_op.constant([[math.sqrt(2.0), math.sqrt(3.0)]] *
- batch_size)
- mu_v = [3.0, -3.0]
- sigma_v = [np.sqrt(2.0), np.sqrt(3.0)]
- n = constant_op.constant(100000)
- normal = normal_lib.Normal(loc=mu, scale=sigma)
- samples = normal.sample(n)
- sample_values = self.evaluate(samples)
- # Note that the standard error for the sample mean is ~ sigma / sqrt(n).
- # The sample variance similarly is dependent on sigma and n.
- # Thus, the tolerances below are very sensitive to number of samples
- # as well as the variances chosen.
- self.assertEqual(samples.get_shape(), (100000, batch_size, 2))
- self.assertAllClose(sample_values[:, 0, 0].mean(), mu_v[0], atol=1e-1)
- self.assertAllClose(sample_values[:, 0, 0].std(), sigma_v[0], atol=1e-1)
- self.assertAllClose(sample_values[:, 0, 1].mean(), mu_v[1], atol=1e-1)
- self.assertAllClose(sample_values[:, 0, 1].std(), sigma_v[1], atol=1e-1)
-
- expected_samples_shape = tensor_shape.TensorShape(
- [self.evaluate(n)]).concatenate(
- tensor_shape.TensorShape(
- self.evaluate(normal.batch_shape_tensor())))
- self.assertAllEqual(expected_samples_shape, samples.get_shape())
- self.assertAllEqual(expected_samples_shape, sample_values.shape)
-
- expected_samples_shape = (
- tensor_shape.TensorShape([self.evaluate(n)]).concatenate(
- normal.batch_shape))
- self.assertAllEqual(expected_samples_shape, samples.get_shape())
- self.assertAllEqual(expected_samples_shape, sample_values.shape)
+ batch_size = 2
+ mu = constant_op.constant([[3.0, -3.0]] * batch_size)
+ sigma = constant_op.constant(
+ [[math.sqrt(2.0), math.sqrt(3.0)]] * batch_size)
+ mu_v = [3.0, -3.0]
+ sigma_v = [np.sqrt(2.0), np.sqrt(3.0)]
+ n = constant_op.constant(100000)
+ normal = normal_lib.Normal(loc=mu, scale=sigma)
+ samples = normal.sample(n)
+ sample_values = self.evaluate(samples)
+ # Note that the standard error for the sample mean is ~ sigma / sqrt(n).
+ # The sample variance similarly is dependent on sigma and n.
+ # Thus, the tolerances below are very sensitive to number of samples
+ # as well as the variances chosen.
+ self.assertEqual(samples.get_shape(), (100000, batch_size, 2))
+ self.assertAllClose(sample_values[:, 0, 0].mean(), mu_v[0], atol=1e-1)
+ self.assertAllClose(sample_values[:, 0, 0].std(), sigma_v[0], atol=1e-1)
+ self.assertAllClose(sample_values[:, 0, 1].mean(), mu_v[1], atol=1e-1)
+ self.assertAllClose(sample_values[:, 0, 1].std(), sigma_v[1], atol=1e-1)
+
+ expected_samples_shape = tensor_shape.TensorShape(
+ [self.evaluate(n)]).concatenate(
+ tensor_shape.TensorShape(
+ self.evaluate(normal.batch_shape_tensor())))
+ self.assertAllEqual(expected_samples_shape, samples.get_shape())
+ self.assertAllEqual(expected_samples_shape, sample_values.shape)
+
+ expected_samples_shape = (
+ tensor_shape.TensorShape([self.evaluate(n)]).concatenate(
+ normal.batch_shape))
+ self.assertAllEqual(expected_samples_shape, samples.get_shape())
+ self.assertAllEqual(expected_samples_shape, sample_values.shape)
@test_util.run_in_graph_and_eager_modes
def testNegativeSigmaFails(self):
- with self.test_session():
- with self.assertRaisesOpError("Condition x > 0 did not hold"):
- normal = normal_lib.Normal(
- loc=[1.], scale=[-5.], validate_args=True, name="G")
- self.evaluate(normal.mean())
+ with self.assertRaisesOpError("Condition x > 0 did not hold"):
+ normal = normal_lib.Normal(
+ loc=[1.], scale=[-5.], validate_args=True, name="G")
+ self.evaluate(normal.mean())
@test_util.run_in_graph_and_eager_modes
def testNormalShape(self):
- with self.test_session():
- mu = constant_op.constant([-3.0] * 5)
- sigma = constant_op.constant(11.0)
- normal = normal_lib.Normal(loc=mu, scale=sigma)
+ mu = constant_op.constant([-3.0] * 5)
+ sigma = constant_op.constant(11.0)
+ normal = normal_lib.Normal(loc=mu, scale=sigma)
- self.assertEqual(self.evaluate(normal.batch_shape_tensor()), [5])
- self.assertEqual(normal.batch_shape, tensor_shape.TensorShape([5]))
- self.assertAllEqual(self.evaluate(normal.event_shape_tensor()), [])
- self.assertEqual(normal.event_shape, tensor_shape.TensorShape([]))
+ self.assertEqual(self.evaluate(normal.batch_shape_tensor()), [5])
+ self.assertEqual(normal.batch_shape, tensor_shape.TensorShape([5]))
+ self.assertAllEqual(self.evaluate(normal.event_shape_tensor()), [])
+ self.assertEqual(normal.event_shape, tensor_shape.TensorShape([]))
def testNormalShapeWithPlaceholders(self):
mu = array_ops.placeholder(dtype=dtypes.float32)
sigma = array_ops.placeholder(dtype=dtypes.float32)
normal = normal_lib.Normal(loc=mu, scale=sigma)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# get_batch_shape should return an "<unknown>" tensor.
self.assertEqual(normal.batch_shape, tensor_shape.TensorShape(None))
self.assertEqual(normal.event_shape, ())
diff --git a/tensorflow/python/kernel_tests/distributions/special_math_test.py b/tensorflow/python/kernel_tests/distributions/special_math_test.py
index a634194ce5..cc43e12168 100644
--- a/tensorflow/python/kernel_tests/distributions/special_math_test.py
+++ b/tensorflow/python/kernel_tests/distributions/special_math_test.py
@@ -92,22 +92,21 @@ class NdtriTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testNdtri(self):
"""Verifies that ndtri computation is correct."""
- with self.test_session():
- if not special:
- return
+ if not special:
+ return
- p = np.linspace(0., 1.0, 50).astype(np.float64)
- # Quantile performs piecewise rational approximation so adding some
- # special input values to make sure we hit all the pieces.
- p = np.hstack((p, np.exp(-32), 1. - np.exp(-32),
- np.exp(-2), 1. - np.exp(-2)))
- expected_x = special.ndtri(p)
- x = special_math.ndtri(p)
- self.assertAllClose(expected_x, self.evaluate(x), atol=0.)
+ p = np.linspace(0., 1.0, 50).astype(np.float64)
+ # Quantile performs piecewise rational approximation so adding some
+ # special input values to make sure we hit all the pieces.
+ p = np.hstack((p, np.exp(-32), 1. - np.exp(-32), np.exp(-2),
+ 1. - np.exp(-2)))
+ expected_x = special.ndtri(p)
+ x = special_math.ndtri(p)
+ self.assertAllClose(expected_x, self.evaluate(x), atol=0.)
def testNdtriDynamicShape(self):
"""Verifies that ndtri computation is correct."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
if not special:
return
@@ -286,7 +285,7 @@ class NdtrGradientTest(test.TestCase):
def _test_grad_accuracy(self, dtype, grid_spec, error_spec):
raw_grid = _make_grid(dtype, grid_spec)
grid = ops.convert_to_tensor(raw_grid)
- with self.test_session():
+ with self.cached_session():
fn = sm.log_ndtr if self._use_log else sm.ndtr
# If there are N points in the grid,
@@ -355,7 +354,7 @@ class LogNdtrGradientTest(NdtrGradientTest):
class ErfInvTest(test.TestCase):
def testErfInvValues(self):
- with self.test_session():
+ with self.cached_session():
if not special:
return
@@ -366,7 +365,7 @@ class ErfInvTest(test.TestCase):
self.assertAllClose(expected_x, x.eval(), atol=0.)
def testErfInvIntegerInput(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(TypeError):
x = np.array([1, 2, 3]).astype(np.int32)
@@ -397,7 +396,7 @@ class LogCDFLaplaceTest(test.TestCase):
self.assertAllEqual(np.ones_like(x, dtype=np.bool), x)
def _test_grid_log(self, dtype, scipy_dtype, grid_spec, error_spec):
- with self.test_session():
+ with self.cached_session():
grid = _make_grid(dtype, grid_spec)
actual = sm.log_cdf_laplace(grid).eval()
@@ -439,7 +438,7 @@ class LogCDFLaplaceTest(test.TestCase):
ErrorSpec(rtol=0.05, atol=0))
def test_float32_extreme_values_result_and_gradient_finite_and_nonzero(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# On the lower branch, log_cdf_laplace(x) = x, so we know this will be
# fine, but test to -200 anyways.
grid = _make_grid(
@@ -458,7 +457,7 @@ class LogCDFLaplaceTest(test.TestCase):
self.assertFalse(np.any(grad_ == 0))
def test_float64_extreme_values_result_and_gradient_finite_and_nonzero(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# On the lower branch, log_cdf_laplace(x) = x, so we know this will be
# fine, but test to -200 anyways.
grid = _make_grid(
diff --git a/tensorflow/python/kernel_tests/distributions/student_t_test.py b/tensorflow/python/kernel_tests/distributions/student_t_test.py
index 05590542ef..b34b538160 100644
--- a/tensorflow/python/kernel_tests/distributions/student_t_test.py
+++ b/tensorflow/python/kernel_tests/distributions/student_t_test.py
@@ -50,100 +50,96 @@ stats = try_import("scipy.stats")
class StudentTTest(test.TestCase):
def testStudentPDFAndLogPDF(self):
- with self.test_session():
- batch_size = 6
- df = constant_op.constant([3.] * batch_size)
- mu = constant_op.constant([7.] * batch_size)
- sigma = constant_op.constant([8.] * batch_size)
- df_v = 3.
- mu_v = 7.
- sigma_v = 8.
- t = np.array([-2.5, 2.5, 8., 0., -1., 2.], dtype=np.float32)
- student = student_t.StudentT(df, loc=mu, scale=-sigma)
-
- log_pdf = student.log_prob(t)
- self.assertEquals(log_pdf.get_shape(), (6,))
- log_pdf_values = self.evaluate(log_pdf)
- pdf = student.prob(t)
- self.assertEquals(pdf.get_shape(), (6,))
- pdf_values = self.evaluate(pdf)
-
- if not stats:
- return
-
- expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v)
- expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v)
- self.assertAllClose(expected_log_pdf, log_pdf_values)
- self.assertAllClose(np.log(expected_pdf), log_pdf_values)
- self.assertAllClose(expected_pdf, pdf_values)
- self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
+ batch_size = 6
+ df = constant_op.constant([3.] * batch_size)
+ mu = constant_op.constant([7.] * batch_size)
+ sigma = constant_op.constant([8.] * batch_size)
+ df_v = 3.
+ mu_v = 7.
+ sigma_v = 8.
+ t = np.array([-2.5, 2.5, 8., 0., -1., 2.], dtype=np.float32)
+ student = student_t.StudentT(df, loc=mu, scale=-sigma)
+
+ log_pdf = student.log_prob(t)
+ self.assertEquals(log_pdf.get_shape(), (6,))
+ log_pdf_values = self.evaluate(log_pdf)
+ pdf = student.prob(t)
+ self.assertEquals(pdf.get_shape(), (6,))
+ pdf_values = self.evaluate(pdf)
+
+ if not stats:
+ return
+
+ expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v)
+ expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v)
+ self.assertAllClose(expected_log_pdf, log_pdf_values)
+ self.assertAllClose(np.log(expected_pdf), log_pdf_values)
+ self.assertAllClose(expected_pdf, pdf_values)
+ self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
def testStudentLogPDFMultidimensional(self):
- with self.test_session():
- batch_size = 6
- df = constant_op.constant([[1.5, 7.2]] * batch_size)
- mu = constant_op.constant([[3., -3.]] * batch_size)
- sigma = constant_op.constant([[-math.sqrt(10.), math.sqrt(15.)]] *
- batch_size)
- df_v = np.array([1.5, 7.2])
- mu_v = np.array([3., -3.])
- sigma_v = np.array([np.sqrt(10.), np.sqrt(15.)])
- t = np.array([[-2.5, 2.5, 4., 0., -1., 2.]], dtype=np.float32).T
- student = student_t.StudentT(df, loc=mu, scale=sigma)
- log_pdf = student.log_prob(t)
- log_pdf_values = self.evaluate(log_pdf)
- self.assertEqual(log_pdf.get_shape(), (6, 2))
- pdf = student.prob(t)
- pdf_values = self.evaluate(pdf)
- self.assertEqual(pdf.get_shape(), (6, 2))
-
- if not stats:
- return
- expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v)
- expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v)
- self.assertAllClose(expected_log_pdf, log_pdf_values)
- self.assertAllClose(np.log(expected_pdf), log_pdf_values)
- self.assertAllClose(expected_pdf, pdf_values)
- self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
+ batch_size = 6
+ df = constant_op.constant([[1.5, 7.2]] * batch_size)
+ mu = constant_op.constant([[3., -3.]] * batch_size)
+ sigma = constant_op.constant(
+ [[-math.sqrt(10.), math.sqrt(15.)]] * batch_size)
+ df_v = np.array([1.5, 7.2])
+ mu_v = np.array([3., -3.])
+ sigma_v = np.array([np.sqrt(10.), np.sqrt(15.)])
+ t = np.array([[-2.5, 2.5, 4., 0., -1., 2.]], dtype=np.float32).T
+ student = student_t.StudentT(df, loc=mu, scale=sigma)
+ log_pdf = student.log_prob(t)
+ log_pdf_values = self.evaluate(log_pdf)
+ self.assertEqual(log_pdf.get_shape(), (6, 2))
+ pdf = student.prob(t)
+ pdf_values = self.evaluate(pdf)
+ self.assertEqual(pdf.get_shape(), (6, 2))
+
+ if not stats:
+ return
+ expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v)
+ expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v)
+ self.assertAllClose(expected_log_pdf, log_pdf_values)
+ self.assertAllClose(np.log(expected_pdf), log_pdf_values)
+ self.assertAllClose(expected_pdf, pdf_values)
+ self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
def testStudentCDFAndLogCDF(self):
- with self.test_session():
- batch_size = 6
- df = constant_op.constant([3.] * batch_size)
- mu = constant_op.constant([7.] * batch_size)
- sigma = constant_op.constant([-8.] * batch_size)
- df_v = 3.
- mu_v = 7.
- sigma_v = 8.
- t = np.array([-2.5, 2.5, 8., 0., -1., 2.], dtype=np.float32)
- student = student_t.StudentT(df, loc=mu, scale=sigma)
-
- log_cdf = student.log_cdf(t)
- self.assertEquals(log_cdf.get_shape(), (6,))
- log_cdf_values = self.evaluate(log_cdf)
- cdf = student.cdf(t)
- self.assertEquals(cdf.get_shape(), (6,))
- cdf_values = self.evaluate(cdf)
-
- if not stats:
- return
- expected_log_cdf = stats.t.logcdf(t, df_v, loc=mu_v, scale=sigma_v)
- expected_cdf = stats.t.cdf(t, df_v, loc=mu_v, scale=sigma_v)
- self.assertAllClose(expected_log_cdf, log_cdf_values, atol=0., rtol=1e-5)
- self.assertAllClose(
- np.log(expected_cdf), log_cdf_values, atol=0., rtol=1e-5)
- self.assertAllClose(expected_cdf, cdf_values, atol=0., rtol=1e-5)
- self.assertAllClose(
- np.exp(expected_log_cdf), cdf_values, atol=0., rtol=1e-5)
+ batch_size = 6
+ df = constant_op.constant([3.] * batch_size)
+ mu = constant_op.constant([7.] * batch_size)
+ sigma = constant_op.constant([-8.] * batch_size)
+ df_v = 3.
+ mu_v = 7.
+ sigma_v = 8.
+ t = np.array([-2.5, 2.5, 8., 0., -1., 2.], dtype=np.float32)
+ student = student_t.StudentT(df, loc=mu, scale=sigma)
+
+ log_cdf = student.log_cdf(t)
+ self.assertEquals(log_cdf.get_shape(), (6,))
+ log_cdf_values = self.evaluate(log_cdf)
+ cdf = student.cdf(t)
+ self.assertEquals(cdf.get_shape(), (6,))
+ cdf_values = self.evaluate(cdf)
+
+ if not stats:
+ return
+ expected_log_cdf = stats.t.logcdf(t, df_v, loc=mu_v, scale=sigma_v)
+ expected_cdf = stats.t.cdf(t, df_v, loc=mu_v, scale=sigma_v)
+ self.assertAllClose(expected_log_cdf, log_cdf_values, atol=0., rtol=1e-5)
+ self.assertAllClose(
+ np.log(expected_cdf), log_cdf_values, atol=0., rtol=1e-5)
+ self.assertAllClose(expected_cdf, cdf_values, atol=0., rtol=1e-5)
+ self.assertAllClose(
+ np.exp(expected_log_cdf), cdf_values, atol=0., rtol=1e-5)
def testStudentEntropy(self):
df_v = np.array([[2., 3., 7.]]) # 1x3
mu_v = np.array([[1., -1, 0]]) # 1x3
sigma_v = np.array([[1., -2., 3.]]).T # transposed => 3x1
- with self.test_session():
- student = student_t.StudentT(df=df_v, loc=mu_v, scale=sigma_v)
- ent = student.entropy()
- ent_values = self.evaluate(ent)
+ student = student_t.StudentT(df=df_v, loc=mu_v, scale=sigma_v)
+ ent = student.entropy()
+ ent_values = self.evaluate(ent)
# Help scipy broadcast to 3x3
ones = np.array([[1, 1, 1]])
@@ -160,90 +156,81 @@ class StudentTTest(test.TestCase):
self.assertAllClose(expected_entropy, ent_values)
def testStudentSample(self):
- with self.test_session():
- df = constant_op.constant(4.)
- mu = constant_op.constant(3.)
- sigma = constant_op.constant(-math.sqrt(10.))
- df_v = 4.
- mu_v = 3.
- sigma_v = np.sqrt(10.)
- n = constant_op.constant(200000)
- student = student_t.StudentT(df=df, loc=mu, scale=sigma)
- samples = student.sample(n, seed=123456)
- sample_values = self.evaluate(samples)
- n_val = 200000
- self.assertEqual(sample_values.shape, (n_val,))
- self.assertAllClose(sample_values.mean(), mu_v, rtol=0.1, atol=0)
- self.assertAllClose(
- sample_values.var(),
- sigma_v**2 * df_v / (df_v - 2),
- rtol=0.1,
- atol=0)
- self._checkKLApprox(df_v, mu_v, sigma_v, sample_values)
+ df = constant_op.constant(4.)
+ mu = constant_op.constant(3.)
+ sigma = constant_op.constant(-math.sqrt(10.))
+ df_v = 4.
+ mu_v = 3.
+ sigma_v = np.sqrt(10.)
+ n = constant_op.constant(200000)
+ student = student_t.StudentT(df=df, loc=mu, scale=sigma)
+ samples = student.sample(n, seed=123456)
+ sample_values = self.evaluate(samples)
+ n_val = 200000
+ self.assertEqual(sample_values.shape, (n_val,))
+ self.assertAllClose(sample_values.mean(), mu_v, rtol=0.1, atol=0)
+ self.assertAllClose(
+ sample_values.var(), sigma_v**2 * df_v / (df_v - 2), rtol=0.1, atol=0)
+ self._checkKLApprox(df_v, mu_v, sigma_v, sample_values)
# Test that sampling with the same seed twice gives the same results.
def testStudentSampleMultipleTimes(self):
- with self.test_session():
- df = constant_op.constant(4.)
- mu = constant_op.constant(3.)
- sigma = constant_op.constant(math.sqrt(10.))
- n = constant_op.constant(100)
+ df = constant_op.constant(4.)
+ mu = constant_op.constant(3.)
+ sigma = constant_op.constant(math.sqrt(10.))
+ n = constant_op.constant(100)
- random_seed.set_random_seed(654321)
- student = student_t.StudentT(
- df=df, loc=mu, scale=sigma, name="student_t1")
- samples1 = self.evaluate(student.sample(n, seed=123456))
+ random_seed.set_random_seed(654321)
+ student = student_t.StudentT(df=df, loc=mu, scale=sigma, name="student_t1")
+ samples1 = self.evaluate(student.sample(n, seed=123456))
- random_seed.set_random_seed(654321)
- student2 = student_t.StudentT(
- df=df, loc=mu, scale=sigma, name="student_t2")
- samples2 = self.evaluate(student2.sample(n, seed=123456))
+ random_seed.set_random_seed(654321)
+ student2 = student_t.StudentT(df=df, loc=mu, scale=sigma, name="student_t2")
+ samples2 = self.evaluate(student2.sample(n, seed=123456))
- self.assertAllClose(samples1, samples2)
+ self.assertAllClose(samples1, samples2)
def testStudentSampleSmallDfNoNan(self):
- with self.test_session():
- df_v = [1e-1, 1e-5, 1e-10, 1e-20]
- df = constant_op.constant(df_v)
- n = constant_op.constant(200000)
- student = student_t.StudentT(df=df, loc=1., scale=1.)
- samples = student.sample(n, seed=123456)
- sample_values = self.evaluate(samples)
- n_val = 200000
- self.assertEqual(sample_values.shape, (n_val, 4))
- self.assertTrue(np.all(np.logical_not(np.isnan(sample_values))))
+ df_v = [1e-1, 1e-5, 1e-10, 1e-20]
+ df = constant_op.constant(df_v)
+ n = constant_op.constant(200000)
+ student = student_t.StudentT(df=df, loc=1., scale=1.)
+ samples = student.sample(n, seed=123456)
+ sample_values = self.evaluate(samples)
+ n_val = 200000
+ self.assertEqual(sample_values.shape, (n_val, 4))
+ self.assertTrue(np.all(np.logical_not(np.isnan(sample_values))))
def testStudentSampleMultiDimensional(self):
- with self.test_session():
- batch_size = 7
- df = constant_op.constant([[5., 7.]] * batch_size)
- mu = constant_op.constant([[3., -3.]] * batch_size)
- sigma = constant_op.constant([[math.sqrt(10.), math.sqrt(15.)]] *
- batch_size)
- df_v = [5., 7.]
- mu_v = [3., -3.]
- sigma_v = [np.sqrt(10.), np.sqrt(15.)]
- n = constant_op.constant(200000)
- student = student_t.StudentT(df=df, loc=mu, scale=sigma)
- samples = student.sample(n, seed=123456)
- sample_values = self.evaluate(samples)
- self.assertEqual(samples.get_shape(), (200000, batch_size, 2))
- self.assertAllClose(
- sample_values[:, 0, 0].mean(), mu_v[0], rtol=0.1, atol=0)
- self.assertAllClose(
- sample_values[:, 0, 0].var(),
- sigma_v[0]**2 * df_v[0] / (df_v[0] - 2),
- rtol=0.2,
- atol=0)
- self._checkKLApprox(df_v[0], mu_v[0], sigma_v[0], sample_values[:, 0, 0])
- self.assertAllClose(
- sample_values[:, 0, 1].mean(), mu_v[1], rtol=0.1, atol=0)
- self.assertAllClose(
- sample_values[:, 0, 1].var(),
- sigma_v[1]**2 * df_v[1] / (df_v[1] - 2),
- rtol=0.2,
- atol=0)
- self._checkKLApprox(df_v[1], mu_v[1], sigma_v[1], sample_values[:, 0, 1])
+ batch_size = 7
+ df = constant_op.constant([[5., 7.]] * batch_size)
+ mu = constant_op.constant([[3., -3.]] * batch_size)
+ sigma = constant_op.constant(
+ [[math.sqrt(10.), math.sqrt(15.)]] * batch_size)
+ df_v = [5., 7.]
+ mu_v = [3., -3.]
+ sigma_v = [np.sqrt(10.), np.sqrt(15.)]
+ n = constant_op.constant(200000)
+ student = student_t.StudentT(df=df, loc=mu, scale=sigma)
+ samples = student.sample(n, seed=123456)
+ sample_values = self.evaluate(samples)
+ self.assertEqual(samples.get_shape(), (200000, batch_size, 2))
+ self.assertAllClose(
+ sample_values[:, 0, 0].mean(), mu_v[0], rtol=0.1, atol=0)
+ self.assertAllClose(
+ sample_values[:, 0, 0].var(),
+ sigma_v[0]**2 * df_v[0] / (df_v[0] - 2),
+ rtol=0.2,
+ atol=0)
+ self._checkKLApprox(df_v[0], mu_v[0], sigma_v[0], sample_values[:, 0, 0])
+ self.assertAllClose(
+ sample_values[:, 0, 1].mean(), mu_v[1], rtol=0.1, atol=0)
+ self.assertAllClose(
+ sample_values[:, 0, 1].var(),
+ sigma_v[1]**2 * df_v[1] / (df_v[1] - 2),
+ rtol=0.2,
+ atol=0)
+ self._checkKLApprox(df_v[1], mu_v[1], sigma_v[1], sample_values[:, 0, 1])
def _checkKLApprox(self, df, mu, sigma, samples):
n = samples.size
@@ -325,114 +312,102 @@ class StudentTTest(test.TestCase):
_check2d_rows(student_t.StudentT(df=7., loc=3., scale=[[2.], [3.], [4.]]))
def testMeanAllowNanStatsIsFalseWorksWhenAllBatchMembersAreDefined(self):
- with self.test_session():
- mu = [1., 3.3, 4.4]
- student = student_t.StudentT(df=[3., 5., 7.], loc=mu, scale=[3., 2., 1.])
- mean = self.evaluate(student.mean())
- self.assertAllClose([1., 3.3, 4.4], mean)
+ mu = [1., 3.3, 4.4]
+ student = student_t.StudentT(df=[3., 5., 7.], loc=mu, scale=[3., 2., 1.])
+ mean = self.evaluate(student.mean())
+ self.assertAllClose([1., 3.3, 4.4], mean)
def testMeanAllowNanStatsIsFalseRaisesWhenBatchMemberIsUndefined(self):
- with self.test_session():
- mu = [1., 3.3, 4.4]
- student = student_t.StudentT(
- df=[0.5, 5., 7.], loc=mu, scale=[3., 2., 1.],
- allow_nan_stats=False)
- with self.assertRaisesOpError("x < y"):
- self.evaluate(student.mean())
+ mu = [1., 3.3, 4.4]
+ student = student_t.StudentT(
+ df=[0.5, 5., 7.], loc=mu, scale=[3., 2., 1.], allow_nan_stats=False)
+ with self.assertRaisesOpError("x < y"):
+ self.evaluate(student.mean())
def testMeanAllowNanStatsIsTrueReturnsNaNForUndefinedBatchMembers(self):
- with self.test_session():
- mu = [-2, 0., 1., 3.3, 4.4]
- sigma = [5., 4., 3., 2., 1.]
- student = student_t.StudentT(
- df=[0.5, 1., 3., 5., 7.], loc=mu, scale=sigma,
- allow_nan_stats=True)
- mean = self.evaluate(student.mean())
- self.assertAllClose([np.nan, np.nan, 1., 3.3, 4.4], mean)
+ mu = [-2, 0., 1., 3.3, 4.4]
+ sigma = [5., 4., 3., 2., 1.]
+ student = student_t.StudentT(
+ df=[0.5, 1., 3., 5., 7.], loc=mu, scale=sigma, allow_nan_stats=True)
+ mean = self.evaluate(student.mean())
+ self.assertAllClose([np.nan, np.nan, 1., 3.3, 4.4], mean)
def testVarianceAllowNanStatsTrueReturnsNaNforUndefinedBatchMembers(self):
- with self.test_session():
- # df = 0.5 ==> undefined mean ==> undefined variance.
- # df = 1.5 ==> infinite variance.
- df = [0.5, 1.5, 3., 5., 7.]
- mu = [-2, 0., 1., 3.3, 4.4]
- sigma = [5., 4., 3., 2., 1.]
- student = student_t.StudentT(
- df=df, loc=mu, scale=sigma, allow_nan_stats=True)
- var = self.evaluate(student.variance())
- ## scipy uses inf for variance when the mean is undefined. When mean is
- # undefined we say variance is undefined as well. So test the first
- # member of var, making sure it is NaN, then replace with inf and compare
- # to scipy.
- self.assertTrue(np.isnan(var[0]))
- var[0] = np.inf
-
- if not stats:
- return
- expected_var = [
- stats.t.var(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma)
- ]
- self.assertAllClose(expected_var, var)
+ # df = 0.5 ==> undefined mean ==> undefined variance.
+ # df = 1.5 ==> infinite variance.
+ df = [0.5, 1.5, 3., 5., 7.]
+ mu = [-2, 0., 1., 3.3, 4.4]
+ sigma = [5., 4., 3., 2., 1.]
+ student = student_t.StudentT(
+ df=df, loc=mu, scale=sigma, allow_nan_stats=True)
+ var = self.evaluate(student.variance())
+ ## scipy uses inf for variance when the mean is undefined. When mean is
+ # undefined we say variance is undefined as well. So test the first
+ # member of var, making sure it is NaN, then replace with inf and compare
+ # to scipy.
+ self.assertTrue(np.isnan(var[0]))
+ var[0] = np.inf
+
+ if not stats:
+ return
+ expected_var = [
+ stats.t.var(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma)
+ ]
+ self.assertAllClose(expected_var, var)
def testVarianceAllowNanStatsFalseGivesCorrectValueForDefinedBatchMembers(
self):
- with self.test_session():
- # df = 1.5 ==> infinite variance.
- df = [1.5, 3., 5., 7.]
- mu = [0., 1., 3.3, 4.4]
- sigma = [4., 3., 2., 1.]
- student = student_t.StudentT(df=df, loc=mu, scale=sigma)
- var = self.evaluate(student.variance())
+ # df = 1.5 ==> infinite variance.
+ df = [1.5, 3., 5., 7.]
+ mu = [0., 1., 3.3, 4.4]
+ sigma = [4., 3., 2., 1.]
+ student = student_t.StudentT(df=df, loc=mu, scale=sigma)
+ var = self.evaluate(student.variance())
- if not stats:
- return
- expected_var = [
- stats.t.var(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma)
- ]
- self.assertAllClose(expected_var, var)
+ if not stats:
+ return
+ expected_var = [
+ stats.t.var(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma)
+ ]
+ self.assertAllClose(expected_var, var)
def testVarianceAllowNanStatsFalseRaisesForUndefinedBatchMembers(self):
- with self.test_session():
- # df <= 1 ==> variance not defined
- student = student_t.StudentT(
- df=1., loc=0., scale=1., allow_nan_stats=False)
- with self.assertRaisesOpError("x < y"):
- self.evaluate(student.variance())
+ # df <= 1 ==> variance not defined
+ student = student_t.StudentT(df=1., loc=0., scale=1., allow_nan_stats=False)
+ with self.assertRaisesOpError("x < y"):
+ self.evaluate(student.variance())
- with self.test_session():
- # df <= 1 ==> variance not defined
- student = student_t.StudentT(
- df=0.5, loc=0., scale=1., allow_nan_stats=False)
- with self.assertRaisesOpError("x < y"):
- self.evaluate(student.variance())
+ # df <= 1 ==> variance not defined
+ student = student_t.StudentT(
+ df=0.5, loc=0., scale=1., allow_nan_stats=False)
+ with self.assertRaisesOpError("x < y"):
+ self.evaluate(student.variance())
def testStd(self):
- with self.test_session():
- # Defined for all batch members.
- df = [3.5, 5., 3., 5., 7.]
- mu = [-2.2]
- sigma = [5., 4., 3., 2., 1.]
- student = student_t.StudentT(df=df, loc=mu, scale=sigma)
- # Test broadcast of mu across shape of df/sigma
- stddev = self.evaluate(student.stddev())
- mu *= len(df)
+ # Defined for all batch members.
+ df = [3.5, 5., 3., 5., 7.]
+ mu = [-2.2]
+ sigma = [5., 4., 3., 2., 1.]
+ student = student_t.StudentT(df=df, loc=mu, scale=sigma)
+ # Test broadcast of mu across shape of df/sigma
+ stddev = self.evaluate(student.stddev())
+ mu *= len(df)
- if not stats:
- return
- expected_stddev = [
- stats.t.std(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma)
- ]
- self.assertAllClose(expected_stddev, stddev)
+ if not stats:
+ return
+ expected_stddev = [
+ stats.t.std(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma)
+ ]
+ self.assertAllClose(expected_stddev, stddev)
def testMode(self):
- with self.test_session():
- df = [0.5, 1., 3]
- mu = [-1, 0., 1]
- sigma = [5., 4., 3.]
- student = student_t.StudentT(df=df, loc=mu, scale=sigma)
- # Test broadcast of mu across shape of df/sigma
- mode = self.evaluate(student.mode())
- self.assertAllClose([-1., 0, 1], mode)
+ df = [0.5, 1., 3]
+ mu = [-1, 0., 1]
+ sigma = [5., 4., 3.]
+ student = student_t.StudentT(df=df, loc=mu, scale=sigma)
+ # Test broadcast of mu across shape of df/sigma
+ mode = self.evaluate(student.mode())
+ self.assertAllClose([-1., 0, 1], mode)
def testPdfOfSample(self):
student = student_t.StudentT(df=3., loc=np.pi, scale=1.)
@@ -510,25 +485,23 @@ class StudentTTest(test.TestCase):
self.assertNear(1., total, err=err)
def testNegativeDofFails(self):
- with self.test_session():
- with self.assertRaisesOpError(r"Condition x > 0 did not hold"):
- student = student_t.StudentT(
- df=[2, -5.], loc=0., scale=1., validate_args=True, name="S")
- self.evaluate(student.mean())
+ with self.assertRaisesOpError(r"Condition x > 0 did not hold"):
+ student = student_t.StudentT(
+ df=[2, -5.], loc=0., scale=1., validate_args=True, name="S")
+ self.evaluate(student.mean())
def testStudentTWithAbsDfSoftplusScale(self):
- with self.test_session():
- df = constant_op.constant([-3.2, -4.6])
- mu = constant_op.constant([-4.2, 3.4])
- sigma = constant_op.constant([-6.4, -8.8])
- student = student_t.StudentTWithAbsDfSoftplusScale(
- df=df, loc=mu, scale=sigma)
- self.assertAllClose(
- math_ops.floor(self.evaluate(math_ops.abs(df))),
- self.evaluate(student.df))
- self.assertAllClose(self.evaluate(mu), self.evaluate(student.loc))
- self.assertAllClose(
- self.evaluate(nn_ops.softplus(sigma)), self.evaluate(student.scale))
+ df = constant_op.constant([-3.2, -4.6])
+ mu = constant_op.constant([-4.2, 3.4])
+ sigma = constant_op.constant([-6.4, -8.8])
+ student = student_t.StudentTWithAbsDfSoftplusScale(
+ df=df, loc=mu, scale=sigma)
+ self.assertAllClose(
+ math_ops.floor(self.evaluate(math_ops.abs(df))),
+ self.evaluate(student.df))
+ self.assertAllClose(self.evaluate(mu), self.evaluate(student.loc))
+ self.assertAllClose(
+ self.evaluate(nn_ops.softplus(sigma)), self.evaluate(student.scale))
if __name__ == "__main__":
diff --git a/tensorflow/python/kernel_tests/distributions/uniform_test.py b/tensorflow/python/kernel_tests/distributions/uniform_test.py
index bc9c267b9a..9cdcd369c1 100644
--- a/tensorflow/python/kernel_tests/distributions/uniform_test.py
+++ b/tensorflow/python/kernel_tests/distributions/uniform_test.py
@@ -50,255 +50,239 @@ class UniformTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testUniformRange(self):
- with self.test_session():
- a = 3.0
- b = 10.0
- uniform = uniform_lib.Uniform(low=a, high=b)
- self.assertAllClose(a, self.evaluate(uniform.low))
- self.assertAllClose(b, self.evaluate(uniform.high))
- self.assertAllClose(b - a, self.evaluate(uniform.range()))
+ a = 3.0
+ b = 10.0
+ uniform = uniform_lib.Uniform(low=a, high=b)
+ self.assertAllClose(a, self.evaluate(uniform.low))
+ self.assertAllClose(b, self.evaluate(uniform.high))
+ self.assertAllClose(b - a, self.evaluate(uniform.range()))
@test_util.run_in_graph_and_eager_modes
def testUniformPDF(self):
- with self.test_session():
- a = constant_op.constant([-3.0] * 5 + [15.0])
- b = constant_op.constant([11.0] * 5 + [20.0])
- uniform = uniform_lib.Uniform(low=a, high=b)
+ a = constant_op.constant([-3.0] * 5 + [15.0])
+ b = constant_op.constant([11.0] * 5 + [20.0])
+ uniform = uniform_lib.Uniform(low=a, high=b)
- a_v = -3.0
- b_v = 11.0
- x = np.array([-10.5, 4.0, 0.0, 10.99, 11.3, 17.0], dtype=np.float32)
+ a_v = -3.0
+ b_v = 11.0
+ x = np.array([-10.5, 4.0, 0.0, 10.99, 11.3, 17.0], dtype=np.float32)
- def _expected_pdf():
- pdf = np.zeros_like(x) + 1.0 / (b_v - a_v)
- pdf[x > b_v] = 0.0
- pdf[x < a_v] = 0.0
- pdf[5] = 1.0 / (20.0 - 15.0)
- return pdf
+ def _expected_pdf():
+ pdf = np.zeros_like(x) + 1.0 / (b_v - a_v)
+ pdf[x > b_v] = 0.0
+ pdf[x < a_v] = 0.0
+ pdf[5] = 1.0 / (20.0 - 15.0)
+ return pdf
- expected_pdf = _expected_pdf()
+ expected_pdf = _expected_pdf()
- pdf = uniform.prob(x)
- self.assertAllClose(expected_pdf, self.evaluate(pdf))
+ pdf = uniform.prob(x)
+ self.assertAllClose(expected_pdf, self.evaluate(pdf))
- log_pdf = uniform.log_prob(x)
- self.assertAllClose(np.log(expected_pdf), self.evaluate(log_pdf))
+ log_pdf = uniform.log_prob(x)
+ self.assertAllClose(np.log(expected_pdf), self.evaluate(log_pdf))
@test_util.run_in_graph_and_eager_modes
def testUniformShape(self):
- with self.test_session():
- a = constant_op.constant([-3.0] * 5)
- b = constant_op.constant(11.0)
- uniform = uniform_lib.Uniform(low=a, high=b)
+ a = constant_op.constant([-3.0] * 5)
+ b = constant_op.constant(11.0)
+ uniform = uniform_lib.Uniform(low=a, high=b)
- self.assertEqual(self.evaluate(uniform.batch_shape_tensor()), (5,))
- self.assertEqual(uniform.batch_shape, tensor_shape.TensorShape([5]))
- self.assertAllEqual(self.evaluate(uniform.event_shape_tensor()), [])
- self.assertEqual(uniform.event_shape, tensor_shape.TensorShape([]))
+ self.assertEqual(self.evaluate(uniform.batch_shape_tensor()), (5,))
+ self.assertEqual(uniform.batch_shape, tensor_shape.TensorShape([5]))
+ self.assertAllEqual(self.evaluate(uniform.event_shape_tensor()), [])
+ self.assertEqual(uniform.event_shape, tensor_shape.TensorShape([]))
@test_util.run_in_graph_and_eager_modes
def testUniformPDFWithScalarEndpoint(self):
- with self.test_session():
- a = constant_op.constant([0.0, 5.0])
- b = constant_op.constant(10.0)
- uniform = uniform_lib.Uniform(low=a, high=b)
+ a = constant_op.constant([0.0, 5.0])
+ b = constant_op.constant(10.0)
+ uniform = uniform_lib.Uniform(low=a, high=b)
- x = np.array([0.0, 8.0], dtype=np.float32)
- expected_pdf = np.array([1.0 / (10.0 - 0.0), 1.0 / (10.0 - 5.0)])
+ x = np.array([0.0, 8.0], dtype=np.float32)
+ expected_pdf = np.array([1.0 / (10.0 - 0.0), 1.0 / (10.0 - 5.0)])
- pdf = uniform.prob(x)
- self.assertAllClose(expected_pdf, self.evaluate(pdf))
+ pdf = uniform.prob(x)
+ self.assertAllClose(expected_pdf, self.evaluate(pdf))
@test_util.run_in_graph_and_eager_modes
def testUniformCDF(self):
- with self.test_session():
- batch_size = 6
- a = constant_op.constant([1.0] * batch_size)
- b = constant_op.constant([11.0] * batch_size)
- a_v = 1.0
- b_v = 11.0
- x = np.array([-2.5, 2.5, 4.0, 0.0, 10.99, 12.0], dtype=np.float32)
+ batch_size = 6
+ a = constant_op.constant([1.0] * batch_size)
+ b = constant_op.constant([11.0] * batch_size)
+ a_v = 1.0
+ b_v = 11.0
+ x = np.array([-2.5, 2.5, 4.0, 0.0, 10.99, 12.0], dtype=np.float32)
- uniform = uniform_lib.Uniform(low=a, high=b)
+ uniform = uniform_lib.Uniform(low=a, high=b)
- def _expected_cdf():
- cdf = (x - a_v) / (b_v - a_v)
- cdf[x >= b_v] = 1
- cdf[x < a_v] = 0
- return cdf
+ def _expected_cdf():
+ cdf = (x - a_v) / (b_v - a_v)
+ cdf[x >= b_v] = 1
+ cdf[x < a_v] = 0
+ return cdf
- cdf = uniform.cdf(x)
- self.assertAllClose(_expected_cdf(), self.evaluate(cdf))
+ cdf = uniform.cdf(x)
+ self.assertAllClose(_expected_cdf(), self.evaluate(cdf))
- log_cdf = uniform.log_cdf(x)
- self.assertAllClose(np.log(_expected_cdf()), self.evaluate(log_cdf))
+ log_cdf = uniform.log_cdf(x)
+ self.assertAllClose(np.log(_expected_cdf()), self.evaluate(log_cdf))
@test_util.run_in_graph_and_eager_modes
def testUniformEntropy(self):
- with self.test_session():
- a_v = np.array([1.0, 1.0, 1.0])
- b_v = np.array([[1.5, 2.0, 3.0]])
- uniform = uniform_lib.Uniform(low=a_v, high=b_v)
+ a_v = np.array([1.0, 1.0, 1.0])
+ b_v = np.array([[1.5, 2.0, 3.0]])
+ uniform = uniform_lib.Uniform(low=a_v, high=b_v)
- expected_entropy = np.log(b_v - a_v)
- self.assertAllClose(expected_entropy, self.evaluate(uniform.entropy()))
+ expected_entropy = np.log(b_v - a_v)
+ self.assertAllClose(expected_entropy, self.evaluate(uniform.entropy()))
@test_util.run_in_graph_and_eager_modes
def testUniformAssertMaxGtMin(self):
- with self.test_session():
- a_v = np.array([1.0, 1.0, 1.0], dtype=np.float32)
- b_v = np.array([1.0, 2.0, 3.0], dtype=np.float32)
+ a_v = np.array([1.0, 1.0, 1.0], dtype=np.float32)
+ b_v = np.array([1.0, 2.0, 3.0], dtype=np.float32)
- with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
- "x < y"):
- uniform = uniform_lib.Uniform(low=a_v, high=b_v, validate_args=True)
- self.evaluate(uniform.low)
+ with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
+ "x < y"):
+ uniform = uniform_lib.Uniform(low=a_v, high=b_v, validate_args=True)
+ self.evaluate(uniform.low)
@test_util.run_in_graph_and_eager_modes
def testUniformSample(self):
- with self.test_session():
- a = constant_op.constant([3.0, 4.0])
- b = constant_op.constant(13.0)
- a1_v = 3.0
- a2_v = 4.0
- b_v = 13.0
- n = constant_op.constant(100000)
- uniform = uniform_lib.Uniform(low=a, high=b)
-
- samples = uniform.sample(n, seed=137)
- sample_values = self.evaluate(samples)
- self.assertEqual(sample_values.shape, (100000, 2))
- self.assertAllClose(
- sample_values[::, 0].mean(), (b_v + a1_v) / 2, atol=1e-1, rtol=0.)
- self.assertAllClose(
- sample_values[::, 1].mean(), (b_v + a2_v) / 2, atol=1e-1, rtol=0.)
- self.assertFalse(
- np.any(sample_values[::, 0] < a1_v) or np.any(sample_values >= b_v))
- self.assertFalse(
- np.any(sample_values[::, 1] < a2_v) or np.any(sample_values >= b_v))
+ a = constant_op.constant([3.0, 4.0])
+ b = constant_op.constant(13.0)
+ a1_v = 3.0
+ a2_v = 4.0
+ b_v = 13.0
+ n = constant_op.constant(100000)
+ uniform = uniform_lib.Uniform(low=a, high=b)
+
+ samples = uniform.sample(n, seed=137)
+ sample_values = self.evaluate(samples)
+ self.assertEqual(sample_values.shape, (100000, 2))
+ self.assertAllClose(
+ sample_values[::, 0].mean(), (b_v + a1_v) / 2, atol=1e-1, rtol=0.)
+ self.assertAllClose(
+ sample_values[::, 1].mean(), (b_v + a2_v) / 2, atol=1e-1, rtol=0.)
+ self.assertFalse(
+ np.any(sample_values[::, 0] < a1_v) or np.any(sample_values >= b_v))
+ self.assertFalse(
+ np.any(sample_values[::, 1] < a2_v) or np.any(sample_values >= b_v))
@test_util.run_in_graph_and_eager_modes
def _testUniformSampleMultiDimensional(self):
# DISABLED: Please enable this test once b/issues/30149644 is resolved.
- with self.test_session():
- batch_size = 2
- a_v = [3.0, 22.0]
- b_v = [13.0, 35.0]
- a = constant_op.constant([a_v] * batch_size)
- b = constant_op.constant([b_v] * batch_size)
-
- uniform = uniform_lib.Uniform(low=a, high=b)
-
- n_v = 100000
- n = constant_op.constant(n_v)
- samples = uniform.sample(n)
- self.assertEqual(samples.get_shape(), (n_v, batch_size, 2))
-
- sample_values = self.evaluate(samples)
-
- self.assertFalse(
- np.any(sample_values[:, 0, 0] < a_v[0]) or
- np.any(sample_values[:, 0, 0] >= b_v[0]))
- self.assertFalse(
- np.any(sample_values[:, 0, 1] < a_v[1]) or
- np.any(sample_values[:, 0, 1] >= b_v[1]))
-
- self.assertAllClose(
- sample_values[:, 0, 0].mean(), (a_v[0] + b_v[0]) / 2, atol=1e-2)
- self.assertAllClose(
- sample_values[:, 0, 1].mean(), (a_v[1] + b_v[1]) / 2, atol=1e-2)
+ batch_size = 2
+ a_v = [3.0, 22.0]
+ b_v = [13.0, 35.0]
+ a = constant_op.constant([a_v] * batch_size)
+ b = constant_op.constant([b_v] * batch_size)
+
+ uniform = uniform_lib.Uniform(low=a, high=b)
+
+ n_v = 100000
+ n = constant_op.constant(n_v)
+ samples = uniform.sample(n)
+ self.assertEqual(samples.get_shape(), (n_v, batch_size, 2))
+
+ sample_values = self.evaluate(samples)
+
+ self.assertFalse(
+ np.any(sample_values[:, 0, 0] < a_v[0]) or
+ np.any(sample_values[:, 0, 0] >= b_v[0]))
+ self.assertFalse(
+ np.any(sample_values[:, 0, 1] < a_v[1]) or
+ np.any(sample_values[:, 0, 1] >= b_v[1]))
+
+ self.assertAllClose(
+ sample_values[:, 0, 0].mean(), (a_v[0] + b_v[0]) / 2, atol=1e-2)
+ self.assertAllClose(
+ sample_values[:, 0, 1].mean(), (a_v[1] + b_v[1]) / 2, atol=1e-2)
@test_util.run_in_graph_and_eager_modes
def testUniformMean(self):
- with self.test_session():
- a = 10.0
- b = 100.0
- uniform = uniform_lib.Uniform(low=a, high=b)
- if not stats:
- return
- s_uniform = stats.uniform(loc=a, scale=b - a)
- self.assertAllClose(self.evaluate(uniform.mean()), s_uniform.mean())
+ a = 10.0
+ b = 100.0
+ uniform = uniform_lib.Uniform(low=a, high=b)
+ if not stats:
+ return
+ s_uniform = stats.uniform(loc=a, scale=b - a)
+ self.assertAllClose(self.evaluate(uniform.mean()), s_uniform.mean())
@test_util.run_in_graph_and_eager_modes
def testUniformVariance(self):
- with self.test_session():
- a = 10.0
- b = 100.0
- uniform = uniform_lib.Uniform(low=a, high=b)
- if not stats:
- return
- s_uniform = stats.uniform(loc=a, scale=b - a)
- self.assertAllClose(self.evaluate(uniform.variance()), s_uniform.var())
+ a = 10.0
+ b = 100.0
+ uniform = uniform_lib.Uniform(low=a, high=b)
+ if not stats:
+ return
+ s_uniform = stats.uniform(loc=a, scale=b - a)
+ self.assertAllClose(self.evaluate(uniform.variance()), s_uniform.var())
@test_util.run_in_graph_and_eager_modes
def testUniformStd(self):
- with self.test_session():
- a = 10.0
- b = 100.0
- uniform = uniform_lib.Uniform(low=a, high=b)
- if not stats:
- return
- s_uniform = stats.uniform(loc=a, scale=b - a)
- self.assertAllClose(self.evaluate(uniform.stddev()), s_uniform.std())
+ a = 10.0
+ b = 100.0
+ uniform = uniform_lib.Uniform(low=a, high=b)
+ if not stats:
+ return
+ s_uniform = stats.uniform(loc=a, scale=b - a)
+ self.assertAllClose(self.evaluate(uniform.stddev()), s_uniform.std())
@test_util.run_in_graph_and_eager_modes
def testUniformNans(self):
- with self.test_session():
- a = 10.0
- b = [11.0, 100.0]
- uniform = uniform_lib.Uniform(low=a, high=b)
+ a = 10.0
+ b = [11.0, 100.0]
+ uniform = uniform_lib.Uniform(low=a, high=b)
- no_nans = constant_op.constant(1.0)
- nans = constant_op.constant(0.0) / constant_op.constant(0.0)
- self.assertTrue(self.evaluate(math_ops.is_nan(nans)))
- with_nans = array_ops.stack([no_nans, nans])
+ no_nans = constant_op.constant(1.0)
+ nans = constant_op.constant(0.0) / constant_op.constant(0.0)
+ self.assertTrue(self.evaluate(math_ops.is_nan(nans)))
+ with_nans = array_ops.stack([no_nans, nans])
- pdf = uniform.prob(with_nans)
+ pdf = uniform.prob(with_nans)
- is_nan = self.evaluate(math_ops.is_nan(pdf))
- self.assertFalse(is_nan[0])
- self.assertTrue(is_nan[1])
+ is_nan = self.evaluate(math_ops.is_nan(pdf))
+ self.assertFalse(is_nan[0])
+ self.assertTrue(is_nan[1])
@test_util.run_in_graph_and_eager_modes
def testUniformSamplePdf(self):
- with self.test_session():
- a = 10.0
- b = [11.0, 100.0]
- uniform = uniform_lib.Uniform(a, b)
- self.assertTrue(
- self.evaluate(
- math_ops.reduce_all(uniform.prob(uniform.sample(10)) > 0)))
+ a = 10.0
+ b = [11.0, 100.0]
+ uniform = uniform_lib.Uniform(a, b)
+ self.assertTrue(
+ self.evaluate(
+ math_ops.reduce_all(uniform.prob(uniform.sample(10)) > 0)))
@test_util.run_in_graph_and_eager_modes
def testUniformBroadcasting(self):
- with self.test_session():
- a = 10.0
- b = [11.0, 20.0]
- uniform = uniform_lib.Uniform(a, b)
+ a = 10.0
+ b = [11.0, 20.0]
+ uniform = uniform_lib.Uniform(a, b)
- pdf = uniform.prob([[10.5, 11.5], [9.0, 19.0], [10.5, 21.0]])
- expected_pdf = np.array([[1.0, 0.1], [0.0, 0.1], [1.0, 0.0]])
- self.assertAllClose(expected_pdf, self.evaluate(pdf))
+ pdf = uniform.prob([[10.5, 11.5], [9.0, 19.0], [10.5, 21.0]])
+ expected_pdf = np.array([[1.0, 0.1], [0.0, 0.1], [1.0, 0.0]])
+ self.assertAllClose(expected_pdf, self.evaluate(pdf))
@test_util.run_in_graph_and_eager_modes
def testUniformSampleWithShape(self):
- with self.test_session():
- a = 10.0
- b = [11.0, 20.0]
- uniform = uniform_lib.Uniform(a, b)
-
- pdf = uniform.prob(uniform.sample((2, 3)))
- # pylint: disable=bad-continuation
- expected_pdf = [
- [[1.0, 0.1], [1.0, 0.1], [1.0, 0.1]],
- [[1.0, 0.1], [1.0, 0.1], [1.0, 0.1]],
- ]
- # pylint: enable=bad-continuation
- self.assertAllClose(expected_pdf, self.evaluate(pdf))
-
- pdf = uniform.prob(uniform.sample())
- expected_pdf = [1.0, 0.1]
- self.assertAllClose(expected_pdf, self.evaluate(pdf))
+ a = 10.0
+ b = [11.0, 20.0]
+ uniform = uniform_lib.Uniform(a, b)
+
+ pdf = uniform.prob(uniform.sample((2, 3)))
+ # pylint: disable=bad-continuation
+ expected_pdf = [
+ [[1.0, 0.1], [1.0, 0.1], [1.0, 0.1]],
+ [[1.0, 0.1], [1.0, 0.1], [1.0, 0.1]],
+ ]
+ # pylint: enable=bad-continuation
+ self.assertAllClose(expected_pdf, self.evaluate(pdf))
+
+ pdf = uniform.prob(uniform.sample())
+ expected_pdf = [1.0, 0.1]
+ self.assertAllClose(expected_pdf, self.evaluate(pdf))
def testFullyReparameterized(self):
a = constant_op.constant(0.1)
diff --git a/tensorflow/python/kernel_tests/distributions/util_test.py b/tensorflow/python/kernel_tests/distributions/util_test.py
index 61faa8466e..27d652c2c6 100644
--- a/tensorflow/python/kernel_tests/distributions/util_test.py
+++ b/tensorflow/python/kernel_tests/distributions/util_test.py
@@ -69,7 +69,7 @@ class AssertCloseTest(test.TestCase):
w = array_ops.placeholder(dtypes.float32)
feed_dict = {x: [1., 5, 10, 15, 20], y: [1.1, 5, 10, 15, 20],
z: [1.0001, 5, 10, 15, 20], w: [1e-8, 5, 10, 15, 20]}
- with self.test_session():
+ with self.cached_session():
with ops.control_dependencies([du.assert_integer_form(x)]):
array_ops.identity(x).eval(feed_dict=feed_dict)
@@ -122,58 +122,52 @@ class GetLogitsAndProbsTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testImproperArguments(self):
- with self.test_session():
- with self.assertRaises(ValueError):
- du.get_logits_and_probs(logits=None, probs=None)
+ with self.assertRaises(ValueError):
+ du.get_logits_and_probs(logits=None, probs=None)
- with self.assertRaises(ValueError):
- du.get_logits_and_probs(logits=[0.1], probs=[0.1])
+ with self.assertRaises(ValueError):
+ du.get_logits_and_probs(logits=[0.1], probs=[0.1])
@test_util.run_in_graph_and_eager_modes
def testLogits(self):
p = np.array([0.01, 0.2, 0.5, 0.7, .99], dtype=np.float32)
logits = _logit(p)
- with self.test_session():
- new_logits, new_p = du.get_logits_and_probs(
- logits=logits, validate_args=True)
+ new_logits, new_p = du.get_logits_and_probs(
+ logits=logits, validate_args=True)
- self.assertAllClose(p, self.evaluate(new_p), rtol=1e-5, atol=0.)
- self.assertAllClose(logits, self.evaluate(new_logits), rtol=1e-5, atol=0.)
+ self.assertAllClose(p, self.evaluate(new_p), rtol=1e-5, atol=0.)
+ self.assertAllClose(logits, self.evaluate(new_logits), rtol=1e-5, atol=0.)
@test_util.run_in_graph_and_eager_modes
def testLogitsMultidimensional(self):
p = np.array([0.2, 0.3, 0.5], dtype=np.float32)
logits = np.log(p)
- with self.test_session():
- new_logits, new_p = du.get_logits_and_probs(
- logits=logits, multidimensional=True, validate_args=True)
+ new_logits, new_p = du.get_logits_and_probs(
+ logits=logits, multidimensional=True, validate_args=True)
- self.assertAllClose(self.evaluate(new_p), p)
- self.assertAllClose(self.evaluate(new_logits), logits)
+ self.assertAllClose(self.evaluate(new_p), p)
+ self.assertAllClose(self.evaluate(new_logits), logits)
@test_util.run_in_graph_and_eager_modes
def testProbability(self):
p = np.array([0.01, 0.2, 0.5, 0.7, .99], dtype=np.float32)
- with self.test_session():
- new_logits, new_p = du.get_logits_and_probs(
- probs=p, validate_args=True)
+ new_logits, new_p = du.get_logits_and_probs(probs=p, validate_args=True)
- self.assertAllClose(_logit(p), self.evaluate(new_logits))
- self.assertAllClose(p, self.evaluate(new_p))
+ self.assertAllClose(_logit(p), self.evaluate(new_logits))
+ self.assertAllClose(p, self.evaluate(new_p))
@test_util.run_in_graph_and_eager_modes
def testProbabilityMultidimensional(self):
p = np.array([[0.3, 0.4, 0.3], [0.1, 0.5, 0.4]], dtype=np.float32)
- with self.test_session():
- new_logits, new_p = du.get_logits_and_probs(
- probs=p, multidimensional=True, validate_args=True)
+ new_logits, new_p = du.get_logits_and_probs(
+ probs=p, multidimensional=True, validate_args=True)
- self.assertAllClose(np.log(p), self.evaluate(new_logits))
- self.assertAllClose(p, self.evaluate(new_p))
+ self.assertAllClose(np.log(p), self.evaluate(new_logits))
+ self.assertAllClose(p, self.evaluate(new_p))
@test_util.run_in_graph_and_eager_modes
def testProbabilityValidateArgs(self):
@@ -183,29 +177,23 @@ class GetLogitsAndProbsTest(test.TestCase):
# Component greater than 1.
p3 = [2, 0.2, 0.5, 0.3, .2]
- with self.test_session():
- _, prob = du.get_logits_and_probs(
- probs=p, validate_args=True)
- self.evaluate(prob)
-
- with self.assertRaisesOpError("Condition x >= 0"):
- _, prob = du.get_logits_and_probs(
- probs=p2, validate_args=True)
- self.evaluate(prob)
+ _, prob = du.get_logits_and_probs(probs=p, validate_args=True)
+ self.evaluate(prob)
- _, prob = du.get_logits_and_probs(
- probs=p2, validate_args=False)
+ with self.assertRaisesOpError("Condition x >= 0"):
+ _, prob = du.get_logits_and_probs(probs=p2, validate_args=True)
self.evaluate(prob)
- with self.assertRaisesOpError("probs has components greater than 1"):
- _, prob = du.get_logits_and_probs(
- probs=p3, validate_args=True)
- self.evaluate(prob)
+ _, prob = du.get_logits_and_probs(probs=p2, validate_args=False)
+ self.evaluate(prob)
- _, prob = du.get_logits_and_probs(
- probs=p3, validate_args=False)
+ with self.assertRaisesOpError("probs has components greater than 1"):
+ _, prob = du.get_logits_and_probs(probs=p3, validate_args=True)
self.evaluate(prob)
+ _, prob = du.get_logits_and_probs(probs=p3, validate_args=False)
+ self.evaluate(prob)
+
@test_util.run_in_graph_and_eager_modes
def testProbabilityValidateArgsMultidimensional(self):
p = np.array([[0.3, 0.4, 0.3], [0.1, 0.5, 0.4]], dtype=np.float32)
@@ -216,41 +204,39 @@ class GetLogitsAndProbsTest(test.TestCase):
# Does not sum to 1.
p4 = np.array([[1.1, 0.3, 0.4], [0.1, 0.5, 0.4]], dtype=np.float32)
- with self.test_session():
- _, prob = du.get_logits_and_probs(
- probs=p, multidimensional=True)
- self.evaluate(prob)
-
- with self.assertRaisesOpError("Condition x >= 0"):
- _, prob = du.get_logits_and_probs(
- probs=p2, multidimensional=True, validate_args=True)
- self.evaluate(prob)
+ _, prob = du.get_logits_and_probs(probs=p, multidimensional=True)
+ self.evaluate(prob)
+ with self.assertRaisesOpError("Condition x >= 0"):
_, prob = du.get_logits_and_probs(
- probs=p2, multidimensional=True, validate_args=False)
+ probs=p2, multidimensional=True, validate_args=True)
self.evaluate(prob)
- with self.assertRaisesOpError(
- "(probs has components greater than 1|probs does not sum to 1)"):
- _, prob = du.get_logits_and_probs(
- probs=p3, multidimensional=True, validate_args=True)
- self.evaluate(prob)
+ _, prob = du.get_logits_and_probs(
+ probs=p2, multidimensional=True, validate_args=False)
+ self.evaluate(prob)
+ with self.assertRaisesOpError(
+ "(probs has components greater than 1|probs does not sum to 1)"):
_, prob = du.get_logits_and_probs(
- probs=p3, multidimensional=True, validate_args=False)
+ probs=p3, multidimensional=True, validate_args=True)
self.evaluate(prob)
- with self.assertRaisesOpError("probs does not sum to 1"):
- _, prob = du.get_logits_and_probs(
- probs=p4, multidimensional=True, validate_args=True)
- self.evaluate(prob)
+ _, prob = du.get_logits_and_probs(
+ probs=p3, multidimensional=True, validate_args=False)
+ self.evaluate(prob)
+ with self.assertRaisesOpError("probs does not sum to 1"):
_, prob = du.get_logits_and_probs(
- probs=p4, multidimensional=True, validate_args=False)
+ probs=p4, multidimensional=True, validate_args=True)
self.evaluate(prob)
+ _, prob = du.get_logits_and_probs(
+ probs=p4, multidimensional=True, validate_args=False)
+ self.evaluate(prob)
+
def testProbsMultidimShape(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
p = array_ops.ones([int(2**11+1)], dtype=np.float16)
du.get_logits_and_probs(
@@ -264,7 +250,7 @@ class GetLogitsAndProbsTest(test.TestCase):
prob.eval(feed_dict={p: np.ones([int(2**11+1)])})
def testLogitsMultidimShape(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
l = array_ops.ones([int(2**11+1)], dtype=np.float16)
du.get_logits_and_probs(
@@ -281,7 +267,7 @@ class GetLogitsAndProbsTest(test.TestCase):
class EmbedCheckCategoricalEventShapeTest(test.TestCase):
def testTooSmall(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
param = array_ops.ones([1], dtype=np.float16)
checked_param = du.embed_check_categorical_event_shape(
@@ -295,7 +281,7 @@ class EmbedCheckCategoricalEventShapeTest(test.TestCase):
checked_param.eval(feed_dict={param: np.ones([1])})
def testTooLarge(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
param = array_ops.ones([int(2**11+1)], dtype=dtypes.float16)
checked_param = du.embed_check_categorical_event_shape(
@@ -310,18 +296,17 @@ class EmbedCheckCategoricalEventShapeTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testUnsupportedDtype(self):
- with self.test_session():
- param = ops.convert_to_tensor(
- np.ones([2**11 + 1]).astype(dtypes.qint16.as_numpy_dtype),
- dtype=dtypes.qint16)
- with self.assertRaises(TypeError):
- du.embed_check_categorical_event_shape(param)
+ param = ops.convert_to_tensor(
+ np.ones([2**11 + 1]).astype(dtypes.qint16.as_numpy_dtype),
+ dtype=dtypes.qint16)
+ with self.assertRaises(TypeError):
+ du.embed_check_categorical_event_shape(param)
class EmbedCheckIntegerCastingClosedTest(test.TestCase):
def testCorrectlyAssertsNonnegative(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("Elements must be non-negative"):
x = array_ops.placeholder(dtype=dtypes.float16)
x_checked = du.embed_check_integer_casting_closed(
@@ -329,7 +314,7 @@ class EmbedCheckIntegerCastingClosedTest(test.TestCase):
x_checked.eval(feed_dict={x: np.array([1, -1], dtype=np.float16)})
def testCorrectlyAssersIntegerForm(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("Elements must be int16-equivalent."):
x = array_ops.placeholder(dtype=dtypes.float16)
x_checked = du.embed_check_integer_casting_closed(
@@ -337,7 +322,7 @@ class EmbedCheckIntegerCastingClosedTest(test.TestCase):
x_checked.eval(feed_dict={x: np.array([1, 1.5], dtype=np.float16)})
def testCorrectlyAssertsLargestPossibleInteger(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("Elements cannot exceed 32767."):
x = array_ops.placeholder(dtype=dtypes.int32)
x_checked = du.embed_check_integer_casting_closed(
@@ -345,7 +330,7 @@ class EmbedCheckIntegerCastingClosedTest(test.TestCase):
x_checked.eval(feed_dict={x: np.array([1, 2**15], dtype=np.int32)})
def testCorrectlyAssertsSmallestPossibleInteger(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("Elements cannot be smaller than 0."):
x = array_ops.placeholder(dtype=dtypes.int32)
x_checked = du.embed_check_integer_casting_closed(
@@ -365,29 +350,27 @@ class LogCombinationsTest(test.TestCase):
log_combs = np.log(special.binom(n, k))
- with self.test_session():
- n = np.array(n, dtype=np.float32)
- counts = [[1., 1], [2., 3], [4., 8], [11, 4]]
- log_binom = du.log_combinations(n, counts)
- self.assertEqual([4], log_binom.get_shape())
- self.assertAllClose(log_combs, self.evaluate(log_binom))
+ n = np.array(n, dtype=np.float32)
+ counts = [[1., 1], [2., 3], [4., 8], [11, 4]]
+ log_binom = du.log_combinations(n, counts)
+ self.assertEqual([4], log_binom.get_shape())
+ self.assertAllClose(log_combs, self.evaluate(log_binom))
def testLogCombinationsShape(self):
# Shape [2, 2]
n = [[2, 5], [12, 15]]
- with self.test_session():
- n = np.array(n, dtype=np.float32)
- # Shape [2, 2, 4]
- counts = [[[1., 1, 0, 0], [2., 2, 1, 0]], [[4., 4, 1, 3], [10, 1, 1, 4]]]
- log_binom = du.log_combinations(n, counts)
- self.assertEqual([2, 2], log_binom.get_shape())
+ n = np.array(n, dtype=np.float32)
+ # Shape [2, 2, 4]
+ counts = [[[1., 1, 0, 0], [2., 2, 1, 0]], [[4., 4, 1, 3], [10, 1, 1, 4]]]
+ log_binom = du.log_combinations(n, counts)
+ self.assertEqual([2, 2], log_binom.get_shape())
class DynamicShapeTest(test.TestCase):
def testSameDynamicShape(self):
- with self.test_session():
+ with self.cached_session():
scalar = constant_op.constant(2.0)
scalar1 = array_ops.placeholder(dtype=dtypes.float32)
@@ -497,22 +480,21 @@ class RotateTransposeTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testRollStatic(self):
- with self.test_session():
- if context.executing_eagerly():
- error_message = r"Attempt to convert a value \(None\)"
- else:
- error_message = "None values not supported."
- with self.assertRaisesRegexp(ValueError, error_message):
- du.rotate_transpose(None, 1)
- for x in (np.ones(1), np.ones((2, 1)), np.ones((3, 2, 1))):
- for shift in np.arange(-5, 5):
- y = du.rotate_transpose(x, shift)
- self.assertAllEqual(
- self._np_rotate_transpose(x, shift), self.evaluate(y))
- self.assertAllEqual(np.roll(x.shape, shift), y.get_shape().as_list())
+ if context.executing_eagerly():
+ error_message = r"Attempt to convert a value \(None\)"
+ else:
+ error_message = "None values not supported."
+ with self.assertRaisesRegexp(ValueError, error_message):
+ du.rotate_transpose(None, 1)
+ for x in (np.ones(1), np.ones((2, 1)), np.ones((3, 2, 1))):
+ for shift in np.arange(-5, 5):
+ y = du.rotate_transpose(x, shift)
+ self.assertAllEqual(
+ self._np_rotate_transpose(x, shift), self.evaluate(y))
+ self.assertAllEqual(np.roll(x.shape, shift), y.get_shape().as_list())
def testRollDynamic(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = array_ops.placeholder(dtypes.float32)
shift = array_ops.placeholder(dtypes.int32)
for x_value in (np.ones(
@@ -530,7 +512,7 @@ class RotateTransposeTest(test.TestCase):
class PickVectorTest(test.TestCase):
def testCorrectlyPicksVector(self):
- with self.test_session():
+ with self.cached_session():
x = np.arange(10, 12)
y = np.arange(15, 18)
self.assertAllEqual(
@@ -568,19 +550,19 @@ class PreferStaticRankTest(test.TestCase):
def testDynamicRankEndsUpBeingNonEmpty(self):
x = array_ops.placeholder(np.float64, shape=None)
rank = du.prefer_static_rank(x)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(2, rank.eval(feed_dict={x: np.zeros((2, 3))}))
def testDynamicRankEndsUpBeingEmpty(self):
x = array_ops.placeholder(np.int32, shape=None)
rank = du.prefer_static_rank(x)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(1, rank.eval(feed_dict={x: []}))
def testDynamicRankEndsUpBeingScalar(self):
x = array_ops.placeholder(np.int32, shape=None)
rank = du.prefer_static_rank(x)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(0, rank.eval(feed_dict={x: 1}))
@@ -607,19 +589,19 @@ class PreferStaticShapeTest(test.TestCase):
def testDynamicShapeEndsUpBeingNonEmpty(self):
x = array_ops.placeholder(np.float64, shape=None)
shape = du.prefer_static_shape(x)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual((2, 3), shape.eval(feed_dict={x: np.zeros((2, 3))}))
def testDynamicShapeEndsUpBeingEmpty(self):
x = array_ops.placeholder(np.int32, shape=None)
shape = du.prefer_static_shape(x)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(np.array([0]), shape.eval(feed_dict={x: []}))
def testDynamicShapeEndsUpBeingScalar(self):
x = array_ops.placeholder(np.int32, shape=None)
shape = du.prefer_static_shape(x)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(np.array([]), shape.eval(feed_dict={x: 1}))
@@ -646,20 +628,20 @@ class PreferStaticValueTest(test.TestCase):
def testDynamicValueEndsUpBeingNonEmpty(self):
x = array_ops.placeholder(np.float64, shape=None)
value = du.prefer_static_value(x)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(np.zeros((2, 3)),
value.eval(feed_dict={x: np.zeros((2, 3))}))
def testDynamicValueEndsUpBeingEmpty(self):
x = array_ops.placeholder(np.int32, shape=None)
value = du.prefer_static_value(x)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(np.array([]), value.eval(feed_dict={x: []}))
def testDynamicValueEndsUpBeingScalar(self):
x = array_ops.placeholder(np.int32, shape=None)
value = du.prefer_static_value(x)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(np.array(1), value.eval(feed_dict={x: 1}))
@@ -691,7 +673,7 @@ class FillTriangularTest(test.TestCase):
def _run_test(self, x_, use_deferred_shape=False, **kwargs):
x_ = np.asarray(x_)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
static_shape = None if use_deferred_shape else x_.shape
x_pl = array_ops.placeholder_with_default(x_, shape=static_shape)
# Add `zeros_like(x)` such that x's value and gradient are identical. We
@@ -761,7 +743,7 @@ class FillTriangularInverseTest(FillTriangularTest):
def _run_test(self, x_, use_deferred_shape=False, **kwargs):
x_ = np.asarray(x_)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
static_shape = None if use_deferred_shape else x_.shape
x_pl = array_ops.placeholder_with_default(x_, shape=static_shape)
zeros_like_x_pl = (x_pl * array_ops.stop_gradient(x_pl - 1.)
@@ -795,7 +777,7 @@ class ReduceWeightedLogSumExp(test.TestCase):
logx_ = np.array([[0., -1, 1000.],
[0, 1, -1000.],
[-5, 0, 5]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
logx = constant_op.constant(logx_)
expected = math_ops.reduce_logsumexp(logx, axis=-1)
grad_expected = gradients_impl.gradients(expected, logx)[0]
@@ -818,7 +800,7 @@ class ReduceWeightedLogSumExp(test.TestCase):
[1, -2, 1],
[1, 0, 1]])
expected, _ = self._reduce_weighted_logsumexp(logx_, w_, axis=-1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
logx = constant_op.constant(logx_)
w = constant_op.constant(w_)
actual, actual_sgn = du.reduce_weighted_logsumexp(
@@ -836,7 +818,7 @@ class ReduceWeightedLogSumExp(test.TestCase):
[1, 0, 1]])
expected, _ = self._reduce_weighted_logsumexp(
logx_, w_, axis=-1, keep_dims=True)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
logx = constant_op.constant(logx_)
w = constant_op.constant(w_)
actual, actual_sgn = du.reduce_weighted_logsumexp(
@@ -848,7 +830,7 @@ class ReduceWeightedLogSumExp(test.TestCase):
def testDocString(self):
"""This test verifies the correctness of the docstring examples."""
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant([[0., 0, 0],
[0, 0, 0]])
@@ -952,7 +934,7 @@ class SoftplusTest(test.TestCase):
use_gpu=True)
def testGradient(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(
[-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
shape=[2, 5],
@@ -968,7 +950,7 @@ class SoftplusTest(test.TestCase):
self.assertLess(err, 1e-4)
def testInverseSoftplusGradientNeverNan(self):
- with self.test_session():
+ with self.cached_session():
# Note that this range contains both zero and inf.
x = constant_op.constant(np.logspace(-8, 6).astype(np.float16))
y = du.softplus_inverse(x)
@@ -977,7 +959,7 @@ class SoftplusTest(test.TestCase):
self.assertAllEqual(np.zeros_like(grads).astype(np.bool), np.isnan(grads))
def testInverseSoftplusGradientFinite(self):
- with self.test_session():
+ with self.cached_session():
# This range of x is all finite, and so is 1 / x. So the
# gradient and its approximations should be finite as well.
x = constant_op.constant(np.logspace(-4.8, 4.5).astype(np.float16))
diff --git a/tensorflow/python/kernel_tests/functional_ops_test.py b/tensorflow/python/kernel_tests/functional_ops_test.py
index 1e76ad7476..3ddb5e06c9 100644
--- a/tensorflow/python/kernel_tests/functional_ops_test.py
+++ b/tensorflow/python/kernel_tests/functional_ops_test.py
@@ -59,42 +59,48 @@ class FunctionalOpsTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testFoldl_Simple(self):
- with self.test_session():
- elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
+ elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
- r = functional_ops.foldl(
- lambda a, x: math_ops.multiply(math_ops.add(a, x), 2),
- elems)
- self.assertAllEqual(208, self.evaluate(r))
+ r = functional_ops.foldl(
+ lambda a, x: math_ops.multiply(math_ops.add(a, x), 2),
+ elems)
+ self.assertAllEqual(208, self.evaluate(r))
- r = functional_ops.foldl(
- lambda a, x: math_ops.multiply(math_ops.add(a, x), 2),
- elems,
- initializer=10)
- self.assertAllEqual(880, self.evaluate(r))
+ r = functional_ops.foldl(
+ lambda a, x: math_ops.multiply(math_ops.add(a, x), 2),
+ elems,
+ initializer=10)
+ self.assertAllEqual(880, self.evaluate(r))
@test_util.run_in_graph_and_eager_modes
def testFoldl_SingleInputMultiOutput(self):
- with self.test_session():
- elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
- initializer = np.array([1, -1.0])
- r = functional_ops.foldl(lambda a, x: a + x, elems, initializer)
- r_value = self.evaluate(r)
+ elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
+ initializer = np.array([1, -1.0])
+ r = functional_ops.foldl(lambda a, x: a + x, elems, initializer)
+ r_value = self.evaluate(r)
- self.assertAllEqual(22, r_value[0])
- self.assertAllEqual(20, r_value[1])
+ self.assertAllEqual(22, r_value[0])
+ self.assertAllEqual(20, r_value[1])
@test_util.run_in_graph_and_eager_modes
def testFoldl_MultiInputSingleOutput(self):
- with self.test_session():
- elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
- initializer = np.array(1.0)
- r = functional_ops.foldl(lambda a, x: a + x[0] + x[1], (elems, -elems),
- initializer)
- self.assertAllEqual(1, self.evaluate(r))
+ elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
+ initializer = np.array(1.0)
+ r = functional_ops.foldl(lambda a, x: a + x[0] + x[1], (elems, -elems),
+ initializer)
+ self.assertAllEqual(1, self.evaluate(r))
+
+ @test_util.run_in_graph_and_eager_modes
+ def testFoldl_MultiInputDifferentDimsSingleOutput(self):
+ elems = np.array([[1.0, 1.0, 1.0], [2.0, 3.0, 4.0]])
+ other_elems = np.array([-1.0, 1.0])
+ initializer = np.array([0.0, 0.0, 0.0])
+ r = functional_ops.foldl(lambda a, x: a + x[0] * x[1],
+ (elems, other_elems), initializer)
+ self.assertAllEqual([1.0, 2.0, 3.0], self.evaluate(r))
def testFoldl_Scoped(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope("root") as varscope:
elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
@@ -114,42 +120,39 @@ class FunctionalOpsTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testFoldr_Simple(self):
- with self.test_session():
- elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
+ elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
- r = functional_ops.foldr(
- lambda a, x: math_ops.multiply(math_ops.add(a, x), 2),
- elems)
- self.assertAllEqual(450, self.evaluate(r))
+ r = functional_ops.foldr(
+ lambda a, x: math_ops.multiply(math_ops.add(a, x), 2),
+ elems)
+ self.assertAllEqual(450, self.evaluate(r))
- r = functional_ops.foldr(
- lambda a, x: math_ops.multiply(math_ops.add(a, x), 2),
- elems,
- initializer=10)
- self.assertAllEqual(1282, self.evaluate(r))
+ r = functional_ops.foldr(
+ lambda a, x: math_ops.multiply(math_ops.add(a, x), 2),
+ elems,
+ initializer=10)
+ self.assertAllEqual(1282, self.evaluate(r))
@test_util.run_in_graph_and_eager_modes
def testFoldr_SingleInputMultiOutput(self):
- with self.test_session():
- elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
- initializer = np.array([1, -1.0])
- r = functional_ops.foldr(lambda a, x: a + x, elems, initializer)
- r_value = self.evaluate(r)
+ elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
+ initializer = np.array([1, -1.0])
+ r = functional_ops.foldr(lambda a, x: a + x, elems, initializer)
+ r_value = self.evaluate(r)
- self.assertAllEqual(22, r_value[0])
- self.assertAllEqual(20, r_value[1])
+ self.assertAllEqual(22, r_value[0])
+ self.assertAllEqual(20, r_value[1])
@test_util.run_in_graph_and_eager_modes
def testFoldr_MultiInputSingleOutput(self):
- with self.test_session():
- elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
- initializer = np.array(1.0)
- r = functional_ops.foldr(lambda a, x: a + x[0] + x[1], (elems, -elems),
- initializer)
- self.assertAllEqual(1, self.evaluate(r))
+ elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
+ initializer = np.array(1.0)
+ r = functional_ops.foldr(lambda a, x: a + x[0] + x[1], (elems, -elems),
+ initializer)
+ self.assertAllEqual(1, self.evaluate(r))
def testFoldr_Scoped(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope("root") as varscope:
elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
@@ -169,7 +172,7 @@ class FunctionalOpsTest(test.TestCase):
# pylint: disable=unnecessary-lambda
def testFold_Grad(self):
- with self.test_session():
+ with self.cached_session():
elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
v = constant_op.constant(2.0, name="v")
r = functional_ops.foldl(
@@ -185,16 +188,15 @@ class FunctionalOpsTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testMap_Simple(self):
- with self.test_session():
- nums = [1, 2, 3, 4, 5, 6]
- elems = constant_op.constant(nums, name="data")
- r = functional_ops.map_fn(
- lambda x: math_ops.multiply(math_ops.add(x, 3), 2), elems)
- self.assertAllEqual(
- np.array([(x + 3) * 2 for x in nums]), self.evaluate(r))
+ nums = [1, 2, 3, 4, 5, 6]
+ elems = constant_op.constant(nums, name="data")
+ r = functional_ops.map_fn(
+ lambda x: math_ops.multiply(math_ops.add(x, 3), 2), elems)
+ self.assertAllEqual(
+ np.array([(x + 3) * 2 for x in nums]), self.evaluate(r))
def testMapSparseTensor(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(TypeError):
functional_ops.map_fn(
lambda x: x,
@@ -211,7 +213,7 @@ class FunctionalOpsTest(test.TestCase):
functional_ops.map_fn(lambda x: x, 1)
def testMap_Scoped(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
def double_scoped(x):
"""2x with a dummy 2 that is scoped."""
@@ -242,7 +244,7 @@ class FunctionalOpsTest(test.TestCase):
self.assertAllEqual(doubles, self.evaluate(r))
def testMap_Grad(self):
- with self.test_session():
+ with self.cached_session():
param = constant_op.constant(2.0)
elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="elems")
y = functional_ops.map_fn(
@@ -254,142 +256,131 @@ class FunctionalOpsTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testMap_SimpleNotTensor(self):
- with self.test_session():
- nums = np.array([1, 2, 3, 4, 5, 6])
- r = functional_ops.map_fn(
- lambda x: math_ops.multiply(math_ops.add(x, 3), 2), nums)
- self.assertAllEqual(
- np.array([(x + 3) * 2 for x in nums]), self.evaluate(r))
+ nums = np.array([1, 2, 3, 4, 5, 6])
+ r = functional_ops.map_fn(
+ lambda x: math_ops.multiply(math_ops.add(x, 3), 2), nums)
+ self.assertAllEqual(
+ np.array([(x + 3) * 2 for x in nums]), self.evaluate(r))
@test_util.run_in_graph_and_eager_modes
def testMap_SingleInputMultiOutput(self):
- with self.test_session():
- nums = np.array([1, 2, 3, 4, 5, 6])
- r = functional_ops.map_fn(
- lambda x: ((x + 3) * 2, -(x + 3) * 2),
- nums,
- dtype=(dtypes.int64, dtypes.int64))
- self.assertEqual(2, len(r))
- self.assertEqual((6,), r[0].get_shape())
- self.assertEqual((6,), r[1].get_shape())
- received = self.evaluate(r)
- self.assertAllEqual((nums + 3) * 2, received[0])
- self.assertAllEqual(-(nums + 3) * 2, received[1])
+ nums = np.array([1, 2, 3, 4, 5, 6])
+ r = functional_ops.map_fn(
+ lambda x: ((x + 3) * 2, -(x + 3) * 2),
+ nums,
+ dtype=(dtypes.int64, dtypes.int64))
+ self.assertEqual(2, len(r))
+ self.assertEqual((6,), r[0].get_shape())
+ self.assertEqual((6,), r[1].get_shape())
+ received = self.evaluate(r)
+ self.assertAllEqual((nums + 3) * 2, received[0])
+ self.assertAllEqual(-(nums + 3) * 2, received[1])
@test_util.run_in_graph_and_eager_modes
def testMap_MultiOutputMismatchedDtype(self):
- with self.test_session():
- nums = np.array([1, 2, 3, 4, 5, 6])
- with self.assertRaisesRegexp(
- TypeError, r"two structures don't have the same nested structure"):
- # lambda emits tuple, but dtype is a list
- functional_ops.map_fn(
- lambda x: ((x + 3) * 2, -(x + 3) * 2),
- nums,
- dtype=[dtypes.int64, dtypes.int64])
+ nums = np.array([1, 2, 3, 4, 5, 6])
+ with self.assertRaisesRegexp(
+ TypeError, r"two structures don't have the same nested structure"):
+ # lambda emits tuple, but dtype is a list
+ functional_ops.map_fn(
+ lambda x: ((x + 3) * 2, -(x + 3) * 2),
+ nums,
+ dtype=[dtypes.int64, dtypes.int64])
@test_util.run_in_graph_and_eager_modes
def testMap_MultiInputSingleOutput(self):
- with self.test_session():
- nums = np.array([1, 2, 3, 4, 5, 6])
- r = functional_ops.map_fn(
- lambda x: x[0] * x[1][0] + x[1][1], (nums, (nums, -nums)),
- dtype=dtypes.int64)
- self.assertEqual((6,), r.get_shape())
- received = self.evaluate(r)
- self.assertAllEqual(nums * nums + (-nums), received)
+ nums = np.array([1, 2, 3, 4, 5, 6])
+ r = functional_ops.map_fn(
+ lambda x: x[0] * x[1][0] + x[1][1], (nums, (nums, -nums)),
+ dtype=dtypes.int64)
+ self.assertEqual((6,), r.get_shape())
+ received = self.evaluate(r)
+ self.assertAllEqual(nums * nums + (-nums), received)
@test_util.run_in_graph_and_eager_modes
def testMap_MultiInputSameStructureOutput(self):
- with self.test_session():
- nums = np.array([1, 2, 3, 4, 5, 6])
- r = functional_ops.map_fn(lambda x: (x[1][0], (x[1][1], x[0])),
- (nums, (2 * nums, -nums)))
- r = [r[0], r[1][0], r[1][1]]
- self.assertEqual((6,), r[0].get_shape())
- self.assertEqual((6,), r[1].get_shape())
- self.assertEqual((6,), r[2].get_shape())
- received = self.evaluate(r)
- self.assertAllEqual(2 * nums, received[0])
- self.assertAllEqual(-nums, received[1])
- self.assertAllEqual(nums, received[2])
+ nums = np.array([1, 2, 3, 4, 5, 6])
+ r = functional_ops.map_fn(lambda x: (x[1][0], (x[1][1], x[0])),
+ (nums, (2 * nums, -nums)))
+ r = [r[0], r[1][0], r[1][1]]
+ self.assertEqual((6,), r[0].get_shape())
+ self.assertEqual((6,), r[1].get_shape())
+ self.assertEqual((6,), r[2].get_shape())
+ received = self.evaluate(r)
+ self.assertAllEqual(2 * nums, received[0])
+ self.assertAllEqual(-nums, received[1])
+ self.assertAllEqual(nums, received[2])
@test_util.run_in_graph_and_eager_modes
def testScan_Simple(self):
- with self.test_session():
- elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
- v = constant_op.constant(2.0, name="v")
+ elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
+ v = constant_op.constant(2.0, name="v")
- # pylint: disable=unnecessary-lambda
- r = functional_ops.scan(lambda a, x: math_ops.multiply(a, x), elems)
- self.assertAllEqual([1., 2., 6., 24., 120., 720.], self.evaluate(r))
+ # pylint: disable=unnecessary-lambda
+ r = functional_ops.scan(lambda a, x: math_ops.multiply(a, x), elems)
+ self.assertAllEqual([1., 2., 6., 24., 120., 720.], self.evaluate(r))
- r = functional_ops.scan(
- lambda a, x: math_ops.multiply(a, x), elems, initializer=v)
- self.assertAllEqual([2., 4., 12., 48., 240., 1440.], self.evaluate(r))
- # pylint: enable=unnecessary-lambda
+ r = functional_ops.scan(
+ lambda a, x: math_ops.multiply(a, x), elems, initializer=v)
+ self.assertAllEqual([2., 4., 12., 48., 240., 1440.], self.evaluate(r))
+ # pylint: enable=unnecessary-lambda
@test_util.run_in_graph_and_eager_modes
def testScan_Reverse(self):
- with self.test_session():
- elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
- v = constant_op.constant(2.0, name="v")
-
- # pylint: disable=unnecessary-lambda
- r = functional_ops.scan(lambda a, x: math_ops.multiply(a, x), elems,
- reverse=True)
- self.assertAllEqual([720., 720., 360., 120., 30., 6.], self.evaluate(r))
- r = functional_ops.scan(
- lambda a, x: math_ops.multiply(a, x), elems, initializer=v,
- reverse=True)
- self.assertAllEqual([1440., 1440., 720., 240., 60., 12.],
- self.evaluate(r))
- # pylint: enable=unnecessary-lambda
+ elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
+ v = constant_op.constant(2.0, name="v")
+
+ # pylint: disable=unnecessary-lambda
+ r = functional_ops.scan(lambda a, x: math_ops.multiply(a, x), elems,
+ reverse=True)
+ self.assertAllEqual([720., 720., 360., 120., 30., 6.], self.evaluate(r))
+ r = functional_ops.scan(
+ lambda a, x: math_ops.multiply(a, x), elems, initializer=v,
+ reverse=True)
+ self.assertAllEqual([1440., 1440., 720., 240., 60., 12.],
+ self.evaluate(r))
+ # pylint: enable=unnecessary-lambda
@test_util.run_in_graph_and_eager_modes
def testScan_SingleInputMultiOutput(self):
- with self.test_session():
- elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
- initializer = (np.array(1.0), np.array(-1.0))
- r = functional_ops.scan(lambda a, x: (a[0] * x, -a[1] * x), elems,
- initializer)
- r_value = self.evaluate(r)
+ elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
+ initializer = (np.array(1.0), np.array(-1.0))
+ r = functional_ops.scan(lambda a, x: (a[0] * x, -a[1] * x), elems,
+ initializer)
+ r_value = self.evaluate(r)
- self.assertAllEqual([1.0, 2.0, 6.0, 24.0, 120.0, 720.0], r_value[0])
- self.assertAllEqual([1.0, -2.0, 6.0, -24.0, 120.0, -720.0], r_value[1])
+ self.assertAllEqual([1.0, 2.0, 6.0, 24.0, 120.0, 720.0], r_value[0])
+ self.assertAllEqual([1.0, -2.0, 6.0, -24.0, 120.0, -720.0], r_value[1])
@test_util.run_in_graph_and_eager_modes
def testScan_MultiInputSingleOutput(self):
- with self.test_session():
- elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
- initializer = np.array(1.0)
- # Multiply a * 1 each time
- r = functional_ops.scan(lambda a, x: a * (x[0] + x[1]),
- (elems + 1, -elems), initializer)
- self.assertAllEqual([1.0, 1.0, 1.0, 1.0, 1.0, 1.0], self.evaluate(r))
+ elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
+ initializer = np.array(1.0)
+ # Multiply a * 1 each time
+ r = functional_ops.scan(lambda a, x: a * (x[0] + x[1]),
+ (elems + 1, -elems), initializer)
+ self.assertAllEqual([1.0, 1.0, 1.0, 1.0, 1.0, 1.0], self.evaluate(r))
@test_util.run_in_graph_and_eager_modes
def testScan_MultiInputSameTypeOutput(self):
- with self.test_session():
- elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
- r = functional_ops.scan(lambda a, x: (a[0] + x[0], a[1] + x[1]),
- (elems, -elems))
- r_value = self.evaluate(r)
- self.assertAllEqual(np.cumsum(elems), r_value[0])
- self.assertAllEqual(np.cumsum(-elems), r_value[1])
+ elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
+ r = functional_ops.scan(lambda a, x: (a[0] + x[0], a[1] + x[1]),
+ (elems, -elems))
+ r_value = self.evaluate(r)
+ self.assertAllEqual(np.cumsum(elems), r_value[0])
+ self.assertAllEqual(np.cumsum(-elems), r_value[1])
@test_util.run_in_graph_and_eager_modes
def testScan_MultiOutputMismatchedInitializer(self):
- with self.test_session():
- elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
- initializer = np.array(1.0)
- # Multiply a * 1 each time
- with self.assertRaisesRegexp(
- ValueError, "two structures don't have the same nested structure"):
- functional_ops.scan(lambda a, x: (a, -a), elems, initializer)
+ elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
+ initializer = np.array(1.0)
+ # Multiply a * 1 each time
+ with self.assertRaisesRegexp(
+ ValueError, "two structures don't have the same nested structure"):
+ functional_ops.scan(lambda a, x: (a, -a), elems, initializer)
def testScan_Scoped(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope("root") as varscope:
elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
@@ -411,30 +402,29 @@ class FunctionalOpsTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testScanFoldl_Nested(self):
- with self.test_session():
- elems = constant_op.constant([1.0, 2.0, 3.0, 4.0], name="data")
- inner_elems = constant_op.constant([0.5, 0.5], name="data")
-
- def r_inner(a, x):
- return functional_ops.foldl(
- lambda b, y: b * y * x, inner_elems, initializer=a)
-
- r = functional_ops.scan(r_inner, elems)
-
- # t == 0 (returns 1)
- # t == 1, a == 1, x == 2 (returns 1)
- # t_0 == 0, b == a == 1, y == 0.5, returns b * y * x = 1
- # t_1 == 1, b == 1, y == 0.5, returns b * y * x = 1
- # t == 2, a == 1, x == 3 (returns 1.5*1.5 == 2.25)
- # t_0 == 0, b == a == 1, y == 0.5, returns b * y * x = 1.5
- # t_1 == 1, b == 1.5, y == 0.5, returns b * y * x = 1.5*1.5
- # t == 3, a == 2.25, x == 4 (returns 9)
- # t_0 == 0, b == a == 2.25, y == 0.5, returns b * y * x = 4.5
- # t_1 == 1, b == 4.5, y == 0.5, returns b * y * x = 9
- self.assertAllClose([1., 1., 2.25, 9.], self.evaluate(r))
+ elems = constant_op.constant([1.0, 2.0, 3.0, 4.0], name="data")
+ inner_elems = constant_op.constant([0.5, 0.5], name="data")
+
+ def r_inner(a, x):
+ return functional_ops.foldl(
+ lambda b, y: b * y * x, inner_elems, initializer=a)
+
+ r = functional_ops.scan(r_inner, elems)
+
+ # t == 0 (returns 1)
+ # t == 1, a == 1, x == 2 (returns 1)
+ # t_0 == 0, b == a == 1, y == 0.5, returns b * y * x = 1
+ # t_1 == 1, b == 1, y == 0.5, returns b * y * x = 1
+ # t == 2, a == 1, x == 3 (returns 1.5*1.5 == 2.25)
+ # t_0 == 0, b == a == 1, y == 0.5, returns b * y * x = 1.5
+ # t_1 == 1, b == 1.5, y == 0.5, returns b * y * x = 1.5*1.5
+ # t == 3, a == 2.25, x == 4 (returns 9)
+ # t_0 == 0, b == a == 2.25, y == 0.5, returns b * y * x = 4.5
+ # t_1 == 1, b == 4.5, y == 0.5, returns b * y * x = 9
+ self.assertAllClose([1., 1., 2.25, 9.], self.evaluate(r))
def testScan_Control(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
s = array_ops.placeholder(dtypes.float32, shape=[None])
b = array_ops.placeholder(dtypes.bool)
@@ -445,7 +435,7 @@ class FunctionalOpsTest(test.TestCase):
b: True}))
def testScan_Grad(self):
- with self.test_session():
+ with self.cached_session():
elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
v = constant_op.constant(2.0, name="v")
@@ -470,22 +460,20 @@ class FunctionalOpsTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testFoldShape(self):
- with self.test_session():
- x = constant_op.constant([[1, 2, 3], [4, 5, 6]])
+ x = constant_op.constant([[1, 2, 3], [4, 5, 6]])
- def fn(_, current_input):
- return current_input
+ def fn(_, current_input):
+ return current_input
- initializer = constant_op.constant([0, 0, 0])
- y = functional_ops.foldl(fn, x, initializer=initializer)
- self.assertAllEqual(y.get_shape(), self.evaluate(y).shape)
+ initializer = constant_op.constant([0, 0, 0])
+ y = functional_ops.foldl(fn, x, initializer=initializer)
+ self.assertAllEqual(y.get_shape(), self.evaluate(y).shape)
@test_util.run_in_graph_and_eager_modes
def testMapShape(self):
- with self.test_session():
- x = constant_op.constant([[1, 2, 3], [4, 5, 6]])
- y = functional_ops.map_fn(lambda e: e, x)
- self.assertAllEqual(y.get_shape(), self.evaluate(y).shape)
+ x = constant_op.constant([[1, 2, 3], [4, 5, 6]])
+ y = functional_ops.map_fn(lambda e: e, x)
+ self.assertAllEqual(y.get_shape(), self.evaluate(y).shape)
def testMapUnknownShape(self):
x = array_ops.placeholder(dtypes.float32)
@@ -494,15 +482,14 @@ class FunctionalOpsTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testMapEmptyScalar(self):
- with self.test_session():
- map_return = functional_ops.map_fn(lambda x: 1, constant_op.constant([]))
- self.assertAllEqual([0], map_return.get_shape().dims)
- self.assertAllEqual([0], self.evaluate(map_return).shape)
+ map_return = functional_ops.map_fn(lambda x: 1, constant_op.constant([]))
+ self.assertAllEqual([0], map_return.get_shape().dims)
+ self.assertAllEqual([0], self.evaluate(map_return).shape)
# TODO(akshayka): this test fails in eager: the iterable is of length 0 so
# so the body of the while loop never executes
def testMapEmptyTensor(self):
- with self.test_session():
+ with self.cached_session():
map_return = functional_ops.map_fn(lambda x: array_ops.zeros([3, 2]),
constant_op.constant([]))
self.assertAllEqual([0, 3, 2], map_return.get_shape().dims)
@@ -510,20 +497,19 @@ class FunctionalOpsTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testScanShape(self):
- with self.test_session():
- x = constant_op.constant([[1, 2, 3], [4, 5, 6]])
+ x = constant_op.constant([[1, 2, 3], [4, 5, 6]])
- def fn(_, current_input):
- return current_input
+ def fn(_, current_input):
+ return current_input
- initializer = constant_op.constant([0, 0, 0])
- y = functional_ops.scan(fn, x, initializer=initializer)
- self.assertAllEqual(y.get_shape(), self.evaluate(y).shape)
+ initializer = constant_op.constant([0, 0, 0])
+ y = functional_ops.scan(fn, x, initializer=initializer)
+ self.assertAllEqual(y.get_shape(), self.evaluate(y).shape)
# TODO(akshayka): this test fails in eager: the iterable is of length 0 so
# so the body of the while loop never executes
def testScanEmptyTensor(self):
- with self.test_session():
+ with self.cached_session():
x = functional_ops.scan(
lambda x, _: x, math_ops.range(0), initializer=array_ops.ones([2, 4]))
self.assertAllEqual([0, 2, 4], x.get_shape())
@@ -540,7 +526,7 @@ class FunctionalOpsTest(test.TestCase):
self.assertIs(None, y.get_shape().dims)
def testScanVaryingShape(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = array_ops.placeholder(dtype=dtypes.float32, shape=[None, 2])
x_t = array_ops.transpose(x)
# scan over dimension 0 (with shape None)
@@ -619,7 +605,7 @@ class FunctionalOpsTest(test.TestCase):
remote_op = functional_ops.remote_call(
args=[a, b], Tout=[dtypes.int32], f=_remote_fn, target="/cpu:0")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
mul = sess.run(remote_op)
self.assertEqual(mul, [6])
@@ -643,7 +629,7 @@ class FunctionalOpsTest(test.TestCase):
f=_remote_fn,
target="/job:localhost/replica:0/task:0/device:GPU:0")[0] + 3.0
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
mul = sess.run(remote_op)
self.assertEqual(mul, 9.0)
@@ -667,7 +653,7 @@ class FunctionalOpsTest(test.TestCase):
f=_remote_fn,
target="/job:localhost/replica:0/task:0/cpu:0")[0] + 3.0
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
mul = sess.run(remote_op)
self.assertEqual(mul, 9.0)
@@ -686,7 +672,7 @@ class FunctionalOpsTest(test.TestCase):
remote_op = functional_ops.remote_call(
args=[a], Tout=[dtypes.string], f=_remote_fn, target="/cpu:0")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
ret = sess.run(remote_op)
self.assertAllEqual(ret, [b"a"])
diff --git a/tensorflow/python/kernel_tests/list_ops_test.py b/tensorflow/python/kernel_tests/list_ops_test.py
index 9b6aee64aa..0f5607712b 100644
--- a/tensorflow/python/kernel_tests/list_ops_test.py
+++ b/tensorflow/python/kernel_tests/list_ops_test.py
@@ -170,9 +170,8 @@ class ListOpsTest(test_util.TensorFlowTestCase):
list_ops.tensor_list_pop_back(
l_cpu, element_dtype=dtypes.float32)[1]), 2.0)
- @test_util.run_in_graph_and_eager_modes
def testGraphStack(self):
- with context.graph_mode(), self.test_session():
+ with self.cached_session():
tl = list_ops.empty_tensor_list(
element_shape=constant_op.constant([1], dtype=dtypes.int32),
element_dtype=dtypes.int32)
@@ -182,9 +181,8 @@ class ListOpsTest(test_util.TensorFlowTestCase):
list_ops.tensor_list_stack(tl, element_dtype=dtypes.int32)),
[[1]])
- @test_util.run_in_graph_and_eager_modes
def testGraphStackInLoop(self):
- with context.graph_mode(), self.test_session():
+ with self.cached_session():
t1 = list_ops.empty_tensor_list(
element_shape=constant_op.constant([], dtype=dtypes.int32),
element_dtype=dtypes.int32)
@@ -200,9 +198,8 @@ class ListOpsTest(test_util.TensorFlowTestCase):
s1 = list_ops.tensor_list_stack(t1, element_dtype=dtypes.int32)
self.assertAllEqual(self.evaluate(s1), [0, 1, 2, 3])
- @test_util.run_in_graph_and_eager_modes
def testGraphStackSwitchDtype(self):
- with context.graph_mode(), self.test_session():
+ with self.cached_session():
list_ = list_ops.empty_tensor_list(
element_shape=constant_op.constant([], dtype=dtypes.int32),
element_dtype=dtypes.int32)
@@ -222,9 +219,8 @@ class ListOpsTest(test_util.TensorFlowTestCase):
np_s1 = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.float32)
self.assertAllEqual(self.evaluate(s1), np_s1)
- @test_util.run_in_graph_and_eager_modes
def testGraphStackInLoopSwitchDtype(self):
- with context.graph_mode(), self.test_session():
+ with self.cached_session():
t1 = list_ops.empty_tensor_list(
element_shape=constant_op.constant([], dtype=dtypes.int32),
element_dtype=dtypes.int32)
@@ -476,6 +472,47 @@ class ListOpsTest(test_util.TensorFlowTestCase):
self.evaluate(t_full_zeros), np.zeros(
(2,), dtype=dtype.as_numpy_dtype))
+ @test_util.run_in_graph_and_eager_modes
+ def testZerosLikeVariant(self):
+ for dtype in (dtypes.uint8, dtypes.uint16, dtypes.int8, dtypes.int16,
+ dtypes.int32, dtypes.int64, dtypes.float16, dtypes.float32,
+ dtypes.float64, dtypes.complex64, dtypes.complex128,
+ dtypes.bool):
+ l = list_ops.empty_tensor_list(
+ element_dtype=dtypes.variant, element_shape=scalar_shape())
+
+ sub_l = list_ops.empty_tensor_list(
+ element_dtype=dtype, element_shape=scalar_shape())
+ l = list_ops.tensor_list_push_back(l, sub_l)
+ sub_l = list_ops.tensor_list_push_back(sub_l, math_ops.cast(
+ 1, dtype=dtype))
+ l = list_ops.tensor_list_push_back(l, sub_l)
+ sub_l = list_ops.tensor_list_push_back(sub_l, math_ops.cast(
+ 2, dtype=dtype))
+ l = list_ops.tensor_list_push_back(l, sub_l)
+
+ # l : [[],
+ # [1],
+ # [1, 2]]
+ #
+ # l_zeros : [[],
+ # [0],
+ # [0, 0]]
+ l_zeros = array_ops.zeros_like(l)
+
+ outputs = []
+ for _ in range(3):
+ l_zeros, out = list_ops.tensor_list_pop_back(
+ l_zeros, element_dtype=dtypes.variant)
+ outputs.append(list_ops.tensor_list_stack(out, element_dtype=dtype))
+
+ # Note: `outputs` contains popped values so the order is reversed.
+ self.assertAllEqual(self.evaluate(outputs[2]), [])
+ self.assertAllEqual(
+ self.evaluate(outputs[1]), np.zeros((1,), dtype=dtype.as_numpy_dtype))
+ self.assertAllEqual(
+ self.evaluate(outputs[0]), np.zeros((2,), dtype=dtype.as_numpy_dtype))
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/py_func_test.py b/tensorflow/python/kernel_tests/py_func_test.py
index 50154a45a8..79fcbaad43 100644
--- a/tensorflow/python/kernel_tests/py_func_test.py
+++ b/tensorflow/python/kernel_tests/py_func_test.py
@@ -61,7 +61,7 @@ class PyFuncTest(test.TestCase):
for dtype in [dtypes.float16, dtypes.float32, dtypes.float64,
dtypes.uint8, dtypes.int8, dtypes.uint16, dtypes.int16,
dtypes.int32, dtypes.int64]:
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(1, dtype=dtype)
y = constant_op.constant(2, dtype=dtype)
z = self.evaluate(script_ops.py_func(sum_func, [x, y], dtype))
@@ -71,7 +71,7 @@ class PyFuncTest(test.TestCase):
def sub_func(x, y):
return x - y
for dtype in [dtypes.complex64, dtypes.complex128]:
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(1 + 1j, dtype=dtype)
y = constant_op.constant(2 - 2j, dtype=dtype)
z = self.evaluate(script_ops.py_func(sub_func, [x, y], dtype))
@@ -81,21 +81,21 @@ class PyFuncTest(test.TestCase):
def and_func(x, y):
return x and y
dtype = dtypes.bool
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(True, dtype=dtype)
y = constant_op.constant(False, dtype=dtype)
z = self.evaluate(script_ops.py_func(and_func, [x, y], dtype))
self.assertEqual(z, False)
def testSingleType(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(1.0, dtypes.float32)
y = constant_op.constant(2.0, dtypes.float32)
z = self.evaluate(script_ops.py_func(np_func, [x, y], dtypes.float32))
self.assertEqual(z, np_func(1.0, 2.0).astype(np.float32))
def testScalar(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(1.0, dtypes.float32)
y = constant_op.constant(2.0, dtypes.float32)
z = self.evaluate(
@@ -103,7 +103,7 @@ class PyFuncTest(test.TestCase):
self.assertEqual(z[0], np_func(1.0, 2.0).astype(np.float32))
def testArray(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant([1.0, 2.0], dtypes.float64)
y = constant_op.constant([2.0, 3.0], dtypes.float64)
z = self.evaluate(script_ops.py_func(np_func, [x, y], [dtypes.float64]))
@@ -111,14 +111,14 @@ class PyFuncTest(test.TestCase):
np_func([1.0, 2.0], [2.0, 3.0]).astype(np.float64))
def testComplexType(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(1 + 2j, dtypes.complex64)
y = constant_op.constant(3 + 4j, dtypes.complex64)
z = self.evaluate(script_ops.py_func(np_func, [x, y], dtypes.complex64))
self.assertAllClose(z, np_func(1 + 2j, 3 + 4j))
def testRFFT(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant([1., 2., 3., 4.], dtypes.float32)
def rfft(x):
@@ -128,7 +128,7 @@ class PyFuncTest(test.TestCase):
self.assertAllClose(y, np.fft.rfft([1., 2., 3., 4.]))
def testPythonLiteral(self):
- with self.test_session():
+ with self.cached_session():
def literal(x):
return 1.0 if float(x) == 0.0 else 0.0
@@ -138,7 +138,7 @@ class PyFuncTest(test.TestCase):
self.assertAllClose(y, 1.0)
def testList(self):
- with self.test_session():
+ with self.cached_session():
def list_func(x):
return [x, x + 1]
@@ -150,7 +150,7 @@ class PyFuncTest(test.TestCase):
def testTuple(self):
# returns a tuple
- with self.test_session():
+ with self.cached_session():
def tuple_func(x):
return x, x + 1
@@ -161,7 +161,7 @@ class PyFuncTest(test.TestCase):
self.assertAllClose(y, [0.0, 1.0])
# returns a tuple, Tout and inp a tuple
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(0.0, dtypes.float64)
y = self.evaluate(
script_ops.py_func(tuple_func, (x,),
@@ -176,7 +176,7 @@ class PyFuncTest(test.TestCase):
def read_and_return_strings(x, y):
return x + y
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant([b"hello", b"hi"], dtypes.string)
y = self.evaluate(
script_ops.py_func(read_fixed_length_numpy_strings, [],
@@ -193,7 +193,7 @@ class PyFuncTest(test.TestCase):
def read_and_return_strings(x, y):
return x + y
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(["hello", "hi"], dtypes.string)
y = self.evaluate(
script_ops.py_func(read_fixed_length_numpy_strings, [],
@@ -210,7 +210,7 @@ class PyFuncTest(test.TestCase):
def read_and_return_strings(x, y):
return x + y
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(["hello", "hi"], dtypes.string)
y, = script_ops.py_func(read_object_array, [],
[dtypes.string])
@@ -219,19 +219,19 @@ class PyFuncTest(test.TestCase):
def testStringPadding(self):
correct = [b"this", b"is", b"a", b"test"]
- with self.test_session():
+ with self.cached_session():
s, = script_ops.py_func(lambda: [correct], [], [dtypes.string])
self.assertAllEqual(s.eval(), correct)
def testStringPaddingAreConvertedToBytes(self):
inp = ["this", "is", "a", "test"]
correct = [b"this", b"is", b"a", b"test"]
- with self.test_session():
+ with self.cached_session():
s, = script_ops.py_func(lambda: [inp], [], [dtypes.string])
self.assertAllEqual(s.eval(), correct)
def testLarge(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = array_ops.zeros([1000000], dtype=np.float32)
y = script_ops.py_func(lambda x: x + 1, [x], [dtypes.float32])
z = script_ops.py_func(lambda x: x * 2, [x], [dtypes.float32])
@@ -239,12 +239,12 @@ class PyFuncTest(test.TestCase):
sess.run([y[0].op, z[0].op])
def testNoInput(self):
- with self.test_session():
+ with self.cached_session():
x = self.evaluate(script_ops.py_func(lambda: 42.0, [], dtypes.float64))
self.assertAllClose(x, 42.0)
def testAlias(self):
- with self.test_session():
+ with self.cached_session():
np_array = np.array([1.0, 2.0], dtype=np.float32)
tf_array = script_ops.py_func(lambda: np_array, [], [dtypes.float32])
value = tf_array + constant_op.constant([2.0, 3.0], dtype=dtypes.float32)
@@ -252,7 +252,7 @@ class PyFuncTest(test.TestCase):
self.assertAllEqual(np_array, [1.0, 2.0])
def testReturnUnicodeString(self):
- with self.test_session():
+ with self.cached_session():
correct = u"你好 世界"
def unicode_string():
@@ -262,7 +262,7 @@ class PyFuncTest(test.TestCase):
self.assertEqual(z.eval(), correct.encode("utf8"))
def testBadNumpyReturnType(self):
- with self.test_session():
+ with self.cached_session():
def bad():
# Structured numpy arrays aren't supported.
@@ -275,7 +275,7 @@ class PyFuncTest(test.TestCase):
y.eval()
def testBadReturnType(self):
- with self.test_session():
+ with self.cached_session():
def bad():
# Non-string python objects aren't supported.
@@ -288,7 +288,7 @@ class PyFuncTest(test.TestCase):
z.eval()
def testReturnInput(self):
- with self.test_session():
+ with self.cached_session():
def ident(x):
return x[0]
@@ -303,7 +303,7 @@ class PyFuncTest(test.TestCase):
self.assertEqual(0.0, z.eval(feed_dict={p: [0.0]}))
def testStateful(self):
- # Not using self.test_session(), which disables optimization.
+ # Not using self.cached_session(), which disables optimization.
with session_lib.Session() as sess:
producer = iter(range(3))
x, = script_ops.py_func(lambda: next(producer), [], [dtypes.int64])
@@ -312,7 +312,7 @@ class PyFuncTest(test.TestCase):
self.assertEqual(sess.run(x), 2)
def testStateless(self):
- # Not using self.test_session(), which disables optimization.
+ # Not using self.cached_session(), which disables optimization.
with session_lib.Session() as sess:
producer = iter(range(3))
x, = script_ops.py_func(
@@ -331,7 +331,7 @@ class PyFuncTest(test.TestCase):
self.assertEqual(None, ops.get_gradient_function(y.op))
def testCOrder(self):
- with self.test_session():
+ with self.cached_session():
val = [[1, 2], [3, 4]]
x, = script_ops.py_func(lambda: np.array(val, order="F"), [],
[dtypes.int64])
@@ -339,7 +339,7 @@ class PyFuncTest(test.TestCase):
def testParallel(self):
# Tests that tf.py_func's can run in parallel if they release the GIL.
- with self.test_session() as session:
+ with self.cached_session() as session:
q = queue.Queue(1)
def blocking_put():
@@ -375,7 +375,7 @@ class PyFuncTest(test.TestCase):
def value(self):
return self._value
- with self.test_session():
+ with self.cached_session():
s = State()
op = s.increment(constant_op.constant(2, dtypes.int64))
ret = self.evaluate(op)
@@ -389,7 +389,7 @@ class PyFuncTest(test.TestCase):
f = script_ops.py_func(
do_nothing, [constant_op.constant(3, dtypes.int64)], [], stateful=False)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(sess.run(f), [])
def _testExceptionHandling(self, py_exp, tf_exp, eager=False):
@@ -417,21 +417,22 @@ class PyFuncTest(test.TestCase):
else:
f = script_ops.py_func(raise_exception, [], [])
- with self.test_session():
- with self.assertRaisesWithPredicateMatch(tf_exp, expected_error_check):
- self.evaluate(f)
+ with self.assertRaisesWithPredicateMatch(tf_exp, expected_error_check):
+ self.evaluate(f)
def testExceptionHandling(self):
- self._testExceptionHandling(ValueError, errors.InvalidArgumentError)
- self._testExceptionHandling(TypeError, errors.InvalidArgumentError)
- self._testExceptionHandling(StopIteration, errors.OutOfRangeError)
- self._testExceptionHandling(MemoryError, errors.ResourceExhaustedError)
- self._testExceptionHandling(NotImplementedError, errors.UnimplementedError)
+ with self.cached_session():
+ self._testExceptionHandling(ValueError, errors.InvalidArgumentError)
+ self._testExceptionHandling(TypeError, errors.InvalidArgumentError)
+ self._testExceptionHandling(StopIteration, errors.OutOfRangeError)
+ self._testExceptionHandling(MemoryError, errors.ResourceExhaustedError)
+ self._testExceptionHandling(NotImplementedError,
+ errors.UnimplementedError)
- class WeirdError(Exception):
- pass
+ class WeirdError(Exception):
+ pass
- self._testExceptionHandling(WeirdError, errors.UnknownError)
+ self._testExceptionHandling(WeirdError, errors.UnknownError)
# ----- Tests shared by py_func and eager_py_func -----
def testCleanup(self):
@@ -452,7 +453,7 @@ class PyFuncTest(test.TestCase):
# (see #18292)
_ = script_ops.py_func(lambda x: x + c.shape[0], [c], [dtypes.float32])
_ = script_ops.eager_py_func(lambda x: x + c.shape[0], [c], [dtypes.float32])
-
+
# Call garbage collector to enforce deletion.
make_graphs()
ops.reset_default_graph()
@@ -610,7 +611,7 @@ class PyFuncTest(test.TestCase):
func=log_huber, inp=[x, m], Tout=dtypes.float32)
dy_dx = gradients_impl.gradients(y, x)[0]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Takes the first branch of log_huber.
y, dy_dx = sess.run([y, dy_dx], feed_dict={x: 1.0, m: 2.0})
self.assertEqual(y, 1.0)
diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
index d0ed08933d..f90545f84c 100644
--- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py
+++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
@@ -54,7 +54,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
self.assertEqual(0, len(gc.garbage))
def testHandleDtypeShapeMatch(self):
- with self.test_session():
+ with self.cached_session():
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
with self.assertRaises(ValueError):
resource_variable_ops.assign_variable_op(
@@ -123,7 +123,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
self.assertFalse(np.allclose(variable.numpy(), copied_variable.numpy()))
def testGraphDeepCopy(self):
- with self.test_session():
+ with self.cached_session():
init_value = np.ones((4, 4, 4))
variable = resource_variable_ops.ResourceVariable(init_value,
name="init")
@@ -145,13 +145,13 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
# variable graph.
def testFetchHandle(self):
- with self.test_session():
+ with self.cached_session():
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1], name="foo")
self.assertGreater(len(handle.eval()), 0)
def testCachedValueReadBeforeWrite(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
v = resource_variable_ops.ResourceVariable(0.0, caching_device="cpu:0")
sess.run(v.initializer)
value, _ = sess.run([v, v.assign_add(1.0)])
@@ -492,7 +492,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
# TODO(alive): how should this work in Eager mode?
def testInitFn(self):
- with self.test_session():
+ with self.cached_session():
v = resource_variable_ops.ResourceVariable(
initial_value=lambda: 1, dtype=dtypes.float32)
self.assertEqual(v.handle.op.colocation_groups(),
@@ -569,11 +569,11 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
self.assertEqual(2.0, self.evaluate(v.value()))
def testVariableDefInitializedInstances(self):
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
v_def = resource_variable_ops.ResourceVariable(
initial_value=constant_op.constant(3.0)).to_proto()
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
# v describes a VariableDef-based variable without an initial value.
v = resource_variable_ops.ResourceVariable(variable_def=v_def)
self.assertEqual(3.0, sess.run(v.initialized_value()))
@@ -584,7 +584,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
self.assertEqual(1.0, v.initialized_value().eval())
v_def.ClearField("initial_value_name")
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
# Restoring a legacy VariableDef proto that does not have
# initial_value_name set should still work.
v = resource_variable_ops.ResourceVariable(variable_def=v_def)
@@ -615,17 +615,16 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
@test_util.run_in_graph_and_eager_modes
def testSparseRead(self):
- with self.test_session():
- init_value = np.reshape(np.arange(np.power(4, 3)), (4, 4, 4))
- v = resource_variable_ops.ResourceVariable(
- constant_op.constant(init_value, dtype=dtypes.int32), name="var3")
- self.evaluate(variables.global_variables_initializer())
+ init_value = np.reshape(np.arange(np.power(4, 3)), (4, 4, 4))
+ v = resource_variable_ops.ResourceVariable(
+ constant_op.constant(init_value, dtype=dtypes.int32), name="var3")
+ self.evaluate(variables.global_variables_initializer())
- value = self.evaluate(v.sparse_read([0, 3, 1, 2]))
- self.assertAllEqual(init_value[[0, 3, 1, 2], ...], value)
+ value = self.evaluate(v.sparse_read([0, 3, 1, 2]))
+ self.assertAllEqual(init_value[[0, 3, 1, 2], ...], value)
def testToFromProto(self):
- with self.test_session():
+ with self.cached_session():
v = resource_variable_ops.ResourceVariable(1.0)
variables.global_variables_initializer().run()
@@ -686,7 +685,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
handle, ignore_lookup_error=True))
def testAssignDifferentShapes(self):
- with self.test_session() as sess, variable_scope.variable_scope(
+ with self.cached_session() as sess, variable_scope.variable_scope(
"foo", use_resource=True):
var = variable_scope.get_variable("x", shape=[1, 1], dtype=dtypes.float32)
placeholder = array_ops.placeholder(dtypes.float32)
@@ -728,7 +727,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
_ = w.value().op.get_attr("_class")
def testSharedName(self):
- with self.test_session():
+ with self.cached_session():
v = resource_variable_ops.ResourceVariable(300.0, name="var4")
variables.global_variables_initializer().run()
@@ -746,7 +745,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
resource_variable_ops.read_variable_op(x, v.dtype.base_dtype).eval()
def testSharedNameWithNamescope(self):
- with self.test_session():
+ with self.cached_session():
with ops.name_scope("foo"):
v = resource_variable_ops.ResourceVariable(300.0, name="var6")
self.assertEqual("foo/var6", v._shared_name) # pylint: disable=protected-access
@@ -774,7 +773,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
str(v.sparse_read(array_ops.placeholder(dtypes.int32)).shape))
def testSetInitialValue(self):
- with self.test_session():
+ with self.cached_session():
# Initialize variable with a value different from the initial value passed
# in the constructor.
v = resource_variable_ops.ResourceVariable(2.0)
diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py
index 562d11f0b0..a28cdc3b26 100644
--- a/tensorflow/python/kernel_tests/rnn_test.py
+++ b/tensorflow/python/kernel_tests/rnn_test.py
@@ -197,7 +197,7 @@ class RNNTest(test.TestCase):
else:
inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1))
- with self.test_session() as sess:
+ with self.cached_session(use_gpu=True) as sess:
outputs, state = rnn.dynamic_rnn(
cell, inputs, dtype=dtypes.float32, sequence_length=[4])
if not in_eager_mode:
@@ -217,7 +217,7 @@ class RNNTest(test.TestCase):
else:
inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1))
- with self.test_session() as sess:
+ with self.cached_session(use_gpu=True) as sess:
outputs, state = rnn.dynamic_rnn(
cell, inputs, dtype=dtypes.float32, sequence_length=[4])
if not in_eager_mode:
@@ -246,7 +246,7 @@ class RNNTest(test.TestCase):
else:
inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1))
- with self.test_session() as sess:
+ with self.cached_session(use_gpu=True) as sess:
outputs, state = rnn.dynamic_rnn(
cell, inputs, dtype=dtypes.float32, sequence_length=[4])
state = (state[0], state[1].stack())
@@ -321,7 +321,7 @@ class RNNTest(test.TestCase):
self._assert_cell_builds(contrib_rnn.IndyLSTMCell, f64, 5, 7, 3)
def testRNNWithKerasSimpleRNNCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
input_shape = 10
output_shape = 5
timestep = 4
@@ -354,7 +354,7 @@ class RNNTest(test.TestCase):
self.assertEqual(len(state), batch)
def testRNNWithKerasGRUCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
input_shape = 10
output_shape = 5
timestep = 4
@@ -387,7 +387,7 @@ class RNNTest(test.TestCase):
self.assertEqual(len(state), batch)
def testRNNWithKerasLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
input_shape = 10
output_shape = 5
timestep = 4
@@ -424,7 +424,7 @@ class RNNTest(test.TestCase):
self.assertEqual(len(state[1]), batch)
def testRNNWithStackKerasCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
input_shape = 10
output_shape = 5
timestep = 4
@@ -465,7 +465,7 @@ class RNNTest(test.TestCase):
self.assertEqual(len(s), batch)
def testStaticRNNWithKerasSimpleRNNCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
input_shape = 10
output_shape = 5
timestep = 4
@@ -567,7 +567,7 @@ class RNNTest(test.TestCase):
rnn_cell_impl.GRUCell(
32, kernel_initializer="ones", dtype=dtypes.float32)
]:
- with self.test_session():
+ with self.cached_session():
x = keras.Input((None, 5))
layer = keras.layers.RNN(cell)
y = layer(x)
diff --git a/tensorflow/python/ops/cond_v2_impl.py b/tensorflow/python/ops/cond_v2_impl.py
index c4e9c982b5..c6a6b2a7fa 100644
--- a/tensorflow/python/ops/cond_v2_impl.py
+++ b/tensorflow/python/ops/cond_v2_impl.py
@@ -180,16 +180,16 @@ def _IfGrad(op, *grads): # pylint: disable=invalid-name
def _get_func_graphs(if_op):
- """Returns `_FuncGraph`s for the input op branches.
+ """Returns `FuncGraph`s for the input op branches.
Args:
if_op: The _If Operation.
Returns:
- A 2-tuple of the `_FuncGraph`s of the then_branch and else_branch.
+ A 2-tuple of the `FuncGraph`s of the then_branch and else_branch.
"""
def _get_func_graph_for_branch(branch_name):
- """Generates and returns a _FuncGraph for the given branch."""
+ """Generates and returns a FuncGraph for the given branch."""
inputs = if_op.inputs[1:] # First input is pred.
input_shapes = [t.shape for t in inputs]
func_name = if_op.get_attr(branch_name).name
@@ -197,7 +197,7 @@ def _get_func_graphs(if_op):
# `if_op.graph` may not be the same as `ops.get_default_graph()` e.g.
# in the case of nested if ops or when the gradient is being computed
# from inside a Defun. We build the `func_graph` with `if_op.graph` as its
- # `outer_graph`. This resembles how the `_FuncGraph` was built in the
+ # `outer_graph`. This resembles how the `FuncGraph` was built in the
# forward pass. We need this so that we can resolve references to tensors
# in `func_graph` from its gradient graph in `_resolve_grad_inputs`.
with if_op.graph.as_default():
@@ -221,7 +221,7 @@ def _grad_fn(func_graph, grads):
func_graph's outputs w.r.t. its inputs.
Args:
- func_graph: function._FuncGraph. The corresponding forward-pass function.
+ func_graph: function.FuncGraph. The corresponding forward-pass function.
grads: The list of input gradient Tensors.
Returns:
@@ -259,7 +259,7 @@ def _grad_fn(func_graph, grads):
def _create_grad_func(func_graph, grads, name):
- """Returns the _FuncGraph representation of _grad_fn."""
+ """Returns the FuncGraph representation of _grad_fn."""
return _function.func_graph_from_py_func(
name, lambda: _grad_fn(func_graph, grads), [], {})
@@ -277,8 +277,8 @@ def _resolve_grad_inputs(cond_graph, grad_graph):
functions, this is always possible.
Args:
- cond_graph: function._FuncGraph. The forward-pass function.
- grad_graph: function._FuncGraph. The gradients function.
+ cond_graph: function.FuncGraph. The forward-pass function.
+ grad_graph: function.FuncGraph. The gradients function.
Returns:
A list of inputs tensors to be passed to grad_graph.
@@ -313,7 +313,7 @@ def _create_new_tf_function(func_graph):
"""Converts func_graph to a TF_Function and adds it to the current graph.
Args:
- func_graph: function._FuncGraph
+ func_graph: function.FuncGraph
Returns:
The name of the new TF_Function.
@@ -365,8 +365,8 @@ def _pad_params(true_graph, false_graph, true_params, false_params):
There is no merging of params.
Args:
- true_graph: function._FuncGraph
- false_graph: function._FuncGraph
+ true_graph: function.FuncGraph
+ false_graph: function.FuncGraph
true_params: a list of Tensors from true_graph
false_params: a list of Tensors from false_graph
@@ -391,8 +391,8 @@ def _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs):
graph to avoid duplicating shared arguments.
Args:
- true_graph: function._FuncGraph
- false_graph: function._FuncGraph
+ true_graph: function.FuncGraph
+ false_graph: function.FuncGraph
true_inputs: a list of Tensors in the outer graph. The inputs for
true_graph.
false_inputs: a list of Tensors in the outer graph. The inputs for
@@ -421,7 +421,7 @@ def _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs):
_create_dummy_params(false_graph, true_only_inputs) +
[false_input_to_param[t] for t in false_only_inputs])
- # Rewrite the _FuncGraphs' state to reflect the new inputs.
+ # Rewrite the FuncGraphs' state to reflect the new inputs.
true_graph.captures = collections.OrderedDict(zip(new_inputs,
true_graph.inputs))
false_graph.captures = collections.OrderedDict(zip(new_inputs,
@@ -434,7 +434,7 @@ def _create_dummy_params(func_graph, template_tensors):
"""Creates tensors in func_graph to represent template_tensors.
Args:
- func_graph: function._FuncGraph.
+ func_graph: function.FuncGraph.
template_tensors: a list of tensors in the outer graph.
Returns:
@@ -451,27 +451,16 @@ def _get_grad_fn_name(func_graph):
Ensures this name is unique in the entire hierarchy.
Args:
- func_graph: The _FuncGraph.
+ func_graph: The FuncGraph.
Returns:
A string, the name to use for the gradient function.
"""
name = "%s_grad" % func_graph.name
-
- base_name = name
- counter = 1
- has_conflict = True
- while has_conflict:
- curr_graph = func_graph.outer_graph
- has_conflict = curr_graph._is_function(name)
- while not has_conflict and isinstance(curr_graph, _function.FuncGraph):
- curr_graph = curr_graph.outer_graph
- has_conflict = curr_graph._is_function(name)
- if has_conflict:
- name = "%s_%s" % (base_name, counter)
- counter += 1
-
- return name
+ outer_most_graph = func_graph
+ while isinstance(outer_most_graph, _function.FuncGraph):
+ outer_most_graph = outer_most_graph.outer_graph
+ return outer_most_graph.unique_name(name)
def _check_same_outputs(true_graph, false_graph):
diff --git a/tensorflow/python/ops/distributions/distribution.py b/tensorflow/python/ops/distributions/distribution.py
index ddf9442cd2..578e7b7dd2 100644
--- a/tensorflow/python/ops/distributions/distribution.py
+++ b/tensorflow/python/ops/distributions/distribution.py
@@ -446,6 +446,24 @@ class Distribution(_BaseDistribution):
self._graph_parents = graph_parents
self._name = name
+ @property
+ def _parameters(self):
+ return self._parameter_dict
+
+ @_parameters.setter
+ def _parameters(self, value):
+ """Intercept assignments to self._parameters to avoid reference cycles.
+
+ Parameters are often created using locals(), so we need to clean out any
+ references to `self` before assigning it to an attribute.
+
+ Args:
+ value: A dictionary of parameters to assign to the `_parameters` property.
+ """
+ if "self" in value:
+ del value["self"]
+ self._parameter_dict = value
+
@classmethod
def param_shapes(cls, sample_shape, name="DistributionParamShapes"):
"""Shapes of parameters given the desired shape of a call to `sample()`.
diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h
index 7f851e3646..f25ed700d6 100644
--- a/tensorflow/stream_executor/blas.h
+++ b/tensorflow/stream_executor/blas.h
@@ -41,6 +41,7 @@ limitations under the License.
#define TENSORFLOW_STREAM_EXECUTOR_BLAS_H_
#include <complex>
+#include <vector>
#include "tensorflow/stream_executor/host_or_device_scalar.h"
#include "tensorflow/stream_executor/lib/array_slice.h"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt
index eb41deee13..9f6dcd8fdb 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt
@@ -9,16 +9,14 @@ tf_proto {
type: TYPE_STRING
}
field {
- name: "client_handles_error_formatting"
- number: 2
- label: LABEL_OPTIONAL
- type: TYPE_BOOL
- }
- field {
name: "executor_type"
number: 3
label: LABEL_OPTIONAL
type: TYPE_STRING
}
+ reserved_range {
+ start: 2
+ end: 3
+ }
}
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt
index e565b903d2..f3a515163d 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt
@@ -132,17 +132,15 @@ tf_proto {
type: TYPE_STRING
}
field {
- name: "client_handles_error_formatting"
- number: 2
- label: LABEL_OPTIONAL
- type: TYPE_BOOL
- }
- field {
name: "executor_type"
number: 3
label: LABEL_OPTIONAL
type: TYPE_STRING
}
+ reserved_range {
+ start: 2
+ end: 3
+ }
}
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.-experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.-experimental.pbtxt
index eb41deee13..9f6dcd8fdb 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.-experimental.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.-experimental.pbtxt
@@ -9,16 +9,14 @@ tf_proto {
type: TYPE_STRING
}
field {
- name: "client_handles_error_formatting"
- number: 2
- label: LABEL_OPTIONAL
- type: TYPE_BOOL
- }
- field {
name: "executor_type"
number: 3
label: LABEL_OPTIONAL
type: TYPE_STRING
}
+ reserved_range {
+ start: 2
+ end: 3
+ }
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.pbtxt
index e565b903d2..f3a515163d 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.pbtxt
@@ -132,17 +132,15 @@ tf_proto {
type: TYPE_STRING
}
field {
- name: "client_handles_error_formatting"
- number: 2
- label: LABEL_OPTIONAL
- type: TYPE_BOOL
- }
- field {
name: "executor_type"
number: 3
label: LABEL_OPTIONAL
type: TYPE_STRING
}
+ reserved_range {
+ start: 2
+ end: 3
+ }
}
}
}
diff --git a/tensorflow/tools/ci_build/Dockerfile.gpu b/tensorflow/tools/ci_build/Dockerfile.gpu
index f05c7a4809..a4cad4b6c6 100644
--- a/tensorflow/tools/ci_build/Dockerfile.gpu
+++ b/tensorflow/tools/ci_build/Dockerfile.gpu
@@ -30,3 +30,4 @@ RUN mkdir /usr/local/cuda-9.0/lib && \
# Configure the build for our CUDA configuration.
ENV TF_NEED_CUDA 1
+ENV TF_NEED_TENSORRT 1
diff --git a/tensorflow/tools/ci_build/install/install_deb_packages.sh b/tensorflow/tools/ci_build/install/install_deb_packages.sh
index 9640810533..179fc42d60 100755
--- a/tensorflow/tools/ci_build/install/install_deb_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_deb_packages.sh
@@ -67,6 +67,12 @@ apt-get install -y --no-install-recommends \
zip \
zlib1g-dev
+apt-get update && \
+ apt-get install nvinfer-runtime-trt-repo-ubuntu1604-4.0.1-ga-cuda9.0 && \
+ apt-get update && \
+ apt-get install libnvinfer4=4.1.2-1+cuda9.0 && \
+ apt-get install libnvinfer-dev=4.1.2-1+cuda9.0
+
# populate the database
updatedb
diff --git a/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh b/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh
index f958b3c9b7..60c974c36b 100755
--- a/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh
+++ b/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh
@@ -52,6 +52,7 @@ ${DOCKER_BINARY} run \
-e "PYTHON_BIN_PATH=/usr/bin/python" \
-e "TF_NEED_HDFS=0" \
-e "TF_NEED_CUDA=${TF_NEED_CUDA}" \
+ -e "TF_NEED_TENSORRT=${TF_NEED_CUDA}" \
-e "TF_NEED_OPENCL_SYCL=0" \
"${DOCKER_IMAGE}" \
"/workspace/tensorflow/tools/ci_build/linux/libtensorflow.sh"
diff --git a/tensorflow/tools/docs/parser.py b/tensorflow/tools/docs/parser.py
index 997afc6ac7..549056c6c4 100644
--- a/tensorflow/tools/docs/parser.py
+++ b/tensorflow/tools/docs/parser.py
@@ -947,6 +947,7 @@ class _ClassPageInfo(object):
self._aliases = None
self._doc = None
self._guides = None
+ self._namedtuplefields = None
self._bases = None
self._properties = []
@@ -1030,6 +1031,17 @@ class _ClassPageInfo(object):
self._guides = guides
@property
+ def namedtuplefields(self):
+ return self._namedtuplefields
+
+ def set_namedtuplefields(self, py_class):
+ if issubclass(py_class, tuple):
+ if all(
+ hasattr(py_class, attr)
+ for attr in ('_asdict', '_fields', '_make', '_replace')):
+ self._namedtuplefields = py_class._fields
+
+ @property
def bases(self):
"""Returns a list of `_LinkInfo` objects pointing to the class' parents."""
return self._bases
@@ -1066,7 +1078,15 @@ class _ClassPageInfo(object):
@property
def properties(self):
"""Returns a list of `_PropertyInfo` describing the class' properties."""
- return self._properties
+ props_dict = {prop.short_name: prop for prop in self._properties}
+ props = []
+ if self.namedtuplefields:
+ for field in self.namedtuplefields:
+ props.append(props_dict.pop(field))
+
+ props.extend(sorted(props_dict.values()))
+
+ return props
def _add_property(self, short_name, full_name, obj, doc):
"""Adds a `_PropertyInfo` entry to the `properties` list.
@@ -1077,6 +1097,9 @@ class _ClassPageInfo(object):
obj: The property object itself
doc: The property's parsed docstring, a `_DocstringInfo`.
"""
+ # Hide useless namedtuple docs-trings
+ if re.match('Alias for field number [0-9]+', doc.docstring):
+ doc = doc._replace(docstring='', brief='')
property_info = _PropertyInfo(short_name, full_name, obj, doc)
self._properties.append(property_info)
@@ -1156,6 +1179,7 @@ class _ClassPageInfo(object):
py_class: The class object being documented
parser_config: An instance of ParserConfig.
"""
+ self.set_namedtuplefields(py_class)
doc_path = documentation_path(self.full_name)
relative_path = os.path.relpath(
path='.', start=os.path.dirname(doc_path) or '.')
diff --git a/tensorflow/tools/docs/parser_test.py b/tensorflow/tools/docs/parser_test.py
index 9f6b185e81..71e96afa10 100644
--- a/tensorflow/tools/docs/parser_test.py
+++ b/tensorflow/tools/docs/parser_test.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
import functools
import os
import sys
@@ -190,6 +191,50 @@ class ParserTest(googletest.TestCase):
# Make sure this file is contained as the definition location.
self.assertEqual(os.path.relpath(__file__, '/'), page_info.defined_in.path)
+ def test_namedtuple_field_order(self):
+ namedtupleclass = collections.namedtuple('namedtupleclass',
+ {'z', 'y', 'x', 'w', 'v', 'u'})
+
+ index = {
+ 'namedtupleclass': namedtupleclass,
+ 'namedtupleclass.u': namedtupleclass.u,
+ 'namedtupleclass.v': namedtupleclass.v,
+ 'namedtupleclass.w': namedtupleclass.w,
+ 'namedtupleclass.x': namedtupleclass.x,
+ 'namedtupleclass.y': namedtupleclass.y,
+ 'namedtupleclass.z': namedtupleclass.z,
+ }
+
+ visitor = DummyVisitor(index=index, duplicate_of={})
+
+ reference_resolver = parser.ReferenceResolver.from_visitor(
+ visitor=visitor, doc_index={}, py_module_names=['tf'])
+
+ tree = {'namedtupleclass': {'u', 'v', 'w', 'x', 'y', 'z'}}
+ parser_config = parser.ParserConfig(
+ reference_resolver=reference_resolver,
+ duplicates={},
+ duplicate_of={},
+ tree=tree,
+ index=index,
+ reverse_index={},
+ guide_index={},
+ base_dir='/')
+
+ page_info = parser.docs_for_object(
+ full_name='namedtupleclass',
+ py_object=namedtupleclass,
+ parser_config=parser_config)
+
+ # Each namedtiple field has a docstring of the form:
+ # 'Alias for field number ##'. These props are returned sorted.
+
+ def sort_key(prop_info):
+ return int(prop_info.obj.__doc__.split(' ')[-1])
+
+ self.assertSequenceEqual(page_info.properties,
+ sorted(page_info.properties, key=sort_key))
+
def test_docs_for_class_should_skip(self):
class Parent(object):
@@ -736,6 +781,5 @@ class TestGenerateSignature(googletest.TestCase):
sig = parser._generate_signature(example_fun, reverse_index={})
self.assertEqual(sig, ['arg1=a.b.c.d', 'arg2=a.b.c.d(1, 2)', "arg3=e['f']"])
-
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/tools/docs/pretty_docs.py b/tensorflow/tools/docs/pretty_docs.py
index aecf753a58..448f246e0e 100644
--- a/tensorflow/tools/docs/pretty_docs.py
+++ b/tensorflow/tools/docs/pretty_docs.py
@@ -136,7 +136,7 @@ def _build_class_page(page_info):
if page_info.properties:
parts.append('## Properties\n\n')
- for prop_info in sorted(page_info.properties):
+ for prop_info in page_info.properties:
h3 = '<h3 id="{short_name}"><code>{short_name}</code></h3>\n\n'
parts.append(h3.format(short_name=prop_info.short_name))
diff --git a/tensorflow/tools/graph_transforms/freeze_requantization_ranges.cc b/tensorflow/tools/graph_transforms/freeze_requantization_ranges.cc
index c8dc2a7c4d..d97496cbeb 100644
--- a/tensorflow/tools/graph_transforms/freeze_requantization_ranges.cc
+++ b/tensorflow/tools/graph_transforms/freeze_requantization_ranges.cc
@@ -92,7 +92,7 @@ Status ExtractMinMaxRecords(const string& log_file_name,
if (!str_util::EndsWith(name_string, print_suffix)) {
continue;
}
- string name = std::string(
+ string name(
name_string.substr(0, name_string.size() - print_suffix.size()));
records->push_back({name, min, max});
}
diff --git a/tensorflow/tools/graph_transforms/sparsify_gather_test.cc b/tensorflow/tools/graph_transforms/sparsify_gather_test.cc
index dd95779a1f..b8d6ba00de 100644
--- a/tensorflow/tools/graph_transforms/sparsify_gather_test.cc
+++ b/tensorflow/tools/graph_transforms/sparsify_gather_test.cc
@@ -42,8 +42,8 @@ class SparsifyGatherTest : public ::testing::Test {
const std::vector<NodeDef*>& inputs, GraphDef* graph_def,
bool control_dep = false) {
NodeDef* node_def = graph_def->add_node();
- node_def->set_name(std::string(name));
- node_def->set_op(std::string(op));
+ node_def->set_name(string(name));
+ node_def->set_op(string(op));
if (!control_dep) {
std::for_each(inputs.begin(), inputs.end(), [&node_def](NodeDef* input) {
node_def->add_input(input->name());
diff --git a/tensorflow/tools/graph_transforms/transform_graph.cc b/tensorflow/tools/graph_transforms/transform_graph.cc
index 5cae8f8d8f..7efe450710 100644
--- a/tensorflow/tools/graph_transforms/transform_graph.cc
+++ b/tensorflow/tools/graph_transforms/transform_graph.cc
@@ -65,19 +65,19 @@ Status ParseTransformParameters(const string& transforms_string,
.GetResult(&remaining, &transform_name);
if (!found_transform_name) {
return errors::InvalidArgument("Looking for transform name, but found ",
- std::string(remaining).c_str());
+ string(remaining).c_str());
}
if (Scanner(remaining).OneLiteral("(").GetResult(&remaining, &match)) {
state = TRANSFORM_PARAM_NAME;
} else {
// Add a transform with no parameters.
- params_list->push_back({std::string(transform_name), func_parameters});
+ params_list->push_back({string(transform_name), func_parameters});
transform_name = "";
state = TRANSFORM_NAME;
}
} else if (state == TRANSFORM_PARAM_NAME) {
if (Scanner(remaining).OneLiteral(")").GetResult(&remaining, &match)) {
- params_list->push_back({std::string(transform_name), func_parameters});
+ params_list->push_back({string(transform_name), func_parameters});
transform_name = "";
state = TRANSFORM_NAME;
} else {
@@ -92,13 +92,13 @@ Status ParseTransformParameters(const string& transforms_string,
if (!found_parameter_name) {
return errors::InvalidArgument(
"Looking for parameter name, but found ",
- std::string(remaining).c_str());
+ string(remaining).c_str());
}
if (Scanner(remaining).OneLiteral("=").GetResult(&remaining, &match)) {
state = TRANSFORM_PARAM_VALUE;
} else {
return errors::InvalidArgument("Looking for =, but found ",
- std::string(remaining).c_str());
+ string(remaining).c_str());
}
}
} else if (state == TRANSFORM_PARAM_VALUE) {
@@ -120,10 +120,9 @@ Status ParseTransformParameters(const string& transforms_string,
}
if (!found_parameter_value) {
return errors::InvalidArgument("Looking for parameter name, but found ",
- std::string(remaining).c_str());
+ string(remaining).c_str());
}
- func_parameters[std::string(parameter_name)].push_back(
- std::string(parameter_value));
+ func_parameters[string(parameter_name)].emplace_back(parameter_value);
// Eat up any trailing quotes.
Scanner(remaining).ZeroOrOneLiteral("\"").GetResult(&remaining, &match);
Scanner(remaining).ZeroOrOneLiteral("'").GetResult(&remaining, &match);
diff --git a/tensorflow/tools/graph_transforms/transform_utils.cc b/tensorflow/tools/graph_transforms/transform_utils.cc
index cb084e49b7..c715380aae 100644
--- a/tensorflow/tools/graph_transforms/transform_utils.cc
+++ b/tensorflow/tools/graph_transforms/transform_utils.cc
@@ -93,7 +93,7 @@ void NodeNamePartsFromInput(const string& input_name, string* prefix,
} else {
*prefix = "";
}
- *node_name = std::string(node_name_piece);
+ *node_name = string(node_name_piece);
}
string NodeNameFromInput(const string& input_name) {
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 997725d865..742f33f68e 100755
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -491,11 +491,11 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
tf_http_archive(
name = "llvm",
urls = [
- "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/67bd0d9a0f5597f57f272061fd70f24dffb3d223.tar.gz",
- "https://github.com/llvm-mirror/llvm/archive/67bd0d9a0f5597f57f272061fd70f24dffb3d223.tar.gz",
+ "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/dc6d9ec3646865125d057b6f515b4543df79920a.tar.gz",
+ "https://github.com/llvm-mirror/llvm/archive/dc6d9ec3646865125d057b6f515b4543df79920a.tar.gz",
],
- sha256 = "b8f4ffbcaeea345e2245fd7028c7e960d71c2a2007c20bbfc5d79ecc86992a5e",
- strip_prefix = "llvm-67bd0d9a0f5597f57f272061fd70f24dffb3d223",
+ sha256 = "c7252290a113f694cccbb4b325c67b56f3aa6f5b3044524302c0e79db2da7e2a",
+ strip_prefix = "llvm-dc6d9ec3646865125d057b6f515b4543df79920a",
build_file = clean_dep("//third_party/llvm:llvm.autogenerated.BUILD"),
)